提交 8ffed043 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/x86): fix matrix_mul quantized performance on vnni

GitOrigin-RevId: 4af6b8be60dd654003576bbd6b817411252fd306
上级 1d860f4d
......@@ -95,6 +95,14 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
public:
AlgoPack() {
//! FIXME: preference to use mkldnn algo on VNNI devices
//! But now mkldnn algo preference issue with NCHW->NHWC->NCHW
#if MEGDNN_X86_WITH_MKL_DNN
//! Create the mkldnn algo
all_algos.emplace_back(&mkldnn_conv_fp32);
all_algos.emplace_back(&mkldnn_matmul_qint8);
all_algos.emplace_back(&mkldnn_qint8);
#endif
all_algos.emplace_back(&stride1_direct_large_group);
all_algos.emplace_back(&stride1_direct_small_group);
all_algos.emplace_back(&stride2_direct_large_group);
......@@ -105,14 +113,6 @@ public:
all_algos.emplace_back(&avx2_stride2_chanwsie_qint8);
all_algos.emplace_back(&matmul);
//! preference to use mkldnn algo on VNNI devices
#if MEGDNN_X86_WITH_MKL_DNN
//! Create the mkldnn algo
all_algos.emplace_back(&mkldnn_conv_fp32);
all_algos.emplace_back(&mkldnn_matmul_qint8);
all_algos.emplace_back(&mkldnn_qint8);
#endif
static CpuOprDelegationStorage<> storage;
auto matmul_opr = storage.get<MatrixMul>();
auto&& matmul_algos =
......@@ -172,15 +172,18 @@ bool ConvBiasImpl::is_matmul_quantized_prefer(
chanwise_avx2_stride2_qint8_usable_preferred(param) ||
direct_avx2_stride1_int8_usable_preferred(param) ||
direct_avx2_stride2_int8_usable_preferred(param);
}
#if MEGDNN_X86_WITH_MKL_DNN
conv_direct_chanwise_mkldnn_usable =
conv_direct_chanwise_mkldnn_usable ||
mkldnn_qint8_usable_preferred(param) ||
mkldnn_matmul_qint8_usable_preferred(param);
conv_direct_chanwise_mkldnn_usable =
conv_direct_chanwise_mkldnn_usable ||
mkldnn_qint8_usable_preferred(param) ||
mkldnn_matmul_qint8_usable_preferred(param);
#endif
}
return !conv_direct_chanwise_mkldnn_usable;
return !conv_direct_chanwise_mkldnn_usable ||
(is_supported(SIMDType::VNNI) &&
!chanwise_avx2_stride1_qint8_usable_preferred(param) &&
!chanwise_avx2_stride2_qint8_usable_preferred(param));
}
// vim: syntax=cpp.doxygen
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/x86/matrix_mul/algos.h"
......@@ -17,7 +18,6 @@
#include "src/x86/matrix_mul/f32/strategy.h"
MIDOUT_DECL(megdnn_x86_matmul_kern)
MIDOUT_DECL(megdnn_x86_matmul_kern_mk8_8x8)
using namespace megdnn;
......@@ -45,17 +45,16 @@ void f32_blas_kern(const MatrixMulImpl::KernParam& kern_param) {
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
void f32_blas_kern_only_packA(const MatrixMulImpl::KernParam& kern_param,
const void* a_panel, const void* b_panel) {
MEGDNN_MARK_USED_VAR(b_panel);
const void* a_panel, const void* b_panel) {
MEGDNN_MARK_USED_VAR(b_panel);
auto m = kern_param.M, n = kern_param.N, k = kern_param.K;
const auto Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>();
auto Atrd = kern_param.LDA, Btrd = kern_param.LDB, Ctrd = kern_param.LDC;
disable_denorm();
cblas_sgemm_compute(CblasRowMajor, CblasPacked, CblasNoTrans, m, n, k,
static_cast<const float*>(a_panel), Atrd,
Bptr, Btrd, 0.0f, Cptr,
Ctrd);
static_cast<const float*>(a_panel), Atrd, Bptr, Btrd,
0.0f, Cptr, Ctrd);
}
#endif
......@@ -111,8 +110,9 @@ WorkspaceBundle MatrixMulImpl::AlgoF32MKLPackA::get_bundle(
return {nullptr, {a_size, 0, 0}};
}
void MatrixMulImpl::AlgoF32MKLPackA::pack_A(const KernParam& kern_param, void* out,
size_t index, size_t stride) const {
void MatrixMulImpl::AlgoF32MKLPackA::pack_A(const KernParam& kern_param,
void* out, size_t index,
size_t stride) const {
MEGDNN_MARK_USED_VAR(stride);
MEGDNN_MARK_USED_VAR(index);
auto m = kern_param.M, n = kern_param.N, k = kern_param.K;
......@@ -164,7 +164,7 @@ size_t get_kern_workspace(MatrixMulImpl::KernSizeParam kern_size_param) {
bool MatrixMulImpl::AlgoInt8x8x32Vnni::usable(
const KernSizeParam& kern_size_param) const {
return kern_size_param.A_type == kern_size_param.B_type &&
return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
((kern_size_param.A_type.enumv() == DTypeEnum::Int8 &&
kern_size_param.C_type.enumv() == DTypeEnum::Int32) ||
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
......@@ -389,9 +389,10 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_workspace(
m, n, k, trans_a, trans_b, strategy, cacheline)
.get_workspace_size();
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16, megdnn_x86_matmul_kern,
8, x86::matmul::gemm_avx2_s8s8s32_2x4x16,
dt_int8, dt_int32);
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16,
megdnn_x86_matmul_kern, 8,
x86::matmul::gemm_avx2_s8s8s32_2x4x16,
dt_int8, dt_int32);
/*************************AlgoInt8x8x32SSEM4N8K2********************/
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_kern(
......@@ -426,9 +427,10 @@ size_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_workspace(
m, n, k, trans_a, trans_b, strategy, cacheline)
.get_workspace_size();
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
AlgoInt8x8x32SSEM4N8K2, megdnn_x86_matmul_kern, 9,
x86::matmul::gemm_sse_s8s8s32_4x8x2, dt_int8, dt_int32, dt_int16);
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32SSEM4N8K2,
megdnn_x86_matmul_kern, 9,
x86::matmul::gemm_sse_s8s8s32_4x8x2,
dt_int8, dt_int32, dt_int16);
/*************************AlgoF32MK8_8x8********************/
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK8_8x8::get_kern(
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/x86/matrix_mul/opr_impl.h"
......@@ -41,9 +42,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
public:
AlgoPack() {
if (is_supported(SIMDType::VNNI)) {
#if MEGDNN_X86_WITH_MKL_DNN
all_algos.emplace_back(&algoint8x8x32mkldnn);
#endif
#if MEGDNN_X86_WITH_VNNI
all_algos.emplace_back(&algoint8x8x32vnni);
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册