未验证 提交 cc7d1f34 编写于 作者: W Wilber 提交者: GitHub

[IR] Support TypeAttribute. (#54984)

上级 89feae07
......@@ -35,6 +35,8 @@ std::vector<Attribute> ArrayAttribute::data() const {
void* PointerAttribute::data() const { return storage()->GetAsKey(); }
Type TypeAttribute::data() const { return storage()->GetAsKey(); }
} // namespace ir
IR_DEFINE_EXPLICIT_TYPE_ID(ir::StrAttribute)
......@@ -45,3 +47,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int32Attribute)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int64Attribute)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::ArrayAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::PointerAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::TypeAttribute)
......@@ -103,6 +103,15 @@ class IR_API PointerAttribute : public Attribute {
void* data() const;
};
class IR_API TypeAttribute : public Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(TypeAttribute, TypeAttributeStorage);
Type data() const;
};
} // namespace ir
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::StrAttribute)
......@@ -113,3 +122,4 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int32Attribute)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int64Attribute)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ArrayAttribute)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::PointerAttribute)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::TypeAttribute)
......@@ -20,6 +20,7 @@
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/utils.h"
namespace ir {
......@@ -131,4 +132,25 @@ struct ArrayAttributeStorage : public AttributeStorage {
size_t length_ = 0;
};
struct TypeAttributeStorage : public AttributeStorage {
using ParamKey = Type;
explicit TypeAttributeStorage(const ParamKey &key) : value_(key) {}
static TypeAttributeStorage *Construct(ParamKey key) {
return new TypeAttributeStorage(key);
}
static std::size_t HashValue(const ParamKey &key) {
return std::hash<Type>()(key);
}
bool operator==(const ParamKey &key) const { return value_ == key; }
ParamKey GetAsKey() const { return value_; }
private:
Type value_;
};
} // namespace ir
......@@ -46,7 +46,8 @@ void BuiltinDialect::initialize() {
PointerAttribute,
Int32Attribute,
Int64Attribute,
ArrayAttribute>();
ArrayAttribute,
TypeAttribute>();
RegisterOps<ModuleOp,
GetParameterOp,
......
......@@ -107,6 +107,8 @@ void BasicIrPrinter::PrintAttribute(const Attribute& attr) {
[this](Attribute v) { this->PrintAttribute(v); },
[this]() { this->os << ","; });
os << "]";
} else if (auto type = attr.dyn_cast<TypeAttribute>()) {
os << type.data();
} else {
auto& dialect = attr.dialect();
dialect.PrintAttribute(attr, os);
......
......@@ -22,13 +22,13 @@ namespace {
// TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be
// removed by dce pass.
// Now just a naive implementation.
class DCEPass : public ir::Pass {
class DcePass : public ir::Pass {
public:
DCEPass() : ir::Pass("DCEPass", 0) {}
DcePass() : ir::Pass("DcePass", 0) {}
void Run(ir::Operation *op) override {
auto module_op = op->dyn_cast<ir::ModuleOp>();
IR_ENFORCE(module_op, "DCEPass should run on module op.");
IR_ENFORCE(module_op, "DcePass should run on module op.");
auto *block = module_op.block();
std::vector<ir::Operation> erased_op;
for (auto it = block->begin(); it != block->end(); ++it) {
......@@ -39,6 +39,7 @@ class DCEPass : public ir::Pass {
for (uint32_t i = 0; i < (*it)->num_results(); ++i) {
use_empty &= (*it)->result(i).use_empty();
}
// TODO(wilber): Support Terminator trait.
if (use_empty && (*it)->name() != "pd.fetch") {
erased_op.push_back(**it);
}
......@@ -56,6 +57,6 @@ class DCEPass : public ir::Pass {
namespace ir {
std::unique_ptr<Pass> CreateDCEPass() { return std::make_unique<DCEPass>(); }
std::unique_ptr<Pass> CreateDcePass() { return std::make_unique<DcePass>(); }
} // namespace ir
......@@ -20,6 +20,6 @@
namespace ir {
class Pass;
IR_API std::unique_ptr<Pass> CreateDCEPass();
IR_API std::unique_ptr<Pass> CreateDcePass();
} // namespace ir
......@@ -19,6 +19,7 @@
#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h"
......@@ -63,4 +64,10 @@ TEST(attribute_test, built_in_attribute) {
string_attr_1.dyn_cast<ir::StrAttribute>();
EXPECT_EQ(string_attr_cast_1.isa<ir::StrAttribute>(), true);
EXPECT_EQ(string_attr_cast_1.size() == 8, 1);
ir::Int32Type i32_type = ir::Int32Type::get(ctx);
ir::Attribute type_attr = ir::TypeAttribute::get(ctx, i32_type);
EXPECT_TRUE(type_attr.isa<ir::TypeAttribute>());
EXPECT_EQ(type_attr.dyn_cast<ir::TypeAttribute>().data().type_id(),
i32_type.type_id());
}
......@@ -429,7 +429,7 @@ TEST(pattern_rewrite, Patterns) {
ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<TestPass>());
pm.AddPass(ir::CreateDCEPass());
pm.AddPass(ir::CreateDcePass());
program.Print(std::cout);
std::cout << std::endl;
pm.Run(&program);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册