diff --git a/paddle/ir/builder.h b/paddle/ir/builder.h index a9b67582b7d150dd43f1c5415d5e4e30d04c394a..cfca24cd5dde643135dd7252077ffe90f1fd7849 100644 --- a/paddle/ir/builder.h +++ b/paddle/ir/builder.h @@ -34,7 +34,7 @@ class Builder { OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name())); OpTy::build(*this, argument, std::forward(args)...); Operation *op = Operation::create(argument); - return dyn_cast(op); + return op->dyn_cast(); } private: diff --git a/paddle/ir/op_base.h b/paddle/ir/op_base.h index 9a42afa0c67a8479cb4613ecfc268333404d7eb3..b245cb7850f463febbb4bc340a1337ac28f59ea0 100644 --- a/paddle/ir/op_base.h +++ b/paddle/ir/op_base.h @@ -89,6 +89,13 @@ class OpTraitBase : public OpBase { explicit OpTraitBase(const Operation *op) : OpBase(op) {} static TypeId GetTraitId() { return TypeId::get(); } + + static ConcreteTrait dyn_cast(const Operation *op) { + if (op->HasTrait()) { + return ConcreteTrait(op); + } + return ConcreteTrait(nullptr); + } }; /// @@ -102,6 +109,14 @@ class OpInterfaceBase : public OpBase { explicit OpInterfaceBase(const Operation *op) : OpBase(op) {} static TypeId GetInterfaceId() { return TypeId::get(); } + + static ConcreteInterface dyn_cast(const Operation *op) { + if (op->HasInterface()) { + return ConcreteInterface( + op, op->op_info().GetInterfaceImpl()); + } + return ConcreteInterface(nullptr, nullptr); + } }; template @@ -168,6 +183,13 @@ class Op : public OpBase { using InterfaceList = typename Filter>::Type; + static ConcreteOp dyn_cast(const Operation *op) { + if (op->op_info().id() == TypeId::get()) { + return ConcreteOp(op); + } + return ConcreteOp(nullptr); + } + static std::vector GetInterfaceMap() { constexpr size_t interfaces_num = std::tuple_size::value; std::vector interfaces_map(interfaces_num); diff --git a/paddle/ir/op_info.cc b/paddle/ir/op_info.cc index e68839f937dd66cf2581c2163d28d1d6ffaae885..9aed5754daa294f397ad767499ce7649ce156423 100644 --- a/paddle/ir/op_info.cc +++ b/paddle/ir/op_info.cc @@ -32,6 +32,8 @@ IrContext *OpInfo::ir_context() const { const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; } +TypeId OpInfo::id() const { return impl_ ? impl_->id() : TypeId(); } + void *OpInfo::GetInterfaceImpl(TypeId interface_id) const { return impl_ ? impl_->interface_impl(interface_id) : nullptr; } diff --git a/paddle/ir/op_info.h b/paddle/ir/op_info.h index 14526c091cd8cd5316652da98bd34d9b71ce8a94..6c8b6c1f0d6baf430560c5dfdacc19e2ea5bc809 100644 --- a/paddle/ir/op_info.h +++ b/paddle/ir/op_info.h @@ -42,6 +42,8 @@ class OpInfo { const char *name() const; + TypeId id() const; + template bool HasTrait() const { return HasTrait(TypeId::get()); diff --git a/paddle/ir/operation.h b/paddle/ir/operation.h index c2346bf180ad11bca047e005f7c8683fae72f9a0..0b7da942d6aa4369f09271092ffef186179d51ad 100644 --- a/paddle/ir/operation.h +++ b/paddle/ir/operation.h @@ -21,10 +21,7 @@ #include "paddle/ir/value_impl.h" namespace ir { -template -class OpTraitBase; -template -class OpInterfaceBase; +class OpBase; class Program; class alignas(8) Operation final { @@ -94,25 +91,15 @@ class alignas(8) Operation final { template struct CastUtil { static T call(const Operation *op) { - throw("Can't dyn_cast to T, T should be a Trait or Interface"); - } - }; - template - struct CastUtil, T>::value>::type> { - static T call(const Operation *op) { - return T(op->HasTrait() ? op : nullptr); + throw("Can't dyn_cast to T, T should be a Op or Trait or Interface"); } }; + template - struct CastUtil, T>::value>::type> { - static T call(const Operation *op) { - typename T::Concept *interface_impl = op->op_info().GetInterfaceImpl(); - return interface_impl ? T(op, interface_impl) : T(nullptr, nullptr); - } + struct CastUtil< + T, + typename std::enable_if::value>::type> { + static T call(const Operation *op) { return T::dyn_cast(op); } }; AttributeMap attribute_; diff --git a/test/cpp/ir/ir_op_test.cc b/test/cpp/ir/ir_op_test.cc index fc23166fa41e71e9bc250586019ffc3c15ffc02e..bcdbfb2f8f420d2ec4c4e7fb05dbca9e5a2dbec5 100644 --- a/test/cpp/ir/ir_op_test.cc +++ b/test/cpp/ir/ir_op_test.cc @@ -138,14 +138,13 @@ TEST(op_test, op_test) { CreateAttributeMap("op1_name", "op1_attr"), op2_info); - if (op->HasTrait()) { - ReadOnlyTrait trait = op->dyn_cast(); - EXPECT_EQ(trait.operation(), op); - } - if (op->HasInterface()) { - InferShapeInterface interface = op->dyn_cast(); - interface.InferShape(); - } + ReadOnlyTrait trait = op->dyn_cast(); + EXPECT_EQ(trait.operation(), op); + InferShapeInterface interface = op->dyn_cast(); + interface.InferShape(); + + Operation2 Op2 = op->dyn_cast(); + EXPECT_EQ(Op2.operation(), op); op->destroy(); }