提交 b18feaab 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): use cutlass remove shared load imma conv kernel

GitOrigin-RevId: 0b5574f52669ba88237967d3f00fceca0857b80f
上级 6b843ccd
......@@ -20,7 +20,7 @@ class Conv2dOperation:
#
def __init__(self, conv_kind, conv_type, arch, tile_description, src, flt, bias, dst, element_epilogue, \
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4, \
need_load_from_const = True, implicit_gemm_mode = ImplicitGemmMode.GemmNt):
need_load_from_const = True, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False):
self.operation_kind = OperationKind.Conv2d
self.conv_kind = conv_kind
......@@ -36,6 +36,7 @@ class Conv2dOperation:
self.swizzling_functor = swizzling_functor
self.need_load_from_const = need_load_from_const
self.implicit_gemm_mode = implicit_gemm_mode
self.without_shared_load = without_shared_load
#
def accumulator_type(self):
accum = self.tile_description.math_instruction.element_accumulator
......@@ -58,11 +59,15 @@ class Conv2dOperation:
unity_kernel = ''
if not self.need_load_from_const:
unity_kernel = '_1x1'
unity_kernel = '_1x1'
return "%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \
reorder_k = ''
if self.without_shared_load:
reorder_k = '_roc'
return "%s%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \
inst_shape, intermediate_type, ConvKindNames[self.conv_kind], unity_kernel, \
ShortEpilogueNames[self.epilogue_functor])
reorder_k, ShortEpilogueNames[self.epilogue_functor])
#
def extended_name(self):
......@@ -177,7 +182,8 @@ using Convolution =
${alignment_filter},
${nonuninity_kernel},
${math_operator},
${implicit_gemm_mode}>;
${implicit_gemm_mode},
${without_shared_load}>;
"""
......@@ -219,7 +225,8 @@ using Convolution =
'alignment_filter': str(operation.flt.alignment),
'nonuninity_kernel': str(operation.need_load_from_const).lower(),
'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation],
'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode]
'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode],
'without_shared_load': str(operation.without_shared_load).lower()
}
return SubstituteTemplate(self.template, values)
......@@ -312,13 +319,13 @@ using Deconvolution =
#
def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_layout, dst_type, min_cc, src_align = 32, flt_align = 32, dst_align = 128, \
skip_unity_kernel = False, implicit_gemm_mode = ImplicitGemmMode.GemmNt):
skip_unity_kernel = False, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False):
operations = []
element_epilogue = DataType.f32
if conv_kind == ConvKind.Fprop:
if src_layout == LayoutType.TensorNHWC:
swizzling_functor = SwizzlingFunctor.ConvFpropNHWC
if implicit_gemm_mode == ImplicitGemmMode.GemmTN:
swizzling_functor = SwizzlingFunctor.ConvFpropTrans
else:
swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx
else:
......@@ -399,10 +406,10 @@ def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_lay
bias = TensorDescription(bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type])))
dst = TensorDescription(dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type]))
new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, True, implicit_gemm_mode)
new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, True, implicit_gemm_mode, without_shared_load)
operations.append(new_operation)
if not skip_unity_kernel:
new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, False, implicit_gemm_mode)
new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, False, implicit_gemm_mode, without_shared_load)
operations.append(new_operation)
return operations
......
......@@ -175,12 +175,10 @@ def GenerateConv2d_Simt(args):
TileDescription([128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
]
......@@ -223,28 +221,36 @@ def GenerateConv2d_TensorOp_8816(args):
for dst_type, dst_layout in zip(dst_types, dst_layouts):
if dst_layout == LayoutType.TensorNC32HW32:
tile_descriptions = [
TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 64, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc),
]
operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1],
dst_layout, dst_type, min_cc, 128, 128, 64,
False, ImplicitGemmMode.GemmTN, True)
else:
assert dst_layout == LayoutType.TensorNC4HW4
tile_descriptions = [
TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc),
]
operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1],
operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1],
dst_layout, dst_type, min_cc, 128, 128, 64,
False)
return operations
def GenerateConv2d_TensorOp_8832(args):
......@@ -279,12 +285,14 @@ def GenerateConv2d_TensorOp_8832(args):
for dst_layout in dst_layouts:
dst_type = math_inst.element_b
tile_descriptions = [
TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 128], 2, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc),
]
operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1],
dst_layout, dst_type, min_cc, 128, 128, 64,
True)
True, ImplicitGemmMode.GemmTN, True)
layouts_nhwc = [
(LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32),
......@@ -299,14 +307,21 @@ def GenerateConv2d_TensorOp_8832(args):
for math_inst in math_instructions:
for layout in layouts_nhwc:
for dst_layout in dst_layouts_nhwc:
dst_type = math_inst.element_b
tile_descriptions = [
TileDescription([128, 32, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc),
]
operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1],
dst_layout, dst_type, min_cc, layout[2], layout[2], 32,
False, ImplicitGemmMode.GemmTn)
dst_type = math_inst.element_b
tile_descriptions = [
TileDescription([128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc),
]
for tile in tile_descriptions:
operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1],
dst_layout, dst_type, min_cc, layout[2], layout[2], 32,
False, ImplicitGemmMode.GemmTN, False)
if tile.threadblock_shape[1] == 32 or tile.threadblock_shape[1] == 64:
dst_align = 32 if tile.threadblock_shape[1] == 32 else 64
operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1],
dst_layout, dst_type, min_cc, layout[2], layout[2], dst_align,
False, ImplicitGemmMode.GemmTN, True)
return operations
def GenerateDeconv_Simt(args):
......@@ -649,3 +664,4 @@ if __name__ == "__main__":
#
###################################################################################################
\ No newline at end of file
......@@ -464,10 +464,10 @@ EpilogueFunctorTag = {
ShortEpilogueNames = {
EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'hswish',
EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'relu',
EpilogueFunctor.BiasAddLinearCombinationClamp: 'identity',
EpilogueFunctor.BiasAddLinearCombinationClamp: 'id',
EpilogueFunctor.BiasAddLinearCombinationHSwish: 'hswish',
EpilogueFunctor.BiasAddLinearCombinationRelu: 'relu',
EpilogueFunctor.BiasAddLinearCombination: 'identity',
EpilogueFunctor.BiasAddLinearCombination: 'id',
}
......@@ -482,7 +482,7 @@ class SwizzlingFunctor(enum.Enum):
Identity4 = enum_auto()
Identity8 = enum_auto()
ConvFpropNCxHWx = enum_auto()
ConvFpropNHWC = enum_auto()
ConvFpropTrans = enum_auto()
ConvDgradNCxHWx = enum_auto()
#
......@@ -492,7 +492,7 @@ SwizzlingFunctorTag = {
SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>',
SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>',
SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle',
SwizzlingFunctor.ConvFpropNHWC: 'cutlass::conv::threadblock::ConvolutionFpropNHWCThreadblockSwizzle',
SwizzlingFunctor.ConvFpropTrans: 'cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle',
SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle',
}
......@@ -563,17 +563,17 @@ StrideSupportNames = {
}
class ImplicitGemmMode(enum.Enum):
GemmNt = enum_auto()
GemmTn = enum_auto()
GemmNT = enum_auto()
GemmTN = enum_auto()
ImplicitGemmModeNames = {
ImplicitGemmMode.GemmNt: 'gemm_nt',
ImplicitGemmMode.GemmTn: 'gemm_tn',
ImplicitGemmMode.GemmNT: 'gemm_nt',
ImplicitGemmMode.GemmTN: 'gemm_tn',
}
ImplicitGemmModeTag = {
ImplicitGemmMode.GemmNt: 'cutlass::conv::ImplicitGemmMode::GEMM_NT',
ImplicitGemmMode.GemmTn: 'cutlass::conv::ImplicitGemmMode::GEMM_TN',
ImplicitGemmMode.GemmNT: 'cutlass::conv::ImplicitGemmMode::GEMM_NT',
ImplicitGemmMode.GemmTN: 'cutlass::conv::ImplicitGemmMode::GEMM_TN',
}
###################################################################################################
......
......@@ -217,56 +217,68 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() {
#if CUDA_VERSION >= 10020
{
using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam;
int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64});
int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64});
int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64});
int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64});
int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64});
int8_nchw32_imma.emplace_back(AlgoParam{64, 64, 64, 32, 32, 64});
int8_nchw32_imma.emplace_back(AlgoParam{32, 64, 64, 32, 16, 64});
int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64, 2});
int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64, 2});
int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64, 2});
int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64, 2});
int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64, 2});
int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 1});
int8_nchw32_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1});
int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 32, 32, 64, 32, 1});
int8_nchw32_imma.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 1});
}
{
using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam;
int4_int4_nchw64_imma.emplace_back(
AlgoParam{128, 128, 128, 64, 64, 128});
AlgoParam{128, 128, 128, 64, 64, 128, 2});
int4_int4_nchw64_imma.emplace_back(
AlgoParam{256, 128, 128, 64, 64, 128});
AlgoParam{128, 256, 128, 64, 64, 128, 2});
int4_int4_nchw64_imma.emplace_back(
AlgoParam{128, 64, 128, 64, 64, 128, 2});
int4_int4_nchw64_imma.emplace_back(
AlgoParam{128, 64, 64, 64, 64, 64, 1});
}
{
using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam;
uint4_int4_nchw64_imma.emplace_back(
AlgoParam{128, 128, 128, 64, 64, 128});
AlgoParam{128, 128, 128, 64, 64, 128, 2});
uint4_int4_nchw64_imma.emplace_back(
AlgoParam{128, 256, 128, 64, 64, 128, 2});
uint4_int4_nchw64_imma.emplace_back(
AlgoParam{256, 128, 128, 64, 64, 128});
AlgoParam{128, 64, 128, 64, 64, 128, 2});
uint4_int4_nchw64_imma.emplace_back(
AlgoParam{128, 64, 64, 64, 64, 64, 1});
}
{
using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam;
int4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 32, 64, 64, 32, 64, 32});
AlgoParam{128, 32, 64, 64, 32, 64, 1, 32});
int4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 32, 64, 64, 32, 64, 1, 16});
int4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 32, 64, 64, 32, 64, 1, 8});
int4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 32, 64, 64, 32, 64, 16});
int4_int4_nhwc_imma.emplace_back(AlgoParam{128, 32, 64, 64, 32, 64, 8});
AlgoParam{128, 64, 64, 64, 64, 64, 1, 32});
int4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 64, 64, 64, 64, 64, 32});
AlgoParam{128, 64, 64, 64, 64, 64, 1, 16});
int4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 64, 64, 64, 64, 64, 16});
int4_int4_nhwc_imma.emplace_back(AlgoParam{128, 64, 64, 64, 64, 64, 8});
AlgoParam{128, 64, 64, 64, 64, 64, 1, 8});
}
{
using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam;
uint4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 32, 64, 64, 32, 64, 32});
AlgoParam{128, 32, 64, 64, 32, 64, 1, 32});
uint4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 32, 64, 64, 32, 64, 16});
AlgoParam{128, 32, 64, 64, 32, 64, 1, 16});
uint4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 32, 64, 64, 32, 64, 8});
AlgoParam{128, 32, 64, 64, 32, 64, 1, 8});
uint4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 64, 64, 64, 64, 64, 32});
AlgoParam{128, 64, 64, 64, 64, 64, 1, 32});
uint4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 64, 64, 64, 64, 64, 16});
AlgoParam{128, 64, 64, 64, 64, 64, 1, 16});
uint4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 64, 64, 64, 64, 64, 8});
AlgoParam{128, 64, 64, 64, 64, 64, 1, 8});
}
#endif
}
......@@ -279,10 +291,8 @@ void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() {
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 64, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 32, 32, 32, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1});
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 2});
}
......
......@@ -723,6 +723,7 @@ public:
int warp_m;
int warp_n;
int warp_k;
int stage;
};
AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param)
: m_algo_param{algo_param} {
......@@ -770,6 +771,7 @@ public:
int warp_m;
int warp_n;
int warp_k;
int stage;
};
AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param)
......@@ -897,6 +899,7 @@ public:
int warp_m;
int warp_n;
int warp_k;
int stage;
int access_size;
};
......
......@@ -38,7 +38,7 @@ void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32(
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float scale,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
cudaStream_t stream);
int stages, cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4(
......@@ -47,7 +47,7 @@ void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4(
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float scale,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
cudaStream_t stream);
int stages, cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4(
......@@ -83,7 +83,7 @@ void do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64(
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float scale,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
cudaStream_t stream);
int stages, cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64(
......@@ -92,7 +92,7 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64(
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float delta, float theta,
float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape,
const GemmCoord& warp_shape, cudaStream_t stream);
const GemmCoord& warp_shape, int stages, cudaStream_t stream);
template <bool signedness>
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc(
......@@ -110,7 +110,7 @@ void do_conv_bias_int4_int4_implicit_gemm_imma_nhwc(
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float scale,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
const int32_t access_size, cudaStream_t stream);
const int32_t access_size, int stages, cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc(
......@@ -119,7 +119,7 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc(
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float delta, float theta,
float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape,
const GemmCoord& warp_shape, const int32_t access_size,
const GemmCoord& warp_shape, const int32_t access_size, int stages,
cudaStream_t stream);
} // namespace cutlass_wrapper
......
/**
* \file dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh"
#include "src/cuda/query_blocksize.cuh"
#include "src/cuda/integer_subbyte_utils.cuh"
using namespace megdnn;
using namespace cuda;
using namespace cutlass_wrapper;
namespace {
template <uint32_t size_bits, uint32_t interleaved>
__device__ __forceinline__ void reorder_ncxhwx_imma_filter_func(
int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH,
uint32_t FW, uint32_t lane, bool trans_oc) {
static constexpr uint32_t elements_per_lane = 128 / size_bits;
static constexpr uint32_t threads_per_interleaved =
interleaved / elements_per_lane;
static constexpr uint32_t instruction_shape_col = 8;
// 4 threads per Quad
static constexpr uint32_t elements_per_thread = instruction_shape_col / 4;
// 4 threads per Quad
static constexpr uint32_t reordered_elements_per_thread = interleaved / 4;
uint32_t id = lane / threads_per_interleaved;
uint32_t residue = lane % threads_per_interleaved;
uint32_t ICx = IC / interleaved;
uint32_t row = id / (ICx * FH * FW);
uint32_t col = id - row * ICx * FH * FW;
// transpose ncxhwx to cxhwnx
uint32_t src_offset = id * interleaved + residue * elements_per_lane;
row = (trans_oc) ? (row / interleaved) * interleaved +
((row % reordered_elements_per_thread) /
elements_per_thread) *
instruction_shape_col +
((row % interleaved) /
reordered_elements_per_thread) *
elements_per_thread +
(row % elements_per_thread)
: row;
uint32_t dst_offset =
(col * OC + row) * interleaved + residue * elements_per_lane;
*(reinterpret_cast<int4*>(dst + dst_offset * size_bits / 8)) =
*(reinterpret_cast<const int4*>(src + src_offset * size_bits / 8));
}
template <uint32_t size_bits, uint32_t interleaved>
__global__ void reorder_ncxhwx_imma_filter_kernel(
int8_t* __restrict__ dst_filter, const int8_t* __restrict__ src_filter,
uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc) {
static constexpr uint32_t elements_per_lane = 128 / size_bits;
const uint32_t size = OC * IC * FH * FW / elements_per_lane;
uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x;
if (lane < size) {
reorder_ncxhwx_imma_filter_func<size_bits, interleaved>(
dst_filter, src_filter, OC, IC, FH, FW, lane, trans_oc);
}
}
template <uint32_t size_bits, uint32_t alignbits, uint32_t interleaved>
__device__ __forceinline__ void reorder_nhwc_imma_filter_func(
int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH,
uint32_t FW, uint32_t lane, bool trans_oc) {
static constexpr uint32_t elements_per_access = alignbits / size_bits;
static constexpr uint32_t instruction_shape_col = 8;
// 4 threads per Quad
static constexpr uint32_t elements_per_thread = instruction_shape_col / 4;
// 4 threads per Quad
static constexpr uint32_t reordered_elements_per_thread = interleaved / 4;
uint32_t ICx = IC / elements_per_access;
uint32_t k = lane / (ICx * FH * FW);
uint32_t cxrs = lane - k * ICx * FH * FW;
uint32_t rs = cxrs / ICx;
uint32_t cx = cxrs - rs * ICx;
// transpose nhwc to ncxhwx
uint32_t src_offset = lane * elements_per_access;
// reorder k
k = (trans_oc)
? (k / interleaved) * interleaved +
((k % reordered_elements_per_thread) /
elements_per_thread) *
instruction_shape_col +
((k % interleaved) / reordered_elements_per_thread) *
elements_per_thread +
(k % elements_per_thread)
: k;
uint32_t dst_offset =
(k * ICx * FH * FW + cx * FH * FW + rs) * elements_per_access;
if (alignbits == 32) {
*(reinterpret_cast<int*>(dst + dst_offset * size_bits / 8)) = *(
reinterpret_cast<const int*>(src + src_offset * size_bits / 8));
} else if (alignbits == 64) {
*(reinterpret_cast<int2*>(dst + dst_offset * size_bits / 8)) =
*(reinterpret_cast<const int2*>(src +
src_offset * size_bits / 8));
} else {
*(reinterpret_cast<int4*>(dst + dst_offset * size_bits / 8)) =
*(reinterpret_cast<const int4*>(src +
src_offset * size_bits / 8));
}
}
template <uint32_t size_bits, uint32_t alignbits, uint32_t interleaved>
__global__ void reorder_nhwc_imma_filter_kernel(
int8_t* __restrict__ dst_filter, const int8_t* __restrict__ src_filter,
uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc) {
static constexpr uint32_t elements_per_access = alignbits / size_bits;
const uint32_t size = OC * IC * FH * FW / elements_per_access;
uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x;
if (lane < size) {
reorder_nhwc_imma_filter_func<size_bits, alignbits, interleaved>(
dst_filter, src_filter, OC, IC, FH, FW, lane, trans_oc);
}
}
} // namespace
template <uint32_t size_bits, uint32_t interleaved>
void megdnn::cuda::cutlass_wrapper::reorder_ncxhwx_imma_filter(
int8_t* dst_filter, const int8_t* src_filter, uint32_t OC, uint32_t IC,
uint32_t FH, uint32_t FW, bool trans_oc, cudaStream_t stream) {
static constexpr uint32_t elements_per_lane = 128 / size_bits;
uint32_t nr_threads =
query_blocksize_for_kernel(reinterpret_cast<const void*>(
reorder_ncxhwx_imma_filter_kernel<size_bits, interleaved>));
uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_lane);
nr_threads = std::min(nr_threads, vthreads);
uint32_t nr_blocks = DIVUP(vthreads, nr_threads);
reorder_ncxhwx_imma_filter_kernel<size_bits, interleaved>
<<<nr_blocks, nr_threads, 0, stream>>>(dst_filter, src_filter, OC,
IC, FH, FW, trans_oc);
after_kernel_launch();
}
template <uint32_t size_bits, uint32_t alignbits>
void megdnn::cuda::cutlass_wrapper::reorder_nhwc_imma_filter(
int8_t* dst_filter, const int8_t* src_filter, uint32_t OC, uint32_t IC,
uint32_t FH, uint32_t FW, bool trans_oc, uint32_t oc_interleaved,
cudaStream_t stream) {
static constexpr uint32_t elements_per_access = alignbits / size_bits;
uint32_t nr_threads =
query_blocksize_for_kernel(reinterpret_cast<const void*>(
reorder_nhwc_imma_filter_kernel<size_bits, alignbits, 32>));
uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_access);
nr_threads = std::min(nr_threads, vthreads);
uint32_t nr_blocks = DIVUP(vthreads, nr_threads);
if (oc_interleaved == 32) {
reorder_nhwc_imma_filter_kernel<size_bits, alignbits, 32>
<<<nr_blocks, nr_threads, 0, stream>>>(
dst_filter, src_filter, OC, IC, FH, FW, trans_oc);
} else {
reorder_nhwc_imma_filter_kernel<size_bits, alignbits, 64>
<<<nr_blocks, nr_threads, 0, stream>>>(
dst_filter, src_filter, OC, IC, FH, FW, trans_oc);
}
after_kernel_launch();
}
#define INST(_size_bits, _interleaved) \
template void megdnn::cuda::cutlass_wrapper::reorder_ncxhwx_imma_filter< \
_size_bits, _interleaved>(int8_t * dst_filter, \
const int8_t* src_filter, uint32_t OC, \
uint32_t IC, uint32_t FH, uint32_t FW, \
bool trans_oc, cudaStream_t stream);
INST(8, 32)
INST(4, 64)
#undef INST
#define INST(_size_bits, _alignbits) \
template void megdnn::cuda::cutlass_wrapper::reorder_nhwc_imma_filter< \
_size_bits, _alignbits>( \
int8_t * dst_filter, const int8_t* src_filter, uint32_t OC, \
uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc, \
uint32_t oc_interleaved, cudaStream_t stream);
INST(4, 32)
INST(4, 64)
INST(4, 128)
#undef INST
// vim: syntax=cuda.doxygen
/**
* \file dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/cuda/utils.cuh"
namespace megdnn {
namespace cuda {
namespace cutlass_wrapper {
template <uint32_t size_bits, uint32_t interleaved>
void reorder_ncxhwx_imma_filter(int8_t* dst_filter, const int8_t* src_filter,
uint32_t OC, uint32_t IC, uint32_t FH,
uint32_t FW, bool trans_oc,
cudaStream_t stream);
template <uint32_t size_bits, uint32_t alignbits>
void reorder_nhwc_imma_filter(int8_t* dst_filter, const int8_t* src_filter,
uint32_t OC, uint32_t IC, uint32_t FH,
uint32_t FW, bool trans_oc,
uint32_t oc_interleaved, cudaStream_t stream);
} // namespace cutlass_wrapper
} // namespace cuda
} // namespace megdnn
......@@ -102,7 +102,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::do_exec(
reinterpret_cast<int8_t*>(z_ptr),
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale,
threadblock_shape, warp_shape, stream);
threadblock_shape, warp_shape, m_algo_param.stage, stream);
}
#endif
......
......@@ -104,7 +104,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec(
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale,
threadblock_shape, warp_shape, m_algo_param.access_size,
stream);
m_algo_param.stage, stream);
} else {
cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc<true>(
reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr),
......@@ -114,7 +114,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec(
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale,
threadblock_shape, warp_shape, m_algo_param.access_size,
stream);
m_algo_param.stage, stream);
}
}
#endif
......
......@@ -12,6 +12,7 @@
#include "./algo.h"
#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
#include "src/cuda/conv_bias/reduce_filter.cuh"
#include "src/cuda/convolution_helper/parameter.cuh"
......@@ -121,41 +122,26 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec(
std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::to_string(
AlgoParam algo_param) {
return ssprintf("%dX%dX%d_%dX%dX%d", algo_param.threadblock_m,
return ssprintf("%dX%dX%d_%dX%dX%d_%d", algo_param.threadblock_m,
algo_param.threadblock_n, algo_param.threadblock_k,
algo_param.warp_m, algo_param.warp_n, algo_param.warp_k);
algo_param.warp_m, algo_param.warp_n, algo_param.warp_k,
algo_param.stage);
}
void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::reorder_filter(
const ExecArgs& args, void* reordered_filter) const {
auto&& param = args.opr->param();
size_t ci = args.src_layout->operator[](1) * 64;
size_t co = args.dst_layout->operator[](1) * 64;
auto&& fm = args.filter_meta;
size_t n = args.src_layout->operator[](0),
ci = args.src_layout->operator[](1) * 64,
hi = args.src_layout->operator[](2),
wi = args.src_layout->operator[](3);
size_t co = args.dst_layout->operator[](1) * 64,
ho = args.dst_layout->operator[](2),
wo = args.dst_layout->operator[](3);
UNPACK_CONV_PARAMETER(fm, param);
MARK_USED_VAR;
// filter: KCRS64 => CRSK64
TensorLayout src{{co, ci / 64, fh, fw, 64}, dtype::QuantizedS4()};
src.init_contiguous_stride();
TensorLayout dst = src;
dst.stride[0] = 64;
dst.stride[1] = co * fh * fw * 64;
dst.stride[2] = co * fw * 64;
dst.stride[3] = co * 64;
dst.stride[4] = 1;
TensorND ts_src, ts_dst;
ts_src.raw_ptr = args.filter_tensor->raw_ptr;
ts_src.layout = src;
ts_dst.raw_ptr = reordered_filter;
ts_dst.layout = dst;
auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>();
transpose->exec(ts_src, ts_dst);
size_t fh = fm.spatial[0], fw = fm.spatial[1];
cudaStream_t stream = cuda_stream(args.opr->handle());
// filter: KCRS64 => CRSK64 and reorder oc
cutlass_wrapper::reorder_ncxhwx_imma_filter<4, 64>(
reinterpret_cast<int8_t*>(reordered_filter),
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh,
fw, true, stream);
}
#endif
......
......@@ -12,6 +12,7 @@
#include "./algo.h"
#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
#include "src/cuda/conv_bias/reduce_filter.cuh"
#include "src/cuda/convolution_helper/parameter.cuh"
......@@ -128,10 +129,10 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec(
std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::to_string(
AlgoParam algo_param) {
return ssprintf("%dX%dX%d_%dX%dX%d_%d", algo_param.threadblock_m,
return ssprintf("%dX%dX%d_%dX%dX%d_%d_%d", algo_param.threadblock_m,
algo_param.threadblock_n, algo_param.threadblock_k,
algo_param.warp_m, algo_param.warp_n, algo_param.warp_k,
algo_param.access_size);
algo_param.stage, algo_param.access_size);
}
void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter(
......@@ -142,17 +143,32 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter(
fh = args.filter_layout->operator[](1),
fw = args.filter_layout->operator[](2);
// reformat grad from nhwc to ncxhwx
TensorLayout exec_src{{co, fh, fw, ci / iterleaved, (size_t)iterleaved / 2},
dtype::Int8()};
TensorLayout exec_dst{{co, ci / iterleaved, fh, fw, (size_t)iterleaved / 2},
dtype::Int8()};
exec_src = exec_src.dimshuffle({0, 3, 1, 2, 4});
cudaStream_t stream = cuda_stream(args.opr->handle());
auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>();
relayout->exec({args.filter_tensor->raw_ptr, exec_src},
{reordered_filter, exec_dst});
// reformat filter from nhwc to ncxhwx and reorder oc
// use trans_oc threadblock_n must be 32 or 64
bool trans_oc = ((co % m_algo_param.threadblock_n == 0) &&
(m_algo_param.threadblock_n == 32 ||
m_algo_param.threadblock_n == 64));
uint32_t oc_iterleave = (m_algo_param.threadblock_n == 64) ? 64 : 32;
if (iterleaved == 8) {
cutlass_wrapper::reorder_nhwc_imma_filter<4, 32>(
reinterpret_cast<int8_t*>(reordered_filter),
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci,
fh, fw, trans_oc, oc_iterleave, stream);
} else if (iterleaved == 16) {
cutlass_wrapper::reorder_nhwc_imma_filter<4, 64>(
reinterpret_cast<int8_t*>(reordered_filter),
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci,
fh, fw, trans_oc, oc_iterleave, stream);
} else {
megdnn_assert(iterleaved == 32);
cutlass_wrapper::reorder_nhwc_imma_filter<4, 128>(
reinterpret_cast<int8_t*>(reordered_filter),
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci,
fh, fw, trans_oc, oc_iterleave, stream);
}
}
#endif
......
......@@ -11,6 +11,7 @@
*/
#include "./algo.h"
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/utils.h"
......@@ -110,11 +111,14 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
size_t ho = args.dst_layout->operator[](2),
wo = args.dst_layout->operator[](3);
size_t co;
bool trans_oc;
if (param.format == Format::NCHW32) {
co = args.dst_layout->operator[](1) * 32;
trans_oc = true;
} else {
megdnn_assert(param.format == Format::NCHW32_NCHW4);
co = args.dst_layout->operator[](1) * 4;
trans_oc = false;
}
UNPACK_CONV_PARAMETER(fm, param);
MARK_USED_VAR
......@@ -123,23 +127,11 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
int8_t* filter_ptr = nullptr;
if (args.preprocessed_filter == nullptr) {
filter_ptr = reinterpret_cast<int8_t*>(args.workspace.raw_ptr);
// reformat filter from nchw32 to chwn32
TensorLayout src{{co, ci / 32, fh, fw, 32}, dtype::Int8()};
src.init_contiguous_stride();
TensorLayout dst = src;
dst.stride[0] = 32;
dst.stride[1] = co * fh * fw * 32;
dst.stride[2] = co * fw * 32;
dst.stride[3] = co * 32;
dst.stride[4] = 1;
TensorND ts_src, ts_dst;
ts_src.raw_ptr = args.filter_tensor->raw_ptr;
ts_src.layout = src;
ts_dst.raw_ptr = args.workspace.raw_ptr;
ts_dst.layout = dst;
auto&& transpose =
args.opr->handle()->create_operator<RelayoutForward>();
transpose->exec(ts_src, ts_dst);
// filter: KCRS32 => CRSK32 and reorder oc
cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>(
filter_ptr,
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci,
fh, fw, trans_oc, stream);
} else {
filter_ptr = reinterpret_cast<int8_t*>(
args.preprocessed_filter->tensors[0].raw_ptr);
......@@ -182,7 +174,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
m_algo_param.stage, stream);
} else {
megdnn_assert(param.format == Format::NCHW32_NCHW4);
cutlass_wrapper::
......@@ -202,7 +194,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
m_algo_param.stage, stream);
}
} else {
if (param.format == Format::NCHW32) {
......@@ -218,7 +210,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
m_algo_param.stage, stream);
} else {
megdnn_assert(param.format == Format::NCHW32_NCHW4);
cutlass_wrapper::
......@@ -238,7 +230,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
m_algo_param.stage, stream);
}
}
after_kernel_launch();
......@@ -246,9 +238,10 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
std::string ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::to_string(
AlgoParam algo_param) {
return ssprintf("%uX%uX%u_%uX%uX%u", algo_param.threadblock_m,
return ssprintf("%uX%uX%u_%uX%uX%u_%u", algo_param.threadblock_m,
algo_param.threadblock_n, algo_param.threadblock_k,
algo_param.warp_m, algo_param.warp_n, algo_param.warp_k);
algo_param.warp_m, algo_param.warp_n, algo_param.warp_k,
algo_param.stage);
}
size_t ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::
......@@ -267,36 +260,26 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec_preprocess(
using Format = Param::Format;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
size_t n = args.src_layout->operator[](0),
ci = args.src_layout->operator[](1) * 32,
hi = args.src_layout->operator[](2),
wi = args.src_layout->operator[](3);
size_t ho = args.dst_layout->operator[](2),
wo = args.dst_layout->operator[](3);
size_t ci = args.src_layout->operator[](1) * 32;
size_t co;
bool trans_oc;
if (param.format == Format::NCHW32) {
co = args.dst_layout->operator[](1) * 32;
trans_oc = true;
} else {
megdnn_assert(param.format == Format::NCHW32_NCHW4);
co = args.dst_layout->operator[](1) * 4;
trans_oc = false;
}
UNPACK_CONV_PARAMETER(fm, param);
MARK_USED_VAR
TensorLayout src{{co, ci / 32, fh, fw, 32}, dtype::Int8()};
src.init_contiguous_stride();
TensorLayout dst = src;
dst.stride[0] = 32;
dst.stride[1] = co * fh * fw * 32;
dst.stride[2] = co * fw * 32;
dst.stride[3] = co * 32;
dst.stride[4] = 1;
TensorND ts_src, ts_dst;
ts_src.raw_ptr = args.filter_tensor->raw_ptr;
ts_src.layout = src;
ts_dst.raw_ptr = args.preprocessed_filter->tensors[0].raw_ptr;
ts_dst.layout = dst;
auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>();
transpose->exec(ts_src, ts_dst);
size_t fh = fm.spatial[0], fw = fm.spatial[1];
cudaStream_t stream = cuda_stream(args.opr->handle());
// filter: KCRS32 => CRSK32 and reorder oc
cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>(
reinterpret_cast<int8_t*>(
args.preprocessed_filter->tensors[0].raw_ptr),
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh,
fw, trans_oc, stream);
}
#endif
......
......@@ -144,7 +144,8 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::do_exec(
reinterpret_cast<uint8_t*>(z_ptr),
reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta,
dst_scale, src_zero, threadblock_shape, warp_shape, stream);
dst_scale, src_zero, threadblock_shape, warp_shape,
m_algo_param.stage, stream);
}
void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias(
......
......@@ -147,7 +147,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec(
reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta,
dst_scale, src_zero, threadblock_shape, warp_shape,
m_algo_param.access_size, stream);
m_algo_param.access_size, m_algo_param.stage, stream);
} else {
cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc<true>(
reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr),
......@@ -157,7 +157,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec(
reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta,
dst_scale, src_zero, threadblock_shape, warp_shape,
m_algo_param.access_size, stream);
m_algo_param.access_size, m_algo_param.stage, stream);
}
}
......
......@@ -840,21 +840,21 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_IMMA) {
param.pad_h = param.pad_w = 1;
param.stride_h = param.stride_w = 1;
param.format = param::ConvBias::Format::NCHW32;
checker.set_param(param).execs({{16, 16, 7, 7, 32},
{512, 16, 3, 3, 32},
{1, 16, 1, 1, 32},
checker.set_param(param).execs({{16, 8, 7, 7, 32},
{256, 8, 3, 3, 32},
{1, 8, 1, 1, 32},
{},
{}});
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
checker.set_param(param).execs({{16, 16, 7, 7, 32},
{512, 16, 1, 1, 32},
{1, 16, 1, 1, 32},
checker.set_param(param).execs({{16, 8, 7, 7, 32},
{256, 8, 1, 1, 32},
{1, 8, 1, 1, 32},
{},
{}});
param.nonlineMode = param::ConvBias::NonlineMode::H_SWISH;
checker.set_param(param).execs({{16, 16, 7, 7, 32},
{512, 16, 3, 3, 32},
{1, 16, 1, 1, 32},
checker.set_param(param).execs({{16, 8, 7, 7, 32},
{256, 8, 3, 3, 32},
{1, 8, 1, 1, 32},
{},
{}});
// use non integer scale
......@@ -867,18 +867,18 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_IMMA) {
.set_epsilon(1 + 1e-3)
.set_max_avg_error(1e-1)
.set_max_avg_biased_error(1e-1)
.execs({{16, 16, 7, 7, 32},
{512, 16, 3, 3, 32},
{1, 16, 1, 1, 32},
{16, 16, 7, 7, 32},
.execs({{16, 8, 7, 7, 32},
{256, 8, 3, 3, 32},
{1, 8, 1, 1, 32},
{16, 8, 7, 7, 32},
{}});
};
std::string algo = ConvBias::algo_name<ConvBias::DirectParam>(
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_256X128X64_64X64X64",
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X128X64_64X64X64_2",
ConvBias::DirectParam{});
check(algo);
algo = ConvBias::algo_name<ConvBias::DirectParam>(
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_32X64X64_32X16X64",
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X32X32_64X32X32_1",
ConvBias::DirectParam{});
check(algo);
}
......@@ -969,7 +969,7 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_NCHW4) {
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<
ConvBiasForward>(
ConvBias::algo_name<ConvBias::DirectParam>(
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_256X128X64_64X64X64",
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X128X64_64X64X64_2",
ConvBias::DirectParam{})
.c_str()));
checker.set_dtype(0, dtype::QuantizedS8(1.9980618f))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册