diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 7e7114111c4e1d1ca6f7a4cbafa183284248b854..a67356d380b9e0b5ce03b92ce11f28a030d20d18 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -127,6 +127,10 @@ def ReadBwdFile(filepath): ###################### ### Yaml Parsers ### ###################### +def ParseIntermediate(string): + return [v.strip() for v in string.split(",")] + + def ParseYamlArgs(string): # Example: const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y @@ -728,11 +732,11 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name, return node_creation_str -def GenerateForwardDefinition(fwd_api_name, bwd_api_name, - forward_inputs_position_map, - forward_outputs_position_map, forward_attrs_list, - backward_fwd_input_map, backward_grad_input_map, - backward_grad_output_map, backward_attrs_list): +def GenerateForwardDefinition( + fwd_api_name, bwd_api_name, forward_inputs_position_map, + forward_outputs_position_map, forward_attrs_list, + backward_fwd_input_map, backward_grad_input_map, + backward_grad_output_map, backward_attrs_list, intermediate_outputs): # fwd_api_name = "" # forward_inputs_position_map = { "name" : [type, fwd_position] } # forward_outputs_position_map = { "name" : [type, fwd_position] } @@ -776,13 +780,20 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, inputs_call_args_str = ", ".join(inputs_call_list) # Forward Full Logic - forward_call_str = f"auto api_result = paddle::experimental::{fwd_api_name}({inputs_call_args_str});" + if len(intermediate_outputs) == 0: + function_name = fwd_api_name + else: + function_name = fwd_api_name + "_intermediate" + forward_call_str = f"auto api_result = paddle::experimental::{function_name}({inputs_call_args_str});" # Get return type list & outputs - num_outputs = len(forward_outputs_position_map.keys()) + num_outputs = len(forward_outputs_position_map.keys()) - len( + intermediate_outputs) returns_type_list = ["" for i in range(num_outputs)] returns_list = ["" for i in range(num_outputs)] for name, (rtype, pos) in forward_outputs_position_map.items(): + if name in intermediate_outputs: + continue if num_outputs == 1: returns_list[0] = f"api_result" else: @@ -1001,6 +1012,10 @@ if __name__ == "__main__": fwd_args_str = fwd_api['args'] fwd_returns_str = fwd_api['output'] + intermediate_outputs = [] + if 'intermediate' in fwd_api.keys(): + intermediate_outputs = ParseIntermediate(fwd_api['intermediate']) + bwd_api_name = fwd_api['backward'] assert bwd_api_name in grad_api_dict.keys() bwd_api = grad_api_dict[bwd_api_name] @@ -1076,7 +1091,7 @@ if __name__ == "__main__": fwd_api_name, bwd_api_name, forward_inputs_position_map, forward_outputs_position_map, forward_attrs_list, backward_fwd_input_map, backward_grad_input_map, - backward_grad_output_map, backward_attrs_list) + backward_grad_output_map, backward_attrs_list, intermediate_outputs) print("Generated Forward Definition: ", forward_definition_str) print("Generated Forward Declaration: ", forward_declaration_str) forward_definition_str += definition_declaration_pair[0]