From f62604de7213a6b9d5566710a3f4cda1f244fd3f Mon Sep 17 00:00:00 2001 From: Bin Lu Date: Mon, 31 May 2021 14:20:48 +0800 Subject: [PATCH] Add files via upload add losses --- ppcls/losses/celoss.py | 114 ++++++++++++++++++++++++++++ ppcls/losses/centerloss.py | 46 ++++++++++++ ppcls/losses/comfunc.py | 44 +++++++++++ ppcls/losses/emlloss.py | 100 +++++++++++++++++++++++++ ppcls/losses/msmloss.py | 87 ++++++++++++++++++++++ ppcls/losses/npairsloss.py | 53 +++++++++++++ ppcls/losses/trihardloss.py | 93 +++++++++++++++++++++++ ppcls/losses/triplet.py | 143 ++++++++++++++++++++++++++++++++++++ 8 files changed, 680 insertions(+) create mode 100644 ppcls/losses/celoss.py create mode 100644 ppcls/losses/centerloss.py create mode 100644 ppcls/losses/comfunc.py create mode 100644 ppcls/losses/emlloss.py create mode 100644 ppcls/losses/msmloss.py create mode 100644 ppcls/losses/npairsloss.py create mode 100644 ppcls/losses/trihardloss.py create mode 100644 ppcls/losses/triplet.py diff --git a/ppcls/losses/celoss.py b/ppcls/losses/celoss.py new file mode 100644 index 00000000..e588bb8f --- /dev/null +++ b/ppcls/losses/celoss.py @@ -0,0 +1,114 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn.functional as F + +__all__ = ['CELoss', 'JSDivLoss', 'KLDivLoss'] + + +class Loss(object): + """ + Loss + """ + def __init__(self, class_dim=1000, epsilon=None): + assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim) + self._class_dim = class_dim + if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0: + self._epsilon = epsilon + self._label_smoothing = True #use label smoothing.(Actually, it is softmax label) + else: + self._epsilon = None + self._label_smoothing = False + + #do label_smoothing + def _labelsmoothing(self, target): + if target.shape[-1] != self._class_dim: + one_hot_target = F.one_hot(target, self._class_dim) #do ont hot(23,34,46)-> 3 * _class_dim + else: + one_hot_target = target + + #do label_smooth + soft_target = F.label_smooth(one_hot_target, epsilon=self._epsilon) #(1 - epsilon) * input + eposilon / K. + soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim]) + return soft_target + + def _crossentropy(self, input, target, use_pure_fp16=False): + if self._label_smoothing: + target = self._labelsmoothing(target) + input = -F.log_softmax(input, axis=-1) #softmax and do log + cost = paddle.sum(target * input, axis=-1) #sum + else: + cost = F.cross_entropy(input=input, label=target) + + if use_pure_fp16: + avg_cost = paddle.sum(cost) + else: + avg_cost = paddle.mean(cost) + return avg_cost + + def _kldiv(self, input, target, name=None): + eps = 1.0e-10 + cost = target * paddle.log( + (target + eps) / (input + eps)) * self._class_dim + return cost + + def _jsdiv(self, input, target): #so the input and target is the fc output; no softmax + input = F.softmax(input) + target = F.softmax(target) + + #two distribution + cost = self._kldiv(input, target) + self._kldiv(target, input) + cost = cost / 2 + avg_cost = paddle.mean(cost) + return avg_cost + + def __call__(self, input, target): + pass + + +class CELoss(Loss): + """ + Cross entropy loss + """ + + def __init__(self, class_dim=1000, epsilon=None): + super(CELoss, self).__init__(class_dim, epsilon) + + def __call__(self, input, target, use_pure_fp16=False): + logits = input["logits"] + cost = self._crossentropy(logits, target, use_pure_fp16) + return {"CELoss": cost} + +class JSDivLoss(Loss): + """ + JSDiv loss + """ + def __init__(self, class_dim=1000, epsilon=None): + super(JSDivLoss, self).__init__(class_dim, epsilon) + + def __call__(self, input, target): + cost = self._jsdiv(input, target) + return cost + + +class KLDivLoss(paddle.nn.Layer): + def __init__(self): + super(KLDivLoss, self).__init__() + + def __call__(self, p, q, is_logit=True): + if is_logit: + p = paddle.nn.functional.softmax(p) + q = paddle.nn.functional.softmax(q) + return -(p * paddle.log(q + 1e-8)).sum(1).mean() \ No newline at end of file diff --git a/ppcls/losses/centerloss.py b/ppcls/losses/centerloss.py new file mode 100644 index 00000000..759f0b11 --- /dev/null +++ b/ppcls/losses/centerloss.py @@ -0,0 +1,46 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +class CenterLoss(nn.Layer): + def __init__(self, num_classes=5013, feat_dim=2048): + super(CenterLoss, self).__init__() + self.num_classes = num_classes + self.feat_dim = feat_dim + self.centers = paddle.randn(shape=[self.num_classes, self.feat_dim]).astype("float64") #random center + + def __call__(self, input, target): + """ + inputs: network output: {"features: xxx", "logits": xxxx} + target: image label + """ + feats = input["features"] + labels = target + batch_size = feats.shape[0] + + #calc feat * feat + dist1 = paddle.sum(paddle.square(feats), axis=1, keepdim=True) + dist1 = paddle.expand(dist1, [batch_size, self.num_classes]) + + #dist2 of centers + dist2 = paddle.sum(paddle.square(self.centers), axis=1, keepdim=True) #num_classes + dist2 = paddle.expand(dist2, [self.num_classes, batch_size]).astype("float64") + dist2 = paddle.transpose(dist2, [1, 0]) + + #first x * x + y * y + distmat = paddle.add(dist1, dist2) + tmp = paddle.matmul(feats, paddle.transpose(self.centers, [1, 0])) + distmat = distmat - 2.0 * tmp + + #generate the mask + classes = paddle.arange(self.num_classes).astype("int64") + labels = paddle.expand(paddle.unsqueeze(labels, 1), (batch_size, self.num_classes)) + mask = paddle.equal(paddle.expand(classes, [batch_size, self.num_classes]), labels).astype("float64") #get mask + + dist = paddle.multiply(distmat, mask) + loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size + + return {'CenterLoss': loss} \ No newline at end of file diff --git a/ppcls/losses/comfunc.py b/ppcls/losses/comfunc.py new file mode 100644 index 00000000..88f61af3 --- /dev/null +++ b/ppcls/losses/comfunc.py @@ -0,0 +1,44 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +def rerange_index(batch_size, samples_each_class): + tmp = np.arange(0, batch_size * batch_size) + tmp = tmp.reshape(-1, batch_size) + rerange_index = [] + + for i in range(batch_size): + step = i // samples_each_class + start = step * samples_each_class + end = (step + 1) * samples_each_class + + pos_idx = [] + neg_idx = [] + for j, k in enumerate(tmp[i]): + if j >= start and j < end: + if j == i: + pos_idx.insert(0, k) + else: + pos_idx.append(k) + else: + neg_idx.append(k) + rerange_index += (pos_idx + neg_idx) + + rerange_index = np.array(rerange_index).astype(np.int32) + return rerange_index \ No newline at end of file diff --git a/ppcls/losses/emlloss.py b/ppcls/losses/emlloss.py new file mode 100644 index 00000000..2ce93457 --- /dev/null +++ b/ppcls/losses/emlloss.py @@ -0,0 +1,100 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +import numpy as np +from .comfunc import rerange_index + +class EmlLoss(paddle.nn.Layer): + def __init__(self, batch_size = 40, samples_each_class = 2): + super(EmlLoss, self).__init__() + assert(batch_size % samples_each_class == 0) + self.samples_each_class = samples_each_class + self.batch_size = batch_size + self.rerange_index = rerange_index(batch_size, samples_each_class) + self.thresh = 20.0 + self.beta = 100000 + + def surrogate_function(self, beta, theta, bias): + x = theta * paddle.exp(bias) + output = paddle.log(1 + beta * x) / math.log(1 + beta) + return output + + def surrogate_function_approximate(self, beta, theta, bias): + output = (paddle.log(theta) + bias + math.log(beta)) / math.log(1+beta) + return output + + def surrogate_function_stable(self, beta, theta, target, thresh): + max_gap = paddle.to_tensor(thresh, dtype='float32') + max_gap.stop_gradient = True + + target_max = paddle.maximum(target, max_gap) + target_min = paddle.minimum(target, max_gap) + + loss1 = self.surrogate_function(beta, theta, target_min) + loss2 = self.surrogate_function_approximate(beta, theta, target_max) + bias = self.surrogate_function(beta, theta, max_gap) + loss = loss1 + loss2 - bias + return loss + + def forward(self, input, target=None): + features = input["features"] + samples_each_class = self.samples_each_class + batch_size = self.batch_size + rerange_index = self.rerange_index + + #calc distance + diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0) + similary_matrix = paddle.sum(paddle.square(diffs), axis=-1) + + tmp = paddle.reshape(similary_matrix, shape = [-1, 1]) + rerange_index = paddle.to_tensor(rerange_index) + tmp = paddle.gather(tmp, index=rerange_index) + similary_matrix = paddle.reshape(tmp, shape=[-1, batch_size]) + + ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1, + samples_each_class - 1, batch_size - samples_each_class], axis = 1) + ignore.stop_gradient = True + + pos_max = paddle.max(pos, axis=1, keepdim=True) + pos = paddle.exp(pos - pos_max) + pos_mean = paddle.mean(pos, axis=1, keepdim=True) + + neg_min = paddle.min(neg, axis=1, keepdim=True) + neg = paddle.exp(neg_min - neg) + neg_mean = paddle.mean(neg, axis=1, keepdim=True) + + bias = pos_max - neg_min + theta = paddle.multiply(neg_mean, pos_mean) + + loss = self.surrogate_function_stable(self.beta, theta, bias, self.thresh) + loss = paddle.mean(loss) + return {"emlloss": loss} + +if __name__=="__main__": + + metric = EmlLoss() + + np.random.seed(1) + features = np.random.randn(40, 32) + features = paddle.to_tensor(features, dtype="float32") + print(features) + + loss = metric(features) + print(loss) \ No newline at end of file diff --git a/ppcls/losses/msmloss.py b/ppcls/losses/msmloss.py new file mode 100644 index 00000000..efe5ddc7 --- /dev/null +++ b/ppcls/losses/msmloss.py @@ -0,0 +1,87 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle +from .comfunc import rerange_index + +class MSMLoss(paddle.nn.Layer): + """ + MSMLoss Loss, based on triplet loss. USE P * K samples. + the batch size is fixed. Batch_size = P * K; but the K may vary between batches. + same label gather together + + supported_metrics = [ + 'euclidean', + 'sqeuclidean', + 'cityblock', + ] + only consider samples_each_class = 2 + """ + def __init__(self, batch_size = 120, samples_each_class=2, margin=0.1): + super(MSMLoss, self).__init__() + self.margin = margin + self.samples_each_class = samples_each_class + self.batch_size = batch_size + self.rerange_index = rerange_index(batch_size, samples_each_class) + + def forward(self, input, target=None): + #normalization + features = input["features"] + features = self._nomalize(features) + samples_each_class = self.samples_each_class + rerange_index = paddle.to_tensor(self.rerange_index) + + #calc sm + diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0) + similary_matrix = paddle.sum(paddle.square(diffs), axis=-1) + + #rerange + tmp = paddle.reshape(similary_matrix, shape = [-1, 1]) + tmp = paddle.gather(tmp, index=rerange_index) + similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size]) + + #split + ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1, + samples_each_class - 1, -1], axis = 1) + ignore.stop_gradient = True + + hard_pos = paddle.max(pos) + hard_neg = paddle.min(neg) + + loss = hard_pos + self.margin - hard_neg + loss = paddle.nn.ReLU()(loss) + return {"msmloss": loss} + + def _nomalize(self, input): + input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True)) + return paddle.divide(input, input_norm) + +if __name__ == "__main__": + + import numpy as np + metric = MSMLoss(48) + + #prepare data + np.random.seed(1) + features = np.random.randn(48, 32) + #print(features) + + #do inference + features = paddle.to_tensor(features) + loss = metric(features) + print(loss) + \ No newline at end of file diff --git a/ppcls/losses/npairsloss.py b/ppcls/losses/npairsloss.py new file mode 100644 index 00000000..fc20d2b8 --- /dev/null +++ b/ppcls/losses/npairsloss.py @@ -0,0 +1,53 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle + +class NpairsLoss(paddle.nn.Layer): + + def __init__(self, reg_lambda=0.01): + super(NpairsLoss, self).__init__() + self.reg_lambda = reg_lambda + + def forward(self, input, target=None): + """ + anchor and positive(should include label) + """ + features = input["features"] + reg_lambda = self.reg_lambda + batch_size = features.shape[0] + fea_dim = features.shape[1] + num_class = batch_size // 2 + + #reshape + out_feas = paddle.reshape(features, shape=[-1, 2, fea_dim]) + anc_feas, pos_feas = paddle.split(out_feas, num_or_sections = 2, axis = 1) + anc_feas = paddle.squeeze(anc_feas, axis=1) + pos_feas = paddle.squeeze(pos_feas, axis=1) + + #get simi matrix + similarity_matrix = paddle.matmul(anc_feas, pos_feas, transpose_y=True) #get similarity matrix + sparse_labels = paddle.arange(0, num_class, dtype='int64') + xentloss = paddle.nn.CrossEntropyLoss()(similarity_matrix, sparse_labels) #by default: mean + + #l2 norm + reg = paddle.mean(paddle.sum(paddle.square(features), axis=1)) + l2loss = 0.5 * reg_lambda * reg + return {"npairsloss": xentloss + l2loss} + +if __name__ == "__main__": + + import numpy as np + metric = NpairsLoss() + + #prepare data + np.random.seed(1) + features = np.random.randn(160, 32) + #print(features) + + #do inference + features = paddle.to_tensor(features) + loss = metric(features) + print(loss) + + diff --git a/ppcls/losses/trihardloss.py b/ppcls/losses/trihardloss.py new file mode 100644 index 00000000..0439e020 --- /dev/null +++ b/ppcls/losses/trihardloss.py @@ -0,0 +1,93 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from .comfunc import rerange_index + +class TriHardLoss(paddle.nn.Layer): + """ + TriHard Loss, based on triplet loss. USE P * K samples. + the batch size is fixed. Batch_size = P * K; but the K may vary between batches. + same label gather together + + supported_metrics = [ + 'euclidean', + 'sqeuclidean', + 'cityblock', + ] + only consider samples_each_class = 2 + """ + def __init__(self, batch_size = 120, samples_each_class=2, margin=0.1): + super(TriHardLoss, self).__init__() + self.margin = margin + self.samples_each_class = samples_each_class + self.batch_size = batch_size + self.rerange_index = rerange_index(batch_size, samples_each_class) + + def forward(self, input, target=None): + features = input["features"] + assert (self.batch_size == features.shape[0]) + + #normalization + features = self._nomalize(features) + samples_each_class = self.samples_each_class + rerange_index = paddle.to_tensor(self.rerange_index) + + #calc sm + diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0) + similary_matrix = paddle.sum(paddle.square(diffs), axis=-1) + + #rerange + tmp = paddle.reshape(similary_matrix, shape = [-1, 1]) + tmp = paddle.gather(tmp, index=rerange_index) + similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size]) + + #split + ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1, + samples_each_class - 1, -1], axis = 1) + + ignore.stop_gradient = True + hard_pos = paddle.max(pos, axis=1) + hard_neg = paddle.min(neg, axis=1) + + loss = hard_pos + self.margin - hard_neg + loss = paddle.nn.ReLU()(loss) + loss = paddle.mean(loss) + return {"trihardloss": loss} + + def _nomalize(self, input): + input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True)) + return paddle.divide(input, input_norm) + +if __name__ == "__main__": + + import numpy as np + metric = TriHardLoss(48) + + #prepare data + np.random.seed(1) + features = np.random.randn(48, 32) + #print(features) + + #do inference + features = paddle.to_tensor(features) + loss = metric(features) + print(loss) + + + diff --git a/ppcls/losses/triplet.py b/ppcls/losses/triplet.py new file mode 100644 index 00000000..a8b1a17c --- /dev/null +++ b/ppcls/losses/triplet.py @@ -0,0 +1,143 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn + +class TripletLossV2(nn.Layer): + """Triplet loss with hard positive/negative mining. + Args: + margin (float): margin for triplet. + """ + def __init__(self, margin=0.5): + super(TripletLossV2, self).__init__() + self.margin = margin + self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin) + + def forward(self, input, target, normalize_feature=True): + """ + Args: + inputs: feature matrix with shape (batch_size, feat_dim) + target: ground truth labels with shape (num_classes) + """ + inputs = input["features"] + + if normalize_feature: + inputs = 1. * inputs / (paddle.expand_as( + paddle.norm(inputs, p=2, axis=-1, keepdim=True), inputs) + + 1e-12) + + bs = inputs.shape[0] + + # compute distance + dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs]) + dist = dist + dist.t() + dist = paddle.addmm(input=dist, + x=inputs, + y=inputs.t(), + alpha=-2.0, + beta=1.0) + dist = paddle.clip(dist, min=1e-12).sqrt() + + # hard negative mining + is_pos = paddle.expand(target, (bs, bs)).equal( + paddle.expand(target, (bs, bs)).t()) + is_neg = paddle.expand(target, (bs, bs)).not_equal( + paddle.expand(target, (bs, bs)).t()) + + # `dist_ap` means distance(anchor, positive) + ## both `dist_ap` and `relative_p_inds` with shape [N, 1] + #print(is_pos.shape, dist.shape, type(is_pos), type(dist), paddle.reshape(paddle.masked_select(dist, is_pos),(bs, -1))) + ''' + dist_ap, relative_p_inds = paddle.max( + paddle.reshape(dist[is_pos], (bs, -1)), axis=1, keepdim=True) + # `dist_an` means distance(anchor, negative) + # both `dist_an` and `relative_n_inds` with shape [N, 1] + dist_an, relative_n_inds = paddle.min( + paddle.reshape(dist[is_neg], (bs, -1)), axis=1, keepdim=True) + ''' + dist_ap = paddle.max(paddle.reshape(paddle.masked_select(dist, is_pos), + (bs, -1)), + axis=1, + keepdim=True) + # `dist_an` means distance(anchor, negative) + # both `dist_an` and `relative_n_inds` with shape [N, 1] + dist_an = paddle.min(paddle.reshape(paddle.masked_select(dist, is_neg), + (bs, -1)), + axis=1, + keepdim=True) + # shape [N] + dist_ap = paddle.squeeze(dist_ap, axis=1) + dist_an = paddle.squeeze(dist_an, axis=1) + + # Compute ranking hinge loss + y = paddle.ones_like(dist_an) + loss = self.ranking_loss(dist_an, dist_ap, y) + return {"TripletLossV2": loss} + + +class TripletLoss(nn.Layer): + """Triplet loss with hard positive/negative mining. + Reference: + Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. + Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. + Args: + margin (float): margin for triplet. + """ + def __init__(self, margin=1.0): + super(TripletLoss, self).__init__() + self.margin = margin + self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin) + + def forward(self, input, target): + """ + Args: + inputs: feature matrix with shape (batch_size, feat_dim) + target: ground truth labels with shape (num_classes) + """ + inputs = input["features"] + + #print(inputs.shape, targets.shape) + bs = inputs.shape[0] + # Compute pairwise distance, replace by the official when merged + dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs]) + dist = dist + dist.t() + dist = paddle.addmm(input=dist, + x=inputs, + y=inputs.t(), + alpha=-2.0, + beta=1.0) + dist = paddle.clip(dist, min=1e-12).sqrt() + + mask = paddle.equal(target.expand([bs, bs]), + target.expand([bs, bs]).t()) + mask_numpy_idx = mask.numpy() + dist_ap, dist_an = [], [] + for i in range(bs): + # dist_ap_i = paddle.to_tensor(dist[i].numpy()[mask_numpy_idx[i]].max(),dtype='float64').unsqueeze(0) + # dist_ap_i.stop_gradient = False + # dist_ap.append(dist_ap_i) + dist_ap.append( + max([ + dist[i][j] + if mask_numpy_idx[i][j] == True else float("-inf") + for j in range(bs) + ]).unsqueeze(0)) + # dist_an_i = paddle.to_tensor(dist[i].numpy()[mask_numpy_idx[i] == False].min(), dtype='float64').unsqueeze(0) + # dist_an_i.stop_gradient = False + # dist_an.append(dist_an_i) + dist_an.append( + min([ + dist[i][k] + if mask_numpy_idx[i][k] == False else float("inf") + for k in range(bs) + ]).unsqueeze(0)) + + dist_ap = paddle.concat(dist_ap, axis=0) + dist_an = paddle.concat(dist_an, axis=0) + + # Compute ranking hinge loss + y = paddle.ones_like(dist_an) + loss = self.ranking_loss(dist_an, dist_ap, y) + return {"TripletLoss": loss} -- GitLab