#include "test/common/matrix_mul.h" #include "test/common/checker.h" #include "test/common/rng.h" #include "test/common/task_record_check.h" #include "test/fallback/fixture.h" namespace megdnn { namespace test { TEST_F(FALLBACK, MATRIX_MUL) { Checker checker(handle()); using Param = MatrixMul::Param; auto args = matrix_mul::get_matmul_args(); for (auto arg : args) { auto m = arg.m, n = arg.n, k = arg.k; auto mask = arg.mask; Param param; param.transposeA = mask & 1; param.transposeB = mask & 2; TensorShape AS, BS, CS; if (param.transposeA) AS = TensorShape{k, m}; else AS = TensorShape{m, k}; if (param.transposeB) BS = TensorShape{n, k}; else BS = TensorShape{k, n}; CS = TensorShape{m, n}; TensorLayout AL, BL, CL; AL = TensorLayout(AS, dtype::Float32()); BL = TensorLayout(BS, dtype::Float32()); CL = TensorLayout(CS, dtype::Float32()); checker.set_param(param); checker.execl({AL, BL, CL}); } } TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) { matrix_mul::check_matrix_mul( dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1); } TEST_F(FALLBACK, MATRIX_MUL_GI_F32_4x12) { matrix_mul::check_matrix_mul( dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_GI_F32_4x12"); } TEST_F(FALLBACK, MATRIX_MUL_GI_PACK_MK4) { matrix_mul::check_matrix_mul( dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_GI_F32_MK4_PACK_4x12", param::MatrixMul::Format::MK4, 1); } TEST_F(FALLBACK, MATRIX_MUL_RECORD) { TaskRecordChecker checker(1); using Param = MatrixMul::Param; auto args = matrix_mul::get_matmul_args(); for (auto arg : args) { auto m = arg.m, n = arg.n, k = arg.k; auto mask = arg.mask; Param param; param.transposeA = mask & 1; param.transposeB = mask & 2; TensorShape AS, BS, CS; if (param.transposeA) AS = TensorShape{k, m}; else AS = TensorShape{m, k}; if (param.transposeB) BS = TensorShape{n, k}; else BS = TensorShape{k, n}; CS = TensorShape{m, n}; TensorLayout AL, BL, CL; AL = TensorLayout(AS, dtype::Float32()); BL = TensorLayout(BS, dtype::Float32()); CL = TensorLayout(CS, dtype::Float32()); checker.set_param(param); checker.execl({AL, BL, CL}); } } TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK4) { matrix_mul::check_matrix_mul( dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE", param::MatrixMul::Format::MK4, 1); } TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK8) { matrix_mul::check_matrix_mul( dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE", param::MatrixMul::Format::MK8, 1); } TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK4_DOT) { matrix_mul::check_matrix_mul( dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE", param::MatrixMul::Format::MK4_DOT, 1); } TEST_F(FALLBACK, MATRIX_MUL_NAIVE) { Checker checker(handle()); checker.set_before_exec_callback(AlgoChecker("FB_NAIVE")); using Param = MatrixMul::Param; auto args = matrix_mul::get_matmul_args(); for (auto arg : args) { auto m = arg.m, n = arg.n, k = arg.k; auto mask = arg.mask; Param param; param.transposeA = mask & 1; param.transposeB = mask & 2; TensorShape AS, BS, CS; if (param.transposeA) AS = TensorShape{k, m}; else AS = TensorShape{m, k}; if (param.transposeB) BS = TensorShape{n, k}; else BS = TensorShape{k, n}; CS = TensorShape{m, n}; TensorLayout AL, BL, CL; AL = TensorLayout(AS, dtype::Float32()); BL = TensorLayout(BS, dtype::Float32()); CL = TensorLayout(CS, dtype::Float32()); checker.set_param(param); checker.execl({AL, BL, CL}); } } TEST_F(FALLBACK, BATCHED_MATRIX_MUL) { Checker checker(handle()); using Param = MatrixMul::Param; auto args = matrix_mul::get_batched_matmul_args(); for (auto arg : args) { auto b = arg.b, m = arg.m, n = arg.n, k = arg.k; auto mask = arg.mask; Param param; param.transposeA = mask & 1; param.transposeB = mask & 2; TensorShape AS, BS, CS; if (param.transposeA) AS = TensorShape{b, k, m}; else AS = TensorShape{b, m, k}; if (param.transposeB) BS = TensorShape{b, n, k}; else BS = TensorShape{b, k, n}; TensorLayout AL, BL; AL = TensorLayout(AS, dtype::Float32()); BL = TensorLayout(BS, dtype::Float32()); checker.set_param(param); checker.execs({AL, BL, {}}); } } #if MEGDNN_WITH_BENCHMARK TEST_F(FALLBACK, BENCHMARK_MATRIX_MUL_FB_GI_F32_4x12) { auto args = matrix_mul::get_benchmark_matmul_args(); matrix_mul::benchmark_single_algo( handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, "FB_GI_F32_4x12", param::MatrixMul::Format::DEFAULT); } TEST_F(FALLBACK, BENCHMARK_MATRIX_MUL_GI_PACK_MK4) { auto args = matrix_mul::get_benchmark_matmul_args(); matrix_mul::benchmark_single_algo( handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, "FB_GI_F32_MK4_PACK_4x12", param::MatrixMul::Format::MK4); } #endif } // namespace test } // namespace megdnn // vim: syntax=cpp.doxygen