diff --git a/paddle/framework/variable.h b/paddle/framework/variable.h index 249b9b1b37f923188f505b0a6ec19de29e009619..b21c95a1a6df513d4839b91fb475f1517817098e 100644 --- a/paddle/framework/variable.h +++ b/paddle/framework/variable.h @@ -12,6 +12,8 @@ */ #pragma once +#include +#include #include namespace paddle { @@ -26,24 +28,14 @@ class Variable { template T* GetMutable() { - if (holder_ != nullptr && typeid(T) == holder_->Type()) { - return static_cast(holder_->Ptr()); - } else { - return Reset(new T(), DefaultDeleter()); + if (holder_ == nullptr || + std::type_index(typeid(T)) != std::type_index(holder_->Type())) { + holder_.reset(new PlaceholderImpl(new T())); } - } - - ~Variable() { - if (holder_ != nullptr) delete holder_; + return static_cast(holder_->Ptr()); } private: - // DefaultDeleter is functor which uses C++'s delete(T*). - template - struct DefaultDeleter { - void operator()(T* ptr) { delete ptr; } - }; - struct Placeholder { virtual ~Placeholder() {} virtual const std::type_info& Type() const = 0; @@ -54,34 +46,17 @@ class Variable { // parameter of Variable. template struct PlaceholderImpl : public Placeholder { - typedef std::function Deleter; - PlaceholderImpl(T* ptr) : ptr_(ptr), type_(typeid(T)) {} - PlaceholderImpl(T* ptr, Deleter d) - : ptr_(ptr), type_(typeid(T)), deleter_(d) {} - virtual ~PlaceholderImpl() { - deleter_(ptr_); - ptr_ = nullptr; - } virtual const std::type_info& Type() const { return type_; } - virtual void* Ptr() const { return ptr_; } + virtual void* Ptr() const { return static_cast(ptr_.get()); } - T* ptr_ = nullptr; + std::unique_ptr ptr_; const std::type_info& type_; - std::function deleter_ = DefaultDeleter(); }; - template - T* Reset(T* allocated, typename PlaceholderImpl::Deleter deleter) { - if (holder_ != nullptr) { - delete holder_; - } - holder_ = new PlaceholderImpl(allocated, deleter); - return allocated; - } - - Placeholder* holder_; // pointers to a PlaceholderImpl object indeed. + std::unique_ptr + holder_; // pointers to a PlaceholderImpl object indeed. }; } // namespace framework