未验证 提交 c084a7b1 编写于 作者: S sneaxiy 提交者: GitHub

Fix UpdateLossScalingKernel to prevent data transform error (#45809)

* fix amp kernel

* update to remove PADDLE_WITH_XPU macro
上级 68f99b78
...@@ -2479,9 +2479,11 @@ void UpdateLossScalingInferMeta(const std::vector<const MetaTensor*>& xs, ...@@ -2479,9 +2479,11 @@ void UpdateLossScalingInferMeta(const std::vector<const MetaTensor*>& xs,
xs.size(), xs.size(),
outs.size())); outs.size()));
for (size_t i = 0; i < xs.size(); ++i) { for (size_t i = 0; i < xs.size(); ++i) {
if (xs[i] != nullptr && outs[i] != nullptr) {
outs[i]->set_dims(xs[i]->dims()); outs[i]->set_dims(xs[i]->dims());
outs[i]->set_dtype(xs[i]->dtype()); outs[i]->set_dtype(xs[i]->dtype());
} }
}
loss_scaling->set_dims({1}); loss_scaling->set_dims({1});
out_good_steps->set_dims({1}); out_good_steps->set_dims({1});
out_good_steps->set_dtype(DataType::INT32); out_good_steps->set_dtype(DataType::INT32);
......
...@@ -365,4 +365,6 @@ PD_REGISTER_KERNEL(update_loss_scaling, ...@@ -365,4 +365,6 @@ PD_REGISTER_KERNEL(update_loss_scaling,
phi::UpdateLossScalingKernel, phi::UpdateLossScalingKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册