提交 08c987d7 编写于 作者: Q qijun

use dynload curand

上级 2f47f35b
......@@ -34,5 +34,5 @@ class GaussianRandomKernel : public framework::OpKernel {
math::RandGaussian<Place, T>(n, mean, std, data, device_context);
}
};
}
}
} // namespace operators
} // namespace paddle
......@@ -149,8 +149,8 @@ void RandUniform<platform::GPUPlace, float>(const int n, const float min,
const float max, float* output,
platform::DeviceContext* context) {
auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context);
PADDLE_ENFORCE(
curandGenerateUniform(cuda_context->curand_generator(), output, n));
PADDLE_ENFORCE(platform::dynload::curandGenerateUniform(
cuda_context->curand_generator(), output, n));
int block = 512;
int grid = (n + block - 1) / block;
UniformShift<float><<<grid, block, 0, cuda_context->stream()>>>(n, min, max,
......@@ -179,8 +179,8 @@ void RandGaussian<platform::GPUPlace, float>(const int n, const float mean,
const int even_n =
HandleOddLengthRandGaussian<float>(n, mean, std, output, cuda_context);
PADDLE_ENFORCE(curandGenerateNormal(cuda_context->curand_generator(), output,
even_n, mean, std));
PADDLE_ENFORCE(platform::dynload::curandGenerateNormal(
cuda_context->curand_generator(), output, even_n, mean, std));
}
} // namespace math
......
......@@ -34,5 +34,5 @@ class UniformRandomKernel : public framework::OpKernel {
math::RandUniform<Place, T>(n, min, max, data, device_context);
}
};
}
}
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册