diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 4e2c6db1a44a4b58cf0adbd502157f463cde96d3..73c8027645eb88c068110d808e70816c483a4615 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -902,7 +902,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): else: function_name = GetIntermediateAPIFunctionName(function_name) - forward_call_str = f"{indent}auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});" + api_out_type = "auto" + if is_inplaced and len(forward_outputs_position_map) == 1: + api_out_type = "auto&" + forward_call_str = f"{indent}{api_out_type} api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});" num_outputs = len(forward_outputs_position_map.keys()) - len( intermediate_outputs) @@ -923,11 +926,18 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): returns_list[pos] = f"{name}" if IsPlainTensorType(rtype): - returns_type_list[pos] = "paddle::experimental::Tensor" + if is_inplaced and inplace_map and name in inplace_map.values(): + returns_type_list[pos] = "paddle::experimental::Tensor&" + else: + returns_type_list[pos] = "paddle::experimental::Tensor" else: assert IsVectorTensorType(rtype) - returns_type_list[ - pos] = "std::vector" + if is_inplaced and inplace_map and name in inplace_map.values(): + returns_type_list[ + pos] = "std::vector&" + else: + returns_type_list[ + pos] = "std::vector" if num_outputs == 1: returns_str = returns_list[0] @@ -936,7 +946,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): returns_type_str = ", ".join(returns_type_list) returns_type_str = f"std::tuple<{returns_type_str}>" returns_str = ", ".join(returns_list) - returns_str = f"std::make_tuple({returns_str})" + returns_str = f"{returns_type_str}{{{returns_str}}}" # Node Creation Pre-Processing # 1. Get Input AutoGradMeta diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index b86685c205a5cc3afd6f350983ed6c957f23116f..45e4665bd297c285db7649c2724daa4ae8346443 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -100,7 +100,7 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj // Set Device ID {} - auto out = {}({}); + decltype({}({})) out = {}({}); PyEval_RestoreThread(tstate); tstate = nullptr; @@ -328,7 +328,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): dygraph_function_call_list[pos] = f"{name}" dygraph_function_call_str = ",".join(dygraph_function_call_list) - # Generate Python-C Function Definitions + # Generate Python-C Function Definitions if is_forward_only: fwd_function_name = FUNCTION_NAME_TEMPLATE.format( "paddle::experimental::", namespace, forward_api_name) @@ -344,7 +344,8 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format( forward_api_name, pythonc_record_event_str, forward_api_name, get_eager_tensor_str, parse_attributes_str, set_device_str, - fwd_function_name, dygraph_function_call_str, return_str) + fwd_function_name, dygraph_function_call_str, fwd_function_name, + dygraph_function_call_str, return_str) # Set prefix of forward_api_name to avoid conflicts prefix = self.namespace.strip("::") @@ -380,6 +381,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): inplaced_forward_api_name, get_eager_tensor_str, parse_attributes_str, set_device_str, inplaced_fwd_function_name, dygraph_function_call_str, + inplaced_fwd_function_name, dygraph_function_call_str, return_str) # Generate Python-C Function Registration diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index f078aae9bb6b163c616a03c789ba743e69713ebb..5f533d241f1f328ab4992b82f7c6a47367f88986 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -714,6 +714,7 @@ backend : x inplace : (x -> out) view : (x -> out) + # intermediate : xshape backward : flatten_grad # flip diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index af870fcc8e54d17f0338f174ddb476d1d9b97a3c..8483325221eb42c35ff2a4f51f9c8e9b9c3bf36d 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -32,11 +32,7 @@ class BaseAPI(object): # names : [], list of output names # types : [], list of output types # out_size_expr : [], expression for getting size of vector - # return_type : Tensor, vector, ..., the return type of api - # args_str: - # args_declare : "str" // str of function params with default value. Example: (..., bool flag=false) - # args_define : "str" // str of function params without default value. Example: (..., bool flag) - self.inputs, self.attrs, self.outputs, self.args_str, self.optional_vars = self.parse_args( + self.inputs, self.attrs, self.outputs, self.optional_vars = self.parse_args( self.api, api_item_yaml) self.is_base_api = True @@ -60,11 +56,38 @@ class BaseAPI(object): def get_api_func_name(self): return self.api - def get_declare_args(self): - return self.args_str['args_declare'] + def get_input_tensor_args(self, inplace_flag=False): + input_args = [] + inplace_type_map = { + "const Tensor&": "Tensor&", + "const std::vector&": "std::vector&" + } + for name in self.inputs['names']: + name = name.split('@')[0] + if inplace_flag and name in self.inplace_map.values(): + input_args.append(inplace_type_map[self.inputs['input_info'][ + name]] + ' ' + name) + else: + input_args.append(self.inputs['input_info'][name] + ' ' + name) + return input_args + + def get_declare_args(self, inplace_flag=False): + declare_args = self.get_input_tensor_args(inplace_flag) + for name in self.attrs['names']: + default_value = '' + if self.attrs['attr_info'][name][1] is not None: + default_value = ' = ' + self.attrs['attr_info'][name][1] + declare_args.append(self.attrs['attr_info'][name][0] + ' ' + name + + default_value) - def get_define_args(self): - return self.args_str["args_define"] + return ", ".join(declare_args) + + def get_define_args(self, inplace_flag=False): + define_args = self.get_input_tensor_args(inplace_flag) + for name in self.attrs['names']: + define_args.append(self.attrs['attr_info'][name][0] + ' ' + name) + + return ", ".join(define_args) def parse_args(self, api_name, api_item_yaml): optional_vars = [] @@ -72,16 +95,15 @@ class BaseAPI(object): optional_vars = [ item.strip() for item in api_item_yaml['optional'].split(',') ] - inputs, attrs, args_str = self.parse_input_and_attr( + inputs, attrs = self.parse_input_and_attr( api_name, api_item_yaml['args'], optional_vars) - output_type_list, output_names, out_size_expr, return_type = self.parse_output( + output_type_list, output_names, out_size_expr = self.parse_output( api_name, api_item_yaml['output']) return inputs, attrs, { 'names': output_names, 'types': output_type_list, - 'out_size_expr': out_size_expr, - 'return_type': return_type - }, args_str, optional_vars + 'out_size_expr': out_size_expr + }, optional_vars def parse_input_and_attr(self, api_name, args_config, optional_vars=[]): inputs = {'names': [], 'input_info': {}} @@ -131,9 +153,6 @@ class BaseAPI(object): 'DataType': 'paddle::optional' } - args_declare_str = "" - args_define_str = "" - for item in args_list: item = item.strip() type_and_name = item.split(' ') @@ -152,8 +171,6 @@ class BaseAPI(object): inputs['names'].append(input_name) inputs['input_info'][input_name] = in_type - args_declare_str = args_declare_str + in_type + ' ' + input_name + ', ' - args_define_str = args_define_str + in_type + ' ' + input_name + ', ' has_input = True break if has_input: @@ -175,16 +192,11 @@ class BaseAPI(object): attr_type = optional_types_trans[attr_type_symbol] default_value_str = "" if default_value is None else '=' + default_value - args_declare_str = args_declare_str + attr_type + ' ' + attr_name + default_value_str + ', ' - args_define_str = args_define_str + attr_type + ' ' + attr_name + ', ' attrs['names'].append(attr_name) attrs['attr_info'][attr_name] = (attr_type, default_value) break - return inputs, attrs, { - 'args_declare': args_declare_str[:-2], - 'args_define': args_define_str[:-2] - } + return inputs, attrs def parse_output(self, api_name, output_config): def parse_output_item(output_item): @@ -211,8 +223,7 @@ class BaseAPI(object): if len(temp_list) == 1: out_type, out_name, size_expr = parse_output_item(temp_list[0]) - return [out_type], [out_name], size_expr, self.get_return_type( - [out_type]) + return [out_type], [out_name], size_expr else: out_type_list = [] out_name_list = [] @@ -221,8 +232,7 @@ class BaseAPI(object): out_type_list.append(out_type) out_name_list.append(out_name) - return out_type_list, out_name_list, size_expr, self.get_return_type( - out_type_list) + return out_type_list, out_name_list, size_expr def parse_infer_meta(self, infer_meta_config): infer_meta = infer_meta_config @@ -285,7 +295,7 @@ class BaseAPI(object): return data_transform def parse_inplace_and_view(self, api_item_yaml): - inplace_map, view_map = None, None + inplace_map, view_map = {}, {} for mode in ['inplace', 'view']: if mode in api_item_yaml: if mode == 'inplace': @@ -310,17 +320,22 @@ class BaseAPI(object): return inplace_map, view_map # Override by child class - def get_return_type(self, out_type_list): + def get_return_type(self, inplace_flag=False): return None def gene_api_declaration(self): - api_declaration = f""" -PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name()}({self.get_declare_args()}); + api_declaration = "" + api_func_name = self.get_api_func_name() + if api_func_name[-1] != '_': + api_declaration = f""" +PADDLE_API {self.get_return_type()} {api_func_name}({self.get_declare_args()}); """ - if self.is_base_api and self.inplace_map is not None: + if self.is_base_api and len(self.inplace_map) > 0: + if api_func_name[-1] != '_': + api_func_name += '_' api_declaration = api_declaration + f""" -PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self.get_declare_args()}); +PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_declare_args(inplace_flag=True)}); """ return api_declaration @@ -714,10 +729,6 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self return input_tensor_code, kernel_args[:-2], kernel_signature - # Override by child class - def gene_return_type_code(self): - return self.outputs['return_type'] - # Override by child class def gene_return_code(self): return "return api_output;" @@ -786,9 +797,11 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self {code_indent} {self.gene_return_code()}""" def gene_base_api_code(self, inplace_flag=False): - api_func_name = self.get_api_func_name() + ('_' if inplace_flag else '') + api_func_name = self.get_api_func_name() + if inplace_flag and api_func_name[-1] != '_': + api_func_name += '_' api_code = f""" -PADDLE_API {self.gene_return_type_code()} {api_func_name}({self.get_define_args()}) {{ +PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{ {self.gene_kernel_select()} """ @@ -812,14 +825,14 @@ PADDLE_API {self.gene_return_type_code()} {api_func_name}({self.get_define_args( def gene_invoke_code(self, invoke_code, params_code): return f""" -PADDLE_API {self.outputs['return_type']} {self.api}({params_code}) {{ +PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{ return {invoke_code}; }}""" def gene_api_code(self): if self.is_base_api: api_code = self.gene_base_api_code() - if self.inplace_map is not None: + if len(self.inplace_map) > 0: api_code = api_code + self.gene_base_api_code(inplace_flag=True) return api_code diff --git a/python/paddle/utils/code_gen/api_gen.py b/python/paddle/utils/code_gen/api_gen.py index 8fd95f9a191c34dc36b536dd5304332a1acba0dd..fa9128252fccaf73807eda6a93fea74a77b61f69 100644 --- a/python/paddle/utils/code_gen/api_gen.py +++ b/python/paddle/utils/code_gen/api_gen.py @@ -19,6 +19,11 @@ import re from api_base import BaseAPI, PREFIX_TENSOR_NAME +inplace_out_type_map = { + "Tensor": "Tensor&", + "std::vector": "std::vector&" +} + class ForwardAPI(BaseAPI): def __init__(self, api_item_yaml): @@ -42,22 +47,33 @@ class ForwardAPI(BaseAPI): else: return False, [] - def get_return_type(self, out_type_list): - return out_type_list[0] if len( - out_type_list) == 1 else "std::tuple<" + ",".join( - out_type_list) + ">" + def get_return_type_with_intermediate(self, inplace_flag=False): + out_type_list = [] + for i, out_type in enumerate(self.outputs['types']): + out_name = self.outputs['names'][i].split('@')[0] + if inplace_flag and out_name in self.inplace_map: + out_type_list.append(inplace_out_type_map[out_type]) + else: + out_type_list.append(out_type) - def gene_return_type_code(self): - if self.is_dygraph_api or len(self.intermediate_outs) == 0: - return self.outputs['return_type'] + if len(out_type_list) == 1: + return out_type_list[0] else: - return_out_list = [] - for i, name in enumerate(self.outputs['names']): - if name.split('@')[0] not in self.intermediate_outs: - return_out_list.append(self.outputs['types'][i]) - return return_out_list[0] if len( - return_out_list) == 1 else "std::tuple<" + ",".join( - return_out_list) + ">" + return "std::tuple<" + ", ".join(out_type_list) + ">" + + def get_return_type(self, inplace_flag=False): + out_type_list = [] + for i, out_type in enumerate(self.outputs['types']): + out_name = self.outputs['names'][i].split('@')[0] + if inplace_flag and out_name in self.inplace_map: + out_type_list.append(inplace_out_type_map[out_type]) + elif self.is_dygraph_api or out_name not in self.intermediate_outs: + out_type_list.append(out_type) + + if len(out_type_list) == 1: + return out_type_list[0] + else: + return "std::tuple<" + ", ".join(out_type_list) + ">" def gene_return_code(self): if self.is_dygraph_api or len(self.intermediate_outs) == 0: @@ -83,17 +99,18 @@ class ForwardAPI(BaseAPI): kernel_output = "" output_names = [] output_create = "" + return_type = self.get_return_type_with_intermediate(inplace_flag) if len(output_type_list) == 1: kernel_output = 'kernel_out' output_names.append('kernel_out') inplace_assign = " = " + self.inplace_map[self.outputs['names'][ - 0]] if inplace_flag and self.inplace_map is not None and self.outputs[ - 'names'][0] in self.inplace_map else "" + 0]] if inplace_flag and self.outputs['names'][ + 0] in self.inplace_map else "" output_create = f""" -{code_indent} {self.outputs['return_type']} api_output{inplace_assign};""" +{code_indent} {return_type} api_output{inplace_assign};""" - if self.outputs['return_type'] == 'std::vector': + if return_type == 'std::vector': assert self.outputs['out_size_expr'] is not None, \ f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." output_create = output_create + f""" @@ -112,15 +129,23 @@ class ForwardAPI(BaseAPI): elif len(output_type_list) > 1: output_create = f""" -{code_indent} {self.outputs['return_type']} api_output;""" +{code_indent} {return_type} api_output;""" + + if inplace_flag: + output_create = f""" +{code_indent} {return_type} api_output{{""" + + for out_name in self.outputs['names']: + if out_name in self.inplace_map: + output_create = output_create + self.inplace_map[ + out_name] + ', ' + else: + output_create += 'Tensor(), ' + output_create = output_create[:-2] + '};' for i in range(len(output_type_list)): kernel_output = kernel_output + f'kernel_out_{i}, ' output_names.append(f'kernel_out_{i}') - if inplace_flag and self.inplace_map is not None and self.outputs[ - 'names'][i] in self.inplace_map: - output_create = output_create + f""" -{code_indent} std::get<{i}>(api_output) = {self.inplace_map[self.outputs['names'][i]]};""" if output_type_list[i] == 'std::vector': assert self.outputs['out_size_expr'][i] is not None, \ diff --git a/python/paddle/utils/code_gen/backward_api_gen.py b/python/paddle/utils/code_gen/backward_api_gen.py index a155a2c3d6c9f761990b17b22d14eda2789ab07d..b918336e43b46dc5e37bb3d74ed13def6b3cae05 100644 --- a/python/paddle/utils/code_gen/backward_api_gen.py +++ b/python/paddle/utils/code_gen/backward_api_gen.py @@ -35,10 +35,10 @@ class BackwardAPI(BaseAPI): r"(?P[a-z][a-z0-9_]+)\s*(?P\([^\)]+\))\s*->\s*(?P.+)", forward_config) api = result.group('api') - _, outputs, _, _ = self.parse_output(self.api, result.group('outputs')) + _, outputs, _, = self.parse_output(self.api, result.group('outputs')) outputs = [item.split('@')[0] for item in outputs] - fw_inputs, fw_attrs, _, = self.parse_input_and_attr( - api, result.group('args')) + fw_inputs, fw_attrs = self.parse_input_and_attr(api, + result.group('args')) return api, fw_inputs, fw_attrs, outputs @@ -77,15 +77,15 @@ class BackwardAPI(BaseAPI): f"{self.api} : Output error: The number of outputs should be less then the number of inputs of forward api. \ Please check the output of {self.api} in yaml." - def get_declare_args(self): + def get_declare_args(self, inplace_flag=False): return self.get_define_args() - def get_define_args(self): + def get_define_args(self, inplace_flag=False): out_type_map = { 'Tensor': 'Tensor*', 'std::vector': 'std::vector' } - intputs_and_attrs = self.args_str['args_define'] + intputs_and_attrs = super(BackwardAPI, self).get_define_args() outs = [] for i, name in enumerate(self.outputs['names']): outs.append(out_type_map[self.outputs['types'][i]] + ' ' + @@ -109,7 +109,7 @@ class BackwardAPI(BaseAPI): else: return super().gene_kernel_backend_select() - def get_return_type(self, out_type_list): + def get_return_type(self, inplace_flag=False): return 'void' def gene_output(self, @@ -176,13 +176,13 @@ class BackwardAPI(BaseAPI): if inveke_func_name.endswith('_grad') or inveke_func_name.endswith( '_grad_impl'): return f""" -PADDLE_API {self.outputs['return_type']} {self.api}({params_code}) {{ +PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{ {invoke_code}; }}""" else: return f""" -PADDLE_API {self.outputs['return_type']} {self.api}({params_code}) {{ +PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{ *{self.outputs['names'][0].split('@')[0]} = {invoke_code}; }}""" diff --git a/python/paddle/utils/code_gen/sparse_api_gen.py b/python/paddle/utils/code_gen/sparse_api_gen.py index eb9bca2eca7b7b78d665826eeb7d2cb6546fc59f..509858d339f69675a55722818768b612231ed868 100644 --- a/python/paddle/utils/code_gen/sparse_api_gen.py +++ b/python/paddle/utils/code_gen/sparse_api_gen.py @@ -25,9 +25,10 @@ class SparseAPI(ForwardAPI): super(SparseAPI, self).__init__(api_item_yaml) def gene_api_declaration(self): - api_declaration = "// " + ', '.join(self.outputs['names']) - return api_declaration + super(SparseAPI, - self).gene_api_declaration() + '\n' + return f""" +// {", ".join(self.outputs['names'])} +{super(SparseAPI, self).gene_api_declaration()} +""" def get_kernel_tensor_out_type(self, output_name): sparse_type = 'TensorType::DENSE_TENSOR' @@ -45,6 +46,7 @@ class SparseAPI(ForwardAPI): kernel_output = "" output_names = [] output_create = "" + return_type = self.get_return_type_with_intermediate(inplace_flag) if len(output_type_list) == 1: kernel_output = 'kernel_out' @@ -53,21 +55,29 @@ class SparseAPI(ForwardAPI): 0]] if inplace_flag and self.inplace_map is not None and self.outputs[ 'names'][0] in self.inplace_map else "" output_create = f""" - {self.outputs['return_type']} api_output{inplace_assign}; + {return_type} api_output{inplace_assign}; auto* kernel_out = {set_out_func}(&api_output, {self.get_kernel_tensor_out_type(self.outputs['names'][0])});""" elif len(output_type_list) > 1: output_create = f""" - {self.outputs['return_type']} api_output;""" + {return_type} api_output;""" + + if inplace_flag: + output_create = f""" + {return_type} api_output{{""" + + for out_name in self.outputs['names']: + out_name = out_name.split('@')[0] + if out_name in self.inplace_map: + output_create = output_create + self.inplace_map[ + out_name] + ', ' + else: + output_create += 'Tensor(), ' + output_create = output_create[:-2] + '};' for i in range(len(output_type_list)): kernel_output = kernel_output + f'kernel_out_{i}, ' output_names.append(f'kernel_out_{i}') - if inplace_flag and self.inplace_map is not None and self.outputs[ - 'names'][i] in self.inplace_map: - output_create = output_create + f""" - std::get<{i}>(api_output) = {self.inplace_map[self.outputs['names'][i]]};""" - output_create = output_create + f""" auto* kernel_out_{i} = {set_out_func}(&std::get<{i}>(api_output), {self.get_kernel_tensor_out_type(self.outputs['names'][i])});""" @@ -151,8 +161,11 @@ class SparseAPI(ForwardAPI): {return_code}""" def gene_base_api_code(self, inplace_flag=False): + api_func_name = self.get_api_func_name() + if inplace_flag and api_func_name[-1] != '_': + api_func_name += '_' return f""" -PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name()}({self.get_define_args()}) {{ +PADDLE_API {self.get_return_type()} {api_func_name}({self.get_define_args()}) {{ {self.gene_kernel_select()} {self.gen_sparse_kernel_code(inplace_flag)} }} diff --git a/python/paddle/utils/code_gen/sparse_bw_api_gen.py b/python/paddle/utils/code_gen/sparse_bw_api_gen.py index 6dc4a2668ebb9a5d96eb134af8c2a96cf52ab5ca..53a99d798118e66e9d64eef5b2f283721b131a92 100644 --- a/python/paddle/utils/code_gen/sparse_bw_api_gen.py +++ b/python/paddle/utils/code_gen/sparse_bw_api_gen.py @@ -31,11 +31,8 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): def gene_kernel_backend_select(self): return BackwardAPI.gene_kernel_backend_select(self) - def get_return_type(self, out_type_list): - return BackwardAPI.get_return_type(self, out_type_list) - - def gene_return_type_code(self): - return self.outputs['return_type'] + def get_return_type(self, inplace_flag=False): + return BackwardAPI.get_return_type(self) def gene_return_code(self): return "" @@ -43,10 +40,10 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): def gene_api_declaration(self): return SparseAPI.gene_api_declaration(self) - def get_declare_args(self): + def get_declare_args(self, inplace_flag=False): return BackwardAPI.get_declare_args(self) - def get_define_args(self): + def get_define_args(self, inplace_flag=False): return BackwardAPI.get_define_args(self) def gene_output(self, diff --git a/python/paddle/utils/code_gen/strings_api_gen.py b/python/paddle/utils/code_gen/strings_api_gen.py index 815b9176cd22cb2ccd1bb8fe7d5b66d9ee151ee7..d697ce393570823e0ddc73d53c71d57a2ff39717 100644 --- a/python/paddle/utils/code_gen/strings_api_gen.py +++ b/python/paddle/utils/code_gen/strings_api_gen.py @@ -32,7 +32,7 @@ class StringsAPI(ForwardAPI): def gene_api_declaration(self): return f""" // {", ".join(self.outputs['names'])} -PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_str['args_declare']}); +{super(StringsAPI, self).gene_api_declaration()} """ def get_kernel_tensor_out_type(self, output_name): @@ -56,6 +56,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s kernel_output = "" output_names = [] output_create = "" + return_type = self.get_return_type(inplace_flag) if len(output_type_list) == 1: kernel_output = 'kernel_out' @@ -67,13 +68,12 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s 0]] if inplace_flag and self.inplace_map is not None and self.outputs[ 'names'][0] in self.inplace_map else "" output_create = f""" - {self.outputs['return_type']} api_output{inplace_assign}; - + {return_type} api_output{inplace_assign}; {tensor_type}* kernel_out = dynamic_cast<{tensor_type}*>({set_out_func}(kernel_backend, &api_output, {kernel_tensor_out_type}));""" elif len(output_type_list) > 1: output_create = f""" - {self.outputs['return_type']} api_output;""" + {return_type} api_output;""" for i in range(len(output_type_list)): kernel_output = kernel_output + f'kernel_out_{i}, ' @@ -264,7 +264,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s def gene_base_api_code(self, inplace_flag=False): api_func_name = self.get_api_func_name() return f""" -PADDLE_API {self.outputs['return_type']} {api_func_name}({self.args_str["args_define"]}) {{ +PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{ {self.gene_kernel_select()} {self.gen_string_tensor_kernel_code(inplace_flag)} }}