提交 739f927c 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): opt dp4a conv for small channel base on cutlass

GitOrigin-RevId: 2a74c35f27ea8186d79313ad607bb8fa89e846fb
上级 4f9948d0
...@@ -260,16 +260,17 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { ...@@ -260,16 +260,17 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() {
void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() {
using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam; using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam;
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32}); int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32}); int8_nchw4_dotprod.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32}); int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32}); int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32}); int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 64, 32, 64, 32, 32}); int8_nchw4_dotprod.emplace_back(AlgoParam{64, 64, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32}); int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32}); int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{32, 32, 32, 32, 32, 32}); int8_nchw4_dotprod.emplace_back(AlgoParam{32, 32, 32, 32, 32, 32, 2});
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8}); int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1});
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 2});
} }
......
...@@ -407,15 +407,16 @@ public: ...@@ -407,15 +407,16 @@ public:
int warp_m; int warp_m;
int warp_n; int warp_n;
int warp_k; int warp_k;
int stage;
std::string to_string() { std::string to_string() {
/// default algorithm /// default algorithm
if (threadblock_m == 128 && threadblock_n == 128 && if (threadblock_m == 128 && threadblock_n == 128 &&
threadblock_k == 32 && warp_m == 32 && warp_n == 64 && threadblock_k == 32 && warp_m == 32 && warp_n == 64 &&
warp_k == 32) { warp_k == 32 && stage == 2) {
return ""; return "";
} }
return ssprintf("_%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n, return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n,
threadblock_k, warp_m, warp_n, warp_k); threadblock_k, warp_m, warp_n, warp_k, stage);
} }
}; };
AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param) AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param)
......
...@@ -172,7 +172,7 @@ void megdnn::cuda::cutlass_wrapper:: ...@@ -172,7 +172,7 @@ void megdnn::cuda::cutlass_wrapper::
const GemmCoord& warp_shape, cudaStream_t stream) { const GemmCoord& warp_shape, cudaStream_t stream) {
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
threadblock_k_, warp_m_, warp_n_, \ threadblock_k_, warp_m_, warp_n_, \
warp_k_, aligned_) \ warp_k_, stage_, aligned_) \
if (threadblock_shape.m() == threadblock_m_ && \ if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \ threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \ threadblock_shape.k() == threadblock_k_ && \
...@@ -194,7 +194,7 @@ void megdnn::cuda::cutlass_wrapper:: ...@@ -194,7 +194,7 @@ void megdnn::cuda::cutlass_wrapper::
cutlass::convolution::threadblock:: \ cutlass::convolution::threadblock:: \
ConvolutionNCxHWxThreadblockSwizzle< \ ConvolutionNCxHWxThreadblockSwizzle< \
cutlass::convolution::ConvType::kConvolution>, \ cutlass::convolution::ConvType::kConvolution>, \
2, 4, aligned_, NeedLoadFromConstMem>; \ stage_, 4, aligned_, NeedLoadFromConstMem>; \
typename Convolution::ConvolutionParameter conv_param{ \ typename Convolution::ConvolutionParameter conv_param{ \
param.n, param.ci, param.co, param.hi, param.wi, \ param.n, param.ci, param.co, param.hi, param.wi, \
param.fh, param.fw, param.ho, param.wo, param.sh, \ param.fh, param.fw, param.ho, param.wo, param.sh, \
...@@ -204,16 +204,17 @@ void megdnn::cuda::cutlass_wrapper:: ...@@ -204,16 +204,17 @@ void megdnn::cuda::cutlass_wrapper::
epilogue, stream); \ epilogue, stream); \
} }
#define DISPATCH_KERNEL \ #define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 4); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \
megdnn_assert(false, \ megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \ "unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d)", \ "(%dx%dx%d)", \
......
...@@ -97,6 +97,13 @@ std::vector<BenchArgs> get_detection_bench_args(size_t batch = 16) { ...@@ -97,6 +97,13 @@ std::vector<BenchArgs> get_detection_bench_args(size_t batch = 16) {
return args; return args;
} }
std::vector<BenchArgs> get_det_first_bench_args(size_t batch = 16) {
std::vector<BenchArgs> args;
args.emplace_back(BenchArgs{batch, 4, 736, 1280, 16, 3, 2});
args.emplace_back(BenchArgs{batch, 16, 384, 640, 16, 3, 1});
return args;
}
void benchmark_target_algo( void benchmark_target_algo(
Handle* handle, const std::vector<BenchArgs>& args, DType src_dtype, Handle* handle, const std::vector<BenchArgs>& args, DType src_dtype,
DType filter_dtype, DType bias_dtype, DType dst_dtype, DType filter_dtype, DType bias_dtype, DType dst_dtype,
...@@ -1236,6 +1243,28 @@ TEST_F(CUDA, BENCHMARK_CUTLASS_CONV_BIAS_INT8_NCHW4) { ...@@ -1236,6 +1243,28 @@ TEST_F(CUDA, BENCHMARK_CUTLASS_CONV_BIAS_INT8_NCHW4) {
dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.0f}, dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.0f},
"INT8_NCHW4_DOTPROD_IMPLICIT_GEMM", param::ConvBias::Format::NCHW4); "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM", param::ConvBias::Format::NCHW4);
} }
TEST_F(CUDA, BENCHMARK_SASS_CONV_BIAS_INT8_NCHW4_DET_FIRST) {
require_compute_capability(6, 1);
std::string algo = ConvBias::algo_name<ConvBias::DirectParam>(
"SASS_INT8_NCHW4_DOTPROD_IMPLICIT_GEMM_128X32_64",
ConvBias::DirectParam{});
benchmark_target_algo(handle_cuda(), get_det_first_bench_args(16),
dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
dtype::QuantizedS32{1.2f * 1.3f},
dtype::QuantizedS8{1.0f}, algo.c_str(),
param::ConvBias::Format::NCHW4);
}
TEST_F(CUDA, BENCHMARK_CUTLASS_CONV_BIAS_INT8_NCHW4_DET_FIRST) {
require_compute_capability(6, 1);
benchmark_target_algo(
handle_cuda(), get_det_first_bench_args(16),
dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.0f},
"INT8_NCHW4_DOTPROD_IMPLICIT_GEMM_16", param::ConvBias::Format::NCHW4);
}
#endif #endif
} // namespace test } // namespace test
} // namespace megdnn } // namespace megdnn
......
Subproject commit 5a7f4bfa0e57f92140c8236322a86730132e0847 Subproject commit 41426ea4074dcfc448b1c9979ea7617407590c04
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册