提交 08c987d7 编写于 作者: Q qijun

use dynload curand

上级 2f47f35b
...@@ -34,5 +34,5 @@ class GaussianRandomKernel : public framework::OpKernel { ...@@ -34,5 +34,5 @@ class GaussianRandomKernel : public framework::OpKernel {
math::RandGaussian<Place, T>(n, mean, std, data, device_context); math::RandGaussian<Place, T>(n, mean, std, data, device_context);
} }
}; };
} } // namespace operators
} } // namespace paddle
...@@ -149,8 +149,8 @@ void RandUniform<platform::GPUPlace, float>(const int n, const float min, ...@@ -149,8 +149,8 @@ void RandUniform<platform::GPUPlace, float>(const int n, const float min,
const float max, float* output, const float max, float* output,
platform::DeviceContext* context) { platform::DeviceContext* context) {
auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context); auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context);
PADDLE_ENFORCE( PADDLE_ENFORCE(platform::dynload::curandGenerateUniform(
curandGenerateUniform(cuda_context->curand_generator(), output, n)); cuda_context->curand_generator(), output, n));
int block = 512; int block = 512;
int grid = (n + block - 1) / block; int grid = (n + block - 1) / block;
UniformShift<float><<<grid, block, 0, cuda_context->stream()>>>(n, min, max, UniformShift<float><<<grid, block, 0, cuda_context->stream()>>>(n, min, max,
...@@ -179,8 +179,8 @@ void RandGaussian<platform::GPUPlace, float>(const int n, const float mean, ...@@ -179,8 +179,8 @@ void RandGaussian<platform::GPUPlace, float>(const int n, const float mean,
const int even_n = const int even_n =
HandleOddLengthRandGaussian<float>(n, mean, std, output, cuda_context); HandleOddLengthRandGaussian<float>(n, mean, std, output, cuda_context);
PADDLE_ENFORCE(curandGenerateNormal(cuda_context->curand_generator(), output, PADDLE_ENFORCE(platform::dynload::curandGenerateNormal(
even_n, mean, std)); cuda_context->curand_generator(), output, even_n, mean, std));
} }
} // namespace math } // namespace math
......
...@@ -34,5 +34,5 @@ class UniformRandomKernel : public framework::OpKernel { ...@@ -34,5 +34,5 @@ class UniformRandomKernel : public framework::OpKernel {
math::RandUniform<Place, T>(n, min, max, data, device_context); math::RandUniform<Place, T>(n, min, max, data, device_context);
} }
}; };
} } // namespace operators
} } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册