未验证 提交 63fb0347 编写于 作者: Z zyfncg 提交者: GitHub

[PHI] Fix some bug of code auto-gen in C++ API (#40262)

* support code auto-gene for sparse backward api

* fix bug of intermediate api and name of return var
上级 f40ed5f4
...@@ -671,7 +671,7 @@ def GenerateNodeCreationCodes( ...@@ -671,7 +671,7 @@ def GenerateNodeCreationCodes(
else: else:
# Tuple api_result # Tuple api_result
if IsPlainTensorType(rtype): if IsPlainTensorType(rtype):
outputs_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);" output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);"
else: else:
assert IsVectorTensorType(rtype) assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);\n" output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);\n"
......
add_subdirectory(lib) add_subdirectory(lib)
cc_library(phi_api SRCS all.cc DEPS phi_function_api phi_bw_function_api sparse_api) cc_library(phi_api SRCS all.cc DEPS phi_function_api phi_bw_function_api sparse_api sparse_bw_api)
...@@ -37,8 +37,16 @@ set(sparse_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_ ...@@ -37,8 +37,16 @@ set(sparse_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_
set(sparse_api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_api.yaml) set(sparse_api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_api.yaml)
set(sparse_api_header_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/include/sparse_api.h) set(sparse_api_header_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/include/sparse_api.h)
set(sparse_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/sparse_api.cc) set(sparse_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/sparse_api.cc)
set(sparse_api_header_file_tmp ${api_header_file}.tmp) set(sparse_api_header_file_tmp ${sparse_api_header_file}.tmp)
set(sparse_api_source_file_tmp ${api_source_file}.tmp) set(sparse_api_source_file_tmp ${sparse_api_source_file}.tmp)
# sparse bw api file
set(sparse_bw_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_bw_api_gen.py)
set(sparse_bw_api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_bw_api.yaml)
set(sparse_bw_api_header_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/backward/sparse_bw_api.h)
set(sparse_bw_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/sparse_bw_api.cc)
set(sparse_bw_api_header_file_tmp ${sparse_bw_api_header_file}.tmp)
set(sparse_bw_api_source_file_tmp ${sparse_bw_api_source_file}.tmp)
# sparse bw api file # sparse bw api file
set(sparse_bw_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_bw_api_gen.py) set(sparse_bw_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_bw_api_gen.py)
......
...@@ -301,12 +301,12 @@ class BaseAPI(object): ...@@ -301,12 +301,12 @@ class BaseAPI(object):
def gene_api_declaration(self): def gene_api_declaration(self):
api_declaration = f""" api_declaration = f"""
PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_str['args_declare']}); PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name()}({self.args_str['args_declare']});
""" """
if self.is_base_api and self.inplace_map is not None: if self.is_base_api and self.inplace_map is not None:
api_declaration = api_declaration + f""" api_declaration = api_declaration + f"""
PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.args_str['args_declare']}); PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self.args_str['args_declare']});
""" """
return api_declaration return api_declaration
...@@ -675,6 +675,14 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self. ...@@ -675,6 +675,14 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
return input_tensor_code, kernel_args[:-2], kernel_signature return input_tensor_code, kernel_args[:-2], kernel_signature
# Override by child class
def gene_return_type_code(self):
return self.outputs['return_type']
# Override by child class
def gene_return_code(self):
return "api_output"
# Override by child class # Override by child class
def gene_output(self, def gene_output(self,
output_type_list, output_type_list,
...@@ -703,7 +711,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self. ...@@ -703,7 +711,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
{code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>(); {code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
{code_indent} (*kernel_fn)({kernel_args}, {outputs_args}); {code_indent} (*kernel_fn)({kernel_args}, {outputs_args});
{code_indent} return out;""" {code_indent} return {self.gene_return_code()};"""
def gen_selected_rows_kernel_code(self, code_indent, inplace_flag=False): def gen_selected_rows_kernel_code(self, code_indent, inplace_flag=False):
input_tensors, kernel_args, kernel_signature = self.get_selected_rows_kernel_args( input_tensors, kernel_args, kernel_signature = self.get_selected_rows_kernel_args(
...@@ -726,12 +734,12 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self. ...@@ -726,12 +734,12 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
{code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>(); {code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
{code_indent} (*kernel_fn)({kernel_args}, {outputs_args}); {code_indent} (*kernel_fn)({kernel_args}, {outputs_args});
{code_indent} return out;""" {code_indent} return {self.gene_return_code()};"""
def gene_base_api_code(self, inplace_flag=False): def gene_base_api_code(self, inplace_flag=False):
api_func_name = self.get_api_func_name() + ('_' if inplace_flag else '') api_func_name = self.get_api_func_name() + ('_' if inplace_flag else '')
api_code = f""" api_code = f"""
PADDLE_API {self.outputs['return_type']} {api_func_name}({self.args_str["args_define"]}) {{ PADDLE_API {self.gene_return_type_code()} {api_func_name}({self.args_str["args_define"]}) {{
{self.gene_kernel_select()} {self.gene_kernel_select()}
""" """
......
...@@ -23,7 +23,8 @@ from api_base import BaseAPI ...@@ -23,7 +23,8 @@ from api_base import BaseAPI
class ForwardAPI(BaseAPI): class ForwardAPI(BaseAPI):
def __init__(self, api_item_yaml): def __init__(self, api_item_yaml):
super(ForwardAPI, self).__init__(api_item_yaml) super(ForwardAPI, self).__init__(api_item_yaml)
self.is_dygraph_api = self.parse_intermediate(api_item_yaml) self.is_dygraph_api, self.intermediate_outs = self.parse_intermediate(
api_item_yaml)
def get_api_func_name(self): def get_api_func_name(self):
if self.is_dygraph_api: if self.is_dygraph_api:
...@@ -33,15 +34,47 @@ class ForwardAPI(BaseAPI): ...@@ -33,15 +34,47 @@ class ForwardAPI(BaseAPI):
def parse_intermediate(self, api_item_yaml): def parse_intermediate(self, api_item_yaml):
if 'intermediate' in api_item_yaml: if 'intermediate' in api_item_yaml:
return True intermediate_outs = [
item.strip()
for item in api_item_yaml['intermediate'].split(',')
]
return True, intermediate_outs
else: else:
return False return False, []
def get_return_type(self, out_type_list): def get_return_type(self, out_type_list):
return out_type_list[0] if len( return out_type_list[0] if len(
out_type_list) == 1 else "std::tuple<" + ",".join( out_type_list) == 1 else "std::tuple<" + ",".join(
out_type_list) + ">" out_type_list) + ">"
def gene_return_type_code(self):
if self.is_dygraph_api or len(self.intermediate_outs) == 0:
return self.outputs['return_type']
else:
return_out_list = []
for i, name in enumerate(self.outputs['names']):
if name not in self.intermediate_outs:
return_out_list.append(self.outputs['types'][i])
return return_out_list[0] if len(
return_out_list) == 1 else "std::tuple<" + ",".join(
return_out_list) + ">"
def gene_return_code(self):
if self.is_dygraph_api or len(self.intermediate_outs) == 0:
return "api_output"
else:
return_out_list = []
for i, name in enumerate(self.outputs['names']):
if name not in self.intermediate_outs:
return_out_list.append(i)
if len(return_out_list) == 1:
return f"std::get<{return_out_list[0]}>(api_output)"
else:
selected_code = [
f"std::get<{i}>(api_output)" for i in return_out_list
]
return '{' + ", ".join(selected_code) + '}'
def gene_output(self, def gene_output(self,
output_type_list, output_type_list,
set_out_func, set_out_func,
...@@ -58,12 +91,12 @@ class ForwardAPI(BaseAPI): ...@@ -58,12 +91,12 @@ class ForwardAPI(BaseAPI):
0]] if inplace_flag and self.inplace_map is not None and self.outputs[ 0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else "" 'names'][0] in self.inplace_map else ""
output_create = f""" output_create = f"""
{code_indent} {self.outputs['return_type']} out{inplace_assign}; {code_indent} {self.outputs['return_type']} api_output{inplace_assign};
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, &out);""" {code_indent} auto kernel_out = {set_out_func}(kernel_backend, &api_output);"""
elif len(output_type_list) > 1: elif len(output_type_list) > 1:
output_create = f""" output_create = f"""
{code_indent} {self.outputs['return_type']} out;""" {code_indent} {self.outputs['return_type']} api_output;"""
for i in range(len(output_type_list)): for i in range(len(output_type_list)):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
...@@ -71,10 +104,10 @@ class ForwardAPI(BaseAPI): ...@@ -71,10 +104,10 @@ class ForwardAPI(BaseAPI):
if inplace_flag and self.inplace_map is not None and self.outputs[ if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map: 'names'][i] in self.inplace_map:
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} std::get<{i}>(out) = {self.inplace_map[self.outputs['names'][i]]};""" {code_indent} std::get<{i}>(api_output) = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(out));""" {code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(api_output));"""
kernel_output = kernel_output[:-2] kernel_output = kernel_output[:-2]
else: else:
...@@ -169,6 +202,10 @@ def generate_api(api_yaml_path, header_file_path, source_file_path, ...@@ -169,6 +202,10 @@ def generate_api(api_yaml_path, header_file_path, source_file_path,
if foward_api.is_dygraph_api: if foward_api.is_dygraph_api:
dygraph_header_file.write(foward_api.gene_api_declaration()) dygraph_header_file.write(foward_api.gene_api_declaration())
dygraph_source_file.write(foward_api.gene_api_code()) dygraph_source_file.write(foward_api.gene_api_code())
foward_api.is_dygraph_api = False
header_file.write(foward_api.gene_api_declaration())
source_file.write(foward_api.gene_api_code())
else: else:
header_file.write(foward_api.gene_api_declaration()) header_file.write(foward_api.gene_api_declaration())
source_file.write(foward_api.gene_api_code()) source_file.write(foward_api.gene_api_code())
......
...@@ -87,33 +87,33 @@ class BackwardAPI(BaseAPI): ...@@ -87,33 +87,33 @@ class BackwardAPI(BaseAPI):
0]] if inplace_flag and self.inplace_map is not None and self.outputs[ 0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else "" 'names'][0] in self.inplace_map else ""
output_create = f""" output_create = f"""
{code_indent} {self.outputs['return_type']} out{inplace_assign}; {code_indent} {self.outputs['return_type']} api_output{inplace_assign};
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, &out);""" {code_indent} auto kernel_out = {set_out_func}(kernel_backend, &api_output);"""
elif len(output_type_list) > 1: elif len(output_type_list) > 1:
output_create = f""" output_create = f"""
{code_indent} {self.outputs['return_type']} out({len(output_type_list)});""" {code_indent} {self.outputs['return_type']} api_output({len(output_type_list)});"""
for i, out_type_item in enumerate(output_type_list): for i, out_type_item in enumerate(output_type_list):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}')
if out_type_item == 'Tensor': if out_type_item == 'Tensor':
get_out_code = f'&out[{i}][0]' get_out_code = f'&api_output[{i}][0]'
if inplace_flag and self.inplace_map is not None and self.outputs[ if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map: 'names'][i] in self.inplace_map:
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} out[{i}].emplace_back({self.inplace_map[self.outputs['names'][i]]});""" {code_indent} api_output[{i}].emplace_back({self.inplace_map[self.outputs['names'][i]]});"""
else: else:
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} out[{i}].emplace_back();""" {code_indent} api_output[{i}].emplace_back();"""
else: else:
get_out_code = f'&out[{i}]' get_out_code = f'&api_output[{i}]'
if inplace_flag and self.inplace_map is not None and self.outputs[ if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map: 'names'][i] in self.inplace_map:
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} out[{i}] = {self.inplace_map[self.outputs['names'][i]]};""" {code_indent} api_output[{i}] = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, {get_out_code});""" {code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, {get_out_code});"""
......
...@@ -60,12 +60,12 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s ...@@ -60,12 +60,12 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s
0]] if inplace_flag and self.inplace_map is not None and self.outputs[ 0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else "" 'names'][0] in self.inplace_map else ""
output_create = f""" output_create = f"""
{self.outputs['return_type']} out{inplace_assign}; {self.outputs['return_type']} api_output{inplace_assign};
auto* kernel_out = {set_out_func}(&out, {self.get_kernel_tensor_out_type(self.outputs['names'][0])});""" auto* kernel_out = {set_out_func}(&api_output, {self.get_kernel_tensor_out_type(self.outputs['names'][0])});"""
elif len(output_type_list) > 1: elif len(output_type_list) > 1:
output_create = f""" output_create = f"""
{self.outputs['return_type']} out;""" {self.outputs['return_type']} api_output;"""
for i in range(len(output_type_list)): for i in range(len(output_type_list)):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
...@@ -73,10 +73,10 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s ...@@ -73,10 +73,10 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s
if inplace_flag and self.inplace_map is not None and self.outputs[ if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map: 'names'][i] in self.inplace_map:
output_create = output_create + f""" output_create = output_create + f"""
std::get<{i}>(out) = {self.inplace_map[self.outputs['names'][i]]};""" std::get<{i}>(api_output) = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f""" output_create = output_create + f"""
auto* kernel_out_{i} = {set_out_func}(&std::get<{i}>(out), {self.get_kernel_tensor_out_type(self.outputs['names'][i])});""" auto* kernel_out_{i} = {set_out_func}(&std::get<{i}>(api_output), {self.get_kernel_tensor_out_type(self.outputs['names'][i])});"""
kernel_output = kernel_output[:-2] kernel_output = kernel_output[:-2]
else: else:
...@@ -155,7 +155,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s ...@@ -155,7 +155,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s
{kernel_context_code} {kernel_context_code}
phi_kernel(&kernel_context); phi_kernel(&kernel_context);
return out;""" return api_output;"""
def gene_base_api_code(self, inplace_flag=False): def gene_base_api_code(self, inplace_flag=False):
api_func_name = self.get_api_func_name() api_func_name = self.get_api_func_name()
......
...@@ -53,33 +53,33 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): ...@@ -53,33 +53,33 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
0]] if inplace_flag and self.inplace_map is not None and self.outputs[ 0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else "" 'names'][0] in self.inplace_map else ""
output_create = f""" output_create = f"""
{self.outputs['return_type']} out{inplace_assign}; {self.outputs['return_type']} api_output{inplace_assign};
auto kernel_out = {set_out_func}(&out, {self.get_kernel_tensor_out_type(self.outputs['names'][0])});""" auto kernel_out = {set_out_func}(&api_output, {self.get_kernel_tensor_out_type(self.outputs['names'][0])});"""
elif len(output_type_list) > 1: elif len(output_type_list) > 1:
output_create = f""" output_create = f"""
{self.outputs['return_type']} out({len(output_type_list)});""" {self.outputs['return_type']} api_output({len(output_type_list)});"""
for i, out_type_item in enumerate(output_type_list): for i, out_type_item in enumerate(output_type_list):
kernel_output = kernel_output + f'kernel_out_{i}, ' kernel_output = kernel_output + f'kernel_out_{i}, '
output_names.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}')
if out_type_item == 'Tensor': if out_type_item == 'Tensor':
get_out_code = f'&out[{i}][0]' get_out_code = f'&api_output[{i}][0]'
if inplace_flag and self.inplace_map is not None and self.outputs[ if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map: 'names'][i] in self.inplace_map:
output_create = output_create + f""" output_create = output_create + f"""
out[{i}].emplace_back({self.inplace_map[self.outputs['names'][i]]});""" api_output[{i}].emplace_back({self.inplace_map[self.outputs['names'][i]]});"""
else: else:
output_create = output_create + f""" output_create = output_create + f"""
out[{i}].emplace_back();""" api_output[{i}].emplace_back();"""
else: else:
get_out_code = f'&out[{i}]' get_out_code = f'&api_output[{i}]'
if inplace_flag and self.inplace_map is not None and self.outputs[ if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][i] in self.inplace_map: 'names'][i] in self.inplace_map:
output_create = output_create + f""" output_create = output_create + f"""
out[{i}] = {self.inplace_map[self.outputs['names'][i]]};""" api_output[{i}] = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f""" output_create = output_create + f"""
auto kernel_out_{i} = {set_out_func}({get_out_code}, {self.get_kernel_tensor_out_type(self.outputs['names'][i])});""" auto kernel_out_{i} = {set_out_func}({get_out_code}, {self.get_kernel_tensor_out_type(self.outputs['names'][i])});"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册