未验证 提交 abd2df4c 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[DoubleGrad PR #3] Supported higher-order GradNode generation (#41051)

* [Refactor] refactored eager_gen.py PR #2

* [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes

* Fixed minor issue

* Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition

* Fixed issues

* Supported higher-order grad node generation

* [DoubleGrad PR #4] Supported higher-order GradNode generation

* Fixed yaml typo
上级 489a64ef
...@@ -89,6 +89,10 @@ def FindForwardName(string): ...@@ -89,6 +89,10 @@ def FindForwardName(string):
return string[:-5] return string[:-5]
def IsGradName(string):
return string.endswith("_grad")
def IsPlainTensorType(string): def IsPlainTensorType(string):
plain_tensor_types = ['Tensor&', 'Tensor', 'const Tensor&', 'const Tensor'] plain_tensor_types = ['Tensor&', 'Tensor', 'const Tensor&', 'const Tensor']
if string in plain_tensor_types: if string in plain_tensor_types:
...@@ -166,6 +170,12 @@ def GetForwardFunctionName(string): ...@@ -166,6 +170,12 @@ def GetForwardFunctionName(string):
return f"{string}_final_state_dygraph_function" return f"{string}_final_state_dygraph_function"
def TransformGradVarNameForDoubleGradGeneration(string):
if IsGradName(string):
string = "grad_" + string[:-5]
return string
###################### ######################
### Yaml Parsers ### ### Yaml Parsers ###
###################### ######################
......
...@@ -31,6 +31,7 @@ from codegen_utils import ParseYamlArgs, ParseYamlReturns, ParseYamlForwardFromB ...@@ -31,6 +31,7 @@ from codegen_utils import ParseYamlArgs, ParseYamlReturns, ParseYamlForwardFromB
from codegen_utils import ParseYamlForward, ParseYamlBackward from codegen_utils import ParseYamlForward, ParseYamlBackward
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads from codegen_utils import ops_to_fill_zero_for_empty_grads
from codegen_utils import TransformGradVarNameForDoubleGradGeneration
from codegen_utils import AssertMessage from codegen_utils import AssertMessage
...@@ -146,15 +147,38 @@ NODE_DECLARATION_TEMPLATE = \ ...@@ -146,15 +147,38 @@ NODE_DECLARATION_TEMPLATE = \
}}; }};
""" """
FUNCTION_TEMPLATE = \ GRAD_FUNCTION_TEMPLATE = \
""" """
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{ std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{
// Fill Zero For GradIn Tensors
{} {}
// Apply Gradient Hooks
auto hooked_grads = ApplyGradientHooks(grads); auto hooked_grads = ApplyGradientHooks(grads);
// Collect GradIn Tensors, Attrs and Recovered TensorWrappers
{}
// Call grad_api function // Call grad_api function
VLOG(3) << \"Final State Running: \" << \"{}\"; VLOG(3) << \"Final State Running: \" << \"{}\";
auto grad_api_returns = {}{}({}); {}
// Get Output
{}
// Get GradIn autograd_meta
{}
// Get GradOut autograd_meta
{}
// Compute Require Grad
{}
// Create Grad Node
{}
// Return
{} {}
}} }}
""" """
...@@ -170,11 +194,14 @@ FORWARD_FUNCTION_TEMPLATE = \ ...@@ -170,11 +194,14 @@ FORWARD_FUNCTION_TEMPLATE = \
// Get Input AutoGradMeta // Get Input AutoGradMeta
{} {}
// Forward API Call // Forward API Call
{}
// Get Outputs
{} {}
// Get Output AutoGradMeta // Get Output AutoGradMeta
{} {}
bool trace_backward = egr::Controller::Instance().HasGrad(); bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({}); bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
// Check Inplace & Bump Inplace Version // Check Inplace & Bump Inplace Version
{} {}
{} {}
...@@ -225,6 +252,7 @@ NODE_CC_FILE_TEMPLATE = \ ...@@ -225,6 +252,7 @@ NODE_CC_FILE_TEMPLATE = \
#include "paddle/phi/api/backward/sparse_bw_api.h" #include "paddle/phi/api/backward/sparse_bw_api.h"
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/eager/utils.h" #include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" #include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
...@@ -689,14 +717,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -689,14 +717,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
assert name in forward_outputs_position_map.keys( assert name in forward_outputs_position_map.keys(
), AssertMessage(name, forward_outputs_position_map.keys()) ), AssertMessage(name, forward_outputs_position_map.keys())
fwd_output_pos = forward_outputs_position_map[name][1] fwd_output_pos = forward_outputs_position_map[name][1]
tw_name = f"std::get<{fwd_output_pos}>(api_result)"
else:
tw_name = f"api_result"
if is_optional: if is_optional:
set_tensor_wrappers = f" if({tw_name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({tw_name}.get_ptr()), false);" set_tensor_wrappers = f" if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), false);"
else: else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);" set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, false);"
set_tensor_wrappers_list.append(set_tensor_wrappers) set_tensor_wrappers_list.append(set_tensor_wrappers)
set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list) set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list)
...@@ -729,12 +754,8 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -729,12 +754,8 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});" set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);" set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
if num_outputs == 1: set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad({name});"
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);" set_grad_in_meta = f" grad_node->SetGradInMeta({name}, {pos});"
set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});"
else:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));"
set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});"
set_out_rank_list.append(set_out_rank) set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history) set_history_list.append(set_history)
set_grad_in_meta_list.append(set_grad_in_meta) set_grad_in_meta_list.append(set_grad_in_meta)
...@@ -898,20 +919,24 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -898,20 +919,24 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
function_name = GetIntermediateAPIFunctionName(function_name) function_name = GetIntermediateAPIFunctionName(function_name)
forward_call_str = f"auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});" forward_call_str = f"auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});"
# Get return type list & outputs
num_outputs = len(forward_outputs_position_map.keys()) - len( num_outputs = len(forward_outputs_position_map.keys()) - len(
intermediate_outputs) intermediate_outputs)
# Get Outputs
get_outputs_str = ""
for name, (rtype, pos) in forward_outputs_position_map.items():
if num_outputs == 1 and len(intermediate_outputs) == 0:
get_outputs_str += f"auto& {name} = api_result;\n"
else:
get_outputs_str += f"auto& {name} = std::get<{pos}>(api_result);\n"
# Get return type list & outputs
returns_type_list = ["" for i in range(num_outputs)] returns_type_list = ["" for i in range(num_outputs)]
returns_list = ["" for i in range(num_outputs)] returns_list = ["" for i in range(num_outputs)]
for name, (rtype, pos) in forward_outputs_position_map.items(): for name, (rtype, pos) in forward_outputs_position_map.items():
if name in intermediate_outputs: if name in intermediate_outputs:
continue continue
if num_outputs == 1 and len(intermediate_outputs) == 0: returns_list[pos] = f"{name}"
returns_list[0] = f"api_result"
else:
# Tuple api_result
returns_list[pos] = f"std::get<{pos}>(api_result)"
if IsPlainTensorType(rtype): if IsPlainTensorType(rtype):
returns_type_list[pos] = "paddle::experimental::Tensor" returns_type_list[pos] = "paddle::experimental::Tensor"
...@@ -956,26 +981,24 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -956,26 +981,24 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1: if num_fwd_outputs == 1:
if IsPlainTensorType(rtype): if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);" output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
else: else:
assert IsVectorTensorType(rtype) assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n" output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else: else:
# Tuple api_result # Tuple api_result
if IsPlainTensorType(rtype): if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));" output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
else: else:
assert IsVectorTensorType(rtype) assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n" output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
outputs_autograd_meta_list.append(output_autograd_meta) outputs_autograd_meta_list.append(output_autograd_meta)
# 3. ComputeRequireGrad & PassStopGradient
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
# 4. Check Inplace # 3. Check Inplace
check_inplace_str = "" check_inplace_str = ""
bump_inplace_version_str = "" bump_inplace_version_str = ""
if is_inplaced: if is_inplaced:
...@@ -1015,7 +1038,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1015,7 +1038,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format( self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name, inputs_args_definition_str, returns_type_str, forward_function_name, inputs_args_definition_str,
dygraph_event_str, amp_logic_str, inputs_autograd_meta_str, dygraph_event_str, amp_logic_str, inputs_autograd_meta_str,
forward_call_str, outputs_autograd_meta_str, forward_call_str, get_outputs_str, outputs_autograd_meta_str,
compute_require_grad_args_str, check_inplace_str, compute_require_grad_args_str, check_inplace_str,
bump_inplace_version_str, node_creation_str, returns_str) bump_inplace_version_str, node_creation_str, returns_str)
self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n" self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n"
...@@ -1083,13 +1106,18 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1083,13 +1106,18 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
class DygraphNodeGenerator(DygraphFunctionGeneratorBase): class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
def __init__(self, forward_api_contents, grad_api_contents, namespace): def __init__(self,
forward_api_contents,
grad_api_contents,
namespace,
next_grad_api_contents=None):
DygraphFunctionGeneratorBase.__init__(self, forward_api_contents, DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
grad_api_contents, namespace) grad_api_contents, namespace)
# Generated Results # Generated Results
self.node_declaration_str = "" self.node_declaration_str = ""
self.node_definition_str = "" self.node_definition_str = ""
self.next_grad_api_contents = next_grad_api_contents
def GenerateNodeDeclaration(self): def GenerateNodeDeclaration(self):
forward_op_name = self.forward_api_name forward_op_name = self.forward_api_name
...@@ -1151,7 +1179,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1151,7 +1179,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
logging.info(f"Generated Node Declaration: {self.node_declaration_str}") logging.info(f"Generated Node Declaration: {self.node_declaration_str}")
def GenerateNodeDefinition(self): def GenerateNodeDefinition(self, grad_node_creation_str):
namespace = self.namespace namespace = self.namespace
forward_api_name = self.forward_api_name forward_api_name = self.forward_api_name
backward_api_name = self.backward_api_name backward_api_name = self.backward_api_name
...@@ -1165,62 +1193,183 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1165,62 +1193,183 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_api_args_len = len(backward_forward_inputs_map.keys()) + len( grad_api_args_len = len(backward_forward_inputs_map.keys()) + len(
backward_grad_inputs_map.keys()) + len(backward_attrs_list) backward_grad_inputs_map.keys()) + len(backward_attrs_list)
grad_api_args = ["" for i in range(grad_api_args_len)] grad_api_args = ["" for i in range(grad_api_args_len)]
get_grad_in_args_list = []
# Fill Grad Ins with Zero
fill_zero_str = ""
if forward_api_name in ops_to_fill_zero_for_empty_grads:
fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n"
# Grad Ins from TensorWrappers
for name, (_, is_fwd_input, for name, (_, is_fwd_input,
grad_api_position), in backward_forward_inputs_map.items(): grad_api_position), in backward_forward_inputs_map.items():
tensor_wrapper_name = GetSavedName(name) tensor_wrapper_name = GetSavedName(name)
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
is_optional = (name in self.optional_inputs) is_optional = (name in self.optional_inputs)
if is_optional: if is_optional:
grad_api_args[ tensor_wrapper_recover_str = f"auto {transformed_tensor_name} = egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr);"
grad_api_position] = f"egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr)"
else: else:
grad_api_args[ tensor_wrapper_recover_str = f"auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr);"
grad_api_position] = f"egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr)" grad_api_args[grad_api_position] = transformed_tensor_name
for _, (ttype, fwd_position, get_grad_in_args_list.append(tensor_wrapper_recover_str)
grad_api_position) in backward_grad_inputs_map.items():
# Grad Ins from grads
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
grad_api_args[ get_tensor_str = f"auto& {transformed_tensor_name} = hooked_grads[{fwd_position}][0];"
grad_api_position] = f"hooked_grads[{fwd_position}][0]"
else: else:
assert IsVectorTensorType(ttype) assert IsVectorTensorType(ttype)
grad_api_args[ get_tensor_str = f"auto& {transformed_tensor_name} = hooked_grads[{fwd_position}];"
grad_api_position] = f"hooked_grads[{fwd_position}]" grad_api_args[grad_api_position] = transformed_tensor_name
get_grad_in_args_list.append(get_tensor_str)
# Grad Attrs
for name, _, _, grad_api_position in backward_attrs_list: for name, _, _, grad_api_position in backward_attrs_list:
saved_attribute_name = GetSavedName(name) saved_attribute_name = GetSavedName(name)
grad_api_args[grad_api_position] = f"this->{saved_attribute_name}" get_attr_str = f"auto& {name} = this->{saved_attribute_name};"
grad_api_args[grad_api_position] = name
get_grad_in_args_list.append(get_attr_str)
get_grad_in_args_str = "\n".join(get_grad_in_args_list)
grad_api_args_str = ", ".join(grad_api_args) grad_api_args_str = ", ".join(grad_api_args)
# Grad Function Call String
grad_api_namespace = f"paddle::experimental::{namespace}"
grad_function_call_str = f"auto grad_api_result = {grad_api_namespace}{backward_api_name}({grad_api_args_str});"
# Get Grad Outputs
get_outputs_str = ""
num_outputs = len(backward_grad_outputs_map.keys())
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
if num_outputs == 1:
get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result;"
else:
get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result[{fwd_position}];"
get_outputs_str += get_tensor_str + "\n"
# Prepare for Node Creation if Necessary
inputs_autograd_meta_str = ""
outputs_autograd_meta_str = ""
compute_require_grad_str = ""
if len(grad_node_creation_str) > 0:
# 1. Get Input AutoGradMeta
inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos,
grad_api_position) in backward_grad_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(
transformed_tensor_name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
# 2. Get TensorWrapper AutoGradMeta
for name, (ttype, _, pos), in backward_forward_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(
transformed_tensor_name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
compute_require_grad_args_str = ",".join(
compute_require_grad_args_list)
# 3. Get Output AutoGradMeta
outputs_autograd_meta_list = []
num_fwd_outputs = len(backward_grad_outputs_map.keys())
for name, (rtype, pos, _) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
output_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(
transformed_tensor_name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
outputs_autograd_meta_list.append(output_autograd_meta)
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
compute_require_grad_str = "bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n"
compute_require_grad_str += f"bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({compute_require_grad_args_str});"
# Construct grad_api returns # Construct grad_api returns
num_bwd_outputs = len(backward_grad_outputs_map.keys()) num_bwd_outputs = len(backward_grad_outputs_map.keys())
slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys()) slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys())
returns_str = f"std::vector<std::vector<paddle::experimental::Tensor>> returns({slot_num_bwd_outputs});\n" returns_str = f"std::vector<std::vector<paddle::experimental::Tensor>> returns({slot_num_bwd_outputs});\n"
for _, (ttype, fwd_position, for name, (ttype, fwd_position,
grad_api_position) in backward_grad_outputs_map.items(): grad_api_position) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
# Infer Grad API Return Type # Infer Grad API Return Type
if num_bwd_outputs == 1: if num_bwd_outputs == 1:
# Single tensor output, return as is # Single tensor output, return as is
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
returns_str += "returns[0] = { grad_api_returns };\n" returns_str += f"returns[0] = {{ {transformed_tensor_name} }};\n"
else: else:
assert IsVectorTensorType(ttype) assert IsVectorTensorType(ttype)
returns_str += "returns[0] = grad_api_returns;\n" returns_str += f"returns[0] = {transformed_tensor_name};\n"
else: else:
# Rearrange output order accordingly # Rearrange output order accordingly
returns_str += f"returns[{fwd_position}] = grad_api_returns[{grad_api_position}];\n" returns_str += f"returns[{fwd_position}] = {transformed_tensor_name};\n"
returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n" returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n"
returns_str += f"return returns;\n" returns_str += f"return returns;\n"
grad_node_name = GetGradNodeName(forward_api_name) grad_node_name = GetGradNodeName(forward_api_name)
fill_zero_str = "" self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format(
if forward_api_name in ops_to_fill_zero_for_empty_grads: grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name,
fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n" grad_function_call_str, get_outputs_str, inputs_autograd_meta_str,
outputs_autograd_meta_str, compute_require_grad_str,
grad_api_namespace = f"paddle::experimental::{namespace}" grad_node_creation_str, returns_str)
self.node_definition_str = FUNCTION_TEMPLATE.format(
grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace,
backward_api_name, grad_api_args_str, returns_str)
logging.info(f"Generated Node Definition: {self.node_definition_str}") logging.info(f"Generated Node Definition: {self.node_definition_str}")
...@@ -1231,7 +1380,22 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1231,7 +1380,22 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
## Code Generation ## ## Code Generation ##
##################### #####################
self.GenerateNodeDeclaration() self.GenerateNodeDeclaration()
self.GenerateNodeDefinition()
namespace = self.namespace
grad_node_creation_str = ""
next_grad_api_contents = self.next_grad_api_contents
if next_grad_api_contents:
forward_api_contents = self.grad_api_contents
forward_api_contents['api'] = forward_api_contents['backward_api']
backward_api_contents = next_grad_api_contents
next_node_generator = DygraphFunctionGeneratorBase(
forward_api_contents, backward_api_contents, namespace)
next_node_generator.run()
next_node_generator.GenerateNodeCreationCodes()
grad_node_creation_str = next_node_generator.node_creation_str
self.GenerateNodeDefinition(grad_node_creation_str)
class DygraphYamlGenerator(YamlGeneratorBase): class DygraphYamlGenerator(YamlGeneratorBase):
...@@ -1278,18 +1442,34 @@ class DygraphYamlGenerator(YamlGeneratorBase): ...@@ -1278,18 +1442,34 @@ class DygraphYamlGenerator(YamlGeneratorBase):
forward_api_contents) forward_api_contents)
if backward_api_contents is None: continue if backward_api_contents is None: continue
# Generate Dygraph Forward Function
function_generator = DygraphForwardFunctionGenerator( function_generator = DygraphForwardFunctionGenerator(
forward_api_contents, backward_api_contents, namespace) forward_api_contents, backward_api_contents, namespace)
function_generator.run() function_generator.run()
node_generator = DygraphNodeGenerator(
forward_api_contents, backward_api_contents, namespace)
node_generator.run()
self.forward_definition_str += function_generator.forward_definition_str + "\n" self.forward_definition_str += function_generator.forward_definition_str + "\n"
self.forward_declaration_str += function_generator.forward_declaration_str + "\n" self.forward_declaration_str += function_generator.forward_declaration_str + "\n"
self.node_declaration_str += node_generator.node_declaration_str + "\n"
self.node_definition_str += node_generator.node_definition_str + "\n" while True:
next_grad_api_contents = self.GetBackwardAPIContents(
backward_api_contents)
node_generator = DygraphNodeGenerator(
forward_api_contents, backward_api_contents, namespace,
next_grad_api_contents)
node_generator.run()
self.node_declaration_str += node_generator.node_declaration_str + "\n"
self.node_definition_str += node_generator.node_definition_str + "\n"
if next_grad_api_contents is None: break
# Detect if there exists higher-order GradNode
forward_api_contents = backward_api_contents
# Fake forward_api_content
forward_api_contents['api'] = forward_api_contents[
'backward_api']
backward_api_contents = next_grad_api_contents
if len(namespace) > 0: if len(namespace) > 0:
if namespace.endswith("::"): if namespace.endswith("::"):
......
...@@ -649,6 +649,16 @@ ...@@ -649,6 +649,16 @@
kernel : kernel :
func : put_along_axis_grad func : put_along_axis_grad
- backward_api : relu_double_grad
forward : relu_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor out, Tensor grad_x_grad)
output : Tensor(out_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [out, out]
kernel :
func : relu_double_grad
- backward_api : relu_grad - backward_api : relu_grad
forward : relu (Tensor x) -> Tensor(out) forward : relu (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad) args : (Tensor out, Tensor out_grad)
...@@ -658,6 +668,7 @@ ...@@ -658,6 +668,7 @@
param : [out] param : [out]
kernel : kernel :
func : relu_grad func : relu_grad
backward: relu_double_grad
- backward_api : reshape_grad - backward_api : reshape_grad
forward : reshape_with_xshape (Tensor x, ScalarArray shape) -> Tensor(out), Tensor(xshape) forward : reshape_with_xshape (Tensor x, ScalarArray shape) -> Tensor(out), Tensor(xshape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册