device_context.h 12.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
QI JUN 已提交
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

13
#include <future>  // NOLINT
D
dzhwinter 已提交
14
#include <memory>
Y
yuyang18 已提交
15
#include <mutex>  // NOLINT
16
#include <string>
D
dzhwinter 已提交
17
#include <unordered_map>
18
#include <utility>
19
#include <vector>
Y
Yu Yang 已提交
20
#include "paddle/fluid/memory/malloc.h"
21
#include "paddle/fluid/platform/temporary_allocator.h"
22
#ifdef PADDLE_WITH_CUDA
23
#include "paddle/fluid/platform/cuda_helper.h"
Y
Yi Wang 已提交
24 25
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
W
Wu Yi 已提交
26
#if !defined(__APPLE__) && !defined(_WIN32)
W
Wu Yi 已提交
27
#include "paddle/fluid/platform/dynload/nccl.h"
W
Wu Yi 已提交
28
#endif
Y
Yi Wang 已提交
29
#include "paddle/fluid/platform/gpu_info.h"
Q
QI JUN 已提交
30
#endif
D
dzhwinter 已提交
31

T
tensor-tang 已提交
32
#ifdef PADDLE_WITH_MKLDNN
L
luotao1 已提交
33
#include "mkldnn.hpp"
T
tensor-tang 已提交
34 35
#endif

36 37
#include <map>
#include "glog/logging.h"
Y
Yi Wang 已提交
38 39
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
S
sneaxiy 已提交
40 41 42
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/stream_callback_manager.h"
#endif
Q
qijun 已提交
43
#include "unsupported/Eigen/CXX11/Tensor"
Q
QI JUN 已提交
44 45 46 47

namespace paddle {
namespace platform {

C
chengduo 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
/*! \brief device temporary allocator singleton.
 *
 * Some operator needs temporary memory during computation, for example,
 * conv_gemm, which needs use col to store the result of im2col. If we
 * create a stack memory which is used by CUDA Kernel, before the
 * Computation(...) returns, we should add ctx->Wait(), because the
 * execution of CUDA is async, if there doesn't have ctx->Wait(),
 * the temporary memory will be released before the CUDA Kernel uses
 * it.
 *
 * DeviceTemporaryAllocator is a singleton, which contains a
 * `TemporaryAllocator` for each <Place, Stream>. And the TemporaryAllocator
 * contains a temp_allocation_queue which is used to store the temporary
 * allocations. The allocation, which is allocated by TemporaryAllocator,
 * is a unique_ptr,  and when it is not held by any variable, it will be
 * pushed into the temp_allocation_queue. There are two opportunities to free
 * the allocations of temp_allocation_queue:
 *  - when the Stream calls cudaStreamSynchronize;
 *  - when the allocation size of opportunities exceeds a certain threshold
67
 *    (defined by FLAGS_limit_of_tmp_allocation).
C
chengduo 已提交
68 69
 *
 * */
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
class DeviceTemporaryAllocator {
 public:
  static DeviceTemporaryAllocator& Instance() {
    PADDLE_ENFORCE_NOT_NULL(allocators,
                            "Need to Create DeviceTemporaryAllocator first!");
    return *allocators;
  }

  static DeviceTemporaryAllocator& Init() {
    if (allocators == nullptr) {
      allocators = new DeviceTemporaryAllocator();
    }
    return *allocators;
  }

/*! \brief  Return handle of single temporary allocator. */
#ifdef PADDLE_WITH_CUDA
  platform::TemporaryAllocator& Get(const platform::Place& place,
                                    const cudaStream_t& stream);
#endif
  template <typename DeviceContext>
  platform::TemporaryAllocator& Get(const DeviceContext& dev_ctx);

  platform::TemporaryAllocator& Get(const platform::Place& place);

 private:
  DeviceTemporaryAllocator() : cpu_allocator_(platform::CPUPlace()) {}

  static DeviceTemporaryAllocator* allocators;

  platform::TemporaryAllocator cpu_allocator_;

#ifdef PADDLE_WITH_CUDA
  std::map<std::pair<platform::Place, cudaStream_t>,
           std::unique_ptr<platform::TemporaryAllocator>>
      device_allocator_;
#endif

  std::mutex mtx_;

