提交 7b0dbe6a 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/arm): fix stride 1 support for int8 nchw_nchw44

GitOrigin-RevId: 9d718eb7a4dae3c2724ea07ba2b639fbfb319f78
上级 198f3eb5
...@@ -37,7 +37,7 @@ static inline size_t get_perthread_cache_bytes(const int ic, const int ih2, ...@@ -37,7 +37,7 @@ static inline size_t get_perthread_cache_bytes(const int ic, const int ih2,
static void get_rectified_size( static void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2,
int& iw2, int& oh2, int& ow2) { int& iw2, int& oh2, int& ow2) {
constexpr int cacheline = 64 / sizeof(float); constexpr int nr_elements_in_cacheline = 64 / sizeof(float);
int ic = param.filter_meta.icpg; int ic = param.filter_meta.icpg;
int iw = param.isz[1]; int iw = param.isz[1];
int oh = param.osz[0]; int oh = param.osz[0];
...@@ -52,7 +52,8 @@ static void get_rectified_size( ...@@ -52,7 +52,8 @@ static void get_rectified_size(
int block_oh = l2_block_helper(param.nr_threads, oh, int block_oh = l2_block_helper(param.nr_threads, oh,
ic * iw * sizeof(float) * stride_h); ic * iw * sizeof(float) * stride_h);
ih2 = block_oh * stride_h + filter_h - stride_h; ih2 = block_oh * stride_h + filter_h - stride_h;
iw2 = round_up(iw + 2 * static_cast<int>(fm.padding[1]), cacheline); iw2 = round_up(iw + 2 * static_cast<int>(fm.padding[1]),
nr_elements_in_cacheline);
} }
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
......
...@@ -90,9 +90,9 @@ public: ...@@ -90,9 +90,9 @@ public:
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
}; };
class ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase {
public: public:
AlgoS8DirectStride2NCHWNCHW44() {} AlgoS8DirectNCHWNCHW44() {}
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { return "S8_CONV_NCHW_NCHW44"; } const char* name() const override { return "S8_CONV_NCHW_NCHW44"; }
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
......
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace conv_bias {
#define KERN(stride, i, layout) \
template <BiasMode bias_mode, typename Op> \
void conv_direct_##stride##_##i##x##i##_int8_nchw_##layout( \
const int8_t* src, const int8_t* filter, const int32_t* bias, \
int32_t* temp, int8_t* dst, const size_t OC, const size_t IC, \
const size_t IH, const size_t IW, const size_t OH, \
const size_t OW, const Op& op);
KERN(stride2, 2, nchw44)
KERN(stride2, 3, nchw44)
KERN(stride2, 5, nchw44)
KERN(stride2, 7, nchw44)
#undef KERN
void pack_nchw44_weight_for_nchw_conv(const int8_t* inptr, int8_t* outptr,
const int ic, const int fh, const int fw,
const int oc);
void pack_nchw_src_for_nchw44_conv(const int8_t* inptr, int8_t* outptr,
const int ic, const int top_pad,
const int bottom_pad, const int left_pad,
const int right_pad, const int ih,
const int iw);
} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn
\ No newline at end of file
...@@ -47,7 +47,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { ...@@ -47,7 +47,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8DirectStride2 s8_direct_stride2_large_group{true}; AlgoS8DirectStride2 s8_direct_stride2_large_group{true};
AlgoS8DirectStride2 s8_direct_stride2_small_group{false}; AlgoS8DirectStride2 s8_direct_stride2_small_group{false};
AlgoS8DirectStride2NCHW44 s8_direct_stride2_nchw44; AlgoS8DirectStride2NCHW44 s8_direct_stride2_nchw44;
AlgoS8DirectStride2NCHWNCHW44 s8_direct_stride2_nchw_nchw44; AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44;
AlgoS8DirectStride1 s8_direct_stride1_large_group{true}; AlgoS8DirectStride1 s8_direct_stride1_large_group{true};
AlgoS8DirectStride1 s8_direct_stride1_small_group{false}; AlgoS8DirectStride1 s8_direct_stride1_small_group{false};
AlgoS8DirectStride1NCHW44 s8_direct_stride1_nchw44; AlgoS8DirectStride1NCHW44 s8_direct_stride1_nchw44;
...@@ -115,7 +115,7 @@ public: ...@@ -115,7 +115,7 @@ public:
direct_algos.emplace_back(&s8_direct_stride2_large_group); direct_algos.emplace_back(&s8_direct_stride2_large_group);
direct_algos.emplace_back(&s8_direct_stride2_small_group); direct_algos.emplace_back(&s8_direct_stride2_small_group);
direct_algos.emplace_back(&s8_direct_stride2_nchw44); direct_algos.emplace_back(&s8_direct_stride2_nchw44);
direct_algos.emplace_back(&s8_direct_stride2_nchw_nchw44); direct_algos.emplace_back(&s8_direct_nchw_nchw44);
direct_algos.emplace_back(&s8_direct_stride1_large_group); direct_algos.emplace_back(&s8_direct_stride1_large_group);
direct_algos.emplace_back(&s8_direct_stride1_small_group); direct_algos.emplace_back(&s8_direct_stride1_small_group);
direct_algos.emplace_back(&s8_direct_stride1_nchw44); direct_algos.emplace_back(&s8_direct_stride1_nchw44);
......
...@@ -40,7 +40,7 @@ private: ...@@ -40,7 +40,7 @@ private:
class AlgoS8DirectStride1NCHW44; class AlgoS8DirectStride1NCHW44;
class AlgoS8DirectStride2; class AlgoS8DirectStride2;
class AlgoS8DirectStride2NCHW44; class AlgoS8DirectStride2NCHW44;
class AlgoS8DirectStride2NCHWNCHW44; class AlgoS8DirectNCHWNCHW44;
class AlgoQU8DirectStride1; class AlgoQU8DirectStride1;
class AlgoQU8DirectStride2; class AlgoQU8DirectStride2;
class AlgoFP32WinogradF23_4x4; class AlgoFP32WinogradF23_4x4;
......
...@@ -244,18 +244,26 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) { ...@@ -244,18 +244,26 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) {
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", true); "IM2COLMATMUL:AARCH64_F32K8X12X1:192", true);
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", false);
#else #else
benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384", benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384",
"IM2COLMATMUL:ARMV7_F32:192", true); "IM2COLMATMUL:ARMV7_F32:192", true);
benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384",
"IM2COLMATMUL:ARMV7_F32:192", false);
#endif #endif
} }
TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) { TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) {
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", true); "IM2COLMATMUL:AARCH64_F32K8X12X1:192", true);
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", false);
#else #else
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:ARMV7_F32:192", true); "IM2COLMATMUL:ARMV7_F32:192", true);
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:ARMV7_F32:192", false);
#endif #endif
} }
......
...@@ -541,7 +541,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) { ...@@ -541,7 +541,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44) {
checker_conv_bias_qint8x8x8( checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({3, 5, 7}, 2, false, false, false, true), get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false,
true),
handle(), "S8_CONV_NCHW_NCHW44");
checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false,
true),
handle(), "S8_CONV_NCHW_NCHW44"); handle(), "S8_CONV_NCHW_NCHW44");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册