未验证 提交 cc7d1f34 编写于 作者: W Wilber 提交者: GitHub

[IR] Support TypeAttribute. (#54984)

上级 89feae07
...@@ -35,6 +35,8 @@ std::vector<Attribute> ArrayAttribute::data() const { ...@@ -35,6 +35,8 @@ std::vector<Attribute> ArrayAttribute::data() const {
void* PointerAttribute::data() const { return storage()->GetAsKey(); } void* PointerAttribute::data() const { return storage()->GetAsKey(); }
Type TypeAttribute::data() const { return storage()->GetAsKey(); }
} // namespace ir } // namespace ir
IR_DEFINE_EXPLICIT_TYPE_ID(ir::StrAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(ir::StrAttribute)
...@@ -45,3 +47,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int32Attribute) ...@@ -45,3 +47,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int32Attribute)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int64Attribute) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int64Attribute)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::ArrayAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(ir::ArrayAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::PointerAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(ir::PointerAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::TypeAttribute)
...@@ -103,6 +103,15 @@ class IR_API PointerAttribute : public Attribute { ...@@ -103,6 +103,15 @@ class IR_API PointerAttribute : public Attribute {
void* data() const; void* data() const;
}; };
class IR_API TypeAttribute : public Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(TypeAttribute, TypeAttributeStorage);
Type data() const;
};
} // namespace ir } // namespace ir
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::StrAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::StrAttribute)
...@@ -113,3 +122,4 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int32Attribute) ...@@ -113,3 +122,4 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int32Attribute)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int64Attribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int64Attribute)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ArrayAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ArrayAttribute)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::PointerAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::PointerAttribute)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::TypeAttribute)
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/ir/core/attribute.h" #include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/attribute_base.h" #include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/utils.h" #include "paddle/ir/core/utils.h"
namespace ir { namespace ir {
...@@ -131,4 +132,25 @@ struct ArrayAttributeStorage : public AttributeStorage { ...@@ -131,4 +132,25 @@ struct ArrayAttributeStorage : public AttributeStorage {
size_t length_ = 0; size_t length_ = 0;
}; };
struct TypeAttributeStorage : public AttributeStorage {
using ParamKey = Type;
explicit TypeAttributeStorage(const ParamKey &key) : value_(key) {}
static TypeAttributeStorage *Construct(ParamKey key) {
return new TypeAttributeStorage(key);
}
static std::size_t HashValue(const ParamKey &key) {
return std::hash<Type>()(key);
}
bool operator==(const ParamKey &key) const { return value_ == key; }
ParamKey GetAsKey() const { return value_; }
private:
Type value_;
};
} // namespace ir } // namespace ir
...@@ -46,7 +46,8 @@ void BuiltinDialect::initialize() { ...@@ -46,7 +46,8 @@ void BuiltinDialect::initialize() {
PointerAttribute, PointerAttribute,
Int32Attribute, Int32Attribute,
Int64Attribute, Int64Attribute,
ArrayAttribute>(); ArrayAttribute,
TypeAttribute>();
RegisterOps<ModuleOp, RegisterOps<ModuleOp,
GetParameterOp, GetParameterOp,
......
...@@ -107,6 +107,8 @@ void BasicIrPrinter::PrintAttribute(const Attribute& attr) { ...@@ -107,6 +107,8 @@ void BasicIrPrinter::PrintAttribute(const Attribute& attr) {
[this](Attribute v) { this->PrintAttribute(v); }, [this](Attribute v) { this->PrintAttribute(v); },
[this]() { this->os << ","; }); [this]() { this->os << ","; });
os << "]"; os << "]";
} else if (auto type = attr.dyn_cast<TypeAttribute>()) {
os << type.data();
} else { } else {
auto& dialect = attr.dialect(); auto& dialect = attr.dialect();
dialect.PrintAttribute(attr, os); dialect.PrintAttribute(attr, os);
......
...@@ -22,13 +22,13 @@ namespace { ...@@ -22,13 +22,13 @@ namespace {
// TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be // TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be
// removed by dce pass. // removed by dce pass.
// Now just a naive implementation. // Now just a naive implementation.
class DCEPass : public ir::Pass { class DcePass : public ir::Pass {
public: public:
DCEPass() : ir::Pass("DCEPass", 0) {} DcePass() : ir::Pass("DcePass", 0) {}
void Run(ir::Operation *op) override { void Run(ir::Operation *op) override {
auto module_op = op->dyn_cast<ir::ModuleOp>(); auto module_op = op->dyn_cast<ir::ModuleOp>();
IR_ENFORCE(module_op, "DCEPass should run on module op."); IR_ENFORCE(module_op, "DcePass should run on module op.");
auto *block = module_op.block(); auto *block = module_op.block();
std::vector<ir::Operation> erased_op; std::vector<ir::Operation> erased_op;
for (auto it = block->begin(); it != block->end(); ++it) { for (auto it = block->begin(); it != block->end(); ++it) {
...@@ -39,6 +39,7 @@ class DCEPass : public ir::Pass { ...@@ -39,6 +39,7 @@ class DCEPass : public ir::Pass {
for (uint32_t i = 0; i < (*it)->num_results(); ++i) { for (uint32_t i = 0; i < (*it)->num_results(); ++i) {
use_empty &= (*it)->result(i).use_empty(); use_empty &= (*it)->result(i).use_empty();
} }
// TODO(wilber): Support Terminator trait.
if (use_empty && (*it)->name() != "pd.fetch") { if (use_empty && (*it)->name() != "pd.fetch") {
erased_op.push_back(**it); erased_op.push_back(**it);
} }
...@@ -56,6 +57,6 @@ class DCEPass : public ir::Pass { ...@@ -56,6 +57,6 @@ class DCEPass : public ir::Pass {
namespace ir { namespace ir {
std::unique_ptr<Pass> CreateDCEPass() { return std::make_unique<DCEPass>(); } std::unique_ptr<Pass> CreateDcePass() { return std::make_unique<DcePass>(); }
} // namespace ir } // namespace ir
...@@ -20,6 +20,6 @@ ...@@ -20,6 +20,6 @@
namespace ir { namespace ir {
class Pass; class Pass;
IR_API std::unique_ptr<Pass> CreateDCEPass(); IR_API std::unique_ptr<Pass> CreateDcePass();
} // namespace ir } // namespace ir
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/ir/core/attribute_base.h" #include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h" #include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
...@@ -63,4 +64,10 @@ TEST(attribute_test, built_in_attribute) { ...@@ -63,4 +64,10 @@ TEST(attribute_test, built_in_attribute) {
string_attr_1.dyn_cast<ir::StrAttribute>(); string_attr_1.dyn_cast<ir::StrAttribute>();
EXPECT_EQ(string_attr_cast_1.isa<ir::StrAttribute>(), true); EXPECT_EQ(string_attr_cast_1.isa<ir::StrAttribute>(), true);
EXPECT_EQ(string_attr_cast_1.size() == 8, 1); EXPECT_EQ(string_attr_cast_1.size() == 8, 1);
ir::Int32Type i32_type = ir::Int32Type::get(ctx);
ir::Attribute type_attr = ir::TypeAttribute::get(ctx, i32_type);
EXPECT_TRUE(type_attr.isa<ir::TypeAttribute>());
EXPECT_EQ(type_attr.dyn_cast<ir::TypeAttribute>().data().type_id(),
i32_type.type_id());
} }
...@@ -429,7 +429,7 @@ TEST(pattern_rewrite, Patterns) { ...@@ -429,7 +429,7 @@ TEST(pattern_rewrite, Patterns) {
ir::PassManager pm(ctx); ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<TestPass>()); pm.AddPass(std::make_unique<TestPass>());
pm.AddPass(ir::CreateDCEPass()); pm.AddPass(ir::CreateDcePass());
program.Print(std::cout); program.Print(std::cout);
std::cout << std::endl; std::cout << std::endl;
pm.Run(&program); pm.Run(&program);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册