triplet.py 5.6 KB
Newer Older
B
Bin Lu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle
import paddle.nn as nn

class TripletLossV2(nn.Layer):
    """Triplet loss with hard positive/negative mining.
    Args:
        margin (float): margin for triplet.
    """
    def __init__(self, margin=0.5):
        super(TripletLossV2, self).__init__()
        self.margin = margin
        self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin)

    def forward(self, input, target, normalize_feature=True):
        """
        Args:
            inputs: feature matrix with shape (batch_size, feat_dim)
            target: ground truth labels with shape (num_classes)
        """
        inputs = input["features"]

        if normalize_feature:
            inputs = 1. * inputs / (paddle.expand_as(
                paddle.norm(inputs, p=2, axis=-1, keepdim=True), inputs) +
                                    1e-12)

        bs = inputs.shape[0]

        # compute distance
        dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs])
        dist = dist + dist.t()
        dist = paddle.addmm(input=dist,
                            x=inputs,
                            y=inputs.t(),
                            alpha=-2.0,
                            beta=1.0)
        dist = paddle.clip(dist, min=1e-12).sqrt()

        # hard negative mining
        is_pos = paddle.expand(target, (bs, bs)).equal(
            paddle.expand(target, (bs, bs)).t())
        is_neg = paddle.expand(target, (bs, bs)).not_equal(
            paddle.expand(target, (bs, bs)).t())

        # `dist_ap` means distance(anchor, positive)
        ## both `dist_ap` and `relative_p_inds` with shape [N, 1]
        #print(is_pos.shape, dist.shape, type(is_pos), type(dist), paddle.reshape(paddle.masked_select(dist, is_pos),(bs, -1)))
        '''
        dist_ap, relative_p_inds = paddle.max(
            paddle.reshape(dist[is_pos], (bs, -1)), axis=1, keepdim=True)
        # `dist_an` means distance(anchor, negative)
        # both `dist_an` and `relative_n_inds` with shape [N, 1]
        dist_an, relative_n_inds = paddle.min(
            paddle.reshape(dist[is_neg], (bs, -1)), axis=1, keepdim=True)
        '''
        dist_ap = paddle.max(paddle.reshape(paddle.masked_select(dist, is_pos),
                                            (bs, -1)),
                             axis=1,
                             keepdim=True)
        # `dist_an` means distance(anchor, negative)
        # both `dist_an` and `relative_n_inds` with shape [N, 1]
        dist_an = paddle.min(paddle.reshape(paddle.masked_select(dist, is_neg),
                                            (bs, -1)),
                             axis=1,
                             keepdim=True)
        # shape [N]
        dist_ap = paddle.squeeze(dist_ap, axis=1)
        dist_an = paddle.squeeze(dist_an, axis=1)

        # Compute ranking hinge loss
        y = paddle.ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)
        return {"TripletLossV2": loss}


class TripletLoss(nn.Layer):
    """Triplet loss with hard positive/negative mining.
    Reference:
    Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
    Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
    Args:
        margin (float): margin for triplet.
    """
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin)

    def forward(self, input, target):
        """
        Args:
            inputs: feature matrix with shape (batch_size, feat_dim)
            target: ground truth labels with shape (num_classes)
        """
        inputs = input["features"]

        #print(inputs.shape, targets.shape)
        bs = inputs.shape[0]
        # Compute pairwise distance, replace by the official when merged
        dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs])
        dist = dist + dist.t()
        dist = paddle.addmm(input=dist,
                            x=inputs,
                            y=inputs.t(),
                            alpha=-2.0,
                            beta=1.0)
        dist = paddle.clip(dist, min=1e-12).sqrt()

        mask = paddle.equal(target.expand([bs, bs]),
                            target.expand([bs, bs]).t())
        mask_numpy_idx = mask.numpy()
        dist_ap, dist_an = [], []
        for i in range(bs):
            # dist_ap_i = paddle.to_tensor(dist[i].numpy()[mask_numpy_idx[i]].max(),dtype='float64').unsqueeze(0)
            # dist_ap_i.stop_gradient = False
            # dist_ap.append(dist_ap_i)
            dist_ap.append(
                max([
                    dist[i][j]
                    if mask_numpy_idx[i][j] == True else float("-inf")
                    for j in range(bs)
                ]).unsqueeze(0))
            # dist_an_i = paddle.to_tensor(dist[i].numpy()[mask_numpy_idx[i] == False].min(), dtype='float64').unsqueeze(0)
            # dist_an_i.stop_gradient = False
            # dist_an.append(dist_an_i)
            dist_an.append(
                min([
                    dist[i][k]
                    if mask_numpy_idx[i][k] == False else float("inf")
                    for k in range(bs)
                ]).unsqueeze(0))

        dist_ap = paddle.concat(dist_ap, axis=0)
        dist_an = paddle.concat(dist_an, axis=0)

        # Compute ranking hinge loss
        y = paddle.ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)
        return {"TripletLoss": loss}