提交 89374521 编写于 作者: M Megvii Engine Team

fix(dnn/arm_common): add nchw44 float channel wise s1/s2

GitOrigin-RevId: 73e6aa1e57c36b5f8bc4e05faf4e9f06ec7e5cb7
上级 9f997ac5
......@@ -178,13 +178,16 @@ public:
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase {
class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF32DirectNCHW44() {}
AlgoF32DirectStride1(bool is_large_group) : m_large_group{is_large_group} {}
bool is_reproducible() const override { return true; }
const char* name() const override { return "F32_CONV_NCHW44_DIRECT"; }
const char* name() const override {
return m_large_group ? "F32STRD1_LARGE_GROUP" : "F32STRD1_SMALL_GROUP";
}
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -194,13 +197,17 @@ public:
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoF32DirectStride2NCHW44 final : public AlgoBase {
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF32DirectStride2NCHW44() {}
AlgoF32DirectStride2(bool is_large_group) : m_large_group{is_large_group} {}
bool is_reproducible() const override { return true; }
const char* name() const override { return "F32_CONV_NCHW44_DIRECT_S2"; }
const char* name() const override {
return m_large_group ? "F32STRD2_LARGE_GROUP" : "F32STRD2_SMALL_GROUP";
}
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -211,16 +218,13 @@ public:
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase {
class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF32DirectStride1(bool is_large_group) : m_large_group{is_large_group} {}
AlgoF32DirectNCHW44() {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "F32STRD1_LARGE_GROUP" : "F32STRD1_SMALL_GROUP";
}
const char* name() const override { return "F32_CONV_NCHW44_DIRECT"; }
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -231,33 +235,29 @@ public:
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
class ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF32DirectStride2(bool is_large_group) : m_large_group{is_large_group} {}
AlgoF32DirectStride2NCHWNCHW44() {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "F32STRD2_LARGE_GROUP" : "F32STRD2_SMALL_GROUP";
}
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
const char* name() const override { return "F32_CONV_NCHW_NCHW44"; }
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(fallback::ConvBiasImpl*,
size_t get_workspace(fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44 final : public AlgoBase {
class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
public:
AlgoF32DirectStride2NCHWNCHW44() {}
bool is_reproducible() const override { return true; }
const char* name() const override { return "F32_CONV_NCHW_NCHW44"; }
const char* name() const override { return "F32_CHANNEL_WISE_NCHW44"; }
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......
/**
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/elemwise_op.h"
#include "midout.h"
using namespace megdnn;
using namespace arm_common;
using conv_fun = std::function<void(
const float* src, const float* filter, const float* bias, float* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW,
const size_t PH, size_t PW)>;
MIDOUT_DECL(conv_bias_fp32_channel_wise_nchw44)
bool ConvBiasImpl::AlgoF32ChannelWiseNCHW44::usable(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
size_t OC = fm.ocpg;
size_t IC = fm.icpg;
size_t GROUP = fm.group;
bool ok_type = (param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
(param.dst_type.enumv() == DTypeEnum::Float32));
bool ok_format = OC == 1 && IC == 1 && GROUP % 4 == 0 &&
fm.format == param::Convolution::Format::NCHW44;
bool ok_filter = fm.spatial_ndim == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[0] == 2);
bool ok_conv = !fm.should_flip;
bool avaible = ok_type && ok_format && ok_filter && ok_slide && ok_conv;
return avaible;
}
size_t ConvBiasImpl::AlgoF32ChannelWiseNCHW44::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam&) const {
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF32ChannelWiseNCHW44::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
const constexpr size_t pack_group_size = 4_z;
auto fm = param.filter_meta;
const int batch = param.n;
const int group = fm.group;
const int stride = fm.stride[0];
conv_fun do_conv_fun = nullptr;
// NOTE: remain_w is not used to gen hash of midout for compatible with
// shape runtime
#define DO_CONV_KERN_FUN(_stride, filter, bias_mode, op) \
MIDOUT_BEGIN(conv_bias_fp32_channel_wise_nchw44, \
midout_iv(#_stride #filter #bias_mode #op##_hash)) { \
do_conv_fun = channel_wise_nchw44_float:: \
do_conv_kern_##_stride##_##filter##x##filter<bias_mode, op>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(_stride, filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(_stride, filter, bias_mode, NoneOp<dt_float32>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(_stride, filter, bias_mode, ReluOp<dt_float32>) \
break; \
case param::ConvBias::NonlineMode::SIGMOID: \
DO_CONV_KERN_FUN(_stride, filter, bias_mode, \
SigmoidOp<dt_float32>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(_stride, filter, bias_mode, HSwishOp<dt_float32>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_BIAS_MODE_PARAM(_stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(_stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(_stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
case BiasMode::BIAS: \
GET_OP_PARAM(_stride, filter, BiasMode::BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN(_stride) \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(_stride, 2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(_stride, 3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(_stride, 5) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_STRIDE() \
if (1 == stride) { \
DISPATCH_CONV_KERN(stride1); \
} else { \
DISPATCH_CONV_KERN(stride2); \
}
DISPATCH_STRIDE();
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
#undef GET_OP_PARAM
#undef GET_BIAS_MODE_PARAM
#undef DISPATCH_CONV_KERN
#undef DISPATCH_STRIDE
megdnn_assert(do_conv_fun);
SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
CpuNDRange ncb_range = {static_cast<size_t>(batch),
static_cast<size_t>(group / pack_group_size)};
auto do_conv = [do_conv_fun](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
const float* sptr =
kern_param.src<float>(batch_id, group_id, 0, pack_group_size);
const float* fptr = kern_param.filter<float>(group_id, pack_group_size);
float* dst =
kern_param.dst<float>(batch_id, group_id, 0, pack_group_size);
const float* bptr =
kern_param.bias<float>(batch_id, group_id, 0, pack_group_size);
//! copy in case of illegal read src when padding is zero
do_conv_fun(sptr, fptr, bptr, dst, IH, IW, OH, OW, PH, PW);
};
ret_kerns.push_back({do_conv, ncb_range});
return ret_kerns;
}
//vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace channel_wise_nchw44_float {
#define KERN(stride, i) \
template <BiasMode bias_mode, typename Op> \
void do_conv_kern_##stride##_##i##x##i( \
const float* src, const float* filter, const float* bias, \
float* dst, const size_t IH, const size_t IW, const size_t OH, \
const size_t OW, const size_t PH, const size_t PW);
KERN(stride1, 2)
KERN(stride1, 3)
KERN(stride1, 5)
KERN(stride2, 2)
KERN(stride2, 3)
KERN(stride2, 5)
#undef KERN
} // namespace channel_wise_nchw44_float
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/int8/direct.cpp
* \file dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -1640,4 +1640,5 @@ FOR_STRIDE
#undef FOR_BIAS
#undef FOR_OP
#undef INSTANTIATION
// vim: syntax=cpp.doxygen
......@@ -65,14 +65,17 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false};
#endif
AlgoF32DirectStride2NCHWNCHW44 f32_direct_stride2_nchw_nchw44;
AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44;
AlgoF32DirectNCHW44 f32_direct_nchw44;
AlgoF32Direct f32_direct_large_group{true};
AlgoF32Direct f32_direct_small_group{false};
AlgoF32DirectNCHW44 f32_direct_nchw44;
AlgoF32DirectStride2 f32_direct_stride2_large_group{true};
AlgoF32DirectStride2 f32_direct_stride2_small_group{false};
AlgoF32DirectStride1 f32_direct_stride1_large_group{true};
AlgoF32DirectStride1 f32_direct_stride1_small_group{false};
AlgoF32DirectStride2NCHWNCHW44 f32_direct_stride2_nchw_nchw44;
AlgoI8x8x16Direct i8x8x16_direct_large_group{true};
AlgoI8x8x16Direct i8x8x16_direct_small_group{false};
AlgoI8x8x16Stride2 i8x8x16_stride2_large_group{true};
......@@ -125,8 +128,11 @@ public:
direct_algos.emplace_back(&i8x8x16_stride2_filter2);
direct_algos.emplace_back(&i8x8x16_stride2_large_group);
direct_algos.emplace_back(&i8x8x16_stride2_small_group);
direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44);
direct_algos.emplace_back(&f32_chanel_wise_nchw44);
direct_algos.emplace_back(&f32_direct_nchw44);
direct_algos.emplace_back(&f32_direct_stride1_large_group);
direct_algos.emplace_back(&f32_direct_stride1_small_group);
direct_algos.emplace_back(&f32_direct_stride2_large_group);
......
......@@ -66,10 +66,10 @@ private:
#endif
class AlgoF32Direct;
class AlgoF32DirectStride1;
class AlgoF32DirectNCHW44;
class AlgoF32DirectStride2;
class AlgoF32DirectStride2NCHWNCHW44;
class AlgoF32DirectStride2NCHW44;
class AlgoF32ChannelWiseNCHW44;
class AlgoF32DirectNCHW44;
class AlgoI8x8x16Direct;
class AlgoI8x8x16Stride2;
......
......@@ -1086,6 +1086,155 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE2) {
used1 / used0);
}
}
TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE1_NCHW44) {
// have to remove preferred restrict in usable func before run the benchmark
using namespace conv_bias;
param::ConvBias param;
param.stride_h = 1;
param.stride_w = 1;
param.pad_h = 1;
param.pad_w = 1;
param.nonlineMode = NonlineMode::RELU;
param.sparse = param::ConvBias::Sparse::GROUP;
constexpr size_t RUN = 50;
Benchmarker<ConvBias> benchmark0(handle());
benchmark0.set_display(false);
benchmark0.set_param(param);
benchmark0.set_times(RUN);
benchmark0.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"F32STRD1_LARGE_GROUP"));
auto opr = handle()->create_operator<ConvBias>();
opr->param() = param;
param.format = param::ConvBias::Format::NCHW44;
Benchmarker<ConvBias> benchmark1(handle());
benchmark1.set_display(false);
benchmark1.set_param(param);
benchmark1.set_times(RUN);
benchmark1.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"F32_CHANNEL_WISE_NCHW44"));
auto run = [&](size_t group, size_t w, size_t h, size_t kernel) {
TensorLayout dst_layout;
opr->deduce_layout({{1, group * 4, h, w}, dtype::Int8()},
{{group * 4, 1, 1, kernel, kernel}, dtype::Int8()},
{{1, group * 4, 1, 1}, dtype::Int32()}, {},
dst_layout);
//! dst.nr_elems * IC * FH * FW * 2
float computations = dst_layout.total_nr_elems() * kernel * kernel *
2.0 / (1024 * 1024 * 1024) * 1e3;
auto used0 = benchmark0.exec({{1, group * 4, h, w},
{group * 4, 1, 1, kernel, kernel},
{1, group * 4, 1, 1},
{},
{}}) /
RUN;
auto used1 = benchmark1.exec({{1, group, h, w, 4},
{group, 1, 1, kernel, kernel, 4},
{1, group, 1, 1, 4},
{},
{}}) /
RUN;
printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops "
"nchw44: "
"%f ms %f GFlops "
"speedup: %f\n",
group, h, w, kernel, used0, computations / used0, used1,
computations / used1, used0 / used1);
};
for (size_t group : {8, 16, 32, 64}) {
for (size_t kerenl : {2, 3, 5}) {
run(group, 112, 112, kerenl);
run(group, 56, 56, kerenl);
run(group, 48, 48, kerenl);
run(group, 28, 28, kerenl);
run(group, 14, 14, kerenl);
}
}
run(8, 112, 112, 3);
run(32, 56, 56, 3);
run(64, 28, 28, 3);
run(128, 14, 14, 3);
}
TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE2_NCHW44) {
// have to remove preferred restrict in usable func before run the benchmark
using namespace conv_bias;
param::ConvBias param;
param.stride_h = 2;
param.stride_w = 2;
param.pad_h = 1;
param.pad_w = 1;
param.nonlineMode = NonlineMode::RELU;
param.sparse = param::ConvBias::Sparse::GROUP;
constexpr size_t RUN = 50;
Benchmarker<ConvBias> benchmark0(handle());
benchmark0.set_display(false);
benchmark0.set_param(param);
benchmark0.set_times(RUN);
benchmark0.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"F32STRD2_LARGE_GROUP"));
auto opr = handle()->create_operator<ConvBias>();
opr->param() = param;
param.format = param::ConvBias::Format::NCHW44;
Benchmarker<ConvBias> benchmark1(handle());
benchmark1.set_display(false);
benchmark1.set_param(param);
benchmark1.set_times(RUN);
benchmark1.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"F32_CHANNEL_WISE_NCHW44"));
auto run = [&](size_t group, size_t w, size_t h, size_t kernel) {
TensorLayout dst_layout;
opr->deduce_layout({{1, group * 4, h, w}, dtype::Int8()},
{{group * 4, 1, 1, kernel, kernel}, dtype::Int8()},
{{1, group * 4, 1, 1}, dtype::Int32()}, {},
dst_layout);
//! dst.nr_elems * IC * FH * FW * 2
float computations = dst_layout.total_nr_elems() * kernel * kernel *
2.0 / (1024 * 1024 * 1024) * 1e3;
auto used0 = benchmark0.exec({{1, group * 4, h, w},
{group * 4, 1, 1, kernel, kernel},
{1, group * 4, 1, 1},
{},
{}}) /
RUN;
auto used1 = benchmark1.exec({{1, group, h, w, 4},
{group, 1, 1, kernel, kernel, 4},
{1, group, 1, 1, 4},
{},
{}}) /
RUN;
printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops "
"nchw44: "
"%f ms %f GFlops "
"speedup: %f\n",
group, h, w, kernel, used0, computations / used0, used1,
computations / used1, used0 / used1);
};
for (size_t group : {8, 16, 32, 64}) {
for (size_t kerenl : {2, 3, 5}) {
run(group, 112, 112, kerenl);
run(group, 56, 56, kerenl);
run(group, 48, 48, kerenl);
run(group, 28, 28, kerenl);
run(group, 14, 14, kerenl);
}
}
run(8, 112, 112, 3);
run(32, 56, 56, 3);
run(64, 28, 28, 3);
run(128, 14, 14, 3);
}
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) {
// have to remove preferred restrict in usable func before run the benchmark
......
......@@ -181,9 +181,9 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
return args;
}
std::vector<conv_bias::TestArg> get_int8_quint8_nchw44_channel_wise_args(
std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args(
std::vector<size_t> kernel, size_t stride, bool no_bias,
bool no_nonlinemode) {
bool no_nonlinemode, bool no_full_bias) {
using namespace conv_bias;
using Param = param::ConvBias;
using NLMode = param::ConvBias::NonlineMode;
......@@ -213,6 +213,15 @@ std::vector<conv_bias::TestArg> get_int8_quint8_nchw44_channel_wise_args(
TensorShape{group, 1, 1, kernel, kernel, 4},
TensorShape{1, group, 1, 1, 4});
}
if (!no_full_bias) {
args.emplace_back(
param, TensorShape{n, group, h, w, 4},
TensorShape{group, 1, 1, kernel, kernel, 4},
TensorShape{n, group,
(h + 2 * param.pad_w - kernel) / stride + 1,
(w + 2 * param.pad_w - kernel) / stride + 1,
4});
}
};
std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
......@@ -224,7 +233,7 @@ std::vector<conv_bias::TestArg> get_int8_quint8_nchw44_channel_wise_args(
for (auto nlmode : nonlinemode) {
for (bool pad : {true}) {
for (size_t group : {1, 2, 4, 7, 128}) {
for (size_t size : {4, 5, 6, 7, 8, 9, 10, 15, 40}) {
for (size_t size : {4, 6, 7, 9, 15, 40}) {
for (size_t kern : kernel) {
pack(n, group, size, size, kern, stride, nlmode,
pad);
......@@ -234,7 +243,7 @@ std::vector<conv_bias::TestArg> get_int8_quint8_nchw44_channel_wise_args(
}
for (bool pad : {false}) {
for (size_t group : {1, 2, 7, 128}) {
for (size_t size : {7, 8, 9, 10, 15, 40}) {
for (size_t size : {7, 9, 15, 40}) {
for (size_t kern : kernel) {
pack(n, group, size, size, kern, stride, nlmode,
pad);
......@@ -374,6 +383,18 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) {
false, true),
handle(), "F32_CONV_NCHW_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44) {
check_conv_bias(
get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, false),
handle(), "F32_CHANNEL_WISE_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) {
check_conv_bias(
get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, false),
handle(), "F32_CHANNEL_WISE_NCHW44");
}
/**********************************F16 direct************************/
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_LARGE_GROUP) {
......@@ -447,14 +468,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_SMALL_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT1_NCHW44) {
checker_conv_bias_int8x8x32_multi(
get_int8_quint8_nchw44_channel_wise_args({2, 3, 5}, 1, false, true),
get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true),
handle(), "S8_CHAN_WISE_STRD1_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT2_NCHW44) {
checker_conv_bias_int8x8x32_multi(
get_int8_quint8_nchw44_channel_wise_args({2, 3, 5}, 2, false, true),
get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true),
handle(), "S8_CHAN_WISE_STRD2_NCHW44");
}
......@@ -490,14 +511,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) {
handle(), "S8_NCHW44_DIRECT_STRD2");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) {
checker_conv_bias_qint8x8x8(get_int8_quint8_nchw44_channel_wise_args(
{2, 3, 5}, 1, false, false),
checker_conv_bias_qint8x8x8(
get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, true),
handle(), "S8_CHAN_WISE_STRD1_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
checker_conv_bias_qint8x8x8(get_int8_quint8_nchw44_channel_wise_args(
{2, 3, 5}, 2, false, false),
checker_conv_bias_qint8x8x8(
get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, true),
handle(), "S8_CHAN_WISE_STRD2_NCHW44");
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册