From abc2cc5704178dfff968aae81ffaed1e2b67b992 Mon Sep 17 00:00:00 2001 From: Roc <30228238+sljlp@users.noreply.github.com> Date: Wed, 13 Jul 2022 10:32:15 +0800 Subject: [PATCH] fix transform data (#44266) * fix transform data * fix dropout kernel * Revert "fix transform data" This reverts commit ada75ecd169ea194ce43f7ed75dcc968f5ed2fb9. --- paddle/phi/kernels/gpu/dropout_kernel.cu | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/gpu/dropout_kernel.cu b/paddle/phi/kernels/gpu/dropout_kernel.cu index 2fa3c7639e..f973bb8e15 100644 --- a/paddle/phi/kernels/gpu/dropout_kernel.cu +++ b/paddle/phi/kernels/gpu/dropout_kernel.cu @@ -84,7 +84,9 @@ PD_REGISTER_KERNEL(dropout, float, double, phi::dtype::bfloat16, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(dropout_nd, GPU, @@ -93,4 +95,6 @@ PD_REGISTER_KERNEL(dropout_nd, float, double, phi::dtype::bfloat16, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); +} -- GitLab