diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 46439c2d2775d9de1e2bcc0cddd27b75d8bc9eb6..af4079875a50ffe6eb627492f834fb601bbee716 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -79,6 +79,6 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece operat cc_test(init_test SRCS init_test.cc DEPS init) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) - +cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) nv_test(device_data_transform_test SRCS device_data_transform_test.cu DEPS operator op_registry init math_function) diff --git a/paddle/framework/details/cow_ptr.h b/paddle/framework/details/cow_ptr.h new file mode 100644 index 0000000000000000000000000000000000000000..7e308ffb5a49876aa2c1833b3b7e2a2c7eb137aa --- /dev/null +++ b/paddle/framework/details/cow_ptr.h @@ -0,0 +1,98 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include + +namespace paddle { +namespace framework { +namespace details { + +// Change it to thread safe flags if needed. +class ThreadUnsafeOwnershipFlags { + public: + ThreadUnsafeOwnershipFlags(bool flag) : flag_(flag) {} + + ThreadUnsafeOwnershipFlags(const ThreadUnsafeOwnershipFlags& other) = delete; + ThreadUnsafeOwnershipFlags& operator=( + const ThreadUnsafeOwnershipFlags& other) = delete; + ThreadUnsafeOwnershipFlags(ThreadUnsafeOwnershipFlags&& other) = default; + + void SetOwnership(bool flag) { flag_ = flag; } + + // Invoke the callback if it is not owned. + template + void AcquireOwnershipOnce(Callback acquire) { + if (!flag_) { + acquire(); + flag_ = true; + } + } + + private: + bool flag_; +}; + +// Copy-On-Write pointer. +// It will hold a T* pointer, and only copy once when `MutableData` is invoked. +// +// The template parameter OwnershipFlags should have: +// * a constructor takes a bool. True if own. +// * SetOwnership(bool flag). +// * AcquireOwnershipOnce(Callback). It will invoke the callback if it is not +// owned. +// +// https://en.wikipedia.org/wiki/Copy-on-write +template +class COWPtr { + public: + // Ctor from raw pointer. + explicit COWPtr(T* ptr) : payload_(ptr), ownership_{true} {} + + // Move methods. Steal ownership from origin + COWPtr(COWPtr&& other) + : payload_(other.payload_), ownership_{std::move(other.ownership_)} {} + COWPtr& operator=(COWPtr&& origin) = default; + + // Copy methods. Not own payload + COWPtr(const COWPtr& other) : payload_(other.payload_), ownership_{false} {} + COWPtr& operator=(const COWPtr& other) { + payload_ = other.payload_; + ownership_.SetOwnership(false); + return *this; + } + + // Access read only data. + const T& Data() const { return *payload_; } + + // Access mutable data. If the data is not owned, the data will be copied + // before. + T* MutableData() { + ownership_.AcquireOwnershipOnce( + [this] { payload_.reset(new T(*payload_)); }); + return payload_.get(); + } + + private: + // Actual data pointer. + std::shared_ptr payload_; + + // Ownership flag. + OwnershipFlags ownership_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/details/cow_ptr_test.cc b/paddle/framework/details/cow_ptr_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..936954a2333e7e5d2a932abad641279db9ef7b9f --- /dev/null +++ b/paddle/framework/details/cow_ptr_test.cc @@ -0,0 +1,35 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/details/cow_ptr.h" +#include "gtest/gtest.h" + +namespace paddle { +namespace framework { +namespace details { + +TEST(COWPtr, all) { + COWPtr ptr(new int{0}); + ASSERT_EQ(ptr.Data(), 0); + COWPtr ptr2 = ptr; + ASSERT_EQ(ptr2.Data(), 0); + ASSERT_EQ(&ptr2.Data(), &ptr.Data()); + *ptr2.MutableData() = 10; + ASSERT_EQ(ptr.Data(), 0); + ASSERT_EQ(ptr2.Data(), 10); +} + +} // namespace details +} // namespace framework +} // namespace paddle