From 3e66845f017033669bccca6a385cde64937544bb Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Tue, 11 Apr 2023 11:16:53 +0800 Subject: [PATCH] [Dy2St]Add backend for to_static API (#52596) * Add backend for to_static API --- .../fluid/tests/unittests/test_input_spec.py | 2 +- python/paddle/jit/api.py | 24 ++++++- .../paddle/jit/dy2static/partial_program.py | 13 ++-- .../jit/dy2static/program_translator.py | 68 +++++++++++-------- python/paddle/jit/dy2static/utils.py | 21 +++++- test/dygraph_to_static/test_cinn_prim.py | 15 ++++ .../test_partial_program_hook.py | 2 +- 7 files changed, 105 insertions(+), 40 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_input_spec.py b/python/paddle/fluid/tests/unittests/test_input_spec.py index dad821438af..2bdce8b4b58 100644 --- a/python/paddle/fluid/tests/unittests/test_input_spec.py +++ b/python/paddle/fluid/tests/unittests/test_input_spec.py @@ -349,7 +349,7 @@ class TestNegSpecWithPrim(unittest.TestCase): ) x = paddle.randn([2, 10]) out = net(x) - np.testing.assert_equal(out.shape, [2, 5]) + np.testing.assert_equal(net.forward._input_spec, None) if __name__ == '__main__': diff --git a/python/paddle/jit/api.py b/python/paddle/jit/api.py index bc07609a111..bde75f6ad73 100644 --- a/python/paddle/jit/api.py +++ b/python/paddle/jit/api.py @@ -218,8 +218,23 @@ def ignore_module(modules: list[Any]): add_ignore_module(modules) +def _check_and_set_backend(backend, build_strategy): + if backend not in ['CINN', None]: + raise ValueError( + "The backend of to_static should be 'CINN' or None, but received {}.".format( + backend + ) + ) + if backend == 'CINN': + build_strategy.build_cinn_pass = True + + def to_static( - function=None, input_spec=None, build_strategy=None, property=False + function=None, + input_spec=None, + build_strategy=None, + backend=None, + **kwargs, ): """ Converts imperative dygraph APIs into declarative function APIs. Decorator @@ -228,7 +243,6 @@ def to_static( Tensor(s) to do imperative training, inference, or other operations. If the decorated function calls other imperative function, the called one will be converted into declarative function as well. - Args: function (callable): callable imperative function. input_spec(list[InputSpec]|tuple[InputSpec]): list/tuple of InputSpec to specific the shape/dtype/name @@ -238,7 +252,8 @@ def to_static( in the computational graph and memory optimization during the execution of the computational graph. For more information about build_strategy, please refer to :code:`paddle.static.BuildStrategy`. The default is None. - property(bool, Optional): whether the fucntion is python property. The default is False. + backend(str, Optional): Specifies compilation backend, which can be `CINN` or None. When backend is `CINN`, CINN compiler will be used to speed up training and inference. + kwargs: Support keys including `property`, set `property` to True if the fucntion is python property. Returns: @@ -263,6 +278,7 @@ def to_static( print(x_v) # [[2. 2.]] """ + property = kwargs.get("property", False) def decorated(python_func): """ @@ -279,6 +295,7 @@ def to_static( input_spec=input_spec, build_strategy=build_strategy, property=property, + backend=backend, ), ) @@ -291,6 +308,7 @@ def to_static( type(build_strategy).__name__ ) ) + _check_and_set_backend(backend, build_strategy) # for usage: `to_static(foo, ...)` if function is not None: diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index 9538bb93007..7a6afc82b1b 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -27,7 +27,12 @@ from paddle.fluid.framework import _apply_pass from paddle.optimizer.lr import LRScheduler from . import logging_utils -from .utils import RETURN_NO_VALUE_MAGIC_NUM, _out_grad_names, _param_grad_names +from .utils import ( + RETURN_NO_VALUE_MAGIC_NUM, + _out_grad_names, + _param_grad_names, + backend_guard, +) __all__ = [] @@ -197,6 +202,7 @@ class PartialProgramLayer: # program_id -> list(scope) self._scope_cache = {} self._hooker = None + self._backend = kwargs.get('backend', None) def __call__(self, inputs): """ @@ -636,10 +642,9 @@ class PartialProgramLayer: start_idx = len(program.block(0).ops) + len(self._outputs.tolist()) if targets: - # TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch. - core.check_and_set_prim_all_enabled() start_idx = len(program.block(0).ops) + len(self._outputs.tolist()) - backward.gradients(targets=targets, inputs=[]) + with backend_guard(self._backend): + backward.gradients(targets=targets, inputs=[]) if self._hooker: program, start_idx = self._hooker.after_append_backward( diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 3777af8879d..a8be1abb2a1 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -48,6 +48,7 @@ from .utils import ( NO_SHAPE_VAR_TYPE, ast_to_func, ast_to_source_code, + backend_guard, func_to_source_code, input_specs_compatible, is_paddle_func, @@ -334,7 +335,7 @@ class StaticFunction: self._class_instance = None if input_spec is not None and prim_or_cinn_is_enabled( - kwargs.get("build_strategy", None) + kwargs.get("build_strategy", None), kwargs.get("backend", None) ): from paddle.static import InputSpec @@ -1184,11 +1185,9 @@ class ProgramCache: def _build_once(self, cache_key): # TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass - # TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch. # NOTE(xiongkun): Need a global FLAGS to enable/disable fallback enable_fallback = enable_prim - core.check_and_set_prim_all_enabled() try: concrete_program = ConcreteProgram.from_func_spec( func_spec=cache_key.function_spec, @@ -1216,7 +1215,8 @@ class ProgramCache: else: raise - if prim_or_cinn_is_enabled(cache_key.kwargs['build_strategy']): + backend = cache_key.kwargs['backend'] + if prim_or_cinn_is_enabled(cache_key.kwargs['build_strategy'], backend): for var in concrete_program.main_program.list_vars(): if var.type not in NO_SHAPE_VAR_TYPE and -1 in var.shape: warnings.warn( @@ -1228,10 +1228,11 @@ class ProgramCache: partial_program = partial_program_from( concrete_program, cache_key.class_instance is not None ) - if core._is_fwd_prim_enabled(): - partial_program.set_hooker( - PrimHooker(concrete_program.main_program) - ) + with backend_guard(backend): + if core._is_fwd_prim_enabled(): + partial_program.set_hooker( + PrimHooker(concrete_program.main_program, backend) + ) return concrete_program, partial_program def __getitem__(self, item): @@ -1291,39 +1292,46 @@ class ProgramCache: class PrimHooker(PartialProgramLayerHook): - def __init__(self, original_program): + def __init__(self, original_program, backend): if len(original_program.blocks) > 1: raise ValueError( 'The primitive mode only support one block currently.' ) + self.backend = backend self.custom_vjps = set() - if core._is_all_prim_enabled(): - self.custom_vjps = { - op.type - for op in original_program.block(0).ops - if core.has_comp_grad_op_maker(op.type) - } + with backend_guard(self.backend): + if core._is_all_prim_enabled(): + self.custom_vjps = { + op.type + for op in original_program.block(0).ops + if core.has_comp_grad_op_maker(op.type) + } def before_append_backward(self, forward_program): - if core._is_fwd_prim_enabled(): - _to_prim(forward_program.blocks, blacklist=self.custom_vjps) - return forward_program + with backend_guard(self.backend): + if core._is_fwd_prim_enabled(): + _to_prim(forward_program.blocks, blacklist=self.custom_vjps) + return forward_program def after_append_backward(self, whole_program, backward_start_idx): - backward_length = len(whole_program.block(0).ops) - backward_start_idx - if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0: - # only process backward part of block - _to_prim(whole_program.blocks, backward_length=backward_length) - new_start_index = len(whole_program.block(0).ops) - backward_length - if backward_length > 0: - # only process forward part of block - _to_prim(whole_program.blocks, start_idx=new_start_index) - return whole_program, new_start_index + with backend_guard(self.backend): + backward_length = ( + len(whole_program.block(0).ops) - backward_start_idx + ) + if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0: + # only process backward part of block + _to_prim(whole_program.blocks, backward_length=backward_length) + new_start_index = len(whole_program.block(0).ops) - backward_length + if backward_length > 0: + # only process forward part of block + _to_prim(whole_program.blocks, start_idx=new_start_index) + return whole_program, new_start_index def after_infer(self, infer_program): - if core._is_fwd_prim_enabled(): - _to_prim(infer_program.block(0)) - return infer_program + with backend_guard(self.backend): + if core._is_fwd_prim_enabled(): + _to_prim(infer_program.block(0)) + return infer_program class ProgramTranslator: diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 3608b8d0641..28c8c739f2e 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -35,6 +35,7 @@ from paddle import fluid # noqa: F401 from paddle.fluid import core, unique_name from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.wrapped_decorator import signature_safe_contextmanager from paddle.utils import gast from .ast_utils import ast_to_source_code @@ -1498,7 +1499,10 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size): return names -def prim_or_cinn_is_enabled(build_strategy): +def prim_or_cinn_is_enabled(build_strategy, backend): + if backend == 'CINN': + return True + if build_strategy is not None and build_strategy.build_cinn_pass: return True @@ -1534,3 +1538,18 @@ def is_builtin(func, name=None): return True else: return False + + +@signature_safe_contextmanager +def backend_guard(backend): + core.check_and_set_prim_all_enabled() + orign_fwd = core._is_fwd_prim_enabled() + orign_bwd = core._is_bwd_prim_enabled() + + if backend == 'CINN': + core._set_prim_all_enabled(True) + try: + yield + finally: + core._set_prim_forward_enabled(orign_fwd) + core._set_prim_backward_enabled(orign_bwd) diff --git a/test/dygraph_to_static/test_cinn_prim.py b/test/dygraph_to_static/test_cinn_prim.py index 6ace7696c38..c5527e85238 100644 --- a/test/dygraph_to_static/test_cinn_prim.py +++ b/test/dygraph_to_static/test_cinn_prim.py @@ -163,5 +163,20 @@ class TestPrimForwardAndBackward(unittest.TestCase): ) +class TestBackend(unittest.TestCase): + def test_backend(self): + x = paddle.randn([2, 4]) + out1 = self.forward(x, 'CINN') + out2 = self.forward(x, None) + np.testing.assert_allclose(out1, out2, rtol=1e-6) + + def forward(self, x, beckend=None): + paddle.seed(2022) + net = PrimeNet() + net = paddle.jit.to_static(net, backend=beckend) + out = net(x) + return out + + if __name__ == '__main__': unittest.main() diff --git a/test/dygraph_to_static/test_partial_program_hook.py b/test/dygraph_to_static/test_partial_program_hook.py index 896dde419bf..b9a64d3d099 100644 --- a/test/dygraph_to_static/test_partial_program_hook.py +++ b/test/dygraph_to_static/test_partial_program_hook.py @@ -44,7 +44,7 @@ class TestPrimHook(unittest.TestCase): f ).get_concrete_program() self._hook = program_translator.PrimHooker( - concrete_program.main_program + concrete_program.main_program, None ) self._forward = partial_program.forward_program self._whole = partial_program._train_program -- GitLab