未验证 提交 b9342a80 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Polish eager code generation (#42822)

* [Eager] Polish eager code generation

* Remove useless code in codegen
上级 570d0322
......@@ -418,7 +418,7 @@ class FunctionGeneratorBase:
return_name] = [return_type, return_pos]
class YamlGeneratorBase:
class GeneratorBase:
def __init__(self, api_yaml_path):
self.namespace = ""
self.api_yaml_path = api_yaml_path
......
......@@ -29,7 +29,7 @@ from codegen_utils import RemoveSpecialSymbolsInName, RecoverBaseNameOfInplaceFu
from codegen_utils import GetInplacedFunctionName
from codegen_utils import ParseYamlArgs, ParseYamlReturns, ParseYamlForwardFromBackward
from codegen_utils import ParseYamlForward, ParseYamlBackward
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
from codegen_utils import FunctionGeneratorBase, GeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads
from codegen_utils import AssertMessage, GetIndent
......@@ -60,14 +60,6 @@ SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = \
}}
"""
PLAIN_TENSOR_MEMBER_TEMPLATE = \
""" egr::TensorWrapper {};
"""
CLEAR_TENSOR_WRAPPER_TEMPLATE = \
""" {}.clear();
"""
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = \
""" void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}) {{
for(const auto& eager_tensor : {}) {{
......@@ -76,10 +68,18 @@ SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = \
}}
"""
PLAIN_TENSOR_MEMBER_TEMPLATE = \
""" egr::TensorWrapper {};
"""
VECTOR_TENSOR_MEMBER_TEMPLATE = \
""" std::vector<egr::TensorWrapper> {};
"""
CLEAR_TENSOR_WRAPPER_TEMPLATE = \
""" {}.clear();
"""
CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE = \
""" for (auto& tw : {}) {{
tw.clear();
......@@ -423,9 +423,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self.forward_returns_list = [
] #[ [ret_name, ret_type, orig_position], ...]
self.backward_inputs_list = [
] #[ [attr_name, attr_type, default_value, orig_position], ...]
self.backward_attrs_list = [
] #[ [attr_name, attr_type, default_value, orig_position], ...]
self.backward_inputs_list = [
] #[ [arg_name, arg_type, orig_position], ...]
self.backward_returns_list = [
] #[ [ret_name, ret_type, orig_position], ...]
......@@ -504,11 +504,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
for _, _, pos in forward_inputs_list:
max_input_position = max(max_input_position, pos)
max_attr_position = -1
for _, _, _, pos in forward_attrs_list:
assert pos > max_input_position, AssertMessage(pos,
max_input_position)
max_attr_position = max(max_attr_position, pos)
def BackwardValidationCheck(self):
backward_forward_inputs_map = self.backward_forward_inputs_map
......@@ -692,12 +690,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
else:
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name});"
set_input_tensor_wrappers_list.append(set_tensor_wrappers)
else:
else: # Forwad's output as backward's input
if num_fwd_outputs > 1:
# Aligned with forward output position
assert name in forward_outputs_position_map.keys(
), AssertMessage(name, forward_outputs_position_map.keys())
fwd_output_pos = forward_outputs_position_map[name][1]
if is_optional:
set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()));"
......@@ -733,7 +730,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_grad_out_meta_list.append(set_grad_out_meta)
set_grad_out_meta_str = "\n".join(set_grad_out_meta_list)
# SetOutRank & SetHistory & SetGradInMeta
# SetOutRank & SetHistory & SetGradInMeta & CheckAndRetainGrad
set_out_rank_list = []
set_history_list = []
set_grad_in_meta_list = []
......@@ -741,11 +738,12 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
num_outputs = len(forward_outputs_position_map.keys())
for name, (_, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
set_out_rank = f"{indent}egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f"{indent}egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
set_retain_grad = f"{indent}egr::EagerUtils::CheckAndRetainGrad({name});"
set_grad_in_meta = f"{indent}grad_node->SetGradInMeta({name}, {pos});"
set_retain_grad = f"{indent}egr::EagerUtils::CheckAndRetainGrad({name});"
set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history)
set_grad_in_meta_list.append(set_grad_in_meta)
......@@ -806,7 +804,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self.DetermineForwardPositionMap(self.forward_inputs_list,
self.forward_returns_list)
# Initialize forward_inputs_position_map, forward_outputs_position_map
# Initialize backward_forward_inputs_map, backward_grad_inputs_map, backward_grad_outputs_map
self.SlotNameMatching()
# Backward Validation Check
......@@ -822,18 +820,16 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
self.forward_definition_str = ""
self.forward_declaration_str = ""
def GenerateForwardDefinition(self, is_inplaced):
def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
namespace = self.namespace
forward_api_name = GetInplacedFunctionName(
self.forward_api_name) if is_inplaced else self.forward_api_name
backward_api_name = self.backward_api_name
forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map
forward_attrs_list = self.forward_attrs_list
backward_forward_inputs_map = self.backward_forward_inputs_map
backward_grad_inputs_map = self.backward_grad_inputs_map
backward_grad_outputs_map = self.backward_grad_outputs_map
backward_attrs_list = self.backward_attrs_list
optional_inputs = self.optional_inputs
intermediate_outputs = self.intermediate_outputs
inplace_map = self.inplace_map if is_inplaced else {}
......@@ -845,6 +841,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
inputs_args_definition_list = ["" for i in range(num_inputs)]
inputs_args_declaration_list = ["" for i in range(num_inputs)]
inputs_call_list = ["" for i in range(num_inputs)]
amp_inputs_call_list = ["" for i in range(num_inputs)]
amp_tensors_vector_list = []
amp_tensors_vector_optional_list = []
......@@ -1019,9 +1016,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
inplace_name, inplace_name)
# Node Creation
self.GenerateNodeCreationCodes()
node_creation_str = self.node_creation_str
dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n"
forward_function_name = GetDygraphForwardFunctionName(forward_api_name)
......@@ -1045,6 +1043,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str,
amp_autocast_list_str, amp_call_str)
# Generate forward_definition_str and forward_declaration_str
self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name, inputs_args_definition_str,
dygraph_event_str, amp_logic_str, inputs_autograd_meta_str,
......@@ -1061,8 +1060,8 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
if forward_api_name != "sum" and "inplace" in forward_api_contents.keys(
):
# Node Definition Generation
self.GenerateForwardDefinition(is_inplaced=True)
# Function Definition and Declaration Generation
self.GenerateForwardDefinitionAndDeclaration(is_inplaced=True)
self.UpdateCoreOpsInformation(is_inplaced=True)
def UpdateCoreOpsInformation(self, is_inplaced):
......@@ -1083,6 +1082,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
final_state_fwd_api_name] = ["" for i in range(num_args)]
core_ops_args_type_info[
final_state_fwd_api_name] = ["" for i in range(num_args)]
for name, (ttype, pos) in forward_inputs_position_map.items():
core_ops_args_info[final_state_fwd_api_name][pos] = name
if IsPlainTensorType(ttype):
......@@ -1104,7 +1104,9 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
#####################
## Code Generation ##
#####################
self.GenerateForwardDefinition(is_inplaced=False)
# Definition And Declaration
self.GenerateForwardDefinitionAndDeclaration(is_inplaced=False)
self.UpdateCoreOpsInformation(is_inplaced=False)
......@@ -1164,9 +1166,10 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_api_contents = self.grad_api_contents
next_grad_api_contents = self.next_grad_api_contents
grad_node_creation_str = ""
grad_node_out_list = []
next_grad_node_creation_str = ""
next_grad_node_out_list = []
if next_grad_api_contents:
# Fake forward_api_contents and backward_api_contents
forward_api_contents = grad_api_contents
forward_api_contents['api'] = forward_api_contents['backward_api']
backward_api_contents = next_grad_api_contents
......@@ -1175,12 +1178,12 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
forward_api_contents, backward_api_contents, namespace)
next_node_generator.run()
next_node_generator.GenerateNodeCreationCodes()
grad_node_creation_str = next_node_generator.node_creation_str
grad_node_out_list = next_node_generator.grad_node_out_list
next_grad_node_creation_str = next_node_generator.node_creation_str
next_grad_node_out_list = next_node_generator.grad_node_out_list
self.RecordGrad2NextGradNameMapping(next_node_generator)
return grad_node_creation_str, grad_node_out_list
return next_grad_node_creation_str, next_grad_node_out_list
def GenerateNodeDeclaration(self):
forward_op_name = self.forward_api_name
......@@ -1188,7 +1191,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
backward_attrs_list = self.backward_attrs_list
no_need_buffers = self.no_need_buffers
# SetTensorWrapper Methods & TensorWrapper Members
# SetTensorWrapper Methods & TensorWrapper Members & ClearTensorWrappers
set_tensor_wrapper_methods_str = ""
tensor_wrapper_members_str = ""
clear_tensor_wrapper_str = ""
......@@ -1241,8 +1244,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
set_attribute_methods_str, tensor_wrapper_members_str,
attribute_members_str)
def GenerateNodeDefinition(self, grad_node_creation_str,
grad_node_out_list):
def GenerateNodeDefinition(self, next_grad_node_creation_str,
next_grad_node_out_list):
namespace = self.namespace
forward_api_name = self.forward_api_name
backward_api_name = self.backward_api_name
......@@ -1362,14 +1365,14 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
inputs_autograd_meta_str = ""
outputs_autograd_meta_str = ""
compute_require_grad_str = ""
if len(grad_node_creation_str) > 0:
# 1. Get Input AutoGradMeta
if len(next_grad_node_creation_str) > 0:
# 1. Get Grad Input AutoGradMeta
inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos,
grad_api_position) in backward_grad_inputs_map.items():
transformed_tensor_name = self.TransformToNextGradName(name)
if transformed_tensor_name in grad_node_out_list:
if transformed_tensor_name in next_grad_node_out_list:
input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
if IsPlainTensorType(ttype):
......@@ -1388,7 +1391,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# 2. Get TensorWrapper AutoGradMeta
for name, (ttype, _, pos), in backward_forward_inputs_map.items():
transformed_tensor_name = self.TransformToNextGradName(name)
if transformed_tensor_name in grad_node_out_list:
if transformed_tensor_name in next_grad_node_out_list:
input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
if IsPlainTensorType(ttype):
......@@ -1447,7 +1450,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name,
grad_function_call_str, check_nan_inf_str, inputs_autograd_meta_str,
outputs_autograd_meta_str, compute_require_grad_str,
grad_node_creation_str, returns_str)
next_grad_node_creation_str, returns_str)
def run(self):
super().run()
......@@ -1458,27 +1461,29 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
## Code Generation ##
#####################
# Higher-order GradNode generation
grad_node_creation_str, grad_node_out_list = self.GenerateHigherOrderNodeCreationCode(
next_grad_node_creation_str, next_grad_node_out_list = self.GenerateHigherOrderNodeCreationCode(
)
self.GenerateNodeDeclaration()
self.GenerateNodeDefinition(grad_node_creation_str, grad_node_out_list)
self.GenerateNodeDefinition(next_grad_node_creation_str,
next_grad_node_out_list)
class DygraphYamlGenerator(YamlGeneratorBase):
class DygraphForwardAndNodesGenerator(GeneratorBase):
def __init__(self, api_yaml_path, backward_yaml_path):
# Parent members:
# self.namespace
# self.api_yaml_path
# self.forward_api_list
YamlGeneratorBase.__init__(self, api_yaml_path)
GeneratorBase.__init__(self, api_yaml_path)
self.backward_yaml_path = backward_yaml_path
self.grad_api_dict = {}
self.forward_definition_str = ""
self.forward_declaration_str = ""
self.forward_definition_str = ""
self.node_declaration_str = ""
self.node_definition_str = ""
......@@ -1518,6 +1523,7 @@ class DygraphYamlGenerator(YamlGeneratorBase):
self.forward_definition_str += function_generator.forward_definition_str + "\n"
self.forward_declaration_str += function_generator.forward_declaration_str + "\n"
# Generate Dygraph GradNode Function
while True:
next_grad_api_contents = self.GetBackwardAPIContents(
backward_api_contents)
......@@ -1611,20 +1617,23 @@ if __name__ == "__main__":
# Generate per Dygraph API
node_declaration_str = ""
node_definition_str = ""
forward_definition_str = ""
forward_declaration_str = ""
forward_definition_str = ""
for i in range(len(api_yaml_paths)):
api_yaml_path = api_yaml_paths[i]
backward_yaml_path = backward_yaml_paths[i]
generator = DygraphYamlGenerator(api_yaml_path, backward_yaml_path)
generator = DygraphForwardAndNodesGenerator(api_yaml_path,
backward_yaml_path)
generator.run()
node_declaration_str += generator.node_declaration_str + "\n"
node_definition_str += generator.node_definition_str + "\n"
forward_definition_str += generator.forward_definition_str + "\n"
forward_declaration_str += generator.forward_declaration_str + "\n"
forward_definition_str += generator.forward_definition_str + "\n"
# Generate Files
nodes_h_path = args.nodes_h_path
......
......@@ -15,7 +15,7 @@
import os
import argparse
import logging
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
from codegen_utils import FunctionGeneratorBase, GeneratorBase
from codegen_utils import yaml_types_mapping
from codegen_utils import ReadFwdFile, IsVectorTensorType, GetForwardFunctionName
from codegen_utils import ParseYamlForward, GetInplacedFunctionName
......@@ -100,6 +100,7 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
// Set Device ID
{}
// Call dygraph function
decltype({}({})) out = {}({});
PyEval_RestoreThread(tstate);
......@@ -341,6 +342,8 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
# Generate Record Event for performance profiling
pythonc_record_event_str = RECORD_EVENT_TEMPLATE.format(
"pythonc_record_event", forward_api_name, "pybind_imperative_func")
# Generate Python-C Function Definetion
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, set_device_str,
......@@ -350,6 +353,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
# Set prefix of forward_api_name to avoid conflicts
prefix = self.namespace.strip("::")
forward_api_name_prefix = "" if prefix == "" else prefix + "_"
# Generate Python-C Function Registration
self.python_c_function_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format(
forward_api_name_prefix, forward_api_name, namespace,
......@@ -376,6 +380,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
inplaced_forward_api_name, inplace_output)
break
# Generate Python-C Function Definetion
self.python_c_function_str += PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str,
......@@ -414,17 +419,17 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
return True
class PythonCYamlGenerator(YamlGeneratorBase):
class PythonCGenerator(GeneratorBase):
def __init__(self, path):
# Parent members:
# self.namespace
# self.api_yaml_path
# self.forward_api_list
YamlGeneratorBase.__init__(self, api_yaml_path)
GeneratorBase.__init__(self, api_yaml_path)
# Generated Result
self.python_c_functions_reg_str = ""
self.python_c_functions_str = ""
self.python_c_functions_reg_str = ""
def GeneratePythonCFunctions(self):
namespace = self.namespace
......@@ -436,8 +441,8 @@ class PythonCYamlGenerator(YamlGeneratorBase):
status = f_generator.run()
if status == True:
self.python_c_functions_reg_str += f_generator.python_c_function_reg_str + ",\n"
self.python_c_functions_str += f_generator.python_c_function_str + "\n"
self.python_c_functions_reg_str += f_generator.python_c_function_reg_str + ",\n"
def AttachNamespace(self):
namespace = self.namespace
......@@ -509,11 +514,11 @@ if __name__ == "__main__":
for i in range(len(api_yaml_paths)):
api_yaml_path = api_yaml_paths[i]
y_generator = PythonCYamlGenerator(api_yaml_path)
y_generator.run()
py_c_generator = PythonCGenerator(api_yaml_path)
py_c_generator.run()
generated_python_c_functions += y_generator.python_c_functions_str + "\n"
generated_python_c_registration += y_generator.python_c_functions_reg_str + "\n"
generated_python_c_functions += py_c_generator.python_c_functions_str + "\n"
generated_python_c_registration += py_c_generator.python_c_functions_reg_str + "\n"
python_c_str = GeneratePythonCWrappers(generated_python_c_functions,
generated_python_c_registration)
......
......@@ -434,7 +434,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
vars_list = kernel['data_type'].split(',')
assert len(
vars_list
) == 1, f"{api} api: The number of params to set data_type only allows 2, but received {len(vars_list)}."
) == 1, f"{api} api: The number of params to set data_type only allows 1, but received {len(vars_list)}."
kernel_select_code = kernel_select_code + f"""
kernel_data_type = ParseDataType({vars_list[0].strip()});
"""
......@@ -837,10 +837,10 @@ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
return api_code
else:
inveke_func_name = self.invoke.split('(')[0].strip()
if inveke_func_name in self.attrs['names']:
invoke_func_name = self.invoke.split('(')[0].strip()
if invoke_func_name in self.attrs['names']:
# Adjust the param whose name is same with api invoked.
pattern = r'\W' + inveke_func_name + '[^A-Za-z0-9_(]'
pattern = r'\W' + invoke_func_name + '[^A-Za-z0-9_(]'
def adjust_name(matched):
matched_str = matched.group()
......
......@@ -172,8 +172,8 @@ class BackwardAPI(BaseAPI):
return kernel_output, output_names, output_create
def gene_invoke_code(self, invoke_code, params_code):
inveke_func_name = invoke_code.split('(')[0].strip()
if inveke_func_name.endswith('_grad') or inveke_func_name.endswith(
invoke_func_name = invoke_code.split('(')[0].strip()
if invoke_func_name.endswith('_grad') or invoke_func_name.endswith(
'_grad_impl'):
return f"""
PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册