mask是干什么用的?

来源:8-19 计算图构建-rnn结构实现、损失函数与训练算子实现

qq_书山压力大EE_0

2019-02-27

with tf.variable_scope('loss'):
        sentence_flatten = tf.reshape(sentence, [-1])
        mask_flatten = tf.reshape(mask, [-1])
        mask_sum = tf.reduce_sum(mask_flatten)
        softmax_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=sentence_flatten)
        weighted_softmax_loss = tf.multiply(softmax_loss,
                                            tf.cast(mask_flatten, tf.float32))
        prediction = tf.argmax(logits, 1, output_type = tf.int32)
        correct_prediction = tf.equal(prediction, sentence_flatten)
        correct_prediction_with_mask = tf.multiply(
            tf.cast(correct_prediction, tf.float32),
            mask_flatten)
        accuracy = tf.reduce_sum(correct_prediction_with_mask) / mask_sum
        loss = tf.reduce_sum(weighted_softmax_loss) / mask_sum
        tf.summary.scalar('loss', loss)

这里loss函数 是 谁- 谁 的损失值? 其中 mask 做什么用的? 我不明白, 能不能麻烦老师 给我讲讲

写回答

1回答

正十七

2019-02-27

因为我们在输入端做了对齐,所以很多样本的末尾其实不是真实数据而是补全的默认值,这些默认值参与训练的话会产生干扰。mask就是排除这个干扰的。比如,每个样本的长度为10,假设样本为:

[1,2,3,4,5,6,-1,-1,-1,-1],其中1-6代表的都是正常词语,-1代表的默认补全值。那么假设这个序列经过lstm后在各个位置上得到的输出值与真实值的损失为[1.5, 1.6, 1.7, 1.8, 1.9, 2.0, 3.0, 3.0, 3.0, 3.0]

假设没有mask,那么总的损失函数就是上面那个向量的平均值,这样,四个-1也参与了训练。那么加上一个[1,1,1,1,1,1,0,0,0,0]的mask,让这个mask和损失的那个向量去做点积。得到

[1.5, 1.6, 1.7, 1.8, 1.9, 2.0, 0,0,0,0], 然后只让1.5~2.0这6个数去做平均,得到的值就和那四个-1没有关系了,也就是那四个-1没有残余训练。

2
1
qq_书山压力大EE_0
非常感谢!
2019-02-27
共1条回复

深度学习之神经网络(CNN/RNN/GAN)算法原理+实战

深度学习算法工程师必学,深入理解深度学习核心算法CNN RNN GAN

2617 学习 · 935 问题

查看课程