From a6bc250d1c8e2aa52ab4ae3678f012819c8b66e4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 22 May 2020 13:15:45 +0800 Subject: [PATCH] feat(dnn/common): add matmul impl for naive with matrix format mk4_dot GitOrigin-RevId: 7c6fbdfa973413b1b8d2f2fb0d18f8bb5ee7f243 --- dnn/scripts/opr_param_defs.py | 5 ++- dnn/src/common/matrix_mul.cpp | 2 ++ dnn/src/naive/matrix_mul/matrix_mul_helper.h | 29 +++++++++++++++ dnn/src/naive/matrix_mul/opr_impl.cpp | 37 +++++++++++--------- 4 files changed, 56 insertions(+), 17 deletions(-) diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 46e77482..5f4a9cec 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -433,7 +433,10 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) 'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), Doc('MK8', 'Split 8 from M and K, better for neon compute:' '(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' - 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))')) + 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), + Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:' + 'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' + 'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) ) (pdef('Winograd', 'winograd param used in convbias'). diff --git a/dnn/src/common/matrix_mul.cpp b/dnn/src/common/matrix_mul.cpp index b12871c4..484d3371 100644 --- a/dnn/src/common/matrix_mul.cpp +++ b/dnn/src/common/matrix_mul.cpp @@ -186,6 +186,8 @@ size_t MatrixMulForward::pack_size(const Param::Format format) { return 1; case Param::Format::MK4: return 4; + case Param::Format::MK4_DOT: + return 4; case Param::Format::MK8: return 8; default: diff --git a/dnn/src/naive/matrix_mul/matrix_mul_helper.h b/dnn/src/naive/matrix_mul/matrix_mul_helper.h index c688a944..7ea7606b 100644 --- a/dnn/src/naive/matrix_mul/matrix_mul_helper.h +++ b/dnn/src/naive/matrix_mul/matrix_mul_helper.h @@ -82,6 +82,35 @@ void run_matrix_mul_mk4_tpl(const itype* A, const itype* B, otype* C, size_t M, } } +template +void run_matrix_mul_mk4_dot_tpl(const itype* A, const itype* B, otype* C, + size_t M, size_t N, size_t K, size_t LDA, + size_t LDB, size_t LDC, const DType& A_type, + const DType& B_type) { + Getter getterA(A_type), getterB(B_type); + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + comp_type res[4] = {comp_type(0)}; + for (size_t k = 0; k < K; ++k) { + for (size_t i = 0; i < 4; i++) { + comp_type av, bv; + for (size_t j = 0; j < 4; j++) { + av = transA ? getterA(A[k * LDA + m * 16 + 4 * i + j]) + : getterA(A[m * LDA + k * 16 + 4 * i + j]), + bv = transB ? getterB(B[n * LDB + k * 4 + j]) + : getterB(B[k * LDB + n * 4 + j]); + res[i] += av * bv; + } + } + } + for (size_t i = 0; i < 4; i++) { + C[m * LDC + n * 4 + i] = res[i]; + } + } + } +} + template void run_matrix_mul_mk8_tpl(const itype* A, const itype* B, otype* C, size_t M, diff --git a/dnn/src/naive/matrix_mul/opr_impl.cpp b/dnn/src/naive/matrix_mul/opr_impl.cpp index 70ba4c49..5a3be598 100644 --- a/dnn/src/naive/matrix_mul/opr_impl.cpp +++ b/dnn/src/naive/matrix_mul/opr_impl.cpp @@ -38,22 +38,27 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B, auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], LDC = C.layout.stride[0]; -#define cb(_itype, _otype, _comp_type) \ - if (param.format == param::MatrixMul::Format::DEFAULT) { \ - return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \ - A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ - C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ - A.layout.dtype, B.layout.dtype); \ - } else if (param.format == param::MatrixMul::Format::MK4) { \ - return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \ - A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ - C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ - A.layout.dtype, B.layout.dtype); \ - } else if (param.format == param::MatrixMul::Format::MK8) { \ - return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ - A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ - C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ - A.layout.dtype, B.layout.dtype); \ +#define cb(_itype, _otype, _comp_type) \ + if (param.format == param::MatrixMul::Format::DEFAULT) { \ + return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \ + A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ + C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ + A.layout.dtype, B.layout.dtype); \ + } else if (param.format == param::MatrixMul::Format::MK4) { \ + return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \ + A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ + C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ + A.layout.dtype, B.layout.dtype); \ + } else if (param.format == param::MatrixMul::Format::MK4_DOT) { \ + return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ + A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ + C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ + A.layout.dtype, B.layout.dtype); \ + } else if (param.format == param::MatrixMul::Format::MK8) { \ + return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ + A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ + C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ + A.layout.dtype, B.layout.dtype); \ } if (A.layout.dtype == dtype::Float32()) { -- GitLab