From 3bda3347981babe1fddf6873ef0a01ce2bb0f4de Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 19 Jan 2021 18:00:07 +0800 Subject: [PATCH] fix(dnn/fallback): fix segmentfault caused by im2col/conv1x1 using fallback naive matmul. GitOrigin-RevId: 03ef904b113ffcc63e6e298827321fe085075b4a --- dnn/src/fallback/matrix_mul/algos.cpp | 36 ++++++++++++++++++++------- dnn/test/fallback/matrix_mul.cpp | 18 ++++++++++++++ 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/dnn/src/fallback/matrix_mul/algos.cpp b/dnn/src/fallback/matrix_mul/algos.cpp index 371702ad3..280a29761 100644 --- a/dnn/src/fallback/matrix_mul/algos.cpp +++ b/dnn/src/fallback/matrix_mul/algos.cpp @@ -49,15 +49,33 @@ void kern_naive(const MatrixMulImpl::KernParam& kern_param) { MIDOUT_BEGIN(megdnn_fb_matmul_naive, void) { size_t M = kern_param.M, N = kern_param.N, K = kern_param.K; size_t LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; - -#define DISPATCH(TA, TB) \ - if (kern_param.trA == TA && kern_param.trB == TB) { \ - naive::dispatch_ta_tb( \ - kern_param.A_ptr, kern_param.B_ptr, kern_param.C_ptr, \ - kern_param.workspace_ptr, M, N, K, LDA, LDB, LDC, \ - kern_param.A_type, kern_param.B_type, kern_param.C_type, \ - kern_param.format, kern_param.compute_mode); \ - return; \ + auto get_pack_size = [kern_param]() -> size_t { + switch (kern_param.format) { + case param::MatrixMul::Format::MK4: + case param::MatrixMul::Format::MK4_DOT: + return 4_z; + case param::MatrixMul::Format::MK8: + return 8_z; + default: + return 1_z; + } + }; + + size_t pack_size = get_pack_size(); + megdnn_assert( + (M % pack_size == 0 && K % pack_size == 0), + "M and N must time of pack_size M: %zu N: %zu pack_size: %zu", + M, N, pack_size); + +#define DISPATCH(TA, TB) \ + if (kern_param.trA == TA && kern_param.trB == TB) { \ + naive::dispatch_ta_tb( \ + kern_param.A_ptr, kern_param.B_ptr, kern_param.C_ptr, \ + kern_param.workspace_ptr, M / pack_size, N, K / pack_size, \ + LDA, LDB, LDC, kern_param.A_type, kern_param.B_type, \ + kern_param.C_type, kern_param.format, \ + kern_param.compute_mode); \ + return; \ } DISPATCH(true, true); DISPATCH(true, false); diff --git a/dnn/test/fallback/matrix_mul.cpp b/dnn/test/fallback/matrix_mul.cpp index 92807b9ac..682f2e194 100644 --- a/dnn/test/fallback/matrix_mul.cpp +++ b/dnn/test/fallback/matrix_mul.cpp @@ -46,6 +46,24 @@ TEST_F(FALLBACK, MATRIX_MUL) { } } +TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK4) { + matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, handle(), "FB_NAIVE", + param::MatrixMul::Format::MK4, 1); +} + +TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK8) { + matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, handle(), "FB_NAIVE", + param::MatrixMul::Format::MK8, 1); +} + +TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK4_DOT) { + matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, handle(), "FB_NAIVE", + param::MatrixMul::Format::MK4_DOT, 1); +} + TEST_F(FALLBACK, MATRIX_MUL_NAIVE) { Checker checker(handle()); checker.set_before_exec_callback(AlgoChecker("FB_NAIVE")); -- GitLab