From 185d1e1f929642e4bae576c3888b9500771a84fe Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 7 Jul 2021 01:54:03 +0000 Subject: [PATCH] fix bug --- .../ch_det_lite_train_distill_v2.1.yml | 18 +++++--- ppocr/losses/distillation_loss.py | 9 +++- .../architectures/distillation_model.py | 4 +- ppocr/postprocess/__init__.py | 4 +- ppocr/postprocess/db_postprocess.py | 41 +++++++++++++++++++ ppocr/utils/save_load.py | 20 +++++++++ tools/program.py | 5 ++- 7 files changed, 89 insertions(+), 12 deletions(-) diff --git a/configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml b/configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml index 54ef12a1..b27eb2f9 100644 --- a/configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml +++ b/configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml @@ -20,7 +20,7 @@ Architecture: algorithm: Distillation Models: Student: - pretrained: + pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained freeze_params: false return_all_feats: false model_type: det @@ -37,7 +37,7 @@ Architecture: name: DBHead k: 50 Student2: - pretrained: + pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained freeze_params: false return_all_feats: false model_type: det @@ -55,6 +55,9 @@ Architecture: name: DBHead k: 50 Teacher: + pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy + freeze_params: true + return_all_feats: false model_type: det algorithm: DB Transform: @@ -73,7 +76,9 @@ Loss: loss_config_list: - DistillationDilaDBLoss: weight: 1.0 - model_name_list: ["Student", "Student2", "Teacher"] + model_name_pairs: + - ["Student", "Teacher"] + - ["Student2", "Teacher"] key: maps balance_loss: true main_loss_type: DiceLoss @@ -81,13 +86,16 @@ Loss: beta: 10 ohem_ratio: 3 - DistillationDMLLoss: + model_name_pairs: + - ["Student", "Student2"] maps_name: ["thrink_maps"] weight: 1.0 act: "softmax" model_name_pairs: ["Student", "Student2"] key: maps - DistillationDBLoss: - model_name_list: ["Student", "Teacher"] + weight: 1.0 + model_name_list: ["Student", "Student2"] key: maps name: DBLoss balance_loss: true @@ -110,7 +118,7 @@ Optimizer: factor: 0 PostProcess: - name: DistillationCTDBPostProcessCLabelDecode + name: DistillationDBPostProcess model_name: ["Student", "Student2"] key: head_out thresh: 0.3 diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index 421bbaba..d4e4a8a2 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -14,6 +14,8 @@ import paddle import paddle.nn as nn +import numpy as np +import cv2 from .rec_ctc_loss import CTCLoss from .basic_loss import DMLLoss @@ -22,6 +24,7 @@ from .det_db_loss import DBLoss from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss + def _sum_loss(loss_dict): if "loss" in loss_dict.keys(): return loss_dict @@ -50,7 +53,7 @@ class DistillationDMLLoss(DMLLoss): self.key = key self.model_name_pairs = model_name_pairs self.name = name - self.maps_name = self.maps_name + self.maps_name = maps_name def _check_maps_name(self, maps_name): if maps_name is None: @@ -172,6 +175,7 @@ class DistillationDBLoss(DBLoss): class DistillationDilaDBLoss(DBLoss): def __init__(self, model_name_pairs=[], + key=None, balance_loss=True, main_loss_type='DiceLoss', alpha=5, @@ -182,6 +186,7 @@ class DistillationDilaDBLoss(DBLoss): super().__init__() self.model_name_pairs = model_name_pairs self.name = name + self.key = key def forward(self, predicts, batch): loss_dict = dict() @@ -219,7 +224,7 @@ class DistillationDilaDBLoss(DBLoss): loss_dict[k] = bce_loss + loss_binary_maps loss_dict = _sum_loss(loss_dict) - return loss + return loss_dict class DistillationDistanceLoss(DistanceLoss): diff --git a/ppocr/modeling/architectures/distillation_model.py b/ppocr/modeling/architectures/distillation_model.py index 2e512331..1e95fe57 100644 --- a/ppocr/modeling/architectures/distillation_model.py +++ b/ppocr/modeling/architectures/distillation_model.py @@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone from ppocr.modeling.necks import build_neck from ppocr.modeling.heads import build_head from .base_model import BaseModel -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import init_model, load_pretrained_params __all__ = ['DistillationModel'] @@ -46,7 +46,7 @@ class DistillationModel(nn.Layer): pretrained = model_config.pop("pretrained") model = BaseModel(model_config) if pretrained is not None: - init_model(model, path=pretrained) + load_pretrained_params(model, pretrained) if freeze_params: for param in model.parameters(): param.trainable = False diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 2f5bdc3b..f2ac65c4 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -21,7 +21,7 @@ import copy __all__ = ['build_post_process'] -from .db_postprocess import DBPostProcess +from .db_postprocess import DBPostProcess, DistillationDBPostProcess from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \ @@ -34,7 +34,7 @@ def build_post_process(config, global_config=None): support_dict = [ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', - 'DistillationCTCLabelDecode', 'TableLabelDecode' + 'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index 769ddbe2..4561b464 100755 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -187,3 +187,44 @@ class DBPostProcess(object): boxes_batch.append({'points': boxes}) return boxes_batch + + +class DistillationDBPostProcess(DBPostProcess): + def __init__(self, + model_name=["student"], + key=None, + thresh=0.3, + box_thresh=0.7, + max_candidates=1000, + unclip_ratio=2.0, + use_dilation=False, + score_mode="fast", + **kwargs): + super(DistillationDBPostProcess, self).__init__(thresh, + box_thresh, + max_candidates, + unclip_ratio, + use_dilation, + score_mode) + if not isinstance(model_name, list): + model_name = [model_name] + self.model_name = model_name + + self.key = key + + def forward(self, predicts, shape_list): + results = {} + for name in self.model_name: + pred = predicts[name] + if self.key is not None: + pred = pred[self.key] + results[name] = super().__call__(pred, shape_list=label) + + return results + + + + + + + diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 76420abb..732f9e20 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -116,6 +116,26 @@ def load_dygraph_params(config, model, logger, optimizer): logger.info(f"loaded pretrained_model successful from {pm}") return {} +def load_pretrained_params(model, path): + if path is None: + return False + if not os.path.exists(path) and not os.path.exists(path + ".pdparams"): + print(f"The pretrained_model {path} does not exists!") + return False + + path = path if path.endswith('.pdparams') else path + '.pdparams' + params = paddle.load(path) + state_dict = model.state_dict() + new_state_dict = {} + for k1, k2 in zip(state_dict.keys(), params.keys()): + if list(state_dict[k1].shape) == list(params[k2].shape): + new_state_dict[k1] = params[k2] + else: + print( + f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" + ) + model.set_state_dict(new_state_dict) + return True def save_model(model, optimizer, diff --git a/tools/program.py b/tools/program.py index 2d99f296..595fe4cb 100755 --- a/tools/program.py +++ b/tools/program.py @@ -186,7 +186,10 @@ def train(config, model.train() use_srn = config['Architecture']['algorithm'] == "SRN" - model_type = config['Architecture']['model_type'] + try: + model_type = config['Architecture']['model_type'] + except: + model_type = None if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] -- GitLab