mask是干什么用的?
来源:8-19 计算图构建-rnn结构实现、损失函数与训练算子实现
qq_书山压力大EE_0
2019-02-27
with tf.variable_scope('loss'):
sentence_flatten = tf.reshape(sentence, [-1])
mask_flatten = tf.reshape(mask, [-1])
mask_sum = tf.reduce_sum(mask_flatten)
softmax_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=sentence_flatten)
weighted_softmax_loss = tf.multiply(softmax_loss,
tf.cast(mask_flatten, tf.float32))
prediction = tf.argmax(logits, 1, output_type = tf.int32)
correct_prediction = tf.equal(prediction, sentence_flatten)
correct_prediction_with_mask = tf.multiply(
tf.cast(correct_prediction, tf.float32),
mask_flatten)
accuracy = tf.reduce_sum(correct_prediction_with_mask) / mask_sum
loss = tf.reduce_sum(weighted_softmax_loss) / mask_sum
tf.summary.scalar('loss', loss)
这里loss函数 是 谁- 谁 的损失值? 其中 mask 做什么用的? 我不明白, 能不能麻烦老师 给我讲讲
写回答
1回答
-
因为我们在输入端做了对齐,所以很多样本的末尾其实不是真实数据而是补全的默认值,这些默认值参与训练的话会产生干扰。mask就是排除这个干扰的。比如,每个样本的长度为10,假设样本为:
[1,2,3,4,5,6,-1,-1,-1,-1],其中1-6代表的都是正常词语,-1代表的默认补全值。那么假设这个序列经过lstm后在各个位置上得到的输出值与真实值的损失为[1.5, 1.6, 1.7, 1.8, 1.9, 2.0, 3.0, 3.0, 3.0, 3.0]
假设没有mask,那么总的损失函数就是上面那个向量的平均值,这样,四个-1也参与了训练。那么加上一个[1,1,1,1,1,1,0,0,0,0]的mask,让这个mask和损失的那个向量去做点积。得到
[1.5, 1.6, 1.7, 1.8, 1.9, 2.0, 0,0,0,0], 然后只让1.5~2.0这6个数去做平均,得到的值就和那四个-1没有关系了,也就是那四个-1没有残余训练。
212019-02-27
相似问题