MTCNN人脸检测损失函数

来源:12-2 MTCNN数据打包(PNet、RNet、ONet)实操(1)

慕村4171945

2019-03-28

老师,你好,能分享一下MTCNN中损失函数的计算吗?
def cls_ohem(cls_prob, label):
zeros = tf.zeros_like(label)
#label=-1 --> label=0net_factory

#pos -> 1, neg -> 0, others -> 0
label_filter_invalid = tf.where(tf.less(label,0), zeros, label)
num_cls_prob = tf.size(cls_prob)
cls_prob_reshape = tf.reshape(cls_prob,[num_cls_prob,-1])
label_int = tf.cast(label_filter_invalid,tf.int32)
# get the number of rows of class_prob
num_row = tf.to_int32(cls_prob.get_shape()[0])
#row = [0,2,4.....]
row = tf.range(num_row)*2
indices_ = row + label_int
label_prob = tf.squeeze(tf.gather(cls_prob_reshape, indices_))
loss = -tf.log(label_prob+1e-10)
zeros = tf.zeros_like(label_prob, dtype=tf.float32)
ones = tf.ones_like(label_prob,dtype=tf.float32)
# set pos and neg to be 1, rest to be 0
valid_inds = tf.where(label < zeros,zeros,ones)
# get the number of POS and NEG examples
num_valid = tf.reduce_sum(valid_inds)

keep_num = tf.cast(num_valid*num_keep_radio,dtype=tf.int32)
#FILTER OUT PART AND LANDMARK DATA
loss = loss * valid_inds
loss,_ = tf.nn.top_k(loss, k=keep_num)
return tf.reduce_mean(loss)

这段计算正负样本的损失函数有点不理解,请你解释一下哈

写回答

1回答

会写代码的好厨师

2019-09-09

loss = -tf.log(label_prob+1e-10)这一行是在全部计算loss

valid_inds = tf.where(label < zeros,zeros,ones) 过滤掉label<0的样本

tf.nn.top_k(loss, k=keep_num)这里是选择top-k的loss来作为实际计算的loss,也就是体现OHEM的部分.

0
0

深度学习之目标检测常用算法原理+实践精讲

从原理到场景实战,掌握目标检测核心技术

878 学习 · 221 问题

查看课程