diff --git a/ppocr/modeling/architectures/__init__.py b/ppocr/modeling/architectures/__init__.py index e9a01cf0281b91d29f2cce88375be3aaf43feb2e..3f47f64a0f01c2267c7ff4aecb3815915b24dadb 100755 --- a/ppocr/modeling/architectures/__init__.py +++ b/ppocr/modeling/architectures/__init__.py @@ -15,10 +15,13 @@ import copy import importlib +from paddle.jit import to_static +from paddle.static import InputSpec + from .base_model import BaseModel from .distillation_model import DistillationModel -__all__ = ['build_model'] +__all__ = ["build_model", "apply_to_static"] def build_model(config): @@ -30,3 +33,18 @@ def build_model(config): mod = importlib.import_module(__name__) arch = getattr(mod, name)(config) return arch + + +def apply_to_static(model, config, logger): + if config["Global"].get("to_static", False) is not True: + return model + assert "image_shape" in config[ + "Global"], "image_shape must be assigned for static training mode..." + supported_list = ["DB"] + assert config["Architecture"][ + "algorithm"] in supported_list, f"algorithms that supports static training must in in {supported_list} but got {config['Architecture']['algorithm']}" + + specs = [InputSpec([None] + config["Global"]["image_shape"])] + model = to_static(model, input_spec=specs) + logger.info("Successfully to apply @to_static with specs: {}".format(specs)) + return model diff --git a/tools/train.py b/tools/train.py index 42aba548d6bf5fc35f033ef2baca0fb54d79e75a..b7c25e34231fb650fd2c7c89dc17320f561962f9 100755 --- a/tools/train.py +++ b/tools/train.py @@ -35,6 +35,7 @@ from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric from ppocr.utils.save_load import load_model from ppocr.utils.utility import set_seed +from ppocr.modeling.architectures import apply_to_static import tools.program as program dist.get_world_size() @@ -121,6 +122,8 @@ def main(config, device, logger, vdl_writer): if config['Global']['distributed']: model = paddle.DataParallel(model) + model = apply_to_static(model, config, logger) + # build loss loss_class = build_loss(config['Loss'])