diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index fb8c9ab96d372bde1fb4e1d86488cd5b831b93e0..221186d780ec907d10dc3d9025d1abe969be6eb9 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -78,3 +78,4 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece) cc_test(init_test SRCS init_test.cc DEPS init) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) +cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) diff --git a/paddle/framework/details/cow_ptr.h b/paddle/framework/details/cow_ptr.h new file mode 100644 index 0000000000000000000000000000000000000000..6f1dcab40b5bfe59c4b3c8f789228942475e0129 --- /dev/null +++ b/paddle/framework/details/cow_ptr.h @@ -0,0 +1,94 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include + +namespace paddle { +namespace framework { +namespace details { + +// Change it to thread safe flags if needed. +class ThreadUnsafeOwnershipFlags { + public: + ThreadUnsafeOwnershipFlags(bool flag) : flag_(flag) {} + + ThreadUnsafeOwnershipFlags(const ThreadUnsafeOwnershipFlags& o) = delete; + ThreadUnsafeOwnershipFlags& operator=(const ThreadUnsafeOwnershipFlags& o) = + delete; + ThreadUnsafeOwnershipFlags(ThreadUnsafeOwnershipFlags&& o) = default; + + void SetOwnership(bool flag) { flag_ = flag; } + + template + void AcquireOwnershipOnce(Callback acquire) { + if (!flag_) { + acquire(); + flag_ = true; + } + } + + private: + bool flag_; +}; + +// Copy On Write pointer. +// It will hold a T* pointer, and only copy once when `MutableData` is invoked. +// +// The template parameter OwnershipFlags should have: +// * a constructor takes a bool. True if own. +// * SetOwnership(bool flag). +// * AcquireOwnershipOnce(Callback). It will invoke the callback if it is not +// owned. +template +class COWPtr { + public: + // Ctor from raw pointer. + explicit COWPtr(T* ptr) : payload_(ptr), ownership_{true} {} + + // Move methods. Steal ownership from origin + COWPtr(COWPtr&& o) + : payload_(o.payload_), ownership_{std::move(o.ownership_)} {} + COWPtr& operator=(COWPtr&& origin) = default; + + // Copy methods. Not own payload + COWPtr(const COWPtr& o) : payload_(o.payload_), ownership_{false} {} + COWPtr& operator=(const COWPtr& o) { + payload_ = o.payload_; + ownership_.SetOwnership(false); + return *this; + } + + const T& Data() const { return *payload_; } + + T* MutableData() { + ownership_.AcquireOwnershipOnce( + [this] { payload_.reset(new T(*payload_)); }); + return payload_.get(); + } + + void Reset() { + ownership_.AcquireOwnershipOnce([this] { payload_.reset(); }); + payload_.reset(new T()); + } + + private: + std::shared_ptr payload_; + OwnershipFlags ownership_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/details/cow_ptr_test.cc b/paddle/framework/details/cow_ptr_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..080a0a0a448c16c0eb6e8ca63c006707fd177374 --- /dev/null +++ b/paddle/framework/details/cow_ptr_test.cc @@ -0,0 +1,39 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/details/cow_ptr.h" +#include "gtest/gtest.h" + +namespace paddle { +namespace framework { +namespace details { + +TEST(COWPtr, all) { + COWPtr ptr(new int{0}); + ASSERT_EQ(ptr.Data(), 0); + COWPtr ptr2 = ptr; + ASSERT_EQ(ptr2.Data(), 0); + ASSERT_EQ(&ptr2.Data(), &ptr.Data()); + *ptr2.MutableData() = 10; + ASSERT_EQ(ptr.Data(), 0); + ASSERT_EQ(ptr2.Data(), 10); + + auto ptr_before = ptr2.MutableData(); + ptr2.Reset(); + ASSERT_NE(ptr2.MutableData(), ptr_before); +} + +} // namespace details +} // namespace framework +} // namespace paddle