@tf.function自定义train_step训练速度缓慢
来源:3-7 tf.function函数转换
OliverSong
2020-08-12
请教老师,个人训练skip-gram模型时,为什么用@tf.function加速train_step速度仅是1.5s/it?(不加速是3s/it)
模型向前计算速度正常极速,但是整个train_step极慢一个iter要1.5s。
@tf.function
def train_step(inp_w_id, inp_v_id, inp_neg_v_ids):
with tf.GradientTape() as tape:
loss = skip_gram_model(inp_w_id, inp_v_id, inp_neg_v_ids) ***# 这步模型向前计算运行时间极快,正常。***
loss_ = tf.reduce_mean(loss) #
variables = skip_gram_model.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss_
调用整个train_step,本行就需要1.5s。
train_step(inp_w_id, inp_v_id, inp_neg_v_ids)
是啥原因呢?apply_gradients的问题吗?s/it
写回答
1回答
-
正十七
2020-08-15
慢的原因有很多,比如你的batch_size是多少,词表有多大(词表太大会明显降低速度)。
按照你的问题描述,加上@tf.function把速度从3s降到1.5s,是有提升啊。继续提升速度恐怕要上GPU?
112020-08-17
相似问题