diff --git a/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml b/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml new file mode 100644 index 0000000000000000000000000000000000000000..6b60ae0860959405fd512913d022c02e2e2dae05 --- /dev/null +++ b/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml @@ -0,0 +1,158 @@ +Global: + debug: false + use_gpu: true + epoch_num: 800 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec_chinese_lite_distillation_v2.1 + save_epoch_step: 3 + eval_batch_step: [0, 2000] + cal_metric_during_train: true + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: false + infer_img: doc/imgs_words/ch/word_1.jpg + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + character_type: ch + max_text_length: 25 + infer_mode: false + use_space_char: false + distributed: true + save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.0005 + warmup_epoch: 5 + regularizer: + name: L2 + factor: 1.0e-05 +Architecture: + name: DistillationModel + algorithm: Distillation + Models: + Student: + pretrained: + freeze_params: false + return_all_feats: true + model_type: rec + algorithm: CRNN + Transform: + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: small + small_stride: [1, 2, 2, 2] + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 48 + Head: + name: CTCHead + fc_decay: 0.00001 + Teacher: + pretrained: + freeze_params: false + return_all_feats: true + model_type: rec + algorithm: CRNN + Transform: + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: small + small_stride: [1, 2, 2, 2] + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 48 + Head: + name: CTCHead + fc_decay: 0.00001 + + +Loss: + name: CombinedLoss + loss_config_list: + - DistillationCTCLoss: + weight: 1.0 + model_name_list: ["Student", "Teacher"] + key: head_out + - DistillationDMLLoss: + weight: 1.0 + act: "softmax" + model_name_pairs: + - ["Student", "Teacher"] + key: head_out + - DistillationDistanceLoss: + weight: 1.0 + mode: "l2" + model_name_pairs: + - ["Student", "Teacher"] + key: backbone_out + +PostProcess: + name: DistillationCTCLabelDecode + model_name: ["Student", "Teacher"] + key: head_out + +Metric: + name: DistillationMetric + base_metric_name: RecMetric + main_indicator: acc + key: "Student" + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ + label_file_list: + - ./train_data/train_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - RecAug: + - CTCLabelEncode: + - RecResizeImg: + image_shape: [3, 32, 320] + - KeepKeys: + keep_keys: + - image + - label + - length + loader: + shuffle: true + batch_size_per_card: 128 + drop_last: true + num_sections: 1 + num_workers: 8 +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data + label_file_list: + - ./train_data/val_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - CTCLabelEncode: + - RecResizeImg: + image_shape: [3, 32, 320] + - KeepKeys: + keep_keys: + - image + - label + - length + loader: + shuffle: false + drop_last: false + batch_size_per_card: 128 + num_workers: 8 diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 223ae6b1da996478ac607e29dd37173ca51d9903..bf10d2982dcdd36021a7385ab8828398b51af3d3 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -13,28 +13,37 @@ # limitations under the License. import copy +import paddle +import paddle.nn as nn +# det loss +from .det_db_loss import DBLoss +from .det_east_loss import EASTLoss +from .det_sast_loss import SASTLoss -def build_loss(config): - # det loss - from .det_db_loss import DBLoss - from .det_east_loss import EASTLoss - from .det_sast_loss import SASTLoss +# rec loss +from .rec_ctc_loss import CTCLoss +from .rec_att_loss import AttentionLoss +from .rec_srn_loss import SRNLoss + +# cls loss +from .cls_loss import ClsLoss + +# e2e loss +from .e2e_pg_loss import PGLoss - # rec loss - from .rec_ctc_loss import CTCLoss - from .rec_att_loss import AttentionLoss - from .rec_srn_loss import SRNLoss +# basic loss function +from .basic_loss import DistanceLoss - # cls loss - from .cls_loss import ClsLoss +# combined loss function +from .combined_loss import CombinedLoss - # e2e loss - from .e2e_pg_loss import PGLoss + +def build_loss(config): support_dict = [ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', - 'SRNLoss', 'PGLoss'] - + 'SRNLoss', 'PGLoss', 'CombinedLoss' + ] config = copy.deepcopy(config) module_name = config.pop('name') assert module_name in support_dict, Exception('loss only support {}'.format( diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..fa3ceda1b747aad3c4b275611b1257bf6950f013 --- /dev/null +++ b/ppocr/losses/basic_loss.py @@ -0,0 +1,103 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddle.nn import L1Loss +from paddle.nn import MSELoss as L2Loss +from paddle.nn import SmoothL1Loss + + +class CELoss(nn.Layer): + def __init__(self, epsilon=None): + super().__init__() + if epsilon is not None and (epsilon <= 0 or epsilon >= 1): + epsilon = None + self.epsilon = epsilon + + def _labelsmoothing(self, target, class_num): + if target.shape[-1] != class_num: + one_hot_target = F.one_hot(target, class_num) + else: + one_hot_target = target + soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon) + soft_target = paddle.reshape(soft_target, shape=[-1, class_num]) + return soft_target + + def forward(self, x, label): + loss_dict = {} + if self.epsilon is not None: + class_num = x.shape[-1] + label = self._labelsmoothing(label, class_num) + x = -F.log_softmax(x, axis=-1) + loss = paddle.sum(x * label, axis=-1) + else: + if label.shape[-1] == x.shape[-1]: + label = F.softmax(label, axis=-1) + soft_label = True + else: + soft_label = False + loss = F.cross_entropy(x, label=label, soft_label=soft_label) + return loss + + +class DMLLoss(nn.Layer): + """ + DMLLoss + """ + + def __init__(self, act=None): + super().__init__() + if act is not None: + assert act in ["softmax", "sigmoid"] + if act == "softmax": + self.act = nn.Softmax(axis=-1) + elif act == "sigmoid": + self.act = nn.Sigmoid() + else: + self.act = None + + def forward(self, out1, out2): + if self.act is not None: + out1 = self.act(out1) + out2 = self.act(out2) + + log_out1 = paddle.log(out1) + log_out2 = paddle.log(out2) + loss = (F.kl_div( + log_out1, out2, reduction='batchmean') + F.kl_div( + log_out2, out1, reduction='batchmean')) / 2.0 + return loss + + +class DistanceLoss(nn.Layer): + """ + DistanceLoss: + mode: loss mode + """ + + def __init__(self, mode="l2", **kargs): + super().__init__() + assert mode in ["l1", "l2", "smooth_l1"] + if mode == "l1": + self.loss_func = nn.L1Loss(**kargs) + elif mode == "l2": + self.loss_func = nn.MSELoss(**kargs) + elif mode == "smooth_l1": + self.loss_func = nn.SmoothL1Loss(**kargs) + + def forward(self, x, y): + return self.loss_func(x, y) diff --git a/ppocr/losses/cls_loss.py b/ppocr/losses/cls_loss.py index 41c7db02446549064ffa8896c2c6861d0d9803c5..ecca5d2e1739631716123d4a793f5ece09d7f9ab 100755 --- a/ppocr/losses/cls_loss.py +++ b/ppocr/losses/cls_loss.py @@ -24,7 +24,7 @@ class ClsLoss(nn.Layer): super(ClsLoss, self).__init__() self.loss_func = nn.CrossEntropyLoss(reduction='mean') - def __call__(self, predicts, batch): + def forward(self, predicts, batch): label = batch[1] loss = self.loss_func(input=predicts, label=label) return {'loss': loss} diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..54da70174cba7bf5ca35e8fbf5aa137a437ae29c --- /dev/null +++ b/ppocr/losses/combined_loss.py @@ -0,0 +1,58 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +from .distillation_loss import DistillationCTCLoss +from .distillation_loss import DistillationDMLLoss +from .distillation_loss import DistillationDistanceLoss + + +class CombinedLoss(nn.Layer): + """ + CombinedLoss: + a combionation of loss function + """ + + def __init__(self, loss_config_list=None): + super().__init__() + self.loss_func = [] + self.loss_weight = [] + assert isinstance(loss_config_list, list), ( + 'operator config should be a list') + for config in loss_config_list: + assert isinstance(config, + dict) and len(config) == 1, "yaml format error" + name = list(config)[0] + param = config[name] + assert "weight" in param, "weight must be in param, but param just contains {}".format( + param.keys()) + self.loss_weight.append(param.pop("weight")) + self.loss_func.append(eval(name)(**param)) + + def forward(self, input, batch, **kargs): + loss_dict = {} + for idx, loss_func in enumerate(self.loss_func): + loss = loss_func(input, batch, **kargs) + if isinstance(loss, paddle.Tensor): + loss = {"loss_{}_{}".format(str(loss), idx): loss} + weight = self.loss_weight[idx] + loss = { + "{}_{}".format(key, idx): loss[key] * weight + for key in loss + } + loss_dict.update(loss) + loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) + return loss_dict diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1e8aa0d8602e3ddd49913e6a572914859377ca42 --- /dev/null +++ b/ppocr/losses/distillation_loss.py @@ -0,0 +1,108 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn + +from .rec_ctc_loss import CTCLoss +from .basic_loss import DMLLoss +from .basic_loss import DistanceLoss + + +class DistillationDMLLoss(DMLLoss): + """ + """ + + def __init__(self, model_name_pairs=[], act=None, key=None, + name="loss_dml"): + super().__init__(act=act) + assert isinstance(model_name_pairs, list) + self.key = key + self.model_name_pairs = model_name_pairs + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + loss = super().forward(out1, out2) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, idx)] = loss + return loss_dict + + +class DistillationCTCLoss(CTCLoss): + def __init__(self, model_name_list=[], key=None, name="loss_ctc"): + super().__init__() + self.model_name_list = model_name_list + self.key = key + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, model_name in enumerate(self.model_name_list): + out = predicts[model_name] + if self.key is not None: + out = out[self.key] + loss = super().forward(out, batch) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}".format(self.name, model_name, + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, model_name)] = loss + return loss_dict + + +class DistillationDistanceLoss(DistanceLoss): + """ + """ + + def __init__(self, + mode="l2", + model_name_pairs=[], + key=None, + name="loss_distance", + **kargs): + super().__init__(mode=mode, **kargs) + assert isinstance(model_name_pairs, list) + self.key = key + self.model_name_pairs = model_name_pairs + self.name = name + "_l2" + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + loss = super().forward(out1, out2) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[ + key] + else: + loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1], + idx)] = loss + return loss_dict diff --git a/ppocr/losses/rec_ctc_loss.py b/ppocr/losses/rec_ctc_loss.py index 425de58710a61fde2034a88707a3032e02007d13..6c0b56ff84db4ff23786fb781d461bf9fbc86ef2 100755 --- a/ppocr/losses/rec_ctc_loss.py +++ b/ppocr/losses/rec_ctc_loss.py @@ -25,7 +25,7 @@ class CTCLoss(nn.Layer): super(CTCLoss, self).__init__() self.loss_func = nn.CTCLoss(blank=0, reduction='none') - def __call__(self, predicts, batch): + def forward(self, predicts, batch): predicts = predicts.transpose((1, 0, 2)) N, B, _ = predicts.shape preds_lengths = paddle.to_tensor([N] * B, dtype='int64') diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index f913010dbd994633d3df1cf996abb994d246a11a..9e9060fa999bd3175c31dfc0797cd293d4e7afec 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -19,20 +19,23 @@ from __future__ import unicode_literals import copy -__all__ = ['build_metric'] +__all__ = ["build_metric"] +from .det_metric import DetMetric +from .rec_metric import RecMetric +from .cls_metric import ClsMetric +from .e2e_metric import E2EMetric +from .distillation_metric import DistillationMetric -def build_metric(config): - from .det_metric import DetMetric - from .rec_metric import RecMetric - from .cls_metric import ClsMetric - from .e2e_metric import E2EMetric - support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric'] +def build_metric(config): + support_dict = [ + "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric" + ] config = copy.deepcopy(config) - module_name = config.pop('name') + module_name = config.pop("name") assert module_name in support_dict, Exception( - 'metric only support {}'.format(support_dict)) + "metric only support {}".format(support_dict)) module_class = eval(module_name)(**config) return module_class diff --git a/ppocr/metrics/distillation_metric.py b/ppocr/metrics/distillation_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..a7d3d095a7d384bf8cdc69b97f8109c359ac2b5b --- /dev/null +++ b/ppocr/metrics/distillation_metric.py @@ -0,0 +1,76 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import copy + +from .rec_metric import RecMetric +from .det_metric import DetMetric +from .e2e_metric import E2EMetric +from .cls_metric import ClsMetric + + +class DistillationMetric(object): + def __init__(self, + key=None, + base_metric_name="RecMetric", + main_indicator='acc', + **kwargs): + self.main_indicator = main_indicator + self.key = key + self.main_indicator = main_indicator + self.base_metric_name = base_metric_name + self.kwargs = kwargs + self.metrics = None + + def _init_metrcis(self, preds): + self.metrics = dict() + mod = importlib.import_module(__name__) + for key in preds: + self.metrics[key] = getattr(mod, self.base_metric_name)( + main_indicator=self.main_indicator, **self.kwargs) + self.metrics[key].reset() + + def __call__(self, preds, *args, **kwargs): + assert isinstance(preds, dict) + if self.metrics is None: + self._init_metrcis(preds) + output = dict() + for key in preds: + metric = self.metrics[key].__call__(preds[key], *args, **kwargs) + for sub_key in metric: + output["{}_{}".format(key, sub_key)] = metric[sub_key] + return output + + def get_metric(self): + """ + return metrics { + 'acc': 0, + 'norm_edit_dis': 0, + } + """ + output = dict() + for key in self.metrics: + metric = self.metrics[key].get_metric() + # main indicator + if key == self.key: + output.update(metric) + else: + for sub_key in metric: + output["{}_{}".format(key, sub_key)] = metric[sub_key] + return output + + def reset(self): + for key in self.metrics: + self.metrics[key].reset() diff --git a/ppocr/modeling/architectures/__init__.py b/ppocr/modeling/architectures/__init__.py index 86eaf7c9fb3c1147f60c7652243184121c62bcea..e9a01cf0281b91d29f2cce88375be3aaf43feb2e 100755 --- a/ppocr/modeling/architectures/__init__.py +++ b/ppocr/modeling/architectures/__init__.py @@ -13,12 +13,20 @@ # limitations under the License. import copy +import importlib + +from .base_model import BaseModel +from .distillation_model import DistillationModel __all__ = ['build_model'] + def build_model(config): - from .base_model import BaseModel - config = copy.deepcopy(config) - module_class = BaseModel(config) - return module_class \ No newline at end of file + if not "name" in config: + arch = BaseModel(config) + else: + name = config.pop("name") + mod = importlib.import_module(__name__) + arch = getattr(mod, name)(config) + return arch diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index 09b6e0346d998e3b90762e6163e8a34b48daff36..4c941fcf65573d9314c0badda49895d0b6b5c4f9 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -32,7 +32,6 @@ class BaseModel(nn.Layer): config (dict): the super parameters for module. """ super(BaseModel, self).__init__() - in_channels = config.get('in_channels', 3) model_type = config['model_type'] # build transfrom, @@ -68,14 +67,23 @@ class BaseModel(nn.Layer): config["Head"]['in_channels'] = in_channels self.head = build_head(config["Head"]) + self.return_all_feats = config.get("return_all_feats", False) + def forward(self, x, data=None): + y = dict() if self.use_transform: x = self.transform(x) x = self.backbone(x) + y["backbone_out"] = x if self.use_neck: x = self.neck(x) + y["neck_out"] = x if data is None: x = self.head(x) else: x = self.head(x, data) - return x + y["head_out"] = x + if self.return_all_feats: + return y + else: + return x diff --git a/ppocr/modeling/architectures/distillation_model.py b/ppocr/modeling/architectures/distillation_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2e512331afcfc20e422dbef4ba1a4acd581df9e7 --- /dev/null +++ b/ppocr/modeling/architectures/distillation_model.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import nn +from ppocr.modeling.transforms import build_transform +from ppocr.modeling.backbones import build_backbone +from ppocr.modeling.necks import build_neck +from ppocr.modeling.heads import build_head +from .base_model import BaseModel +from ppocr.utils.save_load import init_model + +__all__ = ['DistillationModel'] + + +class DistillationModel(nn.Layer): + def __init__(self, config): + """ + the module for OCR distillation. + args: + config (dict): the super parameters for module. + """ + super().__init__() + self.model_list = [] + self.model_name_list = [] + for key in config["Models"]: + model_config = config["Models"][key] + freeze_params = False + pretrained = None + if "freeze_params" in model_config: + freeze_params = model_config.pop("freeze_params") + if "pretrained" in model_config: + pretrained = model_config.pop("pretrained") + model = BaseModel(model_config) + if pretrained is not None: + init_model(model, path=pretrained) + if freeze_params: + for param in model.parameters(): + param.trainable = False + self.model_list.append(self.add_sublayer(key, model)) + self.model_name_list.append(key) + + def forward(self, x): + result_dict = dict() + for idx, model_name in enumerate(self.model_name_list): + result_dict[model_name] = self.model_list[idx](x) + return result_dict diff --git a/ppocr/modeling/backbones/det_mobilenet_v3.py b/ppocr/modeling/backbones/det_mobilenet_v3.py index bb451bbec9327e2624ab0d501a7adf4355dc3407..05113ea8419aa302c952adfd74e9083055c35dca 100755 --- a/ppocr/modeling/backbones/det_mobilenet_v3.py +++ b/ppocr/modeling/backbones/det_mobilenet_v3.py @@ -102,8 +102,7 @@ class MobileNetV3(nn.Layer): padding=1, groups=1, if_act=True, - act='hardswish', - name='conv1') + act='hardswish') self.stages = [] self.out_channels = [] @@ -125,8 +124,7 @@ class MobileNetV3(nn.Layer): kernel_size=k, stride=s, use_se=se, - act=nl, - name="conv" + str(i + 2))) + act=nl)) inplanes = make_divisible(scale * c) i += 1 block_list.append( @@ -138,8 +136,7 @@ class MobileNetV3(nn.Layer): padding=0, groups=1, if_act=True, - act='hardswish', - name='conv_last')) + act='hardswish')) self.stages.append(nn.Sequential(*block_list)) self.out_channels.append(make_divisible(scale * cls_ch_squeeze)) for i, stage in enumerate(self.stages): @@ -163,8 +160,7 @@ class ConvBNLayer(nn.Layer): padding, groups=1, if_act=True, - act=None, - name=None): + act=None): super(ConvBNLayer, self).__init__() self.if_act = if_act self.act = act @@ -175,16 +171,9 @@ class ConvBNLayer(nn.Layer): stride=stride, padding=padding, groups=groups, - weight_attr=ParamAttr(name=name + '_weights'), bias_attr=False) - self.bn = nn.BatchNorm( - num_channels=out_channels, - act=None, - param_attr=ParamAttr(name=name + "_bn_scale"), - bias_attr=ParamAttr(name=name + "_bn_offset"), - moving_mean_name=name + "_bn_mean", - moving_variance_name=name + "_bn_variance") + self.bn = nn.BatchNorm(num_channels=out_channels, act=None) def forward(self, x): x = self.conv(x) @@ -209,8 +198,7 @@ class ResidualUnit(nn.Layer): kernel_size, stride, use_se, - act=None, - name=''): + act=None): super(ResidualUnit, self).__init__() self.if_shortcut = stride == 1 and in_channels == out_channels self.if_se = use_se @@ -222,8 +210,7 @@ class ResidualUnit(nn.Layer): stride=1, padding=0, if_act=True, - act=act, - name=name + "_expand") + act=act) self.bottleneck_conv = ConvBNLayer( in_channels=mid_channels, out_channels=mid_channels, @@ -232,10 +219,9 @@ class ResidualUnit(nn.Layer): padding=int((kernel_size - 1) // 2), groups=mid_channels, if_act=True, - act=act, - name=name + "_depthwise") + act=act) if self.if_se: - self.mid_se = SEModule(mid_channels, name=name + "_se") + self.mid_se = SEModule(mid_channels) self.linear_conv = ConvBNLayer( in_channels=mid_channels, out_channels=out_channels, @@ -243,8 +229,7 @@ class ResidualUnit(nn.Layer): stride=1, padding=0, if_act=False, - act=None, - name=name + "_linear") + act=None) def forward(self, inputs): x = self.expand_conv(inputs) @@ -258,7 +243,7 @@ class ResidualUnit(nn.Layer): class SEModule(nn.Layer): - def __init__(self, in_channels, reduction=4, name=""): + def __init__(self, in_channels, reduction=4): super(SEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2D(1) self.conv1 = nn.Conv2D( @@ -266,17 +251,13 @@ class SEModule(nn.Layer): out_channels=in_channels // reduction, kernel_size=1, stride=1, - padding=0, - weight_attr=ParamAttr(name=name + "_1_weights"), - bias_attr=ParamAttr(name=name + "_1_offset")) + padding=0) self.conv2 = nn.Conv2D( in_channels=in_channels // reduction, out_channels=in_channels, kernel_size=1, stride=1, - padding=0, - weight_attr=ParamAttr(name + "_2_weights"), - bias_attr=ParamAttr(name=name + "_2_offset")) + padding=0) def forward(self, inputs): outputs = self.avg_pool(inputs) diff --git a/ppocr/modeling/backbones/rec_mobilenet_v3.py b/ppocr/modeling/backbones/rec_mobilenet_v3.py index 1ff17159680372b00e6943e180e5fb638b39ec58..c5dcfdd5a3ad1f2c356f488a89e0f1e660a4a832 100644 --- a/ppocr/modeling/backbones/rec_mobilenet_v3.py +++ b/ppocr/modeling/backbones/rec_mobilenet_v3.py @@ -96,8 +96,7 @@ class MobileNetV3(nn.Layer): padding=1, groups=1, if_act=True, - act='hardswish', - name='conv1') + act='hardswish') i = 0 block_list = [] inplanes = make_divisible(inplanes * scale) @@ -110,8 +109,7 @@ class MobileNetV3(nn.Layer): kernel_size=k, stride=s, use_se=se, - act=nl, - name='conv' + str(i + 2))) + act=nl)) inplanes = make_divisible(scale * c) i += 1 self.blocks = nn.Sequential(*block_list) @@ -124,8 +122,7 @@ class MobileNetV3(nn.Layer): padding=0, groups=1, if_act=True, - act='hardswish', - name='conv_last') + act='hardswish') self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) self.out_channels = make_divisible(scale * cls_ch_squeeze) diff --git a/ppocr/modeling/heads/det_db_head.py b/ppocr/modeling/heads/det_db_head.py index ca18d74a68f7b17ee6383d4a0c995a4c46a16187..83e7a5ebfe131ed209b7dd2d4b5a324605be8370 100644 --- a/ppocr/modeling/heads/det_db_head.py +++ b/ppocr/modeling/heads/det_db_head.py @@ -23,10 +23,10 @@ import paddle.nn.functional as F from paddle import ParamAttr -def get_bias_attr(k, name): +def get_bias_attr(k): stdv = 1.0 / math.sqrt(k * 1.0) initializer = paddle.nn.initializer.Uniform(-stdv, stdv) - bias_attr = ParamAttr(initializer=initializer, name=name + "_b_attr") + bias_attr = ParamAttr(initializer=initializer) return bias_attr @@ -38,18 +38,14 @@ class Head(nn.Layer): out_channels=in_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr(name=name_list[0] + '.w_0'), + weight_attr=ParamAttr(), bias_attr=False) self.conv_bn1 = nn.BatchNorm( num_channels=in_channels // 4, param_attr=ParamAttr( - name=name_list[1] + '.w_0', initializer=paddle.nn.initializer.Constant(value=1.0)), bias_attr=ParamAttr( - name=name_list[1] + '.b_0', initializer=paddle.nn.initializer.Constant(value=1e-4)), - moving_mean_name=name_list[1] + '.w_1', - moving_variance_name=name_list[1] + '.w_2', act='relu') self.conv2 = nn.Conv2DTranspose( in_channels=in_channels // 4, @@ -57,19 +53,14 @@ class Head(nn.Layer): kernel_size=2, stride=2, weight_attr=ParamAttr( - name=name_list[2] + '.w_0', initializer=paddle.nn.initializer.KaimingUniform()), - bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv2")) + bias_attr=get_bias_attr(in_channels // 4)) self.conv_bn2 = nn.BatchNorm( num_channels=in_channels // 4, param_attr=ParamAttr( - name=name_list[3] + '.w_0', initializer=paddle.nn.initializer.Constant(value=1.0)), bias_attr=ParamAttr( - name=name_list[3] + '.b_0', initializer=paddle.nn.initializer.Constant(value=1e-4)), - moving_mean_name=name_list[3] + '.w_1', - moving_variance_name=name_list[3] + '.w_2', act="relu") self.conv3 = nn.Conv2DTranspose( in_channels=in_channels // 4, @@ -77,10 +68,8 @@ class Head(nn.Layer): kernel_size=2, stride=2, weight_attr=ParamAttr( - name=name_list[4] + '.w_0', initializer=paddle.nn.initializer.KaimingUniform()), - bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv3"), - ) + bias_attr=get_bias_attr(in_channels // 4), ) def forward(self, x): x = self.conv1(x) diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py index 69d4ef50b648c0251b9b8d0b4c1e731a6f236105..481f93e47e58f8267b23e632df1a1e80733d5944 100755 --- a/ppocr/modeling/heads/rec_ctc_head.py +++ b/ppocr/modeling/heads/rec_ctc_head.py @@ -23,14 +23,12 @@ from paddle import ParamAttr, nn from paddle.nn import functional as F -def get_para_bias_attr(l2_decay, k, name): +def get_para_bias_attr(l2_decay, k): regularizer = paddle.regularizer.L2Decay(l2_decay) stdv = 1.0 / math.sqrt(k * 1.0) initializer = nn.initializer.Uniform(-stdv, stdv) - weight_attr = ParamAttr( - regularizer=regularizer, initializer=initializer, name=name + "_w_attr") - bias_attr = ParamAttr( - regularizer=regularizer, initializer=initializer, name=name + "_b_attr") + weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer) + bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer) return [weight_attr, bias_attr] @@ -38,13 +36,12 @@ class CTCHead(nn.Layer): def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs): super(CTCHead, self).__init__() weight_attr, bias_attr = get_para_bias_attr( - l2_decay=fc_decay, k=in_channels, name='ctc_fc') + l2_decay=fc_decay, k=in_channels) self.fc = nn.Linear( in_channels, out_channels, weight_attr=weight_attr, - bias_attr=bias_attr, - name='ctc_fc') + bias_attr=bias_attr) self.out_channels = out_channels def forward(self, x, labels=None): diff --git a/ppocr/modeling/necks/db_fpn.py b/ppocr/modeling/necks/db_fpn.py index 710023f30cdda90322b731c4bd3465d0dc06a139..1cf30cedd5b23e8a7ba243726a6d7eea7924750c 100644 --- a/ppocr/modeling/necks/db_fpn.py +++ b/ppocr/modeling/necks/db_fpn.py @@ -32,61 +32,53 @@ class DBFPN(nn.Layer): in_channels=in_channels[0], out_channels=self.out_channels, kernel_size=1, - weight_attr=ParamAttr( - name='conv2d_51.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.in3_conv = nn.Conv2D( in_channels=in_channels[1], out_channels=self.out_channels, kernel_size=1, - weight_attr=ParamAttr( - name='conv2d_50.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.in4_conv = nn.Conv2D( in_channels=in_channels[2], out_channels=self.out_channels, kernel_size=1, - weight_attr=ParamAttr( - name='conv2d_49.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.in5_conv = nn.Conv2D( in_channels=in_channels[3], out_channels=self.out_channels, kernel_size=1, - weight_attr=ParamAttr( - name='conv2d_48.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.p5_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr( - name='conv2d_52.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.p4_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr( - name='conv2d_53.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.p3_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr( - name='conv2d_54.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.p2_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr( - name='conv2d_55.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) def forward(self, x): diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 042654a19d2d2d2f1363fedbb9ac3530696e6903..cd2b7ea745b83f6ac99a6dd15bf5bf68c34ffd35 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -21,18 +21,19 @@ import copy __all__ = ['build_post_process'] +from .db_postprocess import DBPostProcess +from .east_postprocess import EASTPostProcess +from .sast_postprocess import SASTPostProcess +from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode +from .cls_postprocess import ClsPostProcess +from .pg_postprocess import PGPostProcess -def build_post_process(config, global_config=None): - from .db_postprocess import DBPostProcess - from .east_postprocess import EASTPostProcess - from .sast_postprocess import SASTPostProcess - from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode - from .cls_postprocess import ClsPostProcess - from .pg_postprocess import PGPostProcess +def build_post_process(config, global_config=None): support_dict = [ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', - 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess' + 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', + 'DistillationCTCLabelDecode' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 85ce580f95b13539c6aeea32b188bfd3b435d140..164dec557a672842df868e2f6a01fc0fbc3e4946 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -125,6 +125,37 @@ class CTCLabelDecode(BaseRecLabelDecode): return dict_character +class DistillationCTCLabelDecode(CTCLabelDecode): + """ + Convert + Convert between text-label and text-index + """ + + def __init__(self, + character_dict_path=None, + character_type='ch', + use_space_char=False, + model_name=["student"], + key=None, + **kwargs): + super(DistillationCTCLabelDecode, self).__init__( + character_dict_path, character_type, use_space_char) + if not isinstance(model_name, list): + model_name = [model_name] + self.model_name = model_name + + self.key = key + + def __call__(self, preds, label=None, *args, **kwargs): + output = dict() + for name in self.model_name: + pred = preds[name] + if self.key is not None: + pred = pred[self.key] + output[name] = super().__call__(pred, label=label, *args, **kwargs) + return output + + class AttnLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 3d1c5c356c9510dd701048aee8cbb3e73e8a059a..23f5401bb71a2ef50ff2ff2c3c27275d7e10b3c0 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -23,6 +23,8 @@ import six import paddle +from ppocr.utils.logging import get_logger + __all__ = ['init_model', 'save_model', 'load_dygraph_pretrain'] @@ -42,44 +44,11 @@ def _mkdir_if_not_exist(path, logger): raise OSError('Failed to mkdir {}'.format(path)) -def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False): - if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): - raise ValueError("Model pretrain path {} does not " - "exists.".format(path)) - if load_static_weights: - pre_state_dict = paddle.static.load_program_state(path) - param_state_dict = {} - model_dict = model.state_dict() - for key in model_dict.keys(): - weight_name = model_dict[key].name - weight_name = weight_name.replace('binarize', '').replace( - 'thresh', '') # for DB - if weight_name in pre_state_dict.keys(): - # logger.info('Load weight: {}, shape: {}'.format( - # weight_name, pre_state_dict[weight_name].shape)) - if 'encoder_rnn' in key: - # delete axis which is 1 - pre_state_dict[weight_name] = pre_state_dict[ - weight_name].squeeze() - # change axis - if len(pre_state_dict[weight_name].shape) > 1: - pre_state_dict[weight_name] = pre_state_dict[ - weight_name].transpose((1, 0)) - param_state_dict[key] = pre_state_dict[weight_name] - else: - param_state_dict[key] = model_dict[key] - model.set_state_dict(param_state_dict) - return - - param_state_dict = paddle.load(path + '.pdparams') - model.set_state_dict(param_state_dict) - return - - -def init_model(config, model, logger, optimizer=None, lr_scheduler=None): +def init_model(config, model, optimizer=None, lr_scheduler=None): """ load model from checkpoint or pretrained_model """ + logger = get_logger() global_config = config['Global'] checkpoints = global_config.get('checkpoints') pretrained_model = global_config.get('pretrained_model') @@ -102,18 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): best_model_dict = states_dict.get('best_model_dict', {}) if 'epoch' in states_dict: best_model_dict['start_epoch'] = states_dict['epoch'] + 1 - logger.info("resume from {}".format(checkpoints)) elif pretrained_model: - load_static_weights = global_config.get('load_static_weights', False) if not isinstance(pretrained_model, list): pretrained_model = [pretrained_model] - if not isinstance(load_static_weights, list): - load_static_weights = [load_static_weights] * len(pretrained_model) - for idx, pretrained in enumerate(pretrained_model): - load_static = load_static_weights[idx] - load_dygraph_pretrain( - model, logger, path=pretrained, load_static_weights=load_static) + for pretrained in pretrained_model: + if not (os.path.isdir(pretrained) or + os.path.exists(pretrained + '.pdparams')): + raise ValueError("Model pretrain path {} does not " + "exists.".format(pretrained)) + param_state_dict = paddle.load(pretrained + '.pdparams') + model.set_state_dict(param_state_dict) logger.info("load pretrained model from {}".format( pretrained_model)) else: diff --git a/tools/eval.py b/tools/eval.py index 9817fa75093dd5127e3d11501ebc0473c9b53365..66eb315f9b37ed681f6a899613fa43c1313bc654 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -49,7 +49,7 @@ def main(): model = build_model(config['Architecture']) use_srn = config['Architecture']['algorithm'] == "SRN" - best_model_dict = init_model(config, model, logger) + best_model_dict = init_model(config, model) if len(best_model_dict): logger.info('metric in ckpt ***************') for k, v in best_model_dict.items(): diff --git a/tools/export_model.py b/tools/export_model.py index bdff89f755d465742f1c2a810f8ae76153a558c6..625c82468edff7c3eeb787422bdef07b4b274460 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -17,7 +17,7 @@ import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.append(os.path.abspath(os.path.join(__dir__, ".."))) import argparse @@ -31,32 +31,12 @@ from ppocr.utils.logging import get_logger from tools.program import load_config, merge_config, ArgsParser -def main(): - FLAGS = ArgsParser().parse_args() - config = load_config(FLAGS.config) - merge_config(FLAGS.opt) - logger = get_logger() - # build post process - - post_process_class = build_post_process(config['PostProcess'], - config['Global']) - - # build model - # for rec algorithm - if hasattr(post_process_class, 'character'): - char_num = len(getattr(post_process_class, 'character')) - config['Architecture']["Head"]['out_channels'] = char_num - model = build_model(config['Architecture']) - init_model(config, model, logger) - model.eval() - - save_path = '{}/inference'.format(config['Global']['save_inference_dir']) - - if config['Architecture']['algorithm'] == "SRN": - max_text_length = config['Architecture']['Head']['max_text_length'] +def export_single_model(model, arch_config, save_path, logger): + if arch_config["algorithm"] == "SRN": + max_text_length = arch_config["Head"]["max_text_length"] other_shape = [ paddle.static.InputSpec( - shape=[None, 1, 64, 256], dtype='float32'), [ + shape=[None, 1, 64, 256], dtype="float32"), [ paddle.static.InputSpec( shape=[None, 256, 1], dtype="int64"), paddle.static.InputSpec( @@ -71,24 +51,66 @@ def main(): model = to_static(model, input_spec=other_shape) else: infer_shape = [3, -1, -1] - if config['Architecture']['model_type'] == "rec": + if arch_config["model_type"] == "rec": infer_shape = [3, 32, -1] # for rec model, H must be 32 - if 'Transform' in config['Architecture'] and config['Architecture'][ - 'Transform'] is not None and config['Architecture'][ - 'Transform']['name'] == 'TPS': + if "Transform" in arch_config and arch_config[ + "Transform"] is not None and arch_config["Transform"][ + "name"] == "TPS": logger.info( - 'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training' + "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training" ) infer_shape[-1] = 100 + model = to_static( model, input_spec=[ paddle.static.InputSpec( - shape=[None] + infer_shape, dtype='float32') + shape=[None] + infer_shape, dtype="float32") ]) paddle.jit.save(model, save_path) - logger.info('inference model is saved to {}'.format(save_path)) + logger.info("inference model is saved to {}".format(save_path)) + return + + +def main(): + FLAGS = ArgsParser().parse_args() + config = load_config(FLAGS.config) + merge_config(FLAGS.opt) + logger = get_logger() + # build post process + + post_process_class = build_post_process(config["PostProcess"], + config["Global"]) + + # build model + # for rec algorithm + if hasattr(post_process_class, "character"): + char_num = len(getattr(post_process_class, "character")) + if config["Architecture"]["algorithm"] in ["Distillation", + ]: # distillation model + for key in config["Architecture"]["Models"]: + config["Architecture"]["Models"][key]["Head"][ + "out_channels"] = char_num + else: # base rec model + config["Architecture"]["Head"]["out_channels"] = char_num + model = build_model(config["Architecture"]) + init_model(config, model) + model.eval() + + save_path = config["Global"]["save_inference_dir"] + + arch_config = config["Architecture"] + + if arch_config["algorithm"] in ["Distillation", ]: # distillation model + archs = list(arch_config["Models"].values()) + for idx, name in enumerate(model.model_name_list): + sub_model_save_path = os.path.join(save_path, name, "inference") + export_single_model(model.model_list[idx], archs[idx], + sub_model_save_path, logger) + else: + save_path = os.path.join(save_path, "inference") + export_single_model(model, arch_config, save_path, logger) if __name__ == "__main__": diff --git a/tools/infer_cls.py b/tools/infer_cls.py index 496964826b0b063f9f937c31342932c6cd95502f..a588cab433442695e3bd395da63e35a2052de501 100755 --- a/tools/infer_cls.py +++ b/tools/infer_cls.py @@ -47,7 +47,7 @@ def main(): # build model model = build_model(config['Architecture']) - init_model(config, model, logger) + init_model(config, model) # create data ops transforms = [] diff --git a/tools/infer_det.py b/tools/infer_det.py index 913d617defea18fe881e6fd2212b1df20f7d26d3..674f52ee35aab25356ccdbf371f8bac5b52b871a 100755 --- a/tools/infer_det.py +++ b/tools/infer_det.py @@ -61,7 +61,7 @@ def main(): # build model model = build_model(config['Architecture']) - init_model(config, model, logger) + init_model(config, model) # build post process post_process_class = build_post_process(config['PostProcess']) diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py index 9c079f6074f088ef0298cab839f74faefad82abb..1cd468b8e552237af31d985b8b68ddbeecba9c96 100755 --- a/tools/infer_e2e.py +++ b/tools/infer_e2e.py @@ -68,7 +68,7 @@ def main(): # build model model = build_model(config['Architecture']) - init_model(config, model, logger) + init_model(config, model) # build post process post_process_class = build_post_process(config['PostProcess'], diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 2563f5a8197ed39b1b5d44c7cfee32797e760758..09f5a0c767b15c312cdfbe8ed695ea06bdc8cdc4 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -20,6 +20,7 @@ import numpy as np import os import sys +import json __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) @@ -46,12 +47,18 @@ def main(): # build model if hasattr(post_process_class, 'character'): - config['Architecture']["Head"]['out_channels'] = len( - getattr(post_process_class, 'character')) + char_num = len(getattr(post_process_class, 'character')) + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - init_model(config, model, logger) + init_model(config, model) # create data ops transforms = [] @@ -107,11 +114,23 @@ def main(): else: preds = model(images) post_result = post_process_class(preds) - for rec_reuslt in post_result: - logger.info('\t result: {}'.format(rec_reuslt)) - if len(rec_reuslt) >= 2: - fout.write(file + "\t" + rec_reuslt[0] + "\t" + str( - rec_reuslt[1]) + "\n") + info = None + if isinstance(post_result, dict): + rec_info = dict() + for key in post_result: + if len(post_result[key][0]) >= 2: + rec_info[key] = { + "label": post_result[key][0][0], + "score": post_result[key][0][1], + } + info = json.dumps(rec_info) + else: + if len(post_result[0]) >= 2: + info = post_result[0][0] + "\t" + str(post_result[0][1]) + + if info is not None: + logger.info("\t result: {}".format(info)) + fout.write(file + "\t" + info) logger.info("success!") diff --git a/tools/program.py b/tools/program.py index 7e54a2f8c2f1db8881aa476a309c8a8c563fcae5..7641bed749ff4bb0d58712a9f50c6a119a4f25ee 100755 --- a/tools/program.py +++ b/tools/program.py @@ -386,7 +386,7 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', - 'CLS', 'PGNet' + 'CLS', 'PGNet', 'Distillation' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' diff --git a/tools/train.py b/tools/train.py index 47358ca43da46b7eb6a04cd1f7fe284efd7e96f7..b024240b4d5d4973645336c62d3762087ec7bbeb 100755 --- a/tools/train.py +++ b/tools/train.py @@ -72,7 +72,14 @@ def main(config, device, logger, vdl_writer): # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) - config['Architecture']["Head"]['out_channels'] = char_num + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num + model = build_model(config['Architecture']) if config['Global']['distributed']: model = paddle.DataParallel(model) @@ -90,7 +97,7 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) # load pretrain model - pre_best_model_dict = init_model(config, model, logger, optimizer) + pre_best_model_dict = init_model(config, model, optimizer) logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: