截断模型与双层dense网络的融合方法
来源:12-5 迁移学习下的模型预测

zozo_zuo
2020-02-10
老师,之前提过一个问题(https://coding.imooc.com/learn/questiondetail/163183.html) ,现在找到一个解决方法。
融合方法
将截断模型当成一个层,代码如下:
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs } from './data';
import { img2x, file2img } from './utils';
const MOBILENET_MODEL_PATH = 'http://127.0.0.1:5500/data/mobilenet/web_model/model.json';
const NUM_CLASSES = 3;
const BRAND_CLASSES = ['android', 'apple', 'windows'];
window.onload = async () => {
const { inputs, labels } = await getInputs();
const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
inputs.forEach(img => {
surface.drawArea.appendChild(img);
});
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
// mobilenet.summary();
const layer = mobilenet.getLayer('conv_pw_13_relu');
const truncatedMobilenet = tf.model({
inputs: mobilenet.inputs,
outputs: layer.output
});
// 锁层
for (const layer of truncatedMobilenet.layers) {
layer.trainable = false;
}
const model = tf.sequential();
model.add(truncatedMobilenet); // 这里把截断模型当成一个层直接添加到model中
model.add(tf.layers.flatten({
// inputShape: layer.outputShape.slice(1)
}));
model.add(tf.layers.dense({
units: 10,
activation: 'relu'
}));
model.add(tf.layers.dense({
units: NUM_CLASSES,
activation: 'softmax'
}));
model.summary();
model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });
const { xs, ys } = tf.tidy(() => {
const xs = tf.concat(inputs.map(imgEl => img2x(imgEl))); // 删除掉原来truncatedMobilenet.predict的内容
const ys = tf.tensor(labels);
return { xs, ys };
});
await model.fit(xs, ys, {
epochs: 20,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss'],
{ callbacks: ['onEpochEnd'] }
)
});
window.predict = async (file) => {
const img = await file2img(file);
document.body.appendChild(img);
const pred = tf.tidy(() => {
const x = img2x(img);
// const input = truncatedMobilenet.predict(x);
return model.predict(x);
});
const index = pred.argMax(1).dataSync()[0];
setTimeout(() => {
alert(`预测结果:${BRAND_CLASSES[index]}`);
}, 0);
};
window.download = async () => {
await model.save('downloads://model');
};
};
这里直接将truncatedMobilenet
当成一个层,添加到model中,我这边测试一下可以使用,模型可以保存下来。
写回答
2回答
-
广州_小彭
2021-02-20
这个方法不行,损失率将不下来
00 -
lewis
2020-02-10
很不错!给你点个大大的赞~
00
相似问题