未验证 提交 97ec57fe 编写于 作者: S Sławomir Siwek 提交者: GitHub

add seed check (#46747)

上级 ff0171e4
...@@ -28,8 +28,13 @@ void GaussianRandomKernel(const Context& ctx, ...@@ -28,8 +28,13 @@ void GaussianRandomKernel(const Context& ctx,
DataType dtype, DataType dtype,
DenseTensor* out) { DenseTensor* out) {
std::normal_distribution<T> dist(mean, std); std::normal_distribution<T> dist(mean, std);
auto engine = std::make_shared<std::mt19937_64>(); std::shared_ptr<std::mt19937_64> engine;
if (seed) {
engine = std::make_shared<std::mt19937_64>();
engine->seed(seed); engine->seed(seed);
} else {
engine = ctx.GetGenerator()->GetCPUEngine();
}
T* data = ctx.template Alloc<T>(out); T* data = ctx.template Alloc<T>(out);
for (int64_t i = 0; i < out->numel(); ++i) { for (int64_t i = 0; i < out->numel(); ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册