提交 2e6e570d 编写于 作者: M Megvii Engine Team

feat(dnn/fallback): add armv7 im2col mk4-dot int8 and

 nchw44 float 3x3 s2 fuse packb speed up about 10%

GitOrigin-RevId: 3f864cef1d41686912555738e79fa4fa9e6ef86b
上级 457a1e01
......@@ -227,28 +227,28 @@ public:
"DefaultStrategyType::FLOAT"_hash);
} else if (format == param::ConvBias::Format::NCHW44) {
#if MEGDNN_AARCH64
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
auto matmul_block = matmul_algo->get_inner_block_size();
//! Optimize NCHW44 3x3s2 8X12X1 im2col+pack fuse
if (matmul_block.m == 8 && matmul_block.n == 12 &&
matmul_block.k == 1 &&
param.filter_meta.spatial[0] == 3 &&
param.filter_meta.spatial[1] == 3 &&
param.filter_meta.stride[0] == 2 &&
param.filter_meta.stride[1] == 2 &&
!param.filter_meta.should_flip) {
MIDOUT_BEGIN(
megdnn_fallback_im2col_factory_make_strategy,
midout_iv(
"DefaultStrategyType::8x12x1_fuse_packb_s2_nchw44"_hash)) {
return std::make_unique<
StrategyFuse8x12x1Nchw44K3x3S2<
float, float,
PostprocessMode::FLOAT>>();
}
MIDOUT_END();
return {};
//! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1 im2col+pack fuse
if ((matmul_block.m == 8 || matmul_block.m == 4) &&
matmul_block.n == 12 && matmul_block.k == 1 &&
param.filter_meta.spatial[0] == 3 &&
param.filter_meta.spatial[1] == 3 &&
param.filter_meta.stride[0] == 2 &&
param.filter_meta.stride[1] == 2 &&
!param.filter_meta.should_flip) {
MIDOUT_BEGIN(
megdnn_fallback_im2col_factory_make_strategy,
midout_iv(
"DefaultStrategyType::8x12x1_fuse_packb_s2_nchw44"_hash)) {
return std::make_unique<
StrategyFuseXx12x1Nchw44K3x3S2<
float, float,
PostprocessMode::FLOAT>>();
}
MIDOUT_END();
return {};
}
#endif
cb1(NCHW44, DEFAULT, dt_float32, dt_float32,
......@@ -345,10 +345,10 @@ public:
"DefaultStrategyType::QINT8x8x32x8"_hash);
} else if (format == param::ConvBias::Format::NCHW44 ||
format == param::ConvBias::Format::NCHW44_DOT) {
#if MEGDNN_AARCH64
auto matmul_block = matmul_algo->get_inner_block_size();
if (format == param::ConvBias::Format::NCHW44) {
//! Optimize NCHW44 3x3s1 4X4X16 im2col+pack fuse
#if MEGDNN_AARCH64
auto matmul_block = matmul_algo->get_inner_block_size();
if (matmul_block.m == 4 && matmul_block.n == 4 &&
matmul_block.k == 16 &&
param.filter_meta.spatial[0] == 3 &&
......@@ -368,7 +368,10 @@ public:
MIDOUT_END();
return {};
}
#endif
} else {
#if MEGDNN_AARCH64
auto matmul_block = matmul_algo->get_inner_block_size();
//! Optimize NCHW44_DOT 3x3s1 8X12X4 im2col+pack fuse
if (matmul_block.m == 8 && matmul_block.n == 12 &&
matmul_block.k == 4 &&
......@@ -389,8 +392,30 @@ public:
MIDOUT_END();
return {};
}
}
#endif
#if MEGDNN_ARMV7
auto matmul_block = matmul_algo->get_inner_block_size();
if (matmul_block.m == 8 && matmul_block.n == 4 &&
matmul_block.k == 4 &&
param.filter_meta.spatial[0] == 3 &&
param.filter_meta.spatial[1] == 3 &&
param.filter_meta.stride[0] == 2 &&
param.filter_meta.stride[1] == 2 &&
!param.filter_meta.should_flip) {
MIDOUT_BEGIN(
megdnn_fallback_im2col_factory_make_strategy,
midout_iv(
"DefaultStrategyType::INT8x8x32_8x4x4_s2"_hash)) {
return std::make_unique<
StrategyFuse8x4x4Nchw44DotK3x3S2<
dt_qint32, dt_qint8,
PostprocessMode::QUANTIZED>>();
}
MIDOUT_END();
return {};
}
#endif
}
cb2(NCHW44, DEFAULT, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8,
dt_int32, dt_int8, PostprocessMode::QUANTIZED,
......
......@@ -488,12 +488,12 @@ public:
template <typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
class StrategyFuse8x12x1Nchw44K3x3S2
: public Strategy<float, float, float, op_ctype, op_dtype,
class StrategyFuse8x12x4Nchw44Dot
: public Strategy<dt_int8, dt_int32, dt_int8, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT,
FormatMode::NCHW44> {
public:
StrategyFuse8x12x1Nchw44K3x3S2() = default;
StrategyFuse8x12x4Nchw44Dot() = default;
constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1;
......@@ -508,16 +508,15 @@ public:
fallback::MatrixMulImpl::KernParam matmul_param,
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
};
#else
template <typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
class StrategyFuse8x12x4Nchw44Dot
class StrategyFuse8x4x4Nchw44DotK3x3S2
: public Strategy<dt_int8, dt_int32, dt_int8, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT,
FormatMode::NCHW44> {
public:
StrategyFuse8x12x4Nchw44Dot() = default;
StrategyFuse8x4x4Nchw44DotK3x3S2() = default;
constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1;
......@@ -534,6 +533,30 @@ public:
};
#endif
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
template <typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
class StrategyFuseXx12x1Nchw44K3x3S2
: public Strategy<float, float, float, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT,
FormatMode::NCHW44> {
public:
StrategyFuseXx12x1Nchw44K3x3S2() = default;
constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1;
constexpr static size_t THREAD_BUNDLE_PACKB_INDEX = 0;
constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1;
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2;
void exec_im2col(
const WorkspaceBundle& bundle, const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
};
#endif
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot_s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#if MEGDNN_ARMV7
#include <arm_neon.h>
using namespace megdnn;
namespace {
#define PACKB_ONELINE() \
int out_index = 0; \
outptr = output_base; \
for (; out_index + 3 < block_size; out_index += 4) { \
std::memcpy(outptr, tmp_output, 16); \
outptr += ksize4; \
tmp_output += 4; \
} \
\
if (out_index < block_size) { \
uint32_t zerobuffer[4] = {0}; \
size_t out_remain = std::min(block_size - out_index, 4); \
std::memcpy(outptr, tmp_output, out_remain * sizeof(uint32_t)); \
outptr += out_remain; \
std::memcpy(outptr, zerobuffer, (4 - out_remain) * sizeof(uint32_t)); \
} \
output_base += 4;
#define STOR_IM2COL_DST() \
output0[count] = uint32_src[index]; \
output1[count] = uint32_src[index + 1]; \
output2[count] = uint32_src[index + 2]; \
count++; \
index += SW;
#define LOAD_AND_STOR_IM2COL_DST() \
uint32x4x2_t val_01 = vld2q_u32(&uint32_src[index]); \
index += 8; \
uint32x4_t val_index8 = vdupq_n_u32(uint32_src[index]); \
uint32x4_t val_2 = vextq_u32(val_01.val[0], val_index8, 1); \
vst1q_u32(&output0[count], val_01.val[0]); \
vst1q_u32(&output1[count], val_01.val[1]); \
vst1q_u32(&output2[count], val_2); \
count += 4;
void fuse_packb(const dt_int8* __restrict src, dt_int8* __restrict dst,
dt_int8* __restrict b_panel, const int OW, const int IC,
const int IH, const int IW, const int cur_index,
const int block_size) {
int start_h = cur_index / OW;
int cur_remain_w = cur_index % OW;
int end_h = (cur_index + block_size) / OW;
int end_remain_w = (cur_index + block_size) % OW;
bool same_line = start_h == end_h ? true : false;
size_t newIC = IC / 4;
const uint32_t* uint32_src =
static_cast<const uint32_t*>(static_cast<const void*>(src));
uint32_t* output = static_cast<uint32_t*>(static_cast<void*>(dst));
uint32_t* b_output = static_cast<uint32_t*>(static_cast<void*>(b_panel));
const int packed_k = newIC * 3 * 3;
const int ksize4 = packed_k * 4;
uint32_t* outptr = b_output;
uint32_t* output_base = b_output;
constexpr int FH = 3;
constexpr int SH = 2;
constexpr int SW = 2;
if (same_line) {
rep(ic, newIC) {
rep(fh, FH) {
uint32_t* output02 = output;
uint32_t* output1 = output + block_size + 1;
size_t count = 0;
size_t index = 0;
int w = cur_remain_w;
index = (ic * IH + (start_h * SH + fh)) * IW + w * SW;
for (; w + 3 < end_remain_w; w += 4) {
uint32x4x2_t val_01 = vld2q_u32(&uint32_src[index]);
vst1q_u32(&output02[count], val_01.val[0]);
vst1q_u32(&output1[count], val_01.val[1]);
count += 4;
index += 8;
}
for (; w < end_remain_w; w++) {
output02[count] = uint32_src[index + 0];
output1[count] = uint32_src[index + 1];
count++;
index += SW;
}
output02[count] = uint32_src[index];
const uint32_t* output_ptr[3];
output_ptr[0] = output02;
output_ptr[1] = output1;
output_ptr[2] = output02 + 1;
for (int i = 0; i < 3; i++) {
const uint32_t* tmp_output = output_ptr[i];
PACKB_ONELINE();
}
}
}
} else {
rep(ic, newIC) {
rep(fh, FH) {
size_t count = 0;
size_t index = 0;
uint32_t* output0 = output;
uint32_t* output1 = output + block_size;
uint32_t* output2 = output1 + block_size;
int w = cur_remain_w;
index = (ic * IH + (SH * start_h + fh)) * IW + SW * w;
for (; w + 3 < OW; w += 4) {
LOAD_AND_STOR_IM2COL_DST()
}
for (; w < OW; w++) {
STOR_IM2COL_DST()
}
for (int h = start_h + 1; h < end_h; h++) {
int ow = 0;
index = (ic * IH + (SH * h + fh)) * IW;
for (; ow + 3 < OW; ow += 4) {
LOAD_AND_STOR_IM2COL_DST()
}
for (; ow < OW; ow++) {
STOR_IM2COL_DST()
}
}
index = (ic * IH + (SH * end_h + fh)) * IW;
w = 0;
for (; w + 3 < end_remain_w; w += 4) {
LOAD_AND_STOR_IM2COL_DST()
}
for (; w < end_remain_w; w++) {
STOR_IM2COL_DST()
}
for (int k = 0; k < 3; k++) {
const uint32_t* tmp_output = output + k * block_size;
PACKB_ONELINE();
}
}
}
}
}
#undef PACKB_ONELINE
#undef STOR_IM2COL_DST
#undef LOAD_AND_STOR_IM2COL_DST
} // namespace
template <typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void StrategyFuse8x4x4Nchw44DotK3x3S2<op_ctype, op_dtype, postprocess_mode>::
exec_im2col(const WorkspaceBundle& bundle,
const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam /*matmul_param*/,
const fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) {
size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t input_offset =
ih * iw * ic *
(sparam.group_id + param.filter_meta.group * sparam.batch_id) *
sizeof(dt_int8);
dt_int8* src2 = reinterpret_cast<dt_int8*>(
reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PADDING_INDEX)) +
input_offset);
bool is_phpwzero = param.filter_meta.padding[0] == 0 &&
param.filter_meta.padding[1] == 0;
if (is_phpwzero) {
src2 = const_cast<dt_int8*>(
param.src<dt_int8>(sparam.batch_id, sparam.group_id));
}
dt_int8* b_panel = reinterpret_cast<dt_int8*>(reinterpret_cast<uintptr_t>(
bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX)));
megdnn_assert(ic % 4 == 0, "nchw44dot_dot with ic is not of time 4");
int8_t* im2col_dst =
static_cast<int8_t*>(bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX));
fuse_packb(src2, im2col_dst, b_panel, ow, ic, ih, iw, sparam.ohw_cur_index,
sparam.output_block_size);
}
namespace megdnn {
template class StrategyFuse8x4x4Nchw44DotK3x3S2<dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED>;
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
......@@ -10,9 +10,8 @@
*/
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_AARCH64
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
#include <arm_neon.h>
using namespace megdnn;
......@@ -163,7 +162,7 @@ void fuse_packb(const float* __restrict src, float* __restrict dst,
template <typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void StrategyFuse8x12x1Nchw44K3x3S2<op_ctype, op_dtype, postprocess_mode>::
void StrategyFuseXx12x1Nchw44K3x3S2<op_ctype, op_dtype, postprocess_mode>::
exec_im2col(const WorkspaceBundle& bundle,
const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam,
......@@ -194,14 +193,13 @@ void StrategyFuse8x12x1Nchw44K3x3S2<op_ctype, op_dtype, postprocess_mode>::
float* im2col_dst =
static_cast<float*>(bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX));
fuse_packb(src2, im2col_dst, b_panel, ow, ic, ih, iw, sparam.ohw_cur_index,
sparam.output_block_size);
}
namespace megdnn {
template class StrategyFuse8x12x1Nchw44K3x3S2<float, float,
template class StrategyFuseXx12x1Nchw44K3x3S2<float, float,
megdnn::PostprocessMode::FLOAT>;
} // namespace megdnn
......
......@@ -1461,6 +1461,25 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) {
#undef cb
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT_S2_FUSE) {
UniformIntRNG rng{-50, 50};
#define cb(name) \
checker_conv_bias(get_nchw44_conv_bias_args({3}, 2, false, \
false, false, false, true), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
dtype::QuantizedS8(60.25f), name); \
float epsilon = 0.001;
#if MEGDNN_AARCH64
cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
#elif MEGDNN_ARMV7
cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96");
#endif
#undef cb
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) {
UniformIntRNG rng{-50, 50};
......
......@@ -655,6 +655,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44) {
bench_case(1, 512, 256, 28, 28, 3, 4, 1, 2);
}
#if __ARM_FEATURE_DOTPROD
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) {
constexpr size_t RUNS = 40;
std::vector<DType> data_type = {
......@@ -708,6 +709,64 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) {
}
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT_S2) {
constexpr size_t RUNS = 40;
std::vector<DType> data_type = {
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f)};
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W,
size_t FS, size_t group, size_t P, size_t S,
bool is_nchw = false) {
param::ConvBias param;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
param.pad_h = P;
param.pad_w = P;
param.stride_h = S;
param.stride_w = S;
param.sparse = param::ConvBias::Sparse::DENSE;
param.format = param::ConvBias::Format::NCHW44_DOT;
auto OH = (H + 2 * P - FS) / static_cast<size_t>(S) + 1;
auto OW = (W + 2 * P - FS) / static_cast<size_t>(S) + 1;
TensorShape src = {N, IC / 4, H, W, 4};
TensorShape filter = {OC / 4, IC / 4, FS, FS, 4, 4};
if (group > 1) {
filter = {group, OC / group / 4, IC / group / 4, FS, FS, 4, 4};
param.sparse = param::ConvBias::Sparse::GROUP;
}
if (is_nchw) {
src = {N, IC, H, W};
filter = {OC / 4, FS, FS, IC, 4};
}
TensorShape bias = {1, OC / 4, 1, 1, 4};
TensorShape dst = {N, OC / 4, OH, OW, 4};
SmallVector<TensorShape> shapes{src, filter, bias, {}, dst};
float computations =
(((IC / group) * FS * FS + 1) * dst.total_nr_elems() * 2 +
dst.total_nr_elems()) *
1e-6;
std::vector<std::pair<SmallVector<TensorShape>, float>> shape_arg = {
std::make_pair(shapes, computations)};
benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}},
{1, {7}}, data_type);
};
bench_case(1, 64, 64, 56, 56, 3, 1, 1, 2);
bench_case(1, 64, 64, 128, 128, 3, 1, 1, 2);
bench_case(1, 64, 64, 256, 256, 3, 1, 1, 2);
bench_case(1, 64, 64, 156, 156, 3, 1, 1, 2);
bench_case(1, 128, 128, 28, 28, 3, 1, 1, 2);
bench_case(1, 256, 256, 14, 14, 3, 1, 1, 2);
bench_case(1, 512, 512, 7, 7, 3, 1, 1, 2);
bench_case(1, 64, 64, 56, 56, 3, 4, 1, 2);
bench_case(1, 128, 128, 28, 28, 3, 4, 1, 2);
bench_case(1, 256, 256, 14, 14, 3, 4, 1, 2);
bench_case(1, 512, 512, 7, 7, 3, 4, 1, 2);
}
#endif
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_FLOAT_NCHW44) {
constexpr size_t RUNS = 40;
std::vector<DType> data_type = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册