未验证 提交 e73ddd6c 编写于 作者: Z zhangbo9674 提交者: GitHub

refine attribute name (#54516)

上级 b04689e9
......@@ -251,8 +251,8 @@ def to_phi_and_fluid_op_name(op_item):
scalar_type_maps = {
'int': 'ir::Int32_tAttribute',
'int64_t': 'ir::Int64_tAttribute',
'int': 'ir::Int32Attribute',
'int64_t': 'ir::Int64Attribute',
'float': 'ir::FloatAttribute',
'dobule': 'ir::DoubleAttribute',
'bool': 'ir::BoolAttribute',
......@@ -309,17 +309,17 @@ class OpInfoParser:
self.attr_types_map = {
'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'],
'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'],
'Scalar(int)': ['ir::Int32_tAttribute', 'int'],
'Scalar(int64_t)': ['ir::Int64_tAttribute', 'int64_t'],
'Scalar(int)': ['ir::Int32Attribute', 'int'],
'Scalar(int64_t)': ['ir::Int64Attribute', 'int64_t'],
'Scalar(float)': ['ir::FloatAttribute', 'float'],
'Scalar(dobule)': ['ir::DoubleAttribute', 'dobule'],
'Scalar[]': [
'ir::ArrayAttribute<paddle::dialect::ScalarAttribute>',
'const std::vector<Scalar>&',
],
'int': ['ir::Int32_tAttribute', 'int'],
'int32_t': ['ir::Int32_tAttribute', 'int32_t'],
'int64_t': ['ir::Int64_tAttribute', 'int64_t'],
'int': ['ir::Int32Attribute', 'int'],
'int32_t': ['ir::Int32Attribute', 'int32_t'],
'int64_t': ['ir::Int64Attribute', 'int64_t'],
'long': ['ir::LongAttribute', 'long'],
'size_t': ['ir::Size_tAttribute', 'size_t'],
'float': ['ir::FloatAttribute', 'float'],
......@@ -345,11 +345,11 @@ class OpInfoParser:
],
'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'],
'int64_t[]': [
'ir::ArrayAttribute<ir::Int64_tAttribute>',
'ir::ArrayAttribute<ir::Int64Attribute>',
'const std::vector<int64_t>&',
],
'int[]': [
'ir::ArrayAttribute<ir::Int32_tAttribute>',
'ir::ArrayAttribute<ir::Int32Attribute>',
'const std::vector<int>&',
],
}
......
......@@ -31,10 +31,10 @@ phi::Scalar ScalarAttribute::data() {
return phi::Scalar(dyn_cast<ir::FloatAttribute>().data());
} else if (isa<ir::DoubleAttribute>()) {
return phi::Scalar(dyn_cast<ir::DoubleAttribute>().data());
} else if (isa<ir::Int32_tAttribute>()) {
return phi::Scalar(dyn_cast<ir::Int32_tAttribute>().data());
} else if (isa<ir::Int64_tAttribute>()) {
return phi::Scalar(dyn_cast<ir::Int64_tAttribute>().data());
} else if (isa<ir::Int32Attribute>()) {
return phi::Scalar(dyn_cast<ir::Int32Attribute>().data());
} else if (isa<ir::Int64Attribute>()) {
return phi::Scalar(dyn_cast<ir::Int64Attribute>().data());
} else if (isa<ir::BoolAttribute>()) {
return phi::Scalar(dyn_cast<ir::BoolAttribute>().data());
} else {
......
......@@ -44,8 +44,8 @@ class ScalarAttribute : public ir::Attribute {
return (val.type_id() == ir::BoolAttribute::type_id()) ||
(val.type_id() == ir::FloatAttribute::type_id()) ||
(val.type_id() == ir::DoubleAttribute::type_id()) ||
(val.type_id() == ir::Int32_tAttribute::type_id()) ||
(val.type_id() == ir::Int64_tAttribute::type_id());
(val.type_id() == ir::Int32Attribute::type_id()) ||
(val.type_id() == ir::Int64Attribute::type_id());
}
phi::Scalar data();
......
......@@ -83,9 +83,9 @@ static inline ir::Attribute TransToIrAttribute(phi::Scalar scalar,
case phi::DataType::FLOAT64:
return ir::DoubleAttribute::get(ctx, scalar.to<double>());
case phi::DataType::INT32:
return ir::Int32_tAttribute::get(ctx, scalar.to<int32_t>());
return ir::Int32Attribute::get(ctx, scalar.to<int32_t>());
case phi::DataType::INT64:
return ir::Int64_tAttribute::get(ctx, scalar.to<int64_t>());
return ir::Int64Attribute::get(ctx, scalar.to<int64_t>());
case phi::DataType::BOOL:
return ir::BoolAttribute::get(ctx, scalar.to<bool>());
default:
......
......@@ -38,7 +38,7 @@ class AttributeVisitor {
public:
virtual ir::Attribute operator()(int i) {
VLOG(10) << "translating int";
return ir::Int32_tAttribute::get(ctx, i);
return ir::Int32Attribute::get(ctx, i);
}
virtual ir::Attribute operator()(float f) {
......@@ -91,7 +91,7 @@ class AttributeVisitor {
std::vector<ir::Attribute> attrs;
attrs.reserve(is.size());
for (const auto& v : is) {
attrs.push_back(ir::Int32_tAttribute::get(ctx, v));
attrs.push_back(ir::Int32Attribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}
......@@ -111,7 +111,7 @@ class AttributeVisitor {
std::vector<ir::Attribute> attrs;
attrs.reserve(i64s.size());
for (const auto& v : i64s) {
attrs.push_back(ir::Int64_tAttribute::get(ctx, v));
attrs.push_back(ir::Int64Attribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}
......
......@@ -131,7 +131,7 @@ inline ir::Operation* InsertSliceOperationForTarget(
std::string slice_op_name(ir::SliceOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(slice_op_name);
std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
{"index", ir::Int32_tAttribute::get(ctx, defining_info.idx_in_vector)},
{"index", ir::Int32Attribute::get(ctx, defining_info.idx_in_vector)},
};
ir::VectorType src_vec_type =
defining_info.value.type().dyn_cast<ir::VectorType>();
......@@ -179,11 +179,11 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx,
} else if (attr.isa<ir::DoubleAttribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::DoubleAttribute>().data());
dtype = phi::DataType::FLOAT64;
} else if (attr.isa<ir::Int32_tAttribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::Int32_tAttribute>().data());
} else if (attr.isa<ir::Int32Attribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::Int32Attribute>().data());
dtype = phi::DataType::INT32;
} else if (attr.isa<ir::Int64_tAttribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::Int64_tAttribute>().data());
} else if (attr.isa<ir::Int64Attribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::Int64Attribute>().data());
dtype = phi::DataType::INT64;
} else if (attr.isa<ir::BoolAttribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::BoolAttribute>().data());
......
......@@ -25,9 +25,9 @@ float FloatAttribute::data() const { return storage()->GetAsKey(); }
double DoubleAttribute::data() const { return storage()->GetAsKey(); }
int32_t Int32_tAttribute::data() const { return storage()->GetAsKey(); }
int32_t Int32Attribute::data() const { return storage()->GetAsKey(); }
int64_t Int64_tAttribute::data() const { return storage()->GetAsKey(); }
int64_t Int64Attribute::data() const { return storage()->GetAsKey(); }
std::vector<Attribute> ArrayAttribute::data() const {
return storage()->GetAsKey();
......
......@@ -61,20 +61,20 @@ class DoubleAttribute : public Attribute {
double data() const;
};
class Int32_tAttribute : public Attribute {
class Int32Attribute : public Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Int32_tAttribute, Int32_tAttributeStorage);
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Int32Attribute, Int32AttributeStorage);
int32_t data() const;
};
class Int64_tAttribute : public Attribute {
class Int64Attribute : public Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Int64_tAttribute, Int64_tAttributeStorage);
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Int64Attribute, Int64AttributeStorage);
int64_t data() const;
};
......
......@@ -81,8 +81,8 @@ struct StrAttributeStorage : public AttributeStorage {
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(BoolAttributeStorage, bool);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(DoubleAttributeStorage, double);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32_tAttributeStorage, int32_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64_tAttributeStorage, int64_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32AttributeStorage, int32_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64AttributeStorage, int64_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(PointerAttributeStorage, void *);
struct ArrayAttributeStorage : public AttributeStorage {
......
......@@ -41,8 +41,8 @@ void BuiltinDialect::initialize() {
FloatAttribute,
DoubleAttribute,
PointerAttribute,
Int32_tAttribute,
Int64_tAttribute,
Int32Attribute,
Int64Attribute,
ArrayAttribute>();
RegisterOps<ModuleOp,
......
......@@ -161,9 +161,9 @@ void SliceOp::Verify(const std::vector<ir::OpResult> &inputs,
IR_ENFORCE(attributes.count("index") != 0,
"The attributes must contains index.");
const ir::Attribute &attr = attributes.at("index");
IR_ENFORCE(attr.isa<ir::Int32_tAttribute>(),
IR_ENFORCE(attr.isa<ir::Int32Attribute>(),
"The attribute index must be INT32.");
auto index = attr.dyn_cast<ir::Int32_tAttribute>().data();
auto index = attr.dyn_cast<ir::Int32Attribute>().data();
// index >= 0 and < inputs[0].size()
IR_ENFORCE(
......
......@@ -83,9 +83,9 @@ class BasicIRPrinter {
os << f.data();
} else if (auto d = attr.dyn_cast<DoubleAttribute>()) {
os << d.data();
} else if (auto i = attr.dyn_cast<Int32_tAttribute>()) {
} else if (auto i = attr.dyn_cast<Int32Attribute>()) {
os << i.data();
} else if (auto i = attr.dyn_cast<Int64_tAttribute>()) {
} else if (auto i = attr.dyn_cast<Int64Attribute>()) {
os << i.data();
} else if (auto arr = attr.dyn_cast<ArrayAttribute>()) {
const auto& vec = arr.data();
......
......@@ -238,7 +238,7 @@ TEST(op_test, module_op_death) {
// (3) Test uses for op.
std::vector<ir::OpResult> inputs{ir::OpResult()};
ir::AttributeMap attrs{{"program", ir::Int32_tAttribute::get(ctx, 1)}};
ir::AttributeMap attrs{{"program", ir::Int32Attribute::get(ctx, 1)}};
std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)};
EXPECT_THROW(ir::Operation::Create(inputs, {}, {}, op_info),
......
......@@ -283,7 +283,7 @@ TEST(program_test, slice_combine_test) {
// (7) Def slice_op = SliceOp(combine_op, 0)
std::string slice_op_name = std::string(ir::SliceOp::name());
ir::OpInfo slice_op_info = ctx->GetRegisteredOpInfo(slice_op_name);
ir::Attribute index_attr = ir::Int32_tAttribute::get(ctx, 0);
ir::Attribute index_attr = ir::Int32Attribute::get(ctx, 0);
ir::Operation *slice_op =
ir::Operation::Create({combine_op->GetResultByIndex(0)},
{{"index", index_attr}},
......@@ -319,8 +319,7 @@ TEST(program_test, builder) {
}
ir::ConstantOp constant = builder.Build<ir::ConstantOp>(
ir::Int32_tAttribute::get(ctx, 2), ir::Int32Type::get(ctx));
ir::Int32Attribute::get(ctx, 2), ir::Int32Type::get(ctx));
EXPECT_EQ(program.block()->size() == 2, true);
EXPECT_EQ(constant.value().dyn_cast<ir::Int32_tAttribute>().data() == 2,
true);
EXPECT_EQ(constant.value().dyn_cast<ir::Int32Attribute>().data() == 2, true);
}
......@@ -28,7 +28,7 @@ TEST(ir_op_info_test, op_op_info_test) {
ir::Block* block = program.block();
ir::Builder builder(context, block);
builder.Build<ir::ConstantOp>(ir::Int32_tAttribute::get(context, 5),
builder.Build<ir::ConstantOp>(ir::Int32Attribute::get(context, 5),
ir::Int32Type::get(context));
ir::Operation* op = block->back();
......
......@@ -164,9 +164,8 @@ void build_context(ir::Operation* op,
} else if (type_name == "paddle::dialect::DataTypeAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::DataTypeAttribute>().data());
} else if (type_name == "ir::Int32_tAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<ir::Int32_tAttribute>().data());
} else if (type_name == "ir::Int32Attribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::Int32Attribute>().data());
} else if (type_name == "paddle::dialect::PlaceAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data());
......
......@@ -47,9 +47,9 @@ TEST(ScalarTest, test_classof) {
ir::Attribute double_scalar = ir::DoubleAttribute::get(ctx, 1.0);
EXPECT_TRUE(double_scalar.isa<ScalarAttribute>());
ir::Attribute int32_scalar = ir::Int32_tAttribute::get(ctx, 1);
ir::Attribute int32_scalar = ir::Int32Attribute::get(ctx, 1);
EXPECT_TRUE(int32_scalar.isa<ScalarAttribute>());
ir::Attribute int64_scalar = ir::Int64_tAttribute::get(ctx, 1l);
ir::Attribute int64_scalar = ir::Int64Attribute::get(ctx, 1l);
EXPECT_TRUE(int64_scalar.isa<ScalarAttribute>());
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册