  DISABLE_COPY_AND_ASSIGN(DeviceTemporaryAllocator);
};

Q
QI JUN 已提交
113 114 115
class DeviceContext {
 public:
  virtual ~DeviceContext() {}
L
liaogang 已提交
116
  virtual Place GetPlace() const = 0;
Q
QI JUN 已提交
117

118
  virtual void Wait() const {}
Q
QI JUN 已提交
119 120
};

Q
qijun 已提交
121 122
class CPUDeviceContext : public DeviceContext {
 public:
123
  CPUDeviceContext();
Q
qijun 已提交
124
  explicit CPUDeviceContext(CPUPlace place);
Q
qijun 已提交
125

126
  Eigen::DefaultDevice* eigen_device() const;
Q
qijun 已提交
127

L
liaogang 已提交
128
  Place GetPlace() const override;
Y
Yu Yang 已提交
129

Q
qijun 已提交
130
 private:
D
dzhwinter 已提交
131
  CPUPlace place_;
Q
qijun 已提交
132
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
Q
QI JUN 已提交
133 134
};

Y
Yang Yu 已提交
135 136 137 138 139 140 141 142
template <typename Place>
struct DefaultDeviceContextType;

template <>
struct DefaultDeviceContextType<platform::CPUPlace> {
  using TYPE = CPUDeviceContext;
};

143
#ifdef PADDLE_WITH_CUDA
144

Q
qijun 已提交
145
class EigenCudaStreamDevice;
S
sneaxiy 已提交
146 147 148 149 150 151 152 153 154 155 156 157
class CudnnHolder {
 public:
  CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place);
  ~CudnnHolder();
  cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }

 private:
  friend class CudnnWorkspaceHandle;
  void ReallocateWorkspace(size_t required_workspace_len);

  template <typename Callback>
  void RunFuncImpl(Callback&& cudnn_func, size_t required_workspace_len) {
Y
Yu Yang 已提交
158
    if (required_workspace_len > WorkspaceSize()) {
S
sneaxiy 已提交
159 160
      ReallocateWorkspace(required_workspace_len);
    }
Z
Zeng Jinle 已提交
161 162
    VLOG(2) << "Cudnn workspace size: "
            << static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
Y
Yu Yang 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
    cudnn_func(WorkspacePtr());
  }

  inline void* WorkspacePtr() {
    if (workspace_) {
      return workspace_->ptr();
    } else {
      return nullptr;
    }
  }

  inline size_t WorkspaceSize() {
    if (workspace_) {
      return workspace_->size();
    } else {
      return 0;
    }
S
sneaxiy 已提交
180 181 182 183 184
  }

  std::mutex& Mutex() { return mtx_; }

  cudnnHandle_t cudnn_handle_;
Y
Yu Yang 已提交
185
  memory::AllocationPtr workspace_;
S
sneaxiy 已提交
186 187 188 189 190 191

  const cudaStream_t* stream_;  // not owned;
  const CUDAPlace place_;

  std::mutex mtx_;
};
D
dongzhihong 已提交
192

S
sneaxiy 已提交
193 194 195 196
class CudnnWorkspaceHandle {
 public:
  /*! \brief The lock would not be acquired when constructor calls.
   *  The lock would be acquired when RunFunc() is called first time. */
S
sneaxiy 已提交
197
  inline explicit CudnnWorkspaceHandle(CudnnHolder* holder) : holder_(holder) {}
S
sneaxiy 已提交
198 199 200

  /*! \brief Thread which call RunFunc() would acquire the lock first
   *  before invoking cudnn functions. */
S
sneaxiy 已提交
201 202 203 204 205 206 207 208
  template <typename Callback>
  inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_len) {
    if (!guard_) {
      guard_.reset(new std::lock_guard<std::mutex>(holder_->Mutex()));
    }
    holder_->RunFuncImpl(std::forward<Callback>(cudnn_func),
                         required_workspace_len);
  }
S
sneaxiy 已提交
209

S
sneaxiy 已提交
210 211
  CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
  CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;
S
sneaxiy 已提交
212 213 214 215 216 217

 private:
  CudnnHolder* holder_;  // not own
  std::unique_ptr<std::lock_guard<std::mutex>> guard_;
};

