diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 25ff352e8cf2c890f8aaa15a3c84965524ca1555..960ef0a5955bfe5f7d33b7c8e4524176b0dbfda6 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -15,13 +15,13 @@ namespace paddle { namespace platform { template <> -Eigen::DefaultDevice DeviceContext::get_eigen_device() { +Eigen::DefaultDevice* DeviceContext::get_eigen_device() { return reinterpret_cast(this)->eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> -Eigen::GpuDevice DeviceContext::get_eigen_device() { +Eigen::GpuDevice* DeviceContext::get_eigen_device() { return reinterpret_cast(this)->eigen_device(); } #endif diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index d6cf114216d60be9d1e2353595bd2625e0f748ac..94f54d705d1aa18097383f579da1a9cdd62adc48 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -31,16 +31,16 @@ class DeviceContext { virtual Place GetPlace() const = 0; template - DeviceType get_eigen_device(); + DeviceType* get_eigen_device(); }; class CPUDeviceContext : public DeviceContext { public: - Eigen::DefaultDevice eigen_device() { + Eigen::DefaultDevice* eigen_device() { if (!eigen_device_) { - eigen_device_ = new Eigen::DefaultDevice(); + eigen_device_.reset(new Eigen::DefaultDevice()); } - return *eigen_device_; + return eigen_device_.get(); } Place GetPlace() const override { @@ -49,7 +49,7 @@ class CPUDeviceContext : public DeviceContext { } private: - Eigen::DefaultDevice* eigen_device_{nullptr}; + std::unique_ptr eigen_device_; }; #ifndef PADDLE_ONLY_CPU @@ -74,8 +74,8 @@ class CUDADeviceContext : public DeviceContext { GPUPlaceGuard guard(gpu_place_); paddle::platform::throw_on_error(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); - eigen_stream_ = new Eigen::CudaStreamDevice(&stream_); - eigen_device_ = new Eigen::GpuDevice(eigen_stream_); + eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_)); + eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); } Place GetPlace() const override { @@ -90,7 +90,7 @@ class CUDADeviceContext : public DeviceContext { cudaStream_t stream() { return stream_; } - Eigen::GpuDevice eigen_device() { return *eigen_device_; } + Eigen::GpuDevice* eigen_device() { return eigen_device_.get(); } cublasHandle_t cublas_handle() { if (!blas_handle_) { @@ -155,10 +155,8 @@ class CUDADeviceContext : public DeviceContext { rand_generator_) == CURAND_STATUS_SUCCESS, "curandDestroyGenerator failed"); } - - delete eigen_stream_; - delete eigen_device_; - + eigen_stream_.reset(); + eigen_device_.reset(); paddle::platform::throw_on_error(cudaStreamDestroy(stream_), "cudaStreamDestroy failed"); } @@ -167,8 +165,8 @@ class CUDADeviceContext : public DeviceContext { GPUPlace gpu_place_; cudaStream_t stream_; - Eigen::CudaStreamDevice* eigen_stream_; - Eigen::GpuDevice* eigen_device_; + std::unique_ptr eigen_stream_; + std::unique_ptr eigen_device_; cublasHandle_t blas_handle_{nullptr}; diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 913e3c0aa96e150cd4f6192d2e6f6fc2227b4311..af2ce17fc2238dda62e9888ebe9426edcd55d2bc 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -21,9 +21,9 @@ TEST(Device, Init) { for (int i = 0; i < count; i++) { paddle::platform::DeviceContext* device_context = new paddle::platform::CUDADeviceContext(i); - Eigen::GpuDevice gpu_device = + Eigen::GpuDevice* gpu_device = device_context->template get_eigen_device(); - ASSERT_NE(nullptr, gpu_device.stream()); + ASSERT_NE(nullptr, gpu_device); delete device_context; } } @@ -33,8 +33,8 @@ TEST(Device, CUDADeviceContext) { for (int i = 0; i < count; i++) { paddle::platform::CUDADeviceContext* device_context = new paddle::platform::CUDADeviceContext(i); - Eigen::GpuDevice gpu_device = device_context->eigen_device(); - ASSERT_NE(nullptr, gpu_device.stream()); + Eigen::GpuDevice* gpu_device = device_context->eigen_device(); + ASSERT_NE(nullptr, gpu_device); cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); ASSERT_NE(nullptr, cudnn_handle); cublasHandle_t cublas_handle = device_context->cublas_handle();