From 69f563d234a81de2da841f8bf267ca452a7599e1 Mon Sep 17 00:00:00 2001 From: weishengyu Date: Thu, 3 Jun 2021 15:17:49 +0800 Subject: [PATCH] rename losses -> loss --- ppcls/engine/trainer.py | 2 +- ppcls/{losses => loss}/__init__.py | 0 ppcls/{losses => loss}/celoss.py | 0 ppcls/{losses => loss}/centerloss.py | 27 ++++++----- ppcls/{losses => loss}/comfunc.py | 15 ++++--- ppcls/{losses => loss}/emlloss.py | 64 +++++++++++++++------------ ppcls/{losses => loss}/msmloss.py | 42 ++++++++++-------- ppcls/{losses => loss}/npairsloss.py | 23 +++++----- ppcls/{losses => loss}/trihardloss.py | 46 ++++++++++--------- ppcls/{losses => loss}/triplet.py | 0 10 files changed, 124 insertions(+), 95 deletions(-) rename ppcls/{losses => loss}/__init__.py (100%) rename ppcls/{losses => loss}/celoss.py (100%) rename ppcls/{losses => loss}/centerloss.py (59%) rename ppcls/{losses => loss}/comfunc.py (82%) rename ppcls/{losses => loss}/emlloss.py (67%) rename ppcls/{losses => loss}/msmloss.py (66%) rename ppcls/{losses => loss}/npairsloss.py (73%) rename ppcls/{losses => loss}/trihardloss.py (66%) rename ppcls/{losses => loss}/triplet.py (100%) diff --git a/ppcls/engine/trainer.py b/ppcls/engine/trainer.py index 9dda5352..6ac3b88b 100644 --- a/ppcls/engine/trainer.py +++ b/ppcls/engine/trainer.py @@ -30,7 +30,7 @@ from ppcls.utils.misc import AverageMeter from ppcls.utils import logger from ppcls.data import build_dataloader from ppcls.arch import build_model -from ppcls.losses import build_loss +from ppcls.loss import build_loss from ppcls.arch.loss_metrics import build_metrics from ppcls.optimizer import build_optimizer from ppcls.utils.save_load import load_dygraph_pretrain diff --git a/ppcls/losses/__init__.py b/ppcls/loss/__init__.py similarity index 100% rename from ppcls/losses/__init__.py rename to ppcls/loss/__init__.py diff --git a/ppcls/losses/celoss.py b/ppcls/loss/celoss.py similarity index 100% rename from ppcls/losses/celoss.py rename to ppcls/loss/celoss.py diff --git a/ppcls/losses/centerloss.py b/ppcls/loss/centerloss.py similarity index 59% rename from ppcls/losses/centerloss.py rename to ppcls/loss/centerloss.py index 7b158b91..d85b3f2a 100644 --- a/ppcls/losses/centerloss.py +++ b/ppcls/loss/centerloss.py @@ -5,12 +5,15 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F + class CenterLoss(nn.Layer): def __init__(self, num_classes=5013, feat_dim=2048): super(CenterLoss, self).__init__() self.num_classes = num_classes self.feat_dim = feat_dim - self.centers = paddle.randn(shape=[self.num_classes, self.feat_dim]).astype("float64") #random center + self.centers = paddle.randn( + shape=[self.num_classes, self.feat_dim]).astype( + "float64") #random center def __call__(self, input, target): """ @@ -23,25 +26,29 @@ class CenterLoss(nn.Layer): #calc feat * feat dist1 = paddle.sum(paddle.square(feats), axis=1, keepdim=True) - dist1 = paddle.expand(dist1, [batch_size, self.num_classes]) + dist1 = paddle.expand(dist1, [batch_size, self.num_classes]) #dist2 of centers - dist2 = paddle.sum(paddle.square(self.centers), axis=1, keepdim=True) #num_classes - dist2 = paddle.expand(dist2, [self.num_classes, batch_size]).astype("float64") + dist2 = paddle.sum(paddle.square(self.centers), axis=1, + keepdim=True) #num_classes + dist2 = paddle.expand(dist2, + [self.num_classes, batch_size]).astype("float64") dist2 = paddle.transpose(dist2, [1, 0]) #first x * x + y * y distmat = paddle.add(dist1, dist2) - tmp = paddle.matmul(feats, paddle.transpose(self.centers, [1, 0])) - distmat = distmat - 2.0 * tmp + tmp = paddle.matmul(feats, paddle.transpose(self.centers, [1, 0])) + distmat = distmat - 2.0 * tmp #generate the mask classes = paddle.arange(self.num_classes).astype("int64") - labels = paddle.expand(paddle.unsqueeze(labels, 1), (batch_size, self.num_classes)) - mask = paddle.equal(paddle.expand(classes, [batch_size, self.num_classes]), labels).astype("float64") #get mask + labels = paddle.expand( + paddle.unsqueeze(labels, 1), (batch_size, self.num_classes)) + mask = paddle.equal( + paddle.expand(classes, [batch_size, self.num_classes]), + labels).astype("float64") #get mask - dist = paddle.multiply(distmat, mask) + dist = paddle.multiply(distmat, mask) loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size return {'CenterLoss': loss} - diff --git a/ppcls/losses/comfunc.py b/ppcls/loss/comfunc.py similarity index 82% rename from ppcls/losses/comfunc.py rename to ppcls/loss/comfunc.py index 234b820f..277bdd6b 100644 --- a/ppcls/losses/comfunc.py +++ b/ppcls/loss/comfunc.py @@ -18,26 +18,27 @@ from __future__ import print_function import numpy as np + def rerange_index(batch_size, samples_each_class): - tmp = np.arange(0, batch_size * batch_size) - tmp = tmp.reshape(-1, batch_size) + tmp = np.arange(0, batch_size * batch_size) + tmp = tmp.reshape(-1, batch_size) rerange_index = [] for i in range(batch_size): step = i // samples_each_class start = step * samples_each_class - end = (step + 1) * samples_each_class + end = (step + 1) * samples_each_class - pos_idx = [] - neg_idx = [] + pos_idx = [] + neg_idx = [] for j, k in enumerate(tmp[i]): if j >= start and j < end: if j == i: pos_idx.insert(0, k) else: - pos_idx.append(k) + pos_idx.append(k) else: - neg_idx.append(k) + neg_idx.append(k) rerange_index += (pos_idx + neg_idx) rerange_index = np.array(rerange_index).astype(np.int32) diff --git a/ppcls/losses/emlloss.py b/ppcls/loss/emlloss.py similarity index 67% rename from ppcls/losses/emlloss.py rename to ppcls/loss/emlloss.py index 410e3478..97357038 100644 --- a/ppcls/losses/emlloss.py +++ b/ppcls/loss/emlloss.py @@ -21,56 +21,64 @@ import paddle import numpy as np from .comfunc import rerange_index + class EmlLoss(paddle.nn.Layer): - def __init__(self, batch_size = 40, samples_each_class = 2): + def __init__(self, batch_size=40, samples_each_class=2): super(EmlLoss, self).__init__() - assert(batch_size % samples_each_class == 0) + assert (batch_size % samples_each_class == 0) self.samples_each_class = samples_each_class - self.batch_size = batch_size - self.rerange_index = rerange_index(batch_size, samples_each_class) + self.batch_size = batch_size + self.rerange_index = rerange_index(batch_size, samples_each_class) self.thresh = 20.0 - self.beta = 100000 - + self.beta = 100000 + def surrogate_function(self, beta, theta, bias): - x = theta * paddle.exp(bias) + x = theta * paddle.exp(bias) output = paddle.log(1 + beta * x) / math.log(1 + beta) return output def surrogate_function_approximate(self, beta, theta, bias): - output = (paddle.log(theta) + bias + math.log(beta)) / math.log(1+beta) + output = ( + paddle.log(theta) + bias + math.log(beta)) / math.log(1 + beta) return output def surrogate_function_stable(self, beta, theta, target, thresh): max_gap = paddle.to_tensor(thresh, dtype='float32') max_gap.stop_gradient = True - + target_max = paddle.maximum(target, max_gap) target_min = paddle.minimum(target, max_gap) - + loss1 = self.surrogate_function(beta, theta, target_min) loss2 = self.surrogate_function_approximate(beta, theta, target_max) - bias = self.surrogate_function(beta, theta, max_gap) - loss = loss1 + loss2 - bias + bias = self.surrogate_function(beta, theta, max_gap) + loss = loss1 + loss2 - bias return loss def forward(self, input, target=None): features = input["features"] samples_each_class = self.samples_each_class - batch_size = self.batch_size - rerange_index = self.rerange_index - + batch_size = self.batch_size + rerange_index = self.rerange_index + #calc distance - diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0) - similary_matrix = paddle.sum(paddle.square(diffs), axis=-1) - - tmp = paddle.reshape(similary_matrix, shape = [-1, 1]) + diffs = paddle.unsqueeze( + features, axis=1) - paddle.unsqueeze( + features, axis=0) + similary_matrix = paddle.sum(paddle.square(diffs), axis=-1) + + tmp = paddle.reshape(similary_matrix, shape=[-1, 1]) rerange_index = paddle.to_tensor(rerange_index) - tmp = paddle.gather(tmp, index=rerange_index) - similary_matrix = paddle.reshape(tmp, shape=[-1, batch_size]) - - ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1, - samples_each_class - 1, batch_size - samples_each_class], axis = 1) - ignore.stop_gradient = True + tmp = paddle.gather(tmp, index=rerange_index) + similary_matrix = paddle.reshape(tmp, shape=[-1, batch_size]) + + ignore, pos, neg = paddle.split( + similary_matrix, + num_or_sections=[ + 1, samples_each_class - 1, batch_size - samples_each_class + ], + axis=1) + ignore.stop_gradient = True pos_max = paddle.max(pos, axis=1, keepdim=True) pos = paddle.exp(pos - pos_max) @@ -79,11 +87,11 @@ class EmlLoss(paddle.nn.Layer): neg_min = paddle.min(neg, axis=1, keepdim=True) neg = paddle.exp(neg_min - neg) neg_mean = paddle.mean(neg, axis=1, keepdim=True) - + bias = pos_max - neg_min theta = paddle.multiply(neg_mean, pos_mean) - loss = self.surrogate_function_stable(self.beta, theta, bias, self.thresh) + loss = self.surrogate_function_stable(self.beta, theta, bias, + self.thresh) loss = paddle.mean(loss) return {"emlloss": loss} - diff --git a/ppcls/losses/msmloss.py b/ppcls/loss/msmloss.py similarity index 66% rename from ppcls/losses/msmloss.py rename to ppcls/loss/msmloss.py index 2585d95e..3aa0dd8b 100644 --- a/ppcls/losses/msmloss.py +++ b/ppcls/loss/msmloss.py @@ -18,6 +18,7 @@ from __future__ import print_function import paddle from .comfunc import rerange_index + class MSMLoss(paddle.nn.Layer): """ MSMLoss Loss, based on triplet loss. USE P * K samples. @@ -31,42 +32,47 @@ class MSMLoss(paddle.nn.Layer): ] only consider samples_each_class = 2 """ - def __init__(self, batch_size = 120, samples_each_class=2, margin=0.1): + + def __init__(self, batch_size=120, samples_each_class=2, margin=0.1): super(MSMLoss, self).__init__() self.margin = margin self.samples_each_class = samples_each_class - self.batch_size = batch_size - self.rerange_index = rerange_index(batch_size, samples_each_class) + self.batch_size = batch_size + self.rerange_index = rerange_index(batch_size, samples_each_class) def forward(self, input, target=None): #normalization features = input["features"] features = self._nomalize(features) samples_each_class = self.samples_each_class - rerange_index = paddle.to_tensor(self.rerange_index) + rerange_index = paddle.to_tensor(self.rerange_index) #calc sm - diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0) - similary_matrix = paddle.sum(paddle.square(diffs), axis=-1) - + diffs = paddle.unsqueeze( + features, axis=1) - paddle.unsqueeze( + features, axis=0) + similary_matrix = paddle.sum(paddle.square(diffs), axis=-1) + #rerange - tmp = paddle.reshape(similary_matrix, shape = [-1, 1]) - tmp = paddle.gather(tmp, index=rerange_index) - similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size]) - + tmp = paddle.reshape(similary_matrix, shape=[-1, 1]) + tmp = paddle.gather(tmp, index=rerange_index) + similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size]) + #split - ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1, - samples_each_class - 1, -1], axis = 1) - ignore.stop_gradient = True + ignore, pos, neg = paddle.split( + similary_matrix, + num_or_sections=[1, samples_each_class - 1, -1], + axis=1) + ignore.stop_gradient = True - hard_pos = paddle.max(pos) + hard_pos = paddle.max(pos) hard_neg = paddle.min(neg) loss = hard_pos + self.margin - hard_neg - loss = paddle.nn.ReLU()(loss) + loss = paddle.nn.ReLU()(loss) return {"msmloss": loss} def _nomalize(self, input): - input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True)) + input_norm = paddle.sqrt( + paddle.sum(paddle.square(input), axis=1, keepdim=True)) return paddle.divide(input, input_norm) - diff --git a/ppcls/losses/npairsloss.py b/ppcls/loss/npairsloss.py similarity index 73% rename from ppcls/losses/npairsloss.py rename to ppcls/loss/npairsloss.py index ecd23349..d4b359e8 100644 --- a/ppcls/losses/npairsloss.py +++ b/ppcls/loss/npairsloss.py @@ -3,12 +3,12 @@ from __future__ import division from __future__ import print_function import paddle + class NpairsLoss(paddle.nn.Layer): - def __init__(self, reg_lambda=0.01): super(NpairsLoss, self).__init__() self.reg_lambda = reg_lambda - + def forward(self, input, target=None): """ anchor and positive(should include label) @@ -16,22 +16,23 @@ class NpairsLoss(paddle.nn.Layer): features = input["features"] reg_lambda = self.reg_lambda batch_size = features.shape[0] - fea_dim = features.shape[1] + fea_dim = features.shape[1] num_class = batch_size // 2 - + #reshape out_feas = paddle.reshape(features, shape=[-1, 2, fea_dim]) - anc_feas, pos_feas = paddle.split(out_feas, num_or_sections = 2, axis = 1) - anc_feas = paddle.squeeze(anc_feas, axis=1) + anc_feas, pos_feas = paddle.split(out_feas, num_or_sections=2, axis=1) + anc_feas = paddle.squeeze(anc_feas, axis=1) pos_feas = paddle.squeeze(pos_feas, axis=1) - + #get simi matrix - similarity_matrix = paddle.matmul(anc_feas, pos_feas, transpose_y=True) #get similarity matrix + similarity_matrix = paddle.matmul( + anc_feas, pos_feas, transpose_y=True) #get similarity matrix sparse_labels = paddle.arange(0, num_class, dtype='int64') - xentloss = paddle.nn.CrossEntropyLoss()(similarity_matrix, sparse_labels) #by default: mean - + xentloss = paddle.nn.CrossEntropyLoss()( + similarity_matrix, sparse_labels) #by default: mean + #l2 norm reg = paddle.mean(paddle.sum(paddle.square(features), axis=1)) l2loss = 0.5 * reg_lambda * reg return {"npairsloss": xentloss + l2loss} - diff --git a/ppcls/losses/trihardloss.py b/ppcls/loss/trihardloss.py similarity index 66% rename from ppcls/losses/trihardloss.py rename to ppcls/loss/trihardloss.py index a122c4f1..132c604d 100644 --- a/ppcls/losses/trihardloss.py +++ b/ppcls/loss/trihardloss.py @@ -19,6 +19,7 @@ from __future__ import print_function import paddle from .comfunc import rerange_index + class TriHardLoss(paddle.nn.Layer): """ TriHard Loss, based on triplet loss. USE P * K samples. @@ -32,45 +33,50 @@ class TriHardLoss(paddle.nn.Layer): ] only consider samples_each_class = 2 """ - def __init__(self, batch_size = 120, samples_each_class=2, margin=0.1): + + def __init__(self, batch_size=120, samples_each_class=2, margin=0.1): super(TriHardLoss, self).__init__() self.margin = margin self.samples_each_class = samples_each_class - self.batch_size = batch_size - self.rerange_index = rerange_index(batch_size, samples_each_class) + self.batch_size = batch_size + self.rerange_index = rerange_index(batch_size, samples_each_class) def forward(self, input, target=None): features = input["features"] assert (self.batch_size == features.shape[0]) - + #normalization features = self._nomalize(features) samples_each_class = self.samples_each_class - rerange_index = paddle.to_tensor(self.rerange_index) + rerange_index = paddle.to_tensor(self.rerange_index) #calc sm - diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0) - similary_matrix = paddle.sum(paddle.square(diffs), axis=-1) - + diffs = paddle.unsqueeze( + features, axis=1) - paddle.unsqueeze( + features, axis=0) + similary_matrix = paddle.sum(paddle.square(diffs), axis=-1) + #rerange - tmp = paddle.reshape(similary_matrix, shape = [-1, 1]) - tmp = paddle.gather(tmp, index=rerange_index) - similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size]) - + tmp = paddle.reshape(similary_matrix, shape=[-1, 1]) + tmp = paddle.gather(tmp, index=rerange_index) + similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size]) + #split - ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1, - samples_each_class - 1, -1], axis = 1) - - ignore.stop_gradient = True - hard_pos = paddle.max(pos, axis=1) + ignore, pos, neg = paddle.split( + similary_matrix, + num_or_sections=[1, samples_each_class - 1, -1], + axis=1) + + ignore.stop_gradient = True + hard_pos = paddle.max(pos, axis=1) hard_neg = paddle.min(neg, axis=1) loss = hard_pos + self.margin - hard_neg - loss = paddle.nn.ReLU()(loss) + loss = paddle.nn.ReLU()(loss) loss = paddle.mean(loss) return {"trihardloss": loss} def _nomalize(self, input): - input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True)) + input_norm = paddle.sqrt( + paddle.sum(paddle.square(input), axis=1, keepdim=True)) return paddle.divide(input, input_norm) - diff --git a/ppcls/losses/triplet.py b/ppcls/loss/triplet.py similarity index 100% rename from ppcls/losses/triplet.py rename to ppcls/loss/triplet.py -- GitLab