未验证 提交 dd79f81f 编写于 作者: L littletomatodonkey 提交者: GitHub

[WIP]add arch init (#744)

* polish trainer
上级 83056d44
......@@ -12,8 +12,54 @@
#See the License for the specific language governing permissions and
#limitations under the License.
import copy
import importlib
import paddle.nn as nn
from . import backbone
from .backbone import *
from ppcls.arch.loss_metrics.loss import *
from .utils import *
def build_model(config):
config = copy.deepcopy(config)
model_type = config.pop("name")
mod = importlib.import_module(__name__)
arch = getattr(mod, model_type)(**config)
return arch
class RecModel(nn.Layer):
def __init__(self, **config):
super().__init__()
backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name")
self.backbone = getattr(backbone_name)(**backbone_config)
if "backbone_stop_layer" in config:
backbone_stop_layer = config["backbone_stop_layer"]
self.backbone.stop_layer(backbone_stop_layer)
if "Neck" in config:
neck_config = config["Neck"]
neck_name = neck_config.pop("name")
self.neck = getattr(neck_name)(**neck_config)
else:
self.neck = None
if "Head" in config:
head_config = config["Head"]
head_name = head_config.pop("name")
self.head = getattr(head_name)(**head_config)
else:
self.head = None
def forward(self, x):
y = self.backbone(x)
if self.neck is not None:
y = self.neck(y)
if self.head is not None:
y = self.head(y)
return y
......@@ -17,11 +17,11 @@ from __future__ import absolute_import, division, print_function
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn import MaxPool2D
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils.save_load import load_dygraph_pretrain
__all__ = ["VGG11", "VGG13", "VGG16", "VGG19"]
......@@ -149,7 +149,12 @@ class ConvBlock(TheseusLayer):
class VGGNet(TheseusLayer):
def __init__(self, config, stop_grad_layers=0, class_num=1000):
def __init__(self,
config,
stop_grad_layers=0,
class_num=1000,
pretrained=False,
**args):
super().__init__()
self.stop_grad_layers = stop_grad_layers
......@@ -176,6 +181,9 @@ class VGGNet(TheseusLayer):
self._fc2 = Linear(4096, 4096)
self._out = Linear(4096, class_num)
if pretrained is not None:
load_dygraph_pretrain(self, pretrained)
def forward(self, inputs):
x = self._conv_block_1(inputs)
x = self._conv_block_2(x)
......
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import sys
import copy
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
# TODO: fix the format
class CELoss(nn.Layer):
"""
"""
def __init__(self, name="loss", epsilon=None):
super().__init__()
self.name = name
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None
self.epsilon = epsilon
def _labelsmoothing(self, target, class_num):
if target.shape[-1] != class_num:
one_hot_target = F.one_hot(target, class_num)
else:
one_hot_target = target
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
return soft_target
def forward(self, logits, label, mode="train"):
loss_dict = {}
if self.epsilon is not None:
class_num = logits.shape[-1]
label = self._labelsmoothing(label, class_num)
x = -F.log_softmax(x, axis=-1)
loss = paddle.sum(x * label, axis=-1)
else:
if label.shape[-1] == logits.shape[-1]:
label = F.softmax(label, axis=-1)
soft_label = True
else:
soft_label = False
loss = F.cross_entropy(logits, label=label, soft_label=soft_label)
loss_dict[self.name] = paddle.mean(loss)
return loss_dict
# TODO: fix the format
class Topk(nn.Layer):
def __init__(self, topk=[1, 5]):
super().__init__()
assert isinstance(topk, (int, list))
if isinstance(topk, int):
topk = [topk]
self.topk = topk
def forward(self, x, label):
metric_dict = dict()
for k in self.topk:
metric_dict["top{}".format(k)] = paddle.metric.accuracy(
x, label, k=k)
return metric_dict
# TODO: fix the format
def build_loss(config):
loss_func = CELoss()
return loss_func
# TODO: fix the format
def build_metrics(config):
metrics_func = Topk()
return metrics_func
# global configs
Global:
pretrained_model: ""
output_dir: "./output/"
device: "gpu"
class_num: 1000
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 90
print_batch_step: 10
use_visualdl: False
image_shape: [3, 224, 224]
infer_imgs:
# model architecture
Arch:
name: "ResNet50"
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Piecewise
learning_rate: 0.1
decay_epochs: [30, 60, 90]
values: [0.1, 0.01, 0.001, 0.0001]
regularizer:
name: 'L2'
coeff: 0.0001
# data loader for train and eval
DataLoader:
Train:
# Dataset:
# Sampler:
# Loader:
batch_size: 256
num_workers: 4
file_list: "./dataset/ILSVRC2012/train_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
Eval:
# TOTO: modify to the latest trainer
# Dataset:
# Sampler:
# Loader:
batch_size: 128
num_workers: 4
file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
Metric:
Train:
- Topk:
k: [1, 5]
Eval:
- Topk:
k: [1, 5]
......@@ -12,4 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .reader import Reader
import copy
import paddle
import os
from paddle.io import DistributedBatchSampler, BatchSampler, DataLoader
from ppcls.utils import logger
# TODO: fix the format
def build_dataloader(config, mode, device, seed=None):
from . import reader
from .reader import Reader
dataloader = Reader(config, mode=mode, places=device)()
return dataloader
......@@ -250,13 +250,14 @@ class Reader:
def __init__(self, config, mode='train', places=None):
try:
self.params = config[mode.upper()]
self.params = config[mode.capitalize()]
except KeyError:
raise ModeException(mode=mode)
use_mix = config.get('use_mix')
self.params['mode'] = mode
self.shuffle = mode == "train"
self.is_train = mode == "train"
self.collate_fn = None
self.batch_ops = []
......@@ -298,7 +299,7 @@ class Reader:
shuffle=False,
num_workers=self.params["num_workers"])
else:
is_train = self.params['mode'] == "train"
is_train = self.is_train
batch_sampler = DistributedBatchSampler(
dataset,
batch_size=batch_size,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
import argparse
import paddle
import paddle.nn as nn
import paddle.distributed as dist
from ppcls.utils import config
from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.data import build_dataloader
from ppcls.arch import build_model
from ppcls.arch.loss_metrics import build_loss
from ppcls.arch.loss_metrics import build_metrics
from ppcls.optimizer import build_optimizer
from ppcls.utils import save_load
class Trainer(object):
def __init__(self, mode="train"):
args = config.parse_args()
self.config = config.get_config(
args.config, overrides=args.override, show=True)
self.mode = mode
self.output_dir = self.config['Global']['output_dir']
# set device
assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
self.device = paddle.set_device(self.config["Global"]["device"])
# set dist
self.config["Global"][
"distributed"] = paddle.distributed.get_world_size() != 1
if self.config["Global"]["distributed"]:
dist.init_parallel_env()
self.model = build_model(self.config["Arch"])
if self.config["Global"]["distributed"]:
self.model = paddle.DataParallel(self.model)
self.vdl_writer = None
if self.config['Global']['use_visualdl']:
from visualdl import LogWriter
vdl_writer_path = os.path.join(self.output_dir, "vdl")
if not os.path.exists(vdl_writer_path):
os.makedirs(vdl_writer_path)
self.vdl_writer = LogWriter(logdir=vdl_writer_path)
logger.info('train with paddle {} and device {}'.format(
paddle.__version__, self.device))
def _build_metric_info(self, metric_config, mode="train"):
"""
_build_metric_info: build metrics according to current mode
Return:
metric: dict of the metrics info
"""
metric = None
mode = mode.capitalize()
if mode in metric_config and metric_config[mode] is not None:
metric = build_metrics(metric_config[mode])
return metric
def _build_loss_info(self, loss_config, mode="train"):
"""
_build_loss_info: build loss according to current mode
Return:
loss_dict: dict of the loss info
"""
loss = None
mode = mode.capitalize()
if mode in loss_config and loss_config[mode] is not None:
loss = build_loss(loss_config[mode])
return loss
def train(self):
# build train loss and metric info
loss_func = self._build_loss_info(self.config["Loss"])
metric_func = self._build_metric_info(self.config["Metric"])
train_dataloader = build_dataloader(self.config["DataLoader"], "train",
self.device)
step_each_epoch = len(train_dataloader)
optimizer, lr_sch = build_optimizer(self.config["Optimizer"],
self.config["Global"]["epochs"],
step_each_epoch,
self.model.parameters())
print_batch_step = self.config['Global']['print_batch_step']
save_interval = self.config["Global"]["save_interval"]
best_metric = {
"metric": 0.0,
"epoch": 0,
}
# key:
# val: metrics list word
output_info = dict()
# global iter counter
global_step = 0
for epoch_id in range(1, self.config["Global"]["epochs"] + 1):
self.model.train()
for iter_id, batch in enumerate(train_dataloader()):
batch_size = batch[0].shape[0]
batch[1] = paddle.to_tensor(batch[1].numpy().astype("int64")
.reshape([-1, 1]))
global_step += 1
# image input
out = self.model(batch[0])
# calc loss
loss_dict = loss_func(out, batch[-1])
for key in loss_dict:
if not key in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0],
batch_size)
# calc metric
if metric_func is not None:
metric_dict = metric_func(out, batch[-1])
for key in metric_dict:
if not key in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(metric_dict[key].numpy()[0],
batch_size)
if iter_id % print_batch_step == 0:
lr_msg = "lr: {:.5f}".format(lr_sch.get_lr())
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg)
for key in output_info
])
logger.info("[Train][Epoch {}][Iter: {}/{}]{}, {}".format(
epoch_id, iter_id,
len(train_dataloader), lr_msg, metric_msg))
# step opt and lr
loss_dict["loss"].backward()
optimizer.step()
optimizer.clear_grad()
lr_sch.step()
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg)
for key in output_info
])
logger.info("[Train][Epoch {}][Avg]{}".format(epoch_id,
metric_msg))
output_info.clear()
# eval model and save model if possible
if self.config["Global"][
"eval_during_train"] and epoch_id % self.config["Global"][
"eval_during_train"] == 0:
acc = self.eval(epoch_id)
if acc >= best_metric["metric"]:
best_metric["metric"] = acc
best_metric["epoch"] = epoch_id
save_load.save_model(
self.model,
optimizer,
self.output_dir,
model_name=self.config["Arch"]["name"],
prefix="best_model")
# save model
if epoch_id % save_interval == 0:
save_load.save_model(
self.model,
optimizer,
self.output_dir,
model_name=self.config["Arch"]["name"],
prefix="ppcls_epoch_{}".format(epoch_id))
def build_avg_metrics(self, info_dict):
return {key: AverageMeter(key, '7.5f') for key in info_dict}
@paddle.no_grad()
def eval(self, epoch_id=0):
output_info = dict()
eval_dataloader = build_dataloader(self.config["DataLoader"], "eval",
self.device)
self.model.eval()
print_batch_step = self.config["Global"]["print_batch_step"]
# build train loss and metric info
loss_func = self._build_loss_info(self.config["Loss"], "eval")
metric_func = self._build_metric_info(self.config["Metric"], "eval")
metric_key = None
for iter_id, batch in enumerate(eval_dataloader()):
batch_size = batch[0].shape[0]
batch[0] = paddle.to_tensor(batch[0]).astype("float32")
batch[1] = paddle.to_tensor(batch[1]).reshape([-1, 1])
# image input
out = self.model(batch[0])
# calc build
if loss_func is not None:
loss_dict = loss_func(out, batch[-1])
for key in loss_dict:
if not key in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0],
batch_size)
# calc metric
if metric_func is not None:
metric_dict = metric_func(out, batch[-1])
if paddle.distributed.get_world_size() > 1:
for key in metric_dict:
paddle.distributed.all_reduce(
metric_dict[key],
op=paddle.distributed.ReduceOp.SUM)
metric_dict[key] = metric_dict[
key] / paddle.distributed.get_world_size()
for key in metric_dict:
if metric_key is None:
metric_key = key
if not key in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(metric_dict[key].numpy()[0],
batch_size)
if iter_id % print_batch_step == 0:
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].val)
for key in output_info
])
logger.info("[Eval][Epoch {}][Iter: {}/{}]{}".format(
epoch_id, iter_id, len(eval_dataloader), metric_msg))
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg)
for key in output_info
])
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
self.model.train()
# do not try to save best model
if metric_func is None:
return -1
# return 1st metric in the dict
return output_info[metric_key].avg
def main():
trainer = Trainer()
trainer.train()
if __name__ == "__main__":
main()
......@@ -12,8 +12,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import paddle
from ppcls.utils import logger
from . import optimizer
from . import learning_rate
from .optimizer import OptimizerBuilder
from .learning_rate import LearningRateBuilder
__all__ = ['build_optimizer']
def build_lr_scheduler(lr_config, epochs, step_each_epoch):
from . import learning_rate
lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch})
if 'name' in lr_config:
lr_name = lr_config.pop('name')
lr = getattr(learning_rate, lr_name)(**lr_config)()
else:
lr = lr_config['learning_rate']
return lr
def build_optimizer(config, epochs, step_each_epoch, parameters):
config = copy.deepcopy(config)
# step1 build lr
lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
logger.info("build lr ({}) success..".format(lr))
# step2 build regularization
if 'regularizer' in config and config['regularizer'] is not None:
reg_config = config.pop('regularizer')
reg_name = reg_config.pop('name') + 'Decay'
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
else:
reg = None
logger.info("build regularizer ({}) success..".format(reg))
# step3 build optimizer
optim_name = config.pop('name')
if 'clip_norm' in config:
clip_norm = config.pop('clip_norm')
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
else:
grad_clip = None
optim = getattr(optimizer, optim_name)(learning_rate=lr,
weight_decay=reg,
grad_clip=grad_clip,
parameter_list=parameters,
**config)()
logger.info("build optimizer ({}) success..".format(optim))
return optim, lr
......@@ -11,149 +11,173 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from paddle.optimizer import lr
import sys
import math
from paddle.optimizer.lr import LinearWarmup
from paddle.optimizer.lr import PiecewiseDecay
from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.optimizer.lr import ExponentialDecay
__all__ = ['LearningRateBuilder']
class Cosine(CosineAnnealingDecay):
"""
Cosine learning rate decay
lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)
Args:
lr(float): initial learning rate
step_each_epoch(int): steps each epoch
epochs(int): total training epochs
"""
def __init__(self, lr, step_each_epoch, epochs, **kwargs):
super(Cosine, self).__init__(
learning_rate=lr,
T_max=step_each_epoch * epochs, )
self.update_specified = False
class Piecewise(PiecewiseDecay):
class Linear(object):
"""
Piecewise learning rate decay
Linear learning rate decay
Args:
lr(float): initial learning rate
step_each_epoch(int): steps each epoch
decay_epochs(list): piecewise decay epochs
gamma(float): decay factor
lr (float): The initial learning rate. It is a python float number.
epochs(int): The decay step size. It determines the decay cycle.
end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
power(float, optional): Power of polynomial. Default: 1.0.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def __init__(self, lr, step_each_epoch, decay_epochs, gamma=0.1, **kwargs):
boundaries = [step_each_epoch * e for e in decay_epochs]
lr_values = [lr * (gamma**i) for i in range(len(boundaries) + 1)]
super(Piecewise, self).__init__(
boundaries=boundaries, values=lr_values)
self.update_specified = False
def __init__(self,
learning_rate,
epochs,
step_each_epoch,
end_lr=0.0,
power=1.0,
warmup_epoch=0,
last_epoch=-1,
**kwargs):
super(Linear, self).__init__()
self.learning_rate = learning_rate
self.epochs = epochs * step_each_epoch
self.end_lr = end_lr
self.power = power
self.last_epoch = last_epoch
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
class CosineWarmup(LinearWarmup):
def __call__(self):
learning_rate = lr.PolynomialDecay(
learning_rate=self.learning_rate,
decay_steps=self.epochs,
end_lr=self.end_lr,
power=self.power,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
class Cosine(object):
"""
Cosine learning rate decay with warmup
[0, warmup_epoch): linear warmup
[warmup_epoch, epochs): cosine decay
Cosine learning rate decay
lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)
Args:
lr(float): initial learning rate
step_each_epoch(int): steps each epoch
epochs(int): total training epochs
warmup_epoch(int): epoch num of warmup
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def __init__(self, lr, step_each_epoch, epochs, warmup_epoch=5, **kwargs):
assert epochs > warmup_epoch, "total epoch({}) should be larger than warmup_epoch({}) in CosineWarmup.".format(
epochs, warmup_epoch)
warmup_step = warmup_epoch * step_each_epoch
start_lr = 0.0
end_lr = lr
lr_sch = Cosine(lr, step_each_epoch, epochs - warmup_epoch)
super(CosineWarmup, self).__init__(
learning_rate=lr_sch,
warmup_steps=warmup_step,
start_lr=start_lr,
end_lr=end_lr)
self.update_specified = False
def __init__(self,
learning_rate,
step_each_epoch,
epochs,
warmup_epoch=0,
last_epoch=-1,
**kwargs):
super(Cosine, self).__init__()
self.learning_rate = learning_rate
self.T_max = step_each_epoch * epochs
self.last_epoch = last_epoch
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
class ExponentialWarmup(LinearWarmup):
def __call__(self):
learning_rate = lr.CosineAnnealingDecay(
learning_rate=self.learning_rate,
T_max=self.T_max,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
class Step(object):
"""
Exponential learning rate decay with warmup
[0, warmup_epoch): linear warmup
[warmup_epoch, epochs): Exponential decay
Piecewise learning rate decay
Args:
lr(float): initial learning rate
step_each_epoch(int): steps each epoch
decay_epochs(float): decay epochs
decay_rate(float): decay rate
warmup_epoch(int): epoch num of warmup
learning_rate (float): The initial learning rate. It is a python float number.
step_size (int): the interval to update.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def __init__(self,
lr,
learning_rate,
step_size,
step_each_epoch,
decay_epochs=2.4,
decay_rate=0.97,
warmup_epoch=5,
gamma,
warmup_epoch=0,
last_epoch=-1,
**kwargs):
warmup_step = warmup_epoch * step_each_epoch
start_lr = 0.0
end_lr = lr
lr_sch = ExponentialDecay(lr, decay_rate)
super(ExponentialWarmup, self).__init__(
learning_rate=lr_sch,
warmup_steps=warmup_step,
start_lr=start_lr,
end_lr=end_lr)
# NOTE: hac method to update exponential lr scheduler
self.update_specified = True
self.update_start_step = warmup_step
self.update_step_interval = int(decay_epochs * step_each_epoch)
self.step_each_epoch = step_each_epoch
super(Step, self).__init__()
self.step_size = step_each_epoch * step_size
self.learning_rate = learning_rate
self.gamma = gamma
self.last_epoch = last_epoch
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
class LearningRateBuilder():
def __call__(self):
learning_rate = lr.StepDecay(
learning_rate=self.learning_rate,
step_size=self.step_size,
gamma=self.gamma,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
class Piecewise(object):
"""
Build learning rate variable
https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/layers_cn.html
Piecewise learning rate decay
Args:
function(str): class name of learning rate
params(dict): parameters used for init the class
boundaries(list): A list of steps numbers. The type of element in the list is python int.
values(list): A list of learning rate values that will be picked during different epoch boundaries.
The type of element in the list is python float.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def __init__(self,
function='Linear',
params={'lr': 0.1,
'steps': 100,
'end_lr': 0.0}):
self.function = function
self.params = params
step_each_epoch,
decay_epochs,
values,
warmup_epoch=0,
last_epoch=-1,
**kwargs):
super(Piecewise, self).__init__()
self.boundaries = [step_each_epoch * e for e in decay_epochs]
self.values = values
self.last_epoch = last_epoch
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
def __call__(self):
mod = sys.modules[__name__]
lr = getattr(mod, self.function)(**self.params)
return lr
learning_rate = lr.PiecewiseDecay(
boundaries=self.boundaries,
values=self.values,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.values[0],
last_epoch=self.last_epoch)
return learning_rate
......@@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os
import copy
import argparse
import yaml
from ppcls.utils import check
from ppcls.utils import logger
from ppcls.utils import check
__all__ = ['get_config']
......@@ -31,6 +31,9 @@ class AttrDict(dict):
else:
self[key] = value
def __deepcopy__(self, content):
return copy.deepcopy(dict(self))
def create_attr_dict(yaml_config):
from ast import literal_eval
......@@ -76,7 +79,6 @@ def print_dict(d, delimiter=0):
logger.info("{}{} : {}".format(delimiter * " ",
logger.coloring(k, "HEADER"),
logger.coloring(v, "OKGREEN")))
if k.isupper():
logger.info(placeholder)
......@@ -84,7 +86,6 @@ def print_dict(d, delimiter=0):
def print_config(config):
"""
visualize configs
Arguments:
config: configs
"""
......@@ -97,21 +98,15 @@ def check_config(config):
Check config
"""
check.check_version()
use_gpu = config.get('use_gpu', True)
if use_gpu:
check.check_gpu()
architecture = config.get('ARCHITECTURE')
check.check_architecture(architecture)
check.check_model_with_running_mode(architecture)
#check.check_architecture(architecture)
use_mix = config.get('use_mix', False)
check.check_mix(architecture, use_mix)
classes_num = config.get('classes_num')
check.check_classes_num(classes_num)
mode = config.get('mode', 'train')
if mode.lower() == 'train':
check.check_function_params(config, 'LEARNING_RATE')
......@@ -121,7 +116,6 @@ def check_config(config):
def override(dl, ks, v):
"""
Recursively replace dict of list
Args:
dl(dict or list): dict or list to be replaced
ks(list): list of keys
......@@ -147,19 +141,15 @@ def override(dl, ks, v):
if len(ks) == 1:
# assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
if not ks[0] in dl:
logger.warning('A new filed ({}) detected!'.format(ks[0]))
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
dl[ks[0]] = str2num(v)
else:
if not ks[0] in dl:
logger.warning('A new filed ({}) detected!'.format(ks[0]))
dl[ks[0]] = {}
override(dl[ks[0]], ks[1:], v)
def override_config(config, options=None):
"""
Recursively override the config
Args:
config(dict): dict to be replaced
options(list): list of pairs(key0.key1.idx.key2=value)
......@@ -167,7 +157,6 @@ def override_config(config, options=None):
'topk=2',
'VALID.transforms.1.ResizeImage.resize_short=300'
]
Returns:
config(dict): replaced config
"""
......@@ -183,7 +172,6 @@ def override_config(config, options=None):
key, value = pair
keys = key.split('.')
override(config, keys, value)
return config
......@@ -197,5 +185,23 @@ def get_config(fname, overrides=None, show=True):
override_config(config, overrides)
if show:
print_config(config)
check_config(config)
# check_config(config)
return config
def parse_args():
parser = argparse.ArgumentParser("generic-image-rec train script")
parser.add_argument(
'-c',
'--config',
type=str,
default='configs/config.yaml',
help='config file path')
parser.add_argument(
'-o',
'--override',
action='append',
default=[],
help='config options to be overridden')
args = parser.parse_args()
return args
......@@ -146,13 +146,13 @@ def _save_student_model(net, model_prefix):
student_model_prefix))
def save_model(net, optimizer, model_path, epoch_id, prefix='ppcls'):
def save_model(net, optimizer, model_path, model_name="", prefix='ppcls'):
"""
save model to the target path
"""
if paddle.distributed.get_rank() != 0:
return
model_path = os.path.join(model_path, str(epoch_id))
model_path = os.path.join(model_path, model_name)
_mkdir_if_not_exist(model_path)
model_prefix = os.path.join(model_path, prefix)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册