garbage_collector.h 6.9 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <deque>
#include <functional>
#include <memory>
#include <mutex>  // NOLINT
S
sneaxiy 已提交
21
#include <utility>
W
wanghuancoder 已提交
22

S
sneaxiy 已提交
23
#include "gflags/gflags.h"
S
sneaxiy 已提交
24
#include "paddle/fluid/platform/device_context.h"
F
fwenguang 已提交
25 26 27
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/device_context.h"
#endif
S
sneaxiy 已提交
28

W
wanghuancoder 已提交
29 30 31 32 33 34
namespace paddle {
namespace platform {
class DeviceContext;
}  // namespace platform
}  // namespace paddle

S
sneaxiy 已提交
35 36 37 38 39
namespace paddle {
namespace framework {

class GarbageCollector {
 public:
S
sneaxiy 已提交
40
  using GarbageQueue = std::deque<std::shared_ptr<memory::Allocation>>;
S
sneaxiy 已提交
41

S
sneaxiy 已提交
42
  GarbageCollector(const platform::Place &place, size_t max_memory_size);
S
sneaxiy 已提交
43

Z
Zeng Jinle 已提交
44
  virtual ~GarbageCollector() PADDLE_MAY_THROW {}
S
fix bug  
sneaxiy 已提交
45

S
sneaxiy 已提交
46
  virtual void Wait() const {}
S
sneaxiy 已提交
47 48

  template <typename Container>
S
sneaxiy 已提交
49
  void Add(Container &&objs);
S
sneaxiy 已提交
50 51

  template <typename Container, typename Callback>
S
sneaxiy 已提交
52
  void Add(Container &&objs, Callback &&callback);
S
sneaxiy 已提交
53

L
Leo Chen 已提交
54 55 56 57
  void DirectClearCallback(const std::function<void()> &callback) {
    ClearCallback(callback);
  }

S
sneaxiy 已提交
58 59 60 61
 protected:
  virtual void ClearCallback(const std::function<void()> &callback) = 0;

  platform::DeviceContext *dev_ctx_;
S
sneaxiy 已提交
62
  std::unique_ptr<GarbageQueue> garbages_;
Z
Zeng Jinle 已提交
63
  mutable std::unique_ptr<std::mutex> mutex_;
S
sneaxiy 已提交
64
  const size_t max_memory_size_;
S
sneaxiy 已提交
65
  size_t cur_memory_size_{0};
S
sneaxiy 已提交
66 67
};

S
sneaxiy 已提交
68
class CPUGarbageCollector : public GarbageCollector {
S
sneaxiy 已提交
69
 public:
S
sneaxiy 已提交
70
  CPUGarbageCollector(const platform::CPUPlace &place, size_t max_memory_size);
S
sneaxiy 已提交
71 72

 protected:
S
sneaxiy 已提交
73
  void ClearCallback(const std::function<void()> &callback) override;
S
sneaxiy 已提交
74 75
};

76 77 78 79 80 81 82 83 84 85
#ifdef PADDLE_WITH_XPU
class XPUGarbageCollector : public GarbageCollector {
 public:
  XPUGarbageCollector(const platform::XPUPlace &place, size_t max_memory_size);

 protected:
  void ClearCallback(const std::function<void()> &callback) override;
};
#endif

J
jianghaicheng 已提交
86 87 88 89 90 91 92 93 94 95
#ifdef PADDLE_WITH_IPU
class IPUGarbageCollector : public GarbageCollector {
 public:
  IPUGarbageCollector(const platform::IPUPlace &place, size_t max_memory_size);

 protected:
  void ClearCallback(const std::function<void()> &callback) override;
};
#endif

96
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
sneaxiy 已提交
97
class UnsafeFastGPUGarbageCollector : public GarbageCollector {
S
fix bug  
sneaxiy 已提交
98 99
 public:
  UnsafeFastGPUGarbageCollector(const platform::CUDAPlace &place,
S
sneaxiy 已提交
100
                                size_t max_memory_size);
S
fix bug  
sneaxiy 已提交
101 102

 protected:
S
sneaxiy 已提交
103
  void ClearCallback(const std::function<void()> &callback) override;
S
fix bug  
sneaxiy 已提交
104 105
};

S
sneaxiy 已提交
106
class DefaultStreamGarbageCollector : public GarbageCollector {
S
sneaxiy 已提交
107 108
 public:
  DefaultStreamGarbageCollector(const platform::CUDAPlace &place,
S
sneaxiy 已提交
109
                                size_t max_memory_size);
S
sneaxiy 已提交
110

S
sneaxiy 已提交
111
  void Wait() const override;
S
sneaxiy 已提交
112 113

