未验证 提交 01724b1a 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[DoubleGrad #4] Bug Fixes to Double Grad Node Generation (#41121)

* [Refactor] refactored eager_gen.py PR #2

* [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes

* Fixed minor issue

* Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition

* Fixed issues

* Supported higher-order grad node generation

* [DoubleGrad PR #4] Supported higher-order GradNode generation

* [DoubleGrad #4] Bug Fixes to Double Grad Node Generation

* Fixed yaml typo

* Fixed yaml typo

* fixed minor issues

* Fixed minor issue
上级 4da4265a
...@@ -21,7 +21,8 @@ import os ...@@ -21,7 +21,8 @@ import os
######################## ########################
### Global Variables ### ### Global Variables ###
######################## ########################
ops_to_fill_zero_for_empty_grads = set(["split", "rnn"]) ops_to_fill_zero_for_empty_grads = set(
["split_grad", "rnn_grad", "matmul_double_grad"])
# For API dispatch used at python-level # For API dispatch used at python-level
# { op_name : [arg_name, ...] } # { op_name : [arg_name, ...] }
...@@ -176,6 +177,11 @@ def TransformGradVarNameForDoubleGradGeneration(string): ...@@ -176,6 +177,11 @@ def TransformGradVarNameForDoubleGradGeneration(string):
return string return string
def GetIndent(num):
tab = " "
return "".join([tab for i in range(num)])
###################### ######################
### Yaml Parsers ### ### Yaml Parsers ###
###################### ######################
......
...@@ -32,7 +32,7 @@ from codegen_utils import ParseYamlForward, ParseYamlBackward ...@@ -32,7 +32,7 @@ from codegen_utils import ParseYamlForward, ParseYamlBackward
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads from codegen_utils import ops_to_fill_zero_for_empty_grads
from codegen_utils import TransformGradVarNameForDoubleGradGeneration from codegen_utils import TransformGradVarNameForDoubleGradGeneration
from codegen_utils import AssertMessage from codegen_utils import AssertMessage, GetIndent
########### ###########
...@@ -112,80 +112,81 @@ ATTRIBUTE_MEMBER_TEMPLATE = \ ...@@ -112,80 +112,81 @@ ATTRIBUTE_MEMBER_TEMPLATE = \
NODE_DECLARATION_TEMPLATE = \ NODE_DECLARATION_TEMPLATE = \
""" """
class {} : public egr::GradNodeBase {{ class {} : public egr::GradNodeBase {{
public: public:
{}() : egr::GradNodeBase() {{}} {}() : egr::GradNodeBase() {{}}
{}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : {}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) :
egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}} egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}}
~{}() override = default; ~{}() override = default;
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override; std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override;
std::string name() override {{ return \" {} \"; }} std::string name() override {{ return \" {} \"; }}
void ClearTensorWrappers() override {{ void ClearTensorWrappers() override {{
{}
is_tensor_wrappers_cleared = true;
}}
// SetTensorWrapperX, SetTensorWrapperY, ...
{} {}
// SetAttributes is_tensor_wrappers_cleared = true;
{} }}
bool IsTensorWrappersCleared() override {{ // SetTensorWrapperX, SetTensorWrapperY, ...
return is_tensor_wrappers_cleared; {}
}} // SetAttributes
private: {}
// TensorWrappers
{} bool IsTensorWrappersCleared() override {{
return is_tensor_wrappers_cleared;
bool is_tensor_wrappers_cleared = false; }}
private:
// Attributes // TensorWrappers
{} {}
}};
bool is_tensor_wrappers_cleared = false;
// Attributes
{}
}};
""" """
GRAD_FUNCTION_TEMPLATE = \ GRAD_FUNCTION_TEMPLATE = \
""" """
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{ std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{
// Fill Zero For GradIn Tensors // Fill Zero For GradIn Tensors
{} {}
// Apply Gradient Hooks // Apply Gradient Hooks
auto hooked_grads = ApplyGradientHooks(grads); auto hooked_grads = ApplyGradientHooks(grads);
// Collect GradIn Tensors, Attrs and Recovered TensorWrappers // Collect GradIn Tensors, Attrs and Recovered TensorWrappers
{} {}
// Call grad_api function // Call grad_api function
VLOG(3) << \"Final State Running: \" << \"{}\"; VLOG(3) << \"Final State Running: \" << \"{}\";
{} {}
// Get Output // Get Output
{} {}
// Get GradIn autograd_meta // Get GradIn autograd_meta
{} {}
// Get GradOut autograd_meta // Get GradOut autograd_meta
{} {}
// Compute Require Grad // Compute Require Grad
{} {}
// Create Grad Node // Create Grad Node
{} {}
// Return // Return
{} {}
}}
}}
""" """
FORWARD_FUNCTION_TEMPLATE = \ FORWARD_FUNCTION_TEMPLATE = \
""" """
{} {}({}) {{ {} {}({}) {{
// Dygraph Record Event // Dygraph Record Event
{} {}
// AMP Logic // AMP Logic
...@@ -208,33 +209,33 @@ FORWARD_FUNCTION_TEMPLATE = \ ...@@ -208,33 +209,33 @@ FORWARD_FUNCTION_TEMPLATE = \
// Node Creation // Node Creation
{} {}
// Returns // Returns
return {}; return {};
}} }}
""" """
FORWARD_BODY_TEMPLATE = \ FORWARD_BODY_TEMPLATE = \
""" """
if(require_any_grad) {{ if(require_any_grad) {{
{} {}
egr::EagerUtils::PassStopGradient({}); egr::EagerUtils::PassStopGradient({});
// Node Construction // Node Construction
{} {}
// SetAttributes // SetAttributes
{} {}
// SetTensorWrappers // SetTensorWrappers
{} {}
// SetGradOutMeta & SetEdges // SetGradOutMeta & SetEdges
{} {}
{} {}
// SetOutRank & SetHistory & SetGradInMeta & RetainGrad // SetOutRank & SetHistory & SetGradInMeta & RetainGrad
{} {}
{} {}
{} {}
{} {}
}} }}
""" """
NAMESPACE_WRAPPER_TEMPLATE = \ NAMESPACE_WRAPPER_TEMPLATE = \
...@@ -318,9 +319,9 @@ std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_r ...@@ -318,9 +319,9 @@ std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_r
CORE_OPS_DECLARATION_TEMPLATE = \ CORE_OPS_DECLARATION_TEMPLATE = \
""" """
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_info; extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_type_info; extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_type_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_returns_info; extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_returns_info;
""" """
...@@ -352,6 +353,12 @@ AMP_LOGIC_TEMPLATE = \ ...@@ -352,6 +353,12 @@ AMP_LOGIC_TEMPLATE = \
}} }}
""" """
CREATE_PLAIN_OPTIONAL_TENSOR_TEMPLATE = \
"""
paddle::optional<const paddle::experimental::Tensor&> {}_optional = paddle::none;
if({}.initialized()) {}_optional = paddle::make_optional<const paddle::experimental::Tensor&>({});
"""
####################### #######################
## Generator Helpers ## ## Generator Helpers ##
...@@ -678,12 +685,15 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -678,12 +685,15 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
num_backward_inputs = len(forward_outputs_position_map.keys()) num_backward_inputs = len(forward_outputs_position_map.keys())
num_backward_outputs = len(forward_inputs_position_map.keys()) num_backward_outputs = len(forward_inputs_position_map.keys())
grad_node_name = GetGradNodeName(forward_api_name) grad_node_name = GetGradNodeName(forward_api_name)
# Helper
indent = GetIndent(2)
# NOTE(Aurelius74): DO NOT use make_shared here. Because some Node contains experimental::Scalar # NOTE(Aurelius74): DO NOT use make_shared here. Because some Node contains experimental::Scalar
# which contains "complex128" as data. "complex128" is memory-aligned manually. But make_shared # which contains "complex128" as data. "complex128" is memory-aligned manually. But make_shared
# request MEMALIGN for allocation (Maybe). # request MEMALIGN for allocation (Maybe).
# See https://stackoverflow.com/questions/31228656/how-can-shared-ptr-disrupt-alignment # See https://stackoverflow.com/questions/31228656/how-can-shared-ptr-disrupt-alignment
# and https://github.com/MRtrix3/mrtrix3/issues/957 # and https://github.com/MRtrix3/mrtrix3/issues/957
node_construction_str = f" auto grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));" node_construction_str = f"{indent}auto grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));"
# SetAttributes # SetAttributes
set_attributes_list = [] set_attributes_list = []
...@@ -693,9 +703,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -693,9 +703,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
for name, _, default_val_attr, _ in backward_attrs_list: for name, _, default_val_attr, _ in backward_attrs_list:
if name in forward_attrs_name_set: if name in forward_attrs_name_set:
set_attributes = f" grad_node->SetAttribute{name}({name});" set_attributes = f"{indent}grad_node->SetAttribute{name}({name});"
else: else:
set_attributes = f" grad_node->SetAttribute{name}({default_val_attr});" set_attributes = f"{indent}grad_node->SetAttribute{name}({default_val_attr});"
set_attributes_list.append(set_attributes) set_attributes_list.append(set_attributes)
set_attributes_str = "\n".join(set_attributes_list) set_attributes_str = "\n".join(set_attributes_list)
...@@ -708,9 +718,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -708,9 +718,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
if is_fwd_input: if is_fwd_input:
if is_optional: if is_optional:
set_tensor_wrappers = f" if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);" set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);"
else: else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);" set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, true);"
else: else:
if num_fwd_outputs > 1: if num_fwd_outputs > 1:
# Aligned with forward output position # Aligned with forward output position
...@@ -719,9 +729,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -719,9 +729,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
fwd_output_pos = forward_outputs_position_map[name][1] fwd_output_pos = forward_outputs_position_map[name][1]
if is_optional: if is_optional:
set_tensor_wrappers = f" if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), false);" set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), false);"
else: else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, false);" set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, false);"
set_tensor_wrappers_list.append(set_tensor_wrappers) set_tensor_wrappers_list.append(set_tensor_wrappers)
set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list) set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list)
...@@ -732,11 +742,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -732,11 +742,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
input_autograd_meta_name = GetAutoGradMetaName(name) input_autograd_meta_name = GetAutoGradMetaName(name)
is_optional = (name in self.optional_inputs) is_optional = (name in self.optional_inputs)
if is_optional: if is_optional:
set_grad_out_meta = f" if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});" set_grad_out_meta = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});"
set_edges = f" if({name}.get_ptr() != nullptr) grad_node->AddEdges({input_autograd_meta_name}, {pos});" set_edges = f"{indent}if({name}.get_ptr() != nullptr) grad_node->AddEdges({input_autograd_meta_name}, {pos});"
else: else:
set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});" set_grad_out_meta = f"{indent}grad_node->SetGradOutMeta({name}, {pos});"
set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});" set_edges = f"{indent}grad_node->AddEdges({input_autograd_meta_name}, {pos});"
set_grad_out_meta_list.append(set_grad_out_meta) set_grad_out_meta_list.append(set_grad_out_meta)
set_edges_list.append(set_edges) set_edges_list.append(set_edges)
...@@ -751,11 +761,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -751,11 +761,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
num_outputs = len(forward_outputs_position_map.keys()) num_outputs = len(forward_outputs_position_map.keys())
for name, (_, pos) in forward_outputs_position_map.items(): for name, (_, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name) output_autograd_meta_name = GetAutoGradMetaName(name)
set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});" set_out_rank = f"{indent}egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);" set_history = f"{indent}egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad({name});" set_retain_grad = f"{indent}egr::EagerUtils::CheckAndRetainGrad({name});"
set_grad_in_meta = f" grad_node->SetGradInMeta({name}, {pos});" set_grad_in_meta = f"{indent}grad_node->SetGradInMeta({name}, {pos});"
set_out_rank_list.append(set_out_rank) set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history) set_history_list.append(set_history)
set_grad_in_meta_list.append(set_grad_in_meta) set_grad_in_meta_list.append(set_grad_in_meta)
...@@ -767,7 +777,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -767,7 +777,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_retain_grad_str = "\n".join(set_retain_grad_list) set_retain_grad_str = "\n".join(set_retain_grad_list)
node_event_name = forward_api_name + " node_creation" node_event_name = forward_api_name + " node_creation"
node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n" node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n"
self.node_creation_str = FORWARD_BODY_TEMPLATE.format( self.node_creation_str = FORWARD_BODY_TEMPLATE.format(
node_creation_event_str, pass_stop_gradient_args_str, node_creation_event_str, pass_stop_gradient_args_str,
...@@ -845,6 +855,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -845,6 +855,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
optional_inputs = self.optional_inputs optional_inputs = self.optional_inputs
intermediate_outputs = self.intermediate_outputs intermediate_outputs = self.intermediate_outputs
inplace_map = self.inplace_map if is_inplaced else {} inplace_map = self.inplace_map if is_inplaced else {}
indent = GetIndent(1)
# Get Function Args # Get Function Args
num_inputs = len(forward_attrs_list) + len( num_inputs = len(forward_attrs_list) + len(
...@@ -918,7 +929,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -918,7 +929,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
else: else:
function_name = GetIntermediateAPIFunctionName(function_name) function_name = GetIntermediateAPIFunctionName(function_name)
forward_call_str = f"auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});" forward_call_str = f"{indent}auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});"
num_outputs = len(forward_outputs_position_map.keys()) - len( num_outputs = len(forward_outputs_position_map.keys()) - len(
intermediate_outputs) intermediate_outputs)
...@@ -926,9 +937,9 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -926,9 +937,9 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
get_outputs_str = "" get_outputs_str = ""
for name, (rtype, pos) in forward_outputs_position_map.items(): for name, (rtype, pos) in forward_outputs_position_map.items():
if num_outputs == 1 and len(intermediate_outputs) == 0: if num_outputs == 1 and len(intermediate_outputs) == 0:
get_outputs_str += f"auto& {name} = api_result;\n" get_outputs_str += f"{indent}auto& {name} = api_result;\n"
else: else:
get_outputs_str += f"auto& {name} = std::get<{pos}>(api_result);\n" get_outputs_str += f"{indent}auto& {name} = std::get<{pos}>(api_result);\n"
# Get return type list & outputs # Get return type list & outputs
returns_type_list = ["" for i in range(num_outputs)] returns_type_list = ["" for i in range(num_outputs)]
...@@ -961,12 +972,12 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -961,12 +972,12 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
for name, (ttype, pos) in forward_inputs_position_map.items(): for name, (ttype, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name) input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});" input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
else: else:
assert IsVectorTensorType(ttype) assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n" input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta) inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name) compute_require_grad_args_list.append(input_autograd_meta_name)
...@@ -981,19 +992,19 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -981,19 +992,19 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1: if num_fwd_outputs == 1:
if IsPlainTensorType(rtype): if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});" output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
else: else:
assert IsVectorTensorType(rtype) assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n" output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else: else:
# Tuple api_result # Tuple api_result
if IsPlainTensorType(rtype): if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});" output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
else: else:
assert IsVectorTensorType(rtype) assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n" output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
outputs_autograd_meta_list.append(output_autograd_meta) outputs_autograd_meta_list.append(output_autograd_meta)
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
...@@ -1012,7 +1023,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1012,7 +1023,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
self.GenerateNodeCreationCodes() self.GenerateNodeCreationCodes()
node_creation_str = self.node_creation_str node_creation_str = self.node_creation_str
dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);" dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);"
forward_function_name = GetDygraphForwardFunctionName(forward_api_name) forward_function_name = GetDygraphForwardFunctionName(forward_api_name)
# Forward amp logic # Forward amp logic
...@@ -1119,6 +1130,34 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1119,6 +1130,34 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
self.node_definition_str = "" self.node_definition_str = ""
self.next_grad_api_contents = next_grad_api_contents self.next_grad_api_contents = next_grad_api_contents
def ResetOptionalInputs(self):
namespace = self.namespace
grad_api_contents = self.grad_api_contents
base_generator = FunctionGeneratorBase(grad_api_contents, namespace)
base_generator.ParseDispensable()
self.optional_inputs = base_generator.optional_inputs
def GenerateHigherOrderNodeCreationCode(self):
namespace = self.namespace
grad_api_contents = self.grad_api_contents
next_grad_api_contents = self.next_grad_api_contents
grad_node_creation_str = ""
if next_grad_api_contents:
forward_api_contents = grad_api_contents
forward_api_contents['api'] = forward_api_contents['backward_api']
backward_api_contents = next_grad_api_contents
next_node_generator = DygraphFunctionGeneratorBase(
forward_api_contents, backward_api_contents, namespace)
next_node_generator.run()
next_node_generator.GenerateNodeCreationCodes()
grad_node_creation_str = next_node_generator.node_creation_str
return grad_node_creation_str
def GenerateNodeDeclaration(self): def GenerateNodeDeclaration(self):
forward_op_name = self.forward_api_name forward_op_name = self.forward_api_name
backward_forward_inputs_map = self.backward_forward_inputs_map backward_forward_inputs_map = self.backward_forward_inputs_map
...@@ -1187,6 +1226,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1187,6 +1226,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
backward_grad_inputs_map = self.backward_grad_inputs_map backward_grad_inputs_map = self.backward_grad_inputs_map
backward_grad_outputs_map = self.backward_grad_outputs_map backward_grad_outputs_map = self.backward_grad_outputs_map
backward_attrs_list = self.backward_attrs_list backward_attrs_list = self.backward_attrs_list
indent = GetIndent(1)
# Construct grad_api function args # Construct grad_api function args
# Order: TensorWrappers, GradTensors, Attributes # Order: TensorWrappers, GradTensors, Attributes
...@@ -1197,8 +1237,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1197,8 +1237,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# Fill Grad Ins with Zero # Fill Grad Ins with Zero
fill_zero_str = "" fill_zero_str = ""
if forward_api_name in ops_to_fill_zero_for_empty_grads: if backward_api_name in ops_to_fill_zero_for_empty_grads:
fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n" fill_zero_str = f"{indent}egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n"
# Grad Ins from TensorWrappers # Grad Ins from TensorWrappers
for name, (_, is_fwd_input, for name, (_, is_fwd_input,
...@@ -1209,9 +1249,9 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1209,9 +1249,9 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
is_optional = (name in self.optional_inputs) is_optional = (name in self.optional_inputs)
if is_optional: if is_optional:
tensor_wrapper_recover_str = f"auto {transformed_tensor_name} = egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr);" tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr);"
else: else:
tensor_wrapper_recover_str = f"auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr);" tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr);"
grad_api_args[grad_api_position] = transformed_tensor_name grad_api_args[grad_api_position] = transformed_tensor_name
get_grad_in_args_list.append(tensor_wrapper_recover_str) get_grad_in_args_list.append(tensor_wrapper_recover_str)
...@@ -1221,18 +1261,29 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1221,18 +1261,29 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name) name)
is_optional = (name in self.optional_inputs)
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
get_tensor_str = f"auto& {transformed_tensor_name} = hooked_grads[{fwd_position}][0];" get_tensor_str = f"{indent}auto& {transformed_tensor_name} = hooked_grads[{fwd_position}][0];"
if is_optional:
get_tensor_str += "\n" + CREATE_PLAIN_OPTIONAL_TENSOR_TEMPLATE.format(
transformed_tensor_name, transformed_tensor_name,
transformed_tensor_name, transformed_tensor_name)
grad_api_args[
grad_api_position] = f"{transformed_tensor_name}_optional"
else:
grad_api_args[grad_api_position] = transformed_tensor_name
else: else:
assert IsVectorTensorType(ttype) assert IsVectorTensorType(ttype)
get_tensor_str = f"auto& {transformed_tensor_name} = hooked_grads[{fwd_position}];" get_tensor_str = f"{indent}auto& {transformed_tensor_name} = hooked_grads[{fwd_position}];"
grad_api_args[grad_api_position] = transformed_tensor_name grad_api_args[grad_api_position] = transformed_tensor_name
get_grad_in_args_list.append(get_tensor_str) get_grad_in_args_list.append(get_tensor_str)
# Grad Attrs # Grad Attrs
for name, _, _, grad_api_position in backward_attrs_list: for name, _, _, grad_api_position in backward_attrs_list:
saved_attribute_name = GetSavedName(name) saved_attribute_name = GetSavedName(name)
get_attr_str = f"auto& {name} = this->{saved_attribute_name};" get_attr_str = f"{indent}auto& {name} = this->{saved_attribute_name};"
grad_api_args[grad_api_position] = name grad_api_args[grad_api_position] = name
get_grad_in_args_list.append(get_attr_str) get_grad_in_args_list.append(get_attr_str)
...@@ -1242,7 +1293,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1242,7 +1293,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# Grad Function Call String # Grad Function Call String
grad_api_namespace = f"paddle::experimental::{namespace}" grad_api_namespace = f"paddle::experimental::{namespace}"
grad_function_call_str = f"auto grad_api_result = {grad_api_namespace}{backward_api_name}({grad_api_args_str});" grad_function_call_str = f"{indent}auto grad_api_result = {grad_api_namespace}{backward_api_name}({grad_api_args_str});"
# Get Grad Outputs # Get Grad Outputs
get_outputs_str = "" get_outputs_str = ""
...@@ -1253,9 +1304,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1253,9 +1304,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
name) name)
if num_outputs == 1: if num_outputs == 1:
get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result;" get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result;"
else: else:
get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result[{grad_api_position}];" if IsPlainTensorType(ttype):
get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result[{grad_api_position}][0];"
else:
assert IsVectorTensorType(ttype)
get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result[{grad_api_position}];"
get_outputs_str += get_tensor_str + "\n" get_outputs_str += get_tensor_str + "\n"
# Prepare for Node Creation if Necessary # Prepare for Node Creation if Necessary
...@@ -1274,13 +1329,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1274,13 +1329,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
input_autograd_meta_name = GetAutoGradMetaName( input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name) transformed_tensor_name)
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});" input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});"
else: else:
assert IsVectorTensorType(ttype) assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName( input_autograd_meta_vec_name = GetAutoGradMetaVectorName(
transformed_tensor_name) transformed_tensor_name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n" input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta) inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name) compute_require_grad_args_list.append(input_autograd_meta_name)
...@@ -1293,13 +1348,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1293,13 +1348,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
input_autograd_meta_name = GetAutoGradMetaName( input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name) transformed_tensor_name)
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});" input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});"
else: else:
assert IsVectorTensorType(ttype) assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName( input_autograd_meta_vec_name = GetAutoGradMetaVectorName(
transformed_tensor_name) transformed_tensor_name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n" input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta) inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name) compute_require_grad_args_list.append(input_autograd_meta_name)
...@@ -1320,30 +1375,30 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1320,30 +1375,30 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
transformed_tensor_name) transformed_tensor_name)
if num_fwd_outputs == 1: if num_fwd_outputs == 1:
if IsPlainTensorType(rtype): if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});" output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});"
else: else:
assert IsVectorTensorType(rtype) assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n" output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else: else:
# Tuple api_result # Tuple api_result
if IsPlainTensorType(rtype): if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});" output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});"
else: else:
assert IsVectorTensorType(rtype) assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n" output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
outputs_autograd_meta_list.append(output_autograd_meta) outputs_autograd_meta_list.append(output_autograd_meta)
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
compute_require_grad_str = "bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n" compute_require_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n"
compute_require_grad_str += f"bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({compute_require_grad_args_str});" compute_require_grad_str += f"{indent}bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({compute_require_grad_args_str});"
# Construct grad_api returns # Construct grad_api returns
num_bwd_outputs = len(backward_grad_outputs_map.keys()) num_bwd_outputs = len(backward_grad_outputs_map.keys())
slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys()) slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys())
returns_str = f"std::vector<std::vector<paddle::experimental::Tensor>> returns({slot_num_bwd_outputs});\n" returns_str = f"{indent}std::vector<std::vector<paddle::experimental::Tensor>> returns({slot_num_bwd_outputs});\n"
for name, (ttype, fwd_position, for name, (ttype, fwd_position,
grad_api_position) in backward_grad_outputs_map.items(): grad_api_position) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
...@@ -1353,15 +1408,20 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1353,15 +1408,20 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
if num_bwd_outputs == 1: if num_bwd_outputs == 1:
# Single tensor output, return as is # Single tensor output, return as is
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
returns_str += f"returns[0] = {{ {transformed_tensor_name} }};\n" returns_str += f"{indent}returns[0] = {{ {transformed_tensor_name} }};\n"
else: else:
assert IsVectorTensorType(ttype) assert IsVectorTensorType(ttype)
returns_str += f"returns[0] = {transformed_tensor_name};\n" returns_str += f"{indent}returns[0] = {transformed_tensor_name};\n"
else: else:
# Rearrange output order accordingly # Rearrange output order accordingly
returns_str += f"returns[{fwd_position}] = {transformed_tensor_name};\n" if IsPlainTensorType(ttype):
returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n" returns_str += f"{indent}returns[{fwd_position}] = {{ {transformed_tensor_name} }};\n"
returns_str += f"return returns;\n" else:
assert IsVectorTensorType(ttype)
returns_str += f"{indent}returns[{fwd_position}] = {transformed_tensor_name};\n"
returns_str += f"{indent}if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n"
returns_str += f"{indent}return returns;\n"
grad_node_name = GetGradNodeName(forward_api_name) grad_node_name = GetGradNodeName(forward_api_name)
...@@ -1376,24 +1436,15 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1376,24 +1436,15 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
def run(self): def run(self):
super().run() super().run()
self.ResetOptionalInputs()
##################### #####################
## Code Generation ## ## Code Generation ##
##################### #####################
self.GenerateNodeDeclaration() self.GenerateNodeDeclaration()
namespace = self.namespace # Higher-order GradNode generation
grad_node_creation_str = "" grad_node_creation_str = self.GenerateHigherOrderNodeCreationCode()
next_grad_api_contents = self.next_grad_api_contents
if next_grad_api_contents:
forward_api_contents = self.grad_api_contents
forward_api_contents['api'] = forward_api_contents['backward_api']
backward_api_contents = next_grad_api_contents
next_node_generator = DygraphFunctionGeneratorBase(
forward_api_contents, backward_api_contents, namespace)
next_node_generator.run()
next_node_generator.GenerateNodeCreationCodes()
grad_node_creation_str = next_node_generator.node_creation_str
self.GenerateNodeDefinition(grad_node_creation_str) self.GenerateNodeDefinition(grad_node_creation_str)
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <memory>
#include "paddle/fluid/eager/eager_tensor.h" #include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/hooks.h" #include "paddle/fluid/eager/hooks.h"
#include "paddle/phi/api/all.h" #include "paddle/phi/api/all.h"
......
...@@ -23,6 +23,7 @@ from ...tensor.math import multiply ...@@ -23,6 +23,7 @@ from ...tensor.math import multiply
import warnings import warnings
from ...fluid.layer_helper import LayerHelper from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import convert_np_dtype_to_dtype_ from ...fluid.framework import convert_np_dtype_to_dtype_
from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode
from ...fluid.data_feeder import check_variable_and_dtype, check_dtype from ...fluid.data_feeder import check_variable_and_dtype, check_dtype
import paddle import paddle
from paddle import _C_ops, in_dynamic_mode from paddle import _C_ops, in_dynamic_mode
...@@ -560,9 +561,10 @@ def relu(x, name=None): ...@@ -560,9 +561,10 @@ def relu(x, name=None):
out = F.relu(x) # [0., 0., 1.] out = F.relu(x) # [0., 0., 1.]
""" """
if in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_relu(x)
if _in_legacy_dygraph():
return _C_ops.relu(x) return _C_ops.relu(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'relu') check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'relu')
helper = LayerHelper('relu', **locals()) helper = LayerHelper('relu', **locals())
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
......
...@@ -377,15 +377,15 @@ ...@@ -377,15 +377,15 @@
data_type : x data_type : x
- backward_api : matmul_double_grad - backward_api : matmul_double_grad
forward : matmul_grad (Tensor x, Tensor y, Tensor out_grad, bool transpose_x, bool transpose_y) -> Tensor(dx), Tensor(dy) forward : matmul_grad (Tensor x, Tensor y, Tensor grad_out, bool transpose_x=false, bool transpose_y=false) -> Tensor(grad_x), Tensor(grad_y)
args : (Tensor x, Tensor y, Tensor out_grad, Tensor dx_grad, Tensor dy_grad, bool transpose_x, bool transpose_y) args : (Tensor x, Tensor y, Tensor grad_out, Tensor grad_x_grad, Tensor grad_y_grad, bool transpose_x=false, bool transpose_y=false)
output : Tensor(d2x), Tensor(d2y), Tensor(dout_grad) output : Tensor(x_grad), Tensor(y_grad), Tensor(grad_out_grad)
infer_meta : infer_meta :
func : GeneralTernaryGradInferMeta func : GeneralTernaryGradInferMeta
param : [x, y, out_grad] param : [x, y, grad_out]
kernel : kernel :
func : matmul_double_grad func : matmul_double_grad
optional : dx_grad, dy_grad optional : grad_x_grad, grad_y_grad
- backward_api : matmul_grad - backward_api : matmul_grad
forward : matmul (Tensor x, Tensor y, bool transpose_x=false, bool transpose_y=false) -> Tensor(out) forward : matmul (Tensor x, Tensor y, bool transpose_x=false, bool transpose_y=false) -> Tensor(out)
...@@ -396,6 +396,7 @@ ...@@ -396,6 +396,7 @@
param : [x, y] param : [x, y]
kernel : kernel :
func : matmul_grad func : matmul_grad
backward : matmul_double_grad
- backward_api : matrix_power_grad - backward_api : matrix_power_grad
forward : matrix_power (Tensor x, int n) -> Tensor(out) forward : matrix_power (Tensor x, int n) -> Tensor(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册