diff --git a/paddle/phi/kernels/gpu/dropout_kernel.cu b/paddle/phi/kernels/gpu/dropout_kernel.cu index 2fa3c7639e3960c8a9c3f9f4dc4f7b000c2ec2bb..f973bb8e15fc75b19e98d8a8116f380699119fe9 100644 --- a/paddle/phi/kernels/gpu/dropout_kernel.cu +++ b/paddle/phi/kernels/gpu/dropout_kernel.cu @@ -84,7 +84,9 @@ PD_REGISTER_KERNEL(dropout, float, double, phi::dtype::bfloat16, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(dropout_nd, GPU, @@ -93,4 +95,6 @@ PD_REGISTER_KERNEL(dropout_nd, float, double, phi::dtype::bfloat16, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); +}