diff --git a/paddle/phi/kernels/cpu/adam_kernel.cc b/paddle/phi/kernels/cpu/adam_kernel.cc index f2fec564e133d92121941dc48c53758d0c6ed778..bf8296f8418fcc6330bf4b5579a6b88a7d486bc9 100644 --- a/paddle/phi/kernels/cpu/adam_kernel.cc +++ b/paddle/phi/kernels/cpu/adam_kernel.cc @@ -70,9 +70,10 @@ void AdamDenseKernel(const Context& dev_ctx, phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); - phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); - + if (!use_global_beta_pow) { + phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); + phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + } return; } diff --git a/paddle/phi/kernels/gpu/adam_kernel.cu b/paddle/phi/kernels/gpu/adam_kernel.cu index 7da864de8b3c6907cbf80fe05f26347262ef03b1..53d31b153d5f8a4fc61bf976e98a5f3c56f0d5f8 100644 --- a/paddle/phi/kernels/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/gpu/adam_kernel.cu @@ -171,8 +171,10 @@ void AdamDenseKernel(const Context& dev_ctx, phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); - phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + if (!use_global_beta_pow) { + phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); + phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + } return; } @@ -377,6 +379,17 @@ PD_REGISTER_KERNEL(adam, kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); + + if (kernel_key.dtype() == phi::DataType::FLOAT16 || + kernel_key.dtype() == phi::DataType::BFLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + } + kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); } PD_REGISTER_KERNEL(merged_adam, @@ -390,4 +403,15 @@ PD_REGISTER_KERNEL(merged_adam, // Skip beta1_pow, beta2_pow data transform kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); + + if (kernel_key.dtype() == phi::DataType::FLOAT16 || + kernel_key.dtype() == phi::DataType::BFLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + } + kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); } diff --git a/paddle/phi/kernels/gpu/adamw_kernel.cu b/paddle/phi/kernels/gpu/adamw_kernel.cu index 29e9b984e35e09062f02cc2675d87fecf0ebaf02..8a27df719565052044aea3d90845592c3feb550b 100644 --- a/paddle/phi/kernels/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/gpu/adamw_kernel.cu @@ -191,8 +191,10 @@ void AdamwDenseKernel(const Context& dev_ctx, phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); - phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + if (!use_global_beta_pow) { + phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); + phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + } return; } @@ -306,4 +308,15 @@ PD_REGISTER_KERNEL(adamw, kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); + + if (kernel_key.dtype() == phi::DataType::FLOAT16 || + kernel_key.dtype() == phi::DataType::BFLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + } + kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); } diff --git a/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc b/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc index f99c10d8ad8ba2b2d446bbe38c95fb7f2fd70983..cec2c01702ffe49a93e25ca9b3ca8fe9beb3f2b1 100644 --- a/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc +++ b/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc @@ -71,8 +71,10 @@ void AdamDenseParamSparseGradKernel( phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, beta1_pow_out); - phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, beta2_pow_out); + if (!use_global_beta_pow) { + phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, beta1_pow_out); + phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, beta2_pow_out); + } return; } diff --git a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu index 81e3a33c359d1a73817ba0e3e88cf6b227a2fffb..555562cfac5837d6cc8aa1d8aab4eac8e85bfdbe 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu @@ -138,8 +138,10 @@ void AdamDenseParamSparseGradKernel( phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); - phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + if (!use_global_beta_pow) { + phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); + phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + } return; } @@ -288,4 +290,14 @@ PD_REGISTER_KERNEL(adam_dense_param_sparse_grad, kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); + + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + } + kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); } diff --git a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu index c405061adbf5a829ebb450ddf01bd5450c614e06..0e24b5f71ed2ac5a205fbbfb08b05340fad452dc 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu @@ -156,8 +156,10 @@ void AdamwDenseParamSparseGradKernel( phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); - phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + if (!use_global_beta_pow) { + phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); + phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + } return; } @@ -315,4 +317,14 @@ PD_REGISTER_KERNEL(adamw_dense_param_sparse_grad, kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); + + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + } + kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); } diff --git a/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc b/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc index 29bfbd5954a7e81886e4e3e20d9dd73924d2331d..f9fb73b07694586034a27752f39cda33517b9662 100644 --- a/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc +++ b/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc @@ -165,8 +165,10 @@ void AdamDenseParamSparseGradKernel( phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); - phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + if (!use_global_beta_pow) { + phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); + phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + } return; } @@ -335,4 +337,6 @@ PD_REGISTER_KERNEL(adam_dense_param_sparse_grad, kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); + kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); } diff --git a/paddle/phi/kernels/xpu/adam_kernel.cc b/paddle/phi/kernels/xpu/adam_kernel.cc index 3c86e820eb03a31edd96dac29fa5a1f7acdae9e5..4a64ce929fab43603e490fbaa306101bcd4791b9 100644 --- a/paddle/phi/kernels/xpu/adam_kernel.cc +++ b/paddle/phi/kernels/xpu/adam_kernel.cc @@ -133,8 +133,10 @@ void AdamDenseKernel(const Context& dev_ctx, phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); - phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + if (!use_global_beta_pow) { + phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); + phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + } return; } @@ -246,4 +248,7 @@ PD_REGISTER_KERNEL( kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); + + kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); } diff --git a/paddle/phi/kernels/xpu/adamw_kernel.cc b/paddle/phi/kernels/xpu/adamw_kernel.cc index c2f0f66b3436487fefb232e0a8d0b29a3aa217d5..9258348117786cfd2825101e7564eda9b669b55c 100644 --- a/paddle/phi/kernels/xpu/adamw_kernel.cc +++ b/paddle/phi/kernels/xpu/adamw_kernel.cc @@ -67,8 +67,10 @@ void AdamwDenseKernel(const Context& dev_ctx, phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); - phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + if (!use_global_beta_pow) { + phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); + phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + } return; } @@ -163,4 +165,11 @@ PD_REGISTER_KERNEL( kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); + // Skip beta1_pow, beta2_pow, skip_update data transform + kernel->OutputAt(3) + .SetBackend(phi::Backend::UNDEFINED) + .SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4) + .SetBackend(phi::Backend::UNDEFINED) + .SetDataType(phi::DataType::FLOAT32); }