未验证 提交 3e66845f 编写于 作者: W WangZhen 提交者: GitHub

[Dy2St]Add backend for to_static API (#52596)

* Add backend for to_static API
上级 2a420036
......@@ -349,7 +349,7 @@ class TestNegSpecWithPrim(unittest.TestCase):
)
x = paddle.randn([2, 10])
out = net(x)
np.testing.assert_equal(out.shape, [2, 5])
np.testing.assert_equal(net.forward._input_spec, None)
if __name__ == '__main__':
......
......@@ -218,8 +218,23 @@ def ignore_module(modules: list[Any]):
add_ignore_module(modules)
def _check_and_set_backend(backend, build_strategy):
if backend not in ['CINN', None]:
raise ValueError(
"The backend of to_static should be 'CINN' or None, but received {}.".format(
backend
)
)
if backend == 'CINN':
build_strategy.build_cinn_pass = True
def to_static(
function=None, input_spec=None, build_strategy=None, property=False
function=None,
input_spec=None,
build_strategy=None,
backend=None,
**kwargs,
):
"""
Converts imperative dygraph APIs into declarative function APIs. Decorator
......@@ -228,7 +243,6 @@ def to_static(
Tensor(s) to do imperative training, inference, or other operations. If the
decorated function calls other imperative function, the called one will be
converted into declarative function as well.
Args:
function (callable): callable imperative function.
input_spec(list[InputSpec]|tuple[InputSpec]): list/tuple of InputSpec to specific the shape/dtype/name
......@@ -238,7 +252,8 @@ def to_static(
in the computational graph and memory optimization during the execution
of the computational graph. For more information about build_strategy,
please refer to :code:`paddle.static.BuildStrategy`. The default is None.
property(bool, Optional): whether the fucntion is python property. The default is False.
backend(str, Optional): Specifies compilation backend, which can be `CINN` or None. When backend is `CINN`, CINN compiler will be used to speed up training and inference.
kwargs: Support keys including `property`, set `property` to True if the fucntion is python property.
Returns:
......@@ -263,6 +278,7 @@ def to_static(
print(x_v) # [[2. 2.]]
"""
property = kwargs.get("property", False)
def decorated(python_func):
"""
......@@ -279,6 +295,7 @@ def to_static(
input_spec=input_spec,
build_strategy=build_strategy,
property=property,
backend=backend,
),
)
......@@ -291,6 +308,7 @@ def to_static(
type(build_strategy).__name__
)
)
_check_and_set_backend(backend, build_strategy)
# for usage: `to_static(foo, ...)`
if function is not None:
......
......@@ -27,7 +27,12 @@ from paddle.fluid.framework import _apply_pass
from paddle.optimizer.lr import LRScheduler
from . import logging_utils
from .utils import RETURN_NO_VALUE_MAGIC_NUM, _out_grad_names, _param_grad_names
from .utils import (
RETURN_NO_VALUE_MAGIC_NUM,
_out_grad_names,
_param_grad_names,
backend_guard,
)
__all__ = []
......@@ -197,6 +202,7 @@ class PartialProgramLayer:
# program_id -> list(scope)
self._scope_cache = {}
self._hooker = None
self._backend = kwargs.get('backend', None)
def __call__(self, inputs):
"""
......@@ -636,9 +642,8 @@ class PartialProgramLayer:
start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
if targets:
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
core.check_and_set_prim_all_enabled()
start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
with backend_guard(self._backend):
backward.gradients(targets=targets, inputs=[])
if self._hooker:
......
......@@ -48,6 +48,7 @@ from .utils import (
NO_SHAPE_VAR_TYPE,
ast_to_func,
ast_to_source_code,
backend_guard,
func_to_source_code,
input_specs_compatible,
is_paddle_func,
......@@ -334,7 +335,7 @@ class StaticFunction:
self._class_instance = None
if input_spec is not None and prim_or_cinn_is_enabled(
kwargs.get("build_strategy", None)
kwargs.get("build_strategy", None), kwargs.get("backend", None)
):
from paddle.static import InputSpec
......@@ -1184,11 +1185,9 @@ class ProgramCache:
def _build_once(self, cache_key):
# TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim
enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
# NOTE(xiongkun): Need a global FLAGS to enable/disable fallback
enable_fallback = enable_prim
core.check_and_set_prim_all_enabled()
try:
concrete_program = ConcreteProgram.from_func_spec(
func_spec=cache_key.function_spec,
......@@ -1216,7 +1215,8 @@ class ProgramCache:
else:
raise
if prim_or_cinn_is_enabled(cache_key.kwargs['build_strategy']):
backend = cache_key.kwargs['backend']
if prim_or_cinn_is_enabled(cache_key.kwargs['build_strategy'], backend):
for var in concrete_program.main_program.list_vars():
if var.type not in NO_SHAPE_VAR_TYPE and -1 in var.shape:
warnings.warn(
......@@ -1228,9 +1228,10 @@ class ProgramCache:
partial_program = partial_program_from(
concrete_program, cache_key.class_instance is not None
)
with backend_guard(backend):
if core._is_fwd_prim_enabled():
partial_program.set_hooker(
PrimHooker(concrete_program.main_program)
PrimHooker(concrete_program.main_program, backend)
)
return concrete_program, partial_program
......@@ -1291,12 +1292,14 @@ class ProgramCache:
class PrimHooker(PartialProgramLayerHook):
def __init__(self, original_program):
def __init__(self, original_program, backend):
if len(original_program.blocks) > 1:
raise ValueError(
'The primitive mode only support one block currently.'
)
self.backend = backend
self.custom_vjps = set()
with backend_guard(self.backend):
if core._is_all_prim_enabled():
self.custom_vjps = {
op.type
......@@ -1305,12 +1308,16 @@ class PrimHooker(PartialProgramLayerHook):
}
def before_append_backward(self, forward_program):
with backend_guard(self.backend):
if core._is_fwd_prim_enabled():
_to_prim(forward_program.blocks, blacklist=self.custom_vjps)
return forward_program
def after_append_backward(self, whole_program, backward_start_idx):
backward_length = len(whole_program.block(0).ops) - backward_start_idx
with backend_guard(self.backend):
backward_length = (
len(whole_program.block(0).ops) - backward_start_idx
)
if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0:
# only process backward part of block
_to_prim(whole_program.blocks, backward_length=backward_length)
......@@ -1321,6 +1328,7 @@ class PrimHooker(PartialProgramLayerHook):
return whole_program, new_start_index
def after_infer(self, infer_program):
with backend_guard(self.backend):
if core._is_fwd_prim_enabled():
_to_prim(infer_program.block(0))
return infer_program
......
......@@ -35,6 +35,7 @@ from paddle import fluid # noqa: F401
from paddle.fluid import core, unique_name
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from paddle.utils import gast
from .ast_utils import ast_to_source_code
......@@ -1498,7 +1499,10 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size):
return names
def prim_or_cinn_is_enabled(build_strategy):
def prim_or_cinn_is_enabled(build_strategy, backend):
if backend == 'CINN':
return True
if build_strategy is not None and build_strategy.build_cinn_pass:
return True
......@@ -1534,3 +1538,18 @@ def is_builtin(func, name=None):
return True
else:
return False
@signature_safe_contextmanager
def backend_guard(backend):
core.check_and_set_prim_all_enabled()
orign_fwd = core._is_fwd_prim_enabled()
orign_bwd = core._is_bwd_prim_enabled()
if backend == 'CINN':
core._set_prim_all_enabled(True)
try:
yield
finally:
core._set_prim_forward_enabled(orign_fwd)
core._set_prim_backward_enabled(orign_bwd)
......@@ -163,5 +163,20 @@ class TestPrimForwardAndBackward(unittest.TestCase):
)
class TestBackend(unittest.TestCase):
def test_backend(self):
x = paddle.randn([2, 4])
out1 = self.forward(x, 'CINN')
out2 = self.forward(x, None)
np.testing.assert_allclose(out1, out2, rtol=1e-6)
def forward(self, x, beckend=None):
paddle.seed(2022)
net = PrimeNet()
net = paddle.jit.to_static(net, backend=beckend)
out = net(x)
return out
if __name__ == '__main__':
unittest.main()
......@@ -44,7 +44,7 @@ class TestPrimHook(unittest.TestCase):
f
).get_concrete_program()
self._hook = program_translator.PrimHooker(
concrete_program.main_program
concrete_program.main_program, None
)
self._forward = partial_program.forward_program
self._whole = partial_program._train_program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册