提交 e05c795b 编写于 作者: M Megvii Engine Team

refactor(dnn/arm): refactor direct algo in algo selection

GitOrigin-RevId: d195f44decb45847fa46e0e90d6e64368c07539c
上级 134a1026
......@@ -22,26 +22,19 @@ using namespace aarch64;
/* ===================== stride-2 algo ===================== */
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16)
bool ConvBiasImpl::AlgoF16DirectStride2::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool ConvBiasImpl::AlgoF16DirectStride2::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
param.filter_meta.format == param::Convolution::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_type.enumv() == DTypeEnum::Float16 &&
param.dst_type.enumv() == DTypeEnum::Float16 &&
!fm.should_flip && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
return param.filter_meta.format == param::Convolution::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_type.enumv() == DTypeEnum::Float16 &&
param.dst_type.enumv() == DTypeEnum::Float16 &&
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 &&
FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
}
MIDOUT_END();
return false;
......@@ -50,8 +43,9 @@ bool ConvBiasImpl::AlgoF16DirectStride2::usable(
size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 1) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto wbundle = arm_common::MultithreadDirectConvCommon<
dt_float16, __fp16>::get_bundle_stride(param, m_large_group);
dt_float16, __fp16>::get_bundle_stride(param, large_group);
return wbundle.total_size_in_bytes();
}
MIDOUT_END();
......@@ -77,6 +71,7 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls(
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
bool large_group = group >= param.nr_threads;
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*,
size_t, size_t, size_t, size_t, size_t)>;
Func conv = nullptr;
......@@ -91,11 +86,11 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls(
}
WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon<
dt_float16, __fp16>::get_bundle_stride(param, m_large_group);
dt_float16, __fp16>::get_bundle_stride(param, large_group);
SmallVector<NCBKern> ret_kerns;
//! Dense conv and small group
if (m_large_group) {
if (large_group) {
//! Channel wise conv and big groups
auto exec_one_group = [bundle, conv](
const NCBKernParam& kern_param,
......
......@@ -18,15 +18,9 @@ namespace aarch64 {
/* ===================== stride-2 algo ===================== */
class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF16DirectStride2(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "ARMV8F16STRD2_LARGE_GROUP"
: "ARMV8F16STRD2_SMALL_GROUP";
}
const char* name() const override { return "ARMV8F16STRD2"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......
......@@ -21,26 +21,19 @@ using namespace megdnn;
using namespace aarch64;
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32)
bool ConvBiasImpl::AlgoF32DirectStride2::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool ConvBiasImpl::AlgoF32DirectStride2::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
!fm.should_flip && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
return param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 &&
FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
}
MIDOUT_END();
return false;
......@@ -49,8 +42,9 @@ bool ConvBiasImpl::AlgoF32DirectStride2::usable(
size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 1) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto wbundle = arm_common::MultithreadDirectConvCommon<
float, float>::get_bundle_stride(param, m_large_group);
float, float>::get_bundle_stride(param, large_group);
return wbundle.total_size_in_bytes();
}
MIDOUT_END();
......@@ -75,6 +69,7 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls(
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
bool large_group = group >= param.nr_threads;
using Func = std::function<void(const float*, const float*, float*, size_t,
size_t, size_t, size_t, size_t)>;
Func conv = nullptr;
......@@ -89,11 +84,11 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls(
}
WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon<
float, float>::get_bundle_stride(param, m_large_group);
float, float>::get_bundle_stride(param, large_group);
SmallVector<NCBKern> ret_kerns;
//! Dense conv and small group
if (m_large_group) {
if (large_group) {
//! Channel wise conv and big groups
auto exec_one_group = [bundle, conv](
const NCBKernParam& kern_param,
......
......@@ -22,15 +22,9 @@ using FallbackConvBiasImpl = fallback::ConvBiasImpl;
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF32DirectStride2(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "ARMV8F32STRD2_LARGE_GROUP"
: "ARMV8F32STRD2_SMALL_GROUP";
}
const char* name() const override { return "ARMV8F32STRD2"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......
......@@ -25,13 +25,11 @@ using namespace megdnn;
using namespace aarch64;
class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoF32DirectStride2 f32_direct_stride2_large_group{true};
AlgoF32DirectStride2 f32_direct_stride2_small_group{false};
AlgoF32DirectStride2 f32_direct_stride2;
AlgoS8MatrixMul s8_matrix_mul;
AlgoQU8MatrixMul qu8_matrix_mul;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoF16DirectStride2 f16_direct_stride2_large_group{true};
AlgoF16DirectStride2 f16_direct_stride2_small_group{false};
AlgoF16DirectStride2 f16_direct_stride2;
#endif
public:
......@@ -39,11 +37,9 @@ public:
matmul_algos.emplace_back(&qu8_matrix_mul);
matmul_algos.emplace_back(&s8_matrix_mul);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
direct_algos.emplace_back(&f16_direct_stride2_large_group);
direct_algos.emplace_back(&f16_direct_stride2_small_group);
direct_algos.emplace_back(&f16_direct_stride2);
#endif
direct_algos.emplace_back(&f32_direct_stride2_large_group);
direct_algos.emplace_back(&f32_direct_stride2_small_group);
direct_algos.emplace_back(&f32_direct_stride2);
}
SmallVector<AlgoBase*> direct_algos;
SmallVector<AlgoBase*> matmul_algos;
......
......@@ -192,9 +192,8 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP16WinogradF23_8x8,
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_kimpl)
bool ConvBiasImpl::AlgoF16Direct::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool ConvBiasImpl::AlgoF16Direct::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
......@@ -203,20 +202,14 @@ bool ConvBiasImpl::AlgoF16Direct::usable(
// ``param.osz[0]*param.osz[1] >= 8'' comes from the fact that the
// kernel may have access to up to 8 fp16 after the end of the memory
// chunk.
bool aviliable = fm.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_type.enumv() == DTypeEnum::Float16 &&
param.dst_type.enumv() == DTypeEnum::Float16 &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 &&
param.isz[0] * param.isz[1] >= 8 &&
param.osz[0] * param.osz[1] >= 8 && FH <= 7 &&
SH == 1 && SW == 1;
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
return fm.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_type.enumv() == DTypeEnum::Float16 &&
param.dst_type.enumv() == DTypeEnum::Float16 &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && param.isz[0] * param.isz[1] >= 8 &&
param.osz[0] * param.osz[1] >= 8 && FH <= 7 && SH == 1 &&
SW == 1;
}
MIDOUT_END();
return false;
......@@ -225,9 +218,10 @@ bool ConvBiasImpl::AlgoF16Direct::usable(
size_t ConvBiasImpl::AlgoF16Direct::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 1) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto wbundle =
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle(
param, m_large_group);
param, large_group);
return wbundle.total_size_in_bytes();
}
MIDOUT_END();
......@@ -241,13 +235,14 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::get_kimpls(
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
bool large_group = group >= param.nr_threads;
WorkspaceBundle bundle =
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle(
param, m_large_group);
param, large_group);
SmallVector<NCBKern> ret_kerns;
//! When group >= nr_threads, treat it as large_group, each thread process
//! one group for better performance
if (m_large_group) {
if (large_group) {
//! Channel wise conv and big groups
auto exec_one_group = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
......@@ -316,27 +311,18 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16Direct::dispatch_kerns(
/* ===================== stride-1 algo ===================== */
bool ConvBiasImpl::AlgoF16DirectStride1::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool ConvBiasImpl::AlgoF16DirectStride1::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_type.enumv() == DTypeEnum::Float16 &&
param.dst_type.enumv() == DTypeEnum::Float16 &&
!fm.should_flip && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5);
if (algo_selection_strategy ==
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
return param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_type.enumv() == DTypeEnum::Float16 &&
param.dst_type.enumv() == DTypeEnum::Float16 &&
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 &&
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5);
}
MIDOUT_END();
return false;
......@@ -351,6 +337,7 @@ ConvBiasImpl::AlgoF16DirectStride1::get_kimpls(
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
bool large_group = group >= param.nr_threads;
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*,
size_t, size_t, size_t, size_t, size_t)>;
Func conv_kern_function = nullptr;
......@@ -371,11 +358,11 @@ ConvBiasImpl::AlgoF16DirectStride1::get_kimpls(
WorkspaceBundle bundle =
MultithreadDirectConvCommon<dt_float16, __fp16>::get_bundle_stride(
param, m_large_group);
param, large_group);
SmallVector<NCBKern> ret_kerns;
//! When group >= nr_threads, treat it as large_group, each thread process
//! one group for better performance
if (m_large_group) {
if (large_group) {
//! Channel wise conv and big groups
auto exec_one_group = [bundle, conv_kern_function](
const NCBKernParam& kern_param,
......@@ -423,8 +410,9 @@ ConvBiasImpl::AlgoF16DirectStride1::get_kimpls(
size_t ConvBiasImpl::AlgoF16DirectStride1::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 1) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = MultithreadDirectConvCommon<
dt_float16, __fp16>::get_bundle_stride(param, m_large_group);
dt_float16, __fp16>::get_bundle_stride(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
......
......@@ -79,15 +79,10 @@ public:
class ConvBiasImpl::AlgoF16Direct final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF16Direct(bool is_large_group) : m_large_group{is_large_group} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "F16DIRECT_LARGE_GROUP"
: "F16DIRECT_SMALL_GROUP";
}
const char* name() const override { return "F16DIRECT"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -99,14 +94,10 @@ public:
class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF16DirectStride1(bool is_large_group) : m_large_group{is_large_group} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "F16STRD1_LARGE_GROUP" : "F16STRD1_SMALL_GROUP";
}
const char* name() const override { return "F16STRD1"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(const NCBKernSizeParam& param) const override;
......
......@@ -334,9 +334,8 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF63_4x4_NCHW44,
/* ===================== direct algo ===================== */
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_kimpl);
bool ConvBiasImpl::AlgoF32Direct::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool ConvBiasImpl::AlgoF32Direct::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
......@@ -345,20 +344,14 @@ bool ConvBiasImpl::AlgoF32Direct::usable(
// ``param.osz[0]*param.osz[1] >= 4'' comes from the fact that the
// kernel may have access to up to 4 floats after the end of the memory
// chunk.
bool aviliable = fm.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 &&
param.isz[0] * param.isz[1] >= 4 &&
param.osz[0] * param.osz[1] >= 4 && FH <= 7 &&
SH == 1 && SW == 1;
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
return fm.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && param.isz[0] * param.isz[1] >= 4 &&
param.osz[0] * param.osz[1] >= 4 && FH <= 7 && SH == 1 &&
SW == 1;
}
MIDOUT_END();
return false;
......@@ -366,8 +359,9 @@ bool ConvBiasImpl::AlgoF32Direct::usable(
size_t ConvBiasImpl::AlgoF32Direct::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto wbundle = MultithreadDirectConvCommon<float, float>::get_bundle(
param, m_large_group);
param, large_group);
return wbundle.total_size_in_bytes();
}
MIDOUT_END();
......@@ -380,13 +374,14 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls(
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
bool large_group = group >= param.nr_threads;
WorkspaceBundle bundle =
MultithreadDirectConvCommon<float, float>::get_bundle(
param, m_large_group);
param, large_group);
SmallVector<NCBKern> ret_kerns;
//! When group >= nr_threads, treat it as large_group, each thread process
//! one group for better performance
if (m_large_group) {
if (large_group) {
//! Channel wise conv and big groups
auto exec_one_group = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
......@@ -452,27 +447,19 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::dispatch_kerns(
return {};
}
/* ===================== stride-1 algo ===================== */
bool ConvBiasImpl::AlgoF32DirectStride1::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool ConvBiasImpl::AlgoF32DirectStride1::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
!fm.should_flip && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
if (algo_selection_strategy ==
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
return param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 &&
FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
}
MIDOUT_END();
return false;
......@@ -481,9 +468,10 @@ bool ConvBiasImpl::AlgoF32DirectStride1::usable(
size_t ConvBiasImpl::AlgoF32DirectStride1::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle =
MultithreadDirectConvCommon<float, float>::get_bundle_stride(
param, m_large_group);
param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
......@@ -499,6 +487,7 @@ ConvBiasImpl::AlgoF32DirectStride1::get_kimpls(
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
bool large_group = group >= param.nr_threads;
using Func = std::function<void(const float*, const float*, float*, size_t,
size_t, size_t, size_t, size_t)>;
Func conv_kern_function = nullptr;
......@@ -522,11 +511,11 @@ ConvBiasImpl::AlgoF32DirectStride1::get_kimpls(
WorkspaceBundle bundle =
MultithreadDirectConvCommon<float, float>::get_bundle_stride(
param, m_large_group);
param, large_group);
SmallVector<NCBKern> ret_kerns;
//! When group >= nr_threads, treat it as large_group, each thread process
//! one group for better performance
if (m_large_group) {
if (large_group) {
//! Channel wise conv and big groups
auto exec_one_group = [bundle, conv_kern_function](
const NCBKernParam& kern_param,
......@@ -580,27 +569,19 @@ ConvBiasImpl::AlgoF32DirectStride1::dispatch_kerns(
/* ===================== stride-2 algo ===================== */
bool ConvBiasImpl::AlgoF32DirectStride2::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool ConvBiasImpl::AlgoF32DirectStride2::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
!fm.should_flip && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
if (algo_selection_strategy ==
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
return param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 &&
FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
}
MIDOUT_END();
return false;
......@@ -608,9 +589,10 @@ bool ConvBiasImpl::AlgoF32DirectStride2::usable(
size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 1) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle =
MultithreadDirectConvCommon<float, float>::get_bundle_stride(
param, m_large_group);
param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
......@@ -625,6 +607,7 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls(
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
bool large_group = group >= param.nr_threads;
using Func = std::function<void(const float*, const float*, float*, size_t,
size_t, size_t, size_t, size_t)>;
Func conv_kern_function = nullptr;
......@@ -648,11 +631,11 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls(
WorkspaceBundle bundle =
MultithreadDirectConvCommon<float, float>::get_bundle_stride(
param, m_large_group);
param, large_group);
SmallVector<NCBKern> ret_kerns;
//! When group >= nr_threads, treat it as large_group, each thread process
//! one group for better performance
if (m_large_group) {
if (large_group) {
//! Channel wise conv and big groups
auto exec_one_group = [bundle, conv_kern_function](
const NCBKernParam& kern_param,
......
......@@ -128,15 +128,10 @@ public:
class ConvBiasImpl::AlgoF32Direct final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF32Direct(bool is_large_group) : m_large_group{is_large_group} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "F32DIRECT_LARGE_GROUP"
: "F32DIRECT_SMALL_GROUP";
}
const char* name() const override { return "F32DIRECT"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -147,14 +142,10 @@ public:
class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF32DirectStride1(bool is_large_group) : m_large_group{is_large_group} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "F32STRD1_LARGE_GROUP" : "F32STRD1_SMALL_GROUP";
}
const char* name() const override { return "F32STRD1"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -165,14 +156,10 @@ public:
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF32DirectStride2(bool is_large_group) : m_large_group{is_large_group} {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "F32STRD2_LARGE_GROUP" : "F32STRD2_SMALL_GROUP";
}
const char* name() const override { return "F32STRD2"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......
......@@ -27,17 +27,10 @@ using namespace arm_common;
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8)
/* ===================== stride1 algo ===================== */
bool ConvBiasImpl::AlgoS8DirectStride1::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool avaible = direct_int8_stride1::can_conv_direct_stride1_int8(param);
auto fm = param.filter_meta;
if (algo_selection_strategy ==
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) {
bool large_group = fm.group >= param.nr_threads;
avaible &= (large_group == m_large_group);
}
return avaible;
bool ConvBiasImpl::AlgoS8DirectStride1::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
return direct_int8_stride1::can_conv_direct_stride1_int8(param);
}
bool ConvBiasImpl::AlgoS8DirectStride1::is_preferred(
const NCBKernSizeParam& param) const {
......@@ -53,8 +46,9 @@ bool ConvBiasImpl::AlgoS8DirectStride1::is_preferred(
}
size_t ConvBiasImpl::AlgoS8DirectStride1::get_workspace(
const NCBKernSizeParam& param) const {
auto bundle = direct_int8_stride1::get_bundle(param, m_large_group);
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_int8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
......@@ -62,7 +56,8 @@ SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoS8DirectStride1::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 0) {
return direct_int8_stride1::get_kimpls(param, m_large_group);
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_int8_stride1::get_kimpls(param, large_group);
}
MIDOUT_END();
return {};
......@@ -117,21 +112,15 @@ ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::dispatch_kerns(
}
/* ===================== stride2 algo ===================== */
bool ConvBiasImpl::AlgoS8DirectStride2::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool avaible = direct_int8_stride2::can_conv_direct_stride2_int8(param);
if (algo_selection_strategy ==
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
avaible &= (large_group == m_large_group);
}
return avaible;
bool ConvBiasImpl::AlgoS8DirectStride2::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
return direct_int8_stride2::can_conv_direct_stride2_int8(param);
}
size_t ConvBiasImpl::AlgoS8DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
auto bundle = direct_int8_stride2::get_bundle(param, m_large_group);
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_int8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
......@@ -139,7 +128,8 @@ SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 1) {
return direct_int8_stride2::get_kimpls(param, m_large_group);
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_int8_stride2::get_kimpls(param, large_group);
}
MIDOUT_END();
return {};
......@@ -147,24 +137,15 @@ ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns(
#if __ARM_FEATURE_DOTPROD
/* ===================== dot stride1 algo ======================== */
bool ConvBiasImpl::AlgoDotS8DirectStride1::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool avaible =
direct_dotprod_int8_stride1::can_conv_direct_stride1_int8(param);
if (algo_selection_strategy ==
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
avaible &= (large_group == m_large_group);
}
return avaible;
bool ConvBiasImpl::AlgoDotS8DirectStride1::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
return direct_dotprod_int8_stride1::can_conv_direct_stride1_int8(param);
}
size_t ConvBiasImpl::AlgoDotS8DirectStride1::get_workspace(
const NCBKernSizeParam& param) const {
auto bundle = direct_dotprod_int8_stride1::get_bundle(param, m_large_group);
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_int8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
......@@ -172,29 +153,23 @@ SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 1) {
return direct_dotprod_int8_stride1::get_kimpls(param, m_large_group);
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_dotprod_int8_stride1::get_kimpls(param, large_group);
}
MIDOUT_END();
return {};
}
/* ===================== dot stride2 algo ======================== */
bool ConvBiasImpl::AlgoDotS8DirectStride2::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool avaible =
direct_dotprod_int8_stride2::can_conv_direct_stride2_int8(param);
if (algo_selection_strategy ==
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
avaible &= (large_group == m_large_group);
}
return avaible;
bool ConvBiasImpl::AlgoDotS8DirectStride2::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
return direct_dotprod_int8_stride2::can_conv_direct_stride2_int8(param);
}
size_t ConvBiasImpl::AlgoDotS8DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
auto bundle = direct_dotprod_int8_stride2::get_bundle(param, m_large_group);
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_int8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
......@@ -202,7 +177,8 @@ SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 2) {
return direct_dotprod_int8_stride2::get_kimpls(param, m_large_group);
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_dotprod_int8_stride2::get_kimpls(param, large_group);
}
MIDOUT_END();
return {};
......
......@@ -18,14 +18,10 @@ namespace megdnn {
namespace arm_common {
class ConvBiasImpl::AlgoS8DirectStride1 final : public AlgoBase {
bool m_large_group;
public:
AlgoS8DirectStride1(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "S8STRD1_LARGE_GROUP" : "S8STRD1_SMALL_GROUP";
}
const char* name() const override { return "S8STRD1"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(const NCBKernSizeParam& param) const override;
......@@ -36,14 +32,10 @@ public:
};
class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase {
bool m_large_group;
public:
AlgoS8DirectStride2(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "S8STRD2_LARGE_GROUP" : "S8STRD2_SMALL_GROUP";
}
const char* name() const override { return "S8STRD2"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -115,16 +107,10 @@ public:
};
class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase {
bool m_large_group;
public:
AlgoDotS8DirectStride1(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "ARMDOTS8STRD1_LARGE_GROUP"
: "ARMDOTS8STRD1_SMALL_GROUP";
}
const char* name() const override { return "ARMDOTS8STRD1"; }
bool usable(const NCBKernSizeParam&,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -134,15 +120,10 @@ public:
};
class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase {
bool m_large_group;
public:
AlgoDotS8DirectStride2(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "ARMDOTS8STRD2_LARGE_GROUP"
: "ARMDOTS8STRD2_SMALL_GROUP";
}
const char* name() const override { return "ARMDOTS8STRD2"; }
bool usable(const NCBKernSizeParam&,
AlgoSelectionStrategy algo_selection_strategy) const override;
......
......@@ -82,28 +82,20 @@ void get_rectified_size_str2(size_t IH, size_t IW, size_t OH, size_t OW,
} // namespace
/* ===================== direct algo ===================== */
bool ConvBiasImpl::AlgoI8x8x16Direct::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool ConvBiasImpl::AlgoI8x8x16Direct::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
param.bias_mode == BiasMode::NO_BIAS &&
param.nonlineMode == NonlineMode::IDENTITY &&
fm.format == param::ConvBias::Format::NCHW && !fm.should_flip &&
param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int16 &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 &&
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5);
if (algo_selection_strategy ==
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
return param.bias_mode == BiasMode::NO_BIAS &&
param.nonlineMode == NonlineMode::IDENTITY &&
fm.format == param::ConvBias::Format::NCHW && !fm.should_flip &&
param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int16 &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 &&
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5);
}
MIDOUT_END();
return false;
......@@ -117,11 +109,12 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Direct::get_bundle(
auto OH = param.osz[0], OW = param.osz[1];
auto PH = fm.padding[0], PW = fm.padding[1];
size_t OH2, OW2, IH2, IW2;
bool large_group = group >= param.nr_threads;
get_rectified_size_str1(IH, IW, OH, OW, PH, PW, IH2, IW2, OH2, OW2);
size_t part0 = 0u, part1 = 0u;
if (need_src_copy_str1(param)) {
part0 = m_large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads
: IC * IH2 * IW2 * sizeof(int8_t) * group * batch;
part0 = large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads
: IC * IH2 * IW2 * sizeof(int8_t) * group * batch;
}
if (need_dst_copy_str1(param)) {
part1 = OH2 * OW2 * sizeof(int16_t) * nr_threads + 16;
......@@ -255,9 +248,10 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Direct::get_kimpls(
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
bool large_group = group >= param.nr_threads;
WorkspaceBundle bundle = get_bundle(param);
SmallVector<NCBKern> ret_kerns;
if (m_large_group) {
if (large_group) {
auto exec_one_group = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
auto fm = kern_param.filter_meta;
......@@ -302,28 +296,20 @@ ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns(
}
/* ===================== stride-2 algo ===================== */
bool ConvBiasImpl::AlgoI8x8x16Stride2::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool ConvBiasImpl::AlgoI8x8x16Stride2::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable = param.bias_mode == BiasMode::NO_BIAS &&
param.nonlineMode == NonlineMode::IDENTITY &&
fm.format == param::ConvBias::Format::NCHW &&
!fm.should_flip &&
param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int16 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2 &&
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5);
if (algo_selection_strategy ==
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
return param.bias_mode == BiasMode::NO_BIAS &&
param.nonlineMode == NonlineMode::IDENTITY &&
fm.format == param::ConvBias::Format::NCHW && !fm.should_flip &&
param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int16 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5);
}
MIDOUT_END();
return false;
......@@ -340,9 +326,10 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Stride2::get_bundle(
size_t OH2, OW2, IH2, IW2;
get_rectified_size_str2(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2);
size_t part0 = 0u, part1 = 0u;
bool large_group = group >= param.nr_threads;
if (need_src_copy_str2(param)) {
part0 = m_large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads
: IC * IH2 * IW2 * sizeof(int8_t) * group * batch;
part0 = large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads
: IC * IH2 * IW2 * sizeof(int8_t) * group * batch;
}
if (need_dst_copy_str2(param)) {
part1 = OH2 * OW2 * sizeof(int16_t) * nr_threads + 16;
......@@ -475,9 +462,10 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Stride2::get_kimpls(
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
bool large_group = group >= param.nr_threads;
WorkspaceBundle bundle = get_bundle(param);
SmallVector<NCBKern> ret_kerns;
if (m_large_group) {
if (large_group) {
auto exec_one_group = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
auto fm = kern_param.filter_meta;
......
......@@ -26,15 +26,10 @@ class ConvBiasImpl::AlgoI8x8x16Direct final : public AlgoBase {
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids);
bool m_large_group;
public:
AlgoI8x8x16Direct(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "I8816DIRECT_LARGE_GROUP"
: "I8816DIRECT_SMALL_GROUP";
}
const char* name() const override { return "I8816DIRECT"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(const NCBKernSizeParam& param) const override;
......@@ -53,15 +48,9 @@ class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase {
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids);
bool m_large_group;
public:
AlgoI8x8x16Stride2(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "I8816STRD2_LARGE_GROUP"
: "I8816STRD2_SMALL_GROUP";
}
const char* name() const override { return "I8816STRD2"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......
......@@ -40,28 +40,20 @@ uint8_t arm_common_algo_type_storage;
} // anonymous namespace
class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoQU8DirectStride2 qu8_direct_stride2_large_group{true};
AlgoQU8DirectStride2 qu8_direct_stride2_small_group{false};
AlgoQU8DirectStride1 qu8_direct_stride1_large_group{true};
AlgoQU8DirectStride1 qu8_direct_stride1_small_group{false};
AlgoS8DirectStride2 s8_direct_stride2_large_group{true};
AlgoS8DirectStride2 s8_direct_stride2_small_group{false};
AlgoQU8DirectStride2 qu8_direct_stride2;
AlgoQU8DirectStride1 qu8_direct_stride1;
AlgoS8DirectStride2 s8_direct_stride2;
AlgoS8DirectNCHW44 s8_direct_nchw44;
AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44;
AlgoS8DirectStride1 s8_direct_stride1_large_group{true};
AlgoS8DirectStride1 s8_direct_stride1_small_group{false};
AlgoS8DirectStride1 s8_direct_stride1;
AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44;
AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44;
#if __ARM_FEATURE_DOTPROD
AlgoDotS8DirectStride1 ds8_direct_stride1_large_group{true};
AlgoDotS8DirectStride1 ds8_direct_stride1_small_group{false};
AlgoDotS8DirectStride2 ds8_direct_stride2_large_group{true};
AlgoDotS8DirectStride2 ds8_direct_stride2_small_group{false};
AlgoDotU8DirectStride1 du8_direct_stride1_large_group{true};
AlgoDotU8DirectStride1 du8_direct_stride1_small_group{false};
AlgoDotU8DirectStride2 du8_direct_stride2_large_group{true};
AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false};
AlgoDotS8DirectStride1 ds8_direct_stride1;
AlgoDotS8DirectStride2 ds8_direct_stride2;
AlgoDotU8DirectStride1 du8_direct_stride1;
AlgoDotU8DirectStride2 du8_direct_stride2;
AlgoDotS8Direct_NCHW44 ds8_direct_nchw44;
AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44;
......@@ -71,23 +63,16 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44;
AlgoF32DirectNCHW44 f32_direct_nchw44;
AlgoF32Direct f32_direct_large_group{true};
AlgoF32Direct f32_direct_small_group{false};
AlgoF32DirectStride2 f32_direct_stride2_large_group{true};
AlgoF32DirectStride2 f32_direct_stride2_small_group{false};
AlgoF32DirectStride1 f32_direct_stride1_large_group{true};
AlgoF32DirectStride1 f32_direct_stride1_small_group{false};
AlgoF32Direct f32_direct;
AlgoF32DirectStride2 f32_direct_stride2;
AlgoF32DirectStride1 f32_direct_stride1;
AlgoI8x8x16Direct i8x8x16_direct_large_group{true};
AlgoI8x8x16Direct i8x8x16_direct_small_group{false};
AlgoI8x8x16Stride2 i8x8x16_stride2_large_group{true};
AlgoI8x8x16Stride2 i8x8x16_stride2_small_group{false};
AlgoI8x8x16Direct i8x8x16_direct;
AlgoI8x8x16Stride2 i8x8x16_stride2;
AlgoI8x8x16Stride2Filter2 i8x8x16_stride2_filter2;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoF16Direct f16_direct_large_group{true};
AlgoF16Direct f16_direct_small_group{false};
AlgoF16DirectStride1 f16_direct_stride1_large_group{true};
AlgoF16DirectStride1 f16_direct_stride1_small_group{false};
AlgoF16Direct f16_direct;
AlgoF16DirectStride1 f16_direct_stride1;
#endif
SmallVector<std::unique_ptr<AlgoBase>> refhold;
......@@ -95,54 +80,39 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
public:
AlgoPack() {
#if __ARM_FEATURE_DOTPROD
direct_algos.emplace_back(&ds8_direct_stride1_large_group);
direct_algos.emplace_back(&ds8_direct_stride1_small_group);
direct_algos.emplace_back(&ds8_direct_stride2_large_group);
direct_algos.emplace_back(&ds8_direct_stride2_small_group);
direct_algos.emplace_back(&du8_direct_stride1_large_group);
direct_algos.emplace_back(&du8_direct_stride1_small_group);
direct_algos.emplace_back(&du8_direct_stride2_large_group);
direct_algos.emplace_back(&du8_direct_stride2_small_group);
direct_algos.emplace_back(&ds8_direct_stride1);
direct_algos.emplace_back(&ds8_direct_stride2);
direct_algos.emplace_back(&du8_direct_stride1);
direct_algos.emplace_back(&du8_direct_stride2);
direct_algos.emplace_back(&ds8_direct_nchw44);
direct_algos.emplace_back(&ds8_direct_nchw_nchw44);
#endif
direct_algos.emplace_back(&qu8_direct_stride2_large_group);
direct_algos.emplace_back(&qu8_direct_stride2_small_group);
direct_algos.emplace_back(&qu8_direct_stride1_large_group);
direct_algos.emplace_back(&qu8_direct_stride1_small_group);
direct_algos.emplace_back(&s8_direct_stride2_large_group);
direct_algos.emplace_back(&s8_direct_stride2_small_group);
direct_algos.emplace_back(&qu8_direct_stride2);
direct_algos.emplace_back(&qu8_direct_stride1);
direct_algos.emplace_back(&s8_direct_stride2);
direct_algos.emplace_back(&s8_direct_nchw44);
direct_algos.emplace_back(&s8_direct_nchw_nchw44);
direct_algos.emplace_back(&s8_direct_stride1_large_group);
direct_algos.emplace_back(&s8_direct_stride1_small_group);
direct_algos.emplace_back(&s8_direct_stride1);
direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44);
direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
direct_algos.emplace_back(&f16_direct_stride1_large_group);
direct_algos.emplace_back(&f16_direct_stride1_small_group);
direct_algos.emplace_back(&f16_direct_large_group);
direct_algos.emplace_back(&f16_direct_small_group);
direct_algos.emplace_back(&f16_direct_stride1);
direct_algos.emplace_back(&f16_direct);
#endif
direct_algos.emplace_back(&i8x8x16_direct_large_group);
direct_algos.emplace_back(&i8x8x16_direct_small_group);
direct_algos.emplace_back(&i8x8x16_direct);
direct_algos.emplace_back(&i8x8x16_stride2_filter2);
direct_algos.emplace_back(&i8x8x16_stride2_large_group);
direct_algos.emplace_back(&i8x8x16_stride2_small_group);
direct_algos.emplace_back(&i8x8x16_stride2);
direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44);
direct_algos.emplace_back(&f32_chanel_wise_nchw44);
direct_algos.emplace_back(&f32_direct_nchw44);
direct_algos.emplace_back(&f32_direct_stride1_large_group);
direct_algos.emplace_back(&f32_direct_stride1_small_group);
direct_algos.emplace_back(&f32_direct_stride2_large_group);
direct_algos.emplace_back(&f32_direct_stride2_small_group);
direct_algos.emplace_back(&f32_direct_large_group);
direct_algos.emplace_back(&f32_direct_small_group);
direct_algos.emplace_back(&f32_direct_stride1);
direct_algos.emplace_back(&f32_direct_stride2);
direct_algos.emplace_back(&f32_direct);
static CpuOprDelegationStorage<2> storage;
auto matmul_opr = storage.get<MatrixMul, 0>();
......
......@@ -25,21 +25,15 @@ using namespace megdnn;
using namespace arm_common;
/* ===================== stride1 algo ===================== */
bool ConvBiasImpl::AlgoQU8DirectStride1::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool avaible = direct_quint8_stride1::can_conv_direct_stride1_quint8(param);
if (algo_selection_strategy ==
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
avaible &= (large_group == m_large_group);
}
return avaible;
bool ConvBiasImpl::AlgoQU8DirectStride1::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
return direct_quint8_stride1::can_conv_direct_stride1_quint8(param);
}
size_t ConvBiasImpl::AlgoQU8DirectStride1::get_workspace(
const NCBKernSizeParam& param) const {
auto bundle = direct_quint8_stride1::get_bundle(param, m_large_group);
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_quint8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
......@@ -47,7 +41,8 @@ SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 0) {
return direct_quint8_stride1::get_kimpls(param, m_large_group);
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_quint8_stride1::get_kimpls(param, large_group);
}
MIDOUT_END();
return {};
......@@ -55,20 +50,15 @@ ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns(
/* ===================== stride2 algo ===================== */
bool ConvBiasImpl::AlgoQU8DirectStride2::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool avaible = direct_quint8_stride2::can_conv_direct_stride2_quint8(param);
if (algo_selection_strategy ==
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
avaible &= (large_group == m_large_group);
}
return avaible;
const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
return direct_quint8_stride2::can_conv_direct_stride2_quint8(param);
}
size_t ConvBiasImpl::AlgoQU8DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
auto bundle = direct_quint8_stride2::get_bundle(param, m_large_group);
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_quint8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
......@@ -76,31 +66,23 @@ SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 1) {
return direct_quint8_stride2::get_kimpls(param, m_large_group);
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_quint8_stride2::get_kimpls(param, large_group);
}
MIDOUT_END();
return {};
}
#if __ARM_FEATURE_DOTPROD
/* ===================== stride1 algo ===================== */
bool ConvBiasImpl::AlgoDotU8DirectStride1::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool avaible =
direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8(
param);
if (algo_selection_strategy ==
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
avaible &= (large_group == m_large_group);
}
return avaible;
bool ConvBiasImpl::AlgoDotU8DirectStride1::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
return direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8(param);
}
size_t ConvBiasImpl::AlgoDotU8DirectStride1::get_workspace(
const NCBKernSizeParam& param) const {
auto bundle =
direct_dotprod_quint8_stride1::get_bundle(param, m_large_group);
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_quint8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
......@@ -108,31 +90,23 @@ SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 0) {
return direct_dotprod_quint8_stride1::get_kimpls(param, m_large_group);
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_dotprod_quint8_stride1::get_kimpls(param, large_group);
}
MIDOUT_END();
return {};
}
/* ===================== stride2 algo ===================== */
bool ConvBiasImpl::AlgoDotU8DirectStride2::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool avaible =
direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8(
param);
if (algo_selection_strategy ==
ConvBiasImpl::AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
avaible &= (large_group == m_large_group);
}
return avaible;
bool ConvBiasImpl::AlgoDotU8DirectStride2::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
return direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8(param);
}
size_t ConvBiasImpl::AlgoDotU8DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
auto bundle =
direct_dotprod_quint8_stride2::get_bundle(param, m_large_group);
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_quint8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
......@@ -140,7 +114,8 @@ SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoDotU8DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 1) {
return direct_dotprod_quint8_stride2::get_kimpls(param, m_large_group);
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_dotprod_quint8_stride2::get_kimpls(param, large_group);
}
MIDOUT_END();
return {};
......
......@@ -18,14 +18,10 @@ namespace megdnn {
namespace arm_common {
class ConvBiasImpl::AlgoQU8DirectStride1 final : public AlgoBase {
bool m_large_group;
public:
AlgoQU8DirectStride1(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "QU8STRD1_LARGE_GROUP" : "QU8STRD1_SMALL_GROUP";
}
const char* name() const override { return "QU8STRD1"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -36,14 +32,10 @@ public:
};
class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase {
bool m_large_group;
public:
AlgoQU8DirectStride2(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "QU8STRD2_LARGE_GROUP" : "QU8STRD2_SMALL_GROUP";
}
const char* name() const override { return "QU8STRD2"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -53,15 +45,10 @@ public:
};
#if __ARM_FEATURE_DOTPROD
class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase {
bool m_large_group;
public:
AlgoDotU8DirectStride1(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "ARMDOTU8STRD1_LARGE_GROUP"
: "ARMDOTU8STRD1_SMALL_GROUP";
}
const char* name() const override { return "ARMDOTU8STRD1"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -72,15 +59,10 @@ public:
};
class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase {
bool m_large_group;
public:
AlgoDotU8DirectStride2(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "ARMDOTU8STRD2_LARGE_GROUP"
: "ARMDOTU8STRD2_SMALL_GROUP";
}
const char* name() const override { return "ARMDOTU8STRD2"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......
......@@ -65,9 +65,10 @@ void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH,
size_t IC = param.filter_meta.icpg; \
size_t OC = param.filter_meta.ocpg; \
size_t group = fm.group; \
bool large_group = group >= param.nr_threads; \
WorkspaceBundle bundle = get_bundle(param); \
SmallVector<NCBKern> ret_kerns; \
if (m_large_group) { \
if (large_group) { \
auto exec_one_group = [bundle]( \
const NCBKernParam& kern_param, \
const NCBKernIndex& ncb_index) mutable { \
......@@ -104,22 +105,15 @@ void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH,
/* ===================== direct algo ===================== */
bool ConvBiasImpl::AlgoDirect::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool ConvBiasImpl::AlgoDirect::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
bool aviliable = fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.spatial[0] <= 7 && fm.stride[0] == 1 &&
fm.stride[1] == 1;
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
return fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.spatial[0] <= 7 &&
fm.stride[0] == 1 && fm.stride[1] == 1;
}
WorkspaceBundle ConvBiasImpl::AlgoDirect::get_bundle(
const NCBKernSizeParam& param) const {
......@@ -133,9 +127,10 @@ WorkspaceBundle ConvBiasImpl::AlgoDirect::get_bundle(
get_rectified_img_size(IH, IW, FH, FW, OH, OW, fm.padding[0], fm.padding[1],
IH2, IW2, OH2, OW2);
size_t part0 = 0u, part1 = 0u;
bool large_group = group >= param.nr_threads;
if (IH != IH2 || IW != IW2) {
part0 = m_large_group ? IC * IH2 * IW2 * sizeof(float) * nr_threads
: IC * IH2 * IW2 * sizeof(float) * group * batch;
part0 = large_group ? IC * IH2 * IW2 * sizeof(float) * nr_threads
: IC * IH2 * IW2 * sizeof(float) * group * batch;
}
if (OH != OH2 || OW != OW2) {
part1 = OH2 * OW2 * sizeof(float) * nr_threads;
......@@ -319,24 +314,17 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoDirect::get_kimpls(
GET_KERN;
}
/* ===================== direct-stride2 algo ===================== */
bool ConvBiasImpl::AlgoDirectStride2::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
bool ConvBiasImpl::AlgoDirectStride2::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 &&
FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7);
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
return param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
}
WorkspaceBundle ConvBiasImpl::AlgoDirectStride2::get_bundle(
......@@ -352,10 +340,10 @@ WorkspaceBundle ConvBiasImpl::AlgoDirectStride2::get_bundle(
size_t src_size = 0, dst_size = 0;
size_t IH2, IW2, OH2, OW2;
get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2);
bool large_group = group >= param.nr_threads; \
if (need_src_copy(param)) {
src_size = m_large_group
? IC * IH2 * IW2 * sizeof(float) * nr_threads
: IC * IH2 * IW2 * sizeof(float) * group * batch;
src_size = large_group ? IC * IH2 * IW2 * sizeof(float) * nr_threads
: IC * IH2 * IW2 * sizeof(float) * group * batch;
}
if (need_dst_copy(param)) {
// we only need one dst plane
......
......@@ -29,14 +29,10 @@ class ConvBiasImpl::AlgoDirect final : public AlgoBase {
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids);
bool m_large_group;
public:
AlgoDirect(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP"
: "X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP";
return "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP";
}
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -65,14 +61,10 @@ class ConvBiasImpl::AlgoDirectStride2 final : public AlgoBase {
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids);
bool m_large_group;
public:
AlgoDirectStride2(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP"
: "X86_CONV_BIAS_DIRECT_STRIDE2_SMALL_GROUP";
return "X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP";
}
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......
......@@ -76,10 +76,8 @@ void* ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::type() const {
}
class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoDirect stride1_direct_large_group{true};
AlgoDirect stride1_direct_small_group{false};
AlgoDirectStride2 stride2_direct_large_group{true};
AlgoDirectStride2 stride2_direct_small_group{false};
AlgoDirect stride1_direct;
AlgoDirectStride2 stride2_direct;
AlgoDirectAvx2Stride1Int8 avx2_stride1_direct_int8;
AlgoAVX2DirectConvStride2 avx2_stride2_direct;
AlgoChanWiseAvx2Stride1Qint8 avx2_stride1_chanwsie_qint8;
......@@ -103,10 +101,8 @@ public:
all_algos.emplace_back(&mkldnn_matmul_qint8);
all_algos.emplace_back(&mkldnn_qint8);
#endif
all_algos.emplace_back(&stride1_direct_large_group);
all_algos.emplace_back(&stride1_direct_small_group);
all_algos.emplace_back(&stride2_direct_large_group);
all_algos.emplace_back(&stride2_direct_small_group);
all_algos.emplace_back(&stride1_direct);
all_algos.emplace_back(&stride2_direct);
all_algos.emplace_back(&avx2_stride1_direct_int8);
all_algos.emplace_back(&avx2_stride2_direct);
all_algos.emplace_back(&avx2_stride1_chanwsie_qint8);
......
......@@ -81,15 +81,10 @@ void checker_conv_bias(std::vector<conv_bias::TestArg> args, Handle* handle,
{arg.src, arg.filter, arg.bias, {}, {}});
}
}
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_LARGE_GROUP) {
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) {
check_conv_bias(
conv_bias::get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
handle(), "ARMV8F32STRD2_LARGE_GROUP");
}
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) {
check_conv_bias(
conv_bias::get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
handle(), "ARMV8F32STRD2_SMALL_GROUP");
handle(), "ARMV8F32STRD2");
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......@@ -114,17 +109,11 @@ void checker_conv_bias_fp16(std::vector<conv_bias::TestArg> args,
}
}
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR2_LARGE_GROUP) {
NormalRNG rng(1);
checker_conv_bias_f16(
conv_bias::get_conv_bias_args({2, 3, 5}, 2, false, false, false),
handle(), rng, "ARMV8F16STRD2_LARGE_GROUP", 0.04);
}
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR2_SMALL_GROUP) {
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR2) {
NormalRNG rng(1);
checker_conv_bias_f16(
conv_bias::get_conv_bias_args({2, 3, 5}, 2, false, false, false),
handle(), rng, "ARMV8F16STRD2_SMALL_GROUP", 0.04);
handle(), rng, "ARMV8F16STRD2", 0.04);
}
#endif
......
......@@ -1310,8 +1310,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE1_NCHW44) {
benchmark0.set_param(param);
benchmark0.set_times(RUN);
benchmark0.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"F32STRD1_LARGE_GROUP"));
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD1"));
auto opr = handle()->create_operator<ConvBias>();
opr->param() = param;
......@@ -1385,8 +1384,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE2_NCHW44) {
benchmark0.set_param(param);
benchmark0.set_times(RUN);
benchmark0.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"F32STRD2_LARGE_GROUP"));
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD2"));
auto opr = handle()->create_operator<ConvBias>();
opr->param() = param;
......@@ -1464,8 +1462,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) {
benchmark0.set_param(param);
benchmark0.set_times(RUN);
benchmark0.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"S8STRD1_LARGE_GROUP"));
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("S8STRD1"));
auto opr = handle()->create_operator<ConvBias>();
opr->param() = param;
......
......@@ -356,15 +356,10 @@ void checker_conv_bias_int8x8x32_multi(std::vector<conv_bias::TestArg> args,
}
/**********************************F32 direct************************/
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_LARGE_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32) {
check_conv_bias(
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
handle(), "F32DIRECT_LARGE_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) {
check_conv_bias(
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
handle(), "F32DIRECT_SMALL_GROUP");
handle(), "F32DIRECT");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) {
......@@ -391,21 +386,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) {
handle(), "F32_CONV_NCHW44_DIRECT");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_LARGE_GROUP) {
check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
handle(), "F32STRD1_LARGE_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_SMALL_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1) {
check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
handle(), "F32STRD1_SMALL_GROUP");
handle(), "F32STRD1");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_LARGE_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) {
check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
handle(), "F32STRD2_LARGE_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) {
check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
handle(), "F32STRD2_SMALL_GROUP");
handle(), "F32STRD2");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S2) {
check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false,
......@@ -437,72 +424,41 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) {
/**********************************F16 direct************************/
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_LARGE_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) {
NormalRNG rng(1);
checker_conv_bias_f16(
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
handle(), rng, "F16DIRECT_LARGE_GROUP", 0.03);
handle(), rng, "F16DIRECT", 0.03);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_SMALL_GROUP) {
NormalRNG rng(1);
checker_conv_bias_f16(
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
handle(), rng, "F16DIRECT_SMALL_GROUP", 0.03);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_LARGE_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1) {
NormalRNG rng(1);
checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
handle(), rng, "F16STRD1_LARGE_GROUP", 0.03);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1_SMALL_GROUP) {
NormalRNG rng(1);
checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
handle(), rng, "F16STRD1_SMALL_GROUP", 0.03);
handle(), rng, "F16STRD1", 0.03);
}
#endif
/**********************************algo 8816 direct************************/
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_LARGE_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT) {
checker_conv_bias_int8x8x16(
get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
"I8816DIRECT_LARGE_GROUP");
"I8816DIRECT");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT_SMALL_GROUP) {
checker_conv_bias_int8x8x16(
get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
"I8816DIRECT_SMALL_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_LARGE_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2) {
checker_conv_bias_int8x8x16(
get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
"I8816STRD2_LARGE_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2_SMALL_GROUP) {
checker_conv_bias_int8x8x16(
get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
"I8816STRD2_SMALL_GROUP");
"I8816STRD2");
}
/**********************************algo 8-8-32 direct************************/
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_LARGE_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1) {
checker_conv_bias_int8x8x32_multi(
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
"S8STRD1_LARGE_GROUP");
"S8STRD1");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1_SMALL_GROUP) {
checker_conv_bias_int8x8x32_multi(
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
"S8STRD1_SMALL_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_LARGE_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2) {
checker_conv_bias_int8x8x32_multi(
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
"S8STRD2_LARGE_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_SMALL_GROUP) {
checker_conv_bias_int8x8x32_multi(
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
"S8STRD2_SMALL_GROUP");
"S8STRD2");
}
TEST_F(ARM_COMMON_MULTI_THREADS,
......@@ -520,25 +476,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS,
}
/********************************qint8 direct******************************/
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_LARGE_GROUP) {
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
{2, 3, 5, 7}, 1, false, false, false),
handle(), "S8STRD1_LARGE_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_SMALL_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1) {
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
{2, 3, 5, 7}, 1, false, false, false),
handle(), "S8STRD1_SMALL_GROUP");
handle(), "S8STRD1");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_LARGE_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2) {
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
{2, 3, 5, 7}, 2, false, false, false),
handle(), "S8STRD2_LARGE_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_SMALL_GROUP) {
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
{2, 3, 5, 7}, 2, false, false, false),
handle(), "S8STRD2_SMALL_GROUP");
handle(), "S8STRD2");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) {
checker_conv_bias_qint8x8x8(
......@@ -586,25 +532,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) {
}
/*****************************quint8 direct****************************/
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_LARGE_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) {
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
{2, 3, 5, 7}, 1, false, false, false),
handle(), "QU8STRD1_LARGE_GROUP");
handle(), "QU8STRD1");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_SMALL_GROUP) {
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
{2, 3, 5, 7}, 1, false, false, false),
handle(), "QU8STRD1_SMALL_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_LARGE_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) {
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
{2, 3, 5, 7}, 2, false, false, false),
handle(), "QU8STRD2_LARGE_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) {
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
{2, 3, 5, 7}, 2, false, false, false),
handle(), "QU8STRD2_SMALL_GROUP");
handle(), "QU8STRD2");
}
/****************************dot qint8 direct*************************/
......@@ -624,100 +560,53 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
}
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) {
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
{2, 3, 5, 7}, 1, false, false, false),
handle(), "ARMDOTS8STRD1_LARGE_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) {
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
{2, 3, 5, 7}, 1, false, false, false),
handle(), "ARMDOTS8STRD1_SMALL_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) {
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
{2, 3, 5, 7}, 2, false, false, false),
handle(), "ARMDOTS8STRD2_LARGE_GROUP");
handle(), "ARMDOTS8STRD1");
}
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_INT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_WITHDOTPROD) {
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
{2, 3, 5, 7}, 2, false, false, false),
handle(), "ARMDOTS8STRD2_SMALL_GROUP");
handle(), "ARMDOTS8STRD2");
}
/****************************dot 8-8-32 direct*************************/
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_LARGE_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT) {
checker_conv_bias_qint8x8x32(
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
"ARMDOTS8STRD1_LARGE_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT_SMALL_GROUP) {
checker_conv_bias_qint8x8x32(
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
"ARMDOTS8STRD1_SMALL_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_LARGE_GROUP) {
checker_conv_bias_qint8x8x32(
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
"ARMDOTS8STRD2_LARGE_GROUP");
"ARMDOTS8STRD1");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT_SMALL_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT) {
checker_conv_bias_qint8x8x32(
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
"ARMDOTS8STRD2_SMALL_GROUP");
"ARMDOTS8STRD2");
}
/******************************dot quint8*****************************/
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD) {
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
{2, 3, 5, 7}, 1, false, false, false),
handle(), "ARMDOTU8STRD1_LARGE_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD_SMALL_GROUP) {
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
{2, 3, 5, 7}, 1, false, false, false),
handle(), "ARMDOTU8STRD1_SMALL_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_LARGE_GROUP) {
checker_conv_bias_quint8x8x8(
get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false),
handle(), "ARMDOTU8STRD2_LARGE_GROUP");
handle(), "ARMDOTU8STRD1");
}
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD_SMALL_GROUP) {
//! TODO: this test without test kernel size=3, add it will case buss error now
//! in armv7
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) {
checker_conv_bias_quint8x8x8(
get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false),
handle(), "ARMDOTU8STRD2_SMALL_GROUP");
handle(), "ARMDOTU8STRD2");
}
/******************************dot quint8x8x32***********************/
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_LARGE_GROUP) {
checker_conv_bias_quint8x8x32(
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
"ARMDOTU8STRD1_LARGE_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1_SMALL_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1) {
checker_conv_bias_quint8x8x32(
get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
"ARMDOTU8STRD1_SMALL_GROUP");
"ARMDOTU8STRD1");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_LARGE_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2) {
checker_conv_bias_quint8x8x32(
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
"ARMDOTU8STRD2_LARGE_GROUP");
"ARMDOTU8STRD2");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_SMALL_GROUP) {
checker_conv_bias_quint8x8x32(
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
"ARMDOTU8STRD2_SMALL_GROUP");
}
/******************************dot int8x8x8 nchw44 ***********************/
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8) {
using namespace conv_bias;
......
......@@ -125,7 +125,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) {
bench_case(1, 32, 32, 80, 80, 3, 4);
bench_case(1, 32, 32, 80, 80, 3, 32);
std::string algo_name = "F32DIRECT_LARGE_GROUP";
std::string algo_name = "F32DIRECT";
printf("Benchmark F32DIRECT_LARGE_GROUP algo\n");
std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(),
dtype::Float32(), dtype::Float32()};
......@@ -137,7 +137,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) {
{1, {4}}, data_type);
shapes_and_computation.clear();
algo_name = "F32DIRECT_SMALL_GROUP";
algo_name = "F32DIRECT";
printf("Benchmark F32DIRECT_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1);
bench_case(1, 32, 32, 128, 128, 3, 1);
......@@ -186,7 +186,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR1) {
bench_case(1, 32, 32, 80, 80, 3, 4);
bench_case(1, 32, 32, 80, 80, 3, 32);
std::string algo_name = "F32STRD1_LARGE_GROUP";
std::string algo_name = "F32STRD1";
printf("Benchmark F32STRD1_LARGE_GROUP algo\n");
std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(),
dtype::Float32(), dtype::Float32()};
......@@ -198,7 +198,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR1) {
{1, {4}}, data_type);
shapes_and_computation.clear();
algo_name = "F32STRD1_SMALL_GROUP";
algo_name = "F32STRD1";
printf("Benchmark F32STRD1_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1);
bench_case(1, 32, 32, 128, 128, 3, 1);
......@@ -249,7 +249,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR2) {
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2);
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2);
std::string algo_name = "F32STRD2_LARGE_GROUP";
std::string algo_name = "F32STRD2";
printf("Benchmark F32STRD2_LARGE_GROUP algo\n");
std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(),
dtype::Float32(), dtype::Float32()};
......@@ -261,7 +261,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR2) {
{1, {4}}, data_type);
shapes_and_computation.clear();
algo_name = "F32STRD2_SMALL_GROUP";
algo_name = "F32STRD2";
printf("Benchmark F32STRD2_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2);
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2);
......@@ -313,7 +313,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16) {
bench_case(1, 32, 32, 80, 80, 3, 4);
bench_case(1, 32, 32, 80, 80, 3, 32);
std::string algo_name = "F16DIRECT_LARGE_GROUP";
std::string algo_name = "F16DIRECT";
printf("Benchmark F16DIRECT_LARGE_GROUP algo\n");
std::vector<DType> data_type = {dtype::Float16(), dtype::Float16(),
dtype::Float16(), dtype::Float16()};
......@@ -325,7 +325,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16) {
{1, {4}}, data_type);
shapes_and_computation.clear();
algo_name = "F16DIRECT_SMALL_GROUP";
algo_name = "F16DIRECT";
printf("Benchmark F16DIRECT_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1);
bench_case(1, 32, 32, 128, 128, 3, 1);
......@@ -375,7 +375,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16_STR1) {
bench_case(1, 32, 32, 80, 80, 3, 4);
bench_case(1, 32, 32, 80, 80, 3, 32);
std::string algo_name = "F16STRD1_LARGE_GROUP";
std::string algo_name = "F16STRD1";
printf("Benchmark F16STRD1_LARGE_GROUP algo\n");
std::vector<DType> data_type = {dtype::Float16(), dtype::Float16(),
dtype::Float16(), dtype::Float16()};
......@@ -387,7 +387,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16_STR1) {
{1, {4}}, data_type);
shapes_and_computation.clear();
algo_name = "F16STRD1_SMALL_GROUP";
algo_name = "F16STRD1";
printf("Benchmark F16STRD1_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1);
bench_case(1, 32, 32, 128, 128, 3, 1);
......@@ -439,7 +439,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
bench_case(1, 32, 32, 80, 80, 3, 4);
bench_case(1, 32, 32, 80, 80, 3, 32);
std::string algo_name = "I8816DIRECT_LARGE_GROUP";
std::string algo_name = "I8816DIRECT";
printf("Benchmark I8816DIRECT_LARGE_GROUP algo\n");
std::vector<DType> data_type = {dtype::Int8(), dtype::Int8(),
dtype::Int16(), dtype::Int16()};
......@@ -451,7 +451,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
{1, {4}}, data_type);
shapes_and_computation.clear();
algo_name = "I8816DIRECT_SMALL_GROUP";
algo_name = "I8816DIRECT";
printf("Benchmark I8816DIRECT_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1);
bench_case(1, 32, 32, 128, 128, 3, 1);
......@@ -503,7 +503,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2);
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2);
std::string algo_name = "I8816STRD2_LARGE_GROUP";
std::string algo_name = "I8816STRD2";
printf("Benchmark I8816STRD2_LARGE_GROUP algo\n");
std::vector<DType> data_type = {dtype::Int8(), dtype::Int8(),
dtype::Int16(), dtype::Int16()};
......@@ -515,7 +515,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
{1, {4}}, data_type);
shapes_and_computation.clear();
algo_name = "I8816STRD2_SMALL_GROUP";
algo_name = "I8816STRD2";
printf("Benchmark I8816STRD2_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2);
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2);
......@@ -567,7 +567,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1);
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1);
std::string algo_name = "S8STRD1_LARGE_GROUP";
std::string algo_name = "S8STRD1";
printf("Benchmark S8STRD1_LARGE_GROUP algo\n");
std::vector<DType> data_type = {
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
......@@ -580,7 +580,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
{1, {4}}, data_type);
shapes_and_computation.clear();
algo_name = "S8STRD1_SMALL_GROUP";
algo_name = "S8STRD1";
printf("Benchmark S8STRD1_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1);
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1);
......@@ -866,7 +866,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2);
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2);
std::string algo_name = "S8STRD2_LARGE_GROUP";
std::string algo_name = "S8STRD2";
printf("Benchmark S8STRD2_LARGE_GROUP algo\n");
std::vector<DType> data_type = {
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
......@@ -879,7 +879,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
{1, {4}}, data_type);
shapes_and_computation.clear();
algo_name = "S8STRD2_SMALL_GROUP";
algo_name = "S8STRD2";
printf("Benchmark S8STRD2_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2);
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2);
......@@ -932,7 +932,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1);
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1);
std::string algo_name = "ARMDOTS8STRD1_LARGE_GROUP";
std::string algo_name = "ARMDOTS8STRD1";
printf("Benchmark ARMDOTS8STRD1_LARGE_GROUP algo\n");
std::vector<DType> data_type = {
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
......@@ -945,7 +945,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
{1, {4}}, data_type);
shapes_and_computation.clear();
algo_name = "ARMDOTS8STRD1_SMALL_GROUP";
algo_name = "ARMDOTS8STRD1";
printf("Benchmark ARMDOTS8STRD1_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1);
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1);
......@@ -997,7 +997,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2);
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2);
std::string algo_name = "ARMDOTS8STRD2_LARGE_GROUP";
std::string algo_name = "ARMDOTS8STRD2";
printf("Benchmark ARMDOTS8STRD2_LARGE_GROUP algo\n");
std::vector<DType> data_type = {
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
......@@ -1010,7 +1010,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
{1, {4}}, data_type);
shapes_and_computation.clear();
algo_name = "ARMDOTS8STRD2_SMALL_GROUP";
algo_name = "ARMDOTS8STRD2";
printf("Benchmark ARMDOTS8STRD2_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2);
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2);
......@@ -1064,7 +1064,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1);
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1);
std::string algo_name = "QU8STRD1_LARGE_GROUP";
std::string algo_name = "QU8STRD1";
printf("Benchmark QU8STRD1_LARGE_GROUP algo\n");
std::vector<DType> data_type = {dtype::Quantized8Asymm(0.2f, 100),
dtype::Quantized8Asymm(0.2f, 120),
......@@ -1078,7 +1078,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
{1, {4}}, data_type);
shapes_and_computation.clear();
algo_name = "QU8STRD1_SMALL_GROUP";
algo_name = "QU8STRD1";
printf("Benchmark QU8STRD1_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1);
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1);
......@@ -1130,7 +1130,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2);
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2);
std::string algo_name = "QU8STRD2_LARGE_GROUP";
std::string algo_name = "QU8STRD2";
printf("Benchmark QU8STRD2_LARGE_GROUP algo\n");
std::vector<DType> data_type = {dtype::Quantized8Asymm(0.2f, 100),
dtype::Quantized8Asymm(0.2f, 120),
......@@ -1144,7 +1144,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
{1, {4}}, data_type);
shapes_and_computation.clear();
algo_name = "QU8STRD2_SMALL_GROUP";
algo_name = "QU8STRD2";
printf("Benchmark QU8STRD2_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2);
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2);
......@@ -1198,7 +1198,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 1);
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 1);
std::string algo_name = "ARMDOTU8STRD1_LARGE_GROUP";
std::string algo_name = "ARMDOTU8STRD1";
printf("Benchmark ARMDOTU8STRD1_LARGE_GROUP algo\n");
std::vector<DType> data_type = {dtype::Quantized8Asymm(0.2f, 100),
dtype::Quantized8Asymm(0.2f, 120),
......@@ -1212,7 +1212,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
{1, {4}}, data_type);
shapes_and_computation.clear();
algo_name = "ARMDOTU8STRD1_SMALL_GROUP";
algo_name = "ARMDOTU8STRD1";
printf("Benchmark ARMDOTS8STRD1_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 1);
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 1);
......@@ -1265,7 +1265,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
bench_case(1, 32, 32, 80, 80, 5, 4, 1, 2);
bench_case(1, 32, 32, 80, 80, 5, 32, 1, 2);
std::string algo_name = "ARMDOTU8STRD2_LARGE_GROUP";
std::string algo_name = "ARMDOTU8STRD2";
printf("Benchmark ARMDOTU8STRD2_LARGE_GROUP algo\n");
std::vector<DType> data_type = {dtype::Quantized8Asymm(0.2f, 100),
dtype::Quantized8Asymm(0.2f, 120),
......@@ -1279,7 +1279,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
{1, {4}}, data_type);
shapes_and_computation.clear();
algo_name = "ARMDOTU8STRD2_SMALL_GROUP";
algo_name = "ARMDOTU8STRD2";
printf("Benchmark ARMDOTU8STRD2_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 5, 1, 1, 2);
bench_case(1, 32, 32, 128, 128, 5, 1, 1, 2);
......
......@@ -176,7 +176,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE1_I8x8x32_WITHDOTPROD) {
constexpr size_t RUN = 50;
Benchmarker<Convolution> benchmark(handle());
benchmark.set_before_exec_callback(
AlgoChecker<Convolution>("CONVOLUTION_DEFAULT_ARMDOTS8STRD1_SMALL_GROUP"));
AlgoChecker<Convolution>("CONVOLUTION_DEFAULT_ARMDOTS8STRD1"));
benchmark.set_dtype(0, dtype::Int8())
.set_dtype(1, dtype::Int8())
.set_dtype(2, dtype::Int32());
......@@ -243,7 +243,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE2_I8x8x32_WITHDOTPROD) {
constexpr size_t RUN = 10;
Benchmarker<Convolution> benchmark(handle());
benchmark.set_before_exec_callback(
AlgoChecker<Convolution>("CONVOLUTION_DEFAULT_ARMDOTS8STRD2_SMALL_GROUP"));
AlgoChecker<Convolution>("CONVOLUTION_DEFAULT_ARMDOTS8STRD2"));
benchmark.set_dtype(0, dtype::Int8())
.set_dtype(1, dtype::Int8())
.set_dtype(2, dtype::Int32());
......@@ -317,7 +317,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE1_QUINT8_WITHDOTPROD) {
benchmark.set_display(false);
benchmark.set_times(RUN);
benchmark.set_before_exec_callback(AlgoChecker<ConvolutionForward>(
"CONVOLUTION_DEFAULT_ARMDOTU8STRD1_SMALL_GROUP"));
"CONVOLUTION_DEFAULT_ARMDOTU8STRD1"));
Benchmarker<Convolution> benchmark_float(handle());
benchmark_float.set_display(false);
......@@ -387,7 +387,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE2_QUINT8_WITHDOTPROD) {
benchmark.set_display(false);
benchmark.set_times(RUN);
benchmark.set_before_exec_callback(AlgoChecker<ConvolutionForward>(
"CONVOLUTION_DEFAULT_ARMDOTU8STRD2_SMALL_GROUP"));
"CONVOLUTION_DEFAULT_ARMDOTU8STRD2"));
Benchmarker<Convolution> benchmark_float(handle());
benchmark_float.set_display(false);
......
......@@ -583,7 +583,7 @@ TEST_F(X86_MULTI_THREADS, AVX2_CONV_BIAS_DIRECT_STRIDE2_S8S8S8) {
}
}
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP) {
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_DENSE) {
using namespace conv_bias;
std::vector<TestArg> args;
......@@ -633,19 +633,19 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP) {
.set_rng(2, &rng);
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP"));
"X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP"));
for (auto&& arg : args) {
checker.set_param(arg.param).exec(
{arg.src, arg.filter, arg.bias, {}, {}});
}
}
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP) {
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_GROUP) {
using namespace conv_bias;
std::vector<TestArg> args;
auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel,
size_t p, NonlineMode nonline_mode) {
auto run = [&](size_t group, size_t channel, size_t w, size_t h,
size_t kernel, size_t p, NonlineMode nonline_mode) {
if (w + 2 * p < kernel || h + 2 * p < kernel)
return;
param::ConvBias param;
......@@ -654,30 +654,37 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP) {
param.pad_h = p;
param.pad_w = p;
param.nonlineMode = nonline_mode;
param.sparse = param::ConvBias::Sparse::GROUP;
//! no bias
args.emplace_back(param, TensorShape{1, ic, h, w},
TensorShape{oc, ic, kernel, kernel}, TensorShape{});
args.emplace_back(
param, TensorShape{1, channel, h, w},
TensorShape{group, channel / group, channel / group, kernel, kernel},
TensorShape{});
//! bias channel
args.emplace_back(param, TensorShape{2, ic, h, w},
TensorShape{oc, ic, kernel, kernel},
TensorShape{1, oc, 1, 1});
args.emplace_back(param, TensorShape{2, channel, h, w},
TensorShape{group, channel / group, channel / group,
kernel, kernel},
TensorShape{1, channel, 1, 1});
//! bias
args.emplace_back(param, TensorShape{2, ic, h, w},
TensorShape{oc, ic, kernel, kernel},
TensorShape{2, oc, (h + param.pad_h * 2 - kernel) + 1,
(w + param.pad_w * 2 - kernel) + 1});
args.emplace_back(
param, TensorShape{2, channel, h, w},
TensorShape{group, channel / group, channel / group, kernel,
kernel},
TensorShape{2, channel, (h + param.pad_h * 2 - kernel) + 1,
(w + param.pad_w * 2 - kernel) + 1});
};
for (size_t kernel : {1, 2, 3, 4, 5, 6, 7})
for (size_t ic : {1, 4, 8, 16})
for (size_t oc : {1, 4, 8})
for (size_t channel : {4, 8, 16})
for (size_t group : {1, 2, 4})
for (size_t p : {0, 2})
for (size_t size : {20, 21, 24})
for (NonlineMode nonline_mode :
{NonlineMode::RELU, NonlineMode::SIGMOID,
NonlineMode::H_SWISH, NonlineMode::IDENTITY}) {
run(oc, ic, size, size, kernel, p, nonline_mode);
run(group, channel, size, size, kernel, p,
nonline_mode);
}
Checker<ConvBias> checker(handle());
......@@ -697,7 +704,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP) {
}
}
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2) {
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2_DENSE) {
using namespace conv_bias;
std::vector<TestArg> args;
......@@ -738,11 +745,68 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2) {
.set_rng(2, &rng);
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"X86_CONV_BIAS_DIRECT_STRIDE2_SMALL_GROUP"));
"X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP"));
for (auto&& arg : args) {
checker.set_param(arg.param).exec(
{arg.src, arg.filter, arg.bias, {}, {}});
}
}
TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2_GROUP) {
using namespace conv_bias;
std::vector<TestArg> args;
auto run = [&](size_t group, size_t channel, size_t w, size_t h,
size_t kernel, size_t p, NonlineMode nonline_mode) {
if (w + 2 * p < kernel || h + 2 * p < kernel)
return;
param::ConvBias param;
param.stride_h = 2;
param.stride_w = 2;
param.pad_h = p;
param.pad_w = p;
param.nonlineMode = nonline_mode;
param.sparse = param::ConvBias::Sparse::GROUP;
//! no bias
args.emplace_back(
param, TensorShape{1, channel, h, w},
TensorShape{group, channel / group, channel / group, kernel, kernel},
TensorShape{});
//! bias channel
args.emplace_back(param, TensorShape{2, channel, h, w},
TensorShape{group, channel / group, channel / group,
kernel, kernel},
TensorShape{1, channel, 1, 1});
//! bias
args.emplace_back(
param, TensorShape{2, channel, h, w},
TensorShape{group, channel / group, channel / group, kernel,
kernel},
TensorShape{2, channel, (h + param.pad_h * 2 - kernel) / 2 + 1,
(w + param.pad_w * 2 - kernel) / 2 + 1});
};
for (size_t kernel : {2, 3, 5, 7})
for (size_t channel : {4, 8, 16})
for (size_t group : {1, 2, 4})
for (size_t p : {0, 2})
for (size_t size : {20, 21, 24})
for (NonlineMode nonline_mode :
{NonlineMode::RELU, NonlineMode::SIGMOID,
NonlineMode::H_SWISH, NonlineMode::IDENTITY}) {
run(group, channel, size, size, kernel, p,
nonline_mode);
}
Checker<ConvBias> checker(handle());
UniformIntRNG rng{-50, 50};
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_rng(0, &rng)
.set_rng(1, &rng)
.set_rng(2, &rng);
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP"));
......@@ -2502,7 +2566,7 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) {
bench_case(1, 32, 32, 80, 80, 3, 32);
std::string algo_name = "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP";
printf("Benchmark X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP algo\n");
printf("Benchmark X86_CONV_BIAS_DIRECT_STRIDE1_GROUP algo\n");
benchmark_impl(param, shapes_and_computation, algo_name, RUNS,
{4, {4, 5, 6, 7}}, {1, {4}}, data_type);
benchmark_impl(param, shapes_and_computation, algo_name, RUNS,
......@@ -2511,8 +2575,8 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) {
{1, {4}}, data_type);
shapes_and_computation.clear();
algo_name = "X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP";
printf("Benchmark X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP algo\n");
algo_name = "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP";
printf("Benchmark X86_CONV_BIAS_DIRECT_STRIDE1_DENSE algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1);
bench_case(1, 32, 32, 128, 128, 3, 1);
bench_case(1, 32, 32, 100, 100, 3, 1);
......
......@@ -125,7 +125,7 @@ TEST_F(X86, DEFAULT_CONV_DIRECT_STRIDE1) {
Checker<ConvolutionForward> checker(handle());
checker.set_before_exec_callback(AlgoChecker<ConvolutionForward>(
"CONVOLUTION_DEFAULT_X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP"));
"CONVOLUTION_DEFAULT_X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP"));
checker.set_epsilon(1);
UniformIntRNG rng{-50, 50};
checker.set_dtype(0, dtype::Float32())
......@@ -167,7 +167,7 @@ TEST_F(X86, DEFAULT_CONV_DIRECT_STRIDE2) {
Checker<ConvolutionForward> checker(handle());
checker.set_before_exec_callback(AlgoChecker<ConvolutionForward>(
"CONVOLUTION_DEFAULT_X86_CONV_BIAS_DIRECT_STRIDE2_SMALL_GROUP"));
"CONVOLUTION_DEFAULT_X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP"));
checker.set_epsilon(1);
UniformIntRNG rng{-50, 50};
checker.set_dtype(0, dtype::Float32())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册