未验证 提交 44ed8f2d 编写于 作者: L Leo Chen 提交者: GitHub

fix bug when beta1_pow on cpu (#31995)

上级 bc7a3afa
...@@ -61,8 +61,20 @@ class AdamNPUKernel : public framework::OpKernel<T> { ...@@ -61,8 +61,20 @@ class AdamNPUKernel : public framework::OpKernel<T> {
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
mom1_out->mutable_data<T>(ctx.GetPlace()); mom1_out->mutable_data<T>(ctx.GetPlace());
mom2_out->mutable_data<T>(ctx.GetPlace()); mom2_out->mutable_data<T>(ctx.GetPlace());
beta1_pow_out->mutable_data<T>(ctx.GetPlace());
beta2_pow_out->mutable_data<T>(ctx.GetPlace()); // NOTE(zhiqiu): beta1_pow and beta2_pow may on CPU and not transform place.
if (beta1_pow->place() == platform::CPUPlace()) {
float beta1 = *beta1_pow->data<float>();
beta1_pow_out->mutable_data<T>(ctx.GetPlace());
TensorFromVector(std::vector<float>{beta1}, ctx.device_context(),
beta1_pow_out);
}
if (beta2_pow->place() == platform::CPUPlace()) {
float beta2 = *beta2_pow->data<float>();
beta2_pow_out->mutable_data<T>(ctx.GetPlace());
TensorFromVector(std::vector<float>{beta2}, ctx.device_context(),
beta2_pow_out);
}
T beta1 = static_cast<T>(ctx.Attr<float>("beta1")); T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
if (ctx.HasInput("Beta1Tensor")) { if (ctx.HasInput("Beta1Tensor")) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册