提交 9320bf92 编写于 作者: M Megvii Engine Team

feat(mgb/dnn): add matmul mk4 dot naive test

GitOrigin-RevId: 2f16d4f89b900101977270eb6446541e5d558a32
上级 a6bc250d
......@@ -30,6 +30,7 @@ void run_matmul_mk_format(Handle* handle, param::MatrixMul::Format format,
auto extra_impl = [](const TensorNDArray& tensors, param::MatrixMul param,
Handle* handle, size_t pack_size) {
megdnn_assert((param.format == param::MatrixMul::Format::MK4 ||
param.format == param::MatrixMul::Format::MK4_DOT ||
param.format == param::MatrixMul::Format::MK8) &&
tensors.size() == 3);
param::MatrixMul new_param = param;
......@@ -41,18 +42,34 @@ void run_matmul_mk_format(Handle* handle, param::MatrixMul::Format format,
TensorLayoutArray default_layouts, mk4_layouts;
if (param.transposeA) {
default_layouts.emplace_back(tensors[0].layout.reshape({K, M}));
mk4_layouts.emplace_back(
default_layouts.back()
.reshape({K / pack_size, M / pack_size, pack_size,
pack_size})
.dimshuffle({0, 2, 1, 3}));
if (param.format == param::MatrixMul::Format::MK4_DOT) {
mk4_layouts.emplace_back(
default_layouts.back()
.reshape({K / pack_size, M / pack_size,
pack_size, pack_size})
.dimshuffle({0, 3, 1, 2}));
} else {
mk4_layouts.emplace_back(
default_layouts.back()
.reshape({K / pack_size, M / pack_size,
pack_size, pack_size})
.dimshuffle({0, 2, 1, 3}));
}
} else {
default_layouts.emplace_back(tensors[0].layout.reshape({M, K}));
mk4_layouts.emplace_back(
default_layouts.back()
.reshape({M / pack_size, K / pack_size, pack_size,
pack_size})
.dimshuffle({0, 3, 1, 2}));
if (param.format == param::MatrixMul::Format::MK4_DOT) {
mk4_layouts.emplace_back(
default_layouts.back()
.reshape({M / pack_size, K / pack_size,
pack_size, pack_size})
.dimshuffle({0, 2, 1, 3}));
} else {
mk4_layouts.emplace_back(
default_layouts.back()
.reshape({M / pack_size, K / pack_size,
pack_size, pack_size})
.dimshuffle({0, 3, 1, 2}));
}
}
if (param.transposeB) {
default_layouts.emplace_back(tensors[1].layout.reshape({N, K}));
......@@ -238,6 +255,11 @@ TEST_F(NAIVE, MATRIX_MUL_MK8) {
dtype::Int16(), dtype::Int16(), dtype::Int32());
}
TEST_F(NAIVE, MATRIX_MUL_MK4_DOT) {
run_matmul_mk_format(handle(), param::MatrixMul::Format::MK4_DOT,
dtype::Int8(), dtype::Int8(), dtype::Int32());
}
TEST_F(NAIVE, MATRIX_MUL_BFLOAT16) {
Checker<MatrixMul> checker(handle(), /* check_dispatch */ false);
MatrixMul::Param param, fp32_param;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册