218
class CUDADeviceContext : public DeviceContext {
Q
QI JUN 已提交
219
 public:
D
dzhwinter 已提交
220
  explicit CUDADeviceContext(CUDAPlace place);
221
  virtual ~CUDADeviceContext();
Q
QI JUN 已提交
222

223
  /*! \brief  Wait for all operations completion in the stream. */
224
  void Wait() const override;
Q
QI JUN 已提交
225

226
  /*! \brief  Return place in the device context. */
L
liaogang 已提交
227
  Place GetPlace() const override;
228

K
Kexin Zhao 已提交
229
  /*! \brief  Return compute capability in the device context. */
K
Kexin Zhao 已提交
230 231
  int GetComputeCapability() const;

232 233 234
  /*! \brief  Return the max physical thread count in the device context */
  int GetMaxPhysicalThreadCount() const;

235 236 237
  /*! \brief  Return eigen device in the device context. */
  Eigen::GpuDevice* eigen_device() const;

238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
  /*! \brief  Call cublas function safely. */
  template <typename Callback>
  inline void CublasCall(Callback&& callback) const {
    cublas_handle_->Call(std::forward<Callback>(callback));
  }

  /*! \brief  Check whether tensor core is supported */
  bool tensor_core_available() const;

  /*! \brief  Call cublas function with Tensor Core safely. If
      Tensor Core is not available, use DEFAULT_MATH instead. */
  template <typename Callback>
  inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
    if (cublas_tensor_core_handle_) {
      cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
    } else {
      cublas_handle_->Call(std::forward<Callback>(callback));
    }
  }
S
sneaxiy 已提交
257

258
  /*! \brief  Return cudnn  handle in the device context. */
259
  cudnnHandle_t cudnn_handle() const;
260

S
sneaxiy 已提交
261 262 263 264 265 266 267 268 269
  /*! \brief  Return a cudnn workspace handle to call multiple cudnn
   *  functions without interrupting by other threads.
   *  Once the first cudnn function is called by the handle, a lock
   *  would be acquired to prevent other threads from accessing the
   *  workspace. Once the handle is destructed, the lock would be released.
   *  CudnnWorkspaceHandle is an RAII object to implement thread-safe
   *  sequential cudnn function calls. */
  CudnnWorkspaceHandle cudnn_workspace_handle() const;

Q
init  
qijun 已提交
270
  /*! \brief  Return cuda stream in the device context. */
271
  cudaStream_t stream() const;
Q
QI JUN 已提交
272

Q
qingqing01 已提交
273
#if !defined(_WIN32)
Q
qingqing01 已提交
274 275 276 277 278
  /*! \brief  Return nccl communicators. */
  ncclComm_t nccl_comm() const { return nccl_comm_; }

  /*! \brief  Set nccl communicators. */
  void set_nccl_comm(ncclComm_t comm) { nccl_comm_ = comm; }
Q
qingqing01 已提交
279
#endif
Q
qingqing01 已提交
280

Y
Yu Yang 已提交
281 282 283 284 285 286
  template <typename Callback>
  void RecordEvent(cudaEvent_t ev, Callback callback) {
    callback();
    PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
  }

S
sneaxiy 已提交
287 288 289 290 291
  template <typename Callback>
  void AddStreamCallback(Callback&& callback) const {
    callback_manager_->AddCallback(callback);
  }

S
fix bug  
sneaxiy 已提交
292
  void WaitStreamCallback() const { callback_manager_->Wait(); }
S
sneaxiy 已提交
293

Q
QI JUN 已提交
294
 private:
D
dzhwinter 已提交
295
  CUDAPlace place_;
Q
QI JUN 已提交
296

N
nhzlx 已提交
297
  mutable std::once_flag init_cudnn_;
298

Q
qijun 已提交
299
  std::unique_ptr<Eigen::GpuDevice> eigen_device_;
Q
init  
qijun 已提交
300
  std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
301
  mutable std::unique_ptr<CudnnHolder> cudnn_holder_;
302
  cudaStream_t stream_;
303 304 305

  std::unique_ptr<CublasHandleHolder> cublas_handle_;
  std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
306

Q
qingqing01 已提交
307
#if !defined(_WIN32)
Q
qingqing01 已提交
308 309 310 311 312 313
  // NCCL communicator (single process version) for NCCL collective operations.
  // NCCL collective operations provides fast collectives over multiple GPUs
  // both within and across nodes.
  // But, this collectives is used for collectives over multiple GPUs within
  // nodes.
  ncclComm_t nccl_comm_{nullptr};
