提交 8ee50a35 编写于 作者: Q qijun

fix gpu build error

上级 ca23d861
...@@ -5,5 +5,4 @@ nv_test(cuda_test SRCS cuda_test.cu) ...@@ -5,5 +5,4 @@ nv_test(cuda_test SRCS cuda_test.cu)
cc_library(place SRCS place.cc) cc_library(place SRCS place.cc)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags) cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
nv_test(cuda_device_test SRCS cuda_device_test.cc DEPS dynload_cuda dynamic_loader eigen3 place)
nv_test(device_context_test SRCS device_context_test.cc DEPS dynload_cuda dynamic_loader eigen3 place) nv_test(device_context_test SRCS device_context_test.cc DEPS dynload_cuda dynamic_loader eigen3 place)
...@@ -20,19 +20,13 @@ limitations under the License. */ ...@@ -20,19 +20,13 @@ limitations under the License. */
#include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h" #include "paddle/platform/dynload/curand.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
class CUDADeviceContext;
template <>
Eigen::GpuDevice DeviceContext::get_eigen_device<Eigen::GpuDevice>() {
return static_cast<CUDADeviceContext*>(this)->eigen_handle();
}
class GPUPlaceGuard { class GPUPlaceGuard {
public: public:
explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
...@@ -49,7 +43,7 @@ class GPUPlaceGuard { ...@@ -49,7 +43,7 @@ class GPUPlaceGuard {
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
public: public:
explicit Device(const GPUPlace gpu_place) : gpu_place_(gpu_place) { explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
GPUPlaceGuard guard(gpu_place_); GPUPlaceGuard guard(gpu_place_);
paddle::platform::throw_on_error(cudaStreamCreate(&stream_), paddle::platform::throw_on_error(cudaStreamCreate(&stream_),
"cudaStreamCreate failed"); "cudaStreamCreate failed");
...@@ -156,5 +150,10 @@ class CUDADeviceContext : public DeviceContext { ...@@ -156,5 +150,10 @@ class CUDADeviceContext : public DeviceContext {
int random_seed_; int random_seed_;
curandGenerator_t rand_generator_{nullptr}; curandGenerator_t rand_generator_{nullptr};
}; };
template <>
Eigen::GpuDevice DeviceContext::get_eigen_device<Eigen::GpuDevice>() {
return dynamic_cast<CUDADeviceContext*>(this)->eigen_device();
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -20,30 +20,23 @@ limitations under the License. */ ...@@ -20,30 +20,23 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
class CPUDeviceContext;
class DeviceContext { class DeviceContext {
public: public:
virtual ~DeviceContext() {} virtual ~DeviceContext() {}
template <typename DeviceType> template <typename DeviceType>
DeviceType get_eigen_device(); inline DeviceType get_eigen_device();
virtual Place GetPlace() const = 0; virtual Place GetPlace() const = 0;
}; };
template <>
Eigen::DefaultDevice DeviceContext::get_eigen_device<Eigen::DefaultDevice>() {
return static_cast<CPUDeviceContext*>(this)->eigen_handle();
}
class CPUDeviceContext : public DeviceContext { class CPUDeviceContext : public DeviceContext {
public: public:
Eigen::DefaultDevice eigen_handle() { Eigen::DefaultDevice eigen_device() {
if (!eigen_handle_) { if (!eigen_device_) {
eigen_handle_ = new Eigen::DefaultDevice(); eigen_device_ = new Eigen::DefaultDevice();
} }
return *eigen_handle_; return *eigen_device_;
} }
Place GetPlace() const override { Place GetPlace() const override {
...@@ -52,7 +45,12 @@ class CPUDeviceContext : public DeviceContext { ...@@ -52,7 +45,12 @@ class CPUDeviceContext : public DeviceContext {
} }
private: private:
Eigen::DefaultDevice* eigen_handle_{nullptr}; Eigen::DefaultDevice* eigen_device_{nullptr};
}; };
template <>
Eigen::DefaultDevice DeviceContext::get_eigen_device<Eigen::DefaultDevice>() {
return dynamic_cast<CPUDeviceContext*>(this)->eigen_device();
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -13,15 +13,16 @@ See the License for the specific language governing permissions and ...@@ -13,15 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/platform/cuda_device.h" #include "paddle/platform/cuda_device_context.h"
using DEVICE_GPU = Eigen::GpuDevice;
TEST(Device, Init) { TEST(Device, Init) {
int count = paddle::platform::GetDeviceCount(); int count = paddle::platform::GetDeviceCount();
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
paddle::platform::DeviceContext* device_context = paddle::platform::DeviceContext* device_context =
new paddle::platform::CUDADeviceContext(i); new paddle::platform::CUDADeviceContext(i);
Eigen::GpuDevice gpu_device = Eigen::GpuDevice gpu_device =
device_context->get_eigen_device<DEVICE_GPU>(); device_context->template get_eigen_device<DEVICE_GPU>();
ASSERT_NE(nullptr, gpu_device.stream()); ASSERT_NE(nullptr, gpu_device.stream());
delete device_context; delete device_context;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册