未验证 提交 08e75731 编写于 作者: J jerrywgz 提交者: GitHub

Merge pull request #16145 from ceci3/npair_loss0

fix npair loss
...@@ -10704,8 +10704,9 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002): ...@@ -10704,8 +10704,9 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002):
similarity_matrix = matmul( similarity_matrix = matmul(
anchor, positive, transpose_x=False, transpose_y=True) anchor, positive, transpose_x=False, transpose_y=True)
softmax_value = softmax(similarity_matrix) softmax_ce = softmax_with_cross_entropy(
cross_entropy = -1 * reduce_sum(labels * log(softmax_value), 0) logits=similarity_matrix, label=labels, soft_label=True)
cross_entropy = reduce_sum(labels * softmax_ce, 0)
celoss = reduce_mean(cross_entropy) celoss = reduce_mean(cross_entropy)
return l2loss + celoss return l2loss + celoss
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册