提交 96050073 编写于 作者: M Megvii Engine Team 提交者: 王彪

feat(dnn/cuda): add implicit bmm large kernel dwconv2d fprop impl

GitOrigin-RevId: feb09ebb5836d26433c4a82940bb5f22795da381
上级 19fe2e94
...@@ -181,6 +181,8 @@ if(MGE_WITH_CUDA) ...@@ -181,6 +181,8 @@ if(MGE_WITH_CUDA)
gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES) gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES)
gen_cutlass_kimpl(conv2d tensorop8816 CUTLASS_SOURCES) gen_cutlass_kimpl(conv2d tensorop8816 CUTLASS_SOURCES)
gen_cutlass_kimpl(conv2d tensorop8832 CUTLASS_SOURCES) gen_cutlass_kimpl(conv2d tensorop8832 CUTLASS_SOURCES)
gen_cutlass_kimpl(dwconv2d_fprop simt CUTLASS_SOURCES)
gen_cutlass_kimpl(dwconv2d_fprop tensorop884 CUTLASS_SOURCES)
list(APPEND SOURCES ${CUTLASS_SOURCES}) list(APPEND SOURCES ${CUTLASS_SOURCES})
list(APPEND SOURCES ${CUSOURCES}) list(APPEND SOURCES ${CUSOURCES})
endif() endif()
......
...@@ -92,6 +92,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { ...@@ -92,6 +92,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
for (auto&& algo : int8_nchw4_dotprod) { for (auto&& algo : int8_nchw4_dotprod) {
all_algos.push_back(&algo); all_algos.push_back(&algo);
} }
fill_dwconv_algos();
all_algos.push_back(&int8_chwn4_dotprod); all_algos.push_back(&int8_chwn4_dotprod);
all_algos.push_back(&fallback_nchw_qs8); all_algos.push_back(&fallback_nchw_qs8);
for (size_t i = all_algo_size; i < all_algos.size(); ++i) { for (size_t i = all_algo_size; i < all_algos.size(); ++i) {
...@@ -301,6 +302,32 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { ...@@ -301,6 +302,32 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() {
} }
#endif #endif
void ConvBiasForwardImpl::AlgoPack::fill_dwconv_algos() {
using AlgoParam = AlgoCutlassConvolutionBase::AlgoParam;
f32_implicit_bmm.emplace_back(AlgoParam{128, 128, 8, 32, 64, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{128, 64, 8, 64, 32, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{128, 32, 8, 64, 32, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{32, 128, 8, 32, 64, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{64, 128, 8, 64, 32, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{64, 64, 8, 64, 32, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{32, 64, 8, 32, 64, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{32, 32, 8, 32, 32, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{64, 32, 8, 64, 32, 8, 1, 1, 1, 2});
for (auto&& algo : f32_implicit_bmm) {
all_algos.push_back(&algo);
}
#if CUDA_VERSION >= 10020
f16_implicit_bmm.emplace_back(AlgoParam{128, 128, 32, 32, 32, 32, 8, 8, 4, 2});
f16_implicit_bmm.emplace_back(AlgoParam{128, 256, 32, 64, 64, 32, 8, 8, 4, 2});
f16_implicit_bmm.emplace_back(AlgoParam{128, 64, 32, 32, 32, 32, 8, 8, 4, 2});
f16_implicit_bmm.emplace_back(AlgoParam{64, 128, 32, 32, 32, 32, 8, 8, 4, 2});
f16_implicit_bmm.emplace_back(AlgoParam{64, 64, 32, 32, 32, 32, 8, 8, 4, 2});
for (auto&& algo : f16_implicit_bmm) {
all_algos.push_back(&algo);
}
#endif
}
void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() {
using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam; using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam;
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32, 1, 1, 4, 2}); int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32, 1, 1, 4, 2});
......
此差异已折叠。
...@@ -71,6 +71,9 @@ public: ...@@ -71,6 +71,9 @@ public:
class AlgoInt4Int4NHWCIMMAImplicitGemm; class AlgoInt4Int4NHWCIMMAImplicitGemm;
class AlgoUInt4Int4NHWCIMMAImplicitGemm; class AlgoUInt4Int4NHWCIMMAImplicitGemm;
class AlgoBFloat16; class AlgoBFloat16;
// The following algorithms are suitable for channel wise convolution
class AlgoFloat32NCHWFMAImplicitBatchedGemm;
class AlgoFloat16NCHWHMMAImplicitBatchedGemm;
class AlgoPack; class AlgoPack;
......
...@@ -39,6 +39,7 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm:: ...@@ -39,6 +39,7 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::
LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorNC4HW4,
NumericTypeID::kS32, NumericTypeID::kS32,
LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorNC4HW4,
NumericTypeID::kS32,
cutlass::conv::ConvType::kConvolution, cutlass::conv::ConvType::kConvolution,
m_algo_param.threadblock_m, m_algo_param.threadblock_m,
m_algo_param.threadblock_n, m_algo_param.threadblock_n,
...@@ -52,6 +53,8 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm:: ...@@ -52,6 +53,8 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::
cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp,
m_algo_param.stage, m_algo_param.stage,
special_optimization, special_optimization,
4,
16,
false}; false};
return (void*)Singleton::get().operation_table.find_op(key); return (void*)Singleton::get().operation_table.find_op(key);
} }
......
...@@ -39,6 +39,7 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: ...@@ -39,6 +39,7 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::
LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorNC4HW4,
NumericTypeID::kS32, NumericTypeID::kS32,
LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorNC4HW4,
NumericTypeID::kS32,
cutlass::conv::ConvType::kConvolution, cutlass::conv::ConvType::kConvolution,
16, 16,
64, 64,
...@@ -52,6 +53,8 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: ...@@ -52,6 +53,8 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::
cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp,
2, 2,
special_optimization, special_optimization,
4,
4,
false}; false};
return (void*)Singleton::get().operation_table.find_op(key); return (void*)Singleton::get().operation_table.find_op(key);
} }
......
...@@ -50,6 +50,7 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::get_avail ...@@ -50,6 +50,7 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::get_avail
LayoutTypeID::kTensorNHWC, LayoutTypeID::kTensorNHWC,
NumericTypeID::kS32, NumericTypeID::kS32,
LayoutTypeID::kTensorNHWC, LayoutTypeID::kTensorNHWC,
NumericTypeID::kS32,
cutlass::conv::ConvType::kConvolution, cutlass::conv::ConvType::kConvolution,
m_algo_param.threadblock_m, m_algo_param.threadblock_m,
m_algo_param.threadblock_n, m_algo_param.threadblock_n,
...@@ -63,6 +64,8 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::get_avail ...@@ -63,6 +64,8 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::get_avail
cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp,
m_algo_param.stage, m_algo_param.stage,
special_optimization, special_optimization,
m_algo_param.access_size,
m_algo_param.access_size,
false}; false};
return (void*)Singleton::get().operation_table.find_op(key); return (void*)Singleton::get().operation_table.find_op(key);
} }
......
...@@ -223,6 +223,9 @@ enum class ThreadblockSwizzleID { ...@@ -223,6 +223,9 @@ enum class ThreadblockSwizzleID {
kConvolutionFpropTrans, kConvolutionFpropTrans,
kConvolutionDgradNCxHWx, kConvolutionDgradNCxHWx,
kConvolutionDgradTrans, kConvolutionDgradTrans,
kDepthwiseConvolutionFprop,
kDepthwiseConvolutionDgrad,
kDepthwiseConvolutionWgrad,
kInvalid kInvalid
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册