未验证 提交 83740655 编写于 作者: F furnace 提交者: GitHub

[Phi] move gaussian_random, fix fp16 (#40122)

[Phi] move gaussian_random, fix fp16
上级 b7bbe39c
...@@ -81,22 +81,25 @@ void GaussianRandomKernel(const Context& dev_ctx, ...@@ -81,22 +81,25 @@ void GaussianRandomKernel(const Context& dev_ctx,
int device_id = dev_ctx.GetPlace().GetDeviceId(); int device_id = dev_ctx.GetPlace().GetDeviceId();
auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id); auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id);
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
if (gen_cuda->GetIsInitPy() && seed_flag) { if (gen_cuda->GetIsInitPy() && seed_flag) {
if (FLAGS_use_curand) { if (FLAGS_use_curand) {
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
funcs::normal_distribution<MT> dist; funcs::normal_distribution<MT> dist;
funcs::normal_transform<MT> trans(mean, std); funcs::normal_transform<MT> trans(mean, std);
funcs::distribution_and_transform<T>(dev_ctx, tensor, dist, trans); funcs::distribution_and_transform<T>(dev_ctx, tensor, dist, trans);
} else { } else {
auto seed_offset = gen_cuda->IncrementOffset(1); auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second; int64_t gen_offset = size * seed_offset.second;
auto func = auto func = GaussianGenerator<T>(static_cast<T>(mean),
GaussianGenerator<MT>(mean, std, seed_offset.first, gen_offset); static_cast<T>(std),
IndexKernel<T, GaussianGenerator<MT>>(dev_ctx, tensor, func); seed_offset.first,
gen_offset);
IndexKernel<T, GaussianGenerator<T>>(dev_ctx, tensor, func);
} }
} else { } else {
auto func = GaussianGenerator<MT>(mean, std, seed); auto func =
IndexKernel<T, GaussianGenerator<MT>>(dev_ctx, tensor, func); GaussianGenerator<T>(static_cast<T>(mean), static_cast<T>(std), seed);
IndexKernel<T, GaussianGenerator<T>>(dev_ctx, tensor, func);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册