diff --git a/paddle/fluid/operators/optimizers/adam_op_xpu.cc b/paddle/fluid/operators/optimizers/adam_op_xpu.cc index 05b4544c02a1231f7f6f275f13a978e66705819b..2abc690fc51b26ee1b538a1d9e6b8b0e104fc0f1 100644 --- a/paddle/fluid/operators/optimizers/adam_op_xpu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_xpu.cc @@ -74,7 +74,7 @@ class AdamOpXPUKernel : public framework::OpKernel { "output size is 1, but received " "value is:%d.", beta2_pow_out->numel())); - + T beta1 = static_cast(ctx.Attr("beta1")); if (ctx.HasInput("Beta1Tensor")) { auto* beta1_tensor = ctx.Input("Beta1Tensor"); @@ -88,30 +88,53 @@ class AdamOpXPUKernel : public framework::OpKernel { if (grad_var->IsType()) { auto& grad = GET_DATA_SAFELY(ctx.Input("Grad"), "Input", "Grad", "Adam"); - auto& dev_ctx = ctx.template device_context(); + const T* beta1_pow_ptr = beta1_pow.template data(); + const T* beta2_pow_ptr = beta2_pow.template data(); + Tensor xpu_beta1_pow; + Tensor xpu_beta2_pow; + if (beta1_pow.place() == platform::CPUPlace() && + beta2_pow.place() == platform::CPUPlace()) { + TensorCopy(beta1_pow, ctx.GetPlace(), dev_ctx, &xpu_beta1_pow); + TensorCopy(beta2_pow, ctx.GetPlace(), dev_ctx, &xpu_beta2_pow); + dev_ctx.Wait(); + beta1_pow_ptr = xpu_beta1_pow.template data(); + beta2_pow_ptr = xpu_beta2_pow.template data(); + } int r = xpu::adam( dev_ctx.x_context(), grad.template data(), mom1.template data(), - mom2.template data(), param.template data(), - beta1_pow.template data(), beta2_pow.template data(), beta1, - beta2, epsilon, lr.template data(), + mom2.template data(), param.template data(), beta1_pow_ptr, + beta2_pow_ptr, beta1, beta2, epsilon, lr.template data(), mom1_out.template mutable_data(ctx.GetPlace()), mom2_out.template mutable_data(ctx.GetPlace()), param_out.template mutable_data(ctx.GetPlace()), param.numel()); - const float* ptr0 = beta1_pow.template data(); - float* ptr1 = beta1_pow_out->mutable_data(ctx.GetPlace()); - float cpudata; - xpu_memcpy(&cpudata, ptr0, sizeof(float), XPU_DEVICE_TO_HOST); - cpudata = cpudata * beta1; - xpu_memcpy(ptr1, &cpudata, sizeof(float), XPU_HOST_TO_DEVICE); - - const float* ptr2 = beta2_pow.template data(); - float* ptr3 = beta2_pow_out->mutable_data(ctx.GetPlace()); - float cpudata1; - xpu_memcpy(&cpudata1, ptr2, sizeof(float), XPU_DEVICE_TO_HOST); - cpudata1 = cpudata1 * beta2; - xpu_memcpy(ptr3, &cpudata1, sizeof(float), XPU_HOST_TO_DEVICE); + //update in cpu and then copy to xpu + if (beta1_pow.place() == platform::CPUPlace() && + beta2_pow.place() == platform::CPUPlace()) { + const T* beta1_pow_p = beta1_pow.template data(); + beta1_pow_out->mutable_data(platform::CPUPlace())[0] = + beta1 * beta1_pow_p[0]; + const T* beta2_pow_p = beta2_pow.template data(); + beta2_pow_out->mutable_data(platform::CPUPlace())[0] = + beta2 * beta2_pow_p[0]; + } else { + T cpu_beta1_pow_out_data; + T cpu_beta2_pow_out_data; + xpu_memcpy(&cpu_beta1_pow_out_data, beta1_pow_ptr, sizeof(T), + XPU_DEVICE_TO_HOST); + cpu_beta1_pow_out_data = cpu_beta1_pow_out_data * beta1; + xpu_memcpy(&cpu_beta2_pow_out_data, beta2_pow_ptr, sizeof(T), + XPU_DEVICE_TO_HOST); + cpu_beta2_pow_out_data = cpu_beta2_pow_out_data * beta2; + + T* beta1_pow_out_p = beta1_pow_out->mutable_data(ctx.GetPlace()); + T* beta2_pow_out_p = beta2_pow_out->mutable_data(ctx.GetPlace()); + xpu_memcpy(beta1_pow_out_p, &cpu_beta1_pow_out_data, sizeof(T), + XPU_HOST_TO_DEVICE); + xpu_memcpy(beta2_pow_out_p, &cpu_beta2_pow_out_data, sizeof(T), + XPU_HOST_TO_DEVICE); + } PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, platform::errors::External(