提交 af90cd7c 编写于 作者: H HydrogenSulfate

update center loss config and related code

上级 9de22673
......@@ -6,8 +6,8 @@ class BNNeck(paddle.nn.Layer):
super(BNNeck, self).__init__()
self.num_filters = num_filters
self.bn = paddle.nn.BatchNorm1D(
self.num_filters)
self.bn = paddle.nn.BatchNorm1D(self.num_filters)
# TODO: freeze bn.bias
# if not trainable:
# self.bn.bias.trainable = False
......
......@@ -25,10 +25,14 @@ class FC(nn.Layer):
super(FC, self).__init__()
self.embedding_size = embedding_size
self.class_num = class_num
# TODO: hard code for initializer
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.XavierNormal())
initializer=paddle.nn.initializer.Normal(std=0.001))
self.fc = paddle.nn.Linear(
self.embedding_size, self.class_num, weight_attr=weight_attr, bias_attr=bias_attr)
self.embedding_size,
self.class_num,
weight_attr=weight_attr,
bias_attr=bias_attr)
def forward(self, input, label=None):
out = self.fc(input)
......
......@@ -8,12 +8,13 @@ Global:
eval_during_train: True
eval_interval: 10
epochs: 120
print_batch_step: 10
print_batch_step: 20
use_visualdl: False
# used for static mode and model export
image_shape: [3, 256, 128]
save_inference_dir: "./inference"
eval_mode: "retrieval"
feat_from: "neck" # 'backbone' or 'neck'
# model architecture
Arch:
......@@ -29,6 +30,7 @@ Arch:
Neck:
name: BNNeck
num_filters: 2048
# trainable: False # TODO: free bn.bias
Head:
name: "FC"
embedding_size: 2048
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
save_interval: 10
eval_during_train: True
eval_interval: 10
epochs: 120
print_batch_step: 20
use_visualdl: False
# used for static mode and model export
image_shape: [3, 256, 128]
save_inference_dir: "./inference"
eval_mode: "retrieval"
feat_from: "neck" # 'backbone' or 'neck'
# model architecture
Arch:
name: "RecModel"
infer_output_key: "features"
infer_add_softmax: False
Backbone:
name: "ResNet50_last_stage_stride1"
pretrained: True
stem_act: null
BackboneStopLayer:
name: "flatten"
Neck:
name: BNNeck
num_filters: 2048
# trainable: False # TODO: free bn.bias
Head:
name: "FC"
embedding_size: 2048
class_num: 751
bias_attr: false
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
- TripletLossV2:
weight: 1.0
margin: 0.3
normalize_feature: false
- CenterLoss:
weight: 0.0005
num_classes: 751
feat_dim: 2048
Eval:
- CELoss:
weight: 1.0
Optimizer:
model:
name: Adam
lr:
name: Piecewise
decay_epochs: [30, 60]
values: [0.00035, 0.000035, 0.0000035]
warmup_epoch: 10
warmup_start_lr: 0.0000035
regularizer:
name: 'L2'
coeff: 0.0005
loss:
name: SGD
lr:
name: Constant
learning_rate: 0.5
# data loader for train and eval
DataLoader:
Train:
dataset:
name: "VeriWild"
image_root: "./dataset/market1501"
cls_label_path: "./dataset/market1501/bounding_box_train.txt"
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: [128, 256]
- RandFlipImage:
flip_code: 1
- Pad:
padding: 10
- RandCropImage:
size: [128, 256]
scale: [ 0.8022, 0.8022 ]
ratio: [ 0.5, 0.5 ]
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.5
sl: 0.02
sh: 0.4
r1: 0.3
mean: [0.4914, 0.4822, 0.4465]
sampler:
name: DistributedRandomIdentitySampler
batch_size: 64
num_instances: 4
drop_last: True
shuffle: True
loader:
num_workers: 6
use_shared_memory: True
Eval:
Query:
dataset:
name: "VeriWild"
image_root: "./dataset/market1501"
cls_label_path: "./dataset/market1501/query.txt"
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: [128, 256]
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 6
use_shared_memory: True
Gallery:
dataset:
name: "VeriWild"
image_root: "./dataset/market1501"
cls_label_path: "./dataset/market1501/bounding_box_test.txt"
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: [128, 256]
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 6
use_shared_memory: True
Metric:
Eval:
- Recallk:
topk: [1, 5]
- mAP: {}
......@@ -223,7 +223,12 @@ class Engine(object):
if self.mode == 'train':
self.optimizer, self.lr_sch = build_optimizer(
self.config["Optimizer"], self.config["Global"]["epochs"],
len(self.train_dataloader), [self.model])
len(self.train_dataloader), [
self.model, * [
m for m in self.train_loss_func.loss_func
if len(m.parameters()) > 0
]
])
# for amp training
if self.amp:
......@@ -251,6 +256,11 @@ class Engine(object):
if self.config["Global"]["distributed"]:
dist.init_parallel_env()
self.model = paddle.DataParallel(self.model)
# NOTE: parallelize loss which has parameters, such as CenterLoss
for i in range(len(self.train_loss_func.loss_func)):
if len(self.train_loss_func.loss_func[i].parameters()) > 0:
self.train_loss_func.loss_func[i] = paddle.DataParallel(
self.train_loss_func.loss_func[i])
# build postprocess for infer
if self.mode == 'infer':
......
......@@ -125,7 +125,14 @@ def cal_feature(engine, name='gallery'):
out = engine.model(batch[0], batch[1])
if "Student" in out:
out = out["Student"]
batch_feas = out["backbone"]
# get features
if engine.config["Global"].get("feat_from", 'backbone') == 'backbone':
# use backbone's output as features
batch_feas = out["backbone"]
else:
# use neck's output as features
batch_feas = out["neck"]
# do norm
if engine.config["Global"].get("feature_normalize", True):
......
......@@ -54,16 +54,24 @@ def train_epoch(engine, epoch_id, print_batch_step):
out = forward(engine, batch)
loss_dict = engine.train_loss_func(out, batch[1])
# step opt and lr
# step opt
if engine.amp:
scaled = engine.scaler.scale(loss_dict["loss"])
scaled.backward()
engine.scaler.minimize(engine.optimizer, scaled)
for i in range(len(engine.optimizer)):
engine.scaler.minimize(engine.optimizer[i], scaled)
else:
loss_dict["loss"].backward()
engine.optimizer.step()
engine.optimizer.clear_grad()
engine.lr_sch.step()
for i in range(len(engine.optimizer)):
engine.optimizer[i].step()
# clear grad
for i in range(len(engine.optimizer)):
engine.optimizer[i].clear_grad()
# step lr
for i in range(len(engine.lr_sch)):
engine.lr_sch[i].step()
# below code just for logging
# update metric_for_logger
......
......@@ -38,7 +38,12 @@ def update_loss(trainer, loss_dict, batch_size):
def log_info(trainer, batch_size, epoch_id, iter_id):
lr_msg = "lr: {:.5f}".format(trainer.lr_sch.get_lr())
if len(trainer.lr_sch) <= 1:
lr_msg = "lr: {:.8f}".format(trainer.lr_sch[0].get_lr())
else:
lr_msg = "lr_model: {:.8f}".format(trainer.lr_sch[0].get_lr())
lr_msg += ", lr_loss: {:.8f}".format(trainer.lr_sch[1].get_lr())
metric_msg = ", ".join([
"{}: {:.5f}".format(key, trainer.output_info[key].avg)
for key in trainer.output_info
......@@ -58,12 +63,23 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
epoch_id, trainer.config["Global"]["epochs"], iter_id,
len(trainer.train_dataloader), lr_msg, metric_msg, time_msg, ips_msg,
eta_msg))
logger.scaler(
name="lr",
value=trainer.lr_sch.get_lr(),
step=trainer.global_step,
writer=trainer.vdl_writer)
if len(trainer.lr_sch) <= 1:
logger.scaler(
name="lr",
value=trainer.lr_sch[0].get_lr(),
step=trainer.global_step,
writer=trainer.vdl_writer)
else:
logger.scaler(
name="lr_model",
value=trainer.lr_sch[0].get_lr(),
step=trainer.global_step,
writer=trainer.vdl_writer)
logger.scaler(
name="lr_loss",
value=trainer.lr_sch[1].get_lr(),
step=trainer.global_step,
writer=trainer.vdl_writer)
for key in trainer.output_info:
logger.scaler(
name="train_{}".format(key),
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Dict
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import Tensor
class CenterLoss(nn.Layer):
def __init__(self, num_classes=5013, feat_dim=2048):
"""Center loss class
Args:
num_classes (int): number of classes.
feat_dim (int): number of feature dimensions.
"""
def __init__(self, num_classes: int, feat_dim: int):
super(CenterLoss, self).__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.centers = paddle.randn(
shape=[self.num_classes, self.feat_dim]).astype(
"float64") #random center
random_init_centers = paddle.randn(
shape=[self.num_classes, self.feat_dim])
self.centers = self.create_parameter(
shape=(self.num_classes, self.feat_dim),
default_initializer=nn.initializer.Assign(random_init_centers))
self.add_parameter("centers", self.centers)
def __call__(self, input, target):
"""
inputs: network output: {"features: xxx", "logits": xxxx}
target: image label
def __call__(self, input: Dict[str, Tensor],
target: Tensor) -> Dict[str, Tensor]:
"""compute center loss.
Args:
input (Dict[str, Tensor]): {'features': (batch_size, feature_dim), ...}.
target (Tensor): ground truth label with shape (batch_size, ).
Returns:
Dict[str, Tensor]: {'CenterLoss': loss}.
"""
feats = input["features"]
feats = input['backbone']
labels = target
batch_size = feats.shape[0]
#calc feat * feat
# squeeze labels to shape (batch_size, )
if labels.ndim >= 2 and labels.shape[-1] == 1:
labels = paddle.squeeze(labels, axis=[-1])
batch_size = feats.shape[0]
# calc feat * feat
dist1 = paddle.sum(paddle.square(feats), axis=1, keepdim=True)
dist1 = paddle.expand(dist1, [batch_size, self.num_classes])
#dist2 of centers
# dist2 of centers
dist2 = paddle.sum(paddle.square(self.centers), axis=1,
keepdim=True) #num_classes
dist2 = paddle.expand(dist2,
[self.num_classes, batch_size]).astype("float64")
keepdim=True) # num_classes
dist2 = paddle.expand(dist2, [self.num_classes, batch_size])
dist2 = paddle.transpose(dist2, [1, 0])
#first x * x + y * y
# first x * x + y * y
distmat = paddle.add(dist1, dist2)
tmp = paddle.matmul(feats, paddle.transpose(self.centers, [1, 0]))
distmat = distmat - 2.0 * tmp
#generate the mask
classes = paddle.arange(self.num_classes).astype("int64")
# generate the mask
classes = paddle.arange(self.num_classes)
labels = paddle.expand(
paddle.unsqueeze(labels, 1), (batch_size, self.num_classes))
mask = paddle.equal(
paddle.expand(classes, [batch_size, self.num_classes]),
labels).astype("float64") #get mask
labels).astype("float32") # get mask
dist = paddle.multiply(distmat, mask)
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
# return loss
return {'CenterLoss': loss}
......@@ -44,29 +44,97 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
config = copy.deepcopy(config)
# step1 build lr
lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
logger.debug("build lr ({}) success..".format(lr))
# step2 build regularization
if 'regularizer' in config and config['regularizer'] is not None:
if 'weight_decay' in config:
logger.warning(
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config. \"weight_decay\" has been ignored."
)
reg_config = config.pop('regularizer')
reg_name = reg_config.pop('name') + 'Decay'
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
config["weight_decay"] = reg
logger.debug("build regularizer ({}) success..".format(reg))
# step3 build optimizer
optim_name = config.pop('name')
if 'clip_norm' in config:
clip_norm = config.pop('clip_norm')
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
if 'name' in config:
# NOTE: build optimizer and lr for model only.
# step1 build lr
lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
logger.debug("build model's lr ({}) success..".format(lr))
# step2 build regularization
if 'regularizer' in config and config['regularizer'] is not None:
if 'weight_decay' in config:
logger.warning(
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config. \"weight_decay\" has been ignored."
)
reg_config = config.pop('regularizer')
reg_name = reg_config.pop('name') + 'Decay'
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
config["weight_decay"] = reg
logger.debug("build model's regularizer ({}) success..".format(
reg))
# step3 build optimizer
optim_name = config.pop('name')
if 'clip_norm' in config:
clip_norm = config.pop('clip_norm')
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
else:
grad_clip = None
optim = getattr(optimizer, optim_name)(
learning_rate=lr, grad_clip=grad_clip,
**config)(model_list=model_list[0:1])
optim = [optim, ]
lr = [lr, ]
logger.debug("build model's optimizer ({}) success..".format(optim))
else:
grad_clip = None
optim = getattr(optimizer, optim_name)(learning_rate=lr,
grad_clip=grad_clip,
**config)(model_list=model_list)
logger.debug("build optimizer ({}) success..".format(optim))
# NOTE: build optimizer and lr for model and loss.
config_model = config['model']
config_loss = config['loss']
# step1 build lr
lr_model = build_lr_scheduler(
config_model.pop('lr'), epochs, step_each_epoch)
logger.debug("build model's lr ({}) success..".format(lr_model))
# step2 build regularization
if 'regularizer' in config_model and config_model[
'regularizer'] is not None:
if 'weight_decay' in config_model:
logger.warning(
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config. \"weight_decay\" has been ignored."
)
reg_config = config_model.pop('regularizer')
reg_name = reg_config.pop('name') + 'Decay'
reg_model = getattr(paddle.regularizer, reg_name)(**reg_config)
config_model["weight_decay"] = reg_model
logger.debug("build model's regularizer ({}) success..".format(
reg_model))
# step3 build optimizer
optim_name = config_model.pop('name')
if 'clip_norm' in config_model:
clip_norm = config_model.pop('clip_norm')
grad_clip_model = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
else:
grad_clip_model = None
optim_model = getattr(optimizer, optim_name)(
learning_rate=lr_model, grad_clip=grad_clip_model,
**config_model)(model_list=model_list[0:1])
# step4 build lr for loss
lr_loss = build_lr_scheduler(
config_loss.pop('lr'), epochs, step_each_epoch)
logger.debug("build loss's lr ({}) success..".format(lr_loss))
# step5 build regularization for loss
if 'regularizer' in config_loss and config_loss[
'regularizer'] is not None:
if 'weight_decay' in config_loss:
logger.warning(
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config. \"weight_decay\" has been ignored."
)
reg_config = config_loss.pop('regularizer')
reg_name = reg_config.pop('name') + 'Decay'
reg_loss = getattr(paddle.regularizer, reg_name)(**reg_config)
config_loss["weight_decay"] = reg_loss
logger.debug("build loss's regularizer ({}) success..".format(
reg_loss))
# step6 build optimizer for loss
optim_name = config_loss.pop('name')
if 'clip_norm' in config_loss:
clip_norm = config_loss.pop('clip_norm')
grad_clip_loss = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
else:
grad_clip_loss = None
optim_loss = getattr(optimizer, optim_name)(
learning_rate=lr_loss, grad_clip=grad_clip_loss,
**config_loss)(model_list=model_list[1:2])
optim = [optim_model, optim_loss]
lr = [lr_model, lr_loss]
logger.debug("build loss's optimizer ({}) success..".format(optim))
return optim, lr
......@@ -75,6 +75,23 @@ class Linear(object):
return learning_rate
class Constant(LRScheduler):
"""
Constant learning rate
Args:
lr (float): The initial learning rate. It is a python float number.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def __init__(self, learning_rate, last_epoch=-1, **kwargs):
self.learning_rate = learning_rate
self.last_epoch = last_epoch
super().__init__()
def get_lr(self):
return self.learning_rate
class Cosine(object):
"""
Cosine learning rate decay
......
......@@ -48,7 +48,7 @@ def _mkdir_if_not_exist(path):
def load_dygraph_pretrain(model, path=None):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
"exists.".format(path + '.pdparams'))
param_state_dict = paddle.load(path + ".pdparams")
model.set_dict(param_state_dict)
return
......@@ -99,7 +99,8 @@ def init_model(config, net, optimizer=None):
opti_dict = paddle.load(checkpoints + ".pdopt")
metric_dict = paddle.load(checkpoints + ".pdstates")
net.set_dict(para_dict)
optimizer.set_state_dict(opti_dict)
for i in range(len(optimizer)):
optimizer[i].set_state_dict(opti_dict)
logger.info("Finish load checkpoints from {}".format(checkpoints))
return metric_dict
......@@ -131,6 +132,6 @@ def save_model(net,
model_path = os.path.join(model_path, prefix)
paddle.save(net.state_dict(), model_path + ".pdparams")
paddle.save(optimizer.state_dict(), model_path + ".pdopt")
paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt")
paddle.save(metric_info, model_path + ".pdstates")
logger.info("Already save model in {}".format(model_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册