未验证 提交 eb27d8b7 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Add build_strategy in @to_static to support open pass (#34347)

* Add build_strategy in @to_static to support open pass

* fix os.environ

* add timeout

* disable test_build_strategy on openblas
上级 cf12ea51
...@@ -131,12 +131,16 @@ class PartialProgramLayer: ...@@ -131,12 +131,16 @@ class PartialProgramLayer:
Layer: A Layer object that run all ops internally in static mode. Layer: A Layer object that run all ops internally in static mode.
""" """
def __init__(self, main_program, inputs, outputs, parameters=None): def __init__(self, main_program, inputs, outputs, parameters=None,
**kwargs):
super(PartialProgramLayer, self).__init__() super(PartialProgramLayer, self).__init__()
self._inputs = NestSequence(inputs) self._inputs = NestSequence(inputs)
self._outputs = NestSequence(outputs, need_check=True) self._outputs = NestSequence(outputs, need_check=True)
self._params = parameters if parameters is not None else [] self._params = parameters if parameters is not None else []
self._build_strategy = kwargs.get('build_strategy', BuildStrategy())
assert isinstance(self._build_strategy, BuildStrategy)
self._origin_main_program = self._verify_program(main_program) self._origin_main_program = self._verify_program(main_program)
self._tmp_scope_vec = self._create_scope_vec() self._tmp_scope_vec = self._create_scope_vec()
# A fake_var to handle empty input or output # A fake_var to handle empty input or output
...@@ -170,7 +174,11 @@ class PartialProgramLayer: ...@@ -170,7 +174,11 @@ class PartialProgramLayer:
@LazyInitialized @LazyInitialized
def _train_program_id(self): def _train_program_id(self):
return _hash_with_id(self._train_program, self) program_id = _hash_with_id(self._train_program, self)
core._set_cached_executor_build_strategy(program_id,
self._build_strategy)
return program_id
def _verify_program(self, main_program): def _verify_program(self, main_program):
""" """
...@@ -451,6 +459,6 @@ def partial_program_from(concrete_program): ...@@ -451,6 +459,6 @@ def partial_program_from(concrete_program):
if inputs and isinstance(inputs[0], layers.Layer): if inputs and isinstance(inputs[0], layers.Layer):
inputs = inputs[1:] inputs = inputs[1:]
return PartialProgramLayer(concrete_program.main_program, inputs, return PartialProgramLayer(
concrete_program.outputs, concrete_program.main_program, inputs, concrete_program.outputs,
concrete_program.parameters) concrete_program.parameters, **concrete_program.kwargs)
...@@ -145,14 +145,13 @@ class CacheKey(object): ...@@ -145,14 +145,13 @@ class CacheKey(object):
""" """
Cached key for ProgramCache. Cached key for ProgramCache.
""" """
__slots__ = [ __slots__ = [
'function_spec', 'input_args_with_spec', 'input_kwargs_with_spec', 'function_spec', 'input_args_with_spec', 'input_kwargs_with_spec',
'class_instance' 'class_instance', 'kwargs'
] ]
def __init__(self, function_spec, input_args_with_spec, def __init__(self, function_spec, input_args_with_spec,
input_kwargs_with_spec, class_instance): input_kwargs_with_spec, class_instance, **kwargs):
""" """
Initializes a cache key. Initializes a cache key.
...@@ -161,11 +160,14 @@ class CacheKey(object): ...@@ -161,11 +160,14 @@ class CacheKey(object):
input_args_with_spec(list[InputSpec]): actual input args with some arguments replaced by InputSpec. input_args_with_spec(list[InputSpec]): actual input args with some arguments replaced by InputSpec.
input_kwargs_with_spec(list[{string:InputSpec}]): actual input kwargs with some arguments replaced by InputSpec. input_kwargs_with_spec(list[{string:InputSpec}]): actual input kwargs with some arguments replaced by InputSpec.
class_instance(object): a instance of class `Layer`. class_instance(object): a instance of class `Layer`.
**kwargs(dict): manage other arguments used for better scalability
""" """
self.function_spec = function_spec self.function_spec = function_spec
self.input_args_with_spec = input_args_with_spec self.input_args_with_spec = input_args_with_spec
self.input_kwargs_with_spec = input_kwargs_with_spec self.input_kwargs_with_spec = input_kwargs_with_spec
self.class_instance = class_instance self.class_instance = class_instance
# NOTE: `kwargs` is usually not considered as basic member for `__hash__`
self.kwargs = kwargs
@classmethod @classmethod
def from_func_and_args(cls, function_spec, args, kwargs, class_instance): def from_func_and_args(cls, function_spec, args, kwargs, class_instance):
...@@ -235,13 +237,14 @@ class StaticFunction(object): ...@@ -235,13 +237,14 @@ class StaticFunction(object):
""" """
def __init__(self, function, input_spec=None): def __init__(self, function, input_spec=None, **kwargs):
""" """
Initializes a `StaticFunction`. Initializes a `StaticFunction`.
Args: Args:
function(callable): A function or method that will be converted into static program. function(callable): A function or method that will be converted into static program.
input_spec(list[InputSpec]): list of InputSpec to specify the `shape/dtype/name` information for each input argument, default None. input_spec(list[InputSpec]): list of InputSpec to specify the `shape/dtype/name` information for each input argument, default None.
**kwargs(dict): other arguments like `build_strategy` et.al.
""" """
# save the instance `self` while decorating a method of class. # save the instance `self` while decorating a method of class.
if inspect.ismethod(function): if inspect.ismethod(function):
...@@ -257,6 +260,7 @@ class StaticFunction(object): ...@@ -257,6 +260,7 @@ class StaticFunction(object):
self._descriptor_cache = weakref.WeakKeyDictionary() self._descriptor_cache = weakref.WeakKeyDictionary()
# Note: Hold a reference to ProgramTranslator for switching `enable_to_static`. # Note: Hold a reference to ProgramTranslator for switching `enable_to_static`.
self._program_trans = ProgramTranslator() self._program_trans = ProgramTranslator()
self._kwargs = kwargs
def __get__(self, instance, owner): def __get__(self, instance, owner):
""" """
...@@ -395,7 +399,8 @@ class StaticFunction(object): ...@@ -395,7 +399,8 @@ class StaticFunction(object):
# 2. generate cache key # 2. generate cache key
cache_key = CacheKey(self._function_spec, input_args_with_spec, cache_key = CacheKey(self._function_spec, input_args_with_spec,
input_kwargs_with_spec, self._class_instance) input_kwargs_with_spec, self._class_instance,
**self._kwargs)
# 3. check whether hit the cache or build a new program for the input arguments # 3. check whether hit the cache or build a new program for the input arguments
concrete_program, partial_program_layer = self._program_cache[cache_key] concrete_program, partial_program_layer = self._program_cache[cache_key]
...@@ -586,7 +591,7 @@ class ConcreteProgram(object): ...@@ -586,7 +591,7 @@ class ConcreteProgram(object):
__slots__ = [ __slots__ = [
'inputs', 'outputs', 'main_program', "startup_program", "parameters", 'inputs', 'outputs', 'main_program', "startup_program", "parameters",
"function" "function", 'kwargs'
] ]
def __init__(self, def __init__(self,
...@@ -595,18 +600,20 @@ class ConcreteProgram(object): ...@@ -595,18 +600,20 @@ class ConcreteProgram(object):
parameters, parameters,
function, function,
main_program, main_program,
startup_program=None): startup_program=None,
**kwargs):
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
self.main_program = main_program self.main_program = main_program
self.startup_program = startup_program self.startup_program = startup_program
self.parameters = parameters self.parameters = parameters
self.function = function self.function = function
self.kwargs = kwargs
@staticmethod @staticmethod
@switch_to_static_graph @switch_to_static_graph
def from_func_spec(func_spec, input_spec, input_kwargs_spec, def from_func_spec(func_spec, input_spec, input_kwargs_spec, class_instance,
class_instance): **kwargs):
""" """
Builds the main_program with specialized inputs and returns outputs Builds the main_program with specialized inputs and returns outputs
of program as fetch_list. of program as fetch_list.
...@@ -635,8 +642,8 @@ class ConcreteProgram(object): ...@@ -635,8 +642,8 @@ class ConcreteProgram(object):
# 1. Adds `fluid.data` layers for input if needed # 1. Adds `fluid.data` layers for input if needed
inputs = func_spec.to_static_inputs_with_spec(input_spec, inputs = func_spec.to_static_inputs_with_spec(input_spec,
main_program) main_program)
kwargs = func_spec.to_static_inputs_with_spec(input_kwargs_spec, _kwargs = func_spec.to_static_inputs_with_spec(
main_program) input_kwargs_spec, main_program)
if class_instance: if class_instance:
inputs = tuple([class_instance] + list(inputs)) inputs = tuple([class_instance] + list(inputs))
...@@ -649,8 +656,8 @@ class ConcreteProgram(object): ...@@ -649,8 +656,8 @@ class ConcreteProgram(object):
class_instance, False)), param_guard( class_instance, False)), param_guard(
get_buffers(class_instance, False)): get_buffers(class_instance, False)):
try: try:
if kwargs: if _kwargs:
outputs = static_func(*inputs, **kwargs) outputs = static_func(*inputs, **_kwargs)
else: else:
outputs = static_func(*inputs) outputs = static_func(*inputs)
except BaseException as e: except BaseException as e:
...@@ -675,7 +682,8 @@ class ConcreteProgram(object): ...@@ -675,7 +682,8 @@ class ConcreteProgram(object):
parameters=all_parameters_and_buffers, parameters=all_parameters_and_buffers,
function=dygraph_function, function=dygraph_function,
main_program=main_program, main_program=main_program,
startup_program=startup_program) startup_program=startup_program,
**kwargs)
def _extract_indeed_params_buffers(class_instance): def _extract_indeed_params_buffers(class_instance):
...@@ -702,7 +710,8 @@ class ProgramCache(object): ...@@ -702,7 +710,8 @@ class ProgramCache(object):
func_spec=cache_key.function_spec, func_spec=cache_key.function_spec,
input_spec=cache_key.input_args_with_spec, input_spec=cache_key.input_args_with_spec,
input_kwargs_spec=cache_key.input_kwargs_with_spec, input_kwargs_spec=cache_key.input_kwargs_with_spec,
class_instance=cache_key.class_instance) class_instance=cache_key.class_instance,
**cache_key.kwargs)
return concrete_program, partial_program_from(concrete_program) return concrete_program, partial_program_from(concrete_program)
def __getitem__(self, item): def __getitem__(self, item):
......
...@@ -158,7 +158,7 @@ def copy_decorator_attrs(original_func, decorated_obj): ...@@ -158,7 +158,7 @@ def copy_decorator_attrs(original_func, decorated_obj):
return decorated_obj return decorated_obj
def declarative(function=None, input_spec=None): def declarative(function=None, input_spec=None, build_strategy=None):
""" """
Converts imperative dygraph APIs into declarative function APIs. Decorator Converts imperative dygraph APIs into declarative function APIs. Decorator
@declarative handles the Program and Executor of static mode and returns @declarative handles the Program and Executor of static mode and returns
...@@ -171,6 +171,12 @@ def declarative(function=None, input_spec=None): ...@@ -171,6 +171,12 @@ def declarative(function=None, input_spec=None):
function (callable): callable imperative function. function (callable): callable imperative function.
input_spec(list[InputSpec]|tuple[InputSpec]): list/tuple of InputSpec to specific the shape/dtype/name input_spec(list[InputSpec]|tuple[InputSpec]): list/tuple of InputSpec to specific the shape/dtype/name
information of each input Tensor. information of each input Tensor.
build_strategy(BuildStrategy|None): This argument is used to compile the
converted program with the specified options, such as operators' fusion
in the computational graph and memory optimization during the execution
of the computational graph. For more information about build_strategy,
please refer to :code:`paddle.static.BuildStrategy`. The default is None.
Returns: Returns:
Tensor(s): containing the numerical result. Tensor(s): containing the numerical result.
...@@ -206,10 +212,18 @@ def declarative(function=None, input_spec=None): ...@@ -206,10 +212,18 @@ def declarative(function=None, input_spec=None):
static_layer = copy_decorator_attrs( static_layer = copy_decorator_attrs(
original_func=python_func, original_func=python_func,
decorated_obj=StaticFunction( decorated_obj=StaticFunction(
function=python_func, input_spec=input_spec)) function=python_func,
input_spec=input_spec,
build_strategy=build_strategy))
return static_layer return static_layer
build_strategy = build_strategy or BuildStrategy()
if not isinstance(build_strategy, BuildStrategy):
raise TypeError(
"Required type(build_strategy) shall be `paddle.static.BuildStrategy`, but received {}".
format(type(build_strategy).__name__))
# for usage: `declarative(foo, ...)` # for usage: `declarative(foo, ...)`
if function is not None: if function is not None:
if isinstance(function, Layer): if isinstance(function, Layer):
......
...@@ -25,6 +25,7 @@ set_tests_properties(test_reinforcement_learning PROPERTIES TIMEOUT 120) ...@@ -25,6 +25,7 @@ set_tests_properties(test_reinforcement_learning PROPERTIES TIMEOUT 120)
set_tests_properties(test_transformer PROPERTIES TIMEOUT 200) set_tests_properties(test_transformer PROPERTIES TIMEOUT 200)
set_tests_properties(test_bmn PROPERTIES TIMEOUT 120) set_tests_properties(test_bmn PROPERTIES TIMEOUT 120)
#set_tests_properties(test_mnist PROPERTIES TIMEOUT 120) #set_tests_properties(test_mnist PROPERTIES TIMEOUT 120)
set_tests_properties(test_build_strategy PROPERTIES TIMEOUT 120)
if(NOT WIN32) if(NOT WIN32)
set_tests_properties(test_resnet_v2 PROPERTIES TIMEOUT 120) set_tests_properties(test_resnet_v2 PROPERTIES TIMEOUT 120)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import unittest
import numpy as np
from paddle.jit import ProgramTranslator
from test_resnet import ResNet, train, predict_dygraph_jit
from test_resnet import predict_dygraph, predict_static, predict_analysis_inference
program_translator = ProgramTranslator()
class TestResnetWithPass(unittest.TestCase):
def setUp(self):
self.build_strategy = paddle.static.BuildStrategy()
self.build_strategy.fuse_elewise_add_act_ops = True
self.build_strategy.fuse_bn_act_ops = True
self.build_strategy.fuse_bn_add_act_ops = True
self.build_strategy.enable_addto = True
# NOTE: for enable_addto
paddle.fluid.set_flags({"FLAGS_max_inplace_grad_add": 8})
def train(self, to_static):
program_translator.enable(to_static)
return train(to_static, self.build_strategy)
def verify_predict(self):
image = np.random.random([1, 3, 224, 224]).astype('float32')
dy_pre = predict_dygraph(image)
st_pre = predict_static(image)
dy_jit_pre = predict_dygraph_jit(image)
predictor_pre = predict_analysis_inference(image)
self.assertTrue(
np.allclose(dy_pre, st_pre),
msg="dy_pre:\n {}\n, st_pre: \n{}.".format(dy_pre, st_pre))
self.assertTrue(
np.allclose(dy_jit_pre, st_pre),
msg="dy_jit_pre:\n {}\n, st_pre: \n{}.".format(dy_jit_pre, st_pre))
self.assertTrue(
np.allclose(predictor_pre, st_pre),
msg="predictor_pre:\n {}\n, st_pre: \n{}.".format(predictor_pre,
st_pre))
def test_resnet(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
self.assertTrue(
np.allclose(static_loss, dygraph_loss),
msg="static_loss: {} \n dygraph_loss: {}".format(static_loss,
dygraph_loss))
self.verify_predict()
def test_in_static_mode_mkldnn(self):
paddle.fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
if paddle.fluid.core.is_compiled_with_mkldnn():
train(True, self.build_strategy)
finally:
paddle.fluid.set_flags({'FLAGS_use_mkldnn': False})
class TestError(unittest.TestCase):
def test_type_error(self):
def foo(x):
out = x + 1
return out
with self.assertRaises(TypeError):
static_foo = paddle.jit.to_static(foo, build_strategy="x")
if __name__ == '__main__':
unittest.main()
...@@ -190,7 +190,6 @@ class ResNet(fluid.dygraph.Layer): ...@@ -190,7 +190,6 @@ class ResNet(fluid.dygraph.Layer):
param_attr=fluid.param_attr.ParamAttr( param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv))) initializer=fluid.initializer.Uniform(-stdv, stdv)))
@declarative
def forward(self, inputs): def forward(self, inputs):
y = self.conv(inputs) y = self.conv(inputs)
y = self.pool2d_max(y) y = self.pool2d_max(y)
...@@ -213,7 +212,7 @@ def reader_decorator(reader): ...@@ -213,7 +212,7 @@ def reader_decorator(reader):
return __reader__ return __reader__
def train(to_static): def train(to_static, build_strategy=None):
""" """
Tests model decorated by `dygraph_to_static_output` in static mode. For users, the model is defined in dygraph mode and trained in static mode. Tests model decorated by `dygraph_to_static_output` in static mode. For users, the model is defined in dygraph mode and trained in static mode.
""" """
...@@ -231,6 +230,8 @@ def train(to_static): ...@@ -231,6 +230,8 @@ def train(to_static):
data_loader.set_sample_list_generator(train_reader) data_loader.set_sample_list_generator(train_reader)
resnet = ResNet() resnet = ResNet()
if to_static:
resnet = paddle.jit.to_static(resnet, build_strategy=build_strategy)
optimizer = optimizer_setting(parameter_list=resnet.parameters()) optimizer = optimizer_setting(parameter_list=resnet.parameters())
for epoch in range(epoch_num): for epoch in range(epoch_num):
......
...@@ -96,6 +96,7 @@ disable_wincpu_test="^jit_kernel_test$|\ ...@@ -96,6 +96,7 @@ disable_wincpu_test="^jit_kernel_test$|\
^test_bmn$|\ ^test_bmn$|\
^test_mobile_net$|\ ^test_mobile_net$|\
^test_resnet_v2$|\ ^test_resnet_v2$|\
^test_build_strategy$|\
^test_se_resnet$|\ ^test_se_resnet$|\
^disable_wincpu_test$" ^disable_wincpu_test$"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册