未验证 提交 01724b1a 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[DoubleGrad #4] Bug Fixes to Double Grad Node Generation (#41121)

* [Refactor] refactored eager_gen.py PR #2

* [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes

* Fixed minor issue

* Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition

* Fixed issues

* Supported higher-order grad node generation

* [DoubleGrad PR #4] Supported higher-order GradNode generation

* [DoubleGrad #4] Bug Fixes to Double Grad Node Generation

* Fixed yaml typo

* Fixed yaml typo

* fixed minor issues

* Fixed minor issue
上级 4da4265a
......@@ -21,7 +21,8 @@ import os
########################
### Global Variables ###
########################
ops_to_fill_zero_for_empty_grads = set(["split", "rnn"])
ops_to_fill_zero_for_empty_grads = set(
["split_grad", "rnn_grad", "matmul_double_grad"])
# For API dispatch used at python-level
# { op_name : [arg_name, ...] }
......@@ -176,6 +177,11 @@ def TransformGradVarNameForDoubleGradGeneration(string):
return string
def GetIndent(num):
tab = " "
return "".join([tab for i in range(num)])
######################
### Yaml Parsers ###
######################
......
......@@ -32,7 +32,7 @@ from codegen_utils import ParseYamlForward, ParseYamlBackward
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads
from codegen_utils import TransformGradVarNameForDoubleGradGeneration
from codegen_utils import AssertMessage
from codegen_utils import AssertMessage, GetIndent
###########
......@@ -112,80 +112,81 @@ ATTRIBUTE_MEMBER_TEMPLATE = \
NODE_DECLARATION_TEMPLATE = \
"""
class {} : public egr::GradNodeBase {{
public:
{}() : egr::GradNodeBase() {{}}
{}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) :
egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}}
~{}() override = default;
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override;
std::string name() override {{ return \" {} \"; }}
void ClearTensorWrappers() override {{
{}
is_tensor_wrappers_cleared = true;
}}
// SetTensorWrapperX, SetTensorWrapperY, ...
class {} : public egr::GradNodeBase {{
public:
{}() : egr::GradNodeBase() {{}}
{}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) :
egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}}
~{}() override = default;
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override;
std::string name() override {{ return \" {} \"; }}
void ClearTensorWrappers() override {{
{}
// SetAttributes
{}
bool IsTensorWrappersCleared() override {{
return is_tensor_wrappers_cleared;
}}
private:
// TensorWrappers
{}
bool is_tensor_wrappers_cleared = false;
// Attributes
{}
}};
is_tensor_wrappers_cleared = true;
}}
// SetTensorWrapperX, SetTensorWrapperY, ...
{}
// SetAttributes
{}
bool IsTensorWrappersCleared() override {{
return is_tensor_wrappers_cleared;
}}
private:
// TensorWrappers
{}
bool is_tensor_wrappers_cleared = false;
// Attributes
{}
}};
"""
GRAD_FUNCTION_TEMPLATE = \
"""
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{
// Fill Zero For GradIn Tensors
{}
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{
// Fill Zero For GradIn Tensors
{}
// Apply Gradient Hooks
auto hooked_grads = ApplyGradientHooks(grads);
// Collect GradIn Tensors, Attrs and Recovered TensorWrappers
{}
// Apply Gradient Hooks
auto hooked_grads = ApplyGradientHooks(grads);
// Collect GradIn Tensors, Attrs and Recovered TensorWrappers
{}
// Call grad_api function
VLOG(3) << \"Final State Running: \" << \"{}\";
{}
// Call grad_api function
VLOG(3) << \"Final State Running: \" << \"{}\";
{}
// Get Output
{}
// Get Output
{}
// Get GradIn autograd_meta
{}
// Get GradIn autograd_meta
{}
// Get GradOut autograd_meta
{}
// Compute Require Grad
{}
// Create Grad Node
{}
// Get GradOut autograd_meta
{}
// Compute Require Grad
{}
// Create Grad Node
{}
// Return
{}
}}
// Return
{}
}}
"""
FORWARD_FUNCTION_TEMPLATE = \
"""
{} {}({}) {{
{} {}({}) {{
// Dygraph Record Event
{}
// AMP Logic
......@@ -208,33 +209,33 @@ FORWARD_FUNCTION_TEMPLATE = \
// Node Creation
{}
// Returns
return {};
}}
// Returns
return {};
}}
"""
FORWARD_BODY_TEMPLATE = \
"""
if(require_any_grad) {{
if(require_any_grad) {{
{}
egr::EagerUtils::PassStopGradient({});
// Node Construction
egr::EagerUtils::PassStopGradient({});
// Node Construction
{}
// SetAttributes
// SetAttributes
{}
// SetTensorWrappers
// SetTensorWrappers
{}
// SetGradOutMeta & SetEdges
// SetGradOutMeta & SetEdges
{}
{}
// SetOutRank & SetHistory & SetGradInMeta & RetainGrad
// SetOutRank & SetHistory & SetGradInMeta & RetainGrad
{}
{}
{}
{}
}}
}}
"""
NAMESPACE_WRAPPER_TEMPLATE = \
......@@ -318,9 +319,9 @@ std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_r
CORE_OPS_DECLARATION_TEMPLATE = \
"""
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_type_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_returns_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_type_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_returns_info;
"""
......@@ -352,6 +353,12 @@ AMP_LOGIC_TEMPLATE = \
}}
"""
CREATE_PLAIN_OPTIONAL_TENSOR_TEMPLATE = \
"""
paddle::optional<const paddle::experimental::Tensor&> {}_optional = paddle::none;
if({}.initialized()) {}_optional = paddle::make_optional<const paddle::experimental::Tensor&>({});
"""
#######################
## Generator Helpers ##
......@@ -678,12 +685,15 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
num_backward_inputs = len(forward_outputs_position_map.keys())
num_backward_outputs = len(forward_inputs_position_map.keys())
grad_node_name = GetGradNodeName(forward_api_name)
# Helper
indent = GetIndent(2)
# NOTE(Aurelius74): DO NOT use make_shared here. Because some Node contains experimental::Scalar
# which contains "complex128" as data. "complex128" is memory-aligned manually. But make_shared
# request MEMALIGN for allocation (Maybe).
# See https://stackoverflow.com/questions/31228656/how-can-shared-ptr-disrupt-alignment
# and https://github.com/MRtrix3/mrtrix3/issues/957
node_construction_str = f" auto grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));"
node_construction_str = f"{indent}auto grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));"
# SetAttributes
set_attributes_list = []
......@@ -693,9 +703,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
for name, _, default_val_attr, _ in backward_attrs_list:
if name in forward_attrs_name_set:
set_attributes = f" grad_node->SetAttribute{name}({name});"
set_attributes = f"{indent}grad_node->SetAttribute{name}({name});"
else:
set_attributes = f" grad_node->SetAttribute{name}({default_val_attr});"
set_attributes = f"{indent}grad_node->SetAttribute{name}({default_val_attr});"
set_attributes_list.append(set_attributes)
set_attributes_str = "\n".join(set_attributes_list)
......@@ -708,9 +718,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
if is_fwd_input:
if is_optional:
set_tensor_wrappers = f" if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);"
set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, true);"
else:
if num_fwd_outputs > 1:
# Aligned with forward output position
......@@ -719,9 +729,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
fwd_output_pos = forward_outputs_position_map[name][1]
if is_optional:
set_tensor_wrappers = f" if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), false);"
set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), false);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, false);"
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, false);"
set_tensor_wrappers_list.append(set_tensor_wrappers)
set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list)
......@@ -732,11 +742,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
input_autograd_meta_name = GetAutoGradMetaName(name)
is_optional = (name in self.optional_inputs)
if is_optional:
set_grad_out_meta = f" if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});"
set_edges = f" if({name}.get_ptr() != nullptr) grad_node->AddEdges({input_autograd_meta_name}, {pos});"
set_grad_out_meta = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});"
set_edges = f"{indent}if({name}.get_ptr() != nullptr) grad_node->AddEdges({input_autograd_meta_name}, {pos});"
else:
set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});"
set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});"
set_grad_out_meta = f"{indent}grad_node->SetGradOutMeta({name}, {pos});"
set_edges = f"{indent}grad_node->AddEdges({input_autograd_meta_name}, {pos});"
set_grad_out_meta_list.append(set_grad_out_meta)
set_edges_list.append(set_edges)
......@@ -751,11 +761,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
num_outputs = len(forward_outputs_position_map.keys())
for name, (_, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
set_out_rank = f"{indent}egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f"{indent}egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad({name});"
set_grad_in_meta = f" grad_node->SetGradInMeta({name}, {pos});"
set_retain_grad = f"{indent}egr::EagerUtils::CheckAndRetainGrad({name});"
set_grad_in_meta = f"{indent}grad_node->SetGradInMeta({name}, {pos});"
set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history)
set_grad_in_meta_list.append(set_grad_in_meta)
......@@ -767,7 +777,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_retain_grad_str = "\n".join(set_retain_grad_list)
node_event_name = forward_api_name + " node_creation"
node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n"
node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n"
self.node_creation_str = FORWARD_BODY_TEMPLATE.format(
node_creation_event_str, pass_stop_gradient_args_str,
......@@ -845,6 +855,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
optional_inputs = self.optional_inputs
intermediate_outputs = self.intermediate_outputs
inplace_map = self.inplace_map if is_inplaced else {}
indent = GetIndent(1)
# Get Function Args
num_inputs = len(forward_attrs_list) + len(
......@@ -918,7 +929,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
else:
function_name = GetIntermediateAPIFunctionName(function_name)
forward_call_str = f"auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});"
forward_call_str = f"{indent}auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});"
num_outputs = len(forward_outputs_position_map.keys()) - len(
intermediate_outputs)
......@@ -926,9 +937,9 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
get_outputs_str = ""
for name, (rtype, pos) in forward_outputs_position_map.items():
if num_outputs == 1 and len(intermediate_outputs) == 0:
get_outputs_str += f"auto& {name} = api_result;\n"
get_outputs_str += f"{indent}auto& {name} = api_result;\n"
else:
get_outputs_str += f"auto& {name} = std::get<{pos}>(api_result);\n"
get_outputs_str += f"{indent}auto& {name} = std::get<{pos}>(api_result);\n"
# Get return type list & outputs
returns_type_list = ["" for i in range(num_outputs)]
......@@ -961,12 +972,12 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
for name, (ttype, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
......@@ -981,19 +992,19 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
outputs_autograd_meta_list.append(output_autograd_meta)
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
......@@ -1012,7 +1023,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
self.GenerateNodeCreationCodes()
node_creation_str = self.node_creation_str
dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);"
dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);"
forward_function_name = GetDygraphForwardFunctionName(forward_api_name)
# Forward amp logic
......@@ -1119,6 +1130,34 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
self.node_definition_str = ""
self.next_grad_api_contents = next_grad_api_contents
def ResetOptionalInputs(self):
namespace = self.namespace
grad_api_contents = self.grad_api_contents
base_generator = FunctionGeneratorBase(grad_api_contents, namespace)
base_generator.ParseDispensable()
self.optional_inputs = base_generator.optional_inputs
def GenerateHigherOrderNodeCreationCode(self):
namespace = self.namespace
grad_api_contents = self.grad_api_contents
next_grad_api_contents = self.next_grad_api_contents
grad_node_creation_str = ""
if next_grad_api_contents:
forward_api_contents = grad_api_contents
forward_api_contents['api'] = forward_api_contents['backward_api']
backward_api_contents = next_grad_api_contents
next_node_generator = DygraphFunctionGeneratorBase(
forward_api_contents, backward_api_contents, namespace)
next_node_generator.run()
next_node_generator.GenerateNodeCreationCodes()
grad_node_creation_str = next_node_generator.node_creation_str
return grad_node_creation_str
def GenerateNodeDeclaration(self):
forward_op_name = self.forward_api_name
backward_forward_inputs_map = self.backward_forward_inputs_map
......@@ -1187,6 +1226,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
backward_grad_inputs_map = self.backward_grad_inputs_map
backward_grad_outputs_map = self.backward_grad_outputs_map
backward_attrs_list = self.backward_attrs_list
indent = GetIndent(1)
# Construct grad_api function args
# Order: TensorWrappers, GradTensors, Attributes
......@@ -1197,8 +1237,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# Fill Grad Ins with Zero
fill_zero_str = ""
if forward_api_name in ops_to_fill_zero_for_empty_grads:
fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n"
if backward_api_name in ops_to_fill_zero_for_empty_grads:
fill_zero_str = f"{indent}egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n"
# Grad Ins from TensorWrappers
for name, (_, is_fwd_input,
......@@ -1209,9 +1249,9 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
is_optional = (name in self.optional_inputs)
if is_optional:
tensor_wrapper_recover_str = f"auto {transformed_tensor_name} = egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr);"
tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr);"
else:
tensor_wrapper_recover_str = f"auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr);"
tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr);"
grad_api_args[grad_api_position] = transformed_tensor_name
get_grad_in_args_list.append(tensor_wrapper_recover_str)
......@@ -1221,18 +1261,29 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
is_optional = (name in self.optional_inputs)
if IsPlainTensorType(ttype):
get_tensor_str = f"auto& {transformed_tensor_name} = hooked_grads[{fwd_position}][0];"
get_tensor_str = f"{indent}auto& {transformed_tensor_name} = hooked_grads[{fwd_position}][0];"
if is_optional:
get_tensor_str += "\n" + CREATE_PLAIN_OPTIONAL_TENSOR_TEMPLATE.format(
transformed_tensor_name, transformed_tensor_name,
transformed_tensor_name, transformed_tensor_name)
grad_api_args[
grad_api_position] = f"{transformed_tensor_name}_optional"
else:
grad_api_args[grad_api_position] = transformed_tensor_name
else:
assert IsVectorTensorType(ttype)
get_tensor_str = f"auto& {transformed_tensor_name} = hooked_grads[{fwd_position}];"
grad_api_args[grad_api_position] = transformed_tensor_name
get_tensor_str = f"{indent}auto& {transformed_tensor_name} = hooked_grads[{fwd_position}];"
grad_api_args[grad_api_position] = transformed_tensor_name
get_grad_in_args_list.append(get_tensor_str)
# Grad Attrs
for name, _, _, grad_api_position in backward_attrs_list:
saved_attribute_name = GetSavedName(name)
get_attr_str = f"auto& {name} = this->{saved_attribute_name};"
get_attr_str = f"{indent}auto& {name} = this->{saved_attribute_name};"
grad_api_args[grad_api_position] = name
get_grad_in_args_list.append(get_attr_str)
......@@ -1242,7 +1293,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# Grad Function Call String
grad_api_namespace = f"paddle::experimental::{namespace}"
grad_function_call_str = f"auto grad_api_result = {grad_api_namespace}{backward_api_name}({grad_api_args_str});"
grad_function_call_str = f"{indent}auto grad_api_result = {grad_api_namespace}{backward_api_name}({grad_api_args_str});"
# Get Grad Outputs
get_outputs_str = ""
......@@ -1253,9 +1304,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
name)
if num_outputs == 1:
get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result;"
get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result;"
else:
get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result[{grad_api_position}];"
if IsPlainTensorType(ttype):
get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result[{grad_api_position}][0];"
else:
assert IsVectorTensorType(ttype)
get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result[{grad_api_position}];"
get_outputs_str += get_tensor_str + "\n"
# Prepare for Node Creation if Necessary
......@@ -1274,13 +1329,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});"
input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(
transformed_tensor_name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n"
input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
......@@ -1293,13 +1348,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});"
input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(
transformed_tensor_name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n"
input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
......@@ -1320,30 +1375,30 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
transformed_tensor_name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});"
output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n"
output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});"
output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n"
output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
outputs_autograd_meta_list.append(output_autograd_meta)
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
compute_require_grad_str = "bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n"
compute_require_grad_str += f"bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({compute_require_grad_args_str});"
compute_require_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n"
compute_require_grad_str += f"{indent}bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({compute_require_grad_args_str});"
# Construct grad_api returns
num_bwd_outputs = len(backward_grad_outputs_map.keys())
slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys())
returns_str = f"std::vector<std::vector<paddle::experimental::Tensor>> returns({slot_num_bwd_outputs});\n"
returns_str = f"{indent}std::vector<std::vector<paddle::experimental::Tensor>> returns({slot_num_bwd_outputs});\n"
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
......@@ -1353,15 +1408,20 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
if num_bwd_outputs == 1:
# Single tensor output, return as is
if IsPlainTensorType(ttype):
returns_str += f"returns[0] = {{ {transformed_tensor_name} }};\n"
returns_str += f"{indent}returns[0] = {{ {transformed_tensor_name} }};\n"
else:
assert IsVectorTensorType(ttype)
returns_str += f"returns[0] = {transformed_tensor_name};\n"
returns_str += f"{indent}returns[0] = {transformed_tensor_name};\n"
else:
# Rearrange output order accordingly
returns_str += f"returns[{fwd_position}] = {transformed_tensor_name};\n"
returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n"
returns_str += f"return returns;\n"
if IsPlainTensorType(ttype):
returns_str += f"{indent}returns[{fwd_position}] = {{ {transformed_tensor_name} }};\n"
else:
assert IsVectorTensorType(ttype)
returns_str += f"{indent}returns[{fwd_position}] = {transformed_tensor_name};\n"
returns_str += f"{indent}if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n"
returns_str += f"{indent}return returns;\n"
grad_node_name = GetGradNodeName(forward_api_name)
......@@ -1376,24 +1436,15 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
def run(self):
super().run()
self.ResetOptionalInputs()
#####################
## Code Generation ##
#####################
self.GenerateNodeDeclaration()
namespace = self.namespace
grad_node_creation_str = ""
next_grad_api_contents = self.next_grad_api_contents
if next_grad_api_contents:
forward_api_contents = self.grad_api_contents
forward_api_contents['api'] = forward_api_contents['backward_api']
backward_api_contents = next_grad_api_contents
next_node_generator = DygraphFunctionGeneratorBase(
forward_api_contents, backward_api_contents, namespace)
next_node_generator.run()
next_node_generator.GenerateNodeCreationCodes()
grad_node_creation_str = next_node_generator.node_creation_str
# Higher-order GradNode generation
grad_node_creation_str = self.GenerateHigherOrderNodeCreationCode()
self.GenerateNodeDefinition(grad_node_creation_str)
......
......@@ -14,6 +14,8 @@
#pragma once
#include <memory>
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/hooks.h"
#include "paddle/phi/api/all.h"
......
......@@ -23,6 +23,7 @@ from ...tensor.math import multiply
import warnings
from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import convert_np_dtype_to_dtype_
from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode
from ...fluid.data_feeder import check_variable_and_dtype, check_dtype
import paddle
from paddle import _C_ops, in_dynamic_mode
......@@ -560,9 +561,10 @@ def relu(x, name=None):
out = F.relu(x) # [0., 0., 1.]
"""
if in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_relu(x)
if _in_legacy_dygraph():
return _C_ops.relu(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'relu')
helper = LayerHelper('relu', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
......
......@@ -377,15 +377,15 @@
data_type : x
- backward_api : matmul_double_grad
forward : matmul_grad (Tensor x, Tensor y, Tensor out_grad, bool transpose_x, bool transpose_y) -> Tensor(dx), Tensor(dy)
args : (Tensor x, Tensor y, Tensor out_grad, Tensor dx_grad, Tensor dy_grad, bool transpose_x, bool transpose_y)
output : Tensor(d2x), Tensor(d2y), Tensor(dout_grad)
forward : matmul_grad (Tensor x, Tensor y, Tensor grad_out, bool transpose_x=false, bool transpose_y=false) -> Tensor(grad_x), Tensor(grad_y)
args : (Tensor x, Tensor y, Tensor grad_out, Tensor grad_x_grad, Tensor grad_y_grad, bool transpose_x=false, bool transpose_y=false)
output : Tensor(x_grad), Tensor(y_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [x, y, out_grad]
param : [x, y, grad_out]
kernel :
func : matmul_double_grad
optional : dx_grad, dy_grad
optional : grad_x_grad, grad_y_grad
- backward_api : matmul_grad
forward : matmul (Tensor x, Tensor y, bool transpose_x=false, bool transpose_y=false) -> Tensor(out)
......@@ -396,6 +396,7 @@
param : [x, y]
kernel :
func : matmul_grad
backward : matmul_double_grad
- backward_api : matrix_power_grad
forward : matrix_power (Tensor x, int n) -> Tensor(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册