未验证 提交 33b3e28a 编写于 作者: Z zyfncg 提交者: GitHub

change output of backward_api (#39229)

上级 30470853
...@@ -31,7 +31,12 @@ class API: ...@@ -31,7 +31,12 @@ class API:
# names : [], list of attribute names # names : [], list of attribute names
# attr_info : { attr_name : (type, default_values)} # attr_info : { attr_name : (type, default_values)}
self.args = gen_utils.parse_args(self.api, api_item_yaml['args']) self.args = gen_utils.parse_args(self.api, api_item_yaml['args'])
self.output = api_item_yaml['output'] self.out_type_list, _ = gen_utils.parse_output(self.api,
api_item_yaml['output'])
self.return_type = self.out_type_list[0] if len(
self.out_type_list) == 1 else "std::tuple<" + ",".join(
self.out_type_list) + ">"
self.is_base_api = True self.is_base_api = True
if 'invoke' in api_item_yaml: if 'invoke' in api_item_yaml:
self.is_base_api = False self.is_base_api = False
...@@ -54,18 +59,44 @@ class API: ...@@ -54,18 +59,44 @@ class API:
def gene_api_declaration(self): def gene_api_declaration(self):
return f""" return f"""
PADDLE_API {self.output} {self.api}({self.args['args_declare']}); PADDLE_API {self.return_type} {self.api}({self.args['args_declare']});
""" """
def gene_output(self, output_type_list):
kernel_output = ""
output_create = ""
if len(output_type_list) == 1:
kernel_output = 'dense_out'
output_create = f"""
{self.return_type} out;
auto dense_out = SetKernelOutput(out_meta, kernel_backend, &out);"""
elif len(output_type_list) > 1:
output_create = f"""
{self.return_type} out;"""
for i in range(len(output_type_list)):
kernel_output = kernel_output + f'dense_out_{i}, '
output_create = output_create + f"""
auto dense_out_{i} = SetKernelOutput(std::get<{i}>(out_meta), kernel_backend, &std::get<{i}>(out));"""
kernel_output = kernel_output[:-2]
else:
raise ValueError(
"{} : Output error: the output should not be empty.".format(
self.api))
return kernel_output, output_create
def gene_api_code(self): def gene_api_code(self):
if self.is_base_api: if self.is_base_api:
input_tensors, kernel_args = gen_utils.get_kernel_args( input_tensors, kernel_args = gen_utils.get_kernel_args(
self.args['inputs']['names'], self.args['attrs'], self.args['inputs']['names'], self.args['attrs'],
self.kernel['param']) self.kernel['param'])
out_type, _ = gen_utils.parse_output(self.api, self.output) outputs_args, output_create = self.gene_output(self.out_type_list)
outputs_args, output_create = gen_utils.gene_output(out_type)
return f""" return f"""
PADDLE_API {self.output} {self.api}({self.args["args_define"]}) {{ PADDLE_API {self.return_type} {self.api}({self.args["args_define"]}) {{
{gen_utils.gene_kernel_select(self.api, self.args['inputs']['names'], self.args['attrs'], self.kernel)} {gen_utils.gene_kernel_select(self.api, self.args['inputs']['names'], self.args['attrs'], self.kernel)}
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
...@@ -82,7 +113,7 @@ PADDLE_API {self.output} {self.api}({self.args["args_define"]}) {{ ...@@ -82,7 +113,7 @@ PADDLE_API {self.output} {self.api}({self.args["args_define"]}) {{
else: else:
return f""" return f"""
PADDLE_API {self.output} {self.api}({self.args["args_define"]}) {{ PADDLE_API {self.return_type} {self.api}({self.args["args_define"]}) {{
return {self.invoke}; return {self.invoke};
}} }}
""" """
......
...@@ -23,9 +23,11 @@ import gen_utils ...@@ -23,9 +23,11 @@ import gen_utils
class BackwardAPI: class BackwardAPI:
def __init__(self, backward_item_yaml): def __init__(self, backward_item_yaml):
self.backward_api = backward_item_yaml['backward_api'] self.backward_api = backward_item_yaml['backward_api']
self.args, self.output_type, self.return_comment = self.parse_and_check_args( self.args, self.output_type_list, self.return_comment = self.parse_and_check_args(
backward_item_yaml['forward'], backward_item_yaml['args'], backward_item_yaml['forward'], backward_item_yaml['args'],
backward_item_yaml['output']) backward_item_yaml['output'])
self.return_type = self.output_type_list[0] if len(
self.output_type_list) == 1 else "std::vector<std::vector<Tensor>>"
self.is_base_api = True self.is_base_api = True
if 'invoke' in backward_item_yaml: if 'invoke' in backward_item_yaml:
...@@ -81,36 +83,65 @@ class BackwardAPI: ...@@ -81,36 +83,65 @@ class BackwardAPI:
Please check the args of {self.backward_api} in yaml." Please check the args of {self.backward_api} in yaml."
# check the output of backward # check the output of backward
output_type, return_comment = gen_utils.parse_output(self.backward_api, out_type_list, return_comment = gen_utils.parse_output(
output_config) self.backward_api, output_config)
assert output_type.count('Tensor') <= len(fw_inputs['names']), \ assert len(out_type_list) <= len(fw_inputs['names']), \
f"{self.backward_api} : Output error: The number of ouputs should be less then the number of inputs of forward api. \ f"{self.backward_api} : Output error: The number of ouputs should be less then the number of inputs of forward api. \
Please check the output of {self.backward_api} in yaml." Please check the output of {self.backward_api} in yaml."
return bw_args, output_type, return_comment return bw_args, out_type_list, return_comment
def gene_api_declaration(self): def gene_api_declaration(self):
if self.return_comment: if self.return_comment:
return f""" return f"""
// {self.return_comment} // {self.return_comment}
{self.output_type} {self.backward_api}({self.args['args_declare']}); {self.return_type} {self.backward_api}({self.args['args_declare']});
""" """
else: else:
return f""" return f"""
{self.output_type} {self.backward_api}({self.args['args_declare']}); {self.return_type} {self.backward_api}({self.args['args_declare']});
""" """
def gene_output(self, output_type_list):
kernel_output = ""
output_create = ""
if len(output_type_list) == 1:
return_type = output_type_list[0]
kernel_output = 'dense_out'
output_create = f"""
{self.return_type} out;
auto dense_out = SetKernelOutput(out_meta, kernel_backend, &out);"""
elif len(output_type_list) > 1:
output_create = f"""
{self.return_type} out;"""
for i, out_type_item in enumerate(output_type_list):
kernel_output = kernel_output + f'dense_out_{i}, '
get_out_code = f'&out[{i}][0]' if out_type_item == 'Tensor' else f'&out[{i}]'
output_create = output_create + f"""
auto dense_out_{i} = SetKernelOutput(std::get<{i}>(out_meta), kernel_backend, {get_out_code});"""
kernel_output = kernel_output[:-2]
else:
raise ValueError(
"{} : Output error: the output should not be empty.".format(
self.backward_api))
return kernel_output, output_create
def gene_api_code(self): def gene_api_code(self):
if self.is_base_api: if self.is_base_api:
input_tensors, kernel_args = gen_utils.get_kernel_args( input_tensors, kernel_args = gen_utils.get_kernel_args(
self.args['inputs']['names'], self.args['attrs'], self.args['inputs']['names'], self.args['attrs'],
self.kernel['param']) self.kernel['param'])
outputs_args, output_create = gen_utils.gene_output( outputs_args, output_create = self.gene_output(
self.output_type) self.output_type_list)
return f""" return f"""
// {self.return_comment} // {self.return_comment}
{self.output_type} {self.backward_api}({self.args["args_define"]}) {{ {self.return_type} {self.backward_api}({self.args["args_define"]}) {{
{gen_utils.gene_kernel_select(self.backward_api, self.args['inputs']['names'], self.args['attrs'], self.kernel)} {gen_utils.gene_kernel_select(self.backward_api, self.args['inputs']['names'], self.args['attrs'], self.kernel)}
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
...@@ -143,7 +174,7 @@ class BackwardAPI: ...@@ -143,7 +174,7 @@ class BackwardAPI:
params_code = self.args["args_define"] params_code = self.args["args_define"]
return f""" return f"""
// {self.return_comment} // {self.return_comment}
{self.output_type} {self.backward_api}({params_code}) {{ {self.return_type} {self.backward_api}({params_code}) {{
return {invoke_code}; return {invoke_code};
}} }}
""" """
......
...@@ -124,7 +124,7 @@ def parse_output(api_name, output_config): ...@@ -124,7 +124,7 @@ def parse_output(api_name, output_config):
if len(temp_list) == 1: if len(temp_list) == 1:
out_type, out_name = parse_output_item(temp_list[0]) out_type, out_name = parse_output_item(temp_list[0])
return out_type, out_name return [out_type], out_name
else: else:
out_type_list = [] out_type_list = []
out_name_list = [] out_name_list = []
...@@ -133,8 +133,7 @@ def parse_output(api_name, output_config): ...@@ -133,8 +133,7 @@ def parse_output(api_name, output_config):
out_type_list.append(out_type) out_type_list.append(out_type)
out_name_list.append(out_name) out_name_list.append(out_name)
return "std::tuple<" + ",".join(out_type_list) + ">", ", ".join( return out_type_list, ", ".join(out_name_list)
out_name_list)
def gene_kernel_select(api, input_names, attrs, kernel) -> str: def gene_kernel_select(api, input_names, attrs, kernel) -> str:
...@@ -315,24 +314,3 @@ def get_kernel_args(input_names, attrs, kernel_param): ...@@ -315,24 +314,3 @@ def get_kernel_args(input_names, attrs, kernel_param):
else: else:
kernel_args = kernel_args + str(param) + ", " kernel_args = kernel_args + str(param) + ", "
return input_tensor_code, kernel_args[:-2] return input_tensor_code, kernel_args[:-2]
def gene_output(output_type):
kernel_output = ""
output_create = f"""
{output_type} out;"""
if output_type == 'Tensor' or output_type == 'std::vector<Tensor>':
kernel_output = 'dense_out'
output_create = output_create + """
auto dense_out = SetKernelOutput(out_meta, kernel_backend, &out);"""
elif re.match(r'std::tuple<.*>$', output_type):
out_num = output_type.count('Tensor')
for i in range(out_num):
kernel_output = kernel_output + f'dense_out_{i}, '
output_create = output_create + f"""
auto dense_out_{i} = SetKernelOutput(std::get<{i}>(out_meta), kernel_backend, &std::get<{i}>(out));"""
kernel_output = kernel_output[:-2]
return kernel_output, output_create
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册