tf.keras.layers.Embedding的batch_input_shape和input_length参数

来源:7-7 文本生成实战之构建模型

wxz123

2019-12-23

图片描述
图片描述
问题1:老师这个tf.keras.layers.Embedding的batch_input_shape这个参数在tf2.0的官方文档中并没有提及,这是怎么回事呢(难道是官方文档不全吗)
问题2:还有batch_input_shape和之前用过的input_length不都是控制输入的shape(一个是[batch_size,None],另一个是timestep),那为什么一个函数要设置两个一样功能的参数呢?

写回答

1回答

正十七

2020-02-05

同学你好,抱歉这个问题耽误了很长时间。通过查看源码,我发现对于这个batch_input_shape参数,它是这样的,虽然embedding这个layer的参数中没有这个参数,但是可以通过**kwargs参数将其传进去。

https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/python/keras/layers/embeddings.py#L91-L123

而传进去后,这个参数继续传递给embedding的父类Layer,https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/python/keras/engine/base_layer.py

在这个Layer中才被使用去设置input_shape.


而对于问题2,两个参数确实有重叠,batch_input_shape更全,可以把batch_size也设置了。input_length只能设置timestamp长度。正因为有重叠所以batch_input_shape被放在了**kwargs里而不是显示为人所见。

推荐使用官方文档里显示出现的参数。而至于batch_size,应该可以像其他程序那样在其他地方设置。

1
0

Google老师亲授 TensorFlow2.0 入门到进阶

Tensorflow2.0实战—以实战促理论的方式学习深度学习

1849 学习 · 896 问题

查看课程