未验证 提交 98debaa8 编写于 作者: G GGBond8488 提交者: GitHub

【Inplace】Add copy for inplace (#54683)

* add clone for inpalce

* fix name

* add inplace pow

* fix typro

* add note

* fix typro

* fix typro

* fix bug

* fix test error

* add type error test

* adjust indentation
上级 86858a5a
...@@ -226,7 +226,17 @@ FORWARD_FUNCTION_TEMPLATE = """ ...@@ -226,7 +226,17 @@ FORWARD_FUNCTION_TEMPLATE = """
VLOG(5) << \"Running C++ API: \" << \"{}\"; VLOG(5) << \"Running C++ API: \" << \"{}\";
// Before log info // Before log info
{} {}
// Forward API Call
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
// Node Declaration
std::shared_ptr<{}> grad_node;
// Set grad_node before API Call
{}
// Forward API Call
{} {}
// Check NaN and Inf if needed // Check NaN and Inf if needed
{} {}
...@@ -234,12 +244,9 @@ FORWARD_FUNCTION_TEMPLATE = """ ...@@ -234,12 +244,9 @@ FORWARD_FUNCTION_TEMPLATE = """
{} {}
// Get Output AutoGradMeta // Get Output AutoGradMeta
{} {}
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
// Check Inplace if needed // Check Inplace if needed
{}{} {}{}
// Node Creation // Set grad_node after API call
{} {}
VLOG(4) << \"Finish AD API: {}"; VLOG(4) << \"Finish AD API: {}";
...@@ -296,10 +303,8 @@ FORWARD_ONLY_FUNCTION_TEMPLATE = """ ...@@ -296,10 +303,8 @@ FORWARD_ONLY_FUNCTION_TEMPLATE = """
}} }}
""" """
FORWARD_BODY_TEMPLATE = """ if(require_any_grad) {{ FORWARD_BODY_BEFORE_API_CALL_TEMPLATE = """ if(require_any_grad) {{
{} {}
egr::EagerUtils::PassStopGradient({});
// Node Construction // Node Construction
{} {}
// Set for forward trace // Set for forward trace
...@@ -310,6 +315,13 @@ FORWARD_BODY_TEMPLATE = """ if(require_any_grad) {{ ...@@ -310,6 +315,13 @@ FORWARD_BODY_TEMPLATE = """ if(require_any_grad) {{
{} {}
// Set TensorWrappers for Forward Inputs if needed // Set TensorWrappers for Forward Inputs if needed
{} {}
}}
"""
FORWARD_BODY_AFTER_API_CALL_TEMPLATE = """ if(require_any_grad) {{
egr::EagerUtils::PassStopGradient({});
// SetGradOutMeta & SetEdges // SetGradOutMeta & SetEdges
{} {}
// SetOutRank & SetHistory & SetGradInMeta // SetOutRank & SetHistory & SetGradInMeta
...@@ -913,7 +925,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -913,7 +925,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list) pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
return pass_stop_gradient_args_str return pass_stop_gradient_args_str
def GenerateNodeCreationCodes(self, for_backward=False): def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False):
forward_api_name = self.forward_api_name forward_api_name = self.forward_api_name
forward_inputs_position_map = self.forward_inputs_position_map forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map forward_outputs_position_map = self.forward_outputs_position_map
...@@ -936,6 +948,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -936,6 +948,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
num_backward_inputs = len(forward_outputs_position_map.keys()) num_backward_inputs = len(forward_outputs_position_map.keys())
num_backward_outputs = len(forward_inputs_position_map.keys()) num_backward_outputs = len(forward_inputs_position_map.keys())
grad_node_name = GetGradNodeName(self.backward_api_name) grad_node_name = GetGradNodeName(self.backward_api_name)
self.grad_node_name = grad_node_name
# Helper # Helper
indent = GetIndent(2) indent = GetIndent(2)
...@@ -945,6 +958,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -945,6 +958,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
# See https://stackoverflow.com/questions/31228656/how-can-shared-ptr-disrupt-alignment # See https://stackoverflow.com/questions/31228656/how-can-shared-ptr-disrupt-alignment
# and https://github.com/MRtrix3/mrtrix3/issues/957 # and https://github.com/MRtrix3/mrtrix3/issues/957
node_construction_str = f"{indent}auto grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));" node_construction_str = f"{indent}auto grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));"
node_assignment_str = f"{indent}grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));"
# SetAttributes # SetAttributes
set_attributes_list = [] set_attributes_list = []
...@@ -972,14 +986,25 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -972,14 +986,25 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
pos, pos,
) in backward_forward_inputs_map.items(): ) in backward_forward_inputs_map.items():
is_optional = name in optional_inputs is_optional = name in optional_inputs
is_inplace_input = (
is_inplaced and name in self.forward_inplace_map.keys()
)
if is_fwd_input: if is_fwd_input:
if is_optional: if is_optional:
set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});" if is_inplace_input:
set_tensor_wrappers = """{indent}if({name}) {
auto {name}_clone = paddle::experimental::assign({name});
grad_node->SetTensorWrapper{name}(*{name}_clone);}""".format_map(
{"indent": indent, "name": name}
)
else:
set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});"
else: else:
set_tensor_wrappers = ( if is_inplace_input:
f"{indent}grad_node->SetTensorWrapper{name}({name});" set_tensor_wrappers = f"{indent}auto {name}_clone = paddle::experimental::assign({name});\n{indent}grad_node->SetTensorWrapper{name}({name}_clone);"
) else:
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name});"
set_input_tensor_wrappers_list.append(set_tensor_wrappers) set_input_tensor_wrappers_list.append(set_tensor_wrappers)
else: # Forwad's output as backward's input else: # Forwad's output as backward's input
if num_fwd_outputs > 1: if num_fwd_outputs > 1:
...@@ -1073,18 +1098,25 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -1073,18 +1098,25 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
node_event_name = forward_api_name + " node_creation" node_event_name = forward_api_name + " node_creation"
node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::OperatorInner, 1);\n" node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::OperatorInner, 1);\n"
self.node_creation_str = ""
if not for_backward: if not for_backward:
self.node_creation_str = FORWARD_BODY_TEMPLATE.format( self.node_creation_before_call_str = (
node_creation_event_str, FORWARD_BODY_BEFORE_API_CALL_TEMPLATE.format(
pass_stop_gradient_args_str, node_creation_event_str,
node_construction_str, node_assignment_str,
set_attributes_str, set_attributes_str,
set_input_tensor_wrappers_str, set_input_tensor_wrappers_str,
set_grad_out_meta_str, )
set_out_rank_str, )
set_history_str, self.node_creation_after_call_str = (
set_grad_in_meta_str, FORWARD_BODY_AFTER_API_CALL_TEMPLATE.format(
set_output_tensor_wrappers_str, pass_stop_gradient_args_str,
set_grad_out_meta_str,
set_out_rank_str,
set_history_str,
set_grad_in_meta_str,
set_output_tensor_wrappers_str,
)
) )
else: else:
self.node_creation_str = ( self.node_creation_str = (
...@@ -1614,8 +1646,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1614,8 +1646,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
# Node Creation # Node Creation
self.GenerateNodeCreationCodes() self.GenerateNodeCreationCodes(is_inplaced=is_inplaced)
node_creation_str = self.node_creation_str node_creation_str = self.node_creation_str
node_creation_before_call_str = self.node_creation_before_call_str
node_creation_after_call_str = self.node_creation_after_call_str
dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n" dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n"
forward_ad_function_name = GetDygraphForwardFunctionName( forward_ad_function_name = GetDygraphForwardFunctionName(
...@@ -1725,14 +1759,16 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1725,14 +1759,16 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
inputs_autograd_meta_str, inputs_autograd_meta_str,
forward_api_name, forward_api_name,
before_log_str, before_log_str,
compute_require_grad_args_str,
self.grad_node_name,
node_creation_before_call_str,
forward_call_str, forward_call_str,
check_nan_inf_str, check_nan_inf_str,
get_outputs_str, get_outputs_str,
outputs_autograd_meta_str, outputs_autograd_meta_str,
compute_require_grad_args_str,
check_inplace_str, check_inplace_str,
bump_inplace_version_str, bump_inplace_version_str,
node_creation_str, node_creation_after_call_str,
forward_api_name, forward_api_name,
log_str, log_str,
returns_str, returns_str,
...@@ -1881,7 +1917,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1881,7 +1917,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
namespace, namespace,
) )
next_node_generator.run() next_node_generator.run()
next_node_generator.GenerateNodeCreationCodes(True) next_node_generator.GenerateNodeCreationCodes(for_backward=True)
next_grad_node_creation_str = next_node_generator.node_creation_str next_grad_node_creation_str = next_node_generator.node_creation_str
next_grad_node_out_list = next_node_generator.grad_node_out_list next_grad_node_out_list = next_node_generator.grad_node_out_list
......
...@@ -254,6 +254,7 @@ ...@@ -254,6 +254,7 @@
func : ElementwiseInferMeta func : ElementwiseInferMeta
kernel : kernel :
func : elementwise_pow func : elementwise_pow
inplace: (x -> out)
backward : elementwise_pow_grad backward : elementwise_pow_grad
- op : embedding - op : embedding
......
...@@ -1822,6 +1822,7 @@ ...@@ -1822,6 +1822,7 @@
kernel : kernel :
func : pow func : pow
data_type : x data_type : x
inplace: (x -> out)
backward : pow_grad backward : pow_grad
- op : prelu - op : prelu
......
...@@ -226,6 +226,7 @@ from .tensor.math import log2 # noqa: F401 ...@@ -226,6 +226,7 @@ from .tensor.math import log2 # noqa: F401
from .tensor.math import log10 # noqa: F401 from .tensor.math import log10 # noqa: F401
from .tensor.math import multiplex # noqa: F401 from .tensor.math import multiplex # noqa: F401
from .tensor.math import pow # noqa: F401 from .tensor.math import pow # noqa: F401
from .tensor.math import pow_ # noqa: F401
from .tensor.math import reciprocal # noqa: F401 from .tensor.math import reciprocal # noqa: F401
from .tensor.math import all # noqa: F401 from .tensor.math import all # noqa: F401
from .tensor.math import any # noqa: F401 from .tensor.math import any # noqa: F401
...@@ -561,6 +562,7 @@ __all__ = [ # noqa ...@@ -561,6 +562,7 @@ __all__ = [ # noqa
'abs', 'abs',
'tril', 'tril',
'pow', 'pow',
'pow_',
'zeros_like', 'zeros_like',
'maximum', 'maximum',
'topk', 'topk',
......
...@@ -164,6 +164,7 @@ from .math import increment # noqa: F401 ...@@ -164,6 +164,7 @@ from .math import increment # noqa: F401
from .math import log # noqa: F401 from .math import log # noqa: F401
from .math import multiplex # noqa: F401 from .math import multiplex # noqa: F401
from .math import pow # noqa: F401 from .math import pow # noqa: F401
from .math import pow_ # noqa: F401
from .math import reciprocal # noqa: F401 from .math import reciprocal # noqa: F401
from .math import reciprocal_ # noqa: F401 from .math import reciprocal_ # noqa: F401
from .math import round # noqa: F401 from .math import round # noqa: F401
...@@ -366,6 +367,7 @@ tensor_method_func = [ # noqa ...@@ -366,6 +367,7 @@ tensor_method_func = [ # noqa
'logsumexp', 'logsumexp',
'multiplex', 'multiplex',
'pow', 'pow',
'pow_',
'prod', 'prod',
'reciprocal', 'reciprocal',
'reciprocal_', 'reciprocal_',
......
...@@ -474,6 +474,22 @@ def pow(x, y, name=None): ...@@ -474,6 +474,22 @@ def pow(x, y, name=None):
) )
@inplace_apis_in_dygraph_only
def pow_(x, y, name=None):
"""
Inplace version of ``pow`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_tensor_pow`.
"""
if isinstance(y, (int, float)):
return _C_ops.pow_(x, y)
elif isinstance(y, (paddle.Tensor, Variable)):
return _C_ops.elementwise_pow_(x, y)
else:
raise TypeError(
'y must be scalar or tensor type, but received: %s ' % (type(y))
)
OP_NAMEMAPPING = { OP_NAMEMAPPING = {
'elementwise_max': 'maximum', 'elementwise_max': 'maximum',
'elementwise_min': 'minimum', 'elementwise_min': 'minimum',
......
...@@ -22,6 +22,8 @@ from paddle.framework import in_dynamic_mode ...@@ -22,6 +22,8 @@ from paddle.framework import in_dynamic_mode
# NOTE(pangyoki): The Inplace APIs with underline(`_`) is only valid for the method of calling `_C_ops` # NOTE(pangyoki): The Inplace APIs with underline(`_`) is only valid for the method of calling `_C_ops`
# in dygraph mode. If static graph mode is used, the inplace mechanism will not be used, and the static method # in dygraph mode. If static graph mode is used, the inplace mechanism will not be used, and the static method
# of the original API will be called. # of the original API will be called.
# NOTE(GGBond8488): Simply run the original version of the API under the static graph mode has a low
# probability that the result is inconsistent with the dynamic graph.
def _inplace_apis_in_dygraph_only_(func): def _inplace_apis_in_dygraph_only_(func):
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
if not in_dynamic_mode(): if not in_dynamic_mode():
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from test_inplace import TestDygraphInplace
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
...@@ -213,5 +214,40 @@ class TestPowerError(unittest.TestCase): ...@@ -213,5 +214,40 @@ class TestPowerError(unittest.TestCase):
self.assertRaises(TypeError, paddle.pow, x, str(y)) self.assertRaises(TypeError, paddle.pow, x, str(y))
class TestInplacePowerScalar(TestDygraphInplace):
def set_np_compare_func(self):
self.np_compare = np.allclose
def inplace_api_processing(self, var):
return paddle.pow_(var, 2)
def non_inplace_api_processing(self, var):
return paddle.pow(var, 2)
class TestInplacePowerTensor(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1])
self.dtype = "float32"
self.y = paddle.ones([10, 20, 1], dtype="float32") * 2
def set_np_compare_func(self):
self.np_compare = np.allclose
def inplace_api_processing(self, var):
return paddle.pow_(var, self.y)
def non_inplace_api_processing(self, var):
return paddle.pow(var, self.y)
def test_type_error(self):
var = paddle.to_tensor(self.input_var_numpy, dtype=self.dtype)
with self.assertRaisesRegex(
TypeError,
'y must be scalar or tensor type, but received: %s ' % (type([2])),
):
paddle.pow_(var, [2])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册