未验证 提交 c084a7b1 编写于 作者: S sneaxiy 提交者: GitHub

Fix UpdateLossScalingKernel to prevent data transform error (#45809)

* fix amp kernel

* update to remove PADDLE_WITH_XPU macro
上级 68f99b78
......@@ -2479,8 +2479,10 @@ void UpdateLossScalingInferMeta(const std::vector<const MetaTensor*>& xs,
xs.size(),
outs.size()));
for (size_t i = 0; i < xs.size(); ++i) {
outs[i]->set_dims(xs[i]->dims());
outs[i]->set_dtype(xs[i]->dtype());
if (xs[i] != nullptr && outs[i] != nullptr) {
outs[i]->set_dims(xs[i]->dims());
outs[i]->set_dtype(xs[i]->dtype());
}
}
loss_scaling->set_dims({1});
out_good_steps->set_dims({1});
......
......@@ -365,4 +365,6 @@ PD_REGISTER_KERNEL(update_loss_scaling,
phi::UpdateLossScalingKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册