From 0b6290eaf2c00bbca9fe318fb494e739ca60eec0 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Mon, 16 May 2022 16:57:13 +0800 Subject: [PATCH] add support for static training (#6297) * add support for static training * fix assert info --- ppocr/modeling/architectures/__init__.py | 20 +++++++++++++++++++- tools/train.py | 3 +++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/ppocr/modeling/architectures/__init__.py b/ppocr/modeling/architectures/__init__.py index e9a01cf0..3f47f64a 100755 --- a/ppocr/modeling/architectures/__init__.py +++ b/ppocr/modeling/architectures/__init__.py @@ -15,10 +15,13 @@ import copy import importlib +from paddle.jit import to_static +from paddle.static import InputSpec + from .base_model import BaseModel from .distillation_model import DistillationModel -__all__ = ['build_model'] +__all__ = ["build_model", "apply_to_static"] def build_model(config): @@ -30,3 +33,18 @@ def build_model(config): mod = importlib.import_module(__name__) arch = getattr(mod, name)(config) return arch + + +def apply_to_static(model, config, logger): + if config["Global"].get("to_static", False) is not True: + return model + assert "image_shape" in config[ + "Global"], "image_shape must be assigned for static training mode..." + supported_list = ["DB"] + assert config["Architecture"][ + "algorithm"] in supported_list, f"algorithms that supports static training must in in {supported_list} but got {config['Architecture']['algorithm']}" + + specs = [InputSpec([None] + config["Global"]["image_shape"])] + model = to_static(model, input_spec=specs) + logger.info("Successfully to apply @to_static with specs: {}".format(specs)) + return model diff --git a/tools/train.py b/tools/train.py index 42aba548..b7c25e34 100755 --- a/tools/train.py +++ b/tools/train.py @@ -35,6 +35,7 @@ from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric from ppocr.utils.save_load import load_model from ppocr.utils.utility import set_seed +from ppocr.modeling.architectures import apply_to_static import tools.program as program dist.get_world_size() @@ -121,6 +122,8 @@ def main(config, device, logger, vdl_writer): if config['Global']['distributed']: model = paddle.DataParallel(model) + model = apply_to_static(model, config, logger) + # build loss loss_class = build_loss(config['Loss']) -- GitLab