提交 bcc6e5ce 编写于 作者: G gangliao 提交者: GitHub

Merge pull request #3084 from gangliao/device_ctx

Refine device context and fix GetPlace()
add_subdirectory(detail) add_subdirectory(detail)
cc_library(memory SRCS memory.cc) cc_library(memory SRCS memory.cc)
cc_library(memcpy SRCS memcpy.cc DEPS device_context) cc_library(memcpy SRCS memcpy.cc)
cc_library(paddle_memory cc_library(paddle_memory
DEPS DEPS
......
...@@ -35,7 +35,7 @@ void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place, ...@@ -35,7 +35,7 @@ void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
platform::GPUPlace src_place, platform::GPUPlace src_place,
const void* src, size_t num, const void* src, size_t num,
cudaStream_t stream) { cudaStream_t stream) {
platform::GPUPlaceGuard g(src_place.device); platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
} }
...@@ -45,7 +45,7 @@ void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place, ...@@ -45,7 +45,7 @@ void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
platform::CPUPlace src_place, platform::CPUPlace src_place,
const void* src, size_t num, const void* src, size_t num,
cudaStream_t stream) { cudaStream_t stream) {
platform::GPUPlaceGuard g(dst_place.device); platform::SetDeviceId(dst_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
} }
...@@ -56,7 +56,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place, ...@@ -56,7 +56,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
const void* src, size_t num, const void* src, size_t num,
cudaStream_t stream) { cudaStream_t stream) {
if (dst_place == src_place) { if (dst_place == src_place) {
platform::GPUPlaceGuard g(src_place.device); platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
} else { } else {
platform::GpuMemcpyPeer(dst, dst_place.device, src, src_place.device, num, platform::GpuMemcpyPeer(dst, dst_place.device, src, src_place.device, num,
......
...@@ -20,12 +20,96 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>() ...@@ -20,12 +20,96 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>()
return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device(); return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device();
} }
CPUDeviceContext::CPUDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}
CPUDeviceContext::CPUDeviceContext(CPUPlace place) {
eigen_device_.reset(new Eigen::DefaultDevice());
}
Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
return eigen_device_.get();
}
Place CPUDeviceContext::GetPlace() const { return CPUPlace(); }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const { Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device(); return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
} }
#endif
CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}
CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device);
Wait();
if (cublas_handle_) {
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
}
if (cudnn_handle_) {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}
if (curand_generator_) {
PADDLE_ENFORCE(dynload::curandDestroyGenerator(curand_generator_));
}
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
}
Place CUDADeviceContext::GetPlace() const { return place_; }
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
void CUDADeviceContext::Wait() const {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
}
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get();
}
cublasHandle_t CUDADeviceContext::cublas_handle() {
if (!cublas_handle_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
}
return cublas_handle_;
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() {
if (!cudnn_handle_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
}
return cudnn_handle_;
}
curandGenerator_t CUDADeviceContext::curand_generator() {
if (!curand_generator_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_,
CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_));
PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_));
}
return curand_generator_;
}
#endif // PADDLE_ONLY_CPU
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -39,14 +39,13 @@ class DeviceContext { ...@@ -39,14 +39,13 @@ class DeviceContext {
class CPUDeviceContext : public DeviceContext { class CPUDeviceContext : public DeviceContext {
public: public:
CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } CPUDeviceContext();
CPUDeviceContext(CPUPlace);
virtual ~CPUDeviceContext() {}
Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); } Eigen::DefaultDevice* eigen_device() const;
Place GetPlace() const override { Place GetPlace() const override;
Place retv = CPUPlace();
return retv;
}
private: private:
std::unique_ptr<Eigen::DefaultDevice> eigen_device_; std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
...@@ -54,119 +53,51 @@ class CPUDeviceContext : public DeviceContext { ...@@ -54,119 +53,51 @@ class CPUDeviceContext : public DeviceContext {
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
class GPUPlaceGuard { class CUDADeviceContext : public DeviceContext {
public: public:
explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { explicit CUDADeviceContext(GPUPlace);
if (previous_ != new_place) { virtual ~CUDADeviceContext();
paddle::platform::SetDeviceId(new_place.device);
}
}
~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); } /*! \brief Wait for all operations completion in the stream. */
void Wait() const;
private: /*! \brief Return CUDA stream in the device context. */
GPUPlace previous_; cudaStream_t stream() const;
};
class CUDADeviceContext : public DeviceContext { /*! \brief Return place in the device context. */
public: Place GetPlace() const override;
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
GPUPlaceGuard guard(gpu_place_); /*! \brief Return eigen device in the device context. */
PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); Eigen::GpuDevice* eigen_device() const;
eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); // clang-format off
} /*! \brief Return cublas handle in the device context. */
cublasHandle_t cublas_handle ();
Place GetPlace() const override {
Place retv = GPUPlace(); /*! \brief Return cudnn handle in the device context. */
return retv; cudnnHandle_t cudnn_handle ();
}
/*! \brief Return curand handle in the device context. */
void Wait() { curandGenerator_t curand_generator();
PADDLE_ENFORCE(cudaStreamSynchronize(stream_), // clang-format on
"cudaStreamSynchronize failed");
}
cudaStream_t stream() const { return stream_; }
Eigen::GpuDevice* eigen_device() const { return eigen_device_.get(); }
cublasHandle_t cublas_handle() {
if (!blas_handle_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_),
"cublasCreate failed");
PADDLE_ENFORCE(
paddle::platform::dynload::cublasSetStream(blas_handle_, stream_),
"cublasSetStream failed");
}
return blas_handle_;
}
cudnnHandle_t cudnn_handle() {
if (!dnn_handle_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_),
"cudnnCreate failed");
PADDLE_ENFORCE(
paddle::platform::dynload::cudnnSetStream(dnn_handle_, stream_),
"cudnnSetStream failed");
}
return dnn_handle_;
}
curandGenerator_t curand_generator() {
if (!rand_generator_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT),
"curandCreateGenerator failed");
PADDLE_ENFORCE(
paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed(
rand_generator_, random_seed_),
"curandSetPseudoRandomGeneratorSeed failed");
PADDLE_ENFORCE(
paddle::platform::dynload::curandSetStream(rand_generator_, stream_),
"curandSetStream failed");
}
return rand_generator_;
}
~CUDADeviceContext() {
Wait();
if (blas_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_),
"cublasDestroy failed");
}
if (dnn_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_),
"cudnnDestroy failed");
}
if (rand_generator_) {
PADDLE_ENFORCE(
paddle::platform::dynload::curandDestroyGenerator(rand_generator_),
"curandDestroyGenerator failed");
}
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_), "cudaStreamDestroy failed");
}
private: private:
GPUPlace gpu_place_; GPUPlace place_;
cudaStream_t stream_;
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_; private:
std::unique_ptr<Eigen::GpuDevice> eigen_device_; std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_;
cublasHandle_t blas_handle_{nullptr}; private:
uint64_t seed_;
cudnnHandle_t dnn_handle_{nullptr}; cudaStream_t stream_;
int random_seed_; // clang-format off
curandGenerator_t rand_generator_{nullptr}; cudnnHandle_t cudnn_handle_ = nullptr;
cublasHandle_t cublas_handle_ = nullptr;
curandGenerator_t curand_generator_ = nullptr;
// clang-format on
}; };
#endif #endif
......
...@@ -58,11 +58,6 @@ struct EnforceNotMet : public std::exception { ...@@ -58,11 +58,6 @@ struct EnforceNotMet : public std::exception {
// For more details, please check https://stackoverflow.com/a/43870188/724872. // For more details, please check https://stackoverflow.com/a/43870188/724872.
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0) #define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
template <typename T>
inline void throw_on_error(T e) {
throw_on_error(e, "");
}
template <typename... Args> template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
int stat, const Args&... args) { int stat, const Args&... args) {
...@@ -132,6 +127,11 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( ...@@ -132,6 +127,11 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
#endif // PADDLE_ONLY_CPU #endif // PADDLE_ONLY_CPU
template <typename T>
inline void throw_on_error(T e) {
throw_on_error(e, "");
}
#define PADDLE_THROW(...) \ #define PADDLE_THROW(...) \
do { \ do { \
throw ::paddle::platform::EnforceNotMet( \ throw ::paddle::platform::EnforceNotMet( \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册