未验证 提交 9ded0692 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine operation dyn_cast (#53996)

* refine op dyn_cast

* fix bug

* refine code

* refine code

* refine code

* refine code
上级 a9b1e887
......@@ -34,7 +34,7 @@ class Builder {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::build(*this, argument, std::forward<Args>(args)...);
Operation *op = Operation::create(argument);
return dyn_cast<OpTy>(op);
return op->dyn_cast<OpTy>();
}
private:
......
......@@ -89,6 +89,13 @@ class OpTraitBase : public OpBase {
explicit OpTraitBase(const Operation *op) : OpBase(op) {}
static TypeId GetTraitId() { return TypeId::get<ConcreteTrait>(); }
static ConcreteTrait dyn_cast(const Operation *op) {
if (op->HasTrait<ConcreteTrait>()) {
return ConcreteTrait(op);
}
return ConcreteTrait(nullptr);
}
};
///
......@@ -102,6 +109,14 @@ class OpInterfaceBase : public OpBase {
explicit OpInterfaceBase(const Operation *op) : OpBase(op) {}
static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); }
static ConcreteInterface dyn_cast(const Operation *op) {
if (op->HasInterface<ConcreteInterface>()) {
return ConcreteInterface(
op, op->op_info().GetInterfaceImpl<ConcreteInterface>());
}
return ConcreteInterface(nullptr, nullptr);
}
};
template <typename ConcreteOp, typename... Args>
......@@ -168,6 +183,13 @@ class Op : public OpBase {
using InterfaceList =
typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type;
static ConcreteOp dyn_cast(const Operation *op) {
if (op->op_info().id() == TypeId::get<ConcreteOp>()) {
return ConcreteOp(op);
}
return ConcreteOp(nullptr);
}
static std::vector<InterfaceValue> GetInterfaceMap() {
constexpr size_t interfaces_num = std::tuple_size<InterfaceList>::value;
std::vector<InterfaceValue> interfaces_map(interfaces_num);
......
......@@ -32,6 +32,8 @@ IrContext *OpInfo::ir_context() const {
const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; }
TypeId OpInfo::id() const { return impl_ ? impl_->id() : TypeId(); }
void *OpInfo::GetInterfaceImpl(TypeId interface_id) const {
return impl_ ? impl_->interface_impl(interface_id) : nullptr;
}
......
......@@ -42,6 +42,8 @@ class OpInfo {
const char *name() const;
TypeId id() const;
template <typename Trait>
bool HasTrait() const {
return HasTrait(TypeId::get<Trait>());
......
......@@ -21,10 +21,7 @@
#include "paddle/ir/value_impl.h"
namespace ir {
template <class ConcreteTrait>
class OpTraitBase;
template <typename ConcreteInterface>
class OpInterfaceBase;
class OpBase;
class Program;
class alignas(8) Operation final {
......@@ -94,25 +91,15 @@ class alignas(8) Operation final {
template <typename T, typename Enabler = void>
struct CastUtil {
static T call(const Operation *op) {
throw("Can't dyn_cast to T, T should be a Trait or Interface");
}
};
template <typename T>
struct CastUtil<T,
typename std::enable_if<
std::is_base_of<OpTraitBase<T>, T>::value>::type> {
static T call(const Operation *op) {
return T(op->HasTrait<T>() ? op : nullptr);
throw("Can't dyn_cast to T, T should be a Op or Trait or Interface");
}
};
template <typename T>
struct CastUtil<T,
typename std::enable_if<
std::is_base_of<OpInterfaceBase<T>, T>::value>::type> {
static T call(const Operation *op) {
typename T::Concept *interface_impl = op->op_info().GetInterfaceImpl<T>();
return interface_impl ? T(op, interface_impl) : T(nullptr, nullptr);
}
struct CastUtil<
T,
typename std::enable_if<std::is_base_of<OpBase, T>::value>::type> {
static T call(const Operation *op) { return T::dyn_cast(op); }
};
AttributeMap attribute_;
......
......@@ -138,14 +138,13 @@ TEST(op_test, op_test) {
CreateAttributeMap("op1_name", "op1_attr"),
op2_info);
if (op->HasTrait<ReadOnlyTrait>()) {
ReadOnlyTrait trait = op->dyn_cast<ReadOnlyTrait>();
EXPECT_EQ(trait.operation(), op);
}
if (op->HasInterface<InferShapeInterface>()) {
InferShapeInterface interface = op->dyn_cast<InferShapeInterface>();
interface.InferShape();
}
ReadOnlyTrait trait = op->dyn_cast<ReadOnlyTrait>();
EXPECT_EQ(trait.operation(), op);
InferShapeInterface interface = op->dyn_cast<InferShapeInterface>();
interface.InferShape();
Operation2 Op2 = op->dyn_cast<Operation2>();
EXPECT_EQ(Op2.operation(), op);
op->destroy();
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册