缩放点积注意力计算拆分embedding dim问题
来源:10-23 DecoderLayer实现
qq_慕前端4252840
2021-08-13
老师,我注意到在计算缩放点积注意力得时候,是把embeding dim拆分成numhead,depth,再交换用seq_len,depth去做注意力的计算。我想问一下这样做得依据是什么呢?为什么是有效的?好像所有的transformer介绍里都不讲为什么要拆分。
如果我不做拆分,直接用seq_len,embedding dim去做计算,是不是就是单头注意力。
写回答
1回答
-
正十七
2021-08-19
embedding dim拆分是为了不增加参数量,相当于每个head计算注意力的时候size都是depth。你理解的对,如果直接用embedding dim去计算,就是单头。
而为什么拆分会有效?则是一个比较open的问题。我的看法是,多头注意力提供了更多样的注意力组合。比如在翻译问题上,source输入是 A B C D E, target输出是U V W X Y。 那么对于W来说,第一个头可能主要倾向attend A,第二个头则是attend A和B,第三个头主要attend C。以此类推。当输入输出序列都较长的时候,这种组合能捕捉到更多的信息。
当然,还有很多其他的讨论,可以参考这里:https://www.zhihu.com/question/341222779
00
相似问题