From 2e1fd1de38496d927e34bafceeb941649601352a Mon Sep 17 00:00:00 2001 From: pangyoki Date: Wed, 14 Apr 2021 16:00:24 +0800 Subject: [PATCH] fix adam bug again (#32246) --- paddle/fluid/operators/optimizers/adam_op_npu.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op_npu.cc b/paddle/fluid/operators/optimizers/adam_op_npu.cc index 6592022711e..b024aca21c3 100644 --- a/paddle/fluid/operators/optimizers/adam_op_npu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_npu.cc @@ -61,17 +61,22 @@ class AdamNPUKernel : public framework::OpKernel { param_out->mutable_data(ctx.GetPlace()); mom1_out->mutable_data(ctx.GetPlace()); mom2_out->mutable_data(ctx.GetPlace()); - beta1_pow_out->mutable_data(ctx.GetPlace()); - beta2_pow_out->mutable_data(ctx.GetPlace()); // NOTE(zhiqiu): beta1_pow and beta2_pow may on CPU and not transform place. if (beta1_pow->place() == platform::CPUPlace()) { T beta1 = *beta1_pow->data(); + // `mutable_data` operation needs to be done after getting data + beta1_pow_out->mutable_data(ctx.GetPlace()); FillNpuTensorWithConstant(beta1_pow_out, beta1); + } else { + beta1_pow_out->mutable_data(ctx.GetPlace()); } if (beta2_pow->place() == platform::CPUPlace()) { T beta2 = *beta2_pow->data(); + beta2_pow_out->mutable_data(ctx.GetPlace()); FillNpuTensorWithConstant(beta2_pow_out, beta2); + } else { + beta2_pow_out->mutable_data(ctx.GetPlace()); } T beta1 = static_cast(ctx.Attr("beta1")); -- GitLab