lstm cell实现的过程中h_size与h 含义的困惑?

来源:7-17 LSTM单元内部结构实现

慕数据4013138

2019-07-26

老师您好,真的纠结好久了,反复看课程还是很困惑

    ###对应7-17的课
    # 定义输入门
    with tf.variable_scope('inputs'):
        ix, ih, ib = _generate_params_for_lstm_cell(
            x_size = [hps.num_embedding_size, hps.num_lstm_nodes[0]], #x_size的大小为一个矩阵【embedding_size的大小,第一层lstm的隐含状态Ct维 度的大小】 
            ## 注意此node只是作者命名时没注意,应该是Ct的维度
            h_size = [hps.num_lstm_nodes[0], hps.num_lstm_nodes[0]], #???
            bias_size = [1, hps.num_lstm_nodes[0]]
        )
          #  ?中间的隐含状态C't,   
        state = tf.Variable( 
            tf.zeros([batch_size, hps.num_lstm_nodes[0]]),
            trainable = False
        )
        # ??# h 上一步的输出  
        h = tf.Variable( 
            tf.zeros([batch_size, hps.num_lstm_nodes[0]]),
            trainable = False

1)这个h_size 的定义h_size = [hps.num_lstm_nodes[0], ps.num_lstm_nodes[0]], #???是
2)h 定义的时候,也说的是上一步的输出,
3)state 对应的是中间的隐含状态大小?中间隐含状态大小不是 h_size 吗?

写回答

1回答

正十七

2019-07-31

Lstm中隐含状态和lstm的输出是两个概念,对应到代码中是state和h

画问好的h_size代表的是上一步的state做全连接计算得到新的值,新的值的size也是num_lstm_nodes[0]。你可以看代码中下面的那些门的计算就知道它的作用了。h和oh,ih,fh做矩阵乘法,其中oh,ih,fh三个参数的size是h_size.

for i in range(num_timesteps):            
            # [batch_size, 1, embed_size]
            embed_input = embed_inputs[:, i, :]
            embed_input = tf.reshape(embed_input,
                                     [batch_size, hps.num_embedding_size])
            forget_gate = tf.sigmoid(
                tf.matmul(embed_input, fx) + tf.matmul(h, fh) + fb)
            input_gate = tf.sigmoid(
                tf.matmul(embed_input, ix) + tf.matmul(h, ih) + ib)
            output_gate = tf.sigmoid(
                tf.matmul(embed_input, ox) + tf.matmul(h, oh) + ob)
            mid_state = tf.tanh(
                tf.matmul(embed_input, cx) + tf.matmul(h, ch) + cb)
            state = mid_state * input_gate + state * forget_gate
            h = output_gate * tf.tanh(state)
        last = h


1
1
慕数据4013138
明白了,谢谢老师
2019-07-31
共1条回复

深度学习之神经网络(CNN/RNN/GAN)算法原理+实战

深度学习算法工程师必学,深入理解深度学习核心算法CNN RNN GAN

2617 学习 · 935 问题

查看课程