提交 d951e9c7 编写于 作者: L liaogang

Fix: refine device context and fix place()

上级 89512dff
...@@ -20,12 +20,96 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>() ...@@ -20,12 +20,96 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>()
return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device(); return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device();
} }
CPUDeviceContext::CPUDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}
CPUDeviceContext::CPUDeviceContext(CPUPlace place) {
eigen_device_.reset(new Eigen::DefaultDevice());
}
Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
return eigen_device_.get();
}
Place CPUDeviceContext::place() const { return CPUPlace(); }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const { Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device(); return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
} }
#endif
CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}
CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device);
wait();
if (cublas_handle_) {
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
}
if (cudnn_handle_) {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}
if (curand_generator_) {
PADDLE_ENFORCE(dynload::curandDestroyGenerator(curand_generator_));
}
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
}
Place CUDADeviceContext::place() const { return place_; }
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
void CUDADeviceContext::wait() const {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
}
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get();
}
cublasHandle_t CUDADeviceContext::cublas_handle() {
if (!cublas_handle_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
}
return cublas_handle_;
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() {
if (!cudnn_handle_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
}
return cudnn_handle_;
}
curandGenerator_t CUDADeviceContext::curand_generator() {
if (!curand_generator_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_,
CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_));
PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_));
}
return curand_generator_;
}
#endif // PADDLE_ONLY_CPU
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -31,7 +31,7 @@ namespace platform { ...@@ -31,7 +31,7 @@ namespace platform {
class DeviceContext { class DeviceContext {
public: public:
virtual ~DeviceContext() {} virtual ~DeviceContext() {}
virtual Place GetPlace() const = 0; virtual Place place() const = 0;
template <typename DeviceType> template <typename DeviceType>
DeviceType* get_eigen_device() const; DeviceType* get_eigen_device() const;
...@@ -39,14 +39,13 @@ class DeviceContext { ...@@ -39,14 +39,13 @@ class DeviceContext {
class CPUDeviceContext : public DeviceContext { class CPUDeviceContext : public DeviceContext {
public: public:
CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } CPUDeviceContext();
CPUDeviceContext(CPUPlace);
virtual ~CPUDeviceContext() {}
Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); } Eigen::DefaultDevice* eigen_device() const;
Place GetPlace() const override { Place place() const override;
Place retv = CPUPlace();
return retv;
}
private: private:
std::unique_ptr<Eigen::DefaultDevice> eigen_device_; std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
...@@ -54,119 +53,51 @@ class CPUDeviceContext : public DeviceContext { ...@@ -54,119 +53,51 @@ class CPUDeviceContext : public DeviceContext {
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
class GPUPlaceGuard { class CUDADeviceContext : public DeviceContext {
public: public:
explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { explicit CUDADeviceContext(GPUPlace);
if (previous_ != new_place) { virtual ~CUDADeviceContext();
paddle::platform::SetDeviceId(new_place.device);
}
}
~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); } /*! \brief Wait for all operations completion in the stream. */
void wait() const;
private: /*! \brief Return CUDA stream in the device context. */
GPUPlace previous_; cudaStream_t stream() const;
};
class CUDADeviceContext : public DeviceContext { /*! \brief Return place in the device context. */
public: Place place() const override;
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
GPUPlaceGuard guard(gpu_place_); /*! \brief Return eigen device in the device context. */
PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); Eigen::GpuDevice* eigen_device() const;
eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); // clang-format off
} /*! \brief Return cublas handle in the device context. */
cublasHandle_t cublas_handle ();
Place GetPlace() const override {
Place retv = GPUPlace(); /*! \brief Return cudnn handle in the device context. */
return retv; cudnnHandle_t cudnn_handle ();
}
/*! \brief Return curand handle in the device context. */
void Wait() { curandGenerator_t curand_generator();
PADDLE_ENFORCE(cudaStreamSynchronize(stream_), // clang-format on
"cudaStreamSynchronize failed");
}
cudaStream_t stream() { return stream_; }
Eigen::GpuDevice* eigen_device() const { return eigen_device_.get(); }
cublasHandle_t cublas_handle() {
if (!blas_handle_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_),
"cublasCreate failed");
PADDLE_ENFORCE(
paddle::platform::dynload::cublasSetStream(blas_handle_, stream_),
"cublasSetStream failed");
}
return blas_handle_;
}
cudnnHandle_t cudnn_handle() {
if (!dnn_handle_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_),
"cudnnCreate failed");
PADDLE_ENFORCE(
paddle::platform::dynload::cudnnSetStream(dnn_handle_, stream_),
"cudnnSetStream failed");
}
return dnn_handle_;
}
curandGenerator_t curand_generator() {
if (!rand_generator_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT),
"curandCreateGenerator failed");
PADDLE_ENFORCE(
paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed(
rand_generator_, random_seed_),
"curandSetPseudoRandomGeneratorSeed failed");
PADDLE_ENFORCE(
paddle::platform::dynload::curandSetStream(rand_generator_, stream_),
"curandSetStream failed");
}
return rand_generator_;
}
~CUDADeviceContext() {
Wait();
if (blas_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_),
"cublasDestroy failed");
}
if (dnn_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_),
"cudnnDestroy failed");
}
if (rand_generator_) {
PADDLE_ENFORCE(
paddle::platform::dynload::curandDestroyGenerator(rand_generator_),
"curandDestroyGenerator failed");
}
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_), "cudaStreamDestroy failed");
}
private: private:
GPUPlace gpu_place_; GPUPlace place_;
cudaStream_t stream_;
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_; private:
std::unique_ptr<Eigen::GpuDevice> eigen_device_; std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_;
cublasHandle_t blas_handle_{nullptr}; private:
uint64_t seed_;
cudnnHandle_t dnn_handle_{nullptr}; cudaStream_t stream_;
int random_seed_; // clang-format off
curandGenerator_t rand_generator_{nullptr}; cudnnHandle_t cudnn_handle_ = nullptr;
cublasHandle_t cublas_handle_ = nullptr;
curandGenerator_t curand_generator_ = nullptr;
// clang-format on
}; };
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册