提交 bcf5691d 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(dnn/arm): add nchw_nchw44 i8i8i16 2x2 3x3 5x5 7x7 s1 s2 conv

GitOrigin-RevId: 8ef1541665121c01e3e934629b16c090e804cd2c
上级 c7b6ef35
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
......@@ -48,6 +49,7 @@ class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase {
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids);
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "I8816STRD2"; }
......@@ -84,6 +86,21 @@ public:
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
public:
AlgoI8x8x16DirectNCHWNCHW44() {}
bool is_reproducible() const override { return true; }
const char* name() const override { return "I8816_CONV_NCHW_NCHW44"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
};
} // namespace arm_common
} // namespace megdnn
......
/**
* \file
dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/int8x8x16/algos.h"
#include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"
#include "midout.h"
using namespace megdnn;
using namespace arm_common;
using conv_fun = std::function<void(
const WorkspaceBundle& bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_i8i8i16_nchw_nchw44)
namespace {
static inline size_t get_perthread_cache_bytes(const int ic, const int ih2,
const int iw2) {
//! border_size is used to avoid read illegal memory
constexpr int iw_expand = 8;
int border_size = 64 * 2;
return ic * ih2 * iw2 * sizeof(int8_t) * iw_expand + border_size;
}
static void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2,
int& iw2, int& oh2, int& ow2) {
int iw = param.isz[1];
int oh = param.osz[0];
int ow = param.osz[1];
oh2 = oh;
ow2 = ow;
constexpr int iw_expand = 8;
auto&& fm = param.filter_meta;
const int stride_h = static_cast<int>(fm.stride[0]);
const int filter_h = static_cast<int>(fm.spatial[0]);
const int ic = fm.icpg;
iw2 = iw + 2 * static_cast<int>(fm.padding[1]);
int block_oh = l2_block_helper(param.nr_threads, oh,
ic * iw2 * stride_h * iw_expand);
ih2 = block_oh * stride_h + filter_h - stride_h;
}
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
int group = fm.group;
int ic = fm.icpg;
int oc = fm.ocpg;
int fh = fm.spatial[0];
int fw = fm.spatial[1];
int stride = fm.stride[0];
int ih2, iw2, oh2, ow2;
get_rectified_size(param, ih2, iw2, oh2, ow2);
constexpr int pack_oc = 8;
const int weight_expand = stride == 1 ? 2 : 1;
size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2);
size_t weight_size = group * round_up(oc, 8) * ic * fh * fw *
sizeof(int8_t) * weight_expand;
size_t bisa_size = 0;
if (param.bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS &&
oc % pack_oc != 0) {
bisa_size = round_up(oc, 8) * sizeof(int16_t);
}
return {nullptr, {src_size * param.nr_threads, weight_size, bisa_size}};
};
static inline void copy_pad_src(int8_t* sptr_base, const int8_t* sptr_origin,
int ph, int pw, int pad_right, int ih, int iw,
int iw2, int pad_top, int pad_bottom, int ic,
int ic_stride) {
constexpr int iw_expand = 8;
MEGDNN_MARK_USED_VAR(ph);
rep(ic_idx, ic) {
const int8_t* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0, sizeof(int8_t) * iw2 * pad_top * iw_expand);
sptr_base += iw2 * pad_top * iw_expand;
rep(ih_idx, ih) {
memset(sptr_base, 0, sizeof(int8_t) * pw * iw_expand);
sptr_base += pw * iw_expand;
memcpy_s8_dup(sptr_base, sptr, iw);
sptr_base += iw * iw_expand;
sptr += iw;
memset(sptr_base, 0, sizeof(int8_t) * pad_right * iw_expand);
sptr_base += pad_right * iw_expand;
}
memset(sptr_base, 0, sizeof(int8_t) * iw2 * pad_bottom * iw_expand);
sptr_base += iw2 * pad_bottom * iw_expand;
}
}
static void pack_weight(const WorkspaceBundle& bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index) {
const int group_id = ncb_index.ndrange_id[0];
int fh = kern_param.filter_meta.spatial[0];
int fw = kern_param.filter_meta.spatial[1];
int oc = kern_param.filter_meta.ocpg;
int ic = kern_param.filter_meta.icpg;
int oc_block = oc;
int stride = kern_param.filter_meta.stride[0];
constexpr int oc_idx = 0;
const int8_t* fptr =
kern_param.filter<dt_int8>(group_id) + oc_idx * fh * fw * ic;
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) +
group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw;
switch (stride) {
case 1:
i8i8i16_direct_nchw_nchw44::pack_weight_int8_nchw_nchw44<1>(
fptr, packed_weight, oc_block, fh, fw, ic);
break;
case 2:
i8i8i16_direct_nchw_nchw44::pack_weight_int8_nchw_nchw44<2>(
fptr, packed_weight, oc_block, fh, fw, ic);
break;
default:
break;
}
constexpr int pack_oc = 8;
if (kern_param.bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS &&
oc % pack_oc != 0) {
auto packed_bias = reinterpret_cast<int16_t*>(bundle.get(2));
memcpy(packed_bias, kern_param.bias_ptr,
round_up(oc, 8) * sizeof(int16_t));
}
}
template <size_t filter_size, BiasMode bias_mode, typename Op, size_t stride>
static void do_conv_kern(const WorkspaceBundle& bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange&, const CpuNDRange&) {
const int oh = kern_param.osz[0];
const int ow = kern_param.osz[1];
const int fh = kern_param.filter_meta.spatial[0];
const int fw = kern_param.filter_meta.spatial[1];
const int ic = kern_param.filter_meta.icpg;
const int oc = kern_param.filter_meta.ocpg;
const int ih = kern_param.isz[0];
const int iw = kern_param.isz[1];
const int stride_h = stride;
const int ph = kern_param.filter_meta.padding[0];
const int pw = kern_param.filter_meta.padding[1];
int ih2 = 0;
int iw2 = 0;
int oh2 = 0;
int ow2 = 0;
get_rectified_size(kern_param, ih2, iw2, oh2, ow2);
constexpr int src_expand = 8;
constexpr int weight_expand = stride == 1 ? 2 : 1;
constexpr int pack_c = 4;
const int batch_id = ncb_index.ndrange_id[0];
const int group_id = ncb_index.ndrange_id[1];
constexpr int oc_idx = 0;
int oc_block = oc;
int oh_block = l2_block_helper(kern_param.nr_threads, oh,
ic * iw2 * stride_h * src_expand);
const int oh_idx = ncb_index.ndrange_id[2];
const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block);
const int ih_real = oh_block_real * stride_h + fh - stride_h;
const int src_top_pad = std::max(ph - oh_idx * oh_block * stride_h, 0);
const int src_bottom_pad = std::max(
(oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph,
0);
const int remain_right_pad = std::max(iw2 - iw - pw, 0);
const int src_offset = std::max(oh_idx * oh_block * stride_h - ph, 0) * iw;
const int8_t* origin_sptr =
static_cast<const int8_t*>(
kern_param.src<int8_t>(batch_id, group_id, 0, 1, 1)) +
src_offset;
const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2);
int8_t* sptr = reinterpret_cast<int8_t*>((int8_t*)bundle.get(0) +
ncb_index.thread_id * src_size);
copy_pad_src(sptr, origin_sptr, ph, pw, remain_right_pad,
ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad,
src_bottom_pad, ic, ih * iw);
//! pack weight
auto packed_weight =
reinterpret_cast<int8_t*>(bundle.get(1)) +
(group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw) *
weight_expand;
//! get param
int16_t* dst = kern_param.dst<int16_t>(batch_id, group_id) +
oh_idx * oh_block * ow * pack_c;
const int16_t* bptr =
kern_param.bias<dt_int16>(batch_id, group_id) + oc_idx;
constexpr int pack_oc = 8;
if (kern_param.bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS &&
oc % pack_oc != 0) {
bptr = reinterpret_cast<int16_t*>(bundle.get(2));
}
Op op;
i8i8i16_direct_nchw_nchw44::conv_direct_i8i8i16_nchw_nchw44<
bias_mode, Op, filter_size, stride>(
sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih_real, iw2,
oh, oh_block_real, ow, op, ph, pw);
}
} // namespace
bool ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
auto fh = fm.spatial[0];
int oc = fm.ocpg;
bool ok_type = ((param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
(param.dst_type.enumv() == DTypeEnum::Int16))) &&
(fm.format == param::Convolution::Format::NCHW44);
bool ok_src_dst = fm.icpg < 4 && (oc % 4 == 0 && oc >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] &&
(fh == 2 || fh == 3 || fh == 5 || fh == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 2 || fm.stride[0] == 1);
bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS &&
param.nonlineMode == param::ConvBias::NonlineMode::IDENTITY;
bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv;
return avaible;
}
size_t ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::dispatch_kerns(
const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
const int batch = param.n;
const int group = fm.group;
WorkspaceBundle bundle = get_bundle(param);
conv_fun do_conv_fun = nullptr;
//! NOTE: remain_w is not used to gen hash of midout for compatible with
//! shape runtime
#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_i8i8i16_nchw_nchw44, \
midout_iv(#stride #filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(stride, filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, NoneOp<dt_int16>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_BIAS_MODE_PARAM(stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN(stride) \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(stride, 2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(stride, 3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(stride, 5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(stride, 7) \
break; \
default: \
megdnn_assert(0); \
break; \
}
switch (param.filter_meta.stride[0]) {
case 1:
DISPATCH_CONV_KERN(1);
break;
case 2:
DISPATCH_CONV_KERN(2);
break;
default:
megdnn_throw(ssprintf("Unsupport stride size %u for the first conv",
param.filter_meta.stride[0])
.c_str());
break;
}
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
#undef GET_OP_PARAM
#undef GET_BIAS_MODE_PARAM
#undef DISPATCH_CONV_KERN
megdnn_assert(do_conv_fun);
constexpr int iw_expand = 8;
SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
int oh = param.osz[0];
int ih2, iw2, oh2, ow2;
const int stride_h = static_cast<int>(fm.stride[0]);
const int ic = fm.icpg;
get_rectified_size(param, ih2, iw2, oh2, ow2);
int oh_block = l2_block_helper(param.nr_threads, oh,
ic * iw2 * stride_h * iw_expand);
auto do_pack_weight = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr);
pack_weight(bundle, kern_param, ncb_index);
};
ret_kerns.push_back({do_pack_weight, {static_cast<size_t>(group)}});
CpuNDRange ncb_range = {static_cast<size_t>(batch),
static_cast<size_t>(group),
static_cast<size_t>(div_ceil(oh, oh_block))};
auto do_conv = [bundle, do_conv_fun, ncb_range](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr);
do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id,
ncb_range);
};
ret_kerns.push_back({do_conv, ncb_range});
return ret_kerns;
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace i8i8i16_direct_nchw_nchw44 {
/**
* @brief
* stride2 from [oc / 4, fh, fw, ic, 4] to [oc / 8, ic, fh, fw, 8]
* stride1 from [oc / 4, fh, fw, ic, 4] to [oc / 8, ic, fh, fw, 16]
* @param in_ptr
* @param dst_ptr
* @param oc
* @param kh
* @param kw
* @param ic
*/
template <int stride>
inline void pack_weight_int8_nchw_nchw44(const int8_t* in_ptr, int8_t* dst_ptr,
const int oc, const int kh,
const int kw, const int ic);
template <>
inline void pack_weight_int8_nchw_nchw44<2>(const int8_t* in_ptr,
int8_t* dst_ptr, const int oc,
const int kh, const int kw,
const int ic) {
constexpr int in_pack_oc = 4;
constexpr int out_pack_oc = 8;
constexpr int out_pair = 2;
const int filter_size = kh * kw;
const int in_oc_stride = filter_size * ic;
const int oc_remain = oc % out_pack_oc;
const int oc_end = oc - oc_remain;
int32_t* pack_dst_ptr = (int32_t*)dst_ptr;
for (int oc_idx = 0; oc_idx < oc_end; oc_idx += out_pack_oc) {
const int32_t* in_oc0_ptr = (int32_t*)(in_ptr + oc_idx * in_oc_stride);
const int32_t* in_oc1_ptr =
(int32_t*)(in_ptr + (oc_idx + in_pack_oc) * in_oc_stride);
for (int filter_idx = 0; filter_idx < filter_size; ++filter_idx) {
for (int ic_idx = 0; ic_idx < ic; ++ic_idx) {
int32_t temp0 = *in_oc0_ptr++;
int32_t temp1 = *in_oc1_ptr++;
pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair +
0] = temp0;
pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair +
1] = temp1;
}
}
pack_dst_ptr += ic * filter_size * out_pair;
}
if (oc_remain > 0) {
const int32_t* in_oc0_ptr = (int32_t*)(in_ptr + oc_end * in_oc_stride);
for (int filter_idx = 0; filter_idx < filter_size; ++filter_idx) {
for (int ic_idx = 0; ic_idx < ic; ++ic_idx) {
int32_t temp0 = *in_oc0_ptr++;
int32_t temp1 = 0;
pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair +
0] = temp0;
pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair +
1] = temp1;
}
}
}
}
template <>
inline void pack_weight_int8_nchw_nchw44<1>(const int8_t* in_ptr,
int8_t* dst_ptr, const int oc,
const int kh, const int kw,
const int ic) {
constexpr int in_pack_oc = 4;
constexpr int out_pack_oc = 8;
constexpr int out_pair = 4;
const int filter_size = kh * kw;
const int in_oc_stride = filter_size * ic;
const int oc_remain = oc % out_pack_oc;
const int oc_end = oc - oc_remain;
int32_t* pack_dst_ptr = (int32_t*)dst_ptr;
for (int oc_idx = 0; oc_idx < oc_end; oc_idx += out_pack_oc) {
const int32_t* in_oc0_ptr = (int32_t*)(in_ptr + oc_idx * in_oc_stride);
const int32_t* in_oc1_ptr =
(int32_t*)(in_ptr + (oc_idx + in_pack_oc) * in_oc_stride);
for (int filter_idx = 0; filter_idx < filter_size; ++filter_idx) {
for (int ic_idx = 0; ic_idx < ic; ++ic_idx) {
int32_t temp0 = *in_oc0_ptr++;
int32_t temp1 = *in_oc1_ptr++;
pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair +
0] = temp0;
pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair +
1] = temp1;
pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair +
2] = temp0;
pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair +
3] = temp1;
}
}
pack_dst_ptr += ic * filter_size * out_pair;
}
if (oc_remain > 0) {
const int32_t* in_oc0_ptr = (int32_t*)(in_ptr + oc_end * in_oc_stride);
for (int filter_idx = 0; filter_idx < filter_size; ++filter_idx) {
for (int ic_idx = 0; ic_idx < ic; ++ic_idx) {
int32_t temp0 = *in_oc0_ptr++;
int32_t temp1 = 0;
pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair +
0] = temp0;
pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair +
1] = temp1;
pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair +
2] = temp0;
pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair +
3] = temp1;
}
}
}
}
template <BiasMode bias_mode, typename Op, int filter_size, int stride>
void conv_direct_i8i8i16_nchw_nchw44(const int8_t* src, const int8_t* filter,
const int16_t* bias, int8_t*, int16_t* dst,
const int oc, const int ic, const int ih,
const int iw, const int oh,
const int oh_block, const int ow,
const Op& op, const int, const int);
} // namespace i8i8i16_direct_nchw_nchw44
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file
* dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h"
INSTANCE_CONV(2, 1);
INSTANCE_CONV(3, 1);
INSTANCE_CONV(5, 1);
INSTANCE_CONV(7, 1);
// vim: syntax=cpp.doxygen
\ No newline at end of file
/**
* \file
* dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h"
INSTANCE_CONV(2, 2);
INSTANCE_CONV(3, 2);
INSTANCE_CONV(5, 2);
INSTANCE_CONV(7, 2);
// vim: syntax=cpp.doxygen
\ No newline at end of file
......@@ -375,6 +375,89 @@ __ai void store_ocx_ow8_remain_static_dt(T& c, const Op& op, T2 dst_ptr,
StoreOcxOw8Remain<c_dim, ow_remain, Op, T, T2, T3>::impl(c, op, dst_ptr,
ld_dst_oc);
}
////////////////////Store_OCX_OW8_Remain/////////////////////////
template <int c_dim, int ow_block, int nr_group, int out_group, typename T,
typename T2, typename T3>
struct StoreOc4Ow8Remain {
static __ai void impl(T& c, T2 dst_ptr, int ld_dst_oc, const int ow_remain);
};
#define cb(step) \
vst1q_lane_s64((int64_t*)(dst_ptr + step * 4), \
vreinterpretq_s64_s16(c[0][step]), 0); \
vst1q_lane_s64((int64_t*)(dst_ptr + step * 4 + ld_dst_oc), \
vreinterpretq_s64_s16(c[0][step]), 1);
#define cb2(step) \
vst1q_lane_s64((int64_t*)(dst_ptr + step * 4), \
vreinterpretq_s64_s16(c[0][step]), 0);
#define cb_case(step) \
case step: \
UNROLL_CALL_RAW(step, cb); \
break;
#define cb_case2(step) \
case step: \
UNROLL_CALL_RAW(step, cb2); \
break;
template <typename T, typename T2, typename T3>
struct StoreOc4Ow8Remain<1, 8, 2, 2, T, T2, T3> {
static __ai void impl(T& c, T2 dst_ptr, int ld_dst_oc,
const int ow_remain) {
if (ow_remain == 8) {
UNROLL_CALL_RAW(8, cb)
} else {
switch (ow_remain) {
cb_case(7);
cb_case(6);
cb_case(5);
cb_case(4);
cb_case(3);
cb_case(2);
cb_case(1);
default:
break;
}
}
}
};
template <typename T, typename T2, typename T3>
struct StoreOc4Ow8Remain<1, 8, 2, 1, T, T2, T3> {
static __ai void impl(T& c, T2 dst_ptr, int, const int ow_remain) {
if (ow_remain == 8) {
UNROLL_CALL_RAW(8, cb2)
} else {
switch (ow_remain) {
cb_case2(7);
cb_case2(6);
cb_case2(5);
cb_case2(4);
cb_case2(3);
cb_case2(2);
cb_case2(1);
default:
break;
}
}
}
};
#undef cb
#undef cb2
#undef cb_case
#undef cb_case2
template <int c_dim, int ow_block, int nr_group, int out_group, typename T,
typename T2>
__ai void store_oc4_ow8_remain_static(T& c, T2 dst_ptr, const int ld_dst_oc,
const int ow_remain) {
StoreOc4Ow8Remain<c_dim, ow_block, nr_group, out_group, T, T2, T2>::impl(
c, dst_ptr, ld_dst_oc, ow_remain);
}
////////////////////Store_OC8_OW8_Remain/////////////////////////
template <int ow_remain, typename Op>
......@@ -548,13 +631,18 @@ __ai float32x4_t neon_vdupq_n(float val) {
__ai int32x4_t neon_vdupq_n(int val) {
return vdupq_n_s32(val);
}
__ai int16x8_t neon_vdupq_n(int16_t val) {
return vdupq_n_s16(val);
}
__ai float32x4_t neon_vld1q(const float* ptr) {
return vld1q_f32(ptr);
}
__ai int32x4_t neon_vld1q(const int* ptr) {
return vld1q_s32(ptr);
}
__ai int16x8_t neon_vld1q(const int16_t* ptr) {
return vld1q_s16(ptr);
}
template <int c_dim, BiasMode bias_mode, int ow_block, typename T, typename T2>
struct InitOcxOw8 {
......@@ -725,6 +813,39 @@ __ai void init_ocx_ow4(T& c, const int32_t* bias_ptr, int oc_step) {
}
///////////////////////////////////////
static inline void memcpy_s8_dup(int8_t* outptr, const int8_t* inptr,
int count) {
constexpr int expand = 8;
for (; count >= 8; count -= 8) {
int8x8_t in = vld1_s8(inptr);
int8x8_t in0 = vdup_lane_s8(in, 0);
int8x8_t in1 = vdup_lane_s8(in, 1);
int8x8_t in2 = vdup_lane_s8(in, 2);
int8x8_t in3 = vdup_lane_s8(in, 3);
int8x8_t in4 = vdup_lane_s8(in, 4);
int8x8_t in5 = vdup_lane_s8(in, 5);
int8x8_t in6 = vdup_lane_s8(in, 6);
int8x8_t in7 = vdup_lane_s8(in, 7);
vst1_s8(outptr + 0 * 8, in0);
vst1_s8(outptr + 1 * 8, in1);
vst1_s8(outptr + 2 * 8, in2);
vst1_s8(outptr + 3 * 8, in3);
vst1_s8(outptr + 4 * 8, in4);
vst1_s8(outptr + 5 * 8, in5);
vst1_s8(outptr + 6 * 8, in6);
vst1_s8(outptr + 7 * 8, in7);
inptr += 8;
outptr += 8 * expand;
}
for (; count > 0; --count) {
int8x8_t in0 = vld1_dup_s8(inptr++);
vst1_s8(outptr, in0);
outptr += 1 * expand;
}
}
} // namespace
} // namespace megdnn
#undef __ai
......
......@@ -71,6 +71,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoI8x8x16Direct i8x8x16_direct;
AlgoI8x8x16Stride2 i8x8x16_stride2;
AlgoI8x8x16Stride2Filter2 i8x8x16_stride2_filter2;
AlgoI8x8x16DirectNCHWNCHW44 i8x8x16_nchw_nchw44;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoF16Direct f16_direct;
AlgoF16DirectStride1 f16_direct_stride1;
......@@ -107,6 +108,7 @@ public:
direct_algos.emplace_back(&i8x8x16_direct);
direct_algos.emplace_back(&i8x8x16_stride2_filter2);
direct_algos.emplace_back(&i8x8x16_stride2);
direct_algos.emplace_back(&i8x8x16_nchw_nchw44);
direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44);
direct_algos.emplace_back(&f32_chanel_wise_nchw44);
......
......@@ -81,6 +81,7 @@ private:
class AlgoI8x8x16Direct;
class AlgoI8x8x16Stride2;
class AlgoI8x8x16Stride2Filter2;
class AlgoI8x8x16DirectNCHWNCHW44;
class AlgoS8WinogradF23_8x8;
class AlgoS8CF32WinogradF23_4x4_NCHW44;
class AlgoS8WinogradF23_8x8_NCHW44;
......
......@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/fallback/convolution/opr_impl.h"
#include "src/common/algo_chooser.h"
#include "src/common/metahelper.h"
#include "src/common/opr_delegate.h"
......@@ -19,6 +18,7 @@
#include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h"
#include "src/fallback/conv_bias/im2col/algos.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/fallback/convolution/opr_impl.h"
#include "src/naive/convolution/algorithms.h"
#include "src/naive/handle.h"
......@@ -479,7 +479,8 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_pack_id,
//! four format of weight layout
//! 1. {oc/4, ic/4, fh, fw, 4, 4},
//! 2. {g, oc/4, ic/4, fh, fw, 4, 4},
//! 3. {g/4, fh, fw, 1, 1, 4}, 4. {oc/4, fh, fw, ic, 4}
//! 3. {g/4, fh, fw, 1, 1, 4},
//! 4. {oc/4, fh, fw, ic, 4}
megdnn_assert((icpg % 4 == 0 && ocpg % 4 == 0) ||
(group % 4 == 0 && icpg == 1 && ocpg == 1 &&
pack_group_size > 1) ||
......
......@@ -116,7 +116,8 @@ CB_TEST(H_SWISH);
#if MEGDNN_WITH_BENCHMARK
static void benchmark_convbias(Handle* handle, std::string int_name,
std::string float_name, bool is_fp32 = false) {
std::string float_name, bool is_fp32 = false,
bool is_8x8x16 = false) {
constexpr size_t RUNS = 30;
Benchmarker<ConvBias> benchmarker_int(handle);
......@@ -142,6 +143,13 @@ static void benchmark_convbias(Handle* handle, std::string int_name,
.set_dtype(2, dtype::Float32())
.set_dtype(4, dtype::Float32())
.set_display(false);
} else if (is_8x8x16) {
benchmarker_nchw44.set_times(RUNS)
.set_dtype(0, dtype::Int8())
.set_dtype(1, dtype::Int8())
.set_dtype(2, dtype::Int16())
.set_dtype(4, dtype::Int16())
.set_display(false);
} else {
benchmarker_nchw44.set_times(RUNS)
.set_dtype(0, dtype::QuantizedS8(2.5))
......@@ -163,6 +171,9 @@ static void benchmark_convbias(Handle* handle, std::string int_name,
size_t FS, size_t stride, bool input_nchw = false) {
param::ConvBias param;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
if (is_8x8x16) {
param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY;
}
param.stride_h = stride;
param.stride_w = stride;
......@@ -235,6 +246,7 @@ static void benchmark_convbias(Handle* handle, std::string int_name,
run(1, 512, 512, 7, 7, 3, 1, false);
} else {
run(1, 1, 4, 112, 112, 2, 2, true);
run(1, 3, 8, 224, 224, 3, 2, true);
run(1, 3, 32, 224, 224, 3, 2, true);
run(1, 3, 32, 224, 224, 5, 2, true);
run(1, 3, 64, 224, 224, 7, 2, true);
......@@ -271,11 +283,15 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) {
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", true);
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", false);
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", false, true);
#else
benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384",
"IM2COLMATMUL:ARMV7_F32:192", true);
benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384",
"IM2COLMATMUL:ARMV7_F32:192", false);
benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384",
"IM2COLMATMUL:ARMV7_F32:192", false, true);
#endif
}
TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) {
......
......@@ -449,7 +449,18 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2) {
get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
"I8816STRD2");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S2) {
checker_conv_bias_int8x8x16(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true,
true),
handle(), "I8816_CONV_NCHW_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S1) {
checker_conv_bias_int8x8x16(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true,
true),
handle(), "I8816_CONV_NCHW_NCHW44");
}
/**********************************algo 8-8-32 direct************************/
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1) {
checker_conv_bias_int8x8x32_multi(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册