未验证 提交 3a452e4e 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine IR builder and throw methods (#54396)

* refine code

* refine code

* refine code

* refine code

* refine code

* refine code

* refine code

* fix bug

* refine code

* refine code

* refine code

* refine code

* refine code

* delete unused code

* delete unused code

* refine code
上级 b62b384b
......@@ -57,6 +57,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
static constexpr uint32_t attributes_num = {attribute_num};
static OpInfoTuple GetOpInfo();
static void Build({build_args});
{build_mutable_attr_is_input}
static void Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes);
{get_inputs_and_outputs}
{exclusive_interface}
......@@ -94,6 +95,7 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/api/lib/utils/allocator.h"
{input}
"""
......@@ -112,9 +114,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{
return std::make_tuple(inputs, attributes, outputs, run_time_info);
}}
"""
CONSTRUCT_INPUT_INFO_TEMPLATE = (
"""OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer})"""
)
CONSTRUCT_INPUT_INFO_TEMPLATE = """OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer}, {is_mutable_attribute})"""
CONSTRUCT_OUTPUT_INFO_TEMPLATE = (
"""OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})"""
)
......@@ -125,6 +125,7 @@ CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = (
# build
OP_BUILD_TEMPLATE = """
void {op_name}::Build({build_args}) {{
{build_mutable_attributes}
{build_inputs}
{build_attributes}
{build_outputs}
......@@ -303,6 +304,7 @@ class OpInfoParser:
self.output_type_list,
self.output_optional_list,
)
# parse attributes
self.attr_types_map = {
'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'],
......@@ -313,7 +315,7 @@ class OpInfoParser:
'Scalar(dobule)': ['ir::DoubleAttribute', 'dobule'],
'Scalar[]': [
'ir::ArrayAttribute<paddle::dialect::ScalarAttribute>',
'std::vector<Scalar>',
'const std::vector<Scalar>&',
],
'int': ['ir::Int32_tAttribute', 'int'],
'int32_t': ['ir::Int32_tAttribute', 'int32_t'],
......@@ -323,18 +325,18 @@ class OpInfoParser:
'float': ['ir::FloatAttribute', 'float'],
'float[]': [
'ir::ArrayAttribute<ir::FloatAttribute>',
'std::vector<float>',
'const std::vector<float>&',
],
'double': ['ir::DoubleAttribute', 'double'],
'bool': ['ir::BoolAttribute', 'bool'],
'bool[]': [
'ir::ArrayAttribute<ir::BoolAttribute>',
'std::vecot<bool>',
'const std::vecot<bool>&',
],
'str': ['ir::StrAttribute', 'std::string'],
'str[]': [
'ir::ArrayAttribute<ir::StrAttribute>',
'std::vector<std::string>',
'const std::vector<std::string>&',
],
'Place': ['paddle::dialect::PlaceAttribute', 'Place'],
'DataLayout': [
......@@ -344,11 +346,11 @@ class OpInfoParser:
'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'],
'int64_t[]': [
'ir::ArrayAttribute<ir::Int64_tAttribute>',
'std::vector<int64_t>',
'const std::vector<int64_t>&',
],
'int[]': [
'ir::ArrayAttribute<ir::Int32_tAttribute>',
'std::vector<int>',
'const std::vector<int>&',
],
}
self.attribute_name_list = self.parse_attribute_name_list()
......@@ -368,6 +370,14 @@ class OpInfoParser:
self.mutable_attribute_type_list,
) = self.parse_mutable_attribute()
(
self.non_mutable_attribute_name_list,
self.non_mutable_attribute_type_list,
self.non_mutable_attribute_data_type_list,
self.non_mutable_attribute_build_arg_type_list,
self.non_mutable_attribute_default_value_list,
) = self.parse_non_nutable_attribute()
# parse infermeta && kernel
self.infer_meta_map = self.parse_infer_meta_map()
self.kernel_map = self.parse_kernel_map()
......@@ -423,6 +433,12 @@ class OpInfoParser:
mutable_attribute_type_list.append(
["ir::StrAttribute", "std::string"]
)
else:
if (
scalar_attr == "depth"
and self.op_phi_name[0] == "one_hot"
):
mutable_attribute_name_list.append("num_classes")
else:
mutable_attribute_name_list.append(scalar_attr)
mutable_attribute_type_list.append(
......@@ -459,7 +475,55 @@ class OpInfoParser:
],
]
)
return mutable_attribute_name_list, mutable_attribute_type_list
sorted_mutable_attribute_name_list = []
sorted_mutable_attribute_type_list = []
for attr_name in self.attribute_name_list:
if attr_name in mutable_attribute_name_list:
sorted_mutable_attribute_name_list.append(attr_name)
sorted_mutable_attribute_type_list.append(
mutable_attribute_type_list[
mutable_attribute_name_list.index(attr_name)
]
)
return (
sorted_mutable_attribute_name_list,
sorted_mutable_attribute_type_list,
)
def parse_non_nutable_attribute(self):
op_non_mutable_attribute_name_list = []
op_non_mutable_attribute_type_list = []
op_non_mutable_attribute_data_type_list = []
op_non_mutable_attribute_build_arg_type_list = []
op_non_mutable_attribute_default_value_list = []
for idx in range(len(self.attribute_name_list)):
if (
self.attribute_name_list[idx]
not in self.mutable_attribute_name_list
):
op_non_mutable_attribute_name_list.append(
self.attribute_name_list[idx]
)
op_non_mutable_attribute_type_list.append(
self.attribute_type_list[idx]
)
op_non_mutable_attribute_data_type_list.append(
self.attribute_data_type_list[idx]
)
op_non_mutable_attribute_build_arg_type_list.append(
self.attribute_build_arg_type_list[idx]
)
op_non_mutable_attribute_default_value_list.append(
self.attribute_default_value_list[idx]
)
return (
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_non_mutable_attribute_data_type_list,
op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
)
def parse_input_name_list(self):
name_list = []
......@@ -649,20 +713,47 @@ def to_pascal_case(s):
# =====================================
def GenBuildInputArgsStr(
op_input_name_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
op_mutable_attribute_name_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
for_func_define=True,
mutable_attr_is_input=False,
):
'''
Example: ir::OperationArgument &argument, ir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={}
Example: ir::Builder &builder, ir::OperationArgument &argument, ir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={}
'''
# add inputs
build_args_str = "ir::OperationArgument &argument"
build_args_str = "ir::Builder &builder, ir::OperationArgument &argument"
if len(op_input_name_list) > 0:
for input_name in op_input_name_list:
build_args_str += ", ir::OpResult " + input_name + "_"
if not mutable_attr_is_input:
# add attributes
for attr_idx in range(len(op_attribute_name_list)):
build_args_str += (
", "
+ op_attribute_build_arg_type_list[attr_idx]
+ " "
+ op_attribute_name_list[attr_idx]
)
if for_func_define:
if op_attribute_default_value_list[attr_idx] is not None:
default_value = op_attribute_default_value_list[attr_idx]
if (
op_attribute_build_arg_type_list[attr_idx]
!= "std::string"
):
if default_value[0] == "'" or default_value[0] == '"':
default_value = default_value[1:]
if default_value[-1] == "'" or default_value[-1] == '"':
default_value = default_value[0:-1]
build_args_str += "=" + default_value
else:
# add mutable attributes as inputs
if len(op_mutable_attribute_name_list) > 0:
for mutable_attr in op_mutable_attribute_name_list:
......@@ -693,9 +784,58 @@ def GenBuildInputArgsStr(
if default_value[-1] == "'" or default_value[-1] == '"':
default_value = default_value[0:-1]
build_args_str += "=" + default_value
return build_args_str
mutable_attribute_phi_type_maps = {
'int': 'phi::DataType::INT32',
'int64_t': 'phi::DataType::INT64',
'float': 'phi::DataType::FLOAT32',
'std::vector<int64_t>': 'phi::DataType::INT64',
'const std::vector<int64_t>&': 'phi::DataType::INT64',
}
def GenBuildInserFullForMutableAttribute(
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
):
build_mutable_attribute = ""
BUILD_INTARRAY_ATTRIBUTE_TEMPLATE = """ // Generate int_array mutable attribute: {attr_name}
paddle::dialect::FullIntArrayOp full_{attr_name}_op = builder.Build<paddle::dialect::FullIntArrayOp>({attr_name}, {phi_dtype}, phi::CPUPlace());
ir::OpResult {attr_name}_ = full_{attr_name}_op->GetResultByIndex(0);
"""
BUILD_SCALAR_ATTRIBUTE_TEMPLATE = """ // Generate scalar mutable attribute: {attr_name}
paddle::dialect::FullOp full_{attr_name}_op = builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{{1}}, {attr_name}, {phi_dtype}, phi::CPUPlace());
ir::OpResult {attr_name}_ = full_{attr_name}_op->GetResultByIndex(0);
"""
for idx in range(len(op_mutable_attribute_name_list)):
attr_name = op_mutable_attribute_name_list[idx]
attr_type = op_mutable_attribute_type_list[idx][0]
if attr_name in op_attribute_name_list:
phi_dtype = mutable_attribute_phi_type_maps[
op_attribute_build_arg_type_list[
op_attribute_name_list.index(attr_name)
]
]
else:
phi_dtype = mutable_attribute_phi_type_maps[
op_mutable_attribute_type_list[idx][1]
]
if attr_type == "paddle::dialect::IntArrayAttribute":
build_mutable_attribute += BUILD_INTARRAY_ATTRIBUTE_TEMPLATE.format(
attr_name=attr_name, phi_dtype=phi_dtype
)
else:
build_mutable_attribute += BUILD_SCALAR_ATTRIBUTE_TEMPLATE.format(
attr_name=attr_name, phi_dtype=phi_dtype
)
return build_mutable_attribute
def GenBuildInputs(op_input_name_list, op_mutable_attribute_name_list):
BUILD_INPUT_TEMPLATE = """ std::vector<ir::OpResult> argument_inputs = {{{inputs_args}}};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end());
......@@ -805,38 +945,39 @@ def GenBuildOutputs(
op_output_type_list,
op_output_size_list,
op_infer_meta_map,
mutable_attr_is_input=False,
):
build_output_str = ' VLOG(4) << "Builder construction outputs";\n'
CREATE_INPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
dense_{name}.set_meta(
CREATE_INPUT_METATENSOR_TEMPLATE = """
VLOG(4) << "Builder construction dense_{name}";
phi::DenseTensor dense_{name}(std::make_unique<paddle::experimental::DefaultAllocator>(paddle::platform::CPUPlace()).get(),
phi::DenseTensorMeta(TransToPhiDataType({name}.dtype()),
{name}.dims(),
{name}.data_layout(),
{name}.lod(),
{name}.offset())
);
{name}.offset()));
VLOG(4) << "Builder construction meta_{name}";
phi::MetaTensor meta_{name}(&dense_{name});
"""
CREATE_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector<phi::DenseTensor> vec_dense_{name}({name}.size(), phi::DenseTensor());
std::vector<phi::MetaTensor> vec_meta_{name};
CREATE_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector<phi::DenseTensor> vec_dense_{name};
for (size_t i=0; i < static_cast<size_t>({name}.size()); i++) {{
vec_dense_{name}[i].set_meta(
vec_dense_{name}.push_back(phi::DenseTensor(std::make_unique<paddle::experimental::DefaultAllocator>(paddle::platform::CPUPlace()).get(),
phi::DenseTensorMeta(TransToPhiDataType({name}[i].dyn_cast<paddle::dialect::DenseTensorType>().dtype()),
{name}[i].dyn_cast<paddle::dialect::DenseTensorType>().dims(),
{name}[i].dyn_cast<paddle::dialect::DenseTensorType>().data_layout(),
{name}[i].dyn_cast<paddle::dialect::DenseTensorType>().lod(),
{name}[i].dyn_cast<paddle::dialect::DenseTensorType>().offset())
);
{name}[i].dyn_cast<paddle::dialect::DenseTensorType>().offset())));
}}
std::vector<phi::MetaTensor> vec_meta_{name};
for (size_t i=0; i < vec_dense_{name}.size(); i++) {{
vec_meta_{name}.push_back(phi::MetaTensor(&vec_dense_{name}[i]));
}}
std::vector<const phi::MetaTensor*> meta_{name};
for (size_t i=0; i < static_cast<size_t>(vec_meta_{name}.size()); i++) {{
meta_{name}.push_back(&vec_meta_{name}[i]);
}}
"""
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<ir::ConstantOp>().value().dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<ir::ConstantOp>().value().dyn_cast<{ir_type}>().data(); (void){name};\n"""
CREATE_STRING_MUTABLE_ATTRIBUE_TEMPLATE = """ std::string {name} = {name}_.owner()->dyn_cast<ir::ConstantOp>().value().dyn_cast<ir::StrAttribute>().data(); (void){name};\n"""
CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
phi::MetaTensor meta_{name}(&dense_{name});
......@@ -863,31 +1004,6 @@ def GenBuildOutputs(
build_output_str += " paddle::dialect::DenseTensorType {name} = {name}_.type().dyn_cast<paddle::dialect::DenseTensorType>(); (void){name};\n".format(
name=op_input_name_list[idx]
)
# Prepare mutable attributes
for idx in range(len(op_mutable_attribute_name_list)):
attr_dtype = op_mutable_attribute_type_list[idx]
# int_array
if attr_dtype[0] == "paddle::dialect::IntArrayAttribute":
build_output_str += (
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx]
)
)
# scalar
elif attr_dtype[0] == "paddle::dialect::ScalarAttribute":
build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx],
dtype=attr_dtype[1],
ir_type=scalar_type_maps[attr_dtype[1]],
)
# string
elif attr_dtype[0] == "ir::StrAttribute":
build_output_str += CREATE_STRING_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx]
)
else:
assert "mutable attribtue type is not right."
build_output_str += "\n"
# Prepare inputs_meta_tensor & attributes for infer meta
infer_meta_args = []
......@@ -979,6 +1095,95 @@ def GenBuildOutputs(
return build_output_str
def GenBuild(
op_class_name,
op_input_name_list,
op_input_type_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
op_output_name_list,
op_output_type_list,
op_output_size_list,
op_infer_meta_map,
muta_attr_is_input=False,
):
build_args_for_declare = ""
build_func = ""
build_args_for_declare = GenBuildInputArgsStr(
op_input_name_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
op_mutable_attribute_name_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
True,
muta_attr_is_input,
)
build_args_for_define = GenBuildInputArgsStr(
op_input_name_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
op_mutable_attribute_name_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
False,
muta_attr_is_input,
)
inset_full_for_mutable_attributes_str = ""
if not muta_attr_is_input:
inset_full_for_mutable_attributes_str = (
GenBuildInserFullForMutableAttribute(
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
)
)
build_inputs_str = GenBuildInputs(
op_input_name_list, op_mutable_attribute_name_list
)
build_attributes_str = GenBuildAttributes(
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
)
build_outputs_str = GenBuildOutputs(
op_input_name_list,
op_input_type_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_output_name_list,
op_output_type_list,
op_output_size_list,
op_infer_meta_map,
False,
)
build_func = OP_BUILD_TEMPLATE.format(
op_name=op_class_name,
build_args=build_args_for_define,
build_mutable_attributes=inset_full_for_mutable_attributes_str,
build_inputs=build_inputs_str,
build_attributes=build_attributes_str,
build_outputs=build_outputs_str,
)
return (build_args_for_declare, build_func)
def OpGenerator(
op_yaml_files,
op_compat_yaml_file,
......@@ -1032,31 +1237,22 @@ def OpGenerator(
op_attribute_data_type_list = op_info.attribute_data_type_list
op_attribute_build_arg_type_list = op_info.attribute_build_arg_type_list
op_attribute_default_value_list = op_info.attribute_default_value_list
op_non_mutable_attribute_name_list = []
op_non_mutable_attribute_type_list = []
op_non_mutable_attribute_data_type_list = []
op_non_mutable_attribute_build_arg_type_list = []
op_non_mutable_attribute_default_value_list = []
for idx in range(len(op_attribute_name_list)):
if (
op_attribute_name_list[idx]
not in op_mutable_attribute_name_list
):
op_non_mutable_attribute_name_list.append(
op_attribute_name_list[idx]
op_non_mutable_attribute_name_list = (
op_info.non_mutable_attribute_name_list
)
op_non_mutable_attribute_type_list.append(
op_attribute_type_list[idx]
op_non_mutable_attribute_type_list = (
op_info.non_mutable_attribute_type_list
)
op_non_mutable_attribute_data_type_list.append(
op_attribute_data_type_list[idx]
op_non_mutable_attribute_data_type_list = (
op_info.non_mutable_attribute_data_type_list
)
op_non_mutable_attribute_build_arg_type_list.append(
op_attribute_build_arg_type_list[idx]
op_non_mutable_attribute_build_arg_type_list = (
op_info.non_mutable_attribute_build_arg_type_list
)
op_non_mutable_attribute_default_value_list.append(
op_attribute_default_value_list[idx]
op_non_mutable_attribute_default_value_list = (
op_info.non_mutable_attribute_default_value_list
)
# others
op_infer_meta_map = op_info.infer_meta_map
op_kernel_map = op_info.kernel_map
......@@ -1075,7 +1271,9 @@ def OpGenerator(
op_class_name = to_pascal_case(op_name) + "Op"
op_dialect_name = dialect_name + "." + op_name
# gen interface/trait str
# =================================== #
# gen interface/trait list str #
# =================================== #
op_interfaces_str = ""
if len(op_interfaces) > 0:
op_interfaces_str = "," + ",".join(op_interfaces)
......@@ -1083,6 +1281,9 @@ def OpGenerator(
if len(op_traits) > 0:
op_traits_str = "," + ",".join(op_traits)
# =================================== #
# gen get input/output methods str #
# =================================== #
op_get_inputs_outputs_str = ""
for idx in range(len(op_input_name_list)):
op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format(
......@@ -1100,58 +1301,72 @@ def OpGenerator(
output_index=idx,
)
# gen build str
build_define_input_args_str = ""
build_declare_input_args_str = ""
build_func_declare_str = ""
# =================================== #
# gen Build methods str #
# =================================== #
build_args_with_muta_attr_not_input_for_declare = ""
build_func_with_muta_attr_not_input = ""
build_mutable_attr_is_input = ""
build_func_with_muta_attr_is_input = ""
if op_infer_meta_map is not None:
build_define_input_args_str = GenBuildInputArgsStr(
op_input_name_list,
op_mutable_attribute_name_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
True,
)
build_declare_input_args_str = GenBuildInputArgsStr(
(
build_args_with_muta_attr_not_input_for_declare,
build_func_with_muta_attr_not_input,
) = GenBuild(
op_class_name,
op_input_name_list,
op_input_type_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
False,
op_output_name_list,
op_output_type_list,
op_output_size_list,
op_infer_meta_map,
muta_attr_is_input=False,
)
build_inputs_str = GenBuildInputs(
op_input_name_list, op_mutable_attribute_name_list
op_infer_meta_args = op_infer_meta_map['param']
if (len(op_mutable_attribute_name_list) > 0) and (
len(
list(
set(op_infer_meta_args)
& set(op_mutable_attribute_name_list)
)
build_attributes_str = GenBuildAttributes(
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
)
build_outputs_str = GenBuildOutputs(
== 0
):
(
build_args_with_muta_attr_is_input_for_declare,
build_func_with_muta_attr_is_input,
) = GenBuild(
op_class_name,
op_input_name_list,
op_input_type_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
op_output_name_list,
op_output_type_list,
op_output_size_list,
op_infer_meta_map,
muta_attr_is_input=True,
)
build_func_declare_str = OP_BUILD_TEMPLATE.format(
op_name=op_class_name,
build_args=build_declare_input_args_str,
build_inputs=build_inputs_str,
build_attributes=build_attributes_str,
build_outputs=build_outputs_str,
)
else:
build_func_declare_str = OP_BUILD_TEMPLATE.format(
op_name=op_class_name,
build_args=build_declare_input_args_str,
build_inputs="",
build_attributes="",
build_outputs="",
build_mutable_attr_is_input = "static void Build({build_args});".format(
build_args=build_args_with_muta_attr_is_input_for_declare
)
# gen op_declare_str/op_defined_str
......@@ -1163,7 +1378,8 @@ def OpGenerator(
traits=op_traits_str,
attribute_declare=op_0_attribute_declare_str,
attribute_num=0,
build_args=build_define_input_args_str,
build_args=build_args_with_muta_attr_not_input_for_declare,
build_mutable_attr_is_input=build_mutable_attr_is_input,
get_inputs_and_outputs=op_get_inputs_outputs_str,
exclusive_interface=exclusive_interface_str,
)
......@@ -1178,7 +1394,8 @@ def OpGenerator(
attribute_num=len(op_non_mutable_attribute_name_list)
),
attribute_num=len(op_non_mutable_attribute_name_list),
build_args=build_define_input_args_str,
build_args=build_args_with_muta_attr_not_input_for_declare,
build_mutable_attr_is_input=build_mutable_attr_is_input,
get_inputs_and_outputs=op_get_inputs_outputs_str,
exclusive_interface=exclusive_interface_str,
)
......@@ -1191,10 +1408,11 @@ def OpGenerator(
attribute_names=attribute_names_str,
)
# =================================== #
# gen GetOpInfo func str #
# =================================== #
# generate get op info funciton: inputs
inputs_info_str = ""
input_info_list = []
if len(op_input_name_list) > 0:
for idx in range(len(op_input_name_list)):
input_info_list.append(
CONSTRUCT_INPUT_INFO_TEMPLATE.format(
......@@ -1202,22 +1420,23 @@ def OpGenerator(
typename=op_input_type_list[idx],
optional=op_input_optional_list[idx],
no_need_buffer=op_input_no_need_buffer_list[idx],
is_mutable_attribute='false',
)
)
# add mutable attribute as input
if len(op_mutable_attribute_name_list) > 0:
for idx in range(len(op_mutable_attribute_name_list)):
input_info_list.append(
CONSTRUCT_INPUT_INFO_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx],
typename=op_mutable_attribute_type_list[idx],
typename=op_mutable_attribute_type_list[idx][0],
optional='false',
no_need_buffer='false',
is_mutable_attribute='true',
)
)
if len(input_info_list) > 0:
inputs_info_str = ", ".join(input_info_list)
else:
inputs_info_str = ""
# generate get op info funciton: outputs
outputs_info_str = ""
if len(op_output_name_list) > 0:
......@@ -1232,25 +1451,21 @@ def OpGenerator(
)
)
outputs_info_str = ", ".join(output_info_list)
# generate get op info funciton: attributes
attribute_info_str = ""
op_mutable_attribute_name_set = set(op_mutable_attribute_name_list)
if len(op_attribute_name_list) > 0:
if len(op_non_mutable_attribute_name_list) > 0:
attribute_info_list = []
for idx in range(len(op_attribute_name_list)):
attribute_name = op_attribute_name_list[idx]
if attribute_name in op_mutable_attribute_name_set:
continue
for idx in range(len(op_non_mutable_attribute_name_list)):
attribute_info_list.append(
CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format(
name=attribute_name,
typename=op_attribute_type_list[idx],
data_type=op_attribute_data_type_list[idx],
name=op_non_mutable_attribute_name_list[idx],
typename=op_non_mutable_attribute_type_list[idx],
data_type=op_non_mutable_attribute_data_type_list[
idx
],
)
)
attribute_info_str = ", ".join(attribute_info_list)
# generate runtiem info
infer_meta_func_str = ""
infer_meta_param_str = ""
......@@ -1274,6 +1489,9 @@ def OpGenerator(
kernel_param=kernel_param_str,
)
# =================================== #
# gen Verify func str #
# =================================== #
# generate op verify function: inputs_type_check_str
if (
len(op_input_type_list) + len(op_mutable_attribute_name_list)
......@@ -1327,7 +1545,6 @@ def OpGenerator(
standard="paddle::dialect::DenseTensorType",
)
inputs_type_check_str += check_str
# generate op verify function: outputs_type_check_str
if len(op_output_type_list) == 0:
outputs_type_check_str = (
......@@ -1364,7 +1581,6 @@ def OpGenerator(
index=idx, standard=output_type
)
outputs_type_check_str += check_str
# generate op verify function: attributes_check_str
if len(op_non_mutable_attribute_name_list) == 0:
attributes_check_str = (
......@@ -1387,7 +1603,6 @@ def OpGenerator(
attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format(
attribute_name=attribute_name, standard=attribute_type
)
# generate op verify function
if "GradOp" in op_class_name or "Grad_Op" in op_class_name:
op_verify_str = GRAD_OP_VERIFY_TEMPLATE.format(
......@@ -1415,7 +1630,9 @@ def OpGenerator(
ops_declare_list.append(op_declare_str)
ops_defined_list.append(op_defined_str)
ops_defined_list.append(op_info_func_str)
ops_defined_list.append(build_func_declare_str)
ops_defined_list.append(build_func_with_muta_attr_not_input)
if len(op_mutable_attribute_name_list) > 0:
ops_defined_list.append(build_func_with_muta_attr_is_input)
ops_defined_list.append(op_verify_str)
ops_defined_list.append(op_infer_shape_str)
......
......@@ -26,5 +26,23 @@ phi::DataLayout DataLayoutAttribute::data() const {
return storage()->GetAsKey();
}
phi::Scalar ScalarAttribute::data() {
if (isa<ir::FloatAttribute>()) {
return phi::Scalar(dyn_cast<ir::FloatAttribute>().data());
} else if (isa<ir::DoubleAttribute>()) {
return phi::Scalar(dyn_cast<ir::DoubleAttribute>().data());
} else if (isa<ir::Int32_tAttribute>()) {
return phi::Scalar(dyn_cast<ir::Int32_tAttribute>().data());
} else if (isa<ir::Int64_tAttribute>()) {
return phi::Scalar(dyn_cast<ir::Int64_tAttribute>().data());
} else if (isa<ir::BoolAttribute>()) {
return phi::Scalar(dyn_cast<ir::BoolAttribute>().data());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported ir attribute when casting it into "
"phi scalar."));
}
}
} // namespace dialect
} // namespace paddle
......@@ -17,6 +17,8 @@
#include "paddle/fluid/ir/dialect/pd_attribute_storage.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace dialect {
......@@ -45,6 +47,8 @@ class ScalarAttribute : public ir::Attribute {
(val.type_id() == ir::Int32_tAttribute::type_id()) ||
(val.type_id() == ir::Int64_tAttribute::type_id());
}
phi::Scalar data();
};
class DataTypeAttribute : public ir::Attribute {
......
......@@ -101,14 +101,17 @@ struct OpInputInfo {
std::string type_name;
bool optional = false;
bool no_need_buffer = false;
bool is_mutable_attribute = false;
OpInputInfo(std::string name,
std::string type_name,
bool optional,
bool no_need_buffer)
bool no_need_buffer,
bool is_mutable_attribute)
: name(name),
type_name(type_name),
optional(optional),
no_need_buffer(no_need_buffer) {}
no_need_buffer(no_need_buffer),
is_mutable_attribute(is_mutable_attribute) {}
};
struct OpOutputInfo {
......
......@@ -56,7 +56,7 @@ class Builder {
template <typename OpTy, typename... Args>
OpTy Build(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::Build(argument, std::forward<Args>(args)...);
OpTy::Build(*this, argument, std::forward<Args>(args)...);
Operation *op = Build(std::move(argument));
return op->dyn_cast<OpTy>();
}
......
......@@ -57,20 +57,15 @@ void ModuleOp::Verify(const std::vector<ir::OpResult> &inputs,
const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
// Verify inputs type:
if (inputs.size() != 0) {
throw("The size of inputs must be equal to 0.");
}
IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");
// Verify if attributes contain attribute name in attributes_name:
auto iter = attributes.find("program");
if (iter == attributes.end() || !iter->second.isa<PointerAttribute>()) {
throw("Type of attribute: program is not right.");
}
IR_ENFORCE(iter != attributes.end() && iter->second.isa<PointerAttribute>(),
"Type of attribute: program is not right.");
// Verify outputs type:
if (outputs.size() != 0) {
throw("The size of outputs must be equal to 0.");
}
IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
}
const char *GetParameterOp::attributes_name[attributes_num] = {
......@@ -81,17 +76,15 @@ void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp.";
// Verify inputs type:
if (inputs.size() != 0) {
throw("The size of inputs must be equal to 0.");
}
// Verify outputs type:
if (outputs.size() != 1) {
throw("The size of outputs must be equal to 1.");
}
IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");
// Verify if attributes contain attribute name in attributes_name:
if (!attributes.at("parameter_name").isa<StrAttribute>()) {
throw("Type of attribute: parameter_name is not right.");
}
auto iter = attributes.find("parameter_name");
IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
"Type of attribute: parameter_name is not right.");
// Verify outputs type:
IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1.");
}
const char *SetParameterOp::attributes_name[attributes_num] = {
......@@ -102,54 +95,45 @@ void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp.";
// Verify inputs type:
if (inputs.size() != 1) {
throw("The size of inputs must be equal to 1.");
}
// Verify outputs type:
if (outputs.size() != 0) {
throw("The size of outputs must be equal to 0.");
}
IR_ENFORCE(inputs.size() == 1, "The size of outputs must be equal to 1.");
// Verify if attributes contain attribute name in attributes_name:
if (!attributes.at("parameter_name").isa<StrAttribute>()) {
throw("Type of attribute: parameter_name is not right.");
}
auto iter = attributes.find("parameter_name");
IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
"Type of attribute: parameter_name is not right.");
// Verify outputs type:
IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
}
void CombineOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
// outputs.size() == 1
PADDLE_ENFORCE_EQ(
outputs.size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
IR_ENFORCE(outputs.size() == 1,
"The size %d of outputs must be equal to 1.",
outputs.size());
// outputs[0].type == Vector<Type>
PADDLE_ENFORCE(outputs[0].isa<ir::VectorType>(),
phi::errors::PreconditionNotMet(
IR_ENFORCE(outputs[0].isa<ir::VectorType>(),
"The type %s of outputs[0] must be equal to VectorType.",
outputs[0]));
outputs[0]);
ir::VectorType output_type = outputs[0].dyn_cast<ir::VectorType>();
// inputs.size() == outputs[0].size()
PADDLE_ENFORCE_EQ(
output_type.size(),
inputs.size(),
phi::errors::PreconditionNotMet(
IR_ENFORCE(output_type.size() == inputs.size(),
"The size %d of outputs[0] must be equal to size %d of inputs.",
output_type.size(),
inputs.size()));
inputs.size());
// forall i in inputs.size(): inputs[i].type == outputs[0][i].type
for (size_t i = 0; i < inputs.size(); i++) {
PADDLE_ENFORCE_EQ(
output_type[i],
inputs[i].type(),
phi::errors::PreconditionNotMet("The type %s of outputs[0][%d] must be "
IR_ENFORCE(output_type[i] == inputs[i].type(),
"The type %s of outputs[0][%d] must be "
"equal to type %s of inputs[%d].",
output_type[i],
i,
inputs[i].type(),
i));
i);
}
}
......@@ -158,65 +142,50 @@ void SliceOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
// inputs.size() == 1
PADDLE_ENFORCE_EQ(
inputs.size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 1.", inputs.size()));
IR_ENFORCE(inputs.size() == 1,
"The size %d of inputs must be equal to 1.",
inputs.size());
// inputs[0].type == Vector<Type>
PADDLE_ENFORCE(inputs[0].type().isa<ir::VectorType>(),
phi::errors::PreconditionNotMet(
IR_ENFORCE(inputs[0].type().isa<ir::VectorType>(),
"The type %s of inputs[0] must be equal to VectorType.",
inputs[0].type()));
inputs[0].type());
ir::VectorType input_type = inputs[0].type().dyn_cast<ir::VectorType>();
// outputs.size() == 1
PADDLE_ENFORCE_EQ(
outputs.size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
IR_ENFORCE(outputs.size() == 1,
"The size %d of outputs must be equal to 1.",
outputs.size());
// attributes contains index: Int32
PADDLE_ENFORCE_NE(
attributes.count("index"),
0,
phi::errors::PreconditionNotMet("The attributes must contains index."));
IR_ENFORCE(attributes.count("index") != 0,
"The attributes must contains index.");
const ir::Attribute &attr = attributes.at("index");
PADDLE_ENFORCE(
attr.isa<ir::Int32_tAttribute>(),
phi::errors::PreconditionNotMet("The attribute index must be INT32."));
IR_ENFORCE(attr.isa<ir::Int32_tAttribute>(),
"The attribute index must be INT32.");
auto index = attr.dyn_cast<ir::Int32_tAttribute>().data();
// index >= 0 and < inputs[0].size()
PADDLE_ENFORCE_GE(
index,
0,
phi::errors::PreconditionNotMet(
"The index %d must be greater or equal than 0.", index));
PADDLE_ENFORCE_LT(
index,
input_type.size(),
phi::errors::PreconditionNotMet(
IR_ENFORCE(
index >= 0, "The index %d must be greater or equal than 0.", index);
IR_ENFORCE(static_cast<size_t>(index) < input_type.size(),
"The index %d must be less or equal than size %d of inputs[0].",
index,
input_type.size()));
input_type.size());
// inputs[index].type == outputs[0].type
PADDLE_ENFORCE_EQ(
input_type[index],
outputs[0],
phi::errors::PreconditionNotMet(
IR_ENFORCE(
input_type[index] == outputs[0],
"The type %s of inputs[%d] must be equal to type %s of outputs[0].",
input_type[index],
index,
outputs[0]));
outputs[0]);
}
const char *ConstantOp::attributes_name[attributes_num] = {"value"};
void ConstantOp::Build(OperationArgument &argument,
void ConstantOp::Build(Builder &builder,
OperationArgument &argument,
Attribute value,
Type output_type) {
argument.AddAttribute("value", value);
......
......@@ -86,6 +86,7 @@ class CombineOp : public ir::Op<CombineOp> {
static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr;
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
......@@ -125,7 +126,8 @@ class ConstantOp : public Op<ConstantOp, ConstantLikeTrait> {
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
static void Build(OperationArgument &argument, // NOLINT
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
Attribute value,
Type output_type);
......
......@@ -16,6 +16,7 @@
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
......@@ -85,7 +86,7 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
base_ptr += sizeof(Operation);
// 3.3. Construct OpOperands.
if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) {
throw("The address of OpOperandImpl must be divisible by 8.");
IR_THROW("The address of OpOperandImpl must be divisible by 8.");
}
for (size_t idx = 0; idx < num_operands; idx++) {
new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op);
......@@ -147,7 +148,7 @@ void Operation::Destroy() {
// 2.2. Deconstruct Operation.
if (reinterpret_cast<uintptr_t>(base_ptr) !=
reinterpret_cast<uintptr_t>(this)) {
throw("Operation address error");
IR_THROW("Operation address error");
}
reinterpret_cast<Operation *>(base_ptr)->~Operation();
base_ptr += sizeof(Operation);
......@@ -178,7 +179,7 @@ Operation::Operation(const AttributeMap &attributes,
ir::OpResult Operation::GetResultByIndex(uint32_t index) const {
if (index >= num_results_) {
throw("index exceeds OP output range.");
IR_THROW("index exceeds OP output range.");
}
uint32_t max_inline_idx = detail::OpResultImpl::GetMaxInlineResultIndex();
const char *ptr =
......@@ -199,7 +200,7 @@ ir::OpResult Operation::GetResultByIndex(uint32_t index) const {
ir::OpOperand Operation::GetOperandByIndex(uint32_t index) const {
if (index >= num_operands_) {
throw("index exceeds OP input range.");
IR_THROW("index exceeds OP input range.");
}
const char *ptr = reinterpret_cast<const char *>(this) + sizeof(Operation) +
(index) * sizeof(detail::OpOperandImpl);
......
......@@ -17,6 +17,8 @@
#include <memory>
#include <unordered_map>
#include "paddle/ir/core/enforce.h"
namespace ir {
// This is a structure for creating, caching, and looking up Storage of
// parametric types.
......@@ -76,7 +78,7 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageImpl(
<< std::hash<ir::TypeId>()(type_id) << ", param_hash=" << hash_value
<< "].";
if (parametric_instance_.find(type_id) == parametric_instance_.end()) {
throw("The input data pointer is null.");
IR_THROW("The input data pointer is null.");
}
ParametricStorageManager &parametric_storage = *parametric_instance_[type_id];
return parametric_storage.GetOrCreate(hash_value, equal_func, constructor);
......@@ -88,7 +90,7 @@ StorageManager::StorageBase *StorageManager::GetParameterlessStorageImpl(
VLOG(4) << "Try to get a parameterless storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
if (parameterless_instance_.find(type_id) == parameterless_instance_.end())
throw("TypeId not found in IrContext.");
IR_THROW("TypeId not found in IrContext.");
StorageBase *parameterless_instance = parameterless_instance_[type_id];
return parameterless_instance;
}
......@@ -107,7 +109,7 @@ void StorageManager::RegisterParameterlessStorageImpl(
VLOG(4) << "Register a parameterless storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
if (parameterless_instance_.find(type_id) != parameterless_instance_.end())
throw("storage class already registered");
IR_THROW("storage class already registered");
parameterless_instance_.emplace(type_id, constructor());
}
......
......@@ -427,6 +427,18 @@
data_type : dtype
backend : place
- op : full_int_array
args : (IntArray value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output: Tensor(out)
infer_meta :
func : CreateIntArrayInferMeta
param : [value, dtype]
kernel :
func : full_int_array
param : [value, dtype]
data_type : dtype
backend : place
- op : full_like
args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Place place = {})
output: Tensor(out)
......
......@@ -62,9 +62,9 @@
beta2 :
data_type : float
tensor_name : Beta2Tensor
episilon :
epsilon :
data_type : float
tensor_name : EpisilonTensor
tensor_name : EpsilonTensor
manual_signature : [adam_]
- op : adamax_
......@@ -85,9 +85,9 @@
beta2 :
data_type : float
tensor_name : Beta2Tensor
episilon :
epsilon :
data_type : float
tensor_name : EpisilonTensor
tensor_name : EpsilonTensor
- op : add (elementwise_add)
backward : add_grad (elementwise_add_grad)
......@@ -1970,7 +1970,7 @@
outputs:
out : Out
int_array:
axis :
dims :
data_type : int
extra :
attrs : [bool use_mkldnn = false]
......
......@@ -41,6 +41,15 @@ void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out) {
CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out);
}
void CreateIntArrayInferMeta(const IntArray& data,
DataType dtype,
MetaTensor* out) {
CreateInferMetaBase({static_cast<int64_t>(data.GetData().size())},
dtype,
DataLayout::NCHW,
out);
}
void CreateInferMetaBase(const std::vector<int64_t>& shape,
DataType dtype,
DataLayout layout,
......
......@@ -35,6 +35,10 @@ void AssignValueInferMeta(const std::vector<int>& shape,
DataType dtype,
MetaTensor* out);
void CreateIntArrayInferMeta(const IntArray& data,
DataType dtype,
MetaTensor* out);
void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out);
void CreateInferMetaBase(const std::vector<int64_t>& shape,
......
......@@ -80,6 +80,18 @@ void FullLikeKernel(const Context& dev_ctx,
FullValue<T>(dev_ctx, out, value);
}
template <typename T, typename Context>
void FullIntArrayKernel(const Context& dev_ctx,
const IntArray& val,
DataType dtype UNUSED,
DenseTensor* out) {
out->Resize(phi::make_ddim({static_cast<int64_t>(val.GetData().size())}));
T* out_data = dev_ctx.template Alloc<T>(out);
for (size_t i = 0; i < val.GetData().size(); ++i) {
out_data[i] = static_cast<T>(val.GetData()[i]);
}
}
} // namespace phi
PD_REGISTER_KERNEL(full,
......@@ -115,3 +127,6 @@ PD_REGISTER_KERNEL(full_like,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(
full_int_array, CPU, ALL_LAYOUT, phi::FullIntArrayKernel, int, int64_t) {}
......@@ -83,4 +83,10 @@ DenseTensor FullLike(const Context& dev_ctx,
return dense_out;
}
template <typename T, typename Context>
void FullIntArrayKernel(const Context& dev_ctx,
const IntArray& val,
DataType dtype,
DenseTensor* out);
} // namespace phi
......@@ -44,76 +44,61 @@
#include "paddle/phi/core/kernel_registry.h"
#include "test/cpp/ir/core/phi_kernel_adaptor.h"
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(full_int_array, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(uniform, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; }
TEST(program_test, program) {
// Prepare ir env
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx);
ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program.block());
ir::Block* block = program.block();
ir::Type fp32_dtype = ir::Float32Type::get(ctx);
paddle::dialect::DenseTensorTypeStorage::Dim dims = {2, 2};
paddle::dialect::DenseTensorTypeStorage::DataLayout data_layout =
paddle::dialect::DenseTensorTypeStorage::DataLayout::NCHW;
paddle::dialect::DenseTensorTypeStorage::LoD lod = {};
size_t offset = 0;
ir::Type dense_tensor_dtype = paddle::dialect::DenseTensorType::get(
ctx, fp32_dtype, dims, data_layout, lod, offset);
// (1) Def a = GetParameterOp("a")
std::string op1_name = std::string(paddle::dialect::UniformOp::name());
ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name);
// ir::Attribute shape_1 = ir::ArrayAttribute::get(ctx, {ten} );
ir::Attribute shape_1 = paddle::dialect::IntArrayAttribute::get(
ctx, std::vector<int64_t>({2, 2}));
ir::Attribute data_type =
paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType::FLOAT32);
ir::Attribute min = ir::FloatAttribute::get(ctx, 0.0);
ir::Attribute max = ir::FloatAttribute::get(ctx, 1.0);
ir::Attribute seed = ir::Int32_tAttribute::get(ctx, 2);
ir::Attribute uni_place = paddle::dialect::PlaceAttribute::get(
ctx, phi::Place(phi::AllocationType::CPU));
std::unordered_map<std::string, ir::Attribute> op1_attribute{
{"shape", shape_1},
{"dtype", data_type},
{"min", min},
{"max", max},
{"seed", seed},
{"place", uni_place}};
ir::Operation* op1 =
ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op1_info);
block->push_back(op1);
// (2) Def b = GetParameterOp("b")
std::string op2_name = std::string(paddle::dialect::UniformOp::name());
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
ir::Attribute ten2 = ir::Int32_tAttribute::get(ctx, 3);
std::unordered_map<std::string, ir::Attribute> op2_attribute{{"shape", ten2}};
ir::Operation* op2 =
ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op2_info);
block->push_back(op2);
// (3) Def out = AddOp(a, b)
std::string add_op_name = std::string(paddle::dialect::AddOp::name());
ir::OpInfo add_op_info = ctx->GetRegisteredOpInfo(add_op_name);
ir::Operation* add_op = ir::Operation::Create(
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)},
{},
{dense_tensor_dtype},
add_op_info);
block->push_back(add_op);
// Def: A = paddle::dialect::UniformOp(std::vector<int64_t> shape,
// phi::DataType dtype, float min, float max, int seed, phi::Place place)
paddle::dialect::UniformOp uniform1 =
builder.Build<paddle::dialect::UniformOp>(std::vector<int64_t>{2, 2},
phi::DataType::FLOAT32,
0.0,
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform1->GetResultByIndex(0)
.type()
.isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 4u);
// Def: B = paddle::dialect::UniformOp(...)
paddle::dialect::UniformOp uniform2 =
builder.Build<paddle::dialect::UniformOp>(std::vector<int64_t>{2, 2},
phi::DataType::FLOAT32,
0.0,
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform2->GetResultByIndex(0)
.type()
.isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 8u);
// Def: C = paddle::dialect::AddOp(ir::OpResult x_, ir::OpResult y_)
paddle::dialect::AddOp add = builder.Build<paddle::dialect::AddOp>(
uniform1->GetResultByIndex(0), uniform2->GetResultByIndex(0));
EXPECT_EQ(
add->GetResultByIndex(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 9u);
// Execute program
paddle::framework::Scope scope;
PhiKernelAdaptor phi_kernel_adaptor(&scope);
phi_kernel_adaptor.run(&program);
auto out_tensor =
......
......@@ -20,6 +20,7 @@
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/program.h"
......@@ -240,10 +241,12 @@ TEST(op_test, module_op_death) {
ir::AttributeMap attrs{{"program", ir::Int32_tAttribute::get(ctx, 1)}};
std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)};
EXPECT_THROW(ir::Operation::Create(inputs, {}, {}, op_info), const char *);
EXPECT_THROW(ir::Operation::Create({}, attrs, {}, op_info), const char *);
EXPECT_THROW(ir::Operation::Create(inputs, {}, {}, op_info),
ir::IrNotMetException);
EXPECT_THROW(ir::Operation::Create({}, attrs, {}, op_info),
ir::IrNotMetException);
EXPECT_THROW(ir::Operation::Create({}, {}, output_types, op_info),
const char *);
ir::IrNotMetException);
ir::Program program(ctx);
......
......@@ -98,27 +98,30 @@ void build_context(ir::Operation* op,
op->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
auto op_info_res = op_info_interface.GetOpInfo();
// inputs include input and mutable attributes
auto input_info = std::get<0>(op_info_res);
std::set<std::string> input_set;
std::map<std::string, size_t> input_index_map;
std::map<std::string, std::string> mutable_attr_type_map;
int input_index = 0;
for (auto& t : input_info) {
VLOG(6) << t.name << "\t" << t.type_name;
input_set.insert(t.name);
input_index_map[t.name] = input_index++;
if (t.is_mutable_attribute) {
mutable_attr_type_map[t.name] = t.type_name;
}
}
auto attr_map = op->attributes();
std::map<std::string, std::string> attr_type_map;
auto attr_info = std::get<1>(op_info_res);
std::map<std::string, std::string> attr_type_map;
for (auto& t : attr_info) {
VLOG(6) << t.name << "\t" << t.type_name;
attr_type_map[t.name] = t.type_name;
}
auto attr_map = op->attributes();
auto runtime_info = std::get<3>(op_info_res);
int input_index = 0;
// int input_index = 0;
std::vector<std::string> vec_param_list;
if (is_infer_meta) {
vec_param_list = runtime_info.infer_meta_param;
......@@ -126,14 +129,32 @@ void build_context(ir::Operation* op,
vec_param_list = runtime_info.kernel_param;
}
for (auto& t : vec_param_list) {
if (input_set.count(t)) {
if (input_index_map.count(t)) {
// get information from input
ir::Value ptr = op->GetOperandByIndex(input_index++).source();
ir::Value ptr = op->GetOperandByIndex(input_index_map[t]).source();
auto in_var_name = name_map.at(ptr);
if (mutable_attr_type_map.count(t)) {
VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t"
<< in_var_name;
if (mutable_attr_type_map[t] == "paddle::dialect::IntArrayAttribute") {
ctx->EmplaceBackAttr(phi::IntArray(
*(scope->Var(in_var_name)->GetMutable<phi::DenseTensor>())));
} else if (mutable_attr_type_map[t] ==
"paddle::dialect::ScalarAttribute") {
ctx->EmplaceBackAttr(phi::Scalar(
*(scope->Var(in_var_name)->GetMutable<phi::DenseTensor>())));
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
mutable_attr_type_map[t]));
}
} else {
VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name;
ctx->EmplaceBackInput(
scope->Var(in_var_name)->GetMutable<phi::DenseTensor>());
}
}
if (attr_type_map.count(t)) {
auto type_name = attr_type_map[t];
......@@ -149,10 +170,14 @@ void build_context(ir::Operation* op,
} else if (type_name == "paddle::dialect::PlaceAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data());
} else if (type_name == "paddle::dialect::ScalarAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::ScalarAttribute>().data());
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
type_name));
}
VLOG(6) << "ctx->EmplaceBackAttr: " << t;
}
}
......@@ -197,6 +222,9 @@ class PhiKernelAdaptor {
phi::KernelKey kernel_key(phi::TransToPhiBackend(cpu_place),
phi::DataLayout::ANY,
phi::DataType::FLOAT32);
if (runtime_info.kernel_func[0] == "full_int_array") {
kernel_key.set_dtype(phi::DataType::INT64);
}
auto found_it = phi_kernels.find(kernel_key);
if (found_it == phi_kernels.end()) {
std::cerr << "kernel name " << runtime_info.kernel_func[0] << std::endl;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册