提交 973d2a0a 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add cutlass matmul using split k parallel

GitOrigin-RevId: 650209e35f813e8eb8373d2ddc1671d3abb1759e
上级 03c921f7
......@@ -9,9 +9,9 @@ ELEMWISE_IMPL := ../src/cuda/cond_take/kimpl \
../src/cuda/elemwise_multi_type/kimpl
CUDA_CONV_IMPL := ../src/cuda/conv_bias/int8/kimpl ../src/cuda/conv_bias/int8_imma/kimpl ../src/cuda/batch_conv_bias/int8/kimpl
CUDA_MATMUL_KIMPL := ../src/cuda/matrix_mul/fp32_simt/kimpl
CUDA_MATMUL_IMPL := ../src/cuda/matrix_mul/fp32_simt/kimpl
all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} $(CUDA_MATMUL_KIMPL)
all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} $(CUDA_MATMUL_IMPL)
../src/common/elemwise/each_mode.inl: gen_elemwise_each_mode.py
./$^ $@
......
......@@ -37,6 +37,9 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {
for (auto&& algo : simt_float32) {
all_algos.push_back(&algo);
}
for (auto&& algo : simt_float32_split_k) {
all_algos.push_back(&algo);
}
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
......@@ -62,6 +65,23 @@ void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() {
simt_float32.emplace_back(AlgoParam{16, 32, 8, 16, 32, 8});
simt_float32.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8});
simt_float32.emplace_back(AlgoParam{16, 128, 8, 16, 64, 8});
simt_float32_split_k.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8});
simt_float32_split_k.emplace_back(AlgoParam{256, 64, 8, 64, 32, 8});
simt_float32_split_k.emplace_back(AlgoParam{32, 256, 8, 16, 64, 8});
simt_float32_split_k.emplace_back(AlgoParam{256, 32, 8, 64, 16, 8});
simt_float32_split_k.emplace_back(AlgoParam{128, 128, 8, 32, 64, 8});
simt_float32_split_k.emplace_back(AlgoParam{128, 64, 8, 64, 32, 8});
simt_float32_split_k.emplace_back(AlgoParam{64, 128, 8, 32, 64, 8});
simt_float32_split_k.emplace_back(AlgoParam{128, 32, 8, 64, 32, 8});
simt_float32_split_k.emplace_back(AlgoParam{32, 128, 8, 32, 64, 8});
simt_float32_split_k.emplace_back(AlgoParam{64, 64, 8, 32, 64, 8});
simt_float32_split_k.emplace_back(AlgoParam{32, 64, 8, 32, 64, 8});
simt_float32_split_k.emplace_back(AlgoParam{64, 32, 8, 64, 32, 8});
simt_float32_split_k.emplace_back(AlgoParam{32, 32, 8, 32, 32, 8});
simt_float32_split_k.emplace_back(AlgoParam{8, 32, 8, 8, 32, 8});
simt_float32_split_k.emplace_back(AlgoParam{16, 32, 8, 16, 32, 8});
simt_float32_split_k.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8});
simt_float32_split_k.emplace_back(AlgoParam{16, 128, 8, 16, 64, 8});
}
MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack;
......
......@@ -43,6 +43,7 @@ public:
CUDA_NAIVE,
CUDA_BFLOAT16,
CUDA_FLOAT32_SIMT,
CUDA_FLOAT32_SIMT_SPLIT_K,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
......@@ -198,6 +199,31 @@ private:
std::string m_name;
};
class MatrixMulForwardImpl::AlgoFloat32SIMTSplitK final : public AlgoBase {
public:
using AlgoParam = MatrixMulForwardImpl::AlgoFloat32SIMT::AlgoParam;
AlgoFloat32SIMTSplitK(AlgoParam algo_param)
: m_algo_param{algo_param},
m_name{ssprintf("CUTLASS_FLOAT32_SIMT_SPLIT_K_%s",
m_algo_param.to_string().c_str())} {}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K)
std::string param() const override {
std::string ret;
serialize_write_pod(m_algo_param, ret);
return ret;
}
private:
AlgoParam m_algo_param;
std::string m_name;
};
class MatrixMulForwardImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;
......@@ -216,6 +242,7 @@ public:
AlgoBFloat16 bfloat16;
#endif
std::vector<AlgoFloat32SIMT> simt_float32;
std::vector<AlgoFloat32SIMTSplitK> simt_float32_split_k;
std::vector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
......
/**
* \file dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
using namespace cutlass_wrapper;
bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available(
const SizeArgs& args) const {
auto&& param = args.opr->param();
int m = args.layout_c.shape[0], n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
return args.opr->param().format == param::MatrixMul::Format::DEFAULT &&
args.layout_a.dtype == dtype::Float32() &&
args.layout_b.dtype == dtype::Float32() &&
args.layout_c.dtype == dtype::Float32() && k > std::max(m, n);
}
size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes(
const SizeArgs& args) const {
size_t lda = args.layout_a.stride[0], ldb = args.layout_b.stride[0],
ldc = args.layout_c.stride[0];
auto&& param = args.opr->param();
int m = args.layout_c.shape[0], n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
GemmCoord problem_size{m, n, k};
int split_k_slices = k / std::max(m, n);
return cutlass_matrix_mul_float32_simt_get_workspace_size(
param.transposeA, lda, param.transposeB, ldb, ldc, problem_size,
1.f, 0.f,
GemmCoord{m_algo_param.threadblock_m, m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n,
m_algo_param.warp_k},
split_k_slices);
}
void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec(
const ExecArgs& args) const {
size_t lda = args.tensor_a.layout.stride[0],
ldb = args.tensor_b.layout.stride[0],
ldc = args.tensor_c.layout.stride[0];
auto&& param = args.opr->param();
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1],
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1];
GemmCoord problem_size{m, n, k};
int split_k_slices = k / std::max(m, n);
auto&& stream = cuda_stream(args.opr->handle());
int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr);
return cutlass_matrix_mul_float32_simt(
args.tensor_a.ptr<dt_float32>(), param.transposeA, lda,
args.tensor_b.ptr<dt_float32>(), param.transposeB, ldb,
args.tensor_c.ptr<dt_float32>(), ldc, workspace, problem_size, 1.f,
0.f,
GemmCoord{m_algo_param.threadblock_m, m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n,
m_algo_param.warp_k},
stream, split_k_slices);
}
// vim: syntax=cpp.doxygen
......@@ -18,6 +18,7 @@
#if __CUDACC_VER_MAJOR__ > 9 || \
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_splitk_parallel.h"
#endif
#include "src/common/opr_param_defs_enumv.cuh"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
......@@ -62,14 +63,20 @@ void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt(
float* /* d_C */, size_t /* ldc */, int* /* workspace */,
GemmCoord const& /* problem_size */, float /* alpha */,
float /* beta */, const GemmCoord& /* threadblock_shape */,
const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
const GemmCoord& /* warp_shape */, cudaStream_t /* stream */,
int /* split_k_slices */) {}
#else
void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt(
const float* d_A, bool transpose_A, size_t lda, const float* d_B,
bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace,
GemmCoord const& problem_size, float alpha, float beta,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
cudaStream_t stream) {
cudaStream_t stream, int split_k_slices) {
static constexpr int kEpilogueElementsPerAccess = 1;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
float, kEpilogueElementsPerAccess, float, float>;
typename EpilogueOp::Params epilogue{alpha, beta};
if (split_k_slices == 1) {
#define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \
warp_k_) \
if (threadblock_shape.m() == threadblock_m_ && \
......@@ -93,29 +100,67 @@ void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt(
workspace, problem_size, \
epilogue, stream); \
}
static constexpr int kEpilogueElementsPerAccess = 1;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
float, kEpilogueElementsPerAccess, float, float>;
typename EpilogueOp::Params epilogue{alpha, beta};
if (!transpose_A && !transpose_B) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
DISPATCH(cb)
} else if (!transpose_A && transpose_B) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
DISPATCH(cb)
} else if (transpose_A && !transpose_B) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
DISPATCH(cb)
if (!transpose_A && !transpose_B) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
DISPATCH(cb)
} else if (!transpose_A && transpose_B) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
DISPATCH(cb)
} else if (transpose_A && !transpose_B) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
DISPATCH(cb)
} else {
megdnn_assert(transpose_A && transpose_B);
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
DISPATCH(cb)
}
#undef cb
} else {
megdnn_assert(transpose_A && transpose_B);
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
DISPATCH(cb)
#define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \
warp_k_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \
using Gemm = cutlass::gemm::device::GemmSplitKParallel< \
float, LayoutA, float, LayoutB, float, \
cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \
cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \
InstructionShape, EpilogueOp>; \
return cutlass_matrix_mul_wrapper<Gemm>( \
d_A, lda, d_B, ldb, d_C, ldc, workspace, problem_size, \
epilogue, stream, split_k_slices); \
}
if (!transpose_A && !transpose_B) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
DISPATCH(cb)
} else if (!transpose_A && transpose_B) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
DISPATCH(cb)
} else if (transpose_A && !transpose_B) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
DISPATCH(cb)
} else {
megdnn_assert(transpose_A && transpose_B);
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
DISPATCH(cb)
}
#undef cb
}
}
#endif
......@@ -127,7 +172,7 @@ size_t megdnn::cuda::cutlass_wrapper::
bool /* transpose_B */, size_t /* ldb */, size_t /* ldc */,
GemmCoord const& /* problem_size */, float /* alpha */,
float /* beta */, const GemmCoord& /* threadblock_shape */,
const GemmCoord& /* warp_shape */) {
const GemmCoord& /* warp_shape */, int /* split_k_slices */) {
return 0;
}
#else
......@@ -136,7 +181,12 @@ size_t megdnn::cuda::cutlass_wrapper::
bool transpose_A, size_t lda, bool transpose_B, size_t ldb,
size_t ldc, GemmCoord const& problem_size, float alpha,
float beta, const GemmCoord& threadblock_shape,
const GemmCoord& warp_shape) {
const GemmCoord& warp_shape, int split_k_slices) {
static constexpr int kEpilogueElementsPerAccess = 1;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
float, kEpilogueElementsPerAccess, float, float>;
typename EpilogueOp::Params epilogue{alpha, beta};
if (split_k_slices == 1) {
#define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \
warp_k_) \
if (threadblock_shape.m() == threadblock_m_ && \
......@@ -169,30 +219,80 @@ size_t megdnn::cuda::cutlass_wrapper::
split_k_slices}; \
return Gemm::get_workspace_size(arguments); \
}
static constexpr int kEpilogueElementsPerAccess = 1;
static constexpr int split_k_slices = 1;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
float, kEpilogueElementsPerAccess, float, float>;
typename EpilogueOp::Params epilogue{alpha, beta};
if (!transpose_A && !transpose_B) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
DISPATCH(cb)
} else if (!transpose_A && transpose_B) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
DISPATCH(cb)
} else if (transpose_A && !transpose_B) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
DISPATCH(cb)
if (!transpose_A && !transpose_B) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
DISPATCH(cb)
} else if (!transpose_A && transpose_B) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
DISPATCH(cb)
} else if (transpose_A && !transpose_B) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
DISPATCH(cb)
} else {
megdnn_assert(transpose_A && transpose_B);
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
DISPATCH(cb)
}
#undef cb
} else {
megdnn_assert(transpose_A && transpose_B);
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
DISPATCH(cb)
#define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \
warp_k_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \
using Gemm = cutlass::gemm::device::GemmSplitKParallel< \
float, LayoutA, float, LayoutB, float, \
cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \
cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \
InstructionShape, EpilogueOp>; \
using TensorRefA = cutlass::TensorRef<typename Gemm::ElementA const, \
typename Gemm::LayoutA>; \
using TensorRefB = cutlass::TensorRef<typename Gemm::ElementB const, \
typename Gemm::LayoutB>; \
using TensorRefC = cutlass::TensorRef<typename Gemm::ElementC const, \
typename Gemm::LayoutC>; \
using TensorRefD = cutlass::TensorRef<typename Gemm::ElementC, \
typename Gemm::LayoutC>; \
TensorRefA tensor_A{nullptr, Gemm::LayoutA{static_cast<int>(lda)}}; \
TensorRefB tensor_B{nullptr, Gemm::LayoutB{static_cast<int>(ldb)}}; \
TensorRefC tensor_C{nullptr, Gemm::LayoutC{static_cast<int>(ldc)}}; \
TensorRefD tensor_D{nullptr, Gemm::LayoutC{static_cast<int>(ldc)}}; \
typename Gemm::Arguments arguments{problem_size, tensor_A, tensor_B, \
tensor_C, tensor_D, epilogue, \
split_k_slices}; \
return Gemm::get_workspace_size(arguments); \
}
if (!transpose_A && !transpose_B) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
DISPATCH(cb)
} else if (!transpose_A && transpose_B) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
DISPATCH(cb)
} else if (transpose_A && !transpose_B) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
DISPATCH(cb)
} else {
megdnn_assert(transpose_A && transpose_B);
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
DISPATCH(cb)
}
#undef cb
}
}
#endif
......
......@@ -26,19 +26,19 @@ void cutlass_matrix_mul_wrapper(
typename Gemm::ElementC* d_C, size_t ldc, int* workspace,
GemmCoord const& problem_size,
typename Gemm::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);
cudaStream_t stream, int split_k_slices = 1);
void cutlass_matrix_mul_float32_simt(
const float* d_A, bool transpose_A, size_t lda, const float* d_B,
bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace,
GemmCoord const& problem_size, float alpha, float beta,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
cudaStream_t stream);
cudaStream_t stream, int split_k_slices = 1);
size_t cutlass_matrix_mul_float32_simt_get_workspace_size(
bool transpose_A, size_t lda, bool transpose_B, size_t ldb, size_t ldc,
GemmCoord const& problem_size, float alpha, float beta,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape);
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, int split_k_slices = 1);
} // namespace cutlass_wrapper
} // namespace cuda
......
......@@ -11,6 +11,7 @@
* implied.
*/
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_splitk_parallel.h"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
using namespace megdnn;
......@@ -24,17 +25,21 @@ void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_wrapper(
typename Gemm::ElementC* d_C, size_t ldc, int* workspace,
GemmCoord const& problem_size,
typename Gemm::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream) {
typename Gemm::TensorRefA tensor_a{
const_cast<typename Gemm::ElementA*>(d_A),
typename Gemm::LayoutA{static_cast<int>(lda)}};
typename Gemm::TensorRefB tensor_b{
const_cast<typename Gemm::ElementB*>(d_B),
typename Gemm::LayoutB{static_cast<int>(ldb)}};
typename Gemm::TensorRefC tensor_c{
nullptr, typename Gemm::LayoutC{static_cast<int>(ldc)}};
typename Gemm::TensorRefD tensor_d{
d_C, typename Gemm::LayoutC{static_cast<int>(ldc)}};
cudaStream_t stream, int split_k_slices) {
using TensorRefA = cutlass::TensorRef<typename Gemm::ElementA const,
typename Gemm::LayoutA>;
using TensorRefB = cutlass::TensorRef<typename Gemm::ElementB const,
typename Gemm::LayoutB>;
using TensorRefC = cutlass::TensorRef<typename Gemm::ElementC const,
typename Gemm::LayoutC>;
using TensorRefD =
cutlass::TensorRef<typename Gemm::ElementC, typename Gemm::LayoutC>;
TensorRefA tensor_a{const_cast<typename Gemm::ElementA*>(d_A),
typename Gemm::LayoutA{static_cast<int>(lda)}};
TensorRefB tensor_b{const_cast<typename Gemm::ElementB*>(d_B),
typename Gemm::LayoutB{static_cast<int>(ldb)}};
TensorRefC tensor_c{nullptr, typename Gemm::LayoutC{static_cast<int>(ldc)}};
TensorRefD tensor_d{d_C, typename Gemm::LayoutC{static_cast<int>(ldc)}};
typename Gemm::Arguments arguments{problem_size,
tensor_a,
......@@ -42,7 +47,7 @@ void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_wrapper(
tensor_c,
tensor_d.non_const_ref(),
epilogue,
1};
split_k_slices};
Gemm gemm_op;
cutlass_check(gemm_op.initialize(arguments, workspace));
cutlass_check(gemm_op(stream));
......
......@@ -42,6 +42,7 @@ public:
class AlgoBFloat16;
#endif
class AlgoFloat32SIMT;
class AlgoFloat32SIMTSplitK;
class AlgoPack;
static const AlgoPack& algo_pack() {
......
......@@ -117,6 +117,18 @@ std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args() {
return args;
}
std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args_split_k() {
std::vector<TestArg> args = get_matmul_args();
for (auto iter = args.begin(); iter < args.end();) {
if (iter->k <= iter->n) {
iter = args.erase(iter);
} else {
iter++;
}
}
return args;
}
std::vector<matrix_mul::TestArg> matrix_mul::get_batched_matmul_args_mask(
uint8_t mask) {
std::vector<TestArg> args;
......
......@@ -53,6 +53,7 @@ struct TestArg {
std::vector<TestArg> get_matmul_args_no_mask();
std::vector<TestArg> get_matmul_args_mask(uint8_t mask);
std::vector<TestArg> get_matmul_args();
std::vector<TestArg> get_matmul_args_split_k();
std::vector<TestArg> get_batched_matmul_args_mask(uint8_t mask);
std::vector<TestArg> get_batched_matmul_args();
std::vector<TestArg> get_batched_matmul_broadcast_args();
......
......@@ -21,7 +21,6 @@
#include "test/cuda/fixture.h"
#include "test/cuda/utils.h"
#if CUDA_VERSION >= 9020
namespace megdnn {
namespace test {
......@@ -284,6 +283,15 @@ TEST_F(CUDA, CUTLASS_GEMM_MULTI_BATCHSIZE) {
param::MatrixMul::Format::DEFAULT);
}
TEST_F(CUDA, CUTLASS_GEMM_SPLIT_K_MULTI_BATCHSIZE) {
auto args = matrix_mul::get_matmul_args_no_mask();
test_multibatchsize(
handle_cuda(), dtype::Float32(), dtype::Float32(), dtype::Float32(),
"CUTLASS_FLOAT32_SIMT_SPLIT_K_128X128X8_32X64X8", args,
param::MatrixMul::Format::DEFAULT,
[](const matrix_mul::TestArg& arg) { return arg.k <= arg.n; });
}
#define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \
cb(1, 64, 256, 8, 32, 64, 8); \
cb(2, 256, 64, 8, 64, 32, 8); \
......@@ -314,6 +322,21 @@ TEST_F(CUDA, CUTLASS_GEMM_MULTI_BATCHSIZE) {
MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
#undef cb
#define cb(name, tbm, tbn, tbk, wm, wn, wk) \
TEST_F(CUDA, CUTLASS_GEMM_SPLIT_K_##name) { \
matrix_mul::check_matrix_mul<MatrixMulForward>( \
dtype::Float32(), dtype::Float32(), dtype::Float32(), \
handle_cuda(), \
"CUTLASS_FLOAT32_SIMT_SPLIT_K_" #tbm "X" #tbn "X" #tbk "_" #wm \
"X" #wn "X" #wk, \
param::MatrixMul::Format::DEFAULT, 8, 1e-3, \
matrix_mul::get_matmul_args_split_k()); \
}
MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
#undef cb
#undef MEGDNN_FOREACH_CUTLASS_KERNEL
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册