“efc119b43b1e2e296682c20d3a244234eb427405”上不存在“paddle/operators/sequence_softmax_op.cu”
提交 674447f6 编写于 作者: H HydrogenSulfate

refine code

上级 97e8abc3
...@@ -48,6 +48,7 @@ Loss: ...@@ -48,6 +48,7 @@ Loss:
weight: 1.0 weight: 1.0
margin: 0.3 margin: 0.3
normalize_feature: false normalize_feature: false
feat_from: "backbone"
Eval: Eval:
- CELoss: - CELoss:
weight: 1.0 weight: 1.0
......
...@@ -61,6 +61,7 @@ Loss: ...@@ -61,6 +61,7 @@ Loss:
weight: 1.0 weight: 1.0
margin: 0.3 margin: 0.3
normalize_feature: false normalize_feature: false
feat_from: "backbone"
Eval: Eval:
- CELoss: - CELoss:
weight: 1.0 weight: 1.0
......
...@@ -40,7 +40,7 @@ Arch: ...@@ -40,7 +40,7 @@ Arch:
initializer: initializer:
name: Constant name: Constant
value: 0.0 value: 0.0
learning_rate: 1.0e-20 # TODO: Temporarily set lr small enough to freeze the bias learning_rate: 1.0e-20 # NOTE: Temporarily set lr small enough to freeze the bias
Head: Head:
name: "FC" name: "FC"
embedding_size: *feat_dim embedding_size: *feat_dim
...@@ -57,14 +57,16 @@ Loss: ...@@ -57,14 +57,16 @@ Loss:
- CELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
- TripletLossV3: - TripletLossV2:
weight: 1.0 weight: 1.0
margin: 0.3 margin: 0.3
normalize_feature: false normalize_feature: false
feat_from: "backbone"
- CenterLoss: - CenterLoss:
weight: 0.0005 weight: 0.0005
num_classes: *class_num num_classes: *class_num
feat_dim: *feat_dim feat_dim: *feat_dim
feat_from: "backbone"
Eval: Eval:
- CELoss: - CELoss:
weight: 1.0 weight: 1.0
......
...@@ -28,12 +28,17 @@ class CenterLoss(nn.Layer): ...@@ -28,12 +28,17 @@ class CenterLoss(nn.Layer):
Args: Args:
num_classes (int): number of classes. num_classes (int): number of classes.
feat_dim (int): number of feature dimensions. feat_dim (int): number of feature dimensions.
feat_from (str): features from backbone or neck
""" """
def __init__(self, num_classes: int, feat_dim: int): def __init__(self,
num_classes: int,
feat_dim: int,
feat_from: str='backbone'):
super(CenterLoss, self).__init__() super(CenterLoss, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.feat_dim = feat_dim self.feat_dim = feat_dim
self.feat_from = feat_from
random_init_centers = paddle.randn( random_init_centers = paddle.randn(
shape=[self.num_classes, self.feat_dim]) shape=[self.num_classes, self.feat_dim])
self.centers = self.create_parameter( self.centers = self.create_parameter(
...@@ -52,7 +57,7 @@ class CenterLoss(nn.Layer): ...@@ -52,7 +57,7 @@ class CenterLoss(nn.Layer):
Returns: Returns:
Dict[str, paddle.Tensor]: {'CenterLoss': loss}. Dict[str, paddle.Tensor]: {'CenterLoss': loss}.
""" """
feats = input['backbone'] feats = input[self.feat_from]
labels = target labels = target
# squeeze labels to shape (batch_size, ) # squeeze labels to shape (batch_size, )
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from typing import Tuple
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
...@@ -13,9 +26,13 @@ class TripletLossV2(nn.Layer): ...@@ -13,9 +26,13 @@ class TripletLossV2(nn.Layer):
margin (float): margin for triplet. margin (float): margin for triplet.
""" """
def __init__(self, margin=0.5, normalize_feature=True): def __init__(self,
margin=0.5,
normalize_feature=True,
feat_from='backbone'):
super(TripletLossV2, self).__init__() super(TripletLossV2, self).__init__()
self.margin = margin self.margin = margin
self.feat_from = feat_from
self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin) self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin)
self.normalize_feature = normalize_feature self.normalize_feature = normalize_feature
...@@ -25,7 +42,7 @@ class TripletLossV2(nn.Layer): ...@@ -25,7 +42,7 @@ class TripletLossV2(nn.Layer):
inputs: feature matrix with shape (batch_size, feat_dim) inputs: feature matrix with shape (batch_size, feat_dim)
target: ground truth labels with shape (num_classes) target: ground truth labels with shape (num_classes)
""" """
inputs = input["backbone"] inputs = input[self.feat_from]
if self.normalize_feature: if self.normalize_feature:
inputs = 1. * inputs / (paddle.expand_as( inputs = 1. * inputs / (paddle.expand_as(
...@@ -136,122 +153,3 @@ class TripletLoss(nn.Layer): ...@@ -136,122 +153,3 @@ class TripletLoss(nn.Layer):
y = paddle.ones_like(dist_an) y = paddle.ones_like(dist_an)
loss = self.ranking_loss(dist_an, dist_ap, y) loss = self.ranking_loss(dist_an, dist_ap, y)
return {"TripletLoss": loss} return {"TripletLoss": loss}
class TripletLossV3(nn.Layer):
"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
Loss for Person Re-Identification'."""
def __init__(self, margin=None, normalize_feature=False):
super(TripletLossV3, self).__init__()
self.normalize_feature = normalize_feature
self.margin = margin
if margin is not None:
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
else:
self.ranking_loss = nn.SoftMarginLoss()
def forward(self, input, target):
global_feat = input["backbone"]
if self.normalize_feature:
global_feat = self._normalize(global_feat, axis=-1)
dist_mat = self._euclidean_dist(global_feat, global_feat)
dist_ap, dist_an = self._hard_example_mining(dist_mat, target)
y = paddle.ones_like(dist_an)
if self.margin is not None:
loss = self.ranking_loss(dist_an, dist_ap, y)
return {"TripletLossV3": loss}
def _normalize(self, x: paddle.Tensor, axis: int=-1) -> paddle.Tensor:
"""Normalizing to unit length along the specified dimension.
Args:
x (paddle.Tensor): (batch_size, feature_dim)
axis (int, optional): normalization dim. Defaults to -1.
Returns:
paddle.Tensor: (batch_size, feature_dim)
"""
x = 1. * x / (paddle.norm(
x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
return x
def _euclidean_dist(self, x: paddle.Tensor,
y: paddle.Tensor) -> paddle.Tensor:
"""compute euclidean distance between two batched vectors
Args:
x (paddle.Tensor): (N, feature_dim)
y (paddle.Tensor): (M, feature_dim)
Returns:
paddle.Tensor: (N, M)
"""
m, n = x.shape[0], y.shape[0]
d = x.shape[1]
xx = paddle.pow(x, 2).sum(1, keepdim=True).expand([m, n])
yy = paddle.pow(y, 2).sum(1, keepdim=True).expand([n, m]).t()
dist = xx + yy
dist = dist.addmm(x, y.t(), alpha=-2, beta=1)
# dist = dist - 2*(x@y.t())
dist = dist.clip(min=1e-12).sqrt() # for numerical stability
return dist
def _hard_example_mining(
self,
dist_mat: paddle.Tensor,
labels: paddle.Tensor,
return_inds: bool=False) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""For each anchor, find the hardest positive and negative sample.
Args:
dist_mat (paddle.Tensor): pair wise distance between samples, [N, N]
labels (paddle.Tensor): labels, [N, ]
return_inds (bool, optional): whether to return the indices . Defaults to False.
Returns:
Tuple[paddle.Tensor, paddle.Tensor]: [(N, ), (N, )]
NOTE: Only consider the case in which all labels have same num of samples,
thus we can cope with all anchors in parallel.
"""
assert len(dist_mat.shape) == 2
assert dist_mat.shape[0] == dist_mat.shape[1]
N = dist_mat.shape[0]
# shape [N, N]
is_pos = labels.expand([N, N]).equal(labels.expand([N, N]).t())
is_neg = labels.expand([N, N]).not_equal(labels.expand([N, N]).t())
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
dist_ap = paddle.max(dist_mat[is_pos].reshape([N, -1]),
1,
keepdim=True)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N, 1]
dist_an = paddle.min(dist_mat[is_neg].reshape([N, -1]),
1,
keepdim=True)
# shape [N]
dist_ap = dist_ap.squeeze(1)
dist_an = dist_an.squeeze(1)
if return_inds:
# shape [N, N]
ind = (labels.new().resize_as_(labels)
.copy_(paddle.arange(0, N).long())
.unsqueeze(0).expand(N, N))
# shape [N, 1]
p_inds = paddle.gather(ind[is_pos].reshape([N, -1]), 1,
relative_p_inds.data)
n_inds = paddle.gather(ind[is_neg].reshape([N, -1]), 1,
relative_n_inds.data)
# shape [N]
p_inds = p_inds.squeeze(1)
n_inds = n_inds.squeeze(1)
return dist_ap, dist_an, p_inds, n_inds
return dist_ap, dist_an
...@@ -46,7 +46,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): ...@@ -46,7 +46,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
def build_optimizer(config, epochs, step_each_epoch, model_list=None): def build_optimizer(config, epochs, step_each_epoch, model_list=None):
optim_config = copy.deepcopy(config) optim_config = copy.deepcopy(config)
if isinstance(optim_config, dict): if isinstance(optim_config, dict):
# convert {'name': xxx, **optim_cfg} to [{'name': {'scope': xxx, **optim_cfg}}] # convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
optim_name = optim_config.pop("name") optim_name = optim_config.pop("name")
optim_config: List[Dict[str, Dict]] = [{ optim_config: List[Dict[str, Dict]] = [{
optim_name: { optim_name: {
...@@ -60,20 +60,19 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None): ...@@ -60,20 +60,19 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
"""NOTE: """NOTE:
Currently only support optim objets below. Currently only support optim objets below.
1. single optimizer config. 1. single optimizer config.
2. model(entire Arch), backbone, neck, head. 2. next level uner Arch, such as Arch.backbone, Arch.neck, Arch.head.
3. loss(entire Loss), specific loss listed in ppcls/loss/__init__.py. 3. loss which has parameters, such as CenterLoss.
""" """
for optim_item in optim_config: for optim_item in optim_config:
# optim_cfg = {optim_name: {'scope': xxx, **optim_cfg}} # optim_cfg = {optim_name: {scope: xxx, **optim_cfg}}
# step1 build lr # step1 build lr
optim_name = list(optim_item.keys())[0] # get optim_name optim_name = list(optim_item.keys())[0] # get optim_name
optim_scope_list = optim_item[optim_name].pop('scope').split( optim_scope = optim_item[optim_name].pop('scope') # get optim_scope
' ') # get optim_scope list
optim_cfg = optim_item[optim_name] # get optim_cfg optim_cfg = optim_item[optim_name] # get optim_cfg
lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch) lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch)
logger.info("build lr ({}) for scope ({}) success..".format( logger.debug("build lr ({}) for scope ({}) success..".format(
lr.__class__.__name__, optim_scope_list)) lr, optim_scope))
# step2 build regularization # step2 build regularization
if 'regularizer' in optim_cfg and optim_cfg['regularizer'] is not None: if 'regularizer' in optim_cfg and optim_cfg['regularizer'] is not None:
if 'weight_decay' in optim_cfg: if 'weight_decay' in optim_cfg:
...@@ -84,14 +83,12 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None): ...@@ -84,14 +83,12 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
reg_name = reg_config.pop('name') + 'Decay' reg_name = reg_config.pop('name') + 'Decay'
reg = getattr(paddle.regularizer, reg_name)(**reg_config) reg = getattr(paddle.regularizer, reg_name)(**reg_config)
optim_cfg["weight_decay"] = reg optim_cfg["weight_decay"] = reg
logger.info("build regularizer ({}) for scope ({}) success..". logger.debug("build regularizer ({}) for scope ({}) success..".
format(reg.__class__.__name__, optim_scope_list)) format(reg, optim_scope))
# step3 build optimizer # step3 build optimizer
if 'clip_norm' in optim_cfg: if 'clip_norm' in optim_cfg:
clip_norm = optim_cfg.pop('clip_norm') clip_norm = optim_cfg.pop('clip_norm')
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm) grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
logger.info("build gradclip ({}) for scope ({}) success..".format(
grad_clip.__class__.__name__, optim_scope_list))
else: else:
grad_clip = None grad_clip = None
optim_model = [] optim_model = []
...@@ -104,34 +101,30 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None): ...@@ -104,34 +101,30 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
return optim, lr return optim, lr
# for dynamic graph # for dynamic graph
for scope in optim_scope_list: for i in range(len(model_list)):
if scope == "all": if len(model_list[i].parameters()) == 0:
optim_model += model_list continue
elif scope == "model": if optim_scope == "all":
optim_model += [model_list[0], ] # optimizer for all
elif scope in ["backbone", "neck", "head"]: optim_model.append(model_list[i])
optim_model += [getattr(model_list[0], scope, None), ]
elif scope == "loss":
optim_model += [model_list[1], ]
else: else:
optim_model += [ if optim_scope.endswith("Loss"):
model_list[1].loss_func[i] # optimizer for loss
for i in range(len(model_list[1].loss_func)) for m in model_list[i].sublayers(True):
if model_list[1].loss_func[i].__class__.__name__ == scope if m.__class__.__name__ == optim_scope:
] optim_model.append(m)
# remove invalid items else:
optim_model = [ # opmizer for module in model, such as backbone, neck, head...
optim_model[i] for i in range(len(optim_model)) if hasattr(model_list[i], optim_scope):
if (optim_model[i] is not None optim_model.append(getattr(model_list[i], optim_scope))
) and (len(optim_model[i].parameters()) > 0)
] assert len(optim_model) == 1, \
assert len(optim_model) > 0, \ "Invalid optim model for optim scope({}), number of optim_model={}".format(optim_scope, len(optim_model))
f"optim_model is empty for optim_scope({optim_scope_list})"
optim = getattr(optimizer, optim_name)( optim = getattr(optimizer, optim_name)(
learning_rate=lr, grad_clip=grad_clip, learning_rate=lr, grad_clip=grad_clip,
**optim_cfg)(model_list=optim_model) **optim_cfg)(model_list=optim_model)
logger.info("build optimizer ({}) for scope ({}) success..".format( logger.debug("build optimizer ({}) for scope ({}) success..".format(
optim.__class__.__name__, optim_scope_list)) optim, optim_scope))
optim_list.append(optim) optim_list.append(optim)
lr_list.append(lr) lr_list.append(lr)
return optim_list, lr_list return optim_list, lr_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册