From 33b3e28a6b060cd75de41d94a10ca7fb10a69d58 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 26 Jan 2022 14:47:18 +0800 Subject: [PATCH] change output of backward_api (#39229) --- python/paddle/utils/code_gen/api_gen.py | 43 ++++++++++++--- .../paddle/utils/code_gen/backward_api_gen.py | 53 +++++++++++++++---- python/paddle/utils/code_gen/gen_utils.py | 28 ++-------- 3 files changed, 82 insertions(+), 42 deletions(-) diff --git a/python/paddle/utils/code_gen/api_gen.py b/python/paddle/utils/code_gen/api_gen.py index 6bb02ab9d4..09182768f2 100644 --- a/python/paddle/utils/code_gen/api_gen.py +++ b/python/paddle/utils/code_gen/api_gen.py @@ -31,7 +31,12 @@ class API: # names : [], list of attribute names # attr_info : { attr_name : (type, default_values)} self.args = gen_utils.parse_args(self.api, api_item_yaml['args']) - self.output = api_item_yaml['output'] + self.out_type_list, _ = gen_utils.parse_output(self.api, + api_item_yaml['output']) + self.return_type = self.out_type_list[0] if len( + self.out_type_list) == 1 else "std::tuple<" + ",".join( + self.out_type_list) + ">" + self.is_base_api = True if 'invoke' in api_item_yaml: self.is_base_api = False @@ -54,18 +59,44 @@ class API: def gene_api_declaration(self): return f""" -PADDLE_API {self.output} {self.api}({self.args['args_declare']}); +PADDLE_API {self.return_type} {self.api}({self.args['args_declare']}); """ + def gene_output(self, output_type_list): + kernel_output = "" + output_create = "" + + if len(output_type_list) == 1: + kernel_output = 'dense_out' + output_create = f""" + {self.return_type} out; + auto dense_out = SetKernelOutput(out_meta, kernel_backend, &out);""" + + elif len(output_type_list) > 1: + output_create = f""" + {self.return_type} out;""" + + for i in range(len(output_type_list)): + kernel_output = kernel_output + f'dense_out_{i}, ' + output_create = output_create + f""" + auto dense_out_{i} = SetKernelOutput(std::get<{i}>(out_meta), kernel_backend, &std::get<{i}>(out));""" + + kernel_output = kernel_output[:-2] + else: + raise ValueError( + "{} : Output error: the output should not be empty.".format( + self.api)) + + return kernel_output, output_create + def gene_api_code(self): if self.is_base_api: input_tensors, kernel_args = gen_utils.get_kernel_args( self.args['inputs']['names'], self.args['attrs'], self.kernel['param']) - out_type, _ = gen_utils.parse_output(self.api, self.output) - outputs_args, output_create = gen_utils.gene_output(out_type) + outputs_args, output_create = self.gene_output(self.out_type_list) return f""" -PADDLE_API {self.output} {self.api}({self.args["args_define"]}) {{ +PADDLE_API {self.return_type} {self.api}({self.args["args_define"]}) {{ {gen_utils.gene_kernel_select(self.api, self.args['inputs']['names'], self.args['attrs'], self.kernel)} auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); @@ -82,7 +113,7 @@ PADDLE_API {self.output} {self.api}({self.args["args_define"]}) {{ else: return f""" -PADDLE_API {self.output} {self.api}({self.args["args_define"]}) {{ +PADDLE_API {self.return_type} {self.api}({self.args["args_define"]}) {{ return {self.invoke}; }} """ diff --git a/python/paddle/utils/code_gen/backward_api_gen.py b/python/paddle/utils/code_gen/backward_api_gen.py index 0cb14327f6..d55759b51c 100644 --- a/python/paddle/utils/code_gen/backward_api_gen.py +++ b/python/paddle/utils/code_gen/backward_api_gen.py @@ -23,9 +23,11 @@ import gen_utils class BackwardAPI: def __init__(self, backward_item_yaml): self.backward_api = backward_item_yaml['backward_api'] - self.args, self.output_type, self.return_comment = self.parse_and_check_args( + self.args, self.output_type_list, self.return_comment = self.parse_and_check_args( backward_item_yaml['forward'], backward_item_yaml['args'], backward_item_yaml['output']) + self.return_type = self.output_type_list[0] if len( + self.output_type_list) == 1 else "std::vector>" self.is_base_api = True if 'invoke' in backward_item_yaml: @@ -81,36 +83,65 @@ class BackwardAPI: Please check the args of {self.backward_api} in yaml." # check the output of backward - output_type, return_comment = gen_utils.parse_output(self.backward_api, - output_config) - assert output_type.count('Tensor') <= len(fw_inputs['names']), \ + out_type_list, return_comment = gen_utils.parse_output( + self.backward_api, output_config) + assert len(out_type_list) <= len(fw_inputs['names']), \ f"{self.backward_api} : Output error: The number of ouputs should be less then the number of inputs of forward api. \ Please check the output of {self.backward_api} in yaml." - return bw_args, output_type, return_comment + return bw_args, out_type_list, return_comment def gene_api_declaration(self): if self.return_comment: return f""" // {self.return_comment} -{self.output_type} {self.backward_api}({self.args['args_declare']}); +{self.return_type} {self.backward_api}({self.args['args_declare']}); """ else: return f""" -{self.output_type} {self.backward_api}({self.args['args_declare']}); +{self.return_type} {self.backward_api}({self.args['args_declare']}); """ + def gene_output(self, output_type_list): + kernel_output = "" + output_create = "" + + if len(output_type_list) == 1: + return_type = output_type_list[0] + kernel_output = 'dense_out' + output_create = f""" + {self.return_type} out; + auto dense_out = SetKernelOutput(out_meta, kernel_backend, &out);""" + + elif len(output_type_list) > 1: + output_create = f""" + {self.return_type} out;""" + + for i, out_type_item in enumerate(output_type_list): + kernel_output = kernel_output + f'dense_out_{i}, ' + get_out_code = f'&out[{i}][0]' if out_type_item == 'Tensor' else f'&out[{i}]' + output_create = output_create + f""" + auto dense_out_{i} = SetKernelOutput(std::get<{i}>(out_meta), kernel_backend, {get_out_code});""" + + kernel_output = kernel_output[:-2] + else: + raise ValueError( + "{} : Output error: the output should not be empty.".format( + self.backward_api)) + + return kernel_output, output_create + def gene_api_code(self): if self.is_base_api: input_tensors, kernel_args = gen_utils.get_kernel_args( self.args['inputs']['names'], self.args['attrs'], self.kernel['param']) - outputs_args, output_create = gen_utils.gene_output( - self.output_type) + outputs_args, output_create = self.gene_output( + self.output_type_list) return f""" // {self.return_comment} -{self.output_type} {self.backward_api}({self.args["args_define"]}) {{ +{self.return_type} {self.backward_api}({self.args["args_define"]}) {{ {gen_utils.gene_kernel_select(self.backward_api, self.args['inputs']['names'], self.args['attrs'], self.kernel)} auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); @@ -143,7 +174,7 @@ class BackwardAPI: params_code = self.args["args_define"] return f""" // {self.return_comment} -{self.output_type} {self.backward_api}({params_code}) {{ +{self.return_type} {self.backward_api}({params_code}) {{ return {invoke_code}; }} """ diff --git a/python/paddle/utils/code_gen/gen_utils.py b/python/paddle/utils/code_gen/gen_utils.py index 9d368c292b..bdc2942055 100644 --- a/python/paddle/utils/code_gen/gen_utils.py +++ b/python/paddle/utils/code_gen/gen_utils.py @@ -124,7 +124,7 @@ def parse_output(api_name, output_config): if len(temp_list) == 1: out_type, out_name = parse_output_item(temp_list[0]) - return out_type, out_name + return [out_type], out_name else: out_type_list = [] out_name_list = [] @@ -133,8 +133,7 @@ def parse_output(api_name, output_config): out_type_list.append(out_type) out_name_list.append(out_name) - return "std::tuple<" + ",".join(out_type_list) + ">", ", ".join( - out_name_list) + return out_type_list, ", ".join(out_name_list) def gene_kernel_select(api, input_names, attrs, kernel) -> str: @@ -241,7 +240,7 @@ def gene_kernel_select(api, input_names, attrs, kernel) -> str: if len(input_names) > 0: kernel_select_code = kernel_select_code + f""" - if (kernel_backend == Backend::UNDEFINED + if (kernel_backend == Backend::UNDEFINED || kernel_layout == DataLayout::UNDEFINED || kernel_data_type == DataType::UNDEFINED ) {{ auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args}); @@ -315,24 +314,3 @@ def get_kernel_args(input_names, attrs, kernel_param): else: kernel_args = kernel_args + str(param) + ", " return input_tensor_code, kernel_args[:-2] - - -def gene_output(output_type): - kernel_output = "" - output_create = f""" - {output_type} out;""" - - if output_type == 'Tensor' or output_type == 'std::vector': - kernel_output = 'dense_out' - output_create = output_create + """ - auto dense_out = SetKernelOutput(out_meta, kernel_backend, &out);""" - elif re.match(r'std::tuple<.*>$', output_type): - out_num = output_type.count('Tensor') - for i in range(out_num): - kernel_output = kernel_output + f'dense_out_{i}, ' - output_create = output_create + f""" - auto dense_out_{i} = SetKernelOutput(std::get<{i}>(out_meta), kernel_backend, &std::get<{i}>(out));""" - - kernel_output = kernel_output[:-2] - - return kernel_output, output_create -- GitLab