提交 d951e9c7 编写于 作者: L liaogang

Fix: refine device context and fix place()

上级 89512dff
......@@ -20,12 +20,96 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>()
return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device();
}
CPUDeviceContext::CPUDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}
CPUDeviceContext::CPUDeviceContext(CPUPlace place) {
eigen_device_.reset(new Eigen::DefaultDevice());
}
Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
return eigen_device_.get();
}
Place CPUDeviceContext::place() const { return CPUPlace(); }
#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
}
#endif
CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}
CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device);
wait();
if (cublas_handle_) {
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
}
if (cudnn_handle_) {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}
if (curand_generator_) {
PADDLE_ENFORCE(dynload::curandDestroyGenerator(curand_generator_));
}
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
}
Place CUDADeviceContext::place() const { return place_; }
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
void CUDADeviceContext::wait() const {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
}
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get();
}
cublasHandle_t CUDADeviceContext::cublas_handle() {
if (!cublas_handle_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
}
return cublas_handle_;
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() {
if (!cudnn_handle_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
}
return cudnn_handle_;
}
curandGenerator_t CUDADeviceContext::curand_generator() {
if (!curand_generator_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_,
CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_));
PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_));
}
return curand_generator_;
}
#endif // PADDLE_ONLY_CPU
} // namespace platform
} // namespace paddle
......@@ -31,7 +31,7 @@ namespace platform {
class DeviceContext {
public:
virtual ~DeviceContext() {}
virtual Place GetPlace() const = 0;
virtual Place place() const = 0;
template <typename DeviceType>
DeviceType* get_eigen_device() const;
......@@ -39,14 +39,13 @@ class DeviceContext {
class CPUDeviceContext : public DeviceContext {
public:
CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); }
CPUDeviceContext();
CPUDeviceContext(CPUPlace);
virtual ~CPUDeviceContext() {}
Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); }
Eigen::DefaultDevice* eigen_device() const;
Place GetPlace() const override {
Place retv = CPUPlace();
return retv;
}
Place place() const override;
private:
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
......@@ -54,119 +53,51 @@ class CPUDeviceContext : public DeviceContext {
#ifndef PADDLE_ONLY_CPU
class GPUPlaceGuard {
class CUDADeviceContext : public DeviceContext {
public:
explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
if (previous_ != new_place) {
paddle::platform::SetDeviceId(new_place.device);
}
}
explicit CUDADeviceContext(GPUPlace);
virtual ~CUDADeviceContext();
~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); }
/*! \brief Wait for all operations completion in the stream. */
void wait() const;
private:
GPUPlace previous_;
};
/*! \brief Return CUDA stream in the device context. */
cudaStream_t stream() const;
class CUDADeviceContext : public DeviceContext {
public:
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed");
eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}
Place GetPlace() const override {
Place retv = GPUPlace();
return retv;
}
void Wait() {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_),
"cudaStreamSynchronize failed");
}
cudaStream_t stream() { return stream_; }
Eigen::GpuDevice* eigen_device() const { return eigen_device_.get(); }
cublasHandle_t cublas_handle() {
if (!blas_handle_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_),
"cublasCreate failed");
PADDLE_ENFORCE(
paddle::platform::dynload::cublasSetStream(blas_handle_, stream_),
"cublasSetStream failed");
}
return blas_handle_;
}
cudnnHandle_t cudnn_handle() {
if (!dnn_handle_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_),
"cudnnCreate failed");
PADDLE_ENFORCE(
paddle::platform::dynload::cudnnSetStream(dnn_handle_, stream_),
"cudnnSetStream failed");
}
return dnn_handle_;
}
curandGenerator_t curand_generator() {
if (!rand_generator_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT),
"curandCreateGenerator failed");
PADDLE_ENFORCE(
paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed(
rand_generator_, random_seed_),
"curandSetPseudoRandomGeneratorSeed failed");
PADDLE_ENFORCE(
paddle::platform::dynload::curandSetStream(rand_generator_, stream_),
"curandSetStream failed");
}
return rand_generator_;
}
~CUDADeviceContext() {
Wait();
if (blas_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_),
"cublasDestroy failed");
}
if (dnn_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_),
"cudnnDestroy failed");
}
if (rand_generator_) {
PADDLE_ENFORCE(
paddle::platform::dynload::curandDestroyGenerator(rand_generator_),
"curandDestroyGenerator failed");
}
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_), "cudaStreamDestroy failed");
}
/*! \brief Return place in the device context. */
Place place() const override;
/*! \brief Return eigen device in the device context. */
Eigen::GpuDevice* eigen_device() const;
// clang-format off
/*! \brief Return cublas handle in the device context. */
cublasHandle_t cublas_handle ();
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle ();
/*! \brief Return curand handle in the device context. */
curandGenerator_t curand_generator();
// clang-format on
private:
GPUPlace gpu_place_;
cudaStream_t stream_;
GPUPlace place_;
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_;
private:
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_;
cublasHandle_t blas_handle_{nullptr};
private:
uint64_t seed_;
cudnnHandle_t dnn_handle_{nullptr};
cudaStream_t stream_;
int random_seed_;
curandGenerator_t rand_generator_{nullptr};
// clang-format off
cudnnHandle_t cudnn_handle_ = nullptr;
cublasHandle_t cublas_handle_ = nullptr;
curandGenerator_t curand_generator_ = nullptr;
// clang-format on
};
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册