提交 89374521 编写于 作者: M Megvii Engine Team

fix(dnn/arm_common): add nchw44 float channel wise s1/s2

GitOrigin-RevId: 73e6aa1e57c36b5f8bc4e05faf4e9f06ec7e5cb7
上级 9f997ac5
......@@ -178,13 +178,16 @@ public:
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase {
class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF32DirectNCHW44() {}
AlgoF32DirectStride1(bool is_large_group) : m_large_group{is_large_group} {}
bool is_reproducible() const override { return true; }
const char* name() const override { return "F32_CONV_NCHW44_DIRECT"; }
const char* name() const override {
return m_large_group ? "F32STRD1_LARGE_GROUP" : "F32STRD1_SMALL_GROUP";
}
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -194,13 +197,17 @@ public:
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoF32DirectStride2NCHW44 final : public AlgoBase {
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF32DirectStride2NCHW44() {}
AlgoF32DirectStride2(bool is_large_group) : m_large_group{is_large_group} {}
bool is_reproducible() const override { return true; }
const char* name() const override { return "F32_CONV_NCHW44_DIRECT_S2"; }
const char* name() const override {
return m_large_group ? "F32STRD2_LARGE_GROUP" : "F32STRD2_SMALL_GROUP";
}
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -211,16 +218,13 @@ public:
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase {
class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF32DirectStride1(bool is_large_group) : m_large_group{is_large_group} {}
AlgoF32DirectNCHW44() {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "F32STRD1_LARGE_GROUP" : "F32STRD1_SMALL_GROUP";
}
const char* name() const override { return "F32_CONV_NCHW44_DIRECT"; }
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......@@ -231,33 +235,29 @@ public:
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
class ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF32DirectStride2(bool is_large_group) : m_large_group{is_large_group} {}
AlgoF32DirectStride2NCHWNCHW44() {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "F32STRD2_LARGE_GROUP" : "F32STRD2_SMALL_GROUP";
}
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
const char* name() const override { return "F32_CONV_NCHW_NCHW44"; }
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(fallback::ConvBiasImpl*,
size_t get_workspace(fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44 final : public AlgoBase {
class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
public:
AlgoF32DirectStride2NCHWNCHW44() {}
bool is_reproducible() const override { return true; }
const char* name() const override { return "F32_CONV_NCHW_NCHW44"; }
const char* name() const override { return "F32_CHANNEL_WISE_NCHW44"; }
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
......
/**
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/elemwise_op.h"
#include "midout.h"
using namespace megdnn;
using namespace arm_common;
using conv_fun = std::function<void(
const float* src, const float* filter, const float* bias, float* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW,
const size_t PH, size_t PW)>;
MIDOUT_DECL(conv_bias_fp32_channel_wise_nchw44)
bool ConvBiasImpl::AlgoF32ChannelWiseNCHW44::usable(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
size_t OC = fm.ocpg;
size_t IC = fm.icpg;
size_t GROUP = fm.group;
bool ok_type = (param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
(param.dst_type.enumv() == DTypeEnum::Float32));
bool ok_format = OC == 1 && IC == 1 && GROUP % 4 == 0 &&
fm.format == param::Convolution::Format::NCHW44;
bool ok_filter = fm.spatial_ndim == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[0] == 2);
bool ok_conv = !fm.should_flip;
bool avaible = ok_type && ok_format && ok_filter && ok_slide && ok_conv;
return avaible;
}
size_t ConvBiasImpl::AlgoF32ChannelWiseNCHW44::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam&) const {
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF32ChannelWiseNCHW44::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
const constexpr size_t pack_group_size = 4_z;
auto fm = param.filter_meta;
const int batch = param.n;
const int group = fm.group;
const int stride = fm.stride[0];
conv_fun do_conv_fun = nullptr;
// NOTE: remain_w is not used to gen hash of midout for compatible with
// shape runtime
#define DO_CONV_KERN_FUN(_stride, filter, bias_mode, op) \
MIDOUT_BEGIN(conv_bias_fp32_channel_wise_nchw44, \
midout_iv(#_stride #filter #bias_mode #op##_hash)) { \
do_conv_fun = channel_wise_nchw44_float:: \
do_conv_kern_##_stride##_##filter##x##filter<bias_mode, op>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(_stride, filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(_stride, filter, bias_mode, NoneOp<dt_float32>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(_stride, filter, bias_mode, ReluOp<dt_float32>) \
break; \
case param::ConvBias::NonlineMode::SIGMOID: \
DO_CONV_KERN_FUN(_stride, filter, bias_mode, \
SigmoidOp<dt_float32>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(_stride, filter, bias_mode, HSwishOp<dt_float32>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_BIAS_MODE_PARAM(_stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(_stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(_stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
case BiasMode::BIAS: \
GET_OP_PARAM(_stride, filter, BiasMode::BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN(_stride) \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(_stride, 2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(_stride, 3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(_stride, 5) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_STRIDE() \
if (1 == stride) { \
DISPATCH_CONV_KERN(stride1); \
} else { \
DISPATCH_CONV_KERN(stride2); \
}
DISPATCH_STRIDE();
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
#undef GET_OP_PARAM
#undef GET_BIAS_MODE_PARAM
#undef DISPATCH_CONV_KERN
#undef DISPATCH_STRIDE
megdnn_assert(do_conv_fun);
SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
CpuNDRange ncb_range = {static_cast<size_t>(batch),
static_cast<size_t>(group / pack_group_size)};
auto do_conv = [do_conv_fun](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
const float* sptr =
kern_param.src<float>(batch_id, group_id, 0, pack_group_size);
const float* fptr = kern_param.filter<float>(group_id, pack_group_size);
float* dst =
kern_param.dst<float>(batch_id, group_id, 0, pack_group_size);
const float* bptr =
kern_param.bias<float>(batch_id, group_id, 0, pack_group_size);
//! copy in case of illegal read src when padding is zero
do_conv_fun(sptr, fptr, bptr, dst, IH, IW, OH, OW, PH, PW);
};
ret_kerns.push_back({do_conv, ncb_range});
return ret_kerns;
}
//vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/int8/direct.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
using namespace megdnn;
using namespace arm_common;
namespace {
template <int size>
void load_vec(float32x4_t* dst, const float* src);
#define cb(i) dst[i] = vld1q_f32(src + i * 4);
#define LOAD_MACRO(n) \
template <> \
inline void load_vec<n>(float32x4_t * dst, const float* src) { \
UNROLL_CALL_NOWRAPPER(n, cb); \
}
LOAD_MACRO(2);
LOAD_MACRO(3);
LOAD_MACRO(4);
LOAD_MACRO(5);
LOAD_MACRO(6);
LOAD_MACRO(7);
LOAD_MACRO(8);
LOAD_MACRO(9);
#undef cb
#undef LOAD_MACRO
template <int size>
void compute_vec(float32x4_t& dst, float32x4_t* src, float32x4_t* filter);
#define cb(i) dst = vmlaq_f32(dst, src[i], filter[i]);
#define COMPUTE_MACRO(n) \
template <> \
inline void compute_vec<n>(float32x4_t & dst, float32x4_t * src, \
float32x4_t * filter) { \
UNROLL_CALL_NOWRAPPER(n, cb); \
}
COMPUTE_MACRO(2);
COMPUTE_MACRO(3);
COMPUTE_MACRO(5);
#undef cb
#undef COMPUTE_MACRO
template <BiasMode bias_mode, int size>
struct load_bias_vec;
#define cb_bias(i) dst[i] = vld1q_f32((bptr) + i * 4);
#define cb_init(i) dst[i] = init;
#define INIT_BIAS_MACRO(n) \
template <BiasMode bias_mode> \
struct load_bias_vec<bias_mode, n> { \
static void impl(float32x4_t* dst, const float32x4_t& init, \
const float* bptr) { \
if (bias_mode == BiasMode::BIAS) { \
UNROLL_CALL_NOWRAPPER(n, cb_bias); \
} else { \
UNROLL_CALL_NOWRAPPER(n, cb_init); \
} \
} \
};
INIT_BIAS_MACRO(1);
INIT_BIAS_MACRO(2);
INIT_BIAS_MACRO(4);
#undef cb_bias
#undef cb_init
#undef INIT_BIAS_MACRO
} // namespace
#define COMPUTE_PADDING_KERNEL() \
do { \
int iw = ow * stride - PW; \
float32x4_t result; \
load_bias_vec<bias_mode, 1>::impl(&result, init, \
bias + oh * OW * 4 + ow * 4); \
for (int kh = 0; kh < fh; kh++) { \
if (kh + ih < 0 || kh + ih >= static_cast<int>(IH)) \
continue; \
for (int kw = 0; kw < fh; kw++) { \
if (kw + iw < 0 || kw + iw >= static_cast<int>(IW)) \
continue; \
const float* sptr = src + (kh + ih) * IW * 4 + (kw + iw) * 4; \
result = vmlaq_f32(result, kernel[kh * fh + kw], \
vld1q_f32(sptr)); \
} \
} \
float* output = dst + oh * OW * 4 + ow * 4; \
op(result, output); \
} while (0)
template <BiasMode bias_mode, typename Op>
struct PaddingCompute {
static void compute(const float* src, const float* bias, float* dst,
const int fh, const int stride, const size_t IH,
const size_t IW, const size_t OH, const size_t OW,
const size_t PH, const size_t PW,
const float32x4_t* kernel, const float32x4_t& init) {
size_t oh_start = (PH + stride - 1) / stride;
size_t ow_start = (PW + stride - 1) / stride;
size_t oh_end = (IH + PH - fh) / stride + 1;
size_t ow_end = (IW + PW - fh) / stride + 1;
Op op;
for (size_t oh = 0; oh < oh_start; oh++) {
int ih = oh * stride - PH;
for (size_t ow = 0; ow < OW; ow++) {
COMPUTE_PADDING_KERNEL();
}
}
for (size_t oh = oh_start; oh < oh_end; oh++) {
int ih = oh * stride - PH;
for (size_t ow = 0; ow < ow_start; ow++) {
COMPUTE_PADDING_KERNEL();
}
for (size_t ow = ow_end; ow < OW; ow++) {
COMPUTE_PADDING_KERNEL();
}
}
for (size_t oh = oh_end; oh < OH; oh++) {
int ih = oh * stride - PH;
for (size_t ow = 0; ow < OW; ow++) {
COMPUTE_PADDING_KERNEL();
}
}
}
};
template <BiasMode bias_mode, typename Op>
struct PaddingComputeK3P1 {
static void compute(const float* src, const float* bias, float* dst,
const size_t stride, const size_t IH, const size_t IW,
const size_t OH, const size_t OW,
const float32x4_t* kernel, const float32x4_t& init) {
constexpr size_t PH = 1, PW = 1, FH = 3;
size_t oh_start = (PH + stride - 1) / stride;
size_t ow_start = (PW + stride - 1) / stride;
size_t oh_end = (IH + PH - FH) / stride + 1;
size_t ow_end = (IW + PW - FH) / stride + 1;
megdnn_assert(oh_start == ow_start && oh_start == 1,
"channel wise padding param error");
megdnn_assert(ow_end == OW - 1 || ow_end == OW, "padding PW error");
megdnn_assert(oh_end == OH - 1 || oh_end == OH, "padding PH error");
Op op;
// line one left
{
float32x4_t result;
load_bias_vec<bias_mode, 1>::impl(&result, init, bias);
result = vmlaq_f32(result, kernel[4], vld1q_f32(src));
result = vmlaq_f32(result, kernel[5], vld1q_f32(src + 4));
result = vmlaq_f32(result, kernel[7], vld1q_f32(src + IW * 4));
result = vmlaq_f32(result, kernel[8], vld1q_f32(src + IW * 4 + 4));
float* output = dst;
op(result, output);
}
// line one mid
for (size_t ow = ow_start; ow < ow_end; ow++) {
int iw = ow * stride - PW;
float32x4_t result;
load_bias_vec<bias_mode, 1>::impl(&result, init, bias + ow * 4);
const float* sptr = src + iw * 4;
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr));
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + 4));
result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + 8));
result = vmlaq_f32(result, kernel[6], vld1q_f32(sptr + IW * 4));
result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + IW * 4 + 4));
result = vmlaq_f32(result, kernel[8], vld1q_f32(sptr + IW * 4 + 8));
float* output = dst + ow * 4;
op(result, output);
}
// line one right
if (OW != ow_end) {
float32x4_t result;
load_bias_vec<bias_mode, 1>::impl(&result, init,
bias + (OW - 1) * 4);
const float* sptr = src + (ow_end * stride - PW) * 4;
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr));
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + 4));
result = vmlaq_f32(result, kernel[6], vld1q_f32(sptr + IW * 4));
result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + IW * 4 + 4));
float* output = dst + ow_end * 4;
op(result, output);
}
// mid line
for (size_t oh = oh_start; oh < oh_end; oh++) {
int ih = oh * stride - PH;
// left
{
float32x4_t result;
load_bias_vec<bias_mode, 1>::impl(&result, init,
bias + oh * OW * 4);
const float* sptr = src + ih * IW * 4;
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr));
result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 4));
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4));
result = vmlaq_f32(result, kernel[5],
vld1q_f32(sptr + IW * 4 + 4));
result = vmlaq_f32(result, kernel[7],
vld1q_f32(sptr + 2 * IW * 4));
result = vmlaq_f32(result, kernel[8],
vld1q_f32(sptr + 2 * IW * 4 + 4));
float* output = dst + oh * OW * 4;
op(result, output);
}
// right
if (OW != ow_end) {
float32x4_t result;
load_bias_vec<bias_mode, 1>::impl(
&result, init, bias + oh * OW * 4 + (OW - 1) * 4);
const float* sptr =
src + ih * IW * 4 + (ow_end * stride - PW) * 4;
result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr));
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4));
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4));
result = vmlaq_f32(result, kernel[4],
vld1q_f32(sptr + IW * 4 + 4));
result = vmlaq_f32(result, kernel[6],
vld1q_f32(sptr + 2 * IW * 4));
result = vmlaq_f32(result, kernel[7],
vld1q_f32(sptr + 2 * IW * 4 + 4));
float* output = dst + oh * OW * 4 + ow_end * 4;
op(result, output);
}
}
// last line left
if (OH != oh_end) {
size_t oh = OH - 1;
{
float32x4_t result;
load_bias_vec<bias_mode, 1>::impl(&result, init,
bias + oh * OW * 4);
const float* sptr = src + (oh_end * stride - PH) * IW * 4;
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr));
result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 4));
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4));
result = vmlaq_f32(result, kernel[5],
vld1q_f32(sptr + IW * 4 + 4));
float* output = dst + oh_end * OW * 4;
op(result, output);
}
// last line mid
for (size_t ow = ow_start; ow < ow_end; ow++) {
int iw = ow * stride - PW;
float32x4_t result;
load_bias_vec<bias_mode, 1>::impl(&result, init,
bias + oh * OW * 4 + ow * 4);
const float* sptr =
src + (oh_end * stride - PH) * IW * 4 + iw * 4;
result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr));
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4));
result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 8));
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4));
result = vmlaq_f32(result, kernel[4],
vld1q_f32(sptr + IW * 4 + 4));
result = vmlaq_f32(result, kernel[5],
vld1q_f32(sptr + IW * 4 + 8));
float* output = dst + oh_end * OW * 4 + ow * 4;
op(result, output);
}
// last line right
if (OW != ow_end) {
float32x4_t result;
load_bias_vec<bias_mode, 1>::impl(
&result, init, bias + oh * OW * 4 + (OW - 1) * 4);
const float* sptr = src + (oh_end * stride - PH) * IW * 4 +
(ow_end * stride - PW) * 4;
result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr));
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4));
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4));
result = vmlaq_f32(result, kernel[4],
vld1q_f32(sptr + IW * 4 + 4));
float* output = dst + oh_end * OW * 4 + ow_end * 4;
op(result, output);
}
}
}
};
#undef COMPUTE_PADDING_KERNEL
template <BiasMode bias_mode, typename Op>
void channel_wise_nchw44_float::do_conv_kern_stride1_2x2(
const float* src, const float* filter, const float* bias, float* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW,
const size_t PH, const size_t PW) {
float32x4_t kernel[4];
load_vec<4>(kernel, filter);
Op op;
float32x4_t init;
if (bias_mode == BiasMode::NO_BIAS) {
init = vdupq_n_f32(0.f);
} else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
init = vld1q_f32(bias);
}
size_t oh_start = PH;
size_t ow_start = PW;
size_t oh_end = IH + PH - 1;
size_t ow_end = IW + PW - 1;
if (PH || PW) {
PaddingCompute<bias_mode, Op>::compute(src, bias, dst, 2, 1, IH, IW, OH,
OW, PH, PW, kernel, init);
}
#define COMPUTE_2X2(dst, src, kernel) \
compute_vec<2>(dst[0], &src[0], kernel); \
compute_vec<2>(dst[1], &src[1], kernel); \
compute_vec<2>(dst[2], &src[2], kernel); \
compute_vec<2>(dst[3], &src[3], kernel)
size_t oh = oh_start;
for (; oh + 1 < oh_end; oh += 2) {
size_t ih = oh - oh_start;
size_t ow = ow_start;
for (; ow + 3 < ow_end; ow += 4) {
size_t iw = ow - ow_start;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2][4];
load_bias_vec<bias_mode, 4>::impl(dst_v[0], init,
bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 4>::impl(
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t src_v[3][5];
load_vec<5>(src_v[0], input);
COMPUTE_2X2(dst_v[0], src_v[0], &kernel[0]);
load_vec<5>(src_v[1], input + IW * 4);
COMPUTE_2X2(dst_v[0], src_v[1], &kernel[2]);
COMPUTE_2X2(dst_v[1], src_v[1], &kernel[0]);
load_vec<5>(src_v[2], input + 2 * IW * 4);
COMPUTE_2X2(dst_v[1], src_v[2], &kernel[2]);
op({{dst_v[0][0], dst_v[0][1]}}, output);
op({{dst_v[0][2], dst_v[0][3]}}, output + 8);
op({{dst_v[1][0], dst_v[1][1]}}, output + OW * 4);
op({{dst_v[1][2], dst_v[1][3]}}, output + OW * 4 + 8);
}
for (; ow < ow_end; ow++) {
size_t iw = ow - ow_start;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2];
load_bias_vec<bias_mode, 1>::impl(&dst_v[0], init,
bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 1>::impl(
&dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t src_v[3][2];
load_vec<2>(src_v[0], input);
compute_vec<2>(dst_v[0], &src_v[0][0], &kernel[0]);
load_vec<2>(src_v[1], input + IW * 4);
compute_vec<2>(dst_v[0], &src_v[1][0], &kernel[2]);
compute_vec<2>(dst_v[1], &src_v[1][0], &kernel[0]);
load_vec<2>(src_v[2], input + 2 * IW * 4);
compute_vec<2>(dst_v[1], &src_v[2][0], &kernel[2]);
op(dst_v[0], output);
op(dst_v[1], output + OW * 4);
}
}
for (; oh < oh_end; oh++) {
size_t ih = oh - oh_start;
size_t ow = ow_start;
for (; ow + 3 < ow_end; ow += 4) {
size_t iw = ow - ow_start;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[1][4];
load_bias_vec<bias_mode, 4>::impl(dst_v[0], init,
bias + oh * OW * 4 + ow * 4);
float32x4_t src_v[2][5];
load_vec<5>(src_v[0], input);
COMPUTE_2X2(dst_v[0], src_v[0], &kernel[0]);
load_vec<5>(src_v[1], input + IW * 4);
COMPUTE_2X2(dst_v[0], src_v[1], &kernel[2]);
op({{dst_v[0][0], dst_v[0][1]}}, output);
op({{dst_v[0][2], dst_v[0][3]}}, output + 8);
}
for (; ow < ow_end; ow++) {
size_t iw = ow - ow_start;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v;
load_bias_vec<bias_mode, 1>::impl(&dst_v, init,
bias + oh * OW * 4 + ow * 4);
float32x4_t src_v[2][2];
load_vec<2>(src_v[0], input);
compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]);
load_vec<2>(src_v[1], input + IW * 4);
compute_vec<2>(dst_v, &src_v[1][0], &kernel[2]);
op(dst_v, output);
}
}
#undef COMPUTE_2X2
}
template <BiasMode bias_mode, typename Op>
void channel_wise_nchw44_float::do_conv_kern_stride1_3x3(
const float* src, const float* filter, const float* bias, float* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW,
const size_t PH, const size_t PW) {
float32x4_t kernel[9];
load_vec<9>(kernel, filter);
Op op;
float32x4_t init;
if (bias_mode == BiasMode::NO_BIAS) {
init = vdupq_n_f32(0.f);
} else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
init = vld1q_f32(bias);
}
size_t oh_start = PH;
size_t ow_start = PW;
size_t oh_end = IH + PH - 2;
size_t ow_end = IW + PW - 2;
if (PH == 1 && PW == 1) {
PaddingComputeK3P1<bias_mode, Op>::compute(src, bias, dst, 1, IH, IW,
OH, OW, kernel, init);
} else if (PH || PW) {
PaddingCompute<bias_mode, Op>::compute(src, bias, dst, 3, 1, IH, IW, OH,
OW, PH, PW, kernel, init);
}
size_t oh = oh_start;
for (; oh + 1 < oh_end; oh += 2) {
size_t ih = oh - PH;
size_t ow = ow_start;
for (; ow + 3 < ow_end; ow += 4) {
size_t iw = ow - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2][4];
load_bias_vec<bias_mode, 4>::impl(dst_v[0], init,
bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 4>::impl(
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t src_v[2][6];
load_vec<6>(src_v[0], input);
compute_vec<3>(dst_v[0][0], &src_v[0][0], &kernel[0]);
compute_vec<3>(dst_v[0][1], &src_v[0][1], &kernel[0]);
compute_vec<3>(dst_v[0][2], &src_v[0][2], &kernel[0]);
compute_vec<3>(dst_v[0][3], &src_v[0][3], &kernel[0]);
load_vec<6>(src_v[1], input + IW * 4);
compute_vec<3>(dst_v[0][0], &src_v[1][0], &kernel[3]);
compute_vec<3>(dst_v[0][1], &src_v[1][1], &kernel[3]);
compute_vec<3>(dst_v[0][2], &src_v[1][2], &kernel[3]);
compute_vec<3>(dst_v[0][3], &src_v[1][3], &kernel[3]);
compute_vec<3>(dst_v[1][0], &src_v[1][0], &kernel[0]);
compute_vec<3>(dst_v[1][1], &src_v[1][1], &kernel[0]);
compute_vec<3>(dst_v[1][2], &src_v[1][2], &kernel[0]);
compute_vec<3>(dst_v[1][3], &src_v[1][3], &kernel[0]);
load_vec<6>(src_v[0], input + 2 * IW * 4);
compute_vec<3>(dst_v[0][0], &src_v[0][0], &kernel[6]);
compute_vec<3>(dst_v[0][1], &src_v[0][1], &kernel[6]);
compute_vec<3>(dst_v[0][2], &src_v[0][2], &kernel[6]);
compute_vec<3>(dst_v[0][3], &src_v[0][3], &kernel[6]);
compute_vec<3>(dst_v[1][0], &src_v[0][0], &kernel[3]);
compute_vec<3>(dst_v[1][1], &src_v[0][1], &kernel[3]);
compute_vec<3>(dst_v[1][2], &src_v[0][2], &kernel[3]);
compute_vec<3>(dst_v[1][3], &src_v[0][3], &kernel[3]);
load_vec<6>(src_v[1], input + 3 * IW * 4);
compute_vec<3>(dst_v[1][0], &src_v[1][0], &kernel[6]);
compute_vec<3>(dst_v[1][1], &src_v[1][1], &kernel[6]);
compute_vec<3>(dst_v[1][2], &src_v[1][2], &kernel[6]);
compute_vec<3>(dst_v[1][3], &src_v[1][3], &kernel[6]);
op({{dst_v[0][0], dst_v[0][1]}}, output);
op({{dst_v[0][2], dst_v[0][3]}}, output + 8);
op({{dst_v[1][0], dst_v[1][1]}}, output + OW * 4);
op({{dst_v[1][2], dst_v[1][3]}}, output + OW * 4 + 8);
}
for (; ow < ow_end; ow++) {
size_t iw = ow - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2];
load_bias_vec<bias_mode, 1>::impl(&dst_v[0], init,
bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 1>::impl(
&dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t src_v[2][3];
load_vec<3>(src_v[0], input);
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]);
load_vec<3>(src_v[1], input + IW * 4);
compute_vec<3>(dst_v[0], &src_v[1][0], &kernel[3]);
compute_vec<3>(dst_v[1], &src_v[1][0], &kernel[0]);
load_vec<3>(src_v[0], input + 2 * IW * 4);
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[6]);
compute_vec<3>(dst_v[1], &src_v[0][0], &kernel[3]);
load_vec<3>(src_v[1], input + 3 * IW * 4);
compute_vec<3>(dst_v[1], &src_v[1][0], &kernel[6]);
op(dst_v[0], output);
op(dst_v[1], output + OW * 4);
}
}
for (; oh < oh_end; oh++) {
size_t ih = oh - PH;
size_t ow = ow_start;
for (; ow + 3 < ow_end; ow += 4) {
size_t iw = ow - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[4];
load_bias_vec<bias_mode, 4>::impl(&dst_v[0], init,
bias + oh * OW * 4 + ow * 4);
float32x4_t src_v[2][6];
load_vec<6>(src_v[0], input);
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]);
compute_vec<3>(dst_v[1], &src_v[0][1], &kernel[0]);
compute_vec<3>(dst_v[2], &src_v[0][2], &kernel[0]);
compute_vec<3>(dst_v[3], &src_v[0][3], &kernel[0]);
load_vec<6>(src_v[1], input + IW * 4);
compute_vec<3>(dst_v[0], &src_v[1][0], &kernel[3]);
compute_vec<3>(dst_v[1], &src_v[1][1], &kernel[3]);
compute_vec<3>(dst_v[2], &src_v[1][2], &kernel[3]);
compute_vec<3>(dst_v[3], &src_v[1][3], &kernel[3]);
load_vec<6>(src_v[0], input + 2 * IW * 4);
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[6]);
compute_vec<3>(dst_v[1], &src_v[0][1], &kernel[6]);
compute_vec<3>(dst_v[2], &src_v[0][2], &kernel[6]);
compute_vec<3>(dst_v[3], &src_v[0][3], &kernel[6]);
op({{dst_v[0], dst_v[1]}}, output);
op({{dst_v[2], dst_v[3]}}, output + 8);
}
for (; ow < ow_end; ow++) {
size_t iw = ow - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v;
load_bias_vec<bias_mode, 1>::impl(&dst_v, init,
bias + oh * OW * 4 + ow * 4);
float32x4_t src_v[3][3];
load_vec<3>(src_v[0], input);
compute_vec<3>(dst_v, &src_v[0][0], &kernel[0]);
load_vec<3>(src_v[1], input + IW * 4);
compute_vec<3>(dst_v, &src_v[1][0], &kernel[3]);
load_vec<3>(src_v[2], input + 2 * IW * 4);
compute_vec<3>(dst_v, &src_v[2][0], &kernel[6]);
op(dst_v, output);
}
}
}
template <BiasMode bias_mode, typename Op>
void channel_wise_nchw44_float::do_conv_kern_stride1_5x5(
const float* src, const float* filter, const float* bias, float* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW,
const size_t PH, const size_t PW) {
Op op;
float32x4_t init;
if (bias_mode == BiasMode::NO_BIAS) {
init = vdupq_n_f32(0.f);
} else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
init = vld1q_f32(bias);
}
size_t oh_start = PH;
size_t ow_start = PW;
size_t oh_end = IH + PH - 4;
size_t ow_end = IW + PW - 4;
if (PH || PW) {
PaddingCompute<bias_mode, Op>::compute(
src, bias, dst, 5, 1, IH, IW, OH, OW, PH, PW,
reinterpret_cast<const float32x4_t*>(filter), init);
}
size_t oh = oh_start;
for (; oh + 1 < oh_end; oh += 2) {
size_t ih = oh - PH;
size_t ow = ow_start;
for (; ow + 1 < ow_end; ow += 2) {
size_t iw = ow - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2][2];
load_bias_vec<bias_mode, 2>::impl(dst_v[0], init,
bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 2>::impl(
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t kernel[2][5];
float32x4_t src_v[2][6];
#define COMPUTE_5X5_4(i, dst, src, kernel0, kernel1) \
load_vec<5>(kernel0, filter + i * 5 * 4); \
load_vec<6>(src, input + i * IW * 4); \
compute_vec<5>(dst[0][0], &src[0], kernel0); \
compute_vec<5>(dst[0][1], &src[1], kernel0); \
compute_vec<5>(dst[1][0], &src[0], kernel1); \
compute_vec<5>(dst[1][1], &src[1], kernel1)
// line 0
load_vec<5>(kernel[0], filter);
load_vec<6>(src_v[0], input);
compute_vec<5>(dst_v[0][0], &src_v[0][0], kernel[0]);
compute_vec<5>(dst_v[0][1], &src_v[0][1], kernel[0]);
// line 1
COMPUTE_5X5_4(1, dst_v, src_v[1], kernel[1], kernel[0]);
// line 2
COMPUTE_5X5_4(2, dst_v, src_v[0], kernel[0], kernel[1]);
// line 3
COMPUTE_5X5_4(3, dst_v, src_v[1], kernel[1], kernel[0]);
// line 4
COMPUTE_5X5_4(4, dst_v, src_v[0], kernel[0], kernel[1]);
// line 5
load_vec<6>(src_v[1], input + 5 * IW * 4);
compute_vec<5>(dst_v[1][0], &src_v[1][0], kernel[0]);
compute_vec<5>(dst_v[1][1], &src_v[1][1], kernel[0]);
#undef COMPUTE_5X5_4
op({{dst_v[0][0], dst_v[0][1]}}, output);
op({{dst_v[1][0], dst_v[1][1]}}, output + OW * 4);
}
for (; ow < ow_end; ow++) {
size_t iw = ow - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2][1];
load_bias_vec<bias_mode, 1>::impl(dst_v[0], init,
bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 1>::impl(
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t kernel[2][5];
float32x4_t src_v[2][5];
#define COMPUTE_5X5_2(i, dst, src, kernel0, kernel1) \
load_vec<5>(kernel0, filter + i * 5 * 4); \
load_vec<6>(src, input + i * IW * 4); \
compute_vec<5>(dst[0][0], &src[0], kernel0); \
compute_vec<5>(dst[1][0], &src[0], kernel1);
// line 0
load_vec<5>(kernel[0], filter);
load_vec<5>(src_v[0], input);
compute_vec<5>(dst_v[0][0], &src_v[0][0], kernel[0]);
// line 1
COMPUTE_5X5_2(1, dst_v, src_v[1], kernel[1], kernel[0]);
// line 2
COMPUTE_5X5_2(2, dst_v, src_v[0], kernel[0], kernel[1]);
// line 3
COMPUTE_5X5_2(3, dst_v, src_v[1], kernel[1], kernel[0]);
// line 4
COMPUTE_5X5_2(4, dst_v, src_v[0], kernel[0], kernel[1]);
// line 5
load_vec<5>(src_v[1], input + 5 * IW * 4);
compute_vec<5>(dst_v[1][0], &src_v[1][0], kernel[0]);
#undef COMPUTE_5X5_2
op(dst_v[0][0], output);
op(dst_v[1][0], output + OW * 4);
}
}
for (; oh < oh_end; oh++) {
size_t ih = oh - PH;
size_t ow = ow_start;
for (; ow + 1 < ow_end; ow += 2) {
size_t iw = ow - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[1][2];
load_bias_vec<bias_mode, 2>::impl(dst_v[0], init,
bias + oh * OW * 4 + ow * 4);
float32x4_t kernel[2][5];
float32x4_t src_v[2][6];
#define COMPUTE_5X5_2(i, dst, src, kernel) \
load_vec<5>(kernel, filter + i * 5 * 4); \
load_vec<6>(src, input + i * IW * 4); \
compute_vec<5>(dst[0][0], &src[0], kernel); \
compute_vec<5>(dst[0][1], &src[1], kernel)
// line 0
COMPUTE_5X5_2(0, dst_v, src_v[0], kernel[0]);
// line 1
COMPUTE_5X5_2(1, dst_v, src_v[1], kernel[1]);
// line 2
COMPUTE_5X5_2(2, dst_v, src_v[0], kernel[0]);
// line 3
COMPUTE_5X5_2(3, dst_v, src_v[1], kernel[1]);
// line 4
COMPUTE_5X5_2(4, dst_v, src_v[0], kernel[0]);
#undef COMPUTE_5X5_2
op({{dst_v[0][0], dst_v[0][1]}}, output);
}
for (; ow < ow_end; ow++) {
size_t iw = ow - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v;
load_bias_vec<bias_mode, 1>::impl(&dst_v, init,
bias + oh * OW * 4 + ow * 4);
float32x4_t kernel[2][5];
float32x4_t src_v[2][5];
#define COMPUTE_5X5_1(i, dst, src, kernel) \
load_vec<5>(kernel, filter + i * 5 * 4); \
load_vec<6>(src, input + i * IW * 4); \
compute_vec<5>(dst, &src[0], kernel)
// line 0
COMPUTE_5X5_1(0, dst_v, src_v[0], kernel[0]);
// line 1
COMPUTE_5X5_1(1, dst_v, src_v[1], kernel[1]);
// line 2
COMPUTE_5X5_1(2, dst_v, src_v[0], kernel[0]);
// line 3
COMPUTE_5X5_1(3, dst_v, src_v[1], kernel[1]);
// line 4
COMPUTE_5X5_1(4, dst_v, src_v[0], kernel[0]);
#undef COMPUTE_5X5_1
op(dst_v, output);
}
}
}
template <BiasMode bias_mode, typename Op>
void channel_wise_nchw44_float::do_conv_kern_stride2_2x2(
const float* src, const float* filter, const float* bias, float* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW,
const size_t PH, const size_t PW) {
float32x4_t kernel[4];
load_vec<4>(kernel, filter);
Op op;
float32x4_t init;
if (bias_mode == BiasMode::NO_BIAS) {
init = vdupq_n_f32(0.f);
} else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
init = vld1q_f32(bias);
}
size_t oh_start = (PH + 1) / 2;
size_t ow_start = (PW + 1) / 2;
size_t oh_end = (IH + PH) / 2;
size_t ow_end = (IW + PW) / 2;
if (PH || PW) {
PaddingCompute<bias_mode, Op>::compute(src, bias, dst, 2, 2, IH, IW, OH,
OW, PH, PW, kernel, init);
}
#define COMPUTE_2X2(dst, src, kernel) \
compute_vec<2>(dst[0], &src[0], kernel); \
compute_vec<2>(dst[1], &src[2], kernel); \
compute_vec<2>(dst[2], &src[4], kernel); \
compute_vec<2>(dst[3], &src[6], kernel)
size_t oh = oh_start;
for (; oh < oh_end; oh++) {
size_t ih = oh * 2 - PH;
size_t ow = ow_start;
for (; ow + 3 < ow_end; ow += 4) {
size_t iw = ow * 2 - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[4];
load_bias_vec<bias_mode, 4>::impl(&dst_v[0], init,
bias + oh * OW * 4 + ow * 4);
float32x4_t src_v[2][8];
load_vec<8>(src_v[0], input);
COMPUTE_2X2(dst_v, src_v[0], &kernel[0]);
load_vec<8>(src_v[1], input + IW * 4);
COMPUTE_2X2(dst_v, src_v[1], &kernel[2]);
#undef COMPUTE_2X2
op({{dst_v[0], dst_v[1]}}, output);
op({{dst_v[2], dst_v[3]}}, output + 8);
}
for (; ow < ow_end; ow++) {
size_t iw = ow * 2 - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v;
load_bias_vec<bias_mode, 1>::impl(&dst_v, init,
bias + oh * OW * 4 + ow * 4);
float32x4_t src_v[2][2];
load_vec<2>(src_v[0], input);
compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]);
load_vec<2>(src_v[1], input + IW * 4);
compute_vec<2>(dst_v, &src_v[1][0], &kernel[2]);
op(dst_v, output);
}
}
#undef COMPUTE_2X2
}
template <BiasMode bias_mode, typename Op>
void channel_wise_nchw44_float::do_conv_kern_stride2_3x3(
const float* src, const float* filter, const float* bias, float* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW,
const size_t PH, const size_t PW) {
float32x4_t kernel[9];
load_vec<9>(kernel, filter);
Op op;
float32x4_t init;
if (bias_mode == BiasMode::NO_BIAS) {
init = vdupq_n_f32(0.f);
} else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
init = vld1q_f32(bias);
}
size_t oh_start = (PH + 1) / 2;
size_t ow_start = (PW + 1) / 2;
size_t oh_end = (IH + PH - 3) / 2 + 1;
size_t ow_end = (IW + PW - 3) / 2 + 1;
if (PH == 1 && PW == 1) {
PaddingComputeK3P1<bias_mode, Op>::compute(src, bias, dst, 2, IH, IW,
OH, OW, kernel, init);
} else if (PH || PW) {
PaddingCompute<bias_mode, Op>::compute(src, bias, dst, 3, 2, IH, IW, OH,
OW, PH, PW, kernel, init);
}
size_t oh = oh_start;
for (; oh + 1 < oh_end; oh += 2) {
size_t ih = oh * 2 - PH;
size_t ow = ow_start;
for (; ow + 1 < ow_end; ow += 2) {
size_t iw = ow * 2 - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2][2];
load_bias_vec<bias_mode, 2>::impl(dst_v[0], init,
bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 2>::impl(
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t src_v[2][5];
load_vec<5>(src_v[0], input);
compute_vec<3>(dst_v[0][0], &src_v[0][0], &kernel[0]);
compute_vec<3>(dst_v[0][1], &src_v[0][2], &kernel[0]);
load_vec<5>(src_v[1], input + IW * 4);
compute_vec<3>(dst_v[0][0], &src_v[1][0], &kernel[3]);
compute_vec<3>(dst_v[0][1], &src_v[1][2], &kernel[3]);
load_vec<5>(src_v[0], input + 2 * IW * 4);
compute_vec<3>(dst_v[0][0], &src_v[0][0], &kernel[6]);
compute_vec<3>(dst_v[0][1], &src_v[0][2], &kernel[6]);
compute_vec<3>(dst_v[1][0], &src_v[0][0], &kernel[0]);
compute_vec<3>(dst_v[1][1], &src_v[0][2], &kernel[0]);
load_vec<5>(src_v[1], input + 3 * IW * 4);
compute_vec<3>(dst_v[1][0], &src_v[1][0], &kernel[3]);
compute_vec<3>(dst_v[1][1], &src_v[1][2], &kernel[3]);
load_vec<5>(src_v[0], input + 4 * IW * 4);
compute_vec<3>(dst_v[1][0], &src_v[0][0], &kernel[6]);
compute_vec<3>(dst_v[1][1], &src_v[0][2], &kernel[6]);
op({{dst_v[0][0], dst_v[0][1]}}, output);
op({{dst_v[1][0], dst_v[1][1]}}, output + OW * 4);
}
for (; ow < ow_end; ow++) {
size_t iw = ow * 2 - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2];
load_bias_vec<bias_mode, 1>::impl(&dst_v[0], init,
bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 1>::impl(
&dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t src_v[2][3];
load_vec<3>(src_v[0], input);
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]);
load_vec<3>(src_v[1], input + IW * 4);
compute_vec<3>(dst_v[0], &src_v[1][0], &kernel[3]);
load_vec<3>(src_v[0], input + 2 * IW * 4);
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[6]);
compute_vec<3>(dst_v[1], &src_v[0][0], &kernel[0]);
load_vec<3>(src_v[1], input + 3 * IW * 4);
compute_vec<3>(dst_v[1], &src_v[1][0], &kernel[3]);
load_vec<3>(src_v[0], input + 4 * IW * 4);
compute_vec<3>(dst_v[1], &src_v[0][0], &kernel[6]);
op(dst_v[0], output);
op(dst_v[1], output + OW * 4);
}
}
for (; oh < oh_end; oh++) {
size_t ih = oh * 2 - PH;
size_t ow = ow_start;
for (; ow + 1 < ow_end; ow += 2) {
size_t iw = ow * 2 - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2];
load_bias_vec<bias_mode, 2>::impl(&dst_v[0], init,
bias + oh * OW * 4 + ow * 4);
float32x4_t src_v[3][5];
load_vec<5>(src_v[0], input);
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]);
compute_vec<3>(dst_v[1], &src_v[0][2], &kernel[0]);
load_vec<5>(src_v[1], input + IW * 4);
compute_vec<3>(dst_v[0], &src_v[1][0], &kernel[3]);
compute_vec<3>(dst_v[1], &src_v[1][2], &kernel[3]);
load_vec<5>(src_v[2], input + 2 * IW * 4);
compute_vec<3>(dst_v[0], &src_v[2][0], &kernel[6]);
compute_vec<3>(dst_v[1], &src_v[2][2], &kernel[6]);
op({{dst_v[0], dst_v[1]}}, output);
}
for (; ow < ow_end; ow++) {
size_t iw = ow * 2 - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v;
load_bias_vec<bias_mode, 1>::impl(&dst_v, init,
bias + oh * OW * 4 + ow * 4);
float32x4_t src_v[3][3];
load_vec<3>(src_v[0], input);
compute_vec<3>(dst_v, &src_v[0][0], &kernel[0]);
load_vec<3>(src_v[1], input + IW * 4);
compute_vec<3>(dst_v, &src_v[1][0], &kernel[3]);
load_vec<3>(src_v[2], input + 2 * IW * 4);
compute_vec<3>(dst_v, &src_v[2][0], &kernel[6]);
op(dst_v, output);
}
}
}
template <BiasMode bias_mode, typename Op>
void channel_wise_nchw44_float::do_conv_kern_stride2_5x5(
const float* src, const float* filter, const float* bias, float* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW,
const size_t PH, const size_t PW) {
Op op;
float32x4_t init;
if (bias_mode == BiasMode::NO_BIAS) {
init = vdupq_n_f32(0.f);
} else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
init = vld1q_f32(bias);
}
constexpr size_t stride = 2;
size_t oh_start = (PH + stride - 1) / stride;
size_t ow_start = (PW + stride - 1) / stride;
size_t oh_end = (IH + PH - 5) / stride + 1;
size_t ow_end = (IW + PW - 5) / stride + 1;
if (PH || PW) {
PaddingCompute<bias_mode, Op>::compute(
src, bias, dst, 5, stride, IH, IW, OH, OW, PH, PW,
reinterpret_cast<const float32x4_t*>(filter), init);
}
size_t oh = oh_start;
for (; oh + 1 < oh_end; oh += 2) {
size_t ih = oh * stride - PH;
size_t ow = ow_start;
for (; ow + 1 < ow_end; ow += 2) {
size_t iw = ow * stride - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2][2];
load_bias_vec<bias_mode, 2>::impl(dst_v[0], init,
bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 2>::impl(
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t kernel[3][5];
float32x4_t src_v[2][7];
#define COMPUTE_5X5_4(i, dst, src, kernel0, kernel1) \
load_vec<5>(kernel0, filter + i * 5 * 4); \
load_vec<7>(src, input + i * IW * 4); \
compute_vec<5>(dst[0][0], &src[0], kernel0); \
compute_vec<5>(dst[0][1], &src[2], kernel0); \
compute_vec<5>(dst[1][0], &src[0], kernel1); \
compute_vec<5>(dst[1][1], &src[2], kernel1)
#define COMPUTE_5X5_2(i, dst, src, kernel) \
load_vec<7>(src, input + i * IW * 4); \
compute_vec<5>(dst[0], &src[0], kernel); \
compute_vec<5>(dst[1], &src[2], kernel)
// line 0
load_vec<5>(kernel[0], filter);
COMPUTE_5X5_2(0, dst_v[0], src_v[0], kernel[0]);
// line 1
load_vec<5>(kernel[1], filter + 5 * 4);
COMPUTE_5X5_2(1, dst_v[0], src_v[1], kernel[1]);
// line 2
COMPUTE_5X5_4(2, dst_v, src_v[0], kernel[2], kernel[0]);
// line 3
COMPUTE_5X5_4(3, dst_v, src_v[1], kernel[0], kernel[1]);
// line 4
COMPUTE_5X5_4(4, dst_v, src_v[0], kernel[1], kernel[2]);
// line 5
COMPUTE_5X5_2(5, dst_v[1], src_v[1], kernel[0]);
// line 6
COMPUTE_5X5_2(6, dst_v[1], src_v[0], kernel[1]);
#undef COMPUTE_5X5_4
#undef COMPUTE_5X5_2
op({{dst_v[0][0], dst_v[0][1]}}, output);
op({{dst_v[1][0], dst_v[1][1]}}, output + OW * 4);
}
for (; ow < ow_end; ow++) {
size_t iw = ow * stride - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2];
load_bias_vec<bias_mode, 1>::impl(&dst_v[0], init,
bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 1>::impl(
&dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t kernel[3][5];
float32x4_t src_v[2][5];
#define COMPUTE_5X5_2(i, dst, src, kernel0, kernel1) \
load_vec<5>(kernel0, filter + i * 5 * 4); \
load_vec<5>(src, input + i * IW * 4); \
compute_vec<5>(dst[0], &src[0], kernel0); \
compute_vec<5>(dst[1], &src[0], kernel1);
#define COMPUTE_5X5_1(i, dst, src, kernel) \
load_vec<5>(src, input + i * IW * 4); \
compute_vec<5>(dst, &src[0], kernel); \
// line 0
load_vec<5>(kernel[0], filter);
COMPUTE_5X5_1(0, dst_v[0], src_v[0], kernel[0]);
// line 1
load_vec<5>(kernel[1], filter + 5 * 4);
COMPUTE_5X5_1(1, dst_v[0], src_v[1], kernel[1]);
// line 2
COMPUTE_5X5_2(2, dst_v, src_v[0], kernel[2], kernel[0]);
// line 3
COMPUTE_5X5_2(3, dst_v, src_v[1], kernel[0], kernel[1]);
// line 4
COMPUTE_5X5_2(4, dst_v, src_v[0], kernel[1], kernel[2]);
// line 5
COMPUTE_5X5_1(5, dst_v[1], src_v[1], kernel[0]);
// line 6
COMPUTE_5X5_1(6, dst_v[1], src_v[0], kernel[1]);
#undef COMPUTE_5X5_2
#undef COMPUTE_5X5_1
op(dst_v[0], output);
op(dst_v[1], output + OW * 4);
}
}
for (; oh < oh_end; oh++) {
size_t ih = oh * stride - PH;
size_t ow = ow_start;
for (; ow < ow_end; ow++) {
size_t iw = ow * stride - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v;
load_bias_vec<bias_mode, 1>::impl(&dst_v, init,
bias + oh * OW * 4 + ow * 4);
float32x4_t kernel[2][5];
float32x4_t src_v[2][5];
#define COMPUTE_5X5_1(i, dst, src, kernel) \
load_vec<5>(kernel, filter + i * 5 * 4); \
load_vec<6>(src, input + i * IW * 4); \
compute_vec<5>(dst, &src[0], kernel)
// line 0
COMPUTE_5X5_1(0, dst_v, src_v[0], kernel[0]);
// line 1
COMPUTE_5X5_1(1, dst_v, src_v[1], kernel[1]);
// line 2
COMPUTE_5X5_1(2, dst_v, src_v[0], kernel[0]);
// line 3
COMPUTE_5X5_1(3, dst_v, src_v[1], kernel[1]);
// line 4
COMPUTE_5X5_1(4, dst_v, src_v[0], kernel[0]);
#undef COMPUTE_5X5_1
op(dst_v, output);
}
}
}
#define INSTANTIATION(stride, i, bias, Op) \
template void \
channel_wise_nchw44_float::do_conv_kern_##stride##_##i##x##i<bias, \
Op>( \
const float*, const float*, const float*, float*, \
const size_t, const size_t, const size_t, const size_t, \
const size_t, const size_t);
#define FOR_OP(stride, i, bias) \
INSTANTIATION(stride, i, bias, SigmoidOp<dt_float32>) \
INSTANTIATION(stride, i, bias, ReluOp<dt_float32>) \
INSTANTIATION(stride, i, bias, HSwishOp<dt_float32>) \
INSTANTIATION(stride, i, bias, NoneOp<dt_float32>)
#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(stride, i, BiasMode::BIAS)
#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5)
#define FOR_STRIDE \
FOR_FILTER(stride1) \
FOR_FILTER(stride2)
FOR_STRIDE
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_BIAS
#undef FOR_OP
#undef INSTANTIATION
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace channel_wise_nchw44_float {
#define KERN(stride, i) \
template <BiasMode bias_mode, typename Op> \
void do_conv_kern_##stride##_##i##x##i( \
const float* src, const float* filter, const float* bias, \
float* dst, const size_t IH, const size_t IW, const size_t OH, \
const size_t OW, const size_t PH, const size_t PW);
KERN(stride1, 2)
KERN(stride1, 3)
KERN(stride1, 5)
KERN(stride2, 2)
KERN(stride2, 3)
KERN(stride2, 5)
#undef KERN
} // namespace channel_wise_nchw44_float
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/int8/direct.cpp
* \file dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -1640,4 +1640,5 @@ FOR_STRIDE
#undef FOR_BIAS
#undef FOR_OP
#undef INSTANTIATION
// vim: syntax=cpp.doxygen
......@@ -65,14 +65,17 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false};
#endif
AlgoF32DirectStride2NCHWNCHW44 f32_direct_stride2_nchw_nchw44;
AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44;
AlgoF32DirectNCHW44 f32_direct_nchw44;
AlgoF32Direct f32_direct_large_group{true};
AlgoF32Direct f32_direct_small_group{false};
AlgoF32DirectNCHW44 f32_direct_nchw44;
AlgoF32DirectStride2 f32_direct_stride2_large_group{true};
AlgoF32DirectStride2 f32_direct_stride2_small_group{false};
AlgoF32DirectStride1 f32_direct_stride1_large_group{true};
AlgoF32DirectStride1 f32_direct_stride1_small_group{false};
AlgoF32DirectStride2NCHWNCHW44 f32_direct_stride2_nchw_nchw44;
AlgoI8x8x16Direct i8x8x16_direct_large_group{true};
AlgoI8x8x16Direct i8x8x16_direct_small_group{false};
AlgoI8x8x16Stride2 i8x8x16_stride2_large_group{true};
......@@ -125,8 +128,11 @@ public:
direct_algos.emplace_back(&i8x8x16_stride2_filter2);
direct_algos.emplace_back(&i8x8x16_stride2_large_group);
direct_algos.emplace_back(&i8x8x16_stride2_small_group);
direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44);
direct_algos.emplace_back(&f32_chanel_wise_nchw44);
direct_algos.emplace_back(&f32_direct_nchw44);
direct_algos.emplace_back(&f32_direct_stride1_large_group);
direct_algos.emplace_back(&f32_direct_stride1_small_group);
direct_algos.emplace_back(&f32_direct_stride2_large_group);
......
......@@ -66,10 +66,10 @@ private:
#endif
class AlgoF32Direct;
class AlgoF32DirectStride1;
class AlgoF32DirectNCHW44;
class AlgoF32DirectStride2;
class AlgoF32DirectStride2NCHWNCHW44;
class AlgoF32DirectStride2NCHW44;
class AlgoF32ChannelWiseNCHW44;
class AlgoF32DirectNCHW44;
class AlgoI8x8x16Direct;
class AlgoI8x8x16Stride2;
......
......@@ -1086,6 +1086,155 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE2) {
used1 / used0);
}
}
TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE1_NCHW44) {
// have to remove preferred restrict in usable func before run the benchmark
using namespace conv_bias;
param::ConvBias param;
param.stride_h = 1;
param.stride_w = 1;
param.pad_h = 1;
param.pad_w = 1;
param.nonlineMode = NonlineMode::RELU;
param.sparse = param::ConvBias::Sparse::GROUP;
constexpr size_t RUN = 50;
Benchmarker<ConvBias> benchmark0(handle());
benchmark0.set_display(false);
benchmark0.set_param(param);
benchmark0.set_times(RUN);
benchmark0.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"F32STRD1_LARGE_GROUP"));
auto opr = handle()->create_operator<ConvBias>();
opr->param() = param;
param.format = param::ConvBias::Format::NCHW44;
Benchmarker<ConvBias> benchmark1(handle());
benchmark1.set_display(false);
benchmark1.set_param(param);
benchmark1.set_times(RUN);
benchmark1.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"F32_CHANNEL_WISE_NCHW44"));
auto run = [&](size_t group, size_t w, size_t h, size_t kernel) {
TensorLayout dst_layout;
opr->deduce_layout({{1, group * 4, h, w}, dtype::Int8()},
{{group * 4, 1, 1, kernel, kernel}, dtype::Int8()},
{{1, group * 4, 1, 1}, dtype::Int32()}, {},
dst_layout);
//! dst.nr_elems * IC * FH * FW * 2
float computations = dst_layout.total_nr_elems() * kernel * kernel *
2.0 / (1024 * 1024 * 1024) * 1e3;
auto used0 = benchmark0.exec({{1, group * 4, h, w},
{group * 4, 1, 1, kernel, kernel},
{1, group * 4, 1, 1},
{},
{}}) /
RUN;
auto used1 = benchmark1.exec({{1, group, h, w, 4},
{group, 1, 1, kernel, kernel, 4},
{1, group, 1, 1, 4},
{},
{}}) /
RUN;
printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops "
"nchw44: "
"%f ms %f GFlops "
"speedup: %f\n",
group, h, w, kernel, used0, computations / used0, used1,
computations / used1, used0 / used1);
};
for (size_t group : {8, 16, 32, 64}) {
for (size_t kerenl : {2, 3, 5}) {
run(group, 112, 112, kerenl);
run(group, 56, 56, kerenl);
run(group, 48, 48, kerenl);
run(group, 28, 28, kerenl);
run(group, 14, 14, kerenl);
}
}
run(8, 112, 112, 3);
run(32, 56, 56, 3);
run(64, 28, 28, 3);
run(128, 14, 14, 3);
}
TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE2_NCHW44) {
// have to remove preferred restrict in usable func before run the benchmark
using namespace conv_bias;
param::ConvBias param;
param.stride_h = 2;
param.stride_w = 2;
param.pad_h = 1;
param.pad_w = 1;
param.nonlineMode = NonlineMode::RELU;
param.sparse = param::ConvBias::Sparse::GROUP;
constexpr size_t RUN = 50;
Benchmarker<ConvBias> benchmark0(handle());
benchmark0.set_display(false);
benchmark0.set_param(param);
benchmark0.set_times(RUN);
benchmark0.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"F32STRD2_LARGE_GROUP"));
auto opr = handle()->create_operator<ConvBias>();
opr->param() = param;
param.format = param::ConvBias::Format::NCHW44;
Benchmarker<ConvBias> benchmark1(handle());
benchmark1.set_display(false);
benchmark1.set_param(param);
benchmark1.set_times(RUN);
benchmark1.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"F32_CHANNEL_WISE_NCHW44"));
auto run = [&](size_t group, size_t w, size_t h, size_t kernel) {
TensorLayout dst_layout;
opr->deduce_layout({{1, group * 4, h, w}, dtype::Int8()},
{{group * 4, 1, 1, kernel, kernel}, dtype::Int8()},
{{1, group * 4, 1, 1}, dtype::Int32()}, {},
dst_layout);
//! dst.nr_elems * IC * FH * FW * 2
float computations = dst_layout.total_nr_elems() * kernel * kernel *
2.0 / (1024 * 1024 * 1024) * 1e3;
auto used0 = benchmark0.exec({{1, group * 4, h, w},
{group * 4, 1, 1, kernel, kernel},
{1, group * 4, 1, 1},
{},
{}}) /
RUN;
auto used1 = benchmark1.exec({{1, group, h, w, 4},
{group, 1, 1, kernel, kernel, 4},
{1, group, 1, 1, 4},
{},
{}}) /
RUN;
printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops "
"nchw44: "
"%f ms %f GFlops "
"speedup: %f\n",
group, h, w, kernel, used0, computations / used0, used1,
computations / used1, used0 / used1);
};
for (size_t group : {8, 16, 32, 64}) {
for (size_t kerenl : {2, 3, 5}) {
run(group, 112, 112, kerenl);
run(group, 56, 56, kerenl);
run(group, 48, 48, kerenl);
run(group, 28, 28, kerenl);
run(group, 14, 14, kerenl);
}
}
run(8, 112, 112, 3);
run(32, 56, 56, 3);
run(64, 28, 28, 3);
run(128, 14, 14, 3);
}
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) {
// have to remove preferred restrict in usable func before run the benchmark
......
......@@ -181,9 +181,9 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
return args;
}
std::vector<conv_bias::TestArg> get_int8_quint8_nchw44_channel_wise_args(
std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args(
std::vector<size_t> kernel, size_t stride, bool no_bias,
bool no_nonlinemode) {
bool no_nonlinemode, bool no_full_bias) {
using namespace conv_bias;
using Param = param::ConvBias;
using NLMode = param::ConvBias::NonlineMode;
......@@ -213,6 +213,15 @@ std::vector<conv_bias::TestArg> get_int8_quint8_nchw44_channel_wise_args(
TensorShape{group, 1, 1, kernel, kernel, 4},
TensorShape{1, group, 1, 1, 4});
}
if (!no_full_bias) {
args.emplace_back(
param, TensorShape{n, group, h, w, 4},
TensorShape{group, 1, 1, kernel, kernel, 4},
TensorShape{n, group,
(h + 2 * param.pad_w - kernel) / stride + 1,
(w + 2 * param.pad_w - kernel) / stride + 1,
4});
}
};
std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
......@@ -224,7 +233,7 @@ std::vector<conv_bias::TestArg> get_int8_quint8_nchw44_channel_wise_args(
for (auto nlmode : nonlinemode) {
for (bool pad : {true}) {
for (size_t group : {1, 2, 4, 7, 128}) {
for (size_t size : {4, 5, 6, 7, 8, 9, 10, 15, 40}) {
for (size_t size : {4, 6, 7, 9, 15, 40}) {
for (size_t kern : kernel) {
pack(n, group, size, size, kern, stride, nlmode,
pad);
......@@ -234,7 +243,7 @@ std::vector<conv_bias::TestArg> get_int8_quint8_nchw44_channel_wise_args(
}
for (bool pad : {false}) {
for (size_t group : {1, 2, 7, 128}) {
for (size_t size : {7, 8, 9, 10, 15, 40}) {
for (size_t size : {7, 9, 15, 40}) {
for (size_t kern : kernel) {
pack(n, group, size, size, kern, stride, nlmode,
pad);
......@@ -374,6 +383,18 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) {
false, true),
handle(), "F32_CONV_NCHW_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44) {
check_conv_bias(
get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, false),
handle(), "F32_CHANNEL_WISE_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) {
check_conv_bias(
get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, false),
handle(), "F32_CHANNEL_WISE_NCHW44");
}
/**********************************F16 direct************************/
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_LARGE_GROUP) {
......@@ -447,14 +468,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2_SMALL_GROUP) {
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT1_NCHW44) {
checker_conv_bias_int8x8x32_multi(
get_int8_quint8_nchw44_channel_wise_args({2, 3, 5}, 1, false, true),
get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true),
handle(), "S8_CHAN_WISE_STRD1_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT2_NCHW44) {
checker_conv_bias_int8x8x32_multi(
get_int8_quint8_nchw44_channel_wise_args({2, 3, 5}, 2, false, true),
get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true),
handle(), "S8_CHAN_WISE_STRD2_NCHW44");
}
......@@ -490,15 +511,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) {
handle(), "S8_NCHW44_DIRECT_STRD2");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) {
checker_conv_bias_qint8x8x8(get_int8_quint8_nchw44_channel_wise_args(
{2, 3, 5}, 1, false, false),
handle(), "S8_CHAN_WISE_STRD1_NCHW44");
checker_conv_bias_qint8x8x8(
get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, true),
handle(), "S8_CHAN_WISE_STRD1_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
checker_conv_bias_qint8x8x8(get_int8_quint8_nchw44_channel_wise_args(
{2, 3, 5}, 2, false, false),
handle(), "S8_CHAN_WISE_STRD2_NCHW44");
checker_conv_bias_qint8x8x8(
get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, true),
handle(), "S8_CHAN_WISE_STRD2_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册