未验证 提交 fdeec8c3 编写于 作者: A Aurelius84 提交者: GitHub

[Phi]Fix beta1_pow/beta2_pow/skip_update data transform problem in adam/adamw (#41641)

* [Phi]Fix beta1_pow/beta2_pow/skip_update data transform problem in adam/adamw

* fix xpu unittest failed
上级 b68bb428
...@@ -272,4 +272,9 @@ PD_REGISTER_KERNEL(adam, ...@@ -272,4 +272,9 @@ PD_REGISTER_KERNEL(adam,
phi::AdamDenseKernel, phi::AdamDenseKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {
// Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
}
...@@ -299,4 +299,9 @@ PD_REGISTER_KERNEL(adamw, ...@@ -299,4 +299,9 @@ PD_REGISTER_KERNEL(adamw,
phi::AdamwDenseKernel, phi::AdamwDenseKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {
// Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
}
...@@ -284,4 +284,9 @@ PD_REGISTER_KERNEL(adam_dense_param_sparse_grad, ...@@ -284,4 +284,9 @@ PD_REGISTER_KERNEL(adam_dense_param_sparse_grad,
phi::sr::AdamDenseParamSparseGradKernel, phi::sr::AdamDenseParamSparseGradKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {
// Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
}
...@@ -310,4 +310,9 @@ PD_REGISTER_KERNEL(adamw_dense_param_sparse_grad, ...@@ -310,4 +310,9 @@ PD_REGISTER_KERNEL(adamw_dense_param_sparse_grad,
phi::sr::AdamwDenseParamSparseGradKernel, phi::sr::AdamwDenseParamSparseGradKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {
// Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册