diff --git a/paddle/fluid/operators/optimizers/adam_op_npu.cc b/paddle/fluid/operators/optimizers/adam_op_npu.cc index 6592022711ebf03b5cf145e0b0f804258c742dfa..b024aca21c38214f115c80c3e034fbcea8bfeef9 100644 --- a/paddle/fluid/operators/optimizers/adam_op_npu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_npu.cc @@ -61,17 +61,22 @@ class AdamNPUKernel : public framework::OpKernel { param_out->mutable_data(ctx.GetPlace()); mom1_out->mutable_data(ctx.GetPlace()); mom2_out->mutable_data(ctx.GetPlace()); - beta1_pow_out->mutable_data(ctx.GetPlace()); - beta2_pow_out->mutable_data(ctx.GetPlace()); // NOTE(zhiqiu): beta1_pow and beta2_pow may on CPU and not transform place. if (beta1_pow->place() == platform::CPUPlace()) { T beta1 = *beta1_pow->data(); + // `mutable_data` operation needs to be done after getting data + beta1_pow_out->mutable_data(ctx.GetPlace()); FillNpuTensorWithConstant(beta1_pow_out, beta1); + } else { + beta1_pow_out->mutable_data(ctx.GetPlace()); } if (beta2_pow->place() == platform::CPUPlace()) { T beta2 = *beta2_pow->data(); + beta2_pow_out->mutable_data(ctx.GetPlace()); FillNpuTensorWithConstant(beta2_pow_out, beta2); + } else { + beta2_pow_out->mutable_data(ctx.GetPlace()); } T beta1 = static_cast(ctx.Attr("beta1"));