关于stateful = True的几个问题

来源:7-8 文本生成实战之采样生成文本

wxz123

2020-01-04

老师针对

keras.layers.SimpleRNN(units = rnn_units,
                               stateful = True,
                               recurrent_initializer = 'glorot_uniform',
                               return_sequences = True),

的stateful = True有几个问题:
1、stateful=true时,batch的顺序不能被打乱,那么

seq_dataset = seq_dataset.shuffle(buffer_size).batch(
    batch_size, drop_remainder=True)

的.shuffle(buffer_size)是否应该去掉
2、

history = model.fit(seq_dataset, epochs = epochs,
                    callbacks = [checkpoint_callback])

fit函数是不是应该加上shuffle=False参数
3、在预测阶段

model.reset_states()
    temperature = 2
    for _ in range(num_generate):
        predictions = model(input_eval)
        predictions = predictions / temperature
        predictions = tf.squeeze(predictions, 0)
        predicted_id = tf.random.categorical(
            predictions, num_samples = 1)[-1, 0].numpy()
        text_generated.append(idx2char[predicted_id])
        input_eval = tf.expand_dims([predicted_id], 0)

由于batch_size=1,timestep=1,对于循环for _ in range(num_generate)中每次进行predictions = model(input_eval)前向传播得到的状态之所以能传到下一次循环作为初始状态,是否是依赖于stateful = True这个机制呢,若stateful = False,则循环中每次调用predictions = model(input_eval)是都将自动调用model.reset_states()将状态初始化,而和上一次循环完全无关,不知道这么理解的是否正确

写回答

1回答

正十七

2020-01-09

同学你好,关于1和2,你的理解是对的,数据不应该shuffle: https://stackoverflow.com/questions/44788946/shuffling-training-data-with-lstm-rnn

关于问题3, 在inference中因为我们执行的是单步,所以不会收到stateful的影响。

0
0

Google老师亲授 TensorFlow2.0 入门到进阶

Tensorflow2.0实战—以实战促理论的方式学习深度学习

1849 学习 · 896 问题

查看课程