diff --git a/paddle/framework/details/cow_ptr.h b/paddle/framework/details/cow_ptr.h index 6f1dcab40b5bfe59c4b3c8f789228942475e0129..7e308ffb5a49876aa2c1833b3b7e2a2c7eb137aa 100644 --- a/paddle/framework/details/cow_ptr.h +++ b/paddle/framework/details/cow_ptr.h @@ -25,13 +25,14 @@ class ThreadUnsafeOwnershipFlags { public: ThreadUnsafeOwnershipFlags(bool flag) : flag_(flag) {} - ThreadUnsafeOwnershipFlags(const ThreadUnsafeOwnershipFlags& o) = delete; - ThreadUnsafeOwnershipFlags& operator=(const ThreadUnsafeOwnershipFlags& o) = - delete; - ThreadUnsafeOwnershipFlags(ThreadUnsafeOwnershipFlags&& o) = default; + ThreadUnsafeOwnershipFlags(const ThreadUnsafeOwnershipFlags& other) = delete; + ThreadUnsafeOwnershipFlags& operator=( + const ThreadUnsafeOwnershipFlags& other) = delete; + ThreadUnsafeOwnershipFlags(ThreadUnsafeOwnershipFlags&& other) = default; void SetOwnership(bool flag) { flag_ = flag; } + // Invoke the callback if it is not owned. template void AcquireOwnershipOnce(Callback acquire) { if (!flag_) { @@ -44,7 +45,7 @@ class ThreadUnsafeOwnershipFlags { bool flag_; }; -// Copy On Write pointer. +// Copy-On-Write pointer. // It will hold a T* pointer, and only copy once when `MutableData` is invoked. // // The template parameter OwnershipFlags should have: @@ -52,6 +53,8 @@ class ThreadUnsafeOwnershipFlags { // * SetOwnership(bool flag). // * AcquireOwnershipOnce(Callback). It will invoke the callback if it is not // owned. +// +// https://en.wikipedia.org/wiki/Copy-on-write template class COWPtr { public: @@ -59,33 +62,34 @@ class COWPtr { explicit COWPtr(T* ptr) : payload_(ptr), ownership_{true} {} // Move methods. Steal ownership from origin - COWPtr(COWPtr&& o) - : payload_(o.payload_), ownership_{std::move(o.ownership_)} {} + COWPtr(COWPtr&& other) + : payload_(other.payload_), ownership_{std::move(other.ownership_)} {} COWPtr& operator=(COWPtr&& origin) = default; // Copy methods. Not own payload - COWPtr(const COWPtr& o) : payload_(o.payload_), ownership_{false} {} - COWPtr& operator=(const COWPtr& o) { - payload_ = o.payload_; + COWPtr(const COWPtr& other) : payload_(other.payload_), ownership_{false} {} + COWPtr& operator=(const COWPtr& other) { + payload_ = other.payload_; ownership_.SetOwnership(false); return *this; } + // Access read only data. const T& Data() const { return *payload_; } + // Access mutable data. If the data is not owned, the data will be copied + // before. T* MutableData() { ownership_.AcquireOwnershipOnce( [this] { payload_.reset(new T(*payload_)); }); return payload_.get(); } - void Reset() { - ownership_.AcquireOwnershipOnce([this] { payload_.reset(); }); - payload_.reset(new T()); - } - private: + // Actual data pointer. std::shared_ptr payload_; + + // Ownership flag. OwnershipFlags ownership_; }; diff --git a/paddle/framework/details/cow_ptr_test.cc b/paddle/framework/details/cow_ptr_test.cc index 080a0a0a448c16c0eb6e8ca63c006707fd177374..936954a2333e7e5d2a932abad641279db9ef7b9f 100644 --- a/paddle/framework/details/cow_ptr_test.cc +++ b/paddle/framework/details/cow_ptr_test.cc @@ -28,10 +28,6 @@ TEST(COWPtr, all) { *ptr2.MutableData() = 10; ASSERT_EQ(ptr.Data(), 0); ASSERT_EQ(ptr2.Data(), 10); - - auto ptr_before = ptr2.MutableData(); - ptr2.Reset(); - ASSERT_NE(ptr2.MutableData(), ptr_before); } } // namespace details