提交 dfd77498 编写于 作者: H HydrogenSulfate

refine hard code

上级 41e1a86c
......@@ -276,6 +276,7 @@ class ResNet(TheseusLayer):
config,
stages_pattern,
version="vb",
stem_act="relu",
class_num=1000,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
data_format="NCHW",
......@@ -309,13 +310,13 @@ class ResNet(TheseusLayer):
[[input_image_channel, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]]
}
self.stem = nn.Sequential(* [
self.stem = nn.Sequential(*[
ConvBNLayer(
num_channels=in_c,
num_filters=out_c,
filter_size=k,
stride=s,
act="relu",
act=stem_act,
lr_mult=self.lr_mult_list[0],
data_format=data_format)
for in_c, out_c, k, s in self.stem_cfg[version]
......
......@@ -32,6 +32,7 @@ class BNNeck(nn.Layer):
epsilon=1e-05,
weight_attr=weight_attr,
bias_attr=bias_attr)
# TODO: set bnneck.bias learnable=False
self.flatten = nn.Flatten()
def forward(self, x):
......
......@@ -31,11 +31,11 @@ class FC(nn.Layer):
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.XavierNormal())
if 'weight_attr' in kwargs:
weight_attr = get_param_attr_dict(kwargs['weight_attr'], None)
weight_attr = get_param_attr_dict(kwargs['weight_attr'])
bias_attr = None
if 'bias_attr' in kwargs:
bias_attr = get_param_attr_dict(kwargs['bias_attr'], None)
bias_attr = get_param_attr_dict(kwargs['bias_attr'])
self.fc = nn.Linear(
self.embedding_size,
......
......@@ -61,31 +61,30 @@ Loss:
Optimizer:
- Adam:
scope: model
lr:
name: Piecewise
decay_epochs: [30, 60]
values: [0.00035, 0.000035, 0.0000035]
warmup_epoch: 10
warmup_start_lr: 0.0000035
warmup_epoch_by_epoch: True
regularizer:
name: 'L2'
coeff: 0.0005
scope: model
lr:
name: Piecewise
decay_epochs: [30, 60]
values: [0.00035, 0.000035, 0.0000035]
warmup_epoch: 10
warmup_start_lr: 0.0000035
warmup_epoch_by_epoch: True
regularizer:
name: 'L2'
coeff: 0.0005
- SGD:
sope: TripletLossV3
lr:
name: Constant
learning_rate: 0.5
scope: CenterLoss
lr:
name: Constant
learning_rate: 1000.0
# data loader for train and eval
DataLoader:
Train:
dataset:
name: "VeriWild"
image_root: "./dataset/market1501/bounding_box_train"
cls_label_path: "./dataset/market1501/bounding_box_train.txt"
relabel: True
name: "Market1501"
image_root: "./dataset/Market-1501-v15.09.15"
cls_label_path: "bounding_box_train"
transform_ops:
- DecodeImage:
to_rgb: True
......@@ -123,9 +122,9 @@ DataLoader:
Eval:
Query:
dataset:
name: "VeriWild"
image_root: "./dataset/market1501/query"
cls_label_path: "./dataset/market1501/query.txt"
name: "Market1501"
image_root: "./dataset/Market-1501-v15.09.15"
cls_label_path: "query"
transform_ops:
- DecodeImage:
to_rgb: True
......@@ -148,9 +147,9 @@ DataLoader:
Gallery:
dataset:
name: "VeriWild"
image_root: "./dataset/market1501/bounding_box_test"
cls_label_path: "./dataset/market1501/bounding_box_test.txt"
name: "Market1501"
image_root: "./dataset/Market-1501-v15.09.15"
cls_label_path: "bounding_box_test"
transform_ops:
- DecodeImage:
to_rgb: True
......
......@@ -29,6 +29,7 @@ from ppcls.data.preprocess.ops.operators import RandFlipImage
from ppcls.data.preprocess.ops.operators import NormalizeImage
from ppcls.data.preprocess.ops.operators import ToCHWImage
from ppcls.data.preprocess.ops.operators import AugMix
from ppcls.data.preprocess.ops.operators import Pad
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
......
......@@ -25,6 +25,7 @@ import cv2
import numpy as np
from PIL import Image
from paddle.vision.transforms import ColorJitter as RawColorJitter
from paddle.vision.transforms import Pad
from .autoaugment import ImageNetPolicy
from .functional import augmentations
......@@ -81,6 +82,8 @@ class UnifiedResize(object):
self.resize_func = cv2.resize
def __call__(self, src, size):
if isinstance(size, list):
size = tuple(size)
return self.resize_func(src, size)
......@@ -99,14 +102,15 @@ class DecodeImage(object):
self.channel_first = channel_first # only enabled when to_np is True
def __call__(self, img):
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
data = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(data, 1)
if not isinstance(img, np.ndarray):
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
data = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(data, 1)
if self.to_rgb:
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
img.shape)
......
......@@ -70,18 +70,6 @@ def train_epoch(engine, epoch_id, print_batch_step):
# clear grad
for i in range(len(engine.optimizer)):
# manually scale up grad of center_loss
if i == 1:
for j in range(len(engine.train_loss_func.loss_func)):
if len(engine.train_loss_func.loss_func[j].parameters(
)) == 0:
continue
for param in engine.train_loss_func.loss_func[
j].parameters():
if hasattr(param, 'grad') and param.grad is not None:
param.grad.set_value(param.grad * (
1.0 / engine.train_loss_func.loss_weight[j]))
engine.optimizer[i].clear_grad()
# step lr
......
......@@ -47,7 +47,7 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
config = copy.deepcopy(config)
optim_config = config["Optimizer"]
if isinstance(optim_config, dict):
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
# convert {'name': xxx, **optim_cfg} to [{'name': {'scope': xxx, **optim_cfg}}]
optim_name = optim_config.pop("name")
optim_config: List[Dict[str, Dict]] = [{
optim_name: {
......@@ -65,15 +65,15 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
3. loss which has parameters, such as CenterLoss.
"""
for optim_item in optim_config:
# optim_cfg = {optim_name: {scope: xxx, **optim_cfg}}
# optim_cfg = {optim_name: {'scope': xxx, **optim_cfg}}
# step1 build lr
optim_name = list(optim_item.keys())[0] # get optim_name
optim_scope = optim_item[optim_name].pop('scope') # get optim_scope
optim_cfg = optim_item[optim_name] # get optim_cfg
lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch)
logger.debug("build lr ({}) for scope ({}) success..".format(
lr, optim_scope))
logger.info("build lr ({}) for scope ({}) success..".format(
lr.__class__.__name__, optim_scope))
# step2 build regularization
if 'regularizer' in optim_cfg and optim_cfg['regularizer'] is not None:
if 'weight_decay' in optim_cfg:
......@@ -84,8 +84,8 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
reg_name = reg_config.pop('name') + 'Decay'
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
optim_cfg["weight_decay"] = reg
logger.debug("build regularizer ({}) for scope ({}) success..".
format(reg, optim_scope))
logger.info("build regularizer ({}) for scope ({}) success..".
format(reg.__class__.__name__, optim_scope))
# step3 build optimizer
if 'clip_norm' in optim_cfg:
clip_norm = optim_cfg.pop('clip_norm')
......@@ -100,26 +100,31 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
# optimizer for all
optim_model.append(model_list[i])
else:
if optim_scope.endswith("Loss"):
if "Loss" in optim_scope:
# optimizer for loss
for m in model_list[i].sublayers(True):
if m.__class__.__name__ == optim_scope:
optim_model.append(m)
if hasattr(model_list[i], 'loss_func'):
for j in range(len(model_list[i].loss_func)):
if model_list[i].loss_func[
j].__class__.__name__ == optim_scope:
optim_model.append(model_list[i].loss_func[j])
elif optim_scope == "model":
# opmizer for entire model
optim_model.append(model_list[i])
if not model_list[i].__class__.__name__.lower().endswith(
"loss"):
optim_model.append(model_list[i])
else:
# opmizer for module in model, such as backbone, neck, head...
if hasattr(model_list[i], optim_scope):
optim_model.append(getattr(model_list[i], optim_scope))
assert len(optim_model) == 1, \
"Invalid optim model for optim scope({}), number of optim_model={}".format(optim_scope, len(optim_model))
"Invalid optim model for optim scope({}), number of optim_model={}".\
format(optim_scope, [m.__class__.__name__ for m in optim_model])
optim = getattr(optimizer, optim_name)(
learning_rate=lr, grad_clip=grad_clip,
**optim_cfg)(model_list=optim_model)
logger.debug("build optimizer ({}) for scope ({}) success..".format(
optim, optim_scope))
logger.info("build optimizer ({}) for scope ({}) success..".format(
optim.__class__.__name__, optim_scope))
optim_list.append(optim)
lr_list.append(lr)
return optim_list, lr_list
......@@ -198,6 +198,7 @@ class Piecewise(object):
epochs,
warmup_epoch=0,
warmup_start_lr=0.0,
warmup_by_epoch=False,
last_epoch=-1,
**kwargs):
super().__init__()
......@@ -205,27 +206,61 @@ class Piecewise(object):
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
logger.warning(msg)
warmup_epoch = epochs
self.boundaries = [step_each_epoch * e for e in decay_epochs]
self.boundaries_steps = [step_each_epoch * e for e in decay_epochs]
self.boundaries_epoch = decay_epochs
self.values = values
self.last_epoch = last_epoch
self.warmup_steps = round(warmup_epoch * step_each_epoch)
self.warmup_epoch = warmup_epoch
self.warmup_start_lr = warmup_start_lr
self.warmup_by_epoch = warmup_by_epoch
def __call__(self):
learning_rate = lr.PiecewiseDecay(
boundaries=self.boundaries,
values=self.values,
last_epoch=self.last_epoch)
if self.warmup_steps > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_steps,
start_lr=self.warmup_start_lr,
end_lr=self.values[0],
if self.warmup_by_epoch is False:
learning_rate = lr.PiecewiseDecay(
boundaries=self.boundaries_steps,
values=self.values,
last_epoch=self.last_epoch)
if self.warmup_steps > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_steps,
start_lr=self.warmup_start_lr,
end_lr=self.values[0],
last_epoch=self.last_epoch)
else:
learning_rate = lr.PiecewiseDecay(
boundaries=self.boundaries_epoch,
values=self.values,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=self.warmup_start_lr,
end_lr=self.values[0],
last_epoch=self.last_epoch)
return learning_rate
class Constant(LRScheduler):
"""
Constant learning rate
Args:
lr (float): The initial learning rate. It is a python float number.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def __init__(self, learning_rate, last_epoch=-1, by_epoch=False, **kwargs):
self.learning_rate = learning_rate
self.last_epoch = last_epoch
self.by_epoch = by_epoch
super().__init__()
def get_lr(self):
return self.learning_rate
class MultiStepDecay(LRScheduler):
"""
Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册