未验证 提交 71c1cd14 编写于 作者: Y yaoxuefeng 提交者: GitHub

fix truncated_gaussian seed (#28777)

上级 de528981
...@@ -109,13 +109,13 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> { ...@@ -109,13 +109,13 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
thrust::device_ptr<T>(data), thrust::device_ptr<T>(data),
TruncatedNormalOffset<T>(mean, std, std::numeric_limits<T>::min(), TruncatedNormalOffset<T>(mean, std, std::numeric_limits<T>::min(),
seed_offset.first, gen_offset)); seed_offset.first, gen_offset));
} } else {
thrust::transform( thrust::transform(
index_sequence_begin, index_sequence_begin + size, index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data), thrust::device_ptr<T>(data),
TruncatedNormal<T>(mean, std, std::numeric_limits<T>::min(), seed)); TruncatedNormal<T>(mean, std, std::numeric_limits<T>::min(), seed));
} }
}
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册