提交 33676125 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add tensorcore matmul for fp16 data type

GitOrigin-RevId: 025c591f75afcef8fd58034a9cdd1ae8528bbda1
上级 12cdbddd
......@@ -5,6 +5,8 @@ genrule(
outs = cutlass_gen_list,
cmd = """GEN=$(location //brain/megbrain/dnn/scripts/cutlass_generator:generator.py)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop884 $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop1688 $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D)
......
......@@ -252,7 +252,8 @@ def GeneratesGemm(tile, data_type, layout_a, layout_b, layout_c, min_cc, align_a
if tile.math_instruction.element_accumulator == DataType.s32:
epilogues = [EpilogueFunctor.LinearCombinationClamp]
else:
assert tile.math_instruction.element_accumulator == DataType.f32
assert tile.math_instruction.element_accumulator == DataType.f32 or \
tile.math_instruction.element_accumulator == DataType.f16
epilogues = [EpilogueFunctor.LinearCombination]
for epilogue in epilogues:
......@@ -799,7 +800,22 @@ class EmitGemmSplitKParallelInstance:
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>
>,
cutlass::epilogue::thread::Convert<
${element_accumulator},
${epilogue_vector_length},
${element_accumulator}
>,
cutlass::reduction::thread::ReduceAdd<
${element_accumulator},
${element_accumulator},
${epilogue_vector_length}
>,
cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle,
${stages},
${align_a},
${align_b},
${math_operation}
>;
"""
def emit(self, operation):
......@@ -831,7 +847,10 @@ class EmitGemmSplitKParallelInstance:
'epilogue_vector_length': str(epilogue_vector_length),
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
'stages': str(operation.tile_description.stages),
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
'align_a': str(operation.A.alignment),
'align_b': str(operation.B.alignment),
}
return SubstituteTemplate(self.template, values)
......
......@@ -32,6 +32,8 @@ if __name__ == "__main__":
f.write("# Generated by dnn/scripts/cutlass_generator/gen_list.py\n\n")
f.write("cutlass_gen_list = [\n")
write_op_list(f, "gemm", "simt")
write_op_list(f, "gemm", "tensorop1688")
write_op_list(f, "gemm", "tensorop884")
write_op_list(f, "gemv", "simt")
write_op_list(f, "deconv", "simt")
write_op_list(f, "conv2d", "simt")
......
......@@ -596,6 +596,131 @@ def GenerateGemv_Simt(args):
align_b))
return operations
#
def GeneratesGemm_TensorOp_1688(args):
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
]
math_instructions = [
MathInstruction( \
[16, 8, 8], \
DataType.f16, DataType.f16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add),
MathInstruction( \
[16, 8, 8], \
DataType.f16, DataType.f16, DataType.f16, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add),
]
min_cc = 75
max_cc = 1024
alignment_constraints = [8, 4, 2,
#1
]
operations = []
for math_inst in math_instructions:
for layout in layouts:
for align in alignment_constraints:
tile_descriptions = [
TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
## comment some configuration to reduce compilation time and binary size
# TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
# TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
# TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
]
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_a,
math_inst.element_accumulator,
]
for tile in tile_descriptions:
operations += GeneratesGemm(tile, \
data_type, \
layout[0], \
layout[1], \
layout[2], \
min_cc, \
align * 16, \
align * 16, \
align * 16)
return operations
#
def GeneratesGemm_TensorOp_884(args):
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
]
math_instructions = [
MathInstruction( \
[8, 8, 4], \
DataType.f16, DataType.f16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add),
MathInstruction( \
[8, 8, 4], \
DataType.f16, DataType.f16, DataType.f16, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add),
]
min_cc = 70
max_cc = 75
alignment_constraints = [8, 4, 2,
# 1
]
operations = []
for math_inst in math_instructions:
for layout in layouts:
for align in alignment_constraints:
tile_descriptions = [
TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
## comment some configuration to reduce compilation time and binary size
# TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
# TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
# TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
]
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_a,
math_inst.element_accumulator,
]
for tile in tile_descriptions:
operations += GeneratesGemm(tile, \
data_type, \
layout[0], \
layout[1], \
layout[2], \
min_cc, \
align * 16, \
align * 16, \
align * 16)
return operations
#
def GenerateConv2dOperations(args):
if args.type == "simt":
......@@ -613,9 +738,14 @@ def GenerateDeconvOperations(args):
return GenerateDeconv_Simt(args)
def GenerateGemmOperations(args):
assert args.type == "simt", "operation gemm only support" \
"simt. (got:{})".format(args.type)
return GenerateGemm_Simt(args)
if args.type == "tensorop884":
return GeneratesGemm_TensorOp_884(args)
elif args.type == "tensorop1688":
return GeneratesGemm_TensorOp_1688(args)
else:
assert args.type == "simt", "operation gemm only support" \
"simt. (got:{})".format(args.type)
return GenerateGemm_Simt(args)
def GenerateGemvOperations(args):
assert args.type == "simt", "operation gemv only support" \
......@@ -631,7 +761,7 @@ if __name__ == "__main__":
parser.add_argument("--operations", type=str, choices=['gemm', 'gemv', 'conv2d', 'deconv'],
required=True, help="Specifies the operation to generate (gemm, gemv, conv2d, deconv)")
parser.add_argument("output", type=str, help="output directory for CUTLASS kernel files")
parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832'],
parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832', 'tensorop884', 'tensorop1688'],
default='simt', help="kernel type of CUTLASS kernel generator")
gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
......
......@@ -151,6 +151,8 @@ if(MGE_WITH_CUDA)
set(${gen_files} "${${gen_files}}" PARENT_SCOPE)
endfunction()
gen_cutlass_kimpl(gemm simt CUTLASS_SOURCES)
gen_cutlass_kimpl(gemm tensorop884 CUTLASS_SOURCES)
gen_cutlass_kimpl(gemm tensorop1688 CUTLASS_SOURCES)
gen_cutlass_kimpl(gemv simt CUTLASS_SOURCES)
gen_cutlass_kimpl(deconv simt CUTLASS_SOURCES)
gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES)
......
......@@ -49,6 +49,8 @@ namespace library {
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
void initialize_all_gemm_simt_operations(Manifest& manifest);
void initialize_all_gemm_tensorop884_operations(Manifest& manifest);
void initialize_all_gemm_tensorop1688_operations(Manifest& manifest);
void initialize_all_conv2d_simt_operations(Manifest& manifest);
void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest);
void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest);
......@@ -56,6 +58,8 @@ void initialize_all_deconv_simt_operations(Manifest& manifest);
void initialize_all(Manifest& manifest) {
initialize_all_gemm_simt_operations(manifest);
initialize_all_gemm_tensorop884_operations(manifest);
initialize_all_gemm_tensorop1688_operations(manifest);
initialize_all_conv2d_simt_operations(manifest);
initialize_all_conv2d_tensorop8816_operations(manifest);
initialize_all_conv2d_tensorop8832_operations(manifest);
......
......@@ -55,6 +55,8 @@ GemmKey get_gemm_key_from_desc(const GemmDescription& desc) {
key.layout_B = desc.B.layout;
key.element_C = desc.C.element;
key.layout_C = desc.C.layout;
key.element_accumulator =
desc.tile_description.math_instruction.element_accumulator;
key.threadblock_shape_m = desc.tile_description.threadblock_shape.m();
key.threadblock_shape_n = desc.tile_description.threadblock_shape.n();
......@@ -75,6 +77,8 @@ GemmKey get_gemm_key_from_desc(const GemmDescription& desc) {
desc.tile_description.math_instruction.instruction_shape.k();
key.stages = desc.stages;
key.alignment_A = desc.A.alignment;
key.alignment_B = desc.B.alignment;
key.split_k_mode = desc.split_k_mode;
return key;
......
......@@ -77,6 +77,7 @@ struct GemmKey {
LayoutTypeID layout_B;
NumericTypeID element_C;
LayoutTypeID layout_C;
NumericTypeID element_accumulator;
int threadblock_shape_m;
int threadblock_shape_n;
......@@ -91,12 +92,15 @@ struct GemmKey {
int instruction_shape_k;
int stages;
int alignment_A;
int alignment_B;
SplitKMode split_k_mode;
inline bool operator==(GemmKey const& rhs) const {
return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) &&
(element_B == rhs.element_B) && (layout_B == rhs.layout_B) &&
(element_C == rhs.element_C) && (layout_C == rhs.layout_C) &&
(element_accumulator == rhs.element_accumulator) &&
(threadblock_shape_m == rhs.threadblock_shape_m) &&
(threadblock_shape_n == rhs.threadblock_shape_n) &&
(threadblock_shape_k == rhs.threadblock_shape_k) &&
......@@ -106,7 +110,9 @@ struct GemmKey {
(instruction_shape_m == rhs.instruction_shape_m) &&
(instruction_shape_n == rhs.instruction_shape_n) &&
(instruction_shape_k == rhs.instruction_shape_k) &&
(stages == rhs.stages) && (split_k_mode == rhs.split_k_mode);
(stages == rhs.stages) && (alignment_A == rhs.alignment_A) &&
(alignment_B == rhs.alignment_B) &&
(split_k_mode == rhs.split_k_mode);
}
inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); }
......@@ -130,10 +136,13 @@ struct GemmKey {
"\n layout_B: " + to_string(layout_B) +
"\n element_C: " + to_string(element_C) +
"\n layout_C: " + to_string(layout_C) +
"\n element_accumulator: " + to_string(element_accumulator) +
"\n threadblock_shape: " + threadblock_shape_str +
"\n warp_shape: " + warp_shape_str +
"\n instruction_shape: " + instruction_shape_str +
"\n stages: " + std::to_string(stages) +
"\n alignment_A: " + std::to_string(alignment_A) +
"\n alignment_B: " + std::to_string(alignment_B) +
"\n split_k_mode: " + to_string(split_k_mode) + "\n}";
}
};
......@@ -147,6 +156,8 @@ struct GemmKeyHasher {
.update(&key.layout_B, sizeof(key.layout_B))
.update(&key.element_C, sizeof(key.element_C))
.update(&key.layout_C, sizeof(key.layout_C))
.update(&key.element_accumulator,
sizeof(key.element_accumulator))
.update(&key.threadblock_shape_m,
sizeof(key.threadblock_shape_m))
.update(&key.threadblock_shape_n,
......@@ -157,6 +168,8 @@ struct GemmKeyHasher {
.update(&key.warp_shape_n, sizeof(key.warp_shape_n))
.update(&key.warp_shape_k, sizeof(key.warp_shape_k))
.update(&key.stages, sizeof(key.stages))
.update(&key.alignment_A, sizeof(key.alignment_A))
.update(&key.alignment_B, sizeof(key.alignment_B))
.update(&key.split_k_mode, sizeof(key.split_k_mode))
.digest();
}
......
......@@ -43,6 +43,12 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {
for (auto&& algo : simt_float32_gemv_batched_strided) {
all_algos.push_back(&algo);
}
for (auto&& algo : tensorop_float16) {
all_algos.push_back(&algo);
}
for (auto&& algo : tensorop_float16_split_k) {
all_algos.push_back(&algo);
}
#endif
all_algos.push_back(&naive);
......@@ -53,7 +59,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {
#if CUDA_VERSION >= 9020
void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() {
using AlgoParam = AlgoFloat32SIMT::AlgoParam;
using AlgoParam = AlgoCutlassMatrixMulBase::AlgoParam;
simt_float32.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8});
simt_float32.emplace_back(AlgoParam{256, 64, 8, 64, 32, 8});
simt_float32.emplace_back(AlgoParam{32, 256, 8, 16, 64, 8});
......@@ -91,6 +97,19 @@ void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() {
simt_float32_gemv_batched_strided.emplace_back(128);
simt_float32_gemv_batched_strided.emplace_back(64);
simt_float32_gemv_batched_strided.emplace_back(32);
#define FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb) \
cb(256, 128, 32, 64, 64, 32, 8, 8, 4); \
cb(128, 256, 32, 64, 64, 32, 8, 8, 4); \
cb(128, 128, 32, 64, 64, 32, 8, 8, 4); \
cb(256, 128, 32, 64, 64, 32, 16, 8, 8); \
cb(128, 256, 32, 64, 64, 32, 16, 8, 8); \
cb(128, 128, 32, 64, 64, 32, 16, 8, 8);
#define cb(...) \
tensorop_float16.emplace_back(AlgoParam{__VA_ARGS__}); \
tensorop_float16_split_k.emplace_back(AlgoParam{__VA_ARGS__});
FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb)
#undef cb
#undef FOREACH_CUTLASS_MATMUL_F16_SHAPES
}
#endif
......
......@@ -41,11 +41,13 @@ public:
CUDA_WMMA_UINT4X4X32,
CUDA_CUBLASLT,
CUDA_NAIVE,
CUDA_BFLOAT16,
CUDA_BFLOAT16,
#if CUDA_VERSION >= 9020
CUDA_FLOAT32_SIMT,
CUDA_FLOAT32_SIMT_SPLIT_K,
CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED,
CUDA_FLOAT32_SIMT,
CUDA_FLOAT32_SIMT_SPLIT_K,
CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED,
CUDA_FLOAT16_TENSOR_OP,
CUDA_FLOAT16_TENSOR_OP_SPLIT_K,
#endif
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
......@@ -188,65 +190,83 @@ private:
#endif
#if CUDA_VERSION >= 9020
class MatrixMulForwardImpl::AlgoFloat32SIMT final : public AlgoBase {
class MatrixMulForwardImpl::AlgoCutlassMatrixMulBase : public AlgoBase {
public:
struct AlgoParam {
int threadblock_m, threadblock_n, threadblock_k;
int warp_m, warp_n, warp_k;
std::string to_string() {
return ssprintf("%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n,
threadblock_k, warp_m, warp_n, warp_k);
}
int instruction_m, instruction_n, instruction_k;
AlgoParam(int threadblock_m_, int threadblock_n_, int threadblock_k_,
int warp_m_, int warp_n_, int warp_k_, int instruction_m_ = 1,
int instruction_n_ = 1, int instruction_k_ = 1)
: threadblock_m{threadblock_m_},
threadblock_n{threadblock_n_},
threadblock_k{threadblock_k_},
warp_m{warp_m_},
warp_n{warp_n_},
warp_k{warp_k_},
instruction_m{instruction_m_},
instruction_n{instruction_n_},
instruction_k{instruction_k_} {}
std::string to_string() const;
};
AlgoCutlassMatrixMulBase(AlgoParam algo_param) : m_algo_param{algo_param} {}
void exec(const ExecArgs& args) const override;
std::string param() const override {
std::string ret;
serialize_write_pod(m_algo_param, ret);
return ret;
}
protected:
virtual int min_alignment_requirement() const = 0;
virtual void do_exec(const ExecArgs& args) const = 0;
std::pair<bool, TensorLayoutArray> construct_aligned_layouts(
const SizeArgs& args) const;
int max_alignment(const SizeArgs& args) const;
AlgoParam m_algo_param;
};
class MatrixMulForwardImpl::AlgoFloat32SIMT final
: public AlgoCutlassMatrixMulBase {
public:
AlgoFloat32SIMT(AlgoParam algo_param)
: m_algo_param{algo_param},
: AlgoCutlassMatrixMulBase{algo_param},
m_name{ssprintf("CUTLASS_FLOAT32_SIMT_%s",
m_algo_param.to_string().c_str())} {}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
void exec(const ExecArgs& args) const override;
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT)
std::string param() const override {
std::string ret;
serialize_write_pod(m_algo_param, ret);
return ret;
}
private:
AlgoParam m_algo_param;
void do_exec(const ExecArgs& args) const override;
int min_alignment_requirement() const override { return 1; }
std::string m_name;
};
class MatrixMulForwardImpl::AlgoFloat32SIMTSplitK final : public AlgoBase {
class MatrixMulForwardImpl::AlgoFloat32SIMTSplitK final
: public AlgoCutlassMatrixMulBase {
public:
using AlgoParam = MatrixMulForwardImpl::AlgoFloat32SIMT::AlgoParam;
AlgoFloat32SIMTSplitK(AlgoParam algo_param)
: m_algo_param{algo_param},
: AlgoCutlassMatrixMulBase{algo_param},
m_name{ssprintf("CUTLASS_FLOAT32_SIMT_SPLIT_K_%s",
m_algo_param.to_string().c_str())} {}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
void exec(const ExecArgs& args) const override;
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K)
std::string param() const override {
std::string ret;
serialize_write_pod(m_algo_param, ret);
return ret;
}
private:
AlgoParam m_algo_param;
void do_exec(const ExecArgs& args) const override;
int min_alignment_requirement() const override { return 1; }
std::string m_name;
};
......@@ -276,6 +296,56 @@ private:
int m_threadblock_n;
std::string m_name;
};
class MatrixMulForwardImpl::AlgoFloat16TensorOp final
: public AlgoCutlassMatrixMulBase {
public:
AlgoFloat16TensorOp(AlgoParam algo_param)
: AlgoCutlassMatrixMulBase{algo_param},
m_name{ssprintf("CUTLASS_FLOAT16_TENSOR_OP_h%d%d%d_%s",
m_algo_param.instruction_m,
m_algo_param.instruction_n,
m_algo_param.instruction_k,
m_algo_param.to_string().c_str())} {}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT16_TENSOR_OP)
private:
void do_exec(const ExecArgs& args) const override;
int min_alignment_requirement() const override { return 2; }
std::string m_name;
};
class MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK final
: public AlgoCutlassMatrixMulBase {
public:
AlgoFloat16TensorOpSplitK(AlgoParam algo_param)
: AlgoCutlassMatrixMulBase{algo_param},
m_name{ssprintf("CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h%d%d%d_%s",
m_algo_param.instruction_m,
m_algo_param.instruction_n,
m_algo_param.instruction_k,
m_algo_param.to_string().c_str())} {}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT16_TENSOR_OP_SPLIT_K)
private:
void do_exec(const ExecArgs& args) const override;
int min_alignment_requirement() const override { return 2; }
std::string m_name;
};
#endif
class MatrixMulForwardImpl::AlgoPack : NonCopyableObj {
......@@ -300,6 +370,8 @@ public:
std::vector<AlgoFloat32SIMTSplitK> simt_float32_split_k;
std::vector<AlgoFloat32SIMTGemvBatchedStrided>
simt_float32_gemv_batched_strided;
std::vector<AlgoFloat16TensorOp> tensorop_float16;
std::vector<AlgoFloat16TensorOpSplitK> tensorop_float16_split_k;
#endif
std::vector<AlgoBase*> all_algos;
......
/**
* \file dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/cutlass/singleton.h"
#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/utils.h"
#if CUDA_VERSION >= 9020
using namespace megdnn;
using namespace cuda;
bool MatrixMulForwardImpl::AlgoFloat16TensorOp::is_available(
const SizeArgs& args) const {
bool available =
args.opr->param().format == param::MatrixMul::Format::DEFAULT &&
args.layout_b.dtype == dtype::Float16() &&
args.layout_c.dtype == dtype::Float16();
int n = args.layout_c.shape[1];
auto&& device_prop = cuda::current_device_prop();
int y_grid_limit = device_prop.maxGridSize[1];
// limit y grid
available &= ((n + m_algo_param.threadblock_n - 1) /
m_algo_param.threadblock_n <=
y_grid_limit);
if (m_algo_param.instruction_m == 8 && m_algo_param.instruction_n == 8 &&
m_algo_param.instruction_k == 4) {
available &= is_compute_capability_required(7, 0);
} else {
megdnn_assert(m_algo_param.instruction_m == 16 &&
m_algo_param.instruction_n == 8 &&
m_algo_param.instruction_k == 8);
available &= is_compute_capability_required(7, 5);
}
return available;
}
size_t MatrixMulForwardImpl::AlgoFloat16TensorOp::get_workspace_in_bytes(
const SizeArgs& args) const {
auto aligned = construct_aligned_layouts(args);
if (!aligned.first)
return 0_z;
const auto& layouts = aligned.second;
size_t ws_size = 0;
for (auto&& ly : layouts) {
ws_size += ly.span().dist_byte();
}
return ws_size;
}
void MatrixMulForwardImpl::AlgoFloat16TensorOp::do_exec(
const ExecArgs& args) const {
int64_t lda = args.tensor_a.layout.stride[0],
ldb = args.tensor_b.layout.stride[0],
ldc = args.tensor_c.layout.stride[0];
int alignment = max_alignment(args);
int min_alignment = min_alignment_requirement();
auto&& param = args.opr->param();
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1],
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1];
megdnn_assert(lda % alignment == 0 && ldb % alignment == 0 &&
ldc % alignment == 0 && m % alignment == 0 &&
n % alignment == 0 && k % alignment == 0 &&
alignment >= min_alignment);
cutlass::gemm::GemmCoord problem_size{m, n, k};
auto&& stream = cuda_stream(args.opr->handle());
int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr);
// \note these constants (i.e. one and zero) of cutlass epilogue will be
// passed by pointers and interpreted as ElementCompute*, which will be used
// to initialize kernel parameters. So the arguments' type on the host side
// should be the same as the ElementCompute of kernel instance, otherwise
// undefined kernel bahaviors will occur caused by incorrect intepretation
// of these pointers.
float one = 1.f, zero = 0.f;
dt_float16 one_f16 = static_cast<dt_float16>(one),
zero_f16 = static_cast<dt_float16>(zero);
using namespace cutlass::library;
auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor
: LayoutTypeID::kRowMajor;
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor
: LayoutTypeID::kRowMajor;
void *host_one, *host_zero;
NumericTypeID element_accumulator;
if (param.compute_mode == param::MatrixMul::ComputeMode::DEFAULT) {
element_accumulator = NumericTypeID::kF16;
host_one = &one_f16;
host_zero = &zero_f16;
} else {
megdnn_assert(param.compute_mode ==
param::MatrixMul::ComputeMode::FLOAT32);
element_accumulator = NumericTypeID::kF32;
host_one = &one;
host_zero = &zero;
}
GemmKey key{NumericTypeID::kF16,
layoutA,
NumericTypeID::kF16,
layoutB,
NumericTypeID::kF16,
LayoutTypeID::kRowMajor,
element_accumulator,
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k,
m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k,
m_algo_param.instruction_m,
m_algo_param.instruction_n,
m_algo_param.instruction_k,
2,
alignment,
alignment,
SplitKMode::kNone};
const auto& table = Singleton::get().operation_table;
megdnn_assert(table.gemm_operations.count(key) > 0,
"key not found in cutlass operation table");
const auto& ops = table.gemm_operations.at(key);
megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu",
ops.size());
GemmArguments gemm_args{problem_size,
args.tensor_a.raw_ptr,
args.tensor_b.raw_ptr,
args.tensor_c.raw_ptr,
args.tensor_c.raw_ptr,
lda,
ldb,
ldc,
ldc,
1,
host_one,
host_zero};
cutlass_check(ops[0]->run(&gemm_args, workspace, stream));
}
#endif
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/cutlass/singleton.h"
#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/utils.h"
#if CUDA_VERSION >= 9020
using namespace megdnn;
using namespace cuda;
bool MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::is_available(
const SizeArgs& args) const {
auto&& param = args.opr->param();
int n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
bool available =
args.opr->param().format == param::MatrixMul::Format::DEFAULT &&
args.layout_a.dtype == dtype::Float16() &&
args.layout_b.dtype == dtype::Float16() &&
args.layout_c.dtype == dtype::Float16() && k > n;
auto&& device_prop = cuda::current_device_prop();
int y_grid_limit = device_prop.maxGridSize[1];
// limit y grid
available &= ((n + m_algo_param.threadblock_n - 1) /
m_algo_param.threadblock_n <=
y_grid_limit);
if (m_algo_param.instruction_m == 8 && m_algo_param.instruction_n == 8 &&
m_algo_param.instruction_k == 4) {
available &= is_compute_capability_required(7, 0);
} else {
megdnn_assert(m_algo_param.instruction_m == 16 &&
m_algo_param.instruction_n == 8 &&
m_algo_param.instruction_k == 8);
available &= is_compute_capability_required(7, 5);
}
return available;
}
size_t MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::get_workspace_in_bytes(
const SizeArgs& args) const {
auto aligned = construct_aligned_layouts(args);
auto&& param = args.opr->param();
int m = args.layout_c.shape[0], n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
int split_k_slices = std::max(1, k / n);
if (!aligned.first)
return args.layout_c.dtype.size(m * n * split_k_slices);
const auto& layouts = aligned.second;
int align_m = layouts[2].shape[0], align_n = layouts[2].shape[1],
align_k = layouts[0].shape[1];
split_k_slices = std::max(1, align_k / align_n);
size_t ws_size =
args.layout_c.dtype.size(align_m * align_n * split_k_slices);
for (auto&& ly : layouts)
ws_size += ly.span().dist_byte();
return ws_size;
}
void MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::do_exec(
const ExecArgs& args) const {
int64_t lda = args.tensor_a.layout.stride[0],
ldb = args.tensor_b.layout.stride[0],
ldc = args.tensor_c.layout.stride[0];
int alignment = max_alignment(args);
int min_alignment = min_alignment_requirement();
auto&& param = args.opr->param();
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1],
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1];
megdnn_assert(lda % alignment == 0 && ldb % alignment == 0 &&
ldc % alignment == 0 && m % alignment == 0 &&
n % alignment == 0 && k % alignment == 0 &&
alignment >= min_alignment);
cutlass::gemm::GemmCoord problem_size{m, n, k};
int split_k_slices = std::max(1, k / n);
auto&& stream = cuda_stream(args.opr->handle());
int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr);
// \note these constants (i.e. one and zero) of cutlass epilogue will be
// passed by pointers and interpreted as ElementCompute*, which will be used
// to initialize kernel parameters. So the arguments' type on the host side
// should be the same as the ElementCompute of kernel instance, otherwise
// undefined kernel bahaviors will occur caused by incorrect intepretation
// of these pointers.
float one = 1.f, zero = 0.f;
dt_float16 one_f16 = static_cast<dt_float16>(one),
zero_f16 = static_cast<dt_float16>(zero);
using namespace cutlass::library;
auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor
: LayoutTypeID::kRowMajor;
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor
: LayoutTypeID::kRowMajor;
void *host_one, *host_zero;
NumericTypeID element_accumulator;
if (param.compute_mode == param::MatrixMul::ComputeMode::DEFAULT) {
element_accumulator = NumericTypeID::kF16;
host_one = &one_f16;
host_zero = &zero_f16;
} else {
megdnn_assert(param.compute_mode ==
param::MatrixMul::ComputeMode::FLOAT32);
element_accumulator = NumericTypeID::kF32;
host_one = &one;
host_zero = &zero;
}
GemmKey key{NumericTypeID::kF16,
layoutA,
NumericTypeID::kF16,
layoutB,
NumericTypeID::kF16,
LayoutTypeID::kRowMajor,
element_accumulator,
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k,
m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k,
m_algo_param.instruction_m,
m_algo_param.instruction_n,
m_algo_param.instruction_k,
2,
alignment,
alignment,
SplitKMode::kParallel};
const auto& table = Singleton::get().operation_table;
megdnn_assert(table.gemm_operations.count(key) > 0,
"key not found in cutlass operation table");
const auto& ops = table.gemm_operations.at(key);
megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu",
ops.size());
GemmArguments gemm_args{problem_size,
args.tensor_a.raw_ptr,
args.tensor_b.raw_ptr,
args.tensor_c.raw_ptr,
args.tensor_c.raw_ptr,
lda,
ldb,
ldc,
ldc,
split_k_slices,
host_one,
host_zero};
cutlass_check(ops[0]->run(&gemm_args, workspace, stream));
}
#endif
// vim: syntax=cpp.doxygen
......@@ -42,7 +42,8 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes(
return 0_z;
}
void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const {
void MatrixMulForwardImpl::AlgoFloat32SIMT::do_exec(
const ExecArgs& args) const {
int64_t lda = args.tensor_a.layout.stride[0],
ldb = args.tensor_b.layout.stride[0],
ldc = args.tensor_c.layout.stride[0];
......@@ -65,12 +66,14 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const {
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor
: LayoutTypeID::kRowMajor;
int alignment = min_alignment_requirement();
GemmKey key{NumericTypeID::kF32,
layoutA,
NumericTypeID::kF32,
layoutB,
NumericTypeID::kF32,
LayoutTypeID::kRowMajor,
NumericTypeID::kF32,
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k,
......@@ -79,8 +82,10 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const {
m_algo_param.warp_k,
1,
1,
1,
2,
1,
2,
alignment,
alignment,
SplitKMode::kNone};
const Operation* op = Singleton::get().operation_table.find_op(key);
......
......@@ -22,7 +22,7 @@ using namespace cuda;
bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available(
const SizeArgs& args) const {
auto&& param = args.opr->param();
int m = args.layout_c.shape[0], n = args.layout_c.shape[1],
int n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
bool available =
args.opr->param().format == param::MatrixMul::Format::DEFAULT &&
......@@ -32,8 +32,8 @@ bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available(
auto&& device_prop = cuda::current_device_prop();
int y_grid_limit = device_prop.maxGridSize[1];
// limit y grid
available &= ((m + m_algo_param.threadblock_m - 1) /
m_algo_param.threadblock_m <=
available &= ((n + m_algo_param.threadblock_n - 1) /
m_algo_param.threadblock_n <=
y_grid_limit);
return available;
}
......@@ -47,7 +47,7 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes(
return args.layout_c.dtype.size(m * n * split_k_slices);
}
void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec(
void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::do_exec(
const ExecArgs& args) const {
int64_t lda = args.tensor_a.layout.stride[0],
ldb = args.tensor_b.layout.stride[0],
......@@ -72,12 +72,14 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec(
auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor
: LayoutTypeID::kRowMajor;
int alignment = min_alignment_requirement();
GemmKey key{NumericTypeID::kF32,
layoutA,
NumericTypeID::kF32,
layoutB,
NumericTypeID::kF32,
LayoutTypeID::kRowMajor,
NumericTypeID::kF32,
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k,
......@@ -87,7 +89,9 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec(
1,
1,
1,
2,
2,
alignment,
alignment,
SplitKMode::kParallel};
Operation const* op = Singleton::get().operation_table.find_op(key);
......
/**
* \file dnn/src/cuda/matrix_mul/cutlass_matrix_mul_base.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
#include "src/cuda/utils.h"
#if CUDA_VERSION >= 9020
using namespace megdnn;
using namespace cuda;
std::string
MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::AlgoParam::to_string() const {
return ssprintf("%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n,
threadblock_k, warp_m, warp_n, warp_k);
}
std::pair<bool, TensorLayoutArray>
MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::construct_aligned_layouts(
const SizeArgs& args) const {
int alignment = max_alignment(args);
int min_alignment = min_alignment_requirement();
bool aligned = alignment >= min_alignment;
if (aligned)
return std::make_pair(!aligned, TensorLayoutArray{{}});
auto&& param = args.opr->param();
int m = args.layout_c.shape[0], n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
size_t align_m = get_aligned_power2(m, min_alignment);
size_t align_n = get_aligned_power2(n, min_alignment);
size_t align_k = get_aligned_power2(k, min_alignment);
TensorLayoutArray layouts;
layouts.emplace_back(TensorLayout{{align_m, align_k}, args.layout_a.dtype});
layouts.emplace_back(TensorLayout{{align_k, align_n}, args.layout_b.dtype});
layouts.emplace_back(TensorLayout{{align_m, align_n}, args.layout_c.dtype});
return std::make_pair(!aligned, std::move(layouts));
}
void MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::exec(
const ExecArgs& args) const {
auto aligned = construct_aligned_layouts(args);
if (!aligned.first)
return do_exec(args);
const auto& layouts = aligned.second;
auto tensor_a = args.tensor_a;
auto tensor_b = args.tensor_b;
auto workspace = args.workspace;
size_t copy_size = 0;
for (const auto& ly : layouts)
copy_size += ly.span().dist_byte();
auto&& param = args.opr->param();
auto&& stream = cuda_stream(args.opr->handle());
cuda_check(cudaMemsetAsync(workspace.raw_ptr, 0, copy_size, stream));
auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>();
auto copy_stride = [](const TensorLayout& src, TensorLayout& dst,
bool trans) {
dst.stride[0] = src.stride[0], dst.stride[1] = src.stride[1];
if (trans)
std::swap(dst.stride[0], dst.stride[1]);
};
copy_stride(layouts[0], tensor_a.layout, param.transposeA);
tensor_a.raw_ptr = workspace.raw_ptr;
relayout->exec(args.tensor_a, tensor_a);
workspace.raw_ptr += layouts[0].span().dist_byte();
workspace.size -= layouts[0].span().dist_byte();
copy_stride(layouts[1], tensor_b.layout, param.transposeB);
tensor_b.raw_ptr = workspace.raw_ptr;
relayout->exec(args.tensor_b, tensor_b);
workspace.raw_ptr += layouts[1].span().dist_byte();
workspace.size -= layouts[1].span().dist_byte();
decltype(tensor_a) tensor_c{workspace.raw_ptr, layouts[2]};
workspace.raw_ptr += layouts[2].span().dist_byte();
workspace.size -= layouts[2].span().dist_byte();
auto&& matmul = args.opr->handle()->create_operator<MatrixMulForward>();
matmul->param().transposeA = false;
matmul->param().transposeB = false;
matmul->param().compute_mode = args.opr->param().compute_mode;
tensor_a.layout = layouts[0];
tensor_b.layout = layouts[1];
ExecArgs args_{static_cast<MatrixMulForwardImpl*>(matmul.get()), tensor_a,
tensor_b, tensor_c, workspace};
do_exec(args_);
tensor_c.layout.TensorShape::operator=(args.layout_c);
relayout->exec(tensor_c, args.tensor_c);
}
int MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::max_alignment(
const SizeArgs& args) const {
auto&& dtype_a = args.layout_a.dtype;
auto&& dtype_b = args.layout_b.dtype;
auto&& dtype_c = args.layout_c.dtype;
auto get_alignment = [](const DType& dt, int len) {
int size_bits = dt.size(1) * 8;
int align = 128;
while (align > 1) {
if ((len * size_bits) % align == 0)
break;
align = align / 2;
}
return align / size_bits;
};
int lda = args.layout_a.stride[0], ldb = args.layout_b.stride[0],
ldc = args.layout_c.stride[0];
auto&& param = args.opr->param();
int m = args.layout_c.shape[0], n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
int max_align = get_alignment(dtype_a, lda);
max_align = std::min(get_alignment(dtype_a, m), max_align);
max_align = std::min(get_alignment(dtype_a, n), max_align);
max_align = std::min(get_alignment(dtype_a, k), max_align);
max_align = std::min(get_alignment(dtype_a, lda), max_align);
max_align = std::min(get_alignment(dtype_b, ldb), max_align);
max_align = std::min(get_alignment(dtype_c, ldc), max_align);
return max_align;
}
#endif
// vim: syntax=cpp.doxygen
......@@ -42,9 +42,12 @@ public:
class AlgoBFloat16;
#endif
#if CUDA_VERSION >= 9020
class AlgoCutlassMatrixMulBase;
class AlgoFloat32SIMT;
class AlgoFloat32SIMTSplitK;
class AlgoFloat32SIMTGemvBatchedStrided;
class AlgoFloat16TensorOp;
class AlgoFloat16TensorOpSplitK;
#endif
class AlgoPack;
......
......@@ -184,7 +184,8 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
const ExecutionPolicyAlgoName& algo,
param::MatrixMul::Format format, size_t nbase,
float eps, std::vector<TestArg>&& user_args,
bool force_deduce_dst) {
bool force_deduce_dst,
param::MatrixMul::ComputeMode compute_mode) {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv());
Checker<Opr> checker(handle);
checker.set_force_deduce_dst(force_deduce_dst);
......@@ -261,6 +262,7 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
Param param;
param.transposeA = arg.mask & 0x1;
param.transposeB = arg.mask & 0x2;
param.compute_mode = compute_mode;
param.format = format;
checker.set_dtype(0, A_dtype)
.set_dtype(1, B_dtype)
......
......@@ -69,7 +69,9 @@ void check_matrix_mul(
const ExecutionPolicyAlgoName& algo = {"", {}},
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT,
size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {},
bool force_deduce_dst = true);
bool force_deduce_dst = true,
param::MatrixMul::ComputeMode compute_mode =
param::MatrixMul::ComputeMode::DEFAULT);
void check_matrix_mul(
DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle,
......
......@@ -21,6 +21,7 @@
#include "test/cuda/fixture.h"
#include "test/cuda/utils.h"
#define MEGDNN_WITH_BENCHMARK 1
#if CUDA_VERSION >= 9020
namespace megdnn {
namespace test {
......@@ -215,6 +216,14 @@ std::vector<BenchArgs> get_feat_model_args() {
return args;
}
std::vector<BenchArgs> get_f16_feat_model_args() {
std::vector<BenchArgs> args;
args.emplace_back(BenchArgs{128, 9216, 9216});
args.emplace_back(BenchArgs{128, 6400, 6400});
args.emplace_back(BenchArgs{128, 5184, 5184});
return args;
}
void benchmark_matrix_mul(
Handle* handle, const std::vector<BenchArgs>& args, DType A_dtype,
DType B_dtype, DType C_dtype, const char* algo = nullptr,
......@@ -364,6 +373,82 @@ MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
#undef cb
#undef MEGDNN_FOREACH_CUTLASS_KERNEL
#define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \
cb(1, 256, 128, 32, 64, 64, 32, 8, 8, 4); \
cb(2, 128, 256, 32, 64, 64, 32, 8, 8, 4); \
cb(3, 128, 128, 32, 64, 64, 32, 8, 8, 4);
#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \
TEST_F(CUDA, CUTLASS_F16_884_GEMM_##name) { \
require_compute_capability(7, 0); \
matrix_mul::check_matrix_mul<MatrixMulForward>( \
dtype::Float16(), dtype::Float16(), dtype::Float16(), \
handle_cuda(), \
"CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn \
"X" #tbk "_" #wm "X" #wn "X" #wk, \
param::MatrixMul::Format::DEFAULT, 8, 1e-2, \
matrix_mul::get_matmul_args()); \
}
MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
#undef cb
#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \
TEST_F(CUDA, CUTLASS_F16_884_GEMM_SPLIT_K_##name) { \
require_compute_capability(7, 0); \
matrix_mul::check_matrix_mul<MatrixMulForward>( \
dtype::Float16(), dtype::Float16(), dtype::Float16(), \
handle_cuda(), \
"CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm \
"X" #tbn "X" #tbk "_" #wm "X" #wn "X" #wk, \
param::MatrixMul::Format::DEFAULT, 8, 1e-3, \
matrix_mul::get_matmul_args_split_k(), true, \
param::MatrixMul::ComputeMode::FLOAT32); \
}
MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
#undef cb
#undef MEGDNN_FOREACH_CUTLASS_KERNEL
#define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \
cb(1, 256, 128, 32, 64, 64, 32, 16, 8, 8); \
cb(2, 128, 256, 32, 64, 64, 32, 16, 8, 8); \
cb(3, 128, 128, 32, 64, 64, 32, 16, 8, 8);
#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \
TEST_F(CUDA, CUTLASS_F16_1688_GEMM_##name) { \
require_compute_capability(7, 5); \
matrix_mul::check_matrix_mul<MatrixMulForward>( \
dtype::Float16(), dtype::Float16(), dtype::Float16(), \
handle_cuda(), \
"CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn \
"X" #tbk "_" #wm "X" #wn "X" #wk, \
param::MatrixMul::Format::DEFAULT, 8, 1e-2, \
matrix_mul::get_matmul_args(), true, \
param::MatrixMul::ComputeMode::FLOAT32); \
}
MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
#undef cb
#define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \
TEST_F(CUDA, CUTLASS_F16_1688_GEMM_SPLIT_K_##name) { \
require_compute_capability(7, 5); \
matrix_mul::check_matrix_mul<MatrixMulForward>( \
dtype::Float16(), dtype::Float16(), dtype::Float16(), \
handle_cuda(), \
"CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm \
"X" #tbn "X" #tbk "_" #wm "X" #wn "X" #wk, \
param::MatrixMul::Format::DEFAULT, 8, 1e-3, \
matrix_mul::get_matmul_args_split_k()); \
}
MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
#undef cb
#undef MEGDNN_FOREACH_CUTLASS_KERNEL
#if MEGDNN_WITH_BENCHMARK
TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL) {
benchmark_matrix_mul(handle_cuda(), get_square_matmul_args(),
......@@ -376,6 +461,12 @@ TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL_FEAT) {
dtype::Float32(), dtype::Float32(),
"CUTLASS_FLOAT32_SIMT");
}
TEST_F(CUDA, BENCHMARK_CUTLASS_F16_MATMUL_FEAT) {
benchmark_matrix_mul(handle_cuda(), get_f16_feat_model_args(),
dtype::Float16(), dtype::Float16(), dtype::Float16(),
"CUTLASS_FLOAT16_TENSOR_OP");
}
#endif
} // namespace test
} // namespace megdnn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册