提交 be2c1a3b 编写于 作者: Q qijun

follow comments

上级 a07deac9
...@@ -15,13 +15,13 @@ namespace paddle { ...@@ -15,13 +15,13 @@ namespace paddle {
namespace platform { namespace platform {
template <> template <>
Eigen::DefaultDevice DeviceContext::get_eigen_device<Eigen::DefaultDevice>() { Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>() {
return reinterpret_cast<CPUDeviceContext*>(this)->eigen_device(); return reinterpret_cast<CPUDeviceContext*>(this)->eigen_device();
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice DeviceContext::get_eigen_device<Eigen::GpuDevice>() { Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() {
return reinterpret_cast<CUDADeviceContext*>(this)->eigen_device(); return reinterpret_cast<CUDADeviceContext*>(this)->eigen_device();
} }
#endif #endif
......
...@@ -31,16 +31,16 @@ class DeviceContext { ...@@ -31,16 +31,16 @@ class DeviceContext {
virtual Place GetPlace() const = 0; virtual Place GetPlace() const = 0;
template <typename DeviceType> template <typename DeviceType>
DeviceType get_eigen_device(); DeviceType* get_eigen_device();
}; };
class CPUDeviceContext : public DeviceContext { class CPUDeviceContext : public DeviceContext {
public: public:
Eigen::DefaultDevice eigen_device() { Eigen::DefaultDevice* eigen_device() {
if (!eigen_device_) { if (!eigen_device_) {
eigen_device_ = new Eigen::DefaultDevice(); eigen_device_.reset(new Eigen::DefaultDevice());
} }
return *eigen_device_; return eigen_device_.get();
} }
Place GetPlace() const override { Place GetPlace() const override {
...@@ -49,7 +49,7 @@ class CPUDeviceContext : public DeviceContext { ...@@ -49,7 +49,7 @@ class CPUDeviceContext : public DeviceContext {
} }
private: private:
Eigen::DefaultDevice* eigen_device_{nullptr}; std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
}; };
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
...@@ -74,8 +74,8 @@ class CUDADeviceContext : public DeviceContext { ...@@ -74,8 +74,8 @@ class CUDADeviceContext : public DeviceContext {
GPUPlaceGuard guard(gpu_place_); GPUPlaceGuard guard(gpu_place_);
paddle::platform::throw_on_error(cudaStreamCreate(&stream_), paddle::platform::throw_on_error(cudaStreamCreate(&stream_),
"cudaStreamCreate failed"); "cudaStreamCreate failed");
eigen_stream_ = new Eigen::CudaStreamDevice(&stream_); eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
eigen_device_ = new Eigen::GpuDevice(eigen_stream_); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
} }
Place GetPlace() const override { Place GetPlace() const override {
...@@ -90,7 +90,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -90,7 +90,7 @@ class CUDADeviceContext : public DeviceContext {
cudaStream_t stream() { return stream_; } cudaStream_t stream() { return stream_; }
Eigen::GpuDevice eigen_device() { return *eigen_device_; } Eigen::GpuDevice* eigen_device() { return eigen_device_.get(); }
cublasHandle_t cublas_handle() { cublasHandle_t cublas_handle() {
if (!blas_handle_) { if (!blas_handle_) {
...@@ -155,10 +155,8 @@ class CUDADeviceContext : public DeviceContext { ...@@ -155,10 +155,8 @@ class CUDADeviceContext : public DeviceContext {
rand_generator_) == CURAND_STATUS_SUCCESS, rand_generator_) == CURAND_STATUS_SUCCESS,
"curandDestroyGenerator failed"); "curandDestroyGenerator failed");
} }
eigen_stream_.reset();
delete eigen_stream_; eigen_device_.reset();
delete eigen_device_;
paddle::platform::throw_on_error(cudaStreamDestroy(stream_), paddle::platform::throw_on_error(cudaStreamDestroy(stream_),
"cudaStreamDestroy failed"); "cudaStreamDestroy failed");
} }
...@@ -167,8 +165,8 @@ class CUDADeviceContext : public DeviceContext { ...@@ -167,8 +165,8 @@ class CUDADeviceContext : public DeviceContext {
GPUPlace gpu_place_; GPUPlace gpu_place_;
cudaStream_t stream_; cudaStream_t stream_;
Eigen::CudaStreamDevice* eigen_stream_; std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_;
Eigen::GpuDevice* eigen_device_; std::unique_ptr<Eigen::GpuDevice> eigen_device_;
cublasHandle_t blas_handle_{nullptr}; cublasHandle_t blas_handle_{nullptr};
......
...@@ -21,9 +21,9 @@ TEST(Device, Init) { ...@@ -21,9 +21,9 @@ TEST(Device, Init) {
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
paddle::platform::DeviceContext* device_context = paddle::platform::DeviceContext* device_context =
new paddle::platform::CUDADeviceContext(i); new paddle::platform::CUDADeviceContext(i);
Eigen::GpuDevice gpu_device = Eigen::GpuDevice* gpu_device =
device_context->template get_eigen_device<DEVICE_GPU>(); device_context->template get_eigen_device<DEVICE_GPU>();
ASSERT_NE(nullptr, gpu_device.stream()); ASSERT_NE(nullptr, gpu_device);
delete device_context; delete device_context;
} }
} }
...@@ -33,8 +33,8 @@ TEST(Device, CUDADeviceContext) { ...@@ -33,8 +33,8 @@ TEST(Device, CUDADeviceContext) {
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
paddle::platform::CUDADeviceContext* device_context = paddle::platform::CUDADeviceContext* device_context =
new paddle::platform::CUDADeviceContext(i); new paddle::platform::CUDADeviceContext(i);
Eigen::GpuDevice gpu_device = device_context->eigen_device(); Eigen::GpuDevice* gpu_device = device_context->eigen_device();
ASSERT_NE(nullptr, gpu_device.stream()); ASSERT_NE(nullptr, gpu_device);
cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle); ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle(); cublasHandle_t cublas_handle = device_context->cublas_handle();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册