提交 3a99409f 编写于 作者: B Bin Lu

Merge branch 'PaddlePaddle:develop_reg' into develop_reg

......@@ -43,19 +43,23 @@
## 文档教程
- [快速安装](./docs/zh_CN/tutorials/install.md)
- [图像识别快速体验]
- 图像识别快速体验(若愚)
- 图像分类快速体验(崔程,基于30分钟入门版修改)
- 算法介绍
- [图像识别系统]
- 图像识别系统] (胜禹)
- [模型库介绍和预训练模型](./docs/zh_CN/models/models_intro.md)
- [图像分类]
- ImageNet分类任务(崔程,基于30分钟进阶版修改)
- [多标签分类任务]()
- [特征学习]
- 商品识别
- 车辆识别
- logo识别
- 动漫人物识别
- [向量检索]
- [商品识别]()
- [车辆识别]()
- [logo识别]()
- [动漫人物识别]()
- [向量检索]()
- 模型训练/评估
- 图像分类任务
- 特征学习任务
- 图像分类任务(崔程,基于原有训练文档整理)
- 特征学习任务(陆彬)
- 模型预测
- [基于训练引擎预测推理](./docs/zh_CN/tutorials/getting_started.md)
- [基于Python预测引擎预测推理](./docs/zh_CN/tutorials/getting_started.md)
......
# 车辆细粒度分类
细粒度分类,是对属于某一类基础类别的图像进行子类别的细粉,如各种鸟、各种花、各种矿石之间。顾名思义,车辆细粒度分类是对车辆的不同子类别进行分类。
其训练过程与车辆ReID相比,有以下不同:
- 数据集不同
- Loss设置不同
其他部分请详见[车辆ReID](./vehicle_reid.md)
整体配置文件:[ResNet50.yaml](../../../ppcls/configs/Vehicle/ResNet50.yaml)
## 数据集
在此demo中,使用[CompCars](http://mmlab.ie.cuhk.edu.hk/datasets/comp_cars/index.html)作为训练数据集。
<img src="../../images/recognotion/vehicle/CompCars.png" style="zoom:50%;" />
图像主要来自网络和监控数据,其中网络数据包含163个汽车制造商、1716个汽车型号的汽车。共**136,726**张全车图像,**27,618**张部分车图像。其中网络汽车数据包含bounding box、视角、5个属性(最大速度、排量、车门数、车座数、汽车类型)。监控数据包含**50,000**张前视角图像。
值得注意的是,此数据集中需要根据自己的需要生成不同的label,如本demo中,将不同年份生产的相同型号的车辆视为同一类,因此,类别总数为:431类。
## Loss设置
与车辆ReID不同,在此分类中,Loss使用的是[TtripLet Loss](../../../ppcls/loss/triplet.py) + [ArcLoss](../../../ppcls/arch/gears/arcmargin.py),权重比例1:1。
# 车辆ReID
ReID,也就是 Re-identification,其定义是利用算法,在图像库中找到要搜索的目标的技术,所以它是属于图像检索的一个子问题。而车辆ReID就是给定一张车辆图像,找出同一摄像头不同的拍摄图像,或者不同摄像头下拍摄的同一车辆图像的过程。在此过程中,如何提取鲁棒特征,尤为重要。因此,此文档主要对车辆ReID中训练特征提取网络部分做相关介绍,内容如下:
- 数据集及预处理方式
- Backbone的具体设置
- Loss函数的相关设置
全部的超参数及具体配置:[ResNet50_ReID.yaml](../../../ppcls/configs/Vehicle/ResNet50_ReID.yaml)
## 数据集及预处理
### VERI-Wild数据集
<img src="../../images/recognotion/vehicle/cars.JPG" style="zoom:50%;" />
此数据集是在一个大型闭路电视监控系统,在无约束的场景下,一个月内(30*24小时)中捕获的。该系统由174个摄像头组成,其摄像机分布在200多平方公里的大型区域。原始车辆图像集包含1200万个车辆图像,经过数据清理和标注,采集了416314张40671个不同的车辆图像。[具体详见论文](https://github.com/PKU-IMRE/VERI-Wild)
## 数据预处理
由于原始的数据集中,车辆图像已经是由检测器检测后crop出的车辆图像,因此无需像训练`ImageNet`中图像crop操作。整体的数据增强方式,按照顺序如下:
- 图像`Resize`到224
- 随机水平翻转
- [AugMix](https://arxiv.org/abs/1912.02781v1)
- Normlize:归一化到0~1
- [RandomErasing](https://arxiv.org/pdf/1708.04896v2.pdf)
在配置文件中设置如下,详见`transform_ops`部分:
```yaml
DataLoader:
Train:
dataset:
# 具体使用的Dataset的的名称
name: "VeriWild"
# 使用此数据集的具体参数
image_root: "/work/dataset/VeRI-Wild/images/"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/train_list_start0.txt"
# 图像增广策略:ResizeImage、RandFlipImage等
transform_ops:
- ResizeImage:
size: 224
- RandFlipImage:
flip_code: 1
- AugMix:
prob: 0.5
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.5
sl: 0.02
sh: 0.4
r1: 0.3
mean: [0., 0., 0.]
sampler:
name: DistributedRandomIdentitySampler
batch_size: 128
num_instances: 2
drop_last: False
shuffle: True
loader:
num_workers: 6
use_shared_memory: False
```
## Backbone的具体设置
具体是用`ResNet50`作为backbone,但在`ResNet50`基础上做了如下修改:
- 对Last Stage(第4个stage),没有做下采样,即第4个stage的feature map和第3个stage的feature map大小一致,都是14x14。
- 在最后加入一个embedding 层,即1x1的卷积层,特征维度为512
具体代码:[ResNet50_last_stage_stride1](../../../ppcls/arch/backbone/variant_models/resnet_variant.py)
在配置文件中Backbone设置如下:
```yaml
Arch:
# 使用RecModel模型进行训练,目前支持普通ImageNet和RecModel两个方式
name: "RecModel"
# 导出inference model的具体配置
infer_output_key: "features"
infer_add_softmax: False
# 使用的Backbone
Backbone:
name: "ResNet50_last_stage_stride1"
pretrained: True
# 使用此层作为Backbone的feature输出,name为具体层的full_name
BackboneStopLayer:
name: "adaptive_avg_pool2d_0"
# Backbone的基础上,新增网络层。此模型添加1x1的卷积层(embedding)
Neck:
name: "VehicleNeck"
in_channels: 2048
out_channels: 512
# 增加ArcMargin, 即ArcLoss的具体实现
Head:
name: "ArcMargin"
embedding_size: 512
class_num: 431
margin: 0.15
scale: 32
```
## Loss的设置
车辆ReID中,使用了[SupConLoss](https://arxiv.org/abs/2004.11362) + [ArcLoss](https://arxiv.org/abs/1801.07698),其中权重比例为1:1
具体代码详见:[SupConLoss代码](../../../ppcls/loss/supconloss.py)[ArcLoss代码](../../../ppcls/arch/gears/arcmargin.py)
在配置文件中设置如下:
```yaml
Loss:
Train:
- CELoss:
weight: 1.0
- SupConLoss:
weight: 1.0
# SupConLoss的具体参数
views: 2
Eval:
- CELoss:
weight: 1.0
```
## 其他相关设置
### Optimizer设置
```yaml
Optimizer:
# 使用的优化器名称
name: Momentum
# 优化器具体参数
momentum: 0.9
lr:
# 使用的学习率调节具体名称
name: MultiStepDecay
# 学习率调节算法具体参数
learning_rate: 0.01
milestones: [30, 60, 70, 80, 90, 100, 120, 140]
gamma: 0.5
verbose: False
last_epoch: -1
regularizer:
name: 'L2'
coeff: 0.0005
```
### Eval Metric设置
```yaml
Metric:
Eval:
# 使用Recallk和mAP两种评价指标
- Recallk:
topk: [1, 5]
- mAP: {}
```
### 其他超参数设置
```yaml
Global:
# 如为null则从头开始训练。若指定中间训练保存的状态地址,则继续训练
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
class_num: 30671
# 保存模型的粒度,每个epoch保存一次
save_interval: 1
eval_during_train: True
eval_interval: 1
# 训练的epoch数
epochs: 160
# log输出频率
print_batch_step: 10
# 是否使用visualdl库
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: "./inference"
# 使用retrival的方式进行评测
eval_mode: "retrieval"
```
......@@ -18,6 +18,8 @@ Global:
# model architecture
Arch:
name: "RecModel"
infer_output_key: "features"
infer_add_softmax: False
Backbone:
name: "ResNet50_last_stage_stride1"
pretrained: True
......@@ -66,10 +68,10 @@ DataLoader:
Train:
dataset:
name: "CompCars"
image_root: "/work/dataset/CompCars/image/"
label_root: "/work/dataset/CompCars/label/"
image_root: "./dataset/CompCars/image/"
label_root: "./dataset/CompCars/label/"
bbox_crop: True
cls_label_path: "/work/dataset/CompCars/train_test_split/classification/train_label.txt"
cls_label_path: "./dataset/CompCars/train_test_split/classification/train_label.txt"
transform_ops:
- DecodeImage:
to_rgb: True
......@@ -106,7 +108,7 @@ DataLoader:
# TOTO: modify to the latest trainer
dataset:
name: "CompCars"
image_root: ".dataset/CompCars/image/"
image_root: "./dataset/CompCars/image/"
label_root: "./dataset/CompCars/label/"
cls_label_path: "./dataset/CompCars/train_test_split/classification/test_label.txt"
bbox_crop: True
......
......@@ -19,6 +19,8 @@ Global:
# model architecture
Arch:
name: "RecModel"
infer_output_key: "features"
infer_add_softmax: False
Backbone:
name: "ResNet50_last_stage_stride1"
pretrained: True
......@@ -31,7 +33,7 @@ Arch:
Head:
name: "ArcMargin"
embedding_size: 512
class_num: 431
class_num: 30671
margin: 0.15
scale: 32
......@@ -66,8 +68,8 @@ DataLoader:
Train:
dataset:
name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images/"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/train_list_start0.txt"
image_root: "./dataset/VeRI-Wild/images/"
cls_label_path: "./dataset/VeRI-Wild/train_test_split/train_list_start0.txt"
transform_ops:
- DecodeImage:
to_rgb: True
......@@ -104,8 +106,8 @@ DataLoader:
# TOTO: modify to the latest trainer
dataset:
name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/test_3000_id_query.txt"
image_root: "./dataset/VeRI-Wild/images"
cls_label_path: "./dataset/VeRI-Wild/train_test_split/test_3000_id_query.txt"
transform_ops:
- DecodeImage:
to_rgb: True
......@@ -130,8 +132,8 @@ DataLoader:
# TOTO: modify to the latest trainer
dataset:
name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/test_3000_id.txt"
image_root: "./dataset/VeRI-Wild/images"
cls_label_path: "./dataset/VeRI-Wild/train_test_split/test_3000_id.txt"
transform_ops:
- DecodeImage:
to_rgb: True
......
......@@ -54,13 +54,8 @@ def create_operators(params):
def build_dataloader(config, mode, device, seed=None):
assert mode in [
'Train',
'Eval',
'Test',
'Gallery',
'Query'
], "Mode should be Train, Eval, Test, Gallery, Query"
assert mode in ['Train', 'Eval', 'Test', 'Gallery', 'Query'
], "Mode should be Train, Eval, Test, Gallery, Query"
# build dataset
config_dataset = config[mode]['dataset']
config_dataset = copy.deepcopy(config_dataset)
......@@ -72,7 +67,7 @@ def build_dataloader(config, mode, device, seed=None):
dataset = eval(dataset_name)(**config_dataset)
logger.info("build dataset({}) success...".format(dataset))
logger.debug("build dataset({}) success...".format(dataset))
# build sampler
config_sampler = config[mode]['sampler']
......@@ -85,7 +80,7 @@ def build_dataloader(config, mode, device, seed=None):
sampler_name = config_sampler.pop("name")
batch_sampler = eval(sampler_name)(dataset, **config_sampler)
logger.info("build batch_sampler({}) success...".format(batch_sampler))
logger.debug("build batch_sampler({}) success...".format(batch_sampler))
# build batch operator
def mix_collate_fn(batch):
......@@ -132,5 +127,5 @@ def build_dataloader(config, mode, device, seed=None):
batch_sampler=batch_sampler,
collate_fn=batch_collate_fn)
logger.info("build data_loader({}) success...".format(data_loader))
logger.debug("build data_loader({}) success...".format(data_loader))
return data_loader
......@@ -30,6 +30,8 @@ import paddle.distributed as dist
from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
from ppcls.data import build_dataloader
from ppcls.arch import build_model
from ppcls.loss import build_loss
......@@ -49,6 +51,11 @@ class Trainer(object):
self.mode = mode
self.config = config
self.output_dir = self.config['Global']['output_dir']
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
f"{mode}.log")
init_logger(name='root', log_file=log_file)
print_config(config)
# set device
assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
self.device = paddle.set_device(self.config["Global"]["device"])
......@@ -153,8 +160,8 @@ class Trainer(object):
time_info[key].reset()
time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0]
batch[1] = paddle.to_tensor(batch[1].numpy().astype("int64")
.reshape([-1, 1]))
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
global_step += 1
# image input
if not self.is_rec:
......@@ -206,8 +213,9 @@ class Trainer(object):
eta_msg = "eta: {:s}".format(
str(datetime.timedelta(seconds=int(eta_sec))))
logger.info(
"[Train][Epoch {}][Iter: {}/{}]{}, {}, {}, {}, {}".
format(epoch_id, iter_id,
"[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".
format(epoch_id, self.config["Global"][
"epochs"], iter_id,
len(self.train_dataloader), lr_msg, metric_msg,
time_msg, ips_msg, eta_msg))
tic = time.time()
......@@ -216,8 +224,8 @@ class Trainer(object):
"{}: {:.5f}".format(key, output_info[key].avg)
for key in output_info
])
logger.info("[Train][Epoch {}][Avg]{}".format(epoch_id,
metric_msg))
logger.info("[Train][Epoch {}/{}][Avg]{}".format(
epoch_id, self.config["Global"]["epochs"], metric_msg))
output_info.clear()
# eval model and save model if possible
......@@ -327,7 +335,7 @@ class Trainer(object):
time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0]
batch[0] = paddle.to_tensor(batch[0]).astype("float32")
batch[1] = paddle.to_tensor(batch[1]).reshape([-1, 1])
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input
if self.is_rec:
out = self.model(batch[0], batch[1])
......@@ -438,9 +446,11 @@ class Trainer(object):
for key in metric_tmp:
if key not in metric_dict:
metric_dict[key] = metric_tmp[key] * block_fea.shape[0] / len(query_feas)
metric_dict[key] = metric_tmp[key] * block_fea.shape[
0] / len(query_feas)
else:
metric_dict[key] += metric_tmp[key] * block_fea.shape[0] / len(query_feas)
metric_dict[key] += metric_tmp[key] * block_fea.shape[
0] / len(query_feas)
metric_info_list = []
for key in metric_dict:
......@@ -467,10 +477,10 @@ class Trainer(object):
for idx, batch in enumerate(dataloader(
)): # load is very time-consuming
batch = [paddle.to_tensor(x) for x in batch]
batch[1] = batch[1].reshape([-1, 1])
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
if len(batch) == 3:
has_unique_id = True
batch[2] = batch[2].reshape([-1, 1])
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
out = self.model(batch[0], batch[1])
batch_feas = out["features"]
......
......@@ -52,5 +52,5 @@ class CombinedLoss(nn.Layer):
def build_loss(config):
module_class = CombinedLoss(copy.deepcopy(config))
logger.info("build loss {} success.".format(module_class))
logger.debug("build loss {} success.".format(module_class))
return module_class
......@@ -45,7 +45,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
config = copy.deepcopy(config)
# step1 build lr
lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
logger.info("build lr ({}) success..".format(lr))
logger.debug("build lr ({}) success..".format(lr))
# step2 build regularization
if 'regularizer' in config and config['regularizer'] is not None:
reg_config = config.pop('regularizer')
......@@ -53,7 +53,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
else:
reg = None
logger.info("build regularizer ({}) success..".format(reg))
logger.debug("build regularizer ({}) success..".format(reg))
# step3 build optimizer
optim_name = config.pop('name')
if 'clip_norm' in config:
......@@ -65,5 +65,5 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
weight_decay=reg,
grad_clip=grad_clip,
**config)(parameters=parameters)
logger.info("build optimizer ({}) success..".format(optim))
logger.debug("build optimizer ({}) success..".format(optim))
return optim, lr
......@@ -67,18 +67,14 @@ def print_dict(d, delimiter=0):
placeholder = "-" * 60
for k, v in sorted(d.items()):
if isinstance(v, dict):
logger.info("{}{} : ".format(delimiter * " ",
logger.coloring(k, "HEADER")))
logger.info("{}{} : ".format(delimiter * " ", k))
print_dict(v, delimiter + 4)
elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
logger.info("{}{} : ".format(delimiter * " ",
logger.coloring(str(k), "HEADER")))
logger.info("{}{} : ".format(delimiter * " ", k))
for value in v:
print_dict(value, delimiter + 4)
else:
logger.info("{}{} : {}".format(delimiter * " ",
logger.coloring(k, "HEADER"),
logger.coloring(v, "OKGREEN")))
logger.info("{}{} : {}".format(delimiter * " ", k, v))
if k.isupper():
logger.info(placeholder)
......@@ -141,7 +137,7 @@ def override(dl, ks, v):
if len(ks) == 1:
# assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
if not ks[0] in dl:
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
print('A new filed ({}) detected!'.format(ks[0], dl))
dl[ks[0]] = str2num(v)
else:
override(dl[ks[0]], ks[1:], v)
......@@ -175,7 +171,7 @@ def override_config(config, options=None):
return config
def get_config(fname, overrides=None, show=True):
def get_config(fname, overrides=None, show=False):
"""
Read config from file
"""
......
......@@ -12,70 +12,86 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import datetime
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S")
def time_zone(sec, fmt):
real_time = datetime.datetime.now()
return real_time.timetuple()
logging.Formatter.converter = time_zone
_logger = logging.getLogger(__name__)
import sys
Color = {
'RED': '\033[31m',
'HEADER': '\033[35m', # deep purple
'PURPLE': '\033[95m', # purple
'OKBLUE': '\033[94m',
'OKGREEN': '\033[92m',
'WARNING': '\033[93m',
'FAIL': '\033[91m',
'ENDC': '\033[0m'
}
def coloring(message, color="OKGREEN"):
assert color in Color.keys()
if os.environ.get('PADDLECLAS_COLORING', False):
return Color[color] + str(message) + Color["ENDC"]
import logging
import datetime
import paddle.distributed as dist
_logger = None
def init_logger(name='root', log_file=None, log_level=logging.INFO):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
global _logger
assert _logger is None, "logger should not be initialized twice or more."
_logger = logging.getLogger(name)
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S")
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
_logger.addHandler(stream_handler)
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
file_handler = logging.FileHandler(log_file, 'a')
file_handler.setFormatter(formatter)
_logger.addHandler(file_handler)
if dist.get_rank() == 0:
_logger.setLevel(log_level)
else:
return message
_logger.setLevel(logging.ERROR)
def anti_fleet(log):
def log_at_trainer0(log):
"""
logs will print multi-times when calling Fleet API.
Only display single log and ignore the others.
"""
def wrapper(fmt, *args):
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
if dist.get_rank() == 0:
log(fmt, *args)
return wrapper
@anti_fleet
@log_at_trainer0
def info(fmt, *args):
_logger.info(fmt, *args)
@anti_fleet
@log_at_trainer0
def debug(fmt, *args):
_logger.debug(fmt, *args)
@log_at_trainer0
def warning(fmt, *args):
_logger.warning(coloring(fmt, "RED"), *args)
_logger.warning(fmt, *args)
@anti_fleet
@log_at_trainer0
def error(fmt, *args):
_logger.error(coloring(fmt, "FAIL"), *args)
_logger.error(fmt, *args)
def scaler(name, value, step, writer):
......@@ -108,13 +124,12 @@ def advertise():
website = "https://github.com/PaddlePaddle/PaddleClas"
AD_LEN = 6 + len(max([copyright, ad, website], key=len))
info(
coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
"=" * (AD_LEN + 4),
"=={}==".format(copyright.center(AD_LEN)),
"=" * (AD_LEN + 4),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(ad.center(AD_LEN)),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(website.center(AD_LEN)),
"=" * (AD_LEN + 4), ), "RED"))
info("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
"=" * (AD_LEN + 4),
"=={}==".format(copyright.center(AD_LEN)),
"=" * (AD_LEN + 4),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(ad.center(AD_LEN)),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(website.center(AD_LEN)),
"=" * (AD_LEN + 4), ))
......@@ -115,19 +115,6 @@ def init_model(config, net, optimizer=None):
pretrained_model), "HEADER"))
def _save_student_model(net, model_prefix):
"""
save student model if the net is the network contains student
"""
student_model_prefix = model_prefix + "_student.pdparams"
if hasattr(net, "_layers"):
net = net._layers
if hasattr(net, "student"):
paddle.save(net.student.state_dict(), student_model_prefix)
logger.info("Already save student model in {}".format(
student_model_prefix))
def save_model(net,
optimizer,
metric_info,
......@@ -141,11 +128,9 @@ def save_model(net,
return
model_path = os.path.join(model_path, model_name)
_mkdir_if_not_exist(model_path)
model_prefix = os.path.join(model_path, prefix)
_save_student_model(net, model_prefix)
model_path = os.path.join(model_path, prefix)
paddle.save(net.state_dict(), model_prefix + ".pdparams")
paddle.save(optimizer.state_dict(), model_prefix + ".pdopt")
paddle.save(metric_info, model_prefix + ".pdstates")
paddle.save(net.state_dict(), model_path + ".pdparams")
paddle.save(optimizer.state_dict(), model_path + ".pdopt")
paddle.save(metric_info, model_path + ".pdstates")
logger.info("Already save model in {}".format(model_path))
......@@ -25,6 +25,7 @@ from ppcls.engine.trainer import Trainer
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(args.config, overrides=args.override, show=True)
config = config.get_config(
args.config, overrides=args.override, show=False)
trainer = Trainer(config, mode="eval")
trainer.eval()
......@@ -25,7 +25,8 @@ from ppcls.engine.trainer import Trainer
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(args.config, overrides=args.override, show=True)
config = config.get_config(
args.config, overrides=args.override, show=False)
trainer = Trainer(config, mode="infer")
trainer.infer()
......@@ -25,6 +25,7 @@ from ppcls.engine.trainer import Trainer
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(args.config, overrides=args.override, show=True)
config = config.get_config(
args.config, overrides=args.override, show=False)
trainer = Trainer(config, mode="train")
trainer.train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册