未验证 提交 5296e91a 编写于 作者: H hong 提交者: GitHub

[NewIR]Fix new ir unique and isclose op (#55213)

* add ir output check in OpTest

* add ir grad check in op test

* fix legacy name converter bug

* add more unittest

* fix

* fix warprnn op bug

* add whit list

* polish code

* polish code

* fix unique and close op bug

* fix bug

* update

* fix new ir unique is close bug

* remove useless code

* use new stringattr api

---------
Co-authored-by: Nkangguangli <kangguangli@hotmail.com>
Co-authored-by: Nzhangbo9674 <zhangbo54@baidu.com>
上级 a1475914
......@@ -39,6 +39,8 @@ phi::Scalar ScalarAttribute::data() {
return phi::Scalar(dyn_cast<ir::Int64Attribute>().data());
} else if (isa<ir::BoolAttribute>()) {
return phi::Scalar(dyn_cast<ir::BoolAttribute>().data());
} else if (isa<ir::StrAttribute>()) {
return phi::Scalar(dyn_cast<ir::StrAttribute>().AsString());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported ir attribute when casting it into "
......
......@@ -45,7 +45,8 @@ class ScalarAttribute : public ir::Attribute {
(val.type_id() == ir::FloatAttribute::type_id()) ||
(val.type_id() == ir::DoubleAttribute::type_id()) ||
(val.type_id() == ir::Int32Attribute::type_id()) ||
(val.type_id() == ir::Int64Attribute::type_id());
(val.type_id() == ir::Int64Attribute::type_id()) ||
(val.type_id() == ir::StrAttribute::type_id());
}
phi::Scalar data();
......
......@@ -160,10 +160,10 @@ void BuildPhiContext(ir::Operation* op,
tensor_attr_type));
}
} else if (tensor_attr_type == "paddle::dialect::ScalarAttribute") {
phi::Attribute r1 = phi::TensorRef(
phi::Attribute attr = phi::TensorRef(
&(inner_scope->FindVar(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
ctx->EmplaceBackAttr(attr);
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
tensor_attr_type));
......
......@@ -62,7 +62,7 @@ class AttributeVisitor {
return ir::DoubleAttribute::get(ctx, d);
}
virtual ir::Attribute operator()(std::string str) {
virtual ir::Attribute operator()(const std::string& str) {
VLOG(10) << "translating string";
return ir::StrAttribute::get(ctx, str);
}
......
......@@ -190,6 +190,11 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx,
} else if (attr.isa<ir::BoolAttribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::BoolAttribute>().data());
dtype = phi::DataType::BOOL;
} else if (attr.isa<dialect::ScalarAttribute>()) {
// TODO(phlrain) : need update here, downcast from double to float
data = static_cast<float>(
attr.dyn_cast<dialect::ScalarAttribute>().data().to<double>());
dtype = phi::DataType::FLOAT64;
}
ir::Builder builder(ctx, program->block());
dialect::FullOp full_op = builder.Build<dialect::FullOp>(
......
......@@ -94,6 +94,7 @@ test_i0_op
test_i1e_op
test_i1_op
test_index_add_op
test_isclose_op
test_index_sample_op
test_instance_norm_op_v2
test_instance_norm_op_v2_new_ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册