提交 d5c1700f 编写于 作者: A Aurelius84

support to_static for benchmark

上级 0703bb14
......@@ -16,11 +16,14 @@ import copy
import importlib
import paddle.nn as nn
from paddle.jit import to_static
from paddle.static import InputSpec
from . import backbone, gears
from .backbone import *
from .gears import build_gear
from .utils import *
from ppcls.utils import logger
from ppcls.utils.save_load import load_dygraph_pretrain
__all__ = ["build_model", "RecModel", "DistillationModel"]
......@@ -34,6 +37,19 @@ def build_model(config):
return arch
def apply_to_static(config, model):
support_to_static = config['Global'].get('to_static', False)
if support_to_static:
specs = None
if 'image_shape' in config['Global']:
specs = [InputSpec([None] + config['Global']['image_shape'])]
model = to_static(model, input_spec=specs)
logger.info("Successfully to apply @to_static with specs: {}".format(
specs))
return model
class RecModel(nn.Layer):
def __init__(self, **config):
super().__init__()
......
......@@ -14,6 +14,8 @@ Global:
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
# model architecture
Arch:
......
......@@ -14,6 +14,8 @@ Global:
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
# model architecture
Arch:
......
......@@ -14,6 +14,8 @@ Global:
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
# model architecture
Arch:
......
......@@ -14,6 +14,8 @@ Global:
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
# model architecture
Arch:
......
......@@ -14,6 +14,8 @@ Global:
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
# model architecture
Arch:
......
......@@ -34,6 +34,7 @@ from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
from ppcls.data import build_dataloader
from ppcls.arch import build_model
from ppcls.arch import apply_to_static
from ppcls.loss import build_loss
from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer
......@@ -73,6 +74,8 @@ class Trainer(object):
self.is_rec = False
self.model = build_model(self.config["Arch"])
# set @to_static for benchmark, skip this by default.
apply_to_static(self.config, self.model)
if self.config["Global"]["pretrained_model"] is not None:
load_dygraph_pretrain(self.model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册