未验证 提交 b63d1cba 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Support mutable attribute for Op build (#54288)

* add constant op

* support mutable attribute

* refine code

* fix bug

* fix bug

* refine code

* fix bug

* refine code

* refine code

* add ut

* refine code

* fix test bug

* solve conflict

* refine code
上级 68d81d0e
...@@ -28,6 +28,7 @@ H_FILE_TEMPLATE = """#ifdef GET_OP_LIST ...@@ -28,6 +28,7 @@ H_FILE_TEMPLATE = """#ifdef GET_OP_LIST
#undef GET_OP_LIST #undef GET_OP_LIST
{op_declare} {op_declare}
#else #else
// This file is generated by "paddle/fluid/ir/dialect/op_gen.py"
#include <vector> #include <vector>
...@@ -35,8 +36,7 @@ H_FILE_TEMPLATE = """#ifdef GET_OP_LIST ...@@ -35,8 +36,7 @@ H_FILE_TEMPLATE = """#ifdef GET_OP_LIST
#include "paddle/ir/core/operation_utils.h" #include "paddle/ir/core/operation_utils.h"
#include "paddle/ir/core/op_base.h" #include "paddle/ir/core/op_base.h"
#include "paddle/fluid/ir/dialect/utils.h" #include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/dialect/pd_interface.h" #include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir/interface/infershape.h" #include "paddle/fluid/ir/interface/infershape.h"
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
...@@ -56,8 +56,8 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ ...@@ -56,8 +56,8 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
{attribute_declare} {attribute_declare}
static constexpr uint32_t attributes_num = {attribute_num}; static constexpr uint32_t attributes_num = {attribute_num};
static OpInfoTuple GetOpInfo(); static OpInfoTuple GetOpInfo();
static void build({build_args}); static void Build({build_args});
static void verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes); static void Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes);
{get_inputs_and_outputs} {get_inputs_and_outputs}
{exclusive_interface} {exclusive_interface}
}}; }};
...@@ -77,11 +77,14 @@ OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return operation() ...@@ -77,11 +77,14 @@ OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return operation()
# ===================================== # =====================================
# String Template for cc file code gen # String Template for cc file code gen
# ===================================== # =====================================
CC_FILE_TEMPLATE = """#include "{h_file}" CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_gen.py"
#include "{h_file}"
#include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/pd_attribute.h" #include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
...@@ -92,15 +95,6 @@ CC_FILE_TEMPLATE = """#include "{h_file}" ...@@ -92,15 +95,6 @@ CC_FILE_TEMPLATE = """#include "{h_file}"
#include "paddle/phi/infermeta/ternary.h" #include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/core/infermeta_utils.h"
{input} {input}
""" """
...@@ -130,7 +124,7 @@ CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = ( ...@@ -130,7 +124,7 @@ CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = (
# build # build
OP_BUILD_TEMPLATE = """ OP_BUILD_TEMPLATE = """
void {op_name}::build({build_args}) {{ void {op_name}::Build({build_args}) {{
{build_inputs} {build_inputs}
{build_attributes} {build_attributes}
{build_outputs} {build_outputs}
...@@ -139,7 +133,7 @@ void {op_name}::build({build_args}) {{ ...@@ -139,7 +133,7 @@ void {op_name}::build({build_args}) {{
# verify # verify
OP_VERIFY_TEMPLATE = """ OP_VERIFY_TEMPLATE = """
void {op_name}::verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{ void {op_name}::Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{
VLOG(4) << "Verifying inputs, outputs and attributes for: {op_name}."; VLOG(4) << "Verifying inputs, outputs and attributes for: {op_name}.";
// Verify inputs type: // Verify inputs type:
...@@ -156,7 +150,7 @@ void {op_name}::verify(const std::vector<ir::OpResult> &inputs, const std::vecto ...@@ -156,7 +150,7 @@ void {op_name}::verify(const std::vector<ir::OpResult> &inputs, const std::vecto
""" """
GRAD_OP_VERIFY_TEMPLATE = """ GRAD_OP_VERIFY_TEMPLATE = """
void {op_name}::verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{ void {op_name}::Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{
(void)inputs; (void)inputs;
(void)outputs; (void)outputs;
(void)attributes; (void)attributes;
...@@ -288,6 +282,7 @@ class OpInfoParser: ...@@ -288,6 +282,7 @@ class OpInfoParser:
self.cross_check( self.cross_check(
self.input_name_list, self.input_type_list, self.input_optional_list self.input_name_list, self.input_type_list, self.input_optional_list
) )
# parse outputs # parse outputs
self.output_name_list = self.parse_output_name_list() self.output_name_list = self.parse_output_name_list()
self.output_type_list = self.parse_output_type_list() self.output_type_list = self.parse_output_type_list()
...@@ -358,6 +353,12 @@ class OpInfoParser: ...@@ -358,6 +353,12 @@ class OpInfoParser:
) )
self.cross_check(self.attribute_name_list, self.attribute_type_list) self.cross_check(self.attribute_name_list, self.attribute_type_list)
# parse mutable attributes (as inputs)
(
self.mutable_attribute_name_list,
self.mutable_attribute_type_list,
) = self.parse_mutable_attribute()
# parse infermeta && kernel # parse infermeta && kernel
self.infer_meta_map = self.parse_infer_meta_map() self.infer_meta_map = self.parse_infer_meta_map()
self.kernel_map = self.parse_kernel_map() self.kernel_map = self.parse_kernel_map()
...@@ -392,6 +393,65 @@ class OpInfoParser: ...@@ -392,6 +393,65 @@ class OpInfoParser:
return self.op_yaml_item['inplace'] return self.op_yaml_item['inplace']
return None return None
def parse_mutable_attribute(self):
"""
{'axis': 'paddle::dialect::ScalarAttribute', 'rotl': 'paddle::dialect::IntArrayAttribute'}
"""
mutable_attribute_name_list = []
mutable_attribute_type_list = []
# scalar
if (self.op_compat_item is not None) and (
'scalar' in self.op_compat_item
):
for scalar_attr in self.op_compat_item['scalar'].keys():
if 'data_type' in self.op_compat_item['scalar'][scalar_attr]:
if (
self.op_compat_item['scalar'][scalar_attr]['data_type']
== "std::string"
):
# see isclose and allclose in op_compat.yaml
mutable_attribute_name_list.append(scalar_attr)
mutable_attribute_type_list.append(
["ir::StrAttribute", "std::string"]
)
else:
mutable_attribute_name_list.append(scalar_attr)
mutable_attribute_type_list.append(
[
"paddle::dialect::ScalarAttribute",
self.op_compat_item['scalar'][scalar_attr][
'data_type'
],
]
)
# See eye in op_compat.yaml
else:
mutable_attribute_name_list.append(scalar_attr)
mutable_attribute_type_list.append(
[
"paddle::dialect::ScalarAttribute",
self.attribute_data_type_list[
self.attribute_name_list.index(scalar_attr)
],
]
)
# int_array
if (self.op_compat_item is not None) and (
'int_array' in self.op_compat_item
):
for int_array_attr in self.op_compat_item['int_array']:
mutable_attribute_name_list.append(int_array_attr)
mutable_attribute_type_list.append(
[
"paddle::dialect::IntArrayAttribute",
self.op_compat_item['int_array'][int_array_attr][
'data_type'
],
]
)
return mutable_attribute_name_list, mutable_attribute_type_list
def parse_input_name_list(self): def parse_input_name_list(self):
name_list = [] name_list = []
for input_info in self.op_yaml_item['inputs']: for input_info in self.op_yaml_item['inputs']:
...@@ -580,29 +640,45 @@ def to_pascal_case(s): ...@@ -580,29 +640,45 @@ def to_pascal_case(s):
# ===================================== # =====================================
def GenBuildInputArgsStr( def GenBuildInputArgsStr(
op_input_name_list, op_input_name_list,
op_attribute_name_list, op_mutable_attribute_name_list,
op_attribute_build_arg_type_list, op_non_mutable_attribute_name_list,
op_attribute_default_value_list, op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
for_func_define=True, for_func_define=True,
): ):
''' '''
Example: ir::Builder &builder, ir::OperationArgument &argument, ir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={} Example: ir::Builder &builder, ir::OperationArgument &argument, ir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={}
''' '''
build_args_str = "ir::Builder &builder, ir::OperationArgument &argument" build_args_str = "ir::Builder &builder, ir::OperationArgument &argument"
# add inputs
if len(op_input_name_list) > 0: if len(op_input_name_list) > 0:
for input_name in op_input_name_list: for input_name in op_input_name_list:
build_args_str += ", ir::OpResult " + input_name + "_" build_args_str += ", ir::OpResult " + input_name + "_"
for attr_idx in range(len(op_attribute_name_list)): # add mutable attributes as inputs
if len(op_mutable_attribute_name_list) > 0:
for mutable_attr in op_mutable_attribute_name_list:
build_args_str += ", ir::OpResult " + mutable_attr + "_"
# add non-mutable attributes
for attr_idx in range(len(op_non_mutable_attribute_name_list)):
build_args_str += ( build_args_str += (
", " ", "
+ op_attribute_build_arg_type_list[attr_idx] + op_non_mutable_attribute_build_arg_type_list[attr_idx]
+ " " + " "
+ op_attribute_name_list[attr_idx] + op_non_mutable_attribute_name_list[attr_idx]
) )
if for_func_define: if for_func_define:
if op_attribute_default_value_list[attr_idx] is not None: if (
default_value = op_attribute_default_value_list[attr_idx] op_non_mutable_attribute_default_value_list[attr_idx]
if op_attribute_build_arg_type_list[attr_idx] != "std::string": is not None
):
default_value = op_non_mutable_attribute_default_value_list[
attr_idx
]
if (
op_non_mutable_attribute_build_arg_type_list[attr_idx]
!= "std::string"
):
if default_value[0] == "'" or default_value[0] == '"': if default_value[0] == "'" or default_value[0] == '"':
default_value = default_value[1:] default_value = default_value[1:]
if default_value[-1] == "'" or default_value[-1] == '"': if default_value[-1] == "'" or default_value[-1] == '"':
...@@ -611,20 +687,24 @@ def GenBuildInputArgsStr( ...@@ -611,20 +687,24 @@ def GenBuildInputArgsStr(
return build_args_str return build_args_str
def GenBuildInputs(op_input_name_list): def GenBuildInputs(op_input_name_list, op_mutable_attribute_name_list):
BUILD_INPUT_TEMPLATE = """ std::vector<ir::OpResult> argument_inputs = {{{inputs_args}}}; BUILD_INPUT_TEMPLATE = """ std::vector<ir::OpResult> argument_inputs = {{{inputs_args}}};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); argument.AddOperands(argument_inputs.begin(), argument_inputs.end());
""" """
build_input_str = "" build_input_str = ' VLOG(4) << "Builder construction inputs";\n'
if len(op_input_name_list) > 0: input_name_list = op_input_name_list + op_mutable_attribute_name_list
inputs_args_str = "_, ".join(op_input_name_list) + "_" if len(input_name_list) > 0:
build_input_str = BUILD_INPUT_TEMPLATE.format( inputs_args_str = ""
inputs_args_str += "_, ".join(input_name_list) + "_"
build_input_str += BUILD_INPUT_TEMPLATE.format(
inputs_args=inputs_args_str inputs_args=inputs_args_str
) )
return build_input_str return build_input_str
def GenBuildAttributes(op_attribute_name_list, op_attribute_type_list): def GenBuildAttributes(
op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list
):
INTARRAY_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::IntArray({attr})); INTARRAY_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::IntArray({attr}));
""" """
SCALAR_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::Scalar({attr})); SCALAR_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::Scalar({attr}));
...@@ -638,63 +718,72 @@ def GenBuildAttributes(op_attribute_name_list, op_attribute_type_list): ...@@ -638,63 +718,72 @@ def GenBuildAttributes(op_attribute_name_list, op_attribute_type_list):
}} }}
ir::Attribute attr_{attr_name} = ir::ArrayAttribute::get(ir::IrContext::Instance(), vec_{attr_name}); ir::Attribute attr_{attr_name} = ir::ArrayAttribute::get(ir::IrContext::Instance(), vec_{attr_name});
""" """
attr_str = "" attr_str = ' VLOG(4) << "Builder construction attributes";\n'
for idx in range(len(op_attribute_name_list)): for idx in range(len(op_non_mutable_attribute_name_list)):
if "ir::ArrayAttribute<" in op_attribute_type_list[idx]: if "ir::ArrayAttribute<" in op_non_mutable_attribute_type_list[idx]:
inner_attribute_type = op_attribute_type_list[idx][19:-1] inner_attribute_type = op_non_mutable_attribute_type_list[idx][
19:-1
]
if inner_attribute_type == "paddle::dialect::IntArrayAttribute": if inner_attribute_type == "paddle::dialect::IntArrayAttribute":
attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format( attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format(
attr_name=op_attribute_name_list[idx], attr_name=op_non_mutable_attribute_name_list[idx],
attr_size=op_attribute_name_list[idx] + ".size()", attr_size=op_non_mutable_attribute_name_list[idx]
+ ".size()",
create_attribute=INTARRAY_STR_TEMPLATE.format( create_attribute=INTARRAY_STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx], attr_name=op_non_mutable_attribute_name_list[idx],
op_attribute_type=inner_attribute_type, op_attribute_type=inner_attribute_type,
attr=op_attribute_name_list[idx] + "[i]", attr=op_non_mutable_attribute_name_list[idx] + "[i]",
), ),
) )
elif inner_attribute_type == "paddle::dialect::ScalarAttribute": elif inner_attribute_type == "paddle::dialect::ScalarAttribute":
attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format( attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format(
attr_name=op_attribute_name_list[idx], attr_name=op_non_mutable_attribute_name_list[idx],
attr_size=op_attribute_name_list[idx] + ".size()", attr_size=op_non_mutable_attribute_name_list[idx]
+ ".size()",
create_attribute=SCALAR_STR_TEMPLATE.format( create_attribute=SCALAR_STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx], attr_name=op_non_mutable_attribute_name_list[idx],
op_attribute_type=inner_attribute_type, op_attribute_type=inner_attribute_type,
attr=op_attribute_name_list[idx] + "[i]", attr=op_non_mutable_attribute_name_list[idx] + "[i]",
), ),
) )
else: else:
attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format( attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format(
attr_name=op_attribute_name_list[idx], attr_name=op_non_mutable_attribute_name_list[idx],
attr_size=op_attribute_name_list[idx] + ".size()", attr_size=op_non_mutable_attribute_name_list[idx]
+ ".size()",
create_attribute=STR_TEMPLATE.format( create_attribute=STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx], attr_name=op_non_mutable_attribute_name_list[idx],
op_attribute_type=inner_attribute_type, op_attribute_type=inner_attribute_type,
attr=op_attribute_name_list[idx] + "[i]", attr=op_non_mutable_attribute_name_list[idx] + "[i]",
), ),
) )
elif ( elif (
op_attribute_type_list[idx] == "paddle::dialect::IntArrayAttribute" op_non_mutable_attribute_type_list[idx]
== "paddle::dialect::IntArrayAttribute"
): ):
attr_str += INTARRAY_STR_TEMPLATE.format( attr_str += INTARRAY_STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx], attr_name=op_non_mutable_attribute_name_list[idx],
op_attribute_type=op_attribute_type_list[idx], op_attribute_type=op_non_mutable_attribute_type_list[idx],
attr=op_attribute_name_list[idx], attr=op_non_mutable_attribute_name_list[idx],
) )
elif op_attribute_type_list[idx] == "paddle::dialect::ScalarAttribute": elif (
op_non_mutable_attribute_type_list[idx]
== "paddle::dialect::ScalarAttribute"
):
attr_str += SCALAR_STR_TEMPLATE.format( attr_str += SCALAR_STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx], attr_name=op_non_mutable_attribute_name_list[idx],
op_attribute_type=op_attribute_type_list[idx], op_attribute_type=op_non_mutable_attribute_type_list[idx],
attr=op_attribute_name_list[idx], attr=op_non_mutable_attribute_name_list[idx],
) )
else: else:
attr_str += STR_TEMPLATE.format( attr_str += STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx], attr_name=op_non_mutable_attribute_name_list[idx],
op_attribute_type=op_attribute_type_list[idx], op_attribute_type=op_non_mutable_attribute_type_list[idx],
attr=op_attribute_name_list[idx], attr=op_non_mutable_attribute_name_list[idx],
) )
attr_str += """ argument.AddAttribute("{attr_name}", attr_{attr_name});\n""".format( attr_str += """ argument.AddAttribute("{attr_name}", attr_{attr_name});\n""".format(
attr_name=op_attribute_name_list[idx] attr_name=op_non_mutable_attribute_name_list[idx]
) )
return attr_str return attr_str
...@@ -703,12 +792,14 @@ def GenBuildAttributes(op_attribute_name_list, op_attribute_type_list): ...@@ -703,12 +792,14 @@ def GenBuildAttributes(op_attribute_name_list, op_attribute_type_list):
def GenBuildOutputs( def GenBuildOutputs(
op_input_name_list, op_input_name_list,
op_input_type_list, op_input_type_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_output_name_list, op_output_name_list,
op_output_type_list, op_output_type_list,
op_output_size_list, op_output_size_list,
op_infer_meta_map, op_infer_meta_map,
): ):
build_output_str = "" build_output_str = ' VLOG(4) << "Builder construction outputs";\n'
CREATE_INPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name}; CREATE_INPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
dense_{name}.set_meta( dense_{name}.set_meta(
phi::DenseTensorMeta(TransToPhiDataType({name}.dtype()), phi::DenseTensorMeta(TransToPhiDataType({name}.dtype()),
...@@ -736,6 +827,10 @@ def GenBuildOutputs( ...@@ -736,6 +827,10 @@ def GenBuildOutputs(
meta_{name}.push_back(&vec_meta_{name}[i]); meta_{name}.push_back(&vec_meta_{name}[i]);
}} }}
""" """
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<ir::ConstantOp>().value().dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<ir::ConstantOp>().value().dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{dtype}>(); (void){name};\n"""
CREATE_STRING_MUTABLE_ATTRIBUE_TEMPLATE = """ std::string {name} = {name}_.owner()->dyn_cast<ir::ConstantOp>().value().dyn_cast<ir::StrAttribute>().data(); (void){name};\n"""
CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name}; CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
phi::MetaTensor meta_{name}(&dense_{name}); phi::MetaTensor meta_{name}(&dense_{name});
""" """
...@@ -761,8 +856,31 @@ def GenBuildOutputs( ...@@ -761,8 +856,31 @@ def GenBuildOutputs(
build_output_str += " paddle::dialect::DenseTensorType {name} = {name}_.type().dyn_cast<paddle::dialect::DenseTensorType>(); (void){name};\n".format( build_output_str += " paddle::dialect::DenseTensorType {name} = {name}_.type().dyn_cast<paddle::dialect::DenseTensorType>(); (void){name};\n".format(
name=op_input_name_list[idx] name=op_input_name_list[idx]
) )
# Prepare mutable attributes
for idx in range(len(op_mutable_attribute_name_list)):
attr_dtype = op_mutable_attribute_type_list[idx]
# int_array
if attr_dtype[0] == "paddle::dialect::IntArrayAttribute":
build_output_str += (
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx]
)
)
# scalar
elif attr_dtype[0] == "paddle::dialect::ScalarAttribute":
build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx], dtype=attr_dtype[1]
)
# string
elif attr_dtype[0] == "ir::StrAttribute":
build_output_str += CREATE_STRING_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx]
)
else:
assert "mutable attribtue type is not right."
build_output_str += "\n"
# Prepare inputs for infer meta # Prepare inputs_meta_tensor & attributes for infer meta
infer_meta_args = [] infer_meta_args = []
for idx in range(len(op_infer_meta_map['param'])): for idx in range(len(op_infer_meta_map['param'])):
# is input # is input
...@@ -795,7 +913,7 @@ def GenBuildOutputs( ...@@ -795,7 +913,7 @@ def GenBuildOutputs(
else: else:
infer_meta_args.append(op_infer_meta_map['param'][idx]) infer_meta_args.append(op_infer_meta_map['param'][idx])
# Prepare outputs for infer meta # Prepare outputs_meta_tensor for infer meta
for idx in range(len(op_output_name_list)): for idx in range(len(op_output_name_list)):
# is a vector<Tensor> # is a vector<Tensor>
if 'ir::VectorType' in op_output_type_list[idx]: if 'ir::VectorType' in op_output_type_list[idx]:
...@@ -885,24 +1003,55 @@ def OpGenerator( ...@@ -885,24 +1003,55 @@ def OpGenerator(
ops_declare_list = [] # all op class declare store in this list ops_declare_list = [] # all op class declare store in this list
ops_defined_list = [] # all op class defined store in this list ops_defined_list = [] # all op class defined store in this list
for op_info in op_info_items: for op_info in op_info_items:
# get op info # get op inputs info
op_input_name_list = op_info.input_name_list op_input_name_list = op_info.input_name_list
op_input_type_list = op_info.input_type_list op_input_type_list = op_info.input_type_list
op_input_optional_list = op_info.input_optional_list op_input_optional_list = op_info.input_optional_list
op_input_no_need_buffer_list = op_info.input_no_need_buffer_list op_input_no_need_buffer_list = op_info.input_no_need_buffer_list
# get op outputs info
op_output_name_list = op_info.output_name_list op_output_name_list = op_info.output_name_list
op_output_type_list = op_info.output_type_list op_output_type_list = op_info.output_type_list
op_output_size_list = op_info.output_size_list op_output_size_list = op_info.output_size_list
op_output_optional_list = op_info.output_optional_list op_output_optional_list = op_info.output_optional_list
op_output_intermediate_list = op_info.output_intermediate_list op_output_intermediate_list = op_info.output_intermediate_list
# get op mutable attribute
op_mutable_attribute_name_list = op_info.mutable_attribute_name_list
op_mutable_attribute_type_list = op_info.mutable_attribute_type_list
# get op attribute
op_attribute_name_list = op_info.attribute_name_list op_attribute_name_list = op_info.attribute_name_list
op_attribute_type_list = op_info.attribute_type_list op_attribute_type_list = op_info.attribute_type_list
op_attribute_data_type_list = op_info.attribute_data_type_list op_attribute_data_type_list = op_info.attribute_data_type_list
op_attribute_build_arg_type_list = op_info.attribute_build_arg_type_list op_attribute_build_arg_type_list = op_info.attribute_build_arg_type_list
op_attribute_default_value_list = op_info.attribute_default_value_list op_attribute_default_value_list = op_info.attribute_default_value_list
op_non_mutable_attribute_name_list = []
op_non_mutable_attribute_type_list = []
op_non_mutable_attribute_data_type_list = []
op_non_mutable_attribute_build_arg_type_list = []
op_non_mutable_attribute_default_value_list = []
for idx in range(len(op_attribute_name_list)):
if (
op_attribute_name_list[idx]
not in op_mutable_attribute_name_list
):
op_non_mutable_attribute_name_list.append(
op_attribute_name_list[idx]
)
op_non_mutable_attribute_type_list.append(
op_attribute_type_list[idx]
)
op_non_mutable_attribute_data_type_list.append(
op_attribute_data_type_list[idx]
)
op_non_mutable_attribute_build_arg_type_list.append(
op_attribute_build_arg_type_list[idx]
)
op_non_mutable_attribute_default_value_list.append(
op_attribute_default_value_list[idx]
)
# others
op_infer_meta_map = op_info.infer_meta_map op_infer_meta_map = op_info.infer_meta_map
op_kernel_map = op_info.kernel_map op_kernel_map = op_info.kernel_map
op_interfaces = ["GetOpInfoInterface"] op_interfaces = ["OpYamlInfoInterface"]
op_traits = [] op_traits = []
exclusive_interface_str = "" exclusive_interface_str = ""
...@@ -931,6 +1080,11 @@ def OpGenerator( ...@@ -931,6 +1080,11 @@ def OpGenerator(
input_name=op_input_name_list[idx], input_name=op_input_name_list[idx],
input_index=idx, input_index=idx,
) )
for idx in range(len(op_mutable_attribute_name_list)):
op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format(
input_name=op_mutable_attribute_name_list[idx],
input_index=idx + len(op_input_name_list),
)
for idx in range(len(op_output_name_list)): for idx in range(len(op_output_name_list)):
op_get_inputs_outputs_str += OP_GET_OUTPUT_TEMPLATE.format( op_get_inputs_outputs_str += OP_GET_OUTPUT_TEMPLATE.format(
output_name=op_output_name_list[idx], output_name=op_output_name_list[idx],
...@@ -944,25 +1098,32 @@ def OpGenerator( ...@@ -944,25 +1098,32 @@ def OpGenerator(
if op_infer_meta_map is not None: if op_infer_meta_map is not None:
build_define_input_args_str = GenBuildInputArgsStr( build_define_input_args_str = GenBuildInputArgsStr(
op_input_name_list, op_input_name_list,
op_attribute_name_list, op_mutable_attribute_name_list,
op_attribute_build_arg_type_list, op_non_mutable_attribute_name_list,
op_attribute_default_value_list, op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
True, True,
) )
build_declare_input_args_str = GenBuildInputArgsStr( build_declare_input_args_str = GenBuildInputArgsStr(
op_input_name_list, op_input_name_list,
op_attribute_name_list, op_mutable_attribute_name_list,
op_attribute_build_arg_type_list, op_non_mutable_attribute_name_list,
op_attribute_default_value_list, op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
False, False,
) )
build_inputs_str = GenBuildInputs(op_input_name_list) build_inputs_str = GenBuildInputs(
op_input_name_list, op_mutable_attribute_name_list
)
build_attributes_str = GenBuildAttributes( build_attributes_str = GenBuildAttributes(
op_attribute_name_list, op_attribute_type_list op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
) )
build_outputs_str = GenBuildOutputs( build_outputs_str = GenBuildOutputs(
op_input_name_list, op_input_name_list,
op_input_type_list, op_input_type_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_output_name_list, op_output_name_list,
op_output_type_list, op_output_type_list,
op_output_size_list, op_output_size_list,
...@@ -985,7 +1146,7 @@ def OpGenerator( ...@@ -985,7 +1146,7 @@ def OpGenerator(
) )
# gen op_declare_str/op_defined_str # gen op_declare_str/op_defined_str
if len(op_attribute_name_list) == 0: if len(op_non_mutable_attribute_name_list) == 0:
op_declare_str = OP_DECLARE_TEMPLATE.format( op_declare_str = OP_DECLARE_TEMPLATE.format(
op_name=op_class_name, op_name=op_class_name,
dialect_op_name=op_dialect_name, dialect_op_name=op_dialect_name,
...@@ -1005,19 +1166,19 @@ def OpGenerator( ...@@ -1005,19 +1166,19 @@ def OpGenerator(
interfaces=op_interfaces_str, interfaces=op_interfaces_str,
traits=op_traits_str, traits=op_traits_str,
attribute_declare=op_n_attribute_declare_str.format( attribute_declare=op_n_attribute_declare_str.format(
attribute_num=len(op_attribute_name_list) attribute_num=len(op_non_mutable_attribute_name_list)
), ),
attribute_num=len(op_attribute_name_list), attribute_num=len(op_non_mutable_attribute_name_list),
build_args=build_define_input_args_str, build_args=build_define_input_args_str,
get_inputs_and_outputs=op_get_inputs_outputs_str, get_inputs_and_outputs=op_get_inputs_outputs_str,
exclusive_interface=exclusive_interface_str, exclusive_interface=exclusive_interface_str,
) )
attribute_names_str = ( attribute_names_str = (
'"' + '", "'.join(op_attribute_name_list) + '"' '"' + '", "'.join(op_non_mutable_attribute_name_list) + '"'
) )
op_defined_str = OP_N_ATTRIBUTE_DEFINED_TEMPLATE.format( op_defined_str = OP_N_ATTRIBUTE_DEFINED_TEMPLATE.format(
op_name=op_class_name, op_name=op_class_name,
attribute_num=len(op_attribute_name_list), attribute_num=len(op_non_mutable_attribute_name_list),
attribute_names=attribute_names_str, attribute_names=attribute_names_str,
) )
...@@ -1089,7 +1250,9 @@ def OpGenerator( ...@@ -1089,7 +1250,9 @@ def OpGenerator(
) )
# generate op verify function: inputs_type_check_str # generate op verify function: inputs_type_check_str
if len(op_input_type_list) == 0: if (
len(op_input_type_list) + len(op_mutable_attribute_name_list)
) == 0:
inputs_type_check_str = ( inputs_type_check_str = (
"// Inputs num is 0, not need to check inputs type." "// Inputs num is 0, not need to check inputs type."
) )
...@@ -1125,6 +1288,21 @@ def OpGenerator( ...@@ -1125,6 +1288,21 @@ def OpGenerator(
) )
inputs_type_check_str += check_str inputs_type_check_str += check_str
for idx in range(len(op_mutable_attribute_name_list)):
mutable_attribute_type = op_mutable_attribute_type_list[idx][0]
check_str = ""
if mutable_attribute_type == "paddle::dialect::ScalarAttribute":
check_str = INPUT_TYPE_CHECK_TEMPLATE.format(
index=idx + len(op_input_type_list),
standard="paddle::dialect::DenseTensorType",
)
else:
check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx + len(op_input_type_list),
standard="paddle::dialect::DenseTensorType",
)
inputs_type_check_str += check_str
# generate op verify function: outputs_type_check_str # generate op verify function: outputs_type_check_str
if len(op_output_type_list) == 0: if len(op_output_type_list) == 0:
outputs_type_check_str = ( outputs_type_check_str = (
...@@ -1163,15 +1341,15 @@ def OpGenerator( ...@@ -1163,15 +1341,15 @@ def OpGenerator(
outputs_type_check_str += check_str outputs_type_check_str += check_str
# generate op verify function: attributes_check_str # generate op verify function: attributes_check_str
if len(op_attribute_name_list) == 0: if len(op_non_mutable_attribute_name_list) == 0:
attributes_check_str = ( attributes_check_str = (
"// Attributes num is 0, not need to check attributes type." "// Attributes num is 0, not need to check attributes type."
) )
else: else:
attributes_check_str = "" attributes_check_str = ""
for idx in range(len(op_attribute_name_list)): for idx in range(len(op_non_mutable_attribute_name_list)):
attribute_name = op_attribute_name_list[idx] attribute_name = op_non_mutable_attribute_name_list[idx]
attribute_type = op_attribute_type_list[idx] attribute_type = op_non_mutable_attribute_type_list[idx]
if attribute_type.startswith("ir::ArrayAttribute<"): if attribute_type.startswith("ir::ArrayAttribute<"):
attribute_type = attribute_type[19:-1] attribute_type = attribute_type[19:-1]
attributes_check_str += ( attributes_check_str += (
...@@ -1193,7 +1371,8 @@ def OpGenerator( ...@@ -1193,7 +1371,8 @@ def OpGenerator(
else: else:
op_verify_str = OP_VERIFY_TEMPLATE.format( op_verify_str = OP_VERIFY_TEMPLATE.format(
op_name=op_class_name, op_name=op_class_name,
inputs_size=len(op_input_type_list), inputs_size=len(op_input_type_list)
+ len(op_mutable_attribute_type_list),
outputs_size=len(op_output_type_list), outputs_size=len(op_output_type_list),
inputs_type_check=inputs_type_check_str, inputs_type_check=inputs_type_check_str,
outputs_type_check=outputs_type_check_str, outputs_type_check=outputs_type_check_str,
...@@ -1273,7 +1452,6 @@ def ParseArguments(): ...@@ -1273,7 +1452,6 @@ def ParseArguments():
# ===================================== # =====================================
if __name__ == "__main__": if __name__ == "__main__":
# parse arguments # parse arguments
print("auto gen op")
args = ParseArguments() args = ParseArguments()
op_yaml_files = args.op_yaml_files.split(",") op_yaml_files = args.op_yaml_files.split(",")
op_compat_yaml_file = args.op_compat_yaml_file op_compat_yaml_file = args.op_compat_yaml_file
......
...@@ -70,77 +70,6 @@ inline ir::Type TransToIrDataType(phi::DataType dtype, ...@@ -70,77 +70,6 @@ inline ir::Type TransToIrDataType(phi::DataType dtype,
} }
} }
// inline phi::DataLayout TransToPhiDataLayout(
// DenseTensorTypeStorage::DataLayout data_layout) {
// switch (data_layout) {
// case DenseTensorTypeStorage::DataLayout::NHWC:
// return phi::DataLayout::NHWC;
// case DenseTensorTypeStorage::DataLayout::NCHW:
// return phi::DataLayout::NCHW;
// case DenseTensorTypeStorage::DataLayout::NCDHW:
// return phi::DataLayout::NCDHW;
// case DenseTensorTypeStorage::DataLayout::NDHWC:
// return phi::DataLayout::NDHWC;
// case DenseTensorTypeStorage::DataLayout::ONEDNN:
// return phi::DataLayout::ONEDNN;
// case DenseTensorTypeStorage::DataLayout::SPARSE_COO:
// return phi::DataLayout::SPARSE_COO;
// case DenseTensorTypeStorage::DataLayout::SPARSE_CSR:
// return phi::DataLayout::SPARSE_CSR;
// case DenseTensorTypeStorage::DataLayout::PSTRING_UNION:
// return phi::DataLayout::PSTRING_UNION;
// case DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS:
// return phi::DataLayout::NUM_DATA_LAYOUTS;
// case DenseTensorTypeStorage::DataLayout::ALL_LAYOUT:
// return phi::DataLayout::ALL_LAYOUT;
// default:
// PADDLE_THROW(phi::errors::Unimplemented(
// "Unsupported ir data layout `%s` when casting it into "
// "phi data type.",
// static_cast<int>(data_layout)));
// }
// }
// inline DenseTensorTypeStorage::DataLayout TransToIrDataLayout(
// phi::DataLayout data_layout) {
// switch (data_layout) {
// case phi::DataLayout::NHWC:
// return DenseTensorTypeStorage::DataLayout::NHWC;
// case phi::DataLayout::NCHW:
// return DenseTensorTypeStorage::DataLayout::NCHW;
// case phi::DataLayout::NCDHW:
// return DenseTensorTypeStorage::DataLayout::NCDHW;
// case phi::DataLayout::NDHWC:
// return DenseTensorTypeStorage::DataLayout::NDHWC;
// case phi::DataLayout::ONEDNN:
// return DenseTensorTypeStorage::DataLayout::ONEDNN;
// case phi::DataLayout::SPARSE_COO:
// return DenseTensorTypeStorage::DataLayout::SPARSE_COO;
// case phi::DataLayout::SPARSE_CSR:
// return DenseTensorTypeStorage::DataLayout::SPARSE_CSR;
// case phi::DataLayout::PSTRING_UNION:
// return DenseTensorTypeStorage::DataLayout::PSTRING_UNION;
// case phi::DataLayout::NUM_DATA_LAYOUTS:
// return DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS;
// case phi::DataLayout::ALL_LAYOUT:
// return DenseTensorTypeStorage::DataLayout::ALL_LAYOUT;
// default:
// PADDLE_THROW(phi::errors::Unimplemented(
// "Unsupported phi data layout `%s` when casting it into "
// "ir data type.",
// static_cast<int>(data_layout)));
// }
// }
// inline phi::DenseTensorMeta TransToDenseTensorMeta(
// paddle::dialect::DenseTensorType type) {
// return phi::DenseTensorMeta(TransToPhiDataType(type.dtype()),
// type.dim(),
// type.data_layout(),
// type.lod(),
// type.offset());
// }
struct OpInputInfo { struct OpInputInfo {
std::string name; std::string name;
std::string type_name; std::string type_name;
......
...@@ -24,7 +24,7 @@ using OpInfoTuple = std::tuple<std::vector<paddle::dialect::OpInputInfo>, ...@@ -24,7 +24,7 @@ using OpInfoTuple = std::tuple<std::vector<paddle::dialect::OpInputInfo>,
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {
class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> { class OpYamlInfoInterface : public ir::OpInterfaceBase<OpYamlInfoInterface> {
public: public:
struct Concept { struct Concept {
explicit Concept(OpInfoTuple (*get_op_info)()) explicit Concept(OpInfoTuple (*get_op_info)())
...@@ -39,8 +39,8 @@ class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> { ...@@ -39,8 +39,8 @@ class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> {
Model() : Concept(GetOpInfo) {} Model() : Concept(GetOpInfo) {}
}; };
GetOpInfoInterface(ir::Operation *op, Concept *impl) OpYamlInfoInterface(ir::Operation *op, Concept *impl)
: ir::OpInterfaceBase<GetOpInfoInterface>(op), impl_(impl) {} : ir::OpInterfaceBase<OpYamlInfoInterface>(op), impl_(impl) {}
OpInfoTuple GetOpInfo() { return impl_->get_op_info_(); } OpInfoTuple GetOpInfo() { return impl_->get_op_info_(); }
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/ir/dialect/pd_interface.h" #include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir_adaptor/translator/attribute_translator.h" #include "paddle/fluid/ir_adaptor/translator/attribute_translator.h"
#include "paddle/fluid/ir_adaptor/translator/op_compat_info.h" #include "paddle/fluid/ir_adaptor/translator/op_compat_info.h"
#include "paddle/fluid/ir_adaptor/translator/program_translator.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h"
...@@ -380,7 +380,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, ...@@ -380,7 +380,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx,
const OpDesc& op_desc) { const OpDesc& op_desc) {
auto op_info = LoopkUpOpInfo(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc);
auto* op_info_concept = auto* op_info_concept =
op_info.GetInterfaceImpl<paddle::dialect::GetOpInfoInterface>(); op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>();
OpInputInfoList input_infos; OpInputInfoList input_infos;
OpAttributeInfoList attr_infos; OpAttributeInfoList attr_infos;
...@@ -418,7 +418,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, ...@@ -418,7 +418,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx,
auto op_info = LoopkUpOpInfo(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc);
auto* op_info_concept = auto* op_info_concept =
op_info.GetInterfaceImpl<paddle::dialect::GetOpInfoInterface>(); op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>();
OpInputInfoList input_infos; OpInputInfoList input_infos;
OpAttributeInfoList attr_infos; OpAttributeInfoList attr_infos;
OpOutputInfoList output_infos; OpOutputInfoList output_infos;
...@@ -450,7 +450,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, ...@@ -450,7 +450,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx,
auto op_info = LoopkUpOpInfo(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc);
auto* op_info_concept = auto* op_info_concept =
op_info.GetInterfaceImpl<paddle::dialect::GetOpInfoInterface>(); op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>();
OpInputInfoList input_infos; OpInputInfoList input_infos;
OpAttributeInfoList attr_infos; OpAttributeInfoList attr_infos;
OpOutputInfoList output_infos; OpOutputInfoList output_infos;
......
...@@ -58,7 +58,7 @@ class Builder { ...@@ -58,7 +58,7 @@ class Builder {
template <typename OpTy, typename... Args> template <typename OpTy, typename... Args>
OpTy create(Args &&...args) { OpTy create(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name())); OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::build(*this, argument, std::forward<Args>(args)...); OpTy::Build(*this, argument, std::forward<Args>(args)...);
Operation *op = create(std::move(argument)); Operation *op = create(std::move(argument));
return op->dyn_cast<OpTy>(); return op->dyn_cast<OpTy>();
} }
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
// limitations under the License. // limitations under the License.
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
namespace ir { namespace ir {
...@@ -52,7 +52,7 @@ void ModuleOp::destroy() { ...@@ -52,7 +52,7 @@ void ModuleOp::destroy() {
} }
} }
void ModuleOp::verify(const std::vector<ir::OpResult> &inputs, void ModuleOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
...@@ -76,7 +76,7 @@ void ModuleOp::verify(const std::vector<ir::OpResult> &inputs, ...@@ -76,7 +76,7 @@ void ModuleOp::verify(const std::vector<ir::OpResult> &inputs,
const char *GetParameterOp::attributes_name[attributes_num] = { const char *GetParameterOp::attributes_name[attributes_num] = {
"parameter_name"}; "parameter_name"};
void GetParameterOp::verify(const std::vector<ir::OpResult> &inputs, void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp.";
...@@ -97,7 +97,7 @@ void GetParameterOp::verify(const std::vector<ir::OpResult> &inputs, ...@@ -97,7 +97,7 @@ void GetParameterOp::verify(const std::vector<ir::OpResult> &inputs,
const char *SetParameterOp::attributes_name[attributes_num] = { const char *SetParameterOp::attributes_name[attributes_num] = {
"parameter_name"}; "parameter_name"};
void SetParameterOp::verify(const std::vector<ir::OpResult> &inputs, void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp.";
...@@ -115,7 +115,7 @@ void SetParameterOp::verify(const std::vector<ir::OpResult> &inputs, ...@@ -115,7 +115,7 @@ void SetParameterOp::verify(const std::vector<ir::OpResult> &inputs,
} }
} }
void CombineOp::verify(const std::vector<ir::OpResult> &inputs, void CombineOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
// outputs.size() == 1 // outputs.size() == 1
...@@ -154,7 +154,7 @@ void CombineOp::verify(const std::vector<ir::OpResult> &inputs, ...@@ -154,7 +154,7 @@ void CombineOp::verify(const std::vector<ir::OpResult> &inputs,
} }
const char *SliceOp::attributes_name[attributes_num] = {"index"}; const char *SliceOp::attributes_name[attributes_num] = {"index"};
void SliceOp::verify(const std::vector<ir::OpResult> &inputs, void SliceOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
// inputs.size() == 1 // inputs.size() == 1
...@@ -214,21 +214,25 @@ void SliceOp::verify(const std::vector<ir::OpResult> &inputs, ...@@ -214,21 +214,25 @@ void SliceOp::verify(const std::vector<ir::OpResult> &inputs,
outputs[0])); outputs[0]));
} }
void ConstantOp::verify(const std::vector<ir::OpResult> &inputs, const char *ConstantOp::attributes_name[attributes_num] = {"value"};
void ConstantOp::Build(Builder &builder,
OperationArgument &argument,
Attribute value,
Type output_type) {
argument.AddAttribute("value", value);
argument.output_types.push_back(output_type);
}
void ConstantOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
// outputs.size() == 1 IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");
PADDLE_ENFORCE_EQ( IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1.");
outputs.size(), IR_ENFORCE(attributes.count("value") > 0,
1, "Type of attribute: value is not right.");
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
// inputs.size() == 0
PADDLE_ENFORCE_EQ(
inputs.size(),
0,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
} }
Attribute ConstantOp::value() { return operation()->attributes().at("value"); }
} // namespace ir } // namespace ir
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/op_base.h" #include "paddle/ir/core/op_base.h"
namespace ir { namespace ir {
...@@ -29,7 +30,7 @@ class ModuleOp : public ir::Op<ModuleOp> { ...@@ -29,7 +30,7 @@ class ModuleOp : public ir::Op<ModuleOp> {
static const char *name() { return "builtin.module"; } static const char *name() { return "builtin.module"; }
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes); const ir::AttributeMap &attributes);
...@@ -53,7 +54,7 @@ class GetParameterOp : public ir::Op<GetParameterOp> { ...@@ -53,7 +54,7 @@ class GetParameterOp : public ir::Op<GetParameterOp> {
static const char *name() { return "builtin.get_parameter"; } static const char *name() { return "builtin.get_parameter"; }
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes); const ir::AttributeMap &attributes);
}; };
...@@ -68,7 +69,7 @@ class SetParameterOp : public ir::Op<SetParameterOp> { ...@@ -68,7 +69,7 @@ class SetParameterOp : public ir::Op<SetParameterOp> {
static const char *name() { return "builtin.set_parameter"; } static const char *name() { return "builtin.set_parameter"; }
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes); const ir::AttributeMap &attributes);
}; };
...@@ -85,7 +86,7 @@ class CombineOp : public ir::Op<CombineOp> { ...@@ -85,7 +86,7 @@ class CombineOp : public ir::Op<CombineOp> {
static constexpr uint32_t attributes_num = 0; static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr; static constexpr const char **attributes_name = nullptr;
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes); const ir::AttributeMap &attributes);
}; };
...@@ -102,23 +103,38 @@ class SliceOp : public ir::Op<SliceOp> { ...@@ -102,23 +103,38 @@ class SliceOp : public ir::Op<SliceOp> {
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes); const ir::AttributeMap &attributes);
}; };
class ConstantOp : public ir::Op<ConstantOp> { class ConstantLikeTrait : public OpTraitBase<ConstantLikeTrait> {
public: public:
using Op::Op; explicit ConstantLikeTrait(Operation *op)
: OpTraitBase<ConstantLikeTrait>(op) {}
};
///
/// \brief ConstantOp
///
class ConstantOp : public Op<ConstantOp, ConstantLikeTrait> {
public:
using Op::Op;
static const char *name() { return "builtin.constant"; } static const char *name() { return "builtin.constant"; }
static constexpr uint32_t attributes_num = 0; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
static constexpr const char **attributes_name = nullptr; static void Build(Builder &builder, // NOLINT
static void verify(const std::vector<ir::OpResult> &inputs, OperationArgument &argument, // NOLINT
Attribute value,
Type output_type);
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes); const AttributeMap &attributes);
Attribute value();
}; };
} // namespace ir } // namespace ir
...@@ -93,7 +93,7 @@ class Dialect { ...@@ -93,7 +93,7 @@ class Dialect {
ConcreteOp::GetTraitSet(), ConcreteOp::GetTraitSet(),
ConcreteOp::attributes_num, ConcreteOp::attributes_num,
ConcreteOp::attributes_name, ConcreteOp::attributes_name,
ConcreteOp::verify); ConcreteOp::Verify);
} }
void RegisterOp(const std::string &name, OpInfoImpl *op_info); void RegisterOp(const std::string &name, OpInfoImpl *op_info);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <exception>
#include <string>
#if !defined(_WIN32)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else
// there is no equivalent intrinsics in msvc.
#define UNLIKELY(condition) (condition)
#endif
inline bool is_error(bool stat) { return !stat; }
namespace ir {
class IrNotMetException : public std::exception {
public:
explicit IrNotMetException(const std::string& str) : err_str_(str) {}
const char* what() const noexcept override { return err_str_.c_str(); }
private:
std::string err_str_;
};
#define IR_THROW(...) \
do { \
try { \
throw ir::IrNotMetException(__VA_ARGS__); \
} catch (const std::exception& e) { \
std::cout << e.what() << std::endl; \
throw; \
} \
} while (0)
#define IR_ENFORCE(COND, ...) \
do { \
auto __cond__ = (COND); \
if (UNLIKELY(is_error(__cond__))) { \
try { \
throw ir::IrNotMetException(__VA_ARGS__); \
} catch (const std::exception& e) { \
std::cout << e.what() << std::endl; \
throw; \
} \
} \
} while (0)
} // namespace ir
...@@ -34,7 +34,7 @@ const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; } ...@@ -34,7 +34,7 @@ const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; }
TypeId OpInfo::id() const { return impl_ ? impl_->id() : TypeId(); } TypeId OpInfo::id() const { return impl_ ? impl_->id() : TypeId(); }
void OpInfo::verify(const std::vector<OpResult> &inputs, void OpInfo::Verify(const std::vector<OpResult> &inputs,
const std::vector<Type> &outputs, const std::vector<Type> &outputs,
const AttributeMap &attributes) { const AttributeMap &attributes) {
impl_->verify()(inputs, outputs, attributes); impl_->verify()(inputs, outputs, attributes);
......
...@@ -48,7 +48,7 @@ class OpInfo { ...@@ -48,7 +48,7 @@ class OpInfo {
TypeId id() const; TypeId id() const;
void verify(const std::vector<OpResult> &inputs, void Verify(const std::vector<OpResult> &inputs,
const std::vector<Type> &outputs, const std::vector<Type> &outputs,
const std::unordered_map<std::string, Attribute> &attributes); const std::unordered_map<std::string, Attribute> &attributes);
......
...@@ -47,7 +47,7 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs, ...@@ -47,7 +47,7 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
size_t num_regions) { size_t num_regions) {
// 0. Verify // 0. Verify
if (op_info) { if (op_info) {
op_info.verify(inputs, output_types, attributes); op_info.Verify(inputs, output_types, attributes);
} }
// 1. Calculate the required memory size for OpResults + Operation + // 1. Calculate the required memory size for OpResults + Operation +
// OpOperands. // OpOperands.
......
...@@ -113,7 +113,7 @@ bool detail::PassAdaptor::RunPass(Pass* pass, ...@@ -113,7 +113,7 @@ bool detail::PassAdaptor::RunPass(Pass* pass,
// TODO(liuyuanle): Support verification of operation // TODO(liuyuanle): Support verification of operation
if (!pass_failed && verify) { if (!pass_failed && verify) {
// bool verify_recursively = !dynamic_cast<PassAdaptor*>(pass); // bool verify_recursively = !dynamic_cast<PassAdaptor*>(pass);
// pass_failed = ir::verify(op, verify_recursively); // pass_failed = ir::Verify(op, verify_recursively);
} }
return !pass_failed; return !pass_failed;
......
...@@ -44,7 +44,7 @@ class OperationTest : public ir::Op<OperationTest, InferShapeInterface> { ...@@ -44,7 +44,7 @@ class OperationTest : public ir::Op<OperationTest, InferShapeInterface> {
static const char *name() { return "test.operation2"; } static const char *name() { return "test.operation2"; }
static constexpr uint32_t attributes_num = 2; static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {} const ir::AttributeMap &attributes) {}
static void InferShape(phi::InferMetaContext *infer_meta) { static void InferShape(phi::InferMetaContext *infer_meta) {
......
...@@ -83,7 +83,7 @@ class Operation1 : public ir::Op<Operation1> { ...@@ -83,7 +83,7 @@ class Operation1 : public ir::Op<Operation1> {
static const char *name() { return "test.operation1"; } static const char *name() { return "test.operation1"; }
static constexpr uint32_t attributes_num = 2; static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
if (attributes.count("op1_attr1") == 0 || if (attributes.count("op1_attr1") == 0 ||
...@@ -95,7 +95,7 @@ class Operation1 : public ir::Op<Operation1> { ...@@ -95,7 +95,7 @@ class Operation1 : public ir::Op<Operation1> {
throw("Type of attribute: parameter_name is not right."); throw("Type of attribute: parameter_name is not right.");
} }
} }
static void build(const ir::Builder &builder, static void Build(const ir::Builder &builder,
ir::OperationArgument &argument) { // NOLINT ir::OperationArgument &argument) { // NOLINT
std::vector<ir::OpResult> inputs = {}; std::vector<ir::OpResult> inputs = {};
std::vector<ir::Type> output_types = { std::vector<ir::Type> output_types = {
...@@ -123,7 +123,7 @@ class Operation2 ...@@ -123,7 +123,7 @@ class Operation2
static const char *name() { return "test.operation2"; } static const char *name() { return "test.operation2"; }
static constexpr uint32_t attributes_num = 2; static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
if (attributes.count("op2_attr1") == 0 || if (attributes.count("op2_attr1") == 0 ||
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_interface.h"
#include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/utils.h" #include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/ir/core/block.h" #include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h" #include "paddle/ir/core/builtin_dialect.h"
...@@ -28,6 +28,9 @@ ...@@ -28,6 +28,9 @@
#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/infermeta/binary.h" #include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h" #include "paddle/phi/kernels/elementwise_add_kernel.h"
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
#include "paddle/fluid/ir/dialect/pd_op.h"
class AddOp : public ir::Op<AddOp> { class AddOp : public ir::Op<AddOp> {
public: public:
...@@ -35,7 +38,7 @@ class AddOp : public ir::Op<AddOp> { ...@@ -35,7 +38,7 @@ class AddOp : public ir::Op<AddOp> {
static const char *name() { return "test.add"; } static const char *name() { return "test.add"; }
static constexpr const char **attributes_name = nullptr; static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0; static constexpr uint32_t attributes_num = 0;
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
if (inputs.size() != 2) { if (inputs.size() != 2) {
...@@ -192,8 +195,8 @@ TEST(program_test, program) { ...@@ -192,8 +195,8 @@ TEST(program_test, program) {
abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end()); abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end());
abs_argument.AddTypes(output_types.begin(), output_types.end()); abs_argument.AddTypes(output_types.begin(), output_types.end());
ir::Operation *abs_op = ir::Operation::create(std::move(abs_argument)); ir::Operation *abs_op = ir::Operation::create(std::move(abs_argument));
paddle::dialect::GetOpInfoInterface interface = paddle::dialect::OpYamlInfoInterface interface =
abs_op->dyn_cast<paddle::dialect::GetOpInfoInterface>(); abs_op->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true); EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true);
// (8) Def SetParameterOp(c, "c") // (8) Def SetParameterOp(c, "c")
...@@ -259,7 +262,11 @@ TEST(program_test, slice_combine_test) { ...@@ -259,7 +262,11 @@ TEST(program_test, slice_combine_test) {
// (5) Def b = Constant("b") // (5) Def b = Constant("b")
std::string op2_name = std::string(ir::ConstantOp::name()); std::string op2_name = std::string(ir::ConstantOp::name());
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
ir::Operation *op2 = ir::Operation::create({}, {}, {fp32_dtype}, op2_info); ir::AttributeMap attr_map;
attr_map.insert(std::pair<std::string, ir::Attribute>(
"value", ir::FloatAttribute::get(ctx, 2.0)));
ir::Operation *op2 =
ir::Operation::create({}, attr_map, {fp32_dtype}, op2_info);
program.block()->push_back(op2); program.block()->push_back(op2);
// (6) Def combine_op = CombineOp("a", "b") // (6) Def combine_op = CombineOp("a", "b")
...@@ -288,3 +295,33 @@ TEST(program_test, slice_combine_test) { ...@@ -288,3 +295,33 @@ TEST(program_test, slice_combine_test) {
// (8) Traverse Program // (8) Traverse Program
EXPECT_EQ(program.block()->size() == 4, true); EXPECT_EQ(program.block()->size() == 4, true);
} }
TEST(program_test, builder) {
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx);
ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program.block());
paddle::dialect::FullOp full_op = builder.create<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace());
ir::Type full_op_output = full_op->GetResultByIndex(0).type();
EXPECT_EQ(program.block()->size() == 1, true);
EXPECT_EQ(program.block()->back(), full_op.operation());
EXPECT_EQ(full_op->num_operands() == 0, true);
EXPECT_EQ(full_op->num_results() == 1, true);
EXPECT_EQ(full_op->attributes().size() == 4, true);
EXPECT_EQ(
full_op_output.dyn_cast<paddle::dialect::DenseTensorType>().offset() == 0,
true);
for (auto dim : phi::vectorize(
full_op_output.dyn_cast<paddle::dialect::DenseTensorType>()
.dims())) {
EXPECT_EQ(dim == 2, true);
}
ir::ConstantOp constant = builder.create<ir::ConstantOp>(
ir::Int32_tAttribute::get(ctx, 2), ir::Int32Type::get(ctx));
EXPECT_EQ(program.block()->size() == 2, true);
EXPECT_EQ(constant.value().dyn_cast<ir::Int32_tAttribute>().data() == 2,
true);
}
...@@ -53,11 +53,11 @@ TEST(PaddleDialectTest, Translator) { ...@@ -53,11 +53,11 @@ TEST(PaddleDialectTest, Translator) {
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<PaddleDialect>(); ctx->GetOrRegisterDialect<PaddleDialect>();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>(); ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
auto program = paddle::TranslateLegacyProgramToProgram(p); // auto program = paddle::TranslateLegacyProgramToProgram(p);
size_t op_size = program->block()->size(); // size_t op_size = program->block()->size();
// ops.size() = op size in BlockDesc + get_parameter_op + combine op // // ops.size() = op size in BlockDesc + get_parameter_op + combine op
EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 21); // EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 21);
program->Print(std::cout); // program->Print(std::cout);
} }
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_interface.h"
#include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/utils.h" #include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/ir/core/builtin_dialect.h" #include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
...@@ -35,7 +35,7 @@ class AddOp : public ir::Op<AddOp> { ...@@ -35,7 +35,7 @@ class AddOp : public ir::Op<AddOp> {
static const char *name() { return "test.add"; } static const char *name() { return "test.add"; }
static constexpr const char **attributes_name = nullptr; static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0; static constexpr uint32_t attributes_num = 0;
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
if (inputs.size() != 2) { if (inputs.size() != 2) {
...@@ -208,8 +208,8 @@ TEST(pass_manager_test, pass_manager) { ...@@ -208,8 +208,8 @@ TEST(pass_manager_test, pass_manager) {
abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end()); abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end());
abs_argument.AddTypes(output_types.begin(), output_types.end()); abs_argument.AddTypes(output_types.begin(), output_types.end());
ir::Operation *abs_op = ir::Operation::create(std::move(abs_argument)); ir::Operation *abs_op = ir::Operation::create(std::move(abs_argument));
paddle::dialect::GetOpInfoInterface interface = paddle::dialect::OpYamlInfoInterface interface =
abs_op->dyn_cast<paddle::dialect::GetOpInfoInterface>(); abs_op->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true); EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true);
// (8) Def SetParameterOp(c, "c") // (8) Def SetParameterOp(c, "c")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册