提交 6b2760dd 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(dnn/fallback): add float32 nchw44 fuse packb 3x3 s2

GitOrigin-RevId: 3b664bb4f578f5e3f2c36fc963217e37676c9b78
上级 7aeb4f6c
......@@ -226,6 +226,31 @@ public:
PostprocessMode::FLOAT,
"DefaultStrategyType::FLOAT"_hash);
} else if (format == param::ConvBias::Format::NCHW44) {
#if MEGDNN_AARCH64
auto matmul_block = matmul_algo->get_inner_block_size();
//! Optimize NCHW44 3x3s2 8X12X1 im2col+pack fuse
if (matmul_block.m == 8 && matmul_block.n == 12 &&
matmul_block.k == 1 &&
param.filter_meta.spatial[0] == 3 &&
param.filter_meta.spatial[1] == 3 &&
param.filter_meta.stride[0] == 2 &&
param.filter_meta.stride[1] == 2 &&
!param.filter_meta.should_flip) {
MIDOUT_BEGIN(
megdnn_fallback_im2col_factory_make_strategy,
midout_iv(
"DefaultStrategyType::8x12x1_fuse_packb_s2_nchw44"_hash)) {
return std::make_unique<
StrategyFuse8x12x1Nchw44K3x3S2<
float, float,
PostprocessMode::FLOAT>>();
}
MIDOUT_END();
return {};
}
#endif
cb1(NCHW44, DEFAULT, dt_float32, dt_float32,
PostprocessMode::FLOAT,
"DefaultStrategyTypeNCHW44::FLOAT"_hash);
......@@ -320,6 +345,52 @@ public:
"DefaultStrategyType::QINT8x8x32x8"_hash);
} else if (format == param::ConvBias::Format::NCHW44 ||
format == param::ConvBias::Format::NCHW44_DOT) {
#if MEGDNN_AARCH64
auto matmul_block = matmul_algo->get_inner_block_size();
if (format == param::ConvBias::Format::NCHW44) {
//! Optimize NCHW44 3x3s1 4X4X16 im2col+pack fuse
if (matmul_block.m == 4 && matmul_block.n == 4 &&
matmul_block.k == 16 &&
param.filter_meta.spatial[0] == 3 &&
param.filter_meta.spatial[1] == 3 &&
param.filter_meta.stride[0] == 1 &&
param.filter_meta.stride[1] == 1 &&
!param.filter_meta.should_flip) {
MIDOUT_BEGIN(
megdnn_fallback_im2col_factory_make_strategy,
midout_iv(
"DefaultStrategyType::INT8x8x32_4x4x16"_hash)) {
return std::make_unique<
StrategyFuse4x4x16Nchw44<
dt_qint32, dt_qint8,
PostprocessMode::QUANTIZED>>();
}
MIDOUT_END();
return {};
}
} else {
//! Optimize NCHW44_DOT 3x3s1 8X12X4 im2col+pack fuse
if (matmul_block.m == 8 && matmul_block.n == 12 &&
matmul_block.k == 4 &&
param.filter_meta.spatial[0] == 3 &&
param.filter_meta.spatial[1] == 3 &&
param.filter_meta.stride[0] == 1 &&
param.filter_meta.stride[1] == 1 &&
!param.filter_meta.should_flip) {
MIDOUT_BEGIN(
megdnn_fallback_im2col_factory_make_strategy,
midout_iv(
"DefaultStrategyType::INT8x8x32_8x12x4"_hash)) {
return std::make_unique<
StrategyFuse8x12x4Nchw44Dot<
dt_qint32, dt_qint8,
PostprocessMode::QUANTIZED>>();
}
MIDOUT_END();
return {};
}
}
#endif
cb2(NCHW44, DEFAULT, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8,
dt_int32, dt_int8, PostprocessMode::QUANTIZED,
......
......@@ -445,6 +445,75 @@ public:
THREAD_BUNDLE_BIAS_INDEX);
}
};
#if MEGDNN_AARCH64
template <typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
class StrategyFuse4x4x16Nchw44
: public Strategy<dt_int8, dt_int32, dt_int8, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT,
FormatMode::NCHW44> {
public:
StrategyFuse4x4x16Nchw44() = default;
constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1;
constexpr static size_t THREAD_BUNDLE_PACKB_INDEX = 0;
constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1;
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2;
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
};
template <typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
class StrategyFuse8x12x1Nchw44K3x3S2
: public Strategy<float, float, float, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT,
FormatMode::NCHW44> {
public:
StrategyFuse8x12x1Nchw44K3x3S2() = default;
constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1;
constexpr static size_t THREAD_BUNDLE_PACKB_INDEX = 0;
constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1;
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2;
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
};
template <typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
class StrategyFuse8x12x4Nchw44Dot
: public Strategy<dt_int8, dt_int32, dt_int8, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT,
FormatMode::NCHW44> {
public:
StrategyFuse8x12x4Nchw44Dot() = default;
constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1;
constexpr static size_t THREAD_BUNDLE_PACKB_INDEX = 0;
constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1;
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2;
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
};
#endif
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -14,6 +14,9 @@
#include "src/x86/conv_bias/postprocess_helper.h"
#endif
#if (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#endif
using namespace megdnn;
#if MEGDNN_X86
......@@ -101,11 +104,23 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16,
megdnn::PostprocessMode::FLOAT)
#else
#if !MEGDNN_DISABLE_FLOAT16
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
megdnn::PostprocessMode::NO_PROCESS)
#endif
#endif
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
#endif
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED)
......
......@@ -11,5 +11,235 @@
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#if MEGDNN_AARCH64
#include <arm_neon.h>
using namespace megdnn;
namespace {
#define TRANS_AND_STORE(input0, input1, input2, input3) \
{ \
auto tmp01 = vzipq_s32(input0, input1); \
auto tmp23 = vzipq_s32(input2, input3); \
auto dst0 = vzip1q_s64(vreinterpretq_s64_s32(tmp01.val[0]), \
vreinterpretq_s64_s32(tmp23.val[0])); \
auto dst1 = vzip2q_s64(vreinterpretq_s64_s32(tmp01.val[0]), \
vreinterpretq_s64_s32(tmp23.val[0])); \
auto dst2 = vzip1q_s64(vreinterpretq_s64_s32(tmp01.val[1]), \
vreinterpretq_s64_s32(tmp23.val[1])); \
auto dst3 = vzip2q_s64(vreinterpretq_s64_s32(tmp01.val[1]), \
vreinterpretq_s64_s32(tmp23.val[1])); \
vst1q_s32(dst, vreinterpretq_s32_s64(dst0)); \
vst1q_s32(dst + 4, vreinterpretq_s32_s64(dst1)); \
vst1q_s32(dst + 8, vreinterpretq_s32_s64(dst2)); \
vst1q_s32(dst + 12, vreinterpretq_s32_s64(dst3)); \
dst += 16; \
}
#define TRANS_AND_STORE_REMAIN(input0, input1, input2, input3, remain) \
{ \
auto tmp01 = vzipq_s32(input0, input1); \
auto tmp23 = vzipq_s32(input2, input3); \
vdst[0] = vzip1q_s64(vreinterpretq_s64_s32(tmp01.val[0]), \
vreinterpretq_s64_s32(tmp23.val[0])); \
vdst[1] = vzip2q_s64(vreinterpretq_s64_s32(tmp01.val[0]), \
vreinterpretq_s64_s32(tmp23.val[0])); \
vdst[2] = vzip1q_s64(vreinterpretq_s64_s32(tmp01.val[1]), \
vreinterpretq_s64_s32(tmp23.val[1])); \
vdst[3] = vzip2q_s64(vreinterpretq_s64_s32(tmp01.val[1]), \
vreinterpretq_s64_s32(tmp23.val[1])); \
for (size_t i = 0; i < remain; i++) { \
vst1q_s32(dst + i * 4, vreinterpretq_s32_s64(vdst[i])); \
} \
dst += 16; \
}
void optimize_fuse_im2col_packB(dt_int8* src, size_t ic, size_t iw, size_t ih,
size_t curr_iw, size_t curr_ih, dt_int8* dst_ptr) {
int* src_line0 =
reinterpret_cast<int*>(src + curr_ih * iw * 4 + curr_iw * 4);
int* src_line1 =
reinterpret_cast<int*>(src + (curr_ih + 1) * iw * 4 + curr_iw * 4);
int* src_line2 =
reinterpret_cast<int*>(src + (curr_ih + 2) * iw * 4 + curr_iw * 4);
int* dst = reinterpret_cast<int*>(dst_ptr);
int32x4_t input[12];
int remain = 0;
for (size_t c = 0; c < ic; c++) {
input[remain] = vld1q_s32(src_line0);
input[remain + 1] = vld1q_s32(src_line0 + 1);
input[remain + 2] = vld1q_s32(src_line0 + 2);
input[remain + 3] = vld1q_s32(src_line1);
input[remain + 4] = vld1q_s32(src_line1 + 1);
input[remain + 5] = vld1q_s32(src_line1 + 2);
input[remain + 6] = vld1q_s32(src_line2);
input[remain + 7] = vld1q_s32(src_line2 + 1);
input[remain + 8] = vld1q_s32(src_line2 + 2);
TRANS_AND_STORE(input[0], input[1], input[2], input[3]);
TRANS_AND_STORE(input[4], input[5], input[6], input[7]);
if (remain == 3) {
TRANS_AND_STORE(input[8], input[9], input[10], input[11]);
remain = 0;
} else {
for (int i = 0; i <= remain; i++) {
input[i] = input[8 + i];
}
remain++;
}
src_line0 += ih * iw;
src_line1 += ih * iw;
src_line2 += ih * iw;
}
//! pad remain to 4
if (remain > 0) {
TRANS_AND_STORE(input[0], input[1], input[2], input[3]);
}
}
void naive_fuse_im2col_packB(dt_int8* src, size_t ic, size_t iw, size_t ih,
size_t curr_iw, size_t curr_ih, size_t num_point,
size_t ow, dt_int8* dst_ptr) {
megdnn_assert(num_point <= 4_z,
"fuse im2col and packB of 4x4x16 num_point must less than 4");
int* src_line0 = reinterpret_cast<int*>(src + curr_ih * iw * 4);
int* src_line1 = reinterpret_cast<int*>(src + (curr_ih + 1) * iw * 4);
int* src_line2 = reinterpret_cast<int*>(src + (curr_ih + 2) * iw * 4);
int remain = 0;
int out[9][4] = {{0}};
int32x4_t input[12];
int* dst = reinterpret_cast<int*>(dst_ptr);
for (size_t c = 0; c < ic; c++) {
//! Read int buffer out
size_t index = 0, w = curr_iw, dalta_h = 0;
while (index < num_point) {
int* src_next_line0 = src_line0 + dalta_h * iw;
int* src_next_line1 = src_next_line0 + iw;
int* src_next_line2 = src_next_line1 + iw;
for (; index < num_point && w < ow; index++, w++) {
out[0][index] = src_next_line0[w];
out[1][index] = src_next_line0[w + 1];
out[2][index] = src_next_line0[w + 2];
out[3][index] = src_next_line1[w];
out[4][index] = src_next_line1[w + 1];
out[5][index] = src_next_line1[w + 2];
out[6][index] = src_next_line2[w];
out[7][index] = src_next_line2[w + 1];
out[8][index] = src_next_line2[w + 2];
}
//! next line
w = 0;
dalta_h += 1;
}
//! load int vector
input[remain] = vld1q_s32(out[0]);
input[remain + 1] = vld1q_s32(out[1]);
input[remain + 2] = vld1q_s32(out[2]);
input[remain + 3] = vld1q_s32(out[3]);
input[remain + 4] = vld1q_s32(out[4]);
input[remain + 5] = vld1q_s32(out[5]);
input[remain + 6] = vld1q_s32(out[6]);
input[remain + 7] = vld1q_s32(out[7]);
input[remain + 8] = vld1q_s32(out[8]);
int64x2_t vdst[4];
TRANS_AND_STORE_REMAIN(input[0], input[1], input[2], input[3], num_point);
TRANS_AND_STORE_REMAIN(input[4], input[5], input[6], input[7], num_point);
if (remain == 3) {
TRANS_AND_STORE_REMAIN(input[8], input[9], input[10], input[11],
num_point);
remain = 0;
} else {
for (int i = 0; i <= remain; i++) {
input[i] = input[8 + i];
}
remain++;
}
src_line0 += ih * iw;
src_line1 += ih * iw;
src_line2 += ih * iw;
}
//! pad remain to 4
if (remain > 0) {
int64x2_t vdst[4];
TRANS_AND_STORE_REMAIN(input[0], input[1], input[2], input[3],
num_point);
}
}
} // namespace
template <typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void StrategyFuse4x4x16Nchw44<op_ctype, op_dtype, postprocess_mode>::
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam,
fallback::MatrixMulImpl::AlgoBase*) {
size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2;
constexpr static size_t pack_size = 4;
size_t input_offset =
ih * iw * ic *
(sparam.group_id + param.filter_meta.group * sparam.batch_id) *
sizeof(dt_int8);
dt_int8* src2 = reinterpret_cast<dt_int8*>(
reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PADDING_INDEX)) +
input_offset);
bool is_phpwzero = param.filter_meta.padding[0] == 0 &&
param.filter_meta.padding[1] == 0;
if (is_phpwzero) {
src2 = const_cast<dt_int8*>(
param.src<dt_int8>(sparam.batch_id, sparam.group_id));
}
dt_int8* b_panel =
reinterpret_cast<dt_int8*>(reinterpret_cast<uintptr_t>(
bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX)));
megdnn_assert(ic % 4 == 0, "nchw44 with ic is not of time 4");
const int packed_k = (ic * 3 * 3) / pack_size;
const int ksize4 = round_up<int>(packed_k, 4) * 16 * sizeof(dt_int8);
size_t out_size = sparam.output_block_size;
size_t curr_index = sparam.ohw_cur_index;
size_t curr_ih = curr_index / ow;
size_t curr_iw = curr_index % ow;
size_t out_index = 0;
while (out_index < out_size) {
for (; curr_iw + 3 < ow && out_index + 3 < out_size;
curr_iw += 4, out_index += 4) {
dt_int8* dst = b_panel + (out_index / 4) * ksize4;
optimize_fuse_im2col_packB(src2, ic / 4, iw, ih, curr_iw, curr_ih,
dst);
}
if (curr_iw < ow && out_index < out_size) {
size_t out_remain = std::min(out_size - out_index, 4_z);
size_t remain_point_this_line = std::min(ow - curr_iw, out_remain);
size_t start_point_next_line =
(out_remain - remain_point_this_line) % ow;
size_t pass_lines = (out_remain - remain_point_this_line) / ow;
dt_int8* dst = b_panel + (out_index / 4) * ksize4;
naive_fuse_im2col_packB(src2, ic / 4, iw, ih, curr_iw, curr_ih,
out_remain, ow, dst);
out_index += out_remain;
curr_iw = start_point_next_line;
curr_ih += (pass_lines + 1);
} else {
curr_iw = 0;
curr_ih++;
}
}
}
#undef TRANS_AND_STORE_REMAIN
#undef TRANS_AND_STORE
namespace megdnn {
template class StrategyFuse4x4x16Nchw44<dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED>;
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
......@@ -11,5 +11,209 @@
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#if MEGDNN_AARCH64
#include <arm_neon.h>
using namespace megdnn;
namespace {
#define PACKB_ONELINE() \
int out_index = 0; \
outptr = output_base; \
for (; out_index + 11 < block_size; out_index += 12) { \
std::memcpy(outptr, tmp_output, 48); \
outptr += ksize12; \
tmp_output += 12; \
} \
\
outptr = output_base4; \
for (; out_index + 3 < block_size; out_index += 4) { \
std::memcpy(outptr, tmp_output, 16); \
outptr += ksize4; \
tmp_output += 4; \
} \
\
if (out_index < block_size) { \
uint32_t zerobuffer[4] = {0}; \
size_t out_remain = std::min(block_size - out_index, 4); \
std::memcpy(outptr, tmp_output, out_remain * sizeof(uint32_t)); \
outptr += out_remain; \
std::memcpy(outptr, zerobuffer, (4 - out_remain) * sizeof(uint32_t)); \
} \
output_base += 12; \
output_base4 += 4;
#define STOR_IM2COL_DST() \
output0[count] = uint32_src[index + 0]; \
output1[count] = uint32_src[index + 1]; \
output2[count] = uint32_src[index + 2];
#define LOAD_AND_STOR_IM2COL_DST() \
uint32x4_t v_tmp = vld1q_u32(&uint32_src[index + 4]); \
uint32x4_t v_o1 = vextq_u32(v_o0, v_tmp, 1); \
uint32x4_t v_o2 = vextq_u32(v_o0, v_tmp, 2); \
vst1q_u32(&output0[count], v_o0); \
vst1q_u32(&output1[count], v_o1); \
vst1q_u32(&output2[count], v_o2); \
v_o0 = v_tmp;
void fuse_packb(const dt_int8* __restrict src, dt_int8* __restrict dst,
dt_int8* __restrict b_panel, const int OW, const int IC,
const int IH, const int IW,
const int cur_index, const int block_size) {
int start_h = cur_index / OW;
int cur_remain_w = cur_index % OW;
int end_h = (cur_index + block_size) / OW;
int end_remain_w = (cur_index + block_size) % OW;
bool same_line = start_h == end_h ? true : false;
size_t newIC = IC / 4;
const uint32_t* uint32_src =
static_cast<const uint32_t*>(static_cast<const void*>(src));
uint32_t* output = static_cast<uint32_t*>(static_cast<void*>(dst));
uint32_t* b_output = static_cast<uint32_t*>(static_cast<void*>(b_panel));
const int packed_k = newIC * 3 * 3;
const int ksize12 = packed_k * 12 * sizeof(dt_int8);
const int ksize4 = packed_k * 4 * sizeof(dt_int8);
uint32_t* outptr = b_output;
uint32_t* output_base = b_output;
uint32_t* output_base4 = b_output + block_size / 12 * ksize12;
constexpr int FH = 3;
if (same_line) {
rep(ic, newIC) {
rep(fh, FH) {
size_t count = 0;
size_t index = 0;
int w = cur_remain_w;
index = (ic * IH + (start_h + fh)) * IW + w;
for (; w + 3 < end_remain_w; w += 4) {
vst1q_u32(&output[count], vld1q_u32(&uint32_src[index]));
count += 4;
index += 4;
}
for (; w < end_remain_w; w++) {
output[count++] = uint32_src[index++];
}
output[count++] = uint32_src[index];
output[count++] = uint32_src[index + 1];
for (int i = 0; i < 3; i++) {
const uint32_t* tmp_output = output + i;
PACKB_ONELINE();
}
}
}
} else {
rep(ic, newIC) {
rep(fh, FH) {
size_t count = 0;
size_t index = 0;
uint32_t* output0 = output;
uint32_t* output1 = output + block_size;
uint32_t* output2 = output1 + block_size;
int w = cur_remain_w;
index = (ic * IH + (start_h + fh)) * IW + w;
uint32x4_t v_o0 = vld1q_u32(&uint32_src[index]);
for ( ; w + 3 < OW; w += 4) {
LOAD_AND_STOR_IM2COL_DST();
count += 4;
index += 4;
}
for (; w < OW; w++) {
STOR_IM2COL_DST();
count++;
index++;
}
for (int h = start_h + 1; h < end_h; h++) {
int ow = 0;
index = (ic * IH + (h + fh)) * IW + ow;
v_o0 = vld1q_u32(&uint32_src[index]);
for (; ow + 3 < OW; ow += 4) {
LOAD_AND_STOR_IM2COL_DST();
count += 4;
index += 4;
}
for (; ow < OW; ow++) {
STOR_IM2COL_DST();
count++;
index++;
}
}
index = (ic * IH + (end_h + fh)) * IW;
w = 0;
v_o0 = vld1q_u32(&uint32_src[index]);
for ( ; w + 3 < end_remain_w; w+=4) {
LOAD_AND_STOR_IM2COL_DST();
count+=4;
index+=4;
}
for ( ; w < end_remain_w; w++) {
STOR_IM2COL_DST();
count++;
index++;
}
for (int k = 0; k < 3; k++) {
const uint32_t* tmp_output = output + k * block_size;
PACKB_ONELINE();
}
}
}
}
}
#undef PACKB_ONELINE
#undef STOR_IM2COL_DST
#undef LOAD_AND_STOR_IM2COL_DST
} // namespace
template <typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void StrategyFuse8x12x4Nchw44Dot<op_ctype, op_dtype, postprocess_mode>::
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam /*matmul_param*/,
fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) {
size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t input_offset =
ih * iw * ic *
(sparam.group_id + param.filter_meta.group * sparam.batch_id) *
sizeof(dt_int8);
dt_int8* src2 = reinterpret_cast<dt_int8*>(
reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PADDING_INDEX)) +
input_offset);
bool is_phpwzero = param.filter_meta.padding[0] == 0 &&
param.filter_meta.padding[1] == 0;
if (is_phpwzero) {
src2 = const_cast<dt_int8*>(
param.src<dt_int8>(sparam.batch_id, sparam.group_id));
}
dt_int8* b_panel =
reinterpret_cast<dt_int8*>(reinterpret_cast<uintptr_t>(
bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX)));
megdnn_assert(ic % 4 == 0, "nchw44_dot with ic is not of time 4");
int8_t* im2col_dst = static_cast<int8_t*>(
bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX));
fuse_packb(src2, im2col_dst, b_panel, ow, ic, ih, iw, sparam.ohw_cur_index,
sparam.output_block_size);
}
namespace megdnn {
template class StrategyFuse8x12x4Nchw44Dot<dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED>;
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_AARCH64
#include <arm_neon.h>
using namespace megdnn;
namespace {
#define PACKB_ONELINE() \
int out_index = 0; \
outptr = output_base; \
for (; out_index + 11 < block_size; out_index += 12) { \
float32x4x4_t v0 = vld4q_f32(tmp_output); \
float32x4x4_t v1 = vld4q_f32(tmp_output + 16); \
float32x4x4_t v2 = vld4q_f32(tmp_output + 32); \
vst1q_f32(outptr, v0.val[0]); \
vst1q_f32(outptr + 4, v1.val[0]); \
vst1q_f32(outptr + 8, v2.val[0]); \
vst1q_f32(outptr + 12, v0.val[1]); \
vst1q_f32(outptr + 16, v1.val[1]); \
vst1q_f32(outptr + 20, v2.val[1]); \
vst1q_f32(outptr + 24, v0.val[2]); \
vst1q_f32(outptr + 28, v1.val[2]); \
vst1q_f32(outptr + 32, v2.val[2]); \
vst1q_f32(outptr + 36, v0.val[3]); \
vst1q_f32(outptr + 40, v1.val[3]); \
vst1q_f32(outptr + 44, v2.val[3]); \
outptr += ksize12; \
tmp_output += 48; \
} \
\
outptr = output_base4; \
for (; out_index + 3 < block_size; out_index += 4) { \
float32x4x4_t v0 = vld4q_f32(tmp_output); \
vst1q_f32(outptr, v0.val[0]); \
vst1q_f32(outptr + 4, v0.val[1]); \
vst1q_f32(outptr + 8, v0.val[2]); \
vst1q_f32(outptr + 12, v0.val[3]); \
outptr += ksize4; \
tmp_output += 16; \
} \
\
if (out_index < block_size) { \
float zerobuffer[16] = {0}; \
size_t out_remain = std::min(block_size - out_index, 4); \
std::memcpy(zerobuffer, tmp_output, out_remain * sizeof(float) * 4); \
float32x4x4_t v0 = vld4q_f32(zerobuffer); \
vst1q_f32(outptr, v0.val[0]); \
vst1q_f32(outptr + 4, v0.val[1]); \
vst1q_f32(outptr + 8, v0.val[2]); \
vst1q_f32(outptr + 12, v0.val[3]); \
} \
output_base += 48; \
output_base4 += 16;
#define LOAD_AND_STOR_IM2COL_DST() \
float32x4_t v1 = vld1q_f32(&src[index + 4]); \
float32x4_t v2 = vld1q_f32(&src[index + 8]); \
vst1q_f32(&output0[i], v0); \
vst1q_f32(&output1[i], v1); \
vst1q_f32(&output2[i], v2); \
i += 4; \
index += 8; \
v0 = v2;
void fuse_packb(const float* __restrict src, float* __restrict dst,
float* __restrict b_panel, const int OW, const int IC,
const int IH, const int IW, const int cur_index,
const int block_size) {
int start_h = cur_index / OW;
int cur_remain_w = cur_index % OW;
int end_h = (cur_index + block_size) / OW;
int end_remain_w = (cur_index + block_size) % OW;
bool same_line = start_h == end_h ? true : false;
size_t newIC = IC / 4;
float* b_output = b_panel;
const int packed_k = IC * 3 * 3;
const int ksize12 = packed_k * 12;
const int ksize4 = packed_k * 4;
float* outptr = b_output;
float* output_base = b_output;
float* output_base4 = b_output + block_size / 12 * ksize12;
constexpr int FH = 3;
constexpr int SH = 2;
constexpr int SW = 2;
if (same_line) {
rep(ic, newIC) {
rep(fh, FH) {
float* output02 = dst;
float* output1 = dst + block_size * 4 + 4;
size_t i = 0;
size_t index = 4 * (ic * IH * IW + (start_h * SH + fh) * IW +
cur_remain_w * SW);
for (int w = cur_remain_w; w < end_remain_w; w++) {
vst1q_f32(&output02[i], vld1q_f32(&src[index]));
vst1q_f32(&output1[i], vld1q_f32(&src[index + 4]));
i += 4;
index += 8;
}
vst1q_f32(&output02[i], vld1q_f32(&src[index]));
float* output[3];
output[0] = output02;
output[1] = output1;
output[2] = output02 + 4;
for (int i = 0; i < 3; i++) {
const float* tmp_output = output[i];
PACKB_ONELINE();
}
}
}
} else {
rep(ic, newIC) {
rep(fh, FH) {
float* output0 = dst;
float* output1 = dst + block_size * 4;
float* output2 = output1 + block_size * 4;
size_t i = 0;
size_t index = 4 * (ic * IH * IW + (start_h * SH + fh) * IW +
(cur_remain_w * SW));
float32x4_t v0 = vld1q_f32(&src[index]);
for (int w = cur_remain_w; w < OW; w++) {
LOAD_AND_STOR_IM2COL_DST();
}
for (int h = start_h + 1; h < end_h; h++) {
size_t index = 4 * (ic * IH * IW + (h * SH + fh) * IW);
v0 = vld1q_f32(&src[index]);
rep(ow, OW) { LOAD_AND_STOR_IM2COL_DST(); }
}
index = 4 * (ic * IH * IW + (end_h * SH + fh) * IW);
v0 = vld1q_f32(&src[index]);
for (int w = 0; w < end_remain_w; w++) {
LOAD_AND_STOR_IM2COL_DST();
}
for (int i = 0; i < 3; i++) {
const float* tmp_output = output0 + i * block_size * 4;
PACKB_ONELINE();
}
}
}
}
}
#undef PACKB_ONELINE
#undef LOAD_AND_STOR_IM2COL_DST
} // namespace
template <typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void StrategyFuse8x12x1Nchw44K3x3S2<op_ctype, op_dtype, postprocess_mode>::
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam /*matmul_param*/,
fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) {
size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t input_offset =
ih * iw * ic *
(sparam.group_id + param.filter_meta.group * sparam.batch_id) *
sizeof(float);
float* src2 = reinterpret_cast<float*>(
reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PADDING_INDEX)) +
input_offset);
bool is_phpwzero = param.filter_meta.padding[0] == 0 &&
param.filter_meta.padding[1] == 0;
if (is_phpwzero) {
src2 = const_cast<float*>(
param.src<float>(sparam.batch_id, sparam.group_id));
}
float* b_panel = reinterpret_cast<float*>(reinterpret_cast<uintptr_t>(
bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX)));
megdnn_assert(ic % 4 == 0, "nchw44_dot with ic is not of time 4");
float* im2col_dst =
static_cast<float*>(bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX));
fuse_packb(src2, im2col_dst, b_panel, ow, ic, ih, iw, sparam.ohw_cur_index,
sparam.output_block_size);
}
namespace megdnn {
template class StrategyFuse8x12x1Nchw44K3x3S2<float, float,
megdnn::PostprocessMode::FLOAT>;
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
......@@ -1838,7 +1838,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) {
check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12");
#endif
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32_FUSE) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
{3}, 2, false, false, false, false, false, true, true,false);
#if MEGDNN_AARCH64
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
#elif MEGDNN_ARMV7
check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12");
#endif
}
/***************************** Conv1x1 Algo Test ***********************/
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
using namespace conv_bias;
......
......@@ -708,6 +708,66 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) {
}
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_FLOAT_NCHW44) {
constexpr size_t RUNS = 40;
std::vector<DType> data_type = {
dtype::Float32(), dtype::Float32(),
dtype::Float32(), dtype::Float32()};
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W,
size_t FS, size_t group, size_t P, size_t S,
bool is_nchw = false) {
param::ConvBias param;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
param.pad_h = P;
param.pad_w = P;
param.stride_h = S;
param.stride_w = S;
param.sparse = param::ConvBias::Sparse::DENSE;
param.format = param::ConvBias::Format::NCHW44;
auto OH = (H + 2 * P - FS) / static_cast<size_t>(S) + 1;
auto OW = (W + 2 * P - FS) / static_cast<size_t>(S) + 1;
TensorShape src = {N, IC / 4, H, W, 4};
TensorShape filter = {OC / 4, IC / 4, FS, FS, 4, 4};
if (group > 1) {
filter = {group, OC / group / 4, IC / group / 4, FS, FS, 4, 4};
param.sparse = param::ConvBias::Sparse::GROUP;
}
if (is_nchw) {
src = {N, IC, H, W};
filter = {OC / 4, FS, FS, IC, 4};
}
TensorShape bias = {1, OC / 4, 1, 1, 4};
TensorShape dst = {N, OC / 4, OH, OW, 4};
SmallVector<TensorShape> shapes{src, filter, bias, {}, dst};
float computations =
(((IC / group) * FS * FS + 1) * dst.total_nr_elems() * 2 +
dst.total_nr_elems()) *
1e-6;
std::vector<std::pair<SmallVector<TensorShape>, float>> shape_arg = {
std::make_pair(shapes, computations)};
benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}},
{1, {7}}, data_type);
};
bench_case(1, 64, 64, 56, 56, 3, 1, 1, 2);
bench_case(1, 128, 128, 28, 28, 3, 1, 1, 2);
bench_case(1, 256, 256, 14, 14, 3, 1, 1, 2);
bench_case(1, 512, 512, 7, 7, 3, 1, 1, 2);
bench_case(1, 64, 64, 56, 56, 3, 4, 1, 2);
bench_case(1, 128, 128, 28, 28, 3, 4, 1, 2);
bench_case(1, 256, 256, 14, 14, 3, 4, 1, 2);
bench_case(1, 512, 512, 7, 7, 3, 4, 1, 2);
bench_case(1, 64, 64, 56*2, 56*2, 3, 4, 1, 2);
bench_case(1, 128, 128, 28*2, 28*2, 3, 4, 1, 2);
bench_case(1, 256, 256, 14*2, 14*2, 3, 4, 1, 2);
bench_case(1, 512, 512, 7*2, 7*2, 3, 4, 1, 2);
}
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE2) {
constexpr size_t RUNS = 50;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册