diff --git a/paddle/pten/api/lib/utils/allocator.h b/paddle/pten/api/lib/utils/allocator.h index 8a8569c73edaeaced7bd1569922ae695424ba8e8..ab7a0fcef289f59abd7f8c1429432ca719c4e7bf 100644 --- a/paddle/pten/api/lib/utils/allocator.h +++ b/paddle/pten/api/lib/utils/allocator.h @@ -28,8 +28,8 @@ class DefaultAllocator : public pten::Allocator { explicit DefaultAllocator(const paddle::platform::Place& place) : place_(place) {} - static void Delete(void* data) { - deleter_(static_cast(data)); + static void Delete(Allocation* allocation) { + deleter_(allocation->CastContextWithoutCheck()); } Allocation Allocate(size_t bytes_size) override { diff --git a/paddle/pten/core/allocator.h b/paddle/pten/core/allocator.h index 9c6f749609a48fdf0125bb48ac820c5da8bc15b4..f03c591b1db2061dc19ed7c40d6a45325e134309 100644 --- a/paddle/pten/core/allocator.h +++ b/paddle/pten/core/allocator.h @@ -55,29 +55,47 @@ class RawAllocator { class Allocation final { public: using Place = paddle::platform::Place; - using DeleterFnPtr = void (*)(void*); + using DeleterFnPtr = void (*)(Allocation*); Allocation() = default; - Allocation(Allocation&&) = default; - Allocation& operator=(Allocation&&) = default; + // Don't own resources, only provide access. Allocation(void* data, const Place& place) : data_(data), place_(place) {} - Allocation(void* data, - void* ctx, - DeleterFnPtr ctx_deleter, - const Place& place) - : data_(data), ctx_(ctx, ctx_deleter), place_(place) {} + // Own resources. + Allocation(void* data, void* ctx, DeleterFnPtr deleter, const Place& place) + : data_(data), ctx_(ctx), deleter_(deleter), place_(place) {} + Allocation(Allocation&& other) { swap(*this, other); } + Allocation& operator=(Allocation&& other) { + // Exchange them explicitly to avoid moving is equivalent + // to copying. + swap(*this, other); + return *this; + } + ~Allocation() { Clear(); } + + void* ptr() const noexcept { return data_; } void* operator->() const noexcept { return data_; } - operator bool() const noexcept { return data_ || ctx_.Get(); } + operator bool() const noexcept { return data_ || ctx_; } const Place& place() const noexcept { return place_; } void Clear() { - ctx_.Clear(); + if (deleter_) { + deleter_(this); + } + ctx_ = nullptr; + deleter_ = nullptr; data_ = nullptr; } + DeleterFnPtr deleter() const noexcept { return deleter_; } + + template + T* CastContextWithoutCheck() const noexcept { + return static_cast(ctx_); + } + /// \brief Statically cast the void pointer of the context object to /// the primitive type. Conversion of any pointer to void* and back /// to pointer to the original cv type preserves its original value. @@ -85,60 +103,31 @@ class Allocation final { /// \param expected_deleter The destructor passed in to enhance type /// safety checking. template - T* CastContext(DeleterFnPtr expected_deleter) const noexcept { - if (ctx_.deleter() != expected_deleter) { - return nullptr; - } - return static_cast(ctx_.Get()); + T* CastContext(DeleterFnPtr expected_deleter) const { + PADDLE_ENFORCE_EQ( + deleter_ == expected_deleter, + true, + paddle::platform::errors::InvalidArgument( + "The deleter of the allocation does not match, so the pointer " + "cannot be safely removed.")); + return CastContextWithoutCheck(); } - public: - class Context { - public: - Context() = default; - Context(void* ctx, DeleterFnPtr deleter) noexcept : ctx_(ctx), - deleter_(deleter) {} - Context(Context&& other) noexcept { - // Exchange them explicitly to avoid moving is equivalent - // to copying. - swap(*this, other); - } - Context& operator=(Context&& other) noexcept { - swap(*this, other); - return *this; - } - ~Context() { Clear(); } - void Clear() { - if (deleter_) { - deleter_(ctx_); - } - ctx_ = nullptr; - deleter_ = nullptr; - } - void* Get() const noexcept { return ctx_; } - DeleterFnPtr deleter() const noexcept { return deleter_; } - void* Release() noexcept { - deleter_ = nullptr; - return ctx_; - } - friend void swap(Context& a, Context& b) noexcept; - - private: - void* ctx_{nullptr}; - DeleterFnPtr deleter_{nullptr}; - }; - private: + friend void swap(Allocation& a, Allocation& b) noexcept; void* data_{nullptr}; - Context ctx_; + void* ctx_{nullptr}; + DeleterFnPtr deleter_{nullptr}; // TODO(Shixiaowei02): Enum needs to be used instead to reduce // the construction overhead by more than 50%. Place place_; }; -inline void swap(Allocation::Context& a, Allocation::Context& b) noexcept { +inline void swap(Allocation& a, Allocation& b) noexcept { + ::std::swap(a.data_, b.data_); ::std::swap(a.ctx_, b.ctx_); ::std::swap(a.deleter_, b.deleter_); + ::std::swap(a.place_, b.place_); } /// \brief Context compatible allocator interface. This allocator is diff --git a/paddle/pten/tests/core/allocator.h b/paddle/pten/tests/core/allocator.h index 053e8ba7b382b6527d2148c45ab928e92f1a3f8e..4a5a9b1690923d65f199de94ed7284284a7db15b 100644 --- a/paddle/pten/tests/core/allocator.h +++ b/paddle/pten/tests/core/allocator.h @@ -38,7 +38,9 @@ class HostAllocatorSample : public pten::RawAllocator { class FancyAllocator : public pten::Allocator { public: - static void Delete(void* data) { ::operator delete(data); } + static void Delete(Allocation* allocation) { + ::operator delete(allocation->ptr()); + } Allocation Allocate(size_t bytes_size) override { void* data = ::operator new(bytes_size);