diff --git a/paddle/fluid/ir/dialect/pd_attribute.cc b/paddle/fluid/ir/dialect/pd_attribute.cc index 687e836dc70f30753b36a5d152b99ae23fa553de..8ccad0cf98185cfa13f20cf1e1d36aab6b670651 100644 --- a/paddle/fluid/ir/dialect/pd_attribute.cc +++ b/paddle/fluid/ir/dialect/pd_attribute.cc @@ -39,6 +39,8 @@ phi::Scalar ScalarAttribute::data() { return phi::Scalar(dyn_cast().data()); } else if (isa()) { return phi::Scalar(dyn_cast().data()); + } else if (isa()) { + return phi::Scalar(dyn_cast().AsString()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Unsupported ir attribute when casting it into " diff --git a/paddle/fluid/ir/dialect/pd_attribute.h b/paddle/fluid/ir/dialect/pd_attribute.h index 5af73b2c0f48b5e78d762ebad0ec69f630e69735..05514705efe14cfc9cb0cb1e143465202c99ddb1 100644 --- a/paddle/fluid/ir/dialect/pd_attribute.h +++ b/paddle/fluid/ir/dialect/pd_attribute.h @@ -45,7 +45,8 @@ class ScalarAttribute : public ir::Attribute { (val.type_id() == ir::FloatAttribute::type_id()) || (val.type_id() == ir::DoubleAttribute::type_id()) || (val.type_id() == ir::Int32Attribute::type_id()) || - (val.type_id() == ir::Int64Attribute::type_id()); + (val.type_id() == ir::Int64Attribute::type_id()) || + (val.type_id() == ir::StrAttribute::type_id()); } phi::Scalar data(); diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h index f3021ad47653213d17582c3e96ff4440e8a00135..91997dd341c1d5c0c00188a6b631d6f5286fbe54 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h @@ -160,10 +160,10 @@ void BuildPhiContext(ir::Operation* op, tensor_attr_type)); } } else if (tensor_attr_type == "paddle::dialect::ScalarAttribute") { - phi::Attribute r1 = phi::TensorRef( + phi::Attribute attr = phi::TensorRef( &(inner_scope->FindVar(in_var_name)->Get())); - ctx->EmplaceBackAttr(r1); + ctx->EmplaceBackAttr(attr); } else { PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ", tensor_attr_type)); diff --git a/paddle/fluid/ir_adaptor/translator/attribute_translator.cc b/paddle/fluid/ir_adaptor/translator/attribute_translator.cc index 2cd001cbcafa4ea0b27c01dcaae947effd9b8bd2..49845754b50278b29b4f98c1188e210e9180e93e 100644 --- a/paddle/fluid/ir_adaptor/translator/attribute_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/attribute_translator.cc @@ -62,7 +62,7 @@ class AttributeVisitor { return ir::DoubleAttribute::get(ctx, d); } - virtual ir::Attribute operator()(std::string str) { + virtual ir::Attribute operator()(const std::string& str) { VLOG(10) << "translating string"; return ir::StrAttribute::get(ctx, str); } diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index f454811b08ff39573acfd56072b4a0073793c159..aa650d60ac20af390a2d1ce7200c9dfd5b027527 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -190,6 +190,11 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx, } else if (attr.isa()) { data = static_cast(attr.dyn_cast().data()); dtype = phi::DataType::BOOL; + } else if (attr.isa()) { + // TODO(phlrain) : need update here, downcast from double to float + data = static_cast( + attr.dyn_cast().data().to()); + dtype = phi::DataType::FLOAT64; } ir::Builder builder(ctx, program->block()); dialect::FullOp full_op = builder.Build( diff --git a/test/white_list/new_ir_op_test_white_list b/test/white_list/new_ir_op_test_white_list index e8cd359de3afb7108f2f49d5769a7ebe20d4c420..7bd1c73c485912bd4a2437c6be3c735d399024e9 100644 --- a/test/white_list/new_ir_op_test_white_list +++ b/test/white_list/new_ir_op_test_white_list @@ -94,6 +94,7 @@ test_i0_op test_i1e_op test_i1_op test_index_add_op +test_isclose_op test_index_sample_op test_instance_norm_op_v2 test_instance_norm_op_v2_new_ir