RNN为什么能解决输入不定长的问题

来源:7-1 序列式问题

慕瓜7596423

2018-11-02

老师好,我看了RNN这几节课,但还是不明白为什么它可以解决不定长问题。假设对于一个文本分类问题,如果用mini batch,那么对于一个文本就是一个样本,一个batch就是多个样本,这些样本的输入长度不同(因为每个文本词语数量不同),那我该怎么训练模型呢?

写回答

1回答

正十七

2018-11-07

理论层面,RNN的网络结构决定了它可以接受序列输入得到输出。

而在实现层面,可以有两种方式实现处理变长模型,因为是batch输入,所以batch中的样本仍然需要你做padding或者裁剪形成等长的形式。但是对于padding的样本,其实是可以控制它们输入到Lstm中的长度的。

https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn

这个API中有一个length的输入就是做这个的。比如你的batch中有三个样本:

样本1:[5, 3, 6, 3, 5, -1, -1, -1, -1, -1]

样本2:[4, 3, 5, 2, 7, 9, 9, 9, 0, -1]

样本3:[5, 4, 6, 7, 2, 8, 3, 7, 6, 0]

那么对应的length就是[5, 9, 10],这样Lstm就知道你的batch内的每个样本的输入长度是多少了。


另外一种方法是加mask,同样是对于上面三个样本,你可以做一个这样的mask

[1,1,1,1,1,0,0,0,0,0]

[1,1,1,1,1,1,1,1,1,0]

[1,1,1,1,1,1,1,1,1,1]

然后将这个mask矩阵乘到损失函数上去。这样,padding部分的内容就不会有梯度传下来。也能达到处理变长的效果。


1
0

深度学习之神经网络(CNN/RNN/GAN)算法原理+实战

深度学习算法工程师必学,深入理解深度学习核心算法CNN RNN GAN

2617 学习 · 935 问题

查看课程