提交 888f4e46 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add implicit bmm large kernel dwconv2d dgrad kernels

GitOrigin-RevId: fcb7974d62f4cc032544d507144e121bcf01c565
上级 08d8635f
......@@ -13,6 +13,10 @@ genrule(
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8816 $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8832 $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_fprop --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_fprop --type tensorop884 $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_dgrad --type simt $(@D)
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_dgrad --type tensorop884 $(@D)
""",
tools = ["//brain/megbrain/dnn/scripts/cutlass_generator:generator.py"],
visibility = ["//visibility:public"],
......
......@@ -545,8 +545,9 @@ def GenerateConv2d(
epilogue: EpilogueFunctor, conv_kind: ConvKind
) -> bool:
return (
conv_kind == ConvKind.Dgrad
(conv_kind == ConvKind.Dgrad or conv_kind == ConvKind.Wgrad)
and epilogue != EpilogueFunctor.BiasAddLinearCombinationClamp
and epilogue != EpilogueFunctor.BiasAddLinearCombination
)
# loop over all tile descriptions
......
......@@ -3,6 +3,8 @@ from generator import (
GenerateGemvOperations,
GenerateConv2dOperations,
GenerateDeconvOperations,
GenerateDwconv2dFpropOperations,
GenerateDwconv2dDgradOperations,
)
......@@ -21,6 +23,12 @@ def write_op_list(f, gen_op, gen_type):
operations = GenerateConv2dOperations(GenArg(gen_op, gen_type))
elif gen_op == "deconv":
operations = GenerateDeconvOperations(GenArg(gen_op, gen_type))
elif gen_op == "dwconv2d_fprop":
operations = GenerateDwconv2dFpropOperations(GenArg(gen_op, gen_type))
elif gen_op == "dwconv2d_dgrad":
operations = GenerateDwconv2dDgradOperations(GenArg(gen_op, gen_type))
elif gen_op == "dwconv2d_wgrad":
pass
for op in operations:
f.write(' "%s.cu",\n' % op.procedural_name())
if gen_op != "gemv":
......@@ -40,4 +48,8 @@ if __name__ == "__main__":
write_op_list(f, "conv2d", "simt")
write_op_list(f, "conv2d", "tensorop8816")
write_op_list(f, "conv2d", "tensorop8832")
write_op_list(f, "dwconv2d_fprop", "simt")
write_op_list(f, "dwconv2d_fprop", "tensorop884")
write_op_list(f, "dwconv2d_dgrad", "simt")
write_op_list(f, "dwconv2d_dgrad", "tensorop884")
f.write("]")
......@@ -1056,7 +1056,8 @@ def GenerateGemm_Simt(args):
return operations
def GenerateDwconv2dFprop_Simt(args):
#
def GenerateDwconv2d_Simt(args, conv_kind):
################################################################################
# warps per threadblock
################################################################################
......@@ -1121,10 +1122,10 @@ def GenerateDwconv2dFprop_Simt(args):
tile_descriptions = [
TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 8], 2, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
TileDescription([64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
TileDescription([32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
......@@ -1232,7 +1233,7 @@ def GenerateDwconv2dFprop_Simt(args):
for alignment_src in alignment_constraints:
operations += GenerateConv2d(
ConvType.DepthwiseConvolution,
ConvKind.Fprop,
conv_kind,
[tile],
layout[0],
layout[1],
......@@ -1249,7 +1250,7 @@ def GenerateDwconv2dFprop_Simt(args):
#
def GenerateDwconv2dFprop_TensorOp_884(args):
def GenerateDwconv2d_TensorOp_884(args, conv_kind):
layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)]
math_instructions = [
......@@ -1296,7 +1297,7 @@ def GenerateDwconv2dFprop_TensorOp_884(args):
for alignment_src in alignment_constraints:
operations += GenerateConv2d(
ConvType.DepthwiseConvolution,
ConvKind.Fprop,
conv_kind,
tile_descriptions,
layout[0],
layout[1],
......@@ -1574,13 +1575,24 @@ def GenerateDeconvOperations(args):
def GenerateDwconv2dFpropOperations(args):
if args.type == "simt":
return GenerateDwconv2dFprop_Simt(args)
return GenerateDwconv2d_Simt(args, ConvKind.Fprop)
else:
assert args.type == "tensorop884", (
"operation dwconv2d fprop only support"
"simt, tensorop884. (got:{})".format(args.type)
)
return GenerateDwconv2dFprop_TensorOp_884(args)
return GenerateDwconv2d_TensorOp_884(args, ConvKind.Fprop)
def GenerateDwconv2dDgradOperations(args):
if args.type == "simt":
return GenerateDwconv2d_Simt(args, ConvKind.Dgrad)
else:
assert args.type == "tensorop884", (
"operation dwconv2d fprop only support"
"simt, tensorop884. (got:{})".format(args.type)
)
return GenerateDwconv2d_TensorOp_884(args, ConvKind.Dgrad)
def GenerateGemmOperations(args):
......@@ -1655,7 +1667,7 @@ if __name__ == "__main__":
elif args.operations == "dwconv2d_fprop":
operations = GenerateDwconv2dFpropOperations(args)
elif args.operations == "dwconv2d_dgrad":
pass
operations = GenerateDwconv2dDgradOperations(args)
elif args.operations == "dwconv2d_wgrad":
pass
......
......@@ -183,6 +183,8 @@ if(MGE_WITH_CUDA)
gen_cutlass_kimpl(conv2d tensorop8832 CUTLASS_SOURCES)
gen_cutlass_kimpl(dwconv2d_fprop simt CUTLASS_SOURCES)
gen_cutlass_kimpl(dwconv2d_fprop tensorop884 CUTLASS_SOURCES)
gen_cutlass_kimpl(dwconv2d_dgrad simt CUTLASS_SOURCES)
gen_cutlass_kimpl(dwconv2d_dgrad tensorop884 CUTLASS_SOURCES)
list(APPEND SOURCES ${CUTLASS_SOURCES})
list(APPEND SOURCES ${CUSOURCES})
endif()
......
......@@ -304,12 +304,13 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() {
void ConvBiasForwardImpl::AlgoPack::fill_dwconv_algos() {
using AlgoParam = AlgoCutlassConvolutionBase::AlgoParam;
/// preferred algo
f32_implicit_bmm.emplace_back(AlgoParam{64, 128, 8, 32, 64, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{128, 128, 8, 32, 64, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{128, 64, 8, 64, 32, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{128, 32, 8, 64, 32, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{32, 128, 8, 32, 64, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{64, 128, 8, 64, 32, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{64, 64, 8, 64, 32, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{64, 64, 8, 32, 64, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{32, 64, 8, 32, 64, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{32, 32, 8, 32, 32, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{64, 32, 8, 64, 32, 8, 1, 1, 1, 2});
......@@ -317,10 +318,11 @@ void ConvBiasForwardImpl::AlgoPack::fill_dwconv_algos() {
all_algos.push_back(&algo);
}
#if CUDA_VERSION >= 10020
/// preferred algo
f16_implicit_bmm.emplace_back(AlgoParam{64, 128, 32, 32, 32, 32, 8, 8, 4, 2});
f16_implicit_bmm.emplace_back(AlgoParam{128, 128, 32, 32, 32, 32, 8, 8, 4, 2});
f16_implicit_bmm.emplace_back(AlgoParam{128, 256, 32, 64, 64, 32, 8, 8, 4, 2});
f16_implicit_bmm.emplace_back(AlgoParam{128, 64, 32, 32, 32, 32, 8, 8, 4, 2});
f16_implicit_bmm.emplace_back(AlgoParam{64, 128, 32, 32, 32, 32, 8, 8, 4, 2});
f16_implicit_bmm.emplace_back(AlgoParam{64, 64, 32, 32, 32, 32, 8, 8, 4, 2});
for (auto&& algo : f16_implicit_bmm) {
all_algos.push_back(&algo);
......
......@@ -272,8 +272,10 @@ std::pair<int, int> get_tensor_alignment(
alignment_src /= src.dtype.size(1);
};
/// TODO: need a better way to check whether tensor core instruction is used
if (format == Format::NCHW32 || format == Format::NCHW32_NCHW4 ||
format == Format::NCHW64 || format == Format::NCHW64) {
format == Format::NCHW64 || format == Format::NCHW64 ||
format == Format::NHWC) {
get_tensor_alignment_tensor_op();
} else if (
format == Format::NCHW4 || format == Format::NCHW4_NCHW ||
......
......@@ -23,6 +23,7 @@ bool ConvBiasForwardImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm::is_available(
#define RETURN_IF_FALSE(stmt_) \
if (!(stmt_)) \
return false;
RETURN_IF_FALSE(is_compute_capability_required(7, 0));
RETURN_IF_FALSE(
args.src_layout->is_contiguous() && args.dst_layout->is_contiguous());
using Param = param::ConvBias;
......
......@@ -41,6 +41,7 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&algo);
int8_algos.push_back(&algo);
}
fill_dwconv_algos();
int8_algos.push_back(&int8_nchw_dotprod);
all_algos.push_back(&int8_nchw_dotprod);
......@@ -54,6 +55,39 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() {
}
}
void ConvolutionBackwardDataImpl::AlgoPack::fill_dwconv_algos() {
{
using AlgoParam = AlgoFloat32NCHWFMAImplicitBatchedGemm::AlgoParam;
/// preferred algo
implbmm_nchw_fma.emplace_back(AlgoParam{64, 128, 8, 32, 64, 8, 2});
implbmm_nchw_fma.emplace_back(AlgoParam{128, 128, 8, 32, 64, 8, 2});
implbmm_nchw_fma.emplace_back(AlgoParam{128, 64, 8, 64, 32, 8, 2});
implbmm_nchw_fma.emplace_back(AlgoParam{128, 32, 8, 64, 32, 8, 2});
implbmm_nchw_fma.emplace_back(AlgoParam{32, 128, 8, 32, 64, 8, 2});
implbmm_nchw_fma.emplace_back(AlgoParam{64, 64, 8, 32, 64, 8, 2});
implbmm_nchw_fma.emplace_back(AlgoParam{32, 64, 8, 32, 64, 8, 2});
implbmm_nchw_fma.emplace_back(AlgoParam{32, 32, 8, 32, 32, 8, 2});
implbmm_nchw_fma.emplace_back(AlgoParam{64, 32, 8, 64, 32, 8, 2});
for (auto&& algo : implbmm_nchw_fma) {
all_algos.push_back(&algo);
}
}
#if CUDA_VERSION >= 10020
{
using AlgoParam = AlgoFloat16NCHWHMMAImplicitBatchedGemm::AlgoParam;
/// preferred algo
implbmm_nchw_hmma.emplace_back(AlgoParam{64, 128, 32, 32, 32, 32, 8, 8, 4, 2});
implbmm_nchw_hmma.emplace_back(AlgoParam{128, 128, 32, 32, 32, 32, 8, 8, 4, 2});
implbmm_nchw_hmma.emplace_back(AlgoParam{128, 256, 32, 64, 64, 32, 8, 8, 4, 2});
implbmm_nchw_hmma.emplace_back(AlgoParam{128, 64, 32, 32, 32, 32, 8, 8, 4, 2});
implbmm_nchw_hmma.emplace_back(AlgoParam{64, 64, 32, 32, 32, 32, 8, 8, 4, 2});
for (auto&& algo : implbmm_nchw_hmma) {
all_algos.push_back(&algo);
}
}
#endif
}
MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl)
ConvolutionBackwardDataImpl::AlgoCUDNN* ConvolutionBackwardDataImpl::AlgoPack::
......
......@@ -41,7 +41,9 @@ public:
CUDA_GROUP_CONV_GENERAL,
CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8,
CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8,
CUDA_IMPLICIT_GEMM_NHWC_IMMA_INT8
CUDA_IMPLICIT_GEMM_NHWC_IMMA_INT8,
CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32,
CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
......@@ -315,6 +317,82 @@ private:
std::string m_name;
};
class ConvolutionBackwardDataImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm final
: public AlgoBase {
public:
struct AlgoParam {
int threadblock_m;
int threadblock_n;
int threadblock_k;
int warp_m;
int warp_n;
int warp_k;
int stage;
std::string to_string() {
return ssprintf(
"_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n,
threadblock_k, warp_m, warp_n, warp_k, stage);
}
};
AlgoFloat32NCHWFMAImplicitBatchedGemm(AlgoParam algo_param)
: m_algo_param{algo_param},
m_name{ssprintf(
"FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM%s",
m_algo_param.to_string().c_str())} {}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override { return 0; }
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32)
private:
const void* get_available_op(const SizeArgs& args) const;
AlgoParam m_algo_param;
std::string m_name;
};
class ConvolutionBackwardDataImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm final
: public AlgoBase {
public:
/// add instruction shape as member of algo param, because f16 tensor core has 2
/// different matrix shapes (i.e. mma.884 and mma.1688)
struct AlgoParam {
int threadblock_m;
int threadblock_n;
int threadblock_k;
int warp_m;
int warp_n;
int warp_k;
int instruction_m;
int instruction_n;
int instruction_k;
int stage;
std::string to_string() {
return ssprintf(
"_%dX%dX%d_%dX%dX%d_mma%dX%dX%d_%dstage", threadblock_m,
threadblock_n, threadblock_k, warp_m, warp_n, warp_k, instruction_m,
instruction_n, instruction_k, stage);
}
};
AlgoFloat16NCHWHMMAImplicitBatchedGemm(AlgoParam algo_param)
: m_algo_param{algo_param},
m_name{ssprintf(
"FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM%s",
m_algo_param.to_string().c_str())} {}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override { return 0; }
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16)
private:
const void* get_available_op(const SizeArgs& args) const;
AlgoParam m_algo_param;
std::string m_name;
};
class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
// defined in cudnn.cpp
void fill_cudnn_algos();
......@@ -322,6 +400,7 @@ class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
void fill_int8_dp4a_algos();
// defined in implicit_gemm_int8_nhwc_imma.cpp
void fill_int8_imma_algos();
void fill_dwconv_algos();
AlgoBase::Mapper m_all_algos_map;
......@@ -337,6 +416,8 @@ public:
std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod;
AlgoInt8NCHWDotProdImplicitGemm int8_nchw_dotprod;
std::vector<AlgoInt8NHWCIMMAImplicitGemm> int8_nhwc_imma;
std::vector<AlgoFloat32NCHWFMAImplicitBatchedGemm> implbmm_nchw_fma;
std::vector<AlgoFloat16NCHWHMMAImplicitBatchedGemm> implbmm_nchw_hmma;
std::vector<AlgoBase*>
//! all algorithms
......
/**
* \file
* dnn/src/cuda/convolution/backward_data/implicit_batched_gemm_float16_nchw_hmma.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/convolution/backward_data/algo.h"
#include "src/cuda/cutlass/singleton.h"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
using namespace cutlass::library;
const void* ConvolutionBackwardDataImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm::
get_available_op(const SizeArgs& args) const {
int alignment_diff = 0;
int wo = args.diff_layout->dtype.size(args.diff_layout->operator[](3));
for (int candidate : {16, 4, 2}) {
if (wo % candidate == 0)
alignment_diff = candidate;
}
alignment_diff /= args.diff_layout->dtype.size(1);
NumericTypeID accumulator_dtype =
args.opr->param().compute_mode == param::Convolution::ComputeMode::DEFAULT
? NumericTypeID::kF16
: NumericTypeID::kF32;
ConvolutionKey key{
cutlass::conv::Operator::kDgrad,
NumericTypeID::kF16,
LayoutTypeID::kTensorNCHW,
NumericTypeID::kF16,
LayoutTypeID::kTensorNCHW,
NumericTypeID::kF16,
LayoutTypeID::kTensorNCHW,
NumericTypeID::kF16,
LayoutTypeID::kTensorNCHW,
accumulator_dtype,
cutlass::conv::ConvType::kDepthwiseConvolution,
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k,
m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k,
m_algo_param.instruction_m,
m_algo_param.instruction_n,
m_algo_param.instruction_k,
cutlass::epilogue::EpilogueType::kBiasAddLinearCombination,
m_algo_param.stage,
cutlass::conv::SpecialOptimizeDesc::NONE,
alignment_diff,
1,
false};
return (void*)Singleton::get().operation_table.find_op(key);
}
bool ConvolutionBackwardDataImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm::is_available(
const SizeArgs& args) const {
#define RETURN_IF_FALSE(stmt_) \
if (!(stmt_)) \
return false;
RETURN_IF_FALSE(is_compute_capability_required(7, 0));
RETURN_IF_FALSE(
args.diff_layout->is_contiguous() && args.grad_layout->is_contiguous());
using Param = param::Convolution;
using Format = Param::Format;
using Sparse = Param::Sparse;
using Mode = Param::Mode;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
RETURN_IF_FALSE(
param.format == Format::NCHW &&
args.diff_layout->dtype.enumv() == DTypeEnum::Float16 &&
args.filter_layout->dtype.enumv() == DTypeEnum::Float16 &&
args.grad_layout->dtype.enumv() == DTypeEnum::Float16);
RETURN_IF_FALSE(param.sparse == Sparse::GROUP);
RETURN_IF_FALSE(param.mode == Mode::CROSS_CORRELATION);
// check if channelwise convolution
RETURN_IF_FALSE(fm.icpg == 1 && fm.ocpg == 1);
const auto* op = get_available_op(args);
RETURN_IF_FALSE(op != nullptr);
return true;
#undef RETURN_IF_FALSE
}
void ConvolutionBackwardDataImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm::exec(
const ExecArgs& args) const {
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
int n = args.diff_layout->operator[](0), ho = args.diff_layout->operator[](2),
wo = args.diff_layout->operator[](3);
int hi = args.grad_layout->operator[](2), wi = args.grad_layout->operator[](3);
int co = fm.group, ci = co, groups = co;
int fh = fm.spatial[0], fw = fm.spatial[1];
int sh = fm.stride[0], sw = fm.stride[1];
int ph = fm.padding[0], pw = fm.padding[1];
int dh = param.dilate_h, dw = param.dilate_w;
// check if channelwise convolution
megdnn_assert(fm.icpg == 1 && fm.ocpg == 1);
auto&& stream = cuda_stream(args.opr->handle());
float alpha = 1.f;
float beta = 0.f;
float gamma = 0.f;
float delta = 0.f;
const Operation* op = (const Operation*)get_available_op(args);
cutlass::conv::Conv2dProblemSize problem_size{
n, hi, wi, ci, co, fh, fw, ho,
wo, ph, pw, sh, sw, dh, dw, cutlass::conv::Mode::kCrossCorrelation,
1, // split k slices, always 1
groups, // groups
};
cutlass::library::ConvolutionArguments conv_args{
problem_size,
args.diff_tensor->raw_ptr(),
args.filter_tensor->raw_ptr(),
nullptr,
nullptr,
args.grad_tensor->raw_ptr(),
&alpha,
&beta,
&gamma,
&delta,
nullptr,
nullptr,
nullptr,
nullptr};
cutlass_check(op->run(&conv_args, nullptr, stream));
after_kernel_launch();
}
// vim: syntax=cpp.doxygen
/**
* \file
* dnn/src/cuda/convolution/backward_data/implicit_batched_gemm_float32_nchw_fma.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/convolution/backward_data/algo.h"
#include "src/cuda/cutlass/singleton.h"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
using namespace cutlass::library;
const void* ConvolutionBackwardDataImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm::
get_available_op(const SizeArgs& args) const {
int alignment_diff = 0;
int wo = args.diff_layout->dtype.size(args.diff_layout->operator[](3));
for (int candidate : {16, 4}) {
if (wo % candidate == 0)
alignment_diff = candidate;
}
alignment_diff /= args.diff_layout->dtype.size(1);
ConvolutionKey key{
cutlass::conv::Operator::kDgrad,
NumericTypeID::kF32,
LayoutTypeID::kTensorNCHW,
NumericTypeID::kF32,
LayoutTypeID::kTensorNCHW,
NumericTypeID::kF32,
LayoutTypeID::kTensorNCHW,
NumericTypeID::kF32,
LayoutTypeID::kTensorNCHW,
NumericTypeID::kF32,
cutlass::conv::ConvType::kDepthwiseConvolution,
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k,
m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k,
1,
1,
1,
cutlass::epilogue::EpilogueType::kBiasAddLinearCombination,
m_algo_param.stage,
cutlass::conv::SpecialOptimizeDesc::NONE,
alignment_diff,
1,
false};
return (void*)Singleton::get().operation_table.find_op(key);
}
bool ConvolutionBackwardDataImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm::is_available(
const SizeArgs& args) const {
#define RETURN_IF_FALSE(stmt_) \
if (!(stmt_)) \
return false;
RETURN_IF_FALSE(
args.diff_layout->is_contiguous() && args.grad_layout->is_contiguous());
using Param = param::Convolution;
using Format = Param::Format;
using Sparse = Param::Sparse;
using Mode = Param::Mode;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
RETURN_IF_FALSE(
param.format == Format::NCHW &&
args.diff_layout->dtype.enumv() == DTypeEnum::Float32 &&
args.filter_layout->dtype.enumv() == DTypeEnum::Float32 &&
args.grad_layout->dtype.enumv() == DTypeEnum::Float32);
RETURN_IF_FALSE(param.sparse == Sparse::GROUP);
RETURN_IF_FALSE(param.mode == Mode::CROSS_CORRELATION);
// check if channelwise convolution
RETURN_IF_FALSE(fm.icpg == 1 && fm.ocpg == 1);
const auto* op = get_available_op(args);
RETURN_IF_FALSE(op != nullptr);
return true;
#undef RETURN_IF_FALSE
}
void ConvolutionBackwardDataImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm::exec(
const ExecArgs& args) const {
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
int n = args.diff_layout->operator[](0), ho = args.diff_layout->operator[](2),
wo = args.diff_layout->operator[](3);
int hi = args.grad_layout->operator[](2), wi = args.grad_layout->operator[](3);
int co = fm.group, ci = co, groups = co;
int fh = fm.spatial[0], fw = fm.spatial[1];
int sh = fm.stride[0], sw = fm.stride[1];
int ph = fm.padding[0], pw = fm.padding[1];
int dh = param.dilate_h, dw = param.dilate_w;
// check if channelwise convolution
megdnn_assert(fm.icpg == 1 && fm.ocpg == 1);
auto&& stream = cuda_stream(args.opr->handle());
float alpha = 1.f;
float beta = 0.f;
float gamma = 0.f;
float delta = 0.f;
const Operation* op = (const Operation*)get_available_op(args);
cutlass::conv::Conv2dProblemSize problem_size{
n, hi, wi, ci, co, fh, fw, ho,
wo, ph, pw, sh, sw, dh, dw, cutlass::conv::Mode::kCrossCorrelation,
1, // split k slices, always 1
groups, // groups
};
cutlass::library::ConvolutionArguments conv_args{
problem_size,
args.diff_tensor->raw_ptr(),
args.filter_tensor->raw_ptr(),
nullptr,
nullptr,
args.grad_tensor->raw_ptr(),
&alpha,
&beta,
&gamma,
&delta,
nullptr,
nullptr,
nullptr,
nullptr};
cutlass_check(op->run(&conv_args, nullptr, stream));
after_kernel_launch();
}
// vim: syntax=cpp.doxygen
......@@ -54,7 +54,7 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::
m_algo_param.stage,
special_optimization,
4,
16,
4,
false};
return (void*)Singleton::get().operation_table.find_op(key);
}
......
......@@ -102,6 +102,8 @@ public:
class AlgoInt8NCHW4DotProdImplicitGemm;
class AlgoInt8NCHWDotProdImplicitGemm;
class AlgoInt8NHWCIMMAImplicitGemm;
class AlgoFloat32NCHWFMAImplicitBatchedGemm;
class AlgoFloat16NCHWHMMAImplicitBatchedGemm;
class AlgoPack;
......
......@@ -55,6 +55,7 @@ void initialize_all_gemm_simt_operations(Manifest& manifest);
void initialize_all_conv2d_simt_operations(Manifest& manifest);
void initialize_all_deconv_simt_operations(Manifest& manifest);
void initialize_all_dwconv2d_fprop_simt_operations(Manifest& manifest);
void initialize_all_dwconv2d_dgrad_simt_operations(Manifest& manifest);
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) && CUTLASS_ARCH_MMA_SM75_SUPPORTED
void initialize_all_gemm_tensorop884_operations(Manifest& manifest);
void initialize_all_gemm_tensorop1688_operations(Manifest& manifest);
......@@ -62,6 +63,7 @@ void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest);
void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest);
void initialize_all_deconv_tensorop8816_operations(Manifest& manifest);
void initialize_all_dwconv2d_fprop_tensorop884_operations(Manifest& manifest);
void initialize_all_dwconv2d_dgrad_tensorop884_operations(Manifest& manifest);
#endif
void initialize_all(Manifest& manifest) {
......@@ -69,6 +71,7 @@ void initialize_all(Manifest& manifest) {
initialize_all_conv2d_simt_operations(manifest);
initialize_all_deconv_simt_operations(manifest);
initialize_all_dwconv2d_fprop_simt_operations(manifest);
initialize_all_dwconv2d_dgrad_simt_operations(manifest);
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) && CUTLASS_ARCH_MMA_SM75_SUPPORTED
initialize_all_gemm_tensorop884_operations(manifest);
initialize_all_gemm_tensorop1688_operations(manifest);
......@@ -76,6 +79,7 @@ void initialize_all(Manifest& manifest) {
initialize_all_conv2d_tensorop8832_operations(manifest);
initialize_all_deconv_tensorop8816_operations(manifest);
initialize_all_dwconv2d_fprop_tensorop884_operations(manifest);
initialize_all_dwconv2d_dgrad_tensorop884_operations(manifest);
#endif
}
......
......@@ -569,6 +569,7 @@ public:
});
return ret;
}
megdnn_assert(false, "Expected algo not found: %s\n", policy_name.name.c_str());
return ret;
}
......
......@@ -497,15 +497,15 @@ void check_chanwise(DType io_type, DType comp_type, Handle* handle, const char*
}
} // namespace
#define MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_FMA_KERNEL(cb) \
cb(1, 128, 128, 8, 32, 64, 8); \
cb(2, 128, 64, 8, 64, 32, 8); \
cb(3, 128, 32, 8, 64, 32, 8); \
cb(4, 64, 128, 8, 64, 32, 8); \
cb(5, 32, 128, 8, 32, 64, 8); \
cb(6, 64, 64, 8, 64, 32, 8); \
cb(7, 32, 64, 8, 32, 64, 8); \
cb(8, 32, 32, 8, 32, 32, 8); \
#define MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FMA_KERNEL(cb) \
cb(1, 128, 128, 8, 32, 64, 8); \
cb(2, 128, 64, 8, 64, 32, 8); \
cb(3, 128, 32, 8, 64, 32, 8); \
cb(4, 64, 128, 8, 32, 64, 8); \
cb(5, 32, 128, 8, 32, 64, 8); \
cb(6, 64, 64, 8, 32, 64, 8); \
cb(7, 32, 64, 8, 32, 64, 8); \
cb(8, 32, 32, 8, 32, 32, 8); \
cb(9, 64, 32, 8, 64, 32, 8);
#define cb(tag, tbm, tbn, tbk, wm, wn, wk) \
......@@ -516,16 +516,29 @@ void check_chanwise(DType io_type, DType comp_type, Handle* handle, const char*
"_" #wm "X" #wn "X" #wk "_2stage"); \
}
MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_FMA_KERNEL(cb)
MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FMA_KERNEL(cb)
#undef cb
#undef MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_FMA_KERNEL
#define MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_HMMA_KERNEL(cb) \
cb(1, 128, 128, 32, 32, 32, 32); \
cb(2, 128, 256, 32, 64, 64, 32); \
cb(3, 128, 64, 32, 32, 32, 32); \
cb(4, 64, 128, 32, 32, 32, 32); \
#define cb(tag, tbm, tbn, tbk, wm, wn, wk) \
TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_DATA_CUTLASS_FMA_##tag) { \
check_chanwise<ConvolutionBackwardData>( \
dtype::Float32(), dtype::Float32(), handle_cuda(), \
"FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \
"_" #wm "X" #wn "X" #wk "_2stage"); \
}
MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FMA_KERNEL(cb)
#undef cb
#undef MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FMA_KERNEL
#define MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_HMMA_KERNEL(cb) \
cb(1, 128, 128, 32, 32, 32, 32); \
cb(2, 128, 256, 32, 64, 64, 32); \
cb(3, 128, 64, 32, 32, 32, 32); \
cb(4, 64, 128, 32, 32, 32, 32); \
cb(5, 64, 64, 32, 32, 32, 32);
// check both ioc16 and io16xc32
......@@ -541,9 +554,26 @@ MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_FMA_KERNEL(cb)
"_" #wm "X" #wn "X" #wk "_2stage"); \
}
MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_HMMA_KERNEL(cb)
MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_HMMA_KERNEL(cb)
#undef cb
#define cb(tag, tbm, tbn, tbk, wm, wn, wk) \
TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_DATA_CUTLASS_HMMA_##tag) { \
check_chanwise<ConvolutionBackwardData>( \
dtype::Float16(), dtype::Float16(), handle_cuda(), \
"FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \
"_" #wm "X" #wn "X" #wk "_mma8X8X4_2stage"); \
check_chanwise<ConvolutionBackwardData>( \
dtype::Float16(), dtype::Float32(), handle_cuda(), \
"FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \
"_" #wm "X" #wn "X" #wk "_mma8X8X4_2stage"); \
}
MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_HMMA_KERNEL(cb)
#undef cb
#undef MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_HMMA_KERNEL
#if MEGDNN_WITH_BENCHMARK
......@@ -1324,6 +1354,81 @@ TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_FORWARD_LARGE_KERNEL) {
// clang-format on
}
TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BACKWARD_DATA_LARGE_KERNEL) {
CUBenchmarker<ConvolutionBackwardData> bencher(handle_cuda());
size_t RUNS = 100;
bencher.set_display(false).set_times(RUNS);
std::unique_ptr<OprProxy<ConvolutionBackwardData>> proxy{
new OprProxy<ConvolutionBackwardData>{true}};
bencher.set_proxy(proxy);
Convolution::Param param;
param.format = ConvBias::Param::Format::NCHW;
param.sparse = Convolution::Param::Sparse::GROUP;
NormalRNG rng;
auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
param.pad_h = f / 2;
param.pad_w = f / 2;
param.stride_h = s;
param.stride_w = s;
param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
TensorLayout dst_layout;
auto opr = handle_cuda()->create_operator<Convolution>();
opr->param() = param;
opr->deduce_layout(
{src, dtype::Float32()}, {filter, dtype::Float32()}, dst_layout);
float bandwith = static_cast<float>(
src.total_nr_elems() + filter.total_nr_elems() +
dst_layout.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;
bencher.set_param(param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_rng(0, &rng)
.set_rng(1, &rng);
bencher.proxy()->target_execution_policy = {};
auto time_in_ms_fp32 = bencher.execs({filter, src, src}) / RUNS;
bencher.set_param(param)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_dtype(2, dtype::Float16())
.set_rng(0, &rng)
.set_rng(1, &rng);
bencher.proxy()->target_execution_policy = {};
auto time_in_ms_fp16 = bencher.execs({filter, src, src}) / RUNS;
bencher.proxy()->target_execution_policy.algo.reset();
param.compute_mode = param::Convolution::ComputeMode::FLOAT32;
bencher.set_param(param);
auto time_in_ms_pseudo_fp16 = bencher.execs({src, filter, {}}) / RUNS;
printf("stride=%zu src=%s, filter=%s, float32: %.2fms %.2fGB/s "
"float16: %.2fms %.2fGB/s "
"pseudo float16: %.2fms %.2fGB/s "
"speedup: "
"%0.2f (fp16/fp32) %.2f (fp16/pseudo fp16)\n",
s, src.to_string().c_str(), filter.to_string().c_str(), time_in_ms_fp32,
bandwith * 4 / time_in_ms_fp32, time_in_ms_fp16,
bandwith * 2 / time_in_ms_fp16, time_in_ms_pseudo_fp16,
bandwith * 2 / time_in_ms_pseudo_fp16, time_in_ms_fp32 / time_in_ms_fp16,
time_in_ms_pseudo_fp16 / time_in_ms_fp16);
};
// clang-format off
for (size_t b : {32, 64})
for (size_t f : {3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}) {
run(b, 384, 32, 32, f, 1);
run(b, 384, 64, 64, f, 1);
}
// clang-format on
}
#endif
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册