From efa85f8c7782a5bf3ac3a93e350677663796d5c4 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Thu, 1 Apr 2021 15:59:01 +0800 Subject: [PATCH] fix adam (#32016) --- paddle/fluid/operators/optimizers/adam_op_npu.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddle/fluid/operators/optimizers/adam_op_npu.cc b/paddle/fluid/operators/optimizers/adam_op_npu.cc index d212ce59592..e2d262ff97d 100644 --- a/paddle/fluid/operators/optimizers/adam_op_npu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_npu.cc @@ -68,12 +68,16 @@ class AdamNPUKernel : public framework::OpKernel { beta1_pow_out->mutable_data(ctx.GetPlace()); TensorFromVector(std::vector{beta1}, ctx.device_context(), beta1_pow_out); + } else { + beta1_pow_out->mutable_data(ctx.GetPlace()); } if (beta2_pow->place() == platform::CPUPlace()) { float beta2 = *beta2_pow->data(); beta2_pow_out->mutable_data(ctx.GetPlace()); TensorFromVector(std::vector{beta2}, ctx.device_context(), beta2_pow_out); + } else { + beta2_pow_out->mutable_data(ctx.GetPlace()); } T beta1 = static_cast(ctx.Attr("beta1")); -- GitLab