 protected:
S
sneaxiy 已提交
114
  void ClearCallback(const std::function<void()> &callback) override;
S
sneaxiy 已提交
115 116
};

S
sneaxiy 已提交
117
class StreamGarbageCollector : public GarbageCollector {
S
sneaxiy 已提交
118 119
 public:
  StreamGarbageCollector(const platform::CUDAPlace &place,
S
sneaxiy 已提交
120
                         size_t max_memory_size);
S
sneaxiy 已提交
121

S
sneaxiy 已提交
122
  ~StreamGarbageCollector();
S
sneaxiy 已提交
123

S
sneaxiy 已提交
124
  void Wait() const override;
S
sneaxiy 已提交
125

126
  gpuStream_t stream() const;
S
sneaxiy 已提交
127 128

 protected:
S
sneaxiy 已提交
129
  void ClearCallback(const std::function<void()> &callback) override;
S
sneaxiy 已提交
130 131

 private:
132
  gpuStream_t stream_;
133 134
  std::unique_ptr<platform::StreamCallbackManager<gpuStream_t>>
      callback_manager_;
S
sneaxiy 已提交
135
};
136 137 138 139 140 141 142 143 144

class CUDAPinnedGarbageCollector : public GarbageCollector {
 public:
  CUDAPinnedGarbageCollector(const platform::CUDAPinnedPlace &place,
                             size_t max_memory_size);

 protected:
  void ClearCallback(const std::function<void()> &callback) override;
};
S
sneaxiy 已提交
145 146
#endif

147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
#ifdef PADDLE_WITH_ASCEND_CL
class NPUDefaultStreamGarbageCollector : public GarbageCollector {
 public:
  NPUDefaultStreamGarbageCollector(const platform::NPUPlace &place,
                                   size_t max_memory_size);

  void Wait() const override;

 protected:
  void ClearCallback(const std::function<void()> &callback) override;
};

class NPUUnsafeFastGarbageCollector : public GarbageCollector {
 public:
  NPUUnsafeFastGarbageCollector(const platform::NPUPlace &place,
                                size_t max_memory_size);

 protected:
  void ClearCallback(const std::function<void()> &callback) override;
};
#endif

F
fwenguang 已提交
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
#ifdef PADDLE_WITH_MLU
class MLUDefaultStreamGarbageCollector : public GarbageCollector {
 public:
  MLUDefaultStreamGarbageCollector(const platform::MLUPlace &place,
                                   size_t max_memory_size);

  void Wait() const override;

 protected:
  void ClearCallback(const std::function<void()> &callback) override;
};

class MLUUnsafeFastGarbageCollector : public GarbageCollector {
 public:
  MLUUnsafeFastGarbageCollector(const platform::MLUPlace &place,
                                size_t max_memory_size);

 protected:
  void ClearCallback(const std::function<void()> &callback) override;
};
class MLUStreamGarbageCollector : public GarbageCollector {
 public:
  MLUStreamGarbageCollector(const platform::MLUPlace &place,
                            size_t max_memory_size);

  ~MLUStreamGarbageCollector();

  void Wait() const override;

  mluStream stream() const;

 protected:
  void ClearCallback(const std::function<void()> &callback) override;

 private:
  mluStream stream_;
  std::unique_ptr<platform::StreamCallbackManager<mluStream>> callback_manager_;
};
#endif

S
sneaxiy 已提交
209 210 211 212 213 214 215
template <typename Container>
void GarbageCollector::Add(Container &&objs) {
  Add(std::forward<Container>(objs), []() {});
}

template <typename Container, typename Callback>
void GarbageCollector::Add(Container &&objs, Callback &&callback) {
Z
Zeng Jinle 已提交
216 217 218 219 220 221 222 223 224
  // Special case when FLAGS_eager_delete_tensor_gb=0.0
  // It speeds up GC about 2~3%.
  if (max_memory_size_ <= 1) {
    callback();
    auto *container = new Container(std::move(objs));
    ClearCallback([container] { delete container; });
    return;
  }

S
sneaxiy 已提交
225 226
  GarbageQueue *garbage_queue = nullptr;
  {
Z
Zeng Jinle 已提交
227
    std::lock_guard<std::mutex> guard(*mutex_);
S
sneaxiy 已提交
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
    for (auto &obj : objs) {
      if (!obj) continue;
      cur_memory_size_ += obj->size();
      garbages_->push_back(std::move(obj));
    }
    if (cur_memory_size_ >= max_memory_size_) {
      cur_memory_size_ = 0;
      garbage_queue = garbages_.release();
      garbages_.reset(new GarbageQueue());
    }
  }

  if (garbage_queue) {
    callback();
    ClearCallback([garbage_queue]() { delete garbage_queue; });
  }
}

S
sneaxiy 已提交
246 247 248
int64_t GetEagerDeletionThreshold();
bool IsFastEagerDeletionModeEnabled();

S
sneaxiy 已提交
249 250 251
void SetEagerDeletionMode(double threshold, double fraction, bool fast_mode);

double GetEagerDeletionMemoryFraction();
S
sneaxiy 已提交
252

S
sneaxiy 已提交
253 254
}  // namespace framework
}  // namespace paddle