mnist训练没有变化
来源:10-4 训练模型
Hank_桁
2019-12-24
老师您好
在训练mnist识别手写数字的模型里,我按照老师您的代码编写,运行的时候发现loss和acc没有变化。感觉像没有数据喂给模型一样
这里是运行效果:
这里是复现的代码:
import * as tf from '@tensorflow/tfjs'
import * as tfvis from '@tensorflow/tfjs-vis'
import {MnistData} from './data'
window.onload = async () => {
const data = new MnistData()
await data.load()
const examples = data.nextTestBatch(20)
const surface = tfvis.visor().surface({ name: '输入示例' })
for (let i = 0; i < 20; i++) {
const imageTensor = tf.tidy(() => {
return examples.xs.slice([i, 0], [1, 784]).reshape([28, 28, 1])
})
const canvas = document.createElement('canvas')
canvas.width = 28
canvas.height = 28
canvas.style = 'margin: 4px;'
await tf.browser.toPixels(imageTensor, canvas)
surface.drawArea.appendChild(canvas)
}
const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
kernelSize: 5,
filters: 8,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));
model.add(tf.layers.maxPool2d({
poolSize: [2, 2],
strides: [2, 2]
}));
model.add(tf.layers.conv2d({
kernelSize: 5,
filters: 16,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));
model.add(tf.layers.maxPool2d({
poolSize: [2, 2],
strides: [2, 2]
}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({
units: 10,
activation: 'softmax',
kernelInitializer: 'varianceScaling'
}));
model.compile({
loss: 'categoricalCrossentropy',
optimizer: tf.train.adam(),
metrics: ['accuracy']
});
const [trainXs, trainYs] = tf.tidy(() => {
const d = data.nextTrainBatch(1000);
return [
d.xs.reshape([1000, 28, 28, 1]),
d.labels
];
});
const [testXs, testYs] = tf.tidy(() => {
const d = data.nextTestBatch(200);
return [
d.xs.reshape([200, 28, 28, 1]),
d.labels
];
});
await model.fit(trainXs, trainYs, {
validationData: [testXs, testYs],
batchSize: 500,
epochs: 20,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss', 'val_loss', 'acc', 'val_acc'],
{ callbacks: ['onEpochEnd'] }
)
});
}
PS:我把mnist的图片和二进制文件放到打包完的dist目录里,没有用http-server来启动静态服务器,不知道是否跟这个有关系
写回答
1回答
-
当然有关,需要开启静态文件服务器
012019-12-25