From 4ae76d2179cf9812f76ea91ab8eb6007a5098ec7 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 14 Apr 2022 12:03:39 +0800 Subject: [PATCH] [Op]Fix adam/adamw beta1_pow/beta2_pow place while copying (#41732) --- paddle/phi/kernels/gpu/adamw_kernel.cu | 4 ++-- paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu | 4 ++-- paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/gpu/adamw_kernel.cu b/paddle/phi/kernels/gpu/adamw_kernel.cu index 3555df11b5..4873ba9c13 100644 --- a/paddle/phi/kernels/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/gpu/adamw_kernel.cu @@ -190,8 +190,8 @@ void AdamwDenseKernel(const Context& dev_ctx, phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, beta1_pow_out); - phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, beta2_pow_out); + phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); + phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); return; } diff --git a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu index 2cb0865032..31abac1499 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu @@ -139,8 +139,8 @@ void AdamDenseParamSparseGradKernel( phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, beta1_pow_out); - phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, beta2_pow_out); + phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); + phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); return; } diff --git a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu index 0fc223e081..b847f48d12 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu @@ -156,8 +156,8 @@ void AdamwDenseParamSparseGradKernel( phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, beta1_pow_out); - phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, beta2_pow_out); + phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); + phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); return; } -- GitLab