提交 2a3f4d09 编写于 作者: M Megvii Engine Team

refactor(dnn/arm): refactor CPU heuristic algo selection

GitOrigin-RevId: 60d2646bb33316411caa18686eec724dc1f6c430
上级 981f487b
......@@ -76,6 +76,18 @@ enum class AlgoSelectionStrategy {
FULL_RUN = 2,
};
/**
* \brief separate algo by datatype for Matmul and conv
*/
enum class AlgoDataType : uint32_t {
FLOAT32 = 1 << 0,
FLOAT16 = 1 << 1,
QINT8X8X32 = 1 << 2,
QUINT8X8X32 = 1 << 3,
INT8X8X16 = 1 << 4,
INT16X16X32 = 1 << 5,
};
/*!
* \brief Abstract representation of an algorithm for implementing
* the operator
......
......@@ -27,6 +27,10 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT};
}
};
} // namespace aarch64
} // namespace megdnn
......
......@@ -32,6 +32,10 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
};
} // namespace aarch64
......
......@@ -45,6 +45,9 @@ public:
return static_cast<ConvBiasImpl*>(conv_bias_opr)
->is_matmul_quantized_prefer(param);
}
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL};
}
};
} // namespace aarch64
......
......@@ -50,10 +50,9 @@ SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
auto&& algos = arm_common::ConvBiasImpl::algo_pack();
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(),
sl_algo_pack.direct_algos.end());
//! We put matmul algos at the end. Because matmul will get privilege when
//! We put matmul algos at the begin. Because matmul will get privilege when
//! prefer return true. See
//! fallback::ConvolutionImpl::ncb_1g_get_all_algorithms for more details.
algos.insert(algos.end(), sl_algo_pack.matmul_algos.begin(),
algos.insert(algos.begin(), sl_algo_pack.matmul_algos.begin(),
sl_algo_pack.matmul_algos.end());
return std::move(algos);
}
......
......@@ -45,6 +45,9 @@ public:
return static_cast<ConvBiasImpl*>(conv_bias_opr)
->is_matmul_quantized_prefer(param);
}
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL};
}
};
} // namespace aarch64
} // namespace megdnn
......
......@@ -89,7 +89,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern,
"AlgoF32K8x12x1Impl"_hash,
aarch64::matmul::sgemm_8x12, float, float);
aarch64::matmul::sgemm_8x12, float, float,
AlgoDataType::FLOAT32, DEFAULT);
/* ===================== F32_MK4_8X12X1 algo ===================== */
bool MatrixMulImpl::AlgoF32MK4_8x12x1::usable(
......@@ -151,7 +152,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4_8x12x1,
megdnn_aarch64_matmul_kern,
"AlgoF32MK4_8x12x1Impl"_hash,
aarch64::matmul::sgemm_mk4_8x12, float,
float);
float, AlgoDataType::FLOAT32, MK4);
/* ===================== F32K4X16X1 algo ===================== */
......@@ -210,7 +211,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K4x16x1::get_kern(
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K4x16x1, megdnn_aarch64_matmul_kern,
"AlgoF32K4x16x1Impl"_hash,
aarch64::matmul::sgemm_4x16, float, float);
aarch64::matmul::sgemm_4x16, float, float,
AlgoDataType::FLOAT32, MK4);
/* ===================== F32MK4_4x16 algo ===================== */
......@@ -328,7 +330,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K8x24x1::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K8x24x1, megdnn_aarch64_matmul_kern,
"AlogF16K8x24x1Impl"_hash,
aarch64::matmul::hgemm_8x24, dt_float16,
dt_float16);
dt_float16, AlgoDataType::FLOAT16,
DEFAULT);
/* ===================== F16_MK8_8x8 algo ===================== */
bool MatrixMulImpl::AlgoF16MK8_8x8::usable(
......@@ -449,7 +452,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd,
megdnn_aarch64_matmul_kern,
"AlgoInt8x8x32K8x12x4DotProdImpl"_hash,
aarch64::matmul::gemm_s8_8x12, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32,
DEFAULT);
/* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */
namespace {
......@@ -520,7 +524,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x12x4DotProd,
megdnn_aarch64_matmul_kern,
"AlgoInt8x8x32MK4_8x12x4DotProdImpl"_hash,
aarch64::matmul::gemm_mk4_s8_8x12, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32,
MK4_DOT);
#else
/* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */
......@@ -593,7 +598,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x4x16,
megdnn_aarch64_matmul_kern,
"AlgoInt8x8x32MK4_4x4x16Impl"_hash,
aarch64::matmul::gemm_mk4_s8_4x4, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32,
MK4);
/* ===================== Int8x8x32 K4x4x16 algo ===================== */
namespace {
......@@ -656,7 +662,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x4x16,
megdnn_aarch64_matmul_kern,
"AlgoInt8x8x32K4x4x16Impl"_hash,
aarch64::matmul::gemm_s8_4x4, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32,
DEFAULT);
/* ===================== Int8x8x32 K8x8x8 algo ===================== */
namespace {
void int8x8x32_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
......@@ -717,7 +724,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x8x8,
megdnn_aarch64_matmul_kern,
"AlgoInt8x8x32K8x8x8Impl"_hash,
aarch64::matmul::gemm_s8_8x8, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32,
DEFAULT);
#endif
/* ===================== Int8x8x16 K8x8x8 algo ===================== */
......@@ -785,7 +793,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K8x8x8,
megdnn_aarch64_matmul_kern,
"AlgoInt8x8x16K8x8x8Impl"_hash,
aarch64::matmul::gemm_s8x8x16_8x8, int8_t,
int16_t);
int16_t, AlgoDataType::INT8X8X16, DEFAULT);
/* ===================== Int8x8x16 K4x4x16 algo ===================== */
namespace {
void int8x8x16_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) {
......@@ -852,7 +860,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x4x16,
megdnn_aarch64_matmul_kern,
"AlgoInt8x8x16K4x4x16Impl"_hash,
aarch64::matmul::gemm_s8x8x16_4x4, int8_t,
int16_t);
int16_t, AlgoDataType::INT8X8X16, DEFAULT);
/* ===================== Int8x8x16 K16x12x4 algo ===================== */
namespace {
......@@ -929,7 +937,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
AlgoInt8x8x16MK4_16x12x4, megdnn_aarch64_matmul_kern,
"AlgoInt8x8x16MK4_16x12x4Impl"_hash,
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53, int8_t, int16_t, int16_t);
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53, int8_t, int16_t, int16_t,
AlgoDataType::INT8X8X16, MK4);
/* ===================== Int8x8x16 MK4 4x4x8 algo ===================== */
namespace {
......@@ -1007,7 +1016,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_4x4x8,
megdnn_aarch64_matmul_kern,
"AlgoInt8x8x16MK4_4x4x8_Impl"_hash,
aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72,
int8_t, int16_t);
int8_t, int16_t, AlgoDataType::INT8X8X16,
MK4);
/* ===================== Int16x16x32 K12x8x1 algo ===================== */
namespace {
......@@ -1078,7 +1088,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x8x1,
megdnn_aarch64_matmul_kern,
"AlgoInt16x16x32K12x8x1Impl"_hash,
aarch64::matmul::gemm_s16_12x8x1, int16_t,
int32_t);
int32_t, AlgoDataType::INT16X16X32,
DEFAULT);
/* ===================== Int16x16x32MK8_8x8 algo ===================== */
......@@ -1201,7 +1212,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x4DotProd,
megdnn_aarch64_matmul_kern,
"AlgoQuint8K8x8x4DotProdImpl"_hash,
aarch64::matmul::gemm_u8_8x8, uint8_t,
int32_t);
int32_t, AlgoDataType::QUINT8X8X32,
DEFAULT);
/* ===================== Quint8 Gemv DotProd algo ===================== */
namespace {
void quint8_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
......@@ -1307,7 +1319,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8,
megdnn_aarch64_matmul_kern,
"AlgoQuint8K8x8x8Impl"_hash,
aarch64::matmul::gemm_u8_8x8, uint8_t,
int32_t);
int32_t, AlgoDataType::QUINT8X8X32,
DEFAULT);
#endif
/* ===================== Int8x8x16 K8x8x8 algo ===================== */
......@@ -1378,6 +1391,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8,
megdnn_aarch64_matmul_kern,
"AlgoInt8x8x16MK4_K8x8x8Impl"_hash,
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, int8_t,
int16_t);
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8,
int8_t, int16_t, AlgoDataType::INT8X8X16,
MK4);
// vim: syntax=cpp.doxygen
......@@ -61,7 +61,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4)
};
class MatrixMulImpl::AlgoF32Gemv final
......@@ -88,7 +88,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8)
};
#endif
......@@ -253,7 +253,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8)
};
#if __ARM_FEATURE_DOTPROD
......@@ -281,7 +281,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT)
};
#else
......
......@@ -29,7 +29,7 @@ public:
}
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
};
class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase {
......@@ -44,7 +44,7 @@ public:
}
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
};
class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase {
......@@ -60,7 +60,7 @@ public:
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
};
class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase {
public:
......@@ -74,7 +74,7 @@ public:
}
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
};
class ConvBiasImpl::AlgoF16Direct final : public AlgoBase {
......@@ -90,6 +90,10 @@ public:
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override{
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase {
......@@ -103,6 +107,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT};
}
};
} // namespace arm_common
......
......@@ -29,7 +29,7 @@ public:
}
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
};
class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase {
......@@ -44,7 +44,7 @@ public:
}
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
};
class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase {
......@@ -59,7 +59,7 @@ public:
}
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
};
class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase {
......@@ -74,7 +74,7 @@ public:
}
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
};
class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase {
......@@ -89,7 +89,7 @@ public:
}
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
};
//===================== NCHW44 Winograd Support =====================//
......@@ -106,7 +106,7 @@ public:
}
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
};
class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase {
......@@ -122,7 +122,7 @@ public:
}
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
};
class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase {
......@@ -138,7 +138,7 @@ public:
}
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
};
// ================================================================= //
......@@ -154,6 +154,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase {
......@@ -168,6 +171,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
......@@ -182,6 +188,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase {
......@@ -197,6 +206,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase {
......@@ -212,6 +224,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase {
......@@ -226,6 +241,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
};
} // namespace arm_common
......
......@@ -29,6 +29,10 @@ public:
const NCBKernSizeParam& param) const override;
bool is_preferred(const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase {
......@@ -42,6 +46,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase {
......@@ -55,6 +62,9 @@ public:
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
bool is_preferred(const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase {
......@@ -68,6 +78,9 @@ public:
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
bool is_preferred(const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase {
......@@ -79,6 +92,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase {
......@@ -90,6 +106,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
};
#if __ARM_FEATURE_DOTPROD
......@@ -104,6 +123,9 @@ public:
size_t get_workspace(const NCBKernSizeParam&) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase {
......@@ -117,6 +139,9 @@ public:
size_t get_workspace(const NCBKernSizeParam&) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase {
......@@ -131,6 +156,9 @@ public:
size_t get_workspace(const NCBKernSizeParam&) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase {
......@@ -148,6 +176,10 @@ public:
const NCBKernSizeParam& param) const override;
bool is_preferred(const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
};
#endif
......@@ -163,7 +195,7 @@ public:
}
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
};
//=======================input int8 compute fp32 output int8============
......@@ -180,7 +212,7 @@ public:
}
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
};
//=======================input int8 compute int16 output int8============
......@@ -198,7 +230,7 @@ public:
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
};
} // namespace arm_common
......
......@@ -36,6 +36,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase {
......@@ -48,6 +51,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase {
......@@ -71,6 +77,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase {
......@@ -84,6 +93,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final : public AlgoBase {
......@@ -96,6 +108,9 @@ public:
const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase {
......@@ -111,6 +126,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
};
} // namespace arm_common
......
......@@ -10,6 +10,7 @@
* implied.
*/
#include "megdnn/opr_param_defs.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8x8x16/algos.h"
#include "src/arm_common/conv_bias/quint8/algos.h"
......@@ -122,9 +123,11 @@ public:
static CpuOprDelegationStorage<2> storage;
auto matmul_opr = storage.get<MatrixMul, 0>();
using MatmulFormat = param::MatrixMul::Format;
auto&& matmul_algos =
static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
->algo_pack();
->select_algo_type(
{AlgoDataType::FLOAT32, MatmulFormat::MK4});
for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr)
continue;
......@@ -133,38 +136,62 @@ public:
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF63(
refhold.emplace_back(new AlgoFP32WinogradF63_4x4(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF63_4x4(
refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF54(
refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF45(
//! uncomment this when low precision mode is done
#if 0
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44(
#endif
//! Qint8x8x32 winograd compute with fp32
refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44(
}
}
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
->select_algo_type({AlgoDataType::FLOAT32,
MatmulFormat::DEFAULT});
for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr)
continue;
for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoFP32WinogradF63(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
//! uncomment this when low precision mode is done
#if 0
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44(
refhold.emplace_back(new AlgoFP32WinogradF54(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
#endif
refhold.emplace_back(new AlgoFP32WinogradF45(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
}
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
->select_algo_type({AlgoDataType::FLOAT16,
MatmulFormat::DEFAULT});
for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr)
continue;
for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoFP16WinogradF23(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
......@@ -177,19 +204,33 @@ public:
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
}
}
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
->select_algo_type({AlgoDataType::FLOAT16,
MatmulFormat::MK8});
for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr)
continue;
for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoFP16WinogradF23_8x8(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
}
}
#endif
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
->select_algo_type({AlgoDataType::INT16X16X32,
MatmulFormat::MK8});
for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr)
continue;
for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoS8WinogradF23_8x8(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoS8WinogradF23_8x8_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
......@@ -240,6 +281,42 @@ bool ConvBiasImpl::is_matmul_quantized_prefer(
return conv_direct_unusable;
}
SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order(
const NCBKernSizeParam& param) const {
auto IC = param.filter_meta.icpg;
auto OC = param.filter_meta.ocpg;
auto FH = param.filter_meta.spatial[0];
auto FW = param.filter_meta.spatial[1];
//! TODO: now winograd only support fast-run
if (param.filter_meta.format == param::ConvBias::Format::NCHW_WINOGRAD ||
param.filter_meta.format == param::ConvBias::Format::NCHW44_WINOGRAD ||
param.filter_meta.format == param::ConvBias::Format::NCHW88_WINOGRAD) {
return {AlgoCategory::WINOGRAD};
}
//! im2col
bool im2col_prefer = (IC >= 32 || OC >= 32);
//! quantized algo use matmul when direct algo is unusable
if (param.src_type.category() == DTypeCategory::QUANTIZED) {
im2col_prefer = is_matmul_quantized_prefer(param);
}
//! conv1x1
im2col_prefer |= (FH == 1 && FW == 1);
//! nchw44 and nchw44-dot hybird mode is direct
if (param.filter_meta.format == param::ConvBias::Format::NCHW44 ||
param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT) {
if (IC < 4) {
im2col_prefer = false;
}
}
if (im2col_prefer) {
return {AlgoCategory::IM2COL, AlgoCategory::DIRECT,
AlgoCategory::NAIVE};
} else {
return {AlgoCategory::DIRECT, AlgoCategory::IM2COL,
AlgoCategory::NAIVE};
}
}
const char* ConvBiasImpl::get_algorithm_set_name() const {
// arm common version 0
return "AC0";
......
......@@ -28,6 +28,9 @@ public:
bool is_matmul_quantized_prefer(
const ConvBiasImpl::NCBKernSizeParam& ncb_param) const override;
SmallVector<AlgoCategory> suggest_algo_category_order(
const NCBKernSizeParam& param) const override;
class AlgoPack;
protected:
......@@ -90,7 +93,7 @@ private:
class AlgoF16Direct;
class AlgoF16DirectStride1;
#endif
};
};
} // namespace arm_common
} // namespace megdnn
......
......@@ -29,6 +29,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase {
......@@ -42,6 +45,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT};
}
};
#if __ARM_FEATURE_DOTPROD
class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase {
......@@ -56,6 +62,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT};
}
};
class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase {
......@@ -69,6 +78,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT};
}
};
#endif
} // namespace arm_common
......
......@@ -26,7 +26,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT)
};
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase {
......@@ -40,7 +40,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT)
};
class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase {
......@@ -54,7 +54,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4)
};
#if __ARM_FEATURE_DOTPROD
......@@ -69,7 +69,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT)
};
#endif
......@@ -87,7 +87,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT)
};
class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase {
......@@ -101,7 +101,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4)
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......@@ -116,7 +116,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT)
};
#endif
......@@ -131,7 +131,13 @@ public:
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(1, 1, 1, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(
1, 1, 1, 4,
static_cast<AlgoDataType>(
static_cast<uint32_t>(AlgoDataType::FLOAT16) |
static_cast<uint32_t>(AlgoDataType::FLOAT32) |
static_cast<uint32_t>(AlgoDataType::QINT8X8X32)),
DEFAULT)
};
} // namespace arm_common
......
......@@ -25,7 +25,7 @@ void* const MatrixMulImpl::sm_arm_common_algo_type =
class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt8x8x16 int8x8x16;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoF16Gemv f16gemv;
AlgoF16Gemv f16gemv;
#endif
AlgoInt8x8x32Gemv int8x8x32_gemv;
AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4;
......@@ -34,10 +34,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#endif
AlgoGevm gevm;
AlgoF32GemvMK4 f32_gemv_mk4;
public:
AlgoPack() {
all_algos.emplace_back(&int8x8x16);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
all_algos.emplace_back(&f16gemv);
#endif
#if __ARM_FEATURE_DOTPROD
......@@ -47,7 +48,7 @@ public:
all_algos.emplace_back(&int8x8x32_gemv_mk4);
all_algos.emplace_back(&f32_gemv_mk4);
all_algos.emplace_back(&gevm);
}
}
SmallVector<AlgoBase*> all_algos;
};
......
......@@ -37,6 +37,9 @@ public:
size_t group = param.filter_meta.group;
return {{kimpl, {group, 1_z, 1_z}}};
}
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL};
}
};
} // namespace armv7
......
......@@ -38,6 +38,10 @@ public:
size_t group = param.filter_meta.group;
return {{kimpl, {group, 1_z, 1_z}}};
}
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL};
}
};
} // namespace armv7
......
......@@ -85,7 +85,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32, megdnn_armv7_matmul_kern,
"AlgoF32Impl"_hash,
armv7::matmul::sgemm_4x12, float, float);
armv7::matmul::sgemm_4x12, float, float,
AlgoDataType::FLOAT32, DEFAULT);
/* ===================== F32 algo mk4 K4x12 ===================== */
......@@ -154,7 +155,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4Pack4x12,
megdnn_armv7_matmul_kern,
"AlgoF32MK4Pack4x12"_hash,
armv7::matmul::sgemm_mk4_pack_4x12, float,
float);
float, AlgoDataType::FLOAT32, MK4);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/* ===================== F16 K4x16x1 algo ===================== */
......@@ -215,7 +216,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K4x16x1::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K4x16x1, megdnn_armv7_matmul_kern,
"AlgoF16K4x16x1"_hash,
armv7::matmul::hgemm_4x16, dt_float16,
dt_float16);
dt_float16, AlgoDataType::FLOAT16,
DEFAULT);
#endif
......@@ -280,7 +282,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x2x16,
megdnn_armv7_matmul_kern,
"AlgoInt8x8x32K4x2x16"_hash,
armv7::matmul::gemm_s8_4x2, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32,
DEFAULT);
/* ===================== Int8x8x32 Kernel 4x8x8 algo ===================== */
namespace {
......@@ -342,7 +345,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x8x8,
megdnn_armv7_matmul_kern,
"AlgoInt8x8x32K4x8x8"_hash,
armv7::matmul::gemm_s8_4x8, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32,
DEFAULT);
/* ===================== Quint8 Kernel 4x8x8 algo ===================== */
namespace {
......@@ -402,7 +406,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K4x8x8::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K4x8x8, megdnn_armv7_matmul_kern,
"AlgoQuint8K4x8x8"_hash,
armv7::matmul::gemm_u8_4x8, uint8_t,
int32_t);
int32_t, AlgoDataType::QUINT8X8X32,
DEFAULT);
/* ===================== Int8x8x16 Kernel 2x4x16 algo ===================== */
namespace {
......@@ -468,7 +473,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x2x16,
megdnn_armv7_matmul_kern,
"AlgoInt8x8x16K4x2x16"_hash,
armv7::matmul::gemm_s8x8x16_4x2, int8_t,
int16_t);
int16_t, AlgoDataType::INT8X8X16, DEFAULT);
/* ===================== Int8x8x16 Kernel 4x8x8 algo ===================== */
namespace {
......@@ -534,7 +539,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x8x8,
megdnn_armv7_matmul_kern,
"AlgoInt8x8x16K4x8x8"_hash,
armv7::matmul::gemm_s8x8x16_4x8, int8_t,
int16_t);
int16_t, AlgoDataType::INT8X8X16, DEFAULT);
/* =================== Int8x8x16 Kernel MK4 8x8x4 algo ===================*/
......@@ -602,7 +607,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16MK4_8x8x4,
megdnn_armv7_matmul_kern,
"AlgoInt8x8x16MK4_8x8x4"_hash,
armv7::matmul::gemm_s8x8x16_mk4_8x8,
int8_t, int16_t, int16_t);
int8_t, int16_t, int16_t,
AlgoDataType::INT8X8X16, MK4);
/* ===================== Int16x16x32 Kernel 12x4x1 algo ===================== */
......@@ -668,7 +674,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x4x1,
megdnn_armv7_matmul_kern,
"AlgoInt16x16x32K12x4x1"_hash,
armv7::matmul::gemm_s16x16x32_12x4,
int16_t, int32_t);
int16_t, int32_t,
AlgoDataType::INT16X16X32, DEFAULT);
#if __ARM_FEATURE_DOTPROD
/* ===================== Int8 K6x8x4 algo ===================== */
namespace {
......@@ -724,7 +731,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K6x8x4,
megdnn_armv7_matmul_kern,
"AlgoInt8x8x32K6x8x4"_hash,
armv7::matmul::gemm_dots8_6x8, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32,
DEFAULT);
/* ===================== Quint8 K4x8x4 algo ===================== */
namespace {
void quint8_dot_k4x8x4_kern(const MatrixMulImpl::KernParam& kern_param) {
......@@ -786,7 +794,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8DotK4x8x4,
megdnn_armv7_matmul_kern,
"AlgoQuint8DotK4x8x4"_hash,
armv7::matmul::gemm_dot_quint8_4x8,
uint8_t, int32_t);
uint8_t, int32_t,
AlgoDataType::QUINT8X8X32, DEFAULT);
/* ======================== Int8 MK4 8x4x4 dot algo ======================== */
namespace {
......@@ -854,7 +863,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x4x4DotProd,
megdnn_armv7_matmul_kern,
"AlgoInt8x8x32MK4_8x4x4DotProd"_hash,
armv7::matmul::gemm_mk4_dots8_8x4, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32, MK4_DOT);
#endif
/* ===================== F32 algo K4x8 ===================== */
......@@ -1099,6 +1108,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x2x16,
megdnn_armv7_matmul_kern,
"AlgoInt8x8x32MK4_4x2x16"_hash,
armv7::matmul::gemm_mk4_s8_4x2, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32, MK4);
// vim: syntax=cpp.doxygen
......@@ -50,7 +50,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4)
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......@@ -73,7 +73,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8)
};
#endif
#if __ARM_FEATURE_DOTPROD
......@@ -205,7 +205,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8)
};
class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase {
......
......@@ -18,7 +18,6 @@ namespace armv7 {
class MatrixMulImpl : public arm_common::MatrixMulImpl {
public:
using arm_common::MatrixMulImpl::MatrixMulImpl;
SmallVector<AlgoBase*> algo_pack() override;
private:
......
......@@ -110,6 +110,11 @@ void __log__(LogLevel level, const char* file, const char* func, int line,
} while (0)
#endif // megdnn_ENABLE_LOGGING
template <typename T>
constexpr int32_t cast_int(T data) {
return static_cast<int32_t>(data);
}
/* helper functions */
/**
* \brief Get the next `stride' index lexicographically.
......@@ -187,6 +192,29 @@ std::unique_ptr<T> make_unique(Args&&... args) {
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}
/*!
* \brief check whether the source enum contain the target data type enum
*/
bool inline contain_data_type(detail::AlgoDataType source,
detail::AlgoDataType target) {
return static_cast<bool>(static_cast<uint32_t>(source) &
static_cast<uint32_t>(target));
}
/*!
* \brief get the source enum contain the data type number
*/
template<typename T>
size_t nr_type_contain(T index) {
uint32_t sr_index = static_cast<uint32_t>(index);
size_t nr_type = 0;
while (sr_index != 0) {
nr_type++;
sr_index &= (sr_index - 1);
}
return nr_type;
}
/**
* \brief Aligned workspace bundle.
*
......
......@@ -26,6 +26,16 @@ public:
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(const NCBKernSizeParam& param) const override;
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;
ConvAlgoTypePack get_algo_type() const override {
auto support_data_type = static_cast<AlgoDataType>(
static_cast<uint32_t>(AlgoDataType::FLOAT16) |
static_cast<uint32_t>(AlgoDataType::FLOAT32) |
static_cast<uint32_t>(AlgoDataType::INT8X8X16) |
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) |
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32));
return {support_data_type, AlgoCategory::NAIVE};
}
};
class ConvBiasImpl::AlgoWinogradF32 final : public AlgoBase {
......@@ -46,6 +56,10 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD};
}
private:
MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name;
......@@ -70,6 +84,10 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD};
}
private:
MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name;
......@@ -94,6 +112,10 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD};
}
private:
MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name;
......@@ -118,6 +140,10 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override;
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD};
}
private:
MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name;
......
......@@ -140,7 +140,7 @@ using BiasMode = ConvBiasForward::BiasMode;
break; \
}
#define MEGDNN_WINOGRAD_ALGO_FUN_DECLARE() \
#define MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(_algo_data_type) \
bool is_reproducible() const override { return true; } \
bool usable(const NCBKernSizeParam& param, \
AlgoSelectionStrategy algo_selection_strategy) const override; \
......@@ -153,6 +153,9 @@ using BiasMode = ConvBiasForward::BiasMode;
const override; \
virtual SmallVector<NCBKern> dispatch_preprocess_kerns( \
const NCBKernSizeParam& param) const override; \
ConvAlgoTypePack get_algo_type() const override { \
return {_algo_data_type, AlgoCategory::WINOGRAD}; \
} \
\
private: \
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; \
......
......@@ -288,7 +288,8 @@ bool ConvBiasImpl::AlgoConv1x1::is_preferred(
size_t OH = param.osz[0];
size_t OW = param.osz[1];
if (OH * OW != 1) {
return true;
return m_matmul_algo->algoset() !=
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV;
} else {
#if (MEGDNN_ARMV7 || MEGDNN_AARCH64)
if (param.src_type.enumv() == DTypeEnum::Int8 &&
......
......@@ -56,6 +56,11 @@ public:
SmallVector<NCBKern> dispatch_preprocess_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override{
return {m_matmul_algo->matmul_description().algo_type.data_type,
AlgoCategory::IM2COL};
}
protected:
size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const;
......
......@@ -34,6 +34,16 @@ public:
bool is_preferred(const NCBKernSizeParam&) const override;
ConvAlgoTypePack get_algo_type() const override {
auto support_data_type = static_cast<AlgoDataType>(
static_cast<uint32_t>(AlgoDataType::FLOAT16) |
static_cast<uint32_t>(AlgoDataType::FLOAT32) |
static_cast<uint32_t>(AlgoDataType::INT8X8X16) |
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) |
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32));
return {support_data_type, AlgoCategory::IM2COL};
}
protected:
size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const;
};
......
......@@ -48,15 +48,25 @@ public:
SmallVector<NCBKern> dispatch_preprocess_kerns(
const NCBKernSizeParam& param) const override;
bool is_preferred(const NCBKernSizeParam& param) const override {
if (param.src_type.category() == DTypeCategory::QUANTIZED) {
static CpuOprDelegationStorage<1> storage;
auto conv_bias_opr = storage.get<ConvBias, 0>();
return static_cast<ConvBiasImpl*>(conv_bias_opr)
->is_matmul_quantized_prefer(param);
size_t OH = param.osz[0];
size_t OW = param.osz[1];
//! gemm and oh * ow > 1 is prefer
//! gemv and oh * ow == 1 is prefer
if ((m_matmul_algo->algoset() !=
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV &&
OH * OW > 1) ||
(m_matmul_algo->algoset() ==
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV &&
OH * OW == 1)) {
return true;
} else {
return false;
}
auto&& fm = param.filter_meta;
auto OC = fm.ocpg, IC = fm.icpg;
return OC >= 32 || IC >= 32;
}
ConvAlgoTypePack get_algo_type() const override {
return {m_matmul_algo->matmul_description().algo_type.data_type,
AlgoCategory::IM2COL};
}
private:
......
......@@ -48,11 +48,26 @@ void incr_ptr(T*& dst, ptrdiff_t delta) {
} // namespace
#if MEGDNN_X86
#define SKIP_GEMV()
//! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may
//! fallback to naive implementation, which may cause performance very low, so
//! here we just enable im2col for gemv in x86 backend.
//! FIXME: remove it when we add direct conv support for int8x8x16
#else
#define SKIP_GEMV() \
if (algo->algoset() == MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { \
continue; \
}
#endif
class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoNaive algo_naive;
SmallVector<std::unique_ptr<AlgoBase>> refhold;
public:
AlgoPack() {
refhold.emplace_back(new AlgoConv1x1Gemv());
all_algos.emplace_back(refhold.back().get());
......@@ -110,8 +125,6 @@ public:
all_algos.emplace_back(refhold.back().get());
#endif
}
//! reverse matmul algo, when the algo is_prefer can be selected first
std::reverse(all_algos.begin(), all_algos.end());
all_algos.emplace_back(&algo_naive);
}
SmallVector<AlgoBase*> all_algos;
......@@ -121,6 +134,22 @@ SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
static AlgoPack sl_algo_pack;
return sl_algo_pack.all_algos;
}
SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::select_algo_type(
ConvAlgoTypePack target_type) {
megdnn_assert(nr_type_contain(target_type.data_type),
"ConvBias algo selection only support one type");
SmallVector<ConvBiasImpl::AlgoBase*> algos;
for (auto&& algo : algo_pack()) {
auto algo_type = algo->get_algo_type();
if (contain_data_type(algo_type.data_type, target_type.data_type) &&
algo_type.algo_category == target_type.algo_category) {
algos.push_back(algo);
}
}
return algos;
}
bool ConvBiasImpl::is_naive_algo(ConvBiasImpl::Algorithm* algo) {
return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0;
}
......@@ -248,12 +277,32 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic(
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
bool reproducible) {
for (auto i : get_all_algorithms_with_ncb(param)) {
if (static_cast<AlgoBase*>(i)->usable_reproducible(
param, AlgoSelectionStrategy::HEURISTIC, reproducible) &&
NCB_ALGO_FUNC(get_workspace, i, param) <=
workspace_limit_in_bytes) {
return i;
auto algo_data_type = param.deduce_algo_data_type();
auto suggest_category_order = suggest_algo_category_order(param);
for (auto category : suggest_category_order) {
auto&& origin_algos = select_algo_type({algo_data_type, category});
ConvBiasImpl::Algorithm* heuristic_algo = nullptr;
for (auto i : origin_algos) {
bool usable_reproducible =
static_cast<AlgoBase*>(i)->usable_reproducible(
param, AlgoSelectionStrategy::HEURISTIC,
reproducible);
if (usable_reproducible &&
static_cast<AlgoBase*>(i)->get_workspace(param) <=
workspace_limit_in_bytes) {
//! store the first usable algo if no prefer algo, choose it as
//! the target algo
if (!heuristic_algo) {
heuristic_algo = i;
}
//! choose the first prefer algo
if (i->is_preferred(param)) {
return i;
}
}
}
if (heuristic_algo) {
return heuristic_algo;
}
}
return nullptr;
......@@ -300,9 +349,8 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param(
sizeof(ConvolutionImpl::CanonizedFilterMeta),
"sizeof CanonizedFilterMeta in convolution and conv_bias "
"should be equal");
CanonizedFilterMeta fm = check_layout_fwd(src, filter, dst);
ConvolutionImpl::CanonizedFilterMeta conv_fm;
conv_fm.copy_from(fm);
auto&& fm = check_layout_fwd(src, filter, dst);
auto& conv_fm = reinterpret_cast<ConvolutionImpl::CanonizedFilterMeta&>(fm);
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT;
if (param().format == Param::Format::NCHW_WINOGRAD ||
......@@ -367,7 +415,7 @@ ConvBiasImpl::NCBKernParam ConvBiasImpl::make_ncb_kern_param(
void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param,
ConvBiasImpl::Algorithm* algo) {
auto ncb_kerns = NCB_ALGO_FUNC(dispatch_kerns, algo, param);
auto&& ncb_kerns = NCB_ALGO_FUNC(dispatch_kerns, algo, param);
for (auto&& kernel : ncb_kerns) {
auto run = [kernel, param](size_t index, size_t thread_id) {
CpuNDRange ndrange_id(kernel.global_size, index);
......@@ -380,7 +428,7 @@ void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param,
void ConvBiasImpl::exec_preprocess_with_ncb_kern(
const NCBKernParam& param, ConvBiasImpl::Algorithm* algo) {
auto ncb_kerns = NCB_ALGO_FUNC(dispatch_preprocess_kerns, algo, param);
auto&& ncb_kerns = NCB_ALGO_FUNC(dispatch_preprocess_kerns, algo, param);
for (auto&& kernel : ncb_kerns) {
auto run = [kernel, param](size_t index, size_t thread_id) {
CpuNDRange ndrange_id(kernel.global_size, index);
......@@ -405,7 +453,6 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb(
}
}
}
std::reverse(prefer_algos.begin(), prefer_algos.end());
//! Prefer algo inserted from begin
algos.insert(algos.begin(), prefer_algos.begin(), prefer_algos.end());
return algos;
......@@ -425,6 +472,35 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm(
return m_prev_selected_algo;
}
SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order(
const NCBKernSizeParam& param) const {
auto IC = param.filter_meta.icpg;
auto OC = param.filter_meta.ocpg;
auto FH = param.filter_meta.spatial[0];
auto FW = param.filter_meta.spatial[1];
//! TODO: now winograd only support in fast-run
if (param.filter_meta.format == param::ConvBias::Format::NCHW_WINOGRAD ||
param.filter_meta.format == param::ConvBias::Format::NCHW44_WINOGRAD ||
param.filter_meta.format == param::ConvBias::Format::NCHW88_WINOGRAD) {
return {AlgoCategory::WINOGRAD};
}
//! im2col + matmul
bool im2col_prefer = (IC >= 32 || OC >= 32);
//! quantized algo use matmul when direct algo is unusable
if (param.src_type.category() == DTypeCategory::QUANTIZED) {
im2col_prefer = is_matmul_quantized_prefer(param);
}
//! conv1x1
im2col_prefer |= (FH == 1 && FW == 1);
if (im2col_prefer) {
return {AlgoCategory::IM2COL, AlgoCategory::DIRECT,
AlgoCategory::NAIVE};
} else {
return {AlgoCategory::DIRECT, AlgoCategory::IM2COL,
AlgoCategory::NAIVE};
}
}
const char* ConvBiasImpl::get_algorithm_set_name() const {
// fallback version 0
return "F0";
......
......@@ -18,6 +18,8 @@
#include "src/fallback/matrix_mul/opr_impl.h"
#include "src/naive/conv_bias/opr_impl.h"
#include <unordered_map>
namespace megdnn {
namespace fallback {
......@@ -44,6 +46,7 @@ class ConvBiasImpl : public naive::ConvBiasForwardImpl {
public:
using naive::ConvBiasForwardImpl::ConvBiasForwardImpl;
using AlgoSelectionStrategy = detail::AlgoSelectionStrategy;
using AlgoDataType = detail::AlgoDataType;
//! implemented by exec_with_ncb_kern()
void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
......@@ -94,6 +97,8 @@ public:
size_t workspace_limit_in_bytes,
bool reproducible) override;
//! size param for kernels with non-contiguous batch
struct NCBKernSizeParam : ConvolutionImpl::NCBKernSizeParam {
NCBKernSizeParam() = default;
......@@ -244,6 +249,9 @@ public:
return (!reproducible || is_reproducible()) &&
usable(param, algo_selection_strategy);
}
//! get the type of the algo
virtual ConvAlgoTypePack get_algo_type() const = 0;
};
/**
......@@ -251,6 +259,17 @@ public:
*/
virtual SmallVector<AlgoBase*> algo_pack();
/**
* \brief select algo according to input algo type
*/
SmallVector<AlgoBase*> select_algo_type(ConvAlgoTypePack algo_type);
/**
* \brief suggest algo category according to the param
*/
virtual SmallVector<AlgoCategory> suggest_algo_category_order(
const NCBKernSizeParam& param) const;
protected:
virtual void exec_with_ncb_kern(const NCBKernParam& param,
ConvBiasImpl::Algorithm* algo);
......
......@@ -83,6 +83,10 @@ public:
SmallVector<NCBKern> dispatch_kern(
const NCBKernSizeParam& /*param*/) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::NAIVE};
}
};
class ConvolutionImpl::AlgoNaive final : public AlgoBase {
......@@ -96,11 +100,17 @@ public:
SmallVector<NCBKern> dispatch_kern(
const NCBKernSizeParam& /*param*/) const override;
ConvAlgoTypePack get_algo_type() const override {
auto support_data_type = static_cast<AlgoDataType>(
static_cast<uint32_t>(AlgoDataType::INT8X8X16) |
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) |
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32));
return {support_data_type, AlgoCategory::NAIVE};
}
};
class ConvolutionImpl::AlgoDefault final : public AlgoBase {
static ConvBiasImpl::NCBKernSizeParam init_conv_bias_param(
const NCBKernSizeParam& param);
WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const;
static SmallVector<NCBKern> get_kimpl(ConvBiasImpl::AlgoBase* algo,
const NCBKernSizeParam& param);
......@@ -136,6 +146,13 @@ public:
//! select matmul to the highest preference
bool is_preferred(const NCBKernSizeParam& param) const override;
static ConvBiasImpl::NCBKernSizeParam init_conv_bias_param(
const NCBKernSizeParam& param);
ConvAlgoTypePack get_algo_type() const override {
return m_algorithm->get_algo_type();
}
private:
std::string m_name;
ConvBiasImpl::AlgoBase* m_algorithm;
......
......@@ -23,6 +23,7 @@
#include "midout.h"
#include <cstring>
#include <unordered_map>
MIDOUT_DECL(megdnn_fb_convbwd_float)
......@@ -75,6 +76,22 @@ SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::algo_pack() {
static AlgoPack sl_algo_pack;
return sl_algo_pack.all_algos;
}
SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::select_algo_type(
ConvAlgoTypePack target_type) {
megdnn_assert(nr_type_contain(target_type.data_type),
"ConvBias algo selection only support one type");
SmallVector<ConvolutionImpl::AlgoBase*> algos;
for (auto&& algo : algo_pack()) {
auto algo_type = algo->get_algo_type();
if (contain_data_type(algo_type.data_type, target_type.data_type) &&
algo_type.algo_category == target_type.algo_category) {
algos.push_back(algo);
}
}
return algos;
}
bool ConvolutionImpl::is_naive_algo(ConvolutionImpl::Algorithm* algo) {
return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0;
}
......@@ -249,9 +266,9 @@ ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param(
void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param,
Algorithm* algo) {
auto kerns = NCB_ALGO_FUNC(dispatch_preprocess_kern, algo, param);
auto fallback_handle = handle();
for (auto kernel : kerns) {
auto&& kerns = NCB_ALGO_FUNC(dispatch_preprocess_kern, algo, param);
auto&& fallback_handle = handle();
for (auto&& kernel : kerns) {
megdnn_assert(
param.filter_meta.format == Param::Format::NCHW ||
param.filter_meta.format == Param::Format::NHWC ||
......@@ -270,9 +287,9 @@ void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param,
void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param,
Algorithm* algo) {
auto kerns = NCB_ALGO_FUNC(dispatch_kern, algo, param);
auto fallback_handle = handle();
for (auto kernel : kerns) {
auto&& kerns = NCB_ALGO_FUNC(dispatch_kern, algo, param);
auto&& fallback_handle = handle();
for (auto&& kernel : kerns) {
megdnn_assert(
param.filter_meta.format == Param::Format::NCHW ||
param.filter_meta.format == Param::Format::NHWC ||
......@@ -292,13 +309,32 @@ void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param,
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
bool reproducible) {
for (auto i : get_all_algorithms_with_ncb(param)) {
bool usable_reproducible =
static_cast<AlgoBase*>(i)->usable_reproducible(
param, AlgoSelectionStrategy::HEURISTIC, reproducible);
if (usable_reproducible && NCB_ALGO_FUNC(get_workspace, i, param) <=
workspace_limit_in_bytes) {
return i;
auto algo_data_type = param.deduce_algo_data_type();
auto suggest_category_order = suggest_algo_category_order(param);
for (auto category : suggest_category_order) {
auto&& origin_algos = select_algo_type({algo_data_type, category});
ConvolutionImpl::Algorithm* heuristic_algo = nullptr;
for (auto i : origin_algos) {
bool usable_reproducible =
static_cast<AlgoBase*>(i)->usable_reproducible(
param, AlgoSelectionStrategy::HEURISTIC,
reproducible);
if (usable_reproducible &&
static_cast<AlgoBase*>(i)->get_workspace(param) <=
workspace_limit_in_bytes) {
//! store the first usable algo if no prefer algo, choose it as
//! the target algo
if (!heuristic_algo) {
heuristic_algo = i;
}
//! choose the first prefer algo
if (i->is_preferred(param)) {
return i;
}
}
}
if (heuristic_algo) {
return heuristic_algo;
}
}
return nullptr;
......@@ -317,8 +353,6 @@ ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) {
}
}
}
std::reverse(prefer_algos.begin(), prefer_algos.end());
//! Prefer algo inserted from begin
ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
return ret;
}
......@@ -337,11 +371,45 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm(
return m_prev_selected_algo;
}
SmallVector<AlgoCategory> ConvolutionImpl::suggest_algo_category_order(
const NCBKernSizeParam& param) const {
static CpuOprDelegationStorage<1> storage;
auto conv_bias_opr = storage.get<ConvBias, 0>();
auto conv_bias_param =
ConvolutionImpl::AlgoDefault::init_conv_bias_param(param);
return static_cast<ConvBiasImpl*>(conv_bias_opr)
->suggest_algo_category_order(conv_bias_param);
}
const char* ConvolutionImpl::get_algorithm_set_name() const {
// fallback version 0
return "F0";
}
ConvolutionImpl::AlgoDataType
ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type() const {
if (src_type.enumv() == DTypeEnum::Float32) {
return ConvolutionImpl::AlgoDataType::FLOAT32;
#if !MEGDNN_DISABLE_FLOAT16
} else if (src_type.enumv() == DTypeEnum::Float16) {
return ConvolutionImpl::AlgoDataType::FLOAT16;
#endif
} else if (src_type.enumv() == DTypeEnum::Int8 ||
src_type.enumv() == DTypeEnum::QuantizedS8) {
if (dst_type.enumv() == DTypeEnum::Int16) {
return ConvolutionImpl::AlgoDataType::INT8X8X16;
} else {
return ConvolutionImpl::AlgoDataType::QINT8X8X32;
}
} else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) {
return ConvolutionImpl::AlgoDataType::QUINT8X8X32;
} else {
megdnn_throw(ssprintf("megdnn not support data type of %s * %s -> %s\n",
src_type.name(), filter_type.name(),
dst_type.name()));
}
}
/* ===================== ConvolutionBackwardData ===================== */
void* const ConvolutionBackwardDataImpl::sm_fallback_deconv_algo_type =
......
......@@ -10,11 +10,28 @@
*/
#pragma once
#include "megdnn/oprs/base.h"
#include "src/common/utils.h"
#include "src/fallback/handle.h"
#include "src/naive/convolution/opr_impl.h"
namespace megdnn {
/**
* \brief Convolutino algo category
*/
enum class AlgoCategory : int32_t {
DIRECT = 0,
IM2COL = 1,
WINOGRAD = 2,
NAIVE = 3,
};
struct ConvAlgoTypePack {
detail::AlgoDataType data_type : 32;
AlgoCategory algo_category : 32;
};
namespace fallback {
/*!
......@@ -33,6 +50,7 @@ class ConvolutionImpl : public naive::ConvolutionForwardImpl {
public:
using naive::ConvolutionForwardImpl::ConvolutionForwardImpl;
using AlgoSelectionStrategy = detail::AlgoSelectionStrategy;
using AlgoDataType = detail::AlgoDataType;
//! implemented by exec_with_ncb_kern()
void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
......@@ -86,6 +104,8 @@ public:
size_t nr_threads;
//! weight_preprocess info
const PreprocessedFilter* preprocessed_filter;
//! get the data type category of the param for select the algo
AlgoDataType deduce_algo_data_type() const;
};
//! memory param for kernels with non-contiguous batch
......@@ -211,6 +231,9 @@ public:
return (!reproducible || is_reproducible()) &&
usable(param, algo_selection_strategy);
}
//! get the type of the algo
virtual ConvAlgoTypePack get_algo_type() const = 0;
};
/**
......@@ -218,6 +241,11 @@ public:
*/
virtual SmallVector<AlgoBase*> algo_pack();
/**
* \brief select algo according to input algo type
*/
SmallVector<AlgoBase*> select_algo_type(ConvAlgoTypePack algo_type);
protected:
virtual void exec_with_ncb_kern(const NCBKernParam& param, Algorithm* algo);
......@@ -258,6 +286,9 @@ private:
_megdnn_tensor_out dst,
const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace);
SmallVector<AlgoCategory> suggest_algo_category_order(
const NCBKernSizeParam& param) const;
};
class ConvolutionBackwardDataImpl : public naive::ConvolutionBackwardDataImpl {
......
......@@ -76,7 +76,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_fb_matmul_f32_kern,
5, matmul::fallback::sgemm_8x12, float,
float);
float, AlgoDataType::FLOAT32, DEFAULT);
/* ===================== gemv algo ===================== */
bool MatrixMulImpl::AlgoGemv::usable(
......
......@@ -37,7 +37,15 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(
8, 16, 1, 4,
static_cast<AlgoDataType>(
static_cast<uint32_t>(AlgoDataType::FLOAT16) |
static_cast<uint32_t>(AlgoDataType::FLOAT32) |
static_cast<uint32_t>(AlgoDataType::INT8X8X16) |
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) |
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)),
DEFAULT)
};
} // namespace fallback
......
......@@ -352,13 +352,15 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
DType dtype_c) \
: A_dtype(dtype_a), B_dtype(dtype_b), C_dtype(dtype_c) {}
#define MEGDNN_OVERRIDE_MATMUL_DESC(_m, _n, _k, _packa_type_size) \
MatmulDescription matmul_description() const override { \
MatmulDescription mdesc; \
mdesc.packmode = packmode(); \
mdesc.innerblocksize = {_m, _n, _k}; \
mdesc.packa_type_size = _packa_type_size; \
return mdesc; \
#define MEGDNN_OVERRIDE_MATMUL_DESC(_m, _n, _k, _packa_type_size, _data_type, \
_format) \
MatmulDescription matmul_description() const override { \
MatmulDescription mdesc; \
mdesc.packmode = packmode(); \
mdesc.innerblocksize = {_m, _n, _k}; \
mdesc.packa_type_size = _packa_type_size; \
mdesc.algo_type = {_data_type, Param::Format::_format}; \
return mdesc; \
}
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL() \
......@@ -373,7 +375,7 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \
_packa_type) \
_packa_type, _support_data_type, _format) \
\
MatrixMulImpl::kern_naked_t MatrixMulImpl::_algo_name::get_kern_naked( \
const KernSizeParam&) const { \
......@@ -474,14 +476,16 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
mdesc.innerblocksize = {_strategy::KERNEL_H, _strategy::KERNEL_W, \
_strategy::UNROLL_K}; \
mdesc.packa_type_size = sizeof(_packa_type); \
mdesc.algo_type = {_support_data_type, Param::Format::_format}; \
return mdesc; \
}
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type) \
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(_algo_name, _midout_name, \
_mid_index, _strategy, \
_i_type, _c_type, _i_type)
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \
_support_data_type, _format) \
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \
_i_type, _support_data_type, _format)
} // namespace matmul
} // namespace megdnn
......
......@@ -38,6 +38,22 @@ SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
return s_algo_pack.all_algos;
}
SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::select_algo_type(
AlgoTypePack index) {
megdnn_assert(nr_type_contain(index.data_type),
"Matmul algo selection only support one type");
SmallVector<MatrixMulImpl::AlgoBase*> algos;
for (auto&& algo : algo_pack()) {
auto algo_desc = algo->matmul_description();
if (contain_data_type(algo_desc.algo_type.data_type,
index.data_type) &&
algo_desc.algo_type.format == index.format) {
algos.push_back(algo);
}
}
return algos;
}
std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) {
std::vector<Algorithm*> gemm_algos, gemv_algos;
......@@ -71,17 +87,25 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic(
"require reproducible algorithm, but given algorithm is not "
"reproducible");
}
auto algos = get_all_algorithms(A, B, C);
AlgoTypePack algo_type;
algo_type.data_type = kern_size_param.deduce_algo_data_type();
algo_type.format = kern_size_param.format;
auto algos = select_algo_type(algo_type);
Algorithm *heuristic_algo = nullptr;
for (auto&& algo : algos) {
if (static_cast<AlgoBase*>(algo)->preferred_reproducible(
if (static_cast<AlgoBase*>(algo)->usable(kern_size_param) &&
static_cast<AlgoBase*>(algo)->preferred_reproducible(
kern_size_param, reproducible) &&
static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <=
workspace_limit_in_bytes) {
return algo;
if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) {
return algo;
} else if (!heuristic_algo) {
heuristic_algo = algo;
}
}
}
return nullptr;
return heuristic_algo;
}
MatrixMulImpl::KernSizeParam MatrixMulImpl::make_kern_size_param(
......@@ -150,4 +174,34 @@ void MatrixMulImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
naive::MatrixMulForwardImpl::exec(A, B, C, workspace);
}
MatrixMulImpl::AlgoDataType
MatrixMulImpl::KernSizeParam::deduce_algo_data_type() const {
megdnn_assert(A_type.enumv() == B_type.enumv(),
"Matmul A type and B type of different ctype\n");
if (A_type.enumv() == DTypeEnum::Float32) {
return MatrixMulImpl::AlgoDataType::FLOAT32;
#if !MEGDNN_DISABLE_FLOAT16
} else if (A_type.enumv() == DTypeEnum::Float16) {
return MatrixMulImpl::AlgoDataType::FLOAT16;
#endif
} else if (A_type.enumv() == DTypeEnum::Int8 ||
A_type.enumv() == DTypeEnum::QuantizedS8) {
if (C_type.enumv() == DTypeEnum::Int16) {
return MatrixMulImpl::AlgoDataType::INT8X8X16;
} else {
megdnn_assert(C_type.enumv() == DTypeEnum::Int32 ||
C_type.enumv() == DTypeEnum::QuantizedS32);
return MatrixMulImpl::AlgoDataType::QINT8X8X32;
}
} else if (A_type.enumv() == DTypeEnum::Quantized8Asymm) {
return MatrixMulImpl::AlgoDataType::QUINT8X8X32;
} else if (A_type.enumv() == DTypeEnum::Int16) {
return MatrixMulImpl::AlgoDataType::INT16X16X32;
} else {
megdnn_throw(ssprintf(
"megdnn matmul not support data type of %s * %s -> %s\n",
A_type.name(), B_type.name(), C_type.name()));
}
}
// vim: syntax=cpp.doxygen
......@@ -10,14 +10,23 @@
* implied.
*/
#pragma once
#include "megdnn/opr_param_defs.h"
#include "src/common/utils.h"
#include "src/naive/matrix_mul/opr_impl.h"
#include <unordered_map>
namespace megdnn {
namespace fallback {
struct AlgoTypePack {
detail::AlgoDataType data_type : 32;
param::MatrixMul::Format format : 32;
};
namespace fallback {
class MatrixMulImpl : public naive::MatrixMulForwardImpl {
public:
using naive::MatrixMulForwardImpl::MatrixMulForwardImpl;
using AlgoDataType = detail::AlgoDataType;
bool is_thread_safe() const override { return true; }
......@@ -34,6 +43,8 @@ public:
bool trA, trB;
Param::ComputeMode compute_mode;
Param::Format format;
//! get the data type category of the param for select the algo
AlgoDataType deduce_algo_data_type() const;
};
struct KernParam : public KernSizeParam {
......@@ -110,6 +121,7 @@ public:
struct MatmulDescription {
PackMode packmode;
InnerBlockSize innerblocksize;
AlgoTypePack algo_type;
size_t packa_type_size;
};
......@@ -146,6 +158,11 @@ public:
*/
virtual SmallVector<AlgoBase*> algo_pack();
/**
* \brief select algo according to input algo type
*/
SmallVector<AlgoBase*> select_algo_type(AlgoTypePack algo_type);
protected:
KernSizeParam make_kern_size_param(const TensorLayout& A,
const TensorLayout& B,
......
......@@ -48,6 +48,10 @@ public:
}
void* type() const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
};
/* ===================== direct-stride2 algo ===================== */
......@@ -81,6 +85,10 @@ public:
}
void* type() const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
};
/* =========================== winograd ======================== */
class ConvBiasImpl::AlgoFP32WinogradF63_8x8 final : public AlgoBase {
......@@ -96,7 +104,7 @@ public:
return m_name.c_str();
}
void* type() const override;
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
};
class ConvBiasImpl::AlgoFP32WinogradF23_8x8 final : public AlgoBase {
......@@ -112,7 +120,7 @@ public:
return m_name.c_str();
}
void* type() const override;
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
};
/* ===================== matmul algo ===================== */
......@@ -151,6 +159,9 @@ public:
}
void* type() const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::IM2COL};
}
};
#if MEGDNN_X86_WITH_MKL_DNN
......@@ -192,6 +203,10 @@ public:
return {{kern, {1_z, 1_z, 1_z}}};
}
void* type() const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
};
#endif
// vim: syntax=cpp.doxygen
......@@ -224,8 +224,6 @@ bool mkldnn_matmul_qint8_preferred(
const ConvBiasImpl::NCBKernSizeParam& param) {
auto is_preferred = true;
auto&& fm = param.filter_meta;
megdnn_assert_internal(fm.group == 1 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1);
// single channel conv should never use matrix mul
if (fm.ocpg == 1 || fm.icpg == 1)
......
......@@ -34,6 +34,10 @@ public:
}
void* type() const override;
bool is_preferred(const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
};
/* ===================== avx2 stride2 chanwise algo ===================== */
......@@ -55,6 +59,10 @@ public:
}
void* type() const override;
bool is_preferred(const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
};
/* ===================== avx2 stride1 direct algo ===================== */
......@@ -76,6 +84,10 @@ public:
}
void* type() const override;
bool is_preferred(const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
};
/* ================== avx2 int8 direct conv stride2 algo ================== */
......@@ -97,6 +109,10 @@ public:
}
void* type() const override;
bool is_preferred(const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
};
#if MEGDNN_X86_WITH_MKL_DNN
......@@ -134,6 +150,10 @@ public:
}
void* type() const override;
bool is_preferred(const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
};
/* ===================== mkldnn qint8 matmul algo ===================== */
class ConvBiasImpl::AlgoMkldnnMatmulQint8 final : public AlgoBase {
......@@ -160,6 +180,10 @@ public:
bool is_preferred(const NCBKernSizeParam& param) const override;
void* type() const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL};
}
};
#endif
......
......@@ -103,10 +103,10 @@ public:
#endif
all_algos.emplace_back(&stride1_direct);
all_algos.emplace_back(&stride2_direct);
all_algos.emplace_back(&avx2_stride1_direct_int8);
all_algos.emplace_back(&avx2_stride2_direct);
all_algos.emplace_back(&avx2_stride1_chanwsie_qint8);
all_algos.emplace_back(&avx2_stride2_chanwsie_qint8);
all_algos.emplace_back(&avx2_stride1_direct_int8);
all_algos.emplace_back(&avx2_stride2_direct);
all_algos.emplace_back(&matmul);
static CpuOprDelegationStorage<> storage;
......@@ -182,4 +182,41 @@ bool ConvBiasImpl::is_matmul_quantized_prefer(
!chanwise_avx2_stride2_qint8_usable_preferred(param));
}
SmallVector<AlgoCategory>
ConvBiasImpl::suggest_algo_category_order(const NCBKernSizeParam& param) const {
auto IC = param.filter_meta.icpg;
auto OC = param.filter_meta.ocpg;
auto FH = param.filter_meta.spatial[0];
auto FW = param.filter_meta.spatial[1];
//! TODO: now winograd only support fast-run
if (param.filter_meta.format == param::ConvBias::Format::NCHW_WINOGRAD ||
param.filter_meta.format == param::ConvBias::Format::NCHW44_WINOGRAD ||
param.filter_meta.format == param::ConvBias::Format::NCHW88_WINOGRAD) {
return {AlgoCategory::WINOGRAD};
}
//! nchw88 use mkl-dnn which algo is direct
if (param.filter_meta.format == param::ConvBias::Format::NCHW88) {
return {AlgoCategory::DIRECT, AlgoCategory::IM2COL};
}
//! im2col + matmul
bool im2col_prefer = (IC >= 32 || OC >= 32);
//! quantized algo use matmul when direct algo is unusable
if (param.src_type.category() == DTypeCategory::QUANTIZED) {
im2col_prefer = is_matmul_quantized_prefer(param);
}
//! conv1x1
im2col_prefer |= (FH == 1 && FW == 1);
//! x86 8x8x16 not optmized, so it will use fallback im2col+matmul
if (param.deduce_algo_data_type() == AlgoDataType::INT8X8X16) {
im2col_prefer = true;
}
if (im2col_prefer) {
return {AlgoCategory::IM2COL, AlgoCategory::DIRECT,
AlgoCategory::NAIVE};
} else {
return {AlgoCategory::DIRECT, AlgoCategory::IM2COL,
AlgoCategory::NAIVE};
}
}
// vim: syntax=cpp.doxygen
......@@ -24,6 +24,8 @@ public:
bool is_thread_safe() const override { return true; }
SmallVector<AlgoBase*> algo_pack() override;
SmallVector<AlgoCategory> suggest_algo_category_order(
const NCBKernSizeParam& param) const override;
class AlgoDirect;
class AlgoDirectStride2;
......
......@@ -184,11 +184,10 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Vnni::get_kern(
return int8x8x32_kern_vnni;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32Vnni,
megdnn_x86_matmul_kern,
"AlgoInt8x8x32Vnni"_hash,
x86::matmul::gemm_int8_vnni_12x32x4,
dt_int8, dt_int32, dt_uint8);
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
AlgoInt8x8x32Vnni, megdnn_x86_matmul_kern, "AlgoInt8x8x32Vnni"_hash,
x86::matmul::gemm_int8_vnni_12x32x4, dt_int8, dt_int32,
dt_uint8AlgoDataType::QINT8X8X32, DEFAULT);
#endif
/* ===================== Int8 mkldnn algo ===================== */
......@@ -397,7 +396,8 @@ size_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_workspace(
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
AlgoInt8x8x16AVX2, megdnn_x86_matmul_kern, "AlgoInt8x8x16AVX2"_hash,
x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16);
x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16,
AlgoDataType::INT8X8X16, DEFAULT);
/*************************AlgoInt8x8x16SSE********************/
void MatrixMulImpl::AlgoInt8x8x16SSE::gemm_s8s8s16_sse_4x8x2(
......@@ -474,7 +474,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16SSE,
megdnn_x86_matmul_kern,
"AlgoInt8x8x16SSE"_hash,
x86::matmul::gemm_sse_s8s8s16_4x8x2,
dt_int8, dt_int16, dt_int16);
dt_int8, dt_int16, dt_int16,
AlgoDataType::INT8X8X16, DEFAULT);
/*************************AlgoInt8x8x32AVX2M4N16K2********************/
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern(
......@@ -516,7 +517,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
AlgoInt8x8x32AVX2M4N16K2, megdnn_x86_matmul_kern,
"AlgoInt8x8x32AVX2M4N16K2"_hash, x86::matmul::gemm_avx2_s8s8s32_4x16x2,
dt_int8, dt_int32, dt_int16);
dt_int8, dt_int32, dt_int16, AlgoDataType::QINT8X8X32, DEFAULT);
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_kern(
const KernSizeParam&) const {
......@@ -556,7 +557,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16,
megdnn_x86_matmul_kern,
"AlgoInt8x8x32AVX2M2N4K16"_hash,
x86::matmul::gemm_avx2_s8s8s32_2x4x16,
dt_int8, dt_int32);
dt_int8, dt_int32,
AlgoDataType::QINT8X8X32, DEFAULT);
/*************************AlgoInt8x8x32SSEM4N8K2********************/
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_kern(
......@@ -596,7 +598,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32SSEM4N8K2,
megdnn_x86_matmul_kern,
"AlgoInt8x8x32SSEM4N8K2"_hash,
x86::matmul::gemm_sse_s8s8s32_4x8x2,
dt_int8, dt_int32, dt_int16);
dt_int8, dt_int32, dt_int16,
AlgoDataType::QINT8X8X32, DEFAULT);
/*************************AlgoF32MK8_8x8********************/
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK8_8x8::get_kern(
......
......@@ -27,7 +27,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT)
};
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
......@@ -49,7 +49,7 @@ public:
WorkspaceBundle get_bundle(const KernSizeParam& param) const override;
InnerBlockSize get_inner_block_size() const override{ return {8, 16, 1}; };
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT)
};
#endif
......@@ -127,7 +127,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4, AlgoDataType::FLOAT32, MK8)
};
#if MEGDNN_X86_WITH_VNNI
......@@ -153,7 +153,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT)
};
#endif
} // namespace x86
......
......@@ -495,8 +495,9 @@ class AlgoChooser {
}
}
mgb_assert(found,
"algo got by heuristic not found in "
"candidate list");
"algo %s got by heuristic not found in "
"candidate list",
heu->name());
return std::move(ret);
}
......@@ -628,7 +629,7 @@ public:
auto algo = get_algo(ctx);
size_t workspace = ctx.get_workspace_size_bytes(algo);
mgb_log_debug(
"%s: input shapes (%s %s, %s %s) -> (%s %s): algo=%s "
"%s:tensor layouts (%s %s, %s %s)->(%s %s) :algo=%s "
"workspace=%.2fMiB reproducible=%d",
mgb_opr->dyn_typeinfo()->name,
layouts[0].to_string().c_str(),
......@@ -636,8 +637,7 @@ public:
layouts[1].to_string().c_str(),
layouts[1].dtype.name(),
layouts[layouts.size() - 1].to_string().c_str(),
layouts[layouts.size() - 1].dtype.name(),
algo->name(),
layouts[layouts.size() - 1].dtype.name(), algo->name(),
workspace / (1024 * 1024.0), algo->is_reproducible());
megdnn_opr->execution_policy() = {algo};
return workspace;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册