提交 a30754b0 编写于 作者: Q qijun

test device_context

上级 3ba7a738
...@@ -2,3 +2,6 @@ nv_test(cuda_test SRCS cuda_test.cu) ...@@ -2,3 +2,6 @@ nv_test(cuda_test SRCS cuda_test.cu)
cc_library(place SRCS place.cc) cc_library(place SRCS place.cc)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags) cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
cc_library(dynamic_loader SRCS dynamic_loader.cc)
nv_test(device_context_test SRCS device_context_test.cu DEPS place dynamic_loader glog gflags)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/cublas.h"
#include "paddle/platform/cuda.h"
#include "paddle/platform/cudnn.h"
#include "paddle/platform/curand.h"
#define EIGEN_USE_GPU
#endif
#include "paddle/framework/enforce.h"
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace platform {
class DeviceContext {
public:
virtual ~DeviceContext() {}
};
class CpuDeviceContext : public DeviceContext {
Eigen::DefaultDevice eigen_device() {
if (!eigen_device_) {
eigen_device_ = new Eigen::DefaultDevice();
}
return *eigen_device_;
}
private:
Eigen::DefaultDevice* eigen_device_{nullptr};
};
#ifndef PADDLE_ONLY_CPU
class DeviceGuard {
public:
explicit DeviceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
if (previous_ != new_place) {
paddle::platform::SetDeviceId(new_place.device);
}
}
~DeviceGuard() { paddle::platform::SetDeviceId(previous_.device); }
private:
GPUPlace previous_;
};
class CudaDeviceContext : public DeviceContext {
public:
explicit CudaDeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
DeviceGuard guard(gpu_place_);
paddle::platform::throw_on_error(cudaStreamCreate(&stream_),
"cudaStreamCreate failed");
eigen_stream_ = new Eigen::CudaStreamDevice(&stream_);
eigen_device_ = new Eigen::GpuDevice(eigen_stream_);
}
void Wait() {
paddle::platform::throw_on_error(cudaStreamSynchronize(stream_),
"cudaStreamSynchronize failed");
}
cudaStream_t stream() { return stream_; }
Eigen::GpuDevice eigen_device() { return *eigen_device_; }
cublasHandle_t cublas_handle() {
if (!blas_handle_) {
DeviceGuard guard(gpu_place_);
PADDLE_ENFORCE(cublasCreate(&blas_handle_) == CUBLAS_STATUS_SUCCESS,
"cublasCreate failed");
PADDLE_ENFORCE(
cublasSetStream(blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS,
"cublasSetStream failed");
}
return blas_handle_;
}
cudnnHandle_t cudnn_handle() {
if (!dnn_handle_) {
DeviceGuard guard(gpu_place_);
PADDLE_ENFORCE(cudnnCreate(&dnn_handle_) == CUDNN_STATUS_SUCCESS,
"cudnnCreate failed");
PADDLE_ENFORCE(
cudnnSetStream(dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS,
"cudnnSetStream failed");
}
return dnn_handle_;
}
curandGenerator_t curand_generator() {
if (!rand_generator_) {
DeviceGuard guard(gpu_place_);
PADDLE_ENFORCE(
curandCreateGenerator(&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) ==
CURAND_STATUS_SUCCESS,
"curandCreateGenerator failed");
PADDLE_ENFORCE(
curandSetPseudoRandomGeneratorSeed(rand_generator_, random_seed_) ==
CURAND_STATUS_SUCCESS,
"curandSetPseudoRandomGeneratorSeed failed");
PADDLE_ENFORCE(
curandSetStream(rand_generator_, stream_) == CURAND_STATUS_SUCCESS,
"curandSetStream failed");
}
return rand_generator_;
}
~CudaDeviceContext() {
Wait();
if (blas_handle_) {
PADDLE_ENFORCE(cublasDestroy(blas_handle_) == CUBLAS_STATUS_SUCCESS,
"cublasDestroy failed");
}
if (dnn_handle_) {
PADDLE_ENFORCE(cudnnDestroy(dnn_handle_) == CUDNN_STATUS_SUCCESS,
"cudnnDestroy failed");
}
if (rand_generator_) {
PADDLE_ENFORCE(
curandDestroyGenerator(rand_generator_) == CURAND_STATUS_SUCCESS,
"curandDestroyGenerator failed");
}
delete eigen_stream_;
delete eigen_device_;
paddle::platform::throw_on_error(cudaStreamDestroy(stream_),
"cudaStreamDestroy failed");
}
private:
GPUPlace gpu_place_;
cudaStream_t stream_;
Eigen::CudaStreamDevice* eigen_stream_;
Eigen::GpuDevice* eigen_device_;
cublasHandle_t blas_handle_{nullptr};
cudnnHandle_t dnn_handle_{nullptr};
int random_seed_;
curandGenerator_t rand_generator_{nullptr};
};
#endif
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/device_context.h"
#include "gtest/gtest.h"
TEST(DeviceContext, CudaDevice) {
int count = paddle::platform::GetDeviceCount();
for (int i = 0; i < count; i++) {
paddle::platform::CudaDeviceContext* device_context = new paddle::platform::CudaDeviceContext(i);
__attribute__((unused)) Eigen::GpuDevice gpu_device = device_context->eigen_device();
__attribute__((unused)) cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
__attribute__((unused)) cublasHandle_t cublas_handle = device_context->cublas_handle();
__attribute__((unused)) curandGenerator_t curand_handle = device_context->curand_generator();
delete device_context;
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册