提交 69f563d2 编写于 作者: W weishengyu

rename losses -> loss

上级 51f0b78b
......@@ -30,7 +30,7 @@ from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.data import build_dataloader
from ppcls.arch import build_model
from ppcls.losses import build_loss
from ppcls.loss import build_loss
from ppcls.arch.loss_metrics import build_metrics
from ppcls.optimizer import build_optimizer
from ppcls.utils.save_load import load_dygraph_pretrain
......
......@@ -5,12 +5,15 @@ import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class CenterLoss(nn.Layer):
def __init__(self, num_classes=5013, feat_dim=2048):
super(CenterLoss, self).__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.centers = paddle.randn(shape=[self.num_classes, self.feat_dim]).astype("float64") #random center
self.centers = paddle.randn(
shape=[self.num_classes, self.feat_dim]).astype(
"float64") #random center
def __call__(self, input, target):
"""
......@@ -26,8 +29,10 @@ class CenterLoss(nn.Layer):
dist1 = paddle.expand(dist1, [batch_size, self.num_classes])
#dist2 of centers
dist2 = paddle.sum(paddle.square(self.centers), axis=1, keepdim=True) #num_classes
dist2 = paddle.expand(dist2, [self.num_classes, batch_size]).astype("float64")
dist2 = paddle.sum(paddle.square(self.centers), axis=1,
keepdim=True) #num_classes
dist2 = paddle.expand(dist2,
[self.num_classes, batch_size]).astype("float64")
dist2 = paddle.transpose(dist2, [1, 0])
#first x * x + y * y
......@@ -37,11 +42,13 @@ class CenterLoss(nn.Layer):
#generate the mask
classes = paddle.arange(self.num_classes).astype("int64")
labels = paddle.expand(paddle.unsqueeze(labels, 1), (batch_size, self.num_classes))
mask = paddle.equal(paddle.expand(classes, [batch_size, self.num_classes]), labels).astype("float64") #get mask
labels = paddle.expand(
paddle.unsqueeze(labels, 1), (batch_size, self.num_classes))
mask = paddle.equal(
paddle.expand(classes, [batch_size, self.num_classes]),
labels).astype("float64") #get mask
dist = paddle.multiply(distmat, mask)
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
return {'CenterLoss': loss}
......@@ -18,6 +18,7 @@ from __future__ import print_function
import numpy as np
def rerange_index(batch_size, samples_each_class):
tmp = np.arange(0, batch_size * batch_size)
tmp = tmp.reshape(-1, batch_size)
......
......@@ -21,10 +21,11 @@ import paddle
import numpy as np
from .comfunc import rerange_index
class EmlLoss(paddle.nn.Layer):
def __init__(self, batch_size = 40, samples_each_class = 2):
def __init__(self, batch_size=40, samples_each_class=2):
super(EmlLoss, self).__init__()
assert(batch_size % samples_each_class == 0)
assert (batch_size % samples_each_class == 0)
self.samples_each_class = samples_each_class
self.batch_size = batch_size
self.rerange_index = rerange_index(batch_size, samples_each_class)
......@@ -37,7 +38,8 @@ class EmlLoss(paddle.nn.Layer):
return output
def surrogate_function_approximate(self, beta, theta, bias):
output = (paddle.log(theta) + bias + math.log(beta)) / math.log(1+beta)
output = (
paddle.log(theta) + bias + math.log(beta)) / math.log(1 + beta)
return output
def surrogate_function_stable(self, beta, theta, target, thresh):
......@@ -60,16 +62,22 @@ class EmlLoss(paddle.nn.Layer):
rerange_index = self.rerange_index
#calc distance
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0)
diffs = paddle.unsqueeze(
features, axis=1) - paddle.unsqueeze(
features, axis=0)
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
tmp = paddle.reshape(similary_matrix, shape = [-1, 1])
tmp = paddle.reshape(similary_matrix, shape=[-1, 1])
rerange_index = paddle.to_tensor(rerange_index)
tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, batch_size])
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1,
samples_each_class - 1, batch_size - samples_each_class], axis = 1)
ignore, pos, neg = paddle.split(
similary_matrix,
num_or_sections=[
1, samples_each_class - 1, batch_size - samples_each_class
],
axis=1)
ignore.stop_gradient = True
pos_max = paddle.max(pos, axis=1, keepdim=True)
......@@ -83,7 +91,7 @@ class EmlLoss(paddle.nn.Layer):
bias = pos_max - neg_min
theta = paddle.multiply(neg_mean, pos_mean)
loss = self.surrogate_function_stable(self.beta, theta, bias, self.thresh)
loss = self.surrogate_function_stable(self.beta, theta, bias,
self.thresh)
loss = paddle.mean(loss)
return {"emlloss": loss}
......@@ -18,6 +18,7 @@ from __future__ import print_function
import paddle
from .comfunc import rerange_index
class MSMLoss(paddle.nn.Layer):
"""
MSMLoss Loss, based on triplet loss. USE P * K samples.
......@@ -31,7 +32,8 @@ class MSMLoss(paddle.nn.Layer):
]
only consider samples_each_class = 2
"""
def __init__(self, batch_size = 120, samples_each_class=2, margin=0.1):
def __init__(self, batch_size=120, samples_each_class=2, margin=0.1):
super(MSMLoss, self).__init__()
self.margin = margin
self.samples_each_class = samples_each_class
......@@ -46,17 +48,21 @@ class MSMLoss(paddle.nn.Layer):
rerange_index = paddle.to_tensor(self.rerange_index)
#calc sm
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0)
diffs = paddle.unsqueeze(
features, axis=1) - paddle.unsqueeze(
features, axis=0)
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
#rerange
tmp = paddle.reshape(similary_matrix, shape = [-1, 1])
tmp = paddle.reshape(similary_matrix, shape=[-1, 1])
tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size])
#split
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1,
samples_each_class - 1, -1], axis = 1)
ignore, pos, neg = paddle.split(
similary_matrix,
num_or_sections=[1, samples_each_class - 1, -1],
axis=1)
ignore.stop_gradient = True
hard_pos = paddle.max(pos)
......@@ -67,6 +73,6 @@ class MSMLoss(paddle.nn.Layer):
return {"msmloss": loss}
def _nomalize(self, input):
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True))
input_norm = paddle.sqrt(
paddle.sum(paddle.square(input), axis=1, keepdim=True))
return paddle.divide(input, input_norm)
......@@ -3,8 +3,8 @@ from __future__ import division
from __future__ import print_function
import paddle
class NpairsLoss(paddle.nn.Layer):
class NpairsLoss(paddle.nn.Layer):
def __init__(self, reg_lambda=0.01):
super(NpairsLoss, self).__init__()
self.reg_lambda = reg_lambda
......@@ -21,17 +21,18 @@ class NpairsLoss(paddle.nn.Layer):
#reshape
out_feas = paddle.reshape(features, shape=[-1, 2, fea_dim])
anc_feas, pos_feas = paddle.split(out_feas, num_or_sections = 2, axis = 1)
anc_feas, pos_feas = paddle.split(out_feas, num_or_sections=2, axis=1)
anc_feas = paddle.squeeze(anc_feas, axis=1)
pos_feas = paddle.squeeze(pos_feas, axis=1)
#get simi matrix
similarity_matrix = paddle.matmul(anc_feas, pos_feas, transpose_y=True) #get similarity matrix
similarity_matrix = paddle.matmul(
anc_feas, pos_feas, transpose_y=True) #get similarity matrix
sparse_labels = paddle.arange(0, num_class, dtype='int64')
xentloss = paddle.nn.CrossEntropyLoss()(similarity_matrix, sparse_labels) #by default: mean
xentloss = paddle.nn.CrossEntropyLoss()(
similarity_matrix, sparse_labels) #by default: mean
#l2 norm
reg = paddle.mean(paddle.sum(paddle.square(features), axis=1))
l2loss = 0.5 * reg_lambda * reg
return {"npairsloss": xentloss + l2loss}
......@@ -19,6 +19,7 @@ from __future__ import print_function
import paddle
from .comfunc import rerange_index
class TriHardLoss(paddle.nn.Layer):
"""
TriHard Loss, based on triplet loss. USE P * K samples.
......@@ -32,7 +33,8 @@ class TriHardLoss(paddle.nn.Layer):
]
only consider samples_each_class = 2
"""
def __init__(self, batch_size = 120, samples_each_class=2, margin=0.1):
def __init__(self, batch_size=120, samples_each_class=2, margin=0.1):
super(TriHardLoss, self).__init__()
self.margin = margin
self.samples_each_class = samples_each_class
......@@ -49,17 +51,21 @@ class TriHardLoss(paddle.nn.Layer):
rerange_index = paddle.to_tensor(self.rerange_index)
#calc sm
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0)
diffs = paddle.unsqueeze(
features, axis=1) - paddle.unsqueeze(
features, axis=0)
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
#rerange
tmp = paddle.reshape(similary_matrix, shape = [-1, 1])
tmp = paddle.reshape(similary_matrix, shape=[-1, 1])
tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size])
#split
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1,
samples_each_class - 1, -1], axis = 1)
ignore, pos, neg = paddle.split(
similary_matrix,
num_or_sections=[1, samples_each_class - 1, -1],
axis=1)
ignore.stop_gradient = True
hard_pos = paddle.max(pos, axis=1)
......@@ -71,6 +77,6 @@ class TriHardLoss(paddle.nn.Layer):
return {"trihardloss": loss}
def _nomalize(self, input):
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True))
input_norm = paddle.sqrt(
paddle.sum(paddle.square(input), axis=1, keepdim=True))
return paddle.divide(input, input_norm)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册