RNN为什么能解决输入不定长的问题
来源:7-1 序列式问题
慕瓜7596423
2018-11-02
老师好,我看了RNN这几节课,但还是不明白为什么它可以解决不定长问题。假设对于一个文本分类问题,如果用mini batch,那么对于一个文本就是一个样本,一个batch就是多个样本,这些样本的输入长度不同(因为每个文本词语数量不同),那我该怎么训练模型呢?
1回答
-
正十七
2018-11-07
理论层面,RNN的网络结构决定了它可以接受序列输入得到输出。
而在实现层面,可以有两种方式实现处理变长模型,因为是batch输入,所以batch中的样本仍然需要你做padding或者裁剪形成等长的形式。但是对于padding的样本,其实是可以控制它们输入到Lstm中的长度的。
https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn
这个API中有一个length的输入就是做这个的。比如你的batch中有三个样本:
样本1:[5, 3, 6, 3, 5, -1, -1, -1, -1, -1]
样本2:[4, 3, 5, 2, 7, 9, 9, 9, 0, -1]
样本3:[5, 4, 6, 7, 2, 8, 3, 7, 6, 0]
那么对应的length就是[5, 9, 10],这样Lstm就知道你的batch内的每个样本的输入长度是多少了。
另外一种方法是加mask,同样是对于上面三个样本,你可以做一个这样的mask
[1,1,1,1,1,0,0,0,0,0]
[1,1,1,1,1,1,1,1,1,0]
[1,1,1,1,1,1,1,1,1,1]
然后将这个mask矩阵乘到损失函数上去。这样,padding部分的内容就不会有梯度传下来。也能达到处理变长的效果。
10