diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index d40e49b546efe43869da138693f2d78ad0c29831..e93592cc4cbacca9ee0202a24b3c813cc23c2202 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -5,5 +5,4 @@ nv_test(cuda_test SRCS cuda_test.cu) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) -nv_test(cuda_device_test SRCS cuda_device_test.cc DEPS dynload_cuda dynamic_loader eigen3 place) nv_test(device_context_test SRCS device_context_test.cc DEPS dynload_cuda dynamic_loader eigen3 place) diff --git a/paddle/platform/cuda_device_context.h b/paddle/platform/cuda_device_context.h index c38dcd5a6158df92da2f7e19599ab3247df604f3..8a9d15e8a8410c7f268ad32d798c408bfbe949bc 100644 --- a/paddle/platform/cuda_device_context.h +++ b/paddle/platform/cuda_device_context.h @@ -20,19 +20,13 @@ limitations under the License. */ #include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/curand.h" #define EIGEN_USE_GPU +#include "paddle/platform/device_context.h" #include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace platform { -class CUDADeviceContext; - -template <> -Eigen::GpuDevice DeviceContext::get_eigen_device() { - return static_cast(this)->eigen_handle(); -} - class GPUPlaceGuard { public: explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { @@ -49,7 +43,7 @@ class GPUPlaceGuard { class CUDADeviceContext : public DeviceContext { public: - explicit Device(const GPUPlace gpu_place) : gpu_place_(gpu_place) { + explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { GPUPlaceGuard guard(gpu_place_); paddle::platform::throw_on_error(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); @@ -156,5 +150,10 @@ class CUDADeviceContext : public DeviceContext { int random_seed_; curandGenerator_t rand_generator_{nullptr}; }; + +template <> +Eigen::GpuDevice DeviceContext::get_eigen_device() { + return dynamic_cast(this)->eigen_device(); +} } // namespace platform } // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index d2a516999170efd5f5679960670db5b534de9d2f..d2f7cf62166c287e987d28c6ad1a191dd0ae9a84 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -20,30 +20,23 @@ limitations under the License. */ namespace paddle { namespace platform { -class CPUDeviceContext; - class DeviceContext { public: virtual ~DeviceContext() {} template - DeviceType get_eigen_device(); + inline DeviceType get_eigen_device(); virtual Place GetPlace() const = 0; }; -template <> -Eigen::DefaultDevice DeviceContext::get_eigen_device() { - return static_cast(this)->eigen_handle(); -} - class CPUDeviceContext : public DeviceContext { public: - Eigen::DefaultDevice eigen_handle() { - if (!eigen_handle_) { - eigen_handle_ = new Eigen::DefaultDevice(); + Eigen::DefaultDevice eigen_device() { + if (!eigen_device_) { + eigen_device_ = new Eigen::DefaultDevice(); } - return *eigen_handle_; + return *eigen_device_; } Place GetPlace() const override { @@ -52,7 +45,12 @@ class CPUDeviceContext : public DeviceContext { } private: - Eigen::DefaultDevice* eigen_handle_{nullptr}; + Eigen::DefaultDevice* eigen_device_{nullptr}; }; + +template <> +Eigen::DefaultDevice DeviceContext::get_eigen_device() { + return dynamic_cast(this)->eigen_device(); +} } // namespace platform } // namespace paddle diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 8390e97b15f174bf7e101d73ce97e575f08a1ac3..abaaaececf5b0d623be90d7c9e67b5b6da05f803 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -13,15 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" -#include "paddle/platform/cuda_device.h" +#include "paddle/platform/cuda_device_context.h" +using DEVICE_GPU = Eigen::GpuDevice; TEST(Device, Init) { int count = paddle::platform::GetDeviceCount(); for (int i = 0; i < count; i++) { paddle::platform::DeviceContext* device_context = new paddle::platform::CUDADeviceContext(i); Eigen::GpuDevice gpu_device = - device_context->get_eigen_device(); + device_context->template get_eigen_device(); ASSERT_NE(nullptr, gpu_device.stream()); delete device_context; }