未验证 提交 625b64bc 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #776 from weisy11/develop_reg

modify format, remove useless files
......@@ -34,5 +34,6 @@ d['valid'] = np.array(setid['valid'][0])
d['test'] = np.array(setid['tstid'][0])
for id in d[sys.argv[2]]:
message = str(data_path) + "/image_" + str(id).zfill(5) + ".jpg " + str(labels[id - 1] - 1)
print(message)
message = str(data_path) + "/image_" + str(id).zfill(5) + ".jpg " + str(
labels[id - 1] - 1)
print(message)
......@@ -18,10 +18,10 @@ import importlib
import paddle.nn as nn
from . import backbone
from . import head
from . import gears
from .backbone import *
from .head import *
from .gears import *
from .utils import *
__all__ = ["build_model", "RecModel"]
......
......@@ -19,10 +19,11 @@ from .fc import FC
__all__ = ['build_head']
def build_head(config):
support_dict = ['ArcMargin', 'CosMargin', 'CircleMargin', 'FC']
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
support_dict))
assert module_name in support_dict, Exception(
'head only support {}'.format(support_dict))
module_class = eval(module_name)(**config)
return module_class
......@@ -16,30 +16,32 @@ import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class CircleMargin(nn.Layer):
def __init__(self, embedding_size,
class_num,
margin,
scale):
def __init__(self, embedding_size, class_num, margin, scale):
super(CircleSoftmax, self).__init__()
self.scale = scale
self.scale = scale
self.margin = margin
self.embedding_size = embedding_size
self.class_num = class_num
weight_attr = paddle.ParamAttr(initializer = paddle.nn.initializer.XavierNormal())
self.fc0 = paddle.nn.Linear(self.embedding_size, self.class_num, weight_attr=weight_attr)
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.XavierNormal())
self.fc0 = paddle.nn.Linear(
self.embedding_size, self.class_num, weight_attr=weight_attr)
def forward(self, input, label):
feat_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True))
feat_norm = paddle.sqrt(
paddle.sum(paddle.square(input), axis=1, keepdim=True))
input = paddle.divide(input, feat_norm)
weight = self.fc0.weight
weight_norm = paddle.sqrt(paddle.sum(paddle.square(weight), axis=0, keepdim=True))
weight_norm = paddle.sqrt(
paddle.sum(paddle.square(weight), axis=0, keepdim=True))
weight = paddle.divide(weight, weight_norm)
logits = paddle.matmul(input, weight)
logits = paddle.matmul(input, weight)
alpha_p = paddle.clip(-logits.detach() + 1 + self.margin, min=0.)
alpha_n = paddle.clip(logits.detach() + self.margin, min=0.)
......@@ -51,5 +53,5 @@ class CircleMargin(nn.Layer):
logits_n = alpha_n * (logits - delta_n)
pre_logits = logits_p * m_hot + logits_n * (1 - m_hot)
pre_logits = self.scale * pre_logits
return pre_logits
......@@ -16,35 +16,41 @@ import paddle
import math
import paddle.nn as nn
class CosMargin(paddle.nn.Layer):
def __init__(self, embedding_size,
class_num,
margin=0.35,
scale=64.0):
def __init__(self, embedding_size, class_num, margin=0.35, scale=64.0):
super(CosMargin, self).__init__()
self.scale = scale
self.margin = margin
self.embedding_size = embedding_size
self.class_num = class_num
weight_attr = paddle.ParamAttr(initializer = paddle.nn.initializer.XavierNormal())
self.fc = nn.Linear(self.embedding_size, self.class_num, weight_attr=weight_attr, bias_attr=False)
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.XavierNormal())
self.fc = nn.Linear(
self.embedding_size,
self.class_num,
weight_attr=weight_attr,
bias_attr=False)
def forward(self, input, label):
label.stop_gradient = True
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True))
input = paddle.divide(input, x_norm)
input_norm = paddle.sqrt(
paddle.sum(paddle.square(input), axis=1, keepdim=True))
input = paddle.divide(input, x_norm)
weight = self.fc.weight
weight_norm = paddle.sqrt(paddle.sum(paddle.square(weight), axis=0, keepdim=True))
weight_norm = paddle.sqrt(
paddle.sum(paddle.square(weight), axis=0, keepdim=True))
weight = paddle.divide(weight, weight_norm)
cos = paddle.matmul(input, weight)
cos = paddle.matmul(input, weight)
cos_m = cos - self.margin
one_hot = paddle.nn.functional.one_hot(label, self.class_num)
one_hot = paddle.squeeze(one_hot, axis=[1])
output = paddle.multiply(one_hot, cos_m) + paddle.multiply((1.0 - one_hot), cos)
output = paddle.multiply(one_hot, cos_m) + paddle.multiply(
(1.0 - one_hot), cos)
output = output * self.scale
return output
......@@ -19,14 +19,16 @@ from __future__ import print_function
import paddle
import paddle.nn as nn
class FC(nn.Layer):
def __init__(self, embedding_size,
class_num):
def __init__(self, embedding_size, class_num):
super(FC, self).__init__()
self.embedding_size = embedding_size
self.embedding_size = embedding_size
self.class_num = class_num
weight_attr = paddle.ParamAttr(initializer = paddle.nn.initializer.XavierNormal())
self.fc = paddle.nn.Linear(self.embedding_size, self.class_num, weight_attr=weight_attr)
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.XavierNormal())
self.fc = paddle.nn.Linear(
self.embedding_size, self.class_num, weight_attr=weight_attr)
def forward(self, input, label):
out = self.fc(input)
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn.functional as F
__all__ = ['CELoss', 'MixCELoss', 'GoogLeNetLoss', 'JSDivLoss', 'MultiLabelLoss']
class Loss(object):
"""
Loss
"""
def __init__(self, class_dim=1000, epsilon=None):
assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
self._class_dim = class_dim
if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
self._epsilon = epsilon
self._label_smoothing = True
else:
self._epsilon = None
self._label_smoothing = False
def _labelsmoothing(self, target):
if target.shape[-1] != self._class_dim:
one_hot_target = F.one_hot(target, self._class_dim)
else:
one_hot_target = target
soft_target = F.label_smooth(one_hot_target, epsilon=self._epsilon)
soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
return soft_target
def _binary_crossentropy(self, input, target):
if self._label_smoothing:
target = self._labelsmoothing(target)
cost = F.binary_cross_entropy_with_logits(logit=input, label=target)
else:
cost = F.binary_cross_entropy_with_logits(logit=input, label=target)
avg_cost = paddle.mean(cost)
return avg_cost
def _crossentropy(self, input, target):
if self._label_smoothing:
target = self._labelsmoothing(target)
input = -F.log_softmax(input, axis=-1)
cost = paddle.sum(target * input, axis=-1)
else:
cost = F.cross_entropy(input=input, label=target)
avg_cost = paddle.mean(cost)
return avg_cost
def _kldiv(self, input, target, name=None):
eps = 1.0e-10
cost = target * paddle.log(
(target + eps) / (input + eps)) * self._class_dim
return cost
def _jsdiv(self, input, target):
input = F.softmax(input)
target = F.softmax(target)
cost = self._kldiv(input, target) + self._kldiv(target, input)
cost = cost / 2
avg_cost = paddle.mean(cost)
return avg_cost
def __call__(self, input, target):
pass
class MultiLabelLoss(Loss):
"""
Multilabel loss based binary cross entropy
"""
def __init__(self, class_dim=1000, epsilon=None):
super(MultiLabelLoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target):
cost = self._binary_crossentropy(input, target)
return cost
class CELoss(Loss):
"""
Cross entropy loss
"""
def __init__(self, class_dim=1000, epsilon=None):
super(CELoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target):
cost = self._crossentropy(input, target)
return cost
class MixCELoss(Loss):
"""
Cross entropy loss with mix(mixup, cutmix, fixmix)
"""
def __init__(self, class_dim=1000, epsilon=None):
super(MixCELoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target0, target1, lam):
cost0 = self._crossentropy(input, target0)
cost1 = self._crossentropy(input, target1)
cost = lam * cost0 + (1.0 - lam) * cost1
avg_cost = paddle.mean(cost)
return avg_cost
class GoogLeNetLoss(Loss):
"""
Cross entropy loss used after googlenet
"""
def __init__(self, class_dim=1000, epsilon=None):
super(GoogLeNetLoss, self).__init__(class_dim, epsilon)
def __call__(self, input0, input1, input2, target):
cost0 = self._crossentropy(input0, target)
cost1 = self._crossentropy(input1, target)
cost2 = self._crossentropy(input2, target)
cost = cost0 + 0.3 * cost1 + 0.3 * cost2
avg_cost = paddle.mean(cost)
return avg_cost
class JSDivLoss(Loss):
"""
JSDiv loss
"""
def __init__(self, class_dim=1000, epsilon=None):
super(JSDivLoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target):
cost = self._jsdiv(input, target)
return cost
......@@ -11,26 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import paddle
import numpy as np
from paddle.io import DistributedBatchSampler, BatchSampler, DataLoader
from ppcls.utils import logger
from . import dataset
from . import imaug
from . import samplers
from ppcls.data import dataloader
from ppcls.data import imaug
# dataset
from .dataset.imagenet_dataset import ImageNetDataset
from .dataset.multilabel_dataset import MultiLabelDataset
from .dataset.common_dataset import create_operators
from .dataset.vehicle_dataset import CompCars, VeriWild
from ppcls.data.dataloader.imagenet_dataset import ImageNetDataset
from ppcls.data.dataloader.multilabel_dataset import MultiLabelDataset
from ppcls.data.dataloader.common_dataset import create_operators
from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
# sampler
from .samplers import DistributedRandomIdentitySampler
from ppcls.data.dataloader import DistributedRandomIdentitySampler
from .preprocess import transform
from ppcls.data.preprocess import transform
def build_dataloader(config, mode, device, seed=None):
......
......@@ -14,22 +14,13 @@
from __future__ import print_function
import io
import tarfile
import numpy as np
from PIL import Image #all use default backend
import paddle
from paddle.io import Dataset
import pickle
import os
import cv2
import random
from .common_dataset import CommonDataset
class ImageNetDataset(CommonDataset):
class ImageNetDataset(CommonDataset):
def _load_anno(self, seed=None):
assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root)
......@@ -47,5 +38,3 @@ class ImageNetDataset(CommonDataset):
self.images.append(os.path.join(self._img_root, l[0]))
self.labels.append(int(l[1]))
assert os.path.exists(self.images[-1])
......@@ -14,26 +14,17 @@
from __future__ import print_function
import io
import tarfile
import numpy as np
from PIL import Image #all use default backend
import paddle
from paddle.io import Dataset
import pickle
import os
import cv2
import random
from ppcls.data import preprocess
from ppcls.data.preprocess import transform
from ppcls.utils import logger
from .common_dataset import CommonDataset
class MultiLabelDataset(CommonDataset):
class MultiLabelDataset(CommonDataset):
def _load_anno(self):
assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root)
......
from .DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
......@@ -30,7 +30,7 @@ from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.data import build_dataloader
from ppcls.arch import build_model
from ppcls.losses import build_loss
from ppcls.loss import build_loss
from ppcls.arch.loss_metrics import build_metrics
from ppcls.optimizer import build_optimizer
from ppcls.utils.save_load import load_dygraph_pretrain
......
......@@ -5,12 +5,15 @@ import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class CenterLoss(nn.Layer):
def __init__(self, num_classes=5013, feat_dim=2048):
super(CenterLoss, self).__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.centers = paddle.randn(shape=[self.num_classes, self.feat_dim]).astype("float64") #random center
self.centers = paddle.randn(
shape=[self.num_classes, self.feat_dim]).astype(
"float64") #random center
def __call__(self, input, target):
"""
......@@ -23,25 +26,29 @@ class CenterLoss(nn.Layer):
#calc feat * feat
dist1 = paddle.sum(paddle.square(feats), axis=1, keepdim=True)
dist1 = paddle.expand(dist1, [batch_size, self.num_classes])
dist1 = paddle.expand(dist1, [batch_size, self.num_classes])
#dist2 of centers
dist2 = paddle.sum(paddle.square(self.centers), axis=1, keepdim=True) #num_classes
dist2 = paddle.expand(dist2, [self.num_classes, batch_size]).astype("float64")
dist2 = paddle.sum(paddle.square(self.centers), axis=1,
keepdim=True) #num_classes
dist2 = paddle.expand(dist2,
[self.num_classes, batch_size]).astype("float64")
dist2 = paddle.transpose(dist2, [1, 0])
#first x * x + y * y
distmat = paddle.add(dist1, dist2)
tmp = paddle.matmul(feats, paddle.transpose(self.centers, [1, 0]))
distmat = distmat - 2.0 * tmp
tmp = paddle.matmul(feats, paddle.transpose(self.centers, [1, 0]))
distmat = distmat - 2.0 * tmp
#generate the mask
classes = paddle.arange(self.num_classes).astype("int64")
labels = paddle.expand(paddle.unsqueeze(labels, 1), (batch_size, self.num_classes))
mask = paddle.equal(paddle.expand(classes, [batch_size, self.num_classes]), labels).astype("float64") #get mask
labels = paddle.expand(
paddle.unsqueeze(labels, 1), (batch_size, self.num_classes))
mask = paddle.equal(
paddle.expand(classes, [batch_size, self.num_classes]),
labels).astype("float64") #get mask
dist = paddle.multiply(distmat, mask)
dist = paddle.multiply(distmat, mask)
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
return {'CenterLoss': loss}
......@@ -18,26 +18,27 @@ from __future__ import print_function
import numpy as np
def rerange_index(batch_size, samples_each_class):
tmp = np.arange(0, batch_size * batch_size)
tmp = tmp.reshape(-1, batch_size)
tmp = np.arange(0, batch_size * batch_size)
tmp = tmp.reshape(-1, batch_size)
rerange_index = []
for i in range(batch_size):
step = i // samples_each_class
start = step * samples_each_class
end = (step + 1) * samples_each_class
end = (step + 1) * samples_each_class
pos_idx = []
neg_idx = []
pos_idx = []
neg_idx = []
for j, k in enumerate(tmp[i]):
if j >= start and j < end:
if j == i:
pos_idx.insert(0, k)
else:
pos_idx.append(k)
pos_idx.append(k)
else:
neg_idx.append(k)
neg_idx.append(k)
rerange_index += (pos_idx + neg_idx)
rerange_index = np.array(rerange_index).astype(np.int32)
......
......@@ -21,56 +21,64 @@ import paddle
import numpy as np
from .comfunc import rerange_index
class EmlLoss(paddle.nn.Layer):
def __init__(self, batch_size = 40, samples_each_class = 2):
def __init__(self, batch_size=40, samples_each_class=2):
super(EmlLoss, self).__init__()
assert(batch_size % samples_each_class == 0)
assert (batch_size % samples_each_class == 0)
self.samples_each_class = samples_each_class
self.batch_size = batch_size
self.rerange_index = rerange_index(batch_size, samples_each_class)
self.batch_size = batch_size
self.rerange_index = rerange_index(batch_size, samples_each_class)
self.thresh = 20.0
self.beta = 100000
self.beta = 100000
def surrogate_function(self, beta, theta, bias):
x = theta * paddle.exp(bias)
x = theta * paddle.exp(bias)
output = paddle.log(1 + beta * x) / math.log(1 + beta)
return output
def surrogate_function_approximate(self, beta, theta, bias):
output = (paddle.log(theta) + bias + math.log(beta)) / math.log(1+beta)
output = (
paddle.log(theta) + bias + math.log(beta)) / math.log(1 + beta)
return output
def surrogate_function_stable(self, beta, theta, target, thresh):
max_gap = paddle.to_tensor(thresh, dtype='float32')
max_gap.stop_gradient = True
target_max = paddle.maximum(target, max_gap)
target_min = paddle.minimum(target, max_gap)
loss1 = self.surrogate_function(beta, theta, target_min)
loss2 = self.surrogate_function_approximate(beta, theta, target_max)
bias = self.surrogate_function(beta, theta, max_gap)
loss = loss1 + loss2 - bias
bias = self.surrogate_function(beta, theta, max_gap)
loss = loss1 + loss2 - bias
return loss
def forward(self, input, target=None):
features = input["features"]
samples_each_class = self.samples_each_class
batch_size = self.batch_size
rerange_index = self.rerange_index
batch_size = self.batch_size
rerange_index = self.rerange_index
#calc distance
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0)
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
tmp = paddle.reshape(similary_matrix, shape = [-1, 1])
diffs = paddle.unsqueeze(
features, axis=1) - paddle.unsqueeze(
features, axis=0)
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
tmp = paddle.reshape(similary_matrix, shape=[-1, 1])
rerange_index = paddle.to_tensor(rerange_index)
tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, batch_size])
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1,
samples_each_class - 1, batch_size - samples_each_class], axis = 1)
ignore.stop_gradient = True
tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, batch_size])
ignore, pos, neg = paddle.split(
similary_matrix,
num_or_sections=[
1, samples_each_class - 1, batch_size - samples_each_class
],
axis=1)
ignore.stop_gradient = True
pos_max = paddle.max(pos, axis=1, keepdim=True)
pos = paddle.exp(pos - pos_max)
......@@ -79,11 +87,11 @@ class EmlLoss(paddle.nn.Layer):
neg_min = paddle.min(neg, axis=1, keepdim=True)
neg = paddle.exp(neg_min - neg)
neg_mean = paddle.mean(neg, axis=1, keepdim=True)
bias = pos_max - neg_min
theta = paddle.multiply(neg_mean, pos_mean)
loss = self.surrogate_function_stable(self.beta, theta, bias, self.thresh)
loss = self.surrogate_function_stable(self.beta, theta, bias,
self.thresh)
loss = paddle.mean(loss)
return {"emlloss": loss}
......@@ -18,6 +18,7 @@ from __future__ import print_function
import paddle
from .comfunc import rerange_index
class MSMLoss(paddle.nn.Layer):
"""
MSMLoss Loss, based on triplet loss. USE P * K samples.
......@@ -31,42 +32,47 @@ class MSMLoss(paddle.nn.Layer):
]
only consider samples_each_class = 2
"""
def __init__(self, batch_size = 120, samples_each_class=2, margin=0.1):
def __init__(self, batch_size=120, samples_each_class=2, margin=0.1):
super(MSMLoss, self).__init__()
self.margin = margin
self.samples_each_class = samples_each_class
self.batch_size = batch_size
self.rerange_index = rerange_index(batch_size, samples_each_class)
self.batch_size = batch_size
self.rerange_index = rerange_index(batch_size, samples_each_class)
def forward(self, input, target=None):
#normalization
features = input["features"]
features = self._nomalize(features)
samples_each_class = self.samples_each_class
rerange_index = paddle.to_tensor(self.rerange_index)
rerange_index = paddle.to_tensor(self.rerange_index)
#calc sm
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0)
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
diffs = paddle.unsqueeze(
features, axis=1) - paddle.unsqueeze(
features, axis=0)
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
#rerange
tmp = paddle.reshape(similary_matrix, shape = [-1, 1])
tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size])
tmp = paddle.reshape(similary_matrix, shape=[-1, 1])
tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size])
#split
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1,
samples_each_class - 1, -1], axis = 1)
ignore.stop_gradient = True
ignore, pos, neg = paddle.split(
similary_matrix,
num_or_sections=[1, samples_each_class - 1, -1],
axis=1)
ignore.stop_gradient = True
hard_pos = paddle.max(pos)
hard_pos = paddle.max(pos)
hard_neg = paddle.min(neg)
loss = hard_pos + self.margin - hard_neg
loss = paddle.nn.ReLU()(loss)
loss = paddle.nn.ReLU()(loss)
return {"msmloss": loss}
def _nomalize(self, input):
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True))
input_norm = paddle.sqrt(
paddle.sum(paddle.square(input), axis=1, keepdim=True))
return paddle.divide(input, input_norm)
......@@ -3,12 +3,12 @@ from __future__ import division
from __future__ import print_function
import paddle
class NpairsLoss(paddle.nn.Layer):
def __init__(self, reg_lambda=0.01):
super(NpairsLoss, self).__init__()
self.reg_lambda = reg_lambda
def forward(self, input, target=None):
"""
anchor and positive(should include label)
......@@ -16,22 +16,23 @@ class NpairsLoss(paddle.nn.Layer):
features = input["features"]
reg_lambda = self.reg_lambda
batch_size = features.shape[0]
fea_dim = features.shape[1]
fea_dim = features.shape[1]
num_class = batch_size // 2
#reshape
out_feas = paddle.reshape(features, shape=[-1, 2, fea_dim])
anc_feas, pos_feas = paddle.split(out_feas, num_or_sections = 2, axis = 1)
anc_feas = paddle.squeeze(anc_feas, axis=1)
anc_feas, pos_feas = paddle.split(out_feas, num_or_sections=2, axis=1)
anc_feas = paddle.squeeze(anc_feas, axis=1)
pos_feas = paddle.squeeze(pos_feas, axis=1)
#get simi matrix
similarity_matrix = paddle.matmul(anc_feas, pos_feas, transpose_y=True) #get similarity matrix
similarity_matrix = paddle.matmul(
anc_feas, pos_feas, transpose_y=True) #get similarity matrix
sparse_labels = paddle.arange(0, num_class, dtype='int64')
xentloss = paddle.nn.CrossEntropyLoss()(similarity_matrix, sparse_labels) #by default: mean
xentloss = paddle.nn.CrossEntropyLoss()(
similarity_matrix, sparse_labels) #by default: mean
#l2 norm
reg = paddle.mean(paddle.sum(paddle.square(features), axis=1))
l2loss = 0.5 * reg_lambda * reg
return {"npairsloss": xentloss + l2loss}
......@@ -19,6 +19,7 @@ from __future__ import print_function
import paddle
from .comfunc import rerange_index
class TriHardLoss(paddle.nn.Layer):
"""
TriHard Loss, based on triplet loss. USE P * K samples.
......@@ -32,45 +33,50 @@ class TriHardLoss(paddle.nn.Layer):
]
only consider samples_each_class = 2
"""
def __init__(self, batch_size = 120, samples_each_class=2, margin=0.1):
def __init__(self, batch_size=120, samples_each_class=2, margin=0.1):
super(TriHardLoss, self).__init__()
self.margin = margin
self.samples_each_class = samples_each_class
self.batch_size = batch_size
self.rerange_index = rerange_index(batch_size, samples_each_class)
self.batch_size = batch_size
self.rerange_index = rerange_index(batch_size, samples_each_class)
def forward(self, input, target=None):
features = input["features"]
assert (self.batch_size == features.shape[0])
#normalization
features = self._nomalize(features)
samples_each_class = self.samples_each_class
rerange_index = paddle.to_tensor(self.rerange_index)
rerange_index = paddle.to_tensor(self.rerange_index)
#calc sm
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0)
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
diffs = paddle.unsqueeze(
features, axis=1) - paddle.unsqueeze(
features, axis=0)
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
#rerange
tmp = paddle.reshape(similary_matrix, shape = [-1, 1])
tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size])
tmp = paddle.reshape(similary_matrix, shape=[-1, 1])
tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size])
#split
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1,
samples_each_class - 1, -1], axis = 1)
ignore.stop_gradient = True
hard_pos = paddle.max(pos, axis=1)
ignore, pos, neg = paddle.split(
similary_matrix,
num_or_sections=[1, samples_each_class - 1, -1],
axis=1)
ignore.stop_gradient = True
hard_pos = paddle.max(pos, axis=1)
hard_neg = paddle.min(neg, axis=1)
loss = hard_pos + self.margin - hard_neg
loss = paddle.nn.ReLU()(loss)
loss = paddle.nn.ReLU()(loss)
loss = paddle.mean(loss)
return {"trihardloss": loss}
def _nomalize(self, input):
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True))
input_norm = paddle.sqrt(
paddle.sum(paddle.square(input), axis=1, keepdim=True))
return paddle.divide(input, input_norm)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册