MTCNN模型的交叉熵损失函数
Created by: yeyupiaoling
环境
- PaddlePaddle
- Ubuntu 16.04
- Python 3.5
问题
我想用PaddlePaddle复现MTCNN模型的交叉熵损失函数,以下的tensorflow的代码
def cls_ohem(cls_prob, label):
'''计算类别损失
参数:
cls_prob:预测类别,是否有人
label:真实值
返回值:
损失
'''
zeros = tf.zeros_like(label)
# 只把pos的label设定为1,其余都为0
label_filter_invalid = tf.where(tf.less(label, 0), zeros, label)
# 类别size[2*batch]
num_cls_prob = tf.size(cls_prob)
cls_prob_reshpae = tf.reshape(cls_prob, [num_cls_prob, -1])
label_int = tf.cast(label_filter_invalid, tf.int32)
# 获取batch数
num_row = tf.to_int32(cls_prob.get_shape()[0])
# 对应某一batch而言,batch*2为非人类别概率,batch*2+1为人概率类别,indices为对应 cls_prob_reshpae
# 应该的真实值,后续用交叉熵计算损失
row = tf.range(num_row) * 2
indices_ = row + label_int
# 真实标签对应的概率
label_prob = tf.squeeze(tf.gather(cls_prob_reshpae, indices_))
loss = -tf.log(label_prob + 1e-10)
zeros = tf.zeros_like(label_prob, dtype=tf.float32)
ones = tf.ones_like(label_prob, dtype=tf.float32)
# 统计neg和pos的数量
valid_inds = tf.where(label < zeros, zeros, ones)
num_valid = tf.reduce_sum(valid_inds)
# 选取70%的数据
keep_num = tf.cast(num_valid * num_keep_radio, dtype=tf.int32)
# 只选取neg,pos的70%损失
loss = loss * valid_inds
loss, _ = tf.nn.top_k(loss, k=keep_num)
return tf.reduce_mean(loss)
我用PaddlePaddle大概复现了一下,逻辑也不知道对不对的。
def cls_ohem(cls_prob, label):
'''计算类别损失
参数:
cls_prob:预测类别,是否有人
label:真实值
返回值:
损失
'''
# 只把pos的label设定为1,其余都为0
zeros = fluid.layers.zeros(shape=[int(i) for i in label.shape], dtype=label.dtype)
cond = fluid.layers.less_than(x=label, y=fluid.layers.fill_constant(shape=[],dtype='float32', value=0))
ie = fluid.layers.IfElse(cond)
with ie.true_block():
ie.output(zeros)
with ie.false_block():
ie.output(label)
label_filter_invalid = ie()
# 类别size[2*batch]
num_cls_prob = sum(cls_prob.shape)
cls_prob_reshpae = fluid.layers.reshape(cls_prob, [num_cls_prob, -1])
label_int = fluid.layers.cast(label_filter_invalid[0], dtype='int32')
# 获取batch数
num_row = fluid.layers.fill_constant(shape=[],dtype='int32', value=cls_prob.shape[0])
# 对应某一batch而言,batch*2为非人类别概率,batch*2+1为人概率类别,indices为对应 cls_prob_reshpae
# 应该的真实值,后续用交叉熵计算损失
indices_ = []
i = fluid.layers.fill_constant(shape=[], dtype='int32', value=0)
cond = fluid.layers.less_than(x=i, y=num_row)
print(cond)
while_op = fluid.layers.While(cond=cond)
with while_op.block():
indices_.append(i + label_int)
i = i + 1
fluid.layers.less_than(x=i, y=num_row, cond=cond)
# 真实标签对应的概率
indices_ = fluid.layers.cast(indices_, dtype='int32')
label_prob = fluid.layers.squeeze(fluid.layers.gather(cls_prob_reshpae, indices_))
loss = -fluid.layers.log(label_prob + 1e-10)
zeros = fluid.layers.zeros(shape=[int(i) for i in label_prob.shape], dtype='float32')
ones = fluid.layers.ones(shape=[int(i) for i in label_prob.shape], dtype='float32')
# 统计neg和pos的数量
cond2 = fluid.layers.less_than(x=label, y=zeros)
ie2 = fluid.layers.IfElse(cond2)
with ie2.true_block():
ie2.output(zeros)
with ie2.false_block():
ie2.output(ones)
valid_inds = ie2()
num_valid = fluid.layers.reduce_sum(valid_inds)
# 选取70%的数据
keep_num = fluid.layers.cast(num_valid * num_keep_radio, dtype='int32')
# 只选取neg,pos的70%损失
loss = loss * valid_inds
loss, _ = fluid.layers.nn.topk(loss, k=keep_num)
return fluid.layers.reduce_mean(loss)
在使用的时候报错:
Traceback (most recent call last):
File "/media/test/5C283BCA283BA1C6/yeyupiaoling/PyCharm/PaddlePaddle-MTCNN/train_PNet/train_PNet.py", line 12, in <module>
image, label, bbox_target, landmark_target, label_cost, bbox_loss, landmark_loss, conv4_1, conv4_2, conv4_3 = P_Net()
File "/media/test/5C283BCA283BA1C6/yeyupiaoling/PyCharm/PaddlePaddle-MTCNN/train_PNet/model.py", line 67, in P_Net
label_cost = cls_ohem(cls_prob=cls_prob, label=label)
File "/media/test/5C283BCA283BA1C6/yeyupiaoling/PyCharm/PaddlePaddle-MTCNN/train_PNet/model.py", line 114, in cls_ohem
while_op = fluid.layers.While(cond=cond)
File "/usr/local/lib/python3.5/dist-packages/paddle/fluid/layers/control_flow.py", line 693, in __init__
raise TypeError("condition should be a bool scalar")
TypeError: condition should be a bool scalar
恳请各位老师帮忙复现一下,谢谢。