embed_input 进行reshape疑惑

来源:7-17 LSTM单元内部结构实现

慕数据4013138

2019-07-26

卢老师您好,在手写lstm的时候,
embed_input为什么需要进行reshape?

            embed_input = tf.reshape(embed_input,   #去掉1这个维度,变为二维矩阵  ?
                                     [batch_size, hps.num_embedding_size])

可是为什么需要变为二维矩阵?这个变为矩阵后,拿到的是每个batch 中,当前timestep=t时的输入值?

写回答

1回答

正十七

2019-07-31

变成二维矩阵的原因是做了embed_input = embed_inputs[:, i, :] 之后,embed_input的大小就变成了[batch_size, 1, num_embedding_size], 变成矩阵就把中间的维度给消掉了。

不是变为矩阵后,而是变为矩阵前,拿到的就已经是单步的数据了。

0
0

深度学习之神经网络(CNN/RNN/GAN)算法原理+实战

深度学习算法工程师必学,深入理解深度学习核心算法CNN RNN GAN

2617 学习 · 935 问题

查看课程