device_context.h 5.0 KB
Newer Older
Q
QI JUN 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

L
liaogang 已提交
14
#include "paddle/platform/enforce.h"
Q
QI JUN 已提交
15 16 17 18
#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
L
liaogang 已提交
19 20
#include "paddle/platform/error.h"
#include "paddle/platform/gpu_info.h"
Q
QI JUN 已提交
21 22
#define EIGEN_USE_GPU
#endif
Q
qijun 已提交
23
#include <paddle/platform/place.h>
Q
qijun 已提交
24
#include <memory>
Q
qijun 已提交
25
#include <unsupported/Eigen/CXX11/Tensor>
Q
QI JUN 已提交
26 27 28 29 30 31 32

namespace paddle {
namespace platform {

class DeviceContext {
 public:
  virtual ~DeviceContext() {}
Q
qijun 已提交
33
  virtual Place GetPlace() const = 0;
Q
QI JUN 已提交
34

Q
qijun 已提交
35
  template <typename DeviceType>
Q
qijun 已提交
36
  DeviceType* get_eigen_device();
Q
QI JUN 已提交
37 38
};

Q
qijun 已提交
39 40
class CPUDeviceContext : public DeviceContext {
 public:
Q
qijun 已提交
41
  Eigen::DefaultDevice* eigen_device() {
Q
qijun 已提交
42
    if (!eigen_device_) {
Q
qijun 已提交
43
      eigen_device_.reset(new Eigen::DefaultDevice());
Q
qijun 已提交
44
    }
Q
qijun 已提交
45
    return eigen_device_.get();
Q
QI JUN 已提交
46
  }
Q
qijun 已提交
47

Y
Yu Yang 已提交
48
  Place GetPlace() const override {
Q
qijun 已提交
49
    Place retv = CPUPlace();
Y
Yu Yang 已提交
50 51 52
    return retv;
  }

Q
qijun 已提交
53
 private:
Q
qijun 已提交
54
  std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
Q
QI JUN 已提交
55 56 57
};

#ifndef PADDLE_ONLY_CPU
D
dongzhihong 已提交
58

Q
QI JUN 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
class GPUPlaceGuard {
 public:
  explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
    if (previous_ != new_place) {
      paddle::platform::SetDeviceId(new_place.device);
    }
  }

  ~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); }

 private:
  GPUPlace previous_;
};

class CUDADeviceContext : public DeviceContext {
 public:
  explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
    GPUPlaceGuard guard(gpu_place_);
L
liaogang 已提交
77
    PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed");
Q
qijun 已提交
78 79
    eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
    eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
Q
qijun 已提交
80 81 82 83 84
  }

  Place GetPlace() const override {
    Place retv = GPUPlace();
    return retv;
Q
QI JUN 已提交
85 86 87
  }

  void Wait() {
L
liaogang 已提交
88 89
    PADDLE_ENFORCE(cudaStreamSynchronize(stream_),
                   "cudaStreamSynchronize failed");
Q
QI JUN 已提交
90 91 92 93
  }

  cudaStream_t stream() { return stream_; }

Q
qijun 已提交
94
  Eigen::GpuDevice* eigen_device() { return eigen_device_.get(); }
Q
QI JUN 已提交
95 96 97 98

  cublasHandle_t cublas_handle() {
    if (!blas_handle_) {
      GPUPlaceGuard guard(gpu_place_);
L
liaogang 已提交
99
      PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_),
Q
QI JUN 已提交
100
                     "cublasCreate failed");
L
liaogang 已提交
101 102 103
      PADDLE_ENFORCE(
          paddle::platform::dynload::cublasSetStream(blas_handle_, stream_),
          "cublasSetStream failed");
Q
QI JUN 已提交
104 105 106 107 108 109 110
    }
    return blas_handle_;
  }

  cudnnHandle_t cudnn_handle() {
    if (!dnn_handle_) {
      GPUPlaceGuard guard(gpu_place_);
L
liaogang 已提交
111
      PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_),
Q
QI JUN 已提交
112
                     "cudnnCreate failed");
L
liaogang 已提交
113 114 115
      PADDLE_ENFORCE(
          paddle::platform::dynload::cudnnSetStream(dnn_handle_, stream_),
          "cudnnSetStream failed");
Q
QI JUN 已提交
116 117 118 119 120 121 122 123
    }
    return dnn_handle_;
  }

  curandGenerator_t curand_generator() {
    if (!rand_generator_) {
      GPUPlaceGuard guard(gpu_place_);
      PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
L
liaogang 已提交
124
                         &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT),
Q
QI JUN 已提交
125 126 127
                     "curandCreateGenerator failed");
      PADDLE_ENFORCE(
          paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed(
L
liaogang 已提交
128
              rand_generator_, random_seed_),
Q
QI JUN 已提交
129
          "curandSetPseudoRandomGeneratorSeed failed");
L
liaogang 已提交
130 131 132
      PADDLE_ENFORCE(
          paddle::platform::dynload::curandSetStream(rand_generator_, stream_),
          "curandSetStream failed");
Q
QI JUN 已提交
133 134 135 136 137 138 139
    }
    return rand_generator_;
  }

  ~CUDADeviceContext() {
    Wait();
    if (blas_handle_) {
L
liaogang 已提交
140
      PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_),
Q
QI JUN 已提交
141 142 143 144
                     "cublasDestroy failed");
    }

    if (dnn_handle_) {
L
liaogang 已提交
145
      PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_),
Q
QI JUN 已提交
146 147 148 149
                     "cudnnDestroy failed");
    }

    if (rand_generator_) {
L
liaogang 已提交
150 151 152
      PADDLE_ENFORCE(
          paddle::platform::dynload::curandDestroyGenerator(rand_generator_),
          "curandDestroyGenerator failed");
Q
QI JUN 已提交
153
    }
Q
qijun 已提交
154 155
    eigen_stream_.reset();
    eigen_device_.reset();
L
liaogang 已提交
156
    PADDLE_ENFORCE(cudaStreamDestroy(stream_), "cudaStreamDestroy failed");
Q
QI JUN 已提交
157 158 159 160 161 162
  }

 private:
  GPUPlace gpu_place_;
  cudaStream_t stream_;

Q
qijun 已提交
163 164
  std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_;
  std::unique_ptr<Eigen::GpuDevice> eigen_device_;
Q
QI JUN 已提交
165 166 167 168 169 170 171 172

  cublasHandle_t blas_handle_{nullptr};

  cudnnHandle_t dnn_handle_{nullptr};

  int random_seed_;
  curandGenerator_t rand_generator_{nullptr};
};
Q
qijun 已提交
173

Q
QI JUN 已提交
174
#endif
Q
qijun 已提交
175

Q
QI JUN 已提交
176 177
}  // namespace platform
}  // namespace paddle