Q
qingqing01 已提交
314
#endif
Q
qingqing01 已提交
315

C
chengduo 已提交
316 317 318 319 320
  int compute_capability_;
  int runtime_version_;
  int driver_version_;
  int multi_process_;
  int max_threads_per_mp_;
Y
yuyang18 已提交
321

S
fix bug  
sneaxiy 已提交
322
  // StreamCallbackManager is thread-safe
S
sneaxiy 已提交
323
  std::unique_ptr<StreamCallbackManager> callback_manager_;
324
  CudnnHolder* cudnn_holder() const;
325

326
  DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
Q
QI JUN 已提交
327
};
Q
qijun 已提交
328

Y
Yang Yu 已提交
329 330
template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
Y
Yang Yu 已提交
331
  using TYPE = CUDADeviceContext;
Y
Yang Yu 已提交
332 333
};

C
chengduoZH 已提交
334
// Currently, CUDAPinnedDeviceContext is only used to data copying.
C
chengduoZH 已提交
335 336 337 338 339 340
class CUDAPinnedDeviceContext : public DeviceContext {
 public:
  CUDAPinnedDeviceContext();
  explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);

  Place GetPlace() const override;
C
chengduoZH 已提交
341

C
chengduoZH 已提交
342 343 344 345 346 347 348 349 350 351 352
  Eigen::DefaultDevice* eigen_device() const;

 private:
  CUDAPinnedPlace place_;
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <>
struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
  using TYPE = CUDAPinnedDeviceContext;
};
Q
QI JUN 已提交
353
#endif
Q
qijun 已提交
354

T
tensor-tang 已提交
355
#ifdef PADDLE_WITH_MKLDNN
S
Sylwester Fraczek 已提交
356 357 358 359 360 361
using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
using BlobMap = std::unordered_map<int, std::shared_ptr<KeyBlob>>;

void set_cur_thread_id(int);
int get_cur_thread_id(void);

T
tensor-tang 已提交
362 363 364 365 366
class MKLDNNDeviceContext : public CPUDeviceContext {
 public:
  explicit MKLDNNDeviceContext(CPUPlace place);

  /* \brief  Get the active engine */
367
  const mkldnn::engine& GetEngine() const { return engine_; }
T
tensor-tang 已提交
368

369 370
  // Set data to blob (i.e. name/data pair). Create blob if not existing
  void SetBlob(const std::string& name, std::shared_ptr<void> data) const;
T
tensor-tang 已提交
371

372 373
  // Find a saved blob. Return nullptr if not found
  std::shared_ptr<void> GetBlob(const std::string& name) const;
T
tensor-tang 已提交
374 375

 private:
376
  mkldnn::engine engine_;
377 378
  std::shared_ptr<BlobMap> p_blobmap_;
  std::shared_ptr<std::mutex> p_mutex_;
T
tensor-tang 已提交
379 380 381
};
#endif

D
dzhwinter 已提交
382 383 384 385 386
/*! \brief device context pool singleton */
class DeviceContextPool {
 public:
  explicit DeviceContextPool(const std::vector<platform::Place>& places);

Y
Yang Yu 已提交
387
  static DeviceContextPool& Instance() {
D
dzhwinter 已提交
388 389 390 391 392
    PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!");
    return *pool;
  }

  /*! \brief  Create should only called by Init function */
Y
Yang Yu 已提交
393
  static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
D
dzhwinter 已提交
394 395 396 397 398 399 400
    if (pool == nullptr) {
      pool = new DeviceContextPool(places);
    }
    return *pool;
  }

  /*! \brief  Return handle of single device context. */
Y
Yu Yang 已提交
401
  platform::DeviceContext* Get(const platform::Place& place);
D
dzhwinter 已提交
402

Y
Yang Yu 已提交
403 404 405 406 407 408 409
  template <typename Place>
  const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
      const Place& place) {
    return reinterpret_cast<
        const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
  }

410 411
  size_t size() const { return device_contexts_.size(); }

D
dzhwinter 已提交
412 413
 private:
  static DeviceContextPool* pool;
414 415
  std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
      device_contexts_;
D
dzhwinter 已提交
416 417 418
  DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};

Q
QI JUN 已提交
419 420
}  // namespace platform
}  // namespace paddle