提交 d5c1700f 编写于 作者: A Aurelius84

support to_static for benchmark

上级 0703bb14
...@@ -16,11 +16,14 @@ import copy ...@@ -16,11 +16,14 @@ import copy
import importlib import importlib
import paddle.nn as nn import paddle.nn as nn
from paddle.jit import to_static
from paddle.static import InputSpec
from . import backbone, gears from . import backbone, gears
from .backbone import * from .backbone import *
from .gears import build_gear from .gears import build_gear
from .utils import * from .utils import *
from ppcls.utils import logger
from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import load_dygraph_pretrain
__all__ = ["build_model", "RecModel", "DistillationModel"] __all__ = ["build_model", "RecModel", "DistillationModel"]
...@@ -34,6 +37,19 @@ def build_model(config): ...@@ -34,6 +37,19 @@ def build_model(config):
return arch return arch
def apply_to_static(config, model):
support_to_static = config['Global'].get('to_static', False)
if support_to_static:
specs = None
if 'image_shape' in config['Global']:
specs = [InputSpec([None] + config['Global']['image_shape'])]
model = to_static(model, input_spec=specs)
logger.info("Successfully to apply @to_static with specs: {}".format(
specs))
return model
class RecModel(nn.Layer): class RecModel(nn.Layer):
def __init__(self, **config): def __init__(self, **config):
super().__init__() super().__init__()
......
...@@ -14,6 +14,8 @@ Global: ...@@ -14,6 +14,8 @@ Global:
# used for static mode and model export # used for static mode and model export
image_shape: [3, 224, 224] image_shape: [3, 224, 224]
save_inference_dir: ./inference save_inference_dir: ./inference
# training model under @to_static
to_static: False
# model architecture # model architecture
Arch: Arch:
......
...@@ -14,6 +14,8 @@ Global: ...@@ -14,6 +14,8 @@ Global:
# used for static mode and model export # used for static mode and model export
image_shape: [3, 224, 224] image_shape: [3, 224, 224]
save_inference_dir: ./inference save_inference_dir: ./inference
# training model under @to_static
to_static: False
# model architecture # model architecture
Arch: Arch:
......
...@@ -14,6 +14,8 @@ Global: ...@@ -14,6 +14,8 @@ Global:
# used for static mode and model export # used for static mode and model export
image_shape: [3, 224, 224] image_shape: [3, 224, 224]
save_inference_dir: ./inference save_inference_dir: ./inference
# training model under @to_static
to_static: False
# model architecture # model architecture
Arch: Arch:
......
...@@ -14,6 +14,8 @@ Global: ...@@ -14,6 +14,8 @@ Global:
# used for static mode and model export # used for static mode and model export
image_shape: [3, 224, 224] image_shape: [3, 224, 224]
save_inference_dir: ./inference save_inference_dir: ./inference
# training model under @to_static
to_static: False
# model architecture # model architecture
Arch: Arch:
......
...@@ -14,6 +14,8 @@ Global: ...@@ -14,6 +14,8 @@ Global:
# used for static mode and model export # used for static mode and model export
image_shape: [3, 224, 224] image_shape: [3, 224, 224]
save_inference_dir: ./inference save_inference_dir: ./inference
# training model under @to_static
to_static: False
# model architecture # model architecture
Arch: Arch:
......
...@@ -34,6 +34,7 @@ from ppcls.utils.logger import init_logger ...@@ -34,6 +34,7 @@ from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config from ppcls.utils.config import print_config
from ppcls.data import build_dataloader from ppcls.data import build_dataloader
from ppcls.arch import build_model from ppcls.arch import build_model
from ppcls.arch import apply_to_static
from ppcls.loss import build_loss from ppcls.loss import build_loss
from ppcls.metric import build_metrics from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer from ppcls.optimizer import build_optimizer
...@@ -73,6 +74,8 @@ class Trainer(object): ...@@ -73,6 +74,8 @@ class Trainer(object):
self.is_rec = False self.is_rec = False
self.model = build_model(self.config["Arch"]) self.model = build_model(self.config["Arch"])
# set @to_static for benchmark, skip this by default.
apply_to_static(self.config, self.model)
if self.config["Global"]["pretrained_model"] is not None: if self.config["Global"]["pretrained_model"] is not None:
load_dygraph_pretrain(self.model, load_dygraph_pretrain(self.model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册