From 44ed8f2d119dbe145af49ab8a442634ee6a59c82 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 31 Mar 2021 20:09:16 +0800 Subject: [PATCH] fix bug when beta1_pow on cpu (#31995) --- paddle/fluid/operators/optimizers/adam_op_npu.cc | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op_npu.cc b/paddle/fluid/operators/optimizers/adam_op_npu.cc index 134544c2f65..d212ce59592 100644 --- a/paddle/fluid/operators/optimizers/adam_op_npu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_npu.cc @@ -61,8 +61,20 @@ class AdamNPUKernel : public framework::OpKernel { param_out->mutable_data(ctx.GetPlace()); mom1_out->mutable_data(ctx.GetPlace()); mom2_out->mutable_data(ctx.GetPlace()); - beta1_pow_out->mutable_data(ctx.GetPlace()); - beta2_pow_out->mutable_data(ctx.GetPlace()); + + // NOTE(zhiqiu): beta1_pow and beta2_pow may on CPU and not transform place. + if (beta1_pow->place() == platform::CPUPlace()) { + float beta1 = *beta1_pow->data(); + beta1_pow_out->mutable_data(ctx.GetPlace()); + TensorFromVector(std::vector{beta1}, ctx.device_context(), + beta1_pow_out); + } + if (beta2_pow->place() == platform::CPUPlace()) { + float beta2 = *beta2_pow->data(); + beta2_pow_out->mutable_data(ctx.GetPlace()); + TensorFromVector(std::vector{beta2}, ctx.device_context(), + beta2_pow_out); + } T beta1 = static_cast(ctx.Attr("beta1")); if (ctx.HasInput("Beta1Tensor")) { -- GitLab