mnist训练没有变化

来源:10-4 训练模型

Hank_桁

2019-12-24

老师您好
在训练mnist识别手写数字的模型里,我按照老师您的代码编写,运行的时候发现loss和acc没有变化。感觉像没有数据喂给模型一样
这里是运行效果:
图片描述
这里是复现的代码:

import * as tf from '@tensorflow/tfjs'
import * as tfvis from '@tensorflow/tfjs-vis'
import {MnistData} from './data'

window.onload = async () => {
    const data = new MnistData()
    await data.load()
    const examples = data.nextTestBatch(20)
    const surface = tfvis.visor().surface({ name: '输入示例' })
    for (let i = 0; i < 20; i++) {
        const imageTensor = tf.tidy(() => {
            return examples.xs.slice([i, 0], [1, 784]).reshape([28, 28, 1])
        }) 

        const canvas = document.createElement('canvas')
        canvas.width = 28
        canvas.height = 28
        canvas.style = 'margin: 4px;'
        await tf.browser.toPixels(imageTensor, canvas)
        surface.drawArea.appendChild(canvas)
    }

    const model = tf.sequential();
    model.add(tf.layers.conv2d({
        inputShape: [28, 28, 1],
        kernelSize: 5,
        filters: 8,
        strides: 1,
        activation: 'relu',
        kernelInitializer: 'varianceScaling'
    }));
    model.add(tf.layers.maxPool2d({
        poolSize: [2, 2],
        strides: [2, 2]
    }));
    model.add(tf.layers.conv2d({
        kernelSize: 5,
        filters: 16,
        strides: 1,
        activation: 'relu',
        kernelInitializer: 'varianceScaling'
    }));
    model.add(tf.layers.maxPool2d({
        poolSize: [2, 2],
        strides: [2, 2]
    }));
    model.add(tf.layers.flatten());
    model.add(tf.layers.dense({
        units: 10,
        activation: 'softmax',
        kernelInitializer: 'varianceScaling'
    }));
    model.compile({
        loss: 'categoricalCrossentropy',
        optimizer: tf.train.adam(),
        metrics: ['accuracy']
    });

    const [trainXs, trainYs] = tf.tidy(() => {
        const d = data.nextTrainBatch(1000);
        return [
            d.xs.reshape([1000, 28, 28, 1]),
            d.labels
        ];
    });

    const [testXs, testYs] = tf.tidy(() => {
        const d = data.nextTestBatch(200);
        return [
            d.xs.reshape([200, 28, 28, 1]),
            d.labels
        ];
    });

    await model.fit(trainXs, trainYs, {
        validationData: [testXs, testYs],
        batchSize: 500,
        epochs: 20,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练效果' },
            ['loss', 'val_loss', 'acc', 'val_acc'],
            { callbacks: ['onEpochEnd'] }
        )
    });
}

PS:我把mnist的图片和二进制文件放到打包完的dist目录里,没有用http-server来启动静态服务器,不知道是否跟这个有关系
图片描述

写回答

1回答

lewis

2019-12-24

当然有关,需要开启静态文件服务器

0
1
Hank_桁
非常感谢!
2019-12-25
共1条回复

JavaScript玩转机器学习-Tensorflow.js项目实战

机器学习理论知识+Tensorflow.js实战开发

647 学习 · 189 问题

查看课程