未验证 提交 03f46e35 编写于 作者: Y yaoxuefeng 提交者: GitHub

fix truncated_gaussian op cuda seed setting (#28678)

上级 60a5eb68
...@@ -71,7 +71,7 @@ struct TruncatedNormalOffset { ...@@ -71,7 +71,7 @@ struct TruncatedNormalOffset {
thrust::minstd_rand rng; thrust::minstd_rand rng;
rng.seed(seed); rng.seed(seed);
thrust::uniform_real_distribution<T> dist(numeric_min, 1); thrust::uniform_real_distribution<T> dist(numeric_min, 1);
rng.discard(n); rng.discard(n + offset_);
T value = dist(rng); T value = dist(rng);
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean; return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean;
...@@ -108,7 +108,7 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> { ...@@ -108,7 +108,7 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
index_sequence_begin, index_sequence_begin + size, index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data), thrust::device_ptr<T>(data),
TruncatedNormalOffset<T>(mean, std, std::numeric_limits<T>::min(), TruncatedNormalOffset<T>(mean, std, std::numeric_limits<T>::min(),
seed_offset.first, seed_offset.second)); seed_offset.first, gen_offset));
} }
thrust::transform( thrust::transform(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册