From e5d3a2d88012daf3ed89a46e79555244de7f2b43 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Thu, 3 Jun 2021 05:30:43 +0000 Subject: [PATCH] fix distillation arch and model init --- ...c_chinese_lite_train_distillation_v2.1.yml | 32 ++++---- ppocr/losses/basic_loss.py | 26 ++++-- ppocr/losses/distillation_loss.py | 41 +++++----- .../architectures/distillation_model.py | 21 ++--- ppocr/utils/save_load.py | 39 +-------- ppstructure/layout/README.md | 80 +++++++++++++++++++ 6 files changed, 141 insertions(+), 98 deletions(-) diff --git a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml index e2b97a7b..f3e75341 100644 --- a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml +++ b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml @@ -4,11 +4,9 @@ Global: epoch_num: 800 log_smooth_window: 20 print_batch_step: 10 - save_model_dir: ./output/rec_D081 + save_model_dir: ./output/rec_chinese_lite_distillation_v2.1 save_epoch_step: 3 - eval_batch_step: - - 0 - - 2000 + eval_batch_step: [0, 2000] cal_metric_during_train: true pretrained_model: null checkpoints: null @@ -37,12 +35,10 @@ Optimizer: Architecture: name: DistillationModel algorithm: Distillation - freeze_params: - - false - - false - pretrained: null Models: Student: + pretrained: null + freeze_params: false model_type: rec algorithm: CRNN Transform: @@ -59,6 +55,8 @@ Architecture: name: CTCHead fc_decay: 0.00001 Teacher: + pretrained: null + freeze_params: false model_type: rec algorithm: CRNN Transform: @@ -85,16 +83,20 @@ Loss: key: null - DistillationDMLLoss: weight: 1.0 - model_name_list1: ["Student"] - model_name_list2: ["Teacher"] + act: "softmax" + model_name_pairs: + - ["Student", "Teacher"] + key: null PostProcess: name: DistillationCTCLabelDecode model_name: "Student" key_out: null + Metric: name: RecMetric main_indicator: acc + Train: dataset: name: SimpleDataSet @@ -108,10 +110,7 @@ Train: - RecAug: null - CTCLabelEncode: null - RecResizeImg: - image_shape: - - 3 - - 32 - - 320 + image_shape: [3, 32, 320] - KeepKeys: keep_keys: - image @@ -135,10 +134,7 @@ Eval: channel_first: false - CTCLabelEncode: null - RecResizeImg: - image_shape: - - 3 - - 32 - - 320 + image_shape: [3, 32, 320] - KeepKeys: keep_keys: - image diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py index 3321827b..153bf690 100644 --- a/ppocr/losses/basic_loss.py +++ b/ppocr/losses/basic_loss.py @@ -62,19 +62,29 @@ class DMLLoss(nn.Layer): DMLLoss """ - def __init__(self, name="loss_dml"): + def __init__(self, act=None, name="loss_dml"): super().__init__() + if act is not None: + assert act in ["softmax", "sigmoid"] self.name = name + if act == "softmax": + self.act = nn.Softmax(axis=-1) + elif act == "sigmoid": + self.act = nn.Sigmoid() + else: + self.act = None def forward(self, out1, out2): loss_dict = {} - soft_out1 = F.softmax(out1, axis=-1) - log_soft_out1 = paddle.log(soft_out1) - soft_out2 = F.softmax(out2, axis=-1) - log_soft_out2 = paddle.log(soft_out2) + if self.act is not None: + out1 = self.act(out1) + out2 = self.act(out2) + + log_out1 = paddle.log(out1) + log_out2 = paddle.log(out2) loss = (F.kl_div( - log_soft_out1, soft_out2, reduction='batchmean') + F.kl_div( - log_soft_out2, soft_out1, reduction='batchmean')) / 2.0 + log_out1, out2, reduction='batchmean') + F.kl_div( + log_out2, log_out1, reduction='batchmean')) / 2.0 loss_dict[self.name] = loss return loss_dict @@ -90,7 +100,7 @@ class DistanceLoss(nn.Layer): assert mode in ["l1", "l2", "smooth_l1"] if mode == "l1": self.loss_func = nn.L1Loss(**kargs) - elif mode == "l1": + elif mode == "l2": self.loss_func = nn.MSELoss(**kargs) elif mode == "smooth_l1": self.loss_func = nn.SmoothL1Loss(**kargs) diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index cc6d7d5a..40a8da77 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -23,35 +23,28 @@ class DistillationDMLLoss(DMLLoss): """ """ - def __init__(self, - model_name_list1=[], - model_name_list2=[], - key=None, + def __init__(self, model_name_pairs=[], act=None, key=None, name="loss_dml"): - super().__init__(name=name) - if not isinstance(model_name_list1, (list, )): - model_name_list1 = [model_name_list1] - if not isinstance(model_name_list2, (list, )): - model_name_list2 = [model_name_list2] - - assert len(model_name_list1) == len(model_name_list2) - self.model_name_list1 = model_name_list1 - self.model_name_list2 = model_name_list2 + super().__init__(act=act, name=name) + assert isinstance(model_name_pairs, list) self.key = key + self.model_name_pairs = model_name_pairs def forward(self, predicts, batch): loss_dict = dict() - for idx in range(len(self.model_name_list1)): - out1 = predicts[self.model_name_list1[idx]] - out2 = predicts[self.model_name_list2[idx]] + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] if self.key is not None: out1 = out1[self.key] out2 = out2[self.key] loss = super().forward(out1, out2) if isinstance(loss, dict): - assert len(loss) == 1 - loss = list(loss.values())[0] - loss_dict["{}_{}".format(self.name, idx)] = loss + for key in loss: + loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[ + key] + else: + loss_dict["{}_{}".format(self.name, idx)] = loss return loss_dict @@ -64,13 +57,15 @@ class DistillationCTCLoss(CTCLoss): def forward(self, predicts, batch): loss_dict = dict() - for model_name in self.model_name_list: + for idx, model_name in enumerate(self.model_name_list): out = predicts[model_name] if self.key is not None: out = out[self.key] loss = super().forward(out, batch) if isinstance(loss, dict): - assert len(loss) == 1 - loss = list(loss.values())[0] - loss_dict["{}_{}".format(self.name, model_name)] = loss + for key in loss: + loss_dict["{}_{}_{}".format(self.name, model_name, + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, model_name)] = loss return loss_dict diff --git a/ppocr/modeling/architectures/distillation_model.py b/ppocr/modeling/architectures/distillation_model.py index cc3f2405..bbb9dceb 100644 --- a/ppocr/modeling/architectures/distillation_model.py +++ b/ppocr/modeling/architectures/distillation_model.py @@ -34,25 +34,20 @@ class DistillationModel(nn.Layer): config (dict): the super parameters for module. """ super().__init__() - - freeze_params = config["freeze_params"] - pretrained = config["pretrained"] - if not isinstance(freeze_params, list): - freeze_params = [freeze_params] - assert len(config["Models"]) == len(freeze_params) - - if not isinstance(pretrained, list): - pretrained = [pretrained] * len(config["Models"]) - assert len(config["Models"]) == len(pretrained) - self.model_dict = dict() index = 0 for key in config["Models"]: model_config = config["Models"][key] + freeze_params = False + pretrained = None + if "freeze_params" in model_config: + freeze_params = model_config.pop("freeze_params") + if "pretrained" in model_config: + pretrained = model_config.pop("pretrained") model = BaseModel(model_config) - if pretrained[index] is not None: + if pretrained is not None: load_dygraph_pretrain(model, path=pretrained[index]) - if freeze_params[index]: + if freeze_params: for param in model.parameters(): param.trainable = False self.model_dict[key] = self.add_sublayer(key, model) diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index c730d1ab..951132c3 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -42,38 +42,10 @@ def _mkdir_if_not_exist(path, logger): raise OSError('Failed to mkdir {}'.format(path)) -def load_dygraph_pretrain(model, - logger=None, - path=None, - load_static_weights=False): +def load_dygraph_pretrain(model, logger=None, path=None): if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): raise ValueError("Model pretrain path {} does not " "exists.".format(path)) - if load_static_weights: - pre_state_dict = paddle.static.load_program_state(path) - param_state_dict = {} - model_dict = model.state_dict() - for key in model_dict.keys(): - weight_name = model_dict[key].name - weight_name = weight_name.replace('binarize', '').replace( - 'thresh', '') # for DB - if weight_name in pre_state_dict.keys(): - # logger.info('Load weight: {}, shape: {}'.format( - # weight_name, pre_state_dict[weight_name].shape)) - if 'encoder_rnn' in key: - # delete axis which is 1 - pre_state_dict[weight_name] = pre_state_dict[ - weight_name].squeeze() - # change axis - if len(pre_state_dict[weight_name].shape) > 1: - pre_state_dict[weight_name] = pre_state_dict[ - weight_name].transpose((1, 0)) - param_state_dict[key] = pre_state_dict[weight_name] - else: - param_state_dict[key] = model_dict[key] - model.set_state_dict(param_state_dict) - return - param_state_dict = paddle.load(path + '.pdparams') model.set_state_dict(param_state_dict) return @@ -108,15 +80,10 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): logger.info("resume from {}".format(checkpoints)) elif pretrained_model: - load_static_weights = global_config.get('load_static_weights', False) if not isinstance(pretrained_model, list): pretrained_model = [pretrained_model] - if not isinstance(load_static_weights, list): - load_static_weights = [load_static_weights] * len(pretrained_model) - for idx, pretrained in enumerate(pretrained_model): - load_static = load_static_weights[idx] - load_dygraph_pretrain( - model, logger, path=pretrained, load_static_weights=load_static) + for pretrained in pretrained_model: + load_dygraph_pretrain(model, logger, path=pretrained) logger.info("load pretrained model from {}".format( pretrained_model)) else: diff --git a/ppstructure/layout/README.md b/ppstructure/layout/README.md index e69de29b..e0a5a32b 100644 --- a/ppstructure/layout/README.md +++ b/ppstructure/layout/README.md @@ -0,0 +1,80 @@ +# Python端预测部署 + +Python预测可以使用`tools/infer.py`,此种方式依赖PaddleDetection源码;也可以使用本篇教程预测方式,先将模型导出,使用一个独立的文件进行预测。 + + +本篇教程使用AnalysisPredictor对[导出模型](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/deploy/EXPORT_MODEL.md)进行高性能预测。 + +在PaddlePaddle中预测引擎和训练引擎底层有着不同的优化方法, 预测引擎使用了AnalysisPredictor,专门针对推理进行了优化,是基于[C++预测库](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/native_infer.html)的Python接口,该引擎可以对模型进行多项图优化,减少不必要的内存拷贝。如果用户在部署已训练模型的过程中对性能有较高的要求,我们提供了独立于PaddleDetection的预测脚本,方便用户直接集成部署。 + + +主要包含两个步骤: + +- 导出预测模型 +- 基于Python的预测 + +## 1. 导出预测模型 + +PaddleDetection在训练过程包括网络的前向和优化器相关参数,而在部署过程中,我们只需要前向参数,具体参考:[导出模型](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/deploy/EXPORT_MODEL.md) + +导出后目录下,包括`infer_cfg.yml`, `model.pdiparams`, `model.pdiparams.info`, `model.pdmodel`四个文件。 + +## 2. 基于python的预测 + +### 2.1 安装依赖 + - `PaddlePaddle`的安装: + 请点击[官方安装文档](https://paddlepaddle.org.cn/install/quick) 选择适合的方式,版本为2.0rc1以上即可 + - 切换到`PaddleDetection`代码库根目录,执行`pip install -r requirements.txt`安装其它依赖 + +### 2.2 执行预测程序 +在终端输入以下命令进行预测: + +```bash +python deploy/python/infer.py --model_dir=/path/to/models --image_file=/path/to/image +--use_gpu=(False/True) +``` + +参数说明如下: + +| 参数 | 是否必须|含义 | +|-------|-------|----------| +| --model_dir | Yes|上述导出的模型路径 | +| --image_file | Option |需要预测的图片 | +| --video_file | Option |需要预测的视频 | +| --camera_id | Option | 用来预测的摄像头ID,默认为-1(表示不使用摄像头预测,可设置为:0 - (摄像头数目-1) ),预测过程中在可视化界面按`q`退出输出预测结果到:output/output.mp4| +| --use_gpu |No|是否GPU,默认为False| +| --run_mode |No|使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16/trt_int8)| +| --threshold |No|预测得分的阈值,默认为0.5| +| --output_dir |No|可视化结果保存的根目录,默认为output/| +| --run_benchmark |No|是否运行benchmark,同时需指定--image_file| + +说明: + +- run_mode:fluid代表使用AnalysisPredictor,精度float32来推理,其他参数指用AnalysisPredictor,TensorRT不同精度来推理。 +- PaddlePaddle默认的GPU安装包(<=1.7),不支持基于TensorRT进行预测,如果想基于TensorRT加速预测,需要自行编译,详细可参考[预测库编译教程](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_usage/deploy/inference/paddle_tensorrt_infer.html)。 + +## 3. 部署性能对比测试 +对比AnalysisPredictor相对Executor的推理速度 + +### 3.1 测试环境: + +- CUDA 9.0 +- CUDNN 7.5 +- PaddlePaddle 1.71 +- GPU: Tesla P40 + +### 3.2 测试方式: + +- Batch Size=1 +- 去掉前100轮warmup时间,测试100轮的平均时间,单位ms/image,只计算模型运行时间,不包括数据的处理和拷贝。 + + +### 3.3 测试结果 + +|模型 | AnalysisPredictor | Executor | 输入| +|---|----|---|---| +| YOLOv3-MobileNetv1 | 15.20 | 19.54 | 608*608 +| faster_rcnn_r50_fpn_1x | 50.05 | 69.58 |800*1088 +| faster_rcnn_r50_1x | 326.11 | 347.22 | 800*1067 +| mask_rcnn_r50_fpn_1x | 67.49 | 91.02 | 800*1088 +| mask_rcnn_r50_1x | 326.11 | 350.94 | 800*1067 -- GitLab