未验证 提交 5924458b 编写于 作者: Z zyfncg 提交者: GitHub

[Phi] Refactor format of inplace C++ api (#42735)

* update code

* change the return type for inplace dygraph api

* change the tuple construct
上级 566ccfef
...@@ -902,7 +902,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -902,7 +902,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
else: else:
function_name = GetIntermediateAPIFunctionName(function_name) function_name = GetIntermediateAPIFunctionName(function_name)
forward_call_str = f"{indent}auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});" api_out_type = "auto"
if is_inplaced and len(forward_outputs_position_map) == 1:
api_out_type = "auto&"
forward_call_str = f"{indent}{api_out_type} api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});"
num_outputs = len(forward_outputs_position_map.keys()) - len( num_outputs = len(forward_outputs_position_map.keys()) - len(
intermediate_outputs) intermediate_outputs)
...@@ -923,9 +926,16 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -923,9 +926,16 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
returns_list[pos] = f"{name}" returns_list[pos] = f"{name}"
if IsPlainTensorType(rtype): if IsPlainTensorType(rtype):
if is_inplaced and inplace_map and name in inplace_map.values():
returns_type_list[pos] = "paddle::experimental::Tensor&"
else:
returns_type_list[pos] = "paddle::experimental::Tensor" returns_type_list[pos] = "paddle::experimental::Tensor"
else: else:
assert IsVectorTensorType(rtype) assert IsVectorTensorType(rtype)
if is_inplaced and inplace_map and name in inplace_map.values():
returns_type_list[
pos] = "std::vector<paddle::experimental::Tensor>&"
else:
returns_type_list[ returns_type_list[
pos] = "std::vector<paddle::experimental::Tensor>" pos] = "std::vector<paddle::experimental::Tensor>"
...@@ -936,7 +946,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -936,7 +946,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
returns_type_str = ", ".join(returns_type_list) returns_type_str = ", ".join(returns_type_list)
returns_type_str = f"std::tuple<{returns_type_str}>" returns_type_str = f"std::tuple<{returns_type_str}>"
returns_str = ", ".join(returns_list) returns_str = ", ".join(returns_list)
returns_str = f"std::make_tuple({returns_str})" returns_str = f"{returns_type_str}{{{returns_str}}}"
# Node Creation Pre-Processing # Node Creation Pre-Processing
# 1. Get Input AutoGradMeta # 1. Get Input AutoGradMeta
......
...@@ -100,7 +100,7 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj ...@@ -100,7 +100,7 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
// Set Device ID // Set Device ID
{} {}
auto out = {}({}); decltype({}({})) out = {}({});
PyEval_RestoreThread(tstate); PyEval_RestoreThread(tstate);
tstate = nullptr; tstate = nullptr;
...@@ -344,7 +344,8 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -344,7 +344,8 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format( self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name, forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, set_device_str, get_eager_tensor_str, parse_attributes_str, set_device_str,
fwd_function_name, dygraph_function_call_str, return_str) fwd_function_name, dygraph_function_call_str, fwd_function_name,
dygraph_function_call_str, return_str)
# Set prefix of forward_api_name to avoid conflicts # Set prefix of forward_api_name to avoid conflicts
prefix = self.namespace.strip("::") prefix = self.namespace.strip("::")
...@@ -380,6 +381,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -380,6 +381,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
inplaced_forward_api_name, get_eager_tensor_str, inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, set_device_str, parse_attributes_str, set_device_str,
inplaced_fwd_function_name, dygraph_function_call_str, inplaced_fwd_function_name, dygraph_function_call_str,
inplaced_fwd_function_name, dygraph_function_call_str,
return_str) return_str)
# Generate Python-C Function Registration # Generate Python-C Function Registration
......
...@@ -714,6 +714,7 @@ ...@@ -714,6 +714,7 @@
backend : x backend : x
inplace : (x -> out) inplace : (x -> out)
view : (x -> out) view : (x -> out)
# intermediate : xshape
backward : flatten_grad backward : flatten_grad
# flip # flip
......
...@@ -32,11 +32,7 @@ class BaseAPI(object): ...@@ -32,11 +32,7 @@ class BaseAPI(object):
# names : [], list of output names # names : [], list of output names
# types : [], list of output types # types : [], list of output types
# out_size_expr : [], expression for getting size of vector<Tensor> # out_size_expr : [], expression for getting size of vector<Tensor>
# return_type : Tensor, vector<Tensor>, ..., the return type of api self.inputs, self.attrs, self.outputs, self.optional_vars = self.parse_args(
# args_str:
# args_declare : "str" // str of function params with default value. Example: (..., bool flag=false)
# args_define : "str" // str of function params without default value. Example: (..., bool flag)
self.inputs, self.attrs, self.outputs, self.args_str, self.optional_vars = self.parse_args(
self.api, api_item_yaml) self.api, api_item_yaml)
self.is_base_api = True self.is_base_api = True
...@@ -60,11 +56,38 @@ class BaseAPI(object): ...@@ -60,11 +56,38 @@ class BaseAPI(object):
def get_api_func_name(self): def get_api_func_name(self):
return self.api return self.api
def get_declare_args(self): def get_input_tensor_args(self, inplace_flag=False):
return self.args_str['args_declare'] input_args = []
inplace_type_map = {
"const Tensor&": "Tensor&",
"const std::vector<Tensor>&": "std::vector<Tensor>&"
}
for name in self.inputs['names']:
name = name.split('@')[0]
if inplace_flag and name in self.inplace_map.values():
input_args.append(inplace_type_map[self.inputs['input_info'][
name]] + ' ' + name)
else:
input_args.append(self.inputs['input_info'][name] + ' ' + name)
return input_args
def get_declare_args(self, inplace_flag=False):
declare_args = self.get_input_tensor_args(inplace_flag)
for name in self.attrs['names']:
default_value = ''
if self.attrs['attr_info'][name][1] is not None:
default_value = ' = ' + self.attrs['attr_info'][name][1]
declare_args.append(self.attrs['attr_info'][name][0] + ' ' + name +
default_value)
def get_define_args(self): return ", ".join(declare_args)
return self.args_str["args_define"]
def get_define_args(self, inplace_flag=False):
define_args = self.get_input_tensor_args(inplace_flag)
for name in self.attrs['names']:
define_args.append(self.attrs['attr_info'][name][0] + ' ' + name)
return ", ".join(define_args)
def parse_args(self, api_name, api_item_yaml): def parse_args(self, api_name, api_item_yaml):
optional_vars = [] optional_vars = []
...@@ -72,16 +95,15 @@ class BaseAPI(object): ...@@ -72,16 +95,15 @@ class BaseAPI(object):
optional_vars = [ optional_vars = [
item.strip() for item in api_item_yaml['optional'].split(',') item.strip() for item in api_item_yaml['optional'].split(',')
] ]
inputs, attrs, args_str = self.parse_input_and_attr( inputs, attrs = self.parse_input_and_attr(
api_name, api_item_yaml['args'], optional_vars) api_name, api_item_yaml['args'], optional_vars)
output_type_list, output_names, out_size_expr, return_type = self.parse_output( output_type_list, output_names, out_size_expr = self.parse_output(
api_name, api_item_yaml['output']) api_name, api_item_yaml['output'])
return inputs, attrs, { return inputs, attrs, {
'names': output_names, 'names': output_names,
'types': output_type_list, 'types': output_type_list,
'out_size_expr': out_size_expr, 'out_size_expr': out_size_expr
'return_type': return_type }, optional_vars
}, args_str, optional_vars
def parse_input_and_attr(self, api_name, args_config, optional_vars=[]): def parse_input_and_attr(self, api_name, args_config, optional_vars=[]):
inputs = {'names': [], 'input_info': {}} inputs = {'names': [], 'input_info': {}}
...@@ -131,9 +153,6 @@ class BaseAPI(object): ...@@ -131,9 +153,6 @@ class BaseAPI(object):
'DataType': 'paddle::optional<DataType>' 'DataType': 'paddle::optional<DataType>'
} }
args_declare_str = ""
args_define_str = ""
for item in args_list: for item in args_list:
item = item.strip() item = item.strip()
type_and_name = item.split(' ') type_and_name = item.split(' ')
...@@ -152,8 +171,6 @@ class BaseAPI(object): ...@@ -152,8 +171,6 @@ class BaseAPI(object):
inputs['names'].append(input_name) inputs['names'].append(input_name)
inputs['input_info'][input_name] = in_type inputs['input_info'][input_name] = in_type
args_declare_str = args_declare_str + in_type + ' ' + input_name + ', '
args_define_str = args_define_str + in_type + ' ' + input_name + ', '
has_input = True has_input = True
break break
if has_input: if has_input:
...@@ -175,16 +192,11 @@ class BaseAPI(object): ...@@ -175,16 +192,11 @@ class BaseAPI(object):
attr_type = optional_types_trans[attr_type_symbol] attr_type = optional_types_trans[attr_type_symbol]
default_value_str = "" if default_value is None else '=' + default_value default_value_str = "" if default_value is None else '=' + default_value
args_declare_str = args_declare_str + attr_type + ' ' + attr_name + default_value_str + ', '
args_define_str = args_define_str + attr_type + ' ' + attr_name + ', '
attrs['names'].append(attr_name) attrs['names'].append(attr_name)
attrs['attr_info'][attr_name] = (attr_type, default_value) attrs['attr_info'][attr_name] = (attr_type, default_value)
break break
return inputs, attrs, { return inputs, attrs
'args_declare': args_declare_str[:-2],
'args_define': args_define_str[:-2]
}
def parse_output(self, api_name, output_config): def parse_output(self, api_name, output_config):
def parse_output_item(output_item): def parse_output_item(output_item):
...@@ -211,8 +223,7 @@ class BaseAPI(object): ...@@ -211,8 +223,7 @@ class BaseAPI(object):
if len(temp_list) == 1: if len(temp_list) == 1:
out_type, out_name, size_expr = parse_output_item(temp_list[0]) out_type, out_name, size_expr = parse_output_item(temp_list[0])
return [out_type], [out_name], size_expr, self.get_return_type( return [out_type], [out_name], size_expr
[out_type])
else: else:
out_type_list = [] out_type_list = []
out_name_list = [] out_name_list = []
...@@ -221,8 +232,7 @@ class BaseAPI(object): ...@@ -221,8 +232,7 @@ class BaseAPI(object):
out_type_list.append(out_type) out_type_list.append(out_type)
out_name_list.append(out_name) out_name_list.append(out_name)
return out_type_list, out_name_list, size_expr, self.get_return_type( return out_type_list, out_name_list, size_expr
out_type_list)
def parse_infer_meta(self, infer_meta_config): def parse_infer_meta(self, infer_meta_config):
infer_meta = infer_meta_config infer_meta = infer_meta_config
...@@ -285,7 +295,7 @@ class BaseAPI(object): ...@@ -285,7 +295,7 @@ class BaseAPI(object):
return data_transform return data_transform
def parse_inplace_and_view(self, api_item_yaml): def parse_inplace_and_view(self, api_item_yaml):
inplace_map, view_map = None, None inplace_map, view_map = {}, {}
for mode in ['inplace', 'view']: for mode in ['inplace', 'view']:
if mode in api_item_yaml: if mode in api_item_yaml:
if mode == 'inplace': if mode == 'inplace':
...@@ -310,17 +320,22 @@ class BaseAPI(object): ...@@ -310,17 +320,22 @@ class BaseAPI(object):
return inplace_map, view_map return inplace_map, view_map
# Override by child class # Override by child class
def get_return_type(self, out_type_list): def get_return_type(self, inplace_flag=False):
return None return None
def gene_api_declaration(self): def gene_api_declaration(self):
api_declaration = ""
api_func_name = self.get_api_func_name()
if api_func_name[-1] != '_':
api_declaration = f""" api_declaration = f"""
PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name()}({self.get_declare_args()}); PADDLE_API {self.get_return_type()} {api_func_name}({self.get_declare_args()});
""" """
if self.is_base_api and self.inplace_map is not None: if self.is_base_api and len(self.inplace_map) > 0:
if api_func_name[-1] != '_':
api_func_name += '_'
api_declaration = api_declaration + f""" api_declaration = api_declaration + f"""
PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self.get_declare_args()}); PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_declare_args(inplace_flag=True)});
""" """
return api_declaration return api_declaration
...@@ -714,10 +729,6 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self ...@@ -714,10 +729,6 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
return input_tensor_code, kernel_args[:-2], kernel_signature return input_tensor_code, kernel_args[:-2], kernel_signature
# Override by child class
def gene_return_type_code(self):
return self.outputs['return_type']
# Override by child class # Override by child class
def gene_return_code(self): def gene_return_code(self):
return "return api_output;" return "return api_output;"
...@@ -786,9 +797,11 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self ...@@ -786,9 +797,11 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
{code_indent} {self.gene_return_code()}""" {code_indent} {self.gene_return_code()}"""
def gene_base_api_code(self, inplace_flag=False): def gene_base_api_code(self, inplace_flag=False):
api_func_name = self.get_api_func_name() + ('_' if inplace_flag else '') api_func_name = self.get_api_func_name()
if inplace_flag and api_func_name[-1] != '_':
api_func_name += '_'
api_code = f""" api_code = f"""
PADDLE_API {self.gene_return_type_code()} {api_func_name}({self.get_define_args()}) {{ PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
{self.gene_kernel_select()} {self.gene_kernel_select()}
""" """
...@@ -812,14 +825,14 @@ PADDLE_API {self.gene_return_type_code()} {api_func_name}({self.get_define_args( ...@@ -812,14 +825,14 @@ PADDLE_API {self.gene_return_type_code()} {api_func_name}({self.get_define_args(
def gene_invoke_code(self, invoke_code, params_code): def gene_invoke_code(self, invoke_code, params_code):
return f""" return f"""
PADDLE_API {self.outputs['return_type']} {self.api}({params_code}) {{ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
return {invoke_code}; return {invoke_code};
}}""" }}"""
def gene_api_code(self): def gene_api_code(self):
if self.is_base_api: if self.is_base_api:
api_code = self.gene_base_api_code() api_code = self.gene_base_api_code()
if self.inplace_map is not None: if len(self.inplace_map) > 0:
api_code = api_code + self.gene_base_api_code(inplace_flag=True) api_code = api_code + self.gene_base_api_code(inplace_flag=True)
return api_code return api_code
......
...@@ -19,6 +19,11 @@ import re ...@@ -19,6 +19,11 @@ import re
from api_base import BaseAPI, PREFIX_TENSOR_NAME from api_base import BaseAPI, PREFIX_TENSOR_NAME
inplace_out_type_map = {
"Tensor": "Tensor&",
"std::vector<Tensor>": "std::vector<Tensor>&"
}
class ForwardAPI(BaseAPI): class ForwardAPI(BaseAPI):
def __init__(self, api_item_yaml): def __init__(self, api_item_yaml):
...@@ -42,22 +47,33 @@ class ForwardAPI(BaseAPI): ...@@ -42,22 +47,33 @@ class ForwardAPI(BaseAPI):
else: else:
return False, [] return False, []
def get_return_type(self, out_type_list): def get_return_type_with_intermediate(self, inplace_flag=False):
return out_type_list[0] if len( out_type_list = []
out_type_list) == 1 else "std::tuple<" + ",".join( for i, out_type in enumerate(self.outputs['types']):
out_type_list) + ">" out_name = self.outputs['names'][i].split('@')[0]
if inplace_flag and out_name in self.inplace_map:
out_type_list.append(inplace_out_type_map[out_type])
else:
out_type_list.append(out_type)
def gene_return_type_code(self): if len(out_type_list) == 1:
if self.is_dygraph_api or len(self.intermediate_outs) == 0: return out_type_list[0]
return self.outputs['return_type']
else: else:
return_out_list = [] return "std::tuple<" + ", ".join(out_type_list) + ">"
for i, name in enumerate(self.outputs['names']):
if name.split('@')[0] not in self.intermediate_outs: def get_return_type(self, inplace_flag=False):
return_out_list.append(self.outputs['types'][i]) out_type_list = []
return return_out_list[0] if len( for i, out_type in enumerate(self.outputs['types']):
return_out_list) == 1 else "std::tuple<" + ",".join( out_name = self.outputs['names'][i].split('@')[0]
return_out_list) + ">" if inplace_flag and out_name in self.inplace_map:
out_type_list.append(inplace_out_type_map[out_type])
elif self.is_dygraph_api or out_name not in self.intermediate_outs:
out_type_list.append(out_type)
if len(out_type_list) == 1:
return out_type_list[0]
else:
return "std::tuple<" + ", ".join(out_type_list) + ">"
def gene_return_code(self): def gene_return_code(self):
if self.is_dygraph_api or len(self.intermediate_outs) == 0: if self.is_dygraph_api or len(self.intermediate_outs) == 0:
...@@ -83,17 +99,18 @@ class ForwardAPI(BaseAPI): ...@@ -83,17 +99,18 @@ class ForwardAPI(BaseAPI):
kernel_output = "" kernel_output = ""
output_names = [] output_names = []
output_create = "" output_create = ""
return_type = self.get_return_type_with_intermediate(inplace_flag)
if len(output_type_list) == 1: if len(output_type_list) == 1:
kernel_output = 'kernel_out' kernel_output = 'kernel_out'
output_names.append('kernel_out') output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[self.outputs['names'][ inplace_assign = " = " + self.inplace_map[self.outputs['names'][
0]] if inplace_flag and self.inplace_map is not None and self.outputs[ 0]] if inplace_flag and self.outputs['names'][
'names'][0] in self.inplace_map else "" 0] in self.inplace_map else ""
output_create = f""" output_create = f"""
{code_indent} {self.outputs['return_type']} api_output{inplace_assign};""" {code_indent} {return_type} api_output{inplace_assign};"""
if self.outputs['return_type'] == 'std::vector<Tensor>': if return_type == 'std::vector<Tensor>':
assert self.outputs['out_size_expr'] is not None, \ assert self.outputs['out_size_expr'] is not None, \
f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
output_create = output_create + f""" output_create = output_create + f"""
...@@ -112,15 +129,23 @@ class ForwardAPI(BaseAPI): ...@@ -112,15 +129,23 @@ class ForwardAPI(BaseAPI):
elif len(output_type_list) > 1: elif len(output_type_list) > 1:
output_create = f""" output_create = f"""
{code_indent} {self.outputs['return_type']} api_output;""" {code_indent} {return_type} api_output;"""
if inplace_flag:
output_create = f"""
{code_indent} {return_type} api_output{{"""
for out_name in self.outputs['names']:
if out_name in self.inplace_map:
output_create = output_create + self.inplace_map[
out_name] + ', '
else:
output_create += 'Tensor(), '
output_create = output_create[:-2] + '};'
for i in range(len(output_type_list)): for i in range(len(output_type_list)):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}')
if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map:
output_create = output_create + f"""
{code_indent} std::get<{i}>(api_output) = {self.inplace_map[self.outputs['names'][i]]};"""
if output_type_list[i] == 'std::vector<Tensor>': if output_type_list[i] == 'std::vector<Tensor>':
assert self.outputs['out_size_expr'][i] is not None, \ assert self.outputs['out_size_expr'][i] is not None, \
......
...@@ -35,10 +35,10 @@ class BackwardAPI(BaseAPI): ...@@ -35,10 +35,10 @@ class BackwardAPI(BaseAPI):
r"(?P<api>[a-z][a-z0-9_]+)\s*(?P<args>\([^\)]+\))\s*->\s*(?P<outputs>.+)", r"(?P<api>[a-z][a-z0-9_]+)\s*(?P<args>\([^\)]+\))\s*->\s*(?P<outputs>.+)",
forward_config) forward_config)
api = result.group('api') api = result.group('api')
_, outputs, _, _ = self.parse_output(self.api, result.group('outputs')) _, outputs, _, = self.parse_output(self.api, result.group('outputs'))
outputs = [item.split('@')[0] for item in outputs] outputs = [item.split('@')[0] for item in outputs]
fw_inputs, fw_attrs, _, = self.parse_input_and_attr( fw_inputs, fw_attrs = self.parse_input_and_attr(api,
api, result.group('args')) result.group('args'))
return api, fw_inputs, fw_attrs, outputs return api, fw_inputs, fw_attrs, outputs
...@@ -77,15 +77,15 @@ class BackwardAPI(BaseAPI): ...@@ -77,15 +77,15 @@ class BackwardAPI(BaseAPI):
f"{self.api} : Output error: The number of outputs should be less then the number of inputs of forward api. \ f"{self.api} : Output error: The number of outputs should be less then the number of inputs of forward api. \
Please check the output of {self.api} in yaml." Please check the output of {self.api} in yaml."
def get_declare_args(self): def get_declare_args(self, inplace_flag=False):
return self.get_define_args() return self.get_define_args()
def get_define_args(self): def get_define_args(self, inplace_flag=False):
out_type_map = { out_type_map = {
'Tensor': 'Tensor*', 'Tensor': 'Tensor*',
'std::vector<Tensor>': 'std::vector<Tensor*>' 'std::vector<Tensor>': 'std::vector<Tensor*>'
} }
intputs_and_attrs = self.args_str['args_define'] intputs_and_attrs = super(BackwardAPI, self).get_define_args()
outs = [] outs = []
for i, name in enumerate(self.outputs['names']): for i, name in enumerate(self.outputs['names']):
outs.append(out_type_map[self.outputs['types'][i]] + ' ' + outs.append(out_type_map[self.outputs['types'][i]] + ' ' +
...@@ -109,7 +109,7 @@ class BackwardAPI(BaseAPI): ...@@ -109,7 +109,7 @@ class BackwardAPI(BaseAPI):
else: else:
return super().gene_kernel_backend_select() return super().gene_kernel_backend_select()
def get_return_type(self, out_type_list): def get_return_type(self, inplace_flag=False):
return 'void' return 'void'
def gene_output(self, def gene_output(self,
...@@ -176,13 +176,13 @@ class BackwardAPI(BaseAPI): ...@@ -176,13 +176,13 @@ class BackwardAPI(BaseAPI):
if inveke_func_name.endswith('_grad') or inveke_func_name.endswith( if inveke_func_name.endswith('_grad') or inveke_func_name.endswith(
'_grad_impl'): '_grad_impl'):
return f""" return f"""
PADDLE_API {self.outputs['return_type']} {self.api}({params_code}) {{ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
{invoke_code}; {invoke_code};
}}""" }}"""
else: else:
return f""" return f"""
PADDLE_API {self.outputs['return_type']} {self.api}({params_code}) {{ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
*{self.outputs['names'][0].split('@')[0]} = {invoke_code}; *{self.outputs['names'][0].split('@')[0]} = {invoke_code};
}}""" }}"""
......
...@@ -25,9 +25,10 @@ class SparseAPI(ForwardAPI): ...@@ -25,9 +25,10 @@ class SparseAPI(ForwardAPI):
super(SparseAPI, self).__init__(api_item_yaml) super(SparseAPI, self).__init__(api_item_yaml)
def gene_api_declaration(self): def gene_api_declaration(self):
api_declaration = "// " + ', '.join(self.outputs['names']) return f"""
return api_declaration + super(SparseAPI, // {", ".join(self.outputs['names'])}
self).gene_api_declaration() + '\n' {super(SparseAPI, self).gene_api_declaration()}
"""
def get_kernel_tensor_out_type(self, output_name): def get_kernel_tensor_out_type(self, output_name):
sparse_type = 'TensorType::DENSE_TENSOR' sparse_type = 'TensorType::DENSE_TENSOR'
...@@ -45,6 +46,7 @@ class SparseAPI(ForwardAPI): ...@@ -45,6 +46,7 @@ class SparseAPI(ForwardAPI):
kernel_output = "" kernel_output = ""
output_names = [] output_names = []
output_create = "" output_create = ""
return_type = self.get_return_type_with_intermediate(inplace_flag)
if len(output_type_list) == 1: if len(output_type_list) == 1:
kernel_output = 'kernel_out' kernel_output = 'kernel_out'
...@@ -53,21 +55,29 @@ class SparseAPI(ForwardAPI): ...@@ -53,21 +55,29 @@ class SparseAPI(ForwardAPI):
0]] if inplace_flag and self.inplace_map is not None and self.outputs[ 0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else "" 'names'][0] in self.inplace_map else ""
output_create = f""" output_create = f"""
{self.outputs['return_type']} api_output{inplace_assign}; {return_type} api_output{inplace_assign};
auto* kernel_out = {set_out_func}(&api_output, {self.get_kernel_tensor_out_type(self.outputs['names'][0])});""" auto* kernel_out = {set_out_func}(&api_output, {self.get_kernel_tensor_out_type(self.outputs['names'][0])});"""
elif len(output_type_list) > 1: elif len(output_type_list) > 1:
output_create = f""" output_create = f"""
{self.outputs['return_type']} api_output;""" {return_type} api_output;"""
if inplace_flag:
output_create = f"""
{return_type} api_output{{"""
for out_name in self.outputs['names']:
out_name = out_name.split('@')[0]
if out_name in self.inplace_map:
output_create = output_create + self.inplace_map[
out_name] + ', '
else:
output_create += 'Tensor(), '
output_create = output_create[:-2] + '};'
for i in range(len(output_type_list)): for i in range(len(output_type_list)):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}')
if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map:
output_create = output_create + f"""
std::get<{i}>(api_output) = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f""" output_create = output_create + f"""
auto* kernel_out_{i} = {set_out_func}(&std::get<{i}>(api_output), {self.get_kernel_tensor_out_type(self.outputs['names'][i])});""" auto* kernel_out_{i} = {set_out_func}(&std::get<{i}>(api_output), {self.get_kernel_tensor_out_type(self.outputs['names'][i])});"""
...@@ -151,8 +161,11 @@ class SparseAPI(ForwardAPI): ...@@ -151,8 +161,11 @@ class SparseAPI(ForwardAPI):
{return_code}""" {return_code}"""
def gene_base_api_code(self, inplace_flag=False): def gene_base_api_code(self, inplace_flag=False):
api_func_name = self.get_api_func_name()
if inplace_flag and api_func_name[-1] != '_':
api_func_name += '_'
return f""" return f"""
PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name()}({self.get_define_args()}) {{ PADDLE_API {self.get_return_type()} {api_func_name}({self.get_define_args()}) {{
{self.gene_kernel_select()} {self.gene_kernel_select()}
{self.gen_sparse_kernel_code(inplace_flag)} {self.gen_sparse_kernel_code(inplace_flag)}
}} }}
......
...@@ -31,11 +31,8 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): ...@@ -31,11 +31,8 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
def gene_kernel_backend_select(self): def gene_kernel_backend_select(self):
return BackwardAPI.gene_kernel_backend_select(self) return BackwardAPI.gene_kernel_backend_select(self)
def get_return_type(self, out_type_list): def get_return_type(self, inplace_flag=False):
return BackwardAPI.get_return_type(self, out_type_list) return BackwardAPI.get_return_type(self)
def gene_return_type_code(self):
return self.outputs['return_type']
def gene_return_code(self): def gene_return_code(self):
return "" return ""
...@@ -43,10 +40,10 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): ...@@ -43,10 +40,10 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
def gene_api_declaration(self): def gene_api_declaration(self):
return SparseAPI.gene_api_declaration(self) return SparseAPI.gene_api_declaration(self)
def get_declare_args(self): def get_declare_args(self, inplace_flag=False):
return BackwardAPI.get_declare_args(self) return BackwardAPI.get_declare_args(self)
def get_define_args(self): def get_define_args(self, inplace_flag=False):
return BackwardAPI.get_define_args(self) return BackwardAPI.get_define_args(self)
def gene_output(self, def gene_output(self,
......
...@@ -32,7 +32,7 @@ class StringsAPI(ForwardAPI): ...@@ -32,7 +32,7 @@ class StringsAPI(ForwardAPI):
def gene_api_declaration(self): def gene_api_declaration(self):
return f""" return f"""
// {", ".join(self.outputs['names'])} // {", ".join(self.outputs['names'])}
PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_str['args_declare']}); {super(StringsAPI, self).gene_api_declaration()}
""" """
def get_kernel_tensor_out_type(self, output_name): def get_kernel_tensor_out_type(self, output_name):
...@@ -56,6 +56,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s ...@@ -56,6 +56,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s
kernel_output = "" kernel_output = ""
output_names = [] output_names = []
output_create = "" output_create = ""
return_type = self.get_return_type(inplace_flag)
if len(output_type_list) == 1: if len(output_type_list) == 1:
kernel_output = 'kernel_out' kernel_output = 'kernel_out'
...@@ -67,13 +68,12 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s ...@@ -67,13 +68,12 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s
0]] if inplace_flag and self.inplace_map is not None and self.outputs[ 0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else "" 'names'][0] in self.inplace_map else ""
output_create = f""" output_create = f"""
{self.outputs['return_type']} api_output{inplace_assign}; {return_type} api_output{inplace_assign};
{tensor_type}* kernel_out = dynamic_cast<{tensor_type}*>({set_out_func}(kernel_backend, &api_output, {kernel_tensor_out_type}));""" {tensor_type}* kernel_out = dynamic_cast<{tensor_type}*>({set_out_func}(kernel_backend, &api_output, {kernel_tensor_out_type}));"""
elif len(output_type_list) > 1: elif len(output_type_list) > 1:
output_create = f""" output_create = f"""
{self.outputs['return_type']} api_output;""" {return_type} api_output;"""
for i in range(len(output_type_list)): for i in range(len(output_type_list)):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
...@@ -264,7 +264,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s ...@@ -264,7 +264,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s
def gene_base_api_code(self, inplace_flag=False): def gene_base_api_code(self, inplace_flag=False):
api_func_name = self.get_api_func_name() api_func_name = self.get_api_func_name()
return f""" return f"""
PADDLE_API {self.outputs['return_type']} {api_func_name}({self.args_str["args_define"]}) {{ PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
{self.gene_kernel_select()} {self.gene_kernel_select()}
{self.gen_string_tensor_kernel_code(inplace_flag)} {self.gen_string_tensor_kernel_code(inplace_flag)}
}} }}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册