未验证 提交 98debaa8 编写于 作者: G GGBond8488 提交者: GitHub

【Inplace】Add copy for inplace (#54683)

* add clone for inpalce

* fix name

* add inplace pow

* fix typro

* add note

* fix typro

* fix typro

* fix bug

* fix test error

* add type error test

* adjust indentation
上级 86858a5a
......@@ -226,7 +226,17 @@ FORWARD_FUNCTION_TEMPLATE = """
VLOG(5) << \"Running C++ API: \" << \"{}\";
// Before log info
{}
// Forward API Call
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
// Node Declaration
std::shared_ptr<{}> grad_node;
// Set grad_node before API Call
{}
// Forward API Call
{}
// Check NaN and Inf if needed
{}
......@@ -234,12 +244,9 @@ FORWARD_FUNCTION_TEMPLATE = """
{}
// Get Output AutoGradMeta
{}
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
// Check Inplace if needed
{}{}
// Node Creation
// Set grad_node after API call
{}
VLOG(4) << \"Finish AD API: {}";
......@@ -296,10 +303,8 @@ FORWARD_ONLY_FUNCTION_TEMPLATE = """
}}
"""
FORWARD_BODY_TEMPLATE = """ if(require_any_grad) {{
FORWARD_BODY_BEFORE_API_CALL_TEMPLATE = """ if(require_any_grad) {{
{}
egr::EagerUtils::PassStopGradient({});
// Node Construction
{}
// Set for forward trace
......@@ -310,6 +315,13 @@ FORWARD_BODY_TEMPLATE = """ if(require_any_grad) {{
{}
// Set TensorWrappers for Forward Inputs if needed
{}
}}
"""
FORWARD_BODY_AFTER_API_CALL_TEMPLATE = """ if(require_any_grad) {{
egr::EagerUtils::PassStopGradient({});
// SetGradOutMeta & SetEdges
{}
// SetOutRank & SetHistory & SetGradInMeta
......@@ -913,7 +925,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
return pass_stop_gradient_args_str
def GenerateNodeCreationCodes(self, for_backward=False):
def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False):
forward_api_name = self.forward_api_name
forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map
......@@ -936,6 +948,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
num_backward_inputs = len(forward_outputs_position_map.keys())
num_backward_outputs = len(forward_inputs_position_map.keys())
grad_node_name = GetGradNodeName(self.backward_api_name)
self.grad_node_name = grad_node_name
# Helper
indent = GetIndent(2)
......@@ -945,6 +958,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
# See https://stackoverflow.com/questions/31228656/how-can-shared-ptr-disrupt-alignment
# and https://github.com/MRtrix3/mrtrix3/issues/957
node_construction_str = f"{indent}auto grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));"
node_assignment_str = f"{indent}grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));"
# SetAttributes
set_attributes_list = []
......@@ -972,14 +986,25 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
pos,
) in backward_forward_inputs_map.items():
is_optional = name in optional_inputs
is_inplace_input = (
is_inplaced and name in self.forward_inplace_map.keys()
)
if is_fwd_input:
if is_optional:
set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});"
if is_inplace_input:
set_tensor_wrappers = """{indent}if({name}) {
auto {name}_clone = paddle::experimental::assign({name});
grad_node->SetTensorWrapper{name}(*{name}_clone);}""".format_map(
{"indent": indent, "name": name}
)
else:
set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});"
else:
set_tensor_wrappers = (
f"{indent}grad_node->SetTensorWrapper{name}({name});"
)
if is_inplace_input:
set_tensor_wrappers = f"{indent}auto {name}_clone = paddle::experimental::assign({name});\n{indent}grad_node->SetTensorWrapper{name}({name}_clone);"
else:
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name});"
set_input_tensor_wrappers_list.append(set_tensor_wrappers)
else: # Forwad's output as backward's input
if num_fwd_outputs > 1:
......@@ -1073,18 +1098,25 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
node_event_name = forward_api_name + " node_creation"
node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::OperatorInner, 1);\n"
self.node_creation_str = ""
if not for_backward:
self.node_creation_str = FORWARD_BODY_TEMPLATE.format(
node_creation_event_str,
pass_stop_gradient_args_str,
node_construction_str,
set_attributes_str,
set_input_tensor_wrappers_str,
set_grad_out_meta_str,
set_out_rank_str,
set_history_str,
set_grad_in_meta_str,
set_output_tensor_wrappers_str,
self.node_creation_before_call_str = (
FORWARD_BODY_BEFORE_API_CALL_TEMPLATE.format(
node_creation_event_str,
node_assignment_str,
set_attributes_str,
set_input_tensor_wrappers_str,
)
)
self.node_creation_after_call_str = (
FORWARD_BODY_AFTER_API_CALL_TEMPLATE.format(
pass_stop_gradient_args_str,
set_grad_out_meta_str,
set_out_rank_str,
set_history_str,
set_grad_in_meta_str,
set_output_tensor_wrappers_str,
)
)
else:
self.node_creation_str = (
......@@ -1614,8 +1646,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
# Node Creation
self.GenerateNodeCreationCodes()
self.GenerateNodeCreationCodes(is_inplaced=is_inplaced)
node_creation_str = self.node_creation_str
node_creation_before_call_str = self.node_creation_before_call_str
node_creation_after_call_str = self.node_creation_after_call_str
dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n"
forward_ad_function_name = GetDygraphForwardFunctionName(
......@@ -1725,14 +1759,16 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
inputs_autograd_meta_str,
forward_api_name,
before_log_str,
compute_require_grad_args_str,
self.grad_node_name,
node_creation_before_call_str,
forward_call_str,
check_nan_inf_str,
get_outputs_str,
outputs_autograd_meta_str,
compute_require_grad_args_str,
check_inplace_str,
bump_inplace_version_str,
node_creation_str,
node_creation_after_call_str,
forward_api_name,
log_str,
returns_str,
......@@ -1881,7 +1917,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
namespace,
)
next_node_generator.run()
next_node_generator.GenerateNodeCreationCodes(True)
next_node_generator.GenerateNodeCreationCodes(for_backward=True)
next_grad_node_creation_str = next_node_generator.node_creation_str
next_grad_node_out_list = next_node_generator.grad_node_out_list
......
......@@ -254,6 +254,7 @@
func : ElementwiseInferMeta
kernel :
func : elementwise_pow
inplace: (x -> out)
backward : elementwise_pow_grad
- op : embedding
......
......@@ -1822,6 +1822,7 @@
kernel :
func : pow
data_type : x
inplace: (x -> out)
backward : pow_grad
- op : prelu
......
......@@ -226,6 +226,7 @@ from .tensor.math import log2 # noqa: F401
from .tensor.math import log10 # noqa: F401
from .tensor.math import multiplex # noqa: F401
from .tensor.math import pow # noqa: F401
from .tensor.math import pow_ # noqa: F401
from .tensor.math import reciprocal # noqa: F401
from .tensor.math import all # noqa: F401
from .tensor.math import any # noqa: F401
......@@ -561,6 +562,7 @@ __all__ = [ # noqa
'abs',
'tril',
'pow',
'pow_',
'zeros_like',
'maximum',
'topk',
......
......@@ -164,6 +164,7 @@ from .math import increment # noqa: F401
from .math import log # noqa: F401
from .math import multiplex # noqa: F401
from .math import pow # noqa: F401
from .math import pow_ # noqa: F401
from .math import reciprocal # noqa: F401
from .math import reciprocal_ # noqa: F401
from .math import round # noqa: F401
......@@ -366,6 +367,7 @@ tensor_method_func = [ # noqa
'logsumexp',
'multiplex',
'pow',
'pow_',
'prod',
'reciprocal',
'reciprocal_',
......
......@@ -474,6 +474,22 @@ def pow(x, y, name=None):
)
@inplace_apis_in_dygraph_only
def pow_(x, y, name=None):
"""
Inplace version of ``pow`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_tensor_pow`.
"""
if isinstance(y, (int, float)):
return _C_ops.pow_(x, y)
elif isinstance(y, (paddle.Tensor, Variable)):
return _C_ops.elementwise_pow_(x, y)
else:
raise TypeError(
'y must be scalar or tensor type, but received: %s ' % (type(y))
)
OP_NAMEMAPPING = {
'elementwise_max': 'maximum',
'elementwise_min': 'minimum',
......
......@@ -22,6 +22,8 @@ from paddle.framework import in_dynamic_mode
# NOTE(pangyoki): The Inplace APIs with underline(`_`) is only valid for the method of calling `_C_ops`
# in dygraph mode. If static graph mode is used, the inplace mechanism will not be used, and the static method
# of the original API will be called.
# NOTE(GGBond8488): Simply run the original version of the API under the static graph mode has a low
# probability that the result is inconsistent with the dynamic graph.
def _inplace_apis_in_dygraph_only_(func):
def __impl__(*args, **kwargs):
if not in_dynamic_mode():
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from test_inplace import TestDygraphInplace
import paddle
from paddle.fluid import core
......@@ -213,5 +214,40 @@ class TestPowerError(unittest.TestCase):
self.assertRaises(TypeError, paddle.pow, x, str(y))
class TestInplacePowerScalar(TestDygraphInplace):
def set_np_compare_func(self):
self.np_compare = np.allclose
def inplace_api_processing(self, var):
return paddle.pow_(var, 2)
def non_inplace_api_processing(self, var):
return paddle.pow(var, 2)
class TestInplacePowerTensor(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1])
self.dtype = "float32"
self.y = paddle.ones([10, 20, 1], dtype="float32") * 2
def set_np_compare_func(self):
self.np_compare = np.allclose
def inplace_api_processing(self, var):
return paddle.pow_(var, self.y)
def non_inplace_api_processing(self, var):
return paddle.pow(var, self.y)
def test_type_error(self):
var = paddle.to_tensor(self.input_var_numpy, dtype=self.dtype)
with self.assertRaisesRegex(
TypeError,
'y must be scalar or tensor type, but received: %s ' % (type([2])),
):
paddle.pow_(var, [2])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册