diff --git a/paddle/fluid/memory/allocation/CMakeLists.txt b/paddle/fluid/memory/allocation/CMakeLists.txt index 71cf12ebf02df6b2217c239385584421042f0340..94dc13ad5f0184da96c2241a2b912b03ff6d1a93 100644 --- a/paddle/fluid/memory/allocation/CMakeLists.txt +++ b/paddle/fluid/memory/allocation/CMakeLists.txt @@ -35,6 +35,7 @@ endif() cc_library(aligned_allocator SRCS aligned_allocator.cc DEPS allocator) cc_library(auto_increment_allocator SRCS auto_increment_allocator.cc DEPS allocator) cc_library(zero_size_allocator SRCS zero_size_allocator.cc DEPS allocator) +cc_library(conditional_allocator SRCS conditional_allocator.cc DEPS allocator) cc_library(allocator_facade SRCS allocator_facade.cc DEPS ${AllocatorFacadeDeps} cpu_allocator @@ -44,6 +45,7 @@ cc_library(allocator_facade SRCS allocator_facade.cc DEPS aligned_allocator auto_increment_allocator zero_size_allocator + conditional_allocator cuda_device_guard) nv_test(allocation_and_eigen_test SRCS allocation_and_eigen_test.cu DEPS allocator_facade) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 971e7d02c58ed6482b758a48a803d06b5e9fe692..7816aec8f788e536293f77f21cbc22ba7bcbebd5 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -19,6 +19,7 @@ #include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/memory/allocation/auto_increment_allocator.h" #include "paddle/fluid/memory/allocation/best_fit_allocator.h" +#include "paddle/fluid/memory/allocation/conditional_allocator.h" #include "paddle/fluid/memory/allocation/cpu_allocator.h" #include "paddle/fluid/memory/allocation/locked_allocator.h" #include "paddle/fluid/memory/allocation/naive_managed_allocator.h" @@ -77,6 +78,18 @@ class CUDAManagedAllocator : public ManagedAllocator { new CUDAAllocator(platform::CUDAPlace(dev_id)))); default_allocator_ = std::make_shared( [this] { return std::move(BestFitAllocatorCreator()); }); + + auto* cond_allocator = new ConditionalAllocator(); + cond_allocator + ->AddAllocator( + [this](size_t size, Attr attr) { return size < max_chunk_size_; }, + default_allocator_) + .AddAllocator( + [](size_t size, Attr attr) { + return true; // default case + }, + raw_allocator_); + default_allocator_.reset(cond_allocator); } ~CUDAManagedAllocator() { diff --git a/paddle/fluid/memory/allocation/conditional_allocator.cc b/paddle/fluid/memory/allocation/conditional_allocator.cc new file mode 100644 index 0000000000000000000000000000000000000000..2df10a89bc2d94920b286e4842f4b6a3d4918375 --- /dev/null +++ b/paddle/fluid/memory/allocation/conditional_allocator.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/memory/allocation/conditional_allocator.h" + +namespace paddle { +namespace memory { +namespace allocation { + +ConditionalAllocator& ConditionalAllocator::AddAllocator( + std::function func, + std::shared_ptr allocator) { + underlying_allocators_.emplace_back(std::move(func), std::move(allocator)); + return *this; +} +std::unique_ptr ConditionalAllocator::Allocate( + size_t size, Allocator::Attr attr) { + return SelectAndInvoke(size, attr, [&](ManagedAllocator& allocator) { + return allocator.Allocate(size, attr); + }); +} +std::shared_ptr ConditionalAllocator::AllocateShared( + size_t size, Allocator::Attr attr) { + return SelectAndInvoke(size, attr, [&](ManagedAllocator& allocator) { + return allocator.AllocateShared(size, attr); + }); +} +bool ConditionalAllocator::IsAllocThreadSafe() const { return true; } + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/allocation/conditional_allocator.h b/paddle/fluid/memory/allocation/conditional_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..f993857c79400115ecef29af164877ba830c50b3 --- /dev/null +++ b/paddle/fluid/memory/allocation/conditional_allocator.h @@ -0,0 +1,55 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "paddle/fluid/memory/allocation/allocator.h" + +namespace paddle { +namespace memory { +namespace allocation { + +class ConditionalAllocator : public ManagedAllocator { + public: + ConditionalAllocator() = default; + + ConditionalAllocator& AddAllocator( + std::function func, + std::shared_ptr allocator); + std::unique_ptr Allocate(size_t size, Attr attr) override; + std::shared_ptr AllocateShared(size_t size, Attr attr) override; + bool IsAllocThreadSafe() const override; + + private: + template + inline typename std::result_of::type + SelectAndInvoke(size_t size, Attr attr, Callback callback) { + for (auto& pair : underlying_allocators_) { + if (pair.first(size, attr)) { + return callback(*pair.second); + } + } + PADDLE_THROW("No suitable allocator"); + } + + std::vector, + std::shared_ptr>> + underlying_allocators_; +}; + +} // namespace allocation +} // namespace memory +} // namespace paddle