From 98debaa89da18129c7bf62aa62be7a024dca313d Mon Sep 17 00:00:00 2001 From: GGBond8488 <33050871+GGBond8488@users.noreply.github.com> Date: Wed, 28 Jun 2023 16:10:03 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Inplace=E3=80=91Add=20copy=20for=20inp?= =?UTF-8?q?lace=20(#54683)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add clone for inpalce * fix name * add inplace pow * fix typro * add note * fix typro * fix typro * fix bug * fix test error * add type error test * adjust indentation --- .../generator/eager_gen.py | 92 +++++++++++++------ paddle/phi/api/yaml/legacy_ops.yaml | 1 + paddle/phi/api/yaml/ops.yaml | 1 + python/paddle/__init__.py | 2 + python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/math.py | 16 ++++ python/paddle/utils/inplace_utils.py | 2 + test/legacy_test/test_pow.py | 36 ++++++++ 8 files changed, 124 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index fe3f73c845e..b90cb5bce70 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -226,7 +226,17 @@ FORWARD_FUNCTION_TEMPLATE = """ VLOG(5) << \"Running C++ API: \" << \"{}\"; // Before log info {} - // Forward API Call + + bool trace_backward = egr::Controller::Instance().HasGrad(); + bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({}); + + // Node Declaration + std::shared_ptr<{}> grad_node; + + // Set grad_node before API Call +{} + + // Forward API Call {} // Check NaN and Inf if needed {} @@ -234,12 +244,9 @@ FORWARD_FUNCTION_TEMPLATE = """ {} // Get Output AutoGradMeta {} - bool trace_backward = egr::Controller::Instance().HasGrad(); - bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({}); - // Check Inplace if needed {}{} - // Node Creation + // Set grad_node after API call {} VLOG(4) << \"Finish AD API: {}"; @@ -296,10 +303,8 @@ FORWARD_ONLY_FUNCTION_TEMPLATE = """ }} """ -FORWARD_BODY_TEMPLATE = """ if(require_any_grad) {{ +FORWARD_BODY_BEFORE_API_CALL_TEMPLATE = """ if(require_any_grad) {{ {} - egr::EagerUtils::PassStopGradient({}); - // Node Construction {} // Set for forward trace @@ -310,6 +315,13 @@ FORWARD_BODY_TEMPLATE = """ if(require_any_grad) {{ {} // Set TensorWrappers for Forward Inputs if needed {} + }} +""" + +FORWARD_BODY_AFTER_API_CALL_TEMPLATE = """ if(require_any_grad) {{ + + egr::EagerUtils::PassStopGradient({}); + // SetGradOutMeta & SetEdges {} // SetOutRank & SetHistory & SetGradInMeta @@ -913,7 +925,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list) return pass_stop_gradient_args_str - def GenerateNodeCreationCodes(self, for_backward=False): + def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False): forward_api_name = self.forward_api_name forward_inputs_position_map = self.forward_inputs_position_map forward_outputs_position_map = self.forward_outputs_position_map @@ -936,6 +948,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): num_backward_inputs = len(forward_outputs_position_map.keys()) num_backward_outputs = len(forward_inputs_position_map.keys()) grad_node_name = GetGradNodeName(self.backward_api_name) + self.grad_node_name = grad_node_name # Helper indent = GetIndent(2) @@ -945,6 +958,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): # See https://stackoverflow.com/questions/31228656/how-can-shared-ptr-disrupt-alignment # and https://github.com/MRtrix3/mrtrix3/issues/957 node_construction_str = f"{indent}auto grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));" + node_assignment_str = f"{indent}grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));" # SetAttributes set_attributes_list = [] @@ -972,14 +986,25 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): pos, ) in backward_forward_inputs_map.items(): is_optional = name in optional_inputs + is_inplace_input = ( + is_inplaced and name in self.forward_inplace_map.keys() + ) if is_fwd_input: if is_optional: - set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});" + if is_inplace_input: + set_tensor_wrappers = """{indent}if({name}) { + auto {name}_clone = paddle::experimental::assign({name}); + grad_node->SetTensorWrapper{name}(*{name}_clone);}""".format_map( + {"indent": indent, "name": name} + ) + else: + set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});" else: - set_tensor_wrappers = ( - f"{indent}grad_node->SetTensorWrapper{name}({name});" - ) + if is_inplace_input: + set_tensor_wrappers = f"{indent}auto {name}_clone = paddle::experimental::assign({name});\n{indent}grad_node->SetTensorWrapper{name}({name}_clone);" + else: + set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name});" set_input_tensor_wrappers_list.append(set_tensor_wrappers) else: # Forwad's output as backward's input if num_fwd_outputs > 1: @@ -1073,18 +1098,25 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): node_event_name = forward_api_name + " node_creation" node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::OperatorInner, 1);\n" + self.node_creation_str = "" if not for_backward: - self.node_creation_str = FORWARD_BODY_TEMPLATE.format( - node_creation_event_str, - pass_stop_gradient_args_str, - node_construction_str, - set_attributes_str, - set_input_tensor_wrappers_str, - set_grad_out_meta_str, - set_out_rank_str, - set_history_str, - set_grad_in_meta_str, - set_output_tensor_wrappers_str, + self.node_creation_before_call_str = ( + FORWARD_BODY_BEFORE_API_CALL_TEMPLATE.format( + node_creation_event_str, + node_assignment_str, + set_attributes_str, + set_input_tensor_wrappers_str, + ) + ) + self.node_creation_after_call_str = ( + FORWARD_BODY_AFTER_API_CALL_TEMPLATE.format( + pass_stop_gradient_args_str, + set_grad_out_meta_str, + set_out_rank_str, + set_history_str, + set_grad_in_meta_str, + set_output_tensor_wrappers_str, + ) ) else: self.node_creation_str = ( @@ -1614,8 +1646,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) # Node Creation - self.GenerateNodeCreationCodes() + self.GenerateNodeCreationCodes(is_inplaced=is_inplaced) node_creation_str = self.node_creation_str + node_creation_before_call_str = self.node_creation_before_call_str + node_creation_after_call_str = self.node_creation_after_call_str dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n" forward_ad_function_name = GetDygraphForwardFunctionName( @@ -1725,14 +1759,16 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): inputs_autograd_meta_str, forward_api_name, before_log_str, + compute_require_grad_args_str, + self.grad_node_name, + node_creation_before_call_str, forward_call_str, check_nan_inf_str, get_outputs_str, outputs_autograd_meta_str, - compute_require_grad_args_str, check_inplace_str, bump_inplace_version_str, - node_creation_str, + node_creation_after_call_str, forward_api_name, log_str, returns_str, @@ -1881,7 +1917,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): namespace, ) next_node_generator.run() - next_node_generator.GenerateNodeCreationCodes(True) + next_node_generator.GenerateNodeCreationCodes(for_backward=True) next_grad_node_creation_str = next_node_generator.node_creation_str next_grad_node_out_list = next_node_generator.grad_node_out_list diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 9d660f4be9a..fd0d4c1c520 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -254,6 +254,7 @@ func : ElementwiseInferMeta kernel : func : elementwise_pow + inplace: (x -> out) backward : elementwise_pow_grad - op : embedding diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 3b2d4623f9a..34c41c1d0a2 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1822,6 +1822,7 @@ kernel : func : pow data_type : x + inplace: (x -> out) backward : pow_grad - op : prelu diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index f4b262573e9..4963ad8b511 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -226,6 +226,7 @@ from .tensor.math import log2 # noqa: F401 from .tensor.math import log10 # noqa: F401 from .tensor.math import multiplex # noqa: F401 from .tensor.math import pow # noqa: F401 +from .tensor.math import pow_ # noqa: F401 from .tensor.math import reciprocal # noqa: F401 from .tensor.math import all # noqa: F401 from .tensor.math import any # noqa: F401 @@ -561,6 +562,7 @@ __all__ = [ # noqa 'abs', 'tril', 'pow', + 'pow_', 'zeros_like', 'maximum', 'topk', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 819a731067b..95623f145b6 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -164,6 +164,7 @@ from .math import increment # noqa: F401 from .math import log # noqa: F401 from .math import multiplex # noqa: F401 from .math import pow # noqa: F401 +from .math import pow_ # noqa: F401 from .math import reciprocal # noqa: F401 from .math import reciprocal_ # noqa: F401 from .math import round # noqa: F401 @@ -366,6 +367,7 @@ tensor_method_func = [ # noqa 'logsumexp', 'multiplex', 'pow', + 'pow_', 'prod', 'reciprocal', 'reciprocal_', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 9aa77730262..8b5af17b86f 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -474,6 +474,22 @@ def pow(x, y, name=None): ) +@inplace_apis_in_dygraph_only +def pow_(x, y, name=None): + """ + Inplace version of ``pow`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_tensor_pow`. + """ + if isinstance(y, (int, float)): + return _C_ops.pow_(x, y) + elif isinstance(y, (paddle.Tensor, Variable)): + return _C_ops.elementwise_pow_(x, y) + else: + raise TypeError( + 'y must be scalar or tensor type, but received: %s ' % (type(y)) + ) + + OP_NAMEMAPPING = { 'elementwise_max': 'maximum', 'elementwise_min': 'minimum', diff --git a/python/paddle/utils/inplace_utils.py b/python/paddle/utils/inplace_utils.py index e02ddbeb758..934dd314a35 100644 --- a/python/paddle/utils/inplace_utils.py +++ b/python/paddle/utils/inplace_utils.py @@ -22,6 +22,8 @@ from paddle.framework import in_dynamic_mode # NOTE(pangyoki): The Inplace APIs with underline(`_`) is only valid for the method of calling `_C_ops` # in dygraph mode. If static graph mode is used, the inplace mechanism will not be used, and the static method # of the original API will be called. +# NOTE(GGBond8488): Simply run the original version of the API under the static graph mode has a low +# probability that the result is inconsistent with the dynamic graph. def _inplace_apis_in_dygraph_only_(func): def __impl__(*args, **kwargs): if not in_dynamic_mode(): diff --git a/test/legacy_test/test_pow.py b/test/legacy_test/test_pow.py index 011593b3e87..e829230492e 100755 --- a/test/legacy_test/test_pow.py +++ b/test/legacy_test/test_pow.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from test_inplace import TestDygraphInplace import paddle from paddle.fluid import core @@ -213,5 +214,40 @@ class TestPowerError(unittest.TestCase): self.assertRaises(TypeError, paddle.pow, x, str(y)) +class TestInplacePowerScalar(TestDygraphInplace): + def set_np_compare_func(self): + self.np_compare = np.allclose + + def inplace_api_processing(self, var): + return paddle.pow_(var, 2) + + def non_inplace_api_processing(self, var): + return paddle.pow(var, 2) + + +class TestInplacePowerTensor(TestDygraphInplace): + def init_data(self): + self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1]) + self.dtype = "float32" + self.y = paddle.ones([10, 20, 1], dtype="float32") * 2 + + def set_np_compare_func(self): + self.np_compare = np.allclose + + def inplace_api_processing(self, var): + return paddle.pow_(var, self.y) + + def non_inplace_api_processing(self, var): + return paddle.pow(var, self.y) + + def test_type_error(self): + var = paddle.to_tensor(self.input_var_numpy, dtype=self.dtype) + with self.assertRaisesRegex( + TypeError, + 'y must be scalar or tensor type, but received: %s ' % (type([2])), + ): + paddle.pow_(var, [2]) + + if __name__ == '__main__': unittest.main() -- GitLab