未验证 提交 625b64bc 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #776 from weisy11/develop_reg

modify format, remove useless files
...@@ -34,5 +34,6 @@ d['valid'] = np.array(setid['valid'][0]) ...@@ -34,5 +34,6 @@ d['valid'] = np.array(setid['valid'][0])
d['test'] = np.array(setid['tstid'][0]) d['test'] = np.array(setid['tstid'][0])
for id in d[sys.argv[2]]: for id in d[sys.argv[2]]:
message = str(data_path) + "/image_" + str(id).zfill(5) + ".jpg " + str(labels[id - 1] - 1) message = str(data_path) + "/image_" + str(id).zfill(5) + ".jpg " + str(
print(message) labels[id - 1] - 1)
print(message)
...@@ -18,10 +18,10 @@ import importlib ...@@ -18,10 +18,10 @@ import importlib
import paddle.nn as nn import paddle.nn as nn
from . import backbone from . import backbone
from . import head from . import gears
from .backbone import * from .backbone import *
from .head import * from .gears import *
from .utils import * from .utils import *
__all__ = ["build_model", "RecModel"] __all__ = ["build_model", "RecModel"]
......
...@@ -19,10 +19,11 @@ from .fc import FC ...@@ -19,10 +19,11 @@ from .fc import FC
__all__ = ['build_head'] __all__ = ['build_head']
def build_head(config): def build_head(config):
support_dict = ['ArcMargin', 'CosMargin', 'CircleMargin', 'FC'] support_dict = ['ArcMargin', 'CosMargin', 'CircleMargin', 'FC']
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format( assert module_name in support_dict, Exception(
support_dict)) 'head only support {}'.format(support_dict))
module_class = eval(module_name)(**config) module_class = eval(module_name)(**config)
return module_class return module_class
...@@ -16,30 +16,32 @@ import math ...@@ -16,30 +16,32 @@ import math
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
class CircleMargin(nn.Layer): class CircleMargin(nn.Layer):
def __init__(self, embedding_size, def __init__(self, embedding_size, class_num, margin, scale):
class_num,
margin,
scale):
super(CircleSoftmax, self).__init__() super(CircleSoftmax, self).__init__()
self.scale = scale self.scale = scale
self.margin = margin self.margin = margin
self.embedding_size = embedding_size self.embedding_size = embedding_size
self.class_num = class_num self.class_num = class_num
weight_attr = paddle.ParamAttr(initializer = paddle.nn.initializer.XavierNormal()) weight_attr = paddle.ParamAttr(
self.fc0 = paddle.nn.Linear(self.embedding_size, self.class_num, weight_attr=weight_attr) initializer=paddle.nn.initializer.XavierNormal())
self.fc0 = paddle.nn.Linear(
self.embedding_size, self.class_num, weight_attr=weight_attr)
def forward(self, input, label): def forward(self, input, label):
feat_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True)) feat_norm = paddle.sqrt(
paddle.sum(paddle.square(input), axis=1, keepdim=True))
input = paddle.divide(input, feat_norm) input = paddle.divide(input, feat_norm)
weight = self.fc0.weight weight = self.fc0.weight
weight_norm = paddle.sqrt(paddle.sum(paddle.square(weight), axis=0, keepdim=True)) weight_norm = paddle.sqrt(
paddle.sum(paddle.square(weight), axis=0, keepdim=True))
weight = paddle.divide(weight, weight_norm) weight = paddle.divide(weight, weight_norm)
logits = paddle.matmul(input, weight) logits = paddle.matmul(input, weight)
alpha_p = paddle.clip(-logits.detach() + 1 + self.margin, min=0.) alpha_p = paddle.clip(-logits.detach() + 1 + self.margin, min=0.)
alpha_n = paddle.clip(logits.detach() + self.margin, min=0.) alpha_n = paddle.clip(logits.detach() + self.margin, min=0.)
...@@ -51,5 +53,5 @@ class CircleMargin(nn.Layer): ...@@ -51,5 +53,5 @@ class CircleMargin(nn.Layer):
logits_n = alpha_n * (logits - delta_n) logits_n = alpha_n * (logits - delta_n)
pre_logits = logits_p * m_hot + logits_n * (1 - m_hot) pre_logits = logits_p * m_hot + logits_n * (1 - m_hot)
pre_logits = self.scale * pre_logits pre_logits = self.scale * pre_logits
return pre_logits return pre_logits
...@@ -16,35 +16,41 @@ import paddle ...@@ -16,35 +16,41 @@ import paddle
import math import math
import paddle.nn as nn import paddle.nn as nn
class CosMargin(paddle.nn.Layer): class CosMargin(paddle.nn.Layer):
def __init__(self, embedding_size, def __init__(self, embedding_size, class_num, margin=0.35, scale=64.0):
class_num,
margin=0.35,
scale=64.0):
super(CosMargin, self).__init__() super(CosMargin, self).__init__()
self.scale = scale self.scale = scale
self.margin = margin self.margin = margin
self.embedding_size = embedding_size self.embedding_size = embedding_size
self.class_num = class_num self.class_num = class_num
weight_attr = paddle.ParamAttr(initializer = paddle.nn.initializer.XavierNormal()) weight_attr = paddle.ParamAttr(
self.fc = nn.Linear(self.embedding_size, self.class_num, weight_attr=weight_attr, bias_attr=False) initializer=paddle.nn.initializer.XavierNormal())
self.fc = nn.Linear(
self.embedding_size,
self.class_num,
weight_attr=weight_attr,
bias_attr=False)
def forward(self, input, label): def forward(self, input, label):
label.stop_gradient = True label.stop_gradient = True
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True)) input_norm = paddle.sqrt(
input = paddle.divide(input, x_norm) paddle.sum(paddle.square(input), axis=1, keepdim=True))
input = paddle.divide(input, x_norm)
weight = self.fc.weight weight = self.fc.weight
weight_norm = paddle.sqrt(paddle.sum(paddle.square(weight), axis=0, keepdim=True)) weight_norm = paddle.sqrt(
paddle.sum(paddle.square(weight), axis=0, keepdim=True))
weight = paddle.divide(weight, weight_norm) weight = paddle.divide(weight, weight_norm)
cos = paddle.matmul(input, weight) cos = paddle.matmul(input, weight)
cos_m = cos - self.margin cos_m = cos - self.margin
one_hot = paddle.nn.functional.one_hot(label, self.class_num) one_hot = paddle.nn.functional.one_hot(label, self.class_num)
one_hot = paddle.squeeze(one_hot, axis=[1]) one_hot = paddle.squeeze(one_hot, axis=[1])
output = paddle.multiply(one_hot, cos_m) + paddle.multiply((1.0 - one_hot), cos) output = paddle.multiply(one_hot, cos_m) + paddle.multiply(
(1.0 - one_hot), cos)
output = output * self.scale output = output * self.scale
return output return output
...@@ -19,14 +19,16 @@ from __future__ import print_function ...@@ -19,14 +19,16 @@ from __future__ import print_function
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
class FC(nn.Layer): class FC(nn.Layer):
def __init__(self, embedding_size, def __init__(self, embedding_size, class_num):
class_num):
super(FC, self).__init__() super(FC, self).__init__()
self.embedding_size = embedding_size self.embedding_size = embedding_size
self.class_num = class_num self.class_num = class_num
weight_attr = paddle.ParamAttr(initializer = paddle.nn.initializer.XavierNormal()) weight_attr = paddle.ParamAttr(
self.fc = paddle.nn.Linear(self.embedding_size, self.class_num, weight_attr=weight_attr) initializer=paddle.nn.initializer.XavierNormal())
self.fc = paddle.nn.Linear(
self.embedding_size, self.class_num, weight_attr=weight_attr)
def forward(self, input, label): def forward(self, input, label):
out = self.fc(input) out = self.fc(input)
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn.functional as F
__all__ = ['CELoss', 'MixCELoss', 'GoogLeNetLoss', 'JSDivLoss', 'MultiLabelLoss']
class Loss(object):
"""
Loss
"""
def __init__(self, class_dim=1000, epsilon=None):
assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
self._class_dim = class_dim
if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
self._epsilon = epsilon
self._label_smoothing = True
else:
self._epsilon = None
self._label_smoothing = False
def _labelsmoothing(self, target):
if target.shape[-1] != self._class_dim:
one_hot_target = F.one_hot(target, self._class_dim)
else:
one_hot_target = target
soft_target = F.label_smooth(one_hot_target, epsilon=self._epsilon)
soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
return soft_target
def _binary_crossentropy(self, input, target):
if self._label_smoothing:
target = self._labelsmoothing(target)
cost = F.binary_cross_entropy_with_logits(logit=input, label=target)
else:
cost = F.binary_cross_entropy_with_logits(logit=input, label=target)
avg_cost = paddle.mean(cost)
return avg_cost
def _crossentropy(self, input, target):
if self._label_smoothing:
target = self._labelsmoothing(target)
input = -F.log_softmax(input, axis=-1)
cost = paddle.sum(target * input, axis=-1)
else:
cost = F.cross_entropy(input=input, label=target)
avg_cost = paddle.mean(cost)
return avg_cost
def _kldiv(self, input, target, name=None):
eps = 1.0e-10
cost = target * paddle.log(
(target + eps) / (input + eps)) * self._class_dim
return cost
def _jsdiv(self, input, target):
input = F.softmax(input)
target = F.softmax(target)
cost = self._kldiv(input, target) + self._kldiv(target, input)
cost = cost / 2
avg_cost = paddle.mean(cost)
return avg_cost
def __call__(self, input, target):
pass
class MultiLabelLoss(Loss):
"""
Multilabel loss based binary cross entropy
"""
def __init__(self, class_dim=1000, epsilon=None):
super(MultiLabelLoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target):
cost = self._binary_crossentropy(input, target)
return cost
class CELoss(Loss):
"""
Cross entropy loss
"""
def __init__(self, class_dim=1000, epsilon=None):
super(CELoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target):
cost = self._crossentropy(input, target)
return cost
class MixCELoss(Loss):
"""
Cross entropy loss with mix(mixup, cutmix, fixmix)
"""
def __init__(self, class_dim=1000, epsilon=None):
super(MixCELoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target0, target1, lam):
cost0 = self._crossentropy(input, target0)
cost1 = self._crossentropy(input, target1)
cost = lam * cost0 + (1.0 - lam) * cost1
avg_cost = paddle.mean(cost)
return avg_cost
class GoogLeNetLoss(Loss):
"""
Cross entropy loss used after googlenet
"""
def __init__(self, class_dim=1000, epsilon=None):
super(GoogLeNetLoss, self).__init__(class_dim, epsilon)
def __call__(self, input0, input1, input2, target):
cost0 = self._crossentropy(input0, target)
cost1 = self._crossentropy(input1, target)
cost2 = self._crossentropy(input2, target)
cost = cost0 + 0.3 * cost1 + 0.3 * cost2
avg_cost = paddle.mean(cost)
return avg_cost
class JSDivLoss(Loss):
"""
JSDiv loss
"""
def __init__(self, class_dim=1000, epsilon=None):
super(JSDivLoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target):
cost = self._jsdiv(input, target)
return cost
...@@ -11,26 +11,25 @@ ...@@ -11,26 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy import copy
import paddle import paddle
import numpy as np import numpy as np
from paddle.io import DistributedBatchSampler, BatchSampler, DataLoader from paddle.io import DistributedBatchSampler, BatchSampler, DataLoader
from ppcls.utils import logger from ppcls.utils import logger
from . import dataset from ppcls.data import dataloader
from . import imaug from ppcls.data import imaug
from . import samplers
# dataset # dataset
from .dataset.imagenet_dataset import ImageNetDataset from ppcls.data.dataloader.imagenet_dataset import ImageNetDataset
from .dataset.multilabel_dataset import MultiLabelDataset from ppcls.data.dataloader.multilabel_dataset import MultiLabelDataset
from .dataset.common_dataset import create_operators from ppcls.data.dataloader.common_dataset import create_operators
from .dataset.vehicle_dataset import CompCars, VeriWild from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
# sampler # sampler
from .samplers import DistributedRandomIdentitySampler from ppcls.data.dataloader import DistributedRandomIdentitySampler
from .preprocess import transform from ppcls.data.preprocess import transform
def build_dataloader(config, mode, device, seed=None): def build_dataloader(config, mode, device, seed=None):
......
...@@ -14,22 +14,13 @@ ...@@ -14,22 +14,13 @@
from __future__ import print_function from __future__ import print_function
import io
import tarfile
import numpy as np import numpy as np
from PIL import Image #all use default backend
import paddle
from paddle.io import Dataset
import pickle
import os import os
import cv2
import random
from .common_dataset import CommonDataset from .common_dataset import CommonDataset
class ImageNetDataset(CommonDataset):
class ImageNetDataset(CommonDataset):
def _load_anno(self, seed=None): def _load_anno(self, seed=None):
assert os.path.exists(self._cls_path) assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root) assert os.path.exists(self._img_root)
...@@ -47,5 +38,3 @@ class ImageNetDataset(CommonDataset): ...@@ -47,5 +38,3 @@ class ImageNetDataset(CommonDataset):
self.images.append(os.path.join(self._img_root, l[0])) self.images.append(os.path.join(self._img_root, l[0]))
self.labels.append(int(l[1])) self.labels.append(int(l[1]))
assert os.path.exists(self.images[-1]) assert os.path.exists(self.images[-1])
...@@ -14,26 +14,17 @@ ...@@ -14,26 +14,17 @@
from __future__ import print_function from __future__ import print_function
import io
import tarfile
import numpy as np import numpy as np
from PIL import Image #all use default backend
import paddle
from paddle.io import Dataset
import pickle
import os import os
import cv2 import cv2
import random
from ppcls.data import preprocess
from ppcls.data.preprocess import transform from ppcls.data.preprocess import transform
from ppcls.utils import logger from ppcls.utils import logger
from .common_dataset import CommonDataset from .common_dataset import CommonDataset
class MultiLabelDataset(CommonDataset):
class MultiLabelDataset(CommonDataset):
def _load_anno(self): def _load_anno(self):
assert os.path.exists(self._cls_path) assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root) assert os.path.exists(self._img_root)
......
from .DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
...@@ -30,7 +30,7 @@ from ppcls.utils.misc import AverageMeter ...@@ -30,7 +30,7 @@ from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger from ppcls.utils import logger
from ppcls.data import build_dataloader from ppcls.data import build_dataloader
from ppcls.arch import build_model from ppcls.arch import build_model
from ppcls.losses import build_loss from ppcls.loss import build_loss
from ppcls.arch.loss_metrics import build_metrics from ppcls.arch.loss_metrics import build_metrics
from ppcls.optimizer import build_optimizer from ppcls.optimizer import build_optimizer
from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import load_dygraph_pretrain
......
...@@ -5,12 +5,15 @@ import paddle ...@@ -5,12 +5,15 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
class CenterLoss(nn.Layer): class CenterLoss(nn.Layer):
def __init__(self, num_classes=5013, feat_dim=2048): def __init__(self, num_classes=5013, feat_dim=2048):
super(CenterLoss, self).__init__() super(CenterLoss, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.feat_dim = feat_dim self.feat_dim = feat_dim
self.centers = paddle.randn(shape=[self.num_classes, self.feat_dim]).astype("float64") #random center self.centers = paddle.randn(
shape=[self.num_classes, self.feat_dim]).astype(
"float64") #random center
def __call__(self, input, target): def __call__(self, input, target):
""" """
...@@ -23,25 +26,29 @@ class CenterLoss(nn.Layer): ...@@ -23,25 +26,29 @@ class CenterLoss(nn.Layer):
#calc feat * feat #calc feat * feat
dist1 = paddle.sum(paddle.square(feats), axis=1, keepdim=True) dist1 = paddle.sum(paddle.square(feats), axis=1, keepdim=True)
dist1 = paddle.expand(dist1, [batch_size, self.num_classes]) dist1 = paddle.expand(dist1, [batch_size, self.num_classes])
#dist2 of centers #dist2 of centers
dist2 = paddle.sum(paddle.square(self.centers), axis=1, keepdim=True) #num_classes dist2 = paddle.sum(paddle.square(self.centers), axis=1,
dist2 = paddle.expand(dist2, [self.num_classes, batch_size]).astype("float64") keepdim=True) #num_classes
dist2 = paddle.expand(dist2,
[self.num_classes, batch_size]).astype("float64")
dist2 = paddle.transpose(dist2, [1, 0]) dist2 = paddle.transpose(dist2, [1, 0])
#first x * x + y * y #first x * x + y * y
distmat = paddle.add(dist1, dist2) distmat = paddle.add(dist1, dist2)
tmp = paddle.matmul(feats, paddle.transpose(self.centers, [1, 0])) tmp = paddle.matmul(feats, paddle.transpose(self.centers, [1, 0]))
distmat = distmat - 2.0 * tmp distmat = distmat - 2.0 * tmp
#generate the mask #generate the mask
classes = paddle.arange(self.num_classes).astype("int64") classes = paddle.arange(self.num_classes).astype("int64")
labels = paddle.expand(paddle.unsqueeze(labels, 1), (batch_size, self.num_classes)) labels = paddle.expand(
mask = paddle.equal(paddle.expand(classes, [batch_size, self.num_classes]), labels).astype("float64") #get mask paddle.unsqueeze(labels, 1), (batch_size, self.num_classes))
mask = paddle.equal(
paddle.expand(classes, [batch_size, self.num_classes]),
labels).astype("float64") #get mask
dist = paddle.multiply(distmat, mask) dist = paddle.multiply(distmat, mask)
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
return {'CenterLoss': loss} return {'CenterLoss': loss}
...@@ -18,26 +18,27 @@ from __future__ import print_function ...@@ -18,26 +18,27 @@ from __future__ import print_function
import numpy as np import numpy as np
def rerange_index(batch_size, samples_each_class): def rerange_index(batch_size, samples_each_class):
tmp = np.arange(0, batch_size * batch_size) tmp = np.arange(0, batch_size * batch_size)
tmp = tmp.reshape(-1, batch_size) tmp = tmp.reshape(-1, batch_size)
rerange_index = [] rerange_index = []
for i in range(batch_size): for i in range(batch_size):
step = i // samples_each_class step = i // samples_each_class
start = step * samples_each_class start = step * samples_each_class
end = (step + 1) * samples_each_class end = (step + 1) * samples_each_class
pos_idx = [] pos_idx = []
neg_idx = [] neg_idx = []
for j, k in enumerate(tmp[i]): for j, k in enumerate(tmp[i]):
if j >= start and j < end: if j >= start and j < end:
if j == i: if j == i:
pos_idx.insert(0, k) pos_idx.insert(0, k)
else: else:
pos_idx.append(k) pos_idx.append(k)
else: else:
neg_idx.append(k) neg_idx.append(k)
rerange_index += (pos_idx + neg_idx) rerange_index += (pos_idx + neg_idx)
rerange_index = np.array(rerange_index).astype(np.int32) rerange_index = np.array(rerange_index).astype(np.int32)
......
...@@ -21,56 +21,64 @@ import paddle ...@@ -21,56 +21,64 @@ import paddle
import numpy as np import numpy as np
from .comfunc import rerange_index from .comfunc import rerange_index
class EmlLoss(paddle.nn.Layer): class EmlLoss(paddle.nn.Layer):
def __init__(self, batch_size = 40, samples_each_class = 2): def __init__(self, batch_size=40, samples_each_class=2):
super(EmlLoss, self).__init__() super(EmlLoss, self).__init__()
assert(batch_size % samples_each_class == 0) assert (batch_size % samples_each_class == 0)
self.samples_each_class = samples_each_class self.samples_each_class = samples_each_class
self.batch_size = batch_size self.batch_size = batch_size
self.rerange_index = rerange_index(batch_size, samples_each_class) self.rerange_index = rerange_index(batch_size, samples_each_class)
self.thresh = 20.0 self.thresh = 20.0
self.beta = 100000 self.beta = 100000
def surrogate_function(self, beta, theta, bias): def surrogate_function(self, beta, theta, bias):
x = theta * paddle.exp(bias) x = theta * paddle.exp(bias)
output = paddle.log(1 + beta * x) / math.log(1 + beta) output = paddle.log(1 + beta * x) / math.log(1 + beta)
return output return output
def surrogate_function_approximate(self, beta, theta, bias): def surrogate_function_approximate(self, beta, theta, bias):
output = (paddle.log(theta) + bias + math.log(beta)) / math.log(1+beta) output = (
paddle.log(theta) + bias + math.log(beta)) / math.log(1 + beta)
return output return output
def surrogate_function_stable(self, beta, theta, target, thresh): def surrogate_function_stable(self, beta, theta, target, thresh):
max_gap = paddle.to_tensor(thresh, dtype='float32') max_gap = paddle.to_tensor(thresh, dtype='float32')
max_gap.stop_gradient = True max_gap.stop_gradient = True
target_max = paddle.maximum(target, max_gap) target_max = paddle.maximum(target, max_gap)
target_min = paddle.minimum(target, max_gap) target_min = paddle.minimum(target, max_gap)
loss1 = self.surrogate_function(beta, theta, target_min) loss1 = self.surrogate_function(beta, theta, target_min)
loss2 = self.surrogate_function_approximate(beta, theta, target_max) loss2 = self.surrogate_function_approximate(beta, theta, target_max)
bias = self.surrogate_function(beta, theta, max_gap) bias = self.surrogate_function(beta, theta, max_gap)
loss = loss1 + loss2 - bias loss = loss1 + loss2 - bias
return loss return loss
def forward(self, input, target=None): def forward(self, input, target=None):
features = input["features"] features = input["features"]
samples_each_class = self.samples_each_class samples_each_class = self.samples_each_class
batch_size = self.batch_size batch_size = self.batch_size
rerange_index = self.rerange_index rerange_index = self.rerange_index
#calc distance #calc distance
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0) diffs = paddle.unsqueeze(
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1) features, axis=1) - paddle.unsqueeze(
features, axis=0)
tmp = paddle.reshape(similary_matrix, shape = [-1, 1]) similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
tmp = paddle.reshape(similary_matrix, shape=[-1, 1])
rerange_index = paddle.to_tensor(rerange_index) rerange_index = paddle.to_tensor(rerange_index)
tmp = paddle.gather(tmp, index=rerange_index) tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, batch_size]) similary_matrix = paddle.reshape(tmp, shape=[-1, batch_size])
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1, ignore, pos, neg = paddle.split(
samples_each_class - 1, batch_size - samples_each_class], axis = 1) similary_matrix,
ignore.stop_gradient = True num_or_sections=[
1, samples_each_class - 1, batch_size - samples_each_class
],
axis=1)
ignore.stop_gradient = True
pos_max = paddle.max(pos, axis=1, keepdim=True) pos_max = paddle.max(pos, axis=1, keepdim=True)
pos = paddle.exp(pos - pos_max) pos = paddle.exp(pos - pos_max)
...@@ -79,11 +87,11 @@ class EmlLoss(paddle.nn.Layer): ...@@ -79,11 +87,11 @@ class EmlLoss(paddle.nn.Layer):
neg_min = paddle.min(neg, axis=1, keepdim=True) neg_min = paddle.min(neg, axis=1, keepdim=True)
neg = paddle.exp(neg_min - neg) neg = paddle.exp(neg_min - neg)
neg_mean = paddle.mean(neg, axis=1, keepdim=True) neg_mean = paddle.mean(neg, axis=1, keepdim=True)
bias = pos_max - neg_min bias = pos_max - neg_min
theta = paddle.multiply(neg_mean, pos_mean) theta = paddle.multiply(neg_mean, pos_mean)
loss = self.surrogate_function_stable(self.beta, theta, bias, self.thresh) loss = self.surrogate_function_stable(self.beta, theta, bias,
self.thresh)
loss = paddle.mean(loss) loss = paddle.mean(loss)
return {"emlloss": loss} return {"emlloss": loss}
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
import paddle import paddle
from .comfunc import rerange_index from .comfunc import rerange_index
class MSMLoss(paddle.nn.Layer): class MSMLoss(paddle.nn.Layer):
""" """
MSMLoss Loss, based on triplet loss. USE P * K samples. MSMLoss Loss, based on triplet loss. USE P * K samples.
...@@ -31,42 +32,47 @@ class MSMLoss(paddle.nn.Layer): ...@@ -31,42 +32,47 @@ class MSMLoss(paddle.nn.Layer):
] ]
only consider samples_each_class = 2 only consider samples_each_class = 2
""" """
def __init__(self, batch_size = 120, samples_each_class=2, margin=0.1):
def __init__(self, batch_size=120, samples_each_class=2, margin=0.1):
super(MSMLoss, self).__init__() super(MSMLoss, self).__init__()
self.margin = margin self.margin = margin
self.samples_each_class = samples_each_class self.samples_each_class = samples_each_class
self.batch_size = batch_size self.batch_size = batch_size
self.rerange_index = rerange_index(batch_size, samples_each_class) self.rerange_index = rerange_index(batch_size, samples_each_class)
def forward(self, input, target=None): def forward(self, input, target=None):
#normalization #normalization
features = input["features"] features = input["features"]
features = self._nomalize(features) features = self._nomalize(features)
samples_each_class = self.samples_each_class samples_each_class = self.samples_each_class
rerange_index = paddle.to_tensor(self.rerange_index) rerange_index = paddle.to_tensor(self.rerange_index)
#calc sm #calc sm
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0) diffs = paddle.unsqueeze(
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1) features, axis=1) - paddle.unsqueeze(
features, axis=0)
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
#rerange #rerange
tmp = paddle.reshape(similary_matrix, shape = [-1, 1]) tmp = paddle.reshape(similary_matrix, shape=[-1, 1])
tmp = paddle.gather(tmp, index=rerange_index) tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size]) similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size])
#split #split
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1, ignore, pos, neg = paddle.split(
samples_each_class - 1, -1], axis = 1) similary_matrix,
ignore.stop_gradient = True num_or_sections=[1, samples_each_class - 1, -1],
axis=1)
ignore.stop_gradient = True
hard_pos = paddle.max(pos) hard_pos = paddle.max(pos)
hard_neg = paddle.min(neg) hard_neg = paddle.min(neg)
loss = hard_pos + self.margin - hard_neg loss = hard_pos + self.margin - hard_neg
loss = paddle.nn.ReLU()(loss) loss = paddle.nn.ReLU()(loss)
return {"msmloss": loss} return {"msmloss": loss}
def _nomalize(self, input): def _nomalize(self, input):
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True)) input_norm = paddle.sqrt(
paddle.sum(paddle.square(input), axis=1, keepdim=True))
return paddle.divide(input, input_norm) return paddle.divide(input, input_norm)
...@@ -3,12 +3,12 @@ from __future__ import division ...@@ -3,12 +3,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import paddle import paddle
class NpairsLoss(paddle.nn.Layer): class NpairsLoss(paddle.nn.Layer):
def __init__(self, reg_lambda=0.01): def __init__(self, reg_lambda=0.01):
super(NpairsLoss, self).__init__() super(NpairsLoss, self).__init__()
self.reg_lambda = reg_lambda self.reg_lambda = reg_lambda
def forward(self, input, target=None): def forward(self, input, target=None):
""" """
anchor and positive(should include label) anchor and positive(should include label)
...@@ -16,22 +16,23 @@ class NpairsLoss(paddle.nn.Layer): ...@@ -16,22 +16,23 @@ class NpairsLoss(paddle.nn.Layer):
features = input["features"] features = input["features"]
reg_lambda = self.reg_lambda reg_lambda = self.reg_lambda
batch_size = features.shape[0] batch_size = features.shape[0]
fea_dim = features.shape[1] fea_dim = features.shape[1]
num_class = batch_size // 2 num_class = batch_size // 2
#reshape #reshape
out_feas = paddle.reshape(features, shape=[-1, 2, fea_dim]) out_feas = paddle.reshape(features, shape=[-1, 2, fea_dim])
anc_feas, pos_feas = paddle.split(out_feas, num_or_sections = 2, axis = 1) anc_feas, pos_feas = paddle.split(out_feas, num_or_sections=2, axis=1)
anc_feas = paddle.squeeze(anc_feas, axis=1) anc_feas = paddle.squeeze(anc_feas, axis=1)
pos_feas = paddle.squeeze(pos_feas, axis=1) pos_feas = paddle.squeeze(pos_feas, axis=1)
#get simi matrix #get simi matrix
similarity_matrix = paddle.matmul(anc_feas, pos_feas, transpose_y=True) #get similarity matrix similarity_matrix = paddle.matmul(
anc_feas, pos_feas, transpose_y=True) #get similarity matrix
sparse_labels = paddle.arange(0, num_class, dtype='int64') sparse_labels = paddle.arange(0, num_class, dtype='int64')
xentloss = paddle.nn.CrossEntropyLoss()(similarity_matrix, sparse_labels) #by default: mean xentloss = paddle.nn.CrossEntropyLoss()(
similarity_matrix, sparse_labels) #by default: mean
#l2 norm #l2 norm
reg = paddle.mean(paddle.sum(paddle.square(features), axis=1)) reg = paddle.mean(paddle.sum(paddle.square(features), axis=1))
l2loss = 0.5 * reg_lambda * reg l2loss = 0.5 * reg_lambda * reg
return {"npairsloss": xentloss + l2loss} return {"npairsloss": xentloss + l2loss}
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
import paddle import paddle
from .comfunc import rerange_index from .comfunc import rerange_index
class TriHardLoss(paddle.nn.Layer): class TriHardLoss(paddle.nn.Layer):
""" """
TriHard Loss, based on triplet loss. USE P * K samples. TriHard Loss, based on triplet loss. USE P * K samples.
...@@ -32,45 +33,50 @@ class TriHardLoss(paddle.nn.Layer): ...@@ -32,45 +33,50 @@ class TriHardLoss(paddle.nn.Layer):
] ]
only consider samples_each_class = 2 only consider samples_each_class = 2
""" """
def __init__(self, batch_size = 120, samples_each_class=2, margin=0.1):
def __init__(self, batch_size=120, samples_each_class=2, margin=0.1):
super(TriHardLoss, self).__init__() super(TriHardLoss, self).__init__()
self.margin = margin self.margin = margin
self.samples_each_class = samples_each_class self.samples_each_class = samples_each_class
self.batch_size = batch_size self.batch_size = batch_size
self.rerange_index = rerange_index(batch_size, samples_each_class) self.rerange_index = rerange_index(batch_size, samples_each_class)
def forward(self, input, target=None): def forward(self, input, target=None):
features = input["features"] features = input["features"]
assert (self.batch_size == features.shape[0]) assert (self.batch_size == features.shape[0])
#normalization #normalization
features = self._nomalize(features) features = self._nomalize(features)
samples_each_class = self.samples_each_class samples_each_class = self.samples_each_class
rerange_index = paddle.to_tensor(self.rerange_index) rerange_index = paddle.to_tensor(self.rerange_index)
#calc sm #calc sm
diffs = paddle.unsqueeze(features, axis=1) - paddle.unsqueeze(features, axis=0) diffs = paddle.unsqueeze(
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1) features, axis=1) - paddle.unsqueeze(
features, axis=0)
similary_matrix = paddle.sum(paddle.square(diffs), axis=-1)
#rerange #rerange
tmp = paddle.reshape(similary_matrix, shape = [-1, 1]) tmp = paddle.reshape(similary_matrix, shape=[-1, 1])
tmp = paddle.gather(tmp, index=rerange_index) tmp = paddle.gather(tmp, index=rerange_index)
similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size]) similary_matrix = paddle.reshape(tmp, shape=[-1, self.batch_size])
#split #split
ignore, pos, neg = paddle.split(similary_matrix, num_or_sections= [1, ignore, pos, neg = paddle.split(
samples_each_class - 1, -1], axis = 1) similary_matrix,
num_or_sections=[1, samples_each_class - 1, -1],
ignore.stop_gradient = True axis=1)
hard_pos = paddle.max(pos, axis=1)
ignore.stop_gradient = True
hard_pos = paddle.max(pos, axis=1)
hard_neg = paddle.min(neg, axis=1) hard_neg = paddle.min(neg, axis=1)
loss = hard_pos + self.margin - hard_neg loss = hard_pos + self.margin - hard_neg
loss = paddle.nn.ReLU()(loss) loss = paddle.nn.ReLU()(loss)
loss = paddle.mean(loss) loss = paddle.mean(loss)
return {"trihardloss": loss} return {"trihardloss": loss}
def _nomalize(self, input): def _nomalize(self, input):
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True)) input_norm = paddle.sqrt(
paddle.sum(paddle.square(input), axis=1, keepdim=True))
return paddle.divide(input, input_norm) return paddle.divide(input, input_norm)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册