未验证 提交 0b6290ea 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

add support for static training (#6297)

* add support for static training

* fix assert info
上级 0571a41d
...@@ -15,10 +15,13 @@ ...@@ -15,10 +15,13 @@
import copy import copy
import importlib import importlib
from paddle.jit import to_static
from paddle.static import InputSpec
from .base_model import BaseModel from .base_model import BaseModel
from .distillation_model import DistillationModel from .distillation_model import DistillationModel
__all__ = ['build_model'] __all__ = ["build_model", "apply_to_static"]
def build_model(config): def build_model(config):
...@@ -30,3 +33,18 @@ def build_model(config): ...@@ -30,3 +33,18 @@ def build_model(config):
mod = importlib.import_module(__name__) mod = importlib.import_module(__name__)
arch = getattr(mod, name)(config) arch = getattr(mod, name)(config)
return arch return arch
def apply_to_static(model, config, logger):
if config["Global"].get("to_static", False) is not True:
return model
assert "image_shape" in config[
"Global"], "image_shape must be assigned for static training mode..."
supported_list = ["DB"]
assert config["Architecture"][
"algorithm"] in supported_list, f"algorithms that supports static training must in in {supported_list} but got {config['Architecture']['algorithm']}"
specs = [InputSpec([None] + config["Global"]["image_shape"])]
model = to_static(model, input_spec=specs)
logger.info("Successfully to apply @to_static with specs: {}".format(specs))
return model
...@@ -35,6 +35,7 @@ from ppocr.postprocess import build_post_process ...@@ -35,6 +35,7 @@ from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import load_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import set_seed from ppocr.utils.utility import set_seed
from ppocr.modeling.architectures import apply_to_static
import tools.program as program import tools.program as program
dist.get_world_size() dist.get_world_size()
...@@ -121,6 +122,8 @@ def main(config, device, logger, vdl_writer): ...@@ -121,6 +122,8 @@ def main(config, device, logger, vdl_writer):
if config['Global']['distributed']: if config['Global']['distributed']:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
model = apply_to_static(model, config, logger)
# build loss # build loss
loss_class = build_loss(config['Loss']) loss_class = build_loss(config['Loss'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册