From 247ef47772738c195eaabe86b067bc7fb196651b Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 19 Oct 2022 17:03:34 +0800 Subject: [PATCH] [CherryPick] Support TypeHint for function decorated by @to_static (#47147) * [Dy2Static] Support TypeHint for function decorated by @to_static (#47121) * Add TypeHint Transformer * add unittest for typehint transformer * [Dy2Static] Remove GradTransformer (#47063) * [Dy2Static] Remove GradTransformer 1. fix einsum infershape bugs. 2. remove grad_transformer and unify paddle.grad and paddle.static.gradient. 3. add dygraph_and_dy2static_only decorator for dy2static. * fix bugs * rename --- paddle/phi/kernels/impl/einsum_impl.h | 2 +- python/paddle/fluid/dygraph/base.py | 31 +++++++- .../dygraph_to_static/ast_transformer.py | 4 +- .../dygraph_to_static/typehint_transformer.py | 47 +++++++++++ python/paddle/fluid/framework.py | 12 +++ .../dygraph_to_static/test_typehint.py | 79 +++++++++++++++++++ 6 files changed, 172 insertions(+), 3 deletions(-) create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/typehint_transformer.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_typehint.py diff --git a/paddle/phi/kernels/impl/einsum_impl.h b/paddle/phi/kernels/impl/einsum_impl.h index 80529c8b66..dafb967ae8 100644 --- a/paddle/phi/kernels/impl/einsum_impl.h +++ b/paddle/phi/kernels/impl/einsum_impl.h @@ -241,7 +241,7 @@ inline static void InferLabelShape(const std::vector& op_labels, } else if (labelshape->is_default(c) || (*labelshape)[c] == -1) { (*labelshape)[c] = op_dim[dim_ptr]; dim_ptr++; - } else { + } else if (op_dim[dim_ptr] != -1) { PADDLE_ENFORCE_EQ( (*labelshape)[c], op_dim[dim_ptr], diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index b30e3ff1d8..cca6a5fc1c 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -27,6 +27,7 @@ from ..data_feeder import convert_dtype import warnings from ..framework import _get_paddle_place, _in_legacy_dygraph, _in_eager_without_dygraph_check import paddle +import warnings __all__ = [ 'no_grad', 'no_grad_', 'grad', 'guard', 'enable_dygraph', 'disable_dygraph', @@ -45,6 +46,20 @@ def in_declarative_mode(): return _in_declarative_mode_ +def declarative_unsupport_argument_warning(func_name, input_names, inputs, + support_values): + """ + Warning if inputs do not elementwisely equals to support_values. + It's a utility function for dy2static when dygraph interface have + more inputs than static interface such as paddle.grad. + + """ + for name, inp, sup in zip(input_names, inputs, support_values): + if inp != sup: + warnings.warn(f"{func_name} has unsupported parameter in jit: " + + f"{name}, jit will discard it") + + def _switch_to_static_graph_(func): def __impl__(*args, **kwargs): @@ -290,6 +305,10 @@ def no_grad(func=None): test_layer() """ + if in_declarative_mode(): + warnings.warn( + "paddle.no_grad is only supported for inference model, and not supported for training under @to_static." + ) if func is None: return _switch_tracer_mode_guard_(is_train=False) else: @@ -428,7 +447,7 @@ def guard(place=None): yield -@framework.dygraph_only +@framework.non_static_only def grad(outputs, inputs, grad_outputs=None, @@ -563,6 +582,16 @@ def grad(outputs, grad_y1 = paddle.to_tensor(3.0) print(test_dygraph_grad([grad_y1, grad_value])) # [24.] ''' + if in_declarative_mode(): + # In dy2static context, we call static interface `gradients` + # to calculate grads. + from paddle.static import gradients + declarative_unsupport_argument_warning( + "paddle.grad", + ["retain_graph", "create_grad", "only_inputs", "allow_unused"], + [retain_graph, create_graph, only_inputs, allow_unused], + [None, False, True, False]) + return gradients(outputs, inputs, grad_outputs, no_grad_vars) def check_in_out(in_out_list, name): assert in_out_list is not None, "{} should not be None".format(name) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index e045348e6c..e946969f1a 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -30,6 +30,7 @@ from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import Br from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import CastTransformer from paddle.fluid.dygraph.dygraph_to_static.grad_transformer import GradTransformer +from paddle.fluid.dygraph.dygraph_to_static.typehint_transformer import TypeHintTransformer from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTransformer @@ -104,8 +105,9 @@ class DygraphToStaticAst(BaseTransformer): PrintTransformer, # print statement CallTransformer, # transform call recursively CastTransformer, # type casting statement - GradTransformer, # transform paddle.grad to paddle.gradients + #GradTransformer, # transform paddle.grad to paddle.gradients DecoratorTransformer, # transform decorators to function call + TypeHintTransformer, # remove all typehint in gast.Name ] apply_optimization(transformers) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/typehint_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/typehint_transformer.py new file mode 100644 index 0000000000..f258b98b50 --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/typehint_transformer.py @@ -0,0 +1,47 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.utils import gast +import warnings + +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper +from paddle.fluid.dygraph.dygraph_to_static import utils +from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer + + +class TypeHintTransformer(BaseTransformer): + """ + A class remove all the typehint in gast.Name(annotation). + Please put it behind other transformers because other transformer may relay on typehints. + """ + + def __init__(self, wrapper_root): + assert isinstance( + wrapper_root, AstNodeWrapper + ), "Input non-AstNodeWrapper node for the initialization of TypeHintTransformer." + self.wrapper_root = wrapper_root + self.root = wrapper_root.node + + def transform(self): + self.visit(self.root) + + def visit_FunctionDef(self, node): + node.returns = None + self.generic_visit(node) + return node + + def visit_Name(self, node): + node.annotation = None + self.generic_visit(node) + return node diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index c4b145df5d..9321068de7 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -511,6 +511,17 @@ def _dygraph_only_(func): return __impl__ +def _non_static_only_(func): + + def __impl__(*args, **kwargs): + from .dygraph.base import in_declarative_mode + assert _non_static_mode() or in_declarative_mode( + ), "We only support '%s()' in dynamic graph mode, please call 'paddle.disable_static()' to enter dynamic graph mode." % func.__name__ + return func(*args, **kwargs) + + return __impl__ + + def _static_only_(func): def __impl__(*args, **kwargs): @@ -570,6 +581,7 @@ dygraph_not_support = wrap_decorator(_dygraph_not_support_) dygraph_only = wrap_decorator(_dygraph_only_) static_only = wrap_decorator(_static_only_) fake_interface_only = wrap_decorator(_fake_interface_only_) +non_static_only = wrap_decorator(_non_static_only_) def _dygraph_tracer(): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_typehint.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_typehint.py new file mode 100644 index 0000000000..b8addd53d5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_typehint.py @@ -0,0 +1,79 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.fluid as fluid +import unittest + +from paddle.fluid.dygraph.jit import declarative + +SEED = 2020 +np.random.seed(SEED) + + +class A: + pass + + +def function(x: A) -> A: + t: A = A() + return 2 * x + + +class TestTransformWhileLoop(unittest.TestCase): + + def setUp(self): + self.place = fluid.CUDAPlace( + 0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace() + self.x = np.zeros(shape=(1), dtype=np.int32) + self._init_dyfunc() + + def _init_dyfunc(self): + self.dyfunc = function + + def _run_static(self): + return self._run(to_static=True) + + def _run_dygraph(self): + return self._run(to_static=False) + + def _run(self, to_static): + with fluid.dygraph.guard(self.place): + # Set the input of dyfunc to VarBase + tensor_x = fluid.dygraph.to_variable(self.x, zero_copy=False) + if to_static: + ret = declarative(self.dyfunc)(tensor_x) + else: + ret = self.dyfunc(tensor_x) + if hasattr(ret, "numpy"): + return ret.numpy() + else: + return ret + + def test_ast_to_func(self): + static_numpy = self._run_static() + dygraph_numpy = self._run_dygraph() + print(static_numpy, dygraph_numpy) + np.testing.assert_allclose(dygraph_numpy, static_numpy, rtol=1e-05) + + +class TestTypeHint(TestTransformWhileLoop): + + def _init_dyfunc(self): + self.dyfunc = function + + +if __name__ == '__main__': + with fluid.framework._test_eager_guard(): + unittest.main() -- GitLab