未验证 提交 f62604de 编写于 作者: B Bin Lu 提交者: GitHub

Add files via upload

add losses
上级 90418d79
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn.functional as F
__all__ = ['CELoss', 'JSDivLoss', 'KLDivLoss']
class Loss(object):
"""
Loss
"""
def __init__(self, class_dim=1000, epsilon=None):
assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
self._class_dim = class_dim
if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
self._epsilon = epsilon
self._label_smoothing = True #use label smoothing.(Actually, it is softmax label)
else:
self._epsilon = None
self._label_smoothing = False
#do label_smoothing
def _labelsmoothing(self, target):
if target.shape[-1] != self._class_dim:
one_hot_target = F.one_hot(target, self._class_dim) #do ont hot(23,34,46)-> 3 * _class_dim
else:
one_hot_target = target
#do label_smooth
soft_target = F.label_smooth(one_hot_target, epsilon=self._epsilon) #(1 - epsilon) * input + eposilon / K.
soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
return soft_target
def _crossentropy(self, input, target, use_pure_fp16=False):
if self._label_smoothing:
target = self._labelsmoothing(target)
input = -F.log_softmax(input, axis=-1) #softmax and do log
cost = paddle.sum(target * input, axis=-1) #sum
else:
cost = F.cross_entropy(input=input, label=target)
if use_pure_fp16:
avg_cost = paddle.sum(cost)
else:
avg_cost = paddle.mean(cost)
return avg_cost
def _kldiv(self, input, target, name=None):
eps = 1.0e-10
cost = target * paddle.log(
(target + eps) / (input + eps)) * self._class_dim
return cost
def _jsdiv(self, input, target): #so the input and target is the fc output; no softmax
input = F.softmax(input)
target = F.softmax(target)
#two distribution
cost = self._kldiv(input, target) + self._kldiv(target, input)
cost = cost / 2
avg_cost = paddle.mean(cost)
return avg_cost
def __call__(self, input, target):
pass
class CELoss(Loss):
"""
Cross entropy loss
"""
def __init__(self, class_dim=1000, epsilon=None):
super(CELoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target, use_pure_fp16=False):
logits = input["logits"]
cost = self._crossentropy(logits, target, use_pure_fp16)
return {"CELoss": cost}
class JSDivLoss(Loss):
"""
JSDiv loss
"""
def __init__(self, class_dim=1000, epsilon=None):
super(JSDivLoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target):
cost = self._jsdiv(input, target)
return cost
class KLDivLoss(paddle.nn.Layer):
def __init__(self):
super(KLDivLoss, self).__init__()
def __call__(self, p, q, is_logit=True):
if is_logit:
p = paddle.nn.functional.softmax(p)
q = paddle.nn.functional.softmax(q)
return -(p * paddle.log(q + 1e-8)).sum(1).mean()
\ No newline at end of file
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class CenterLoss(nn.Layer):
def __init__(self, num_classes=5013, feat_dim=2048):
super(CenterLoss, self).__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.centers = paddle.randn(shape=[self.num_classes, self.feat_dim]).astype("float64") #random center
def __call__(self, input, target):
"""
inputs: network output: {"features: xxx", "logits": xxxx}
target: image label
"""
feats = input["features"]
labels = target
batch_size = feats.shape[0]
#calc feat * feat
dist1 = paddle.sum(paddle.square(feats), axis=1, keepdim=True)
dist1 = paddle.expand(dist1, [batch_size, self.num_classes])
#dist2 of centers
dist2 = paddle.sum(paddle.square(self.centers), axis=1, keepdim=True) #num_classes
dist2 = paddle.expand(dist2, [self.num_classes, batch_size]).astype("float64")
dist2 = paddle.transpose(dist2, [1, 0])
#first x * x + y * y
distmat = paddle.add(dist1, dist2)
tmp = paddle.matmul(feats, paddle.transpose(self.centers, [1, 0]))
distmat = distmat - 2.0 * tmp
#generate the mask
classes = paddle.arange(self.num_classes).astype("int64")
labels = paddle.expand(paddle.unsqueeze(labels, 1), (batch_size, self.num_classes))
mask = paddle.equal(paddle.expand(classes, [batch_size, self.num_classes]), labels).astype("float64") #get mask
dist = paddle.multiply(distmat, mask)
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
return {'CenterLoss': loss}
\ No newline at end of file
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def rerange_index(batch_size, samples_each_class):
tmp = np.arange(0, batch_size * batch_size)
tmp = tmp.reshape(-1, batch_size)
rerange_index = []
for i in range(batch_size):
step = i // samples_each_class
start = step * samples_each_class
end = (step + 1) * samples_each_class
pos_idx = []
neg_idx = []
for j, k in enumerate(tmp[i]):
if j >= start and j < end:
if j == i:
pos_idx.insert(0, k)
else:
pos_idx.append(k)
else:
neg_idx.append(k)
rerange_index += (pos_idx + neg_idx)
rerange_index = np.array(rerange_index).astype(np.int32)
return rerange_index
\ No newline at end of file
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import numpy as np
from .comfunc import rerange_index
class EmlLoss(paddle.nn.Layer):
def __init__(self, batch_size = 40, samples_each_class = 2):
super(EmlLoss, self).__init__()
assert(batch_size % samples_each_class == 0)
self.samples_each_class = samples_each_class
self.batch_size = batch_size
self.rerange_index = rerange_index(batch_size, samples_each_class)
self.thresh = 20.0
self.beta = 100000
def surrogate_function(self, beta, theta, bias):
x = theta * paddle.exp(bias)
output = paddle.log(1 + beta * x) / math.log(1 + beta)
return output
def surrogate_function_approximate(self, beta, theta, bias):
output = (paddle.log(theta) + bias + math.log(beta)) / math.log(1+beta)
return output
def surrogate_function_stable(self, beta, theta, target, thresh):
max_gap = paddle.to_tensor(thresh, dtype='float32')
max_gap.stop_gradient = True
target_max = paddle.maximum(target, max_gap)
target_min = paddle.minimum(target, max_gap)
loss1 = self.surrogate_function(beta, theta, target_min)
loss2 = self.surrogate_function_approximate(beta, theta, target_max)
bias = self.surrogate_function(beta, theta, max_gap)
loss = loss1 + loss2 - bias
return loss
def forward(self, input, target=None):
features = input["features"]
samples_each_class = self.samples_each_class
batch_size = self.batch_size
rerange_index = self.rerange_index
#calc distance
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0)
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
tmp = paddle.reshape(similary_matrix, shape = [-1, 1])
rerange_index = paddle.to_tensor(rerange_index)
tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, batch_size])
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1,
samples_each_class - 1, batch_size - samples_each_class], axis = 1)
ignore.stop_gradient = True
pos_max = paddle.max(pos, axis=1, keepdim=True)
pos = paddle.exp(pos - pos_max)
pos_mean = paddle.mean(pos, axis=1, keepdim=True)
neg_min = paddle.min(neg, axis=1, keepdim=True)
neg = paddle.exp(neg_min - neg)
neg_mean = paddle.mean(neg, axis=1, keepdim=True)
bias = pos_max - neg_min
theta = paddle.multiply(neg_mean, pos_mean)
loss = self.surrogate_function_stable(self.beta, theta, bias, self.thresh)
loss = paddle.mean(loss)
return {"emlloss": loss}
if __name__=="__main__":
metric = EmlLoss()
np.random.seed(1)
features = np.random.randn(40, 32)
features = paddle.to_tensor(features, dtype="float32")
print(features)
loss = metric(features)
print(loss)
\ No newline at end of file
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from .comfunc import rerange_index
class MSMLoss(paddle.nn.Layer):
"""
MSMLoss Loss, based on triplet loss. USE P * K samples.
the batch size is fixed. Batch_size = P * K; but the K may vary between batches.
same label gather together
supported_metrics = [
'euclidean',
'sqeuclidean',
'cityblock',
]
only consider samples_each_class = 2
"""
def __init__(self, batch_size = 120, samples_each_class=2, margin=0.1):
super(MSMLoss, self).__init__()
self.margin = margin
self.samples_each_class = samples_each_class
self.batch_size = batch_size
self.rerange_index = rerange_index(batch_size, samples_each_class)
def forward(self, input, target=None):
#normalization
features = input["features"]
features = self._nomalize(features)
samples_each_class = self.samples_each_class
rerange_index = paddle.to_tensor(self.rerange_index)
#calc sm
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0)
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
#rerange
tmp = paddle.reshape(similary_matrix, shape = [-1, 1])
tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size])
#split
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1,
samples_each_class - 1, -1], axis = 1)
ignore.stop_gradient = True
hard_pos = paddle.max(pos)
hard_neg = paddle.min(neg)
loss = hard_pos + self.margin - hard_neg
loss = paddle.nn.ReLU()(loss)
return {"msmloss": loss}
def _nomalize(self, input):
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True))
return paddle.divide(input, input_norm)
if __name__ == "__main__":
import numpy as np
metric = MSMLoss(48)
#prepare data
np.random.seed(1)
features = np.random.randn(48, 32)
#print(features)
#do inference
features = paddle.to_tensor(features)
loss = metric(features)
print(loss)
\ No newline at end of file
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
class NpairsLoss(paddle.nn.Layer):
def __init__(self, reg_lambda=0.01):
super(NpairsLoss, self).__init__()
self.reg_lambda = reg_lambda
def forward(self, input, target=None):
"""
anchor and positive(should include label)
"""
features = input["features"]
reg_lambda = self.reg_lambda
batch_size = features.shape[0]
fea_dim = features.shape[1]
num_class = batch_size // 2
#reshape
out_feas = paddle.reshape(features, shape=[-1, 2, fea_dim])
anc_feas, pos_feas = paddle.split(out_feas, num_or_sections = 2, axis = 1)
anc_feas = paddle.squeeze(anc_feas, axis=1)
pos_feas = paddle.squeeze(pos_feas, axis=1)
#get simi matrix
similarity_matrix = paddle.matmul(anc_feas, pos_feas, transpose_y=True) #get similarity matrix
sparse_labels = paddle.arange(0, num_class, dtype='int64')
xentloss = paddle.nn.CrossEntropyLoss()(similarity_matrix, sparse_labels) #by default: mean
#l2 norm
reg = paddle.mean(paddle.sum(paddle.square(features), axis=1))
l2loss = 0.5 * reg_lambda * reg
return {"npairsloss": xentloss + l2loss}
if __name__ == "__main__":
import numpy as np
metric = NpairsLoss()
#prepare data
np.random.seed(1)
features = np.random.randn(160, 32)
#print(features)
#do inference
features = paddle.to_tensor(features)
loss = metric(features)
print(loss)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from .comfunc import rerange_index
class TriHardLoss(paddle.nn.Layer):
"""
TriHard Loss, based on triplet loss. USE P * K samples.
the batch size is fixed. Batch_size = P * K; but the K may vary between batches.
same label gather together
supported_metrics = [
'euclidean',
'sqeuclidean',
'cityblock',
]
only consider samples_each_class = 2
"""
def __init__(self, batch_size = 120, samples_each_class=2, margin=0.1):
super(TriHardLoss, self).__init__()
self.margin = margin
self.samples_each_class = samples_each_class
self.batch_size = batch_size
self.rerange_index = rerange_index(batch_size, samples_each_class)
def forward(self, input, target=None):
features = input["features"]
assert (self.batch_size == features.shape[0])
#normalization
features = self._nomalize(features)
samples_each_class = self.samples_each_class
rerange_index = paddle.to_tensor(self.rerange_index)
#calc sm
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0)
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
#rerange
tmp = paddle.reshape(similary_matrix, shape = [-1, 1])
tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size])
#split
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1,
samples_each_class - 1, -1], axis = 1)
ignore.stop_gradient = True
hard_pos = paddle.max(pos, axis=1)
hard_neg = paddle.min(neg, axis=1)
loss = hard_pos + self.margin - hard_neg
loss = paddle.nn.ReLU()(loss)
loss = paddle.mean(loss)
return {"trihardloss": loss}
def _nomalize(self, input):
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True))
return paddle.divide(input, input_norm)
if __name__ == "__main__":
import numpy as np
metric = TriHardLoss(48)
#prepare data
np.random.seed(1)
features = np.random.randn(48, 32)
#print(features)
#do inference
features = paddle.to_tensor(features)
loss = metric(features)
print(loss)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
class TripletLossV2(nn.Layer):
"""Triplet loss with hard positive/negative mining.
Args:
margin (float): margin for triplet.
"""
def __init__(self, margin=0.5):
super(TripletLossV2, self).__init__()
self.margin = margin
self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin)
def forward(self, input, target, normalize_feature=True):
"""
Args:
inputs: feature matrix with shape (batch_size, feat_dim)
target: ground truth labels with shape (num_classes)
"""
inputs = input["features"]
if normalize_feature:
inputs = 1. * inputs / (paddle.expand_as(
paddle.norm(inputs, p=2, axis=-1, keepdim=True), inputs) +
1e-12)
bs = inputs.shape[0]
# compute distance
dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs])
dist = dist + dist.t()
dist = paddle.addmm(input=dist,
x=inputs,
y=inputs.t(),
alpha=-2.0,
beta=1.0)
dist = paddle.clip(dist, min=1e-12).sqrt()
# hard negative mining
is_pos = paddle.expand(target, (bs, bs)).equal(
paddle.expand(target, (bs, bs)).t())
is_neg = paddle.expand(target, (bs, bs)).not_equal(
paddle.expand(target, (bs, bs)).t())
# `dist_ap` means distance(anchor, positive)
## both `dist_ap` and `relative_p_inds` with shape [N, 1]
#print(is_pos.shape, dist.shape, type(is_pos), type(dist), paddle.reshape(paddle.masked_select(dist, is_pos),(bs, -1)))
'''
dist_ap, relative_p_inds = paddle.max(
paddle.reshape(dist[is_pos], (bs, -1)), axis=1, keepdim=True)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N, 1]
dist_an, relative_n_inds = paddle.min(
paddle.reshape(dist[is_neg], (bs, -1)), axis=1, keepdim=True)
'''
dist_ap = paddle.max(paddle.reshape(paddle.masked_select(dist, is_pos),
(bs, -1)),
axis=1,
keepdim=True)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N, 1]
dist_an = paddle.min(paddle.reshape(paddle.masked_select(dist, is_neg),
(bs, -1)),
axis=1,
keepdim=True)
# shape [N]
dist_ap = paddle.squeeze(dist_ap, axis=1)
dist_an = paddle.squeeze(dist_an, axis=1)
# Compute ranking hinge loss
y = paddle.ones_like(dist_an)
loss = self.ranking_loss(dist_an, dist_ap, y)
return {"TripletLossV2": loss}
class TripletLoss(nn.Layer):
"""Triplet loss with hard positive/negative mining.
Reference:
Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
Args:
margin (float): margin for triplet.
"""
def __init__(self, margin=1.0):
super(TripletLoss, self).__init__()
self.margin = margin
self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin)
def forward(self, input, target):
"""
Args:
inputs: feature matrix with shape (batch_size, feat_dim)
target: ground truth labels with shape (num_classes)
"""
inputs = input["features"]
#print(inputs.shape, targets.shape)
bs = inputs.shape[0]
# Compute pairwise distance, replace by the official when merged
dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs])
dist = dist + dist.t()
dist = paddle.addmm(input=dist,
x=inputs,
y=inputs.t(),
alpha=-2.0,
beta=1.0)
dist = paddle.clip(dist, min=1e-12).sqrt()
mask = paddle.equal(target.expand([bs, bs]),
target.expand([bs, bs]).t())
mask_numpy_idx = mask.numpy()
dist_ap, dist_an = [], []
for i in range(bs):
# dist_ap_i = paddle.to_tensor(dist[i].numpy()[mask_numpy_idx[i]].max(),dtype='float64').unsqueeze(0)
# dist_ap_i.stop_gradient = False
# dist_ap.append(dist_ap_i)
dist_ap.append(
max([
dist[i][j]
if mask_numpy_idx[i][j] == True else float("-inf")
for j in range(bs)
]).unsqueeze(0))
# dist_an_i = paddle.to_tensor(dist[i].numpy()[mask_numpy_idx[i] == False].min(), dtype='float64').unsqueeze(0)
# dist_an_i.stop_gradient = False
# dist_an.append(dist_an_i)
dist_an.append(
min([
dist[i][k]
if mask_numpy_idx[i][k] == False else float("inf")
for k in range(bs)
]).unsqueeze(0))
dist_ap = paddle.concat(dist_ap, axis=0)
dist_an = paddle.concat(dist_an, axis=0)
# Compute ranking hinge loss
y = paddle.ones_like(dist_an)
loss = self.ranking_loss(dist_an, dist_ap, y)
return {"TripletLoss": loss}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册