diff --git a/paddle/phi/kernels/cpu/where_grad_kernel.cc b/paddle/phi/kernels/cpu/where_grad_kernel.cc index 67c8cee1038c7a990e5961a3fcd17e8d7c591207..a9cdbd7ad77ccf2a6867d343c75ccd32a2e5055e 100644 --- a/paddle/phi/kernels/cpu/where_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/where_grad_kernel.cc @@ -14,6 +14,8 @@ #include "paddle/phi/kernels/where_grad_kernel.h" +#include "paddle/phi/core/kernel_registry.h" + namespace phi { template diff --git a/paddle/phi/kernels/cpu/where_kernel.cc b/paddle/phi/kernels/cpu/where_kernel.cc index f624c13c262296964cef6b98f7d5d26dfc0b7d56..353d11c93c1cc22e681fe65f6715d79c0da24281 100644 --- a/paddle/phi/kernels/cpu/where_kernel.cc +++ b/paddle/phi/kernels/cpu/where_kernel.cc @@ -14,6 +14,8 @@ #include "paddle/phi/kernels/where_kernel.h" +#include "paddle/phi/core/kernel_registry.h" + namespace phi { template diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 84da69ed5da027b3ba10e4702b061fdf4bc6d2c6..b75477a1af98271350fead6dff04ce7f399054bc 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include #include #include @@ -33,7 +34,6 @@ #include "paddle/phi/common/float16.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" -#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/extensions.h" diff --git a/paddle/phi/kernels/gpu/where_grad_kernel.cu b/paddle/phi/kernels/gpu/where_grad_kernel.cu index f21aca80e21b30de8931b4fcd4ae3922be959958..14cc1d311321dd1e96448e89ce9348c6bf489394 100644 --- a/paddle/phi/kernels/gpu/where_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/where_grad_kernel.cu @@ -14,6 +14,9 @@ #include "paddle/phi/kernels/where_grad_kernel.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" + namespace phi { template diff --git a/paddle/phi/kernels/gpu/where_kernel.cu b/paddle/phi/kernels/gpu/where_kernel.cu index 03c24eea3a95af1ed57f5c8df42b01fd09af1fa2..a0be388065f4bc5b5ef5a7e0778c50d2243263cd 100644 --- a/paddle/phi/kernels/gpu/where_kernel.cu +++ b/paddle/phi/kernels/gpu/where_kernel.cu @@ -14,6 +14,8 @@ #include "paddle/phi/kernels/where_kernel.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" diff --git a/paddle/phi/kernels/where_grad_kernel.h b/paddle/phi/kernels/where_grad_kernel.h index 1a3c66ee6ed8403d0b453ed38d21e4beed02661c..5f596da93e9c2e9027fe49f68746f71d46021e9e 100644 --- a/paddle/phi/kernels/where_grad_kernel.h +++ b/paddle/phi/kernels/where_grad_kernel.h @@ -14,10 +14,7 @@ #pragma once -#include "paddle/phi/backends/all_context.h" -#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/kernel_registry.h" namespace phi { diff --git a/paddle/phi/kernels/where_kernel.h b/paddle/phi/kernels/where_kernel.h index 254271ac9c7238c66d09ffe41d12e29fe8f23237..6348177e697647756e50934784e76df889b18194 100644 --- a/paddle/phi/kernels/where_kernel.h +++ b/paddle/phi/kernels/where_kernel.h @@ -14,10 +14,7 @@ #pragma once -#include "paddle/phi/backends/all_context.h" -#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/kernel_registry.h" namespace phi { diff --git a/paddle/utils/variant.h b/paddle/utils/variant.h index b856fa8f7a1d77fe7483e9798ed786d9e111ab2c..a7546d094c2ffe28d057d47e202a4ea9f75c904b 100644 --- a/paddle/utils/variant.h +++ b/paddle/utils/variant.h @@ -2691,7 +2691,8 @@ inline constexpr bool all(std::initializer_list bs) { template inline constexpr decltype(auto) visit(Visitor &&visitor, Vs &&... vs) { - return (detail::all({!vs.valueless_by_exception()...}) + return (detail::all( + lib::array{!vs.valueless_by_exception()...}) ? (void)0 : throw_bad_variant_access()), detail::visitation::variant::visit_value(