From c084a7b12063dd398b8db628671d88b84bec199f Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 7 Sep 2022 09:26:33 +0800 Subject: [PATCH] Fix UpdateLossScalingKernel to prevent data transform error (#45809) * fix amp kernel * update to remove PADDLE_WITH_XPU macro --- paddle/phi/infermeta/multiary.cc | 6 ++++-- paddle/phi/kernels/gpu/amp_kernel.cu | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index f43660bfdb..e5835dc56d 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2479,8 +2479,10 @@ void UpdateLossScalingInferMeta(const std::vector& xs, xs.size(), outs.size())); for (size_t i = 0; i < xs.size(); ++i) { - outs[i]->set_dims(xs[i]->dims()); - outs[i]->set_dtype(xs[i]->dtype()); + if (xs[i] != nullptr && outs[i] != nullptr) { + outs[i]->set_dims(xs[i]->dims()); + outs[i]->set_dtype(xs[i]->dtype()); + } } loss_scaling->set_dims({1}); out_good_steps->set_dims({1}); diff --git a/paddle/phi/kernels/gpu/amp_kernel.cu b/paddle/phi/kernels/gpu/amp_kernel.cu index 51e11cc44b..230eb801d2 100644 --- a/paddle/phi/kernels/gpu/amp_kernel.cu +++ b/paddle/phi/kernels/gpu/amp_kernel.cu @@ -365,4 +365,6 @@ PD_REGISTER_KERNEL(update_loss_scaling, phi::UpdateLossScalingKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); +} -- GitLab