未验证 提交 8b17207c 编写于 作者: W WangZhen 提交者: GitHub

Gen all Apis (#56526)

上级 8fe86ebb
...@@ -64,7 +64,7 @@ API_IMPL_TEMPLATE = """ ...@@ -64,7 +64,7 @@ API_IMPL_TEMPLATE = """
{ret_type} {api_name}({args}){{ {ret_type} {api_name}({args}){{
{in_combine} {in_combine}
{compute_op} {compute_op}
{out_slice} {out_split}
{return_result} {return_result}
}} }}
...@@ -73,34 +73,15 @@ API_IMPL_TEMPLATE = """ ...@@ -73,34 +73,15 @@ API_IMPL_TEMPLATE = """
COMBINE_OP_TEMPLATE = """ COMBINE_OP_TEMPLATE = """
auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<ir::CombineOp>({in_name});""" auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<ir::CombineOp>({in_name});"""
SLICE_OP_TEMPLATE = """ SPLIT_OP_TEMPLATE = """
auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<ir::SliceOp>({in_name});""" auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<ir::SplitOp>({in_name});"""
COMPUTE_OP_TEMPLATE = """ COMPUTE_OP_TEMPLATE = """
paddle::dialect::{op_class_name} {op_inst_name} = APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::{op_class_name}>({args});""" paddle::dialect::{op_class_name} {op_inst_name} = APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::{op_class_name}>({args});"""
API_LIST = [
'add_n',
'mean',
'sum',
'divide',
'full',
'tanh_grad',
'mean_grad',
'concat',
'add',
'multiply',
'elementwise_pow',
'scale',
'reshape',
'expand',
'tile',
'add_grad',
'divide_grad',
'sum_grad',
]
OP_RESULT = 'ir::OpResult' OP_RESULT = 'ir::OpResult'
VECTOR_TYPE = 'ir::VectorType' VECTOR_TYPE = 'ir::VectorType'
PD_MANUAL_OP_LIST = ['add_n']
def get_op_class_name(op_name): def get_op_class_name(op_name):
...@@ -142,56 +123,70 @@ class CodeGen: ...@@ -142,56 +123,70 @@ class CodeGen:
ret.append(f'{self._type_map[type]} {name}') ret.append(f'{self._type_map[type]} {name}')
return ', '.join(ret) return ', '.join(ret)
def _gen_api_attrs(self, op_info, with_default): def _gen_api_attrs(self, op_info, with_default, is_mutable_attr):
name_list = op_info.attribute_name_list name_list = op_info.attribute_name_list
type_list = op_info.attribute_build_arg_type_list type_list = op_info.attribute_build_arg_type_list
default_value_list = op_info.attribute_default_value_list default_value_list = op_info.attribute_default_value_list
mutable_name_list = op_info.mutable_attribute_name_list
assert len(name_list) == len(type_list) == len(default_value_list) assert len(name_list) == len(type_list) == len(default_value_list)
ret = [] no_mutable_attr = []
mutable_attr = []
for name, type, default_value in zip( for name, type, default_value in zip(
name_list, type_list, default_value_list name_list, type_list, default_value_list
): ):
if is_mutable_attr and name in mutable_name_list:
mutable_attr.append(f'{OP_RESULT} {name}')
continue
if with_default and default_value is not None: if with_default and default_value is not None:
if type in ['float', 'double']: if type in ['float', 'double']:
default_value = default_value.strip('"') default_value = default_value.strip('"')
ret.append( no_mutable_attr.append(
'{type} {name} = {default_value}'.format( '{type} {name} = {default_value}'.format(
type=type, name=name, default_value=default_value type=type, name=name, default_value=default_value
) )
) )
else: else:
ret.append(f'{type} {name}') no_mutable_attr.append(f'{type} {name}')
return ', '.join(ret) return ', '.join(mutable_attr + no_mutable_attr)
def _gen_api_args(self, op_info, with_default_attr): def _gen_api_args(self, op_info, with_default_attr, is_mutable_attr):
inputs = self._gen_api_inputs(op_info) inputs = self._gen_api_inputs(op_info)
attrs = self._gen_api_attrs(op_info, with_default_attr) attrs = self._gen_api_attrs(op_info, with_default_attr, is_mutable_attr)
return (inputs + ', ' + attrs).strip(', ') return (inputs + ', ' + attrs).strip(', ')
def _gen_ret_type(self, op_info): def _gen_ret_type(self, op_info):
type_list = op_info.output_type_list type_list = op_info.output_type_list
assert len(type_list) >= 1
if len(type_list) > 1: if len(type_list) > 1:
return 'std::tuple<{}>'.format( return 'std::tuple<{}>'.format(
', '.join([self._type_map[type] for type in type_list]) ', '.join([self._type_map[type] for type in type_list])
) )
elif len(type_list) == 1: elif len(type_list) == 1:
return self._type_map[type_list[0]] return self._type_map[type_list[0]]
elif len(type_list) == 0:
return 'void'
def _gen_one_declare(self, op_info, op_name): def _gen_one_declare(self, op_info, op_name, is_mutable_attr):
return API_DECLARE_TEMPLATE.format( return API_DECLARE_TEMPLATE.format(
ret_type=self._gen_ret_type(op_info), ret_type=self._gen_ret_type(op_info),
api_name=op_name, api_name=op_name,
args=self._gen_api_args(op_info, True), args=self._gen_api_args(op_info, True, is_mutable_attr),
) )
def _gen_h_file(self, op_info_items, namespaces, h_file_path): def _gen_h_file(self, op_info_items, namespaces, h_file_path):
declare_str = '' declare_str = ''
for op_info in op_info_items: for op_info in op_info_items:
for op_name in op_info.op_phi_name: for op_name in op_info.op_phi_name:
if op_name not in API_LIST: # NOTE:When infer_meta_func is None, the Build() function generated in pd_op
# is wrong, so temporarily skip the automatic generation of these APIs
if (
op_info.infer_meta_func is None
and op_name not in PD_MANUAL_OP_LIST
):
continue continue
declare_str += self._gen_one_declare(op_info, op_name) declare_str += self._gen_one_declare(op_info, op_name, False)
if len(op_info.mutable_attribute_name_list) > 0:
declare_str += self._gen_one_declare(op_info, op_name, True)
body = declare_str body = declare_str
for namespace in reversed(namespaces): for namespace in reversed(namespaces):
body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body) body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body)
...@@ -218,9 +213,13 @@ class CodeGen: ...@@ -218,9 +213,13 @@ class CodeGen:
combine_op_list.append(None) combine_op_list.append(None)
return combine_op, combine_op_list return combine_op, combine_op_list
def _gen_compute_op_args(self, op_info, in_combine_op_list): def _gen_compute_op_args(
self, op_info, in_combine_op_list, is_mutable_attr
):
input_name_list = op_info.input_name_list input_name_list = op_info.input_name_list
attribute_name_list = op_info.attribute_name_list all_attr_list = op_info.attribute_name_list
no_mutable_attr_list = op_info.non_mutable_attribute_name_list
mutable_attr_list = op_info.mutable_attribute_name_list
assert len(input_name_list) == len(in_combine_op_list) assert len(input_name_list) == len(in_combine_op_list)
ret = [] ret = []
for input_name, combine_op in zip(input_name_list, in_combine_op_list): for input_name, combine_op in zip(input_name_list, in_combine_op_list):
...@@ -228,61 +227,69 @@ class CodeGen: ...@@ -228,61 +227,69 @@ class CodeGen:
ret.append(input_name) ret.append(input_name)
else: else:
ret.append(f'{combine_op}.out()') ret.append(f'{combine_op}.out()')
ret += list(attribute_name_list) if is_mutable_attr:
ret += list(mutable_attr_list + no_mutable_attr_list)
else:
ret += list(all_attr_list)
return ', '.join(ret) return ', '.join(ret)
def _gen_compute_op(self, op_info, op_name, in_combine_op_list): def _gen_compute_op(
self, op_info, op_name, in_combine_op_list, is_mutable_attr
):
op_class_name = to_pascal_case(op_name) + 'Op' op_class_name = to_pascal_case(op_name) + 'Op'
op_inst_name = op_name + '_op' op_inst_name = op_name + '_op'
return ( return (
COMPUTE_OP_TEMPLATE.format( COMPUTE_OP_TEMPLATE.format(
op_class_name=op_class_name, op_class_name=op_class_name,
op_inst_name=op_inst_name, op_inst_name=op_inst_name,
args=self._gen_compute_op_args(op_info, in_combine_op_list), args=self._gen_compute_op_args(
op_info, in_combine_op_list, is_mutable_attr
),
), ),
op_inst_name, op_inst_name,
) )
def _gen_out_slice_and_ret_list(self, op_info, op_inst_name): def _gen_out_split_and_ret_list(self, op_info, op_inst_name):
name_list = op_info.output_name_list name_list = op_info.output_name_list
type_list = op_info.output_type_list type_list = op_info.output_type_list
slice_op_str = '' split_op_str = ''
ret_list = [] ret_list = []
for i, (name, type) in enumerate(zip(name_list, type_list)): for i, (name, type) in enumerate(zip(name_list, type_list)):
if VECTOR_TYPE in type: if VECTOR_TYPE in type:
slice_op_name = f'{name}_slice_op' split_op_name = f'{name}_split_op'
slice_op_str += SLICE_OP_TEMPLATE.format( split_op_str += SPLIT_OP_TEMPLATE.format(
op_name=slice_op_name, in_name=f'{op_inst_name}.result({i})' op_name=split_op_name, in_name=f'{op_inst_name}.result({i})'
) )
ret_list.append(f'{slice_op_name}.outputs()') ret_list.append(f'{split_op_name}.outputs()')
else: else:
ret_list.append(f'{op_inst_name}.result({i})') ret_list.append(f'{op_inst_name}.result({i})')
return slice_op_str, ret_list return split_op_str, ret_list
def _gen_return_result(self, ret_list): def _gen_return_result(self, ret_list):
assert len(ret_list) >= 1
if len(ret_list) > 1: if len(ret_list) > 1:
return 'return std::make_tuple({});'.format(', '.join(ret_list)) return 'return std::make_tuple({});'.format(', '.join(ret_list))
else: elif len(ret_list) == 1:
return f'return {ret_list[0]};' return f'return {ret_list[0]};'
elif len(ret_list) == 0:
return 'return;'
def _gen_one_impl(self, op_info, op_name): def _gen_one_impl(self, op_info, op_name, is_mutable_attr):
in_combine, in_combine_op_list = self._gen_in_combine(op_info) in_combine, in_combine_op_list = self._gen_in_combine(op_info)
compute_op, op_inst_name = self._gen_compute_op( compute_op, op_inst_name = self._gen_compute_op(
op_info, op_name, in_combine_op_list op_info, op_name, in_combine_op_list, is_mutable_attr
) )
out_slice, ret_list = self._gen_out_slice_and_ret_list( out_split, ret_list = self._gen_out_split_and_ret_list(
op_info, op_inst_name op_info, op_inst_name
) )
ret = API_IMPL_TEMPLATE.format( ret = API_IMPL_TEMPLATE.format(
ret_type=self._gen_ret_type(op_info), ret_type=self._gen_ret_type(op_info),
api_name=op_name, api_name=op_name,
args=self._gen_api_args(op_info, False), args=self._gen_api_args(op_info, False, is_mutable_attr),
in_combine=in_combine, in_combine=in_combine,
compute_op=compute_op, compute_op=compute_op,
out_slice=out_slice, out_split=out_split,
return_result=self._gen_return_result(ret_list), return_result=self._gen_return_result(ret_list),
) )
...@@ -293,9 +300,16 @@ class CodeGen: ...@@ -293,9 +300,16 @@ class CodeGen:
impl_str = '' impl_str = ''
for op_info in op_info_items: for op_info in op_info_items:
for op_name in op_info.op_phi_name: for op_name in op_info.op_phi_name:
if op_name not in API_LIST: # NOTE:When infer_meta_func is None, the Build() function generated in pd_op
# is wrong, so temporarily skip the automatic generation of these APIs
if (
op_info.infer_meta_func is None
and op_name not in PD_MANUAL_OP_LIST
):
continue continue
impl_str += self._gen_one_impl(op_info, op_name) impl_str += self._gen_one_impl(op_info, op_name, False)
if len(op_info.mutable_attribute_name_list) > 0:
impl_str += self._gen_one_impl(op_info, op_name, True)
body = impl_str body = impl_str
for namespace in reversed(namespaces): for namespace in reversed(namespaces):
body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body) body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body)
......
...@@ -18,18 +18,5 @@ ...@@ -18,18 +18,5 @@
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {} // namespace dialect
std::vector<ir::OpResult> concat_grad(std::vector<ir::OpResult> x,
ir::OpResult out_grad,
ir::OpResult axis) {
auto combine_op =
APIBuilder::Instance().GetBuilder()->Build<ir::CombineOp>(x);
paddle::dialect::ConcatGradOp concat_grad_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::ConcatGradOp>(
combine_op.out(), out_grad, axis);
auto split_op = APIBuilder::Instance().GetBuilder()->Build<ir::SplitOp>(
concat_grad_op.result(0));
return split_op.outputs();
}
} // namespace dialect
} // namespace paddle } // namespace paddle
...@@ -21,10 +21,5 @@ ...@@ -21,10 +21,5 @@
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {} // namespace dialect
std::vector<ir::OpResult> concat_grad(std::vector<ir::OpResult> x,
ir::OpResult out_grad,
ir::OpResult axis);
} // namespace dialect
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册