diff --git a/paddle/fluid/operators/optimizers/adam_op_npu.cc b/paddle/fluid/operators/optimizers/adam_op_npu.cc index 134544c2f65bc397acc3cb6451990e6cee3b0990..d212ce5959291c5b75b8d24aa10d6dbd43bb2d8f 100644 --- a/paddle/fluid/operators/optimizers/adam_op_npu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_npu.cc @@ -61,8 +61,20 @@ class AdamNPUKernel : public framework::OpKernel { param_out->mutable_data(ctx.GetPlace()); mom1_out->mutable_data(ctx.GetPlace()); mom2_out->mutable_data(ctx.GetPlace()); - beta1_pow_out->mutable_data(ctx.GetPlace()); - beta2_pow_out->mutable_data(ctx.GetPlace()); + + // NOTE(zhiqiu): beta1_pow and beta2_pow may on CPU and not transform place. + if (beta1_pow->place() == platform::CPUPlace()) { + float beta1 = *beta1_pow->data(); + beta1_pow_out->mutable_data(ctx.GetPlace()); + TensorFromVector(std::vector{beta1}, ctx.device_context(), + beta1_pow_out); + } + if (beta2_pow->place() == platform::CPUPlace()) { + float beta2 = *beta2_pow->data(); + beta2_pow_out->mutable_data(ctx.GetPlace()); + TensorFromVector(std::vector{beta2}, ctx.device_context(), + beta2_pow_out); + } T beta1 = static_cast(ctx.Attr("beta1")); if (ctx.HasInput("Beta1Tensor")) {