提交 47fe7663 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add implicit bmm kernels for large kernel depthwise convolution backward filter opr

GitOrigin-RevId: 932e7689e89f2864884546935ca13f656842f41c
上级 dcc96935
......@@ -17,6 +17,8 @@ genrule(
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_fprop --type tensorop884 $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_dgrad --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_dgrad --type tensorop884 $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_wgrad --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_wgrad --type tensorop884 $(@D)
""",
tools = ["//brain/megbrain/dnn/scripts/cutlass_generator:generator.py"],
visibility = ["//visibility:public"],
......
......@@ -317,7 +317,7 @@ class EmitDeconvInstance:
def __init__(self):
self.template = """
// kernel instance "${operation_name}" generated by cutlass generator
using Deconvolution =
using Convolution =
typename cutlass::conv::device::Deconvolution<
${element_src},
${layout_src},
......@@ -415,6 +415,103 @@ using Deconvolution =
return SubstituteTemplate(self.template, values)
class EmitConvolutionBackwardFilterInstance:
def __init__(self):
self.template = """
// kernel instance "${operation_name}" generated by cutlass generator
using Convolution =
typename cutlass::conv::device::ConvolutionBackwardFilter<
${element_src},
${layout_src},
${element_diff},
${layout_diff},
${element_grad},
${layout_grad},
${element_accumulator},
${conv_type},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor}<
${element_grad},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>,
${swizzling_functor},
${stages},
${alignment_src},
${alignment_diff},
${special_optimization},
${math_operator},
${implicit_gemm_mode}>;
"""
def emit(self, operation):
warp_shape = [
int(
operation.tile_description.threadblock_shape[idx]
/ operation.tile_description.warp_count[idx]
)
for idx in range(3)
]
epilogue_vector_length = int(
min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128)
/ DataTypeSize[operation.dst.element]
)
values = {
"operation_name": operation.procedural_name(),
"conv_type": ConvTypeTag[operation.conv_type],
"element_src": DataTypeTag[operation.src.element],
"layout_src": LayoutTag[operation.src.layout],
"element_diff": DataTypeTag[operation.flt.element],
"layout_diff": LayoutTag[operation.flt.layout],
"element_grad": DataTypeTag[operation.dst.element],
"layout_grad": LayoutTag[operation.dst.layout],
"element_accumulator": DataTypeTag[operation.accumulator_type()],
"opcode_class": OpcodeClassTag[
operation.tile_description.math_instruction.opcode_class
],
"arch": "cutlass::arch::Sm%d" % operation.arch,
"threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
"threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
"threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
"warp_shape_m": str(warp_shape[0]),
"warp_shape_n": str(warp_shape[1]),
"warp_shape_k": str(warp_shape[2]),
"instruction_shape_m": str(
operation.tile_description.math_instruction.instruction_shape[0]
),
"instruction_shape_n": str(
operation.tile_description.math_instruction.instruction_shape[1]
),
"instruction_shape_k": str(
operation.tile_description.math_instruction.instruction_shape[2]
),
"epilogue_vector_length": str(epilogue_vector_length),
"epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
"element_epilogue": str(DataTypeTag[operation.element_epilogue]),
"swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
"stages": str(operation.tile_description.stages),
"alignment_src": str(operation.src.alignment),
"alignment_diff": str(operation.flt.alignment),
"special_optimization": SpecialOptimizeDescTag[
operation.special_optimization
],
"math_operator": MathOperationTag[
operation.tile_description.math_instruction.math_operation
],
"implicit_gemm_mode": ImplicitGemmModeTag[operation.implicit_gemm_mode],
}
return SubstituteTemplate(self.template, values)
###################################################################################################
#
# Generator functions for all layouts
......@@ -500,6 +597,7 @@ def GenerateConv2d(
epilogues = [
EpilogueFunctor.BiasAddLinearCombination,
EpilogueFunctor.BiasAddLinearCombinationRelu,
EpilogueFunctor.LinearCombination,
]
if conv_type == ConvType.Convolution:
epilogues.append(EpilogueFunctor.BiasAddLinearCombinationHSwish)
......@@ -544,11 +642,15 @@ def GenerateConv2d(
def filter_epilogue_with_conv_kind(
epilogue: EpilogueFunctor, conv_kind: ConvKind
) -> bool:
if conv_kind == ConvKind.Fprop:
return epilogue == EpilogueFunctor.LinearCombination
elif conv_kind == ConvKind.Dgrad:
return (
(conv_kind == ConvKind.Dgrad or conv_kind == ConvKind.Wgrad)
and epilogue != EpilogueFunctor.BiasAddLinearCombinationClamp
epilogue != EpilogueFunctor.BiasAddLinearCombinationClamp
and epilogue != EpilogueFunctor.BiasAddLinearCombination
)
elif conv_kind == ConvKind.Wgrad:
return epilogue != EpilogueFunctor.LinearCombination
# loop over all tile descriptions
for tile in tile_descriptions:
......@@ -557,7 +659,7 @@ def GenerateConv2d(
bias_type, epilogues = get_bias_type_and_epilogues(tile, dst_type)
flt_align = get_flt_align(tile)
flt_align = flt_align if conv_kind == ConvKind.Wgrad else get_flt_align(tile)
dst_align = get_dst_align(tile, dst_layout)
......@@ -771,11 +873,14 @@ class EmitConvSingleKernelWrapper:
if self.operation.conv_kind == ConvKind.Fprop:
self.instance_emitter = EmitConv2dInstance()
self.convolution_name = "Convolution"
else:
assert self.operation.conv_kind == ConvKind.Dgrad
self.convolution_name = "ConvolutionOperation"
elif self.operation.conv_kind == ConvKind.Dgrad:
self.instance_emitter = EmitDeconvInstance()
self.convolution_name = "Deconvolution"
self.convolution_name = "ConvolutionOperation"
else:
assert self.operation.conv_kind == ConvKind.Wgrad
self.instance_emitter = EmitConvolutionBackwardFilterInstance()
self.convolution_name = "ConvolutionBackwardFilterOperation"
self.header_template = """
#if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor})
......@@ -800,7 +905,7 @@ namespace cutlass {
namespace library {
void initialize_${operation_name}(Manifest &manifest) {
manifest.append(new ConvolutionOperation<${convolution_name}>(
manifest.append(new ${convolution_name}<Convolution>(
"${operation_name}"
));
}
......
......@@ -5,6 +5,7 @@ from generator import (
GenerateDeconvOperations,
GenerateDwconv2dFpropOperations,
GenerateDwconv2dDgradOperations,
GenerateDwconv2dWgradOperations,
)
......@@ -28,7 +29,7 @@ def write_op_list(f, gen_op, gen_type):
elif gen_op == "dwconv2d_dgrad":
operations = GenerateDwconv2dDgradOperations(GenArg(gen_op, gen_type))
elif gen_op == "dwconv2d_wgrad":
pass
operations = GenerateDwconv2dWgradOperations(GenArg(gen_op, gen_type))
for op in operations:
f.write(' "%s.cu",\n' % op.procedural_name())
if gen_op != "gemv":
......@@ -52,4 +53,6 @@ if __name__ == "__main__":
write_op_list(f, "dwconv2d_fprop", "tensorop884")
write_op_list(f, "dwconv2d_dgrad", "simt")
write_op_list(f, "dwconv2d_dgrad", "tensorop884")
write_op_list(f, "dwconv2d_wgrad", "simt")
write_op_list(f, "dwconv2d_wgrad", "tensorop884")
f.write("]")
......@@ -1115,6 +1115,9 @@ def GenerateDwconv2d_Simt(args, conv_kind):
dst_types = [DataType.f32]
if conv_kind == ConvKind.Wgrad:
alignment_constraints = [32]
else:
alignment_constraints = [128, 32]
operations = []
......@@ -1244,7 +1247,9 @@ def GenerateDwconv2d_Simt(args, conv_kind):
32,
32,
SpecialOptimizeDesc.NoneSpecialOpt,
ImplicitGemmMode.GemmTN,
ImplicitGemmMode.GemmNT
if conv_kind == ConvKind.Wgrad
else ImplicitGemmMode.GemmTN,
)
return operations
......@@ -1277,11 +1282,14 @@ def GenerateDwconv2d_TensorOp_884(args, conv_kind):
dst_layouts = [LayoutType.TensorNCHW]
if conv_kind == ConvKind.Wgrad:
dst_types = [DataType.f32]
else:
dst_types = [DataType.f16]
alignment_constraints = [128, 32, 16]
cuda_major = 10
cuda_minor = 2
cuda_minor = 1
operations = []
for math_inst in math_instructions:
......@@ -1295,6 +1303,30 @@ def GenerateDwconv2d_TensorOp_884(args, conv_kind):
for layout in layouts:
for dst_type, dst_layout in zip(dst_types, dst_layouts):
for alignment_src in alignment_constraints:
if conv_kind == ConvKind.Wgrad:
# skip io16xc16
if math_inst.element_accumulator == DataType.f16:
continue
for alignment_diff in alignment_constraints:
operations += GenerateConv2d(
ConvType.DepthwiseConvolution,
conv_kind,
tile_descriptions,
layout[0],
layout[1],
dst_layout,
dst_type,
min_cc,
alignment_src,
alignment_diff,
32, # always f32 output
SpecialOptimizeDesc.NoneSpecialOpt,
ImplicitGemmMode.GemmNT,
False,
cuda_major,
cuda_minor,
)
else:
operations += GenerateConv2d(
ConvType.DepthwiseConvolution,
conv_kind,
......@@ -1501,7 +1533,7 @@ def GeneratesGemm_TensorOp_884(args):
# 1
]
cuda_major = 10
cuda_minor = 2
cuda_minor = 1
operations = []
for math_inst in math_instructions:
......@@ -1595,6 +1627,17 @@ def GenerateDwconv2dDgradOperations(args):
return GenerateDwconv2d_TensorOp_884(args, ConvKind.Dgrad)
def GenerateDwconv2dWgradOperations(args):
if args.type == "simt":
return GenerateDwconv2d_Simt(args, ConvKind.Wgrad)
else:
assert args.type == "tensorop884", (
"operation dwconv2d fprop only support"
"simt, tensorop884. (got:{})".format(args.type)
)
return GenerateDwconv2d_TensorOp_884(args, ConvKind.Wgrad)
def GenerateGemmOperations(args):
if args.type == "tensorop884":
return GeneratesGemm_TensorOp_884(args)
......@@ -1668,8 +1711,9 @@ if __name__ == "__main__":
operations = GenerateDwconv2dFpropOperations(args)
elif args.operations == "dwconv2d_dgrad":
operations = GenerateDwconv2dDgradOperations(args)
elif args.operations == "dwconv2d_wgrad":
pass
else:
assert args.operations == "dwconv2d_wgrad", "invalid operation"
operations = GenerateDwconv2dWgradOperations(args)
if (
args.operations == "conv2d"
......
......@@ -483,6 +483,7 @@ EpilogueFunctorTag = {
#
ShortEpilogueNames = {
EpilogueFunctor.LinearCombination: "id",
EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: "hswish",
EpilogueFunctor.BiasAddLinearCombinationReluClamp: "relu",
EpilogueFunctor.BiasAddLinearCombinationClamp: "id",
......
......@@ -1382,4 +1382,60 @@ cutlass_gen_list = [
"cutlass_tensorop_h884dwdgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align1x1.cu",
"cutlass_tensorop_h884dwdgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align1x1.cu",
"all_dwconv2d_dgrad_tensorop884_operations.cu",
"cutlass_simt_sdwwgrad_id_f32_32x32x8_32x32x8_2_nchw_nchw_align1x1.cu",
"cutlass_simt_sdwwgrad_id_f32_32x64x8_32x64x8_2_nchw_nchw_align1x1.cu",
"cutlass_simt_sdwwgrad_id_f32_64x32x8_64x32x8_2_nchw_nchw_align1x1.cu",
"cutlass_simt_sdwwgrad_id_f32_32x128x8_32x64x8_2_nchw_nchw_align1x1.cu",
"cutlass_simt_sdwwgrad_id_f32_64x64x8_32x64x8_2_nchw_nchw_align1x1.cu",
"cutlass_simt_sdwwgrad_id_f32_128x32x8_64x32x8_2_nchw_nchw_align1x1.cu",
"cutlass_simt_sdwwgrad_id_f32_64x128x8_32x64x8_2_nchw_nchw_align1x1.cu",
"cutlass_simt_sdwwgrad_id_f32_128x64x8_64x32x8_2_nchw_nchw_align1x1.cu",
"cutlass_simt_sdwwgrad_id_f32_128x128x8_32x64x8_2_nchw_nchw_align1x1.cu",
"all_dwconv2d_wgrad_simt_operations.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align8x8.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align8x8.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align8x8.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align8x8.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align8x8.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align8x2.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align8x2.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align8x2.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align8x2.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align8x2.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align8x1.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align8x1.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align8x1.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align8x1.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align8x1.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align2x8.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align2x8.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align2x8.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align2x8.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align2x8.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align2x2.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align2x2.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align2x2.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align2x2.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align2x2.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align2x1.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align2x1.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align2x1.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align2x1.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align2x1.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align1x8.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align1x8.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align1x8.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align1x8.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align1x8.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align1x2.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align1x2.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align1x2.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align1x2.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align1x2.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x256x32_64x64x32_2_nchw_nchw_align1x1.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x128x32_32x32x32_2_nchw_nchw_align1x1.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x128x32_32x32x32_2_nchw_nchw_align1x1.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_128x64x32_32x32x32_2_nchw_nchw_align1x1.cu",
"cutlass_tensorop_s884dwwgrad_id_f16_64x64x32_32x32x32_2_nchw_nchw_align1x1.cu",
"all_dwconv2d_wgrad_tensorop884_operations.cu",
]
\ No newline at end of file
......@@ -185,6 +185,8 @@ if(MGE_WITH_CUDA)
gen_cutlass_kimpl(dwconv2d_fprop tensorop884 CUTLASS_SOURCES)
gen_cutlass_kimpl(dwconv2d_dgrad simt CUTLASS_SOURCES)
gen_cutlass_kimpl(dwconv2d_dgrad tensorop884 CUTLASS_SOURCES)
gen_cutlass_kimpl(dwconv2d_wgrad simt CUTLASS_SOURCES)
gen_cutlass_kimpl(dwconv2d_wgrad tensorop884 CUTLASS_SOURCES)
list(APPEND SOURCES ${CUTLASS_SOURCES})
list(APPEND SOURCES ${CUSOURCES})
endif()
......
......@@ -317,7 +317,7 @@ void ConvBiasForwardImpl::AlgoPack::fill_dwconv_algos() {
for (auto&& algo : f32_implicit_bmm) {
all_algos.push_back(&algo);
}
#if CUDA_VERSION >= 10020
#if CUDA_VERSION >= 10010
/// preferred algo
f16_implicit_bmm.emplace_back(AlgoParam{64, 128, 32, 32, 32, 32, 8, 8, 4, 2});
f16_implicit_bmm.emplace_back(AlgoParam{128, 128, 32, 32, 32, 32, 8, 8, 4, 2});
......
......@@ -50,6 +50,7 @@ bool ConvBiasForwardImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm::is_available(
RETURN_IF_FALSE(param.mode == Mode::CROSS_CORRELATION);
// check if channelwise convolution
RETURN_IF_FALSE(fm.icpg == 1 && fm.ocpg == 1);
RETURN_IF_FALSE(param.dilate_h == 1 && param.dilate_w == 1);
const auto* op = get_cutlass_conv_op(
args, ConvOperator::kFprop, ConvType::kDepthwiseConvolution, false, false);
RETURN_IF_FALSE(op != nullptr);
......
......@@ -50,6 +50,7 @@ bool ConvBiasForwardImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm::is_available(
RETURN_IF_FALSE(param.mode == Mode::CROSS_CORRELATION);
// check if channelwise convolution
RETURN_IF_FALSE(fm.icpg == 1 && fm.ocpg == 1);
RETURN_IF_FALSE(param.dilate_h == 1 && param.dilate_w == 1);
const auto* op = get_cutlass_conv_op(
args, ConvOperator::kFprop, ConvType::kDepthwiseConvolution, false, false);
RETURN_IF_FALSE(op != nullptr);
......
......@@ -146,15 +146,17 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
args.filter_meta.stride[0] != 1 ||
args.filter_meta.stride[1] != 1 || hw_size < 512;
//! choose for large kernel cases
size_t fh = args.filter_meta.spatial[2], fw = args.filter_meta.spatial[3];
size_t fh = args.filter_meta.spatial[0], fw = args.filter_meta.spatial[1];
size_t hi = src[2], wi = src[3];
const bool prefer_dnn_lk_implbmm = hi <= 2 * fh && wi <= 2 * fw;
//! avoid bad case in cudnn, check dnn chanwise impl first
if (is_chanwise) {
if (prefer_dnn_lk_implbmm) {
#if CUDA_VERSION >= 10020
if (sm_algo_pack.f16_implicit_bmm[0].is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes))
return &sm_algo_pack.f16_implicit_bmm[0];
#endif
if (sm_algo_pack.f32_implicit_bmm[0].is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes))
return &sm_algo_pack.f32_implicit_bmm[0];
......
......@@ -72,7 +72,7 @@ void ConvolutionBackwardDataImpl::AlgoPack::fill_dwconv_algos() {
all_algos.push_back(&algo);
}
}
#if CUDA_VERSION >= 10020
#if CUDA_VERSION >= 10010
{
using AlgoParam = AlgoFloat16NCHWHMMAImplicitBatchedGemm::AlgoParam;
/// preferred algo
......
......@@ -24,8 +24,10 @@ const void* ConvolutionBackwardDataImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm:
int alignment_diff = 0;
int wo = args.diff_layout->dtype.size(args.diff_layout->operator[](3));
for (int candidate : {16, 4, 2}) {
if (wo % candidate == 0)
if (wo % candidate == 0) {
alignment_diff = candidate;
break;
}
}
alignment_diff /= args.diff_layout->dtype.size(1);
NumericTypeID accumulator_dtype =
......@@ -85,6 +87,7 @@ bool ConvolutionBackwardDataImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm::is_ava
RETURN_IF_FALSE(param.mode == Mode::CROSS_CORRELATION);
// check if channelwise convolution
RETURN_IF_FALSE(fm.icpg == 1 && fm.ocpg == 1);
RETURN_IF_FALSE(param.dilate_h == 1 && param.dilate_w == 1);
const auto* op = get_available_op(args);
RETURN_IF_FALSE(op != nullptr);
return true;
......
......@@ -24,8 +24,10 @@ const void* ConvolutionBackwardDataImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm::
int alignment_diff = 0;
int wo = args.diff_layout->dtype.size(args.diff_layout->operator[](3));
for (int candidate : {16, 4}) {
if (wo % candidate == 0)
if (wo % candidate == 0) {
alignment_diff = candidate;
break;
}
}
alignment_diff /= args.diff_layout->dtype.size(1);
ConvolutionKey key{
......@@ -81,6 +83,7 @@ bool ConvolutionBackwardDataImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm::is_avai
RETURN_IF_FALSE(param.mode == Mode::CROSS_CORRELATION);
// check if channelwise convolution
RETURN_IF_FALSE(fm.icpg == 1 && fm.ocpg == 1);
RETURN_IF_FALSE(param.dilate_h == 1 && param.dilate_w == 1);
const auto* op = get_available_op(args);
RETURN_IF_FALSE(op != nullptr);
return true;
......
......@@ -25,6 +25,7 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() {
for (auto&& i : cudnn) {
all_algos.push_back(&i);
}
fill_dwconv_algos();
all_algos.push_back(&matmul);
all_algos.push_back(&group);
......@@ -48,6 +49,39 @@ ConvolutionBackwardFilterImpl::AlgoCUDNN* ConvolutionBackwardFilterImpl::AlgoPac
"can not find cudnn bwd_filter algorithm %d", static_cast<int>(algo)));
}
void ConvolutionBackwardFilterImpl::AlgoPack::fill_dwconv_algos() {
{
using AlgoParam = AlgoFloat32NCHWFMAImplicitBatchedGemm::AlgoParam;
/// preferred algo
implbmm_nchw_fma.emplace_back(AlgoParam{64, 128, 8, 32, 64, 8, 2});
implbmm_nchw_fma.emplace_back(AlgoParam{128, 128, 8, 32, 64, 8, 2});
implbmm_nchw_fma.emplace_back(AlgoParam{128, 64, 8, 64, 32, 8, 2});
implbmm_nchw_fma.emplace_back(AlgoParam{128, 32, 8, 64, 32, 8, 2});
implbmm_nchw_fma.emplace_back(AlgoParam{32, 128, 8, 32, 64, 8, 2});
implbmm_nchw_fma.emplace_back(AlgoParam{64, 64, 8, 32, 64, 8, 2});
implbmm_nchw_fma.emplace_back(AlgoParam{32, 64, 8, 32, 64, 8, 2});
implbmm_nchw_fma.emplace_back(AlgoParam{32, 32, 8, 32, 32, 8, 2});
implbmm_nchw_fma.emplace_back(AlgoParam{64, 32, 8, 64, 32, 8, 2});
for (auto&& algo : implbmm_nchw_fma) {
all_algos.push_back(&algo);
}
}
#if CUDA_VERSION >= 10010
{
using AlgoParam = AlgoFloat16NCHWHMMAImplicitBatchedGemm::AlgoParam;
/// preferred algo
implbmm_nchw_hmma.emplace_back(AlgoParam{64, 128, 32, 32, 32, 32, 8, 8, 4, 2});
implbmm_nchw_hmma.emplace_back(AlgoParam{128, 128, 32, 32, 32, 32, 8, 8, 4, 2});
implbmm_nchw_hmma.emplace_back(AlgoParam{128, 256, 32, 64, 64, 32, 8, 8, 4, 2});
implbmm_nchw_hmma.emplace_back(AlgoParam{128, 64, 32, 32, 32, 32, 8, 8, 4, 2});
implbmm_nchw_hmma.emplace_back(AlgoParam{64, 64, 32, 32, 32, 32, 8, 8, 4, 2});
for (auto&& algo : implbmm_nchw_hmma) {
all_algos.push_back(&algo);
}
}
#endif
}
ConvolutionBackwardFilterImpl::AlgoPack ConvolutionBackwardFilterImpl::sm_algo_pack;
ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs(
......
......@@ -37,6 +37,8 @@ public:
CUDA_CHANWISE,
CUDA_BFLOAT16,
CUDA_GROUP_CONV_GENERAL,
CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32,
CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
......@@ -210,9 +212,86 @@ private:
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};
class ConvolutionBackwardFilterImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm final
: public AlgoBase {
public:
struct AlgoParam {
int threadblock_m;
int threadblock_n;
int threadblock_k;
int warp_m;
int warp_n;
int warp_k;
int stage;
std::string to_string() {
return ssprintf(
"_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n,
threadblock_k, warp_m, warp_n, warp_k, stage);
}
};
AlgoFloat32NCHWFMAImplicitBatchedGemm(AlgoParam algo_param)
: m_algo_param{algo_param},
m_name{ssprintf(
"FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM%s",
m_algo_param.to_string().c_str())} {}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override { return 0; }
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32)
private:
const void* get_available_op(const SizeArgs& args) const;
AlgoParam m_algo_param;
std::string m_name;
};
class ConvolutionBackwardFilterImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm final
: public AlgoBase {
public:
/// add instruction shape as member of algo param, because f16 tensor core has 2
/// different matrix shapes (i.e. mma.884 and mma.1688)
struct AlgoParam {
int threadblock_m;
int threadblock_n;
int threadblock_k;
int warp_m;
int warp_n;
int warp_k;
int instruction_m;
int instruction_n;
int instruction_k;
int stage;
std::string to_string() {
return ssprintf(
"_%dX%dX%d_%dX%dX%d_mma%dX%dX%d_%dstage", threadblock_m,
threadblock_n, threadblock_k, warp_m, warp_n, warp_k, instruction_m,
instruction_n, instruction_k, stage);
}
};
AlgoFloat16NCHWHMMAImplicitBatchedGemm(AlgoParam algo_param)
: m_algo_param{algo_param},
m_name{ssprintf(
"FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM%s",
m_algo_param.to_string().c_str())} {}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16)
private:
const void* get_available_op(const SizeArgs& args) const;
AlgoParam m_algo_param;
std::string m_name;
};
class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj {
// defined in cudnn.cpp
void fill_cudnn_algos();
void fill_dwconv_algos();
AlgoBase::Mapper m_all_algos_map;
......@@ -224,6 +303,8 @@ public:
AlgoChanwise chanwise;
AlgoGroupConvGeneral group;
AlgoBFloat16 bfloat16;
std::vector<AlgoFloat32NCHWFMAImplicitBatchedGemm> implbmm_nchw_fma;
std::vector<AlgoFloat16NCHWHMMAImplicitBatchedGemm> implbmm_nchw_hmma;
std::vector<AlgoBase*>
//! all algorithms
......
/**
* \file
* dnn/src/cuda/convolution/backward_filter/implicit_batched_gemm_float16_nchw_hmma.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/convolution/backward_filter/algo.h"
#include "src/cuda/cutlass/singleton.h"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
using namespace cutlass::library;
const void* ConvolutionBackwardFilterImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm::
get_available_op(const SizeArgs& args) const {
auto get_alignment = [](const TensorLayout& layout) {
int alignment = 0;
int width = layout.dtype.size(layout[3]);
for (int candidate : {16, 4, 2}) {
if (width % candidate == 0) {
alignment = candidate;
break;
}
}
alignment /= layout.dtype.size(1);
return alignment;
};
int alignment_src = get_alignment(*args.src_layout);
int alignment_diff = get_alignment(*args.diff_layout);
megdnn_assert(alignment_src >= 1 && alignment_diff >= 1);
NumericTypeID accumulator_dtype =
args.opr->param().compute_mode == param::Convolution::ComputeMode::DEFAULT
? NumericTypeID::kF16
: NumericTypeID::kF32;
ConvolutionKey key{
cutlass::conv::Operator::kWgrad,
NumericTypeID::kF16, // src tensor data type
LayoutTypeID::kTensorNCHW, // src tensor layout
NumericTypeID::kF16, // diff tensor data type
LayoutTypeID::kTensorNCHW, // diff tensor layout
NumericTypeID::kF32, // grad tensor data type
LayoutTypeID::kTensorNCHW, // grad tensor layout
NumericTypeID::kF32, // dummy argument, not used.
LayoutTypeID::kTensorNCHW, // dummy argument, not used
accumulator_dtype,
cutlass::conv::ConvType::kDepthwiseConvolution,
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k,
m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k,
m_algo_param.instruction_m,
m_algo_param.instruction_n,
m_algo_param.instruction_k,
cutlass::epilogue::EpilogueType::kLinearCombination, // no bias
m_algo_param.stage,
cutlass::conv::SpecialOptimizeDesc::NONE,
alignment_src,
alignment_diff,
true};
return (void*)Singleton::get().operation_table.find_op(key);
}
bool ConvolutionBackwardFilterImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm::
is_available(const SizeArgs& args) const {
#define RETURN_IF_FALSE(stmt_) \
if (!(stmt_)) \
return false;
RETURN_IF_FALSE(is_compute_capability_required(7, 0));
RETURN_IF_FALSE(
args.src_layout->is_contiguous() && args.diff_layout->is_contiguous() &&
args.grad_layout->is_contiguous());
using Param = param::Convolution;
using Format = Param::Format;
using Sparse = Param::Sparse;
using Mode = Param::Mode;
using ComputeMode = Param::ComputeMode;
auto&& param = args.opr->param();
auto&& fm = args.grad_filter_meta;
RETURN_IF_FALSE(param.compute_mode == ComputeMode::FLOAT32);
RETURN_IF_FALSE(
param.format == Format::NCHW &&
args.src_layout->dtype.enumv() == DTypeEnum::Float16 &&
args.diff_layout->dtype.enumv() == DTypeEnum::Float16 &&
args.grad_layout->dtype.enumv() == DTypeEnum::Float16);
RETURN_IF_FALSE(param.sparse == Sparse::GROUP);
RETURN_IF_FALSE(param.mode == Mode::CROSS_CORRELATION);
// check if channelwise convolution
RETURN_IF_FALSE(fm.icpg == 1 && fm.ocpg == 1);
RETURN_IF_FALSE(param.dilate_h == 1 && param.dilate_w == 1);
const auto* op = get_available_op(args);
RETURN_IF_FALSE(op != nullptr);
return true;
#undef RETURN_IF_FALSE
}
size_t ConvolutionBackwardFilterImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm::
get_workspace_in_bytes(const SizeArgs& args) const {
auto layout = *args.grad_layout;
// modify data type
layout.modify_dtype_inplace(dtype::Float32());
return layout.span().dist_byte();
}
void ConvolutionBackwardFilterImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm::exec(
const ExecArgs& args) const {
auto&& param = args.opr->param();
auto&& fm = args.grad_filter_meta;
int hi = args.src_layout->operator[](2), wi = args.src_layout->operator[](3);
int n = args.diff_layout->operator[](0), ho = args.diff_layout->operator[](2),
wo = args.diff_layout->operator[](3);
int co = fm.group, ci = co, groups = co;
int fh = fm.spatial[0], fw = fm.spatial[1];
int sh = fm.stride[0], sw = fm.stride[1];
int ph = fm.padding[0], pw = fm.padding[1];
int dh = param.dilate_h, dw = param.dilate_w;
// check if channelwise convolution
megdnn_assert(fm.icpg == 1 && fm.ocpg == 1);
auto&& stream = cuda_stream(args.opr->handle());
float alpha = 1.f;
float beta = 0.f;
const Operation* op = (const Operation*)get_available_op(args);
cutlass::conv::Conv2dProblemSize problem_size{
n, hi, wi, ci, co, fh, fw, ho,
wo, ph, pw, sh, sw, dh, dw, cutlass::conv::Mode::kCrossCorrelation,
1, // split k slices, always 1
groups, // groups
};
cutlass::library::ConvolutionArguments conv_args{
problem_size,
args.src_tensor->raw_ptr(),
args.diff_tensor->raw_ptr(),
nullptr,
nullptr,
args.workspace.raw_ptr,
&alpha,
&beta,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr};
cutlass_check(op->run(&conv_args, nullptr, stream));
after_kernel_launch();
auto&& typecvt = args.opr->handle()->create_operator<TypeCvt>();
auto f32_grad_layout = *args.grad_layout;
// modify data type
f32_grad_layout.modify_dtype_inplace(dtype::Float32());
TensorND src{args.workspace.raw_ptr, f32_grad_layout},
dst{args.grad_tensor->raw_ptr(), *args.grad_layout};
typecvt->exec(src, dst);
}
// vim: syntax=cpp.doxygen
/**
* \file
* dnn/src/cuda/convolution/backward_filter/implicit_batched_gemm_float32_nchw_fma.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/convolution/backward_filter/algo.h"
#include "src/cuda/cutlass/singleton.h"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
using namespace cutlass::library;
const void* ConvolutionBackwardFilterImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm::
get_available_op(const SizeArgs& args) const {
ConvolutionKey key{
cutlass::conv::Operator::kWgrad,
NumericTypeID::kF32, // src tensor data type
LayoutTypeID::kTensorNCHW, // src tensor layout
NumericTypeID::kF32, // diff tensor data type
LayoutTypeID::kTensorNCHW, // diff tensor layout
NumericTypeID::kF32, // grad tensor data type
LayoutTypeID::kTensorNCHW, // grad tensor layout
NumericTypeID::kF32, // dummy argument, not used.
LayoutTypeID::kTensorNCHW, // dummy argument, not used
NumericTypeID::kF32,
cutlass::conv::ConvType::kDepthwiseConvolution,
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k,
m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k,
1,
1,
1,
cutlass::epilogue::EpilogueType::kLinearCombination, // no bias
m_algo_param.stage,
cutlass::conv::SpecialOptimizeDesc::NONE,
1,
1,
true};
return (void*)Singleton::get().operation_table.find_op(key);
}
bool ConvolutionBackwardFilterImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm::is_available(
const SizeArgs& args) const {
#define RETURN_IF_FALSE(stmt_) \
if (!(stmt_)) \
return false;
RETURN_IF_FALSE(is_compute_capability_required(6, 1));
RETURN_IF_FALSE(
args.src_layout->is_contiguous() && args.diff_layout->is_contiguous() &&
args.grad_layout->is_contiguous());
using Param = param::Convolution;
using Format = Param::Format;
using Sparse = Param::Sparse;
using Mode = Param::Mode;
auto&& param = args.opr->param();
auto&& fm = args.grad_filter_meta;
RETURN_IF_FALSE(
param.format == Format::NCHW &&
args.src_layout->dtype.enumv() == DTypeEnum::Float32 &&
args.diff_layout->dtype.enumv() == DTypeEnum::Float32 &&
args.grad_layout->dtype.enumv() == DTypeEnum::Float32);
RETURN_IF_FALSE(param.sparse == Sparse::GROUP);
RETURN_IF_FALSE(param.mode == Mode::CROSS_CORRELATION);
// check if channelwise convolution
RETURN_IF_FALSE(fm.icpg == 1 && fm.ocpg == 1);
RETURN_IF_FALSE(param.dilate_h == 1 && param.dilate_w == 1);
const auto* op = get_available_op(args);
RETURN_IF_FALSE(op != nullptr);
return true;
#undef RETURN_IF_FALSE
}
void ConvolutionBackwardFilterImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm::exec(
const ExecArgs& args) const {
auto&& param = args.opr->param();
auto&& fm = args.grad_filter_meta;
int hi = args.src_layout->operator[](2), wi = args.src_layout->operator[](3);
int n = args.diff_layout->operator[](0), ho = args.diff_layout->operator[](2),
wo = args.diff_layout->operator[](3);
int co = fm.group, ci = co, groups = co;
int fh = fm.spatial[0], fw = fm.spatial[1];
int sh = fm.stride[0], sw = fm.stride[1];
int ph = fm.padding[0], pw = fm.padding[1];
int dh = param.dilate_h, dw = param.dilate_w;
// check if channelwise convolution
megdnn_assert(fm.icpg == 1 && fm.ocpg == 1);
auto&& stream = cuda_stream(args.opr->handle());
float alpha = 1.f;
float beta = 0.f;
const Operation* op = (const Operation*)get_available_op(args);
cutlass::conv::Conv2dProblemSize problem_size{
n, hi, wi, ci, co, fh, fw, ho,
wo, ph, pw, sh, sw, dh, dw, cutlass::conv::Mode::kCrossCorrelation,
1, // split k slices, always 1
groups, // groups
};
cutlass::library::ConvolutionArguments conv_args{
problem_size,
args.src_tensor->raw_ptr(),
args.diff_tensor->raw_ptr(),
nullptr,
nullptr,
args.grad_tensor->raw_ptr(),
&alpha,
&beta,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr};
cutlass_check(op->run(&conv_args, nullptr, stream));
after_kernel_launch();
}
// vim: syntax=cpp.doxygen
......@@ -116,15 +116,18 @@ ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
AlgoBase::SizeArgs args(this, filter, diff, grad);
//! choose for large kernel cases
size_t fh = args.filter_meta.spatial[2], fw = args.filter_meta.spatial[3];
size_t fh = args.filter_meta.spatial[0], fw = args.filter_meta.spatial[1];
size_t ho = diff[2], wo = diff[3];
const bool prefer_dnn_lk_implbmm = args.filter_meta.format == Param::Format::NCHW &&
ho <= 2 * fh && wo <= 2 * fw;
if (prefer_dnn_lk_implbmm) {
if (sm_algo_pack.implbmm_nchw_hmma.is_available_attribute(
#if CUDA_VERSION >= 10020
if (sm_algo_pack.implbmm_nchw_hmma[0].is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes))
return &sm_algo_pack.implbmm_nchw_hmma[0];
if (sm_algo_pack.implbmm_nchw_fma.is_available_attribute(args, positive_attr, negative_attr, workspace_limit_in_bytes))
#endif
if (sm_algo_pack.implbmm_nchw_fma[0].is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes))
return &sm_algo_pack.implbmm_nchw_fma[0];
}
......@@ -255,6 +258,23 @@ ConvolutionBackwardFilterImpl::Algorithm* ConvolutionBackwardFilterImpl::
const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args(this, src, diff, grad);
//! choose for large kernel cases
size_t fh = args.grad_filter_meta.spatial[0], fw = args.grad_filter_meta.spatial[1];
size_t ho = diff[2], wo = diff[3];
const bool prefer_dnn_lk_implbmm =
args.grad_filter_meta.format == Param::Format::NCHW && ho <= 2 * fh &&
wo <= 2 * fw;
if (prefer_dnn_lk_implbmm) {
#if CUDA_VERSION >= 10020
if (sm_algo_pack.implbmm_nchw_hmma[0].is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes))
return &sm_algo_pack.implbmm_nchw_hmma[0];
#endif
if (sm_algo_pack.implbmm_nchw_fma[0].is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes))
return &sm_algo_pack.implbmm_nchw_fma[0];
}
if (args.grad_filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
......
......@@ -156,6 +156,8 @@ public:
class AlgoChanwise;
class AlgoGroupConvGeneral;
class AlgoBFloat16;
class AlgoFloat32NCHWFMAImplicitBatchedGemm;
class AlgoFloat16NCHWHMMAImplicitBatchedGemm;
class AlgoPack;
......
......@@ -135,6 +135,15 @@ namespace detail {
template <typename EpilogueOp, epilogue::EpilogueType type>
struct init_epilogue_param_;
template <typename EpilogueOp>
struct init_epilogue_param_<EpilogueOp, epilogue::EpilogueType::kLinearCombination> {
using ElementCompute = typename EpilogueOp::ElementCompute;
typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) {
return {*static_cast<ElementCompute const*>(conv_args->alpha),
*static_cast<ElementCompute const*>(conv_args->beta)};
}
};
template <typename EpilogueOp>
struct init_epilogue_param_<
EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombination> {
......@@ -290,6 +299,159 @@ public:
///////////////////////////////////////////////////////////////////////////////////////////////////
/// We add a new template class to handle convolution backward filter operation, because
/// the device-level convolution operator of backward filter is different from the
/// others (convolution forward and convolution backward data).
/// But the description object is reused in this wrapper of convolution backward filter.
/// The reason is that we do not want to introduce an another unnecessary structure.
/// TODO: Maybe the device-level operator in cutlass for convoluton forward, backward
/// data and backward filter should be combined.
template <typename Operator_>
class ConvolutionBackwardFilterOperationBase : public Operation {
public:
using Operator = Operator_;
using ElementSrc = typename Operator::ElementSrc;
using LayoutSrc = typename Operator::LayoutSrc;
using ElementDiff = typename Operator::ElementDiff;
using LayoutDiff = typename Operator::LayoutDiff;
using ElementGrad = typename Operator::ElementGrad;
using LayoutGrad = typename Operator::LayoutGrad;
using ElementAccumulator = typename Operator::ElementAccumulator;
ConvolutionBackwardFilterOperationBase(char const* name = "unknown_convolution") {
m_description.name = name;
m_description.provider = Provider::kCUTLASS;
m_description.kind = OperationKind::kConvolution;
m_description.conv_op = Operator::kConvolutionalOperator;
m_description.tile_description.threadblock_shape = make_Coord(
Operator::ThreadblockShape::kM, Operator::ThreadblockShape::kN,
Operator::ThreadblockShape::kK);
m_description.tile_description.threadblock_stages = Operator::kStages;
m_description.tile_description.warp_count = make_Coord(
Operator::ConvolutionKernel::WarpCount::kM,
Operator::ConvolutionKernel::WarpCount::kN,
Operator::ConvolutionKernel::WarpCount::kK);
m_description.tile_description.math_instruction.instruction_shape = make_Coord(
Operator::InstructionShape::kM, Operator::InstructionShape::kN,
Operator::InstructionShape::kK);
m_description.tile_description.math_instruction.element_accumulator =
NumericTypeMap<ElementAccumulator>::kId;
m_description.tile_description.math_instruction.opcode_class =
OpcodeClassMap<typename Operator::OperatorClass>::kId;
m_description.tile_description.math_instruction.math_operation =
MathOperationMap<typename Operator::Operator>::kId;
m_description.tile_description.minimum_compute_capability =
ArchMap<typename Operator::ArchTag,
typename Operator::OperatorClass>::kMin;
m_description.tile_description.maximum_compute_capability =
ArchMap<typename Operator::ArchTag,
typename Operator::OperatorClass>::kMax;
/// src in description -> src in C++ template
m_description.src =
make_TensorDescription<ElementSrc, LayoutSrc>(Operator::kAlignmentSrc);
/// filter in description -> diff in C++ template
m_description.filter = make_TensorDescription<ElementDiff, LayoutDiff>(
Operator::kAlignmentDiff);
/// dst in description -> grad in C++ template
m_description.dst = make_TensorDescription<ElementGrad, LayoutGrad>(
Operator::kAlignmentGrad);
/// because bias tensor is not used in ConvolutionBackwardFilter operation, the
/// following tensor description is a dummy arguments
m_description.bias = make_TensorDescription<ElementGrad, LayoutGrad>(
Operator::kAlignmentGrad);
m_description.convolution_type = Operator::kConvolutionType;
m_description.arch_tag = ArchTagMap<typename Operator::ArchTag>::kId;
m_description.epilogue_type = Operator::EpilogueOutputOp::kType;
m_description.epilogue_count = Operator::EpilogueOutputOp::kCount;
m_description.threadblock_swizzle =
ThreadblockSwizzleMap<typename Operator::ThreadblockSwizzle>::kId;
m_description.special_optimization = Operator::kSpecialOpt;
m_description.gemm_mode = Operator::kGemmMode;
/// ConvolutionBackwardFilter operation is only used for depthwise convolution,
/// so the option without_shared_load is always true
m_description.without_shared_load = true;
}
virtual OperationDescription const& description() const { return m_description; }
protected:
ConvolutionDescription m_description;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator_>
class ConvolutionBackwardFilterOperation
: public ConvolutionBackwardFilterOperationBase<Operator_> {
public:
using Operator = Operator_;
using ElementSrc = typename Operator::ElementSrc;
using LayoutSrc = typename Operator::LayoutSrc;
using ElementDiff = typename Operator::ElementDiff;
using LayoutDiff = typename Operator::LayoutDiff;
using ElementGrad = typename Operator::ElementGrad;
using LayoutGrad = typename Operator::LayoutGrad;
using ElementAccumulator = typename Operator::ElementAccumulator;
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
using OperatorArguments = typename Operator::Arguments;
ConvolutionBackwardFilterOperation(char const* name = "unknown_gemm")
: ConvolutionBackwardFilterOperationBase<Operator_>(name) {}
virtual Status run(
void const* arguments_ptr, void* device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
cutlass::conv::Operator conv_op = this->m_description.conv_op;
ConvolutionArguments const* conv_args =
reinterpret_cast<ConvolutionArguments const*>(arguments_ptr);
const auto& ps = conv_args->problem_size;
OperatorArguments args;
args.problem_size = ps;
/// src in convolution arguments -> ref_src
args.ref_src = {
static_cast<ElementSrc*>(const_cast<void*>(conv_args->src)),
LayoutSrc::packed(implicit_gemm_tensor_b_extent(conv_op, ps))};
/// filter in convolution arguments -> ref_diff
args.ref_diff = {
static_cast<ElementDiff*>(const_cast<void*>(conv_args->filter)),
LayoutDiff::packed(implicit_gemm_tensor_a_extent(conv_op, ps))};
/// dst in convolution arguments -> ref_grad
args.ref_grad = {
static_cast<ElementGrad*>(conv_args->dst),
LayoutGrad::packed(implicit_gemm_tensor_c_extent(conv_op, ps))};
args.output_op = init_epilogue_param<typename Operator::EpilogueOutputOp>().get(
conv_args);
Operator op;
Status status = op.initialize(args, device_workspace);
if (status != Status::kSuccess) {
return status;
}
return op.run(stream);
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
......
......@@ -44,6 +44,11 @@ namespace cutlass {
namespace library {
/////////////////////////////////////////////////////////////////////////////////////////////////
#if ((__CUDACC_VER_MAJOR__ > 10) || \
(__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))
#define CUTLASS_ARCH_MMA_SM70_SUPPORTED 1
#endif
#if ((__CUDACC_VER_MAJOR__ > 10) || \
(__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))
#define CUTLASS_ARCH_MMA_SM75_SUPPORTED 1
......@@ -56,14 +61,18 @@ void initialize_all_conv2d_simt_operations(Manifest& manifest);
void initialize_all_deconv_simt_operations(Manifest& manifest);
void initialize_all_dwconv2d_fprop_simt_operations(Manifest& manifest);
void initialize_all_dwconv2d_dgrad_simt_operations(Manifest& manifest);
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) && CUTLASS_ARCH_MMA_SM75_SUPPORTED
void initialize_all_dwconv2d_wgrad_simt_operations(Manifest& manifest);
#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) && CUTLASS_ARCH_MMA_SM70_SUPPORTED
void initialize_all_gemm_tensorop884_operations(Manifest& manifest);
void initialize_all_dwconv2d_fprop_tensorop884_operations(Manifest& manifest);
void initialize_all_dwconv2d_dgrad_tensorop884_operations(Manifest& manifest);
void initialize_all_dwconv2d_wgrad_tensorop884_operations(Manifest& manifest);
#endif
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) && CUTLASS_ARCH_MMA_SM75_SUPPORTED
void initialize_all_gemm_tensorop1688_operations(Manifest& manifest);
void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest);
void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest);
void initialize_all_deconv_tensorop8816_operations(Manifest& manifest);
void initialize_all_dwconv2d_fprop_tensorop884_operations(Manifest& manifest);
void initialize_all_dwconv2d_dgrad_tensorop884_operations(Manifest& manifest);
#endif
void initialize_all(Manifest& manifest) {
......@@ -72,14 +81,18 @@ void initialize_all(Manifest& manifest) {
initialize_all_deconv_simt_operations(manifest);
initialize_all_dwconv2d_fprop_simt_operations(manifest);
initialize_all_dwconv2d_dgrad_simt_operations(manifest);
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) && CUTLASS_ARCH_MMA_SM75_SUPPORTED
initialize_all_dwconv2d_wgrad_simt_operations(manifest);
#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) && CUTLASS_ARCH_MMA_SM70_SUPPORTED
initialize_all_gemm_tensorop884_operations(manifest);
initialize_all_dwconv2d_fprop_tensorop884_operations(manifest);
initialize_all_dwconv2d_dgrad_tensorop884_operations(manifest);
initialize_all_dwconv2d_wgrad_tensorop884_operations(manifest);
#endif
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) && CUTLASS_ARCH_MMA_SM75_SUPPORTED
initialize_all_gemm_tensorop1688_operations(manifest);
initialize_all_conv2d_tensorop8816_operations(manifest);
initialize_all_conv2d_tensorop8832_operations(manifest);
initialize_all_deconv_tensorop8816_operations(manifest);
initialize_all_dwconv2d_fprop_tensorop884_operations(manifest);
initialize_all_dwconv2d_dgrad_tensorop884_operations(manifest);
#endif
}
......
......@@ -279,7 +279,6 @@ struct ConvolutionKey {
struct ConvolutionKeyHasher {
inline size_t operator()(ConvolutionKey const& key) const {
return Hash()
.update(&key.conv_op, sizeof(key.conv_op))
.update(&key.conv_op, sizeof(key.conv_op))
.update(&key.element_src, sizeof(key.element_src))
.update(&key.layout_src, sizeof(key.layout_src))
......
......@@ -1322,6 +1322,8 @@ static struct {
{"batch_convolution", "BatchConvolution", conv::ConvType::kBatchConvolution},
{"local", "Local", conv::ConvType::kLocal},
{"local_share", "LocalShare", conv::ConvType::kLocalShare},
{"depthwise_convolution", "DepthwiseConvolution",
conv::ConvType::kDepthwiseConvolution},
};
/// Converts a ConvType enumerant to a string
......
......@@ -44,7 +44,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {
for (auto&& algo : simt_float32_gemv_batched_strided) {
all_algos.push_back(&algo);
}
#if CUDA_VERSION >= 10020
#if CUDA_VERSION >= 10010
for (auto&& algo : tensorop_float16) {
all_algos.push_back(&algo);
}
......@@ -113,21 +113,26 @@ void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() {
simt_float32_gemv_batched_strided.emplace_back(128);
simt_float32_gemv_batched_strided.emplace_back(64);
simt_float32_gemv_batched_strided.emplace_back(32);
#define FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb) \
#define FOREACH_CUTLASS_MATMUL_MMA_SM70_SHAPES(cb) \
cb(256, 128, 32, 64, 64, 32, 8, 8, 4); \
cb(128, 256, 32, 64, 64, 32, 8, 8, 4); \
cb(128, 128, 32, 64, 64, 32, 8, 8, 4); \
cb(128, 128, 32, 64, 64, 32, 8, 8, 4);
#define FOREACH_CUTLASS_MATMUL_MMA_SM75_SHAPES(cb) \
cb(256, 128, 32, 64, 64, 32, 16, 8, 8); \
cb(128, 256, 32, 64, 64, 32, 16, 8, 8); \
cb(128, 128, 32, 64, 64, 32, 16, 8, 8);
#define cb(...) \
tensorop_float16.emplace_back(AlgoParam{__VA_ARGS__}); \
tensorop_float16_split_k.emplace_back(AlgoParam{__VA_ARGS__});
#if CUDA_VERSION >= 10010
FOREACH_CUTLASS_MATMUL_MMA_SM70_SHAPES(cb)
#endif
#if CUDA_VERSION >= 10020
FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb)
FOREACH_CUTLASS_MATMUL_MMA_SM75_SHAPES(cb)
#endif
#undef cb
#undef FOREACH_CUTLASS_MATMUL_F16_SHAPES
#undef FOREACH_CUTLASS_MATMUL_MMA_SM70_SHAPES
#undef FOREACH_CUTLASS_MATMUL_MMA_SM75_SHAPES
}
#endif
......
......@@ -350,7 +350,7 @@ private:
std::string m_name;
};
#if CUDA_VERSION >= 10020
#if CUDA_VERSION >= 10010
class MatrixMulForwardImpl::AlgoFloat16TensorOp final
: public AlgoCutlassMatrixMulBase {
public:
......@@ -418,7 +418,7 @@ public:
std::vector<AlgoFloat32SIMT> simt_float32;
std::vector<AlgoFloat32SIMTSplitK> simt_float32_split_k;
std::vector<AlgoFloat32SIMTGemvBatchedStrided> simt_float32_gemv_batched_strided;
#if CUDA_VERSION >= 10020
#if CUDA_VERSION >= 10010
std::vector<AlgoFloat16TensorOp> tensorop_float16;
std::vector<AlgoFloat16TensorOpSplitK> tensorop_float16_split_k;
#endif
......
......@@ -15,7 +15,7 @@
#include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/utils.h"
#if CUDA_VERSION >= 10020
#if CUDA_VERSION >= 10010
using namespace megdnn;
using namespace cuda;
......
......@@ -15,7 +15,7 @@
#include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/utils.h"
#if CUDA_VERSION >= 10020
#if CUDA_VERSION >= 10010
using namespace megdnn;
using namespace cuda;
......
......@@ -46,8 +46,10 @@ public:
class AlgoFloat32SIMT;
class AlgoFloat32SIMTSplitK;
class AlgoFloat32SIMTGemvBatchedStrided;
#if CUDA_VERSION >= 10010
class AlgoFloat16TensorOp;
class AlgoFloat16TensorOpSplitK;
#endif
#endif
class AlgoPack;
......
......@@ -494,6 +494,21 @@ void check_chanwise(DType io_type, DType comp_type, Handle* handle, const char*
checker.set_param(gconv_param({M, 7, 7, 2, 2}, io16xc32))
.execs({{2, 1, 1, 15, 15}, {8, 2, 7, 7}, {8, 2, 14, 14}});
} else if (std::is_same<Op, ConvolutionBackwardFilter>::value) {
// align 8
checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
.execs({{8, 2, 16, 16}, {8, 2, 16, 16}, {2, 1, 1, 15, 15}});
// align 1
checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
.execs({{8, 2, 15, 15}, {8, 2, 15, 15}, {2, 1, 1, 15, 15}});
// align 2
checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
.execs({{8, 2, 14, 14}, {8, 2, 14, 14}, {2, 1, 1, 15, 15}});
// custom padding
checker.set_param(gconv_param({M, 3, 3, 1, 1}, io16xc32))
.execs({{8, 2, 16, 16}, {8, 2, 8, 8}, {2, 1, 1, 15, 15}});
// custom stride
checker.set_param(gconv_param({M, 7, 7, 2, 2}, io16xc32))
.execs({{8, 2, 14, 14}, {8, 2, 7, 7}, {2, 1, 1, 15, 15}});
}
}
} // namespace
......@@ -535,14 +550,32 @@ MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FMA_KERNEL(cb)
#undef cb
#define cb(tag, tbm, tbn, tbk, wm, wn, wk) \
TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_FILTER_CUTLASS_FMA_##tag) { \
require_compute_capability(6, 1); \
check_chanwise<ConvolutionBackwardFilter>( \
dtype::Float32(), dtype::Float32(), handle_cuda(), \
"FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \
"_" #wm "X" #wn "X" #wk "_2stage"); \
}
MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FMA_KERNEL(cb)
#undef cb
#undef MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FMA_KERNEL
#if CUDA_VERSION >= 10010
#define MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_HMMA_KERNEL(cb) \
cb(1, 128, 128, 32, 32, 32, 32); \
cb(2, 128, 256, 32, 64, 64, 32); \
cb(3, 128, 64, 32, 32, 32, 32); \
cb(4, 64, 128, 32, 32, 32, 32); \
cb(5, 64, 64, 32, 32, 32, 32);
#else
// hmma instruction need cuda version >= 10.2, disable hmma testcases in this path
#define MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_HMMA_KERNEL(cb)
#endif
// check both ioc16 and io16xc32
#define cb(tag, tbm, tbn, tbk, wm, wn, wk) \
......@@ -579,6 +612,19 @@ MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_HMMA_KERNEL(cb)
#undef cb
#define cb(tag, tbm, tbn, tbk, wm, wn, wk) \
TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_FILTER_CUTLASS_HMMA_##tag) { \
require_compute_capability(7, 0); \
check_chanwise<ConvolutionBackwardData>( \
dtype::Float16(), dtype::Float32(), handle_cuda(), \
"FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \
"_" #wm "X" #wn "X" #wk "_mma8X8X4_2stage"); \
}
MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_HMMA_KERNEL(cb)
#undef cb
#undef MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_HMMA_KERNEL
#if MEGDNN_WITH_BENCHMARK
......@@ -1434,6 +1480,77 @@ TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BACKWARD_DATA_LARGE_KERNEL) {
}
// clang-format on
}
TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BACKWARD_FILTER_LARGE_KERNEL) {
CUBenchmarker<ConvolutionBackwardFilter> bencher(handle_cuda());
size_t RUNS = 100;
bencher.set_display(false).set_times(RUNS);
std::unique_ptr<OprProxy<ConvolutionBackwardFilter>> proxy{
new OprProxy<ConvolutionBackwardFilter>{true}};
bencher.set_proxy(proxy);
Convolution::Param param;
param.format = ConvBias::Param::Format::NCHW;
param.sparse = Convolution::Param::Sparse::GROUP;
NormalRNG rng;
auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
param.pad_h = f / 2;
param.pad_w = f / 2;
param.stride_h = s;
param.stride_w = s;
param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
TensorLayout dst_layout;
auto opr = handle_cuda()->create_operator<Convolution>();
opr->param() = param;
opr->deduce_layout(
{src, dtype::Float32()}, {filter, dtype::Float32()}, dst_layout);
float bandwith = static_cast<float>(
src.total_nr_elems() + filter.total_nr_elems() +
dst_layout.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;
bencher.set_param(param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_rng(0, &rng)
.set_rng(1, &rng);
bencher.proxy()->target_execution_policy = {};
auto time_in_ms_fp32 = bencher.execs({src, src, filter}) / RUNS;
bencher.set_param(param)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_dtype(2, dtype::Float16())
.set_rng(0, &rng)
.set_rng(1, &rng);
bencher.proxy()->target_execution_policy = {};
param.compute_mode = param::Convolution::ComputeMode::FLOAT32;
bencher.set_param(param);
auto time_in_ms_pseudo_fp16 = bencher.execs({src, src, filter}) / RUNS;
printf("stride=%zu src=%s, filter=%s, float32: %.2fms %.2fGB/s "
"pseudo float16: %.2fms %.2fGB/s "
"speedup: "
"%0.2f (fp16/fp32) \n",
s, src.to_string().c_str(), filter.to_string().c_str(), time_in_ms_fp32,
bandwith * 4 / time_in_ms_fp32, time_in_ms_pseudo_fp16,
bandwith * 2 / time_in_ms_pseudo_fp16,
time_in_ms_fp32 / time_in_ms_pseudo_fp16);
};
// clang-format off
for (size_t b : {32, 64})
for (size_t f : {3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}) {
run(b, 384, 32, 32, f, 1);
run(b, 384, 64, 64, f, 1);
}
// clang-format on
}
#endif
// vim: syntax=cpp.doxygen
......@@ -1093,8 +1093,11 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_GROUP) {
run(2, 32, 7, 7, 3, 3, 64, 1, 1, 1, 1, 1, 1, 4, nlmode);
// strided case
run(2, 32, 7, 7, 3, 3, 64, 0, 0, 2, 2, 1, 1, 8, nlmode);
// dilate conv is supported in CUDNN since version 7.5.0
#if CUDNN_VERSION >= 7500
// dilated case
run(2, 32, 7, 7, 3, 3, 64, 0, 0, 1, 1, 2, 2, 8, nlmode);
#endif
}
}
......
......@@ -213,7 +213,7 @@ std::vector<BenchArgs> get_feat_model_args() {
return args;
}
#if CUDA_VERSION >= 10020
#if CUDA_VERSION >= 10010
std::vector<BenchArgs> get_f16_feat_model_args() {
std::vector<BenchArgs> args;
args.emplace_back(BenchArgs{128, 9216, 9216});
......@@ -367,7 +367,7 @@ MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
#undef cb
#undef MEGDNN_FOREACH_CUTLASS_KERNEL
#if CUDA_VERSION >= 10020
#if CUDA_VERSION >= 10010
#define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \
cb(1, 256, 128, 32, 64, 64, 32, 8, 8, 4); \
cb(2, 128, 256, 32, 64, 64, 32, 8, 8, 4); \
......@@ -403,7 +403,9 @@ MEGDNN_FOREACH_CUTLASS_KERNEL(cb)
#undef cb
#undef MEGDNN_FOREACH_CUTLASS_KERNEL
#endif
#if CUDA_VERSION >= 10020
#define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \
cb(1, 256, 128, 32, 64, 64, 32, 16, 8, 8); \
cb(2, 128, 256, 32, 64, 64, 32, 16, 8, 8); \
......@@ -454,7 +456,7 @@ TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL_FEAT) {
dtype::Float32(), "CUTLASS_FLOAT32_SIMT");
}
#if CUDA_VERSION >= 10020
#if CUDA_VERSION >= 10010
TEST_F(CUDA, BENCHMARK_CUTLASS_F16_MATMUL_FEAT) {
benchmark_matrix_mul(
handle_cuda(), get_f16_feat_model_args(), dtype::Float16(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册