未验证 提交 3a452e4e 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine IR builder and throw methods (#54396)

* refine code

* refine code

* refine code

* refine code

* refine code

* refine code

* refine code

* fix bug

* refine code

* refine code

* refine code

* refine code

* refine code

* delete unused code

* delete unused code

* refine code
上级 b62b384b
...@@ -57,6 +57,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ ...@@ -57,6 +57,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
static constexpr uint32_t attributes_num = {attribute_num}; static constexpr uint32_t attributes_num = {attribute_num};
static OpInfoTuple GetOpInfo(); static OpInfoTuple GetOpInfo();
static void Build({build_args}); static void Build({build_args});
{build_mutable_attr_is_input}
static void Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes); static void Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes);
{get_inputs_and_outputs} {get_inputs_and_outputs}
{exclusive_interface} {exclusive_interface}
...@@ -94,6 +95,7 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g ...@@ -94,6 +95,7 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/ternary.h" #include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/api/lib/utils/allocator.h"
{input} {input}
""" """
...@@ -112,9 +114,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{ ...@@ -112,9 +114,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{
return std::make_tuple(inputs, attributes, outputs, run_time_info); return std::make_tuple(inputs, attributes, outputs, run_time_info);
}} }}
""" """
CONSTRUCT_INPUT_INFO_TEMPLATE = ( CONSTRUCT_INPUT_INFO_TEMPLATE = """OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer}, {is_mutable_attribute})"""
"""OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer})"""
)
CONSTRUCT_OUTPUT_INFO_TEMPLATE = ( CONSTRUCT_OUTPUT_INFO_TEMPLATE = (
"""OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})""" """OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})"""
) )
...@@ -125,6 +125,7 @@ CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = ( ...@@ -125,6 +125,7 @@ CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = (
# build # build
OP_BUILD_TEMPLATE = """ OP_BUILD_TEMPLATE = """
void {op_name}::Build({build_args}) {{ void {op_name}::Build({build_args}) {{
{build_mutable_attributes}
{build_inputs} {build_inputs}
{build_attributes} {build_attributes}
{build_outputs} {build_outputs}
...@@ -303,6 +304,7 @@ class OpInfoParser: ...@@ -303,6 +304,7 @@ class OpInfoParser:
self.output_type_list, self.output_type_list,
self.output_optional_list, self.output_optional_list,
) )
# parse attributes # parse attributes
self.attr_types_map = { self.attr_types_map = {
'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'], 'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'],
...@@ -313,7 +315,7 @@ class OpInfoParser: ...@@ -313,7 +315,7 @@ class OpInfoParser:
'Scalar(dobule)': ['ir::DoubleAttribute', 'dobule'], 'Scalar(dobule)': ['ir::DoubleAttribute', 'dobule'],
'Scalar[]': [ 'Scalar[]': [
'ir::ArrayAttribute<paddle::dialect::ScalarAttribute>', 'ir::ArrayAttribute<paddle::dialect::ScalarAttribute>',
'std::vector<Scalar>', 'const std::vector<Scalar>&',
], ],
'int': ['ir::Int32_tAttribute', 'int'], 'int': ['ir::Int32_tAttribute', 'int'],
'int32_t': ['ir::Int32_tAttribute', 'int32_t'], 'int32_t': ['ir::Int32_tAttribute', 'int32_t'],
...@@ -323,18 +325,18 @@ class OpInfoParser: ...@@ -323,18 +325,18 @@ class OpInfoParser:
'float': ['ir::FloatAttribute', 'float'], 'float': ['ir::FloatAttribute', 'float'],
'float[]': [ 'float[]': [
'ir::ArrayAttribute<ir::FloatAttribute>', 'ir::ArrayAttribute<ir::FloatAttribute>',
'std::vector<float>', 'const std::vector<float>&',
], ],
'double': ['ir::DoubleAttribute', 'double'], 'double': ['ir::DoubleAttribute', 'double'],
'bool': ['ir::BoolAttribute', 'bool'], 'bool': ['ir::BoolAttribute', 'bool'],
'bool[]': [ 'bool[]': [
'ir::ArrayAttribute<ir::BoolAttribute>', 'ir::ArrayAttribute<ir::BoolAttribute>',
'std::vecot<bool>', 'const std::vecot<bool>&',
], ],
'str': ['ir::StrAttribute', 'std::string'], 'str': ['ir::StrAttribute', 'std::string'],
'str[]': [ 'str[]': [
'ir::ArrayAttribute<ir::StrAttribute>', 'ir::ArrayAttribute<ir::StrAttribute>',
'std::vector<std::string>', 'const std::vector<std::string>&',
], ],
'Place': ['paddle::dialect::PlaceAttribute', 'Place'], 'Place': ['paddle::dialect::PlaceAttribute', 'Place'],
'DataLayout': [ 'DataLayout': [
...@@ -344,11 +346,11 @@ class OpInfoParser: ...@@ -344,11 +346,11 @@ class OpInfoParser:
'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'], 'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'],
'int64_t[]': [ 'int64_t[]': [
'ir::ArrayAttribute<ir::Int64_tAttribute>', 'ir::ArrayAttribute<ir::Int64_tAttribute>',
'std::vector<int64_t>', 'const std::vector<int64_t>&',
], ],
'int[]': [ 'int[]': [
'ir::ArrayAttribute<ir::Int32_tAttribute>', 'ir::ArrayAttribute<ir::Int32_tAttribute>',
'std::vector<int>', 'const std::vector<int>&',
], ],
} }
self.attribute_name_list = self.parse_attribute_name_list() self.attribute_name_list = self.parse_attribute_name_list()
...@@ -368,6 +370,14 @@ class OpInfoParser: ...@@ -368,6 +370,14 @@ class OpInfoParser:
self.mutable_attribute_type_list, self.mutable_attribute_type_list,
) = self.parse_mutable_attribute() ) = self.parse_mutable_attribute()
(
self.non_mutable_attribute_name_list,
self.non_mutable_attribute_type_list,
self.non_mutable_attribute_data_type_list,
self.non_mutable_attribute_build_arg_type_list,
self.non_mutable_attribute_default_value_list,
) = self.parse_non_nutable_attribute()
# parse infermeta && kernel # parse infermeta && kernel
self.infer_meta_map = self.parse_infer_meta_map() self.infer_meta_map = self.parse_infer_meta_map()
self.kernel_map = self.parse_kernel_map() self.kernel_map = self.parse_kernel_map()
...@@ -423,6 +433,12 @@ class OpInfoParser: ...@@ -423,6 +433,12 @@ class OpInfoParser:
mutable_attribute_type_list.append( mutable_attribute_type_list.append(
["ir::StrAttribute", "std::string"] ["ir::StrAttribute", "std::string"]
) )
else:
if (
scalar_attr == "depth"
and self.op_phi_name[0] == "one_hot"
):
mutable_attribute_name_list.append("num_classes")
else: else:
mutable_attribute_name_list.append(scalar_attr) mutable_attribute_name_list.append(scalar_attr)
mutable_attribute_type_list.append( mutable_attribute_type_list.append(
...@@ -459,7 +475,55 @@ class OpInfoParser: ...@@ -459,7 +475,55 @@ class OpInfoParser:
], ],
] ]
) )
return mutable_attribute_name_list, mutable_attribute_type_list sorted_mutable_attribute_name_list = []
sorted_mutable_attribute_type_list = []
for attr_name in self.attribute_name_list:
if attr_name in mutable_attribute_name_list:
sorted_mutable_attribute_name_list.append(attr_name)
sorted_mutable_attribute_type_list.append(
mutable_attribute_type_list[
mutable_attribute_name_list.index(attr_name)
]
)
return (
sorted_mutable_attribute_name_list,
sorted_mutable_attribute_type_list,
)
def parse_non_nutable_attribute(self):
op_non_mutable_attribute_name_list = []
op_non_mutable_attribute_type_list = []
op_non_mutable_attribute_data_type_list = []
op_non_mutable_attribute_build_arg_type_list = []
op_non_mutable_attribute_default_value_list = []
for idx in range(len(self.attribute_name_list)):
if (
self.attribute_name_list[idx]
not in self.mutable_attribute_name_list
):
op_non_mutable_attribute_name_list.append(
self.attribute_name_list[idx]
)
op_non_mutable_attribute_type_list.append(
self.attribute_type_list[idx]
)
op_non_mutable_attribute_data_type_list.append(
self.attribute_data_type_list[idx]
)
op_non_mutable_attribute_build_arg_type_list.append(
self.attribute_build_arg_type_list[idx]
)
op_non_mutable_attribute_default_value_list.append(
self.attribute_default_value_list[idx]
)
return (
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_non_mutable_attribute_data_type_list,
op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
)
def parse_input_name_list(self): def parse_input_name_list(self):
name_list = [] name_list = []
...@@ -649,20 +713,47 @@ def to_pascal_case(s): ...@@ -649,20 +713,47 @@ def to_pascal_case(s):
# ===================================== # =====================================
def GenBuildInputArgsStr( def GenBuildInputArgsStr(
op_input_name_list, op_input_name_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
op_mutable_attribute_name_list, op_mutable_attribute_name_list,
op_non_mutable_attribute_name_list, op_non_mutable_attribute_name_list,
op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list, op_non_mutable_attribute_default_value_list,
for_func_define=True, for_func_define=True,
mutable_attr_is_input=False,
): ):
''' '''
Example: ir::OperationArgument &argument, ir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={} Example: ir::Builder &builder, ir::OperationArgument &argument, ir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={}
''' '''
# add inputs # add inputs
build_args_str = "ir::OperationArgument &argument" build_args_str = "ir::Builder &builder, ir::OperationArgument &argument"
if len(op_input_name_list) > 0: if len(op_input_name_list) > 0:
for input_name in op_input_name_list: for input_name in op_input_name_list:
build_args_str += ", ir::OpResult " + input_name + "_" build_args_str += ", ir::OpResult " + input_name + "_"
if not mutable_attr_is_input:
# add attributes
for attr_idx in range(len(op_attribute_name_list)):
build_args_str += (
", "
+ op_attribute_build_arg_type_list[attr_idx]
+ " "
+ op_attribute_name_list[attr_idx]
)
if for_func_define:
if op_attribute_default_value_list[attr_idx] is not None:
default_value = op_attribute_default_value_list[attr_idx]
if (
op_attribute_build_arg_type_list[attr_idx]
!= "std::string"
):
if default_value[0] == "'" or default_value[0] == '"':
default_value = default_value[1:]
if default_value[-1] == "'" or default_value[-1] == '"':
default_value = default_value[0:-1]
build_args_str += "=" + default_value
else:
# add mutable attributes as inputs # add mutable attributes as inputs
if len(op_mutable_attribute_name_list) > 0: if len(op_mutable_attribute_name_list) > 0:
for mutable_attr in op_mutable_attribute_name_list: for mutable_attr in op_mutable_attribute_name_list:
...@@ -693,9 +784,58 @@ def GenBuildInputArgsStr( ...@@ -693,9 +784,58 @@ def GenBuildInputArgsStr(
if default_value[-1] == "'" or default_value[-1] == '"': if default_value[-1] == "'" or default_value[-1] == '"':
default_value = default_value[0:-1] default_value = default_value[0:-1]
build_args_str += "=" + default_value build_args_str += "=" + default_value
return build_args_str return build_args_str
mutable_attribute_phi_type_maps = {
'int': 'phi::DataType::INT32',
'int64_t': 'phi::DataType::INT64',
'float': 'phi::DataType::FLOAT32',
'std::vector<int64_t>': 'phi::DataType::INT64',
'const std::vector<int64_t>&': 'phi::DataType::INT64',
}
def GenBuildInserFullForMutableAttribute(
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
):
build_mutable_attribute = ""
BUILD_INTARRAY_ATTRIBUTE_TEMPLATE = """ // Generate int_array mutable attribute: {attr_name}
paddle::dialect::FullIntArrayOp full_{attr_name}_op = builder.Build<paddle::dialect::FullIntArrayOp>({attr_name}, {phi_dtype}, phi::CPUPlace());
ir::OpResult {attr_name}_ = full_{attr_name}_op->GetResultByIndex(0);
"""
BUILD_SCALAR_ATTRIBUTE_TEMPLATE = """ // Generate scalar mutable attribute: {attr_name}
paddle::dialect::FullOp full_{attr_name}_op = builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{{1}}, {attr_name}, {phi_dtype}, phi::CPUPlace());
ir::OpResult {attr_name}_ = full_{attr_name}_op->GetResultByIndex(0);
"""
for idx in range(len(op_mutable_attribute_name_list)):
attr_name = op_mutable_attribute_name_list[idx]
attr_type = op_mutable_attribute_type_list[idx][0]
if attr_name in op_attribute_name_list:
phi_dtype = mutable_attribute_phi_type_maps[
op_attribute_build_arg_type_list[
op_attribute_name_list.index(attr_name)
]
]
else:
phi_dtype = mutable_attribute_phi_type_maps[
op_mutable_attribute_type_list[idx][1]
]
if attr_type == "paddle::dialect::IntArrayAttribute":
build_mutable_attribute += BUILD_INTARRAY_ATTRIBUTE_TEMPLATE.format(
attr_name=attr_name, phi_dtype=phi_dtype
)
else:
build_mutable_attribute += BUILD_SCALAR_ATTRIBUTE_TEMPLATE.format(
attr_name=attr_name, phi_dtype=phi_dtype
)
return build_mutable_attribute
def GenBuildInputs(op_input_name_list, op_mutable_attribute_name_list): def GenBuildInputs(op_input_name_list, op_mutable_attribute_name_list):
BUILD_INPUT_TEMPLATE = """ std::vector<ir::OpResult> argument_inputs = {{{inputs_args}}}; BUILD_INPUT_TEMPLATE = """ std::vector<ir::OpResult> argument_inputs = {{{inputs_args}}};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); argument.AddOperands(argument_inputs.begin(), argument_inputs.end());
...@@ -805,38 +945,39 @@ def GenBuildOutputs( ...@@ -805,38 +945,39 @@ def GenBuildOutputs(
op_output_type_list, op_output_type_list,
op_output_size_list, op_output_size_list,
op_infer_meta_map, op_infer_meta_map,
mutable_attr_is_input=False,
): ):
build_output_str = ' VLOG(4) << "Builder construction outputs";\n' build_output_str = ' VLOG(4) << "Builder construction outputs";\n'
CREATE_INPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name}; CREATE_INPUT_METATENSOR_TEMPLATE = """
dense_{name}.set_meta( VLOG(4) << "Builder construction dense_{name}";
phi::DenseTensor dense_{name}(std::make_unique<paddle::experimental::DefaultAllocator>(paddle::platform::CPUPlace()).get(),
phi::DenseTensorMeta(TransToPhiDataType({name}.dtype()), phi::DenseTensorMeta(TransToPhiDataType({name}.dtype()),
{name}.dims(), {name}.dims(),
{name}.data_layout(), {name}.data_layout(),
{name}.lod(), {name}.lod(),
{name}.offset()) {name}.offset()));
); VLOG(4) << "Builder construction meta_{name}";
phi::MetaTensor meta_{name}(&dense_{name}); phi::MetaTensor meta_{name}(&dense_{name});
""" """
CREATE_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector<phi::DenseTensor> vec_dense_{name}({name}.size(), phi::DenseTensor()); CREATE_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector<phi::DenseTensor> vec_dense_{name};
std::vector<phi::MetaTensor> vec_meta_{name};
for (size_t i=0; i < static_cast<size_t>({name}.size()); i++) {{ for (size_t i=0; i < static_cast<size_t>({name}.size()); i++) {{
vec_dense_{name}[i].set_meta( vec_dense_{name}.push_back(phi::DenseTensor(std::make_unique<paddle::experimental::DefaultAllocator>(paddle::platform::CPUPlace()).get(),
phi::DenseTensorMeta(TransToPhiDataType({name}[i].dyn_cast<paddle::dialect::DenseTensorType>().dtype()), phi::DenseTensorMeta(TransToPhiDataType({name}[i].dyn_cast<paddle::dialect::DenseTensorType>().dtype()),
{name}[i].dyn_cast<paddle::dialect::DenseTensorType>().dims(), {name}[i].dyn_cast<paddle::dialect::DenseTensorType>().dims(),
{name}[i].dyn_cast<paddle::dialect::DenseTensorType>().data_layout(), {name}[i].dyn_cast<paddle::dialect::DenseTensorType>().data_layout(),
{name}[i].dyn_cast<paddle::dialect::DenseTensorType>().lod(), {name}[i].dyn_cast<paddle::dialect::DenseTensorType>().lod(),
{name}[i].dyn_cast<paddle::dialect::DenseTensorType>().offset()) {name}[i].dyn_cast<paddle::dialect::DenseTensorType>().offset())));
); }}
std::vector<phi::MetaTensor> vec_meta_{name};
for (size_t i=0; i < vec_dense_{name}.size(); i++) {{
vec_meta_{name}.push_back(phi::MetaTensor(&vec_dense_{name}[i])); vec_meta_{name}.push_back(phi::MetaTensor(&vec_dense_{name}[i]));
}} }}
std::vector<const phi::MetaTensor*> meta_{name}; std::vector<const phi::MetaTensor*> meta_{name};
for (size_t i=0; i < static_cast<size_t>(vec_meta_{name}.size()); i++) {{ for (size_t i=0; i < static_cast<size_t>(vec_meta_{name}.size()); i++) {{
meta_{name}.push_back(&vec_meta_{name}[i]); meta_{name}.push_back(&vec_meta_{name}[i]);
}} }}
""" """
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<ir::ConstantOp>().value().dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<ir::ConstantOp>().value().dyn_cast<{ir_type}>().data(); (void){name};\n"""
CREATE_STRING_MUTABLE_ATTRIBUE_TEMPLATE = """ std::string {name} = {name}_.owner()->dyn_cast<ir::ConstantOp>().value().dyn_cast<ir::StrAttribute>().data(); (void){name};\n"""
CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name}; CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
phi::MetaTensor meta_{name}(&dense_{name}); phi::MetaTensor meta_{name}(&dense_{name});
...@@ -863,31 +1004,6 @@ def GenBuildOutputs( ...@@ -863,31 +1004,6 @@ def GenBuildOutputs(
build_output_str += " paddle::dialect::DenseTensorType {name} = {name}_.type().dyn_cast<paddle::dialect::DenseTensorType>(); (void){name};\n".format( build_output_str += " paddle::dialect::DenseTensorType {name} = {name}_.type().dyn_cast<paddle::dialect::DenseTensorType>(); (void){name};\n".format(
name=op_input_name_list[idx] name=op_input_name_list[idx]
) )
# Prepare mutable attributes
for idx in range(len(op_mutable_attribute_name_list)):
attr_dtype = op_mutable_attribute_type_list[idx]
# int_array
if attr_dtype[0] == "paddle::dialect::IntArrayAttribute":
build_output_str += (
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx]
)
)
# scalar
elif attr_dtype[0] == "paddle::dialect::ScalarAttribute":
build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx],
dtype=attr_dtype[1],
ir_type=scalar_type_maps[attr_dtype[1]],
)
# string
elif attr_dtype[0] == "ir::StrAttribute":
build_output_str += CREATE_STRING_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx]
)
else:
assert "mutable attribtue type is not right."
build_output_str += "\n"
# Prepare inputs_meta_tensor & attributes for infer meta # Prepare inputs_meta_tensor & attributes for infer meta
infer_meta_args = [] infer_meta_args = []
...@@ -979,6 +1095,95 @@ def GenBuildOutputs( ...@@ -979,6 +1095,95 @@ def GenBuildOutputs(
return build_output_str return build_output_str
def GenBuild(
op_class_name,
op_input_name_list,
op_input_type_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
op_output_name_list,
op_output_type_list,
op_output_size_list,
op_infer_meta_map,
muta_attr_is_input=False,
):
build_args_for_declare = ""
build_func = ""
build_args_for_declare = GenBuildInputArgsStr(
op_input_name_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
op_mutable_attribute_name_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
True,
muta_attr_is_input,
)
build_args_for_define = GenBuildInputArgsStr(
op_input_name_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
op_mutable_attribute_name_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
False,
muta_attr_is_input,
)
inset_full_for_mutable_attributes_str = ""
if not muta_attr_is_input:
inset_full_for_mutable_attributes_str = (
GenBuildInserFullForMutableAttribute(
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
)
)
build_inputs_str = GenBuildInputs(
op_input_name_list, op_mutable_attribute_name_list
)
build_attributes_str = GenBuildAttributes(
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
)
build_outputs_str = GenBuildOutputs(
op_input_name_list,
op_input_type_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_output_name_list,
op_output_type_list,
op_output_size_list,
op_infer_meta_map,
False,
)
build_func = OP_BUILD_TEMPLATE.format(
op_name=op_class_name,
build_args=build_args_for_define,
build_mutable_attributes=inset_full_for_mutable_attributes_str,
build_inputs=build_inputs_str,
build_attributes=build_attributes_str,
build_outputs=build_outputs_str,
)
return (build_args_for_declare, build_func)
def OpGenerator( def OpGenerator(
op_yaml_files, op_yaml_files,
op_compat_yaml_file, op_compat_yaml_file,
...@@ -1032,31 +1237,22 @@ def OpGenerator( ...@@ -1032,31 +1237,22 @@ def OpGenerator(
op_attribute_data_type_list = op_info.attribute_data_type_list op_attribute_data_type_list = op_info.attribute_data_type_list
op_attribute_build_arg_type_list = op_info.attribute_build_arg_type_list op_attribute_build_arg_type_list = op_info.attribute_build_arg_type_list
op_attribute_default_value_list = op_info.attribute_default_value_list op_attribute_default_value_list = op_info.attribute_default_value_list
op_non_mutable_attribute_name_list = [] op_non_mutable_attribute_name_list = (
op_non_mutable_attribute_type_list = [] op_info.non_mutable_attribute_name_list
op_non_mutable_attribute_data_type_list = []
op_non_mutable_attribute_build_arg_type_list = []
op_non_mutable_attribute_default_value_list = []
for idx in range(len(op_attribute_name_list)):
if (
op_attribute_name_list[idx]
not in op_mutable_attribute_name_list
):
op_non_mutable_attribute_name_list.append(
op_attribute_name_list[idx]
) )
op_non_mutable_attribute_type_list.append( op_non_mutable_attribute_type_list = (
op_attribute_type_list[idx] op_info.non_mutable_attribute_type_list
) )
op_non_mutable_attribute_data_type_list.append( op_non_mutable_attribute_data_type_list = (
op_attribute_data_type_list[idx] op_info.non_mutable_attribute_data_type_list
) )
op_non_mutable_attribute_build_arg_type_list.append( op_non_mutable_attribute_build_arg_type_list = (
op_attribute_build_arg_type_list[idx] op_info.non_mutable_attribute_build_arg_type_list
) )
op_non_mutable_attribute_default_value_list.append( op_non_mutable_attribute_default_value_list = (
op_attribute_default_value_list[idx] op_info.non_mutable_attribute_default_value_list
) )
# others # others
op_infer_meta_map = op_info.infer_meta_map op_infer_meta_map = op_info.infer_meta_map
op_kernel_map = op_info.kernel_map op_kernel_map = op_info.kernel_map
...@@ -1075,7 +1271,9 @@ def OpGenerator( ...@@ -1075,7 +1271,9 @@ def OpGenerator(
op_class_name = to_pascal_case(op_name) + "Op" op_class_name = to_pascal_case(op_name) + "Op"
op_dialect_name = dialect_name + "." + op_name op_dialect_name = dialect_name + "." + op_name
# gen interface/trait str # =================================== #
# gen interface/trait list str #
# =================================== #
op_interfaces_str = "" op_interfaces_str = ""
if len(op_interfaces) > 0: if len(op_interfaces) > 0:
op_interfaces_str = "," + ",".join(op_interfaces) op_interfaces_str = "," + ",".join(op_interfaces)
...@@ -1083,6 +1281,9 @@ def OpGenerator( ...@@ -1083,6 +1281,9 @@ def OpGenerator(
if len(op_traits) > 0: if len(op_traits) > 0:
op_traits_str = "," + ",".join(op_traits) op_traits_str = "," + ",".join(op_traits)
# =================================== #
# gen get input/output methods str #
# =================================== #
op_get_inputs_outputs_str = "" op_get_inputs_outputs_str = ""
for idx in range(len(op_input_name_list)): for idx in range(len(op_input_name_list)):
op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format( op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format(
...@@ -1100,58 +1301,72 @@ def OpGenerator( ...@@ -1100,58 +1301,72 @@ def OpGenerator(
output_index=idx, output_index=idx,
) )
# gen build str # =================================== #
build_define_input_args_str = "" # gen Build methods str #
build_declare_input_args_str = "" # =================================== #
build_func_declare_str = "" build_args_with_muta_attr_not_input_for_declare = ""
build_func_with_muta_attr_not_input = ""
build_mutable_attr_is_input = ""
build_func_with_muta_attr_is_input = ""
if op_infer_meta_map is not None: if op_infer_meta_map is not None:
build_define_input_args_str = GenBuildInputArgsStr( (
op_input_name_list, build_args_with_muta_attr_not_input_for_declare,
op_mutable_attribute_name_list, build_func_with_muta_attr_not_input,
op_non_mutable_attribute_name_list, ) = GenBuild(
op_non_mutable_attribute_build_arg_type_list, op_class_name,
op_non_mutable_attribute_default_value_list,
True,
)
build_declare_input_args_str = GenBuildInputArgsStr(
op_input_name_list, op_input_name_list,
op_input_type_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
op_mutable_attribute_name_list, op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_non_mutable_attribute_name_list, op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list, op_non_mutable_attribute_default_value_list,
False, op_output_name_list,
op_output_type_list,
op_output_size_list,
op_infer_meta_map,
muta_attr_is_input=False,
) )
build_inputs_str = GenBuildInputs( op_infer_meta_args = op_infer_meta_map['param']
op_input_name_list, op_mutable_attribute_name_list if (len(op_mutable_attribute_name_list) > 0) and (
len(
list(
set(op_infer_meta_args)
& set(op_mutable_attribute_name_list)
) )
build_attributes_str = GenBuildAttributes(
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
) )
build_outputs_str = GenBuildOutputs( == 0
):
(
build_args_with_muta_attr_is_input_for_declare,
build_func_with_muta_attr_is_input,
) = GenBuild(
op_class_name,
op_input_name_list, op_input_name_list,
op_input_type_list, op_input_type_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
op_mutable_attribute_name_list, op_mutable_attribute_name_list,
op_mutable_attribute_type_list, op_mutable_attribute_type_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
op_output_name_list, op_output_name_list,
op_output_type_list, op_output_type_list,
op_output_size_list, op_output_size_list,
op_infer_meta_map, op_infer_meta_map,
muta_attr_is_input=True,
) )
build_func_declare_str = OP_BUILD_TEMPLATE.format(
op_name=op_class_name, build_mutable_attr_is_input = "static void Build({build_args});".format(
build_args=build_declare_input_args_str, build_args=build_args_with_muta_attr_is_input_for_declare
build_inputs=build_inputs_str,
build_attributes=build_attributes_str,
build_outputs=build_outputs_str,
)
else:
build_func_declare_str = OP_BUILD_TEMPLATE.format(
op_name=op_class_name,
build_args=build_declare_input_args_str,
build_inputs="",
build_attributes="",
build_outputs="",
) )
# gen op_declare_str/op_defined_str # gen op_declare_str/op_defined_str
...@@ -1163,7 +1378,8 @@ def OpGenerator( ...@@ -1163,7 +1378,8 @@ def OpGenerator(
traits=op_traits_str, traits=op_traits_str,
attribute_declare=op_0_attribute_declare_str, attribute_declare=op_0_attribute_declare_str,
attribute_num=0, attribute_num=0,
build_args=build_define_input_args_str, build_args=build_args_with_muta_attr_not_input_for_declare,
build_mutable_attr_is_input=build_mutable_attr_is_input,
get_inputs_and_outputs=op_get_inputs_outputs_str, get_inputs_and_outputs=op_get_inputs_outputs_str,
exclusive_interface=exclusive_interface_str, exclusive_interface=exclusive_interface_str,
) )
...@@ -1178,7 +1394,8 @@ def OpGenerator( ...@@ -1178,7 +1394,8 @@ def OpGenerator(
attribute_num=len(op_non_mutable_attribute_name_list) attribute_num=len(op_non_mutable_attribute_name_list)
), ),
attribute_num=len(op_non_mutable_attribute_name_list), attribute_num=len(op_non_mutable_attribute_name_list),
build_args=build_define_input_args_str, build_args=build_args_with_muta_attr_not_input_for_declare,
build_mutable_attr_is_input=build_mutable_attr_is_input,
get_inputs_and_outputs=op_get_inputs_outputs_str, get_inputs_and_outputs=op_get_inputs_outputs_str,
exclusive_interface=exclusive_interface_str, exclusive_interface=exclusive_interface_str,
) )
...@@ -1191,10 +1408,11 @@ def OpGenerator( ...@@ -1191,10 +1408,11 @@ def OpGenerator(
attribute_names=attribute_names_str, attribute_names=attribute_names_str,
) )
# =================================== #
# gen GetOpInfo func str #
# =================================== #
# generate get op info funciton: inputs # generate get op info funciton: inputs
inputs_info_str = ""
input_info_list = [] input_info_list = []
if len(op_input_name_list) > 0:
for idx in range(len(op_input_name_list)): for idx in range(len(op_input_name_list)):
input_info_list.append( input_info_list.append(
CONSTRUCT_INPUT_INFO_TEMPLATE.format( CONSTRUCT_INPUT_INFO_TEMPLATE.format(
...@@ -1202,22 +1420,23 @@ def OpGenerator( ...@@ -1202,22 +1420,23 @@ def OpGenerator(
typename=op_input_type_list[idx], typename=op_input_type_list[idx],
optional=op_input_optional_list[idx], optional=op_input_optional_list[idx],
no_need_buffer=op_input_no_need_buffer_list[idx], no_need_buffer=op_input_no_need_buffer_list[idx],
is_mutable_attribute='false',
) )
) )
# add mutable attribute as input
if len(op_mutable_attribute_name_list) > 0:
for idx in range(len(op_mutable_attribute_name_list)): for idx in range(len(op_mutable_attribute_name_list)):
input_info_list.append( input_info_list.append(
CONSTRUCT_INPUT_INFO_TEMPLATE.format( CONSTRUCT_INPUT_INFO_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx], name=op_mutable_attribute_name_list[idx],
typename=op_mutable_attribute_type_list[idx], typename=op_mutable_attribute_type_list[idx][0],
optional='false', optional='false',
no_need_buffer='false', no_need_buffer='false',
is_mutable_attribute='true',
) )
) )
if len(input_info_list) > 0:
inputs_info_str = ", ".join(input_info_list) inputs_info_str = ", ".join(input_info_list)
else:
inputs_info_str = ""
# generate get op info funciton: outputs # generate get op info funciton: outputs
outputs_info_str = "" outputs_info_str = ""
if len(op_output_name_list) > 0: if len(op_output_name_list) > 0:
...@@ -1232,25 +1451,21 @@ def OpGenerator( ...@@ -1232,25 +1451,21 @@ def OpGenerator(
) )
) )
outputs_info_str = ", ".join(output_info_list) outputs_info_str = ", ".join(output_info_list)
# generate get op info funciton: attributes # generate get op info funciton: attributes
attribute_info_str = "" attribute_info_str = ""
op_mutable_attribute_name_set = set(op_mutable_attribute_name_list) if len(op_non_mutable_attribute_name_list) > 0:
if len(op_attribute_name_list) > 0:
attribute_info_list = [] attribute_info_list = []
for idx in range(len(op_attribute_name_list)): for idx in range(len(op_non_mutable_attribute_name_list)):
attribute_name = op_attribute_name_list[idx]
if attribute_name in op_mutable_attribute_name_set:
continue
attribute_info_list.append( attribute_info_list.append(
CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format( CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format(
name=attribute_name, name=op_non_mutable_attribute_name_list[idx],
typename=op_attribute_type_list[idx], typename=op_non_mutable_attribute_type_list[idx],
data_type=op_attribute_data_type_list[idx], data_type=op_non_mutable_attribute_data_type_list[
idx
],
) )
) )
attribute_info_str = ", ".join(attribute_info_list) attribute_info_str = ", ".join(attribute_info_list)
# generate runtiem info # generate runtiem info
infer_meta_func_str = "" infer_meta_func_str = ""
infer_meta_param_str = "" infer_meta_param_str = ""
...@@ -1274,6 +1489,9 @@ def OpGenerator( ...@@ -1274,6 +1489,9 @@ def OpGenerator(
kernel_param=kernel_param_str, kernel_param=kernel_param_str,
) )
# =================================== #
# gen Verify func str #
# =================================== #
# generate op verify function: inputs_type_check_str # generate op verify function: inputs_type_check_str
if ( if (
len(op_input_type_list) + len(op_mutable_attribute_name_list) len(op_input_type_list) + len(op_mutable_attribute_name_list)
...@@ -1327,7 +1545,6 @@ def OpGenerator( ...@@ -1327,7 +1545,6 @@ def OpGenerator(
standard="paddle::dialect::DenseTensorType", standard="paddle::dialect::DenseTensorType",
) )
inputs_type_check_str += check_str inputs_type_check_str += check_str
# generate op verify function: outputs_type_check_str # generate op verify function: outputs_type_check_str
if len(op_output_type_list) == 0: if len(op_output_type_list) == 0:
outputs_type_check_str = ( outputs_type_check_str = (
...@@ -1364,7 +1581,6 @@ def OpGenerator( ...@@ -1364,7 +1581,6 @@ def OpGenerator(
index=idx, standard=output_type index=idx, standard=output_type
) )
outputs_type_check_str += check_str outputs_type_check_str += check_str
# generate op verify function: attributes_check_str # generate op verify function: attributes_check_str
if len(op_non_mutable_attribute_name_list) == 0: if len(op_non_mutable_attribute_name_list) == 0:
attributes_check_str = ( attributes_check_str = (
...@@ -1387,7 +1603,6 @@ def OpGenerator( ...@@ -1387,7 +1603,6 @@ def OpGenerator(
attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format( attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format(
attribute_name=attribute_name, standard=attribute_type attribute_name=attribute_name, standard=attribute_type
) )
# generate op verify function # generate op verify function
if "GradOp" in op_class_name or "Grad_Op" in op_class_name: if "GradOp" in op_class_name or "Grad_Op" in op_class_name:
op_verify_str = GRAD_OP_VERIFY_TEMPLATE.format( op_verify_str = GRAD_OP_VERIFY_TEMPLATE.format(
...@@ -1415,7 +1630,9 @@ def OpGenerator( ...@@ -1415,7 +1630,9 @@ def OpGenerator(
ops_declare_list.append(op_declare_str) ops_declare_list.append(op_declare_str)
ops_defined_list.append(op_defined_str) ops_defined_list.append(op_defined_str)
ops_defined_list.append(op_info_func_str) ops_defined_list.append(op_info_func_str)
ops_defined_list.append(build_func_declare_str) ops_defined_list.append(build_func_with_muta_attr_not_input)
if len(op_mutable_attribute_name_list) > 0:
ops_defined_list.append(build_func_with_muta_attr_is_input)
ops_defined_list.append(op_verify_str) ops_defined_list.append(op_verify_str)
ops_defined_list.append(op_infer_shape_str) ops_defined_list.append(op_infer_shape_str)
......
...@@ -26,5 +26,23 @@ phi::DataLayout DataLayoutAttribute::data() const { ...@@ -26,5 +26,23 @@ phi::DataLayout DataLayoutAttribute::data() const {
return storage()->GetAsKey(); return storage()->GetAsKey();
} }
phi::Scalar ScalarAttribute::data() {
if (isa<ir::FloatAttribute>()) {
return phi::Scalar(dyn_cast<ir::FloatAttribute>().data());
} else if (isa<ir::DoubleAttribute>()) {
return phi::Scalar(dyn_cast<ir::DoubleAttribute>().data());
} else if (isa<ir::Int32_tAttribute>()) {
return phi::Scalar(dyn_cast<ir::Int32_tAttribute>().data());
} else if (isa<ir::Int64_tAttribute>()) {
return phi::Scalar(dyn_cast<ir::Int64_tAttribute>().data());
} else if (isa<ir::BoolAttribute>()) {
return phi::Scalar(dyn_cast<ir::BoolAttribute>().data());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported ir attribute when casting it into "
"phi scalar."));
}
}
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include "paddle/fluid/ir/dialect/pd_attribute_storage.h" #include "paddle/fluid/ir/dialect/pd_attribute_storage.h"
#include "paddle/ir/core/attribute.h" #include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/enforce.h"
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {
...@@ -45,6 +47,8 @@ class ScalarAttribute : public ir::Attribute { ...@@ -45,6 +47,8 @@ class ScalarAttribute : public ir::Attribute {
(val.type_id() == ir::Int32_tAttribute::type_id()) || (val.type_id() == ir::Int32_tAttribute::type_id()) ||
(val.type_id() == ir::Int64_tAttribute::type_id()); (val.type_id() == ir::Int64_tAttribute::type_id());
} }
phi::Scalar data();
}; };
class DataTypeAttribute : public ir::Attribute { class DataTypeAttribute : public ir::Attribute {
......
...@@ -101,14 +101,17 @@ struct OpInputInfo { ...@@ -101,14 +101,17 @@ struct OpInputInfo {
std::string type_name; std::string type_name;
bool optional = false; bool optional = false;
bool no_need_buffer = false; bool no_need_buffer = false;
bool is_mutable_attribute = false;
OpInputInfo(std::string name, OpInputInfo(std::string name,
std::string type_name, std::string type_name,
bool optional, bool optional,
bool no_need_buffer) bool no_need_buffer,
bool is_mutable_attribute)
: name(name), : name(name),
type_name(type_name), type_name(type_name),
optional(optional), optional(optional),
no_need_buffer(no_need_buffer) {} no_need_buffer(no_need_buffer),
is_mutable_attribute(is_mutable_attribute) {}
}; };
struct OpOutputInfo { struct OpOutputInfo {
......
...@@ -56,7 +56,7 @@ class Builder { ...@@ -56,7 +56,7 @@ class Builder {
template <typename OpTy, typename... Args> template <typename OpTy, typename... Args>
OpTy Build(Args &&...args) { OpTy Build(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name())); OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::Build(argument, std::forward<Args>(args)...); OpTy::Build(*this, argument, std::forward<Args>(args)...);
Operation *op = Build(std::move(argument)); Operation *op = Build(std::move(argument));
return op->dyn_cast<OpTy>(); return op->dyn_cast<OpTy>();
} }
......
...@@ -57,20 +57,15 @@ void ModuleOp::Verify(const std::vector<ir::OpResult> &inputs, ...@@ -57,20 +57,15 @@ void ModuleOp::Verify(const std::vector<ir::OpResult> &inputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
// Verify inputs type: // Verify inputs type:
if (inputs.size() != 0) { IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");
throw("The size of inputs must be equal to 0.");
}
// Verify if attributes contain attribute name in attributes_name: // Verify if attributes contain attribute name in attributes_name:
auto iter = attributes.find("program"); auto iter = attributes.find("program");
if (iter == attributes.end() || !iter->second.isa<PointerAttribute>()) { IR_ENFORCE(iter != attributes.end() && iter->second.isa<PointerAttribute>(),
throw("Type of attribute: program is not right."); "Type of attribute: program is not right.");
}
// Verify outputs type: // Verify outputs type:
if (outputs.size() != 0) { IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
throw("The size of outputs must be equal to 0.");
}
} }
const char *GetParameterOp::attributes_name[attributes_num] = { const char *GetParameterOp::attributes_name[attributes_num] = {
...@@ -81,17 +76,15 @@ void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs, ...@@ -81,17 +76,15 @@ void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp.";
// Verify inputs type: // Verify inputs type:
if (inputs.size() != 0) { IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");
throw("The size of inputs must be equal to 0.");
}
// Verify outputs type:
if (outputs.size() != 1) {
throw("The size of outputs must be equal to 1.");
}
// Verify if attributes contain attribute name in attributes_name: // Verify if attributes contain attribute name in attributes_name:
if (!attributes.at("parameter_name").isa<StrAttribute>()) { auto iter = attributes.find("parameter_name");
throw("Type of attribute: parameter_name is not right."); IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
} "Type of attribute: parameter_name is not right.");
// Verify outputs type:
IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1.");
} }
const char *SetParameterOp::attributes_name[attributes_num] = { const char *SetParameterOp::attributes_name[attributes_num] = {
...@@ -102,54 +95,45 @@ void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs, ...@@ -102,54 +95,45 @@ void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp.";
// Verify inputs type: // Verify inputs type:
if (inputs.size() != 1) { IR_ENFORCE(inputs.size() == 1, "The size of outputs must be equal to 1.");
throw("The size of inputs must be equal to 1.");
}
// Verify outputs type:
if (outputs.size() != 0) {
throw("The size of outputs must be equal to 0.");
}
// Verify if attributes contain attribute name in attributes_name: // Verify if attributes contain attribute name in attributes_name:
if (!attributes.at("parameter_name").isa<StrAttribute>()) { auto iter = attributes.find("parameter_name");
throw("Type of attribute: parameter_name is not right."); IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
} "Type of attribute: parameter_name is not right.");
// Verify outputs type:
IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
} }
void CombineOp::Verify(const std::vector<ir::OpResult> &inputs, void CombineOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
// outputs.size() == 1 // outputs.size() == 1
PADDLE_ENFORCE_EQ( IR_ENFORCE(outputs.size() == 1,
outputs.size(), "The size %d of outputs must be equal to 1.",
1, outputs.size());
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
// outputs[0].type == Vector<Type> // outputs[0].type == Vector<Type>
PADDLE_ENFORCE(outputs[0].isa<ir::VectorType>(), IR_ENFORCE(outputs[0].isa<ir::VectorType>(),
phi::errors::PreconditionNotMet(
"The type %s of outputs[0] must be equal to VectorType.", "The type %s of outputs[0] must be equal to VectorType.",
outputs[0])); outputs[0]);
ir::VectorType output_type = outputs[0].dyn_cast<ir::VectorType>(); ir::VectorType output_type = outputs[0].dyn_cast<ir::VectorType>();
// inputs.size() == outputs[0].size() // inputs.size() == outputs[0].size()
PADDLE_ENFORCE_EQ( IR_ENFORCE(output_type.size() == inputs.size(),
output_type.size(),
inputs.size(),
phi::errors::PreconditionNotMet(
"The size %d of outputs[0] must be equal to size %d of inputs.", "The size %d of outputs[0] must be equal to size %d of inputs.",
output_type.size(), output_type.size(),
inputs.size())); inputs.size());
// forall i in inputs.size(): inputs[i].type == outputs[0][i].type // forall i in inputs.size(): inputs[i].type == outputs[0][i].type
for (size_t i = 0; i < inputs.size(); i++) { for (size_t i = 0; i < inputs.size(); i++) {
PADDLE_ENFORCE_EQ( IR_ENFORCE(output_type[i] == inputs[i].type(),
output_type[i], "The type %s of outputs[0][%d] must be "
inputs[i].type(),
phi::errors::PreconditionNotMet("The type %s of outputs[0][%d] must be "
"equal to type %s of inputs[%d].", "equal to type %s of inputs[%d].",
output_type[i], output_type[i],
i, i,
inputs[i].type(), inputs[i].type(),
i)); i);
} }
} }
...@@ -158,65 +142,50 @@ void SliceOp::Verify(const std::vector<ir::OpResult> &inputs, ...@@ -158,65 +142,50 @@ void SliceOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
// inputs.size() == 1 // inputs.size() == 1
PADDLE_ENFORCE_EQ( IR_ENFORCE(inputs.size() == 1,
inputs.size(), "The size %d of inputs must be equal to 1.",
1, inputs.size());
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 1.", inputs.size()));
// inputs[0].type == Vector<Type> // inputs[0].type == Vector<Type>
PADDLE_ENFORCE(inputs[0].type().isa<ir::VectorType>(), IR_ENFORCE(inputs[0].type().isa<ir::VectorType>(),
phi::errors::PreconditionNotMet(
"The type %s of inputs[0] must be equal to VectorType.", "The type %s of inputs[0] must be equal to VectorType.",
inputs[0].type())); inputs[0].type());
ir::VectorType input_type = inputs[0].type().dyn_cast<ir::VectorType>(); ir::VectorType input_type = inputs[0].type().dyn_cast<ir::VectorType>();
// outputs.size() == 1 // outputs.size() == 1
PADDLE_ENFORCE_EQ( IR_ENFORCE(outputs.size() == 1,
outputs.size(), "The size %d of outputs must be equal to 1.",
1, outputs.size());
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
// attributes contains index: Int32 // attributes contains index: Int32
PADDLE_ENFORCE_NE( IR_ENFORCE(attributes.count("index") != 0,
attributes.count("index"), "The attributes must contains index.");
0,
phi::errors::PreconditionNotMet("The attributes must contains index."));
const ir::Attribute &attr = attributes.at("index"); const ir::Attribute &attr = attributes.at("index");
PADDLE_ENFORCE( IR_ENFORCE(attr.isa<ir::Int32_tAttribute>(),
attr.isa<ir::Int32_tAttribute>(), "The attribute index must be INT32.");
phi::errors::PreconditionNotMet("The attribute index must be INT32."));
auto index = attr.dyn_cast<ir::Int32_tAttribute>().data(); auto index = attr.dyn_cast<ir::Int32_tAttribute>().data();
// index >= 0 and < inputs[0].size() // index >= 0 and < inputs[0].size()
PADDLE_ENFORCE_GE( IR_ENFORCE(
index, index >= 0, "The index %d must be greater or equal than 0.", index);
0, IR_ENFORCE(static_cast<size_t>(index) < input_type.size(),
phi::errors::PreconditionNotMet(
"The index %d must be greater or equal than 0.", index));
PADDLE_ENFORCE_LT(
index,
input_type.size(),
phi::errors::PreconditionNotMet(
"The index %d must be less or equal than size %d of inputs[0].", "The index %d must be less or equal than size %d of inputs[0].",
index, index,
input_type.size())); input_type.size());
// inputs[index].type == outputs[0].type // inputs[index].type == outputs[0].type
PADDLE_ENFORCE_EQ( IR_ENFORCE(
input_type[index], input_type[index] == outputs[0],
outputs[0],
phi::errors::PreconditionNotMet(
"The type %s of inputs[%d] must be equal to type %s of outputs[0].", "The type %s of inputs[%d] must be equal to type %s of outputs[0].",
input_type[index], input_type[index],
index, index,
outputs[0])); outputs[0]);
} }
const char *ConstantOp::attributes_name[attributes_num] = {"value"}; const char *ConstantOp::attributes_name[attributes_num] = {"value"};
void ConstantOp::Build(OperationArgument &argument, void ConstantOp::Build(Builder &builder,
OperationArgument &argument,
Attribute value, Attribute value,
Type output_type) { Type output_type) {
argument.AddAttribute("value", value); argument.AddAttribute("value", value);
......
...@@ -86,6 +86,7 @@ class CombineOp : public ir::Op<CombineOp> { ...@@ -86,6 +86,7 @@ class CombineOp : public ir::Op<CombineOp> {
static constexpr uint32_t attributes_num = 0; static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr; static constexpr const char **attributes_name = nullptr;
static void Verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes); const ir::AttributeMap &attributes);
...@@ -125,7 +126,8 @@ class ConstantOp : public Op<ConstantOp, ConstantLikeTrait> { ...@@ -125,7 +126,8 @@ class ConstantOp : public Op<ConstantOp, ConstantLikeTrait> {
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void Build(OperationArgument &argument, // NOLINT static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
Attribute value, Attribute value,
Type output_type); Type output_type);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/ir/core/block.h" #include "paddle/ir/core/block.h"
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/op_info.h" #include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
...@@ -85,7 +86,7 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs, ...@@ -85,7 +86,7 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
base_ptr += sizeof(Operation); base_ptr += sizeof(Operation);
// 3.3. Construct OpOperands. // 3.3. Construct OpOperands.
if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) { if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) {
throw("The address of OpOperandImpl must be divisible by 8."); IR_THROW("The address of OpOperandImpl must be divisible by 8.");
} }
for (size_t idx = 0; idx < num_operands; idx++) { for (size_t idx = 0; idx < num_operands; idx++) {
new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op); new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op);
...@@ -147,7 +148,7 @@ void Operation::Destroy() { ...@@ -147,7 +148,7 @@ void Operation::Destroy() {
// 2.2. Deconstruct Operation. // 2.2. Deconstruct Operation.
if (reinterpret_cast<uintptr_t>(base_ptr) != if (reinterpret_cast<uintptr_t>(base_ptr) !=
reinterpret_cast<uintptr_t>(this)) { reinterpret_cast<uintptr_t>(this)) {
throw("Operation address error"); IR_THROW("Operation address error");
} }
reinterpret_cast<Operation *>(base_ptr)->~Operation(); reinterpret_cast<Operation *>(base_ptr)->~Operation();
base_ptr += sizeof(Operation); base_ptr += sizeof(Operation);
...@@ -178,7 +179,7 @@ Operation::Operation(const AttributeMap &attributes, ...@@ -178,7 +179,7 @@ Operation::Operation(const AttributeMap &attributes,
ir::OpResult Operation::GetResultByIndex(uint32_t index) const { ir::OpResult Operation::GetResultByIndex(uint32_t index) const {
if (index >= num_results_) { if (index >= num_results_) {
throw("index exceeds OP output range."); IR_THROW("index exceeds OP output range.");
} }
uint32_t max_inline_idx = detail::OpResultImpl::GetMaxInlineResultIndex(); uint32_t max_inline_idx = detail::OpResultImpl::GetMaxInlineResultIndex();
const char *ptr = const char *ptr =
...@@ -199,7 +200,7 @@ ir::OpResult Operation::GetResultByIndex(uint32_t index) const { ...@@ -199,7 +200,7 @@ ir::OpResult Operation::GetResultByIndex(uint32_t index) const {
ir::OpOperand Operation::GetOperandByIndex(uint32_t index) const { ir::OpOperand Operation::GetOperandByIndex(uint32_t index) const {
if (index >= num_operands_) { if (index >= num_operands_) {
throw("index exceeds OP input range."); IR_THROW("index exceeds OP input range.");
} }
const char *ptr = reinterpret_cast<const char *>(this) + sizeof(Operation) + const char *ptr = reinterpret_cast<const char *>(this) + sizeof(Operation) +
(index) * sizeof(detail::OpOperandImpl); (index) * sizeof(detail::OpOperandImpl);
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include "paddle/ir/core/enforce.h"
namespace ir { namespace ir {
// This is a structure for creating, caching, and looking up Storage of // This is a structure for creating, caching, and looking up Storage of
// parametric types. // parametric types.
...@@ -76,7 +78,7 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageImpl( ...@@ -76,7 +78,7 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageImpl(
<< std::hash<ir::TypeId>()(type_id) << ", param_hash=" << hash_value << std::hash<ir::TypeId>()(type_id) << ", param_hash=" << hash_value
<< "]."; << "].";
if (parametric_instance_.find(type_id) == parametric_instance_.end()) { if (parametric_instance_.find(type_id) == parametric_instance_.end()) {
throw("The input data pointer is null."); IR_THROW("The input data pointer is null.");
} }
ParametricStorageManager &parametric_storage = *parametric_instance_[type_id]; ParametricStorageManager &parametric_storage = *parametric_instance_[type_id];
return parametric_storage.GetOrCreate(hash_value, equal_func, constructor); return parametric_storage.GetOrCreate(hash_value, equal_func, constructor);
...@@ -88,7 +90,7 @@ StorageManager::StorageBase *StorageManager::GetParameterlessStorageImpl( ...@@ -88,7 +90,7 @@ StorageManager::StorageBase *StorageManager::GetParameterlessStorageImpl(
VLOG(4) << "Try to get a parameterless storage of: [TypeId_hash=" VLOG(4) << "Try to get a parameterless storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "]."; << std::hash<ir::TypeId>()(type_id) << "].";
if (parameterless_instance_.find(type_id) == parameterless_instance_.end()) if (parameterless_instance_.find(type_id) == parameterless_instance_.end())
throw("TypeId not found in IrContext."); IR_THROW("TypeId not found in IrContext.");
StorageBase *parameterless_instance = parameterless_instance_[type_id]; StorageBase *parameterless_instance = parameterless_instance_[type_id];
return parameterless_instance; return parameterless_instance;
} }
...@@ -107,7 +109,7 @@ void StorageManager::RegisterParameterlessStorageImpl( ...@@ -107,7 +109,7 @@ void StorageManager::RegisterParameterlessStorageImpl(
VLOG(4) << "Register a parameterless storage of: [TypeId_hash=" VLOG(4) << "Register a parameterless storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "]."; << std::hash<ir::TypeId>()(type_id) << "].";
if (parameterless_instance_.find(type_id) != parameterless_instance_.end()) if (parameterless_instance_.find(type_id) != parameterless_instance_.end())
throw("storage class already registered"); IR_THROW("storage class already registered");
parameterless_instance_.emplace(type_id, constructor()); parameterless_instance_.emplace(type_id, constructor());
} }
......
...@@ -427,6 +427,18 @@ ...@@ -427,6 +427,18 @@
data_type : dtype data_type : dtype
backend : place backend : place
- op : full_int_array
args : (IntArray value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output: Tensor(out)
infer_meta :
func : CreateIntArrayInferMeta
param : [value, dtype]
kernel :
func : full_int_array
param : [value, dtype]
data_type : dtype
backend : place
- op : full_like - op : full_like
args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Place place = {}) args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Place place = {})
output: Tensor(out) output: Tensor(out)
......
...@@ -62,9 +62,9 @@ ...@@ -62,9 +62,9 @@
beta2 : beta2 :
data_type : float data_type : float
tensor_name : Beta2Tensor tensor_name : Beta2Tensor
episilon : epsilon :
data_type : float data_type : float
tensor_name : EpisilonTensor tensor_name : EpsilonTensor
manual_signature : [adam_] manual_signature : [adam_]
- op : adamax_ - op : adamax_
...@@ -85,9 +85,9 @@ ...@@ -85,9 +85,9 @@
beta2 : beta2 :
data_type : float data_type : float
tensor_name : Beta2Tensor tensor_name : Beta2Tensor
episilon : epsilon :
data_type : float data_type : float
tensor_name : EpisilonTensor tensor_name : EpsilonTensor
- op : add (elementwise_add) - op : add (elementwise_add)
backward : add_grad (elementwise_add_grad) backward : add_grad (elementwise_add_grad)
...@@ -1970,7 +1970,7 @@ ...@@ -1970,7 +1970,7 @@
outputs: outputs:
out : Out out : Out
int_array: int_array:
axis : dims :
data_type : int data_type : int
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
......
...@@ -41,6 +41,15 @@ void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out) { ...@@ -41,6 +41,15 @@ void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out) {
CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out); CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out);
} }
void CreateIntArrayInferMeta(const IntArray& data,
DataType dtype,
MetaTensor* out) {
CreateInferMetaBase({static_cast<int64_t>(data.GetData().size())},
dtype,
DataLayout::NCHW,
out);
}
void CreateInferMetaBase(const std::vector<int64_t>& shape, void CreateInferMetaBase(const std::vector<int64_t>& shape,
DataType dtype, DataType dtype,
DataLayout layout, DataLayout layout,
......
...@@ -35,6 +35,10 @@ void AssignValueInferMeta(const std::vector<int>& shape, ...@@ -35,6 +35,10 @@ void AssignValueInferMeta(const std::vector<int>& shape,
DataType dtype, DataType dtype,
MetaTensor* out); MetaTensor* out);
void CreateIntArrayInferMeta(const IntArray& data,
DataType dtype,
MetaTensor* out);
void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out); void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out);
void CreateInferMetaBase(const std::vector<int64_t>& shape, void CreateInferMetaBase(const std::vector<int64_t>& shape,
......
...@@ -80,6 +80,18 @@ void FullLikeKernel(const Context& dev_ctx, ...@@ -80,6 +80,18 @@ void FullLikeKernel(const Context& dev_ctx,
FullValue<T>(dev_ctx, out, value); FullValue<T>(dev_ctx, out, value);
} }
template <typename T, typename Context>
void FullIntArrayKernel(const Context& dev_ctx,
const IntArray& val,
DataType dtype UNUSED,
DenseTensor* out) {
out->Resize(phi::make_ddim({static_cast<int64_t>(val.GetData().size())}));
T* out_data = dev_ctx.template Alloc<T>(out);
for (size_t i = 0; i < val.GetData().size(); ++i) {
out_data[i] = static_cast<T>(val.GetData()[i]);
}
}
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(full, PD_REGISTER_KERNEL(full,
...@@ -115,3 +127,6 @@ PD_REGISTER_KERNEL(full_like, ...@@ -115,3 +127,6 @@ PD_REGISTER_KERNEL(full_like,
phi::dtype::complex<double>) { phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
} }
PD_REGISTER_KERNEL(
full_int_array, CPU, ALL_LAYOUT, phi::FullIntArrayKernel, int, int64_t) {}
...@@ -83,4 +83,10 @@ DenseTensor FullLike(const Context& dev_ctx, ...@@ -83,4 +83,10 @@ DenseTensor FullLike(const Context& dev_ctx,
return dense_out; return dense_out;
} }
template <typename T, typename Context>
void FullIntArrayKernel(const Context& dev_ctx,
const IntArray& val,
DataType dtype,
DenseTensor* out);
} // namespace phi } // namespace phi
...@@ -44,76 +44,61 @@ ...@@ -44,76 +44,61 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "test/cpp/ir/core/phi_kernel_adaptor.h" #include "test/cpp/ir/core/phi_kernel_adaptor.h"
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(full_int_array, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(uniform, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(uniform, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; } bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; }
TEST(program_test, program) { TEST(program_test, program) {
// Prepare ir env
ir::IrContext* ctx = ir::IrContext::Instance(); ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>(); ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx); ir::Program program(ctx);
ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program.block());
ir::Block* block = program.block(); ir::Block* block = program.block();
ir::Type fp32_dtype = ir::Float32Type::get(ctx);
paddle::dialect::DenseTensorTypeStorage::Dim dims = {2, 2};
paddle::dialect::DenseTensorTypeStorage::DataLayout data_layout =
paddle::dialect::DenseTensorTypeStorage::DataLayout::NCHW;
paddle::dialect::DenseTensorTypeStorage::LoD lod = {};
size_t offset = 0;
ir::Type dense_tensor_dtype = paddle::dialect::DenseTensorType::get(
ctx, fp32_dtype, dims, data_layout, lod, offset);
// (1) Def a = GetParameterOp("a")
std::string op1_name = std::string(paddle::dialect::UniformOp::name());
ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name);
// ir::Attribute shape_1 = ir::ArrayAttribute::get(ctx, {ten} );
ir::Attribute shape_1 = paddle::dialect::IntArrayAttribute::get(
ctx, std::vector<int64_t>({2, 2}));
ir::Attribute data_type =
paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType::FLOAT32);
ir::Attribute min = ir::FloatAttribute::get(ctx, 0.0);
ir::Attribute max = ir::FloatAttribute::get(ctx, 1.0);
ir::Attribute seed = ir::Int32_tAttribute::get(ctx, 2);
ir::Attribute uni_place = paddle::dialect::PlaceAttribute::get(
ctx, phi::Place(phi::AllocationType::CPU));
std::unordered_map<std::string, ir::Attribute> op1_attribute{
{"shape", shape_1},
{"dtype", data_type},
{"min", min},
{"max", max},
{"seed", seed},
{"place", uni_place}};
ir::Operation* op1 =
ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op1_info);
block->push_back(op1);
// (2) Def b = GetParameterOp("b")
std::string op2_name = std::string(paddle::dialect::UniformOp::name());
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
ir::Attribute ten2 = ir::Int32_tAttribute::get(ctx, 3);
std::unordered_map<std::string, ir::Attribute> op2_attribute{{"shape", ten2}};
ir::Operation* op2 =
ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op2_info);
block->push_back(op2);
// (3) Def out = AddOp(a, b)
std::string add_op_name = std::string(paddle::dialect::AddOp::name());
ir::OpInfo add_op_info = ctx->GetRegisteredOpInfo(add_op_name);
ir::Operation* add_op = ir::Operation::Create(
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)},
{},
{dense_tensor_dtype},
add_op_info);
block->push_back(add_op);
// Def: A = paddle::dialect::UniformOp(std::vector<int64_t> shape,
// phi::DataType dtype, float min, float max, int seed, phi::Place place)
paddle::dialect::UniformOp uniform1 =
builder.Build<paddle::dialect::UniformOp>(std::vector<int64_t>{2, 2},
phi::DataType::FLOAT32,
0.0,
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform1->GetResultByIndex(0)
.type()
.isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 4u);
// Def: B = paddle::dialect::UniformOp(...)
paddle::dialect::UniformOp uniform2 =
builder.Build<paddle::dialect::UniformOp>(std::vector<int64_t>{2, 2},
phi::DataType::FLOAT32,
0.0,
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform2->GetResultByIndex(0)
.type()
.isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 8u);
// Def: C = paddle::dialect::AddOp(ir::OpResult x_, ir::OpResult y_)
paddle::dialect::AddOp add = builder.Build<paddle::dialect::AddOp>(
uniform1->GetResultByIndex(0), uniform2->GetResultByIndex(0));
EXPECT_EQ(
add->GetResultByIndex(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 9u);
// Execute program
paddle::framework::Scope scope; paddle::framework::Scope scope;
PhiKernelAdaptor phi_kernel_adaptor(&scope); PhiKernelAdaptor phi_kernel_adaptor(&scope);
phi_kernel_adaptor.run(&program); phi_kernel_adaptor.run(&program);
auto out_tensor = auto out_tensor =
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_base.h" #include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
...@@ -240,10 +241,12 @@ TEST(op_test, module_op_death) { ...@@ -240,10 +241,12 @@ TEST(op_test, module_op_death) {
ir::AttributeMap attrs{{"program", ir::Int32_tAttribute::get(ctx, 1)}}; ir::AttributeMap attrs{{"program", ir::Int32_tAttribute::get(ctx, 1)}};
std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)};
EXPECT_THROW(ir::Operation::Create(inputs, {}, {}, op_info), const char *); EXPECT_THROW(ir::Operation::Create(inputs, {}, {}, op_info),
EXPECT_THROW(ir::Operation::Create({}, attrs, {}, op_info), const char *); ir::IrNotMetException);
EXPECT_THROW(ir::Operation::Create({}, attrs, {}, op_info),
ir::IrNotMetException);
EXPECT_THROW(ir::Operation::Create({}, {}, output_types, op_info), EXPECT_THROW(ir::Operation::Create({}, {}, output_types, op_info),
const char *); ir::IrNotMetException);
ir::Program program(ctx); ir::Program program(ctx);
......
...@@ -98,27 +98,30 @@ void build_context(ir::Operation* op, ...@@ -98,27 +98,30 @@ void build_context(ir::Operation* op,
op->dyn_cast<paddle::dialect::OpYamlInfoInterface>(); op->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
auto op_info_res = op_info_interface.GetOpInfo(); auto op_info_res = op_info_interface.GetOpInfo();
// inputs include input and mutable attributes
auto input_info = std::get<0>(op_info_res); auto input_info = std::get<0>(op_info_res);
std::map<std::string, size_t> input_index_map;
std::set<std::string> input_set; std::map<std::string, std::string> mutable_attr_type_map;
int input_index = 0;
for (auto& t : input_info) { for (auto& t : input_info) {
VLOG(6) << t.name << "\t" << t.type_name; VLOG(6) << t.name << "\t" << t.type_name;
input_index_map[t.name] = input_index++;
input_set.insert(t.name); if (t.is_mutable_attribute) {
mutable_attr_type_map[t.name] = t.type_name;
}
} }
auto attr_map = op->attributes();
std::map<std::string, std::string> attr_type_map;
auto attr_info = std::get<1>(op_info_res); auto attr_info = std::get<1>(op_info_res);
std::map<std::string, std::string> attr_type_map;
for (auto& t : attr_info) { for (auto& t : attr_info) {
VLOG(6) << t.name << "\t" << t.type_name; VLOG(6) << t.name << "\t" << t.type_name;
attr_type_map[t.name] = t.type_name; attr_type_map[t.name] = t.type_name;
} }
auto attr_map = op->attributes();
auto runtime_info = std::get<3>(op_info_res); auto runtime_info = std::get<3>(op_info_res);
int input_index = 0; // int input_index = 0;
std::vector<std::string> vec_param_list; std::vector<std::string> vec_param_list;
if (is_infer_meta) { if (is_infer_meta) {
vec_param_list = runtime_info.infer_meta_param; vec_param_list = runtime_info.infer_meta_param;
...@@ -126,14 +129,32 @@ void build_context(ir::Operation* op, ...@@ -126,14 +129,32 @@ void build_context(ir::Operation* op,
vec_param_list = runtime_info.kernel_param; vec_param_list = runtime_info.kernel_param;
} }
for (auto& t : vec_param_list) { for (auto& t : vec_param_list) {
if (input_set.count(t)) { if (input_index_map.count(t)) {
// get information from input // get information from input
ir::Value ptr = op->GetOperandByIndex(input_index++).source(); ir::Value ptr = op->GetOperandByIndex(input_index_map[t]).source();
auto in_var_name = name_map.at(ptr); auto in_var_name = name_map.at(ptr);
if (mutable_attr_type_map.count(t)) {
VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t"
<< in_var_name;
if (mutable_attr_type_map[t] == "paddle::dialect::IntArrayAttribute") {
ctx->EmplaceBackAttr(phi::IntArray(
*(scope->Var(in_var_name)->GetMutable<phi::DenseTensor>())));
} else if (mutable_attr_type_map[t] ==
"paddle::dialect::ScalarAttribute") {
ctx->EmplaceBackAttr(phi::Scalar(
*(scope->Var(in_var_name)->GetMutable<phi::DenseTensor>())));
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
mutable_attr_type_map[t]));
}
} else {
VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name;
ctx->EmplaceBackInput( ctx->EmplaceBackInput(
scope->Var(in_var_name)->GetMutable<phi::DenseTensor>()); scope->Var(in_var_name)->GetMutable<phi::DenseTensor>());
} }
}
if (attr_type_map.count(t)) { if (attr_type_map.count(t)) {
auto type_name = attr_type_map[t]; auto type_name = attr_type_map[t];
...@@ -149,10 +170,14 @@ void build_context(ir::Operation* op, ...@@ -149,10 +170,14 @@ void build_context(ir::Operation* op,
} else if (type_name == "paddle::dialect::PlaceAttribute") { } else if (type_name == "paddle::dialect::PlaceAttribute") {
ctx->EmplaceBackAttr( ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data()); attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data());
} else if (type_name == "paddle::dialect::ScalarAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::ScalarAttribute>().data());
} else { } else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ", PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
type_name)); type_name));
} }
VLOG(6) << "ctx->EmplaceBackAttr: " << t;
} }
} }
...@@ -197,6 +222,9 @@ class PhiKernelAdaptor { ...@@ -197,6 +222,9 @@ class PhiKernelAdaptor {
phi::KernelKey kernel_key(phi::TransToPhiBackend(cpu_place), phi::KernelKey kernel_key(phi::TransToPhiBackend(cpu_place),
phi::DataLayout::ANY, phi::DataLayout::ANY,
phi::DataType::FLOAT32); phi::DataType::FLOAT32);
if (runtime_info.kernel_func[0] == "full_int_array") {
kernel_key.set_dtype(phi::DataType::INT64);
}
auto found_it = phi_kernels.find(kernel_key); auto found_it = phi_kernels.find(kernel_key);
if (found_it == phi_kernels.end()) { if (found_it == phi_kernels.end()) {
std::cerr << "kernel name " << runtime_info.kernel_func[0] << std::endl; std::cerr << "kernel name " << runtime_info.kernel_func[0] << std::endl;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册