提交 ca2828dd 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(dnn/x86): fix x86 int8 matmul ldc bug

GitOrigin-RevId: 2502f99000d5e90fdc410b1d0bf731668cd1077c
上级 aa4e8476
...@@ -71,13 +71,13 @@ void gemm_avx2_s8s8s32_2x4x16::kern(const dt_int8* pack_a_ptr, ...@@ -71,13 +71,13 @@ void gemm_avx2_s8s8s32_2x4x16::kern(const dt_int8* pack_a_ptr,
auto iter_a_ptr = pack_a_ptr + m_offset * roundup_k; auto iter_a_ptr = pack_a_ptr + m_offset * roundup_k;
for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) { for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) {
auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k;
auto iter_c_ptr = c_ptr + m_offset * n + n_offset; auto iter_c_ptr = c_ptr + m_offset * ldc + n_offset;
matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16(iter_a_ptr, iter_b_ptr, matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16(iter_a_ptr, iter_b_ptr,
iter_c_ptr, ldc, k); iter_c_ptr, ldc, k);
} }
if (n_end < n) { if (n_end < n) {
auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; auto iter_b_ptr = pack_b_ptr + n_end * roundup_k;
auto iter_c_ptr = c_ptr + m_offset * n + n_end; auto iter_c_ptr = c_ptr + m_offset * ldc + n_end;
matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16_remain( matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16_remain(
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_tile, iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_tile,
n_remain); n_remain);
...@@ -87,14 +87,14 @@ void gemm_avx2_s8s8s32_2x4x16::kern(const dt_int8* pack_a_ptr, ...@@ -87,14 +87,14 @@ void gemm_avx2_s8s8s32_2x4x16::kern(const dt_int8* pack_a_ptr,
auto iter_a_ptr = pack_a_ptr + m_end * roundup_k; auto iter_a_ptr = pack_a_ptr + m_end * roundup_k;
for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) { for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) {
auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k;
auto iter_c_ptr = c_ptr + m_end * n + n_offset; auto iter_c_ptr = c_ptr + m_end * ldc + n_offset;
matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16_remain( matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16_remain(
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain, iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain,
n_tile); n_tile);
} }
if (n_end < n) { if (n_end < n) {
auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; auto iter_b_ptr = pack_b_ptr + n_end * roundup_k;
auto iter_c_ptr = c_ptr + m_end * n + n_end; auto iter_c_ptr = c_ptr + m_end * ldc + n_end;
matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16_remain( matmul_avx2_2x4x16::kern_gemm_s8s8s32_2x4x16_remain(
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain, iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain,
n_remain); n_remain);
......
...@@ -59,13 +59,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr, ...@@ -59,13 +59,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr,
auto iter_a_ptr = pack_a_ptr + m_offset * roundup_k; auto iter_a_ptr = pack_a_ptr + m_offset * roundup_k;
for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) { for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) {
auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k;
auto iter_c_ptr = c_ptr + m_offset * n + n_offset; auto iter_c_ptr = c_ptr + m_offset * ldc + n_offset;
matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2( matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2(
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k); iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k);
} }
if (n_remain > 0) { if (n_remain > 0) {
auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; auto iter_b_ptr = pack_b_ptr + n_end * roundup_k;
auto iter_c_ptr = c_ptr + m_offset * n + n_end; auto iter_c_ptr = c_ptr + m_offset * ldc + n_end;
if (n_remain <= 8) { if (n_remain <= 8) {
matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n( matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n(
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, n_remain); iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, n_remain);
...@@ -79,13 +79,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr, ...@@ -79,13 +79,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr,
auto iter_a_ptr = pack_a_ptr + m_end * roundup_k; auto iter_a_ptr = pack_a_ptr + m_end * roundup_k;
for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) { for (size_t n_offset = 0; n_offset < n_end; n_offset += n_tile) {
auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k;
auto iter_c_ptr = c_ptr + m_end * n + n_offset; auto iter_c_ptr = c_ptr + m_end * ldc + n_offset;
matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2_remain_m( matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2_remain_m(
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain); iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain);
} }
if (n_remain > 0) { if (n_remain > 0) {
auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; auto iter_b_ptr = pack_b_ptr + n_end * roundup_k;
auto iter_c_ptr = c_ptr + m_end * n + n_end; auto iter_c_ptr = c_ptr + m_end * ldc + n_end;
if (n_remain <= 8) { if (n_remain <= 8) {
matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n( matmul_avx2_4x16x2::kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n(
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain, iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain,
......
...@@ -59,13 +59,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr, ...@@ -59,13 +59,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr,
auto iter_a_ptr = pack_a_ptr + m_offset * roundup_k; auto iter_a_ptr = pack_a_ptr + m_offset * roundup_k;
for (int n_offset = 0; n_offset < n_end; n_offset += n_tile) { for (int n_offset = 0; n_offset < n_end; n_offset += n_tile) {
auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k;
auto iter_c_ptr = c_ptr + m_offset * n + n_offset; auto iter_c_ptr = c_ptr + m_offset * ldc + n_offset;
matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2( matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2(
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k); iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k);
} }
if (n_remain > 0) { if (n_remain > 0) {
auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; auto iter_b_ptr = pack_b_ptr + n_end * roundup_k;
auto iter_c_ptr = c_ptr + m_offset * n + n_end; auto iter_c_ptr = c_ptr + m_offset * ldc + n_end;
matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2_remain_n( matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2_remain_n(
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, n_remain); iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, n_remain);
} }
...@@ -74,13 +74,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr, ...@@ -74,13 +74,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr,
auto iter_a_ptr = pack_a_ptr + m_end * roundup_k; auto iter_a_ptr = pack_a_ptr + m_end * roundup_k;
for (int n_offset = 0; n_offset < n_end; n_offset += n_tile) { for (int n_offset = 0; n_offset < n_end; n_offset += n_tile) {
auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k; auto iter_b_ptr = pack_b_ptr + n_offset * roundup_k;
auto iter_c_ptr = c_ptr + m_end * n + n_offset; auto iter_c_ptr = c_ptr + m_end * ldc + n_offset;
matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2_remain_m( matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2_remain_m(
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain); iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain);
} }
if (n_remain > 0) { if (n_remain > 0) {
auto iter_b_ptr = pack_b_ptr + n_end * roundup_k; auto iter_b_ptr = pack_b_ptr + n_end * roundup_k;
auto iter_c_ptr = c_ptr + m_end * n + n_end; auto iter_c_ptr = c_ptr + m_end * ldc + n_end;
matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2_remain_m_n( matmul_sse_4x8x2::kern_gemm_s8s8s32_sse_4x8x2_remain_m_n(
iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain, iter_a_ptr, iter_b_ptr, iter_c_ptr, ldc, k, m_remain,
n_remain); n_remain);
......
...@@ -78,6 +78,7 @@ protected: ...@@ -78,6 +78,7 @@ protected:
TensorsConstriant m_tensor_constraint; TensorsConstriant m_tensor_constraint;
bool m_no_naive_and_check = false; bool m_no_naive_and_check = false;
bool m_stable_check = false; bool m_stable_check = false;
bool m_force_deduce_dst = true;
/** /**
* the offset from the start of malloc memory * the offset from the start of malloc memory
* *
...@@ -236,6 +237,12 @@ public: ...@@ -236,6 +237,12 @@ public:
return *this; return *this;
} }
//! froce deduce dst
Checker& set_force_deduce_dst(bool force_deduce_dst) {
m_force_deduce_dst = force_deduce_dst;
return *this;
}
Checker& set_no_naive_check(bool no_naive_and_check) { Checker& set_no_naive_check(bool no_naive_and_check) {
m_no_naive_and_check = no_naive_and_check; m_no_naive_and_check = no_naive_and_check;
return *this; return *this;
...@@ -343,7 +350,10 @@ void Checker<Opr, Proxy>::exec(TensorLayoutArray layouts) { ...@@ -343,7 +350,10 @@ void Checker<Opr, Proxy>::exec(TensorLayoutArray layouts) {
auto opr_cur = this->opr(); auto opr_cur = this->opr();
opr_naive->param() = m_param; opr_naive->param() = m_param;
opr_cur->param() = m_param; opr_cur->param() = m_param;
m_naive_proxy.deduce_layout(opr_naive.get(), layouts); bool deduce_layout = layouts.back().ndim == 0;
if (deduce_layout || m_force_deduce_dst) {
m_naive_proxy.deduce_layout(opr_naive.get(), layouts);
}
auto exec_naive = [this, &opr_naive, &layouts, auto exec_naive = [this, &opr_naive, &layouts,
&opr_relayout](const TensorValueArray& values) { &opr_relayout](const TensorValueArray& values) {
TensorValueArray contig_values = values; TensorValueArray contig_values = values;
......
...@@ -101,7 +101,7 @@ std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args_mask( ...@@ -101,7 +101,7 @@ std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args_mask(
size_t Astride = mask & 1 ? m + 2 : k + 2; size_t Astride = mask & 1 ? m + 2 : k + 2;
// B: (k, n) // B: (k, n)
size_t Bstride = mask & 2 ? k + 2 : n + 2; size_t Bstride = mask & 2 ? k + 2 : n + 2;
size_t Cstride = n + 2; size_t Cstride = n * 2 + 2;
args.emplace_back(m, n, k, mask, Astride, Bstride, Cstride); args.emplace_back(m, n, k, mask, Astride, Bstride, Cstride);
} }
return args; return args;
...@@ -183,9 +183,11 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, ...@@ -183,9 +183,11 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
Handle* handle, Handle* handle,
const ExecutionPolicyAlgoName& algo, const ExecutionPolicyAlgoName& algo,
param::MatrixMul::Format format, size_t nbase, param::MatrixMul::Format format, size_t nbase,
float eps, std::vector<TestArg>&& user_args) { float eps, std::vector<TestArg>&& user_args,
bool force_deduce_dst) {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); megdnn_assert(A_dtype.enumv() == B_dtype.enumv());
Checker<Opr> checker(handle); Checker<Opr> checker(handle);
checker.set_force_deduce_dst(force_deduce_dst);
if (!algo.name.empty()) { if (!algo.name.empty()) {
checker.set_before_exec_callback(AlgoChecker<Opr>(algo)); checker.set_before_exec_callback(AlgoChecker<Opr>(algo));
} }
...@@ -245,16 +247,16 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, ...@@ -245,16 +247,16 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
for (auto& arg : args) { for (auto& arg : args) {
size_t m = arg.m, n = arg.n, k = arg.k; size_t m = arg.m, n = arg.n, k = arg.k;
#if MEGDNN_WITH_CUDA if (handle->type() == Handle::HandleType::CUDA) {
//[NOTE]: cublas can only process 4B aligned 8-bit input matrix; //! NOTE: cublas can only process 4B aligned 8-bit input matrix;
bool is_dt_8bit = A_dtype.enumv() == DTypeEnum::Int8 || bool is_dt_8bit = A_dtype.enumv() == DTypeEnum::Int8 ||
A_dtype.enumv() == DTypeEnum::QuantizedS8 || A_dtype.enumv() == DTypeEnum::QuantizedS8 ||
A_dtype.enumv() == DTypeEnum::Uint8 || A_dtype.enumv() == DTypeEnum::Uint8 ||
A_dtype.enumv() == DTypeEnum::Quantized8Asymm; A_dtype.enumv() == DTypeEnum::Quantized8Asymm;
if (is_dt_8bit && ((m % 4 != 0) || (n % 4 != 0))) { if (is_dt_8bit && ((m % 4 != 0) || (n % 4 != 0))) {
continue; continue;
}
} }
#endif
Param param; Param param;
param.transposeA = arg.mask & 0x1; param.transposeA = arg.mask & 0x1;
...@@ -312,20 +314,22 @@ void matrix_mul::check_batched_matrix_mul(DType A_dtype, DType B_dtype, ...@@ -312,20 +314,22 @@ void matrix_mul::check_batched_matrix_mul(DType A_dtype, DType B_dtype,
DType C_dtype, Handle* handle, DType C_dtype, Handle* handle,
const ExecutionPolicyAlgoName& algo, const ExecutionPolicyAlgoName& algo,
float eps, float eps,
std::vector<TestArg>&& args) { std::vector<TestArg>&& args,
bool force_deduce_dst) {
check_matrix_mul<megdnn::BatchedMatrixMul>( check_matrix_mul<megdnn::BatchedMatrixMul>(
A_dtype, B_dtype, C_dtype, handle, algo, A_dtype, B_dtype, C_dtype, handle, algo,
param::MatrixMul::Format::DEFAULT, 8, eps, param::MatrixMul::Format::DEFAULT, 8, eps,
std::forward<decltype(args)>(args)); std::forward<decltype(args)>(args), force_deduce_dst);
} }
void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
Handle* handle, Handle* handle,
const ExecutionPolicyAlgoName& algo, const ExecutionPolicyAlgoName& algo,
param::MatrixMul::Format format, size_t nbase, param::MatrixMul::Format format, size_t nbase,
float eps) { float eps, bool force_deduce_dst) {
check_matrix_mul<megdnn::MatrixMul>(A_dtype, B_dtype, C_dtype, handle, algo, check_matrix_mul<megdnn::MatrixMul>(A_dtype, B_dtype, C_dtype, handle, algo,
format, nbase, eps); format, nbase, eps, {},
force_deduce_dst);
} }
#if MEGDNN_WITH_BENCHMARK #if MEGDNN_WITH_BENCHMARK
......
...@@ -68,19 +68,21 @@ void check_matrix_mul( ...@@ -68,19 +68,21 @@ void check_matrix_mul(
DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle,
const ExecutionPolicyAlgoName& algo = {"", {}}, const ExecutionPolicyAlgoName& algo = {"", {}},
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT,
size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {}); size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {},
bool force_deduce_dst = true);
void check_matrix_mul( void check_matrix_mul(
DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle,
const ExecutionPolicyAlgoName& algo = {"", {}}, const ExecutionPolicyAlgoName& algo = {"", {}},
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT,
size_t nbase = 8, float eps = 1e-3); size_t nbase = 8, float eps = 1e-3, bool force_deduce_dst = true);
void check_batched_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, void check_batched_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
Handle* handle, Handle* handle,
const ExecutionPolicyAlgoName& algo = {"", {}}, const ExecutionPolicyAlgoName& algo = {"", {}},
float eps = 1e-3, float eps = 1e-3,
std::vector<TestArg>&& args = {}); std::vector<TestArg>&& args = {},
bool force_deduce_dst = true);
#if MEGDNN_WITH_BENCHMARK #if MEGDNN_WITH_BENCHMARK
std::vector<TestArg> get_benchmark_matmul_args(); std::vector<TestArg> get_benchmark_matmul_args();
......
...@@ -44,21 +44,31 @@ TEST_F(X86, MATRIX_MUL_MKLDNN_8X8X32) { ...@@ -44,21 +44,31 @@ TEST_F(X86, MATRIX_MUL_MKLDNN_8X8X32) {
//! FIXME: need to add tests of GEMV and QUINT8 //! FIXME: need to add tests of GEMV and QUINT8
TEST_F(X86, MATRIX_MUL_AVX2_8X8X32) { TEST_F(X86, MATRIX_MUL_AVX2_8X8X32) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
handle(), "X86_INT8X8X32_AVX2_2X4X16"); handle(), "X86_INT8X8X32_AVX2_2X4X16",
param::MatrixMul::Format::DEFAULT, 8, 1e-3,
false);
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
handle(), "X86_INT8X8X32_AVX2_4X16X2"); handle(), "X86_INT8X8X32_AVX2_4X16X2",
param::MatrixMul::Format::DEFAULT, 8, 1e-3,
false);
} }
TEST_F(X86, MATRIX_MUL_AVX2_8X8X16) { TEST_F(X86, MATRIX_MUL_AVX2_8X8X16) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
handle(), "X86_INT8X8X16_AVX2"); handle(), "X86_INT8X8X16_AVX2",
param::MatrixMul::Format::DEFAULT, 8, 1e-3,
false);
} }
TEST_F(X86, MATRIX_MUL_SSE_8X8X16) { TEST_F(X86, MATRIX_MUL_SSE_8X8X16) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
handle(), "X86_INT8X8X16_SSE"); handle(), "X86_INT8X8X16_SSE",
param::MatrixMul::Format::DEFAULT, 8, 1e-3,
false);
} }
TEST_F(X86, MATRIX_MUL_SSE_8X8X32) { TEST_F(X86, MATRIX_MUL_SSE_8X8X32) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
handle(), "X86_INT8X8X32_SSE_4X8X2"); handle(), "X86_INT8X8X32_SSE_4X8X2",
param::MatrixMul::Format::DEFAULT, 8, 1e-3,
false);
} }
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
...@@ -72,7 +82,7 @@ TEST_F(X86, MATRIX_MUL_MKL_PACKA) { ...@@ -72,7 +82,7 @@ TEST_F(X86, MATRIX_MUL_MKL_PACKA) {
TEST_F(X86, MATRIX_MUL_AVX2_MK8_8X8) { TEST_F(X86, MATRIX_MUL_AVX2_MK8_8X8) {
matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{},
dtype::Float32{}, handle(), "X86_F32MK8_8X8", dtype::Float32{}, handle(), "X86_F32MK8_8X8",
param::MatrixMul::Format::MK8, 1); param::MatrixMul::Format::MK8, 1, 1e-3, false);
} }
#if MEGDNN_WITH_BENCHMARK #if MEGDNN_WITH_BENCHMARK
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册