未验证 提交 86554d91 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[DoubleGrad PR #1] Decoupled code generation logics for Dygraph...

[DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes (#40937)

* [Refactor] refactored eager_gen.py PR #2

* [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes

* Fixed minor issue
上级 0c024cb9
......@@ -172,7 +172,7 @@ FORWARD_FUNCTION_TEMPLATE = \
"""
NODE_CREATION_TEMPLATE = \
FORWARD_BODY_TEMPLATE = \
"""
// Get AutoGradMeta
{}
......@@ -305,7 +305,6 @@ BUMP_INPLACE_VERSION_TEMPLATE = \
VLOG(3) << \"Tensor(\" << {}.name() << \") uses Inplace Strategy.\";\n
"""
AMP_LOGIC_TEMPLATE = \
"""
if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{
......@@ -363,7 +362,7 @@ def GenerateCoreOpInfoDefinition():
#####################
## Generator Class ##
#####################
class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
def __init__(self, forward_api_contents, grad_api_contents, namespace):
self.forward_api_contents = forward_api_contents
# Members from Parent:
......@@ -409,12 +408,6 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
self.backward_grad_outputs_map = {
} #{ "name" : [type, fwd_position, orig_position] ...}
# Generated Results
self.forward_definition_str = ""
self.forward_declaration_str = ""
self.node_declaration_str = ""
self.node_definition_str = ""
def DygraphYamlValidationCheck(self):
forward_api_contents = self.forward_api_contents
grad_api_contents = self.grad_api_contents
......@@ -632,139 +625,237 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
f"Generated Backward Grad Output Map: {self.backward_grad_outputs_map}"
)
def GenerateNodeDeclaration(self):
forward_op_name = self.forward_api_name
def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):
forward_api_name = self.forward_api_name
forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map
forward_attrs_list = self.forward_attrs_list
backward_forward_inputs_map = self.backward_forward_inputs_map
backward_grad_inputs_map = self.backward_grad_inputs_map
backward_grad_outputs_map = self.backward_grad_outputs_map
backward_attrs_list = self.backward_attrs_list
no_need_buffers = self.no_need_buffers
optional_inputs = self.optional_inputs
inplace_map = self.inplace_map if is_inplaced else {}
# SetTensorWrapper Methods & TensorWrapper Members
set_tensor_wrapper_methods_str = ""
tensor_wrapper_members_str = ""
clear_tensor_wrapper_str = ""
for tname, (ttype, is_fwd_input,
_) in backward_forward_inputs_map.items():
no_need_buffer = "true" if tname in no_need_buffers else "false"
tensor_wrapper_name = GetSavedName(tname)
# Get Input AutoGradMeta
inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype):
set_tensor_wrapper_methods_str += SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format(
tname, tname, tensor_wrapper_name, tname, no_need_buffer)
tensor_wrapper_members_str += PLAIN_TENSOR_MEMBER_TEMPLATE.format(
tensor_wrapper_name)
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPER_TEMPLATE.format(
tensor_wrapper_name)
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
compute_require_grad_args_str = ",".join(compute_require_grad_args_list)
# Get Output AutoGradMeta
outputs_autograd_meta_list = []
pass_stop_gradient_args_list = ["false"]
num_fwd_outputs = len(forward_outputs_position_map.keys())
for name, (rtype, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
assert IsVectorTensorType(ttype)
set_tensor_wrapper_methods_str += SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format(
tname, tname, tname, tensor_wrapper_name, no_need_buffer)
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
tensor_wrapper_members_str += VECTOR_TENSOR_MEMBER_TEMPLATE.format(
tensor_wrapper_name)
outputs_autograd_meta_list.append(output_autograd_meta)
pass_stop_gradient_args_list.append(output_autograd_meta_name)
clear_tensor_wrapper_str += CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE.format(
tensor_wrapper_name)
# ComputeRequireGrad & PassStopGradient
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
# SetAttributes & Attribute Members
set_attribute_methods_str = ""
attribute_members_str = ""
for aname, atype, default_val, _ in backward_attrs_list:
saved_attr_name = GetSavedName(aname)
set_attribute_methods_str += SET_ATTR_METHOD_TEMPLATE.format(
aname, GetConstReference(atype), aname, saved_attr_name, aname)
# Check Inplace
check_inplace_str = ""
bump_inplace_version_str = ""
if is_inplaced:
for inplace_name in inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
inplace_name, inplace_autograd_meta_name)
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
inplace_name, inplace_name)
if default_val:
attribute_members_str += ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE.format(
RemoveConstAndReference(atype), saved_attr_name,
default_val)
# Node Construction
num_backward_inputs = len(forward_outputs_position_map.keys())
num_backward_outputs = len(forward_inputs_position_map.keys())
grad_node_name = GetGradNodeName(forward_api_name)
node_construction_str = f" auto grad_node = std::make_shared<{grad_node_name}>({num_backward_inputs}, {num_backward_outputs});"
# SetAttributes
set_attributes_list = []
forward_attrs_name_set = set()
for name, _, _, _ in forward_attrs_list:
forward_attrs_name_set.add(name)
for name, _, default_val_attr, _ in backward_attrs_list:
if name in forward_attrs_name_set:
set_attributes = f" grad_node->SetAttribute{name}({name});"
else:
attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format(
RemoveConstAndReference(atype), saved_attr_name)
set_attributes = f" grad_node->SetAttribute{name}({default_val_attr});"
set_attributes_list.append(set_attributes)
set_attributes_str = "\n".join(set_attributes_list)
grad_node_name = GetGradNodeName(forward_op_name)
self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
grad_node_name, grad_node_name, grad_node_name, grad_node_name,
grad_node_name, clear_tensor_wrapper_str,
set_tensor_wrapper_methods_str, set_attribute_methods_str,
tensor_wrapper_members_str, attribute_members_str)
# SetTensorWrappers
set_tensor_wrappers_list = []
for name, (atype, is_fwd_input,
pos) in backward_forward_inputs_map.items():
is_optional = (name in optional_inputs)
logging.info(f"Generated Node Declaration: {self.node_declaration_str}")
if is_fwd_input:
if is_optional:
set_tensor_wrappers = f" if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
else:
if num_fwd_outputs > 1:
# Aligned with forward output position
assert name in forward_outputs_position_map.keys(
), AssertMessage(name, forward_outputs_position_map.keys())
fwd_output_pos = forward_outputs_position_map[name][1]
tw_name = f"std::get<{fwd_output_pos}>(api_result)"
else:
tw_name = f"api_result"
def GenerateNodeDefinition(self):
namespace = self.namespace
forward_api_name = self.forward_api_name
backward_api_name = self.backward_api_name
backward_forward_inputs_map = self.backward_forward_inputs_map
backward_grad_inputs_map = self.backward_grad_inputs_map
backward_grad_outputs_map = self.backward_grad_outputs_map
backward_attrs_list = self.backward_attrs_list
if is_optional:
set_tensor_wrappers = f" if({tw_name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({tw_name}.get_ptr()), false);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);"
set_tensor_wrappers_list.append(set_tensor_wrappers)
set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list)
# Construct grad_api function args
# Order: TensorWrappers, GradTensors, Attributes
grad_api_args_len = len(backward_forward_inputs_map.keys()) + len(
backward_grad_inputs_map.keys()) + len(backward_attrs_list)
grad_api_args = ["" for i in range(grad_api_args_len)]
for name, (_, is_fwd_input,
grad_api_position), in backward_forward_inputs_map.items():
tensor_wrapper_name = GetSavedName(name)
# SetGradOutMeta & SetEdges
set_grad_out_meta_list = []
set_edges_list = []
for name, (_, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
is_optional = (name in self.optional_inputs)
if is_optional:
grad_api_args[
grad_api_position] = f"egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr)"
set_grad_out_meta = f" if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});"
set_edges = f" if({name}.get_ptr() != nullptr) grad_node->AddEdges({input_autograd_meta_name}, {pos});"
else:
grad_api_args[
grad_api_position] = f"egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr)"
set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});"
set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});"
for _, (ttype, fwd_position,
grad_api_position) in backward_grad_inputs_map.items():
if IsPlainTensorType(ttype):
grad_api_args[
grad_api_position] = f"hooked_grads[{fwd_position}][0]"
else:
assert IsVectorTensorType(ttype)
grad_api_args[
grad_api_position] = f"hooked_grads[{fwd_position}]"
set_grad_out_meta_list.append(set_grad_out_meta)
set_edges_list.append(set_edges)
set_grad_out_meta_str = "\n".join(set_grad_out_meta_list)
set_edges_str = "\n".join(set_edges_list)
for name, _, _, grad_api_position in backward_attrs_list:
saved_attribute_name = GetSavedName(name)
grad_api_args[grad_api_position] = f"this->{saved_attribute_name}"
grad_api_args_str = ", ".join(grad_api_args)
# SetOutRank & SetHistory & SetGradInMeta
set_out_rank_list = []
set_history_list = []
set_grad_in_meta_list = []
set_retain_grad_list = []
num_outputs = len(forward_outputs_position_map.keys())
for name, (_, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
# Construct grad_api returns
num_bwd_outputs = len(backward_grad_outputs_map.keys())
slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys())
returns_str = f"std::vector<std::vector<paddle::experimental::Tensor>> returns({slot_num_bwd_outputs});\n"
for _, (ttype, fwd_position,
grad_api_position) in backward_grad_outputs_map.items():
# Infer Grad API Return Type
if num_bwd_outputs == 1:
# Single tensor output, return as is
if IsPlainTensorType(ttype):
returns_str += "returns[0] = { grad_api_returns };\n"
else:
assert IsVectorTensorType(ttype)
returns_str += "returns[0] = grad_api_returns;\n"
if num_outputs == 1:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);"
set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});"
else:
# Rearrange output order accordingly
returns_str += f"returns[{fwd_position}] = grad_api_returns[{grad_api_position}];\n"
returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n"
returns_str += f"return returns;\n"
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));"
set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});"
set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history)
set_grad_in_meta_list.append(set_grad_in_meta)
set_retain_grad_list.append(set_retain_grad)
grad_node_name = GetGradNodeName(forward_api_name)
set_out_rank_str = "\n".join(set_out_rank_list)
set_history_str = "\n".join(set_history_list)
set_grad_in_meta_str = "\n".join(set_grad_in_meta_list)
set_retain_grad_str = "\n".join(set_retain_grad_list)
fill_zero_str = ""
if forward_api_name in ops_to_fill_zero_for_empty_grads:
fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n"
node_event_name = forward_api_name + " node_creation"
node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n"
grad_api_namespace = f"paddle::experimental::{namespace}"
self.node_creation_str = FORWARD_BODY_TEMPLATE.format(
inputs_autograd_meta_str, compute_require_grad_args_str,
check_inplace_str, forward_call_str, bump_inplace_version_str,
node_creation_event_str, outputs_autograd_meta_str,
pass_stop_gradient_args_str, node_construction_str,
set_attributes_str, set_tensor_wrappers_str, set_grad_out_meta_str,
set_edges_str, set_out_rank_str, set_history_str,
set_grad_in_meta_str, set_retain_grad_str)
self.node_definition_str = FUNCTION_TEMPLATE.format(
grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace,
backward_api_name, grad_api_args_str, returns_str)
def run(self):
# Basic Validation Check
self.DygraphYamlValidationCheck()
logging.info(f"Generated Node Definition: {self.node_definition_str}")
##########################
## Parsing Raw Contents ##
##########################
# Parse inplace_map
self.ParseInplaceInfo()
# Parse no_need_buffer
self.ParseNoNeedBuffer()
# Parse optional_inputs
self.ParseDispensable()
# Parse intermediate_outputs
self.ParseIntermediate()
self.IntermediateValidationCheck()
# Initialize backward_forward_str, backward_inputs_list, backward_attrs_list, backward_returns_list
self.CollectBackwardInfo()
# Initialize forward_inputs_list, forward_attrs_list, forward_returns_list
self.CollectForwardInfoFromBackwardContents()
# Initialize orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list
self.CollectOriginalForwardInfo()
# Forwards Validation Check
self.ForwardsValidationCheck()
#############################
## Process Parsed Contents ##
#############################
# Initialize forward_inputs_position_map, forward_outputs_position_map
self.DetermineForwardPositionMap(self.forward_inputs_list,
self.forward_returns_list)
# Initialize forward_inputs_position_map, forward_outputs_position_map
self.SlotNameMatching()
# Backward Validation Check
self.BackwardValidationCheck()
class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
def __init__(self, forward_api_contents, grad_api_contents, namespace):
DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
grad_api_contents, namespace)
# Generated Results
self.forward_definition_str = ""
self.forward_declaration_str = ""
def GenerateForwardDefinition(self, is_inplaced):
namespace = self.namespace
......@@ -780,7 +871,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
backward_attrs_list = self.backward_attrs_list
optional_inputs = self.optional_inputs
intermediate_outputs = self.intermediate_outputs
inplace_map = self.inplace_map
inplace_map = self.inplace_map if is_inplaced else {}
# Get Function Args
num_inputs = len(forward_attrs_list) + len(
......@@ -918,182 +1009,6 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
logging.info(
f"Generated Forward Declaration: {self.forward_declaration_str}")
def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):
forward_api_name = self.forward_api_name
forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map
forward_attrs_list = self.forward_attrs_list
backward_forward_inputs_map = self.backward_forward_inputs_map
backward_grad_inputs_map = self.backward_grad_inputs_map
backward_grad_outputs_map = self.backward_grad_outputs_map
backward_attrs_list = self.backward_attrs_list
optional_inputs = self.optional_inputs
inplace_map = self.inplace_map
# Get Input AutoGradMeta
inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
compute_require_grad_args_str = ",".join(compute_require_grad_args_list)
# Get Output AutoGradMeta
outputs_autograd_meta_list = []
pass_stop_gradient_args_list = ["false"]
num_fwd_outputs = len(forward_outputs_position_map.keys())
for name, (rtype, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
outputs_autograd_meta_list.append(output_autograd_meta)
pass_stop_gradient_args_list.append(output_autograd_meta_name)
# ComputeRequireGrad & PassStopGradient
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
# Check Inplace
check_inplace_str = ""
bump_inplace_version_str = ""
if is_inplaced:
for inplace_name in inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
inplace_name, inplace_autograd_meta_name)
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
inplace_name, inplace_name)
# Node Construction
num_backward_inputs = len(forward_outputs_position_map.keys())
num_backward_outputs = len(forward_inputs_position_map.keys())
grad_node_name = GetGradNodeName(forward_api_name)
node_construction_str = f" auto grad_node = std::make_shared<{grad_node_name}>({num_backward_inputs}, {num_backward_outputs});"
# SetAttributes
set_attributes_list = []
forward_attrs_name_set = set()
for name, _, _, _ in forward_attrs_list:
forward_attrs_name_set.add(name)
for name, _, default_val_attr, _ in backward_attrs_list:
if name in forward_attrs_name_set:
set_attributes = f" grad_node->SetAttribute{name}({name});"
else:
set_attributes = f" grad_node->SetAttribute{name}({default_val_attr});"
set_attributes_list.append(set_attributes)
set_attributes_str = "\n".join(set_attributes_list)
# SetTensorWrappers
set_tensor_wrappers_list = []
for name, (atype, is_fwd_input,
pos) in backward_forward_inputs_map.items():
is_optional = (name in optional_inputs)
if is_fwd_input:
if is_optional:
set_tensor_wrappers = f" if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
else:
if num_fwd_outputs > 1:
# Aligned with forward output position
assert name in forward_outputs_position_map.keys(
), AssertMessage(name, forward_outputs_position_map.keys())
fwd_output_pos = forward_outputs_position_map[name][1]
tw_name = f"std::get<{fwd_output_pos}>(api_result)"
else:
tw_name = f"api_result"
if is_optional:
set_tensor_wrappers = f" if({tw_name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({tw_name}.get_ptr()), false);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);"
set_tensor_wrappers_list.append(set_tensor_wrappers)
set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list)
# SetGradOutMeta & SetEdges
set_grad_out_meta_list = []
set_edges_list = []
for name, (_, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
is_optional = (name in self.optional_inputs)
if is_optional:
set_grad_out_meta = f" if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});"
set_edges = f" if({name}.get_ptr() != nullptr) grad_node->AddEdges({input_autograd_meta_name}, {pos});"
else:
set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});"
set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});"
set_grad_out_meta_list.append(set_grad_out_meta)
set_edges_list.append(set_edges)
set_grad_out_meta_str = "\n".join(set_grad_out_meta_list)
set_edges_str = "\n".join(set_edges_list)
# SetOutRank & SetHistory & SetGradInMeta
set_out_rank_list = []
set_history_list = []
set_grad_in_meta_list = []
set_retain_grad_list = []
num_outputs = len(forward_outputs_position_map.keys())
for name, (_, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
if num_outputs == 1:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);"
set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});"
else:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));"
set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});"
set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history)
set_grad_in_meta_list.append(set_grad_in_meta)
set_retain_grad_list.append(set_retain_grad)
set_out_rank_str = "\n".join(set_out_rank_list)
set_history_str = "\n".join(set_history_list)
set_grad_in_meta_str = "\n".join(set_grad_in_meta_list)
set_retain_grad_str = "\n".join(set_retain_grad_list)
node_event_name = forward_api_name + " node_creation"
node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n"
self.node_creation_str = NODE_CREATION_TEMPLATE.format(
inputs_autograd_meta_str, compute_require_grad_args_str,
check_inplace_str, forward_call_str, bump_inplace_version_str,
node_creation_event_str, outputs_autograd_meta_str,
pass_stop_gradient_args_str, node_construction_str,
set_attributes_str, set_tensor_wrappers_str, set_grad_out_meta_str,
set_edges_str, set_out_rank_str, set_history_str,
set_grad_in_meta_str, set_retain_grad_str)
def GenerateInplacedForwardDygraphFunctions(self):
# Inplaced Version Dygraph Function Generation
forward_api_name = self.forward_api_name
......@@ -1139,60 +1054,168 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
core_ops_returns_info[final_state_fwd_api_name][pos] = name
def run(self):
# Basic Validation Check
self.DygraphYamlValidationCheck()
super().run()
##########################
## Parsing Raw Contents ##
##########################
# Parse inplace_map
self.ParseInplaceInfo()
#####################
## Code Generation ##
#####################
self.GenerateForwardDefinition(is_inplaced=False)
# Parse no_need_buffer
self.ParseNoNeedBuffer()
self.UpdateCoreOpsInformation(is_inplaced=False)
# Parse optional_inputs
self.ParseDispensable()
self.GenerateInplacedForwardDygraphFunctions()
# Parse intermediate_outputs
self.ParseIntermediate()
self.IntermediateValidationCheck()
# Initialize backward_forward_str, backward_inputs_list, backward_attrs_list, backward_returns_list
self.CollectBackwardInfo()
class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
def __init__(self, forward_api_contents, grad_api_contents, namespace):
DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
grad_api_contents, namespace)
# Initialize forward_inputs_list, forward_attrs_list, forward_returns_list
self.CollectForwardInfoFromBackwardContents()
# Generated Results
self.node_declaration_str = ""
self.node_definition_str = ""
# Initialize orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list
self.CollectOriginalForwardInfo()
def GenerateNodeDeclaration(self):
forward_op_name = self.forward_api_name
backward_forward_inputs_map = self.backward_forward_inputs_map
backward_attrs_list = self.backward_attrs_list
no_need_buffers = self.no_need_buffers
# Forwards Validation Check
self.ForwardsValidationCheck()
# SetTensorWrapper Methods & TensorWrapper Members
set_tensor_wrapper_methods_str = ""
tensor_wrapper_members_str = ""
clear_tensor_wrapper_str = ""
for tname, (ttype, is_fwd_input,
_) in backward_forward_inputs_map.items():
no_need_buffer = "true" if tname in no_need_buffers else "false"
tensor_wrapper_name = GetSavedName(tname)
if IsPlainTensorType(ttype):
set_tensor_wrapper_methods_str += SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format(
tname, tname, tensor_wrapper_name, tname, no_need_buffer)
#############################
## Process Parsed Contents ##
#############################
# Initialize forward_inputs_position_map, forward_outputs_position_map
self.DetermineForwardPositionMap(self.forward_inputs_list,
self.forward_returns_list)
tensor_wrapper_members_str += PLAIN_TENSOR_MEMBER_TEMPLATE.format(
tensor_wrapper_name)
# Initialize forward_inputs_position_map, forward_outputs_position_map
self.SlotNameMatching()
clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPER_TEMPLATE.format(
tensor_wrapper_name)
# Backward Validation Check
self.BackwardValidationCheck()
else:
assert IsVectorTensorType(ttype)
set_tensor_wrapper_methods_str += SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format(
tname, tname, tname, tensor_wrapper_name, no_need_buffer)
tensor_wrapper_members_str += VECTOR_TENSOR_MEMBER_TEMPLATE.format(
tensor_wrapper_name)
clear_tensor_wrapper_str += CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE.format(
tensor_wrapper_name)
# SetAttributes & Attribute Members
set_attribute_methods_str = ""
attribute_members_str = ""
for aname, atype, default_val, _ in backward_attrs_list:
saved_attr_name = GetSavedName(aname)
set_attribute_methods_str += SET_ATTR_METHOD_TEMPLATE.format(
aname, GetConstReference(atype), aname, saved_attr_name, aname)
if default_val:
attribute_members_str += ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE.format(
RemoveConstAndReference(atype), saved_attr_name,
default_val)
else:
attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format(
RemoveConstAndReference(atype), saved_attr_name)
grad_node_name = GetGradNodeName(forward_op_name)
self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
grad_node_name, grad_node_name, grad_node_name, grad_node_name,
grad_node_name, clear_tensor_wrapper_str,
set_tensor_wrapper_methods_str, set_attribute_methods_str,
tensor_wrapper_members_str, attribute_members_str)
logging.info(f"Generated Node Declaration: {self.node_declaration_str}")
def GenerateNodeDefinition(self):
namespace = self.namespace
forward_api_name = self.forward_api_name
backward_api_name = self.backward_api_name
backward_forward_inputs_map = self.backward_forward_inputs_map
backward_grad_inputs_map = self.backward_grad_inputs_map
backward_grad_outputs_map = self.backward_grad_outputs_map
backward_attrs_list = self.backward_attrs_list
# Construct grad_api function args
# Order: TensorWrappers, GradTensors, Attributes
grad_api_args_len = len(backward_forward_inputs_map.keys()) + len(
backward_grad_inputs_map.keys()) + len(backward_attrs_list)
grad_api_args = ["" for i in range(grad_api_args_len)]
for name, (_, is_fwd_input,
grad_api_position), in backward_forward_inputs_map.items():
tensor_wrapper_name = GetSavedName(name)
is_optional = (name in self.optional_inputs)
if is_optional:
grad_api_args[
grad_api_position] = f"egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr)"
else:
grad_api_args[
grad_api_position] = f"egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr)"
for _, (ttype, fwd_position,
grad_api_position) in backward_grad_inputs_map.items():
if IsPlainTensorType(ttype):
grad_api_args[
grad_api_position] = f"hooked_grads[{fwd_position}][0]"
else:
assert IsVectorTensorType(ttype)
grad_api_args[
grad_api_position] = f"hooked_grads[{fwd_position}]"
for name, _, _, grad_api_position in backward_attrs_list:
saved_attribute_name = GetSavedName(name)
grad_api_args[grad_api_position] = f"this->{saved_attribute_name}"
grad_api_args_str = ", ".join(grad_api_args)
# Construct grad_api returns
num_bwd_outputs = len(backward_grad_outputs_map.keys())
slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys())
returns_str = f"std::vector<std::vector<paddle::experimental::Tensor>> returns({slot_num_bwd_outputs});\n"
for _, (ttype, fwd_position,
grad_api_position) in backward_grad_outputs_map.items():
# Infer Grad API Return Type
if num_bwd_outputs == 1:
# Single tensor output, return as is
if IsPlainTensorType(ttype):
returns_str += "returns[0] = { grad_api_returns };\n"
else:
assert IsVectorTensorType(ttype)
returns_str += "returns[0] = grad_api_returns;\n"
else:
# Rearrange output order accordingly
returns_str += f"returns[{fwd_position}] = grad_api_returns[{grad_api_position}];\n"
returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n"
returns_str += f"return returns;\n"
grad_node_name = GetGradNodeName(forward_api_name)
fill_zero_str = ""
if forward_api_name in ops_to_fill_zero_for_empty_grads:
fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n"
grad_api_namespace = f"paddle::experimental::{namespace}"
self.node_definition_str = FUNCTION_TEMPLATE.format(
grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace,
backward_api_name, grad_api_args_str, returns_str)
logging.info(f"Generated Node Definition: {self.node_definition_str}")
def run(self):
super().run()
#####################
## Code Generation ##
#####################
self.GenerateNodeDeclaration()
self.GenerateNodeDefinition()
self.GenerateForwardDefinition(is_inplaced=False)
self.UpdateCoreOpsInformation(is_inplaced=False)
self.GenerateInplacedForwardDygraphFunctions()
class DygraphYamlGenerator(YamlGeneratorBase):
......@@ -1239,14 +1262,18 @@ class DygraphYamlGenerator(YamlGeneratorBase):
forward_api_contents)
if backward_api_contents is None: continue
d_generator = DygraphSingleFunctionGenerator(
function_generator = DygraphForwardFunctionGenerator(
forward_api_contents, backward_api_contents, namespace)
function_generator.run()
node_generator = DygraphNodeGenerator(
forward_api_contents, backward_api_contents, namespace)
d_generator.run()
node_generator.run()
self.forward_definition_str += d_generator.forward_definition_str + "\n"
self.forward_declaration_str += d_generator.forward_declaration_str + "\n"
self.node_declaration_str += d_generator.node_declaration_str + "\n"
self.node_definition_str += d_generator.node_definition_str + "\n"
self.forward_definition_str += function_generator.forward_definition_str + "\n"
self.forward_declaration_str += function_generator.forward_declaration_str + "\n"
self.node_declaration_str += node_generator.node_declaration_str + "\n"
self.node_definition_str += node_generator.node_definition_str + "\n"
if len(namespace) > 0:
if namespace.endswith("::"):
......
......@@ -311,7 +311,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
dygraph_function_call_list[pos] = f"{name}"
dygraph_function_call_str = ",".join(dygraph_function_call_list)
# Generate Python-C Function Definitions
# Generate Python-C Function Definitions
if is_forward_only:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"paddle::experimental::", namespace, forward_api_name)
......@@ -337,21 +337,21 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
forward_api_name_prefix, forward_api_name, namespace,
forward_api_name, forward_api_name)
if len(inplace_map) > 0:
assert len(
inplace_map
) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}"
if inplace_map:
inplaced_forward_api_name = GetInplacedFunctionName(
self.forward_api_name)
# Generate Python-C Function Definitions
if is_forward_only:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"paddle::experimental::", namespace,
inplaced_forward_api_name)
elif len(inplace_map) > 0:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
else:
inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace,
GetForwardFunctionName(inplaced_forward_api_name))
assert len(
inplace_map
) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}"
for inplace_input, inplace_output in inplace_map.items():
return_str = RETURN_INPLACE_PYOBJECT_TEMPLATE.format(
inplaced_forward_api_name, inplace_input,
......@@ -361,7 +361,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
self.python_c_function_str += PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, fwd_function_name,
parse_attributes_str, inplaced_fwd_function_name,
dygraph_function_call_str, return_str)
# Generate Python-C Function Registration
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册