关于stateful = True的几个问题
来源:7-8 文本生成实战之采样生成文本
wxz123
2020-01-04
老师针对
keras.layers.SimpleRNN(units = rnn_units,
stateful = True,
recurrent_initializer = 'glorot_uniform',
return_sequences = True),
的stateful = True有几个问题:
1、stateful=true时,batch的顺序不能被打乱,那么
seq_dataset = seq_dataset.shuffle(buffer_size).batch(
batch_size, drop_remainder=True)
的.shuffle(buffer_size)是否应该去掉
2、
history = model.fit(seq_dataset, epochs = epochs,
callbacks = [checkpoint_callback])
fit函数是不是应该加上shuffle=False参数
3、在预测阶段
model.reset_states()
temperature = 2
for _ in range(num_generate):
predictions = model(input_eval)
predictions = predictions / temperature
predictions = tf.squeeze(predictions, 0)
predicted_id = tf.random.categorical(
predictions, num_samples = 1)[-1, 0].numpy()
text_generated.append(idx2char[predicted_id])
input_eval = tf.expand_dims([predicted_id], 0)
由于batch_size=1,timestep=1
,对于循环for _ in range(num_generate)
中每次进行predictions = model(input_eval)
前向传播得到的状态之所以能传到下一次循环作为初始状态,是否是依赖于stateful = True
这个机制呢,若stateful = False
,则循环中每次调用predictions = model(input_eval)
是都将自动调用model.reset_states()
将状态初始化,而和上一次循环完全无关,不知道这么理解的是否正确
写回答
1回答
-
正十七
2020-01-09
同学你好,关于1和2,你的理解是对的,数据不应该shuffle: https://stackoverflow.com/questions/44788946/shuffling-training-data-with-lstm-rnn
关于问题3, 在inference中因为我们执行的是单步,所以不会收到stateful的影响。
00
相似问题