// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include // NOLINT #include "paddle/fluid/platform/device_context.h" namespace paddle { namespace framework { class GarbageCollector { public: using GarbageQueue = std::deque>; GarbageCollector(const platform::Place &place, size_t max_memory_size); virtual ~GarbageCollector() = default; virtual void Wait() const {} template void Add(Container &&objs); template void Add(Container &&objs, Callback &&callback); protected: virtual void ClearCallback(const std::function &callback) = 0; platform::DeviceContext *dev_ctx_; std::unique_ptr garbages_; mutable std::mutex mutex_; const size_t max_memory_size_; size_t cur_memory_size_{0}; }; class CPUGarbageCollector : public GarbageCollector { public: CPUGarbageCollector(const platform::CPUPlace &place, size_t max_memory_size); protected: void ClearCallback(const std::function &callback) override; }; #ifdef PADDLE_WITH_CUDA class UnsafeFastGPUGarbageCollector : public GarbageCollector { public: UnsafeFastGPUGarbageCollector(const platform::CUDAPlace &place, size_t max_memory_size); protected: void ClearCallback(const std::function &callback) override; }; class DefaultStreamGarbageCollector : public GarbageCollector { public: DefaultStreamGarbageCollector(const platform::CUDAPlace &place, size_t max_memory_size); void Wait() const override; protected: void ClearCallback(const std::function &callback) override; }; class StreamGarbageCollector : public GarbageCollector { public: StreamGarbageCollector(const platform::CUDAPlace &place, size_t max_memory_size); ~StreamGarbageCollector(); void Wait() const override; cudaStream_t stream() const; protected: void ClearCallback(const std::function &callback) override; private: cudaStream_t stream_; std::unique_ptr callback_manager_; }; #endif template void GarbageCollector::Add(Container &&objs) { Add(std::forward(objs), []() {}); } template void GarbageCollector::Add(Container &&objs, Callback &&callback) { GarbageQueue *garbage_queue = nullptr; { std::lock_guard guard(mutex_); for (auto &obj : objs) { if (!obj) continue; cur_memory_size_ += obj->size(); garbages_->push_back(std::move(obj)); } if (cur_memory_size_ >= max_memory_size_) { cur_memory_size_ = 0; garbage_queue = garbages_.release(); garbages_.reset(new GarbageQueue()); } } if (garbage_queue) { callback(); ClearCallback([garbage_queue]() { delete garbage_queue; }); } } } // namespace framework } // namespace paddle