From 7c274dc0a16b77fae0faf527ef02a1f72abad593 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 21 Aug 2017 16:41:22 +0800 Subject: [PATCH] use curand --- paddle/operators/math/math_function.cc | 9 +++++ paddle/operators/math/math_function.cu | 56 ++++++++++++++++++-------- paddle/operators/math/math_function.h | 8 ++++ paddle/platform/device_context.cc | 15 ++++--- paddle/platform/device_context.h | 6 +-- 5 files changed, 70 insertions(+), 24 deletions(-) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index da59044899..d0b1f8ee48 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -109,6 +109,15 @@ void matmul(const framework::Tensor& matrix_a, matrix_b.data(), beta, matrix_out->data(), context); } +template <> +void Set(const int n, const float alpha, + float* output, + platform::DeviceContext* context) { + auto* cpu_context = reinterpret_cast(context); + framework::EigenVector::Type out(output, n); + out.device(*(cpu_context->eigen_device())) = t.constant(T(alpha)); +} + template <> void RandUniform(const int n, const float min, const float max, float* output, diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 5a400d4445..76bbf790db 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -126,20 +126,48 @@ void matmul(const framework::Tensor& matrix_a, matrix_b.data(), beta, matrix_out->data(), context); } +template <> +void Set(const int n, const float alpha, + float* output, + platform::DeviceContext* context) { + auto* cuda_context = reinterpret_cast(context); + framework::EigenVector::Type out(output, n); + out.device(*(cuda_context->eigen_device())) = t.constant(T(alpha)); +} + +template +__global__ void UniformShift(const int n, const T min, const T max, T* x) { + float scale = max - min; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; + i += blockDim.x * gridDim.x) { + x[i] = x[i] * scale + min; + } +} + template <> void RandUniform(const int n, const float min, const float max, float* output, platform::DeviceContext* context) { auto* cuda_context = reinterpret_cast(context); - thrust::uniform_real_distribution distribution(min, max); - thrust::minstd_rand engine = cuda_context->rand_enigne(); - engine->discard(n); - - thrust::counting_iterator index_sequence_begin(0); + PADDLE_ENFORCE( + curandGenerateUniform(cuda_context->curand_generator(), output, n)); + int block = 512; + int grid = (n + block - 1) / block; + UniformShift<<stream()>>>(n, min, max, + output); +} - thrust::transform(thrust::cuda::par.on(cuda_context->stream()), - index_sequence_begin, index_sequence_begin + n, - thrust::device_ptr(output), distribution(engine)); +template +int HandleOddLengthRandGaussian(const int n, const T mean, const T std, + T* output, CUDADeviceContext* context) { + if (n % 2 == 1) { + std::default_random_engine generator; + std::normal_distribution distribution(mean, std); + const T random_value = distribution(generator); + Set(1, random_value, output + (n - 1), context); + return n - 1; + } + return n; } template <> @@ -147,15 +175,11 @@ void RandGaussian(const int n, const float mean, const float std, float* output, platform::DeviceContext* context) { auto* cuda_context = reinterpret_cast(context); - thrust::normal_distribution distribution(mean, std); - thrust::minstd_rand engine = cuda_context->rand_enigne(); - engine->discard(n); - - thrust::counting_iterator index_sequence_begin(0); - thrust::transform(thrust::cuda::par.on(cuda_context->stream()), - index_sequence_begin, index_sequence_begin + n, - thrust::device_ptr(output), distribution(engine)); + const int even_n = + HandleOddLengthRandGaussian(n, mean, std, output, cuda_context); + PADDLE_ENFORCE(curandGenerateNormal(cuda_context->curand_generator(), output, + even_n, mean, std)); } } // namespace math diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index ea15e8fd2b..afe6de7483 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -54,6 +54,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" +#include "paddle/platform/eigen.h" #include "paddle/platform/enforce.h" namespace paddle { @@ -77,6 +78,13 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a, framework::Tensor* matrix_out, T beta, platform::DeviceContext* context); +template +void Set(const int n, const T alpha, T* output, + platform::DeviceContext* context) { + framework::EigenVector::Type out(output, n); + out.device(*(context->eigen_device())) = t.constant(T(alpha)); +} + template void RandUniform(const int n, const T min, const T max, T* output, platform::DeviceContext* context); diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index fabbb55443..5fd93555a5 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -157,12 +157,17 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { return cudnn_handle_; } -thrust::minstd_rand& CPUDeviceContext::rand_engine() { - if (!rand_engine_) { - rand_engine_.reset(new thrust::minstd_rand()); - rand_engine_->seed(rand_seed_); +curandGenerator_t CUDADeviceContext::curand_generator() { + if (!curand_generator_) { + SetDeviceId(place_.device); + PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, + CURAND_RNG_PSEUDO_DEFAULT)); + PADDLE_ENFORCE( + dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); + + PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_)); } - return *(rand_engine_.get()); + return curand_generator_; } cudaStream_t CUDADeviceContext::stream() { return stream_; } diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index e4de3807cd..7013343a8d 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -15,10 +15,9 @@ limitations under the License. */ #include "paddle/platform/place.h" #ifndef PADDLE_ONLY_CPU -#include -#include #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" +#include "paddle/platform/dynload/curand.h" #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif @@ -80,7 +79,8 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle(); - thrust::minstd_rand& CPUDeviceContext::rand_engine(); + /*! \brief Return curand handle in the device context. */ + curandGenerator_t curand_generator(); /*! \brief Return cuda stream in the device context. */ cudaStream_t stream(); -- GitLab