npairsloss.py 1.6 KB
Newer Older
B
Bin Lu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle

class NpairsLoss(paddle.nn.Layer):
    
    def __init__(self, reg_lambda=0.01):
        super(NpairsLoss, self).__init__()
        self.reg_lambda = reg_lambda
        
    def forward(self, input, target=None):
        """
        anchor and positive(should include label)
        """
        features = input["features"]
        reg_lambda = self.reg_lambda
        batch_size = features.shape[0]
        fea_dim    = features.shape[1]
        num_class = batch_size // 2
        
        #reshape
        out_feas = paddle.reshape(features, shape=[-1, 2, fea_dim])
        anc_feas, pos_feas = paddle.split(out_feas, num_or_sections = 2, axis = 1)
        anc_feas   = paddle.squeeze(anc_feas, axis=1)
        pos_feas = paddle.squeeze(pos_feas, axis=1)
        
        #get simi matrix
        similarity_matrix = paddle.matmul(anc_feas, pos_feas, transpose_y=True)     #get similarity matrix
        sparse_labels = paddle.arange(0, num_class, dtype='int64')
        xentloss = paddle.nn.CrossEntropyLoss()(similarity_matrix, sparse_labels)   #by default: mean
        
        #l2 norm
        reg = paddle.mean(paddle.sum(paddle.square(features), axis=1))
        l2loss = 0.5 * reg_lambda * reg
        return {"npairsloss": xentloss + l2loss}
    
if __name__ == "__main__":
    
    import numpy as np
    metric = NpairsLoss()

    #prepare data
    np.random.seed(1)
    features = np.random.randn(160, 32)
    #print(features)

    #do inference
    features  = paddle.to_tensor(features)
    loss = metric(features)
    print(loss)