diff --git a/paddle/framework/variable.h b/paddle/framework/variable.h index b33e10e6820129a874f5355d14d8a3e990186025..72c4a7a2a1d1cf93a784f24e687727ee8481484c 100644 --- a/paddle/framework/variable.h +++ b/paddle/framework/variable.h @@ -25,21 +25,24 @@ class Variable { public: template const T& Get() const { - PADDLE_ASSERT(holder_ != nullptr); - PADDLE_ASSERT(std::type_index(typeid(T)) == - std::type_index(holder_->Type())); + PADDLE_ASSERT(IsType()); return *static_cast(holder_->Ptr()); } template T* GetMutable() { - if (holder_ == nullptr || - std::type_index(typeid(T)) != std::type_index(holder_->Type())) { + if (!IsType()) { holder_.reset(new PlaceholderImpl(new T())); } return static_cast(holder_->Ptr()); } + template + bool IsType() const { + return holder_ != nullptr && + std::type_index(typeid(T)) == std::type_index(holder_->Type()); + } + private: struct Placeholder { virtual ~Placeholder() {}