未验证 提交 83740655 编写于 作者: F furnace 提交者: GitHub

[Phi] move gaussian_random, fix fp16 (#40122)

[Phi] move gaussian_random, fix fp16
上级 b7bbe39c
......@@ -81,22 +81,25 @@ void GaussianRandomKernel(const Context& dev_ctx,
int device_id = dev_ctx.GetPlace().GetDeviceId();
auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id);
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
if (gen_cuda->GetIsInitPy() && seed_flag) {
if (FLAGS_use_curand) {
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
funcs::normal_distribution<MT> dist;
funcs::normal_transform<MT> trans(mean, std);
funcs::distribution_and_transform<T>(dev_ctx, tensor, dist, trans);
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
auto func =
GaussianGenerator<MT>(mean, std, seed_offset.first, gen_offset);
IndexKernel<T, GaussianGenerator<MT>>(dev_ctx, tensor, func);
auto func = GaussianGenerator<T>(static_cast<T>(mean),
static_cast<T>(std),
seed_offset.first,
gen_offset);
IndexKernel<T, GaussianGenerator<T>>(dev_ctx, tensor, func);
}
} else {
auto func = GaussianGenerator<MT>(mean, std, seed);
IndexKernel<T, GaussianGenerator<MT>>(dev_ctx, tensor, func);
auto func =
GaussianGenerator<T>(static_cast<T>(mean), static_cast<T>(std), seed);
IndexKernel<T, GaussianGenerator<T>>(dev_ctx, tensor, func);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册