提交 9e876203 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(dnn): add int8 direct conv dot nchw44

GitOrigin-RevId: 31830ba7a49f7c0b9fb3f011e09f934601a825a0
上级 09ceaaae
......@@ -189,6 +189,28 @@ public:
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase {
public:
AlgoDotS8Direct_NCHW44() {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return "ARMDOTS8DIRECT_NCHW44";
}
bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam&,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(FallbackConvBiasImpl*,
const NCBKernSizeParam&) const override;
SmallVector<NCBKern> dispatch_kerns(
fallback::ConvBiasImpl* opr,
const NCBKernSizeParam& param) const override;
bool is_preferred(megdnn::fallback::ConvBiasImpl*,
const NCBKernSizeParam& param) const override;
};
#endif
class ConvBiasImpl::AlgoS8WinogradF23_8x8 final : public AlgoBase {
......
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#ifdef __ARM_FEATURE_DOTPROD
#include "src/arm_common/elemwise_helper/kimpl/typecvt.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h"
namespace megdnn {
namespace arm_common {
namespace direct_dotprod_nchw44 {
template <>
void copy_packed_src_int8_nchw44<1>(int8_t* dst, const int dst_step,
const int8_t* src, const int src_step,
const int ic, const int ic_step,
const int ih, const int pad_left,
const int pad_right, const int pad_top,
const int pad_bottom) {
constexpr int IC_PACK_SIZE = 4;
rep_step(ic_idx, ic, IC_PACK_SIZE) {
const int8_t* i_src = src + ic_idx * ic_step;
//! pad top
int bytes_pad_top = pad_top * dst_step * IC_PACK_SIZE * sizeof(int8_t);
memset(dst, 0, bytes_pad_top);
dst += bytes_pad_top / sizeof(int8_t);
rep(ih_idx, ih) {
int bytes_row_in_dst = dst_step * IC_PACK_SIZE * sizeof(int8_t);
memset(dst, 0, bytes_row_in_dst);
//! left elements
int pad_left_elements = pad_left * IC_PACK_SIZE;
//! copy row [ih_idx, x]
int bytes_copy = src_step * IC_PACK_SIZE * sizeof(int8_t);
memcpy(dst + pad_left_elements, i_src, bytes_copy);
//! dst move to next row
dst += bytes_row_in_dst / sizeof(int8_t);
//! src move to next row
i_src += bytes_copy / sizeof(int8_t);
}
//! pad bottom
int bytes_pad_bottom =
pad_bottom * dst_step * IC_PACK_SIZE * sizeof(int8_t);
memset(dst, 0, bytes_pad_bottom);
dst += bytes_pad_bottom / sizeof(int8_t);
}
}
template <>
void copy_packed_src_int8_nchw44<2>(int8_t* dst, const int dst_step,
const int8_t* src, const int src_step,
const int ic, const int ic_step,
const int ih, const int pad_left,
const int pad_right, const int pad_top,
const int pad_bottom) {
constexpr int IC_PACK_SIZE = 4;
int odd_start = megdnn::div_ceil(dst_step, 2);
bool nochange = pad_left % 2 == 0;
rep_step(ic_idx, ic, IC_PACK_SIZE) {
const int32_t* i_src =
reinterpret_cast<const int32_t*>(src + ic_idx * ic_step);
int bytes_pad_top = pad_top * dst_step * IC_PACK_SIZE * sizeof(int8_t);
memset(dst, 0, bytes_pad_top);
dst += bytes_pad_top / sizeof(int8_t);
rep(ih_idx, ih) {
int bytes_row_in_dst = dst_step * IC_PACK_SIZE * sizeof(int8_t);
memset(dst, 0, bytes_row_in_dst);
int32_t* dst_even = reinterpret_cast<int32_t*>(dst) + pad_left / 2 +
pad_left % 2;
int32_t* dst_odd =
reinterpret_cast<int32_t*>(dst) + odd_start + pad_left / 2;
int i_src_idx = 0;
if (nochange) {
for (; i_src_idx + 7 < src_step; i_src_idx += 8) {
int32x4x2_t tmp;
tmp = vld2q_s32(i_src + i_src_idx);
vst1q_s32(dst_even, tmp.val[0]);
vst1q_s32(dst_odd, tmp.val[1]);
dst_even += 4;
dst_odd += 4;
}
} else {
for (; i_src_idx + 7 < src_step; i_src_idx += 8) {
int32x4x2_t tmp;
tmp = vld2q_s32(i_src + i_src_idx);
vst1q_s32(dst_even, tmp.val[1]);
vst1q_s32(dst_odd, tmp.val[0]);
dst_even += 4;
dst_odd += 4;
}
}
for (; i_src_idx < src_step; ++i_src_idx) {
if (nochange) {
if (i_src_idx % 2 == 0) {
*dst_even = *(i_src + i_src_idx);
dst_even++;
} else {
*dst_odd = *(i_src + i_src_idx);
dst_odd++;
}
} else {
if (i_src_idx % 2 == 0) {
*dst_odd = *(i_src + i_src_idx);
dst_odd++;
} else {
*dst_even = *(i_src + i_src_idx);
dst_even++;
}
}
}
//! dst move to next row
dst += bytes_row_in_dst / sizeof(int8_t);
//! src move to next row
i_src += src_step;
}
//! pad bottom
int bytes_pad_bottom =
pad_bottom * dst_step * IC_PACK_SIZE * sizeof(int8_t);
memset(dst, 0, bytes_pad_bottom);
dst += bytes_pad_bottom / sizeof(int8_t);
}
}
template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int filter_size>
void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow,
const int8_t* src, const int ih, const int iw,
const int8_t* filter, const int32_t* bias,
const int oh_size, const int oc, const int ic,
const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int IC_PACK_SIZE = 4;
constexpr int OC_PACK_SIZE = 4;
#if MEGDNN_AARCH64
constexpr int OC_BIG_INTERVAL = 12;
constexpr int OC_MID_INTERVAL = 8;
constexpr int OC_SMA_INTERVAL = 4;
#else
constexpr int OC_BIG_INTERVAL = 4;
constexpr int OC_MID_INTERVAL = 4;
constexpr int OC_SMA_INTERVAL = 4;
#endif
constexpr int OW_INTERVAL = 8;
constexpr int SH = stride;
const int dst_numbers_per_channel = oh * ow;
const int ow_remain = ow % OW_INTERVAL;
const int ow_end_idx = ow - ow_remain;
const int oc_remain =
oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8
const int oc_end_idx = oc - oc_remain;
const int dst_numbers_4channel_packed =
dst_numbers_per_channel * OC_PACK_SIZE;
using remain_fun = std::function<void(
dst_type * dst, const int dst_step, const int8_t* src, const int ih,
const int iw, const int8_t* filter, const int32_t* bias,
const int ic, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_mid_oc_remain = nullptr;
remain_fun kern_sma_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_BIG_INTERVAL, \
OW_INTERVAL>::impl; \
kern_mid_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_MID_INTERVAL, \
OW_INTERVAL>::impl; \
kern_sma_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_SMA_INTERVAL, \
OW_INTERVAL>::impl; \
break;
UNROLL_CALL_RAW(8, cb);
#undef cb
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}
//! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC]
//! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL,
//! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates
//! [OW_INTERVAL, 1, OC_INTERVAL] each time
for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) {
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_BIG_INTERVAL, OW_INTERVAL>::
impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
kern_big_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}
#ifdef MEGDNN_AARCH64
//! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32
if (oc_remain) {
int oc_idx = oc_end_idx;
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_MID_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
} else {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_SMA_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
}
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
kern_mid_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
} else {
kern_sma_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}
}
#endif
}
#define CONSTRUCT_FUNC(filter_size) \
template <typename dst_type, BiasMode bias_mode, typename Op, int stride> \
void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \
dst_type* dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op) { \
conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, Op, \
filter_size>( \
dst, oh, ow, src, ih, iw, weight, bias, oh_size, oc, ic, op); \
}
CONSTRUCT_FUNC(2);
CONSTRUCT_FUNC(3);
CONSTRUCT_FUNC(5);
CONSTRUCT_FUNC(7);
#undef CONSTRUCT_FUNC
#define INSTANTIATION(dst_type, stride, i, bias_mode, Op) \
template void conv_direct_##i##x##i##_int8_nchw44<dst_type, bias_mode, Op, \
stride>( \
dst_type * dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op);
#define FOR_OP(stride, i, bias_mode) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int32, stride, i, bias_mode, \
NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)
#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)
#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)
FOR_FILTER(1)
FOR_FILTER(2)
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn
#endif
//vim: syntax=cpp.doxygen
\ No newline at end of file
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#pragma once
#include "src/arm_common/conv_bias/opr_impl.h"
namespace megdnn {
namespace arm_common {
namespace direct_dotprod_nchw44 {
using BiasMode = ConvBiasForward::BiasMode;
/**
* @brief : do direct conv with no side effect
* input buffer's size is [ih, iw]
* output buffer's size is [oh, ow]
* filter layout is [OC/4, IC/4, FH, FW, 4, 4]
*
* @param : [output ptr] dst
* [input] oh -> dst rows
* [input] ow -> dst cols
* [input ptr] src
* [input] ih -> rows of src used by this this kernel
* [input] iw -> src step in elements [iw2]
* [input ptr] filter
* [input ptr] bias
* [input] oh_size -> rows of result generated by this kernel
* [input] oc -> output channels
* [input] ic -> intput channels
* [input] op -> post process operator
* @return none
*/
#define KERN(filter_size) \
template <typename dst_type, BiasMode bias_mode, typename Op, int stride> \
void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \
dst_type* dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op)
KERN(2);
KERN(3);
KERN(5);
KERN(7);
#undef KERN
/**
* @brief : copy data from src to dst for direct conv with no side effect
* @param : [output ptr] dst
* [input] dst_step -> step of dst in numbers of elements
* [input ptr] src
* [input] src_step -> step of src in numbers of elements
* [input] ic -> input channels
* [input] ic_step -> step of ic in numbers of elements
* [input] ih -> totle rows to copy
* [input] pad_left -> cols padding at left
* [input] pad_right -> cols padding at right
* [input] pad_top -> rows padding at top
* [input] pad_bottom -> rows padding at bottom
* @return none
*/
template <int stride>
void copy_packed_src_int8_nchw44(int8_t* dst, const int dst_step,
const int8_t* src, const int src_step,
const int ic, const int ic_step, const int ih,
const int pad_left, const int pad_right,
const int pad_top, const int pad_bottom);
} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn
#endif
//vim: syntax=cpp.doxygen
\ No newline at end of file
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_dotpord_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h"
#include "src/arm_common/elemwise_op.h"
#include "midout.h"
using namespace megdnn;
using namespace arm_common;
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8)
using direct_fun = std::function<void(
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& ncb_param,
const ConvBiasImpl::NCBKernIndex& ncb_index)>;
namespace {
static void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih,
int& iw, int& oh, int& ow) {
int IC = param.filter_meta.icpg;
int IW = param.isz[1];
int OH = param.osz[0];
int OW = param.osz[1];
oh = OH;
ow = OW;
constexpr int cacheline = 64 / sizeof(int8_t);
int oh_tile_size =
l2_block_helper(param.nr_threads, OH, IC * IW * sizeof(int8_t) * 2);
auto&& fm = param.filter_meta;
const int SH = static_cast<int>(fm.stride[0]);
const int FH = static_cast<int>(fm.spatial[0]);
const int PW = static_cast<int>(fm.padding[1]);
ih = oh_tile_size * SH + FH - SH;
iw = round_up(IW + 2 * PW, cacheline);
}
static inline int get_perthread_cache_bytes(const int ic, const int ih,
const int iw) {
// border_size is used to avoid read illegal memory
int border_size = 64 * 2;
return ic * ih * iw * sizeof(int8_t) + border_size;
}
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
int IC = fm.icpg;
int ih2, iw2, oh2, ow2;
get_rectified_size(param, ih2, iw2, oh2, ow2);
int bytes_of_copy_per_thread = get_perthread_cache_bytes(IC, ih2, iw2);
return {nullptr, {bytes_of_copy_per_thread * param.nr_threads}};
}
template <typename dst_type, size_t filter_size, BiasMode bias_mode,
typename Op, int stride>
static void conv_kern(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& ncb_param,
const ConvBiasImpl::NCBKernIndex& ncb_index) {
const int OH = ncb_param.osz[0];
const int OW = ncb_param.osz[1];
const int FH = ncb_param.filter_meta.spatial[0];
const int IC = ncb_param.filter_meta.icpg;
const int OC = ncb_param.filter_meta.ocpg;
const int IH = ncb_param.isz[0];
const int IW = ncb_param.isz[1];
const int SH = ncb_param.filter_meta.stride[0];
const int PH = ncb_param.filter_meta.padding[0];
const int PW = ncb_param.filter_meta.padding[1];
int ih2 = 0;
int iw2 = 0;
int oh2 = 0;
int ow2 = 0;
get_rectified_size(ncb_param, ih2, iw2, oh2, ow2);
constexpr int IC_PACK_SIZE = 4;
constexpr int OC_PACK_SIZE = 4;
bundle.set(ncb_param.workspace_ptr);
const int batch_id = ncb_index.ndrange_id[0];
const int group_id = ncb_index.ndrange_id[1];
const int oh_tile_id = ncb_index.ndrange_id[2];
const int thread_id = ncb_index.thread_id;
const int oh_tile_size = l2_block_helper(ncb_param.nr_threads, OH,
IC * IW * sizeof(int8_t) * 2);
const int oh_start_row = oh_tile_id * oh_tile_size;
const int ih_start_row = std::max(oh_start_row * SH - PH, 0);
const int oh_real_size = std::min(OH - oh_start_row, oh_tile_size);
const int ih_real_size = oh_real_size * SH + FH - SH;
const int rows_padding_at_top = std::max(PH - oh_start_row * SH, 0);
const int rows_padding_at_bottom =
std::max((oh_start_row + oh_real_size - 1) * SH + FH - IH - PH, 0);
const int cols_padding_at_left = PW;
const int cols_padding_at_right = std::max(iw2 - IW - PW, 0);
//! src layout{IC/4, IH, IW, 4}
const int bytes_of_src_offset =
ih_start_row * IW * IC_PACK_SIZE * sizeof(int8_t);
const int8_t* copy_src = static_cast<const int8_t*>(
ncb_param.src<int8_t>(batch_id, group_id) + bytes_of_src_offset);
const int bytes_of_copy_per_thread =
get_perthread_cache_bytes(IC, ih2, iw2);
int8_t* copy_dst = reinterpret_cast<int8_t*>(bundle.get(0)) +
thread_id * bytes_of_copy_per_thread;
const int rows_copy_from_src =
ih_real_size - rows_padding_at_top - rows_padding_at_bottom;
direct_dotprod_nchw44::copy_packed_src_int8_nchw44<stride>(
copy_dst, iw2, copy_src, IW, IC, IH * IW, rows_copy_from_src,
cols_padding_at_left, cols_padding_at_right, rows_padding_at_top,
rows_padding_at_bottom);
const int8_t* weights = ncb_param.filter<int8_t>(group_id);
dst_type* dst = ncb_param.dst<dst_type>(batch_id, group_id) +
oh_start_row * OW * OC_PACK_SIZE;
//! only broadcast or no_bias
const int32_t* bias = ncb_param.bias<int32_t>(batch_id, group_id);
Op op = Op(1.0f, 4.0f);
if (ncb_param.dst_type.enumv() == DTypeEnum::QuantizedS8) {
float scale_bias =
ncb_param.bias_type.param<dtype::QuantizedS32>().scale;
float scale_dst = ncb_param.dst_type.param<dtype::QuantizedS8>().scale;
op = Op(scale_bias, scale_dst);
}
#define KERN1_NCHW44_CONV(filter) \
direct_dotprod_nchw44::conv_direct_##filter##x##filter##_int8_nchw44< \
dst_type, bias_mode, Op, stride>(dst, OH, OW, copy_dst, \
ih_real_size, iw2, weights, bias, \
oh_real_size, OC, IC, op);
DISPATCH_FILTER(filter_size, KERN1_NCHW44_CONV);
#undef KERN1_NCHW44_CONV
}
} // namespace
bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::usable(
FallbackConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
auto FW = fm.spatial[1];
auto SH = fm.stride[0];
auto SW = fm.stride[1];
auto OC = fm.ocpg;
auto IC = fm.icpg;
//! src and filter are qint8, dst is qint8.
bool data_type_ok = param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS8 ||
param.dst_type.enumv() == DTypeEnum::QuantizedS32);
if (param.bias_type.valid()) {
data_type_ok &= param.bias_type.enumv() == DTypeEnum::QuantizedS32;
}
data_type_ok |= param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
param.dst_type.enumv() == DTypeEnum::Int32;
bool layout_ok = fm.format == param::Convolution::Format::NCHW44_DOT &&
IC % 4 == 0 && OC % 4 == 0;
bool param_ok = !fm.should_flip && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 && FH == FW &&
(FH >= 2 && FH <= 7);
bool stride_ok = SH == SW && (SH == 1 || SH == 2);
return data_type_ok && layout_ok && param_ok && stride_ok;
}
bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::is_preferred(
megdnn::fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
return true;
}
size_t ConvBiasImpl::AlgoDotS8Direct_NCHW44::get_workspace(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("ALGODOTS8DIRECT_NCHW44"_hash)) {
auto fm = param.filter_meta;
size_t BATCH = param.n;
size_t GROUP = fm.group;
WorkspaceBundle wbundle = get_bundle(param);
direct_fun kernel = nullptr;
bool quantized = param.dst_type.enumv() == DTypeEnum::QuantizedS8;
#define DO_CONV_KERN_FUN(dst_type, filter, bias_mode, op, stride) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, \
midout_iv(#dst_type #filter #bias_mode #op##_hash)) { \
kernel = conv_kern<dst_type, filter, bias_mode, op, stride>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(i, bias_mode, stride) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
if (quantized) { \
DO_CONV_KERN_FUN(dt_int8, i, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>, \
stride) \
} else { \
DO_CONV_KERN_FUN(dt_int32, i, bias_mode, \
NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>, \
stride) \
} \
break; \
case param::ConvBias::NonlineMode::RELU: \
if (quantized) { \
DO_CONV_KERN_FUN(dt_int8, i, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>, \
stride) \
} else { \
megdnn_assert("No support NoQuantized RELU"); \
} \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
if (quantized) { \
DO_CONV_KERN_FUN(dt_int8, i, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>, \
stride) \
} else { \
megdnn_assert("No support NoQuantized H_SWISH"); \
} \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_STRIDE_PARAM(filter, bias_mode) \
switch (fm.stride[0]) { \
case 1: \
GET_OP_PARAM(filter, bias_mode, 1); \
break; \
case 2: \
GET_OP_PARAM(filter, bias_mode, 2); \
break; \
default: \
megdnn_assert(0); \
}
#define GET_BIAS_MODE_PARAM(filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_STRIDE_PARAM(filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_STRIDE_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define SELECT_CONV_KERN() \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(7) \
break; \
default: \
megdnn_assert(0); \
break; \
}
SELECT_CONV_KERN()
#undef DO_CONV_KERN_FUN
#undef GET_OP_PARAM
#undef GET_STRIDE_PARAM
#undef GET_BIAS_MODE_PARAM
#undef SELECT_CONV_KERN
megdnn_assert(kernel);
SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
int OH = param.osz[0];
int IC = param.filter_meta.icpg;
int IW = param.isz[1];
int oh_tile_size = l2_block_helper(param.nr_threads, OH,
IC * IW * sizeof(int8_t) * 2);
size_t oh_tiles = static_cast<size_t>(div_ceil(OH, oh_tile_size));
auto do_conv = [wbundle, kernel](const NCBKernParam& ncb_param,
const NCBKernIndex& ncb_index) {
kernel(wbundle, ncb_param, std::move(ncb_index));
};
ret_kerns.push_back({do_conv, {BATCH, GROUP, oh_tiles}});
return ret_kerns;
}
MIDOUT_END();
return {};
}
#endif
//vim: syntax=cpp.doxygen
\ No newline at end of file
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#ifdef __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/intrinsic_helper.h"
#include "src/arm_common/neon_struct.h"
#include "src/common/unroll_macro.h"
namespace megdnn {
namespace arm_common {
namespace direct_dotprod_nchw44 {
constexpr int SIMD_LEN = 16;
constexpr int IC_PACK_SIZE = 4;
constexpr int OC_PACK_SIZE = 4;
constexpr int filter_next_col =
IC_PACK_SIZE * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
template <int row, BiasMode bias_mode>
inline void init_ocx_ow8(int32x4_t c[][8], const int32_t* bias_ptr,
int oc_step) {
static_assert(row == 1 || row == 2 || row == 3, "Invalid OC number.");
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
#define BIAS_INIT(step, i) c[i][step] = vld1q_s32(bias_ptr + i * oc_step);
switch (row) {
case 3:
UNROLL_CALL_RAW(8, BIAS_INIT, 2);
case 2:
UNROLL_CALL_RAW(8, BIAS_INIT, 1);
default:
UNROLL_CALL_RAW(8, BIAS_INIT, 0);
}
#undef BIAS_INIT
} else {
#define BIAS_INIT(step, i) c[i][step] = vdupq_n_s32(0);
switch (row) {
case 3:
UNROLL_CALL_RAW(8, BIAS_INIT, 2);
case 2:
UNROLL_CALL_RAW(8, BIAS_INIT, 1);
default:
UNROLL_CALL_RAW(8, BIAS_INIT, 0);
}
#undef BIAS_INIT
}
}
#define cb11(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8));
#define cb21(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \
op(res[1][col], \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8));
#define cb31(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \
op(res[1][col], \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8)); \
op(res[2][col], reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + \
ld_dst_oc + col / 2 * 8));
#define cb12(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8));
#define cb22(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \
op({{res[1][2 * step], res[1][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8));
#define cb32(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \
op({{res[1][2 * step], res[1][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8)); \
op({{res[2][2 * step], res[2][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + 2 * ld_dst_oc + step * 8));
template <int row, int ow_remain, typename Op, typename T>
struct StoreOCxOWx {
static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr,
const int ld_dst_oc);
};
template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<1, ow_remain, Op, T> {
static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr,
const int ld_dst_oc) {
switch (ow_remain) {
case 8:
UNROLL_CALL_RAW(4, cb12);
break;
case 7:
cb11(6);
case 6:
UNROLL_CALL_RAW(3, cb12);
break;
case 5:
cb11(4);
case 4:
UNROLL_CALL_RAW(2, cb12);
break;
case 3:
cb11(2);
case 2:
UNROLL_CALL_RAW(1, cb12);
break;
case 1:
cb11(0);
default:
break;
}
}
};
template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<2, ow_remain, Op, T> {
static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr,
const int ld_dst_oc) {
switch (ow_remain) {
case 8:
UNROLL_CALL_RAW(4, cb22);
break;
case 7:
cb21(6);
case 6:
UNROLL_CALL_RAW(3, cb22);
break;
case 5:
cb21(4);
case 4:
UNROLL_CALL_RAW(2, cb22);
break;
case 3:
cb21(2);
case 2:
UNROLL_CALL_RAW(1, cb22);
break;
case 1:
cb21(0);
default:
break;
}
}
};
template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<3, ow_remain, Op, T> {
static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr,
const int ld_dst_oc) {
switch (ow_remain) {
case 8:
UNROLL_CALL_RAW(4, cb32);
break;
case 7:
cb31(6);
case 6:
UNROLL_CALL_RAW(3, cb32);
break;
case 5:
cb31(4);
case 4:
UNROLL_CALL_RAW(2, cb32);
break;
case 3:
cb31(2);
case 2:
UNROLL_CALL_RAW(1, cb32);
break;
case 1:
cb31(0);
default:
break;
}
}
};
#undef cb11
#undef cb21
#undef cb31
#undef cb12
#undef cb22
#undef cb32
template <int row, int ow_remain, typename Op, typename T>
inline void store_ocx_owx_remain_static(int32x4_t res[][8], const Op& op,
T* dst_ptr, const int ld_dst_oc) {
StoreOCxOWx<row, ow_remain, Op, T>::impl(res, op, dst_ptr, ld_dst_oc);
}
template <int res_row, int src_row, int src_start_idx, int weight_idx,
typename FUNC, typename T, typename T2, typename T3>
struct ShiftCalHelper {
static void impl(T& res, T2& src, T3& weight) {
#define cb(step) \
res[res_row][step] = FUNC::template impl<((src_start_idx + step) % 4)>( \
res[res_row][step], weight[weight_idx], \
src[src_row][(src_start_idx + step) / 4]);
UNROLL_CALL_RAW(8, cb);
#undef cb
}
};
template <int res_row, int src_row, int src_start_idx, int weight_idx,
typename FUNC, typename T, typename T2, typename T3>
inline void cal_helper(T& res, T2& src, T3& weight) {
ShiftCalHelper<res_row, src_row, src_start_idx, weight_idx, FUNC, T, T2,
T3>::impl(res, src, weight);
};
/**
* oc12_owx(m = 12, n = x) and oc8_owx(m = 8, n = x) and oc4_owx(m = 4, n = x)
* gemm like kernel
* */
template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int ow_remain, int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44 {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op);
};
template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44<dst_type, 1, bias_mode, Op, ow_remain, filter_size,
oc_interval, ow_interval> {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int filter_next_row =
FW * OC_PACK_SIZE *
IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int filter_next_4oc =
FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int src_next_ic = ih * iw;
const int src_next_row = iw * IC_PACK_SIZE;
constexpr int NSRC = (ow_interval + filter_size - 1) / 4 + 1;
constexpr int LOOP = oc_interval / 4;
int32x4_t res[3][ow_interval];
init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE);
for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) {
const int8_t* i_src = src + ic_idx * src_next_ic;
const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE;
for (int fh_idx = 0; fh_idx < FH; ++fh_idx) {
int8x16_t src[1][4];
int8x16_t weight[3];
load_helper<NSRC, 0, SIMD_LEN, 1, Vld1q_s8>(src, i_src, 0);
//! do not use switch order 3,2,1 because it will slow the speed.
#define CALC_PART(step) \
switch (LOOP) { \
case 1: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \
break; \
case 2: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \
break; \
case 3: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \
weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \
filter_next_col * step); \
cal_helper<2, 0, step, 2, Vdotq_laneq_s32>(res, src, weight); \
break; \
default: \
break; \
}
switch (filter_size) {
case 2:
UNROLL_CALL_RAW(2, CALC_PART);
break;
case 3:
UNROLL_CALL_RAW(3, CALC_PART);
break;
case 5:
UNROLL_CALL_RAW(5, CALC_PART);
break;
case 7:
UNROLL_CALL_RAW(7, CALC_PART);
break;
default:
break;
}
#undef CALC_PART
i_filter += filter_next_row;
i_src += src_next_row;
}
}
store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst,
dst_step);
}
};
template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size,
oc_interval, ow_interval> {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int filter_next_row =
FW * OC_PACK_SIZE *
IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int filter_next_4oc =
FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int src_next_ic = ih * iw;
const int src_next_row = iw * IC_PACK_SIZE;
constexpr int NSRC = (ow_interval * 2 + filter_size - 3) / 8 + 1;
constexpr int LOOP = oc_interval / 4;
int32x4_t res[3][ow_interval];
init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE);
for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) {
const int8_t* i_src = src + ic_idx * src_next_ic;
const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE;
for (int fh_idx = 0; fh_idx < FH; ++fh_idx) {
int8x16_t src[2][3];
int8x16_t weight[3];
const int offset = megdnn::div_ceil(iw, 2) * IC_PACK_SIZE;
load_helper<NSRC, 0, SIMD_LEN, 2, Vld1q_s8>(src, i_src, offset);
//! do not use switch order 3,2,1 because it will slow the speed.
#define CALC_PART(step) \
switch (LOOP) { \
case 1: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \
weight); \
break; \
case 2: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \
weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \
weight); \
break; \
case 3: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \
weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \
weight); \
weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \
filter_next_col * step); \
cal_helper<2, step % 2, step / 2, 2, Vdotq_laneq_s32>(res, src, \
weight); \
break; \
default: \
break; \
}
switch (filter_size) {
case 2:
UNROLL_CALL_RAW(2, CALC_PART);
break;
case 3:
UNROLL_CALL_RAW(3, CALC_PART);
break;
case 5:
UNROLL_CALL_RAW(5, CALC_PART);
break;
case 7:
UNROLL_CALL_RAW(7, CALC_PART);
break;
default:
break;
}
#undef CALC_PART
i_filter += filter_next_row;
i_src += src_next_row;
}
}
store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst,
dst_step);
}
};
} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn
#endif
//vim: syntax=cpp.doxygen
......@@ -536,6 +536,7 @@ inline void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr,
#undef BAIS_INIT
}
}
/////////////////////////init_ocx_ow8////////////////////
inline float32x4_t neon_vdupq_n(float val) {
......
......@@ -64,6 +64,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoDotU8DirectStride1 du8_direct_stride1_small_group{false};
AlgoDotU8DirectStride2 du8_direct_stride2_large_group{true};
AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false};
AlgoDotS8Direct_NCHW44 ds8_direct_nchw44;
#endif
AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44;
......@@ -103,6 +105,8 @@ public:
direct_algos.emplace_back(&du8_direct_stride1_small_group);
direct_algos.emplace_back(&du8_direct_stride2_large_group);
direct_algos.emplace_back(&du8_direct_stride2_small_group);
direct_algos.emplace_back(&ds8_direct_nchw44);
#endif
direct_algos.emplace_back(&qu8_direct_stride2_large_group);
direct_algos.emplace_back(&qu8_direct_stride2_small_group);
......
......@@ -67,6 +67,8 @@ private:
class AlgoDotS8DirectStride2;
class AlgoDotU8DirectStride1;
class AlgoDotU8DirectStride2;
class AlgoDotS8Direct_NCHW44;
#endif
class AlgoF32Direct;
class AlgoF32DirectStride1;
......
......@@ -1809,6 +1809,81 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) {
used1 / used0);
}
}
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) {
using namespace conv_bias;
std::vector<TestArg> args;
auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel,
size_t p, size_t stride, NonlineMode nonline_mode) {
if (w + 2 * p < kernel || h + 2 * p < kernel)
return;
param::ConvBias param;
param.stride_h = stride;
param.stride_w = stride;
param.pad_h = p;
param.pad_w = p;
param.nonlineMode = nonline_mode;
param.format = param::ConvBias::Format::NCHW44_DOT;
//! channel bias
args.emplace_back(param, TensorShape{1, ic/4, h, w, 4},
TensorShape{oc/4, ic/4, kernel, kernel, 4, 4},
TensorShape{1, oc/4, 1, 1, 4});
};
for (size_t stride : {1, 2})
for (size_t kernel : {2, 3, 5, 7})
for(size_t oc : {64})
for (NonlineMode nonline_mode : {NonlineMode::IDENTITY}) {
run(oc, oc, 56, 56, kernel, kernel / 2, stride, nonline_mode);
}
constexpr size_t RUN = 50;
Benchmarker<ConvBias> benchmark0(handle());
benchmark0.set_dtype(0, dtype::QuantizedS8(2.5f))
.set_dtype(1, dtype::QuantizedS8(2.5f))
.set_dtype(2, dtype::QuantizedS32(6.25f))
.set_dtype(4, dtype::QuantizedS8(60.25f));
benchmark0.set_display(false);
benchmark0.set_times(RUN);
benchmark0.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("ARMDOTS8DIRECT_NCHW44"));
Benchmarker<ConvBias> benchmark1(handle());
benchmark1.set_dtype(0, dtype::QuantizedS8(2.5f))
.set_dtype(1, dtype::QuantizedS8(2.5f))
.set_dtype(2, dtype::QuantizedS32(6.25f))
.set_dtype(4, dtype::QuantizedS8(60.25f));
benchmark1.set_display(false);
benchmark1.set_times(RUN);
for (auto&& arg : args) {
TensorLayout dst_layout;
auto opr = handle()->create_operator<ConvBias>();
opr->param() = arg.param;
opr->deduce_layout({arg.src, dtype::Int8()},
{arg.filter, dtype::Int8()},
{arg.bias, dtype::Int32()}, {}, dst_layout);
//! dst.nr_elems * IC * FH * FW * 2
float computations = dst_layout.total_nr_elems() * arg.filter[1] *
arg.filter[2] * arg.filter[3] * 8.0 /
(1024 * 1024 * 1024) * 1e3;
auto used0 = benchmark0.set_param(arg.param).exec(
{arg.src, arg.filter, arg.bias, {}, {}}) /
RUN;
auto used1 = benchmark1.set_param(arg.param).exec(
{arg.src, arg.filter, arg.bias, {}, {}}) /
RUN;
printf("%s %s: Direct use: %f ms %f Gflops normal: %f ms %f GFlops "
"speedup: %f\n",
arg.src.to_string().c_str(), arg.filter.to_string().c_str(),
used0, computations / used0, used1, computations / used1,
used1 / used0);
}
}
#endif
#endif
......
......@@ -155,7 +155,7 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
if (support_sigmoid) {
nonlinemode.emplace_back(NLMode::SIGMOID);
}
std::vector<megdnn::BiasMode> bias_mode = {
megdnn::BiasMode::BROADCAST_CHANNEL_BIAS};
if (no_bias) {
......@@ -672,6 +672,63 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_SMALL_GROUP) {
get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
"ARMDOTU8STRD2_SMALL_GROUP");
}
/******************************dot int8x8x8 nchw44 ***********************/
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8) {
using namespace conv_bias;
std::vector<TestArg> args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 1);
for (auto&& arg : args)
arg.param.format = param::ConvBias::Format::NCHW44_DOT;
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x32) {
using namespace conv_bias;
std::vector<TestArg> args =
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, true, true);
for (auto&& arg : args)
arg.param.format = param::ConvBias::Format::NCHW44_DOT;
checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_8x8x32) {
using namespace conv_bias;
std::vector<TestArg> args =
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, true, true);
for (auto&& arg : args)
arg.param.format = param::ConvBias::Format::NCHW44_DOT;
checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x8) {
using namespace conv_bias;
//! test qint8x8x8
std::vector<TestArg> args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 2);
for (auto&& arg : args)
arg.param.format = param::ConvBias::Format::NCHW44_DOT;
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x32) {
using namespace conv_bias;
//! test qint8x8x8
std::vector<TestArg> args =
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, true, true);
for (auto&& arg : args)
arg.param.format = param::ConvBias::Format::NCHW44_DOT;
checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_8x8x32) {
using namespace conv_bias;
//! test qint8x8x8
std::vector<TestArg> args =
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, true, true);
for (auto&& arg : args)
arg.param.format = param::ConvBias::Format::NCHW44_DOT;
checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
}
#endif
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册