diff --git a/paddle/phi/kernels/funcs/fused_gemm_epilogue.h b/paddle/phi/kernels/funcs/fused_gemm_epilogue.h index 8a2309ba26000fd79b726e8632e775cf4441247a..ab0758e2e3ff4792dac3c2c46de7fb0b7b21bcd0 100644 --- a/paddle/phi/kernels/funcs/fused_gemm_epilogue.h +++ b/paddle/phi/kernels/funcs/fused_gemm_epilogue.h @@ -535,10 +535,11 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, bool use_addto_dx, bool use_addto_dy) { using MT = typename phi::dtype::MPTypeTrait::Type; - static_assert(std::is_same::value || std::is_same::value, - ""); - static_assert(std::is_same::value || std::is_same::value, - ""); + constexpr bool kIsValidDataType = + (std::is_same::value || std::is_same::value) && + (std::is_same::value || std::is_same::value); + static_assert(kIsValidDataType, "Invalid data type"); + using Trait = FusedGEMMGradTrait; if (dx) {