未验证 提交 8b17207c 编写于 作者: W WangZhen 提交者: GitHub

Gen all Apis (#56526)

上级 8fe86ebb
......@@ -64,7 +64,7 @@ API_IMPL_TEMPLATE = """
{ret_type} {api_name}({args}){{
{in_combine}
{compute_op}
{out_slice}
{out_split}
{return_result}
}}
......@@ -73,34 +73,15 @@ API_IMPL_TEMPLATE = """
COMBINE_OP_TEMPLATE = """
auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<ir::CombineOp>({in_name});"""
SLICE_OP_TEMPLATE = """
auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<ir::SliceOp>({in_name});"""
SPLIT_OP_TEMPLATE = """
auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<ir::SplitOp>({in_name});"""
COMPUTE_OP_TEMPLATE = """
paddle::dialect::{op_class_name} {op_inst_name} = APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::{op_class_name}>({args});"""
API_LIST = [
'add_n',
'mean',
'sum',
'divide',
'full',
'tanh_grad',
'mean_grad',
'concat',
'add',
'multiply',
'elementwise_pow',
'scale',
'reshape',
'expand',
'tile',
'add_grad',
'divide_grad',
'sum_grad',
]
OP_RESULT = 'ir::OpResult'
VECTOR_TYPE = 'ir::VectorType'
PD_MANUAL_OP_LIST = ['add_n']
def get_op_class_name(op_name):
......@@ -142,56 +123,70 @@ class CodeGen:
ret.append(f'{self._type_map[type]} {name}')
return ', '.join(ret)
def _gen_api_attrs(self, op_info, with_default):
def _gen_api_attrs(self, op_info, with_default, is_mutable_attr):
name_list = op_info.attribute_name_list
type_list = op_info.attribute_build_arg_type_list
default_value_list = op_info.attribute_default_value_list
mutable_name_list = op_info.mutable_attribute_name_list
assert len(name_list) == len(type_list) == len(default_value_list)
ret = []
no_mutable_attr = []
mutable_attr = []
for name, type, default_value in zip(
name_list, type_list, default_value_list
):
if is_mutable_attr and name in mutable_name_list:
mutable_attr.append(f'{OP_RESULT} {name}')
continue
if with_default and default_value is not None:
if type in ['float', 'double']:
default_value = default_value.strip('"')
ret.append(
no_mutable_attr.append(
'{type} {name} = {default_value}'.format(
type=type, name=name, default_value=default_value
)
)
else:
ret.append(f'{type} {name}')
return ', '.join(ret)
no_mutable_attr.append(f'{type} {name}')
return ', '.join(mutable_attr + no_mutable_attr)
def _gen_api_args(self, op_info, with_default_attr):
def _gen_api_args(self, op_info, with_default_attr, is_mutable_attr):
inputs = self._gen_api_inputs(op_info)
attrs = self._gen_api_attrs(op_info, with_default_attr)
attrs = self._gen_api_attrs(op_info, with_default_attr, is_mutable_attr)
return (inputs + ', ' + attrs).strip(', ')
def _gen_ret_type(self, op_info):
type_list = op_info.output_type_list
assert len(type_list) >= 1
if len(type_list) > 1:
return 'std::tuple<{}>'.format(
', '.join([self._type_map[type] for type in type_list])
)
elif len(type_list) == 1:
return self._type_map[type_list[0]]
elif len(type_list) == 0:
return 'void'
def _gen_one_declare(self, op_info, op_name):
def _gen_one_declare(self, op_info, op_name, is_mutable_attr):
return API_DECLARE_TEMPLATE.format(
ret_type=self._gen_ret_type(op_info),
api_name=op_name,
args=self._gen_api_args(op_info, True),
args=self._gen_api_args(op_info, True, is_mutable_attr),
)
def _gen_h_file(self, op_info_items, namespaces, h_file_path):
declare_str = ''
for op_info in op_info_items:
for op_name in op_info.op_phi_name:
if op_name not in API_LIST:
# NOTE:When infer_meta_func is None, the Build() function generated in pd_op
# is wrong, so temporarily skip the automatic generation of these APIs
if (
op_info.infer_meta_func is None
and op_name not in PD_MANUAL_OP_LIST
):
continue
declare_str += self._gen_one_declare(op_info, op_name)
declare_str += self._gen_one_declare(op_info, op_name, False)
if len(op_info.mutable_attribute_name_list) > 0:
declare_str += self._gen_one_declare(op_info, op_name, True)
body = declare_str
for namespace in reversed(namespaces):
body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body)
......@@ -218,9 +213,13 @@ class CodeGen:
combine_op_list.append(None)
return combine_op, combine_op_list
def _gen_compute_op_args(self, op_info, in_combine_op_list):
def _gen_compute_op_args(
self, op_info, in_combine_op_list, is_mutable_attr
):
input_name_list = op_info.input_name_list
attribute_name_list = op_info.attribute_name_list
all_attr_list = op_info.attribute_name_list
no_mutable_attr_list = op_info.non_mutable_attribute_name_list
mutable_attr_list = op_info.mutable_attribute_name_list
assert len(input_name_list) == len(in_combine_op_list)
ret = []
for input_name, combine_op in zip(input_name_list, in_combine_op_list):
......@@ -228,61 +227,69 @@ class CodeGen:
ret.append(input_name)
else:
ret.append(f'{combine_op}.out()')
ret += list(attribute_name_list)
if is_mutable_attr:
ret += list(mutable_attr_list + no_mutable_attr_list)
else:
ret += list(all_attr_list)
return ', '.join(ret)
def _gen_compute_op(self, op_info, op_name, in_combine_op_list):
def _gen_compute_op(
self, op_info, op_name, in_combine_op_list, is_mutable_attr
):
op_class_name = to_pascal_case(op_name) + 'Op'
op_inst_name = op_name + '_op'
return (
COMPUTE_OP_TEMPLATE.format(
op_class_name=op_class_name,
op_inst_name=op_inst_name,
args=self._gen_compute_op_args(op_info, in_combine_op_list),
args=self._gen_compute_op_args(
op_info, in_combine_op_list, is_mutable_attr
),
),
op_inst_name,
)
def _gen_out_slice_and_ret_list(self, op_info, op_inst_name):
def _gen_out_split_and_ret_list(self, op_info, op_inst_name):
name_list = op_info.output_name_list
type_list = op_info.output_type_list
slice_op_str = ''
split_op_str = ''
ret_list = []
for i, (name, type) in enumerate(zip(name_list, type_list)):
if VECTOR_TYPE in type:
slice_op_name = f'{name}_slice_op'
slice_op_str += SLICE_OP_TEMPLATE.format(
op_name=slice_op_name, in_name=f'{op_inst_name}.result({i})'
split_op_name = f'{name}_split_op'
split_op_str += SPLIT_OP_TEMPLATE.format(
op_name=split_op_name, in_name=f'{op_inst_name}.result({i})'
)
ret_list.append(f'{slice_op_name}.outputs()')
ret_list.append(f'{split_op_name}.outputs()')
else:
ret_list.append(f'{op_inst_name}.result({i})')
return slice_op_str, ret_list
return split_op_str, ret_list
def _gen_return_result(self, ret_list):
assert len(ret_list) >= 1
if len(ret_list) > 1:
return 'return std::make_tuple({});'.format(', '.join(ret_list))
else:
elif len(ret_list) == 1:
return f'return {ret_list[0]};'
elif len(ret_list) == 0:
return 'return;'
def _gen_one_impl(self, op_info, op_name):
def _gen_one_impl(self, op_info, op_name, is_mutable_attr):
in_combine, in_combine_op_list = self._gen_in_combine(op_info)
compute_op, op_inst_name = self._gen_compute_op(
op_info, op_name, in_combine_op_list
op_info, op_name, in_combine_op_list, is_mutable_attr
)
out_slice, ret_list = self._gen_out_slice_and_ret_list(
out_split, ret_list = self._gen_out_split_and_ret_list(
op_info, op_inst_name
)
ret = API_IMPL_TEMPLATE.format(
ret_type=self._gen_ret_type(op_info),
api_name=op_name,
args=self._gen_api_args(op_info, False),
args=self._gen_api_args(op_info, False, is_mutable_attr),
in_combine=in_combine,
compute_op=compute_op,
out_slice=out_slice,
out_split=out_split,
return_result=self._gen_return_result(ret_list),
)
......@@ -293,9 +300,16 @@ class CodeGen:
impl_str = ''
for op_info in op_info_items:
for op_name in op_info.op_phi_name:
if op_name not in API_LIST:
# NOTE:When infer_meta_func is None, the Build() function generated in pd_op
# is wrong, so temporarily skip the automatic generation of these APIs
if (
op_info.infer_meta_func is None
and op_name not in PD_MANUAL_OP_LIST
):
continue
impl_str += self._gen_one_impl(op_info, op_name)
impl_str += self._gen_one_impl(op_info, op_name, False)
if len(op_info.mutable_attribute_name_list) > 0:
impl_str += self._gen_one_impl(op_info, op_name, True)
body = impl_str
for namespace in reversed(namespaces):
body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body)
......
......@@ -18,18 +18,5 @@
#include "paddle/ir/core/builtin_op.h"
namespace paddle {
namespace dialect {
std::vector<ir::OpResult> concat_grad(std::vector<ir::OpResult> x,
ir::OpResult out_grad,
ir::OpResult axis) {
auto combine_op =
APIBuilder::Instance().GetBuilder()->Build<ir::CombineOp>(x);
paddle::dialect::ConcatGradOp concat_grad_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::ConcatGradOp>(
combine_op.out(), out_grad, axis);
auto split_op = APIBuilder::Instance().GetBuilder()->Build<ir::SplitOp>(
concat_grad_op.result(0));
return split_op.outputs();
}
} // namespace dialect
namespace dialect {} // namespace dialect
} // namespace paddle
......@@ -21,10 +21,5 @@
#include "paddle/phi/common/place.h"
namespace paddle {
namespace dialect {
std::vector<ir::OpResult> concat_grad(std::vector<ir::OpResult> x,
ir::OpResult out_grad,
ir::OpResult axis);
} // namespace dialect
namespace dialect {} // namespace dialect
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册