提交 7c274dc0 编写于 作者: Q qijun

use curand

上级 d525abed
...@@ -109,6 +109,15 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a, ...@@ -109,6 +109,15 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a,
matrix_b.data<double>(), beta, matrix_out->data<double>(), context); matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
} }
template <>
void Set<typename CPUPlace, typename float>(const int n, const float alpha,
float* output,
platform::DeviceContext* context) {
auto* cpu_context = reinterpret_cast<platform::CPUDeviceContext*>(context);
framework::EigenVector::Type<T> out(output, n);
out.device(*(cpu_context->eigen_device())) = t.constant(T(alpha));
}
template <> template <>
void RandUniform<platform::CPUPlace, float>(const int n, const float min, void RandUniform<platform::CPUPlace, float>(const int n, const float min,
const float max, float* output, const float max, float* output,
......
...@@ -126,20 +126,48 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a, ...@@ -126,20 +126,48 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a,
matrix_b.data<double>(), beta, matrix_out->data<double>(), context); matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
} }
template <>
void Set<typename GPUPlace, typename float>(const int n, const float alpha,
float* output,
platform::DeviceContext* context) {
auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context);
framework::EigenVector::Type<T> out(output, n);
out.device(*(cuda_context->eigen_device())) = t.constant(T(alpha));
}
template <typename T>
__global__ void UniformShift(const int n, const T min, const T max, T* x) {
float scale = max - min;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
i += blockDim.x * gridDim.x) {
x[i] = x[i] * scale + min;
}
}
template <> template <>
void RandUniform<platform::GPUPlace, float>(const int n, const float min, void RandUniform<platform::GPUPlace, float>(const int n, const float min,
const float max, float* output, const float max, float* output,
platform::DeviceContext* context) { platform::DeviceContext* context) {
auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context); auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context);
thrust::uniform_real_distribution<float> distribution(min, max); PADDLE_ENFORCE(
thrust::minstd_rand engine = cuda_context->rand_enigne(); curandGenerateUniform(cuda_context->curand_generator(), output, n));
engine->discard(n); int block = 512;
int grid = (n + block - 1) / block;
thrust::counting_iterator<unsigned int> index_sequence_begin(0); UniformShift<float><<<grid, block, 0, cuda_context->stream()>>>(n, min, max,
output);
}
thrust::transform(thrust::cuda::par.on(cuda_context->stream()), template <typename T>
index_sequence_begin, index_sequence_begin + n, int HandleOddLengthRandGaussian(const int n, const T mean, const T std,
thrust::device_ptr<float>(output), distribution(engine)); T* output, CUDADeviceContext* context) {
if (n % 2 == 1) {
std::default_random_engine generator;
std::normal_distribution<T> distribution(mean, std);
const T random_value = distribution(generator);
Set<T, platform::GPUPlace>(1, random_value, output + (n - 1), context);
return n - 1;
}
return n;
} }
template <> template <>
...@@ -147,15 +175,11 @@ void RandGaussian<platform::GPUPlace, float>(const int n, const float mean, ...@@ -147,15 +175,11 @@ void RandGaussian<platform::GPUPlace, float>(const int n, const float mean,
const float std, float* output, const float std, float* output,
platform::DeviceContext* context) { platform::DeviceContext* context) {
auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context); auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context);
thrust::normal_distribution<float> distribution(mean, std);
thrust::minstd_rand engine = cuda_context->rand_enigne();
engine->discard(n);
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
thrust::transform(thrust::cuda::par.on(cuda_context->stream()), const int even_n =
index_sequence_begin, index_sequence_begin + n, HandleOddLengthRandGaussian<float>(n, mean, std, output, cuda_context);
thrust::device_ptr<float>(output), distribution(engine)); PADDLE_ENFORCE(curandGenerateNormal(cuda_context->curand_generator(), output,
even_n, mean, std));
} }
} // namespace math } // namespace math
......
...@@ -54,6 +54,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, ...@@ -54,6 +54,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/platform/eigen.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -77,6 +78,13 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a, ...@@ -77,6 +78,13 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a,
framework::Tensor* matrix_out, T beta, framework::Tensor* matrix_out, T beta,
platform::DeviceContext* context); platform::DeviceContext* context);
template <typename Place, typename T>
void Set(const int n, const T alpha, T* output,
platform::DeviceContext* context) {
framework::EigenVector::Type<T> out(output, n);
out.device(*(context->eigen_device())) = t.constant(T(alpha));
}
template <typename Place, typename T> template <typename Place, typename T>
void RandUniform(const int n, const T min, const T max, T* output, void RandUniform(const int n, const T min, const T max, T* output,
platform::DeviceContext* context); platform::DeviceContext* context);
......
...@@ -157,12 +157,17 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { ...@@ -157,12 +157,17 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() {
return cudnn_handle_; return cudnn_handle_;
} }
thrust::minstd_rand& CPUDeviceContext::rand_engine() { curandGenerator_t CUDADeviceContext::curand_generator() {
if (!rand_engine_) { if (!curand_generator_) {
rand_engine_.reset(new thrust::minstd_rand()); SetDeviceId(place_.device);
rand_engine_->seed(rand_seed_); PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_,
CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_));
PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_));
} }
return *(rand_engine_.get()); return curand_generator_;
} }
cudaStream_t CUDADeviceContext::stream() { return stream_; } cudaStream_t CUDADeviceContext::stream() { return stream_; }
......
...@@ -15,10 +15,9 @@ limitations under the License. */ ...@@ -15,10 +15,9 @@ limitations under the License. */
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
#include <thrust/device_ptr.h>
#include <thrust/random.h>
#include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/gpu_info.h" #include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
...@@ -80,7 +79,8 @@ class CUDADeviceContext : public DeviceContext { ...@@ -80,7 +79,8 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */ /*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle(); cudnnHandle_t cudnn_handle();
thrust::minstd_rand& CPUDeviceContext::rand_engine(); /*! \brief Return curand handle in the device context. */
curandGenerator_t curand_generator();
/*! \brief Return cuda stream in the device context. */ /*! \brief Return cuda stream in the device context. */
cudaStream_t stream(); cudaStream_t stream();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册