diff --git a/paddle/phi/kernels/gpu/adam_kernel.cu b/paddle/phi/kernels/gpu/adam_kernel.cu index d3317e258e5382d5d2ca49916da056e8f8506527..33b6f3a5a1bee92f802b673b0982fc94dbe9406e 100644 --- a/paddle/phi/kernels/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/gpu/adam_kernel.cu @@ -272,4 +272,9 @@ PD_REGISTER_KERNEL(adam, phi::AdamDenseKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + // Skip beta1_pow, beta2_pow, skip_update data transform + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/gpu/adamw_kernel.cu b/paddle/phi/kernels/gpu/adamw_kernel.cu index 8fef101383bb09c49c88a6cc36ddf8af46e8be65..3555df11b5e1f06ffee85c5a73387705568ec5f4 100644 --- a/paddle/phi/kernels/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/gpu/adamw_kernel.cu @@ -299,4 +299,9 @@ PD_REGISTER_KERNEL(adamw, phi::AdamwDenseKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + // Skip beta1_pow, beta2_pow, skip_update data transform + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu index 32c05765a9ab0fa239d459dc28b4573ee29eb7cb..2cb086503283b5e45193e43ca01b5ae14ef8c175 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu @@ -284,4 +284,9 @@ PD_REGISTER_KERNEL(adam_dense_param_sparse_grad, phi::sr::AdamDenseParamSparseGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + // Skip beta1_pow, beta2_pow, skip_update data transform + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu index 2e48b8235ed72ab377879767341808944732743e..0fc223e081506ef78dd359e6b6103f8fcd19a834 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu @@ -310,4 +310,9 @@ PD_REGISTER_KERNEL(adamw_dense_param_sparse_grad, phi::sr::AdamwDenseParamSparseGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + // Skip beta1_pow, beta2_pow, skip_update data transform + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); +}