截断模型与双层dense网络的融合方法

来源:12-5 迁移学习下的模型预测

zozo_zuo

2020-02-10

老师,之前提过一个问题(https://coding.imooc.com/learn/questiondetail/163183.html) ,现在找到一个解决方法。

融合方法

将截断模型当成一个层,代码如下:

import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs } from './data';
import { img2x, file2img } from './utils';

const MOBILENET_MODEL_PATH = 'http://127.0.0.1:5500/data/mobilenet/web_model/model.json';
const NUM_CLASSES = 3;
const BRAND_CLASSES = ['android', 'apple', 'windows'];

window.onload = async () => {
    const { inputs, labels } = await getInputs();
    const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
    inputs.forEach(img => {
        surface.drawArea.appendChild(img);
    });

    const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
    // mobilenet.summary();
    const layer = mobilenet.getLayer('conv_pw_13_relu');
    const truncatedMobilenet = tf.model({
        inputs: mobilenet.inputs,
        outputs: layer.output
    });
	// 锁层
    for (const layer of truncatedMobilenet.layers) {   
        layer.trainable = false;
    }

    const model = tf.sequential();
    model.add(truncatedMobilenet);  // 这里把截断模型当成一个层直接添加到model中
    model.add(tf.layers.flatten({
        // inputShape: layer.outputShape.slice(1)
    }));
    model.add(tf.layers.dense({
        units: 10,
        activation: 'relu'
    }));
    model.add(tf.layers.dense({
        units: NUM_CLASSES,
        activation: 'softmax'
    }));

    model.summary();
    model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });

    const { xs, ys } = tf.tidy(() => {
        const xs = tf.concat(inputs.map(imgEl => img2x(imgEl)));   // 删除掉原来truncatedMobilenet.predict的内容
        const ys = tf.tensor(labels);
        return { xs, ys };
    });

    await model.fit(xs, ys, {
        epochs: 20,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练效果' },
            ['loss'],
            { callbacks: ['onEpochEnd'] }
        )
    });

    window.predict = async (file) => {
        const img = await file2img(file);
        document.body.appendChild(img);
        const pred = tf.tidy(() => {
            const x = img2x(img);
            // const input = truncatedMobilenet.predict(x);
            return model.predict(x);
        });

        const index = pred.argMax(1).dataSync()[0];
        setTimeout(() => {
            alert(`预测结果:${BRAND_CLASSES[index]}`);
        }, 0);
    };

    window.download = async () => {
        await model.save('downloads://model');
    };
};

这里直接将truncatedMobilenet当成一个层,添加到model中,我这边测试一下可以使用,模型可以保存下来。

写回答

2回答

广州_小彭

2021-02-20

这个方法不行,损失率将不下来

0
0

lewis

2020-02-10

很不错!给你点个大大的赞~

0
0

JavaScript玩转机器学习-Tensorflow.js项目实战

机器学习理论知识+Tensorflow.js实战开发

644 学习 · 189 问题

查看课程