未验证 提交 9ded0692 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine operation dyn_cast (#53996)

* refine op dyn_cast

* fix bug

* refine code

* refine code

* refine code

* refine code
上级 a9b1e887
...@@ -34,7 +34,7 @@ class Builder { ...@@ -34,7 +34,7 @@ class Builder {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name())); OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::build(*this, argument, std::forward<Args>(args)...); OpTy::build(*this, argument, std::forward<Args>(args)...);
Operation *op = Operation::create(argument); Operation *op = Operation::create(argument);
return dyn_cast<OpTy>(op); return op->dyn_cast<OpTy>();
} }
private: private:
......
...@@ -89,6 +89,13 @@ class OpTraitBase : public OpBase { ...@@ -89,6 +89,13 @@ class OpTraitBase : public OpBase {
explicit OpTraitBase(const Operation *op) : OpBase(op) {} explicit OpTraitBase(const Operation *op) : OpBase(op) {}
static TypeId GetTraitId() { return TypeId::get<ConcreteTrait>(); } static TypeId GetTraitId() { return TypeId::get<ConcreteTrait>(); }
static ConcreteTrait dyn_cast(const Operation *op) {
if (op->HasTrait<ConcreteTrait>()) {
return ConcreteTrait(op);
}
return ConcreteTrait(nullptr);
}
}; };
/// ///
...@@ -102,6 +109,14 @@ class OpInterfaceBase : public OpBase { ...@@ -102,6 +109,14 @@ class OpInterfaceBase : public OpBase {
explicit OpInterfaceBase(const Operation *op) : OpBase(op) {} explicit OpInterfaceBase(const Operation *op) : OpBase(op) {}
static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); } static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); }
static ConcreteInterface dyn_cast(const Operation *op) {
if (op->HasInterface<ConcreteInterface>()) {
return ConcreteInterface(
op, op->op_info().GetInterfaceImpl<ConcreteInterface>());
}
return ConcreteInterface(nullptr, nullptr);
}
}; };
template <typename ConcreteOp, typename... Args> template <typename ConcreteOp, typename... Args>
...@@ -168,6 +183,13 @@ class Op : public OpBase { ...@@ -168,6 +183,13 @@ class Op : public OpBase {
using InterfaceList = using InterfaceList =
typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type; typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type;
static ConcreteOp dyn_cast(const Operation *op) {
if (op->op_info().id() == TypeId::get<ConcreteOp>()) {
return ConcreteOp(op);
}
return ConcreteOp(nullptr);
}
static std::vector<InterfaceValue> GetInterfaceMap() { static std::vector<InterfaceValue> GetInterfaceMap() {
constexpr size_t interfaces_num = std::tuple_size<InterfaceList>::value; constexpr size_t interfaces_num = std::tuple_size<InterfaceList>::value;
std::vector<InterfaceValue> interfaces_map(interfaces_num); std::vector<InterfaceValue> interfaces_map(interfaces_num);
......
...@@ -32,6 +32,8 @@ IrContext *OpInfo::ir_context() const { ...@@ -32,6 +32,8 @@ IrContext *OpInfo::ir_context() const {
const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; } const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; }
TypeId OpInfo::id() const { return impl_ ? impl_->id() : TypeId(); }
void *OpInfo::GetInterfaceImpl(TypeId interface_id) const { void *OpInfo::GetInterfaceImpl(TypeId interface_id) const {
return impl_ ? impl_->interface_impl(interface_id) : nullptr; return impl_ ? impl_->interface_impl(interface_id) : nullptr;
} }
......
...@@ -42,6 +42,8 @@ class OpInfo { ...@@ -42,6 +42,8 @@ class OpInfo {
const char *name() const; const char *name() const;
TypeId id() const;
template <typename Trait> template <typename Trait>
bool HasTrait() const { bool HasTrait() const {
return HasTrait(TypeId::get<Trait>()); return HasTrait(TypeId::get<Trait>());
......
...@@ -21,10 +21,7 @@ ...@@ -21,10 +21,7 @@
#include "paddle/ir/value_impl.h" #include "paddle/ir/value_impl.h"
namespace ir { namespace ir {
template <class ConcreteTrait> class OpBase;
class OpTraitBase;
template <typename ConcreteInterface>
class OpInterfaceBase;
class Program; class Program;
class alignas(8) Operation final { class alignas(8) Operation final {
...@@ -94,25 +91,15 @@ class alignas(8) Operation final { ...@@ -94,25 +91,15 @@ class alignas(8) Operation final {
template <typename T, typename Enabler = void> template <typename T, typename Enabler = void>
struct CastUtil { struct CastUtil {
static T call(const Operation *op) { static T call(const Operation *op) {
throw("Can't dyn_cast to T, T should be a Trait or Interface"); throw("Can't dyn_cast to T, T should be a Op or Trait or Interface");
}
};
template <typename T>
struct CastUtil<T,
typename std::enable_if<
std::is_base_of<OpTraitBase<T>, T>::value>::type> {
static T call(const Operation *op) {
return T(op->HasTrait<T>() ? op : nullptr);
} }
}; };
template <typename T> template <typename T>
struct CastUtil<T, struct CastUtil<
typename std::enable_if< T,
std::is_base_of<OpInterfaceBase<T>, T>::value>::type> { typename std::enable_if<std::is_base_of<OpBase, T>::value>::type> {
static T call(const Operation *op) { static T call(const Operation *op) { return T::dyn_cast(op); }
typename T::Concept *interface_impl = op->op_info().GetInterfaceImpl<T>();
return interface_impl ? T(op, interface_impl) : T(nullptr, nullptr);
}
}; };
AttributeMap attribute_; AttributeMap attribute_;
......
...@@ -138,14 +138,13 @@ TEST(op_test, op_test) { ...@@ -138,14 +138,13 @@ TEST(op_test, op_test) {
CreateAttributeMap("op1_name", "op1_attr"), CreateAttributeMap("op1_name", "op1_attr"),
op2_info); op2_info);
if (op->HasTrait<ReadOnlyTrait>()) { ReadOnlyTrait trait = op->dyn_cast<ReadOnlyTrait>();
ReadOnlyTrait trait = op->dyn_cast<ReadOnlyTrait>(); EXPECT_EQ(trait.operation(), op);
EXPECT_EQ(trait.operation(), op); InferShapeInterface interface = op->dyn_cast<InferShapeInterface>();
} interface.InferShape();
if (op->HasInterface<InferShapeInterface>()) {
InferShapeInterface interface = op->dyn_cast<InferShapeInterface>(); Operation2 Op2 = op->dyn_cast<Operation2>();
interface.InferShape(); EXPECT_EQ(Op2.operation(), op);
}
op->destroy(); op->destroy();
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册