提交 e3b208ce 编写于 作者: J jim19930609

Supported intermediate outputs for eager final state codegen

上级 b29c05c7
......@@ -127,6 +127,10 @@ def ReadBwdFile(filepath):
######################
### Yaml Parsers ###
######################
def ParseIntermediate(string):
return [v.strip() for v in string.split(",")]
def ParseYamlArgs(string):
# Example: const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y
......@@ -728,11 +732,11 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
return node_creation_str
def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list):
def GenerateForwardDefinition(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, intermediate_outputs):
# fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
......@@ -776,13 +780,20 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
inputs_call_args_str = ", ".join(inputs_call_list)
# Forward Full Logic
forward_call_str = f"auto api_result = paddle::experimental::{fwd_api_name}({inputs_call_args_str});"
if len(intermediate_outputs) == 0:
function_name = fwd_api_name
else:
function_name = fwd_api_name + "_intermediate"
forward_call_str = f"auto api_result = paddle::experimental::{function_name}({inputs_call_args_str});"
# Get return type list & outputs
num_outputs = len(forward_outputs_position_map.keys())
num_outputs = len(forward_outputs_position_map.keys()) - len(
intermediate_outputs)
returns_type_list = ["" for i in range(num_outputs)]
returns_list = ["" for i in range(num_outputs)]
for name, (rtype, pos) in forward_outputs_position_map.items():
if name in intermediate_outputs:
continue
if num_outputs == 1:
returns_list[0] = f"api_result"
else:
......@@ -1001,6 +1012,10 @@ if __name__ == "__main__":
fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output']
intermediate_outputs = []
if 'intermediate' in fwd_api.keys():
intermediate_outputs = ParseIntermediate(fwd_api['intermediate'])
bwd_api_name = fwd_api['backward']
assert bwd_api_name in grad_api_dict.keys()
bwd_api = grad_api_dict[bwd_api_name]
......@@ -1076,7 +1091,7 @@ if __name__ == "__main__":
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list)
backward_grad_output_map, backward_attrs_list, intermediate_outputs)
print("Generated Forward Definition: ", forward_definition_str)
print("Generated Forward Declaration: ", forward_declaration_str)
forward_definition_str += definition_declaration_pair[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册