device_context.h 5.5 KB
Newer Older
Q
QI JUN 已提交
1 2 3 4 5 6 7 8 9 10 11 12
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Q
qijun 已提交
13

Q
QI JUN 已提交
14
#include "paddle/framework/enforce.h"
Q
qijun 已提交
15 16 17 18 19 20 21 22 23
#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/cuda.h"
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#define EIGEN_USE_GPU
#endif
#include <paddle/platform/place.h>
#include <unsupported/Eigen/CXX11/Tensor>
Q
qijun 已提交
24

Q
QI JUN 已提交
25 26 27
namespace paddle {
namespace platform {

Q
qijun 已提交
28 29 30
class DeviceContext {
 public:
  virtual ~DeviceContext() {}
Q
qijun 已提交
31
  virtual Place GetPlace() const = 0;
Q
QI JUN 已提交
32

Q
qijun 已提交
33
  template <typename DeviceType>
Q
qijun 已提交
34
  DeviceType get_eigen_device();
Q
qijun 已提交
35 36 37 38
};

class CPUDeviceContext : public DeviceContext {
 public:
Q
qijun 已提交
39 40 41
  Eigen::DefaultDevice eigen_device() {
    if (!eigen_device_) {
      eigen_device_ = new Eigen::DefaultDevice();
Q
qijun 已提交
42
    }
Q
qijun 已提交
43
    return *eigen_device_;
Q
QI JUN 已提交
44
  }
Q
qijun 已提交
45

Y
Yu Yang 已提交
46
  Place GetPlace() const override {
Q
qijun 已提交
47
    Place retv = CPUPlace();
Y
Yu Yang 已提交
48 49 50
    return retv;
  }

Q
qijun 已提交
51
 private:
Q
qijun 已提交
52
  Eigen::DefaultDevice* eigen_device_{nullptr};
Q
QI JUN 已提交
53
};
Q
qijun 已提交
54

Q
qijun 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
#ifndef PADDLE_ONLY_CPU

class GPUPlaceGuard {
 public:
  explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
    if (previous_ != new_place) {
      paddle::platform::SetDeviceId(new_place.device);
    }
  }

  ~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); }

 private:
  GPUPlace previous_;
};

class CUDADeviceContext : public DeviceContext {
 public:
  explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
    GPUPlaceGuard guard(gpu_place_);
    paddle::platform::throw_on_error(cudaStreamCreate(&stream_),
                                     "cudaStreamCreate failed");
    eigen_stream_ = new Eigen::CudaStreamDevice(&stream_);
    eigen_device_ = new Eigen::GpuDevice(eigen_stream_);
  }

  Place GetPlace() const override {
    Place retv = GPUPlace();
    return retv;
  }

  void Wait() {
    paddle::platform::throw_on_error(cudaStreamSynchronize(stream_),
                                     "cudaStreamSynchronize failed");
  }

  cudaStream_t stream() { return stream_; }

  Eigen::GpuDevice eigen_device() { return *eigen_device_; }

  cublasHandle_t cublas_handle() {
    if (!blas_handle_) {
      GPUPlaceGuard guard(gpu_place_);
      PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) ==
                         CUBLAS_STATUS_SUCCESS,
                     "cublasCreate failed");
      PADDLE_ENFORCE(paddle::platform::dynload::cublasSetStream(
                         blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS,
                     "cublasSetStream failed");
    }
    return blas_handle_;
  }

  cudnnHandle_t cudnn_handle() {
    if (!dnn_handle_) {
      GPUPlaceGuard guard(gpu_place_);
      PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) ==
                         CUDNN_STATUS_SUCCESS,
                     "cudnnCreate failed");
      PADDLE_ENFORCE(paddle::platform::dynload::cudnnSetStream(
                         dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS,
                     "cudnnSetStream failed");
    }
    return dnn_handle_;
  }

  curandGenerator_t curand_generator() {
    if (!rand_generator_) {
      GPUPlaceGuard guard(gpu_place_);
      PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
                         &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) ==
                         CURAND_STATUS_SUCCESS,
                     "curandCreateGenerator failed");
      PADDLE_ENFORCE(
          paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed(
              rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS,
          "curandSetPseudoRandomGeneratorSeed failed");
      PADDLE_ENFORCE(paddle::platform::dynload::curandSetStream(
                         rand_generator_, stream_) == CURAND_STATUS_SUCCESS,
                     "curandSetStream failed");
    }
    return rand_generator_;
  }

  ~CUDADeviceContext() {
    Wait();
    if (blas_handle_) {
      PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) ==
                         CUBLAS_STATUS_SUCCESS,
                     "cublasDestroy failed");
    }

    if (dnn_handle_) {
      PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_) ==
                         CUDNN_STATUS_SUCCESS,
                     "cudnnDestroy failed");
    }

    if (rand_generator_) {
      PADDLE_ENFORCE(paddle::platform::dynload::curandDestroyGenerator(
                         rand_generator_) == CURAND_STATUS_SUCCESS,
                     "curandDestroyGenerator failed");
    }

    delete eigen_stream_;
    delete eigen_device_;

    paddle::platform::throw_on_error(cudaStreamDestroy(stream_),
                                     "cudaStreamDestroy failed");
  }

 private:
  GPUPlace gpu_place_;
  cudaStream_t stream_;

  Eigen::CudaStreamDevice* eigen_stream_;
  Eigen::GpuDevice* eigen_device_;

  cublasHandle_t blas_handle_{nullptr};

  cudnnHandle_t dnn_handle_{nullptr};

  int random_seed_;
  curandGenerator_t rand_generator_{nullptr};
};

#endif

Q
QI JUN 已提交
183 184
}  // namespace platform
}  // namespace paddle