提交 69f563d2 编写于 作者: W weishengyu

rename losses -> loss

上级 51f0b78b
...@@ -30,7 +30,7 @@ from ppcls.utils.misc import AverageMeter ...@@ -30,7 +30,7 @@ from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger from ppcls.utils import logger
from ppcls.data import build_dataloader from ppcls.data import build_dataloader
from ppcls.arch import build_model from ppcls.arch import build_model
from ppcls.losses import build_loss from ppcls.loss import build_loss
from ppcls.arch.loss_metrics import build_metrics from ppcls.arch.loss_metrics import build_metrics
from ppcls.optimizer import build_optimizer from ppcls.optimizer import build_optimizer
from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import load_dygraph_pretrain
......
...@@ -5,12 +5,15 @@ import paddle ...@@ -5,12 +5,15 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
class CenterLoss(nn.Layer): class CenterLoss(nn.Layer):
def __init__(self, num_classes=5013, feat_dim=2048): def __init__(self, num_classes=5013, feat_dim=2048):
super(CenterLoss, self).__init__() super(CenterLoss, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.feat_dim = feat_dim self.feat_dim = feat_dim
self.centers = paddle.randn(shape=[self.num_classes, self.feat_dim]).astype("float64") #random center self.centers = paddle.randn(
shape=[self.num_classes, self.feat_dim]).astype(
"float64") #random center
def __call__(self, input, target): def __call__(self, input, target):
""" """
...@@ -23,25 +26,29 @@ class CenterLoss(nn.Layer): ...@@ -23,25 +26,29 @@ class CenterLoss(nn.Layer):
#calc feat * feat #calc feat * feat
dist1 = paddle.sum(paddle.square(feats), axis=1, keepdim=True) dist1 = paddle.sum(paddle.square(feats), axis=1, keepdim=True)
dist1 = paddle.expand(dist1, [batch_size, self.num_classes]) dist1 = paddle.expand(dist1, [batch_size, self.num_classes])
#dist2 of centers #dist2 of centers
dist2 = paddle.sum(paddle.square(self.centers), axis=1, keepdim=True) #num_classes dist2 = paddle.sum(paddle.square(self.centers), axis=1,
dist2 = paddle.expand(dist2, [self.num_classes, batch_size]).astype("float64") keepdim=True) #num_classes
dist2 = paddle.expand(dist2,
[self.num_classes, batch_size]).astype("float64")
dist2 = paddle.transpose(dist2, [1, 0]) dist2 = paddle.transpose(dist2, [1, 0])
#first x * x + y * y #first x * x + y * y
distmat = paddle.add(dist1, dist2) distmat = paddle.add(dist1, dist2)
tmp = paddle.matmul(feats, paddle.transpose(self.centers, [1, 0])) tmp = paddle.matmul(feats, paddle.transpose(self.centers, [1, 0]))
distmat = distmat - 2.0 * tmp distmat = distmat - 2.0 * tmp
#generate the mask #generate the mask
classes = paddle.arange(self.num_classes).astype("int64") classes = paddle.arange(self.num_classes).astype("int64")
labels = paddle.expand(paddle.unsqueeze(labels, 1), (batch_size, self.num_classes)) labels = paddle.expand(
mask = paddle.equal(paddle.expand(classes, [batch_size, self.num_classes]), labels).astype("float64") #get mask paddle.unsqueeze(labels, 1), (batch_size, self.num_classes))
mask = paddle.equal(
paddle.expand(classes, [batch_size, self.num_classes]),
labels).astype("float64") #get mask
dist = paddle.multiply(distmat, mask) dist = paddle.multiply(distmat, mask)
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
return {'CenterLoss': loss} return {'CenterLoss': loss}
...@@ -18,26 +18,27 @@ from __future__ import print_function ...@@ -18,26 +18,27 @@ from __future__ import print_function
import numpy as np import numpy as np
def rerange_index(batch_size, samples_each_class): def rerange_index(batch_size, samples_each_class):
tmp = np.arange(0, batch_size * batch_size) tmp = np.arange(0, batch_size * batch_size)
tmp = tmp.reshape(-1, batch_size) tmp = tmp.reshape(-1, batch_size)
rerange_index = [] rerange_index = []
for i in range(batch_size): for i in range(batch_size):
step = i // samples_each_class step = i // samples_each_class
start = step * samples_each_class start = step * samples_each_class
end = (step + 1) * samples_each_class end = (step + 1) * samples_each_class
pos_idx = [] pos_idx = []
neg_idx = [] neg_idx = []
for j, k in enumerate(tmp[i]): for j, k in enumerate(tmp[i]):
if j >= start and j < end: if j >= start and j < end:
if j == i: if j == i:
pos_idx.insert(0, k) pos_idx.insert(0, k)
else: else:
pos_idx.append(k) pos_idx.append(k)
else: else:
neg_idx.append(k) neg_idx.append(k)
rerange_index += (pos_idx + neg_idx) rerange_index += (pos_idx + neg_idx)
rerange_index = np.array(rerange_index).astype(np.int32) rerange_index = np.array(rerange_index).astype(np.int32)
......
...@@ -21,56 +21,64 @@ import paddle ...@@ -21,56 +21,64 @@ import paddle
import numpy as np import numpy as np
from .comfunc import rerange_index from .comfunc import rerange_index
class EmlLoss(paddle.nn.Layer): class EmlLoss(paddle.nn.Layer):
def __init__(self, batch_size = 40, samples_each_class = 2): def __init__(self, batch_size=40, samples_each_class=2):
super(EmlLoss, self).__init__() super(EmlLoss, self).__init__()
assert(batch_size % samples_each_class == 0) assert (batch_size % samples_each_class == 0)
self.samples_each_class = samples_each_class self.samples_each_class = samples_each_class
self.batch_size = batch_size self.batch_size = batch_size
self.rerange_index = rerange_index(batch_size, samples_each_class) self.rerange_index = rerange_index(batch_size, samples_each_class)
self.thresh = 20.0 self.thresh = 20.0
self.beta = 100000 self.beta = 100000
def surrogate_function(self, beta, theta, bias): def surrogate_function(self, beta, theta, bias):
x = theta * paddle.exp(bias) x = theta * paddle.exp(bias)
output = paddle.log(1 + beta * x) / math.log(1 + beta) output = paddle.log(1 + beta * x) / math.log(1 + beta)
return output return output
def surrogate_function_approximate(self, beta, theta, bias): def surrogate_function_approximate(self, beta, theta, bias):
output = (paddle.log(theta) + bias + math.log(beta)) / math.log(1+beta) output = (
paddle.log(theta) + bias + math.log(beta)) / math.log(1 + beta)
return output return output
def surrogate_function_stable(self, beta, theta, target, thresh): def surrogate_function_stable(self, beta, theta, target, thresh):
max_gap = paddle.to_tensor(thresh, dtype='float32') max_gap = paddle.to_tensor(thresh, dtype='float32')
max_gap.stop_gradient = True max_gap.stop_gradient = True
target_max = paddle.maximum(target, max_gap) target_max = paddle.maximum(target, max_gap)
target_min = paddle.minimum(target, max_gap) target_min = paddle.minimum(target, max_gap)
loss1 = self.surrogate_function(beta, theta, target_min) loss1 = self.surrogate_function(beta, theta, target_min)
loss2 = self.surrogate_function_approximate(beta, theta, target_max) loss2 = self.surrogate_function_approximate(beta, theta, target_max)
bias = self.surrogate_function(beta, theta, max_gap) bias = self.surrogate_function(beta, theta, max_gap)
loss = loss1 + loss2 - bias loss = loss1 + loss2 - bias
return loss return loss
def forward(self, input, target=None): def forward(self, input, target=None):
features = input["features"] features = input["features"]
samples_each_class = self.samples_each_class samples_each_class = self.samples_each_class
batch_size = self.batch_size batch_size = self.batch_size
rerange_index = self.rerange_index rerange_index = self.rerange_index
#calc distance #calc distance
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0) diffs = paddle.unsqueeze(
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1) features, axis=1) - paddle.unsqueeze(
features, axis=0)
tmp = paddle.reshape(similary_matrix, shape = [-1, 1]) similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
tmp = paddle.reshape(similary_matrix, shape=[-1, 1])
rerange_index = paddle.to_tensor(rerange_index) rerange_index = paddle.to_tensor(rerange_index)
tmp = paddle.gather(tmp, index=rerange_index) tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, batch_size]) similary_matrix = paddle.reshape(tmp, shape=[-1, batch_size])
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1, ignore, pos, neg = paddle.split(
samples_each_class - 1, batch_size - samples_each_class], axis = 1) similary_matrix,
ignore.stop_gradient = True num_or_sections=[
1, samples_each_class - 1, batch_size - samples_each_class
],
axis=1)
ignore.stop_gradient = True
pos_max = paddle.max(pos, axis=1, keepdim=True) pos_max = paddle.max(pos, axis=1, keepdim=True)
pos = paddle.exp(pos - pos_max) pos = paddle.exp(pos - pos_max)
...@@ -79,11 +87,11 @@ class EmlLoss(paddle.nn.Layer): ...@@ -79,11 +87,11 @@ class EmlLoss(paddle.nn.Layer):
neg_min = paddle.min(neg, axis=1, keepdim=True) neg_min = paddle.min(neg, axis=1, keepdim=True)
neg = paddle.exp(neg_min - neg) neg = paddle.exp(neg_min - neg)
neg_mean = paddle.mean(neg, axis=1, keepdim=True) neg_mean = paddle.mean(neg, axis=1, keepdim=True)
bias = pos_max - neg_min bias = pos_max - neg_min
theta = paddle.multiply(neg_mean, pos_mean) theta = paddle.multiply(neg_mean, pos_mean)
loss = self.surrogate_function_stable(self.beta, theta, bias, self.thresh) loss = self.surrogate_function_stable(self.beta, theta, bias,
self.thresh)
loss = paddle.mean(loss) loss = paddle.mean(loss)
return {"emlloss": loss} return {"emlloss": loss}
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
import paddle import paddle
from .comfunc import rerange_index from .comfunc import rerange_index
class MSMLoss(paddle.nn.Layer): class MSMLoss(paddle.nn.Layer):
""" """
MSMLoss Loss, based on triplet loss. USE P * K samples. MSMLoss Loss, based on triplet loss. USE P * K samples.
...@@ -31,42 +32,47 @@ class MSMLoss(paddle.nn.Layer): ...@@ -31,42 +32,47 @@ class MSMLoss(paddle.nn.Layer):
] ]
only consider samples_each_class = 2 only consider samples_each_class = 2
""" """
def __init__(self, batch_size = 120, samples_each_class=2, margin=0.1):
def __init__(self, batch_size=120, samples_each_class=2, margin=0.1):
super(MSMLoss, self).__init__() super(MSMLoss, self).__init__()
self.margin = margin self.margin = margin
self.samples_each_class = samples_each_class self.samples_each_class = samples_each_class
self.batch_size = batch_size self.batch_size = batch_size
self.rerange_index = rerange_index(batch_size, samples_each_class) self.rerange_index = rerange_index(batch_size, samples_each_class)
def forward(self, input, target=None): def forward(self, input, target=None):
#normalization #normalization
features = input["features"] features = input["features"]
features = self._nomalize(features) features = self._nomalize(features)
samples_each_class = self.samples_each_class samples_each_class = self.samples_each_class
rerange_index = paddle.to_tensor(self.rerange_index) rerange_index = paddle.to_tensor(self.rerange_index)
#calc sm #calc sm
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0) diffs = paddle.unsqueeze(
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1) features, axis=1) - paddle.unsqueeze(
features, axis=0)
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
#rerange #rerange
tmp = paddle.reshape(similary_matrix, shape = [-1, 1]) tmp = paddle.reshape(similary_matrix, shape=[-1, 1])
tmp = paddle.gather(tmp, index=rerange_index) tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size]) similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size])
#split #split
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1, ignore, pos, neg = paddle.split(
samples_each_class - 1, -1], axis = 1) similary_matrix,
ignore.stop_gradient = True num_or_sections=[1, samples_each_class - 1, -1],
axis=1)
ignore.stop_gradient = True
hard_pos = paddle.max(pos) hard_pos = paddle.max(pos)
hard_neg = paddle.min(neg) hard_neg = paddle.min(neg)
loss = hard_pos + self.margin - hard_neg loss = hard_pos + self.margin - hard_neg
loss = paddle.nn.ReLU()(loss) loss = paddle.nn.ReLU()(loss)
return {"msmloss": loss} return {"msmloss": loss}
def _nomalize(self, input): def _nomalize(self, input):
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True)) input_norm = paddle.sqrt(
paddle.sum(paddle.square(input), axis=1, keepdim=True))
return paddle.divide(input, input_norm) return paddle.divide(input, input_norm)
...@@ -3,12 +3,12 @@ from __future__ import division ...@@ -3,12 +3,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import paddle import paddle
class NpairsLoss(paddle.nn.Layer): class NpairsLoss(paddle.nn.Layer):
def __init__(self, reg_lambda=0.01): def __init__(self, reg_lambda=0.01):
super(NpairsLoss, self).__init__() super(NpairsLoss, self).__init__()
self.reg_lambda = reg_lambda self.reg_lambda = reg_lambda
def forward(self, input, target=None): def forward(self, input, target=None):
""" """
anchor and positive(should include label) anchor and positive(should include label)
...@@ -16,22 +16,23 @@ class NpairsLoss(paddle.nn.Layer): ...@@ -16,22 +16,23 @@ class NpairsLoss(paddle.nn.Layer):
features = input["features"] features = input["features"]
reg_lambda = self.reg_lambda reg_lambda = self.reg_lambda
batch_size = features.shape[0] batch_size = features.shape[0]
fea_dim = features.shape[1] fea_dim = features.shape[1]
num_class = batch_size // 2 num_class = batch_size // 2
#reshape #reshape
out_feas = paddle.reshape(features, shape=[-1, 2, fea_dim]) out_feas = paddle.reshape(features, shape=[-1, 2, fea_dim])
anc_feas, pos_feas = paddle.split(out_feas, num_or_sections = 2, axis = 1) anc_feas, pos_feas = paddle.split(out_feas, num_or_sections=2, axis=1)
anc_feas = paddle.squeeze(anc_feas, axis=1) anc_feas = paddle.squeeze(anc_feas, axis=1)
pos_feas = paddle.squeeze(pos_feas, axis=1) pos_feas = paddle.squeeze(pos_feas, axis=1)
#get simi matrix #get simi matrix
similarity_matrix = paddle.matmul(anc_feas, pos_feas, transpose_y=True) #get similarity matrix similarity_matrix = paddle.matmul(
anc_feas, pos_feas, transpose_y=True) #get similarity matrix
sparse_labels = paddle.arange(0, num_class, dtype='int64') sparse_labels = paddle.arange(0, num_class, dtype='int64')
xentloss = paddle.nn.CrossEntropyLoss()(similarity_matrix, sparse_labels) #by default: mean xentloss = paddle.nn.CrossEntropyLoss()(
similarity_matrix, sparse_labels) #by default: mean
#l2 norm #l2 norm
reg = paddle.mean(paddle.sum(paddle.square(features), axis=1)) reg = paddle.mean(paddle.sum(paddle.square(features), axis=1))
l2loss = 0.5 * reg_lambda * reg l2loss = 0.5 * reg_lambda * reg
return {"npairsloss": xentloss + l2loss} return {"npairsloss": xentloss + l2loss}
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
import paddle import paddle
from .comfunc import rerange_index from .comfunc import rerange_index
class TriHardLoss(paddle.nn.Layer): class TriHardLoss(paddle.nn.Layer):
""" """
TriHard Loss, based on triplet loss. USE P * K samples. TriHard Loss, based on triplet loss. USE P * K samples.
...@@ -32,45 +33,50 @@ class TriHardLoss(paddle.nn.Layer): ...@@ -32,45 +33,50 @@ class TriHardLoss(paddle.nn.Layer):
] ]
only consider samples_each_class = 2 only consider samples_each_class = 2
""" """
def __init__(self, batch_size = 120, samples_each_class=2, margin=0.1):
def __init__(self, batch_size=120, samples_each_class=2, margin=0.1):
super(TriHardLoss, self).__init__() super(TriHardLoss, self).__init__()
self.margin = margin self.margin = margin
self.samples_each_class = samples_each_class self.samples_each_class = samples_each_class
self.batch_size = batch_size self.batch_size = batch_size
self.rerange_index = rerange_index(batch_size, samples_each_class) self.rerange_index = rerange_index(batch_size, samples_each_class)
def forward(self, input, target=None): def forward(self, input, target=None):
features = input["features"] features = input["features"]
assert (self.batch_size == features.shape[0]) assert (self.batch_size == features.shape[0])
#normalization #normalization
features = self._nomalize(features) features = self._nomalize(features)
samples_each_class = self.samples_each_class samples_each_class = self.samples_each_class
rerange_index = paddle.to_tensor(self.rerange_index) rerange_index = paddle.to_tensor(self.rerange_index)
#calc sm #calc sm
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0) diffs = paddle.unsqueeze(
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1) features, axis=1) - paddle.unsqueeze(
features, axis=0)
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
#rerange #rerange
tmp = paddle.reshape(similary_matrix, shape = [-1, 1]) tmp = paddle.reshape(similary_matrix, shape=[-1, 1])
tmp = paddle.gather(tmp, index=rerange_index) tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size]) similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size])
#split #split
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1, ignore, pos, neg = paddle.split(
samples_each_class - 1, -1], axis = 1) similary_matrix,
num_or_sections=[1, samples_each_class - 1, -1],
ignore.stop_gradient = True axis=1)
hard_pos = paddle.max(pos, axis=1)
ignore.stop_gradient = True
hard_pos = paddle.max(pos, axis=1)
hard_neg = paddle.min(neg, axis=1) hard_neg = paddle.min(neg, axis=1)
loss = hard_pos + self.margin - hard_neg loss = hard_pos + self.margin - hard_neg
loss = paddle.nn.ReLU()(loss) loss = paddle.nn.ReLU()(loss)
loss = paddle.mean(loss) loss = paddle.mean(loss)
return {"trihardloss": loss} return {"trihardloss": loss}
def _nomalize(self, input): def _nomalize(self, input):
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True)) input_norm = paddle.sqrt(
paddle.sum(paddle.square(input), axis=1, keepdim=True))
return paddle.divide(input, input_norm) return paddle.divide(input, input_norm)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册