未验证 提交 d658940a 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Support mutable attribute as input for paddle dialect OP build method (#54563)

* support mutable attr is input for build

* add ut

* solve conflict
上级 a96c6dc7
...@@ -991,6 +991,9 @@ def GenBuildOutputs( ...@@ -991,6 +991,9 @@ def GenBuildOutputs(
}} }}
""" """
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullIntArrayOp>().operation()->attributes().at("value").dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullOp>().operation()->attributes().at("value").dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{dtype}>(); (void){name};\n"""
CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name}; CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
phi::MetaTensor meta_{name}(&dense_{name}); phi::MetaTensor meta_{name}(&dense_{name});
""" """
...@@ -1017,6 +1020,32 @@ def GenBuildOutputs( ...@@ -1017,6 +1020,32 @@ def GenBuildOutputs(
name=op_input_name_list[idx] name=op_input_name_list[idx]
) )
# Prepare mutable attributes
if mutable_attr_is_input:
for idx in range(len(op_mutable_attribute_name_list)):
attr_dtype = op_mutable_attribute_type_list[idx]
# int_array
if attr_dtype[0] == "paddle::dialect::IntArrayAttribute":
build_output_str += (
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx]
)
)
# scalar
elif attr_dtype[0] == "paddle::dialect::ScalarAttribute":
build_output_str += (
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx],
dtype=attr_dtype[1],
)
)
# string
elif attr_dtype[0] == "ir::StrAttribute":
build_output_str += ""
else:
assert "mutable attribtue type is not right."
build_output_str += "\n"
# Prepare inputs_meta_tensor & attributes for infer meta # Prepare inputs_meta_tensor & attributes for infer meta
infer_meta_args = [] infer_meta_args = []
for idx in range(len(op_infer_meta_map['param'])): for idx in range(len(op_infer_meta_map['param'])):
...@@ -1181,7 +1210,7 @@ def GenBuild( ...@@ -1181,7 +1210,7 @@ def GenBuild(
op_output_type_list, op_output_type_list,
op_output_size_list, op_output_size_list,
op_infer_meta_map, op_infer_meta_map,
False, muta_attr_is_input,
) )
build_func = OP_BUILD_TEMPLATE.format( build_func = OP_BUILD_TEMPLATE.format(
...@@ -1346,16 +1375,7 @@ def OpGenerator( ...@@ -1346,16 +1375,7 @@ def OpGenerator(
op_infer_meta_map, op_infer_meta_map,
muta_attr_is_input=False, muta_attr_is_input=False,
) )
op_infer_meta_args = op_infer_meta_map['param'] if len(op_mutable_attribute_name_list) > 0:
if (len(op_mutable_attribute_name_list) > 0) and (
len(
list(
set(op_infer_meta_args)
& set(op_mutable_attribute_name_list)
)
)
== 0
):
( (
build_args_with_muta_attr_is_input_for_declare, build_args_with_muta_attr_is_input_for_declare,
build_func_with_muta_attr_is_input, build_func_with_muta_attr_is_input,
......
...@@ -109,3 +109,72 @@ TEST(program_test, program) { ...@@ -109,3 +109,72 @@ TEST(program_test, program) {
EXPECT_EQ(res2, true); EXPECT_EQ(res2, true);
EXPECT_EQ(res3, true); EXPECT_EQ(res3, true);
} }
TEST(program_test, mutable_attribute) {
// Prepare ir env
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx);
ir::Builder builder = ir::Builder(ctx, program.block());
ir::Block* block = program.block();
// Def FullOp
paddle::dialect::FullIntArrayOp full_shape_op =
builder.Build<paddle::dialect::FullIntArrayOp>(
std::vector<int64_t>{2, 2}, phi::DataType::INT64, phi::CPUPlace());
ir::OpResult shape_ = full_shape_op->result(0);
// Generate scalar mutable attribute: min
paddle::dialect::FullOp full_min_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 0.0, phi::DataType::FLOAT32, phi::CPUPlace());
ir::OpResult min_ = full_min_op->result(0);
// Generate scalar mutable attribute: max
paddle::dialect::FullOp full_max_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());
ir::OpResult max_ = full_max_op->result(0);
// Def: static void Build(ir::Builder &builder, ir::OperationArgument
// &argument, ir::OpResult shape_, ir::OpResult min_, ir::OpResult max_,
// phi::DataType dtype, int seed, phi::Place place={});
paddle::dialect::UniformOp uniform1 =
builder.Build<paddle::dialect::UniformOp>(
shape_, min_, max_, phi::DataType::FLOAT32, 2, phi::CPUPlace());
EXPECT_EQ(uniform1->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 4u);
// Def: B = paddle::dialect::UniformOp(...)
paddle::dialect::UniformOp uniform2 =
builder.Build<paddle::dialect::UniformOp>(
shape_, min_, max_, phi::DataType::FLOAT32, 2, phi::CPUPlace());
EXPECT_EQ(uniform2->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 5u);
// Def: C = paddle::dialect::AddOp(ir::OpResult x_, ir::OpResult y_)
paddle::dialect::AddOp add = builder.Build<paddle::dialect::AddOp>(
uniform1->result(0), uniform2->result(0));
EXPECT_EQ(add->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 6u);
// Execute program
paddle::framework::Scope scope;
PhiKernelAdaptor phi_kernel_adaptor(&scope);
phi_kernel_adaptor.run(&program);
auto out_tensor =
scope.Var(phi_kernel_adaptor.out_name)->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 1.80721);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 1.70047);
bool res2 = simple_cmp(out_tensor.data<float>()[2], 1.56764);
bool res3 = simple_cmp(out_tensor.data<float>()[3], 1.85063);
std::cerr << out_tensor.data<float>()[0] << "\t"
<< out_tensor.data<float>()[1] << "\t"
<< out_tensor.data<float>()[2] << "\t"
<< out_tensor.data<float>()[3] << std::endl;
EXPECT_EQ(res0, true);
EXPECT_EQ(res1, true);
EXPECT_EQ(res2, true);
EXPECT_EQ(res3, true);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册