提交 674447f6 编写于 作者: H HydrogenSulfate

refine code

上级 97e8abc3
......@@ -48,6 +48,7 @@ Loss:
weight: 1.0
margin: 0.3
normalize_feature: false
feat_from: "backbone"
Eval:
- CELoss:
weight: 1.0
......
......@@ -61,6 +61,7 @@ Loss:
weight: 1.0
margin: 0.3
normalize_feature: false
feat_from: "backbone"
Eval:
- CELoss:
weight: 1.0
......
......@@ -40,7 +40,7 @@ Arch:
initializer:
name: Constant
value: 0.0
learning_rate: 1.0e-20 # TODO: Temporarily set lr small enough to freeze the bias
learning_rate: 1.0e-20 # NOTE: Temporarily set lr small enough to freeze the bias
Head:
name: "FC"
embedding_size: *feat_dim
......@@ -57,14 +57,16 @@ Loss:
- CELoss:
weight: 1.0
epsilon: 0.1
- TripletLossV3:
- TripletLossV2:
weight: 1.0
margin: 0.3
normalize_feature: false
feat_from: "backbone"
- CenterLoss:
weight: 0.0005
num_classes: *class_num
feat_dim: *feat_dim
feat_from: "backbone"
Eval:
- CELoss:
weight: 1.0
......
......@@ -28,12 +28,17 @@ class CenterLoss(nn.Layer):
Args:
num_classes (int): number of classes.
feat_dim (int): number of feature dimensions.
feat_from (str): features from backbone or neck
"""
def __init__(self, num_classes: int, feat_dim: int):
def __init__(self,
num_classes: int,
feat_dim: int,
feat_from: str='backbone'):
super(CenterLoss, self).__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.feat_from = feat_from
random_init_centers = paddle.randn(
shape=[self.num_classes, self.feat_dim])
self.centers = self.create_parameter(
......@@ -52,7 +57,7 @@ class CenterLoss(nn.Layer):
Returns:
Dict[str, paddle.Tensor]: {'CenterLoss': loss}.
"""
feats = input['backbone']
feats = input[self.feat_from]
labels = target
# squeeze labels to shape (batch_size, )
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Tuple
import paddle
import paddle.nn as nn
......@@ -13,9 +26,13 @@ class TripletLossV2(nn.Layer):
margin (float): margin for triplet.
"""
def __init__(self, margin=0.5, normalize_feature=True):
def __init__(self,
margin=0.5,
normalize_feature=True,
feat_from='backbone'):
super(TripletLossV2, self).__init__()
self.margin = margin
self.feat_from = feat_from
self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin)
self.normalize_feature = normalize_feature
......@@ -25,7 +42,7 @@ class TripletLossV2(nn.Layer):
inputs: feature matrix with shape (batch_size, feat_dim)
target: ground truth labels with shape (num_classes)
"""
inputs = input["backbone"]
inputs = input[self.feat_from]
if self.normalize_feature:
inputs = 1. * inputs / (paddle.expand_as(
......@@ -136,122 +153,3 @@ class TripletLoss(nn.Layer):
y = paddle.ones_like(dist_an)
loss = self.ranking_loss(dist_an, dist_ap, y)
return {"TripletLoss": loss}
class TripletLossV3(nn.Layer):
"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
Loss for Person Re-Identification'."""
def __init__(self, margin=None, normalize_feature=False):
super(TripletLossV3, self).__init__()
self.normalize_feature = normalize_feature
self.margin = margin
if margin is not None:
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
else:
self.ranking_loss = nn.SoftMarginLoss()
def forward(self, input, target):
global_feat = input["backbone"]
if self.normalize_feature:
global_feat = self._normalize(global_feat, axis=-1)
dist_mat = self._euclidean_dist(global_feat, global_feat)
dist_ap, dist_an = self._hard_example_mining(dist_mat, target)
y = paddle.ones_like(dist_an)
if self.margin is not None:
loss = self.ranking_loss(dist_an, dist_ap, y)
return {"TripletLossV3": loss}
def _normalize(self, x: paddle.Tensor, axis: int=-1) -> paddle.Tensor:
"""Normalizing to unit length along the specified dimension.
Args:
x (paddle.Tensor): (batch_size, feature_dim)
axis (int, optional): normalization dim. Defaults to -1.
Returns:
paddle.Tensor: (batch_size, feature_dim)
"""
x = 1. * x / (paddle.norm(
x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
return x
def _euclidean_dist(self, x: paddle.Tensor,
y: paddle.Tensor) -> paddle.Tensor:
"""compute euclidean distance between two batched vectors
Args:
x (paddle.Tensor): (N, feature_dim)
y (paddle.Tensor): (M, feature_dim)
Returns:
paddle.Tensor: (N, M)
"""
m, n = x.shape[0], y.shape[0]
d = x.shape[1]
xx = paddle.pow(x, 2).sum(1, keepdim=True).expand([m, n])
yy = paddle.pow(y, 2).sum(1, keepdim=True).expand([n, m]).t()
dist = xx + yy
dist = dist.addmm(x, y.t(), alpha=-2, beta=1)
# dist = dist - 2*(x@y.t())
dist = dist.clip(min=1e-12).sqrt() # for numerical stability
return dist
def _hard_example_mining(
self,
dist_mat: paddle.Tensor,
labels: paddle.Tensor,
return_inds: bool=False) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""For each anchor, find the hardest positive and negative sample.
Args:
dist_mat (paddle.Tensor): pair wise distance between samples, [N, N]
labels (paddle.Tensor): labels, [N, ]
return_inds (bool, optional): whether to return the indices . Defaults to False.
Returns:
Tuple[paddle.Tensor, paddle.Tensor]: [(N, ), (N, )]
NOTE: Only consider the case in which all labels have same num of samples,
thus we can cope with all anchors in parallel.
"""
assert len(dist_mat.shape) == 2
assert dist_mat.shape[0] == dist_mat.shape[1]
N = dist_mat.shape[0]
# shape [N, N]
is_pos = labels.expand([N, N]).equal(labels.expand([N, N]).t())
is_neg = labels.expand([N, N]).not_equal(labels.expand([N, N]).t())
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
dist_ap = paddle.max(dist_mat[is_pos].reshape([N, -1]),
1,
keepdim=True)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N, 1]
dist_an = paddle.min(dist_mat[is_neg].reshape([N, -1]),
1,
keepdim=True)
# shape [N]
dist_ap = dist_ap.squeeze(1)
dist_an = dist_an.squeeze(1)
if return_inds:
# shape [N, N]
ind = (labels.new().resize_as_(labels)
.copy_(paddle.arange(0, N).long())
.unsqueeze(0).expand(N, N))
# shape [N, 1]
p_inds = paddle.gather(ind[is_pos].reshape([N, -1]), 1,
relative_p_inds.data)
n_inds = paddle.gather(ind[is_neg].reshape([N, -1]), 1,
relative_n_inds.data)
# shape [N]
p_inds = p_inds.squeeze(1)
n_inds = n_inds.squeeze(1)
return dist_ap, dist_an, p_inds, n_inds
return dist_ap, dist_an
......@@ -46,7 +46,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
optim_config = copy.deepcopy(config)
if isinstance(optim_config, dict):
# convert {'name': xxx, **optim_cfg} to [{'name': {'scope': xxx, **optim_cfg}}]
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
optim_name = optim_config.pop("name")
optim_config: List[Dict[str, Dict]] = [{
optim_name: {
......@@ -60,20 +60,19 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
"""NOTE:
Currently only support optim objets below.
1. single optimizer config.
2. model(entire Arch), backbone, neck, head.
3. loss(entire Loss), specific loss listed in ppcls/loss/__init__.py.
2. next level uner Arch, such as Arch.backbone, Arch.neck, Arch.head.
3. loss which has parameters, such as CenterLoss.
"""
for optim_item in optim_config:
# optim_cfg = {optim_name: {'scope': xxx, **optim_cfg}}
# optim_cfg = {optim_name: {scope: xxx, **optim_cfg}}
# step1 build lr
optim_name = list(optim_item.keys())[0] # get optim_name
optim_scope_list = optim_item[optim_name].pop('scope').split(
' ') # get optim_scope list
optim_scope = optim_item[optim_name].pop('scope') # get optim_scope
optim_cfg = optim_item[optim_name] # get optim_cfg
lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch)
logger.info("build lr ({}) for scope ({}) success..".format(
lr.__class__.__name__, optim_scope_list))
logger.debug("build lr ({}) for scope ({}) success..".format(
lr, optim_scope))
# step2 build regularization
if 'regularizer' in optim_cfg and optim_cfg['regularizer'] is not None:
if 'weight_decay' in optim_cfg:
......@@ -84,14 +83,12 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
reg_name = reg_config.pop('name') + 'Decay'
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
optim_cfg["weight_decay"] = reg
logger.info("build regularizer ({}) for scope ({}) success..".
format(reg.__class__.__name__, optim_scope_list))
logger.debug("build regularizer ({}) for scope ({}) success..".
format(reg, optim_scope))
# step3 build optimizer
if 'clip_norm' in optim_cfg:
clip_norm = optim_cfg.pop('clip_norm')
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
logger.info("build gradclip ({}) for scope ({}) success..".format(
grad_clip.__class__.__name__, optim_scope_list))
else:
grad_clip = None
optim_model = []
......@@ -104,34 +101,30 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
return optim, lr
# for dynamic graph
for scope in optim_scope_list:
if scope == "all":
optim_model += model_list
elif scope == "model":
optim_model += [model_list[0], ]
elif scope in ["backbone", "neck", "head"]:
optim_model += [getattr(model_list[0], scope, None), ]
elif scope == "loss":
optim_model += [model_list[1], ]
for i in range(len(model_list)):
if len(model_list[i].parameters()) == 0:
continue
if optim_scope == "all":
# optimizer for all
optim_model.append(model_list[i])
else:
optim_model += [
model_list[1].loss_func[i]
for i in range(len(model_list[1].loss_func))
if model_list[1].loss_func[i].__class__.__name__ == scope
]
# remove invalid items
optim_model = [
optim_model[i] for i in range(len(optim_model))
if (optim_model[i] is not None
) and (len(optim_model[i].parameters()) > 0)
]
assert len(optim_model) > 0, \
f"optim_model is empty for optim_scope({optim_scope_list})"
if optim_scope.endswith("Loss"):
# optimizer for loss
for m in model_list[i].sublayers(True):
if m.__class__.__name__ == optim_scope:
optim_model.append(m)
else:
# opmizer for module in model, such as backbone, neck, head...
if hasattr(model_list[i], optim_scope):
optim_model.append(getattr(model_list[i], optim_scope))
assert len(optim_model) == 1, \
"Invalid optim model for optim scope({}), number of optim_model={}".format(optim_scope, len(optim_model))
optim = getattr(optimizer, optim_name)(
learning_rate=lr, grad_clip=grad_clip,
**optim_cfg)(model_list=optim_model)
logger.info("build optimizer ({}) for scope ({}) success..".format(
optim.__class__.__name__, optim_scope_list))
logger.debug("build optimizer ({}) for scope ({}) success..".format(
optim, optim_scope))
optim_list.append(optim)
lr_list.append(lr)
return optim_list, lr_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册