From 837406551260414ab18689251e3b2422a10faf69 Mon Sep 17 00:00:00 2001 From: furnace <34057289+windstamp@users.noreply.github.com> Date: Fri, 4 Mar 2022 15:47:46 +0800 Subject: [PATCH] [Phi] move gaussian_random, fix fp16 (#40122) [Phi] move gaussian_random, fix fp16 --- paddle/phi/kernels/gpu/gaussian_random_kernel.cu | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/gpu/gaussian_random_kernel.cu b/paddle/phi/kernels/gpu/gaussian_random_kernel.cu index d5acc60a36..da16800ad0 100644 --- a/paddle/phi/kernels/gpu/gaussian_random_kernel.cu +++ b/paddle/phi/kernels/gpu/gaussian_random_kernel.cu @@ -81,22 +81,25 @@ void GaussianRandomKernel(const Context& dev_ctx, int device_id = dev_ctx.GetPlace().GetDeviceId(); auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id); - using MT = typename phi::kps::details::MPTypeTrait::Type; if (gen_cuda->GetIsInitPy() && seed_flag) { if (FLAGS_use_curand) { + using MT = typename phi::kps::details::MPTypeTrait::Type; funcs::normal_distribution dist; funcs::normal_transform trans(mean, std); funcs::distribution_and_transform(dev_ctx, tensor, dist, trans); } else { auto seed_offset = gen_cuda->IncrementOffset(1); int64_t gen_offset = size * seed_offset.second; - auto func = - GaussianGenerator(mean, std, seed_offset.first, gen_offset); - IndexKernel>(dev_ctx, tensor, func); + auto func = GaussianGenerator(static_cast(mean), + static_cast(std), + seed_offset.first, + gen_offset); + IndexKernel>(dev_ctx, tensor, func); } } else { - auto func = GaussianGenerator(mean, std, seed); - IndexKernel>(dev_ctx, tensor, func); + auto func = + GaussianGenerator(static_cast(mean), static_cast(std), seed); + IndexKernel>(dev_ctx, tensor, func); } } -- GitLab