未验证 提交 a31cf3d2 编写于 作者: L Leo Chen 提交者: GitHub

skip data_transfer for save op (#56775)

上级 2ea7a6a3
...@@ -101,7 +101,9 @@ PD_REGISTER_KERNEL(save, ...@@ -101,7 +101,9 @@ PD_REGISTER_KERNEL(save,
int16_t, int16_t,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(save_sr, PD_REGISTER_KERNEL(save_sr,
CPU, CPU,
...@@ -115,7 +117,9 @@ PD_REGISTER_KERNEL(save_sr, ...@@ -115,7 +117,9 @@ PD_REGISTER_KERNEL(save_sr,
int16_t, int16_t,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(save, PD_REGISTER_KERNEL(save,
...@@ -130,7 +134,9 @@ PD_REGISTER_KERNEL(save, ...@@ -130,7 +134,9 @@ PD_REGISTER_KERNEL(save,
int16_t, int16_t,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(save_sr, PD_REGISTER_KERNEL(save_sr,
GPU, GPU,
...@@ -144,5 +150,7 @@ PD_REGISTER_KERNEL(save_sr, ...@@ -144,5 +150,7 @@ PD_REGISTER_KERNEL(save_sr,
int16_t, int16_t,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
#endif #endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册