From 2f47f35b3efec36189a4c6757490b897130d3028 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 21 Aug 2017 09:12:25 +0000 Subject: [PATCH] fix gpu build error --- paddle/operators/math/CMakeLists.txt | 4 ++-- paddle/operators/math/math_function.cc | 10 +++++----- paddle/operators/math/math_function.cu | 15 ++++++++------- paddle/operators/math/math_function.h | 7 ++----- paddle/operators/uniform_random_op.cu | 9 +++------ paddle/platform/device_context.cc | 10 +++++----- paddle/platform/device_context.h | 6 +++--- 7 files changed, 28 insertions(+), 33 deletions(-) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index ed51d416ed9..228f463f2bd 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,8 +1,8 @@ if(WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context) + nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context eigen3) else() - cc_library(math_function SRCS math_function.cc DEPS cblas device_context) + cc_library(math_function SRCS math_function.cc DEPS cblas device_context eigen3) endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index d0b1f8ee48f..a098e02f95d 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -110,12 +110,12 @@ void matmul(const framework::Tensor& matrix_a, } template <> -void Set(const int n, const float alpha, - float* output, - platform::DeviceContext* context) { +void Set(const int n, const float alpha, + float* output, + platform::DeviceContext* context) { auto* cpu_context = reinterpret_cast(context); - framework::EigenVector::Type out(output, n); - out.device(*(cpu_context->eigen_device())) = t.constant(T(alpha)); + framework::EigenVector::Type out(output, n); + out.device(*(cpu_context->eigen_device())) = out.constant(float(alpha)); } template <> diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 76bbf790db3..3ff622f3082 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -127,12 +127,12 @@ void matmul(const framework::Tensor& matrix_a, } template <> -void Set(const int n, const float alpha, - float* output, - platform::DeviceContext* context) { +void Set(const int n, const float alpha, + float* output, + platform::DeviceContext* context) { auto* cuda_context = reinterpret_cast(context); - framework::EigenVector::Type out(output, n); - out.device(*(cuda_context->eigen_device())) = t.constant(T(alpha)); + framework::EigenVector::Type out(output, n); + out.device(*(cuda_context->eigen_device())) = out.constant(float(alpha)); } template @@ -159,12 +159,13 @@ void RandUniform(const int n, const float min, template int HandleOddLengthRandGaussian(const int n, const T mean, const T std, - T* output, CUDADeviceContext* context) { + T* output, + platform::CUDADeviceContext* context) { if (n % 2 == 1) { std::default_random_engine generator; std::normal_distribution distribution(mean, std); const T random_value = distribution(generator); - Set(1, random_value, output + (n - 1), context); + Set(1, random_value, output + (n - 1), context); return n - 1; } return n; diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index afe6de7483d..6543a1b515e 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -52,9 +52,9 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, #include +#include "paddle/framework/eigen.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" -#include "paddle/platform/eigen.h" #include "paddle/platform/enforce.h" namespace paddle { @@ -80,10 +80,7 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a, template void Set(const int n, const T alpha, T* output, - platform::DeviceContext* context) { - framework::EigenVector::Type out(output, n); - out.device(*(context->eigen_device())) = t.constant(T(alpha)); -} + platform::DeviceContext* context); template void RandUniform(const int n, const T min, const T max, T* output, diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu index 91368fa73e9..1bfffc47783 100644 --- a/paddle/operators/uniform_random_op.cu +++ b/paddle/operators/uniform_random_op.cu @@ -14,9 +14,6 @@ #include "paddle/operators/uniform_random_op.h" -namespace paddle { -namespace operators { - -REGISTER_OP_GPU_KERNEL(uniform_random, - paddle::operators::GPUUniformRandomKernel< - paddle::platform::GPUPlace, float>); +REGISTER_OP_GPU_KERNEL( + uniform_random, + paddle::operators::UniformRandomKernel); diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 5fd93555a51..ad9b4e42f33 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -25,9 +25,9 @@ CPUDeviceContext::CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } -CPUDeviceContext::CPUDeviceContext(CPUPlace place, int rand_seed) { +CPUDeviceContext::CPUDeviceContext(CPUPlace place, int seed) { eigen_device_.reset(new Eigen::DefaultDevice()); - rand_seed_ = rand_seed; + rand_seed_ = seed; } std::minstd_rand& CPUDeviceContext::rand_engine() { @@ -105,7 +105,7 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device() const { } CUDADeviceContext::CUDADeviceContext(GPUPlace place, uint64_t seed) - : place_(place), seed_(seed) { + : place_(place), rand_seed_(seed) { SetDeviceId(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); @@ -162,8 +162,8 @@ curandGenerator_t CUDADeviceContext::curand_generator() { SetDeviceId(place_.device); PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE( - dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); + PADDLE_ENFORCE(dynload::curandSetPseudoRandomGeneratorSeed( + curand_generator_, rand_seed_)); PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_)); } diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 7013343a8de..e18f48fef59 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -40,7 +40,7 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); - explicit CPUDeviceContext(CPUPlace place, int rand_seed = 0); + explicit CPUDeviceContext(CPUPlace place, int seed = 0); virtual ~CPUDeviceContext() {} Eigen::DefaultDevice* eigen_device() const; @@ -60,7 +60,7 @@ class EigenCudaStreamDevice; class CUDADeviceContext : public DeviceContext { public: - explicit CUDADeviceContext(GPUPlace place, uint64_t rand_seed = 0); + explicit CUDADeviceContext(GPUPlace place, uint64_t seed = 0); virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ @@ -93,12 +93,12 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr eigen_stream_; uint64_t rand_seed_; - std::unique_ptr rand_engine_; // clang-format off cudaStream_t stream_{nullptr}; cudnnHandle_t cudnn_handle_{nullptr}; cublasHandle_t cublas_handle_{nullptr}; + curandGenerator_t curand_generator_{nullptr}; // clang-format on }; -- GitLab