未验证 提交 2e1fd1de 编写于 作者: P pangyoki 提交者: GitHub

fix adam bug again (#32246)

上级 260ef770
......@@ -61,17 +61,22 @@ class AdamNPUKernel : public framework::OpKernel<T> {
param_out->mutable_data<T>(ctx.GetPlace());
mom1_out->mutable_data<T>(ctx.GetPlace());
mom2_out->mutable_data<T>(ctx.GetPlace());
beta1_pow_out->mutable_data<T>(ctx.GetPlace());
beta2_pow_out->mutable_data<T>(ctx.GetPlace());
// NOTE(zhiqiu): beta1_pow and beta2_pow may on CPU and not transform place.
if (beta1_pow->place() == platform::CPUPlace()) {
T beta1 = *beta1_pow->data<T>();
// `mutable_data` operation needs to be done after getting data
beta1_pow_out->mutable_data<T>(ctx.GetPlace());
FillNpuTensorWithConstant<T>(beta1_pow_out, beta1);
} else {
beta1_pow_out->mutable_data<T>(ctx.GetPlace());
}
if (beta2_pow->place() == platform::CPUPlace()) {
T beta2 = *beta2_pow->data<T>();
beta2_pow_out->mutable_data<T>(ctx.GetPlace());
FillNpuTensorWithConstant<T>(beta2_pow_out, beta2);
} else {
beta2_pow_out->mutable_data<T>(ctx.GetPlace());
}
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册