提交 871e6a51 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(dnn/x86): opt x86 quantized heuristic

GitOrigin-RevId: 72abe9efcc0653625ce022956cf67673f0b8cf6d
上级 6c29548d
/**
* \file dnn/src/x86/conv_bias/int8/algo_usable_preferred.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/x86/conv_bias/int8/algo_usable_preferred.h"
#include "src/x86/utils.h"
#if MEGDNN_X86_WITH_MKL_DNN
#include <mkldnn.hpp>
#endif
#include <cstring>
#if MEGDNN_X86_WITH_MKL_DNN
using namespace dnnl;
#endif
using namespace megdnn;
using namespace x86;
namespace megdnn {
namespace x86 {
bool chanwise_avx2_stride1_qint8_usable(
const ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
(param.bias_mode != BiasMode::BIAS) &&
((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(((param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int32) ||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS32)))) &&
fm.format == ConvBiasImpl::Param::Format::NCHW &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && (FH == 2 || FH == 3 || FH == 5 || FH == 7) &&
fm.stride[0] == 1 && fm.stride[1] == 1 && (fm.icpg == 1) &&
(fm.ocpg == 1) && is_supported(SIMDType::AVX2);
return aviliable;
}
bool chanwise_avx2_stride1_qint8_preferred(
const ConvBiasImpl::NCBKernSizeParam& param) {
MEGDNN_MARK_USED_VAR(param);
return true;
}
bool chanwise_avx2_stride1_qint8_usable_preferred(
const ConvBiasImpl::NCBKernSizeParam& param) {
return chanwise_avx2_stride1_qint8_usable(param) &&
chanwise_avx2_stride1_qint8_preferred(param);
}
bool chanwise_avx2_stride2_qint8_usable(
const ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
(param.bias_mode != BiasMode::BIAS) &&
((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(((param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int32) ||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS32)))) &&
fm.format == ConvBiasImpl::Param::Format::NCHW &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && (FH == 2 || FH == 3 || FH == 5 || FH == 7) &&
fm.stride[0] == 2 && fm.stride[1] == 2 && (fm.icpg == 1) &&
(fm.ocpg == 1) && is_supported(SIMDType::AVX2);
return aviliable;
}
bool chanwise_avx2_stride2_qint8_preferred(
const ConvBiasImpl::NCBKernSizeParam& param) {
MEGDNN_MARK_USED_VAR(param);
return true;
}
bool chanwise_avx2_stride2_qint8_usable_preferred(
const ConvBiasImpl::NCBKernSizeParam& param) {
return chanwise_avx2_stride2_qint8_usable(param) &&
chanwise_avx2_stride2_qint8_preferred(param);
}
bool direct_avx2_stride1_int8_usable(
const ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(((param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int32) ||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS32)) &&
param.bias_mode == BiasMode::NO_BIAS &&
param.nonlineMode == NonlineMode::IDENTITY)) &&
fm.format == ConvBiasImpl::Param::Format::NCHW &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7) &&
fm.stride[0] == 1 && fm.stride[1] == 1 &&
is_supported(SIMDType::AVX2);
return aviliable;
}
bool direct_avx2_stride1_int8_preferred(
const ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
auto IC = fm.icpg;
auto OC = fm.ocpg;
auto is_preferred = true;
if (IC > 128 && OC > 128)
is_preferred = false;
return is_preferred;
}
bool direct_avx2_stride1_int8_usable_preferred(
const ConvBiasImpl::NCBKernSizeParam& param) {
return direct_avx2_stride1_int8_usable(param) &&
direct_avx2_stride1_int8_preferred(param);
}
bool direct_avx2_stride2_int8_usable(
const ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(((param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int32) ||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS32)) &&
param.bias_mode == BiasMode::NO_BIAS &&
param.nonlineMode == NonlineMode::IDENTITY)) &&
fm.format == ConvBiasImpl::Param::Format::NCHW &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7) &&
fm.stride[0] == 2 && fm.stride[1] == 2 &&
is_supported(SIMDType::AVX2);
return aviliable;
}
bool direct_avx2_stride2_int8_preferred(
const ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
auto IC = fm.icpg;
auto OC = fm.ocpg;
auto is_preferred = false;
if (IC <= 31 && OC <= 31)
is_preferred = true;
return is_preferred;
}
bool direct_avx2_stride2_int8_usable_preferred(
const ConvBiasImpl::NCBKernSizeParam& param) {
return direct_avx2_stride2_int8_usable(param) &&
direct_avx2_stride2_int8_preferred(param);
}
#if MEGDNN_X86_WITH_MKL_DNN
bool mkldnn_qint8_usable(const ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
return (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS32 ||
param.dst_type.enumv() == DTypeEnum::Int32) &&
fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 && !fm.should_flip &&
param.bias_mode == BiasMode::NO_BIAS &&
param.nonlineMode == NonlineMode::IDENTITY;
}
bool mkldnn_qint8_preferred(const ConvBiasImpl::NCBKernSizeParam& param) {
MEGDNN_MARK_USED_VAR(param);
return is_supported(SIMDType::VNNI);
}
bool mkldnn_qint8_usable_preferred(
const ConvBiasImpl::NCBKernSizeParam& param) {
return mkldnn_qint8_usable(param) && mkldnn_qint8_preferred(param);
}
bool mkldnn_matmul_qint8_usable(const ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
return (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS32 ||
param.dst_type.enumv() == DTypeEnum::Int32) &&
fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 &&
fm.group == 1 && fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
param.bias_mode == BiasMode::NO_BIAS &&
param.nonlineMode == NonlineMode::IDENTITY &&
//! The matmul opr is only used in single thread
//! TODO:support the no pack matmul algo in fallback im2col + matmul
param.nr_threads == 1_z;
}
bool mkldnn_matmul_qint8_preferred(
const ConvBiasImpl::NCBKernSizeParam& param) {
auto is_preferred = true;
auto&& fm = param.filter_meta;
megdnn_assert_internal(fm.group == 1 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1);
// single channel conv should never use matrix mul
if (fm.ocpg == 1 || fm.icpg == 1)
is_preferred = false;
return is_preferred && is_supported(SIMDType::VNNI);
}
bool mkldnn_matmul_qint8_usable_preferred(
const ConvBiasImpl::NCBKernSizeParam& param) {
return mkldnn_matmul_qint8_usable(param) &&
mkldnn_matmul_qint8_preferred(param);
}
#endif
} // namespace x86
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/x86/conv_bias/int8/algo_usable_preferred.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/common/utils.h"
#include "src/x86/conv_bias/opr_impl.h"
namespace megdnn {
namespace x86 {
bool chanwise_avx2_stride1_qint8_usable(const ConvBiasImpl::NCBKernSizeParam&);
bool chanwise_avx2_stride1_qint8_preferred(
const ConvBiasImpl::NCBKernSizeParam&);
bool chanwise_avx2_stride1_qint8_usable_preferred(
const ConvBiasImpl::NCBKernSizeParam&);
bool chanwise_avx2_stride2_qint8_usable(const ConvBiasImpl::NCBKernSizeParam&);
bool chanwise_avx2_stride2_qint8_preferred(
const ConvBiasImpl::NCBKernSizeParam&);
bool chanwise_avx2_stride2_qint8_usable_preferred(
const ConvBiasImpl::NCBKernSizeParam&);
bool direct_avx2_stride1_int8_usable(const ConvBiasImpl::NCBKernSizeParam&);
bool direct_avx2_stride1_int8_preferred(const ConvBiasImpl::NCBKernSizeParam&);
bool direct_avx2_stride1_int8_usable_preferred(
const ConvBiasImpl::NCBKernSizeParam&);
bool direct_avx2_stride2_int8_usable(const ConvBiasImpl::NCBKernSizeParam&);
bool direct_avx2_stride2_int8_preferred(const ConvBiasImpl::NCBKernSizeParam&);
bool direct_avx2_stride2_int8_usable_preferred(
const ConvBiasImpl::NCBKernSizeParam&);
#if MEGDNN_X86_WITH_MKL_DNN
bool mkldnn_qint8_usable(const ConvBiasImpl::NCBKernSizeParam&);
bool mkldnn_qint8_preferred(const ConvBiasImpl::NCBKernSizeParam&);
bool mkldnn_qint8_usable_preferred(const ConvBiasImpl::NCBKernSizeParam&);
bool mkldnn_matmul_qint8_usable(const ConvBiasImpl::NCBKernSizeParam&);
bool mkldnn_matmul_qint8_preferred(const ConvBiasImpl::NCBKernSizeParam&);
bool mkldnn_matmul_qint8_usable_preferred(
const ConvBiasImpl::NCBKernSizeParam&);
#endif
} // namespace x86
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -14,6 +14,7 @@
#include "src/common/opr_delegate.h"
#include "src/common/utils.h"
#include "src/fallback/convolution/img2col_helper.h"
#include "src/x86/conv_bias/int8/algo_usable_preferred.h"
#include "src/x86/conv_bias/int8/avx2_chanwise_stride1.h"
#include "src/x86/conv_bias/int8/avx2_chanwise_stride2.h"
#include "src/x86/conv_bias/int8/avx2_direct_conv_stride1.h"
......@@ -37,25 +38,7 @@ using namespace x86;
bool ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::usable(
FallbackConvBiasImpl* /*opr*/, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
(param.bias_mode != BiasMode::BIAS) &&
((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(((param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int32) ||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS32)))) &&
fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && fm.stride[0] == 1 &&
fm.stride[1] == 1 && (fm.icpg == 1) && (fm.ocpg == 1) &&
is_supported(SIMDType::AVX2);
return aviliable;
return chanwise_avx2_stride1_qint8_usable(param);
}
WorkspaceBundle ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::get_bundle(
......@@ -94,28 +77,15 @@ ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::get_kimpls(
return avx2_chanwise_stride1::get_kimpls(param, bundle);
}
bool ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::is_preferred(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
return chanwise_avx2_stride1_qint8_preferred(param);
}
bool ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::usable(
FallbackConvBiasImpl* /*opr*/, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
(param.bias_mode != BiasMode::BIAS) &&
((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(((param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int32) ||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS32)))) &&
fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && fm.stride[0] == 2 &&
fm.stride[1] == 2 && (fm.icpg == 1) && (fm.ocpg == 1) &&
is_supported(SIMDType::AVX2);
return aviliable;
return chanwise_avx2_stride2_qint8_usable(param);
}
WorkspaceBundle ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::get_bundle(
......@@ -154,28 +124,15 @@ ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::get_kimpls(
return avx2_chanwise_stride2::get_kimpls(param, bundle);
}
bool ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::is_preferred(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
return chanwise_avx2_stride2_qint8_preferred(param);
}
bool ConvBiasImpl::AlgoDirectAvx2Stride1Int8::usable(
FallbackConvBiasImpl* /*opr*/, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(((param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int32) ||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS32)) &&
param.bias_mode == BiasMode::NO_BIAS &&
param.nonlineMode == NonlineMode::IDENTITY)) &&
fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7) &&
fm.stride[0] == 1 && fm.stride[1] == 1 &&
is_supported(SIMDType::AVX2);
return aviliable;
return direct_avx2_stride1_int8_usable(param);
}
WorkspaceBundle ConvBiasImpl::AlgoDirectAvx2Stride1Int8::get_bundle(
......@@ -224,19 +181,75 @@ ConvBiasImpl::AlgoDirectAvx2Stride1Int8::get_kimpls(
return direct_conv_avx2_stride1::get_kimpls(param, bundle);
}
bool ConvBiasImpl::AlgoDirectAvx2Stride1Int8::is_preferred(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
return direct_avx2_stride1_int8_preferred(param);
}
/* ===================== avx2 int8 stride 2 ===================== */
bool ConvBiasImpl::AlgoAVX2DirectConvStride2::usable(
FallbackConvBiasImpl* /*opr*/, const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
return direct_avx2_stride2_int8_usable(param);
}
WorkspaceBundle ConvBiasImpl::AlgoAVX2DirectConvStride2::get_bundle(
const NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
size_t N = param.n;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
size_t IH = param.isz[0];
size_t IW = param.isz[1];
size_t OH = param.osz[0];
size_t OW = param.osz[1];
size_t FH = fm.spatial[0];
size_t FW = fm.spatial[1];
size_t GROUP = fm.group;
size_t IC_STEP = 2, OC_STEP = 4;
size_t pad_h = fm.padding[0];
size_t pad_w = fm.padding[1];
size_t src_size = 0, filter_size = 0;
//! pack filter, pack src
filter_size = GROUP * round_up(OC, OC_STEP) * round_up(IC, IC_STEP) * FH *
FW * sizeof(int16_t);
//! avx256 iw max offset 32, caused by w_remain < 16
src_size = N * GROUP * div_ceil(IC, IC_STEP) * (IH + 2 * pad_h) *
(IW + 2 * pad_w) * 2 * sizeof(int8_t) +
32;
bool need_post_process = param.dst_type.enumv() == DTypeEnum::QuantizedS8;
if (need_post_process) {
size_t dst_tmp = N * GROUP * OC * OW * OH * sizeof(int32_t);
return WorkspaceBundle(nullptr, {src_size, filter_size, dst_tmp});
} else {
return WorkspaceBundle(nullptr, {src_size, filter_size});
}
}
size_t ConvBiasImpl::AlgoAVX2DirectConvStride2::get_workspace(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<fallback::ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoAVX2DirectConvStride2::get_kimpls(
const NCBKernSizeParam& param) const {
auto bundle = get_bundle(param);
return direct_conv_avx2_stride2::get_kimpls(param, bundle);
}
bool ConvBiasImpl::AlgoAVX2DirectConvStride2::is_preferred(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
return direct_avx2_stride2_int8_preferred(param);
}
#if MEGDNN_X86_WITH_MKL_DNN
bool ConvBiasImpl::AlgoMkldnnQint8::usable(FallbackConvBiasImpl*,
const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
return (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS32 ||
param.dst_type.enumv() == DTypeEnum::Int32) &&
fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 && !fm.should_flip &&
param.bias_mode == BiasMode::NO_BIAS &&
param.nonlineMode == NonlineMode::IDENTITY;
return mkldnn_qint8_usable(param);
}
WorkspaceBundle ConvBiasImpl::AlgoMkldnnQint8::get_bundle(
......@@ -412,39 +425,25 @@ void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32(
stream_mkldnn.wait();
}
}
#undef REORDER_MEMORY
#endif
#if MEGDNN_X86_WITH_MKL_DNN
bool ConvBiasImpl::AlgoMkldnnQint8::is_preferred(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
return mkldnn_qint8_preferred(param);
}
/* ===================== mkldnn qint8 matmul algo ===================== */
bool ConvBiasImpl::AlgoMkldnnMatmulQint8::usable(FallbackConvBiasImpl*,
const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
return (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS32 ||
param.dst_type.enumv() == DTypeEnum::Int32) &&
fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 &&
fm.group == 1 && fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
param.bias_mode == BiasMode::NO_BIAS &&
param.nonlineMode == NonlineMode::IDENTITY &&
//! The matmul opr is only used in single thread
//! TODO:support the no pack matmul algo in fallback im2col + matmul
param.nr_threads == 1_z;
return mkldnn_matmul_qint8_usable(param);
}
bool ConvBiasImpl::AlgoMkldnnMatmulQint8::is_preferred(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
auto&& fm = param.filter_meta;
megdnn_assert_internal(fm.group == 1 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1);
// single channel conv should never use matrix mul
if (fm.ocpg == 1 || fm.icpg == 1)
return false;
return true;
return mkldnn_matmul_qint8_preferred(param);
}
WorkspaceBundle ConvBiasImpl::AlgoMkldnnMatmulQint8::get_bundle(
const NCBKernSizeParam& param) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
......@@ -473,6 +472,7 @@ WorkspaceBundle ConvBiasImpl::AlgoMkldnnMatmulQint8::get_bundle(
}
return {nullptr, {part0, part1, part2}};
}
MatrixMul* ConvBiasImpl::AlgoMkldnnMatmulQint8::get_matmul_opr() {
static CpuOprDelegationStorage<> storage;
return storage.get<MatrixMul>();
......@@ -553,76 +553,5 @@ void ConvBiasImpl::AlgoMkldnnMatmulQint8::kern_mkldnn_matmul_s8x8x32(
}
#endif
/* ===================== avx2 int8 stride 2 ===================== */
bool ConvBiasImpl::AlgoAVX2DirectConvStride2::usable(
FallbackConvBiasImpl* /*opr*/, const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(((param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int32) ||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS32)) &&
param.bias_mode == BiasMode::NO_BIAS &&
param.nonlineMode == NonlineMode::IDENTITY)) &&
fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7) &&
fm.stride[0] == 2 && fm.stride[1] == 2 &&
is_supported(SIMDType::AVX2);
return aviliable;
}
WorkspaceBundle ConvBiasImpl::AlgoAVX2DirectConvStride2::get_bundle(
const NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
size_t N = param.n;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
size_t IH = param.isz[0];
size_t IW = param.isz[1];
size_t OH = param.osz[0];
size_t OW = param.osz[1];
size_t FH = fm.spatial[0];
size_t FW = fm.spatial[1];
size_t GROUP = fm.group;
size_t IC_STEP = 2, OC_STEP = 4;
size_t pad_h = fm.padding[0];
size_t pad_w = fm.padding[1];
size_t src_size = 0, filter_size = 0;
//! pack filter, pack src
filter_size = GROUP * round_up(OC, OC_STEP) * round_up(IC, IC_STEP) * FH *
FW * sizeof(int16_t);
//! avx256 iw max offset 32, caused by w_remain < 16
src_size = N * GROUP * div_ceil(IC, IC_STEP) * (IH + 2 * pad_h) *
(IW + 2 * pad_w) * 2 * sizeof(int8_t) +
32;
bool need_post_process = param.dst_type.enumv() == DTypeEnum::QuantizedS8;
if (need_post_process) {
size_t dst_tmp = N * GROUP * OC * OW * OH * sizeof(int32_t);
return WorkspaceBundle(nullptr, {src_size, filter_size, dst_tmp});
} else {
return WorkspaceBundle(nullptr, {src_size, filter_size});
}
}
size_t ConvBiasImpl::AlgoAVX2DirectConvStride2::get_workspace(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<fallback::ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoAVX2DirectConvStride2::get_kimpls(
const NCBKernSizeParam& param) const {
auto bundle = get_bundle(param);
return direct_conv_avx2_stride2::get_kimpls(param, bundle);
}
// vim: syntax=cpp.doxygen
......@@ -35,6 +35,8 @@ public:
return get_kimpls(param);
}
void* type() const override;
bool is_preferred(FallbackConvBiasImpl*,
const NCBKernSizeParam& param) const override;
};
/* ===================== avx2 stride2 chanwise algo ===================== */
......@@ -57,6 +59,8 @@ public:
return get_kimpls(param);
}
void* type() const override;
bool is_preferred(FallbackConvBiasImpl*,
const NCBKernSizeParam& param) const override;
};
/* ===================== avx2 stride1 direct algo ===================== */
......@@ -79,6 +83,32 @@ public:
return get_kimpls(param);
}
void* type() const override;
bool is_preferred(FallbackConvBiasImpl*,
const NCBKernSizeParam& param) const override;
};
/* ================== avx2 int8 direct conv stride2 algo ================== */
class ConvBiasImpl::AlgoAVX2DirectConvStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "X86_CONV_BIAS_DIRECT_AVX2_INT8_STRIDE2";
}
bool usable(FallbackConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(FallbackConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl*,
const NCBKernSizeParam& param) const override {
return get_kimpls(param);
}
void* type() const override;
bool is_preferred(FallbackConvBiasImpl*,
const NCBKernSizeParam& param) const override;
};
#if MEGDNN_X86_WITH_MKL_DNN
......@@ -117,6 +147,8 @@ public:
return {{kern, {group, n, 1_z}}};
}
void* type() const override;
bool is_preferred(FallbackConvBiasImpl*,
const NCBKernSizeParam& param) const override;
};
/* ===================== mkldnn qint8 matmul algo ===================== */
class ConvBiasImpl::AlgoMkldnnMatmulQint8 final : public AlgoBase {
......@@ -148,27 +180,7 @@ public:
void* type() const override;
};
#endif
/* ================== avx2 int8 direct conv stride2 algo ================== */
class ConvBiasImpl::AlgoAVX2DirectConvStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "X86_CONV_BIAS_DIRECT_AVX2_INT8_STRIDE2";
}
bool usable(FallbackConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(FallbackConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl*,
const NCBKernSizeParam& param) const override {
return get_kimpls(param);
}
void* type() const override;
};
} // namespace x86
} // namespace megdnn
......
......@@ -16,6 +16,7 @@
#include "src/common/metahelper.h"
#include "src/common/opr_delegate.h"
#include "src/x86/conv_bias/f32/algos.h"
#include "src/x86/conv_bias/int8/algo_usable_preferred.h"
#include "src/x86/conv_bias/int8/algos.h"
#include "src/x86/matrix_mul/opr_impl.h"
......@@ -94,12 +95,6 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
public:
AlgoPack() {
#if MEGDNN_X86_WITH_MKL_DNN
//! Create the mkldnn algo
all_algos.emplace_back(&mkldnn_conv_fp32);
all_algos.emplace_back(&mkldnn_matmul_qint8);
all_algos.emplace_back(&mkldnn_qint8);
#endif
all_algos.emplace_back(&stride1_direct_large_group);
all_algos.emplace_back(&stride1_direct_small_group);
all_algos.emplace_back(&stride2_direct_large_group);
......@@ -110,6 +105,14 @@ public:
all_algos.emplace_back(&avx2_stride2_chanwsie_qint8);
all_algos.emplace_back(&matmul);
//! preference to use mkldnn algo on VNNI devices
#if MEGDNN_X86_WITH_MKL_DNN
//! Create the mkldnn algo
all_algos.emplace_back(&mkldnn_conv_fp32);
all_algos.emplace_back(&mkldnn_matmul_qint8);
all_algos.emplace_back(&mkldnn_qint8);
#endif
static CpuOprDelegationStorage<> storage;
auto matmul_opr = storage.get<MatrixMul>();
auto&& matmul_algos =
......@@ -159,4 +162,25 @@ const char* ConvBiasImpl::get_algorithm_set_name() const {
return "X0";
}
bool ConvBiasImpl::is_matmul_quantized_prefer(
const ConvBiasImpl::NCBKernSizeParam& param) {
bool conv_direct_chanwise_mkldnn_usable = true;
if (param.dst_type.enumv() == DTypeEnum::QuantizedS8 ||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
conv_direct_chanwise_mkldnn_usable =
chanwise_avx2_stride1_qint8_usable_preferred(param) ||
chanwise_avx2_stride2_qint8_usable_preferred(param) ||
direct_avx2_stride1_int8_usable_preferred(param) ||
direct_avx2_stride2_int8_usable_preferred(param);
}
#if MEGDNN_X86_WITH_MKL_DNN
conv_direct_chanwise_mkldnn_usable =
conv_direct_chanwise_mkldnn_usable ||
mkldnn_qint8_usable_preferred(param) ||
mkldnn_matmul_qint8_usable_preferred(param);
#endif
return !conv_direct_chanwise_mkldnn_usable;
}
// vim: syntax=cpp.doxygen
......@@ -53,6 +53,9 @@ public:
size_t& IW2, size_t& OH2, size_t& OW2);
const char* get_algorithm_set_name() const override;
bool is_matmul_quantized_prefer(
const ConvBiasImpl::NCBKernSizeParam& ncb_param) override;
};
} // namespace x86
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册