提交 be2c1a3b 编写于 作者: Q qijun

follow comments

上级 a07deac9
......@@ -15,13 +15,13 @@ namespace paddle {
namespace platform {
template <>
Eigen::DefaultDevice DeviceContext::get_eigen_device<Eigen::DefaultDevice>() {
Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>() {
return reinterpret_cast<CPUDeviceContext*>(this)->eigen_device();
}
#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice DeviceContext::get_eigen_device<Eigen::GpuDevice>() {
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() {
return reinterpret_cast<CUDADeviceContext*>(this)->eigen_device();
}
#endif
......
......@@ -31,16 +31,16 @@ class DeviceContext {
virtual Place GetPlace() const = 0;
template <typename DeviceType>
DeviceType get_eigen_device();
DeviceType* get_eigen_device();
};
class CPUDeviceContext : public DeviceContext {
public:
Eigen::DefaultDevice eigen_device() {
Eigen::DefaultDevice* eigen_device() {
if (!eigen_device_) {
eigen_device_ = new Eigen::DefaultDevice();
eigen_device_.reset(new Eigen::DefaultDevice());
}
return *eigen_device_;
return eigen_device_.get();
}
Place GetPlace() const override {
......@@ -49,7 +49,7 @@ class CPUDeviceContext : public DeviceContext {
}
private:
Eigen::DefaultDevice* eigen_device_{nullptr};
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};
#ifndef PADDLE_ONLY_CPU
......@@ -74,8 +74,8 @@ class CUDADeviceContext : public DeviceContext {
GPUPlaceGuard guard(gpu_place_);
paddle::platform::throw_on_error(cudaStreamCreate(&stream_),
"cudaStreamCreate failed");
eigen_stream_ = new Eigen::CudaStreamDevice(&stream_);
eigen_device_ = new Eigen::GpuDevice(eigen_stream_);
eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}
Place GetPlace() const override {
......@@ -90,7 +90,7 @@ class CUDADeviceContext : public DeviceContext {
cudaStream_t stream() { return stream_; }
Eigen::GpuDevice eigen_device() { return *eigen_device_; }
Eigen::GpuDevice* eigen_device() { return eigen_device_.get(); }
cublasHandle_t cublas_handle() {
if (!blas_handle_) {
......@@ -155,10 +155,8 @@ class CUDADeviceContext : public DeviceContext {
rand_generator_) == CURAND_STATUS_SUCCESS,
"curandDestroyGenerator failed");
}
delete eigen_stream_;
delete eigen_device_;
eigen_stream_.reset();
eigen_device_.reset();
paddle::platform::throw_on_error(cudaStreamDestroy(stream_),
"cudaStreamDestroy failed");
}
......@@ -167,8 +165,8 @@ class CUDADeviceContext : public DeviceContext {
GPUPlace gpu_place_;
cudaStream_t stream_;
Eigen::CudaStreamDevice* eigen_stream_;
Eigen::GpuDevice* eigen_device_;
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_;
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
cublasHandle_t blas_handle_{nullptr};
......
......@@ -21,9 +21,9 @@ TEST(Device, Init) {
for (int i = 0; i < count; i++) {
paddle::platform::DeviceContext* device_context =
new paddle::platform::CUDADeviceContext(i);
Eigen::GpuDevice gpu_device =
Eigen::GpuDevice* gpu_device =
device_context->template get_eigen_device<DEVICE_GPU>();
ASSERT_NE(nullptr, gpu_device.stream());
ASSERT_NE(nullptr, gpu_device);
delete device_context;
}
}
......@@ -33,8 +33,8 @@ TEST(Device, CUDADeviceContext) {
for (int i = 0; i < count; i++) {
paddle::platform::CUDADeviceContext* device_context =
new paddle::platform::CUDADeviceContext(i);
Eigen::GpuDevice gpu_device = device_context->eigen_device();
ASSERT_NE(nullptr, gpu_device.stream());
Eigen::GpuDevice* gpu_device = device_context->eigen_device();
ASSERT_NE(nullptr, gpu_device);
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册