npairsloss.py 1.6 KB
Newer Older
B
Bin Lu 已提交
1 2 3 4 5
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle

W
weishengyu 已提交
6

B
Bin Lu 已提交
7
class NpairsLoss(paddle.nn.Layer):
8 9 10 11 12
    """Npair_loss_
    paper [Improved deep metric learning with multi-class N-pair loss objective](https://dl.acm.org/doi/10.5555/3157096.3157304)
    code reference: https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/losses/metric_learning/npairs_loss
    """

B
Bin Lu 已提交
13 14 15
    def __init__(self, reg_lambda=0.01):
        super(NpairsLoss, self).__init__()
        self.reg_lambda = reg_lambda
W
weishengyu 已提交
16

B
Bin Lu 已提交
17 18 19 20 21 22 23
    def forward(self, input, target=None):
        """
        anchor and positive(should include label)
        """
        features = input["features"]
        reg_lambda = self.reg_lambda
        batch_size = features.shape[0]
W
weishengyu 已提交
24
        fea_dim = features.shape[1]
B
Bin Lu 已提交
25
        num_class = batch_size // 2
W
weishengyu 已提交
26

B
Bin Lu 已提交
27 28
        #reshape
        out_feas = paddle.reshape(features, shape=[-1, 2, fea_dim])
W
weishengyu 已提交
29 30
        anc_feas, pos_feas = paddle.split(out_feas, num_or_sections=2, axis=1)
        anc_feas = paddle.squeeze(anc_feas, axis=1)
B
Bin Lu 已提交
31
        pos_feas = paddle.squeeze(pos_feas, axis=1)
W
weishengyu 已提交
32

B
Bin Lu 已提交
33
        #get simi matrix
W
weishengyu 已提交
34 35
        similarity_matrix = paddle.matmul(
            anc_feas, pos_feas, transpose_y=True)  #get similarity matrix
B
Bin Lu 已提交
36
        sparse_labels = paddle.arange(0, num_class, dtype='int64')
W
weishengyu 已提交
37 38 39
        xentloss = paddle.nn.CrossEntropyLoss()(
            similarity_matrix, sparse_labels)  #by default: mean

B
Bin Lu 已提交
40 41 42 43
        #l2 norm
        reg = paddle.mean(paddle.sum(paddle.square(features), axis=1))
        l2loss = 0.5 * reg_lambda * reg
        return {"npairsloss": xentloss + l2loss}