提交 a6bc250d 编写于 作者: M Megvii Engine Team

feat(dnn/common): add matmul impl for naive with matrix format mk4_dot

GitOrigin-RevId: 7c6fbdfa973413b1b8d2f2fb0d18f8bb5ee7f243
上级 bb872965
......@@ -433,7 +433,10 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'),
Doc('MK8', 'Split 8 from M and K, better for neon compute:'
'(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the '
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'))
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'),
Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:'
'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the '
'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))'))
)
(pdef('Winograd', 'winograd param used in convbias').
......
......@@ -186,6 +186,8 @@ size_t MatrixMulForward::pack_size(const Param::Format format) {
return 1;
case Param::Format::MK4:
return 4;
case Param::Format::MK4_DOT:
return 4;
case Param::Format::MK8:
return 8;
default:
......
......@@ -82,6 +82,35 @@ void run_matrix_mul_mk4_tpl(const itype* A, const itype* B, otype* C, size_t M,
}
}
template <typename itype, typename otype, bool transA, bool transB,
typename comp_type = otype>
void run_matrix_mul_mk4_dot_tpl(const itype* A, const itype* B, otype* C,
size_t M, size_t N, size_t K, size_t LDA,
size_t LDB, size_t LDC, const DType& A_type,
const DType& B_type) {
Getter<itype, comp_type> getterA(A_type), getterB(B_type);
for (size_t m = 0; m < M; ++m) {
for (size_t n = 0; n < N; ++n) {
comp_type res[4] = {comp_type(0)};
for (size_t k = 0; k < K; ++k) {
for (size_t i = 0; i < 4; i++) {
comp_type av, bv;
for (size_t j = 0; j < 4; j++) {
av = transA ? getterA(A[k * LDA + m * 16 + 4 * i + j])
: getterA(A[m * LDA + k * 16 + 4 * i + j]),
bv = transB ? getterB(B[n * LDB + k * 4 + j])
: getterB(B[k * LDB + n * 4 + j]);
res[i] += av * bv;
}
}
}
for (size_t i = 0; i < 4; i++) {
C[m * LDC + n * 4 + i] = res[i];
}
}
}
}
template <typename itype, typename otype, bool transA, bool transB,
typename comp_type = otype>
void run_matrix_mul_mk8_tpl(const itype* A, const itype* B, otype* C, size_t M,
......
......@@ -38,22 +38,27 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B,
auto LDA = A.layout.stride[0], LDB = B.layout.stride[0],
LDC = C.layout.stride[0];
#define cb(_itype, _otype, _comp_type) \
if (param.format == param::MatrixMul::Format::DEFAULT) { \
return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
} else if (param.format == param::MatrixMul::Format::MK4) { \
return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
} else if (param.format == param::MatrixMul::Format::MK8) { \
return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
#define cb(_itype, _otype, _comp_type) \
if (param.format == param::MatrixMul::Format::DEFAULT) { \
return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
} else if (param.format == param::MatrixMul::Format::MK4) { \
return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
} else if (param.format == param::MatrixMul::Format::MK4_DOT) { \
return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
} else if (param.format == param::MatrixMul::Format::MK8) { \
return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
}
if (A.layout.dtype == dtype::Float32()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册