From 94243789828f38e5220ae4b7e97553701e148000 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Wed, 23 Feb 2022 09:36:57 +0800 Subject: [PATCH] Supported intermediate outputs for eager final state codegen (#39767) * Supported intermediate outputs for eager final state codegen * Added validation check for intermediate tensors --- .../final_state_generator/eager_gen.py | 49 ++++++++++++++++--- 1 file changed, 41 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index ca02a3d3977..0578f930679 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -127,6 +127,26 @@ def ReadBwdFile(filepath): ###################### ### Yaml Parsers ### ###################### +def IntermediateValidationCheck(intermediate_outputs, forward_returns_list): + # intermediate_outputs : [name0, name1, ...] + # forward_returns_list : [[ret_name, type, orig_pos], ...] + """ + Check whether intermediate_outputs are positioned + at the very end of forward_returns_list + """ + + intermediate_positions = range( + len(forward_returns_list) - len(intermediate_outputs), + len(forward_returns_list)) + for ret_name, _, pos in forward_returns_list: + if ret_name in intermediate_outputs: + assert pos in intermediate_positions + + +def ParseIntermediate(string): + return [v.strip() for v in string.split(",")] + + def ParseNoNeedBuffer(string): # string: "x, y" no_need_buffer_set = set() @@ -742,11 +762,11 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name, return node_creation_str -def GenerateForwardDefinition(fwd_api_name, bwd_api_name, - forward_inputs_position_map, - forward_outputs_position_map, forward_attrs_list, - backward_fwd_input_map, backward_grad_input_map, - backward_grad_output_map, backward_attrs_list): +def GenerateForwardDefinition( + fwd_api_name, bwd_api_name, forward_inputs_position_map, + forward_outputs_position_map, forward_attrs_list, + backward_fwd_input_map, backward_grad_input_map, + backward_grad_output_map, backward_attrs_list, intermediate_outputs): # fwd_api_name = "" # forward_inputs_position_map = { "name" : [type, fwd_position] } # forward_outputs_position_map = { "name" : [type, fwd_position] } @@ -790,13 +810,20 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, inputs_call_args_str = ", ".join(inputs_call_list) # Forward Full Logic - forward_call_str = f"auto api_result = paddle::experimental::{fwd_api_name}({inputs_call_args_str});" + if len(intermediate_outputs) == 0: + function_name = fwd_api_name + else: + function_name = fwd_api_name + "_intermediate" + forward_call_str = f"auto api_result = paddle::experimental::{function_name}({inputs_call_args_str});" # Get return type list & outputs - num_outputs = len(forward_outputs_position_map.keys()) + num_outputs = len(forward_outputs_position_map.keys()) - len( + intermediate_outputs) returns_type_list = ["" for i in range(num_outputs)] returns_list = ["" for i in range(num_outputs)] for name, (rtype, pos) in forward_outputs_position_map.items(): + if name in intermediate_outputs: + continue if num_outputs == 1: returns_list[0] = f"api_result" else: @@ -1037,6 +1064,12 @@ if __name__ == "__main__": print("Prased Forward Attrs List: ", forward_attrs_list) print("Parsed Forward Returns List: ", forward_returns_list) + intermediate_outputs = [] + if 'intermediate' in fwd_api.keys(): + intermediate_outputs = ParseIntermediate(fwd_api['intermediate']) + + IntermediateValidationCheck(intermediate_outputs, forward_returns_list) + # Collect Original Forward Inputs/Outputs and then perform validation checks orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list = ParseYamlForward( fwd_args_str, fwd_returns_str) @@ -1095,7 +1128,7 @@ if __name__ == "__main__": fwd_api_name, bwd_api_name, forward_inputs_position_map, forward_outputs_position_map, forward_attrs_list, backward_fwd_input_map, backward_grad_input_map, - backward_grad_output_map, backward_attrs_list) + backward_grad_output_map, backward_attrs_list, intermediate_outputs) print("Generated Forward Definition: ", forward_definition_str) print("Generated Forward Declaration: ", forward_declaration_str) forward_definition_str += definition_declaration_pair[0] -- GitLab