diff --git a/doc/doc_ch/models_list.md b/doc/doc_ch/models_list.md index 318d5874f5e01390976723ccdb98012b95a6eb7f..c36cac037e1f90a41da24bc64cacbbb860e04c6b 100644 --- a/doc/doc_ch/models_list.md +++ b/doc/doc_ch/models_list.md @@ -22,7 +22,7 @@ PaddleOCR提供的可下载模型包括`推理模型`、`训练模型`、`预训 |模型类型|模型格式|简介| |--- | --- | --- | -|推理模型|inference.pdmodel、inference.pdiparams|用于预测引擎推理,[详情](./inference.md)| +|推理模型|inference.pdmodel、inference.pdiparams|用于预测引擎推理,[详情](./inference_ppocr.md)| |训练模型、预训练模型|\*.pdparams、\*.pdopt、\*.states |训练过程中保存的模型的参数、优化器状态和训练中间信息,多用于模型指标评估和恢复训练| |nb模型|\*.nb|经过飞桨Paddle-Lite工具优化后的模型,适用于移动端/IoT端等端侧部署场景(需使用飞桨Paddle Lite部署)。| @@ -114,7 +114,7 @@ PaddleOCR提供的可下载模型包括`推理模型`、`训练模型`、`预训 | ka_PP-OCRv3_rec | ppocr/utils/dict/ka_dict.txt |卡纳达文识别|[ka_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/multi_language/ka_PP-OCRv3_rec.yml)|9.9M|[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/ka_PP-OCRv3_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/ka_PP-OCRv3_rec_train.tar) | | ta_PP-OCRv3_rec | ppocr/utils/dict/ta_dict.txt |泰米尔文识别|[ta_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/multi_language/ta_PP-OCRv3_rec.yml)|9.6M|[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/ta_PP-OCRv3_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/ta_PP-OCRv3_rec_train.tar) | | latin_PP-OCRv3_rec | ppocr/utils/dict/latin_dict.txt | 拉丁文识别 | [latin_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/multi_language/latin_PP-OCRv3_rec.yml) |9.7M|[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/latin_PP-OCRv3_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/latin_PP-OCRv3_rec_train.tar) | -| arabic_PP-OCRv3_rec | ppocr/utils/dict/arabic_dict.txt | 阿拉伯字母 | [arabic_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/multi_language/rec_arabic_lite_train.yml) |9.6M|[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/arabic_PP-OCRv3_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/arabic_PP-OCRv3_rec_train.tar) | +| arabic_PP-OCRv3_rec | ppocr/utils/dict/arabic_dict.txt | 阿拉伯字母 | [arabic_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/multi_language/arabic_PP-OCRv3_rec.yml) |9.6M|[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/arabic_PP-OCRv3_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/arabic_PP-OCRv3_rec_train.tar) | | cyrillic_PP-OCRv3_rec | ppocr/utils/dict/cyrillic_dict.txt | 斯拉夫字母 | [cyrillic_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/multi_language/cyrillic_PP-OCRv3_rec.yml) |9.6M|[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/cyrillic_PP-OCRv3_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/cyrillic_PP-OCRv3_rec_train.tar) | | devanagari_PP-OCRv3_rec | ppocr/utils/dict/devanagari_dict.txt |梵文字母 | [devanagari_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/multi_language/devanagari_PP-OCRv3_rec.yml) |9.9M|[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/devanagari_PP-OCRv3_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/devanagari_PP-OCRv3_rec_train.tar) | diff --git a/doc/doc_en/models_list_en.md b/doc/doc_en/models_list_en.md index 8e8c1f2fe11bcd0748d556d34fd184fed4b3a86f..c52f71dfe4124302b8cb308980a6228a89589bd6 100644 --- a/doc/doc_en/models_list_en.md +++ b/doc/doc_en/models_list_en.md @@ -20,7 +20,7 @@ The downloadable models provided by PaddleOCR include `inference model`, `traine |model type|model format|description| |--- | --- | --- | -|inference model|inference.pdmodel、inference.pdiparams|Used for inference based on Paddle inference engine,[detail](./inference_en.md)| +|inference model|inference.pdmodel、inference.pdiparams|Used for inference based on Paddle inference engine,[detail](./inference_ppocr_en.md)| |trained model, pre-trained model|\*.pdparams、\*.pdopt、\*.states |The checkpoints model saved in the training process, which stores the parameters of the model, mostly used for model evaluation and continuous training.| |nb model|\*.nb| Model optimized by Paddle-Lite, which is suitable for mobile-side deployment scenarios (Paddle-Lite is needed for nb model deployment). | @@ -37,7 +37,7 @@ Relationship of the above models is as follows. |model name|description|config|model size|download| | --- | --- | --- | --- | --- | -|ch_PP-OCRv3_det_slim| [New] slim quantization with distillation lightweight model, supporting Chinese, English, multilingual text detection |[ch_PP-OCRv3_det_cml.yml](../../configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml)| 1.1M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/ch/ch_PP-OCRv3_det_slim_distill_train.tar) / [nb model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.nb)| +|ch_PP-OCRv3_det_slim| [New] slim quantization with distillation lightweight model, supporting Chinese, English, multilingual text detection |[ch_PP-OCRv3_det_cml.yml](../../configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml)| 1.1M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_distill_train.tar) / [nb model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.nb)| |ch_PP-OCRv3_det| [New] Original lightweight model, supporting Chinese, English, multilingual text detection |[ch_PP-OCRv3_det_cml.yml](../../configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml)| 3.8M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar)| |ch_PP-OCRv2_det_slim| [New] slim quantization with distillation lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml)| 3M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)| |ch_PP-OCRv2_det| [New] Original lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml)|3M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar)| @@ -75,7 +75,7 @@ Relationship of the above models is as follows. |model name|description|config|model size|download| | --- | --- | --- | --- | --- | -|ch_PP-OCRv3_rec_slim | [New] Slim qunatization with distillation lightweight model, supporting Chinese, English text recognition |[ch_PP-OCRv3_rec_distillation.yml](../../configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml)| 4.9M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/ch/ch_PP-OCRv3_rec_slim_train.tar) / [nb model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.nb) | +|ch_PP-OCRv3_rec_slim | [New] Slim qunatization with distillation lightweight model, supporting Chinese, English text recognition |[ch_PP-OCRv3_rec_distillation.yml](../../configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml)| 4.9M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_train.tar) / [nb model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.nb) | |ch_PP-OCRv3_rec| [New] Original lightweight model, supporting Chinese, English, multilingual text recognition |[ch_PP-OCRv3_rec_distillation.yml](../../configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml)| 12.4M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar) | |ch_PP-OCRv2_rec_slim| Slim qunatization with distillation lightweight model, supporting Chinese, English text recognition|[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml)| 9M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_train.tar) | |ch_PP-OCRv2_rec| Original lightweight model, supporting Chinese, English, multilingual text recognition |[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)|8.5M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar) | @@ -91,7 +91,7 @@ Relationship of the above models is as follows. |model name|description|config|model size|download| | --- | --- | --- | --- | --- | -|en_PP-OCRv3_rec_slim | [New] Slim qunatization with distillation lightweight model, supporting english, English text recognition |[en_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/en_PP-OCRv3_rec.yml)| 3.2M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/english/PP-OCRv3_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_slim_train.tar) / [nb model](https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_slim_infer.nb) | +|en_PP-OCRv3_rec_slim | [New] Slim qunatization with distillation lightweight model, supporting english, English text recognition |[en_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/en_PP-OCRv3_rec.yml)| 3.2M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_slim_train.tar) / [nb model](https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_slim_infer.nb) | |en_PP-OCRv3_rec| [New] Original lightweight model, supporting english, English, multilingual text recognition |[en_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/en_PP-OCRv3_rec.yml)| 9.6M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_train.tar) | |en_number_mobile_slim_v2.0_rec|Slim pruned and quantized lightweight model, supporting English and number recognition|[rec_en_number_lite_train.yml](../../configs/rec/multi_language/rec_en_number_lite_train.yml)| 2.7M | [inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/en_number_mobile_v2.0_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/en_number_mobile_v2.0_rec_slim_train.tar) | |en_number_mobile_v2.0_rec|Original lightweight model, supporting English and number recognition|[rec_en_number_lite_train.yml](../../configs/rec/multi_language/rec_en_number_lite_train.yml)|2.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_train.tar) | @@ -108,7 +108,7 @@ Relationship of the above models is as follows. | ka_PP-OCRv3_rec | ppocr/utils/dict/ka_dict.txt | Lightweight model for Kannada recognition |[ka_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/multi_language/ka_PP-OCRv3_rec.yml)|9.9M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/ka_PP-OCRv3_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/ka_PP-OCRv3_rec_train.tar) | | ta_PP-OCRv3_rec | ppocr/utils/dict/ta_dict.txt |Lightweight model for Tamil recognition|[ta_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/multi_language/ta_PP-OCRv3_rec.yml)|9.6M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/ta_PP-OCRv3_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/ta_PP-OCRv3_rec_train.tar) | | latin_PP-OCRv3_rec | ppocr/utils/dict/latin_dict.txt | Lightweight model for latin recognition | [latin_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/multi_language/latin_PP-OCRv3_rec.yml) |9.7M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/latin_PP-OCRv3_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/latin_PP-OCRv3_rec_train.tar) | -| arabic_PP-OCRv3_rec | ppocr/utils/dict/arabic_dict.txt | Lightweight model for arabic recognition | [arabic_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/multi_language/rec_arabic_lite_train.yml) |9.6M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/arabic_PP-OCRv3_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/arabic_PP-OCRv3_rec_train.tar) | +| arabic_PP-OCRv3_rec | ppocr/utils/dict/arabic_dict.txt | Lightweight model for arabic recognition | [arabic_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/multi_language/arabic_PP-OCRv3_rec.yml) |9.6M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/arabic_PP-OCRv3_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/arabic_PP-OCRv3_rec_train.tar) | | cyrillic_PP-OCRv3_rec | ppocr/utils/dict/cyrillic_dict.txt | Lightweight model for cyrillic recognition | [cyrillic_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/multi_language/cyrillic_PP-OCRv3_rec.yml) |9.6M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/cyrillic_PP-OCRv3_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/cyrillic_PP-OCRv3_rec_train.tar) | | devanagari_PP-OCRv3_rec | ppocr/utils/dict/devanagari_dict.txt | Lightweight model for devanagari recognition | [devanagari_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/multi_language/devanagari_PP-OCRv3_rec.yml) |9.9M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/devanagari_PP-OCRv3_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/multilingual/devanagari_PP-OCRv3_rec_train.tar) | diff --git a/ppocr/modeling/architectures/__init__.py b/ppocr/modeling/architectures/__init__.py index e9a01cf0281b91d29f2cce88375be3aaf43feb2e..1c955ef3abe9c38e816616cc9b5399c6832aa5f1 100755 --- a/ppocr/modeling/architectures/__init__.py +++ b/ppocr/modeling/architectures/__init__.py @@ -15,10 +15,13 @@ import copy import importlib +from paddle.jit import to_static +from paddle.static import InputSpec + from .base_model import BaseModel from .distillation_model import DistillationModel -__all__ = ['build_model'] +__all__ = ["build_model", "apply_to_static"] def build_model(config): @@ -30,3 +33,36 @@ def build_model(config): mod = importlib.import_module(__name__) arch = getattr(mod, name)(config) return arch + + +def apply_to_static(model, config, logger): + if config["Global"].get("to_static", False) is not True: + return model + assert "image_shape" in config[ + "Global"], "image_shape must be assigned for static training mode..." + supported_list = ["DB", "SVTR"] + if config["Architecture"]["algorithm"] in ["Distillation"]: + algo = list(config["Architecture"]["Models"].values())[0]["algorithm"] + else: + algo = config["Architecture"]["algorithm"] + assert algo in supported_list, f"algorithms that supports static training must in in {supported_list} but got {algo}" + + specs = [ + InputSpec( + [None] + config["Global"]["image_shape"], dtype='float32') + ] + + if algo == "SVTR": + specs.append([ + InputSpec( + [None, config["Global"]["max_text_length"]], + dtype='int64'), InputSpec( + [None, config["Global"]["max_text_length"]], dtype='int64'), + InputSpec( + [None], dtype='int64'), InputSpec( + [None], dtype='float64') + ]) + + model = to_static(model, input_spec=specs) + logger.info("Successfully to apply @to_static with specs: {}".format(specs)) + return model diff --git a/ppocr/modeling/heads/rec_sar_head.py b/ppocr/modeling/heads/rec_sar_head.py index 0e6b34404b61b44bebcbc7d67ddfd0a95382c39b..5e64cae85afafc555f2519ed6dd3f05eafff7ea2 100644 --- a/ppocr/modeling/heads/rec_sar_head.py +++ b/ppocr/modeling/heads/rec_sar_head.py @@ -83,7 +83,7 @@ class SAREncoder(nn.Layer): def forward(self, feat, img_metas=None): if img_metas is not None: - assert len(img_metas[0]) == feat.shape[0] + assert len(img_metas[0]) == paddle.shape(feat)[0] valid_ratios = None if img_metas is not None and self.mask: @@ -98,9 +98,10 @@ class SAREncoder(nn.Layer): if valid_ratios is not None: valid_hf = [] - T = holistic_feat.shape[1] - for i in range(len(valid_ratios)): - valid_step = min(T, math.ceil(T * valid_ratios[i])) - 1 + T = paddle.shape(holistic_feat)[1] + for i in range(paddle.shape(valid_ratios)[0]): + valid_step = paddle.minimum( + T, paddle.ceil(valid_ratios[i] * T).astype('int32')) - 1 valid_hf.append(holistic_feat[i, valid_step, :]) valid_hf = paddle.stack(valid_hf, axis=0) else: @@ -247,13 +248,14 @@ class ParallelSARDecoder(BaseDecoder): # bsz * (seq_len + 1) * h * w * attn_size attn_weight = self.conv1x1_2(attn_weight) # bsz * (seq_len + 1) * h * w * 1 - bsz, T, h, w, c = attn_weight.shape + bsz, T, h, w, c = paddle.shape(attn_weight) assert c == 1 if valid_ratios is not None: # cal mask of attention weight - for i in range(len(valid_ratios)): - valid_width = min(w, math.ceil(w * valid_ratios[i])) + for i in range(paddle.shape(valid_ratios)[0]): + valid_width = paddle.minimum( + w, paddle.ceil(valid_ratios[i] * w).astype("int32")) if valid_width < w: attn_weight[i, :, :, valid_width:, :] = float('-inf') @@ -288,7 +290,7 @@ class ParallelSARDecoder(BaseDecoder): img_metas: [label, valid_ratio] ''' if img_metas is not None: - assert len(img_metas[0]) == feat.shape[0] + assert paddle.shape(img_metas[0])[0] == paddle.shape(feat)[0] valid_ratios = None if img_metas is not None and self.mask: @@ -302,7 +304,6 @@ class ParallelSARDecoder(BaseDecoder): # bsz * (seq_len + 1) * C out_dec = self._2d_attention( in_dec, feat, out_enc, valid_ratios=valid_ratios) - # bsz * (seq_len + 1) * num_classes return out_dec[:, 1:, :] # bsz * seq_len * num_classes @@ -395,7 +396,6 @@ class SARHead(nn.Layer): if self.training: label = targets[0] # label - label = paddle.to_tensor(label, dtype='int64') final_out = self.decoder( feat, holistic_feat, label, img_metas=targets) else: diff --git a/tools/train.py b/tools/train.py index 42aba548d6bf5fc35f033ef2baca0fb54d79e75a..b7c25e34231fb650fd2c7c89dc17320f561962f9 100755 --- a/tools/train.py +++ b/tools/train.py @@ -35,6 +35,7 @@ from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric from ppocr.utils.save_load import load_model from ppocr.utils.utility import set_seed +from ppocr.modeling.architectures import apply_to_static import tools.program as program dist.get_world_size() @@ -121,6 +122,8 @@ def main(config, device, logger, vdl_writer): if config['Global']['distributed']: model = paddle.DataParallel(model) + model = apply_to_static(model, config, logger) + # build loss loss_class = build_loss(config['Loss'])