From 0cfb5465cdcb25e821ce690e57c77a68d7e6fc54 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Fri, 5 Jan 2018 16:40:45 +0800 Subject: [PATCH] Add COWPtr and its unittest It will be used for LoD information in LoDTensor since LoD is a copy on write field. It is pretty slow for copying LoD information between operators. For resnet it will cost roughly 10% time of whole time, including reading data. --- paddle/framework/CMakeLists.txt | 1 + paddle/framework/details/cow_ptr.h | 94 ++++++++++++++++++++++++ paddle/framework/details/cow_ptr_test.cc | 39 ++++++++++ 3 files changed, 134 insertions(+) create mode 100644 paddle/framework/details/cow_ptr.h create mode 100644 paddle/framework/details/cow_ptr_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index fb8c9ab96d3..221186d780e 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -78,3 +78,4 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece) cc_test(init_test SRCS init_test.cc DEPS init) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) +cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) diff --git a/paddle/framework/details/cow_ptr.h b/paddle/framework/details/cow_ptr.h new file mode 100644 index 00000000000..6f1dcab40b5 --- /dev/null +++ b/paddle/framework/details/cow_ptr.h @@ -0,0 +1,94 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include + +namespace paddle { +namespace framework { +namespace details { + +// Change it to thread safe flags if needed. +class ThreadUnsafeOwnershipFlags { + public: + ThreadUnsafeOwnershipFlags(bool flag) : flag_(flag) {} + + ThreadUnsafeOwnershipFlags(const ThreadUnsafeOwnershipFlags& o) = delete; + ThreadUnsafeOwnershipFlags& operator=(const ThreadUnsafeOwnershipFlags& o) = + delete; + ThreadUnsafeOwnershipFlags(ThreadUnsafeOwnershipFlags&& o) = default; + + void SetOwnership(bool flag) { flag_ = flag; } + + template + void AcquireOwnershipOnce(Callback acquire) { + if (!flag_) { + acquire(); + flag_ = true; + } + } + + private: + bool flag_; +}; + +// Copy On Write pointer. +// It will hold a T* pointer, and only copy once when `MutableData` is invoked. +// +// The template parameter OwnershipFlags should have: +// * a constructor takes a bool. True if own. +// * SetOwnership(bool flag). +// * AcquireOwnershipOnce(Callback). It will invoke the callback if it is not +// owned. +template +class COWPtr { + public: + // Ctor from raw pointer. + explicit COWPtr(T* ptr) : payload_(ptr), ownership_{true} {} + + // Move methods. Steal ownership from origin + COWPtr(COWPtr&& o) + : payload_(o.payload_), ownership_{std::move(o.ownership_)} {} + COWPtr& operator=(COWPtr&& origin) = default; + + // Copy methods. Not own payload + COWPtr(const COWPtr& o) : payload_(o.payload_), ownership_{false} {} + COWPtr& operator=(const COWPtr& o) { + payload_ = o.payload_; + ownership_.SetOwnership(false); + return *this; + } + + const T& Data() const { return *payload_; } + + T* MutableData() { + ownership_.AcquireOwnershipOnce( + [this] { payload_.reset(new T(*payload_)); }); + return payload_.get(); + } + + void Reset() { + ownership_.AcquireOwnershipOnce([this] { payload_.reset(); }); + payload_.reset(new T()); + } + + private: + std::shared_ptr payload_; + OwnershipFlags ownership_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/details/cow_ptr_test.cc b/paddle/framework/details/cow_ptr_test.cc new file mode 100644 index 00000000000..080a0a0a448 --- /dev/null +++ b/paddle/framework/details/cow_ptr_test.cc @@ -0,0 +1,39 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/details/cow_ptr.h" +#include "gtest/gtest.h" + +namespace paddle { +namespace framework { +namespace details { + +TEST(COWPtr, all) { + COWPtr ptr(new int{0}); + ASSERT_EQ(ptr.Data(), 0); + COWPtr ptr2 = ptr; + ASSERT_EQ(ptr2.Data(), 0); + ASSERT_EQ(&ptr2.Data(), &ptr.Data()); + *ptr2.MutableData() = 10; + ASSERT_EQ(ptr.Data(), 0); + ASSERT_EQ(ptr2.Data(), 10); + + auto ptr_before = ptr2.MutableData(); + ptr2.Reset(); + ASSERT_NE(ptr2.MutableData(), ptr_before); +} + +} // namespace details +} // namespace framework +} // namespace paddle -- GitLab