未验证 提交 94243789 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Supported intermediate outputs for eager final state codegen (#39767)

* Supported intermediate outputs for eager final state codegen

* Added validation check for intermediate tensors
上级 69e9e9d5
...@@ -127,6 +127,26 @@ def ReadBwdFile(filepath): ...@@ -127,6 +127,26 @@ def ReadBwdFile(filepath):
###################### ######################
### Yaml Parsers ### ### Yaml Parsers ###
###################### ######################
def IntermediateValidationCheck(intermediate_outputs, forward_returns_list):
# intermediate_outputs : [name0, name1, ...]
# forward_returns_list : [[ret_name, type, orig_pos], ...]
"""
Check whether intermediate_outputs are positioned
at the very end of forward_returns_list
"""
intermediate_positions = range(
len(forward_returns_list) - len(intermediate_outputs),
len(forward_returns_list))
for ret_name, _, pos in forward_returns_list:
if ret_name in intermediate_outputs:
assert pos in intermediate_positions
def ParseIntermediate(string):
return [v.strip() for v in string.split(",")]
def ParseNoNeedBuffer(string): def ParseNoNeedBuffer(string):
# string: "x, y" # string: "x, y"
no_need_buffer_set = set() no_need_buffer_set = set()
...@@ -742,11 +762,11 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name, ...@@ -742,11 +762,11 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
return node_creation_str return node_creation_str
def GenerateForwardDefinition(fwd_api_name, bwd_api_name, def GenerateForwardDefinition(
forward_inputs_position_map, fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list): backward_grad_output_map, backward_attrs_list, intermediate_outputs):
# fwd_api_name = "" # fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] } # forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] } # forward_outputs_position_map = { "name" : [type, fwd_position] }
...@@ -790,13 +810,20 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, ...@@ -790,13 +810,20 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
inputs_call_args_str = ", ".join(inputs_call_list) inputs_call_args_str = ", ".join(inputs_call_list)
# Forward Full Logic # Forward Full Logic
forward_call_str = f"auto api_result = paddle::experimental::{fwd_api_name}({inputs_call_args_str});" if len(intermediate_outputs) == 0:
function_name = fwd_api_name
else:
function_name = fwd_api_name + "_intermediate"
forward_call_str = f"auto api_result = paddle::experimental::{function_name}({inputs_call_args_str});"
# Get return type list & outputs # Get return type list & outputs
num_outputs = len(forward_outputs_position_map.keys()) num_outputs = len(forward_outputs_position_map.keys()) - len(
intermediate_outputs)
returns_type_list = ["" for i in range(num_outputs)] returns_type_list = ["" for i in range(num_outputs)]
returns_list = ["" for i in range(num_outputs)] returns_list = ["" for i in range(num_outputs)]
for name, (rtype, pos) in forward_outputs_position_map.items(): for name, (rtype, pos) in forward_outputs_position_map.items():
if name in intermediate_outputs:
continue
if num_outputs == 1: if num_outputs == 1:
returns_list[0] = f"api_result" returns_list[0] = f"api_result"
else: else:
...@@ -1037,6 +1064,12 @@ if __name__ == "__main__": ...@@ -1037,6 +1064,12 @@ if __name__ == "__main__":
print("Prased Forward Attrs List: ", forward_attrs_list) print("Prased Forward Attrs List: ", forward_attrs_list)
print("Parsed Forward Returns List: ", forward_returns_list) print("Parsed Forward Returns List: ", forward_returns_list)
intermediate_outputs = []
if 'intermediate' in fwd_api.keys():
intermediate_outputs = ParseIntermediate(fwd_api['intermediate'])
IntermediateValidationCheck(intermediate_outputs, forward_returns_list)
# Collect Original Forward Inputs/Outputs and then perform validation checks # Collect Original Forward Inputs/Outputs and then perform validation checks
orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list = ParseYamlForward( orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list = ParseYamlForward(
fwd_args_str, fwd_returns_str) fwd_args_str, fwd_returns_str)
...@@ -1095,7 +1128,7 @@ if __name__ == "__main__": ...@@ -1095,7 +1128,7 @@ if __name__ == "__main__":
fwd_api_name, bwd_api_name, forward_inputs_position_map, fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list) backward_grad_output_map, backward_attrs_list, intermediate_outputs)
print("Generated Forward Definition: ", forward_definition_str) print("Generated Forward Definition: ", forward_definition_str)
print("Generated Forward Declaration: ", forward_declaration_str) print("Generated Forward Declaration: ", forward_declaration_str)
forward_definition_str += definition_declaration_pair[0] forward_definition_str += definition_declaration_pair[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册