From fdeec8c37e6a4d53557eb9715e39b6ff04ced5bc Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 12 Apr 2022 10:30:19 +0800 Subject: [PATCH] [Phi]Fix beta1_pow/beta2_pow/skip_update data transform problem in adam/adamw (#41641) * [Phi]Fix beta1_pow/beta2_pow/skip_update data transform problem in adam/adamw * fix xpu unittest failed --- paddle/phi/kernels/gpu/adam_kernel.cu | 7 ++++++- paddle/phi/kernels/gpu/adamw_kernel.cu | 7 ++++++- paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu | 7 ++++++- paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu | 7 ++++++- 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/gpu/adam_kernel.cu b/paddle/phi/kernels/gpu/adam_kernel.cu index d3317e258e..33b6f3a5a1 100644 --- a/paddle/phi/kernels/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/gpu/adam_kernel.cu @@ -272,4 +272,9 @@ PD_REGISTER_KERNEL(adam, phi::AdamDenseKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + // Skip beta1_pow, beta2_pow, skip_update data transform + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/gpu/adamw_kernel.cu b/paddle/phi/kernels/gpu/adamw_kernel.cu index 8fef101383..3555df11b5 100644 --- a/paddle/phi/kernels/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/gpu/adamw_kernel.cu @@ -299,4 +299,9 @@ PD_REGISTER_KERNEL(adamw, phi::AdamwDenseKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + // Skip beta1_pow, beta2_pow, skip_update data transform + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu index 32c05765a9..2cb0865032 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu @@ -284,4 +284,9 @@ PD_REGISTER_KERNEL(adam_dense_param_sparse_grad, phi::sr::AdamDenseParamSparseGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + // Skip beta1_pow, beta2_pow, skip_update data transform + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu index 2e48b8235e..0fc223e081 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu @@ -310,4 +310,9 @@ PD_REGISTER_KERNEL(adamw_dense_param_sparse_grad, phi::sr::AdamwDenseParamSparseGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + // Skip beta1_pow, beta2_pow, skip_update data transform + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); +} -- GitLab