提交 c9986df5 编写于 作者: M Megvii Engine Team

feat(dnn/arm): add fp32 nchw_nchw44 conv

GitOrigin-RevId: f19fe892d9f3e4c166d4835804bf5fc0ad31ccbc
上级 ca855d8d
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
......@@ -156,7 +157,6 @@ private:
uint32_t m_tile_size;
};
class ConvBiasImpl::AlgoF32Direct final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
......@@ -217,6 +217,24 @@ public:
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
public:
AlgoF32DirectStride2NCHWNCHW44() {}
bool is_reproducible() const override { return true; }
const char* name() const override { return "F32_CONV_NCHW_NCHW44"; }
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
};
} // namespace arm_common
} // namespace megdnn
......
/**
* \file
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"
#include "midout.h"
using namespace megdnn;
using namespace arm_common;
using conv_fun = std::function<void(
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw_nchw44_stride2)
namespace {
static inline int block_helper(const int nthread, const int amount,
const int per_unit_bytes) {
MEGDNN_MARK_USED_VAR(per_unit_bytes);
const int block_per_thread = div_ceil(amount, nthread);
const int best_block = 16;
const int max_block_num = div_ceil(block_per_thread, best_block);
const int min_block_num = std::max(max_block_num - 1, 1);
const int max_block = div_ceil(block_per_thread, max_block_num);
const int min_block = div_ceil(block_per_thread, min_block_num);
const int max_loss = std::abs(max_block_num * max_block - block_per_thread);
const int min_loss = std::abs(min_block_num * min_block - block_per_thread);
int block = max_loss > min_loss ? min_block : max_block;
return block;
}
static inline size_t get_perthread_cache_bytes(const int ic, const int ih2,
const int iw2) {
// border_size is used to avoid read illegal memory
int border_size = 64 * 2;
return ic * ih2 * iw2 * sizeof(float) + border_size;
}
static void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2,
int& iw2, int& oh2, int& ow2) {
int iw = param.isz[1];
int oh = param.osz[0];
int ow = param.osz[1];
oh2 = oh;
ow2 = ow;
constexpr int cacheline = 64 / sizeof(float);
int block_oh = block_helper(param.nr_threads, oh, 0);
auto&& fm = param.filter_meta;
const int stride_h = static_cast<int>(fm.stride[0]);
const int filter_h = static_cast<int>(fm.spatial[0]);
ih2 = block_oh * stride_h + filter_h - stride_h;
iw2 = round_up(iw + 2 * static_cast<int>(fm.padding[1]), cacheline);
}
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
int group = fm.group;
int ic = fm.icpg;
int oc = fm.ocpg;
int fh = fm.spatial[0];
int fw = fm.spatial[1];
int ih2, iw2, oh2, ow2;
get_rectified_size(param, ih2, iw2, oh2, ow2);
int oh_block = block_helper(param.nr_threads, oh2, 0);
megdnn_assert(oh_block != 0, "oh_block!=0");
size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2);
size_t weight_size = group * oc * ic * fh * fw * sizeof(float);
return {nullptr, {src_size * param.nr_threads, weight_size}};
};
static inline void copy_pad_src(float* sptr_base, const float* sptr_origin,
int ph, int pw, int pad_right, int ih, int iw,
int iw2, int pad_top, int pad_bottom, int ic,
int ic_stride) {
MEGDNN_MARK_USED_VAR(ph);
rep(ic_idx, ic) {
const float* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0, sizeof(float) * iw2 * pad_top);
sptr_base += iw2 * pad_top;
rep(ih_idx, ih) {
memset(sptr_base, 0, sizeof(float) * pw);
sptr_base += pw;
memcpy(sptr_base, sptr, sizeof(float) * iw);
sptr_base += iw;
sptr += iw;
memset(sptr_base, 0, sizeof(float) * pad_right);
sptr_base += pad_right;
}
memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom);
sptr_base += iw2 * pad_bottom;
}
}
static void pack_weight(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index) {
bundle.set(kern_param.workspace_ptr);
const int group_id = ncb_index.ndrange_id[0];
int fh = kern_param.filter_meta.spatial[0];
int fw = kern_param.filter_meta.spatial[1];
int oc = kern_param.filter_meta.ocpg;
int ic = kern_param.filter_meta.icpg;
int oc_block = oc;
int oc_idx = 0;
const float* fptr =
kern_param.filter<dt_float32>(group_id) + oc_idx * fh * fw * ic;
auto packed_weight = reinterpret_cast<float*>(bundle.get(1)) +
group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw;
conv_bias::pack_weight_fp32_nchw_nchw44(fptr, packed_weight, oc_block, fh,
fw, ic);
}
template <size_t filter, BiasMode bias_mode, typename Op>
static void do_conv_kern(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange&, const CpuNDRange&) {
const int oh = kern_param.osz[0];
const int ow = kern_param.osz[1];
const int fh = kern_param.filter_meta.spatial[0];
const int fw = kern_param.filter_meta.spatial[1];
const int ic = kern_param.filter_meta.icpg;
const int oc = kern_param.filter_meta.ocpg;
const int ih = kern_param.isz[0];
const int iw = kern_param.isz[1];
const int stride_h = kern_param.filter_meta.stride[0];
const int ph = kern_param.filter_meta.padding[0];
const int pw = kern_param.filter_meta.padding[1];
int ih2 = 0;
int iw2 = 0;
int oh2 = 0;
int ow2 = 0;
get_rectified_size(kern_param, ih2, iw2, oh2, ow2);
bundle.set(kern_param.workspace_ptr);
constexpr int pack_c = 4;
const int batch_id = ncb_index.ndrange_id[0];
const int group_id = ncb_index.ndrange_id[1];
int oc_idx = 0;
int oc_block = oc;
int oh_block = block_helper(kern_param.nr_threads, oh2, 0);
const int oh_idx = ncb_index.ndrange_id[2];
const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block);
const int ih_real = oh_block_real * stride_h + fh - stride_h;
const int src_top_pad = std::max(ph - oh_idx * oh_block * stride_h, 0);
const int src_bottom_pad = std::max(
(oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph,
0);
const int remain_right_pad = std::max(iw2 - iw - pw, 0);
const int src_offset = std::max(oh_idx * oh_block * stride_h - ph, 0) * iw;
const float* origin_sptr = static_cast<const float*>(kern_param.src<float>(
batch_id, group_id, 0, 1, 1)) +
src_offset;
const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2);
float* sptr = reinterpret_cast<float*>((int8_t*)bundle.get(0) +
ncb_index.thread_id * src_size);
copy_pad_src(sptr, origin_sptr, ph, pw, remain_right_pad,
ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad,
src_bottom_pad, ic, ih * iw);
// pack weight
auto packed_weight = reinterpret_cast<float*>(bundle.get(1)) +
group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw;
// get param
float_t* dst = kern_param.dst<float_t>(batch_id, group_id) +
oh_idx * oh_block * ow * pack_c;
const float* bptr =
kern_param.bias<dt_float32>(batch_id, group_id) + oc_idx;
Op op;
#define KERN1_NCHW44_CONV(filter) \
conv_bias::conv_direct_stride2_##filter##x##filter##_fp32_nchw_nchw44< \
\
bias_mode, Op>(sptr, packed_weight, bptr, nullptr, dst, oc_block, \
ic, ih_real, iw2, oh, oh_block_real, ow, op, ph, \
pw)
DISPATCH_FILTER(filter, KERN1_NCHW44_CONV);
#undef KERN1_NCHW44_CONV
}
} // namespace
/* ===================== stride2 algo ===================== */
bool ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::usable(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
auto fh = fm.spatial[0];
int oc = fm.ocpg;
bool ok_type = ((param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
(param.dst_type.enumv() == DTypeEnum::Float32))) &&
(fm.format == param::Convolution::Format::NCHW44);
bool ok_src_dst = fm.icpg < 4 && (oc % 4 == 0 && oc >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] &&
(fh == 3 || fh == 5 || fh == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2;
bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS;
bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv;
return avaible;
}
size_t ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
const int batch = param.n;
const int group = fm.group;
WorkspaceBundle wbundle = get_bundle(param);
conv_fun do_conv_fun = nullptr;
// NOTE: remain_w is not used to gen hash of midout for compatible with
// shape runtime
#define DO_CONV_KERN_FUN(filter, bias_mode, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw_nchw44_stride2, \
midout_iv(#filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(filter, bias_mode, NoneOp<dt_float32>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(filter, bias_mode, ReluOp<dt_float32>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(filter, bias_mode, HSwishOp<dt_float32>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_BIAS_MODE_PARAM(filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN() \
switch (param.filter_meta.spatial[0]) { \
case 3: \
GET_BIAS_MODE_PARAM(3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(7) \
break; \
default: \
megdnn_assert(0); \
break; \
}
DISPATCH_CONV_KERN();
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
#undef GET_OP_PARAM
#undef GET_BIAS_MODE_PARAM
#undef DISPATCH_CONV_KERN
megdnn_assert(do_conv_fun);
SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
WorkspaceBundle bundle = wbundle;
int oh = param.osz[0];
int oh_block = block_helper(param.nr_threads, oh, 0);
auto do_pack_weight = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
pack_weight(bundle, kern_param, ncb_index);
};
ret_kerns.push_back({do_pack_weight, {static_cast<size_t>(group)}});
CpuNDRange ncb_range = {static_cast<size_t>(batch),
static_cast<size_t>(group),
static_cast<size_t>(div_ceil(oh, oh_block))};
auto do_conv = [bundle, do_conv_fun, ncb_range](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id,
ncb_range);
};
ret_kerns.push_back({do_conv, ncb_range});
return ret_kerns;
}
// vim: syntax=cpp.doxygen
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
using namespace megdnn;
using namespace arm_common;
namespace {
template <int src_idx, int weight_idx, int c_dim, typename Func, typename T,
typename T2, typename T3, typename T4>
struct ShiftCalHelper {
static void impl(T& c, T2& src, T3& weight);
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
constexpr int stride = 2;
#define cb(step) \
c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \
c[0][step], weight[0][weight_idx], \
src[(step * stride + src_idx) / 4]); \
c[1][step] = Func::template impl<(step * stride + src_idx) % 4>( \
c[1][step], weight[1][weight_idx], \
src[(step * stride + src_idx) / 4]);
UNROLL_CALL_RAW(8, cb);
#undef cb
}
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
constexpr int stride = 2;
#define cb(step) \
c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \
c[0][step], weight[0][weight_idx], \
src[(step * stride + src_idx) / 4]);
UNROLL_CALL_RAW(8, cb);
#undef cb
}
};
template <int src_idx, int weight_idx, int c_dim, typename FUNC, typename T,
typename T2, typename T3>
inline void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, T, T2, T3, int>::impl(
c, src, weight);
};
template <int oc>
struct OCHelper {
public:
static const int val = -1;
};
template <>
struct OCHelper<4> {
public:
static const int val = 1;
};
template <>
struct OCHelper<8> {
public:
static const int val = 2;
};
/**
* oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel
* */
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int oc_block>
struct KerNeonXXs2NchwNchw44FP32 {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op);
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 7, oc_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 7;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size = 6;
constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \
src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<3, 3, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<4, 4, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<5, 5, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<6, 6, c_dim, Vfmaq_laneq_f32>(c, src, weight);
UNROLL_CALL_RAW(7, KERNEL_CB)
#undef KERNEL_CB
src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 5, oc_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 5;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size = 5;
constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \
src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<3, 3, c_dim, Vfmaq_laneq_f32>(c, src, weight); \
cal_helper<4, 4, c_dim, Vfmaq_laneq_f32>(c, src, weight);
UNROLL_CALL_RAW(5, KERNEL_CB)
#undef KERNEL_CB
src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 3;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size = 5;
constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
// row 0
load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<3, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr,
ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight);
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight);
// row 1
load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + iw, 0);
load_helper<3, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight);
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight);
// row 2
load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + 2 * iw, 0);
load_helper<3, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight);
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight);
src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};
} // namespace
void conv_bias::pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr,
float32_t* dst_ptr, const int oc,
const int kh, const int kw,
const int ic) {
constexpr int oc_step = 4;
const int filter_oc_stride = kh * kw * ic;
const int filter_ic_stride = kh * kw * oc_step;
for (int oc_idx = 0; oc_idx < oc; oc_idx += oc_step) {
const float32_t* in_ptr_oc = in_ptr + oc_idx * filter_oc_stride;
float32_t* dst_ptr_oc = dst_ptr + oc_idx * filter_oc_stride;
for (int kh_idx = 0; kh_idx < kh; ++kh_idx) {
for (int kw_idx = 0; kw_idx < kw; ++kw_idx) {
for (int ic_idx = 0; ic_idx < ic; ++ic_idx) {
float32x4_t vsrc = vld1q_f32(in_ptr_oc);
vst1q_f32(dst_ptr_oc + ic_idx * filter_ic_stride, vsrc);
in_ptr_oc += oc_step;
}
dst_ptr_oc += oc_step;
}
}
}
}
template <BiasMode bias_mode, typename Op, int filter_size>
static void conv_direct_stride2_fp32_nchw_nchw44(
const float32_t* src, const float32_t* filter, const float32_t* bias,
float32_t*, float32_t* dst, const int oc, const int ic, const int ih,
const int iw, const int oh, const int oh_block, const int ow,
const Op& op, const int, const int) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 1;
constexpr int big_oc_step = 8;
constexpr int oc_step = 4;
constexpr int ih_step = 1;
constexpr int oh_step = 1;
constexpr int ow_step = 8;
constexpr int stride_h = 2;
constexpr int stride_w = 2;
constexpr int pack_iw_len = 1;
const int img_stride = oh * ow;
const int ow_end = ow / ow_step * ow_step;
const int ow_remain = ow - ow_end;
const int oc_end = oc / big_oc_step * big_oc_step;
const int oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;
using remain_fun = std::function<void(
const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \
big_oc_step>::impl; \
kern_small_oc_remain = \
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \
oc_step>::impl; \
break;
UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}
for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44FP32<
bias_mode, Op, 0, filter_size,
big_oc_step>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
if (oc_remain > 0) {
int oc_idx = oc_end;
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, 0, filter_size,
oc_step>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}
}
#define CONSTRUCT_FUNC(filter_size) \
template <BiasMode bias_mode, typename Op> \
void conv_bias:: \
conv_direct_stride2_##filter_size##x##filter_size##_fp32_nchw_nchw44( \
const float32_t* src, const float32_t* filter, \
const float32_t* bias, float32_t* temp, float32_t* dst, \
const int oc, const int ic, const int ih, const int iw, \
const int oh, const int oh_block, const int ow, \
const Op& op, const int ph, const int pw) { \
conv_direct_stride2_fp32_nchw_nchw44<bias_mode, Op, filter_size>( \
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, oh_block, \
ow, op, ph, pw); \
}
CONSTRUCT_FUNC(3);
CONSTRUCT_FUNC(5);
CONSTRUCT_FUNC(7);
#undef CONSTRUCT_FUNC
template <BiasMode bias_mode, typename Op>
void conv_bias::conv_direct_stride2_2x2_fp32_nchw_nchw44(
const float32_t*, const float32_t*, const float32_t*, float32_t*,
float32_t*, const int, const int, const int, const int, const int,
const int, const int, const Op&, const int, const int) {
megdnn_assert(0, "not imple nchw_nchw44 2x2s2 conv");
}
#define INSTANTIATION(stride, i, bias, Op) \
template void conv_bias:: \
conv_direct_##stride##_##i##x##i##_fp32_nchw_nchw44<bias, Op>( \
const float32_t*, const float32_t*, const float32_t*, \
float32_t*, float32_t*, const int, const int, const int, \
const int, const int, const int, const int, const Op&, \
const int, const int);
#define FOR_OP(stride, i, bias) \
INSTANTIATION(stride, i, bias, NoneOp<dt_float32>) \
INSTANTIATION(stride, i, bias, ReluOp<dt_float32>) \
INSTANTIATION(stride, i, bias, HSwishOp<dt_float32>)
#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)
#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)
FOR_FILTER(stride2)
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
/**
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace conv_bias {
#define KERN(stride, i, layout) \
template <BiasMode bias_mode, typename Op> \
void conv_direct_##stride##_##i##x##i##_fp32_nchw_##layout( \
const float* src, const float* filter, const float* bias, \
float* temp, float* dst, const int oc, const int ic, const int ih, \
const int iw, const int oh, const int oh_block, const int ow, \
const Op& op, const int ph, const int pw);
KERN(stride2, 2, nchw44)
KERN(stride2, 3, nchw44)
KERN(stride2, 5, nchw44)
KERN(stride2, 7, nchw44)
#undef KERN
void pack_weight_fp32_nchw_nchw44(const float_t* in_ptr, float_t* dst_ptr,
const int oc, const int kh, const int kw,
const int ic);
} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn
\ No newline at end of file
......@@ -174,7 +174,167 @@ inline void store_ocx_ow4_remain_static(T& c, const Op& op, int8_t* dst_ptr,
int ld_dst_oc) {
StoreOcxOw4Remain<c_dim, ow_remain, Op, T>::impl(c, op, dst_ptr, ld_dst_oc);
}
////////////////////Store_OCX_OW8_Remain/////////////////////////
template <int c_dim, int ow_remain, typename Op, typename T>
struct StoreOcxOw8Remain {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc);
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<2, 0, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op({{c[0][4], c[0][5]}}, dst_ptr + 16);
op({{c[0][6], c[0][7]}}, dst_ptr + 24);
op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc);
op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8);
op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16);
op({{c[1][6], c[1][7]}}, dst_ptr + ld_dst_oc + 24);
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<2, 7, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op({{c[0][4], c[0][5]}}, dst_ptr + 16);
op(c[0][6], dst_ptr + 24);
op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc);
op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8);
op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16);
op(c[1][6], dst_ptr + ld_dst_oc + 24);
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<2, 6, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op({{c[0][4], c[0][5]}}, dst_ptr + 16);
op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc);
op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8);
op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16);
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<2, 5, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op(c[0][4], dst_ptr + 16);
op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc);
op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8);
op(c[1][4], dst_ptr + ld_dst_oc + 16);
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<2, 4, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc);
op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8);
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<2, 3, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op(c[0][2], dst_ptr + 8);
op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc);
op(c[1][2], dst_ptr + ld_dst_oc + 8);
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<2, 2, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc);
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<2, 1, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) {
op(c[0][0], dst_ptr);
op(c[1][0], dst_ptr + ld_dst_oc);
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<1, 0, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op({{c[0][4], c[0][5]}}, dst_ptr + 16);
op({{c[0][6], c[0][7]}}, dst_ptr + 24);
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<1, 7, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op({{c[0][4], c[0][5]}}, dst_ptr + 16);
op(c[0][6], dst_ptr + 24);
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<1, 6, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op({{c[0][4], c[0][5]}}, dst_ptr + 16);
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<1, 5, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
op(c[0][4], dst_ptr + 16);
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<1, 4, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op({{c[0][2], c[0][3]}}, dst_ptr + 8);
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<1, 3, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int) {
op({{c[0][0], c[0][1]}}, dst_ptr);
op(c[0][2], dst_ptr + 8);
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<1, 2, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int) {
op({{c[0][0], c[0][1]}}, dst_ptr);
}
};
template <typename Op, typename T>
struct StoreOcxOw8Remain<1, 1, Op, T> {
static void impl(T& c, const Op& op, float32_t* dst_ptr, int) {
op(c[0][0], dst_ptr);
}
};
template <int c_dim, int ow_remain, typename Op, typename T>
inline void store_ocx_ow8_remain_static(T& c, const Op& op, float32_t* dst_ptr,
int ld_dst_oc) {
StoreOcxOw8Remain<c_dim, ow_remain, Op, T>::impl(c, op, dst_ptr, ld_dst_oc);
}
////////////////////Store_OC8_OW8_Remain/////////////////////////
template <int ow_remain, typename Op>
......@@ -299,14 +459,15 @@ struct Store_OC8_OW8_Remain<1, Op> {
}
};
template <int ow_remain, typename Op>
inline void store_oc8_ow8_remain_static(int32x4_t c[2][8], const Op& op,
int8_t* dst_ptr, int ld_dst_oc) {
///////////
template <int ow_remain, typename Op, typename T, typename T2>
inline void store_oc8_ow8_remain_static(T& c, const Op& op, T2 dst_ptr,
int ld_dst_oc) {
Store_OC8_OW8_Remain<ow_remain, Op>::impl(c, op, dst_ptr, ld_dst_oc);
}
///////////////////////////////////////////////////////
//////////////////////////////////////
template <BiasMode bias_mode>
inline void init_oc4_ow8(int32x4_t c[8], const int32_t* bias_ptr) {
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
......@@ -337,6 +498,49 @@ inline void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr,
#undef BAIS_INIT
}
}
/////////////////////////init_ocx_ow8////////////////////
template <int c_dim, BiasMode bias_mode, typename T, typename T2>
struct InitOcxOw8 {
static void impl(T& c, T2 bias_ptr, int oc_step);
};
template <BiasMode bias_mode, typename T, typename T2>
struct InitOcxOw8<2, bias_mode, T, T2> {
static void impl(T& c, const float32_t* bias_ptr, int oc_step) {
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
#define BAIS_INIT(step) \
c[0][step] = vld1q_f32(bias_ptr); \
c[1][step] = vld1q_f32(bias_ptr + oc_step);
UNROLL_CALL_RAW(8, BAIS_INIT);
#undef BAIS_INIT
} else {
#define BAIS_INIT(step) \
c[0][step] = vdupq_n_f32(0); \
c[1][step] = vdupq_n_f32(0);
UNROLL_CALL_RAW(8, BAIS_INIT);
#undef BAIS_INIT
}
}
};
template <BiasMode bias_mode, typename T, typename T2>
struct InitOcxOw8<1, bias_mode, T, T2> {
static void impl(T& c, const float32_t* bias_ptr, int) {
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
#define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr);
UNROLL_CALL_RAW(8, BAIS_INIT);
#undef BAIS_INIT
} else {
#define BAIS_INIT(step) c[0][step] = vdupq_n_f32(0);
UNROLL_CALL_RAW(8, BAIS_INIT);
#undef BAIS_INIT
}
}
};
template <int c_dim, BiasMode bias_mode, typename T, typename T2>
inline void init_ocx_ow8(T& c, T2 bias_ptr, int oc_step) {
InitOcxOw8<c_dim, bias_mode, T, T2>::impl(c, bias_ptr, oc_step);
}
/////////////////////init_ocx_ow4/////////////////////
template <int c_dim, BiasMode bias_mode, typename T>
struct InitOcxOw4 {
static void impl(T& c, const int32_t* bias_ptr, int oc_step);
......@@ -383,57 +587,54 @@ inline void init_ocx_ow4(T& c, const int32_t* bias_ptr, int oc_step) {
}
///////////////////////////////////////
template <int weight_number, int base_offset, int ptr_step, int oc_block,
typename Func, typename T, typename... XT>
typename Func, typename T, typename T2, typename... XT>
struct LoadHelper {
static void impl(T& weight, const int8_t* ptr, int oc_offset, XT... args);
static void impl(T& weight, T2 ptr, int oc_offset, XT... args);
};
#define WEIGHT_CB(step) \
src[step] = Func::impl(ptr + base_offset + step * ptr_step, args...);
template <int base_offset, int ptr_step, typename Func, typename T,
template <int base_offset, int ptr_step, typename Func, typename T, typename T2,
typename... XT>
struct LoadHelper<1, base_offset, ptr_step, 0, Func, T, XT...> {
static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) {
struct LoadHelper<1, base_offset, ptr_step, 0, Func, T, T2, XT...> {
static void impl(T& src, T2 ptr, int, XT... args) {
UNROLL_CALL_RAW(1, WEIGHT_CB);
}
};
template <int base_offset, int ptr_step, typename Func, typename T,
template <int base_offset, int ptr_step, typename Func, typename T, typename T2,
typename... XT>
struct LoadHelper<2, base_offset, ptr_step, 0, Func, T, XT...> {
static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) {
struct LoadHelper<2, base_offset, ptr_step, 0, Func, T, T2, XT...> {
static void impl(T& src, T2 ptr, int, XT... args) {
UNROLL_CALL_RAW(2, WEIGHT_CB);
}
};
template <int base_offset, int ptr_step, typename Func, typename T,
template <int base_offset, int ptr_step, typename Func, typename T, typename T2,
typename... XT>
struct LoadHelper<3, base_offset, ptr_step, 0, Func, T, XT...> {
static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) {
struct LoadHelper<3, base_offset, ptr_step, 0, Func, T, T2, XT...> {
static void impl(T& src, T2 ptr, int, XT... args) {
UNROLL_CALL_RAW(3, WEIGHT_CB);
}
};
template <int base_offset, int ptr_step, typename Func, typename T,
template <int base_offset, int ptr_step, typename Func, typename T, typename T2,
typename... XT>
struct LoadHelper<4, base_offset, ptr_step, 0, Func, T, XT...> {
static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) {
MEGDNN_MARK_USED_VAR(oc_offset);
struct LoadHelper<4, base_offset, ptr_step, 0, Func, T, T2, XT...> {
static void impl(T& src, T2 ptr, int, XT... args) {
UNROLL_CALL_RAW(4, WEIGHT_CB);
}
};
template <int base_offset, int ptr_step, typename Func, typename T,
template <int base_offset, int ptr_step, typename Func, typename T, typename T2,
typename... XT>
struct LoadHelper<5, base_offset, ptr_step, 0, Func, T, XT...> {
static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) {
MEGDNN_MARK_USED_VAR(oc_offset);
struct LoadHelper<5, base_offset, ptr_step, 0, Func, T, T2, XT...> {
static void impl(T& src, T2 ptr, int, XT... args) {
UNROLL_CALL_RAW(5, WEIGHT_CB);
}
};
template <int base_offset, int ptr_step, typename Func, typename T,
template <int base_offset, int ptr_step, typename Func, typename T, typename T2,
typename... XT>
struct LoadHelper<6, base_offset, ptr_step, 0, Func, T, XT...> {
static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) {
MEGDNN_MARK_USED_VAR(oc_offset);
struct LoadHelper<6, base_offset, ptr_step, 0, Func, T, T2, XT...> {
static void impl(T& src, T2 ptr, int, XT... args) {
UNROLL_CALL_RAW(6, WEIGHT_CB);
}
};
......@@ -441,27 +642,36 @@ struct LoadHelper<6, base_offset, ptr_step, 0, Func, T, XT...> {
#define WEIGHT_CB(step) \
src[0][step] = Func::impl(ptr + base_offset + step * ptr_step);
template <int base_offset, int ptr_step, typename Func, typename T>
struct LoadHelper<1, base_offset, ptr_step, 1, Func, T> {
static void impl(T& src, const int8_t* ptr, int oc_offset) {
MEGDNN_MARK_USED_VAR(oc_offset);
UNROLL_CALL_RAW(1, WEIGHT_CB);
}
template <int base_offset, int ptr_step, typename Func, typename T, typename T2>
struct LoadHelper<1, base_offset, ptr_step, 1, Func, T, T2> {
static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(1, WEIGHT_CB); }
};
template <int base_offset, int ptr_step, typename Func, typename T>
struct LoadHelper<2, base_offset, ptr_step, 1, Func, T> {
static void impl(T& src, const int8_t* ptr, int oc_offset) {
MEGDNN_MARK_USED_VAR(oc_offset);
UNROLL_CALL_RAW(2, WEIGHT_CB);
}
template <int base_offset, int ptr_step, typename Func, typename T, typename T2>
struct LoadHelper<2, base_offset, ptr_step, 1, Func, T, T2> {
static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(2, WEIGHT_CB); }
};
template <int base_offset, int ptr_step, typename Func, typename T>
struct LoadHelper<3, base_offset, ptr_step, 1, Func, T> {
static void impl(T& src, const int8_t* ptr, int oc_offset) {
MEGDNN_MARK_USED_VAR(oc_offset);
UNROLL_CALL_RAW(3, WEIGHT_CB);
}
template <int base_offset, int ptr_step, typename Func, typename T, typename T2>
struct LoadHelper<3, base_offset, ptr_step, 1, Func, T, T2> {
static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(3, WEIGHT_CB); }
};
template <int base_offset, int ptr_step, typename Func, typename T, typename T2>
struct LoadHelper<4, base_offset, ptr_step, 1, Func, T, T2> {
static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(4, WEIGHT_CB); }
};
template <int base_offset, int ptr_step, typename Func, typename T, typename T2>
struct LoadHelper<5, base_offset, ptr_step, 1, Func, T, T2> {
static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(5, WEIGHT_CB); }
};
template <int base_offset, int ptr_step, typename Func, typename T, typename T2>
struct LoadHelper<6, base_offset, ptr_step, 1, Func, T, T2> {
static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(6, WEIGHT_CB); }
};
template <int base_offset, int ptr_step, typename Func, typename T, typename T2>
struct LoadHelper<7, base_offset, ptr_step, 1, Func, T, T2> {
static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(7, WEIGHT_CB); }
};
#undef WEIGHT_CB
......@@ -470,40 +680,63 @@ struct LoadHelper<3, base_offset, ptr_step, 1, Func, T> {
src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); \
src[1][step] = Func::impl(ptr + base_offset + step * ptr_step + oc_offset);
template <int base_offset, int ptr_step, typename Func, typename T>
struct LoadHelper<1, base_offset, ptr_step, 2, Func, T> {
static void impl(T& src, const int8_t* ptr, int oc_offset) {
template <int base_offset, int ptr_step, typename Func, typename T, typename T2>
struct LoadHelper<1, base_offset, ptr_step, 2, Func, T, T2> {
static void impl(T& src, T2 ptr, int oc_offset) {
UNROLL_CALL_RAW(1, WEIGHT_CB);
}
};
template <int base_offset, int ptr_step, typename Func, typename T>
struct LoadHelper<2, base_offset, ptr_step, 2, Func, T> {
static void impl(T& src, const int8_t* ptr, int oc_offset) {
template <int base_offset, int ptr_step, typename Func, typename T, typename T2>
struct LoadHelper<2, base_offset, ptr_step, 2, Func, T, T2> {
static void impl(T& src, T2 ptr, int oc_offset) {
UNROLL_CALL_RAW(2, WEIGHT_CB);
}
};
template <int base_offset, int ptr_step, typename Func, typename T>
struct LoadHelper<3, base_offset, ptr_step, 2, Func, T> {
static void impl(T& src, const int8_t* ptr, int oc_offset) {
template <int base_offset, int ptr_step, typename Func, typename T, typename T2>
struct LoadHelper<3, base_offset, ptr_step, 2, Func, T, T2> {
static void impl(T& src, T2 ptr, int oc_offset) {
UNROLL_CALL_RAW(3, WEIGHT_CB);
}
};
template <int base_offset, int ptr_step, typename Func, typename T, typename T2>
struct LoadHelper<4, base_offset, ptr_step, 2, Func, T, T2> {
static void impl(T& src, T2 ptr, int oc_offset) {
UNROLL_CALL_RAW(4, WEIGHT_CB);
}
};
template <int base_offset, int ptr_step, typename Func, typename T, typename T2>
struct LoadHelper<5, base_offset, ptr_step, 2, Func, T, T2> {
static void impl(T& src, T2 ptr, int oc_offset) {
UNROLL_CALL_RAW(5, WEIGHT_CB);
}
};
template <int base_offset, int ptr_step, typename Func, typename T, typename T2>
struct LoadHelper<6, base_offset, ptr_step, 2, Func, T, T2> {
static void impl(T& src, T2 ptr, int oc_offset) {
UNROLL_CALL_RAW(6, WEIGHT_CB);
}
};
template <int base_offset, int ptr_step, typename Func, typename T, typename T2>
struct LoadHelper<7, base_offset, ptr_step, 2, Func, T, T2> {
static void impl(T& src, T2 ptr, int oc_offset) {
UNROLL_CALL_RAW(7, WEIGHT_CB);
}
};
#undef WEIGHT_CB
template <int weight_number, int base_offset, int ptr_step, int c_dim,
typename Func, typename T>
inline void load_helper(T& weight, const int8_t* ptr, int oc_offset) {
LoadHelper<weight_number, base_offset, ptr_step, c_dim, Func, T>::impl(
typename Func, typename T, typename T2>
inline void load_helper(T& weight, T2 ptr, int oc_offset) {
LoadHelper<weight_number, base_offset, ptr_step, c_dim, Func, T, T2>::impl(
weight, ptr, oc_offset);
}
template <int weight_number, int base_offset, int ptr_step, int c_dim,
typename Func, typename T, typename... XT>
inline void load_helper_x(T& weight, const int8_t* ptr, int oc_offset,
XT... args) {
LoadHelper<weight_number, base_offset, ptr_step, c_dim, Func, T,
typename Func, typename T, typename T2, typename... XT>
inline void load_helper_x(T& weight, T2 ptr, int oc_offset, XT... args) {
LoadHelper<weight_number, base_offset, ptr_step, c_dim, Func, T, T2,
XT...>::impl(weight, ptr, oc_offset, args...);
}
......
......@@ -34,6 +34,9 @@ struct Vmlal_s16 {
struct Vld1q_s8 {
static int8x16_t impl(const int8_t* ptr) { return vld1q_s8(ptr); }
};
struct Vld1q_f32 {
static float32x4_t impl(const float32_t* ptr) { return vld1q_f32(ptr); }
};
struct Vld1_s8 {
static int8x8_t impl(const int8_t* ptr) { return vld1_s8(ptr); }
};
......@@ -50,5 +53,13 @@ struct Vldq_tbl_low_s8 {
struct Vld1_dup_s8_s16 {
static int16x8_t impl(const int8_t* ptr) { return vld1_dup_s8_s16(ptr); }
};
struct Vfmaq_laneq_f32 {
template <const int lane>
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vfmaq_laneq_f32(a, b, v, lane);
}
};
} // namespace
} // namespace megdnn
\ No newline at end of file
......@@ -71,6 +71,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoF32DirectStride2 f32_direct_stride2_small_group{false};
AlgoF32DirectStride1 f32_direct_stride1_large_group{true};
AlgoF32DirectStride1 f32_direct_stride1_small_group{false};
AlgoF32DirectStride2NCHWNCHW44 f32_direct_stride2_nchw_nchw44;
AlgoI8x8x16Direct i8x8x16_direct_large_group{true};
AlgoI8x8x16Direct i8x8x16_direct_small_group{false};
AlgoI8x8x16Stride2 i8x8x16_stride2_large_group{true};
......@@ -123,6 +124,7 @@ public:
direct_algos.emplace_back(&i8x8x16_stride2_filter2);
direct_algos.emplace_back(&i8x8x16_stride2_large_group);
direct_algos.emplace_back(&i8x8x16_stride2_small_group);
direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44);
direct_algos.emplace_back(&f32_direct_stride1_large_group);
direct_algos.emplace_back(&f32_direct_stride1_small_group);
direct_algos.emplace_back(&f32_direct_stride2_large_group);
......
......@@ -67,6 +67,7 @@ private:
class AlgoF32Direct;
class AlgoF32DirectStride1;
class AlgoF32DirectStride2;
class AlgoF32DirectStride2NCHWNCHW44;
class AlgoI8x8x16Direct;
class AlgoI8x8x16Stride2;
class AlgoI8x8x16Stride2Filter2;
......
......@@ -45,13 +45,17 @@ struct HSwishOp;
vst1q_##_func_suffix(dst, vitem.val[0]); \
vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
void operator()(const _neon_type& src, _ctype* dst) const { \
auto vitem = operator()(src); \
vst1q_##_func_suffix(dst, vitem); \
} \
_neon_type2 operator()(const _neon_type2& src) const { \
auto val1 = src.val[0]; \
auto val2 = src.val[1]; \
H_SWISH_KERN(_func_suffix, val1, val2); \
return {{val1, val2}}; \
} \
_neon_type operator()(const _neon_type& src) { \
_neon_type operator()(const _neon_type& src) const { \
auto val_zero = vdupq_n_##_func_suffix(0.f); \
auto val_six = vdupq_n_##_func_suffix(6.f); \
auto val_three = vdupq_n_##_func_suffix(3.f); \
......@@ -64,6 +68,7 @@ struct HSwishOp;
val_rec_six); \
} \
};
OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
OP(__fp16, float16x8_t, float16x8x2_t, f16, 8)
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
......@@ -30,6 +31,13 @@ struct NoneOp;
using NoneOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
_neon_type2 operator()(const _neon_type2& src) const { return src; } \
void operator()(const _neon_type2& src, _ctype* dst) const { \
vst1q_##_func_suffix(dst, src.val[0]); \
vst1q_##_func_suffix(dst + SIMD_WIDTH, src.val[1]); \
} \
void operator()(const _neon_type& src, _ctype* dst) const { \
vst1q_##_func_suffix(dst, src); \
} \
_neon_type operator()(const _neon_type& src) const { return src; } \
};
......
......@@ -47,11 +47,16 @@ struct ReluOp;
auto vitem1 = vmaxq_##_func_suffix(src.val[1], vzero); \
return {{vitem0, vitem1}}; \
} \
void operator()(const _neon_type& src, _ctype* dst) const { \
auto vitem = operator()(src); \
vst1q_##_func_suffix(dst, vitem); \
} \
_neon_type operator()(const _neon_type& src) const { \
auto vzero = vdupq_n_##_func_suffix(0); \
return vmaxq_##_func_suffix(src, vzero); \
} \
};
OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
OP(__fp16, float16x8_t, float16x8x2_t, f16, 8)
......
......@@ -479,6 +479,39 @@ UNROLL_CALL_RAW(4, cb);
#undef cb
} // namespace
#define vdup_laneq_s16(vec, lane) Vdup_laneq_s16_armv7<lane>::impl(vec)
namespace {
template <int lane>
struct Vfmap_laneq_f32_armv7 {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v);
};
template <>
struct Vfmap_laneq_f32_armv7<0> {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlaq_lane_f32(a, b, vget_low_f32(v), 0);
}
};
template <>
struct Vfmap_laneq_f32_armv7<1> {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlaq_lane_f32(a, b, vget_low_f32(v), 1);
}
};
template <>
struct Vfmap_laneq_f32_armv7<2> {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlaq_lane_f32(a, b, vget_high_f32(v), 0);
}
};
template <>
struct Vfmap_laneq_f32_armv7<3> {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlaq_lane_f32(a, b, vget_high_f32(v), 1);
}
};
} // namespace
#define vfmaq_laneq_f32(a, b, v, lane) \
Vfmap_laneq_f32_armv7<lane>::impl(a, b, v)
#endif
......
......@@ -85,7 +85,7 @@ TEST_F(ARM_COMMON, CONV_BIAS_MATMUL_QU8) {
#if MEGDNN_WITH_BENCHMARK
static void benchmark_convbias(Handle* handle) {
static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
constexpr size_t RUNS = 30;
Benchmarker<ConvBias> benchmarker_int(handle);
......@@ -102,15 +102,25 @@ static void benchmark_convbias(Handle* handle) {
Benchmarker<ConvBias> benchmarker_float(handle);
benchmarker_float.set_display(false).set_times(RUNS);
benchmarker_float.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>(".+"));
conv_bias::ConvBiasAlgoChecker<ConvBias>(
"IM2COLMATMUL:AARCH64_F32K8X12X1:192"));
Benchmarker<ConvBias> benchmarker_int_nchw44(handle);
benchmarker_int_nchw44.set_times(RUNS)
.set_dtype(0, dtype::QuantizedS8(2.5))
.set_dtype(1, dtype::QuantizedS8(2.5))
.set_dtype(2, dtype::QuantizedS32(6.25))
.set_dtype(4, dtype::QuantizedS8(60.25))
.set_display(false);
if (is_fp32) {
benchmarker_int_nchw44.set_times(RUNS)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(4, dtype::Float32())
.set_display(false);
} else {
benchmarker_int_nchw44.set_times(RUNS)
.set_dtype(0, dtype::QuantizedS8(2.5))
.set_dtype(1, dtype::QuantizedS8(2.5))
.set_dtype(2, dtype::QuantizedS32(6.25))
.set_dtype(4, dtype::QuantizedS8(60.25))
.set_display(false);
}
benchmarker_int_nchw44.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>(".+"));
......@@ -151,7 +161,6 @@ static void benchmark_convbias(Handle* handle) {
auto int_nchw44_used = benchmarker_int_nchw44.set_param(param).exec(
{src, filter, bias, {}, dst}) /
RUNS;
float computations = IC * (FS * FS) * dst.total_nr_elems() * 2 * 1e-6;
printf("run: %s %s %s->%s \n", src.to_string().c_str(),
filter.to_string().c_str(), bias.to_string().c_str(),
......@@ -160,32 +169,42 @@ static void benchmark_convbias(Handle* handle) {
computations / float_used);
printf("int_nchw: %f ms %f Gflops, ", int_used,
computations / int_used);
printf("int_nchw44: %f ms %f Gflops %f speedup, ", int_nchw44_used,
computations / int_nchw44_used, int_used / int_nchw44_used);
auto speed_up = int_used / int_nchw44_used;
if (is_fp32) {
speed_up = float_used / int_nchw44_used;
printf("fp32_nchw44: %f ms %f Gflops %f speedup, ", int_nchw44_used,
computations / int_nchw44_used, speed_up);
} else {
printf("int_nchw44: %f ms %f Gflops %f speedup, ", int_nchw44_used,
computations / int_nchw44_used, speed_up);
}
printf("\n");
};
run(1, 3, 32, 224, 224, 3, 2, true);
run(1, 3, 64, 224, 224, 5, 2, true);
run(1, 3, 64, 224, 224, 7, 2, true);
run(1, 3, 32, 224, 224, 7, 2, true);
for (size_t stride : {1, 2}) {
printf("stride %zu\n", stride);
for (size_t filter_size : {2, 3, 5, 7}) {
for (size_t img_size : {32}) {
for (size_t channel : {8, 16, 32, 64, 128, 256}) {
run(1, channel, channel, img_size, img_size, filter_size,
stride, false);
if (is_fp32) {
run(1, 3, 32, 224, 224, 3, 2, true);
run(1, 3, 64, 224, 224, 7, 2, true);
} else {
for (size_t stride : {1, 2}) {
printf("stride %zu\n", stride);
for (size_t filter_size : {2, 3, 5, 7}) {
for (size_t img_size : {32}) {
for (size_t channel : {8, 16, 32, 64, 128, 256}) {
run(1, channel, channel, img_size, img_size,
filter_size, stride, false);
}
}
}
}
}
}
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) {
benchmark_convbias(handle());
benchmark_convbias(handle(), true);
}
TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) {
benchmark_convbias(handle());
benchmark_convbias(handle(), true);
}
#endif
TEST_F(ARM_COMMON, CONV_BIAS_MATMUL_QS8) {
using namespace conv_bias;
......@@ -1464,7 +1483,8 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) {
#if MEGDNN_WITH_BENCHMARK
namespace {
std::vector<conv_bias::TestArg> get_conv_bias_1x1_benchmark_args(size_t pack_size = 1) {
std::vector<conv_bias::TestArg> get_conv_bias_1x1_benchmark_args(
size_t pack_size = 1) {
using namespace conv_bias;
std::vector<TestArg> args;
param::ConvBias param;
......@@ -1474,15 +1494,17 @@ std::vector<conv_bias::TestArg> get_conv_bias_1x1_benchmark_args(size_t pack_siz
param.pad_w = 0;
param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY;
auto bench_case = [&](size_t OC, size_t IC, size_t H, size_t W) {
if(pack_size == 1)
if (pack_size == 1)
args.emplace_back(param, TensorShape{1, IC, H, W},
TensorShape{OC, IC, 1, 1}, TensorShape{});
TensorShape{OC, IC, 1, 1}, TensorShape{});
else {
if(pack_size == 4)
if (pack_size == 4)
param.format = param::ConvBias::Format::NCHW44;
args.emplace_back(param, TensorShape{1, IC / pack_size, H, W, pack_size},
TensorShape{OC / pack_size, IC / pack_size, 1, 1, pack_size, pack_size},
TensorShape{});
args.emplace_back(param,
TensorShape{1, IC / pack_size, H, W, pack_size},
TensorShape{OC / pack_size, IC / pack_size, 1, 1,
pack_size, pack_size},
TensorShape{});
}
};
......
......@@ -78,9 +78,10 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
std::vector<TestArg> args;
auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w,
size_t kernel, size_t stride, size_t group, NLMode nlmode) {
size_t kernel, size_t stride, size_t group, NLMode nlmode,
int any_pad = -1) {
constexpr int pack_c = 4;
const size_t pad = no_pad ? 0 : kernel / 2;
const size_t pad = any_pad >= 0 ? any_pad : kernel / 2;
auto bias_mode = no_bias ? megdnn::BiasMode::NO_BIAS
: megdnn::BiasMode::BROADCAST_CHANNEL_BIAS;
auto oc_per_group = oc / group;
......@@ -90,7 +91,8 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
ic_per_group > 0;
bool nchw_disable = group > 1 || ic_per_group >= 4;
bool nchw44_disable = ic_per_group % pack_c != 0;
if (!(ok_group)) {
bool invalid_pad = (w + 2 * pad < kernel) || (h + 2 * pad < kernel);
if (!(ok_group) || invalid_pad) {
return;
}
if ((is_input_nchw && nchw_disable) ||
......@@ -107,6 +109,7 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
param.pad_h = pad;
param.pad_w = pad;
param.nonlineMode = nlmode;
auto src_tensor_shape = TensorShape{n, ic / pack_c, h, w, pack_c};
auto weight_tensor_shape = TensorShape{
oc / pack_c, ic / pack_c, kernel_h, kernel_w, pack_c, pack_c};
......@@ -338,6 +341,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) {
check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
handle(), "F32STRD2_SMALL_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) {
check_conv_bias(
get_nchw44_conv_bias_args({3, 5, 7}, 2, false, false, false, true),
handle(), "F32_CONV_NCHW_NCHW44");
}
/**********************************F16 direct************************/
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_LARGE_GROUP) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册