diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index f43660bfdbe8ce14da2785e38d9ece7255e9a8f5..e5835dc56d4718267e4715f77b35b862023591b3 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2479,8 +2479,10 @@ void UpdateLossScalingInferMeta(const std::vector& xs, xs.size(), outs.size())); for (size_t i = 0; i < xs.size(); ++i) { - outs[i]->set_dims(xs[i]->dims()); - outs[i]->set_dtype(xs[i]->dtype()); + if (xs[i] != nullptr && outs[i] != nullptr) { + outs[i]->set_dims(xs[i]->dims()); + outs[i]->set_dtype(xs[i]->dtype()); + } } loss_scaling->set_dims({1}); out_good_steps->set_dims({1}); diff --git a/paddle/phi/kernels/gpu/amp_kernel.cu b/paddle/phi/kernels/gpu/amp_kernel.cu index 51e11cc44b8563b77b67b70b43e46a958f04e9ae..230eb801d20d51473799faa99942e665b6ef4fbd 100644 --- a/paddle/phi/kernels/gpu/amp_kernel.cu +++ b/paddle/phi/kernels/gpu/amp_kernel.cu @@ -365,4 +365,6 @@ PD_REGISTER_KERNEL(update_loss_scaling, phi::UpdateLossScalingKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); +}