提交 3bda3347 编写于 作者: M Megvii Engine Team

fix(dnn/fallback): fix segmentfault caused by im2col/conv1x1 using

 fallback naive matmul.

GitOrigin-RevId: 03ef904b113ffcc63e6e298827321fe085075b4a
上级 2eed7d83
......@@ -49,15 +49,33 @@ void kern_naive(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_fb_matmul_naive, void) {
size_t M = kern_param.M, N = kern_param.N, K = kern_param.K;
size_t LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
#define DISPATCH(TA, TB) \
if (kern_param.trA == TA && kern_param.trB == TB) { \
naive::dispatch_ta_tb<TA, TB>( \
kern_param.A_ptr, kern_param.B_ptr, kern_param.C_ptr, \
kern_param.workspace_ptr, M, N, K, LDA, LDB, LDC, \
kern_param.A_type, kern_param.B_type, kern_param.C_type, \
kern_param.format, kern_param.compute_mode); \
return; \
auto get_pack_size = [kern_param]() -> size_t {
switch (kern_param.format) {
case param::MatrixMul::Format::MK4:
case param::MatrixMul::Format::MK4_DOT:
return 4_z;
case param::MatrixMul::Format::MK8:
return 8_z;
default:
return 1_z;
}
};
size_t pack_size = get_pack_size();
megdnn_assert(
(M % pack_size == 0 && K % pack_size == 0),
"M and N must time of pack_size M: %zu N: %zu pack_size: %zu",
M, N, pack_size);
#define DISPATCH(TA, TB) \
if (kern_param.trA == TA && kern_param.trB == TB) { \
naive::dispatch_ta_tb<TA, TB>( \
kern_param.A_ptr, kern_param.B_ptr, kern_param.C_ptr, \
kern_param.workspace_ptr, M / pack_size, N, K / pack_size, \
LDA, LDB, LDC, kern_param.A_type, kern_param.B_type, \
kern_param.C_type, kern_param.format, \
kern_param.compute_mode); \
return; \
}
DISPATCH(true, true);
DISPATCH(true, false);
......
......@@ -46,6 +46,24 @@ TEST_F(FALLBACK, MATRIX_MUL) {
}
}
TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK4) {
matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{},
dtype::Float32{}, handle(), "FB_NAIVE",
param::MatrixMul::Format::MK4, 1);
}
TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK8) {
matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{},
dtype::Float32{}, handle(), "FB_NAIVE",
param::MatrixMul::Format::MK8, 1);
}
TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK4_DOT) {
matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{},
dtype::Float32{}, handle(), "FB_NAIVE",
param::MatrixMul::Format::MK4_DOT, 1);
}
TEST_F(FALLBACK, MATRIX_MUL_NAIVE) {
Checker<MatrixMul> checker(handle());
checker.set_before_exec_callback(AlgoChecker<MatrixMul>("FB_NAIVE"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册