From be3908a3326748b3da1ddceba2393516d93f38b3 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 19 Oct 2022 10:57:42 +0800 Subject: [PATCH] [Dy2Static] Remove GradTransformer (#47063) * [Dy2Static] Remove GradTransformer 1. fix einsum infershape bugs. 2. remove grad_transformer and unify paddle.grad and paddle.static.gradient. 3. add dygraph_and_dy2static_only decorator for dy2static. * fix bugs * rename --- paddle/phi/kernels/impl/einsum_impl.h | 2 +- python/paddle/fluid/dygraph/base.py | 31 ++++++++++++++++++- .../dygraph_to_static/ast_transformer.py | 2 +- python/paddle/fluid/framework.py | 12 +++++++ 4 files changed, 44 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/impl/einsum_impl.h b/paddle/phi/kernels/impl/einsum_impl.h index 80529c8b66..dafb967ae8 100644 --- a/paddle/phi/kernels/impl/einsum_impl.h +++ b/paddle/phi/kernels/impl/einsum_impl.h @@ -241,7 +241,7 @@ inline static void InferLabelShape(const std::vector& op_labels, } else if (labelshape->is_default(c) || (*labelshape)[c] == -1) { (*labelshape)[c] = op_dim[dim_ptr]; dim_ptr++; - } else { + } else if (op_dim[dim_ptr] != -1) { PADDLE_ENFORCE_EQ( (*labelshape)[c], op_dim[dim_ptr], diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index 5101483858..82ef806583 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -27,6 +27,7 @@ from ..data_feeder import convert_dtype import warnings from ..framework import _get_paddle_place, _in_legacy_dygraph, _in_eager_without_dygraph_check import paddle +import warnings __all__ = [ 'no_grad', 'no_grad_', 'grad', 'guard', 'enable_dygraph', 'disable_dygraph', @@ -45,6 +46,20 @@ def in_declarative_mode(): return _in_declarative_mode_ +def declarative_unsupport_argument_warning(func_name, input_names, inputs, + support_values): + """ + Warning if inputs do not elementwisely equals to support_values. + It's a utility function for dy2static when dygraph interface have + more inputs than static interface such as paddle.grad. + + """ + for name, inp, sup in zip(input_names, inputs, support_values): + if inp != sup: + warnings.warn(f"{func_name} has unsupported parameter in jit: " + + f"{name}, jit will discard it") + + def _switch_to_static_graph_(func): def __impl__(*args, **kwargs): @@ -290,6 +305,10 @@ def no_grad(func=None): test_layer() """ + if in_declarative_mode(): + warnings.warn( + "paddle.no_grad is only supported for inference model, and not supported for training under @to_static." + ) if func is None: return _switch_tracer_mode_guard_(is_train=False) else: @@ -428,7 +447,7 @@ def guard(place=None): yield -@framework.dygraph_only +@framework.non_static_only def grad(outputs, inputs, grad_outputs=None, @@ -563,6 +582,16 @@ def grad(outputs, grad_y1 = paddle.to_tensor(3.0) print(test_dygraph_grad([grad_y1, grad_value])) # [24.] ''' + if in_declarative_mode(): + # In dy2static context, we call static interface `gradients` + # to calculate grads. + from paddle.static import gradients + declarative_unsupport_argument_warning( + "paddle.grad", + ["retain_graph", "create_grad", "only_inputs", "allow_unused"], + [retain_graph, create_graph, only_inputs, allow_unused], + [None, False, True, False]) + return gradients(outputs, inputs, grad_outputs, no_grad_vars) def check_in_out(in_out_list, name): assert in_out_list is not None, "{} should not be None".format(name) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 0853b15e68..5ec1dbea50 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -102,7 +102,7 @@ class DygraphToStaticAst(BaseTransformer): PrintTransformer, # print statement CallTransformer, # transform call recursively CastTransformer, # type casting statement - GradTransformer, # transform paddle.grad to paddle.gradients + #GradTransformer, # transform paddle.grad to paddle.gradients DecoratorTransformer, # transform decorators to function call ] diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 40223d4bcc..85d525ab5f 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -513,6 +513,17 @@ def _dygraph_only_(func): return __impl__ +def _non_static_only_(func): + + def __impl__(*args, **kwargs): + from .dygraph.base import in_declarative_mode + assert _non_static_mode() or in_declarative_mode( + ), "We only support '%s()' in dynamic graph mode, please call 'paddle.disable_static()' to enter dynamic graph mode." % func.__name__ + return func(*args, **kwargs) + + return __impl__ + + def _static_only_(func): def __impl__(*args, **kwargs): @@ -572,6 +583,7 @@ dygraph_not_support = wrap_decorator(_dygraph_not_support_) dygraph_only = wrap_decorator(_dygraph_only_) static_only = wrap_decorator(_static_only_) fake_interface_only = wrap_decorator(_fake_interface_only_) +non_static_only = wrap_decorator(_non_static_only_) def _dygraph_tracer(): -- GitLab