提交 3a99409f 编写于 作者: B Bin Lu

Merge branch 'PaddlePaddle:develop_reg' into develop_reg

...@@ -43,19 +43,23 @@ ...@@ -43,19 +43,23 @@
## 文档教程 ## 文档教程
- [快速安装](./docs/zh_CN/tutorials/install.md) - [快速安装](./docs/zh_CN/tutorials/install.md)
- [图像识别快速体验] - 图像识别快速体验(若愚)
- 图像分类快速体验(崔程,基于30分钟入门版修改)
- 算法介绍 - 算法介绍
- [图像识别系统] - 图像识别系统] (胜禹)
- [模型库介绍和预训练模型](./docs/zh_CN/models/models_intro.md) - [模型库介绍和预训练模型](./docs/zh_CN/models/models_intro.md)
- [图像分类]
- ImageNet分类任务(崔程,基于30分钟进阶版修改)
- [多标签分类任务]()
- [特征学习] - [特征学习]
- 商品识别 - [商品识别]()
- 车辆识别 - [车辆识别]()
- logo识别 - [logo识别]()
- 动漫人物识别 - [动漫人物识别]()
- [向量检索] - [向量检索]()
- 模型训练/评估 - 模型训练/评估
- 图像分类任务 - 图像分类任务(崔程,基于原有训练文档整理)
- 特征学习任务 - 特征学习任务(陆彬)
- 模型预测 - 模型预测
- [基于训练引擎预测推理](./docs/zh_CN/tutorials/getting_started.md) - [基于训练引擎预测推理](./docs/zh_CN/tutorials/getting_started.md)
- [基于Python预测引擎预测推理](./docs/zh_CN/tutorials/getting_started.md) - [基于Python预测引擎预测推理](./docs/zh_CN/tutorials/getting_started.md)
......
# 车辆细粒度分类
细粒度分类,是对属于某一类基础类别的图像进行子类别的细粉,如各种鸟、各种花、各种矿石之间。顾名思义,车辆细粒度分类是对车辆的不同子类别进行分类。
其训练过程与车辆ReID相比,有以下不同:
- 数据集不同
- Loss设置不同
其他部分请详见[车辆ReID](./vehicle_reid.md)
整体配置文件:[ResNet50.yaml](../../../ppcls/configs/Vehicle/ResNet50.yaml)
## 数据集
在此demo中,使用[CompCars](http://mmlab.ie.cuhk.edu.hk/datasets/comp_cars/index.html)作为训练数据集。
<img src="../../images/recognotion/vehicle/CompCars.png" style="zoom:50%;" />
图像主要来自网络和监控数据,其中网络数据包含163个汽车制造商、1716个汽车型号的汽车。共**136,726**张全车图像,**27,618**张部分车图像。其中网络汽车数据包含bounding box、视角、5个属性(最大速度、排量、车门数、车座数、汽车类型)。监控数据包含**50,000**张前视角图像。
值得注意的是,此数据集中需要根据自己的需要生成不同的label,如本demo中,将不同年份生产的相同型号的车辆视为同一类,因此,类别总数为:431类。
## Loss设置
与车辆ReID不同,在此分类中,Loss使用的是[TtripLet Loss](../../../ppcls/loss/triplet.py) + [ArcLoss](../../../ppcls/arch/gears/arcmargin.py),权重比例1:1。
# 车辆ReID
ReID,也就是 Re-identification,其定义是利用算法,在图像库中找到要搜索的目标的技术,所以它是属于图像检索的一个子问题。而车辆ReID就是给定一张车辆图像,找出同一摄像头不同的拍摄图像,或者不同摄像头下拍摄的同一车辆图像的过程。在此过程中,如何提取鲁棒特征,尤为重要。因此,此文档主要对车辆ReID中训练特征提取网络部分做相关介绍,内容如下:
- 数据集及预处理方式
- Backbone的具体设置
- Loss函数的相关设置
全部的超参数及具体配置:[ResNet50_ReID.yaml](../../../ppcls/configs/Vehicle/ResNet50_ReID.yaml)
## 数据集及预处理
### VERI-Wild数据集
<img src="../../images/recognotion/vehicle/cars.JPG" style="zoom:50%;" />
此数据集是在一个大型闭路电视监控系统,在无约束的场景下,一个月内(30*24小时)中捕获的。该系统由174个摄像头组成,其摄像机分布在200多平方公里的大型区域。原始车辆图像集包含1200万个车辆图像,经过数据清理和标注,采集了416314张40671个不同的车辆图像。[具体详见论文](https://github.com/PKU-IMRE/VERI-Wild)
## 数据预处理
由于原始的数据集中,车辆图像已经是由检测器检测后crop出的车辆图像,因此无需像训练`ImageNet`中图像crop操作。整体的数据增强方式,按照顺序如下:
- 图像`Resize`到224
- 随机水平翻转
- [AugMix](https://arxiv.org/abs/1912.02781v1)
- Normlize:归一化到0~1
- [RandomErasing](https://arxiv.org/pdf/1708.04896v2.pdf)
在配置文件中设置如下,详见`transform_ops`部分:
```yaml
DataLoader:
Train:
dataset:
# 具体使用的Dataset的的名称
name: "VeriWild"
# 使用此数据集的具体参数
image_root: "/work/dataset/VeRI-Wild/images/"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/train_list_start0.txt"
# 图像增广策略:ResizeImage、RandFlipImage等
transform_ops:
- ResizeImage:
size: 224
- RandFlipImage:
flip_code: 1
- AugMix:
prob: 0.5
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.5
sl: 0.02
sh: 0.4
r1: 0.3
mean: [0., 0., 0.]
sampler:
name: DistributedRandomIdentitySampler
batch_size: 128
num_instances: 2
drop_last: False
shuffle: True
loader:
num_workers: 6
use_shared_memory: False
```
## Backbone的具体设置
具体是用`ResNet50`作为backbone,但在`ResNet50`基础上做了如下修改:
- 对Last Stage(第4个stage),没有做下采样,即第4个stage的feature map和第3个stage的feature map大小一致,都是14x14。
- 在最后加入一个embedding 层,即1x1的卷积层,特征维度为512
具体代码:[ResNet50_last_stage_stride1](../../../ppcls/arch/backbone/variant_models/resnet_variant.py)
在配置文件中Backbone设置如下:
```yaml
Arch:
# 使用RecModel模型进行训练,目前支持普通ImageNet和RecModel两个方式
name: "RecModel"
# 导出inference model的具体配置
infer_output_key: "features"
infer_add_softmax: False
# 使用的Backbone
Backbone:
name: "ResNet50_last_stage_stride1"
pretrained: True
# 使用此层作为Backbone的feature输出,name为具体层的full_name
BackboneStopLayer:
name: "adaptive_avg_pool2d_0"
# Backbone的基础上,新增网络层。此模型添加1x1的卷积层(embedding)
Neck:
name: "VehicleNeck"
in_channels: 2048
out_channels: 512
# 增加ArcMargin, 即ArcLoss的具体实现
Head:
name: "ArcMargin"
embedding_size: 512
class_num: 431
margin: 0.15
scale: 32
```
## Loss的设置
车辆ReID中,使用了[SupConLoss](https://arxiv.org/abs/2004.11362) + [ArcLoss](https://arxiv.org/abs/1801.07698),其中权重比例为1:1
具体代码详见:[SupConLoss代码](../../../ppcls/loss/supconloss.py)[ArcLoss代码](../../../ppcls/arch/gears/arcmargin.py)
在配置文件中设置如下:
```yaml
Loss:
Train:
- CELoss:
weight: 1.0
- SupConLoss:
weight: 1.0
# SupConLoss的具体参数
views: 2
Eval:
- CELoss:
weight: 1.0
```
## 其他相关设置
### Optimizer设置
```yaml
Optimizer:
# 使用的优化器名称
name: Momentum
# 优化器具体参数
momentum: 0.9
lr:
# 使用的学习率调节具体名称
name: MultiStepDecay
# 学习率调节算法具体参数
learning_rate: 0.01
milestones: [30, 60, 70, 80, 90, 100, 120, 140]
gamma: 0.5
verbose: False
last_epoch: -1
regularizer:
name: 'L2'
coeff: 0.0005
```
### Eval Metric设置
```yaml
Metric:
Eval:
# 使用Recallk和mAP两种评价指标
- Recallk:
topk: [1, 5]
- mAP: {}
```
### 其他超参数设置
```yaml
Global:
# 如为null则从头开始训练。若指定中间训练保存的状态地址,则继续训练
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
class_num: 30671
# 保存模型的粒度,每个epoch保存一次
save_interval: 1
eval_during_train: True
eval_interval: 1
# 训练的epoch数
epochs: 160
# log输出频率
print_batch_step: 10
# 是否使用visualdl库
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: "./inference"
# 使用retrival的方式进行评测
eval_mode: "retrieval"
```
...@@ -18,6 +18,8 @@ Global: ...@@ -18,6 +18,8 @@ Global:
# model architecture # model architecture
Arch: Arch:
name: "RecModel" name: "RecModel"
infer_output_key: "features"
infer_add_softmax: False
Backbone: Backbone:
name: "ResNet50_last_stage_stride1" name: "ResNet50_last_stage_stride1"
pretrained: True pretrained: True
...@@ -66,10 +68,10 @@ DataLoader: ...@@ -66,10 +68,10 @@ DataLoader:
Train: Train:
dataset: dataset:
name: "CompCars" name: "CompCars"
image_root: "/work/dataset/CompCars/image/" image_root: "./dataset/CompCars/image/"
label_root: "/work/dataset/CompCars/label/" label_root: "./dataset/CompCars/label/"
bbox_crop: True bbox_crop: True
cls_label_path: "/work/dataset/CompCars/train_test_split/classification/train_label.txt" cls_label_path: "./dataset/CompCars/train_test_split/classification/train_label.txt"
transform_ops: transform_ops:
- DecodeImage: - DecodeImage:
to_rgb: True to_rgb: True
...@@ -106,7 +108,7 @@ DataLoader: ...@@ -106,7 +108,7 @@ DataLoader:
# TOTO: modify to the latest trainer # TOTO: modify to the latest trainer
dataset: dataset:
name: "CompCars" name: "CompCars"
image_root: ".dataset/CompCars/image/" image_root: "./dataset/CompCars/image/"
label_root: "./dataset/CompCars/label/" label_root: "./dataset/CompCars/label/"
cls_label_path: "./dataset/CompCars/train_test_split/classification/test_label.txt" cls_label_path: "./dataset/CompCars/train_test_split/classification/test_label.txt"
bbox_crop: True bbox_crop: True
......
...@@ -19,6 +19,8 @@ Global: ...@@ -19,6 +19,8 @@ Global:
# model architecture # model architecture
Arch: Arch:
name: "RecModel" name: "RecModel"
infer_output_key: "features"
infer_add_softmax: False
Backbone: Backbone:
name: "ResNet50_last_stage_stride1" name: "ResNet50_last_stage_stride1"
pretrained: True pretrained: True
...@@ -31,7 +33,7 @@ Arch: ...@@ -31,7 +33,7 @@ Arch:
Head: Head:
name: "ArcMargin" name: "ArcMargin"
embedding_size: 512 embedding_size: 512
class_num: 431 class_num: 30671
margin: 0.15 margin: 0.15
scale: 32 scale: 32
...@@ -66,8 +68,8 @@ DataLoader: ...@@ -66,8 +68,8 @@ DataLoader:
Train: Train:
dataset: dataset:
name: "VeriWild" name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images/" image_root: "./dataset/VeRI-Wild/images/"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/train_list_start0.txt" cls_label_path: "./dataset/VeRI-Wild/train_test_split/train_list_start0.txt"
transform_ops: transform_ops:
- DecodeImage: - DecodeImage:
to_rgb: True to_rgb: True
...@@ -104,8 +106,8 @@ DataLoader: ...@@ -104,8 +106,8 @@ DataLoader:
# TOTO: modify to the latest trainer # TOTO: modify to the latest trainer
dataset: dataset:
name: "VeriWild" name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images" image_root: "./dataset/VeRI-Wild/images"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/test_3000_id_query.txt" cls_label_path: "./dataset/VeRI-Wild/train_test_split/test_3000_id_query.txt"
transform_ops: transform_ops:
- DecodeImage: - DecodeImage:
to_rgb: True to_rgb: True
...@@ -130,8 +132,8 @@ DataLoader: ...@@ -130,8 +132,8 @@ DataLoader:
# TOTO: modify to the latest trainer # TOTO: modify to the latest trainer
dataset: dataset:
name: "VeriWild" name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images" image_root: "./dataset/VeRI-Wild/images"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/test_3000_id.txt" cls_label_path: "./dataset/VeRI-Wild/train_test_split/test_3000_id.txt"
transform_ops: transform_ops:
- DecodeImage: - DecodeImage:
to_rgb: True to_rgb: True
......
...@@ -54,13 +54,8 @@ def create_operators(params): ...@@ -54,13 +54,8 @@ def create_operators(params):
def build_dataloader(config, mode, device, seed=None): def build_dataloader(config, mode, device, seed=None):
assert mode in [ assert mode in ['Train', 'Eval', 'Test', 'Gallery', 'Query'
'Train', ], "Mode should be Train, Eval, Test, Gallery, Query"
'Eval',
'Test',
'Gallery',
'Query'
], "Mode should be Train, Eval, Test, Gallery, Query"
# build dataset # build dataset
config_dataset = config[mode]['dataset'] config_dataset = config[mode]['dataset']
config_dataset = copy.deepcopy(config_dataset) config_dataset = copy.deepcopy(config_dataset)
...@@ -72,7 +67,7 @@ def build_dataloader(config, mode, device, seed=None): ...@@ -72,7 +67,7 @@ def build_dataloader(config, mode, device, seed=None):
dataset = eval(dataset_name)(**config_dataset) dataset = eval(dataset_name)(**config_dataset)
logger.info("build dataset({}) success...".format(dataset)) logger.debug("build dataset({}) success...".format(dataset))
# build sampler # build sampler
config_sampler = config[mode]['sampler'] config_sampler = config[mode]['sampler']
...@@ -85,7 +80,7 @@ def build_dataloader(config, mode, device, seed=None): ...@@ -85,7 +80,7 @@ def build_dataloader(config, mode, device, seed=None):
sampler_name = config_sampler.pop("name") sampler_name = config_sampler.pop("name")
batch_sampler = eval(sampler_name)(dataset, **config_sampler) batch_sampler = eval(sampler_name)(dataset, **config_sampler)
logger.info("build batch_sampler({}) success...".format(batch_sampler)) logger.debug("build batch_sampler({}) success...".format(batch_sampler))
# build batch operator # build batch operator
def mix_collate_fn(batch): def mix_collate_fn(batch):
...@@ -132,5 +127,5 @@ def build_dataloader(config, mode, device, seed=None): ...@@ -132,5 +127,5 @@ def build_dataloader(config, mode, device, seed=None):
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=batch_collate_fn) collate_fn=batch_collate_fn)
logger.info("build data_loader({}) success...".format(data_loader)) logger.debug("build data_loader({}) success...".format(data_loader))
return data_loader return data_loader
...@@ -30,6 +30,8 @@ import paddle.distributed as dist ...@@ -30,6 +30,8 @@ import paddle.distributed as dist
from ppcls.utils.check import check_gpu from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger from ppcls.utils import logger
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
from ppcls.data import build_dataloader from ppcls.data import build_dataloader
from ppcls.arch import build_model from ppcls.arch import build_model
from ppcls.loss import build_loss from ppcls.loss import build_loss
...@@ -49,6 +51,11 @@ class Trainer(object): ...@@ -49,6 +51,11 @@ class Trainer(object):
self.mode = mode self.mode = mode
self.config = config self.config = config
self.output_dir = self.config['Global']['output_dir'] self.output_dir = self.config['Global']['output_dir']
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
f"{mode}.log")
init_logger(name='root', log_file=log_file)
print_config(config)
# set device # set device
assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"] assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
self.device = paddle.set_device(self.config["Global"]["device"]) self.device = paddle.set_device(self.config["Global"]["device"])
...@@ -153,8 +160,8 @@ class Trainer(object): ...@@ -153,8 +160,8 @@ class Trainer(object):
time_info[key].reset() time_info[key].reset()
time_info["reader_cost"].update(time.time() - tic) time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0] batch_size = batch[0].shape[0]
batch[1] = paddle.to_tensor(batch[1].numpy().astype("int64") batch[1] = batch[1].reshape([-1, 1]).astype("int64")
.reshape([-1, 1]))
global_step += 1 global_step += 1
# image input # image input
if not self.is_rec: if not self.is_rec:
...@@ -206,8 +213,9 @@ class Trainer(object): ...@@ -206,8 +213,9 @@ class Trainer(object):
eta_msg = "eta: {:s}".format( eta_msg = "eta: {:s}".format(
str(datetime.timedelta(seconds=int(eta_sec)))) str(datetime.timedelta(seconds=int(eta_sec))))
logger.info( logger.info(
"[Train][Epoch {}][Iter: {}/{}]{}, {}, {}, {}, {}". "[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".
format(epoch_id, iter_id, format(epoch_id, self.config["Global"][
"epochs"], iter_id,
len(self.train_dataloader), lr_msg, metric_msg, len(self.train_dataloader), lr_msg, metric_msg,
time_msg, ips_msg, eta_msg)) time_msg, ips_msg, eta_msg))
tic = time.time() tic = time.time()
...@@ -216,8 +224,8 @@ class Trainer(object): ...@@ -216,8 +224,8 @@ class Trainer(object):
"{}: {:.5f}".format(key, output_info[key].avg) "{}: {:.5f}".format(key, output_info[key].avg)
for key in output_info for key in output_info
]) ])
logger.info("[Train][Epoch {}][Avg]{}".format(epoch_id, logger.info("[Train][Epoch {}/{}][Avg]{}".format(
metric_msg)) epoch_id, self.config["Global"]["epochs"], metric_msg))
output_info.clear() output_info.clear()
# eval model and save model if possible # eval model and save model if possible
...@@ -327,7 +335,7 @@ class Trainer(object): ...@@ -327,7 +335,7 @@ class Trainer(object):
time_info["reader_cost"].update(time.time() - tic) time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0] batch_size = batch[0].shape[0]
batch[0] = paddle.to_tensor(batch[0]).astype("float32") batch[0] = paddle.to_tensor(batch[0]).astype("float32")
batch[1] = paddle.to_tensor(batch[1]).reshape([-1, 1]) batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input # image input
if self.is_rec: if self.is_rec:
out = self.model(batch[0], batch[1]) out = self.model(batch[0], batch[1])
...@@ -438,9 +446,11 @@ class Trainer(object): ...@@ -438,9 +446,11 @@ class Trainer(object):
for key in metric_tmp: for key in metric_tmp:
if key not in metric_dict: if key not in metric_dict:
metric_dict[key] = metric_tmp[key] * block_fea.shape[0] / len(query_feas) metric_dict[key] = metric_tmp[key] * block_fea.shape[
0] / len(query_feas)
else: else:
metric_dict[key] += metric_tmp[key] * block_fea.shape[0] / len(query_feas) metric_dict[key] += metric_tmp[key] * block_fea.shape[
0] / len(query_feas)
metric_info_list = [] metric_info_list = []
for key in metric_dict: for key in metric_dict:
...@@ -467,10 +477,10 @@ class Trainer(object): ...@@ -467,10 +477,10 @@ class Trainer(object):
for idx, batch in enumerate(dataloader( for idx, batch in enumerate(dataloader(
)): # load is very time-consuming )): # load is very time-consuming
batch = [paddle.to_tensor(x) for x in batch] batch = [paddle.to_tensor(x) for x in batch]
batch[1] = batch[1].reshape([-1, 1]) batch[1] = batch[1].reshape([-1, 1]).astype("int64")
if len(batch) == 3: if len(batch) == 3:
has_unique_id = True has_unique_id = True
batch[2] = batch[2].reshape([-1, 1]) batch[2] = batch[2].reshape([-1, 1]).astype("int64")
out = self.model(batch[0], batch[1]) out = self.model(batch[0], batch[1])
batch_feas = out["features"] batch_feas = out["features"]
......
...@@ -52,5 +52,5 @@ class CombinedLoss(nn.Layer): ...@@ -52,5 +52,5 @@ class CombinedLoss(nn.Layer):
def build_loss(config): def build_loss(config):
module_class = CombinedLoss(copy.deepcopy(config)) module_class = CombinedLoss(copy.deepcopy(config))
logger.info("build loss {} success.".format(module_class)) logger.debug("build loss {} success.".format(module_class))
return module_class return module_class
...@@ -45,7 +45,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): ...@@ -45,7 +45,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
config = copy.deepcopy(config) config = copy.deepcopy(config)
# step1 build lr # step1 build lr
lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch) lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
logger.info("build lr ({}) success..".format(lr)) logger.debug("build lr ({}) success..".format(lr))
# step2 build regularization # step2 build regularization
if 'regularizer' in config and config['regularizer'] is not None: if 'regularizer' in config and config['regularizer'] is not None:
reg_config = config.pop('regularizer') reg_config = config.pop('regularizer')
...@@ -53,7 +53,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): ...@@ -53,7 +53,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
reg = getattr(paddle.regularizer, reg_name)(**reg_config) reg = getattr(paddle.regularizer, reg_name)(**reg_config)
else: else:
reg = None reg = None
logger.info("build regularizer ({}) success..".format(reg)) logger.debug("build regularizer ({}) success..".format(reg))
# step3 build optimizer # step3 build optimizer
optim_name = config.pop('name') optim_name = config.pop('name')
if 'clip_norm' in config: if 'clip_norm' in config:
...@@ -65,5 +65,5 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): ...@@ -65,5 +65,5 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
weight_decay=reg, weight_decay=reg,
grad_clip=grad_clip, grad_clip=grad_clip,
**config)(parameters=parameters) **config)(parameters=parameters)
logger.info("build optimizer ({}) success..".format(optim)) logger.debug("build optimizer ({}) success..".format(optim))
return optim, lr return optim, lr
...@@ -67,18 +67,14 @@ def print_dict(d, delimiter=0): ...@@ -67,18 +67,14 @@ def print_dict(d, delimiter=0):
placeholder = "-" * 60 placeholder = "-" * 60
for k, v in sorted(d.items()): for k, v in sorted(d.items()):
if isinstance(v, dict): if isinstance(v, dict):
logger.info("{}{} : ".format(delimiter * " ", logger.info("{}{} : ".format(delimiter * " ", k))
logger.coloring(k, "HEADER")))
print_dict(v, delimiter + 4) print_dict(v, delimiter + 4)
elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict): elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
logger.info("{}{} : ".format(delimiter * " ", logger.info("{}{} : ".format(delimiter * " ", k))
logger.coloring(str(k), "HEADER")))
for value in v: for value in v:
print_dict(value, delimiter + 4) print_dict(value, delimiter + 4)
else: else:
logger.info("{}{} : {}".format(delimiter * " ", logger.info("{}{} : {}".format(delimiter * " ", k, v))
logger.coloring(k, "HEADER"),
logger.coloring(v, "OKGREEN")))
if k.isupper(): if k.isupper():
logger.info(placeholder) logger.info(placeholder)
...@@ -141,7 +137,7 @@ def override(dl, ks, v): ...@@ -141,7 +137,7 @@ def override(dl, ks, v):
if len(ks) == 1: if len(ks) == 1:
# assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl)) # assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
if not ks[0] in dl: if not ks[0] in dl:
logger.warning('A new filed ({}) detected!'.format(ks[0], dl)) print('A new filed ({}) detected!'.format(ks[0], dl))
dl[ks[0]] = str2num(v) dl[ks[0]] = str2num(v)
else: else:
override(dl[ks[0]], ks[1:], v) override(dl[ks[0]], ks[1:], v)
...@@ -175,7 +171,7 @@ def override_config(config, options=None): ...@@ -175,7 +171,7 @@ def override_config(config, options=None):
return config return config
def get_config(fname, overrides=None, show=True): def get_config(fname, overrides=None, show=False):
""" """
Read config from file Read config from file
""" """
......
...@@ -12,70 +12,86 @@ ...@@ -12,70 +12,86 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import os import os
import datetime import sys
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S")
def time_zone(sec, fmt):
real_time = datetime.datetime.now()
return real_time.timetuple()
logging.Formatter.converter = time_zone
_logger = logging.getLogger(__name__)
Color = { import logging
'RED': '\033[31m', import datetime
'HEADER': '\033[35m', # deep purple import paddle.distributed as dist
'PURPLE': '\033[95m', # purple
'OKBLUE': '\033[94m', _logger = None
'OKGREEN': '\033[92m',
'WARNING': '\033[93m',
'FAIL': '\033[91m', def init_logger(name='root', log_file=None, log_level=logging.INFO):
'ENDC': '\033[0m' """Initialize and get a logger by name.
} If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
def coloring(message, color="OKGREEN"): added. If `log_file` is specified a FileHandler will also be added.
assert color in Color.keys() Args:
if os.environ.get('PADDLECLAS_COLORING', False): name (str): Logger name.
return Color[color] + str(message) + Color["ENDC"] log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
global _logger
assert _logger is None, "logger should not be initialized twice or more."
_logger = logging.getLogger(name)
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S")
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
_logger.addHandler(stream_handler)
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
file_handler = logging.FileHandler(log_file, 'a')
file_handler.setFormatter(formatter)
_logger.addHandler(file_handler)
if dist.get_rank() == 0:
_logger.setLevel(log_level)
else: else:
return message _logger.setLevel(logging.ERROR)
def anti_fleet(log): def log_at_trainer0(log):
""" """
logs will print multi-times when calling Fleet API. logs will print multi-times when calling Fleet API.
Only display single log and ignore the others. Only display single log and ignore the others.
""" """
def wrapper(fmt, *args): def wrapper(fmt, *args):
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0: if dist.get_rank() == 0:
log(fmt, *args) log(fmt, *args)
return wrapper return wrapper
@anti_fleet @log_at_trainer0
def info(fmt, *args): def info(fmt, *args):
_logger.info(fmt, *args) _logger.info(fmt, *args)
@anti_fleet @log_at_trainer0
def debug(fmt, *args):
_logger.debug(fmt, *args)
@log_at_trainer0
def warning(fmt, *args): def warning(fmt, *args):
_logger.warning(coloring(fmt, "RED"), *args) _logger.warning(fmt, *args)
@anti_fleet @log_at_trainer0
def error(fmt, *args): def error(fmt, *args):
_logger.error(coloring(fmt, "FAIL"), *args) _logger.error(fmt, *args)
def scaler(name, value, step, writer): def scaler(name, value, step, writer):
...@@ -108,13 +124,12 @@ def advertise(): ...@@ -108,13 +124,12 @@ def advertise():
website = "https://github.com/PaddlePaddle/PaddleClas" website = "https://github.com/PaddlePaddle/PaddleClas"
AD_LEN = 6 + len(max([copyright, ad, website], key=len)) AD_LEN = 6 + len(max([copyright, ad, website], key=len))
info( info("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format( "=" * (AD_LEN + 4),
"=" * (AD_LEN + 4), "=={}==".format(copyright.center(AD_LEN)),
"=={}==".format(copyright.center(AD_LEN)), "=" * (AD_LEN + 4),
"=" * (AD_LEN + 4), "=={}==".format(' ' * AD_LEN),
"=={}==".format(' ' * AD_LEN), "=={}==".format(ad.center(AD_LEN)),
"=={}==".format(ad.center(AD_LEN)), "=={}==".format(' ' * AD_LEN),
"=={}==".format(' ' * AD_LEN), "=={}==".format(website.center(AD_LEN)),
"=={}==".format(website.center(AD_LEN)), "=" * (AD_LEN + 4), ))
"=" * (AD_LEN + 4), ), "RED"))
...@@ -115,19 +115,6 @@ def init_model(config, net, optimizer=None): ...@@ -115,19 +115,6 @@ def init_model(config, net, optimizer=None):
pretrained_model), "HEADER")) pretrained_model), "HEADER"))
def _save_student_model(net, model_prefix):
"""
save student model if the net is the network contains student
"""
student_model_prefix = model_prefix + "_student.pdparams"
if hasattr(net, "_layers"):
net = net._layers
if hasattr(net, "student"):
paddle.save(net.student.state_dict(), student_model_prefix)
logger.info("Already save student model in {}".format(
student_model_prefix))
def save_model(net, def save_model(net,
optimizer, optimizer,
metric_info, metric_info,
...@@ -141,11 +128,9 @@ def save_model(net, ...@@ -141,11 +128,9 @@ def save_model(net,
return return
model_path = os.path.join(model_path, model_name) model_path = os.path.join(model_path, model_name)
_mkdir_if_not_exist(model_path) _mkdir_if_not_exist(model_path)
model_prefix = os.path.join(model_path, prefix) model_path = os.path.join(model_path, prefix)
_save_student_model(net, model_prefix)
paddle.save(net.state_dict(), model_prefix + ".pdparams") paddle.save(net.state_dict(), model_path + ".pdparams")
paddle.save(optimizer.state_dict(), model_prefix + ".pdopt") paddle.save(optimizer.state_dict(), model_path + ".pdopt")
paddle.save(metric_info, model_prefix + ".pdstates") paddle.save(metric_info, model_path + ".pdstates")
logger.info("Already save model in {}".format(model_path)) logger.info("Already save model in {}".format(model_path))
...@@ -25,6 +25,7 @@ from ppcls.engine.trainer import Trainer ...@@ -25,6 +25,7 @@ from ppcls.engine.trainer import Trainer
if __name__ == "__main__": if __name__ == "__main__":
args = config.parse_args() args = config.parse_args()
config = config.get_config(args.config, overrides=args.override, show=True) config = config.get_config(
args.config, overrides=args.override, show=False)
trainer = Trainer(config, mode="eval") trainer = Trainer(config, mode="eval")
trainer.eval() trainer.eval()
...@@ -25,7 +25,8 @@ from ppcls.engine.trainer import Trainer ...@@ -25,7 +25,8 @@ from ppcls.engine.trainer import Trainer
if __name__ == "__main__": if __name__ == "__main__":
args = config.parse_args() args = config.parse_args()
config = config.get_config(args.config, overrides=args.override, show=True) config = config.get_config(
args.config, overrides=args.override, show=False)
trainer = Trainer(config, mode="infer") trainer = Trainer(config, mode="infer")
trainer.infer() trainer.infer()
...@@ -25,6 +25,7 @@ from ppcls.engine.trainer import Trainer ...@@ -25,6 +25,7 @@ from ppcls.engine.trainer import Trainer
if __name__ == "__main__": if __name__ == "__main__":
args = config.parse_args() args = config.parse_args()
config = config.get_config(args.config, overrides=args.override, show=True) config = config.get_config(
args.config, overrides=args.override, show=False)
trainer = Trainer(config, mode="train") trainer = Trainer(config, mode="train")
trainer.train() trainer.train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册