未验证 提交 63fb0347 编写于 作者: Z zyfncg 提交者: GitHub

[PHI] Fix some bug of code auto-gen in C++ API (#40262)

* support code auto-gene for sparse backward api

* fix bug of intermediate api and name of return var
上级 f40ed5f4
......@@ -671,7 +671,7 @@ def GenerateNodeCreationCodes(
else:
# Tuple api_result
if IsPlainTensorType(rtype):
outputs_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);"
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);\n"
......
add_subdirectory(lib)
cc_library(phi_api SRCS all.cc DEPS phi_function_api phi_bw_function_api sparse_api)
cc_library(phi_api SRCS all.cc DEPS phi_function_api phi_bw_function_api sparse_api sparse_bw_api)
......@@ -37,8 +37,16 @@ set(sparse_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_
set(sparse_api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_api.yaml)
set(sparse_api_header_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/include/sparse_api.h)
set(sparse_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/sparse_api.cc)
set(sparse_api_header_file_tmp ${api_header_file}.tmp)
set(sparse_api_source_file_tmp ${api_source_file}.tmp)
set(sparse_api_header_file_tmp ${sparse_api_header_file}.tmp)
set(sparse_api_source_file_tmp ${sparse_api_source_file}.tmp)
# sparse bw api file
set(sparse_bw_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_bw_api_gen.py)
set(sparse_bw_api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_bw_api.yaml)
set(sparse_bw_api_header_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/backward/sparse_bw_api.h)
set(sparse_bw_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/sparse_bw_api.cc)
set(sparse_bw_api_header_file_tmp ${sparse_bw_api_header_file}.tmp)
set(sparse_bw_api_source_file_tmp ${sparse_bw_api_source_file}.tmp)
# sparse bw api file
set(sparse_bw_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_bw_api_gen.py)
......
......@@ -301,12 +301,12 @@ class BaseAPI(object):
def gene_api_declaration(self):
api_declaration = f"""
PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_str['args_declare']});
PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name()}({self.args_str['args_declare']});
"""
if self.is_base_api and self.inplace_map is not None:
api_declaration = api_declaration + f"""
PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.args_str['args_declare']});
PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self.args_str['args_declare']});
"""
return api_declaration
......@@ -675,6 +675,14 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
return input_tensor_code, kernel_args[:-2], kernel_signature
# Override by child class
def gene_return_type_code(self):
return self.outputs['return_type']
# Override by child class
def gene_return_code(self):
return "api_output"
# Override by child class
def gene_output(self,
output_type_list,
......@@ -703,7 +711,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
{code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
{code_indent} (*kernel_fn)({kernel_args}, {outputs_args});
{code_indent} return out;"""
{code_indent} return {self.gene_return_code()};"""
def gen_selected_rows_kernel_code(self, code_indent, inplace_flag=False):
input_tensors, kernel_args, kernel_signature = self.get_selected_rows_kernel_args(
......@@ -726,12 +734,12 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
{code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
{code_indent} (*kernel_fn)({kernel_args}, {outputs_args});
{code_indent} return out;"""
{code_indent} return {self.gene_return_code()};"""
def gene_base_api_code(self, inplace_flag=False):
api_func_name = self.get_api_func_name() + ('_' if inplace_flag else '')
api_code = f"""
PADDLE_API {self.outputs['return_type']} {api_func_name}({self.args_str["args_define"]}) {{
PADDLE_API {self.gene_return_type_code()} {api_func_name}({self.args_str["args_define"]}) {{
{self.gene_kernel_select()}
"""
......
......@@ -23,7 +23,8 @@ from api_base import BaseAPI
class ForwardAPI(BaseAPI):
def __init__(self, api_item_yaml):
super(ForwardAPI, self).__init__(api_item_yaml)
self.is_dygraph_api = self.parse_intermediate(api_item_yaml)
self.is_dygraph_api, self.intermediate_outs = self.parse_intermediate(
api_item_yaml)
def get_api_func_name(self):
if self.is_dygraph_api:
......@@ -33,15 +34,47 @@ class ForwardAPI(BaseAPI):
def parse_intermediate(self, api_item_yaml):
if 'intermediate' in api_item_yaml:
return True
intermediate_outs = [
item.strip()
for item in api_item_yaml['intermediate'].split(',')
]
return True, intermediate_outs
else:
return False
return False, []
def get_return_type(self, out_type_list):
return out_type_list[0] if len(
out_type_list) == 1 else "std::tuple<" + ",".join(
out_type_list) + ">"
def gene_return_type_code(self):
if self.is_dygraph_api or len(self.intermediate_outs) == 0:
return self.outputs['return_type']
else:
return_out_list = []
for i, name in enumerate(self.outputs['names']):
if name not in self.intermediate_outs:
return_out_list.append(self.outputs['types'][i])
return return_out_list[0] if len(
return_out_list) == 1 else "std::tuple<" + ",".join(
return_out_list) + ">"
def gene_return_code(self):
if self.is_dygraph_api or len(self.intermediate_outs) == 0:
return "api_output"
else:
return_out_list = []
for i, name in enumerate(self.outputs['names']):
if name not in self.intermediate_outs:
return_out_list.append(i)
if len(return_out_list) == 1:
return f"std::get<{return_out_list[0]}>(api_output)"
else:
selected_code = [
f"std::get<{i}>(api_output)" for i in return_out_list
]
return '{' + ", ".join(selected_code) + '}'
def gene_output(self,
output_type_list,
set_out_func,
......@@ -58,12 +91,12 @@ class ForwardAPI(BaseAPI):
0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else ""
output_create = f"""
{code_indent} {self.outputs['return_type']} out{inplace_assign};
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, &out);"""
{code_indent} {self.outputs['return_type']} api_output{inplace_assign};
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, &api_output);"""
elif len(output_type_list) > 1:
output_create = f"""
{code_indent} {self.outputs['return_type']} out;"""
{code_indent} {self.outputs['return_type']} api_output;"""
for i in range(len(output_type_list)):
kernel_output = kernel_output + f'kernel_out_{i}, '
......@@ -71,10 +104,10 @@ class ForwardAPI(BaseAPI):
if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map:
output_create = output_create + f"""
{code_indent} std::get<{i}>(out) = {self.inplace_map[self.outputs['names'][i]]};"""
{code_indent} std::get<{i}>(api_output) = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(out));"""
{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(api_output));"""
kernel_output = kernel_output[:-2]
else:
......@@ -169,6 +202,10 @@ def generate_api(api_yaml_path, header_file_path, source_file_path,
if foward_api.is_dygraph_api:
dygraph_header_file.write(foward_api.gene_api_declaration())
dygraph_source_file.write(foward_api.gene_api_code())
foward_api.is_dygraph_api = False
header_file.write(foward_api.gene_api_declaration())
source_file.write(foward_api.gene_api_code())
else:
header_file.write(foward_api.gene_api_declaration())
source_file.write(foward_api.gene_api_code())
......
......@@ -87,33 +87,33 @@ class BackwardAPI(BaseAPI):
0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else ""
output_create = f"""
{code_indent} {self.outputs['return_type']} out{inplace_assign};
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, &out);"""
{code_indent} {self.outputs['return_type']} api_output{inplace_assign};
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, &api_output);"""
elif len(output_type_list) > 1:
output_create = f"""
{code_indent} {self.outputs['return_type']} out({len(output_type_list)});"""
{code_indent} {self.outputs['return_type']} api_output({len(output_type_list)});"""
for i, out_type_item in enumerate(output_type_list):
kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}')
if out_type_item == 'Tensor':
get_out_code = f'&out[{i}][0]'
get_out_code = f'&api_output[{i}][0]'
if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map:
output_create = output_create + f"""
{code_indent} out[{i}].emplace_back({self.inplace_map[self.outputs['names'][i]]});"""
{code_indent} api_output[{i}].emplace_back({self.inplace_map[self.outputs['names'][i]]});"""
else:
output_create = output_create + f"""
{code_indent} out[{i}].emplace_back();"""
{code_indent} api_output[{i}].emplace_back();"""
else:
get_out_code = f'&out[{i}]'
get_out_code = f'&api_output[{i}]'
if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map:
output_create = output_create + f"""
{code_indent} out[{i}] = {self.inplace_map[self.outputs['names'][i]]};"""
{code_indent} api_output[{i}] = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, {get_out_code});"""
......
......@@ -60,12 +60,12 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s
0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else ""
output_create = f"""
{self.outputs['return_type']} out{inplace_assign};
auto* kernel_out = {set_out_func}(&out, {self.get_kernel_tensor_out_type(self.outputs['names'][0])});"""
{self.outputs['return_type']} api_output{inplace_assign};
auto* kernel_out = {set_out_func}(&api_output, {self.get_kernel_tensor_out_type(self.outputs['names'][0])});"""
elif len(output_type_list) > 1:
output_create = f"""
{self.outputs['return_type']} out;"""
{self.outputs['return_type']} api_output;"""
for i in range(len(output_type_list)):
kernel_output = kernel_output + f'kernel_out_{i}, '
......@@ -73,10 +73,10 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s
if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map:
output_create = output_create + f"""
std::get<{i}>(out) = {self.inplace_map[self.outputs['names'][i]]};"""
std::get<{i}>(api_output) = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f"""
auto* kernel_out_{i} = {set_out_func}(&std::get<{i}>(out), {self.get_kernel_tensor_out_type(self.outputs['names'][i])});"""
auto* kernel_out_{i} = {set_out_func}(&std::get<{i}>(api_output), {self.get_kernel_tensor_out_type(self.outputs['names'][i])});"""
kernel_output = kernel_output[:-2]
else:
......@@ -155,7 +155,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s
{kernel_context_code}
phi_kernel(&kernel_context);
return out;"""
return api_output;"""
def gene_base_api_code(self, inplace_flag=False):
api_func_name = self.get_api_func_name()
......
......@@ -53,33 +53,33 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else ""
output_create = f"""
{self.outputs['return_type']} out{inplace_assign};
auto kernel_out = {set_out_func}(&out, {self.get_kernel_tensor_out_type(self.outputs['names'][0])});"""
{self.outputs['return_type']} api_output{inplace_assign};
auto kernel_out = {set_out_func}(&api_output, {self.get_kernel_tensor_out_type(self.outputs['names'][0])});"""
elif len(output_type_list) > 1:
output_create = f"""
{self.outputs['return_type']} out({len(output_type_list)});"""
{self.outputs['return_type']} api_output({len(output_type_list)});"""
for i, out_type_item in enumerate(output_type_list):
kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}')
if out_type_item == 'Tensor':
get_out_code = f'&out[{i}][0]'
get_out_code = f'&api_output[{i}][0]'
if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map:
output_create = output_create + f"""
out[{i}].emplace_back({self.inplace_map[self.outputs['names'][i]]});"""
api_output[{i}].emplace_back({self.inplace_map[self.outputs['names'][i]]});"""
else:
output_create = output_create + f"""
out[{i}].emplace_back();"""
api_output[{i}].emplace_back();"""
else:
get_out_code = f'&out[{i}]'
get_out_code = f'&api_output[{i}]'
if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map:
output_create = output_create + f"""
out[{i}] = {self.inplace_map[self.outputs['names'][i]]};"""
api_output[{i}] = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f"""
auto kernel_out_{i} = {set_out_func}({get_out_code}, {self.get_kernel_tensor_out_type(self.outputs['names'][i])});"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册