triplet.py 5.2 KB
Newer Older
B
Bin Lu 已提交
1 2 3 4 5 6 7
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle
import paddle.nn as nn

D
dongshuilong 已提交
8

B
Bin Lu 已提交
9 10 11 12 13
class TripletLossV2(nn.Layer):
    """Triplet loss with hard positive/negative mining.
    Args:
        margin (float): margin for triplet.
    """
D
dongshuilong 已提交
14 15

    def __init__(self, margin=0.5, normalize_feature=True):
B
Bin Lu 已提交
16 17 18
        super(TripletLossV2, self).__init__()
        self.margin = margin
        self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin)
D
dongshuilong 已提交
19
        self.normalize_feature = normalize_feature
B
Bin Lu 已提交
20

D
dongshuilong 已提交
21
    def forward(self, input, target):
B
Bin Lu 已提交
22 23 24 25 26 27 28
        """
        Args:
            inputs: feature matrix with shape (batch_size, feat_dim)
            target: ground truth labels with shape (num_classes)
        """
        inputs = input["features"]

D
dongshuilong 已提交
29
        if self.normalize_feature:
B
Bin Lu 已提交
30
            inputs = 1. * inputs / (paddle.expand_as(
D
dongshuilong 已提交
31 32
                paddle.norm(
                    inputs, p=2, axis=-1, keepdim=True), inputs) + 1e-12)
B
Bin Lu 已提交
33 34 35 36 37 38

        bs = inputs.shape[0]

        # compute distance
        dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs])
        dist = dist + dist.t()
D
dongshuilong 已提交
39 40
        dist = paddle.addmm(
            input=dist, x=inputs, y=inputs.t(), alpha=-2.0, beta=1.0)
B
Bin Lu 已提交
41 42 43
        dist = paddle.clip(dist, min=1e-12).sqrt()

        # hard negative mining
D
dongshuilong 已提交
44 45 46 47
        is_pos = paddle.expand(target, (
            bs, bs)).equal(paddle.expand(target, (bs, bs)).t())
        is_neg = paddle.expand(target, (
            bs, bs)).not_equal(paddle.expand(target, (bs, bs)).t())
B
Bin Lu 已提交
48 49 50 51 52 53 54 55 56 57 58

        # `dist_ap` means distance(anchor, positive)
        ## both `dist_ap` and `relative_p_inds` with shape [N, 1]
        '''
        dist_ap, relative_p_inds = paddle.max(
            paddle.reshape(dist[is_pos], (bs, -1)), axis=1, keepdim=True)
        # `dist_an` means distance(anchor, negative)
        # both `dist_an` and `relative_n_inds` with shape [N, 1]
        dist_an, relative_n_inds = paddle.min(
            paddle.reshape(dist[is_neg], (bs, -1)), axis=1, keepdim=True)
        '''
D
dongshuilong 已提交
59 60
        dist_ap = paddle.max(paddle.reshape(
            paddle.masked_select(dist, is_pos), (bs, -1)),
B
Bin Lu 已提交
61 62 63 64
                             axis=1,
                             keepdim=True)
        # `dist_an` means distance(anchor, negative)
        # both `dist_an` and `relative_n_inds` with shape [N, 1]
D
dongshuilong 已提交
65 66
        dist_an = paddle.min(paddle.reshape(
            paddle.masked_select(dist, is_neg), (bs, -1)),
B
Bin Lu 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
                             axis=1,
                             keepdim=True)
        # shape [N]
        dist_ap = paddle.squeeze(dist_ap, axis=1)
        dist_an = paddle.squeeze(dist_an, axis=1)

        # Compute ranking hinge loss
        y = paddle.ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)
        return {"TripletLossV2": loss}


class TripletLoss(nn.Layer):
    """Triplet loss with hard positive/negative mining.
    Reference:
    Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
    Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
    Args:
        margin (float): margin for triplet.
    """
D
dongshuilong 已提交
87

B
Bin Lu 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin)

    def forward(self, input, target):
        """
        Args:
            inputs: feature matrix with shape (batch_size, feat_dim)
            target: ground truth labels with shape (num_classes)
        """
        inputs = input["features"]

        bs = inputs.shape[0]
        # Compute pairwise distance, replace by the official when merged
        dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs])
        dist = dist + dist.t()
D
dongshuilong 已提交
105 106
        dist = paddle.addmm(
            input=dist, x=inputs, y=inputs.t(), alpha=-2.0, beta=1.0)
B
Bin Lu 已提交
107 108
        dist = paddle.clip(dist, min=1e-12).sqrt()

D
dongshuilong 已提交
109 110
        mask = paddle.equal(
            target.expand([bs, bs]), target.expand([bs, bs]).t())
B
Bin Lu 已提交
111 112 113 114 115 116 117 118
        mask_numpy_idx = mask.numpy()
        dist_ap, dist_an = [], []
        for i in range(bs):
            # dist_ap_i = paddle.to_tensor(dist[i].numpy()[mask_numpy_idx[i]].max(),dtype='float64').unsqueeze(0)
            # dist_ap_i.stop_gradient = False
            # dist_ap.append(dist_ap_i)
            dist_ap.append(
                max([
D
dongshuilong 已提交
119 120
                    dist[i][j] if mask_numpy_idx[i][j] == True else float(
                        "-inf") for j in range(bs)
B
Bin Lu 已提交
121 122 123 124 125 126
                ]).unsqueeze(0))
            # dist_an_i = paddle.to_tensor(dist[i].numpy()[mask_numpy_idx[i] == False].min(), dtype='float64').unsqueeze(0)
            # dist_an_i.stop_gradient = False
            # dist_an.append(dist_an_i)
            dist_an.append(
                min([
D
dongshuilong 已提交
127 128
                    dist[i][k] if mask_numpy_idx[i][k] == False else float(
                        "inf") for k in range(bs)
B
Bin Lu 已提交
129 130 131 132 133 134 135 136 137
                ]).unsqueeze(0))

        dist_ap = paddle.concat(dist_ap, axis=0)
        dist_an = paddle.concat(dist_an, axis=0)

        # Compute ranking hinge loss
        y = paddle.ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)
        return {"TripletLoss": loss}