未验证 提交 b9342a80 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Polish eager code generation (#42822)

* [Eager] Polish eager code generation

* Remove useless code in codegen
上级 570d0322
...@@ -418,7 +418,7 @@ class FunctionGeneratorBase: ...@@ -418,7 +418,7 @@ class FunctionGeneratorBase:
return_name] = [return_type, return_pos] return_name] = [return_type, return_pos]
class YamlGeneratorBase: class GeneratorBase:
def __init__(self, api_yaml_path): def __init__(self, api_yaml_path):
self.namespace = "" self.namespace = ""
self.api_yaml_path = api_yaml_path self.api_yaml_path = api_yaml_path
......
...@@ -29,7 +29,7 @@ from codegen_utils import RemoveSpecialSymbolsInName, RecoverBaseNameOfInplaceFu ...@@ -29,7 +29,7 @@ from codegen_utils import RemoveSpecialSymbolsInName, RecoverBaseNameOfInplaceFu
from codegen_utils import GetInplacedFunctionName from codegen_utils import GetInplacedFunctionName
from codegen_utils import ParseYamlArgs, ParseYamlReturns, ParseYamlForwardFromBackward from codegen_utils import ParseYamlArgs, ParseYamlReturns, ParseYamlForwardFromBackward
from codegen_utils import ParseYamlForward, ParseYamlBackward from codegen_utils import ParseYamlForward, ParseYamlBackward
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase from codegen_utils import FunctionGeneratorBase, GeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads from codegen_utils import ops_to_fill_zero_for_empty_grads
from codegen_utils import AssertMessage, GetIndent from codegen_utils import AssertMessage, GetIndent
...@@ -60,14 +60,6 @@ SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = \ ...@@ -60,14 +60,6 @@ SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = \
}} }}
""" """
PLAIN_TENSOR_MEMBER_TEMPLATE = \
""" egr::TensorWrapper {};
"""
CLEAR_TENSOR_WRAPPER_TEMPLATE = \
""" {}.clear();
"""
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = \ SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = \
""" void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}) {{ """ void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}) {{
for(const auto& eager_tensor : {}) {{ for(const auto& eager_tensor : {}) {{
...@@ -76,10 +68,18 @@ SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = \ ...@@ -76,10 +68,18 @@ SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = \
}} }}
""" """
PLAIN_TENSOR_MEMBER_TEMPLATE = \
""" egr::TensorWrapper {};
"""
VECTOR_TENSOR_MEMBER_TEMPLATE = \ VECTOR_TENSOR_MEMBER_TEMPLATE = \
""" std::vector<egr::TensorWrapper> {}; """ std::vector<egr::TensorWrapper> {};
""" """
CLEAR_TENSOR_WRAPPER_TEMPLATE = \
""" {}.clear();
"""
CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE = \ CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE = \
""" for (auto& tw : {}) {{ """ for (auto& tw : {}) {{
tw.clear(); tw.clear();
...@@ -423,9 +423,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -423,9 +423,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self.forward_returns_list = [ self.forward_returns_list = [
] #[ [ret_name, ret_type, orig_position], ...] ] #[ [ret_name, ret_type, orig_position], ...]
self.backward_inputs_list = [
] #[ [attr_name, attr_type, default_value, orig_position], ...]
self.backward_attrs_list = [ self.backward_attrs_list = [
] #[ [attr_name, attr_type, default_value, orig_position], ...]
self.backward_inputs_list = [
] #[ [arg_name, arg_type, orig_position], ...] ] #[ [arg_name, arg_type, orig_position], ...]
self.backward_returns_list = [ self.backward_returns_list = [
] #[ [ret_name, ret_type, orig_position], ...] ] #[ [ret_name, ret_type, orig_position], ...]
...@@ -504,11 +504,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -504,11 +504,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
for _, _, pos in forward_inputs_list: for _, _, pos in forward_inputs_list:
max_input_position = max(max_input_position, pos) max_input_position = max(max_input_position, pos)
max_attr_position = -1
for _, _, _, pos in forward_attrs_list: for _, _, _, pos in forward_attrs_list:
assert pos > max_input_position, AssertMessage(pos, assert pos > max_input_position, AssertMessage(pos,
max_input_position) max_input_position)
max_attr_position = max(max_attr_position, pos)
def BackwardValidationCheck(self): def BackwardValidationCheck(self):
backward_forward_inputs_map = self.backward_forward_inputs_map backward_forward_inputs_map = self.backward_forward_inputs_map
...@@ -692,12 +690,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -692,12 +690,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
else: else:
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name});" set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name});"
set_input_tensor_wrappers_list.append(set_tensor_wrappers) set_input_tensor_wrappers_list.append(set_tensor_wrappers)
else: else: # Forwad's output as backward's input
if num_fwd_outputs > 1: if num_fwd_outputs > 1:
# Aligned with forward output position # Aligned with forward output position
assert name in forward_outputs_position_map.keys( assert name in forward_outputs_position_map.keys(
), AssertMessage(name, forward_outputs_position_map.keys()) ), AssertMessage(name, forward_outputs_position_map.keys())
fwd_output_pos = forward_outputs_position_map[name][1]
if is_optional: if is_optional:
set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()));" set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()));"
...@@ -733,7 +730,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -733,7 +730,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_grad_out_meta_list.append(set_grad_out_meta) set_grad_out_meta_list.append(set_grad_out_meta)
set_grad_out_meta_str = "\n".join(set_grad_out_meta_list) set_grad_out_meta_str = "\n".join(set_grad_out_meta_list)
# SetOutRank & SetHistory & SetGradInMeta # SetOutRank & SetHistory & SetGradInMeta & CheckAndRetainGrad
set_out_rank_list = [] set_out_rank_list = []
set_history_list = [] set_history_list = []
set_grad_in_meta_list = [] set_grad_in_meta_list = []
...@@ -741,11 +738,12 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -741,11 +738,12 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
num_outputs = len(forward_outputs_position_map.keys()) num_outputs = len(forward_outputs_position_map.keys())
for name, (_, pos) in forward_outputs_position_map.items(): for name, (_, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name) output_autograd_meta_name = GetAutoGradMetaName(name)
set_out_rank = f"{indent}egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});" set_out_rank = f"{indent}egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f"{indent}egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);" set_history = f"{indent}egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
set_retain_grad = f"{indent}egr::EagerUtils::CheckAndRetainGrad({name});"
set_grad_in_meta = f"{indent}grad_node->SetGradInMeta({name}, {pos});" set_grad_in_meta = f"{indent}grad_node->SetGradInMeta({name}, {pos});"
set_retain_grad = f"{indent}egr::EagerUtils::CheckAndRetainGrad({name});"
set_out_rank_list.append(set_out_rank) set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history) set_history_list.append(set_history)
set_grad_in_meta_list.append(set_grad_in_meta) set_grad_in_meta_list.append(set_grad_in_meta)
...@@ -806,7 +804,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -806,7 +804,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self.DetermineForwardPositionMap(self.forward_inputs_list, self.DetermineForwardPositionMap(self.forward_inputs_list,
self.forward_returns_list) self.forward_returns_list)
# Initialize forward_inputs_position_map, forward_outputs_position_map # Initialize backward_forward_inputs_map, backward_grad_inputs_map, backward_grad_outputs_map
self.SlotNameMatching() self.SlotNameMatching()
# Backward Validation Check # Backward Validation Check
...@@ -822,18 +820,16 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -822,18 +820,16 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
self.forward_definition_str = "" self.forward_definition_str = ""
self.forward_declaration_str = "" self.forward_declaration_str = ""
def GenerateForwardDefinition(self, is_inplaced): def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
namespace = self.namespace namespace = self.namespace
forward_api_name = GetInplacedFunctionName( forward_api_name = GetInplacedFunctionName(
self.forward_api_name) if is_inplaced else self.forward_api_name self.forward_api_name) if is_inplaced else self.forward_api_name
backward_api_name = self.backward_api_name
forward_inputs_position_map = self.forward_inputs_position_map forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map forward_outputs_position_map = self.forward_outputs_position_map
forward_attrs_list = self.forward_attrs_list forward_attrs_list = self.forward_attrs_list
backward_forward_inputs_map = self.backward_forward_inputs_map
backward_grad_inputs_map = self.backward_grad_inputs_map
backward_grad_outputs_map = self.backward_grad_outputs_map
backward_attrs_list = self.backward_attrs_list
optional_inputs = self.optional_inputs optional_inputs = self.optional_inputs
intermediate_outputs = self.intermediate_outputs intermediate_outputs = self.intermediate_outputs
inplace_map = self.inplace_map if is_inplaced else {} inplace_map = self.inplace_map if is_inplaced else {}
...@@ -845,6 +841,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -845,6 +841,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
inputs_args_definition_list = ["" for i in range(num_inputs)] inputs_args_definition_list = ["" for i in range(num_inputs)]
inputs_args_declaration_list = ["" for i in range(num_inputs)] inputs_args_declaration_list = ["" for i in range(num_inputs)]
inputs_call_list = ["" for i in range(num_inputs)] inputs_call_list = ["" for i in range(num_inputs)]
amp_inputs_call_list = ["" for i in range(num_inputs)] amp_inputs_call_list = ["" for i in range(num_inputs)]
amp_tensors_vector_list = [] amp_tensors_vector_list = []
amp_tensors_vector_optional_list = [] amp_tensors_vector_optional_list = []
...@@ -1019,9 +1016,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1019,9 +1016,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format( bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
inplace_name, inplace_name) inplace_name, inplace_name)
# Node Creation
self.GenerateNodeCreationCodes() self.GenerateNodeCreationCodes()
node_creation_str = self.node_creation_str node_creation_str = self.node_creation_str
dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n" dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n"
forward_function_name = GetDygraphForwardFunctionName(forward_api_name) forward_function_name = GetDygraphForwardFunctionName(forward_api_name)
...@@ -1045,6 +1043,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1045,6 +1043,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str, amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str,
amp_autocast_list_str, amp_call_str) amp_autocast_list_str, amp_call_str)
# Generate forward_definition_str and forward_declaration_str
self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format( self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name, inputs_args_definition_str, returns_type_str, forward_function_name, inputs_args_definition_str,
dygraph_event_str, amp_logic_str, inputs_autograd_meta_str, dygraph_event_str, amp_logic_str, inputs_autograd_meta_str,
...@@ -1061,8 +1060,8 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1061,8 +1060,8 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
if forward_api_name != "sum" and "inplace" in forward_api_contents.keys( if forward_api_name != "sum" and "inplace" in forward_api_contents.keys(
): ):
# Node Definition Generation # Function Definition and Declaration Generation
self.GenerateForwardDefinition(is_inplaced=True) self.GenerateForwardDefinitionAndDeclaration(is_inplaced=True)
self.UpdateCoreOpsInformation(is_inplaced=True) self.UpdateCoreOpsInformation(is_inplaced=True)
def UpdateCoreOpsInformation(self, is_inplaced): def UpdateCoreOpsInformation(self, is_inplaced):
...@@ -1083,6 +1082,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1083,6 +1082,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
final_state_fwd_api_name] = ["" for i in range(num_args)] final_state_fwd_api_name] = ["" for i in range(num_args)]
core_ops_args_type_info[ core_ops_args_type_info[
final_state_fwd_api_name] = ["" for i in range(num_args)] final_state_fwd_api_name] = ["" for i in range(num_args)]
for name, (ttype, pos) in forward_inputs_position_map.items(): for name, (ttype, pos) in forward_inputs_position_map.items():
core_ops_args_info[final_state_fwd_api_name][pos] = name core_ops_args_info[final_state_fwd_api_name][pos] = name
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
...@@ -1104,7 +1104,9 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1104,7 +1104,9 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
##################### #####################
## Code Generation ## ## Code Generation ##
##################### #####################
self.GenerateForwardDefinition(is_inplaced=False)
# Definition And Declaration
self.GenerateForwardDefinitionAndDeclaration(is_inplaced=False)
self.UpdateCoreOpsInformation(is_inplaced=False) self.UpdateCoreOpsInformation(is_inplaced=False)
...@@ -1164,9 +1166,10 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1164,9 +1166,10 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_api_contents = self.grad_api_contents grad_api_contents = self.grad_api_contents
next_grad_api_contents = self.next_grad_api_contents next_grad_api_contents = self.next_grad_api_contents
grad_node_creation_str = "" next_grad_node_creation_str = ""
grad_node_out_list = [] next_grad_node_out_list = []
if next_grad_api_contents: if next_grad_api_contents:
# Fake forward_api_contents and backward_api_contents
forward_api_contents = grad_api_contents forward_api_contents = grad_api_contents
forward_api_contents['api'] = forward_api_contents['backward_api'] forward_api_contents['api'] = forward_api_contents['backward_api']
backward_api_contents = next_grad_api_contents backward_api_contents = next_grad_api_contents
...@@ -1175,12 +1178,12 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1175,12 +1178,12 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
forward_api_contents, backward_api_contents, namespace) forward_api_contents, backward_api_contents, namespace)
next_node_generator.run() next_node_generator.run()
next_node_generator.GenerateNodeCreationCodes() next_node_generator.GenerateNodeCreationCodes()
grad_node_creation_str = next_node_generator.node_creation_str next_grad_node_creation_str = next_node_generator.node_creation_str
grad_node_out_list = next_node_generator.grad_node_out_list next_grad_node_out_list = next_node_generator.grad_node_out_list
self.RecordGrad2NextGradNameMapping(next_node_generator) self.RecordGrad2NextGradNameMapping(next_node_generator)
return grad_node_creation_str, grad_node_out_list return next_grad_node_creation_str, next_grad_node_out_list
def GenerateNodeDeclaration(self): def GenerateNodeDeclaration(self):
forward_op_name = self.forward_api_name forward_op_name = self.forward_api_name
...@@ -1188,7 +1191,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1188,7 +1191,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
backward_attrs_list = self.backward_attrs_list backward_attrs_list = self.backward_attrs_list
no_need_buffers = self.no_need_buffers no_need_buffers = self.no_need_buffers
# SetTensorWrapper Methods & TensorWrapper Members # SetTensorWrapper Methods & TensorWrapper Members & ClearTensorWrappers
set_tensor_wrapper_methods_str = "" set_tensor_wrapper_methods_str = ""
tensor_wrapper_members_str = "" tensor_wrapper_members_str = ""
clear_tensor_wrapper_str = "" clear_tensor_wrapper_str = ""
...@@ -1241,8 +1244,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1241,8 +1244,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
set_attribute_methods_str, tensor_wrapper_members_str, set_attribute_methods_str, tensor_wrapper_members_str,
attribute_members_str) attribute_members_str)
def GenerateNodeDefinition(self, grad_node_creation_str, def GenerateNodeDefinition(self, next_grad_node_creation_str,
grad_node_out_list): next_grad_node_out_list):
namespace = self.namespace namespace = self.namespace
forward_api_name = self.forward_api_name forward_api_name = self.forward_api_name
backward_api_name = self.backward_api_name backward_api_name = self.backward_api_name
...@@ -1362,14 +1365,14 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1362,14 +1365,14 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
inputs_autograd_meta_str = "" inputs_autograd_meta_str = ""
outputs_autograd_meta_str = "" outputs_autograd_meta_str = ""
compute_require_grad_str = "" compute_require_grad_str = ""
if len(grad_node_creation_str) > 0: if len(next_grad_node_creation_str) > 0:
# 1. Get Input AutoGradMeta # 1. Get Grad Input AutoGradMeta
inputs_autograd_meta_list = [] inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"] compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos, for name, (ttype, pos,
grad_api_position) in backward_grad_inputs_map.items(): grad_api_position) in backward_grad_inputs_map.items():
transformed_tensor_name = self.TransformToNextGradName(name) transformed_tensor_name = self.TransformToNextGradName(name)
if transformed_tensor_name in grad_node_out_list: if transformed_tensor_name in next_grad_node_out_list:
input_autograd_meta_name = GetAutoGradMetaName( input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name) transformed_tensor_name)
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
...@@ -1388,7 +1391,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1388,7 +1391,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# 2. Get TensorWrapper AutoGradMeta # 2. Get TensorWrapper AutoGradMeta
for name, (ttype, _, pos), in backward_forward_inputs_map.items(): for name, (ttype, _, pos), in backward_forward_inputs_map.items():
transformed_tensor_name = self.TransformToNextGradName(name) transformed_tensor_name = self.TransformToNextGradName(name)
if transformed_tensor_name in grad_node_out_list: if transformed_tensor_name in next_grad_node_out_list:
input_autograd_meta_name = GetAutoGradMetaName( input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name) transformed_tensor_name)
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
...@@ -1447,7 +1450,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1447,7 +1450,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name, grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name,
grad_function_call_str, check_nan_inf_str, inputs_autograd_meta_str, grad_function_call_str, check_nan_inf_str, inputs_autograd_meta_str,
outputs_autograd_meta_str, compute_require_grad_str, outputs_autograd_meta_str, compute_require_grad_str,
grad_node_creation_str, returns_str) next_grad_node_creation_str, returns_str)
def run(self): def run(self):
super().run() super().run()
...@@ -1458,27 +1461,29 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1458,27 +1461,29 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
## Code Generation ## ## Code Generation ##
##################### #####################
# Higher-order GradNode generation # Higher-order GradNode generation
grad_node_creation_str, grad_node_out_list = self.GenerateHigherOrderNodeCreationCode( next_grad_node_creation_str, next_grad_node_out_list = self.GenerateHigherOrderNodeCreationCode(
) )
self.GenerateNodeDeclaration() self.GenerateNodeDeclaration()
self.GenerateNodeDefinition(grad_node_creation_str, grad_node_out_list) self.GenerateNodeDefinition(next_grad_node_creation_str,
next_grad_node_out_list)
class DygraphYamlGenerator(YamlGeneratorBase): class DygraphForwardAndNodesGenerator(GeneratorBase):
def __init__(self, api_yaml_path, backward_yaml_path): def __init__(self, api_yaml_path, backward_yaml_path):
# Parent members: # Parent members:
# self.namespace # self.namespace
# self.api_yaml_path # self.api_yaml_path
# self.forward_api_list # self.forward_api_list
YamlGeneratorBase.__init__(self, api_yaml_path) GeneratorBase.__init__(self, api_yaml_path)
self.backward_yaml_path = backward_yaml_path self.backward_yaml_path = backward_yaml_path
self.grad_api_dict = {} self.grad_api_dict = {}
self.forward_definition_str = ""
self.forward_declaration_str = "" self.forward_declaration_str = ""
self.forward_definition_str = ""
self.node_declaration_str = "" self.node_declaration_str = ""
self.node_definition_str = "" self.node_definition_str = ""
...@@ -1518,6 +1523,7 @@ class DygraphYamlGenerator(YamlGeneratorBase): ...@@ -1518,6 +1523,7 @@ class DygraphYamlGenerator(YamlGeneratorBase):
self.forward_definition_str += function_generator.forward_definition_str + "\n" self.forward_definition_str += function_generator.forward_definition_str + "\n"
self.forward_declaration_str += function_generator.forward_declaration_str + "\n" self.forward_declaration_str += function_generator.forward_declaration_str + "\n"
# Generate Dygraph GradNode Function
while True: while True:
next_grad_api_contents = self.GetBackwardAPIContents( next_grad_api_contents = self.GetBackwardAPIContents(
backward_api_contents) backward_api_contents)
...@@ -1611,20 +1617,23 @@ if __name__ == "__main__": ...@@ -1611,20 +1617,23 @@ if __name__ == "__main__":
# Generate per Dygraph API # Generate per Dygraph API
node_declaration_str = "" node_declaration_str = ""
node_definition_str = "" node_definition_str = ""
forward_definition_str = ""
forward_declaration_str = "" forward_declaration_str = ""
forward_definition_str = ""
for i in range(len(api_yaml_paths)): for i in range(len(api_yaml_paths)):
api_yaml_path = api_yaml_paths[i] api_yaml_path = api_yaml_paths[i]
backward_yaml_path = backward_yaml_paths[i] backward_yaml_path = backward_yaml_paths[i]
generator = DygraphYamlGenerator(api_yaml_path, backward_yaml_path) generator = DygraphForwardAndNodesGenerator(api_yaml_path,
backward_yaml_path)
generator.run() generator.run()
node_declaration_str += generator.node_declaration_str + "\n" node_declaration_str += generator.node_declaration_str + "\n"
node_definition_str += generator.node_definition_str + "\n" node_definition_str += generator.node_definition_str + "\n"
forward_definition_str += generator.forward_definition_str + "\n"
forward_declaration_str += generator.forward_declaration_str + "\n" forward_declaration_str += generator.forward_declaration_str + "\n"
forward_definition_str += generator.forward_definition_str + "\n"
# Generate Files # Generate Files
nodes_h_path = args.nodes_h_path nodes_h_path = args.nodes_h_path
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import os import os
import argparse import argparse
import logging import logging
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase from codegen_utils import FunctionGeneratorBase, GeneratorBase
from codegen_utils import yaml_types_mapping from codegen_utils import yaml_types_mapping
from codegen_utils import ReadFwdFile, IsVectorTensorType, GetForwardFunctionName from codegen_utils import ReadFwdFile, IsVectorTensorType, GetForwardFunctionName
from codegen_utils import ParseYamlForward, GetInplacedFunctionName from codegen_utils import ParseYamlForward, GetInplacedFunctionName
...@@ -100,6 +100,7 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj ...@@ -100,6 +100,7 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
// Set Device ID // Set Device ID
{} {}
// Call dygraph function
decltype({}({})) out = {}({}); decltype({}({})) out = {}({});
PyEval_RestoreThread(tstate); PyEval_RestoreThread(tstate);
...@@ -341,6 +342,8 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -341,6 +342,8 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
# Generate Record Event for performance profiling # Generate Record Event for performance profiling
pythonc_record_event_str = RECORD_EVENT_TEMPLATE.format( pythonc_record_event_str = RECORD_EVENT_TEMPLATE.format(
"pythonc_record_event", forward_api_name, "pybind_imperative_func") "pythonc_record_event", forward_api_name, "pybind_imperative_func")
# Generate Python-C Function Definetion
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format( self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name, forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, set_device_str, get_eager_tensor_str, parse_attributes_str, set_device_str,
...@@ -350,6 +353,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -350,6 +353,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
# Set prefix of forward_api_name to avoid conflicts # Set prefix of forward_api_name to avoid conflicts
prefix = self.namespace.strip("::") prefix = self.namespace.strip("::")
forward_api_name_prefix = "" if prefix == "" else prefix + "_" forward_api_name_prefix = "" if prefix == "" else prefix + "_"
# Generate Python-C Function Registration # Generate Python-C Function Registration
self.python_c_function_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format( self.python_c_function_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format(
forward_api_name_prefix, forward_api_name, namespace, forward_api_name_prefix, forward_api_name, namespace,
...@@ -376,6 +380,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -376,6 +380,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
inplaced_forward_api_name, inplace_output) inplaced_forward_api_name, inplace_output)
break break
# Generate Python-C Function Definetion
self.python_c_function_str += PYTHON_C_FUNCTION_TEMPLATE.format( self.python_c_function_str += PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str, inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str, inplaced_forward_api_name, get_eager_tensor_str,
...@@ -414,17 +419,17 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -414,17 +419,17 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
return True return True
class PythonCYamlGenerator(YamlGeneratorBase): class PythonCGenerator(GeneratorBase):
def __init__(self, path): def __init__(self, path):
# Parent members: # Parent members:
# self.namespace # self.namespace
# self.api_yaml_path # self.api_yaml_path
# self.forward_api_list # self.forward_api_list
YamlGeneratorBase.__init__(self, api_yaml_path) GeneratorBase.__init__(self, api_yaml_path)
# Generated Result # Generated Result
self.python_c_functions_reg_str = ""
self.python_c_functions_str = "" self.python_c_functions_str = ""
self.python_c_functions_reg_str = ""
def GeneratePythonCFunctions(self): def GeneratePythonCFunctions(self):
namespace = self.namespace namespace = self.namespace
...@@ -436,8 +441,8 @@ class PythonCYamlGenerator(YamlGeneratorBase): ...@@ -436,8 +441,8 @@ class PythonCYamlGenerator(YamlGeneratorBase):
status = f_generator.run() status = f_generator.run()
if status == True: if status == True:
self.python_c_functions_reg_str += f_generator.python_c_function_reg_str + ",\n"
self.python_c_functions_str += f_generator.python_c_function_str + "\n" self.python_c_functions_str += f_generator.python_c_function_str + "\n"
self.python_c_functions_reg_str += f_generator.python_c_function_reg_str + ",\n"
def AttachNamespace(self): def AttachNamespace(self):
namespace = self.namespace namespace = self.namespace
...@@ -509,11 +514,11 @@ if __name__ == "__main__": ...@@ -509,11 +514,11 @@ if __name__ == "__main__":
for i in range(len(api_yaml_paths)): for i in range(len(api_yaml_paths)):
api_yaml_path = api_yaml_paths[i] api_yaml_path = api_yaml_paths[i]
y_generator = PythonCYamlGenerator(api_yaml_path) py_c_generator = PythonCGenerator(api_yaml_path)
y_generator.run() py_c_generator.run()
generated_python_c_functions += y_generator.python_c_functions_str + "\n" generated_python_c_functions += py_c_generator.python_c_functions_str + "\n"
generated_python_c_registration += y_generator.python_c_functions_reg_str + "\n" generated_python_c_registration += py_c_generator.python_c_functions_reg_str + "\n"
python_c_str = GeneratePythonCWrappers(generated_python_c_functions, python_c_str = GeneratePythonCWrappers(generated_python_c_functions,
generated_python_c_registration) generated_python_c_registration)
......
...@@ -434,7 +434,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -434,7 +434,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
vars_list = kernel['data_type'].split(',') vars_list = kernel['data_type'].split(',')
assert len( assert len(
vars_list vars_list
) == 1, f"{api} api: The number of params to set data_type only allows 2, but received {len(vars_list)}." ) == 1, f"{api} api: The number of params to set data_type only allows 1, but received {len(vars_list)}."
kernel_select_code = kernel_select_code + f""" kernel_select_code = kernel_select_code + f"""
kernel_data_type = ParseDataType({vars_list[0].strip()}); kernel_data_type = ParseDataType({vars_list[0].strip()});
""" """
...@@ -837,10 +837,10 @@ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{ ...@@ -837,10 +837,10 @@ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
return api_code return api_code
else: else:
inveke_func_name = self.invoke.split('(')[0].strip() invoke_func_name = self.invoke.split('(')[0].strip()
if inveke_func_name in self.attrs['names']: if invoke_func_name in self.attrs['names']:
# Adjust the param whose name is same with api invoked. # Adjust the param whose name is same with api invoked.
pattern = r'\W' + inveke_func_name + '[^A-Za-z0-9_(]' pattern = r'\W' + invoke_func_name + '[^A-Za-z0-9_(]'
def adjust_name(matched): def adjust_name(matched):
matched_str = matched.group() matched_str = matched.group()
......
...@@ -172,8 +172,8 @@ class BackwardAPI(BaseAPI): ...@@ -172,8 +172,8 @@ class BackwardAPI(BaseAPI):
return kernel_output, output_names, output_create return kernel_output, output_names, output_create
def gene_invoke_code(self, invoke_code, params_code): def gene_invoke_code(self, invoke_code, params_code):
inveke_func_name = invoke_code.split('(')[0].strip() invoke_func_name = invoke_code.split('(')[0].strip()
if inveke_func_name.endswith('_grad') or inveke_func_name.endswith( if invoke_func_name.endswith('_grad') or invoke_func_name.endswith(
'_grad_impl'): '_grad_impl'):
return f""" return f"""
PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册