未验证 提交 f6d18678 编写于 作者: C ceci3 提交者: GitHub

test=develop

上级 6bce9861
......@@ -10567,17 +10567,17 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002):
'''
**Npair Loss Layer**
see http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf
Read `Improved Deep Metric Learning with Multi class N pair Loss Objective <http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf>`_ .
Npair loss requires paired data. Npair loss has two parts, the first part is L2
regularizer on the embedding vector, the second part is cross entropy loss which
Npair loss requires paired data. Npair loss has two parts: the first part is L2
regularizer on the embedding vector; the second part is cross entropy loss which
takes the similarity matrix of anchor and positive as logits.
Args:
anchor(Variable): embedding vector for the anchor image. shape=[batch_size, embedding_dims]
positive(Variable): embedding vector for the positive image. shape=[batch_size, embedding_dims]
labels(Varieble): 1-D tensor. shape=[batch_size]
l2_res(float32): L2 regularization term on embedding vector, default: 0.02
labels(Variable): 1-D tensor. shape=[batch_size]
l2_reg(float32): L2 regularization term on embedding vector, default: 0.002
Returns:
npair loss(Variable): return npair loss, shape=[1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册