未验证 提交 03f46e35 编写于 作者: Y yaoxuefeng 提交者: GitHub

fix truncated_gaussian op cuda seed setting (#28678)

上级 60a5eb68
......@@ -71,7 +71,7 @@ struct TruncatedNormalOffset {
thrust::minstd_rand rng;
rng.seed(seed);
thrust::uniform_real_distribution<T> dist(numeric_min, 1);
rng.discard(n);
rng.discard(n + offset_);
T value = dist(rng);
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean;
......@@ -108,7 +108,7 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
TruncatedNormalOffset<T>(mean, std, std::numeric_limits<T>::min(),
seed_offset.first, seed_offset.second));
seed_offset.first, gen_offset));
}
thrust::transform(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册