From d0096eaf556cccc454d554720dfc34d1e42ad014 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 9 Sep 2022 10:45:36 +0800 Subject: [PATCH] [ Dy2Static ] convert_call support staticmethod for class. (#44983) * convert_call support staticmethod for class. * while support for python container. It is convenient to convert more dynamic graph codes into static graphs. * cond support python container * add unittest for staticmethod convert_call * fix bugs * add unittest for item interface * fix bugs * change to np.testing.assert_allclose * code format * fix comments. * code format --- .../dygraph_to_static/ast_transformer.py | 3 +++ .../dygraph_to_static/convert_call_func.py | 5 ++++ .../dygraph_to_static/convert_operators.py | 8 ++++-- .../dygraph_to_static/program_translator.py | 1 - python/paddle/fluid/layers/math_op_patch.py | 25 ++++++++++++++++--- .../dygraph_to_static/test_convert_call.py | 23 +++++++++++++++++ .../dygraph_to_static/test_tensor_methods.py | 23 +++++++++++++++++ 7 files changed, 82 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index e689797bde4..fd146d77632 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -123,7 +123,10 @@ class DygraphToStaticAst(BaseTransformer): # Remove the decorated name of dygraph_to_static if hasattr(node, 'decorator_list'): decorator_list = [] + ignore_list = ["staticmethod"] for d in node.decorator_list: + if isinstance(d, gast.Name) and d.id in ignore_list: + continue if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES: raise NotImplementedError( "ProgramTranslator hasn't implemented multiple decorators. Please remove " diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py index 5bb75bda8de..fda668dc745 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py @@ -213,6 +213,11 @@ def convert_call(func): elif isinstance(fn, StaticFunction): _, fn = unwrap_decorators(fn) global_functions.add(fn) + elif inspect.isclass(fn): + if isinstance(fn.__dict__.get(func.__name__, None), + staticmethod): + global_functions.add( + func) # Add func to ensure that we will convert if func in global_functions: converted_call = convert_to_static(func) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index 938cf9c3228..e22d83d56f3 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -87,7 +87,10 @@ def convert_while_loop(cond, Args: cond(Callable): A callable object that returns a boolean variable to control whether to execute the loop body. It takes ``loop_vars`` as arguments. body(Callable): A callable object that returns a tuple or list of variables with the same arguments ``loops_vars`` as ``cond`` . - loop_vars(list|tuple): A list or tuple of variables passed to ``cond`` and ``body`` . + get_args(callable): Get all arguments that needed in true_fn and false_fn. + set_args(callable): Update arguments that modified in trure_fn and false_fn. + return_name_ids(list[string], optional): the returned names. + push_pop_names(list[string], optional): the names on which called .append() or .pop(). Returns: A list or tuple of variables which returned by ``body``. @@ -306,7 +309,8 @@ def convert_ifelse(pred, false_fn(callable): A callable to be performed if ``pred`` is false. get_args(callable): Get all arguments that needed in true_fn and false_fn. set_args(callable): Update arguments that modified in trure_fn and false_fn. - return_name_ids(list[string]): the returned names. + return_name_ids(list[string], optional): the returned names. + push_pop_names(list[string], optional): the names on which called .append() or .pop(). Returns: ``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` . diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 77b55f35e2e..2a098947413 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -380,7 +380,6 @@ class StaticFunction(object): try: concrete_program, partial_program_layer = self.get_concrete_program( *args, **kwargs, is_train=self._is_train_mode()) - # 3. synchronize self.training attribute. if isinstance(self._class_instance, layers.Layer): partial_program_layer.training = self._class_instance.training diff --git a/python/paddle/fluid/layers/math_op_patch.py b/python/paddle/fluid/layers/math_op_patch.py index 3721b97368a..7c66e9736ea 100644 --- a/python/paddle/fluid/layers/math_op_patch.py +++ b/python/paddle/fluid/layers/math_op_patch.py @@ -227,12 +227,30 @@ def monkey_patch_variable(): .format(self.type)) array_write(x=var, i=array_length(self), array=self) + @static_only + def _item(self): + """ + In order to be compatible with the item interface introduced by the dynamic graph, it does nothing but returns self. + It will check that the shape must be a 1-D tensor + """ + if len(self.shape) > 1: + raise TypeError( + "Required input var should be 1-D Variable, but received {}". + format(self.shape)) + return self + @static_only def pop(self, *args): """ - **Notes**: - **The type variable must be LoD Tensor Array. - + The type variable must be LoD Tensor Array. + When self is LoDTensorArray, calling pop is similar to Python's pop on list. + This interface is used to simplify dygraph to static graph operations. + + Args: + self(Variable): The source variable, which must be LOD_TENSOR_ARRAY + *args: optional, a int means index. + Returns: + Variable: self[index] """ from paddle.fluid.dygraph.dygraph_to_static.convert_operators import _run_paddle_pop if self.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY: @@ -410,6 +428,7 @@ def monkey_patch_variable(): ('cpu', cpu), ('cuda', cuda), ('append', append), + ('item', _item), ('pop', pop), ('dim', lambda x: len(x.shape)), ('ndimension', lambda x: len(x.shape)), diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py index b3cdde63639..4f4d42c8092 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py @@ -65,6 +65,23 @@ def dyfunc_with_third_library_logging(x_v): return x_v +class A: + + @staticmethod + def add(a, b): + """ + dygraph mode, return a numpy object. + static mode, return a variable object. + """ + return paddle.to_tensor(a.numpy() + b.numpy()) + + +@paddle.jit.to_static +def dyfunc_with_staticmethod(x_v): + a = A() + return a.add(x_v, x_v) + + class TestRecursiveCall1(unittest.TestCase): def setUp(self): @@ -188,6 +205,12 @@ class TestThirdPartyLibrary(TestRecursiveCall2): self.dygraph_func = dyfunc_with_third_library_logging +class TestStaticMethod(TestRecursiveCall2): + + def set_func(self): + self.dygraph_func = dyfunc_with_staticmethod + + # Situation 2 : test not_to_static diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_methods.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_methods.py index 2ad9153fbaa..a21a155d600 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_methods.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_methods.py @@ -62,6 +62,29 @@ class TestTensorDygraphOnlyMethodError(unittest.TestCase): static_res = self._run(to_static=True) +@paddle.jit.to_static +def tensor_item(x): + x = paddle.to_tensor(x) + y = x.clone() + return y.item() + + +class TestTensorItem(unittest.TestCase): + + def _run(self, to_static): + prog_trans = paddle.jit.ProgramTranslator() + prog_trans.enable(to_static) + x = paddle.ones([1]) + if to_static: + return tensor_item(x).numpy() + return tensor_item(x) + + def test_tensor_clone(self): + dygraph_res = self._run(to_static=False) + static_res = self._run(to_static=True) + np.testing.assert_allclose(dygraph_res, static_res) + + @paddle.jit.to_static def tensor_size(x): x = paddle.to_tensor(x) -- GitLab