未验证 提交 e73ddd6c 编写于 作者: Z zhangbo9674 提交者: GitHub

refine attribute name (#54516)

上级 b04689e9
...@@ -251,8 +251,8 @@ def to_phi_and_fluid_op_name(op_item): ...@@ -251,8 +251,8 @@ def to_phi_and_fluid_op_name(op_item):
scalar_type_maps = { scalar_type_maps = {
'int': 'ir::Int32_tAttribute', 'int': 'ir::Int32Attribute',
'int64_t': 'ir::Int64_tAttribute', 'int64_t': 'ir::Int64Attribute',
'float': 'ir::FloatAttribute', 'float': 'ir::FloatAttribute',
'dobule': 'ir::DoubleAttribute', 'dobule': 'ir::DoubleAttribute',
'bool': 'ir::BoolAttribute', 'bool': 'ir::BoolAttribute',
...@@ -309,17 +309,17 @@ class OpInfoParser: ...@@ -309,17 +309,17 @@ class OpInfoParser:
self.attr_types_map = { self.attr_types_map = {
'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'], 'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'],
'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'], 'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'],
'Scalar(int)': ['ir::Int32_tAttribute', 'int'], 'Scalar(int)': ['ir::Int32Attribute', 'int'],
'Scalar(int64_t)': ['ir::Int64_tAttribute', 'int64_t'], 'Scalar(int64_t)': ['ir::Int64Attribute', 'int64_t'],
'Scalar(float)': ['ir::FloatAttribute', 'float'], 'Scalar(float)': ['ir::FloatAttribute', 'float'],
'Scalar(dobule)': ['ir::DoubleAttribute', 'dobule'], 'Scalar(dobule)': ['ir::DoubleAttribute', 'dobule'],
'Scalar[]': [ 'Scalar[]': [
'ir::ArrayAttribute<paddle::dialect::ScalarAttribute>', 'ir::ArrayAttribute<paddle::dialect::ScalarAttribute>',
'const std::vector<Scalar>&', 'const std::vector<Scalar>&',
], ],
'int': ['ir::Int32_tAttribute', 'int'], 'int': ['ir::Int32Attribute', 'int'],
'int32_t': ['ir::Int32_tAttribute', 'int32_t'], 'int32_t': ['ir::Int32Attribute', 'int32_t'],
'int64_t': ['ir::Int64_tAttribute', 'int64_t'], 'int64_t': ['ir::Int64Attribute', 'int64_t'],
'long': ['ir::LongAttribute', 'long'], 'long': ['ir::LongAttribute', 'long'],
'size_t': ['ir::Size_tAttribute', 'size_t'], 'size_t': ['ir::Size_tAttribute', 'size_t'],
'float': ['ir::FloatAttribute', 'float'], 'float': ['ir::FloatAttribute', 'float'],
...@@ -345,11 +345,11 @@ class OpInfoParser: ...@@ -345,11 +345,11 @@ class OpInfoParser:
], ],
'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'], 'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'],
'int64_t[]': [ 'int64_t[]': [
'ir::ArrayAttribute<ir::Int64_tAttribute>', 'ir::ArrayAttribute<ir::Int64Attribute>',
'const std::vector<int64_t>&', 'const std::vector<int64_t>&',
], ],
'int[]': [ 'int[]': [
'ir::ArrayAttribute<ir::Int32_tAttribute>', 'ir::ArrayAttribute<ir::Int32Attribute>',
'const std::vector<int>&', 'const std::vector<int>&',
], ],
} }
......
...@@ -31,10 +31,10 @@ phi::Scalar ScalarAttribute::data() { ...@@ -31,10 +31,10 @@ phi::Scalar ScalarAttribute::data() {
return phi::Scalar(dyn_cast<ir::FloatAttribute>().data()); return phi::Scalar(dyn_cast<ir::FloatAttribute>().data());
} else if (isa<ir::DoubleAttribute>()) { } else if (isa<ir::DoubleAttribute>()) {
return phi::Scalar(dyn_cast<ir::DoubleAttribute>().data()); return phi::Scalar(dyn_cast<ir::DoubleAttribute>().data());
} else if (isa<ir::Int32_tAttribute>()) { } else if (isa<ir::Int32Attribute>()) {
return phi::Scalar(dyn_cast<ir::Int32_tAttribute>().data()); return phi::Scalar(dyn_cast<ir::Int32Attribute>().data());
} else if (isa<ir::Int64_tAttribute>()) { } else if (isa<ir::Int64Attribute>()) {
return phi::Scalar(dyn_cast<ir::Int64_tAttribute>().data()); return phi::Scalar(dyn_cast<ir::Int64Attribute>().data());
} else if (isa<ir::BoolAttribute>()) { } else if (isa<ir::BoolAttribute>()) {
return phi::Scalar(dyn_cast<ir::BoolAttribute>().data()); return phi::Scalar(dyn_cast<ir::BoolAttribute>().data());
} else { } else {
......
...@@ -44,8 +44,8 @@ class ScalarAttribute : public ir::Attribute { ...@@ -44,8 +44,8 @@ class ScalarAttribute : public ir::Attribute {
return (val.type_id() == ir::BoolAttribute::type_id()) || return (val.type_id() == ir::BoolAttribute::type_id()) ||
(val.type_id() == ir::FloatAttribute::type_id()) || (val.type_id() == ir::FloatAttribute::type_id()) ||
(val.type_id() == ir::DoubleAttribute::type_id()) || (val.type_id() == ir::DoubleAttribute::type_id()) ||
(val.type_id() == ir::Int32_tAttribute::type_id()) || (val.type_id() == ir::Int32Attribute::type_id()) ||
(val.type_id() == ir::Int64_tAttribute::type_id()); (val.type_id() == ir::Int64Attribute::type_id());
} }
phi::Scalar data(); phi::Scalar data();
......
...@@ -83,9 +83,9 @@ static inline ir::Attribute TransToIrAttribute(phi::Scalar scalar, ...@@ -83,9 +83,9 @@ static inline ir::Attribute TransToIrAttribute(phi::Scalar scalar,
case phi::DataType::FLOAT64: case phi::DataType::FLOAT64:
return ir::DoubleAttribute::get(ctx, scalar.to<double>()); return ir::DoubleAttribute::get(ctx, scalar.to<double>());
case phi::DataType::INT32: case phi::DataType::INT32:
return ir::Int32_tAttribute::get(ctx, scalar.to<int32_t>()); return ir::Int32Attribute::get(ctx, scalar.to<int32_t>());
case phi::DataType::INT64: case phi::DataType::INT64:
return ir::Int64_tAttribute::get(ctx, scalar.to<int64_t>()); return ir::Int64Attribute::get(ctx, scalar.to<int64_t>());
case phi::DataType::BOOL: case phi::DataType::BOOL:
return ir::BoolAttribute::get(ctx, scalar.to<bool>()); return ir::BoolAttribute::get(ctx, scalar.to<bool>());
default: default:
......
...@@ -38,7 +38,7 @@ class AttributeVisitor { ...@@ -38,7 +38,7 @@ class AttributeVisitor {
public: public:
virtual ir::Attribute operator()(int i) { virtual ir::Attribute operator()(int i) {
VLOG(10) << "translating int"; VLOG(10) << "translating int";
return ir::Int32_tAttribute::get(ctx, i); return ir::Int32Attribute::get(ctx, i);
} }
virtual ir::Attribute operator()(float f) { virtual ir::Attribute operator()(float f) {
...@@ -91,7 +91,7 @@ class AttributeVisitor { ...@@ -91,7 +91,7 @@ class AttributeVisitor {
std::vector<ir::Attribute> attrs; std::vector<ir::Attribute> attrs;
attrs.reserve(is.size()); attrs.reserve(is.size());
for (const auto& v : is) { for (const auto& v : is) {
attrs.push_back(ir::Int32_tAttribute::get(ctx, v)); attrs.push_back(ir::Int32Attribute::get(ctx, v));
} }
return ir::ArrayAttribute::get(ctx, attrs); return ir::ArrayAttribute::get(ctx, attrs);
} }
...@@ -111,7 +111,7 @@ class AttributeVisitor { ...@@ -111,7 +111,7 @@ class AttributeVisitor {
std::vector<ir::Attribute> attrs; std::vector<ir::Attribute> attrs;
attrs.reserve(i64s.size()); attrs.reserve(i64s.size());
for (const auto& v : i64s) { for (const auto& v : i64s) {
attrs.push_back(ir::Int64_tAttribute::get(ctx, v)); attrs.push_back(ir::Int64Attribute::get(ctx, v));
} }
return ir::ArrayAttribute::get(ctx, attrs); return ir::ArrayAttribute::get(ctx, attrs);
} }
......
...@@ -131,7 +131,7 @@ inline ir::Operation* InsertSliceOperationForTarget( ...@@ -131,7 +131,7 @@ inline ir::Operation* InsertSliceOperationForTarget(
std::string slice_op_name(ir::SliceOp::name()); std::string slice_op_name(ir::SliceOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(slice_op_name); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(slice_op_name);
std::unordered_map<std::string, ir::Attribute> op_attribute_map = { std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
{"index", ir::Int32_tAttribute::get(ctx, defining_info.idx_in_vector)}, {"index", ir::Int32Attribute::get(ctx, defining_info.idx_in_vector)},
}; };
ir::VectorType src_vec_type = ir::VectorType src_vec_type =
defining_info.value.type().dyn_cast<ir::VectorType>(); defining_info.value.type().dyn_cast<ir::VectorType>();
...@@ -179,11 +179,11 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx, ...@@ -179,11 +179,11 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx,
} else if (attr.isa<ir::DoubleAttribute>()) { } else if (attr.isa<ir::DoubleAttribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::DoubleAttribute>().data()); data = static_cast<float>(attr.dyn_cast<ir::DoubleAttribute>().data());
dtype = phi::DataType::FLOAT64; dtype = phi::DataType::FLOAT64;
} else if (attr.isa<ir::Int32_tAttribute>()) { } else if (attr.isa<ir::Int32Attribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::Int32_tAttribute>().data()); data = static_cast<float>(attr.dyn_cast<ir::Int32Attribute>().data());
dtype = phi::DataType::INT32; dtype = phi::DataType::INT32;
} else if (attr.isa<ir::Int64_tAttribute>()) { } else if (attr.isa<ir::Int64Attribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::Int64_tAttribute>().data()); data = static_cast<float>(attr.dyn_cast<ir::Int64Attribute>().data());
dtype = phi::DataType::INT64; dtype = phi::DataType::INT64;
} else if (attr.isa<ir::BoolAttribute>()) { } else if (attr.isa<ir::BoolAttribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::BoolAttribute>().data()); data = static_cast<float>(attr.dyn_cast<ir::BoolAttribute>().data());
......
...@@ -25,9 +25,9 @@ float FloatAttribute::data() const { return storage()->GetAsKey(); } ...@@ -25,9 +25,9 @@ float FloatAttribute::data() const { return storage()->GetAsKey(); }
double DoubleAttribute::data() const { return storage()->GetAsKey(); } double DoubleAttribute::data() const { return storage()->GetAsKey(); }
int32_t Int32_tAttribute::data() const { return storage()->GetAsKey(); } int32_t Int32Attribute::data() const { return storage()->GetAsKey(); }
int64_t Int64_tAttribute::data() const { return storage()->GetAsKey(); } int64_t Int64Attribute::data() const { return storage()->GetAsKey(); }
std::vector<Attribute> ArrayAttribute::data() const { std::vector<Attribute> ArrayAttribute::data() const {
return storage()->GetAsKey(); return storage()->GetAsKey();
......
...@@ -61,20 +61,20 @@ class DoubleAttribute : public Attribute { ...@@ -61,20 +61,20 @@ class DoubleAttribute : public Attribute {
double data() const; double data() const;
}; };
class Int32_tAttribute : public Attribute { class Int32Attribute : public Attribute {
public: public:
using Attribute::Attribute; using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Int32_tAttribute, Int32_tAttributeStorage); DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Int32Attribute, Int32AttributeStorage);
int32_t data() const; int32_t data() const;
}; };
class Int64_tAttribute : public Attribute { class Int64Attribute : public Attribute {
public: public:
using Attribute::Attribute; using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Int64_tAttribute, Int64_tAttributeStorage); DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Int64Attribute, Int64AttributeStorage);
int64_t data() const; int64_t data() const;
}; };
......
...@@ -81,8 +81,8 @@ struct StrAttributeStorage : public AttributeStorage { ...@@ -81,8 +81,8 @@ struct StrAttributeStorage : public AttributeStorage {
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(BoolAttributeStorage, bool); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(BoolAttributeStorage, bool);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(DoubleAttributeStorage, double); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(DoubleAttributeStorage, double);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32_tAttributeStorage, int32_t); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32AttributeStorage, int32_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64_tAttributeStorage, int64_t); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64AttributeStorage, int64_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(PointerAttributeStorage, void *); DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(PointerAttributeStorage, void *);
struct ArrayAttributeStorage : public AttributeStorage { struct ArrayAttributeStorage : public AttributeStorage {
......
...@@ -41,8 +41,8 @@ void BuiltinDialect::initialize() { ...@@ -41,8 +41,8 @@ void BuiltinDialect::initialize() {
FloatAttribute, FloatAttribute,
DoubleAttribute, DoubleAttribute,
PointerAttribute, PointerAttribute,
Int32_tAttribute, Int32Attribute,
Int64_tAttribute, Int64Attribute,
ArrayAttribute>(); ArrayAttribute>();
RegisterOps<ModuleOp, RegisterOps<ModuleOp,
......
...@@ -161,9 +161,9 @@ void SliceOp::Verify(const std::vector<ir::OpResult> &inputs, ...@@ -161,9 +161,9 @@ void SliceOp::Verify(const std::vector<ir::OpResult> &inputs,
IR_ENFORCE(attributes.count("index") != 0, IR_ENFORCE(attributes.count("index") != 0,
"The attributes must contains index."); "The attributes must contains index.");
const ir::Attribute &attr = attributes.at("index"); const ir::Attribute &attr = attributes.at("index");
IR_ENFORCE(attr.isa<ir::Int32_tAttribute>(), IR_ENFORCE(attr.isa<ir::Int32Attribute>(),
"The attribute index must be INT32."); "The attribute index must be INT32.");
auto index = attr.dyn_cast<ir::Int32_tAttribute>().data(); auto index = attr.dyn_cast<ir::Int32Attribute>().data();
// index >= 0 and < inputs[0].size() // index >= 0 and < inputs[0].size()
IR_ENFORCE( IR_ENFORCE(
......
...@@ -83,9 +83,9 @@ class BasicIRPrinter { ...@@ -83,9 +83,9 @@ class BasicIRPrinter {
os << f.data(); os << f.data();
} else if (auto d = attr.dyn_cast<DoubleAttribute>()) { } else if (auto d = attr.dyn_cast<DoubleAttribute>()) {
os << d.data(); os << d.data();
} else if (auto i = attr.dyn_cast<Int32_tAttribute>()) { } else if (auto i = attr.dyn_cast<Int32Attribute>()) {
os << i.data(); os << i.data();
} else if (auto i = attr.dyn_cast<Int64_tAttribute>()) { } else if (auto i = attr.dyn_cast<Int64Attribute>()) {
os << i.data(); os << i.data();
} else if (auto arr = attr.dyn_cast<ArrayAttribute>()) { } else if (auto arr = attr.dyn_cast<ArrayAttribute>()) {
const auto& vec = arr.data(); const auto& vec = arr.data();
......
...@@ -238,7 +238,7 @@ TEST(op_test, module_op_death) { ...@@ -238,7 +238,7 @@ TEST(op_test, module_op_death) {
// (3) Test uses for op. // (3) Test uses for op.
std::vector<ir::OpResult> inputs{ir::OpResult()}; std::vector<ir::OpResult> inputs{ir::OpResult()};
ir::AttributeMap attrs{{"program", ir::Int32_tAttribute::get(ctx, 1)}}; ir::AttributeMap attrs{{"program", ir::Int32Attribute::get(ctx, 1)}};
std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)};
EXPECT_THROW(ir::Operation::Create(inputs, {}, {}, op_info), EXPECT_THROW(ir::Operation::Create(inputs, {}, {}, op_info),
......
...@@ -283,7 +283,7 @@ TEST(program_test, slice_combine_test) { ...@@ -283,7 +283,7 @@ TEST(program_test, slice_combine_test) {
// (7) Def slice_op = SliceOp(combine_op, 0) // (7) Def slice_op = SliceOp(combine_op, 0)
std::string slice_op_name = std::string(ir::SliceOp::name()); std::string slice_op_name = std::string(ir::SliceOp::name());
ir::OpInfo slice_op_info = ctx->GetRegisteredOpInfo(slice_op_name); ir::OpInfo slice_op_info = ctx->GetRegisteredOpInfo(slice_op_name);
ir::Attribute index_attr = ir::Int32_tAttribute::get(ctx, 0); ir::Attribute index_attr = ir::Int32Attribute::get(ctx, 0);
ir::Operation *slice_op = ir::Operation *slice_op =
ir::Operation::Create({combine_op->GetResultByIndex(0)}, ir::Operation::Create({combine_op->GetResultByIndex(0)},
{{"index", index_attr}}, {{"index", index_attr}},
...@@ -319,8 +319,7 @@ TEST(program_test, builder) { ...@@ -319,8 +319,7 @@ TEST(program_test, builder) {
} }
ir::ConstantOp constant = builder.Build<ir::ConstantOp>( ir::ConstantOp constant = builder.Build<ir::ConstantOp>(
ir::Int32_tAttribute::get(ctx, 2), ir::Int32Type::get(ctx)); ir::Int32Attribute::get(ctx, 2), ir::Int32Type::get(ctx));
EXPECT_EQ(program.block()->size() == 2, true); EXPECT_EQ(program.block()->size() == 2, true);
EXPECT_EQ(constant.value().dyn_cast<ir::Int32_tAttribute>().data() == 2, EXPECT_EQ(constant.value().dyn_cast<ir::Int32Attribute>().data() == 2, true);
true);
} }
...@@ -28,7 +28,7 @@ TEST(ir_op_info_test, op_op_info_test) { ...@@ -28,7 +28,7 @@ TEST(ir_op_info_test, op_op_info_test) {
ir::Block* block = program.block(); ir::Block* block = program.block();
ir::Builder builder(context, block); ir::Builder builder(context, block);
builder.Build<ir::ConstantOp>(ir::Int32_tAttribute::get(context, 5), builder.Build<ir::ConstantOp>(ir::Int32Attribute::get(context, 5),
ir::Int32Type::get(context)); ir::Int32Type::get(context));
ir::Operation* op = block->back(); ir::Operation* op = block->back();
......
...@@ -164,9 +164,8 @@ void build_context(ir::Operation* op, ...@@ -164,9 +164,8 @@ void build_context(ir::Operation* op,
} else if (type_name == "paddle::dialect::DataTypeAttribute") { } else if (type_name == "paddle::dialect::DataTypeAttribute") {
ctx->EmplaceBackAttr( ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::DataTypeAttribute>().data()); attr_map[t].dyn_cast<paddle::dialect::DataTypeAttribute>().data());
} else if (type_name == "ir::Int32_tAttribute") { } else if (type_name == "ir::Int32Attribute") {
ctx->EmplaceBackAttr( ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::Int32Attribute>().data());
attr_map[t].dyn_cast<ir::Int32_tAttribute>().data());
} else if (type_name == "paddle::dialect::PlaceAttribute") { } else if (type_name == "paddle::dialect::PlaceAttribute") {
ctx->EmplaceBackAttr( ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data()); attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data());
......
...@@ -47,9 +47,9 @@ TEST(ScalarTest, test_classof) { ...@@ -47,9 +47,9 @@ TEST(ScalarTest, test_classof) {
ir::Attribute double_scalar = ir::DoubleAttribute::get(ctx, 1.0); ir::Attribute double_scalar = ir::DoubleAttribute::get(ctx, 1.0);
EXPECT_TRUE(double_scalar.isa<ScalarAttribute>()); EXPECT_TRUE(double_scalar.isa<ScalarAttribute>());
ir::Attribute int32_scalar = ir::Int32_tAttribute::get(ctx, 1); ir::Attribute int32_scalar = ir::Int32Attribute::get(ctx, 1);
EXPECT_TRUE(int32_scalar.isa<ScalarAttribute>()); EXPECT_TRUE(int32_scalar.isa<ScalarAttribute>());
ir::Attribute int64_scalar = ir::Int64_tAttribute::get(ctx, 1l); ir::Attribute int64_scalar = ir::Int64Attribute::get(ctx, 1l);
EXPECT_TRUE(int64_scalar.isa<ScalarAttribute>()); EXPECT_TRUE(int64_scalar.isa<ScalarAttribute>());
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册