提交 03c921f7 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add cutlass matmul impls

GitOrigin-RevId: 619c8c299ca5a4b43806861f42756eae7120c9ff
上级 886e7c6e
......@@ -3,6 +3,7 @@
dnn/src/cuda/conv_bias/int8/kimpl/* binary
dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary
dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary
dnn/src/cuda/matrix_mul/fp32_simt/kimpl/* binary
dnn/src/cuda/sass/prebuilt/map_defs.cpp binary
tools/mlir/mlir-tblgen filter=lfs diff=lfs merge=lfs -text
*.caffemodel filter=lfs diff=lfs merge=lfs -text
......
......@@ -8,9 +8,10 @@ ELEMWISE_IMPL := ../src/cuda/cond_take/kimpl \
../src/naive/elemwise/kimpl \
../src/cuda/elemwise_multi_type/kimpl
CUDA_CONV_IMPL := ../src/cuda/conv_bias/int8/kimpl ../src/cuda/conv_bias/int8_imma/kimpl ../src/cuda/batch_conv_bias/int8/kimpl
CUDA_CONV_IMPL := ../src/cuda/conv_bias/int8/kimpl ../src/cuda/conv_bias/int8_imma/kimpl ../src/cuda/batch_conv_bias/int8/kimpl
CUDA_MATMUL_KIMPL := ../src/cuda/matrix_mul/fp32_simt/kimpl
all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL}
all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} $(CUDA_MATMUL_KIMPL)
../src/common/elemwise/each_mode.inl: gen_elemwise_each_mode.py
./$^ $@
......@@ -47,4 +48,7 @@ all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL}
../src/cuda/batch_conv_bias/int8/kimpl: gen_cuda_batch_conv_bias_kern_impls.py
./$^ --type dp4a $@
../src/cuda/matrix_mul/fp32_simt/kimpl: gen_cutlass_matmul_kern_impls.py
./$^ $@
.PHONY: all
......@@ -33,12 +33,37 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {
#if !MEGDNN_DISABLE_FLOAT16
all_algos.push_back(&bfloat16);
#endif
fill_cutlass_algos();
for (auto&& algo : simt_float32) {
all_algos.push_back(&algo);
}
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() {
using AlgoParam = AlgoFloat32SIMT::AlgoParam;
simt_float32.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8});
simt_float32.emplace_back(AlgoParam{256, 64, 8, 64, 32, 8});
simt_float32.emplace_back(AlgoParam{32, 256, 8, 16, 64, 8});
simt_float32.emplace_back(AlgoParam{256, 32, 8, 64, 16, 8});
simt_float32.emplace_back(AlgoParam{128, 128, 8, 32, 64, 8});
simt_float32.emplace_back(AlgoParam{128, 64, 8, 64, 32, 8});
simt_float32.emplace_back(AlgoParam{64, 128, 8, 32, 64, 8});
simt_float32.emplace_back(AlgoParam{128, 32, 8, 64, 32, 8});
simt_float32.emplace_back(AlgoParam{32, 128, 8, 32, 64, 8});
simt_float32.emplace_back(AlgoParam{64, 64, 8, 32, 64, 8});
simt_float32.emplace_back(AlgoParam{32, 64, 8, 32, 64, 8});
simt_float32.emplace_back(AlgoParam{64, 32, 8, 64, 32, 8});
simt_float32.emplace_back(AlgoParam{32, 32, 8, 32, 32, 8});
simt_float32.emplace_back(AlgoParam{8, 32, 8, 8, 32, 8});
simt_float32.emplace_back(AlgoParam{16, 32, 8, 16, 32, 8});
simt_float32.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8});
simt_float32.emplace_back(AlgoParam{16, 128, 8, 16, 64, 8});
}
MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack;
MEGDNN_DEF_GET_ALGO_FROM_DESC(MatrixMulForwardImpl)
......
......@@ -41,7 +41,8 @@ public:
CUDA_WMMA_UINT4X4X32,
CUDA_CUBLASLT,
CUDA_NAIVE,
CUDA_BFLOAT16
CUDA_BFLOAT16,
CUDA_FLOAT32_SIMT,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
......@@ -165,6 +166,38 @@ private:
};
#endif
class MatrixMulForwardImpl::AlgoFloat32SIMT final : public AlgoBase {
public:
struct AlgoParam {
int threadblock_m, threadblock_n, threadblock_k;
int warp_m, warp_n, warp_k;
std::string to_string() {
return ssprintf("%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n,
threadblock_k, warp_m, warp_n, warp_k);
}
};
AlgoFloat32SIMT(AlgoParam algo_param)
: m_algo_param{algo_param},
m_name{ssprintf("CUTLASS_FLOAT32_SIMT_%s",
m_algo_param.to_string().c_str())} {}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT)
std::string param() const override {
std::string ret;
serialize_write_pod(m_algo_param, ret);
return ret;
}
private:
AlgoParam m_algo_param;
std::string m_name;
};
class MatrixMulForwardImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;
......@@ -182,9 +215,11 @@ public:
#if !MEGDNN_DISABLE_FLOAT16
AlgoBFloat16 bfloat16;
#endif
std::vector<AlgoFloat32SIMT> simt_float32;
std::vector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
void fill_cutlass_algos();
};
} // namespace cuda
......
/**
* \file dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
using namespace cutlass_wrapper;
bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available(
const SizeArgs& args) const {
#if CUDA_VERSION >= 9200
return args.opr->param().format == param::MatrixMul::Format::DEFAULT &&
args.layout_a.dtype == dtype::Float32() &&
args.layout_b.dtype == dtype::Float32() &&
args.layout_c.dtype == dtype::Float32();
#else
return false;
#endif
}
size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes(
const SizeArgs& args) const {
size_t lda = args.layout_a.stride[0], ldb = args.layout_b.stride[0],
ldc = args.layout_c.stride[0];
auto&& param = args.opr->param();
int m = args.layout_c.shape[0], n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
GemmCoord problem_size{m, n, k};
return cutlass_matrix_mul_float32_simt_get_workspace_size(
param.transposeA, lda, param.transposeB, ldb, ldc, problem_size,
1.f, 0.f,
GemmCoord{m_algo_param.threadblock_m, m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n,
m_algo_param.warp_k});
}
void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const {
size_t lda = args.tensor_a.layout.stride[0],
ldb = args.tensor_b.layout.stride[0],
ldc = args.tensor_c.layout.stride[0];
auto&& param = args.opr->param();
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1],
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1];
GemmCoord problem_size{m, n, k};
auto&& stream = cuda_stream(args.opr->handle());
int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr);
return cutlass_matrix_mul_float32_simt(
args.tensor_a.ptr<dt_float32>(), param.transposeA, lda,
args.tensor_b.ptr<dt_float32>(), param.transposeB, ldb,
args.tensor_c.ptr<dt_float32>(), ldc, workspace, problem_size, 1.f,
0.f,
GemmCoord{m_algo_param.threadblock_m, m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "cuda.h"
#if __CUDACC_VER_MAJOR__ > 9 || \
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
#include "cutlass/gemm/device/gemm.h"
#endif
#include "src/common/opr_param_defs_enumv.cuh"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
#pragma GCC diagnostic pop
using namespace megdnn;
using namespace cuda;
using namespace cutlass_wrapper;
/* ================= cutlass kernel wrapper for f32 matrix mul ================
*/
#define DISPATCH(cb) \
cb(64, 256, 8, 32, 64, 8); \
cb(256, 64, 8, 64, 32, 8); \
cb(32, 256, 8, 16, 64, 8); \
cb(256, 32, 8, 64, 16, 8); \
cb(128, 128, 8, 32, 64, 8); \
cb(128, 64, 8, 64, 32, 8); \
cb(64, 128, 8, 32, 64, 8); \
cb(128, 32, 8, 64, 32, 8); \
cb(32, 128, 8, 32, 64, 8); \
cb(64, 64, 8, 32, 64, 8); \
cb(32, 64, 8, 32, 64, 8); \
cb(64, 32, 8, 64, 32, 8); \
cb(32, 32, 8, 32, 32, 8); \
cb(8, 32, 8, 8, 32, 8); \
cb(16, 32, 8, 16, 32, 8); \
cb(16, 64, 8, 16, 64, 8); \
cb(16, 128, 8, 16, 64, 8); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d)", \
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k());
#if __CUDACC_VER_MAJOR__ < 9 || \
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ <= 2)
void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt(
const float* /* d_A */, bool /* transpose_A */, size_t /* lda */,
const float* /* d_B */, bool /* transpose_B */, size_t /* ldb */,
float* /* d_C */, size_t /* ldc */, int* /* workspace */,
GemmCoord const& /* problem_size */, float /* alpha */,
float /* beta */, const GemmCoord& /* threadblock_shape */,
const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
#else
void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt(
const float* d_A, bool transpose_A, size_t lda, const float* d_B,
bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace,
GemmCoord const& problem_size, float alpha, float beta,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
cudaStream_t stream) {
#define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \
warp_k_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \
using Gemm = cutlass::gemm::device::Gemm< \
float, LayoutA, float, LayoutB, float, \
cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \
cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \
InstructionShape, EpilogueOp, \
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, \
2>; \
return cutlass_matrix_mul_wrapper<Gemm>(d_A, lda, d_B, ldb, d_C, ldc, \
workspace, problem_size, \
epilogue, stream); \
}
static constexpr int kEpilogueElementsPerAccess = 1;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
float, kEpilogueElementsPerAccess, float, float>;
typename EpilogueOp::Params epilogue{alpha, beta};
if (!transpose_A && !transpose_B) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
DISPATCH(cb)
} else if (!transpose_A && transpose_B) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
DISPATCH(cb)
} else if (transpose_A && !transpose_B) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
DISPATCH(cb)
} else {
megdnn_assert(transpose_A && transpose_B);
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
DISPATCH(cb)
}
#undef cb
}
#endif
#if __CUDACC_VER_MAJOR__ < 9 || \
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ <= 2)
size_t megdnn::cuda::cutlass_wrapper::
cutlass_matrix_mul_float32_simt_get_workspace_size(
bool /* transpose_A */, size_t /* lda */,
bool /* transpose_B */, size_t /* ldb */, size_t /* ldc */,
GemmCoord const& /* problem_size */, float /* alpha */,
float /* beta */, const GemmCoord& /* threadblock_shape */,
const GemmCoord& /* warp_shape */) {
return 0;
}
#else
size_t megdnn::cuda::cutlass_wrapper::
cutlass_matrix_mul_float32_simt_get_workspace_size(
bool transpose_A, size_t lda, bool transpose_B, size_t ldb,
size_t ldc, GemmCoord const& problem_size, float alpha,
float beta, const GemmCoord& threadblock_shape,
const GemmCoord& warp_shape) {
#define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \
warp_k_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \
using Gemm = cutlass::gemm::device::Gemm< \
float, LayoutA, float, LayoutB, float, \
cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \
cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \
InstructionShape, EpilogueOp, \
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, \
2>; \
typename Gemm::TensorRefA tensor_A{ \
nullptr, Gemm::LayoutA{static_cast<int>(lda)}}; \
typename Gemm::TensorRefB tensor_B{ \
nullptr, Gemm::LayoutB{static_cast<int>(ldb)}}; \
typename Gemm::TensorRefC tensor_C{ \
nullptr, Gemm::LayoutC{static_cast<int>(ldc)}}; \
typename Gemm::TensorRefD tensor_D{ \
nullptr, Gemm::LayoutC{static_cast<int>(ldc)}}; \
typename Gemm::Arguments arguments{problem_size, tensor_A, tensor_B, \
tensor_C, tensor_D, epilogue, \
split_k_slices}; \
return Gemm::get_workspace_size(arguments); \
}
static constexpr int kEpilogueElementsPerAccess = 1;
static constexpr int split_k_slices = 1;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
float, kEpilogueElementsPerAccess, float, float>;
typename EpilogueOp::Params epilogue{alpha, beta};
if (!transpose_A && !transpose_B) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
DISPATCH(cb)
} else if (!transpose_A && transpose_B) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
DISPATCH(cb)
} else if (transpose_A && !transpose_B) {
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
DISPATCH(cb)
} else {
megdnn_assert(transpose_A && transpose_B);
using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::ColumnMajor;
DISPATCH(cb)
}
#undef cb
}
#endif
#undef DISPATCH
// vim: syntax=cuda.doxygen
/**
* \file dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "cutlass/gemm/gemm.h"
#include "src/cuda/utils.cuh"
namespace megdnn {
namespace cuda {
namespace cutlass_wrapper {
using GemmCoord = cutlass::gemm::GemmCoord;
template <typename Gemm>
void cutlass_matrix_mul_wrapper(
const typename Gemm::ElementA* d_A, size_t lda,
const typename Gemm::ElementB* d_B, size_t ldb,
typename Gemm::ElementC* d_C, size_t ldc, int* workspace,
GemmCoord const& problem_size,
typename Gemm::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);
void cutlass_matrix_mul_float32_simt(
const float* d_A, bool transpose_A, size_t lda, const float* d_B,
bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace,
GemmCoord const& problem_size, float alpha, float beta,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
cudaStream_t stream);
size_t cutlass_matrix_mul_float32_simt_get_workspace_size(
bool transpose_A, size_t lda, bool transpose_B, size_t ldb, size_t ldc,
GemmCoord const& problem_size, float alpha, float beta,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape);
} // namespace cutlass_wrapper
} // namespace cuda
} // namespace megdnn
// vim: syntax=cuda.doxygen
/**
* \file
* dnn/src/cuda/matrix_mul/matrix_mul_float_simt_cutlass_wrapper.cuinl
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "cutlass/gemm/device/gemm.h"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
using namespace megdnn;
using namespace cuda;
using namespace cutlass_wrapper;
template <typename Gemm>
void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_wrapper(
const typename Gemm::ElementA* d_A, size_t lda,
const typename Gemm::ElementB* d_B, size_t ldb,
typename Gemm::ElementC* d_C, size_t ldc, int* workspace,
GemmCoord const& problem_size,
typename Gemm::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream) {
typename Gemm::TensorRefA tensor_a{
const_cast<typename Gemm::ElementA*>(d_A),
typename Gemm::LayoutA{static_cast<int>(lda)}};
typename Gemm::TensorRefB tensor_b{
const_cast<typename Gemm::ElementB*>(d_B),
typename Gemm::LayoutB{static_cast<int>(ldb)}};
typename Gemm::TensorRefC tensor_c{
nullptr, typename Gemm::LayoutC{static_cast<int>(ldc)}};
typename Gemm::TensorRefD tensor_d{
d_C, typename Gemm::LayoutC{static_cast<int>(ldc)}};
typename Gemm::Arguments arguments{problem_size,
tensor_a,
tensor_b,
tensor_c,
tensor_d.non_const_ref(),
epilogue,
1};
Gemm gemm_op;
cutlass_check(gemm_op.initialize(arguments, workspace));
cutlass_check(gemm_op(stream));
after_kernel_launch();
}
// vim: syntax=cuda.doxygen
......@@ -41,6 +41,7 @@ public:
#if !MEGDNN_DISABLE_FLOAT16
class AlgoBFloat16;
#endif
class AlgoFloat32SIMT;
class AlgoPack;
static const AlgoPack& algo_pack() {
......
/**
* \file dnn/test/cuda/cutlass_matmul.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include <cuda.h>
#include "megdnn/oprs/linalg.h"
#include "src/common/utils.h"
#include "test/common/checker.h"
#include "test/common/matrix_mul.h"
#include "test/common/tensor.h"
#include "test/common/workspace_wrapper.h"
#include "test/cuda/benchmark.h"
#include "test/cuda/fixture.h"
#include "test/cuda/utils.h"
#if CUDA_VERSION >= 9020
namespace megdnn {
namespace test {
namespace {
void test_multibatchsize(
Handle* handle_cuda, DType A_dtype, DType B_dtype, DType C_dtype,
const char* algo, const std::vector<matrix_mul::TestArg>& args,
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT,
const std::function<bool(const matrix_mul::TestArg&)>& filter = {}) {
Checker<MatrixMulForward> checker(handle_cuda, false);
if (algo) {
checker.set_before_exec_callback(AlgoChecker<MatrixMulForward>(algo));
}
std::unique_ptr<RNG> rng;
if (A_dtype.enumv() == DTypeEnum::Float32) {
rng = std::make_unique<UniformFloatRNG>(-1, 1);
megdnn_assert(B_dtype.enumv() == DTypeEnum::Float32 &&
C_dtype.enumv() == DTypeEnum::Float32);
}
megdnn_assert(rng != nullptr);
struct Compare {
bool is_same(dt_float32 expected, dt_float32 actual) const {
return expected == actual;
}
};
// copy rhs->lhs, lhs is 8 times of rhs
auto copy = [](SyncedTensor<dt_float32, Compare>& lhs,
SyncedTensor<dt_float32, Compare>& rhs) {
size_t chunk = rhs.layout().span().dist_byte();
size_t tot = lhs.layout().span().dist_byte();
megdnn_assert(tot % chunk == 0);
char* pointer_lhs = reinterpret_cast<char*>(lhs.ptr_mutable_host());
const char* pointer_rhs = reinterpret_cast<const char*>(rhs.ptr_host());
for (size_t i = 0; i < tot; i += chunk) {
std::memcpy(pointer_lhs + i, pointer_rhs, chunk);
}
};
using Param = param::MatrixMul;
megdnn_assert(format == Param::Format::DEFAULT);
for (auto&& arg : args) {
megdnn_assert(arg.mask == 0x0);
// make m, n, k big enough
size_t m = arg.m, n = (arg.n << 3), k = (arg.k << 3);
size_t m_prime = (m << 3);
if (filter && filter(arg))
continue;
TensorShape A{m, k}, B{k, n}, C{m, n};
TensorShape A_prime{m_prime, k}, C_prime{m_prime, n};
SyncedTensor<dt_float32, Compare> A_tensor{handle_cuda, {A, A_dtype}},
B_tensor{handle_cuda, {B, B_dtype}},
C_tensor{handle_cuda, {C, C_dtype}},
A_tensor_prime{handle_cuda, {A_prime, A_dtype}},
C_tensor_prime{handle_cuda, {C_prime, C_dtype}},
C_tensor_batch{handle_cuda, {C_prime, C_dtype}};
rng->gen(A_tensor.tensornd_host());
rng->gen(B_tensor.tensornd_host());
copy(A_tensor_prime, A_tensor);
auto opr_reference = handle_cuda->create_operator<MatrixMulForward>();
{
opr_reference->execution_policy().algo.reset();
for (auto i : opr_reference->get_all_algorithms_info(
A_tensor.layout(), B_tensor.layout(),
C_tensor.layout())) {
if (std::regex_match(
i.name.c_str(),
std::regex("(" + std::string(algo) + ")(.*)"))) {
opr_reference->execution_policy().algo = i;
break;
}
}
megdnn_assert(opr_reference->execution_policy().algo.valid());
size_t ws_size = opr_reference->get_workspace_in_bytes(
A_tensor.layout(), B_tensor.layout(), C_tensor.layout());
WorkspaceWrapper ws_reference(handle_cuda, ws_size);
opr_reference->exec(
A_tensor.tensornd_dev(), B_tensor.tensornd_dev(),
C_tensor.tensornd_dev(), ws_reference.workspace());
}
copy(C_tensor_prime, C_tensor);
checker.set_dtype(0, A_dtype)
.set_dtype(1, B_dtype)
.set_dtype(2, C_dtype)
.set_epsilon(1e-6)
.exect({A_tensor_prime.tensornd_host(),
B_tensor.tensornd_host(),
{}},
{{}, {}, C_tensor_prime.tensornd_host()});
{
opr_reference->execution_policy().algo.reset();
for (auto i : opr_reference->get_all_algorithms_info(
A_tensor_prime.layout(), B_tensor.layout(),
C_tensor_batch.layout())) {
if (std::regex_match(
i.name.c_str(),
std::regex("(" + std::string(algo) + ")(.*)"))) {
opr_reference->execution_policy().algo = i;
break;
}
}
megdnn_assert(opr_reference->execution_policy().algo.valid());
size_t ws_size = opr_reference->get_workspace_in_bytes(
A_tensor_prime.layout(), B_tensor.layout(),
C_tensor_batch.layout());
WorkspaceWrapper ws_reference(handle_cuda, ws_size);
opr_reference->exec(
A_tensor_prime.tensornd_dev(), B_tensor.tensornd_dev(),
C_tensor_batch.tensornd_dev(), ws_reference.workspace());
}
C_tensor_batch.check_with(C_tensor_prime);
}
}
#if MEGDNN_WITH_BENCHMARK
struct BenchArgs {
size_t m, n, k, mask = 0x0;
};
std::vector<BenchArgs> get_square_matmul_args() {
std::vector<BenchArgs> args;
args.emplace_back(BenchArgs{128, 128, 128});
args.emplace_back(BenchArgs{256, 256, 256});
args.emplace_back(BenchArgs{512, 512, 512});
args.emplace_back(BenchArgs{1024, 1024, 1024});
args.emplace_back(BenchArgs{2048, 2048, 2048});
args.emplace_back(BenchArgs{4096, 4096, 4096});
return args;
}
std::vector<BenchArgs> get_feat_model_args() {
std::vector<BenchArgs> args;
args.emplace_back(BenchArgs{2, 4096, 4096});
args.emplace_back(BenchArgs{2, 1024, 6912});
args.emplace_back(BenchArgs{2, 3456, 3456});
args.emplace_back(BenchArgs{2, 2304, 2304});
args.emplace_back(BenchArgs{1, 256, 8192});
args.emplace_back(BenchArgs{2, 864, 864});
args.emplace_back(BenchArgs{2, 9, 64});
args.emplace_back(BenchArgs{4, 4096, 4096});
args.emplace_back(BenchArgs{4, 1024, 6912});
args.emplace_back(BenchArgs{4, 3456, 3456});
args.emplace_back(BenchArgs{4, 2304, 2304});
args.emplace_back(BenchArgs{2, 256, 8192});
args.emplace_back(BenchArgs{4, 864, 864});
args.emplace_back(BenchArgs{4, 9, 64});
args.emplace_back(BenchArgs{8, 4096, 4096});
args.emplace_back(BenchArgs{8, 1024, 6912});
args.emplace_back(BenchArgs{8, 3456, 3456});
args.emplace_back(BenchArgs{8, 2304, 2304});
args.emplace_back(BenchArgs{4, 256, 8192});
args.emplace_back(BenchArgs{8, 864, 864});
args.emplace_back(BenchArgs{4, 9, 64});
args.emplace_back(BenchArgs{16, 4096, 4096});
args.emplace_back(BenchArgs{16, 1024, 6912});
args.emplace_back(BenchArgs{16, 3456, 3456});
args.emplace_back(BenchArgs{16, 2304, 2304});
args.emplace_back(BenchArgs{8, 256, 8192});
args.emplace_back(BenchArgs{16, 864, 864});
args.emplace_back(BenchArgs{8, 9, 64});
args.emplace_back(BenchArgs{32, 4096, 4096});
args.emplace_back(BenchArgs{32, 1024, 6912});
args.emplace_back(BenchArgs{32, 3456, 3456});
args.emplace_back(BenchArgs{32, 2304, 2304});
args.emplace_back(BenchArgs{16, 256, 8192});
args.emplace_back(BenchArgs{32, 864, 864});
args.emplace_back(BenchArgs{32, 9, 64});
args.emplace_back(BenchArgs{64, 4096, 4096});
args.emplace_back(BenchArgs{64, 1024, 6912});
args.emplace_back(BenchArgs{64, 3456, 3456});
args.emplace_back(BenchArgs{64, 2304, 2304});
args.emplace_back(BenchArgs{32, 256, 8192});
args.emplace_back(BenchArgs{64, 864, 864});
args.emplace_back(BenchArgs{64, 9, 64});
args.emplace_back(BenchArgs{128, 4096, 4096});
args.emplace_back(BenchArgs{128, 1024, 6912});
args.emplace_back(BenchArgs{128, 3456, 3456});
args.emplace_back(BenchArgs{128, 2304, 2304});
args.emplace_back(BenchArgs{64, 256, 8192});
args.emplace_back(BenchArgs{128, 864, 864});
args.emplace_back(BenchArgs{128, 9, 64});
return args;
}
void benchmark_matrix_mul(
Handle* handle, const std::vector<BenchArgs>& args, DType A_dtype,
DType B_dtype, DType C_dtype, const char* algo = nullptr,
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT) {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv());
CUBenchmarker<MatrixMulForward> benchmarker(handle);
CUBenchmarker<MatrixMulForward> benchmarker_cublas(handle);
size_t RUNS = 1000;
benchmarker.set_display(false).set_times(RUNS);
benchmarker_cublas.set_display(false).set_times(RUNS);
benchmarker_cublas.set_before_exec_callback(
AlgoChecker<MatrixMulForward>("CUBLAS"));
benchmarker.set_dtype(0, A_dtype)
.set_dtype(1, B_dtype)
.set_dtype(2, C_dtype);
benchmarker_cublas.set_dtype(0, A_dtype)
.set_dtype(1, B_dtype)
.set_dtype(2, C_dtype);
using Param = MatrixMul::Param;
for (auto&& arg : args) {
size_t m = arg.m, n = arg.n, k = arg.k;
Param param;
param.transposeA = arg.mask & 0x1;
param.transposeB = arg.mask & 0x2;
param.format = format;
size_t A0 = m, A1 = k, B0 = k, B1 = n;
if (param.transposeA) {
std::swap(A0, A1);
}
if (param.transposeB) {
std::swap(B0, B1);
}
benchmarker.set_param(param);
TensorShape A{A0, A1}, B{B0, B1}, C{m, n};
float time_in_ms = 0.f;
if (algo) {
time_in_ms =
algo_benchmark<MatrixMulForward, OprProxy<MatrixMulForward>,
CUTimer>(benchmarker, {A, B, C}, algo) /
RUNS;
} else {
time_in_ms = benchmarker.execs({A, B, C}) / RUNS;
}
benchmarker_cublas.set_param(param);
auto time_in_ms_cublas = benchmarker_cublas.execs({A, B, C}) / RUNS;
float flo = 2.0 * m * n * k / (1e12);
printf("A=%s, B=%s, C=%s, time(algo=%s)=%.2f %.2fTops, "
"time(cublas)=%.2f %.2fTops, "
"perf(algo=%s)/perf(cublas)=%.2f\n",
A.to_string().c_str(), B.to_string().c_str(),
C.to_string().c_str(), algo, time_in_ms,
(flo / (time_in_ms * 1e-3)), time_in_ms_cublas,
(flo / (time_in_ms_cublas * 1e-3)), algo,
time_in_ms_cublas / time_in_ms);
}
}
#endif
} // namespace
TEST_F(CUDA, CUTLASS_GEMM_MULTI_BATCHSIZE) {
auto args = matrix_mul::get_matmul_args_no_mask();
test_multibatchsize(handle_cuda(), dtype::Float32(), dtype::Float32(),
dtype::Float32(),
"CUTLASS_FLOAT32_SIMT_128X128X8_32X64X8", args,
param::MatrixMul::Format::DEFAULT);
}
#define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \
cb(1, 64, 256, 8, 32, 64, 8); \
cb(2, 256, 64, 8, 64, 32, 8); \
cb(3, 32, 256, 8, 16, 64, 8); \
cb(4, 256, 32, 8, 64, 16, 8); \
cb(5, 128, 128, 8, 32, 64, 8); \
cb(6, 128, 64, 8, 64, 32, 8); \
cb(7, 64, 128, 8, 32, 64, 8); \
cb(8, 128, 32, 8, 64, 32, 8); \
cb(9, 32, 128, 8, 32, 64, 8); \
cb(10, 64, 64, 8, 32, 64, 8); \
cb(11, 32, 64, 8, 32, 64, 8); \
cb(12, 64, 32, 8, 64, 32, 8); \
cb(13, 32, 32, 8, 32, 32, 8); \
cb(14, 8, 32, 8, 8, 32, 8); \
cb(15, 16, 32, 8, 16, 32, 8); \
cb(16, 16, 64, 8, 16, 64, 8); \
cb(17, 16, 128, 8, 16, 64, 8);
#define cb(name, tbm, tbn, tbk, wm, wn, wk) \
TEST_F(CUDA, CUTLASS_GEMM_##name) { \
matrix_mul::check_matrix_mul<MatrixMulForward>( \
dtype::Float32(), dtype::Float32(), dtype::Float32(), \
handle_cuda(), \
"CUTLASS_FLOAT32_SIMT_" #tbm "X" #tbn "X" #tbk "_" #wm "X" #wn \
"X" #wk); \
}
MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
#undef cb
#undef MEGDNN_FOREACH_CUTLASS_KERNEL
#if MEGDNN_WITH_BENCHMARK
TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL) {
benchmark_matrix_mul(handle_cuda(), get_square_matmul_args(),
dtype::Float32(), dtype::Float32(), dtype::Float32(),
"CUTLASS_FLOAT32_SIMT");
}
TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL_FEAT) {
benchmark_matrix_mul(handle_cuda(), get_feat_model_args(), dtype::Float32(),
dtype::Float32(), dtype::Float32(),
"CUTLASS_FLOAT32_SIMT");
}
#endif
} // namespace test
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册