diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index c96631206dfd7a7b24399581395f84f6837ece31..eee868900b585445721195d13af06a5640e43b5b 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -17,7 +17,7 @@ include(ExternalProject) set(CUTLASS_PREFIX_DIR ${THIRD_PARTY_PATH}/cutlass) set(CUTLASS_REPOSITORY https://github.com/NVIDIA/cutlass.git) -set(CUTLASS_TAG v2.10.0) +set(CUTLASS_TAG v2.11.0) include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/") include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/include/") diff --git a/paddle/phi/kernels/fusion/cutlass/default_moe_fc_traits.h b/paddle/phi/kernels/fusion/cutlass/moe/default_moe_fc_traits.h similarity index 100% rename from paddle/phi/kernels/fusion/cutlass/default_moe_fc_traits.h rename to paddle/phi/kernels/fusion/cutlass/moe/default_moe_fc_traits.h diff --git a/paddle/phi/kernels/fusion/cutlass/linear_combination_ft_gelu.h b/paddle/phi/kernels/fusion/cutlass/moe/linear_combination_ft_gelu.h similarity index 100% rename from paddle/phi/kernels/fusion/cutlass/linear_combination_ft_gelu.h rename to paddle/phi/kernels/fusion/cutlass/moe/linear_combination_ft_gelu.h diff --git a/paddle/phi/kernels/fusion/cutlass/moe_cutlass_kernel.h b/paddle/phi/kernels/fusion/cutlass/moe/moe_cutlass_kernel.h similarity index 98% rename from paddle/phi/kernels/fusion/cutlass/moe_cutlass_kernel.h rename to paddle/phi/kernels/fusion/cutlass/moe/moe_cutlass_kernel.h index f037f4e01b143a4db1ffdd24ff3595e55c9781e2..f0fcafba453c40673ec0a2f667d0164a01027664 100644 --- a/paddle/phi/kernels/fusion/cutlass/moe_cutlass_kernel.h +++ b/paddle/phi/kernels/fusion/cutlass/moe/moe_cutlass_kernel.h @@ -42,6 +42,7 @@ #include "cutlass/gemm/kernel/grouped_problem_visitor.h" #include "cutlass/layout/matrix.h" #include "cutlass/trace.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -350,14 +351,16 @@ template struct GemmMoeProblemVisitor - : public MoeProblemVisitor, - ThreadblockShape, - GroupScheduleMode_, - PrefetchTileCount, - ThreadCount> { + : public MoeProblemVisitor< + detail::GemmGroupedProblemSizeHelper, + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount> { static bool const kTransposed = Transposed; - using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using ProblemSizeHelper = + detail::GemmGroupedProblemSizeHelper; using Base = MoeProblemVisitor