提交 9b4b910d 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): integrate cutlass operation table and replace all cutlass wrappers

GitOrigin-RevId: 2a70335441e8a844dcf3c2d00bbf6db381ad9623
上级 b18feaab
......@@ -163,7 +163,7 @@ using Convolution =
${element_bias},
${layout_bias},
${element_accumulator},
${conv_type},
${conv_type},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
......@@ -246,6 +246,7 @@ using Deconvolution =
${element_bias},
${layout_bias},
${element_accumulator},
${conv_type},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
......@@ -276,6 +277,7 @@ using Deconvolution =
values = {
'operation_name': operation.procedural_name(),
'conv_type': ConvTypeTag[operation.conv_type],
'element_src': DataTypeTag[operation.src.element],
'layout_src': LayoutTag[operation.src.layout],
'element_flt': DataTypeTag[operation.flt.element],
......@@ -530,44 +532,17 @@ void initialize_${configuration_name}(Manifest &manifest) {
###################################################################################################
class EmitConvSingleKernelWrapper():
def __init__(self, kernel_path, operation, wrapper_path):
def __init__(self, kernel_path, operation):
self.kernel_path = kernel_path
self.wrapper_path = wrapper_path
self.operation = operation
self.conv_wrappers = { \
ConvKind.Fprop: """
template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>(
const typename Convolution::ElementSrc* d_src,
const typename Convolution::ElementFilter* d_filter,
const typename Convolution::ElementBias* d_bias,
const typename Convolution::ElementDst* d_z,
typename Convolution::ElementDst* d_dst,
int* workspace,
typename Convolution::ConvolutionParameter const& conv_param,
typename Convolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream,
typename Convolution::ExtraParam extra_param);
""", \
ConvKind.Dgrad: """
template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Deconvolution>(
const typename Deconvolution::ElementSrc* d_src,
const typename Deconvolution::ElementFilter* d_filter,
const typename Deconvolution::ElementBias* d_bias,
const typename Deconvolution::ElementDst* d_z,
typename Deconvolution::ElementDst* d_dst,
int* workspace,
typename Deconvolution::ConvolutionParameter const& conv_param,
typename Deconvolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);
""", \
}
if self.operation.conv_kind == ConvKind.Fprop:
self.instance_emitter = EmitConv2dInstance()
self.convolution_name = "Convolution"
else:
assert self.operation.conv_kind == ConvKind.Dgrad
self.instance_emitter = EmitDeconvInstance()
self.convolution_name = "Deconvolution"
self.header_template = """
#if !MEGDNN_TEGRA_X1
......@@ -575,13 +550,30 @@ template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Decon
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "${wrapper_path}"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "cutlass/convolution/device/convolution.h"
#include "src/cuda/cutlass/manifest.h"
#include "src/cuda/cutlass/convolution_operation.h"
"""
self.instance_template = """
${operation_instance}
"""
self.wrapper_template = """
${wrapper_instance}
self.manifest_template = """
namespace cutlass {
namespace library {
void initialize_${operation_name}(Manifest &manifest) {
manifest.append(new ConvolutionOperation<${convolution_name}>(
"${operation_name}"
));
}
} // namespace library
} // namespace cutlass
"""
self.epilogue_template = """
......@@ -593,9 +585,7 @@ ${wrapper_instance}
def __enter__(self):
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
self.kernel_file = LazyFile(self.kernel_path)
self.kernel_file.write(SubstituteTemplate(self.header_template, {
'wrapper_path': self.wrapper_path,
}))
self.kernel_file.write(self.header_template)
return self
#
......@@ -604,11 +594,12 @@ ${wrapper_instance}
'operation_instance': self.instance_emitter.emit(self.operation),
}))
# emit wrapper
wrapper = SubstituteTemplate(self.wrapper_template, {
'wrapper_instance': self.conv_wrappers[self.operation.conv_kind],
# emit manifest helper
manifest = SubstituteTemplate(self.manifest_template, {
'operation_name': self.operation.procedural_name(),
'convolution_name': self.convolution_name
})
self.kernel_file.write(wrapper)
self.kernel_file.write(manifest)
#
def __exit__(self, exception_type, exception_value, traceback):
......
......@@ -940,8 +940,8 @@ void initialize_${configuration_name}(Manifest &manifest) {
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -995,48 +995,101 @@ void initialize_${configuration_name}(Manifest &manifest) {
###################################################################################################
class EmitGemmSingleKernelWrapper:
def __init__(self, kernel_path, gemm_operation, wrapper_path):
def __init__(self, kernel_path, gemm_operation):
self.kernel_path = kernel_path
self.wrapper_path = wrapper_path
self.operation = gemm_operation
gemm_wrapper = """
template void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_wrapper<Operation_${operation_name}>(
const typename Operation_${operation_name}::ElementA* d_A, size_t lda,
const typename Operation_${operation_name}::ElementB* d_B, size_t ldb,
typename Operation_${operation_name}::ElementC* d_C, size_t ldc,
int* workspace,
cutlass::gemm::GemmCoord const& problem_size,
typename Operation_${operation_name}::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream, int split_k_slices);
instance_emitters = {
GemmKind.Gemm: EmitGemmInstance(),
GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(),
}
self.instance_emitter = instance_emitters[self.operation.gemm_kind]
self.header_template = """
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_splitk_parallel.h"
#include "src/cuda/cutlass/manifest.h"
#include "src/cuda/cutlass/gemm_operation.h"
"""
self.instance_template = """
${operation_instance}
"""
self.manifest_template = """
namespace cutlass {
namespace library {
void initialize_${operation_name}(Manifest &manifest) {
manifest.append(new GemmOperation<
Operation_${operation_name}
>("${operation_name}"));
}
} // namespace library
} // namespace cutlass
"""
self.epilogue_template = """
#pragma GCC diagnostic pop
#endif
"""
#
def __enter__(self):
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
self.kernel_file = LazyFile(self.kernel_path)
self.kernel_file.write(self.header_template)
return self
#
def emit(self):
self.kernel_file.write(SubstituteTemplate(self.instance_template, {
'operation_instance': self.instance_emitter.emit(self.operation),
}))
gemv_wrapper = """
# emit manifest helper
manifest = SubstituteTemplate(self.manifest_template, {
'operation_name': self.operation.procedural_name(),
})
self.kernel_file.write(manifest)
#
def __exit__(self, exception_type, exception_value, traceback):
self.kernel_file.write(self.epilogue_template)
self.kernel_file.close()
###################################################################################################
###################################################################################################
class EmitGemvSingleKernelWrapper:
def __init__(self, kernel_path, gemm_operation, wrapper_path):
self.kernel_path = kernel_path
self.wrapper_path = wrapper_path
self.operation = gemm_operation
self.wrapper_template = """
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_${operation_name}>(
BatchedGemmCoord const& problem_size,
const typename Operation_${operation_name}::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, size_t batch_stride_b,
const typename Operation_${operation_name}::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_${operation_name}::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
"""
if self.operation.gemm_kind == GemmKind.SplitKParallel or \
self.operation.gemm_kind == GemmKind.Gemm:
self.wrapper_template = gemm_wrapper
else:
assert self.operation.gemm_kind == GemmKind.GemvBatchedStrided
self.wrapper_template = gemv_wrapper
instance_emitters = {
GemmKind.Gemm: EmitGemmInstance(),
GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(),
GemmKind.GemvBatchedStrided: EmitGemvBatchedStridedInstance(),
}
self.instance_emitter = instance_emitters[self.operation.gemm_kind]
self.instance_emitter = EmitGemvBatchedStridedInstance()
self.header_template = """
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
......@@ -1055,10 +1108,10 @@ ${operation_instance}
"""
#
def __enter__(self):
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
self.kernel_file = LazyFile(self.kernel_path)
self.kernel_file.write(SubstituteTemplate(self.header_template, {
'wrapper_path': self.wrapper_path,
'wrapper_path': self.wrapper_path,
}))
return self
......@@ -1070,7 +1123,7 @@ ${operation_instance}
# emit wrapper
wrapper = SubstituteTemplate(self.wrapper_template, {
'operation_name': self.operation.procedural_name(),
'operation_name': self.operation.procedural_name(),
})
self.kernel_file.write(wrapper)
......@@ -1079,7 +1132,5 @@ ${operation_instance}
self.kernel_file.write(self.epilogue_template)
self.kernel_file.close()
###################################################################################################
###################################################################################################
......@@ -23,6 +23,8 @@ def write_op_list(f, gen_op, gen_type):
operations = GenerateDeconvOperations(GenArg(gen_op, gen_type))
for op in operations:
f.write(' "%s.cu",\n' % op.procedural_name())
if gen_op != "gemv":
f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type))
if __name__ == "__main__":
......
......@@ -292,7 +292,7 @@ def GenerateConv2d_TensorOp_8832(args):
]
operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1],
dst_layout, dst_type, min_cc, 128, 128, 64,
True, ImplicitGemmMode.GemmTN, True)
False, ImplicitGemmMode.GemmTN, True)
layouts_nhwc = [
(LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32),
......@@ -633,16 +633,10 @@ if __name__ == "__main__":
parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832'],
default='simt', help="kernel type of CUTLASS kernel generator")
operation2wrapper_path = {
"gemm": "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuinl", \
"gemv": "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl", \
"conv2d": "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl", \
"deconv": "src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl", \
}
gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
args = parser.parse_args()
wrapper_path = operation2wrapper_path[args.operations]
if args.operations == "gemm":
operations = GenerateGemmOperations(args)
elif args.operations == "gemv":
......@@ -652,16 +646,22 @@ if __name__ == "__main__":
elif args.operations == "deconv":
operations = GenerateDeconvOperations(args)
if args.operations == "conv2d" or args.operations == "deconv":
for operation in operations:
with EmitConvSingleKernelWrapper(args.output, operation, wrapper_path) as emitter:
with EmitConvSingleKernelWrapper(args.output, operation) as emitter:
emitter.emit()
elif args.operations == "gemm" or args.operations == "gemv":
elif args.operations == "gemm":
for operation in operations:
with EmitGemmSingleKernelWrapper(args.output, operation, wrapper_path) as emitter:
with EmitGemmSingleKernelWrapper(args.output, operation) as emitter:
emitter.emit()
elif args.operations == "gemv":
for operation in operations:
with EmitGemvSingleKernelWrapper(args.output, operation, gemv_wrapper_path) as emitter:
emitter.emit()
if args.operations != "gemv":
GenerateManifest(args, operations, args.output)
#
###################################################################################################
\ No newline at end of file
......@@ -137,6 +137,7 @@ cutlass_gen_list = [
"cutlass_simt_sgemm_split_k_parallel_256x32_8x2_tt_align1.cu",
"cutlass_simt_sgemm_256x64_8x2_tt_align1.cu",
"cutlass_simt_sgemm_split_k_parallel_256x64_8x2_tt_align1.cu",
"all_gemm_simt_operations.cu",
"cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4.cu",
"cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2.cu",
"cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1.cu",
......@@ -169,6 +170,7 @@ cutlass_gen_list = [
"cutlass_simt_s8_idgrad_id_s8_16x128x16_16x64x16_2_nc4hw4_k4rsc4.cu",
"cutlass_simt_s8_idgrad_id_s8_16x128x16_16x128x16_1_nc4hw4_k4rsc4.cu",
"cutlass_simt_s8_idgrad_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu",
"all_deconv_simt_operations.cu",
"cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu",
"cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu",
"cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu",
......@@ -373,6 +375,7 @@ cutlass_gen_list = [
"cutlass_simt_f32_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu",
"cutlass_simt_f32_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu",
"cutlass_simt_f32_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu",
"all_conv2d_simt_operations.cu",
"cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu",
"cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu",
"cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu",
......@@ -481,26 +484,47 @@ cutlass_gen_list = [
"cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu",
"cutlass_tensorop_s8_i8816fprop_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu",
"cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu",
"all_conv2d_tensorop8816_operations.cu",
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu",
"cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu",
"cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu",
"cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu",
......@@ -621,4 +645,5 @@ cutlass_gen_list = [
"cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu",
"cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu",
"cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu",
"all_conv2d_tensorop8832_operations.cu",
]
\ No newline at end of file
......@@ -8,6 +8,7 @@ import enum
import os.path
import shutil
from lazy_file import LazyFile
from library import *
from gemm_operation import *
from conv2d_operation import *
......@@ -349,3 +350,41 @@ void initialize_all(Manifest &manifest) {
#
###################################################################################################
def GenerateManifest(args, operations, output_dir):
manifest_path = os.path.join(output_dir, "all_%s_%s_operations.cu" % (args.operations, args.type))
f = LazyFile(manifest_path)
f.write("""
/*
Generated by generator.py - Do not edit.
*/
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
#include "cutlass/cutlass.h"
#include "src/cuda/cutlass/library.h"
#include "src/cuda/cutlass/manifest.h"
namespace cutlass {
namespace library {
""")
for op in operations:
f.write("void initialize_%s(Manifest &manifest);\n" % op.procedural_name())
f.write("""
void initialize_all_%s_%s_operations(Manifest &manifest) {
""" % (args.operations, args.type))
for op in operations:
f.write(" initialize_%s(manifest);\n" % op.procedural_name())
f.write("""
}
} // namespace library
} // namespace cutlass
#endif
""")
f.close()
......@@ -217,68 +217,77 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() {
#if CUDA_VERSION >= 10020
{
using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam;
int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64, 2});
int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64, 2});
int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64, 2});
int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64, 2});
int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64, 2});
int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 1});
int8_nchw32_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1});
int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 32, 32, 64, 32, 1});
int8_nchw32_imma.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 1});
int8_nchw32_imma.emplace_back(
AlgoParam{128, 256, 64, 64, 64, 64, 8, 8, 16, 2});
int8_nchw32_imma.emplace_back(
AlgoParam{256, 128, 64, 64, 64, 64, 8, 8, 16, 2});
int8_nchw32_imma.emplace_back(
AlgoParam{128, 128, 64, 64, 64, 64, 8, 8, 16, 2});
int8_nchw32_imma.emplace_back(
AlgoParam{128, 64, 64, 64, 32, 64, 8, 8, 16, 2});
int8_nchw32_imma.emplace_back(
AlgoParam{64, 128, 64, 32, 64, 64, 8, 8, 16, 2});
int8_nchw32_imma.emplace_back(
AlgoParam{128, 64, 32, 64, 32, 32, 8, 8, 16, 1});
int8_nchw32_imma.emplace_back(
AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1});
int8_nchw32_imma.emplace_back(
AlgoParam{64, 128, 32, 32, 64, 32, 8, 8, 16, 1});
int8_nchw32_imma.emplace_back(
AlgoParam{32, 128, 32, 32, 64, 32, 8, 8, 16, 1});
}
{
using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam;
int4_int4_nchw64_imma.emplace_back(
AlgoParam{128, 128, 128, 64, 64, 128, 2});
AlgoParam{128, 128, 128, 64, 64, 128, 8, 8, 32, 2});
int4_int4_nchw64_imma.emplace_back(
AlgoParam{128, 256, 128, 64, 64, 128, 2});
AlgoParam{128, 256, 128, 64, 64, 128, 8, 8, 32, 2});
int4_int4_nchw64_imma.emplace_back(
AlgoParam{128, 64, 128, 64, 64, 128, 2});
AlgoParam{128, 64, 128, 64, 64, 128, 8, 8, 32, 2});
int4_int4_nchw64_imma.emplace_back(
AlgoParam{128, 64, 64, 64, 64, 64, 1});
AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1});
}
{
using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam;
uint4_int4_nchw64_imma.emplace_back(
AlgoParam{128, 128, 128, 64, 64, 128, 2});
AlgoParam{128, 128, 128, 64, 64, 128, 8, 8, 32, 2});
uint4_int4_nchw64_imma.emplace_back(
AlgoParam{128, 256, 128, 64, 64, 128, 2});
AlgoParam{128, 256, 128, 64, 64, 128, 8, 8, 32, 2});
uint4_int4_nchw64_imma.emplace_back(
AlgoParam{128, 64, 128, 64, 64, 128, 2});
AlgoParam{128, 64, 128, 64, 64, 128, 8, 8, 32, 2});
uint4_int4_nchw64_imma.emplace_back(
AlgoParam{128, 64, 64, 64, 64, 64, 1});
AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1});
}
{
using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam;
int4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 32, 64, 64, 32, 64, 1, 32});
AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 32});
int4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 32, 64, 64, 32, 64, 1, 16});
AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 16});
int4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 32, 64, 64, 32, 64, 1, 8});
AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 8});
int4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 64, 64, 64, 64, 64, 1, 32});
AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 32});
int4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 64, 64, 64, 64, 64, 1, 16});
AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 16});
int4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 64, 64, 64, 64, 64, 1, 8});
AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 8});
}
{
using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam;
uint4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 32, 64, 64, 32, 64, 1, 32});
AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 32});
uint4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 32, 64, 64, 32, 64, 1, 16});
AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 16});
uint4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 32, 64, 64, 32, 64, 1, 8});
AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 8});
uint4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 64, 64, 64, 64, 64, 1, 32});
AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 32});
uint4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 64, 64, 64, 64, 64, 1, 16});
AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 16});
uint4_int4_nhwc_imma.emplace_back(
AlgoParam{128, 64, 64, 64, 64, 64, 1, 8});
AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 8});
}
#endif
}
......@@ -286,15 +295,24 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() {
void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() {
using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam;
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1});
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 2});
int8_nchw4_dotprod.emplace_back(
AlgoParam{128, 128, 32, 64, 32, 32, 1, 1, 4, 2});
int8_nchw4_dotprod.emplace_back(
AlgoParam{128, 64, 32, 64, 32, 32, 1, 1, 4, 2});
int8_nchw4_dotprod.emplace_back(
AlgoParam{64, 128, 32, 64, 32, 32, 1, 1, 4, 2});
int8_nchw4_dotprod.emplace_back(
AlgoParam{32, 128, 32, 32, 64, 32, 1, 1, 4, 2});
int8_nchw4_dotprod.emplace_back(
AlgoParam{128, 32, 32, 64, 32, 32, 1, 1, 4, 2});
int8_nchw4_dotprod.emplace_back(
AlgoParam{32, 64, 32, 32, 64, 32, 1, 1, 4, 2});
int8_nchw4_dotprod.emplace_back(
AlgoParam{64, 32, 32, 64, 32, 32, 1, 1, 4, 2});
int8_nchw4_dotprod.emplace_back(
AlgoParam{16, 128, 16, 16, 128, 16, 1, 1, 4, 1});
int8_nchw4_dotprod.emplace_back(
AlgoParam{16, 64, 8, 16, 64, 8, 1, 1, 4, 2});
}
ConvBiasForwardImpl::AlgoBase*
......
......@@ -28,6 +28,17 @@
#include <memory>
#include <unordered_map>
namespace cutlass {
namespace library {
// forward declaration of cutlass library concepts, we hope that algo.h does
// not depend on cutlass headers
class Operation;
} // namespace library
} // namespace cutlass
namespace megdnn {
namespace cuda {
......@@ -505,9 +516,44 @@ public:
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8)
};
class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final
: public AlgoBase {
/*********************** Cutlass Algorithms ************************/
/* The inheritance of cutlass algorithm classes:
*
* AlgoCutlassConvolutionBase
* +
* +--- AlgoInt8NCHW4DotProdImplicitGemm
* +--- AlgoInt8NCHW32IMMAImplicitGemm
* +
* +--- AlgoInt4NCHW64IMMAImplicitGemmBase
* +----+--- AlgoInt4Int4NCHW64IMMAImplicitGemm
* +----+--- AlgoUInt4Int4NCHW64IMMAImplicitGemm
* +
* +--- AlgoInt4NHWCIMMAImplicitGemmBase
* +----+--- AlgoInt4Int4NHWCIMMAImplicitGemm
* +----+--- AlgoUInt4Int4NHWCIMMAImplicitGemm
* +
*/
/*
* The base class for all cutlass algorithm classes
*/
class ConvBiasForwardImpl::AlgoCutlassConvolutionBase : public AlgoBase {
public:
// corresponds to cutlass::conv::Operator. we hope that algo.h does not
// depend on cutlass headers
enum class ConvOperator { kFprop, kDgrad, kWgrad };
// corresponds to cutlass::conv::ConvType. we hope that algo.h does not
// depend on cutlass headers
enum class ConvType {
kConvolution,
kBatchConvolution,
kLocal,
kLocalShare
};
// common parameters for operation selection
struct AlgoParam {
int threadblock_m;
int threadblock_n;
......@@ -515,21 +561,54 @@ public:
int warp_m;
int warp_n;
int warp_k;
int instruction_m;
int instruction_n;
int instruction_k;
int stage;
std::string to_string() {
/// default algorithm
if (threadblock_m == 128 && threadblock_n == 128 &&
threadblock_k == 32 && warp_m == 32 && warp_n == 64 &&
warp_k == 32 && stage == 2) {
return "";
}
return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m,
threadblock_n, threadblock_k, warp_m, warp_n,
warp_k, stage);
}
int access_size;
AlgoParam(int threadblock_m_, int threadblock_n_, int threadblock_k_,
int warp_m_, int warp_n_, int warp_k_, int instruction_m_,
int instruction_n_, int instruction_k_, int stage_,
int access_size_ = 0);
std::string to_string() const;
};
AlgoCutlassConvolutionBase(AlgoParam algo_param)
: m_algo_param{algo_param} {}
// generate a cutlass::library::ConvolutionKey and find the corresponding
// operation (cutlass kernel) from the global OperationTable
const cutlass::library::Operation* get_cutlass_conv_op(
const SizeArgs& args, ConvOperator conv_op, ConvType conv_type,
bool load_from_const, bool without_shared_load) const;
// execute the cutlass kernel found by get_cutlass_conv_op. we give
// subclasses full freedom to decide where and how these arguments are
// extracted
void execute_cutlass_conv_op(const cutlass::library::Operation* op,
const void* src, const void* filter,
const void* bias, const void* z, void* dst,
void* workspace, size_t n, size_t hi,
size_t wi, size_t ci, size_t co, size_t fh,
size_t fw, size_t ho, size_t wo, size_t ph,
size_t pw, size_t sh, size_t sw, size_t dh,
size_t dw, const void* alpha, const void* beta,
const void* gamma, const void* delta,
const void* theta, const void* threshold,
const void* dst_scale, cudaStream_t stream,
const void* extra_param = nullptr) const;
protected:
AlgoParam m_algo_param;
};
class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final
: public AlgoCutlassConvolutionBase {
public:
AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param)
: m_algo_param{algo_param},
: AlgoCutlassConvolutionBase(algo_param),
m_name{ssprintf("INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s",
m_algo_param.to_string().c_str())} {}
bool is_available(const SizeArgs& args) const override;
......@@ -555,7 +634,6 @@ public:
private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
const SizeArgs& args) const;
AlgoParam m_algo_param;
std::string m_name;
};
......@@ -714,19 +792,10 @@ private:
#if CUDA_VERSION >= 10020
class ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm final
: public AlgoBase {
: public AlgoCutlassConvolutionBase {
public:
struct AlgoParam {
int threadblock_m;
int threadblock_n;
int threadblock_k;
int warp_m;
int warp_n;
int warp_k;
int stage;
};
AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param)
: m_algo_param{algo_param} {
: AlgoCutlassConvolutionBase(algo_param) {
m_name = ConvBias::algo_name<ConvBias::DirectParam>(
ssprintf("INT8_NCHW32_IMMA_IMPLICIT_GEMM_%s",
to_string(m_algo_param).c_str()),
......@@ -757,25 +826,14 @@ private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
const SizeArgs& args) const;
AlgoParam m_algo_param;
std::string m_name;
};
class ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase
: public AlgoBase {
: public AlgoCutlassConvolutionBase {
public:
struct AlgoParam {
int threadblock_m;
int threadblock_n;
int threadblock_k;
int warp_m;
int warp_n;
int warp_k;
int stage;
};
AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param)
: m_algo_param(algo_param) {}
: AlgoCutlassConvolutionBase(algo_param) {}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
......@@ -799,16 +857,9 @@ protected:
virtual std::tuple<float, float, float, float, float> get_constants(
const ExecArgs& args) const = 0;
virtual void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr,
void* z_ptr, convolution::ConvParam kern_param,
uint32_t nonlinear_mode, float alpha, float beta,
float gamma, float delta, float theta,
cudaStream_t stream) const = 0;
void reorder_filter(const ExecArgs& args, void* reordered_filter) const;
std::string m_name;
AlgoParam m_algo_param;
};
class ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm final
......@@ -842,11 +893,6 @@ private:
std::tuple<float, float, float, float, float> get_constants(
const ExecArgs& args) const override;
void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr,
void* z_ptr, convolution::ConvParam kern_param,
uint32_t nonlinear_mode, float alpha, float beta, float gamma,
float delta, float theta, cudaStream_t stream) const override;
};
class ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm final
......@@ -881,30 +927,15 @@ private:
std::tuple<float, float, float, float, float> get_constants(
const ExecArgs& args) const override;
void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr,
void* z_ptr, convolution::ConvParam kern_param,
uint32_t nonlinear_mode, float alpha, float beta, float gamma,
float delta, float theta, cudaStream_t stream) const override;
void update_bias(const ExecArgs& args, void* updated_bias,
void* reduce_filter_ptr, void* reduce_workspace) const;
};
class ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase : public AlgoBase {
class ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase
: public AlgoCutlassConvolutionBase {
public:
struct AlgoParam {
int threadblock_m;
int threadblock_n;
int threadblock_k;
int warp_m;
int warp_n;
int warp_k;
int stage;
int access_size;
};
AlgoInt4NHWCIMMAImplicitGemmBase(AlgoParam algo_param)
: m_algo_param(algo_param) {}
: AlgoCutlassConvolutionBase(algo_param) {}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
......@@ -928,17 +959,10 @@ protected:
virtual std::tuple<float, float, float, float, float> get_constants(
const ExecArgs& args) const = 0;
virtual void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr,
void* z_ptr, convolution::ConvParam kern_param,
uint32_t nonlinear_mode, float alpha, float beta,
float gamma, float delta, float theta,
cudaStream_t stream) const = 0;
void reorder_filter(const ExecArgs& args, int interleaved,
void* reordered_filter) const;
std::string m_name;
AlgoParam m_algo_param;
};
class ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm final
......@@ -971,11 +995,6 @@ private:
std::tuple<float, float, float, float, float> get_constants(
const ExecArgs& args) const override;
void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr,
void* z_ptr, convolution::ConvParam kern_param,
uint32_t nonlinear_mode, float alpha, float beta, float gamma,
float delta, float theta, cudaStream_t stream) const override;
};
class ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm final
......@@ -1009,11 +1028,6 @@ private:
std::tuple<float, float, float, float, float> get_constants(
const ExecArgs& args) const override;
void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr,
void* z_ptr, convolution::ConvParam kern_param,
uint32_t nonlinear_mode, float alpha, float beta, float gamma,
float delta, float theta, cudaStream_t stream) const override;
void update_bias(const ExecArgs& args, void* updated_bias,
void* reduce_filter_ptr, void* reduce_workspace) const;
};
......
/**
* \file dnn/src/cuda/conv_bias/cutlass_convolution_base.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/cutlass/singleton.h"
namespace megdnn {
namespace cuda {
using namespace cutlass::library;
using namespace cutlass::epilogue;
ConvBiasForwardImpl::AlgoCutlassConvolutionBase::AlgoParam::AlgoParam(
int threadblock_m_, int threadblock_n_, int threadblock_k_, int warp_m_,
int warp_n_, int warp_k_, int instruction_m_, int instruction_n_,
int instruction_k_, int stage_, int access_size_)
: threadblock_m(threadblock_m_),
threadblock_n(threadblock_n_),
threadblock_k(threadblock_k_),
warp_m(warp_m_),
warp_n(warp_n_),
warp_k(warp_k_),
instruction_m(instruction_m_),
instruction_n(instruction_m_),
instruction_k(instruction_k_),
stage(stage_),
access_size(access_size_) {}
std::string
ConvBiasForwardImpl::AlgoCutlassConvolutionBase::AlgoParam::to_string() const {
/// default algorithm
if (threadblock_m == 128 && threadblock_n == 128 && threadblock_k == 32 &&
warp_m == 32 && warp_n == 64 && warp_k == 32 && stage == 2) {
return "";
}
return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n,
threadblock_k, warp_m, warp_n, warp_k, stage);
}
namespace {
using Base = ConvBiasForwardImpl::AlgoCutlassConvolutionBase;
cutlass::conv::Operator convert_conv_op(Base::ConvOperator conv_op) {
switch (conv_op) {
case Base::ConvOperator::kFprop:
return cutlass::conv::Operator::kFprop;
case Base::ConvOperator::kDgrad:
return cutlass::conv::Operator::kDgrad;
case Base::ConvOperator::kWgrad:
return cutlass::conv::Operator::kWgrad;
default:
megdnn_assert(0, "invalid conv op");
}
}
cutlass::conv::ConvType convert_conv_type(Base::ConvType conv_type) {
switch (conv_type) {
case Base::ConvType::kConvolution:
return cutlass::conv::ConvType::kConvolution;
case Base::ConvType::kBatchConvolution:
return cutlass::conv::ConvType::kBatchConvolution;
case Base::ConvType::kLocal:
return cutlass::conv::ConvType::kLocal;
case Base::ConvType::kLocalShare:
return cutlass::conv::ConvType::kLocalShare;
default:
megdnn_assert(0, "invalid conv type");
}
}
NumericTypeID convert_dtype(DTypeEnum dtype) {
switch (dtype) {
case DTypeEnum::Float32:
return NumericTypeID::kF32;
case DTypeEnum::Float16:
return NumericTypeID::kF16;
case DTypeEnum::Int8:
return NumericTypeID::kS8;
case DTypeEnum::QuantizedS32:
return NumericTypeID::kS32;
case DTypeEnum::QuantizedS8:
return NumericTypeID::kS8;
case DTypeEnum::QuantizedS4:
return NumericTypeID::kS4;
case DTypeEnum::Quantized4Asymm:
return NumericTypeID::kU4;
default:
megdnn_assert(0, "invalid dtype");
}
}
struct LayoutPack {
LayoutTypeID src;
LayoutTypeID filter;
LayoutTypeID dst;
LayoutTypeID bias;
};
LayoutPack get_layout_pack(const param::ConvBias::Format format,
int access_type) {
using Format = param::ConvBias::Format;
switch (format) {
case Format::NCHW4:
return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4,
LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorNC4HW4};
case Format::NCHW4_NCHW:
return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4,
LayoutTypeID::kTensorNCHW, LayoutTypeID::kTensorNCHW};
case Format::NCHW4_NHWC:
return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4,
LayoutTypeID::kTensorNHWC, LayoutTypeID::kTensorNHWC};
case Format::NCHW4_NCHW32:
return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4,
LayoutTypeID::kTensorNC32HW32,
LayoutTypeID::kTensorNC32HW32};
case Format::NCHW32:
return {LayoutTypeID::kTensorNC32HW32,
LayoutTypeID::kTensorC32RSK32,
LayoutTypeID::kTensorNC32HW32,
LayoutTypeID::kTensorNC32HW32};
case Format::NCHW32_NCHW4:
return {LayoutTypeID::kTensorNC32HW32,
LayoutTypeID::kTensorC32RSK32, LayoutTypeID::kTensorNC4HW4,
LayoutTypeID::kTensorNC4HW4};
case Format::NCHW64:
return {LayoutTypeID::kTensorNC64HW64,
LayoutTypeID::kTensorC64RSK64,
LayoutTypeID::kTensorNC64HW64,
LayoutTypeID::kTensorNC64HW64};
case Format::NHWC:
switch (access_type) {
case 8:
return {LayoutTypeID::kTensorNHWC,
LayoutTypeID::kTensorNC8HW8,
LayoutTypeID::kTensorNHWC,
LayoutTypeID::kTensorNHWC};
case 16:
return {LayoutTypeID::kTensorNHWC,
LayoutTypeID::kTensorNC16HW16,
LayoutTypeID::kTensorNHWC,
LayoutTypeID::kTensorNHWC};
case 32:
return {LayoutTypeID::kTensorNHWC,
LayoutTypeID::kTensorNC32HW32,
LayoutTypeID::kTensorNHWC,
LayoutTypeID::kTensorNHWC};
default:
megdnn_assert(0, "invalid access_type");
}
default:
megdnn_assert(0, "invalid format");
}
}
EpilogueType get_epilogue_type(const param::ConvBias::NonlineMode mode,
bool clamp) {
using NonlineMode = param::ConvBias::NonlineMode;
if (clamp) {
if (mode == NonlineMode::IDENTITY) {
return EpilogueType::kBiasAddLinearCombinationClamp;
} else if (mode == NonlineMode::RELU) {
return EpilogueType::kBiasAddLinearCombinationReluClamp;
} else if (mode == NonlineMode::H_SWISH) {
return EpilogueType::kBiasAddLinearCombinationHSwishClamp;
}
} else {
if (mode == NonlineMode::IDENTITY) {
return EpilogueType::kBiasAddLinearCombination;
} else if (mode == NonlineMode::RELU) {
return EpilogueType::kBiasAddLinearCombinationRelu;
} else if (mode == NonlineMode::H_SWISH) {
return EpilogueType::kBiasAddLinearCombinationHSwish;
}
}
megdnn_assert(0, "invalid nonlinear mode");
}
} // namespace
const Operation*
ConvBiasForwardImpl::AlgoCutlassConvolutionBase::get_cutlass_conv_op(
const SizeArgs& args, ConvOperator conv_op, ConvType conv_type,
bool load_from_const, bool without_shared_load) const {
using Format = param::ConvBias::Format;
auto&& param = args.opr->param();
auto layouts = get_layout_pack(param.format, m_algo_param.access_size);
auto epilogue_type = get_epilogue_type(param.nonlineMode,
param.format != Format::NCHW4_NCHW);
ConvolutionKey key{convert_conv_op(conv_op),
convert_dtype(args.src_layout->dtype.enumv()),
layouts.src,
convert_dtype(args.filter_layout->dtype.enumv()),
layouts.filter,
convert_dtype(args.dst_layout->dtype.enumv()),
layouts.dst,
convert_dtype(args.bias_layout->dtype.enumv()),
layouts.bias,
convert_conv_type(conv_type),
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k,
m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k,
m_algo_param.instruction_m,
m_algo_param.instruction_n,
m_algo_param.instruction_k,
epilogue_type,
m_algo_param.stage,
load_from_const,
without_shared_load};
return Singleton::get().operation_table.find_op(key);
}
void ConvBiasForwardImpl::AlgoCutlassConvolutionBase::execute_cutlass_conv_op(
const Operation* op, const void* src, const void* filter,
const void* bias, const void* z, void* dst, void* workspace, size_t n,
size_t hi, size_t wi, size_t ci, size_t co, size_t fh, size_t fw,
size_t ho, size_t wo, size_t ph, size_t pw, size_t sh, size_t sw,
size_t dh, size_t dw, const void* alpha, const void* beta,
const void* gamma, const void* delta, const void* theta,
const void* threshold, const void* dst_scale, cudaStream_t stream,
const void* extra_param) const {
// gcc prints warnings when size_t values are implicitly narrowed to int
cutlass::conv::Conv2dProblemSize problem_size{
int(n), int(hi), int(wi), int(ci),
int(co), int(fh), int(fw), int(ho),
int(wo), int(ph), int(pw), int(sh),
int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation};
ConvolutionArguments conv_args{
problem_size, src, filter, bias, z,
dst, alpha, beta, gamma, delta,
theta, threshold, dst_scale, extra_param};
cutlass_check(op->run(&conv_args, workspace, stream));
}
} // namespace cuda
} // namespace megdnn
/**
* \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "cutlass/gemm/gemm.h"
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/utils.cuh"
namespace megdnn {
namespace cuda {
namespace cutlass_wrapper {
using GemmCoord = cutlass::gemm::GemmCoord;
template <typename Convolution>
void cutlass_convolution_wrapper(
const typename Convolution::ElementSrc* d_src,
const typename Convolution::ElementFilter* d_filter,
const typename Convolution::ElementBias* d_bias,
const typename Convolution::ElementDst* d_z,
typename Convolution::ElementDst* d_dst, int* workspace,
typename Convolution::ConvolutionParameter const& conv_param,
typename Convolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream, typename Convolution::ExtraParam extra_param = {});
template <bool NeedLoadFromConstMem>
void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32(
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias,
const int8_t* d_z, int8_t* d_dst, int* workspace,
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float scale,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
int stages, cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4(
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias,
const int8_t* d_z, int8_t* d_dst, int* workspace,
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float scale,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
int stages, cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4(
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias,
const int8_t* d_z, int8_t* d_dst, int* workspace,
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float scale,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
int stages, cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw(
const int8_t* d_src, const int8_t* d_filter, const float* d_bias,
const float* d_z, float* d_dst, int* workspace,
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float scale,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
int stages, cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32(
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias,
const int8_t* d_z, int8_t* d_dst, int* workspace,
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float scale,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
int stages, cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64(
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias,
const int8_t* d_z, int8_t* d_dst, int* workspace,
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float scale,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
int stages, cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64(
const uint8_t* d_src, const int8_t* d_filter, const int32_t* d_bias,
const uint8_t* d_z, uint8_t* d_dst, int* workspace,
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float delta, float theta,
float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape,
const GemmCoord& warp_shape, int stages, cudaStream_t stream);
template <bool signedness>
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc(
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias,
const int8_t* d_z, int8_t* d_dst, int* workspace,
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float delta, float theta,
float scale, const GemmCoord& threadblock_shape,
const GemmCoord& warp_shape, int stages, cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_int4_int4_implicit_gemm_imma_nhwc(
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias,
const int8_t* d_z, int8_t* d_dst, int* workspace,
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float scale,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
const int32_t access_size, int stages, cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc(
const uint8_t* d_src, const int8_t* d_filter, const int32_t* d_bias,
const uint8_t* d_z, uint8_t* d_dst, int* workspace,
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float delta, float theta,
float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape,
const GemmCoord& warp_shape, const int32_t access_size, int stages,
cudaStream_t stream);
} // namespace cutlass_wrapper
} // namespace cuda
} // namespace megdnn
// vim: syntax=cuda.doxygen
/**
* \file
* dnn/src/cuda/conv_bias/int8/implicit_gemm_conv_bias_cutlass_wrapper.cuinl
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "cutlass/convolution/device/convolution.h"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
using namespace megdnn;
using namespace cuda;
using namespace cutlass_wrapper;
template <typename Convolution>
void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper(
const typename Convolution::ElementSrc* d_src,
const typename Convolution::ElementFilter* d_filter,
const typename Convolution::ElementBias* d_bias,
const typename Convolution::ElementDst* d_z,
typename Convolution::ElementDst* d_dst, int* workspace,
typename Convolution::ConvolutionParameter const& conv_param,
typename Convolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream, typename Convolution::ExtraParam extra_param) {
typename Convolution::TensorRefSrc tensor_src{
const_cast<typename Convolution::ElementSrc*>(d_src),
Convolution::LayoutSrc::packed(
{conv_param.N, conv_param.H, conv_param.W, conv_param.C})};
typename Convolution::TensorRefFilter tensor_filter{
const_cast<typename Convolution::ElementFilter*>(d_filter),
Convolution::LayoutFilter::packed(
{conv_param.K, conv_param.R, conv_param.S, conv_param.C})};
typename Convolution::TensorRefBias tensor_bias{
const_cast<typename Convolution::ElementBias*>(d_bias),
Convolution::LayoutBias::packed({1, 1, 1, conv_param.K})};
typename Convolution::TensorRefDst tensor_z{
const_cast<typename Convolution::ElementDst*>(d_z),
Convolution::LayoutDst::packed(
{conv_param.N, conv_param.P, conv_param.Q, conv_param.K})};
typename Convolution::TensorRefDst tensor_dst{
d_dst,
Convolution::LayoutDst::packed(
{conv_param.N, conv_param.P, conv_param.Q, conv_param.K})};
typename Convolution::Arguments arguments{conv_param,
tensor_src.non_const_ref(),
tensor_filter.non_const_ref(),
tensor_bias.non_const_ref(),
tensor_z.non_const_ref(),
tensor_dst.non_const_ref(),
epilogue,
{},
{},
extra_param};
Convolution conv_op;
cutlass_check(conv_op.initialize(arguments, workspace));
cutlass_check(conv_op(stream));
after_kernel_launch();
}
// vim: syntax=cuda.doxygen
......@@ -10,8 +10,7 @@
* implied.
*/
#include "./algo.h"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
#include "src/cuda/conv_bias/algo.h"
using namespace megdnn;
using namespace cuda;
......@@ -81,29 +80,6 @@ ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::get_constants(
return {alpha, beta, gamma, delta, theta};
}
void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::do_exec(
const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr,
ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta,
float gamma, float delta, float theta, cudaStream_t stream) const {
float dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale;
cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k};
cutlass_wrapper::GemmCoord warp_shape{
m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k};
cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64<
true>(reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr),
reinterpret_cast<int8_t*>(filter_ptr),
reinterpret_cast<int32_t*>(bias_ptr),
reinterpret_cast<int8_t*>(z_ptr),
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale,
threadblock_shape, warp_shape, m_algo_param.stage, stream);
}
#endif
// vim: syntax=cpp.doxygen
......@@ -10,8 +10,7 @@
* implied.
*/
#include "./algo.h"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
#include "src/cuda/conv_bias/algo.h"
using namespace megdnn;
using namespace cuda;
......@@ -81,42 +80,6 @@ ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::get_constants(
return {alpha, beta, gamma, delta, theta};
}
void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec(
const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr,
ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta,
float gamma, float delta, float theta, cudaStream_t stream) const {
float dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale;
cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k};
cutlass_wrapper::GemmCoord warp_shape{
m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k};
if (kern_param.fh == 1 && kern_param.fw == 1) {
cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc<false>(
reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr),
reinterpret_cast<int8_t*>(filter_ptr),
reinterpret_cast<int32_t*>(bias_ptr),
reinterpret_cast<int8_t*>(z_ptr),
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale,
threadblock_shape, warp_shape, m_algo_param.access_size,
m_algo_param.stage, stream);
} else {
cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc<true>(
reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr),
reinterpret_cast<int8_t*>(filter_ptr),
reinterpret_cast<int32_t*>(bias_ptr),
reinterpret_cast<int8_t*>(z_ptr),
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale,
threadblock_shape, warp_shape, m_algo_param.access_size,
m_algo_param.stage, stream);
}
}
#endif
// vim: syntax=cpp.doxygen
......@@ -10,10 +10,9 @@
* implied.
*/
#include "./algo.h"
#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
#include "src/cuda/conv_bias/reduce_filter.cuh"
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/utils.h"
......@@ -102,22 +101,40 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec(
if (args.z_layout->ndim > 0)
z_ptr = args.z_tensor->raw_ptr;
// \note these constants of cutlass epilogue will be passed to method
// `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*,
// a different dtype here results in undefined epilogue behaviors
float alpha, beta, gamma, delta, theta;
std::tie(alpha, beta, gamma, delta, theta) = get_constants(args);
float dst_scale = 0.f;
float threshold = 0.f;
uint8_t src_zero = 0;
bool load_from_const = !(fh == 1 && fw == 1);
bool without_shared_load = true;
if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) {
dst_scale =
args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale;
src_zero = args.src_layout->dtype.param<dtype::Quantized4Asymm>()
.zero_point;
} else { // DTypeEnum::QuantizedS4
dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale;
}
ConvParam kern_param;
kern_param.n = n, kern_param.co = co, kern_param.ci = ci,
kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho,
kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw,
kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh,
kern_param.fw = fw;
cudaStream_t stream = cuda_stream(args.opr->handle());
uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode);
const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop,
ConvType::kConvolution,
load_from_const, without_shared_load);
cudaStream_t stream = cuda_stream(args.opr->handle());
execute_cutlass_conv_op(op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr,
z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi,
ci, co, fh, fw, ho, wo, ph, pw, sh, sw, dh, dw,
&alpha, &beta, &gamma, &delta, &theta, &threshold,
&dst_scale, stream, &src_zero);
do_exec(args, filter_ptr, bias_ptr, z_ptr, kern_param, nonlinear_mode,
alpha, beta, gamma, delta, theta, stream);
after_kernel_launch();
}
std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::to_string(
......
......@@ -10,10 +10,9 @@
* implied.
*/
#include "./algo.h"
#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
#include "src/cuda/conv_bias/reduce_filter.cuh"
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/utils.h"
......@@ -109,22 +108,43 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec(
if (args.z_layout->ndim > 0)
z_ptr = args.z_tensor->raw_ptr;
// \note these constants of cutlass epilogue will be passed to method
// `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*,
// a different dtype here results in undefined epilogue behaviors
float alpha, beta, gamma, delta, theta;
std::tie(alpha, beta, gamma, delta, theta) = get_constants(args);
float dst_scale = 0.f;
float threshold = 0.f;
uint8_t src_zero = 0;
bool load_from_const = !(fh == 1 && fw == 1);
bool without_shared_load = ((co % m_algo_param.threadblock_n == 0) &&
(m_algo_param.threadblock_n == 32 ||
m_algo_param.threadblock_n == 64));
if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) {
dst_scale =
args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale;
src_zero = args.src_layout->dtype.param<dtype::Quantized4Asymm>()
.zero_point;
} else { // DTypeEnum::QuantizedS4
dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale;
}
ConvParam kern_param;
kern_param.n = n, kern_param.co = co, kern_param.ci = ci,
kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho,
kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw,
kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh,
kern_param.fw = fw;
cudaStream_t stream = cuda_stream(args.opr->handle());
uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode);
const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop,
ConvType::kConvolution,
load_from_const, without_shared_load);
cudaStream_t stream = cuda_stream(args.opr->handle());
execute_cutlass_conv_op(op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr,
z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi,
ci, co, fh, fw, ho, wo, ph, pw, sh, sw, dh, dw,
&alpha, &beta, &gamma, &delta, &theta, &threshold,
&dst_scale, stream, &src_zero);
do_exec(args, filter_ptr, bias_ptr, z_ptr, kern_param, nonlinear_mode,
alpha, beta, gamma, delta, theta, stream);
after_kernel_launch();
}
std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::to_string(
......
......@@ -10,12 +10,11 @@
* implied.
*/
#include "./algo.h"
#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/utils.h"
#include "src/common/conv_bias.h"
using namespace megdnn;
using namespace cuda;
......@@ -38,8 +37,7 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::is_available(
bool available = true;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
if (!check_bias_share_in_channel(*(args.bias_layout),
param.format))
if (!check_bias_share_in_channel(*(args.bias_layout), param.format))
return false;
if (param.format != Format::NCHW32 && param.format != Format::NCHW32_NCHW4)
return false;
......@@ -137,19 +135,16 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
args.preprocessed_filter->tensors[0].raw_ptr);
}
ConvParam kern_param;
kern_param.n = n, kern_param.co = co, kern_param.ci = ci,
kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho,
kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw,
kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh,
kern_param.fw = fw;
float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale,
filter_scale =
args.filter_layout->dtype.param<dtype::QuantizedS8>().scale,
bias_scale =
args.bias_layout->dtype.param<dtype::QuantizedS32>().scale,
dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS8>().scale;
// \note these constants of cutlass epilogue will be passed to method
// `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*,
// a different dtype here results in undefined epilogue behaviors
float alpha = src_scale * filter_scale / dst_scale,
beta = bias_scale / dst_scale;
int8_t* z_dev_ptr = nullptr;
......@@ -159,80 +154,20 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
float z_scale = args.z_layout->dtype.param<dtype::QuantizedS8>().scale;
gamma = z_scale / dst_scale;
}
uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode);
if (fh == 1 && fw == 1) {
if (param.format == Format::NCHW32) {
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32<
false>(
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr,
args.dst_tensor->compatible_ptr<int8_t>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale,
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
m_algo_param.stage, stream);
} else {
megdnn_assert(param.format == Format::NCHW32_NCHW4);
cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4<
false>(
args.src_tensor->compatible_ptr<int8_t>(),
filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(),
z_dev_ptr,
args.dst_tensor->compatible_ptr<int8_t>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma,
dst_scale,
cutlass_wrapper::GemmCoord{
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
m_algo_param.stage, stream);
}
} else {
if (param.format == Format::NCHW32) {
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32<
true>(
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr,
args.dst_tensor->compatible_ptr<int8_t>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale,
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
m_algo_param.stage, stream);
} else {
megdnn_assert(param.format == Format::NCHW32_NCHW4);
cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4<
true>(
args.src_tensor->compatible_ptr<int8_t>(),
filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(),
z_dev_ptr,
args.dst_tensor->compatible_ptr<int8_t>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma,
dst_scale,
cutlass_wrapper::GemmCoord{
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
m_algo_param.stage, stream);
}
}
float delta = 0.f, theta = 0.f, threshold = 0.f;
bool load_from_const = !(fh == 1 && fw == 1);
bool without_shared_load = (param.format == Format::NCHW32);
const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop,
ConvType::kConvolution,
load_from_const, without_shared_load);
execute_cutlass_conv_op(
op, args.src_tensor->raw_ptr, filter_ptr, args.bias_tensor->raw_ptr,
z_dev_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh,
fw, ho, wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta,
&theta, &threshold, &dst_scale, stream);
after_kernel_launch();
}
......@@ -249,9 +184,8 @@ size_t ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::
return 0_z;
}
SmallVector<TensorLayout> ConvBiasForwardImpl::
AlgoInt8NCHW32IMMAImplicitGemm::deduce_preprocessed_filter_layout(
const SizeArgs& args) const {
SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::
deduce_preprocessed_filter_layout(const SizeArgs& args) const {
return {args.filter_layout->collapse_contiguous()};
}
......
......@@ -6,14 +6,14 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./algo.h"
#include "src/cuda/utils.h"
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
......@@ -34,8 +34,7 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available(
bool available = true;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
if (!check_bias_share_in_channel(*(args.bias_layout),
param.format))
if (!check_bias_share_in_channel(*(args.bias_layout), param.format))
return false;
bool valid_format = param.format == Format::NCHW4_NCHW32 &&
m_algo_param.threadblock_m % 32 == 0;
......@@ -48,7 +47,8 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available(
(args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 ||
args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm);
valid_format |= param.format == Format::NCHW4;
if (!valid_format) return false;
if (!valid_format)
return false;
size_t n = args.src_layout->operator[](0),
ci = args.src_layout->operator[](1) * 4,
hi = args.src_layout->operator[](2),
......@@ -170,16 +170,13 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
args.preprocessed_filter->tensors[0].raw_ptr);
}
convolution::ConvParam kern_param;
kern_param.n = n, kern_param.co = co, kern_param.ci = ci,
kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho,
kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw,
kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh,
kern_param.fw = fw;
float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale,
filter_scale =
args.filter_layout->dtype.param<dtype::QuantizedS8>().scale;
// \note these constants of cutlass epilogue will be passed to method
// `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*,
// a different dtype here results in undefined epilogue behaviors
float alpha = src_scale * filter_scale;
float beta = 1.f;
float dst_scale = 1.f;
......@@ -192,13 +189,15 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
if (args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32) {
megdnn_assert(args.dst_layout->dtype.category() ==
DTypeCategory::QUANTIZED);
float bias_scale = args.bias_layout->dtype.param<dtype::QuantizedS32>()
.scale;
float bias_scale =
args.bias_layout->dtype.param<dtype::QuantizedS32>().scale;
dst_scale = get_scale(args.dst_layout->dtype);
alpha /= dst_scale, beta = bias_scale / dst_scale;
}
float delta = 0.f;
void* z_ptr = nullptr;
if (args.z_layout->ndim > 0) {
z_ptr = args.z_tensor->raw_ptr;
gamma = 1.f;
if (args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) {
megdnn_assert(args.dst_layout->dtype.category() ==
......@@ -213,98 +212,20 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
delta = -z_zero * gamma;
}
}
uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode);
bool nonunity_kernel = !(fh == 1 && fw == 1);
#define DISPATCH(_nonunity_kernel) \
if (nonunity_kernel == _nonunity_kernel) { \
cb(_nonunity_kernel) \
}
if (param.format == Format::NCHW4) {
#define cb(_nonunity_kernel) \
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< \
_nonunity_kernel>( \
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \
args.bias_tensor->compatible_ptr<int32_t>(), \
args.z_tensor->compatible_ptr<int8_t>(), \
args.dst_tensor->compatible_ptr<int8_t>(), nullptr, kern_param, \
nonlinear_mode, alpha, beta, gamma, dst_scale, \
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \
m_algo_param.threadblock_n, \
m_algo_param.threadblock_k}, \
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \
m_algo_param.warp_n, \
m_algo_param.warp_k}, \
m_algo_param.stage, stream);
DISPATCH(true);
DISPATCH(false);
#undef cb
} else if (param.format == Format::NCHW4_NCHW) {
#define cb(_nonunity_kernel) \
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw< \
_nonunity_kernel>( \
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \
args.bias_tensor->compatible_ptr<float>(), \
args.z_tensor->compatible_ptr<float>(), \
args.dst_tensor->compatible_ptr<float>(), nullptr, kern_param, \
nonlinear_mode, alpha, beta, gamma, dst_scale, \
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \
m_algo_param.threadblock_n, \
m_algo_param.threadblock_k}, \
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \
m_algo_param.warp_n, \
m_algo_param.warp_k}, \
m_algo_param.stage, stream);
DISPATCH(true);
DISPATCH(false);
#undef cb
} else if (param.format == Format::NCHW4_NHWC) {
#define cb(_signedness) \
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc< \
_signedness>( \
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \
args.bias_tensor->compatible_ptr<int32_t>(), \
reinterpret_cast<int8_t*>(args.z_tensor->raw_ptr), \
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, \
kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, \
dst_scale, \
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \
m_algo_param.threadblock_n, \
m_algo_param.threadblock_k}, \
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \
m_algo_param.warp_n, \
m_algo_param.warp_k}, \
m_algo_param.stage, stream);
if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4) {
cb(true);
} else {
megdnn_assert(args.dst_layout->dtype.enumv() ==
DTypeEnum::Quantized4Asymm);
cb(false);
}
#undef cb
} else {
megdnn_assert(param.format == Format::NCHW4_NCHW32);
#define cb(_nonunity_kernel) \
cutlass_wrapper:: \
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< \
_nonunity_kernel>( \
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \
args.bias_tensor->compatible_ptr<int32_t>(), \
args.z_tensor->compatible_ptr<int8_t>(), \
args.dst_tensor->compatible_ptr<int8_t>(), nullptr, \
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, \
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \
m_algo_param.threadblock_n, \
m_algo_param.threadblock_k}, \
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \
m_algo_param.warp_n, \
m_algo_param.warp_k}, \
m_algo_param.stage, stream);
DISPATCH(true);
DISPATCH(false);
#undef cb
#undef DISPATCH
}
float threshold = 0.f;
bool load_from_const = !(fh == 1 && fw == 1);
bool without_shared_load = false;
const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop,
ConvType::kConvolution,
load_from_const, without_shared_load);
execute_cutlass_conv_op(
op, args.src_tensor->raw_ptr, filter_ptr, args.bias_tensor->raw_ptr,
z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw,
ho, wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta,
&theta, &threshold, &dst_scale, stream);
after_kernel_launch();
}
......
......@@ -10,8 +10,7 @@
* implied.
*/
#include "./algo.h"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/conv_bias/reduce_filter.cuh"
#include "src/cuda/utils.h"
......@@ -120,32 +119,15 @@ ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::get_constants(
delta = -z_zero * gamma;
}
return {alpha, beta, gamma, delta, theta};
}
// identity epilogue has no theta:
// alpha * accumulator + beta * bias + gamma * source + delta
if (args.opr->param().nonlineMode ==
param::ConvBias::NonlineMode::IDENTITY) {
delta += theta;
theta = 0.f;
}
void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::do_exec(
const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr,
ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta,
float gamma, float delta, float theta, cudaStream_t stream) const {
float dst_scale =
args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale;
uint8_t src_zero =
args.src_layout->dtype.param<dtype::Quantized4Asymm>().zero_point;
cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k};
cutlass_wrapper::GemmCoord warp_shape{
m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k};
cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64<
true>(reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr),
reinterpret_cast<int8_t*>(filter_ptr),
reinterpret_cast<int32_t*>(bias_ptr),
reinterpret_cast<uint8_t*>(z_ptr),
reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta,
dst_scale, src_zero, threadblock_shape, warp_shape,
m_algo_param.stage, stream);
return {alpha, beta, gamma, delta, theta};
}
void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias(
......
......@@ -10,8 +10,7 @@
* implied.
*/
#include "./algo.h"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/conv_bias/reduce_filter.cuh"
#include "src/cuda/utils.h"
......@@ -121,44 +120,15 @@ ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::get_constants(
delta = -z_zero * gamma;
}
return {alpha, beta, gamma, delta, theta};
}
void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec(
const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr,
ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta,
float gamma, float delta, float theta, cudaStream_t stream) const {
float dst_scale =
args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale;
uint8_t src_zero =
args.src_layout->dtype.param<dtype::Quantized4Asymm>().zero_point;
cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k};
cutlass_wrapper::GemmCoord warp_shape{
m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k};
if (kern_param.fh == 1 && kern_param.fw == 1) {
cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc<false>(
reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr),
reinterpret_cast<int8_t*>(filter_ptr),
reinterpret_cast<int32_t*>(bias_ptr),
reinterpret_cast<uint8_t*>(z_ptr),
reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta,
dst_scale, src_zero, threadblock_shape, warp_shape,
m_algo_param.access_size, m_algo_param.stage, stream);
} else {
cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc<true>(
reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr),
reinterpret_cast<int8_t*>(filter_ptr),
reinterpret_cast<int32_t*>(bias_ptr),
reinterpret_cast<uint8_t*>(z_ptr),
reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta,
dst_scale, src_zero, threadblock_shape, warp_shape,
m_algo_param.access_size, m_algo_param.stage, stream);
// identity epilogue has no theta:
// alpha * accumulator + beta * bias + gamma * source + delta
if (args.opr->param().nonlineMode ==
param::ConvBias::NonlineMode::IDENTITY) {
delta += theta;
theta = 0.f;
}
return {alpha, beta, gamma, delta, theta};
}
void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::update_bias(
......
......@@ -57,6 +57,7 @@ public:
class AlgoBatchedMatmul;
class AlgoGroupConvGeneral;
class AlgoQUInt4x4x32WMMA;
class AlgoCutlassConvolutionBase;
class AlgoInt8CHWN4DotProdImplicitGemm;
class AlgoInt8NCHW4DotProdImplicitGemm;
class AlgoInt8CHWN4IMMAImplicitGemm;
......
/**
* \file src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#if !MEGDNN_TEGRA_X1
#include "cutlass/convolution/device/convolution.h"
#endif
#include "src/common/opr_param_defs_enumv.cuh"
#include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh"
#pragma GCC diagnostic pop
using namespace megdnn;
using namespace cuda;
using namespace cutlass_wrapper;
/* ================ cutlass kernel wrapper for nchw4 layout ================= */
#if MEGDNN_TEGRA_X1
void megdnn::cuda::cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4(
const int8_t* /* d_src */, const int8_t* /* d_filter */,
int8_t* /* d_dst */, int* /* workspace */,
const convolution::ConvParam& /* param */, float /* alpha */,
const GemmCoord& /* threadblock_shape */,
const GemmCoord& /* warp_shape */, int /* stages */,
cudaStream_t /* stream */) {}
#else
void megdnn::cuda::cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4(
const int8_t* d_src, const int8_t* d_filter, int8_t* d_dst,
int* workspace, const convolution::ConvParam& param, float alpha,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
int stages, cudaStream_t stream) {
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
threadblock_k_, warp_m_, warp_n_, \
warp_k_, stage_, aligned_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_ && stages == stage_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \
using Deconvolution = cutlass::conv::device::Deconvolution< \
int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \
cutlass::layout::TensorKxRSCx<4>, ElementOutput, \
cutlass::layout::TensorNCxHWx<4>, int32_t, \
cutlass::layout::TensorNCxHWx<4>, int32_t, \
cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
cutlass::conv::threadblock:: \
ConvolutionDgradNCxHWxThreadblockSwizzle, \
stage_, 4, aligned_, true, cutlass::arch::OpMultiplyAdd>; \
typename Deconvolution::ConvolutionParameter conv_param( \
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
return cutlass_deconvolution_wrapper<Deconvolution>( \
d_src, d_filter, nullptr, nullptr, d_dst, workspace, \
conv_param, epilogue, stream); \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 64, 16, 2, 4); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d)", \
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k());
using ElementOutput = int8_t;
using ElementAccumulator = int32_t;
using ElementBias = int32_t;
using ElementCompute = float;
using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
ElementOutput, 4, ElementAccumulator, ElementBias, ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, 0, 0};
DISPATCH_KERNEL;
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE
#undef DISPATCH_KERNEL
}
#endif
// vim: syntax=cuda.doxygen
/**
* \file src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "cutlass/gemm/gemm.h"
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/utils.cuh"
namespace megdnn {
namespace cuda {
namespace cutlass_wrapper {
using GemmCoord = cutlass::gemm::GemmCoord;
template <typename Convolution>
void cutlass_deconvolution_wrapper(
const typename Convolution::ElementSrc* d_src,
const typename Convolution::ElementFilter* d_filter,
const typename Convolution::ElementBias* d_bias,
const typename Convolution::ElementDst* d_z,
typename Convolution::ElementDst* d_dst, int* workspace,
typename Convolution::ConvolutionParameter const& conv_param,
typename Convolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);
void do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4(
const int8_t* d_src, const int8_t* d_filter, int8_t* d_dst,
int* workspace, const convolution::ConvParam& param, float alpha,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
int stages, cudaStream_t stream);
} // namespace cutlass_wrapper
} // namespace cuda
} // namespace megdnn
// vim: syntax=cuda.doxygen
/**
* \file
* dnn/src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "cutlass/convolution/device/convolution.h"
#include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh"
using namespace megdnn;
using namespace cuda;
using namespace cutlass_wrapper;
template <typename Deconvolution>
void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper(
const typename Deconvolution::ElementSrc* d_src,
const typename Deconvolution::ElementFilter* d_filter,
const typename Deconvolution::ElementBias* d_bias,
const typename Deconvolution::ElementDst* d_z,
typename Deconvolution::ElementDst* d_dst, int* workspace,
typename Deconvolution::ConvolutionParameter const& conv_param,
typename Deconvolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream) {
typename Deconvolution::TensorRefSrc tensor_src{
const_cast<typename Deconvolution::ElementSrc*>(d_src),
Deconvolution::LayoutSrc::packed(
{conv_param.N, conv_param.P, conv_param.Q, conv_param.K})};
typename Deconvolution::TensorRefFilter tensor_filter{
const_cast<typename Deconvolution::ElementFilter*>(d_filter),
Deconvolution::LayoutFilter::packed(
{conv_param.K, conv_param.R, conv_param.S, conv_param.C})};
typename Deconvolution::TensorRefBias tensor_bias{
const_cast<typename Deconvolution::ElementBias*>(d_bias),
Deconvolution::LayoutBias::packed({1, 1, 1, conv_param.K})};
typename Deconvolution::TensorRefDst tensor_z{
const_cast<typename Deconvolution::ElementDst*>(d_z),
Deconvolution::LayoutDst::packed(
{conv_param.N, conv_param.H, conv_param.W, conv_param.C})};
typename Deconvolution::TensorRefDst tensor_dst{
d_dst,
Deconvolution::LayoutDst::packed(
{conv_param.N, conv_param.H, conv_param.W, conv_param.C})};
typename Deconvolution::Arguments arguments{conv_param,
tensor_src.non_const_ref(),
tensor_filter.non_const_ref(),
tensor_bias.non_const_ref(),
tensor_z.non_const_ref(),
tensor_dst.non_const_ref(),
epilogue};
Deconvolution deconv_op;
cutlass_check(deconv_op.initialize(arguments, workspace));
cutlass_check(deconv_op(stream));
after_kernel_launch();
}
// vim: syntax=cuda.doxygen
/**
* \file dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp
* \file
* dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -10,11 +11,11 @@
* implied.
*/
#include "./algo.h"
#include "src/cuda/utils.h"
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh"
#include "src/cuda/convolution/backward_data/algo.h"
#include "src/cuda/convolution/backward_data/deconv_int8_helper.cuh"
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/cutlass/singleton.h"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
......@@ -70,6 +71,7 @@ size_t ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::
void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
const ExecArgs& args) const {
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
size_t n = args.diff_layout->operator[](0),
co = args.diff_layout->operator[](1) * 4,
......@@ -81,6 +83,7 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
size_t fh = fm.spatial[0], fw = fm.spatial[1];
size_t sh = fm.stride[0], sw = fm.stride[1];
size_t ph = fm.padding[0], pw = fm.padding[1];
size_t dh = param.dilate_h, dw = param.dilate_w;
auto&& stream = cuda_stream(args.opr->handle());
......@@ -93,12 +96,6 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
filter_ptr, args.filter_tensor->compatible_ptr<int8_t>(), co,
ci, fh, fw, stream);
}
convolution::ConvParam kern_param;
kern_param.n = n, kern_param.co = co, kern_param.ci = ci,
kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho,
kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw,
kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh,
kern_param.fw = fw;
float diff_scale =
args.diff_layout->dtype.param<dtype::QuantizedS8>().scale,
......@@ -106,17 +103,60 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
args.filter_layout->dtype.param<dtype::QuantizedS8>().scale,
grad_scale =
args.grad_layout->dtype.param<dtype::QuantizedS8>().scale;
float alpha = diff_scale * filter_scale / grad_scale;
cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4(
args.diff_tensor->compatible_ptr<int8_t>(), filter_ptr,
args.grad_tensor->compatible_ptr<int8_t>(), nullptr, kern_param,
alpha,
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n,
m_algo_param.warp_k},
m_algo_param.stage, stream);
// \note these constants of cutlass epilogue will be passed to struct
// `ConvolutionArguments` by pointer and interpreted as ElementCompute*, a
// different dtype here results in undefined epilogue behaviors
float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f,
gamma = 0.f, delta = 0.f;
using namespace cutlass::library;
// only use 16x64x8_16x64x8_2stages impl
ConvolutionKey key{
cutlass::conv::Operator::kDgrad,
NumericTypeID::kS8,
LayoutTypeID::kTensorNC4HW4,
NumericTypeID::kS8,
LayoutTypeID::kTensorK4RSC4,
NumericTypeID::kS8,
LayoutTypeID::kTensorNC4HW4,
NumericTypeID::kS32,
LayoutTypeID::kTensorNC4HW4,
cutlass::conv::ConvType::kConvolution,
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k,
m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k,
1,
1,
4,
cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp,
m_algo_param.stage,
true,
false};
const Operation* op = Singleton::get().operation_table.find_op(key);
// gcc prints warnings when size_t values are implicitly narrowed to int
cutlass::conv::Conv2dProblemSize problem_size{
int(n), int(hi), int(wi), int(ci),
int(co), int(fh), int(fw), int(ho),
int(wo), int(ph), int(pw), int(sh),
int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation};
cutlass::library::ConvolutionArguments conv_args{
problem_size, args.diff_tensor->compatible_ptr<int8_t>(),
filter_ptr, nullptr,
nullptr, args.grad_tensor->compatible_ptr<int8_t>(),
&alpha, &beta,
&gamma, &delta,
nullptr, nullptr,
nullptr, nullptr};
cutlass_check(op->run(&conv_args, nullptr, stream));
after_kernel_launch();
}
......
......@@ -11,16 +11,16 @@
* implied.
*/
#include "./algo.h"
#include "src/cuda/utils.h"
#include "src/cuda/convolution/backward_data/algo.h"
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh"
#include "src/cuda/cutlass/singleton.h"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::
is_available(const SizeArgs& args) const {
bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::is_available(
const SizeArgs& args) const {
auto&& fm = args.filter_meta;
if (fm.format != Param::Format::NCHW)
return false;
......@@ -42,7 +42,8 @@ bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::
// TODO support group deconv int8
available &= (fm.group == 1);
// ic and oc must be multiples of 4
available &= ((fm.group * fm.icpg) % 4 == 0 && (fm.group * fm.ocpg) % 4 == 0);
available &=
((fm.group * fm.icpg) % 4 == 0 && (fm.group * fm.ocpg) % 4 == 0);
// mode must be cross correlation
available &= !fm.should_flip;
// mode must be 2D
......@@ -73,6 +74,7 @@ size_t ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::
void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec(
const ExecArgs& args) const {
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
size_t n = args.diff_layout->operator[](0),
co = args.diff_layout->operator[](1),
......@@ -84,6 +86,7 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec(
size_t fh = fm.spatial[0], fw = fm.spatial[1];
size_t sh = fm.stride[0], sw = fm.stride[1];
size_t ph = fm.padding[0], pw = fm.padding[1];
size_t dh = param.dilate_h, dw = param.dilate_w;
auto&& stream = cuda_stream(args.opr->handle());
......@@ -120,26 +123,63 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec(
}
int8_t* inner_grad_ptr = reinterpret_cast<int8_t*>(bundle.get(2));
convolution::ConvParam kern_param;
kern_param.n = n, kern_param.co = co, kern_param.ci = ci,
kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho,
kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw,
kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh,
kern_param.fw = fw;
float diff_scale =
args.diff_layout->dtype.param<dtype::QuantizedS8>().scale,
filter_scale =
args.filter_layout->dtype.param<dtype::QuantizedS8>().scale,
grad_scale =
args.grad_layout->dtype.param<dtype::QuantizedS8>().scale;
float alpha = diff_scale * filter_scale / grad_scale;
// \note these constants of cutlass epilogue will be passed to struct
// `ConvolutionArguments` by pointer and interpreted as ElementCompute*, a
// different dtype here results in undefined epilogue behaviors
float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f,
gamma = 0.f, delta = 0.f;
using namespace cutlass::library;
// only use 16x64x8_16x64x8_2stages impl
cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4(
inner_diff_ptr, inner_filter_ptr, inner_grad_ptr, nullptr,
kern_param, alpha, cutlass_wrapper::GemmCoord{16, 64, 8},
cutlass_wrapper::GemmCoord{16, 64, 8}, 2, stream);
ConvolutionKey key{
cutlass::conv::Operator::kDgrad,
NumericTypeID::kS8,
LayoutTypeID::kTensorNC4HW4,
NumericTypeID::kS8,
LayoutTypeID::kTensorK4RSC4,
NumericTypeID::kS8,
LayoutTypeID::kTensorNC4HW4,
NumericTypeID::kS32,
LayoutTypeID::kTensorNC4HW4,
cutlass::conv::ConvType::kConvolution,
16,
64,
8,
16,
64,
8,
1,
1,
4,
cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp,
2,
true,
false};
const Operation* op = Singleton::get().operation_table.find_op(key);
// gcc prints warnings when size_t values are implicitly narrowed to int
cutlass::conv::Conv2dProblemSize problem_size{
int(n), int(hi), int(wi), int(ci),
int(co), int(fh), int(fw), int(ho),
int(wo), int(ph), int(pw), int(sh),
int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation};
cutlass::library::ConvolutionArguments conv_args{
problem_size, inner_diff_ptr, inner_filter_ptr, nullptr,
nullptr, inner_grad_ptr, &alpha, &beta,
&gamma, &delta, nullptr, nullptr,
nullptr, nullptr};
cutlass_check(op->run(&conv_args, nullptr, stream));
after_kernel_launch();
......
/***************************************************************************************************
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
*modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
*notice, this list of conditions and the following disclaimer in the
*documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its
*contributors may be used to endorse or promote products derived from this
*software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/cutlass/arch_mappings.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename ArchTag, typename OperatorClass>
struct ArchMap;
template <>
struct ArchMap<arch::Sm50, arch::OpClassSimt> {
static int const kMin = 50;
static int const kMax = 1024;
};
template <>
struct ArchMap<arch::Sm60, arch::OpClassSimt> {
static int const kMin = 60;
static int const kMax = 1024;
};
template <>
struct ArchMap<arch::Sm61, arch::OpClassSimt> {
static int const kMin = 61;
static int const kMax = 1024;
};
template <>
struct ArchMap<arch::Sm70, arch::OpClassWmmaTensorOp> {
static int const kMin = 70;
static int const kMax = 1024;
};
template <>
struct ArchMap<arch::Sm70, arch::OpClassTensorOp> {
static int const kMin = 70;
static int const kMax = 75;
};
template <typename OperatorClass>
struct ArchMap<arch::Sm75, OperatorClass> {
static int const kMin = 75;
static int const kMax = 1024;
};
template <typename OperatorClass>
struct ArchMap<arch::Sm80, OperatorClass> {
static int const kMin = 80;
static int const kMax = 1024;
};
template <typename OperatorClass>
struct ArchMap<arch::Sm86, OperatorClass> {
static int const kMin = 86;
static int const kMax = 1024;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
*modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
*notice, this list of conditions and the following disclaimer in the
*documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its
*contributors may be used to endorse or promote products derived from this
*software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/cutlass/convolution_operation.h
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "cutlass/convolution/device/convolution.h"
#include "src/cuda/cutlass/library_internal.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator_>
class ConvolutionOperationBase : public Operation {
public:
using Operator = Operator_;
using ElementSrc = typename Operator::ElementSrc;
using LayoutSrc = typename Operator::LayoutSrc;
using ElementFilter = typename Operator::ElementFilter;
using LayoutFilter = typename Operator::LayoutFilter;
using ElementDst = typename Operator::ElementDst;
using LayoutDst = typename Operator::LayoutDst;
using ElementBias = typename Operator::ElementBias;
using LayoutBias = typename Operator::LayoutBias;
using ElementAccumulator = typename Operator::ElementAccumulator;
ConvolutionOperationBase(char const* name = "unknown_convolution") {
m_description.name = name;
m_description.provider = Provider::kCUTLASS;
m_description.kind = OperationKind::kConvolution;
m_description.conv_op = Operator::kConvolutionalOperator;
m_description.tile_description.threadblock_shape = make_Coord(
Operator::ThreadblockShape::kM, Operator::ThreadblockShape::kN,
Operator::ThreadblockShape::kK);
m_description.tile_description.threadblock_stages = Operator::kStages;
m_description.tile_description.warp_count =
make_Coord(Operator::ConvolutionKernel::WarpCount::kM,
Operator::ConvolutionKernel::WarpCount::kN,
Operator::ConvolutionKernel::WarpCount::kK);
m_description.tile_description.math_instruction.instruction_shape =
make_Coord(Operator::InstructionShape::kM,
Operator::InstructionShape::kN,
Operator::InstructionShape::kK);
m_description.tile_description.math_instruction.element_accumulator =
NumericTypeMap<ElementAccumulator>::kId;
m_description.tile_description.math_instruction.opcode_class =
OpcodeClassMap<typename Operator::OperatorClass>::kId;
m_description.tile_description.math_instruction.math_operation =
MathOperationMap<typename Operator::Operator>::kId;
m_description.tile_description.minimum_compute_capability =
ArchMap<typename Operator::ArchTag,
typename Operator::OperatorClass>::kMin;
m_description.tile_description.maximum_compute_capability =
ArchMap<typename Operator::ArchTag,
typename Operator::OperatorClass>::kMax;
m_description.src = make_TensorDescription<ElementSrc, LayoutSrc>(
Operator::kAlignmentSrc);
m_description.filter =
make_TensorDescription<ElementFilter, LayoutFilter>(
Operator::kAlignmentFilter);
m_description.dst = make_TensorDescription<ElementDst, LayoutDst>(
Operator::kAlignmentDst);
m_description.bias = make_TensorDescription<ElementBias, LayoutBias>(
Operator::kAlignmentDst);
m_description.convolution_type = Operator::kConvolutionType;
m_description.arch_tag = ArchTagMap<typename Operator::ArchTag>::kId;
m_description.epilogue_type = Operator::EpilogueOutputOp::kType;
m_description.epilogue_count = Operator::EpilogueOutputOp::kCount;
m_description.threadblock_swizzle = ThreadblockSwizzleMap<
typename Operator::ThreadblockSwizzle>::kId;
m_description.need_load_from_const_mem =
Operator::kNeedLoadFromConstMem;
m_description.gemm_mode = Operator::kGemmMode;
m_description.without_shared_load = Operator::kWithoutSharedLoad;
}
virtual OperationDescription const& description() const {
return m_description;
}
protected:
ConvolutionDescription m_description;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
template <typename EpilogueOp, epilogue::EpilogueType type>
struct init_epilogue_param_;
template <typename EpilogueOp>
struct init_epilogue_param_<EpilogueOp,
epilogue::EpilogueType::kBiasAddLinearCombination> {
using ElementCompute = typename EpilogueOp::ElementCompute;
typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) {
return {*static_cast<ElementCompute const*>(conv_args->alpha),
*static_cast<ElementCompute const*>(conv_args->beta),
*static_cast<ElementCompute const*>(conv_args->gamma),
*static_cast<ElementCompute const*>(conv_args->delta)};
}
};
template <typename EpilogueOp>
struct init_epilogue_param_<
EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombinationClamp> {
using ElementCompute = typename EpilogueOp::ElementCompute;
typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) {
return {*static_cast<ElementCompute const*>(conv_args->alpha),
*static_cast<ElementCompute const*>(conv_args->beta),
*static_cast<ElementCompute const*>(conv_args->gamma),
*static_cast<ElementCompute const*>(conv_args->delta)};
}
};
template <typename EpilogueOp>
struct init_epilogue_param_<
EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombinationRelu> {
using ElementCompute = typename EpilogueOp::ElementCompute;
typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) {
return {*static_cast<ElementCompute const*>(conv_args->alpha),
*static_cast<ElementCompute const*>(conv_args->beta),
*static_cast<ElementCompute const*>(conv_args->gamma),
*static_cast<ElementCompute const*>(conv_args->threshold),
*static_cast<ElementCompute const*>(conv_args->delta),
*static_cast<ElementCompute const*>(conv_args->theta)};
}
};
template <typename EpilogueOp>
struct init_epilogue_param_<
EpilogueOp,
epilogue::EpilogueType::kBiasAddLinearCombinationReluClamp> {
using ElementCompute = typename EpilogueOp::ElementCompute;
typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) {
return {*static_cast<ElementCompute const*>(conv_args->alpha),
*static_cast<ElementCompute const*>(conv_args->beta),
*static_cast<ElementCompute const*>(conv_args->gamma),
*static_cast<ElementCompute const*>(conv_args->threshold),
*static_cast<ElementCompute const*>(conv_args->delta),
*static_cast<ElementCompute const*>(conv_args->theta)};
}
};
template <typename EpilogueOp>
struct init_epilogue_param_<
EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombinationHSwish> {
using ElementCompute = typename EpilogueOp::ElementCompute;
typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) {
return {*static_cast<ElementCompute const*>(conv_args->alpha),
*static_cast<ElementCompute const*>(conv_args->beta),
*static_cast<ElementCompute const*>(conv_args->gamma),
*static_cast<ElementCompute const*>(conv_args->scale),
*static_cast<ElementCompute const*>(conv_args->delta),
*static_cast<ElementCompute const*>(conv_args->theta)};
}
};
template <typename EpilogueOp>
struct init_epilogue_param_<
EpilogueOp,
epilogue::EpilogueType::kBiasAddLinearCombinationHSwishClamp> {
using ElementCompute = typename EpilogueOp::ElementCompute;
typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) {
return {*static_cast<ElementCompute const*>(conv_args->alpha),
*static_cast<ElementCompute const*>(conv_args->beta),
*static_cast<ElementCompute const*>(conv_args->gamma),
*static_cast<ElementCompute const*>(conv_args->scale),
*static_cast<ElementCompute const*>(conv_args->delta),
*static_cast<ElementCompute const*>(conv_args->theta)};
}
};
} // namespace detail
template <typename EpilogueOp>
struct init_epilogue_param
: public detail::init_epilogue_param_<EpilogueOp, EpilogueOp::kType> {};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator_>
class ConvolutionOperation : public ConvolutionOperationBase<Operator_> {
public:
using Operator = Operator_;
using ElementSrc = typename Operator::ElementSrc;
using LayoutSrc = typename Operator::LayoutSrc;
using ElementFilter = typename Operator::ElementFilter;
using LayoutFilter = typename Operator::LayoutFilter;
using ElementBias = typename Operator::ElementBias;
using LayoutBias = typename Operator::LayoutBias;
using ElementDst = typename Operator::ElementDst;
using LayoutDst = typename Operator::LayoutDst;
using ElementAccumulator = typename Operator::ElementAccumulator;
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
using OperatorArguments = typename Operator::Arguments;
ConvolutionOperation(char const* name = "unknown_gemm")
: ConvolutionOperationBase<Operator_>(name) {}
virtual Status run(void const* arguments_ptr,
void* device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
cutlass::conv::Operator conv_op = this->m_description.conv_op;
ConvolutionArguments const* conv_args =
reinterpret_cast<ConvolutionArguments const*>(arguments_ptr);
const auto& ps = conv_args->problem_size;
OperatorArguments args;
args.problem_size = ps;
args.ref_src = {
static_cast<ElementSrc*>(const_cast<void*>(conv_args->src)),
LayoutSrc::packed(implicit_gemm_tensor_a_extent(conv_op, ps))};
args.ref_filter = {static_cast<ElementFilter*>(
const_cast<void*>(conv_args->filter)),
LayoutFilter::packed(
implicit_gemm_tensor_b_extent(conv_op, ps))};
args.ref_bias = {
static_cast<ElementBias*>(const_cast<void*>(conv_args->bias)),
LayoutBias::packed(
implicit_gemm_tensor_bias_extent(conv_op, ps))};
args.ref_z = {
static_cast<ElementDst*>(const_cast<void*>(conv_args->z)),
LayoutDst::packed(implicit_gemm_tensor_c_extent(conv_op, ps))};
args.ref_dst = {
static_cast<ElementDst*>(conv_args->dst),
LayoutDst::packed(implicit_gemm_tensor_c_extent(conv_op, ps))};
args.output_op =
init_epilogue_param<typename Operator::EpilogueOutputOp>().get(
conv_args);
if (conv_args->extra_param) {
args.extra_param =
*reinterpret_cast<typename Operator::ExtraParam const*>(
conv_args->extra_param);
}
Operator op;
Status status = op.initialize(args, device_workspace);
if (status != Status::kSuccess) {
return status;
}
return op.run(stream);
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
*modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
*notice, this list of conditions and the following disclaimer in the
*documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its
*contributors may be used to endorse or promote products derived from this
*software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/cutlass/gemm_operation.h
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "cutlass/gemm/device/gemm.h"
#include "src/cuda/cutlass/library_internal.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Check whether Operator has member ReductionKernel using SFINAE (Substitution
/// Failure Is Not An Error)
template <typename Operator>
struct split_k_mode {
template <typename T>
static char check(typename T::ReductionKernel*);
template <typename T>
static int check(...);
SplitKMode operator()() {
if (sizeof(check<Operator>(0)) == sizeof(char)) {
// cutlass::gemm::device::GemmSplitKParallel
return SplitKMode::kParallel;
} else {
// cutlass::gemm::device::Gemm
return SplitKMode::kNone;
}
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator_>
class GemmOperationBase : public Operation {
public:
using Operator = Operator_;
using ElementA = typename Operator::ElementA;
using LayoutA = typename Operator::LayoutA;
using ElementB = typename Operator::ElementB;
using LayoutB = typename Operator::LayoutB;
using ElementC = typename Operator::ElementC;
using LayoutC = typename Operator::LayoutC;
using ElementAccumulator = typename Operator::ElementAccumulator;
GemmOperationBase(char const* name = "unknown_gemm") {
m_description.name = name;
m_description.provider = Provider::kCUTLASS;
m_description.kind = OperationKind::kGemm;
m_description.gemm_kind = GemmKind::kGemm;
m_description.tile_description.threadblock_shape = make_Coord(
Operator::ThreadblockShape::kM, Operator::ThreadblockShape::kN,
Operator::ThreadblockShape::kK);
m_description.tile_description.threadblock_stages = Operator::kStages;
m_description.tile_description.warp_count =
make_Coord(Operator::GemmKernel::WarpCount::kM,
Operator::GemmKernel::WarpCount::kN,
Operator::GemmKernel::WarpCount::kK);
m_description.tile_description.math_instruction.instruction_shape =
make_Coord(Operator::InstructionShape::kM,
Operator::InstructionShape::kN,
Operator::InstructionShape::kK);
m_description.tile_description.math_instruction.element_accumulator =
NumericTypeMap<ElementAccumulator>::kId;
m_description.tile_description.math_instruction.opcode_class =
OpcodeClassMap<typename Operator::OperatorClass>::kId;
m_description.tile_description.math_instruction.math_operation =
MathOperationMap<typename Operator::Operator>::kId;
m_description.tile_description.minimum_compute_capability =
ArchMap<typename Operator::ArchTag,
typename Operator::OperatorClass>::kMin;
m_description.tile_description.maximum_compute_capability =
ArchMap<typename Operator::ArchTag,
typename Operator::OperatorClass>::kMax;
m_description.A = make_TensorDescription<ElementA, LayoutA>(
Operator::kAlignmentA);
m_description.B = make_TensorDescription<ElementB, LayoutB>(
Operator::kAlignmentB);
m_description.C = make_TensorDescription<ElementC, LayoutC>(
Operator::kAlignmentC);
m_description.stages = Operator::kStages;
split_k_mode<Operator> mode;
m_description.split_k_mode = mode();
}
virtual OperationDescription const& description() const {
return m_description;
}
protected:
GemmDescription m_description;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator_>
class GemmOperation : public GemmOperationBase<Operator_> {
public:
using Operator = Operator_;
using ElementA = typename Operator::ElementA;
using LayoutA = typename Operator::LayoutA;
using ElementB = typename Operator::ElementB;
using LayoutB = typename Operator::LayoutB;
using ElementC = typename Operator::ElementC;
using LayoutC = typename Operator::LayoutC;
using ElementAccumulator = typename Operator::ElementAccumulator;
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
using OperatorArguments = typename Operator::Arguments;
GemmOperation(char const* name = "unknown_gemm")
: GemmOperationBase<Operator_>(name) {}
virtual Status run(void const* arguments_ptr,
void* device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
GemmArguments const* gemm_args =
reinterpret_cast<GemmArguments const*>(arguments_ptr);
OperatorArguments args;
args.problem_size = gemm_args->problem_size;
args.ref_A = {static_cast<ElementA const*>(gemm_args->A),
int(gemm_args->lda)};
args.ref_B = {static_cast<ElementB const*>(gemm_args->B),
int(gemm_args->ldb)};
args.ref_C = {static_cast<ElementC const*>(gemm_args->C),
int(gemm_args->ldc)};
args.ref_D = {static_cast<ElementC*>(gemm_args->D),
int(gemm_args->ldd)};
args.split_k_slices = gemm_args->split_k_slices;
args.epilogue = {*static_cast<ElementCompute const*>(gemm_args->alpha),
*static_cast<ElementCompute const*>(gemm_args->beta)};
Operator op;
Status status = op.initialize(args, device_workspace);
if (status != Status::kSuccess) {
return status;
}
return op.run(stream);
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
*modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
*notice, this list of conditions and the following disclaimer in the
*documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its
*contributors may be used to endorse or promote products derived from this
*software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/cutlass/initialize_all.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/cutlass/manifest.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
/////////////////////////////////////////////////////////////////////////////////////////////////
#if __CUDACC_VER_MAJOR__ > 9 || \
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
void initialize_all_gemm_simt_operations(Manifest& manifest);
void initialize_all_conv2d_simt_operations(Manifest& manifest);
void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest);
void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest);
void initialize_all_deconv_simt_operations(Manifest& manifest);
void initialize_all(Manifest& manifest) {
initialize_all_gemm_simt_operations(manifest);
initialize_all_conv2d_simt_operations(manifest);
initialize_all_conv2d_tensorop8816_operations(manifest);
initialize_all_conv2d_tensorop8832_operations(manifest);
initialize_all_deconv_simt_operations(manifest);
}
#else
void initialize_all(Manifest& manifest) {}
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
此差异已折叠。
此差异已折叠。
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
*modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
*notice, this list of conditions and the following disclaimer in the
*documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its
*contributors may be used to endorse or promote products derived from this
*software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/cutlass/manifest.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include <memory>
#include "src/cuda/cutlass/manifest.h"
namespace cutlass {
namespace library {
//////////////////////////////////////////////////////////////////////////////////////////////////////////
/// Top-level initialization
Status Manifest::initialize() {
if (!operations_.empty()) {
operations_.clear();
}
// initialize procedurally generated cutlass op in manifest object
initialize_all(*this);
return Status::kSuccess;
}
/// Used for initialization
void Manifest::reserve(size_t operation_count) {
operations_.reserve(operation_count);
}
/// Graceful shutdown
Status Manifest::release() {
operations_.clear();
return Status::kSuccess;
}
/// Appends an operation and takes ownership
void Manifest::append(Operation* operation_ptr) {
operations_.emplace_back(operation_ptr);
}
/// Returns an iterator to the first operation
OperationVector const& Manifest::operations() const {
return operations_;
}
/// Returns a const iterator
OperationVector::const_iterator Manifest::begin() const {
return operations_.begin();
}
/// Returns a const iterator
OperationVector::const_iterator Manifest::end() const {
return operations_.end();
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
*modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
*notice, this list of conditions and the following disclaimer in the
*documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its
*contributors may be used to endorse or promote products derived from this
*software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/cutlass/manifest.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <list>
#include <map>
#include <memory>
#include "src/cuda/cutlass/library.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
// Forward declaration
class Manifest;
// init and insert all cutlass gemm operations in manifest object (procedurally
// generated using generator.py)
void initialize_all(Manifest& manifest);
/////////////////////////////////////////////////////////////////////////////////////////////////////////
/// List of operations
using OperationVector = std::vector<std::unique_ptr<Operation>>;
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Manifest of CUTLASS Library
class Manifest {
private:
/// Operation provider
Provider provider_;
/// Global list of operations
OperationVector operations_;
public:
Manifest(Provider provider = library::Provider::kCUTLASS)
: provider_(provider) {}
/// Top-level initialization
Status initialize();
/// Used for initialization
void reserve(size_t operation_count);
/// Graceful shutdown
Status release();
/// Appends an operation and takes ownership
void append(Operation* operation_ptr);
/// Returns an iterator to the first operation
OperationVector const& operations() const;
/// Returns a const iterator
OperationVector::const_iterator begin() const;
/// Returns a const iterator
OperationVector::const_iterator end() const;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
*modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
*notice, this list of conditions and the following disclaimer in the
*documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its
*contributors may be used to endorse or promote products derived from this
*software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
* \file dnn/src/cuda/cutlass/operation_table.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/common/utils.h"
#include "src/cuda/cutlass/operation_table.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
/////////////////////////////////////////////////////////////////////////////////////////////////
GemmKey get_gemm_key_from_desc(const GemmDescription& desc) {
GemmKey key;
key.element_A = desc.A.element;
key.layout_A = desc.A.layout;
key.element_B = desc.B.element;
key.layout_B = desc.B.layout;
key.element_C = desc.C.element;
key.layout_C = desc.C.layout;
key.threadblock_shape_m = desc.tile_description.threadblock_shape.m();
key.threadblock_shape_n = desc.tile_description.threadblock_shape.n();
key.threadblock_shape_k = desc.tile_description.threadblock_shape.k();
key.warp_shape_m = desc.tile_description.threadblock_shape.m() /
desc.tile_description.warp_count.m();
key.warp_shape_n = desc.tile_description.threadblock_shape.n() /
desc.tile_description.warp_count.n();
key.warp_shape_k = desc.tile_description.threadblock_shape.k() /
desc.tile_description.warp_count.k();
key.instruction_shape_m =
desc.tile_description.math_instruction.instruction_shape.m();
key.instruction_shape_n =
desc.tile_description.math_instruction.instruction_shape.n();
key.instruction_shape_k =
desc.tile_description.math_instruction.instruction_shape.k();
key.stages = desc.stages;
key.split_k_mode = desc.split_k_mode;
return key;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
ConvolutionKey get_convolution_key_from_desc(
const ConvolutionDescription& desc) {
ConvolutionKey key;
key.conv_op = desc.conv_op;
key.element_src = desc.src.element;
key.layout_src = desc.src.layout;
key.element_filter = desc.filter.element;
key.layout_filter = desc.filter.layout;
key.element_dst = desc.dst.element;
key.layout_dst = desc.dst.layout;
key.element_bias = desc.bias.element;
key.layout_bias = desc.bias.layout;
key.convolution_type = desc.convolution_type;
key.threadblock_shape_m = desc.tile_description.threadblock_shape.m();
key.threadblock_shape_n = desc.tile_description.threadblock_shape.n();
key.threadblock_shape_k = desc.tile_description.threadblock_shape.k();
key.warp_shape_m = desc.tile_description.threadblock_shape.m() /
desc.tile_description.warp_count.m();
key.warp_shape_n = desc.tile_description.threadblock_shape.n() /
desc.tile_description.warp_count.n();
key.warp_shape_k = desc.tile_description.threadblock_shape.k() /
desc.tile_description.warp_count.k();
key.instruction_shape_m =
desc.tile_description.math_instruction.instruction_shape.m();
key.instruction_shape_n =
desc.tile_description.math_instruction.instruction_shape.n();
key.instruction_shape_k =
desc.tile_description.math_instruction.instruction_shape.k();
key.epilogue_type = desc.epilogue_type;
key.stages = desc.tile_description.threadblock_stages;
key.need_load_from_const_mem = desc.need_load_from_const_mem;
key.without_shared_load = desc.without_shared_load;
return key;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
void OperationTable::append(Manifest const& manifest) {
// Insert operations into appropriate data structure
for (auto const& operation : manifest) {
OperationDescription const& desc = operation->description();
// insert all gemm operations into operation table
if (desc.kind == OperationKind::kGemm) {
GemmKey key = get_gemm_key_from_desc(
static_cast<GemmDescription const&>(desc));
gemm_operations[key].push_back(operation.get());
}
// insert all conv operations into operation table
if (desc.kind == OperationKind::kConvolution) {
ConvolutionKey key = get_convolution_key_from_desc(
static_cast<ConvolutionDescription const&>(desc));
convolution_operations[key].push_back(operation.get());
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
Operation const* OperationTable::find_op(GemmKey const& key) const {
megdnn_assert(gemm_operations.count(key) > 0,
"key not found in cutlass operation table");
auto const& ops = gemm_operations.at(key);
megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu",
ops.size());
return ops[0];
}
/////////////////////////////////////////////////////////////////////////////////////////////////
Operation const* OperationTable::find_op(ConvolutionKey const& key) const {
megdnn_assert(convolution_operations.count(key) > 0,
"key not found in cutlass operation table");
auto const& ops = convolution_operations.at(key);
megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu",
ops.size());
return ops[0];
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -21,22 +21,6 @@ namespace cutlass_wrapper {
using GemmCoord = cutlass::gemm::GemmCoord;
using BatchedGemmCoord = cutlass::gemm::BatchedGemmCoord;
template <typename Gemm>
void cutlass_matrix_mul_wrapper(
const typename Gemm::ElementA* d_A, size_t lda,
const typename Gemm::ElementB* d_B, size_t ldb,
typename Gemm::ElementC* d_C, size_t ldc, int* workspace,
GemmCoord const& problem_size,
typename Gemm::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream, int split_k_slices = 1);
void cutlass_matrix_mul_float32_simt(
const float* d_A, bool transpose_A, size_t lda, const float* d_B,
bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace,
GemmCoord const& problem_size, float alpha, float beta,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
cudaStream_t stream, int split_k_slices = 1);
template <typename GemvKernel>
void cutlass_vector_matrix_mul_batched_strided_wrapper(
BatchedGemmCoord const& problem_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册