未验证 提交 3e66845f 编写于 作者: W WangZhen 提交者: GitHub

[Dy2St]Add backend for to_static API (#52596)

* Add backend for to_static API
上级 2a420036
...@@ -349,7 +349,7 @@ class TestNegSpecWithPrim(unittest.TestCase): ...@@ -349,7 +349,7 @@ class TestNegSpecWithPrim(unittest.TestCase):
) )
x = paddle.randn([2, 10]) x = paddle.randn([2, 10])
out = net(x) out = net(x)
np.testing.assert_equal(out.shape, [2, 5]) np.testing.assert_equal(net.forward._input_spec, None)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -218,8 +218,23 @@ def ignore_module(modules: list[Any]): ...@@ -218,8 +218,23 @@ def ignore_module(modules: list[Any]):
add_ignore_module(modules) add_ignore_module(modules)
def _check_and_set_backend(backend, build_strategy):
if backend not in ['CINN', None]:
raise ValueError(
"The backend of to_static should be 'CINN' or None, but received {}.".format(
backend
)
)
if backend == 'CINN':
build_strategy.build_cinn_pass = True
def to_static( def to_static(
function=None, input_spec=None, build_strategy=None, property=False function=None,
input_spec=None,
build_strategy=None,
backend=None,
**kwargs,
): ):
""" """
Converts imperative dygraph APIs into declarative function APIs. Decorator Converts imperative dygraph APIs into declarative function APIs. Decorator
...@@ -228,7 +243,6 @@ def to_static( ...@@ -228,7 +243,6 @@ def to_static(
Tensor(s) to do imperative training, inference, or other operations. If the Tensor(s) to do imperative training, inference, or other operations. If the
decorated function calls other imperative function, the called one will be decorated function calls other imperative function, the called one will be
converted into declarative function as well. converted into declarative function as well.
Args: Args:
function (callable): callable imperative function. function (callable): callable imperative function.
input_spec(list[InputSpec]|tuple[InputSpec]): list/tuple of InputSpec to specific the shape/dtype/name input_spec(list[InputSpec]|tuple[InputSpec]): list/tuple of InputSpec to specific the shape/dtype/name
...@@ -238,7 +252,8 @@ def to_static( ...@@ -238,7 +252,8 @@ def to_static(
in the computational graph and memory optimization during the execution in the computational graph and memory optimization during the execution
of the computational graph. For more information about build_strategy, of the computational graph. For more information about build_strategy,
please refer to :code:`paddle.static.BuildStrategy`. The default is None. please refer to :code:`paddle.static.BuildStrategy`. The default is None.
property(bool, Optional): whether the fucntion is python property. The default is False. backend(str, Optional): Specifies compilation backend, which can be `CINN` or None. When backend is `CINN`, CINN compiler will be used to speed up training and inference.
kwargs: Support keys including `property`, set `property` to True if the fucntion is python property.
Returns: Returns:
...@@ -263,6 +278,7 @@ def to_static( ...@@ -263,6 +278,7 @@ def to_static(
print(x_v) # [[2. 2.]] print(x_v) # [[2. 2.]]
""" """
property = kwargs.get("property", False)
def decorated(python_func): def decorated(python_func):
""" """
...@@ -279,6 +295,7 @@ def to_static( ...@@ -279,6 +295,7 @@ def to_static(
input_spec=input_spec, input_spec=input_spec,
build_strategy=build_strategy, build_strategy=build_strategy,
property=property, property=property,
backend=backend,
), ),
) )
...@@ -291,6 +308,7 @@ def to_static( ...@@ -291,6 +308,7 @@ def to_static(
type(build_strategy).__name__ type(build_strategy).__name__
) )
) )
_check_and_set_backend(backend, build_strategy)
# for usage: `to_static(foo, ...)` # for usage: `to_static(foo, ...)`
if function is not None: if function is not None:
......
...@@ -27,7 +27,12 @@ from paddle.fluid.framework import _apply_pass ...@@ -27,7 +27,12 @@ from paddle.fluid.framework import _apply_pass
from paddle.optimizer.lr import LRScheduler from paddle.optimizer.lr import LRScheduler
from . import logging_utils from . import logging_utils
from .utils import RETURN_NO_VALUE_MAGIC_NUM, _out_grad_names, _param_grad_names from .utils import (
RETURN_NO_VALUE_MAGIC_NUM,
_out_grad_names,
_param_grad_names,
backend_guard,
)
__all__ = [] __all__ = []
...@@ -197,6 +202,7 @@ class PartialProgramLayer: ...@@ -197,6 +202,7 @@ class PartialProgramLayer:
# program_id -> list(scope) # program_id -> list(scope)
self._scope_cache = {} self._scope_cache = {}
self._hooker = None self._hooker = None
self._backend = kwargs.get('backend', None)
def __call__(self, inputs): def __call__(self, inputs):
""" """
...@@ -636,10 +642,9 @@ class PartialProgramLayer: ...@@ -636,10 +642,9 @@ class PartialProgramLayer:
start_idx = len(program.block(0).ops) + len(self._outputs.tolist()) start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
if targets: if targets:
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
core.check_and_set_prim_all_enabled()
start_idx = len(program.block(0).ops) + len(self._outputs.tolist()) start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
backward.gradients(targets=targets, inputs=[]) with backend_guard(self._backend):
backward.gradients(targets=targets, inputs=[])
if self._hooker: if self._hooker:
program, start_idx = self._hooker.after_append_backward( program, start_idx = self._hooker.after_append_backward(
......
...@@ -48,6 +48,7 @@ from .utils import ( ...@@ -48,6 +48,7 @@ from .utils import (
NO_SHAPE_VAR_TYPE, NO_SHAPE_VAR_TYPE,
ast_to_func, ast_to_func,
ast_to_source_code, ast_to_source_code,
backend_guard,
func_to_source_code, func_to_source_code,
input_specs_compatible, input_specs_compatible,
is_paddle_func, is_paddle_func,
...@@ -334,7 +335,7 @@ class StaticFunction: ...@@ -334,7 +335,7 @@ class StaticFunction:
self._class_instance = None self._class_instance = None
if input_spec is not None and prim_or_cinn_is_enabled( if input_spec is not None and prim_or_cinn_is_enabled(
kwargs.get("build_strategy", None) kwargs.get("build_strategy", None), kwargs.get("backend", None)
): ):
from paddle.static import InputSpec from paddle.static import InputSpec
...@@ -1184,11 +1185,9 @@ class ProgramCache: ...@@ -1184,11 +1185,9 @@ class ProgramCache:
def _build_once(self, cache_key): def _build_once(self, cache_key):
# TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim # TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim
enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
# NOTE(xiongkun): Need a global FLAGS to enable/disable fallback # NOTE(xiongkun): Need a global FLAGS to enable/disable fallback
enable_fallback = enable_prim enable_fallback = enable_prim
core.check_and_set_prim_all_enabled()
try: try:
concrete_program = ConcreteProgram.from_func_spec( concrete_program = ConcreteProgram.from_func_spec(
func_spec=cache_key.function_spec, func_spec=cache_key.function_spec,
...@@ -1216,7 +1215,8 @@ class ProgramCache: ...@@ -1216,7 +1215,8 @@ class ProgramCache:
else: else:
raise raise
if prim_or_cinn_is_enabled(cache_key.kwargs['build_strategy']): backend = cache_key.kwargs['backend']
if prim_or_cinn_is_enabled(cache_key.kwargs['build_strategy'], backend):
for var in concrete_program.main_program.list_vars(): for var in concrete_program.main_program.list_vars():
if var.type not in NO_SHAPE_VAR_TYPE and -1 in var.shape: if var.type not in NO_SHAPE_VAR_TYPE and -1 in var.shape:
warnings.warn( warnings.warn(
...@@ -1228,10 +1228,11 @@ class ProgramCache: ...@@ -1228,10 +1228,11 @@ class ProgramCache:
partial_program = partial_program_from( partial_program = partial_program_from(
concrete_program, cache_key.class_instance is not None concrete_program, cache_key.class_instance is not None
) )
if core._is_fwd_prim_enabled(): with backend_guard(backend):
partial_program.set_hooker( if core._is_fwd_prim_enabled():
PrimHooker(concrete_program.main_program) partial_program.set_hooker(
) PrimHooker(concrete_program.main_program, backend)
)
return concrete_program, partial_program return concrete_program, partial_program
def __getitem__(self, item): def __getitem__(self, item):
...@@ -1291,39 +1292,46 @@ class ProgramCache: ...@@ -1291,39 +1292,46 @@ class ProgramCache:
class PrimHooker(PartialProgramLayerHook): class PrimHooker(PartialProgramLayerHook):
def __init__(self, original_program): def __init__(self, original_program, backend):
if len(original_program.blocks) > 1: if len(original_program.blocks) > 1:
raise ValueError( raise ValueError(
'The primitive mode only support one block currently.' 'The primitive mode only support one block currently.'
) )
self.backend = backend
self.custom_vjps = set() self.custom_vjps = set()
if core._is_all_prim_enabled(): with backend_guard(self.backend):
self.custom_vjps = { if core._is_all_prim_enabled():
op.type self.custom_vjps = {
for op in original_program.block(0).ops op.type
if core.has_comp_grad_op_maker(op.type) for op in original_program.block(0).ops
} if core.has_comp_grad_op_maker(op.type)
}
def before_append_backward(self, forward_program): def before_append_backward(self, forward_program):
if core._is_fwd_prim_enabled(): with backend_guard(self.backend):
_to_prim(forward_program.blocks, blacklist=self.custom_vjps) if core._is_fwd_prim_enabled():
return forward_program _to_prim(forward_program.blocks, blacklist=self.custom_vjps)
return forward_program
def after_append_backward(self, whole_program, backward_start_idx): def after_append_backward(self, whole_program, backward_start_idx):
backward_length = len(whole_program.block(0).ops) - backward_start_idx with backend_guard(self.backend):
if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0: backward_length = (
# only process backward part of block len(whole_program.block(0).ops) - backward_start_idx
_to_prim(whole_program.blocks, backward_length=backward_length) )
new_start_index = len(whole_program.block(0).ops) - backward_length if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0:
if backward_length > 0: # only process backward part of block
# only process forward part of block _to_prim(whole_program.blocks, backward_length=backward_length)
_to_prim(whole_program.blocks, start_idx=new_start_index) new_start_index = len(whole_program.block(0).ops) - backward_length
return whole_program, new_start_index if backward_length > 0:
# only process forward part of block
_to_prim(whole_program.blocks, start_idx=new_start_index)
return whole_program, new_start_index
def after_infer(self, infer_program): def after_infer(self, infer_program):
if core._is_fwd_prim_enabled(): with backend_guard(self.backend):
_to_prim(infer_program.block(0)) if core._is_fwd_prim_enabled():
return infer_program _to_prim(infer_program.block(0))
return infer_program
class ProgramTranslator: class ProgramTranslator:
......
...@@ -35,6 +35,7 @@ from paddle import fluid # noqa: F401 ...@@ -35,6 +35,7 @@ from paddle import fluid # noqa: F401
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from paddle.utils import gast from paddle.utils import gast
from .ast_utils import ast_to_source_code from .ast_utils import ast_to_source_code
...@@ -1498,7 +1499,10 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size): ...@@ -1498,7 +1499,10 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size):
return names return names
def prim_or_cinn_is_enabled(build_strategy): def prim_or_cinn_is_enabled(build_strategy, backend):
if backend == 'CINN':
return True
if build_strategy is not None and build_strategy.build_cinn_pass: if build_strategy is not None and build_strategy.build_cinn_pass:
return True return True
...@@ -1534,3 +1538,18 @@ def is_builtin(func, name=None): ...@@ -1534,3 +1538,18 @@ def is_builtin(func, name=None):
return True return True
else: else:
return False return False
@signature_safe_contextmanager
def backend_guard(backend):
core.check_and_set_prim_all_enabled()
orign_fwd = core._is_fwd_prim_enabled()
orign_bwd = core._is_bwd_prim_enabled()
if backend == 'CINN':
core._set_prim_all_enabled(True)
try:
yield
finally:
core._set_prim_forward_enabled(orign_fwd)
core._set_prim_backward_enabled(orign_bwd)
...@@ -163,5 +163,20 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -163,5 +163,20 @@ class TestPrimForwardAndBackward(unittest.TestCase):
) )
class TestBackend(unittest.TestCase):
def test_backend(self):
x = paddle.randn([2, 4])
out1 = self.forward(x, 'CINN')
out2 = self.forward(x, None)
np.testing.assert_allclose(out1, out2, rtol=1e-6)
def forward(self, x, beckend=None):
paddle.seed(2022)
net = PrimeNet()
net = paddle.jit.to_static(net, backend=beckend)
out = net(x)
return out
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -44,7 +44,7 @@ class TestPrimHook(unittest.TestCase): ...@@ -44,7 +44,7 @@ class TestPrimHook(unittest.TestCase):
f f
).get_concrete_program() ).get_concrete_program()
self._hook = program_translator.PrimHooker( self._hook = program_translator.PrimHooker(
concrete_program.main_program concrete_program.main_program, None
) )
self._forward = partial_program.forward_program self._forward = partial_program.forward_program
self._whole = partial_program._train_program self._whole = partial_program._train_program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册