提交 6c29548d 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/arm): fix nchw_nchw44 dot stride1 support

GitOrigin-RevId: c8d3d55b258e2a43c27b903808566f2ea1857842
上级 02cbb13b
/**
* \file dnn/src/arm_common/conv_bias/block_helper.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/common/utils.h"
namespace megdnn {
namespace {
// block_helper is used to calculate oh block size
static inline int l2_block_helper(const int nthread, const int amount,
const int size_per_unit) {
constexpr int l2_cache_size = 256 * 1024;
const int block_per_thread = div_ceil(amount, nthread);
const int best_block = std::min(
amount, (l2_cache_size + size_per_unit / 2) / size_per_unit);
const int max_block_num = div_ceil(block_per_thread, best_block);
const int min_block_num = std::max(max_block_num - 1, 1);
const int max_block = div_ceil(block_per_thread, max_block_num);
const int min_block = div_ceil(block_per_thread, min_block_num);
const int max_loss = std::abs(max_block_num * max_block - block_per_thread);
const int min_loss = std::abs(min_block_num * min_block - block_per_thread);
int block = max_loss > min_loss ? min_block : max_block;
return block;
}
} // namespace
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
*/ */
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/fp32/algos.h" #include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h"
...@@ -26,22 +27,7 @@ using conv_fun = std::function<void( ...@@ -26,22 +27,7 @@ using conv_fun = std::function<void(
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw44_stride1) MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw44_stride1)
namespace { namespace {
// block_helper is used to calculate oh block size
static inline int block_helper(const int nthread, const int amount,
const int size_per_unit) {
constexpr int l2_cache_size = 256 * 1024;
const int block_per_thread = div_ceil(amount, nthread);
const int best_block = std::min(
amount, (l2_cache_size + size_per_unit / 2) / size_per_unit);
const int max_block_num = div_ceil(block_per_thread, best_block);
const int min_block_num = std::max(max_block_num - 1, 1);
const int max_block = div_ceil(block_per_thread, max_block_num);
const int min_block = div_ceil(block_per_thread, min_block_num);
const int max_loss = std::abs(max_block_num * max_block - block_per_thread);
const int min_loss = std::abs(min_block_num * min_block - block_per_thread);
int block = max_loss > min_loss ? min_block : max_block;
return block;
}
static inline size_t get_perthread_cache_bytes(const int ic, const int ih2, static inline size_t get_perthread_cache_bytes(const int ic, const int ih2,
const int iw2) { const int iw2) {
// border_size is used to avoid read illegal memory // border_size is used to avoid read illegal memory
...@@ -60,7 +46,7 @@ static void get_rectified_size( ...@@ -60,7 +46,7 @@ static void get_rectified_size(
ow2 = ow; ow2 = ow;
constexpr int cacheline = 64 / sizeof(float); constexpr int cacheline = 64 / sizeof(float);
int block_oh = int block_oh =
block_helper(param.nr_threads, oh, ic * iw * sizeof(float) * 2); l2_block_helper(param.nr_threads, oh, ic * iw * sizeof(float) * 2);
auto&& fm = param.filter_meta; auto&& fm = param.filter_meta;
const int stride_h = static_cast<int>(fm.stride[0]); const int stride_h = static_cast<int>(fm.stride[0]);
const int filter_h = static_cast<int>(fm.spatial[0]); const int filter_h = static_cast<int>(fm.spatial[0]);
...@@ -106,8 +92,8 @@ static void do_conv_kern(WorkspaceBundle bundle, ...@@ -106,8 +92,8 @@ static void do_conv_kern(WorkspaceBundle bundle,
const int group_id = ncb_index.ndrange_id[1]; const int group_id = ncb_index.ndrange_id[1];
constexpr int oc_idx = 0; constexpr int oc_idx = 0;
int oc_block = oc; int oc_block = oc;
int oh_block = block_helper(kern_param.nr_threads, oh2, int oh_block = l2_block_helper(kern_param.nr_threads, oh2,
ic * iw * sizeof(float) * stride_h); ic * iw * sizeof(float) * stride_h);
const int oh_idx = ncb_index.ndrange_id[2]; const int oh_idx = ncb_index.ndrange_id[2];
const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block); const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block);
const int ih_real = oh_block_real * stride_h + fh - stride_h; const int ih_real = oh_block_real * stride_h + fh - stride_h;
...@@ -298,8 +284,8 @@ ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_kerns( ...@@ -298,8 +284,8 @@ ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_kerns(
int ic = param.filter_meta.icpg; int ic = param.filter_meta.icpg;
int iw = param.isz[1]; int iw = param.isz[1];
int stride_h = param.filter_meta.stride[0]; int stride_h = param.filter_meta.stride[0];
int oh_block = block_helper(param.nr_threads, oh, int oh_block = l2_block_helper(param.nr_threads, oh,
ic * iw * sizeof(float) * stride_h); ic * iw * sizeof(float) * stride_h);
CpuNDRange ncb_range = {static_cast<size_t>(batch), CpuNDRange ncb_range = {static_cast<size_t>(batch),
static_cast<size_t>(group), static_cast<size_t>(group),
static_cast<size_t>(div_ceil(oh, oh_block))}; static_cast<size_t>(div_ceil(oh, oh_block))};
......
...@@ -133,6 +133,21 @@ public: ...@@ -133,6 +133,21 @@ public:
}; };
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
class ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "ARMDOTS8_NCHW_NCHW44"; }
bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam&,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(FallbackConvBiasImpl*,
const NCBKernSizeParam&) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase {
bool m_large_group; bool m_large_group;
......
/**
* \file
* dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "midout.h"
using namespace megdnn;
using namespace arm_common;
using conv_fun = std::function<void(
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_dot)
namespace {
static inline size_t get_perthread_cache_bytes(const int ic, const int ih2,
const int iw2,
const int stride) {
//! border_size is used to avoid read illegal memory
constexpr int cacheline_size = 64;
constexpr int border_size = 2 * cacheline_size;
const int pack_iw_len = stride == 1 ? 4 : 1;
return round_up(
ic * ih2 * iw2 * pack_iw_len * (int)sizeof(int8_t) + border_size,
cacheline_size);
}
static inline size_t get_temp_bytes(const int iw, const int pw) {
//! border_size is used to avoid read illegal memory
constexpr int cacheline_size = 64;
constexpr int border_size = 1 * cacheline_size;
return round_up(iw + pw * 2, cacheline_size) + border_size;
}
static void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2,
int& iw2) {
auto&& fm = param.filter_meta;
const int stride_h = static_cast<int>(fm.stride[0]);
const int filter_h = static_cast<int>(fm.spatial[0]);
int ic = param.filter_meta.icpg;
int iw = param.isz[1];
int oh = param.osz[0];
int block_oh = l2_block_helper(param.nr_threads, oh,
ic * iw * sizeof(int8_t) * stride_h);
ih2 = block_oh * stride_h + filter_h - stride_h;
iw2 = iw + 2 * static_cast<int>(fm.padding[1]);
}
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
int ic = fm.icpg;
int fh = fm.spatial[0];
int fw = fm.spatial[1];
int iw = param.isz[1];
int pw = param.filter_meta.padding[1];
int stride_w = param.filter_meta.stride[1];
int ih2, iw2;
get_rectified_size(param, ih2, iw2);
size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2, stride_w);
size_t weight_size = fm.group * fm.icpg * fm.ocpg * fh * round_up(fw, 4);
size_t temp_size = 0;
if (fm.stride[0] == 1) {
temp_size = get_temp_bytes(iw, pw);
}
return {nullptr,
{src_size * param.nr_threads, weight_size,
temp_size * param.nr_threads}};
};
void do_weight_trans(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex&, const CpuNDRange&) {
const int ic = kern_param.filter_meta.icpg;
const int oc = kern_param.filter_meta.ocpg;
const int fh = kern_param.filter_meta.spatial[0];
const int fw = kern_param.filter_meta.spatial[1];
const int fw2 = round_up(fw, 4);
bundle.set(kern_param.workspace_ptr);
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1));
auto origin_weight = kern_param.filter<dt_int8>();
pack_weight_int8_nchw_nchw44_dot(packed_weight, origin_weight, oc, ic, fh,
fw, fw2);
}
template <size_t filter, BiasMode bias_mode, typename Op, int stride>
static void do_conv_kern(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange&, const CpuNDRange&) {
const int oh = kern_param.osz[0];
const int ow = kern_param.osz[1];
const int fh = kern_param.filter_meta.spatial[0];
const int fw = kern_param.filter_meta.spatial[1];
const int ic = kern_param.filter_meta.icpg;
const int oc = kern_param.filter_meta.ocpg;
const int ih = kern_param.isz[0];
const int iw = kern_param.isz[1];
const int stride_h = kern_param.filter_meta.stride[0];
const int stride_w = kern_param.filter_meta.stride[1];
const int ph = kern_param.filter_meta.padding[0];
const int pw = kern_param.filter_meta.padding[1];
int ih2 = 0;
int iw2 = 0;
get_rectified_size(kern_param, ih2, iw2);
bundle.set(kern_param.workspace_ptr);
constexpr int pack_c = 4;
const int batch_id = ncb_index.ndrange_id[0];
const int group_id = ncb_index.ndrange_id[1];
constexpr int oc_idx = 0;
int oc_block = oc;
int oh_block = l2_block_helper(kern_param.nr_threads, oh,
ic * iw * sizeof(int8_t) * stride_h);
const int oh_idx = ncb_index.ndrange_id[2];
const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block);
const int ih_real = oh_block_real * stride_h + fh - stride_h;
const int src_top_pad = std::max(ph - oh_idx * oh_block * stride_h, 0);
const int src_bottom_pad = std::max(
(oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph,
0);
const int remain_right_pad = std::max(iw2 - iw - pw, 0);
const int src_offset = std::max(oh_idx * oh_block * stride_h - ph, 0) * iw;
const int8_t* origin_sptr =
static_cast<const int8_t*>(
kern_param.src<int8_t>(batch_id, group_id, 0, 1, 1)) +
src_offset;
const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2, stride_w);
int8_t* sptr = reinterpret_cast<int8_t*>(bundle.get(0)) +
ncb_index.thread_id * src_size;
int8_t* tmp_ptr = nullptr;
if (stride == 1) {
const size_t tmp_size = get_temp_bytes(iw, pw);
tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) +
ncb_index.thread_id * tmp_size;
}
pack_src_int8_nchw_nchw44_dot<stride>(
sptr, origin_sptr, ph, pw, remain_right_pad,
ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad,
src_bottom_pad, ic, ih * iw, tmp_ptr);
const int8_t* fptr =
reinterpret_cast<int8_t*>(bundle.get(1)) + oc_idx * fh * fw * ic;
int8_t* dst = kern_param.dst<int8_t>(batch_id, group_id) +
oh_idx * oh_block * ow * pack_c;
const int bias_offset = oc_idx;
const int32_t* bptr =
kern_param.bias<dt_int32>(batch_id, group_id) + bias_offset;
float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale;
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale;
Op op(scale_bias, scale_dst);
conv_direct_int8_nchw_nchw44_dot<bias_mode, Op, filter, stride>(
sptr, fptr, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh,
oh_block_real, ow, op);
}
} // namespace
bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
auto fh = fm.spatial[0];
int oc = fm.ocpg;
int ic = fm.icpg;
bool ok_type = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS8))) &&
(fm.format == param::Convolution::Format::NCHW44);
bool ok_src_dst = (oc % 4 == 0 && oc >= 4 && ic < 4);
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] &&
(fh == 2 || fh == 3 || fh == 5 || fh == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[0] == 2);
bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS;
bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv;
return avaible;
}
size_t ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
const int batch = param.n;
const int group = fm.group;
WorkspaceBundle wbundle = get_bundle(param);
conv_fun do_conv_fun = nullptr;
// NOTE: remain_w is not used to gen hash of midout for compatible with
// shape runtime
#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_dot, \
midout_iv(#stride #filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(stride, filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_BIAS_MODE_PARAM(stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN(stride) \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(stride, 2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(stride, 3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(stride, 5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(stride, 7) \
break; \
default: \
megdnn_assert(0); \
break; \
}
switch (param.filter_meta.stride[0]) {
case 1:
DISPATCH_CONV_KERN(1);
break;
case 2:
DISPATCH_CONV_KERN(2);
break;
default:
megdnn_assert(0);
break;
}
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
#undef GET_OP_PARAM
#undef GET_BIAS_MODE_PARAM
#undef DISPATCH_CONV_KERN
megdnn_assert(do_conv_fun);
SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
WorkspaceBundle bundle = wbundle;
int oh = param.osz[0];
int ic = param.filter_meta.icpg;
int iw = param.isz[1];
int stride_h = param.filter_meta.stride[0];
int oh_block = l2_block_helper(param.nr_threads, oh,
ic * iw * sizeof(int8_t) * stride_h);
CpuNDRange ncb_range = {static_cast<size_t>(batch),
static_cast<size_t>(group),
static_cast<size_t>(div_ceil(oh, oh_block))};
auto do_trans_weight = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
do_weight_trans(bundle, kern_param, ncb_index, ncb_index.ndrange_id);
};
ret_kerns.push_back({do_trans_weight, {1}});
auto do_conv = [bundle, do_conv_fun, ncb_range](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id,
ncb_range);
};
ret_kerns.push_back({do_conv, ncb_range});
return ret_kerns;
}
#endif
// vim: syntax=cpp.doxygen
...@@ -55,6 +55,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { ...@@ -55,6 +55,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44;
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
AlgoDotS8DirectNCHWNCHW44 ds8_direct_stride2_nchw_nchw44;
AlgoDotS8DirectStride1 ds8_direct_stride1_large_group{true}; AlgoDotS8DirectStride1 ds8_direct_stride1_large_group{true};
AlgoDotS8DirectStride1 ds8_direct_stride1_small_group{false}; AlgoDotS8DirectStride1 ds8_direct_stride1_small_group{false};
AlgoDotS8DirectStride2 ds8_direct_stride2_large_group{true}; AlgoDotS8DirectStride2 ds8_direct_stride2_large_group{true};
...@@ -93,6 +94,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { ...@@ -93,6 +94,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
public: public:
AlgoPack() { AlgoPack() {
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
direct_algos.emplace_back(&ds8_direct_stride2_nchw_nchw44);
direct_algos.emplace_back(&ds8_direct_stride1_large_group); direct_algos.emplace_back(&ds8_direct_stride1_large_group);
direct_algos.emplace_back(&ds8_direct_stride1_small_group); direct_algos.emplace_back(&ds8_direct_stride1_small_group);
direct_algos.emplace_back(&ds8_direct_stride2_large_group); direct_algos.emplace_back(&ds8_direct_stride2_large_group);
......
...@@ -62,6 +62,7 @@ private: ...@@ -62,6 +62,7 @@ private:
class AlgoFP16WinogradF23_8x8; class AlgoFP16WinogradF23_8x8;
#endif #endif
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
class AlgoDotS8DirectNCHWNCHW44;
class AlgoDotS8DirectStride1; class AlgoDotS8DirectStride1;
class AlgoDotS8DirectStride2; class AlgoDotS8DirectStride2;
class AlgoDotU8DirectStride1; class AlgoDotU8DirectStride1;
......
...@@ -60,6 +60,14 @@ struct Vfmaq_laneq_f32 { ...@@ -60,6 +60,14 @@ struct Vfmaq_laneq_f32 {
return vfmaq_laneq_f32(a, b, v, lane); return vfmaq_laneq_f32(a, b, v, lane);
} }
}; };
#if __ARM_FEATURE_DOTPROD
struct Vdotq_laneq_s32 {
template <const int lane>
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_laneq_s32(a, b, v, lane);
}
};
#endif
} // namespace } // namespace
} // namespace megdnn } // namespace megdnn
......
...@@ -481,37 +481,71 @@ UNROLL_CALL_RAW(4, cb); ...@@ -481,37 +481,71 @@ UNROLL_CALL_RAW(4, cb);
#define vdup_laneq_s16(vec, lane) Vdup_laneq_s16_armv7<lane>::impl(vec) #define vdup_laneq_s16(vec, lane) Vdup_laneq_s16_armv7<lane>::impl(vec)
namespace { namespace {
template <int lane> template <int lane>
struct Vfmap_laneq_f32_armv7 { struct Vfmaq_laneq_f32_armv7 {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v); static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v);
}; };
template <> template <>
struct Vfmap_laneq_f32_armv7<0> { struct Vfmaq_laneq_f32_armv7<0> {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlaq_lane_f32(a, b, vget_low_f32(v), 0); return vmlaq_lane_f32(a, b, vget_low_f32(v), 0);
} }
}; };
template <> template <>
struct Vfmap_laneq_f32_armv7<1> { struct Vfmaq_laneq_f32_armv7<1> {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlaq_lane_f32(a, b, vget_low_f32(v), 1); return vmlaq_lane_f32(a, b, vget_low_f32(v), 1);
} }
}; };
template <> template <>
struct Vfmap_laneq_f32_armv7<2> { struct Vfmaq_laneq_f32_armv7<2> {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlaq_lane_f32(a, b, vget_high_f32(v), 0); return vmlaq_lane_f32(a, b, vget_high_f32(v), 0);
} }
}; };
template <> template <>
struct Vfmap_laneq_f32_armv7<3> { struct Vfmaq_laneq_f32_armv7<3> {
static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlaq_lane_f32(a, b, vget_high_f32(v), 1); return vmlaq_lane_f32(a, b, vget_high_f32(v), 1);
} }
}; };
} // namespace } // namespace
#define vfmaq_laneq_f32(a, b, v, lane) \ #define vfmaq_laneq_f32(a, b, v, lane) \
Vfmap_laneq_f32_armv7<lane>::impl(a, b, v) Vfmaq_laneq_f32_armv7<lane>::impl(a, b, v)
#if __ARM_FEATURE_DOTPROD
template <int lane>
struct Vdotq_laneq_s32_armv7 {
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v);
};
template <>
struct Vdotq_laneq_s32_armv7<0> {
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_lane_s32(a, b, vget_low_s32(v), 0);
}
};
template <>
struct Vdotq_laneq_s32_armv7<1> {
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_lane_s32(a, b, vget_low_s32(v), 1);
}
};
template <>
struct Vdotq_laneq_s32_armv7<2> {
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_lane_s32(a, b, vget_high_s32(v), 0);
}
};
template <>
struct Vdotq_laneq_s32_armv7<3> {
static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_lane_s32(a, b, vget_high_f32(v), 1);
}
};
#define vdotq_laneq_s32(a, b, v, lane) \
Vdotq_laneq_s32_armv7<lane>::impl(a, b, v)
#endif
#endif #endif
......
...@@ -109,14 +109,12 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) { ...@@ -109,14 +109,12 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
.set_dtype(4, dtype::QuantizedS8(60.25)) .set_dtype(4, dtype::QuantizedS8(60.25))
.set_display(false); .set_display(false);
benchmarker_int.set_before_exec_callback( benchmarker_int.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>( conv_bias::ConvBiasAlgoChecker<ConvBias>("IM2COLMATMUL:.+"));
"IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384"));
Benchmarker<ConvBias> benchmarker_float(handle); Benchmarker<ConvBias> benchmarker_float(handle);
benchmarker_float.set_display(false).set_times(RUNS); benchmarker_float.set_display(false).set_times(RUNS);
benchmarker_float.set_before_exec_callback( benchmarker_float.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>( conv_bias::ConvBiasAlgoChecker<ConvBias>("IM2COLMATMUL:.+"));
"IM2COLMATMUL:AARCH64_F32K8X12X1:192"));
Benchmarker<ConvBias> benchmarker_nchw44(handle); Benchmarker<ConvBias> benchmarker_nchw44(handle);
if (is_fp32) { if (is_fp32) {
...@@ -213,6 +211,15 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) { ...@@ -213,6 +211,15 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
run(1, 256, 256, 14, 14, 3, 1, false); run(1, 256, 256, 14, 14, 3, 1, false);
run(1, 512, 512, 7, 7, 3, 1, false); run(1, 512, 512, 7, 7, 3, 1, false);
} else { } else {
run(1, 1, 4, 112, 112, 2, 2, true);
run(1, 3, 32, 224, 224, 3, 2, true);
run(1, 3, 32, 224, 224, 5, 2, true);
run(1, 3, 64, 224, 224, 7, 2, true);
run(1, 1, 4, 112, 112, 2, 1, true);
run(1, 3, 32, 224, 224, 3, 1, true);
run(1, 3, 32, 224, 224, 5, 1, true);
run(1, 3, 64, 224, 224, 7, 1, true);
for (size_t stride : {1, 2}) { for (size_t stride : {1, 2}) {
printf("stride %zu\n", stride); printf("stride %zu\n", stride);
for (size_t filter_size : {2, 3, 5, 7}) { for (size_t filter_size : {2, 3, 5, 7}) {
...@@ -228,9 +235,11 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) { ...@@ -228,9 +235,11 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
} }
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) { TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) {
benchmark_convbias(handle(), true); benchmark_convbias(handle(), true);
benchmark_convbias(handle(), false);
} }
TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) { TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) {
benchmark_convbias(handle(), true); benchmark_convbias(handle(), true);
benchmark_convbias(handle(), false);
} }
#endif #endif
......
...@@ -557,6 +557,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) { ...@@ -557,6 +557,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) {
/****************************dot qint8 direct*************************/ /****************************dot qint8 direct*************************/
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false,
true),
handle(), "ARMDOTS8_NCHW_NCHW44");
checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false,
true),
handle(), "ARMDOTS8_NCHW_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) { CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册