From 09696d5df86e6163c7becb2922bb42b04f04066a Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 28 Nov 2019 20:27:15 +0800 Subject: [PATCH] Use system allocator in OpTest (#21335) * use system allocator in unittests, test=develop * fix op bugs, test=develop * fix tensor copy bug when src and dst are the same, test=develop --- paddle/fluid/framework/CMakeLists.txt | 2 + .../fluid/framework/copy_same_tensor_test.cc | 96 +++++++++++++++++++ paddle/fluid/framework/tensor_util.cc | 13 ++- .../memory/allocation/allocator_facade.cc | 46 +++++++-- .../operators/detection/yolov3_loss_op.h | 4 +- .../fluid/operators/fused/fusion_lstm_op.cc | 2 +- paddle/fluid/operators/lod_reset_op.cc | 8 +- paddle/fluid/operators/uniform_random_op.h | 4 +- paddle/fluid/operators/utils.h | 4 +- .../pybind/global_value_getter_setter.cc | 11 ++- python/paddle/fluid/__init__.py | 2 +- .../paddle/fluid/tests/unittests/op_test.py | 12 +++ 12 files changed, 179 insertions(+), 25 deletions(-) create mode 100644 paddle/fluid/framework/copy_same_tensor_test.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 79af44edfe..c5d8f92ad5 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -56,6 +56,8 @@ else() cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor dlpack_tensor) endif() +cc_test(copy_same_tensor_test SRCS copy_same_tensor_test.cc DEPS tensor) + cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) if(WITH_GPU) diff --git a/paddle/fluid/framework/copy_same_tensor_test.cc b/paddle/fluid/framework/copy_same_tensor_test.cc new file mode 100644 index 0000000000..9350c387a6 --- /dev/null +++ b/paddle/fluid/framework/copy_same_tensor_test.cc @@ -0,0 +1,96 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "gflags/gflags.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" + +DECLARE_bool(use_system_allocator); + +namespace paddle { +namespace framework { + +static std::vector CreatePlaceList() { + std::vector places; + places.emplace_back(platform::CPUPlace()); +#ifdef PADDLE_WITH_CUDA + places.emplace_back(platform::CUDAPlace(0)); +#endif + return places; +} + +template +static bool CopySameTensorTestMain(const DDim &dims, + const platform::Place &src_place, + const platform::Place &dst_place, + bool sync_copy) { + FLAGS_use_system_allocator = true; // force to use system allocator + + // Step 1: create a cpu tensor and initialize it with random value; + Tensor src_cpu_tensor; + { + src_cpu_tensor.Resize(dims); + auto *src_ptr_cpu = src_cpu_tensor.mutable_data(platform::CPUPlace()); + int64_t num = src_cpu_tensor.numel(); + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dist(-1000, 1000); + for (int64_t i = 0; i < num; ++i) { + src_ptr_cpu[i] = dist(gen); + } + } + + // Step 2: copy the source tensor to dst place + Tensor dst_cpu_tensor; + { + Tensor src_tensor; + TensorCopySync(src_cpu_tensor, src_place, &src_tensor); + + // The source tensor and dst_tensor is the same + if (sync_copy) { + TensorCopySync(src_tensor, dst_place, &src_tensor); + } else { + TensorCopy(src_tensor, dst_place, &src_tensor); + platform::DeviceContextPool::Instance().Get(src_place)->Wait(); + platform::DeviceContextPool::Instance().Get(dst_place)->Wait(); + } + + // Get the result cpu tensor + TensorCopySync(src_tensor, platform::CPUPlace(), &dst_cpu_tensor); + } + + const void *ground_truth_ptr = src_cpu_tensor.data(); + const void *result_ptr = dst_cpu_tensor.data(); + size_t byte_num = product(dims) * sizeof(T); + return std::memcmp(ground_truth_ptr, result_ptr, byte_num) == 0; +} + +TEST(test_tensor_copy, test_copy_same_tensor) { + using DataType = float; + auto dims = make_ddim({3, 4, 5}); + + auto places = CreatePlaceList(); + for (auto &src_p : places) { + for (auto &dst_p : places) { + ASSERT_TRUE(CopySameTensorTestMain(dims, src_p, dst_p, true)); + ASSERT_TRUE(CopySameTensorTestMain(dims, src_p, dst_p, false)); + } + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 16f0a4c6ff..2fdac0dc0e 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -25,6 +25,12 @@ namespace framework { void TensorCopy(const Tensor& src, const platform::Place& dst_place, const platform::DeviceContext& ctx, Tensor* dst) { + if (&src == dst) { + auto src_copy = src; + TensorCopy(src_copy, dst_place, ctx, dst); + return; + } + VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to " << dst_place; src.check_memory_size(); @@ -33,7 +39,6 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, dst->set_layout(src.layout()); auto src_place = src.place(); auto src_ptr = src.data(); - auto dst_ptr = dst->mutable_data(dst_place, src.type()); if (src_ptr == dst_ptr && src_place == dst_place) { @@ -115,6 +120,12 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, void TensorCopySync(const Tensor& src, const platform::Place& dst_place, Tensor* dst) { + if (&src == dst) { + auto src_copy = src; + TensorCopySync(src_copy, dst_place, dst); + return; + } + VLOG(3) << "TensorCopySync " << src.dims() << " from " << src.place() << " to " << dst_place; src.check_memory_size(); diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 220b50b1cc..63763acb64 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -41,12 +41,18 @@ DEFINE_int64( "The retry time (milliseconds) when allocator fails " "to allocate memory. No retry if this value is not greater than 0"); +DEFINE_bool(use_system_allocator, false, + "Whether to use system allocator to allocate CPU and GPU memory. " + "Only used for unittests."); + namespace paddle { namespace memory { namespace allocation { class AllocatorFacadePrivate { public: + using AllocatorMap = std::map>; + AllocatorFacadePrivate() { auto strategy = GetAllocatorStrategy(); switch (strategy) { @@ -80,6 +86,7 @@ class AllocatorFacadePrivate { } } InitZeroSizeAllocators(); + InitSystemAllocators(); if (FLAGS_gpu_allocator_retry_time > 0) { WrapCUDARetryAllocator(FLAGS_gpu_allocator_retry_time); @@ -90,7 +97,10 @@ class AllocatorFacadePrivate { inline const std::shared_ptr& GetAllocator( const platform::Place& place, size_t size) { - const auto& allocators = (size > 0 ? allocators_ : zero_size_allocators_); + const auto& allocators = + (size > 0 ? (UNLIKELY(FLAGS_use_system_allocator) ? system_allocators_ + : allocators_) + : zero_size_allocators_); auto iter = allocators.find(place); PADDLE_ENFORCE(iter != allocators.end(), "No such allocator for the place, %s", place); @@ -98,6 +108,19 @@ class AllocatorFacadePrivate { } private: + void InitSystemAllocators() { + system_allocators_[platform::CPUPlace()] = std::make_shared(); +#ifdef PADDLE_WITH_CUDA + system_allocators_[platform::CUDAPinnedPlace()] = + std::make_shared(); + int device_count = platform::GetCUDADeviceCount(); + for (int i = 0; i < device_count; ++i) { + platform::CUDAPlace p(i); + system_allocators_[p] = std::make_shared(p); + } +#endif + } + void InitNaiveBestFitCPUAllocator() { allocators_[platform::CPUPlace()] = std::make_shared(platform::CPUPlace()); @@ -153,14 +176,18 @@ class AllocatorFacadePrivate { } } - void CheckAllocThreadSafe() const { - for (auto& pair : allocators_) { - PADDLE_ENFORCE_EQ(pair.second->IsAllocThreadSafe(), true); + static void CheckAllocThreadSafe(const AllocatorMap& allocators) { + for (auto& pair : allocators) { + PADDLE_ENFORCE_EQ(pair.second->IsAllocThreadSafe(), true, + platform::errors::InvalidArgument( + "Public allocators must be thread safe")); } + } - for (auto& pair : zero_size_allocators_) { - PADDLE_ENFORCE_EQ(pair.second->IsAllocThreadSafe(), true); - } + void CheckAllocThreadSafe() const { + CheckAllocThreadSafe(allocators_); + CheckAllocThreadSafe(zero_size_allocators_); + CheckAllocThreadSafe(system_allocators_); } void WrapCUDARetryAllocator(size_t retry_time) { @@ -173,8 +200,9 @@ class AllocatorFacadePrivate { } private: - std::map> allocators_; - std::map> zero_size_allocators_; + AllocatorMap allocators_; + AllocatorMap zero_size_allocators_; + AllocatorMap system_allocators_; }; // Pimpl. Make interface clean. diff --git a/paddle/fluid/operators/detection/yolov3_loss_op.h b/paddle/fluid/operators/detection/yolov3_loss_op.h index f8d49960c7..b29ad1c920 100644 --- a/paddle/fluid/operators/detection/yolov3_loss_op.h +++ b/paddle/fluid/operators/detection/yolov3_loss_op.h @@ -299,8 +299,8 @@ class Yolov3LossKernel : public framework::OpKernel { gt_match_mask->mutable_data({n, b}, ctx.GetPlace()); const T* gt_score_data; + Tensor gtscore; if (!gt_score) { - Tensor gtscore; gtscore.mutable_data({n, b}, ctx.GetPlace()); math::SetConstant()( ctx.template device_context(), >score, @@ -454,8 +454,8 @@ class Yolov3LossGradKernel : public framework::OpKernel { memset(input_grad_data, 0, input_grad->numel() * sizeof(T)); const T* gt_score_data; + Tensor gtscore; if (!gt_score) { - Tensor gtscore; gtscore.mutable_data({n, b}, ctx.GetPlace()); math::SetConstant()( ctx.template device_context(), >score, diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index c256e581ee..1cc5df8600 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -389,7 +389,7 @@ class FuisonLSTMKernel : public framework::OpKernel { const T* c0_data = c0->data(); prev_h_data = reordered_h0_data; prev_c_data = reordered_c0_data; - size_t sz = sizeof(T) * D; + size_t sz = D; for (int i = 0; i < max_bs; ++i) { blas.VCOPY(sz, h0_data + seq_order[i] * D, reordered_h0_data); blas.VCOPY(sz, c0_data + seq_order[i] * D, reordered_c0_data); diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc index c7cd230a45..e003c7f150 100644 --- a/paddle/fluid/operators/lod_reset_op.cc +++ b/paddle/fluid/operators/lod_reset_op.cc @@ -205,8 +205,8 @@ class LoDResetGradMaker : public framework::SingleGradOpMaker { } }; -DECLARE_INPLACE_OP_INFERER(LodResetInplaceInferer, {"X", "Out"}); -DECLARE_INPLACE_OP_INFERER(LodResetGradInplaceInferer, +DECLARE_INPLACE_OP_INFERER(LoDResetInplaceInferer, {"X", "Out"}); +DECLARE_INPLACE_OP_INFERER(LoDResetGradInplaceInferer, {framework::GradVarName("Out"), framework::GradVarName("X")}); @@ -220,10 +220,10 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker, ops::LoDResetGradMaker, ops::LoDResetGradMaker, - ops::LoDResetOpVarTypeInference, ops::LodResetInplaceInferer); + ops::LoDResetOpVarTypeInference, ops::LoDResetInplaceInferer); REGISTER_OPERATOR(lod_reset_grad, ops::LoDResetGradOp, ops::LoDResetGradNoNeedBufferVarInference, - ops::LodResetGradInplaceInferer); + ops::LoDResetGradInplaceInferer); REGISTER_OP_CPU_KERNEL( lod_reset, ops::LoDResetKernel, diff --git a/paddle/fluid/operators/uniform_random_op.h b/paddle/fluid/operators/uniform_random_op.h index 32e3b034fd..649437aded 100644 --- a/paddle/fluid/operators/uniform_random_op.h +++ b/paddle/fluid/operators/uniform_random_op.h @@ -26,8 +26,8 @@ inline std::vector GetNewDataFromShapeTensor( const Tensor *new_data_tensor) { if (new_data_tensor->type() == framework::proto::VarType::INT64) { auto *new_data = new_data_tensor->data(); + framework::Tensor cpu_starts_tensor; if (platform::is_gpu_place(new_data_tensor->place())) { - framework::Tensor cpu_starts_tensor; TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor); new_data = cpu_starts_tensor.data(); @@ -38,8 +38,8 @@ inline std::vector GetNewDataFromShapeTensor( } else if (new_data_tensor->type() == framework::proto::VarType::INT32) { auto *new_data = new_data_tensor->data(); std::vector vec_new_data; + framework::Tensor cpu_starts_tensor; if (platform::is_gpu_place(new_data_tensor->place())) { - framework::Tensor cpu_starts_tensor; TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor); new_data = cpu_starts_tensor.data(); diff --git a/paddle/fluid/operators/utils.h b/paddle/fluid/operators/utils.h index 72fb17a34a..89141f1ff0 100644 --- a/paddle/fluid/operators/utils.h +++ b/paddle/fluid/operators/utils.h @@ -25,16 +25,16 @@ inline std::vector GetDataFromTensor(const framework::Tensor* x) { std::vector vec_new_data; if (x->type() == framework::proto::VarType::INT32) { auto* data = x->data(); + framework::Tensor cpu_attr_tensor; if (platform::is_gpu_place(x->place())) { - framework::Tensor cpu_attr_tensor; TensorCopySync(*x, platform::CPUPlace(), &cpu_attr_tensor); data = cpu_attr_tensor.data(); } vec_new_data = std::vector(data, data + x->numel()); } else if (x->type() == framework::proto::VarType::INT64) { auto* data = x->data(); + framework::Tensor cpu_attr_tensor; if (platform::is_gpu_place(x->place())) { - framework::Tensor cpu_attr_tensor; TensorCopySync(*x, platform::CPUPlace(), &cpu_attr_tensor); data = cpu_attr_tensor.data(); } diff --git a/paddle/fluid/pybind/global_value_getter_setter.cc b/paddle/fluid/pybind/global_value_getter_setter.cc index 10c9fb8bdb..d84e0d94a6 100644 --- a/paddle/fluid/pybind/global_value_getter_setter.cc +++ b/paddle/fluid/pybind/global_value_getter_setter.cc @@ -29,6 +29,7 @@ DECLARE_double(eager_delete_tensor_gb); DECLARE_bool(use_mkldnn); DECLARE_bool(use_ngraph); +DECLARE_bool(use_system_allocator); namespace paddle { namespace pybind { @@ -150,9 +151,12 @@ void BindGlobalValueGetterSetter(pybind11::module *module) { GlobalVarGetterSetterRegistry::MutableInstance()->RegisterGetter( \ #var, []() -> py::object { return py::cast(var); }) -#define REGISTER_GLOBAL_VAR_SETTER_ONLY(var) \ - GlobalVarGetterSetterRegistry::MutableInstance()->RegisterSetter( \ - #var, [](const py::object &obj) { var = py::cast(obj); }) +#define REGISTER_GLOBAL_VAR_SETTER_ONLY(var) \ + GlobalVarGetterSetterRegistry::MutableInstance()->RegisterSetter( \ + #var, [](const py::object &obj) { \ + using ValueType = std::remove_reference::type; \ + var = py::cast(obj); \ + }) #define REGISTER_GLOBAL_VAR_GETTER_SETTER(var) \ REGISTER_GLOBAL_VAR_GETTER_ONLY(var); \ @@ -162,6 +166,7 @@ static void RegisterGlobalVarGetterSetter() { REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_use_mkldnn); REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_use_ngraph); REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_eager_delete_tensor_gb); + REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_use_system_allocator); } } // namespace pybind diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 6336911da9..1afa69174f 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -168,7 +168,7 @@ def __bootstrap__(): 'print_sub_graph_dir', 'pe_profile_fname', 'inner_op_parallelism', 'enable_parallel_graph', 'fuse_parameter_groups_size', 'multiple_of_cupti_buffer_size', 'fuse_parameter_memory_size', - 'tracer_profile_fname', 'dygraph_debug' + 'tracer_profile_fname', 'dygraph_debug', 'use_system_allocator' ] if 'Darwin' not in sysstr: read_env_flags.append('use_pinned_memory') diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 723fd48f28..c0dd0809c4 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -35,6 +35,14 @@ from testsuite import create_op, set_input, append_input_output, append_loss_ops from paddle.fluid import unique_name +def _set_use_system_allocator(value=None): + USE_SYSTEM_ALLOCATOR_FLAG = "FLAGS_use_system_allocator" + old_value = core.globals()[USE_SYSTEM_ALLOCATOR_FLAG] + value = old_value if value is None else value + core.globals()[USE_SYSTEM_ALLOCATOR_FLAG] = value + return old_value + + def randomize_probability(batch_size, class_num, dtype='float32'): prob = np.random.uniform( 0.1, 1.0, size=(batch_size, class_num)).astype(dtype) @@ -146,12 +154,16 @@ class OpTest(unittest.TestCase): np.random.seed(123) random.seed(124) + cls._use_system_allocator = _set_use_system_allocator(True) + @classmethod def tearDownClass(cls): """Restore random seeds""" np.random.set_state(cls._np_rand_state) random.setstate(cls._py_rand_state) + _set_use_system_allocator(cls._use_system_allocator) + def try_call_once(self, data_type): if not self.call_once: self.call_once = True -- GitLab