未验证 提交 5924458b 编写于 作者: Z zyfncg 提交者: GitHub

[Phi] Refactor format of inplace C++ api (#42735)

* update code

* change the return type for inplace dygraph api

* change the tuple construct
上级 566ccfef
......@@ -902,7 +902,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
else:
function_name = GetIntermediateAPIFunctionName(function_name)
forward_call_str = f"{indent}auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});"
api_out_type = "auto"
if is_inplaced and len(forward_outputs_position_map) == 1:
api_out_type = "auto&"
forward_call_str = f"{indent}{api_out_type} api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});"
num_outputs = len(forward_outputs_position_map.keys()) - len(
intermediate_outputs)
......@@ -923,9 +926,16 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
returns_list[pos] = f"{name}"
if IsPlainTensorType(rtype):
if is_inplaced and inplace_map and name in inplace_map.values():
returns_type_list[pos] = "paddle::experimental::Tensor&"
else:
returns_type_list[pos] = "paddle::experimental::Tensor"
else:
assert IsVectorTensorType(rtype)
if is_inplaced and inplace_map and name in inplace_map.values():
returns_type_list[
pos] = "std::vector<paddle::experimental::Tensor>&"
else:
returns_type_list[
pos] = "std::vector<paddle::experimental::Tensor>"
......@@ -936,7 +946,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
returns_type_str = ", ".join(returns_type_list)
returns_type_str = f"std::tuple<{returns_type_str}>"
returns_str = ", ".join(returns_list)
returns_str = f"std::make_tuple({returns_str})"
returns_str = f"{returns_type_str}{{{returns_str}}}"
# Node Creation Pre-Processing
# 1. Get Input AutoGradMeta
......
......@@ -100,7 +100,7 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
// Set Device ID
{}
auto out = {}({});
decltype({}({})) out = {}({});
PyEval_RestoreThread(tstate);
tstate = nullptr;
......@@ -344,7 +344,8 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, set_device_str,
fwd_function_name, dygraph_function_call_str, return_str)
fwd_function_name, dygraph_function_call_str, fwd_function_name,
dygraph_function_call_str, return_str)
# Set prefix of forward_api_name to avoid conflicts
prefix = self.namespace.strip("::")
......@@ -380,6 +381,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, set_device_str,
inplaced_fwd_function_name, dygraph_function_call_str,
inplaced_fwd_function_name, dygraph_function_call_str,
return_str)
# Generate Python-C Function Registration
......
......@@ -714,6 +714,7 @@
backend : x
inplace : (x -> out)
view : (x -> out)
# intermediate : xshape
backward : flatten_grad
# flip
......
......@@ -32,11 +32,7 @@ class BaseAPI(object):
# names : [], list of output names
# types : [], list of output types
# out_size_expr : [], expression for getting size of vector<Tensor>
# return_type : Tensor, vector<Tensor>, ..., the return type of api
# args_str:
# args_declare : "str" // str of function params with default value. Example: (..., bool flag=false)
# args_define : "str" // str of function params without default value. Example: (..., bool flag)
self.inputs, self.attrs, self.outputs, self.args_str, self.optional_vars = self.parse_args(
self.inputs, self.attrs, self.outputs, self.optional_vars = self.parse_args(
self.api, api_item_yaml)
self.is_base_api = True
......@@ -60,11 +56,38 @@ class BaseAPI(object):
def get_api_func_name(self):
return self.api
def get_declare_args(self):
return self.args_str['args_declare']
def get_input_tensor_args(self, inplace_flag=False):
input_args = []
inplace_type_map = {
"const Tensor&": "Tensor&",
"const std::vector<Tensor>&": "std::vector<Tensor>&"
}
for name in self.inputs['names']:
name = name.split('@')[0]
if inplace_flag and name in self.inplace_map.values():
input_args.append(inplace_type_map[self.inputs['input_info'][
name]] + ' ' + name)
else:
input_args.append(self.inputs['input_info'][name] + ' ' + name)
return input_args
def get_declare_args(self, inplace_flag=False):
declare_args = self.get_input_tensor_args(inplace_flag)
for name in self.attrs['names']:
default_value = ''
if self.attrs['attr_info'][name][1] is not None:
default_value = ' = ' + self.attrs['attr_info'][name][1]
declare_args.append(self.attrs['attr_info'][name][0] + ' ' + name +
default_value)
def get_define_args(self):
return self.args_str["args_define"]
return ", ".join(declare_args)
def get_define_args(self, inplace_flag=False):
define_args = self.get_input_tensor_args(inplace_flag)
for name in self.attrs['names']:
define_args.append(self.attrs['attr_info'][name][0] + ' ' + name)
return ", ".join(define_args)
def parse_args(self, api_name, api_item_yaml):
optional_vars = []
......@@ -72,16 +95,15 @@ class BaseAPI(object):
optional_vars = [
item.strip() for item in api_item_yaml['optional'].split(',')
]
inputs, attrs, args_str = self.parse_input_and_attr(
inputs, attrs = self.parse_input_and_attr(
api_name, api_item_yaml['args'], optional_vars)
output_type_list, output_names, out_size_expr, return_type = self.parse_output(
output_type_list, output_names, out_size_expr = self.parse_output(
api_name, api_item_yaml['output'])
return inputs, attrs, {
'names': output_names,
'types': output_type_list,
'out_size_expr': out_size_expr,
'return_type': return_type
}, args_str, optional_vars
'out_size_expr': out_size_expr
}, optional_vars
def parse_input_and_attr(self, api_name, args_config, optional_vars=[]):
inputs = {'names': [], 'input_info': {}}
......@@ -131,9 +153,6 @@ class BaseAPI(object):
'DataType': 'paddle::optional<DataType>'
}
args_declare_str = ""
args_define_str = ""
for item in args_list:
item = item.strip()
type_and_name = item.split(' ')
......@@ -152,8 +171,6 @@ class BaseAPI(object):
inputs['names'].append(input_name)
inputs['input_info'][input_name] = in_type
args_declare_str = args_declare_str + in_type + ' ' + input_name + ', '
args_define_str = args_define_str + in_type + ' ' + input_name + ', '
has_input = True
break
if has_input:
......@@ -175,16 +192,11 @@ class BaseAPI(object):
attr_type = optional_types_trans[attr_type_symbol]
default_value_str = "" if default_value is None else '=' + default_value
args_declare_str = args_declare_str + attr_type + ' ' + attr_name + default_value_str + ', '
args_define_str = args_define_str + attr_type + ' ' + attr_name + ', '
attrs['names'].append(attr_name)
attrs['attr_info'][attr_name] = (attr_type, default_value)
break
return inputs, attrs, {
'args_declare': args_declare_str[:-2],
'args_define': args_define_str[:-2]
}
return inputs, attrs
def parse_output(self, api_name, output_config):
def parse_output_item(output_item):
......@@ -211,8 +223,7 @@ class BaseAPI(object):
if len(temp_list) == 1:
out_type, out_name, size_expr = parse_output_item(temp_list[0])
return [out_type], [out_name], size_expr, self.get_return_type(
[out_type])
return [out_type], [out_name], size_expr
else:
out_type_list = []
out_name_list = []
......@@ -221,8 +232,7 @@ class BaseAPI(object):
out_type_list.append(out_type)
out_name_list.append(out_name)
return out_type_list, out_name_list, size_expr, self.get_return_type(
out_type_list)
return out_type_list, out_name_list, size_expr
def parse_infer_meta(self, infer_meta_config):
infer_meta = infer_meta_config
......@@ -285,7 +295,7 @@ class BaseAPI(object):
return data_transform
def parse_inplace_and_view(self, api_item_yaml):
inplace_map, view_map = None, None
inplace_map, view_map = {}, {}
for mode in ['inplace', 'view']:
if mode in api_item_yaml:
if mode == 'inplace':
......@@ -310,17 +320,22 @@ class BaseAPI(object):
return inplace_map, view_map
# Override by child class
def get_return_type(self, out_type_list):
def get_return_type(self, inplace_flag=False):
return None
def gene_api_declaration(self):
api_declaration = ""
api_func_name = self.get_api_func_name()
if api_func_name[-1] != '_':
api_declaration = f"""
PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name()}({self.get_declare_args()});
PADDLE_API {self.get_return_type()} {api_func_name}({self.get_declare_args()});
"""
if self.is_base_api and self.inplace_map is not None:
if self.is_base_api and len(self.inplace_map) > 0:
if api_func_name[-1] != '_':
api_func_name += '_'
api_declaration = api_declaration + f"""
PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self.get_declare_args()});
PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_declare_args(inplace_flag=True)});
"""
return api_declaration
......@@ -714,10 +729,6 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
return input_tensor_code, kernel_args[:-2], kernel_signature
# Override by child class
def gene_return_type_code(self):
return self.outputs['return_type']
# Override by child class
def gene_return_code(self):
return "return api_output;"
......@@ -786,9 +797,11 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
{code_indent} {self.gene_return_code()}"""
def gene_base_api_code(self, inplace_flag=False):
api_func_name = self.get_api_func_name() + ('_' if inplace_flag else '')
api_func_name = self.get_api_func_name()
if inplace_flag and api_func_name[-1] != '_':
api_func_name += '_'
api_code = f"""
PADDLE_API {self.gene_return_type_code()} {api_func_name}({self.get_define_args()}) {{
PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
{self.gene_kernel_select()}
"""
......@@ -812,14 +825,14 @@ PADDLE_API {self.gene_return_type_code()} {api_func_name}({self.get_define_args(
def gene_invoke_code(self, invoke_code, params_code):
return f"""
PADDLE_API {self.outputs['return_type']} {self.api}({params_code}) {{
PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
return {invoke_code};
}}"""
def gene_api_code(self):
if self.is_base_api:
api_code = self.gene_base_api_code()
if self.inplace_map is not None:
if len(self.inplace_map) > 0:
api_code = api_code + self.gene_base_api_code(inplace_flag=True)
return api_code
......
......@@ -19,6 +19,11 @@ import re
from api_base import BaseAPI, PREFIX_TENSOR_NAME
inplace_out_type_map = {
"Tensor": "Tensor&",
"std::vector<Tensor>": "std::vector<Tensor>&"
}
class ForwardAPI(BaseAPI):
def __init__(self, api_item_yaml):
......@@ -42,22 +47,33 @@ class ForwardAPI(BaseAPI):
else:
return False, []
def get_return_type(self, out_type_list):
return out_type_list[0] if len(
out_type_list) == 1 else "std::tuple<" + ",".join(
out_type_list) + ">"
def get_return_type_with_intermediate(self, inplace_flag=False):
out_type_list = []
for i, out_type in enumerate(self.outputs['types']):
out_name = self.outputs['names'][i].split('@')[0]
if inplace_flag and out_name in self.inplace_map:
out_type_list.append(inplace_out_type_map[out_type])
else:
out_type_list.append(out_type)
def gene_return_type_code(self):
if self.is_dygraph_api or len(self.intermediate_outs) == 0:
return self.outputs['return_type']
if len(out_type_list) == 1:
return out_type_list[0]
else:
return_out_list = []
for i, name in enumerate(self.outputs['names']):
if name.split('@')[0] not in self.intermediate_outs:
return_out_list.append(self.outputs['types'][i])
return return_out_list[0] if len(
return_out_list) == 1 else "std::tuple<" + ",".join(
return_out_list) + ">"
return "std::tuple<" + ", ".join(out_type_list) + ">"
def get_return_type(self, inplace_flag=False):
out_type_list = []
for i, out_type in enumerate(self.outputs['types']):
out_name = self.outputs['names'][i].split('@')[0]
if inplace_flag and out_name in self.inplace_map:
out_type_list.append(inplace_out_type_map[out_type])
elif self.is_dygraph_api or out_name not in self.intermediate_outs:
out_type_list.append(out_type)
if len(out_type_list) == 1:
return out_type_list[0]
else:
return "std::tuple<" + ", ".join(out_type_list) + ">"
def gene_return_code(self):
if self.is_dygraph_api or len(self.intermediate_outs) == 0:
......@@ -83,17 +99,18 @@ class ForwardAPI(BaseAPI):
kernel_output = ""
output_names = []
output_create = ""
return_type = self.get_return_type_with_intermediate(inplace_flag)
if len(output_type_list) == 1:
kernel_output = 'kernel_out'
output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[self.outputs['names'][
0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else ""
0]] if inplace_flag and self.outputs['names'][
0] in self.inplace_map else ""
output_create = f"""
{code_indent} {self.outputs['return_type']} api_output{inplace_assign};"""
{code_indent} {return_type} api_output{inplace_assign};"""
if self.outputs['return_type'] == 'std::vector<Tensor>':
if return_type == 'std::vector<Tensor>':
assert self.outputs['out_size_expr'] is not None, \
f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
output_create = output_create + f"""
......@@ -112,15 +129,23 @@ class ForwardAPI(BaseAPI):
elif len(output_type_list) > 1:
output_create = f"""
{code_indent} {self.outputs['return_type']} api_output;"""
{code_indent} {return_type} api_output;"""
if inplace_flag:
output_create = f"""
{code_indent} {return_type} api_output{{"""
for out_name in self.outputs['names']:
if out_name in self.inplace_map:
output_create = output_create + self.inplace_map[
out_name] + ', '
else:
output_create += 'Tensor(), '
output_create = output_create[:-2] + '};'
for i in range(len(output_type_list)):
kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}')
if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map:
output_create = output_create + f"""
{code_indent} std::get<{i}>(api_output) = {self.inplace_map[self.outputs['names'][i]]};"""
if output_type_list[i] == 'std::vector<Tensor>':
assert self.outputs['out_size_expr'][i] is not None, \
......
......@@ -35,10 +35,10 @@ class BackwardAPI(BaseAPI):
r"(?P<api>[a-z][a-z0-9_]+)\s*(?P<args>\([^\)]+\))\s*->\s*(?P<outputs>.+)",
forward_config)
api = result.group('api')
_, outputs, _, _ = self.parse_output(self.api, result.group('outputs'))
_, outputs, _, = self.parse_output(self.api, result.group('outputs'))
outputs = [item.split('@')[0] for item in outputs]
fw_inputs, fw_attrs, _, = self.parse_input_and_attr(
api, result.group('args'))
fw_inputs, fw_attrs = self.parse_input_and_attr(api,
result.group('args'))
return api, fw_inputs, fw_attrs, outputs
......@@ -77,15 +77,15 @@ class BackwardAPI(BaseAPI):
f"{self.api} : Output error: The number of outputs should be less then the number of inputs of forward api. \
Please check the output of {self.api} in yaml."
def get_declare_args(self):
def get_declare_args(self, inplace_flag=False):
return self.get_define_args()
def get_define_args(self):
def get_define_args(self, inplace_flag=False):
out_type_map = {
'Tensor': 'Tensor*',
'std::vector<Tensor>': 'std::vector<Tensor*>'
}
intputs_and_attrs = self.args_str['args_define']
intputs_and_attrs = super(BackwardAPI, self).get_define_args()
outs = []
for i, name in enumerate(self.outputs['names']):
outs.append(out_type_map[self.outputs['types'][i]] + ' ' +
......@@ -109,7 +109,7 @@ class BackwardAPI(BaseAPI):
else:
return super().gene_kernel_backend_select()
def get_return_type(self, out_type_list):
def get_return_type(self, inplace_flag=False):
return 'void'
def gene_output(self,
......@@ -176,13 +176,13 @@ class BackwardAPI(BaseAPI):
if inveke_func_name.endswith('_grad') or inveke_func_name.endswith(
'_grad_impl'):
return f"""
PADDLE_API {self.outputs['return_type']} {self.api}({params_code}) {{
PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
{invoke_code};
}}"""
else:
return f"""
PADDLE_API {self.outputs['return_type']} {self.api}({params_code}) {{
PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
*{self.outputs['names'][0].split('@')[0]} = {invoke_code};
}}"""
......
......@@ -25,9 +25,10 @@ class SparseAPI(ForwardAPI):
super(SparseAPI, self).__init__(api_item_yaml)
def gene_api_declaration(self):
api_declaration = "// " + ', '.join(self.outputs['names'])
return api_declaration + super(SparseAPI,
self).gene_api_declaration() + '\n'
return f"""
// {", ".join(self.outputs['names'])}
{super(SparseAPI, self).gene_api_declaration()}
"""
def get_kernel_tensor_out_type(self, output_name):
sparse_type = 'TensorType::DENSE_TENSOR'
......@@ -45,6 +46,7 @@ class SparseAPI(ForwardAPI):
kernel_output = ""
output_names = []
output_create = ""
return_type = self.get_return_type_with_intermediate(inplace_flag)
if len(output_type_list) == 1:
kernel_output = 'kernel_out'
......@@ -53,21 +55,29 @@ class SparseAPI(ForwardAPI):
0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else ""
output_create = f"""
{self.outputs['return_type']} api_output{inplace_assign};
{return_type} api_output{inplace_assign};
auto* kernel_out = {set_out_func}(&api_output, {self.get_kernel_tensor_out_type(self.outputs['names'][0])});"""
elif len(output_type_list) > 1:
output_create = f"""
{self.outputs['return_type']} api_output;"""
{return_type} api_output;"""
if inplace_flag:
output_create = f"""
{return_type} api_output{{"""
for out_name in self.outputs['names']:
out_name = out_name.split('@')[0]
if out_name in self.inplace_map:
output_create = output_create + self.inplace_map[
out_name] + ', '
else:
output_create += 'Tensor(), '
output_create = output_create[:-2] + '};'
for i in range(len(output_type_list)):
kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}')
if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map:
output_create = output_create + f"""
std::get<{i}>(api_output) = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f"""
auto* kernel_out_{i} = {set_out_func}(&std::get<{i}>(api_output), {self.get_kernel_tensor_out_type(self.outputs['names'][i])});"""
......@@ -151,8 +161,11 @@ class SparseAPI(ForwardAPI):
{return_code}"""
def gene_base_api_code(self, inplace_flag=False):
api_func_name = self.get_api_func_name()
if inplace_flag and api_func_name[-1] != '_':
api_func_name += '_'
return f"""
PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name()}({self.get_define_args()}) {{
PADDLE_API {self.get_return_type()} {api_func_name}({self.get_define_args()}) {{
{self.gene_kernel_select()}
{self.gen_sparse_kernel_code(inplace_flag)}
}}
......
......@@ -31,11 +31,8 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
def gene_kernel_backend_select(self):
return BackwardAPI.gene_kernel_backend_select(self)
def get_return_type(self, out_type_list):
return BackwardAPI.get_return_type(self, out_type_list)
def gene_return_type_code(self):
return self.outputs['return_type']
def get_return_type(self, inplace_flag=False):
return BackwardAPI.get_return_type(self)
def gene_return_code(self):
return ""
......@@ -43,10 +40,10 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
def gene_api_declaration(self):
return SparseAPI.gene_api_declaration(self)
def get_declare_args(self):
def get_declare_args(self, inplace_flag=False):
return BackwardAPI.get_declare_args(self)
def get_define_args(self):
def get_define_args(self, inplace_flag=False):
return BackwardAPI.get_define_args(self)
def gene_output(self,
......
......@@ -32,7 +32,7 @@ class StringsAPI(ForwardAPI):
def gene_api_declaration(self):
return f"""
// {", ".join(self.outputs['names'])}
PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_str['args_declare']});
{super(StringsAPI, self).gene_api_declaration()}
"""
def get_kernel_tensor_out_type(self, output_name):
......@@ -56,6 +56,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s
kernel_output = ""
output_names = []
output_create = ""
return_type = self.get_return_type(inplace_flag)
if len(output_type_list) == 1:
kernel_output = 'kernel_out'
......@@ -67,13 +68,12 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s
0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else ""
output_create = f"""
{self.outputs['return_type']} api_output{inplace_assign};
{return_type} api_output{inplace_assign};
{tensor_type}* kernel_out = dynamic_cast<{tensor_type}*>({set_out_func}(kernel_backend, &api_output, {kernel_tensor_out_type}));"""
elif len(output_type_list) > 1:
output_create = f"""
{self.outputs['return_type']} api_output;"""
{return_type} api_output;"""
for i in range(len(output_type_list)):
kernel_output = kernel_output + f'kernel_out_{i}, '
......@@ -264,7 +264,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s
def gene_base_api_code(self, inplace_flag=False):
api_func_name = self.get_api_func_name()
return f"""
PADDLE_API {self.outputs['return_type']} {api_func_name}({self.args_str["args_define"]}) {{
PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
{self.gene_kernel_select()}
{self.gen_string_tensor_kernel_code(inplace_flag)}
}}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册