From 5296e91a3e860b8790cc957d3ebb2da39aeb97c7 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Tue, 1 Aug 2023 11:56:10 +0800 Subject: [PATCH] [NewIR]Fix new ir unique and isclose op (#55213) * add ir output check in OpTest * add ir grad check in op test * fix legacy name converter bug * add more unittest * fix * fix warprnn op bug * add whit list * polish code * polish code * fix unique and close op bug * fix bug * update * fix new ir unique is close bug * remove useless code * use new stringattr api --------- Co-authored-by: kangguangli Co-authored-by: zhangbo9674 --- paddle/fluid/ir/dialect/pd_attribute.cc | 2 ++ paddle/fluid/ir/dialect/pd_attribute.h | 3 ++- paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h | 4 ++-- paddle/fluid/ir_adaptor/translator/attribute_translator.cc | 2 +- paddle/fluid/ir_adaptor/translator/op_translator.cc | 5 +++++ test/white_list/new_ir_op_test_white_list | 1 + 6 files changed, 13 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/ir/dialect/pd_attribute.cc b/paddle/fluid/ir/dialect/pd_attribute.cc index 687e836dc70..8ccad0cf981 100644 --- a/paddle/fluid/ir/dialect/pd_attribute.cc +++ b/paddle/fluid/ir/dialect/pd_attribute.cc @@ -39,6 +39,8 @@ phi::Scalar ScalarAttribute::data() { return phi::Scalar(dyn_cast().data()); } else if (isa()) { return phi::Scalar(dyn_cast().data()); + } else if (isa()) { + return phi::Scalar(dyn_cast().AsString()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Unsupported ir attribute when casting it into " diff --git a/paddle/fluid/ir/dialect/pd_attribute.h b/paddle/fluid/ir/dialect/pd_attribute.h index 5af73b2c0f4..05514705efe 100644 --- a/paddle/fluid/ir/dialect/pd_attribute.h +++ b/paddle/fluid/ir/dialect/pd_attribute.h @@ -45,7 +45,8 @@ class ScalarAttribute : public ir::Attribute { (val.type_id() == ir::FloatAttribute::type_id()) || (val.type_id() == ir::DoubleAttribute::type_id()) || (val.type_id() == ir::Int32Attribute::type_id()) || - (val.type_id() == ir::Int64Attribute::type_id()); + (val.type_id() == ir::Int64Attribute::type_id()) || + (val.type_id() == ir::StrAttribute::type_id()); } phi::Scalar data(); diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h index f3021ad4765..91997dd341c 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h @@ -160,10 +160,10 @@ void BuildPhiContext(ir::Operation* op, tensor_attr_type)); } } else if (tensor_attr_type == "paddle::dialect::ScalarAttribute") { - phi::Attribute r1 = phi::TensorRef( + phi::Attribute attr = phi::TensorRef( &(inner_scope->FindVar(in_var_name)->Get())); - ctx->EmplaceBackAttr(r1); + ctx->EmplaceBackAttr(attr); } else { PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ", tensor_attr_type)); diff --git a/paddle/fluid/ir_adaptor/translator/attribute_translator.cc b/paddle/fluid/ir_adaptor/translator/attribute_translator.cc index 2cd001cbcaf..49845754b50 100644 --- a/paddle/fluid/ir_adaptor/translator/attribute_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/attribute_translator.cc @@ -62,7 +62,7 @@ class AttributeVisitor { return ir::DoubleAttribute::get(ctx, d); } - virtual ir::Attribute operator()(std::string str) { + virtual ir::Attribute operator()(const std::string& str) { VLOG(10) << "translating string"; return ir::StrAttribute::get(ctx, str); } diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index f454811b08f..aa650d60ac2 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -190,6 +190,11 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx, } else if (attr.isa()) { data = static_cast(attr.dyn_cast().data()); dtype = phi::DataType::BOOL; + } else if (attr.isa()) { + // TODO(phlrain) : need update here, downcast from double to float + data = static_cast( + attr.dyn_cast().data().to()); + dtype = phi::DataType::FLOAT64; } ir::Builder builder(ctx, program->block()); dialect::FullOp full_op = builder.Build( diff --git a/test/white_list/new_ir_op_test_white_list b/test/white_list/new_ir_op_test_white_list index e8cd359de3a..7bd1c73c485 100644 --- a/test/white_list/new_ir_op_test_white_list +++ b/test/white_list/new_ir_op_test_white_list @@ -94,6 +94,7 @@ test_i0_op test_i1e_op test_i1_op test_index_add_op +test_isclose_op test_index_sample_op test_instance_norm_op_v2 test_instance_norm_op_v2_new_ir -- GitLab