提交 4d56371e 编写于 作者: M Megvii Engine Team

refactor(dnn/arm): split arm direct kernel to cut compile time

GitOrigin-RevId: b06fba83eb05ad2109d8dcc78e6a9c88498c093d
上级 55844d3e
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace conv_bias {
template <>
void pack_src_fp32_nchw44<1>(float* sptr_base, const float* sptr_origin,
const int, const int pw, const int pad_right,
const int ih, const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride) {
constexpr int ic_step = 4;
rep_step(ic_idx, ic, ic_step) {
const float* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step);
sptr_base += iw2 * pad_top * ic_step;
rep(ih_idx, ih) {
memset(sptr_base, 0, sizeof(float) * pw * ic_step);
sptr_base += pw * ic_step;
memcpy(sptr_base, sptr, sizeof(float) * iw * ic_step);
sptr_base += iw * ic_step;
sptr += iw * ic_step;
memset(sptr_base, 0, sizeof(float) * pad_right * ic_step);
sptr_base += pad_right * ic_step;
}
memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step);
sptr_base += iw2 * pad_bottom * ic_step;
}
}
namespace {
static inline void odd_even_split_iw8_even(float* sptr_base, const float* sptr,
const int odd_start,
const int src_idx,
const int iw_idx) {
constexpr int ic_step = 4;
const int src_offset = src_idx * ic_step;
const int even_offset = iw_idx / 2 * ic_step;
const int odd_offset = (odd_start + iw_idx / 2) * ic_step;
float32x4_t temp[8];
temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step);
temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step);
temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step);
temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step);
temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step);
temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step);
temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step);
temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step);
vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[0]);
vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[2]);
vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[4]);
vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[6]);
vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[1]);
vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[3]);
vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[5]);
vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[7]);
}
static inline void odd_even_split_iw8_odd(float* sptr_base, const float* sptr,
const int odd_start,
const int src_idx, const int iw_idx) {
constexpr int ic_step = 4;
const int src_offset = src_idx * ic_step;
const int even_offset = (iw_idx + 1) / 2 * ic_step;
const int odd_offset = (odd_start + iw_idx / 2) * ic_step;
float32x4_t temp[8];
temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step);
temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step);
temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step);
temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step);
temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step);
temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step);
temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step);
temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step);
vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[0]);
vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[2]);
vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[4]);
vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[6]);
vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[1]);
vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[3]);
vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[5]);
vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[7]);
}
} // namespace
template <>
void pack_src_fp32_nchw44<2>(float* sptr_base, const float* sptr_origin,
const int ph, const int pw, const int pad_right,
const int ih, const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride) {
constexpr int ic_step = 4;
int odd_start = megdnn::div_ceil(iw2, 2);
float32x4_t zero_v = vdupq_n_f32(0.f);
MEGDNN_MARK_USED_VAR(ph);
bool even_start = pw % 2 == 0;
rep_step(ic_idx, ic, ic_step) {
const float* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step);
sptr_base += iw2 * pad_top * ic_step;
rep(ih_idx, ih) {
int iw_idx = 0;
rep(idx, pw) {
if (iw_idx % 2 == 0) {
vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v);
} else {
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step,
zero_v);
}
++iw_idx;
}
int src_idx = 0;
if (even_start) {
for (; src_idx + 7 < iw; src_idx += 8) {
odd_even_split_iw8_even(sptr_base, sptr, odd_start, src_idx,
iw_idx);
iw_idx += 8;
}
} else {
for (; src_idx + 7 < iw; src_idx += 8) {
odd_even_split_iw8_odd(sptr_base, sptr, odd_start, src_idx,
iw_idx);
iw_idx += 8;
}
}
for (; src_idx < iw; ++src_idx) {
if (iw_idx % 2 == 0) {
vst1q_f32(sptr_base + iw_idx / 2 * ic_step,
vld1q_f32(sptr + src_idx * ic_step));
} else {
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step,
vld1q_f32(sptr + src_idx * ic_step));
}
++iw_idx;
}
rep(idx, pad_right) {
if (iw_idx % 2 == 0) {
vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v);
} else {
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step,
zero_v);
}
++iw_idx;
}
sptr_base += iw2 * ic_step;
sptr += iw * ic_step;
}
memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step);
sptr_base += iw2 * pad_bottom * ic_step;
}
}
} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1(2);
\ No newline at end of file
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2(2);
\ No newline at end of file
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1(3);
\ No newline at end of file
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2(3);
\ No newline at end of file
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1(5);
\ No newline at end of file
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2(5);
\ No newline at end of file
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1(7);
\ No newline at end of file
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2(7);
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.cpp
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -12,7 +12,7 @@
*/
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
......@@ -24,21 +24,21 @@ using namespace megdnn;
using namespace arm_common;
namespace {
template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block,
typename T, typename T2, typename T3, typename T4>
template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T,
typename T2, typename T3, typename T4>
struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, 8, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8]); \
c[1][step] = Func::template impl<lane>(c[1][step], weight[1][lane], \
src[(step + src_idx) % 8]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8], lane); \
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \
src[(step + src_idx) % 8], lane);
UNROLL_CALL_RAW(8, cb, 0);
UNROLL_CALL_RAW(8, cb, 1);
......@@ -47,15 +47,15 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> {
#undef cb
}
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, 4, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4]); \
c[1][step] = Func::template impl<lane>(c[1][step], weight[1][lane], \
src[(step + src_idx) % 4]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4], lane); \
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \
src[(step + src_idx) % 4], lane);
UNROLL_CALL_RAW(4, cb, 0);
UNROLL_CALL_RAW(4, cb, 1);
......@@ -64,13 +64,13 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> {
#undef cb
}
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, 8, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8], lane);
UNROLL_CALL_RAW(8, cb, 0);
UNROLL_CALL_RAW(8, cb, 1);
......@@ -79,13 +79,13 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> {
#undef cb
}
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, 4, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4], lane);
UNROLL_CALL_RAW(4, cb, 0);
UNROLL_CALL_RAW(4, cb, 1);
......@@ -95,11 +95,11 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> {
}
};
template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block,
typename T, typename T2, typename T3>
template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T,
typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, ow_block, T, T2, T3,
int>::impl(c, src, weight);
ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, T, T2, T3, int>::impl(
c, src, weight);
};
template <int oc>
struct OCHelper {
......@@ -162,13 +162,11 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> {
0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw;
weight_ptr += ld_weight_fh;
}
......@@ -209,18 +207,15 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step);
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<2, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw;
weight_ptr += ld_weight_fh;
}
......@@ -260,32 +255,27 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step);
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<2, 0, c_dim, ow_block>(c, src, weight);
src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step);
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<3, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<3, 0, c_dim, ow_block>(c, src, weight);
src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step);
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<4, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<4, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw;
weight_ptr += ld_weight_fh;
}
......@@ -326,44 +316,37 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step);
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<2, 0, c_dim, ow_block>(c, src, weight);
src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step);
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<3, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<3, 0, c_dim, ow_block>(c, src, weight);
src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step);
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<4, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<4, 0, c_dim, ow_block>(c, src, weight);
src[4] = vld1q_f32(src_ptr + (ow_block + 4) * ic_step);
load_helper<ic_step, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<5, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<5, 0, c_dim, ow_block>(c, src, weight);
src[5] = vld1q_f32(src_ptr + (ow_block + 5) * ic_step);
load_helper<ic_step, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<6, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<6, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw;
weight_ptr += ld_weight_fh;
}
......@@ -375,36 +358,14 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
} // namespace
void conv_bias::pack_src_fp32_nchw44_stride1(
float* sptr_base, const float* sptr_origin, const int, const int pw,
const int pad_right, const int ih, const int iw, const int iw2,
const int pad_top, const int pad_bottom, const int ic,
const int ic_stride) {
constexpr int ic_step = 4;
rep_step(ic_idx, ic, ic_step) {
const float* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step);
sptr_base += iw2 * pad_top * ic_step;
rep(ih_idx, ih) {
memset(sptr_base, 0, sizeof(float) * pw * ic_step);
sptr_base += pw * ic_step;
memcpy(sptr_base, sptr, sizeof(float) * iw * ic_step);
sptr_base += iw * ic_step;
sptr += iw * ic_step;
memset(sptr_base, 0, sizeof(float) * pad_right * ic_step);
sptr_base += pad_right * ic_step;
}
memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step);
sptr_base += iw2 * pad_bottom * ic_step;
}
}
template <BiasMode bias_mode, typename Op, int filter_size>
static void conv_direct_stride1_fp32_nchw44(
const float32_t* src, const float32_t* filter, const float32_t* bias,
float32_t*, float32_t* dst, const int oc, const int ic, const int ih,
const int iw, const int oh, const int oh_block, const int ow,
const Op& op, const int, const int) {
template <BiasMode bias_mode, typename Op, int filter_size, int stride>
void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter,
const float* bias, float*, float* dst,
const int oc, const int ic,
const int ih, const int iw,
const int oh, const int oh_block,
const int ow, const Op& op, const int,
const int) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
......@@ -518,55 +479,23 @@ static void conv_direct_stride1_fp32_nchw44(
}
}
#define CONSTRUCT_FUNC(filter_size) \
template <BiasMode bias_mode, typename Op> \
void conv_bias:: \
conv_direct_stride1_##filter_size##x##filter_size##_fp32_nchw44( \
const float32_t* src, const float32_t* filter, \
const float32_t* bias, float32_t* temp, float32_t* dst, \
const int oc, const int ic, const int ih, const int iw, \
const int oh, const int oh_block, const int ow, \
const Op& op, const int ph, const int pw) { \
conv_direct_stride1_fp32_nchw44<bias_mode, Op, filter_size>( \
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, oh_block, \
ow, op, ph, pw); \
}
CONSTRUCT_FUNC(2);
CONSTRUCT_FUNC(3);
CONSTRUCT_FUNC(5);
CONSTRUCT_FUNC(7);
#undef CONSTRUCT_FUNC
#define INSTANTIATION(stride, i, bias, Op) \
template void conv_bias::conv_direct_##stride##_##i##x##i##_fp32_nchw44< \
bias, Op>(const float32_t*, const float32_t*, const float32_t*, \
float32_t*, float32_t*, const int, const int, const int, \
const int, const int, const int, const int, const Op&, \
const int, const int);
#define FOR_OP(stride, i, bias) \
INSTANTIATION(stride, i, bias, NoneOp<dt_float32>) \
INSTANTIATION(stride, i, bias, ReluOp<dt_float32>) \
INSTANTIATION(stride, i, bias, HSwishOp<dt_float32>) \
INSTANTIATION(stride, i, bias, SigmoidOp<dt_float32>)
#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(stride, i, BiasMode::BIAS)
#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)
FOR_FILTER(stride1)
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
#define INSTANTIATION(filter_size, bias_mode, Op) \
template void \
conv_bias::conv_direct_fp32_nchw44<bias_mode, Op, filter_size, 1>( \
const float* src, const float* filter, const float* bias, float*, \
float* dst, const int oc, const int ic, const int ih, \
const int iw, const int oh, const int oh_block, const int ow, \
const Op& op, const int, const int);
#define FOR_OP(filter_size, bias) \
INSTANTIATION(filter_size, bias, NoneOp<dt_float32>) \
INSTANTIATION(filter_size, bias, ReluOp<dt_float32>) \
INSTANTIATION(filter_size, bias, HSwishOp<dt_float32>) \
INSTANTIATION(filter_size, bias, SigmoidOp<dt_float32>)
#define INSTANTIATION_CONV_S1(filter_size) \
FOR_OP(filter_size, BiasMode::NO_BIAS) \
FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(filter_size, BiasMode::BIAS)
// vim: syntax=cpp.doxygen
\ No newline at end of file
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(2, 1);
\ No newline at end of file
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(2, 2);
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(3, 1);
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(3, 2);
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(5, 1);
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(5, 2);
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(7, 1);
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(7, 2);
......@@ -13,8 +13,8 @@
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "midout.h"
......@@ -112,17 +112,11 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2);
float* sptr = reinterpret_cast<float*>((int8_t*)bundle.get(0) +
ncb_index.thread_id * src_size);
if (stride == 1) {
conv_bias::pack_src_fp32_nchw44_stride1(
sptr, origin_sptr, ph, pw, remain_right_pad,
ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad,
src_bottom_pad, ic, ih * iw);
} else {
conv_bias::pack_src_fp32_nchw44_stride2(
sptr, origin_sptr, ph, pw, remain_right_pad,
ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad,
src_bottom_pad, ic, ih * iw);
}
conv_bias::pack_src_fp32_nchw44<stride>(
sptr, origin_sptr, ph, pw, remain_right_pad,
ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad,
src_bottom_pad, ic, ih * iw);
const float* fptr =
kern_param.filter<dt_float32>(group_id) + oc_idx * fh * fw * ic;
......@@ -135,25 +129,9 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
kern_param.bias<dt_float32>(batch_id, group_id) + bias_offset;
Op op;
if (stride == 1) {
#define KERN1_NCHW44_CONV(filter) \
conv_bias::conv_direct_stride1_##filter##x##filter##_fp32_nchw44< \
\
bias_mode, Op>(sptr, fptr, bptr, nullptr, dst, oc_block, ic, \
ih_real, iw2, oh, oh_block_real, ow, op, ph, pw)
DISPATCH_FILTER(filter, KERN1_NCHW44_CONV);
#undef KERN1_NCHW44_CONV
} else {
#define KERN1_NCHW44_CONV(filter) \
conv_bias::conv_direct_stride2_##filter##x##filter##_fp32_nchw44< \
\
bias_mode, Op>(sptr, fptr, bptr, nullptr, dst, oc_block, ic, \
ih_real, iw2, oh, oh_block_real, ow, op, ph, pw)
DISPATCH_FILTER(filter, KERN1_NCHW44_CONV);
#undef KERN1_NCHW44_CONV
}
conv_bias::conv_direct_fp32_nchw44<bias_mode, Op, filter, stride>(
sptr, fptr, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh,
oh_block_real, ow, op, ph, pw);
}
} // namespace
......
......@@ -15,26 +15,20 @@
namespace megdnn {
namespace arm_common {
namespace conv_bias {
#define KERN(stride, i, layout) \
template <BiasMode bias_mode, typename Op> \
void conv_direct_##stride##_##i##x##i##_fp32_##layout( \
const float* src, const float* filter, const float* bias, \
float* temp, float* dst, const int oc, const int ic, const int ih, \
const int iw, const int oh, const int oh_block, const int ow, \
const Op& op, const int ph, const int pw);
KERN(stride1, 2, nchw44)
KERN(stride1, 3, nchw44)
KERN(stride1, 5, nchw44)
KERN(stride1, 7, nchw44)
#undef KERN
template <BiasMode bias_mode, typename Op, int filter_size, int stride>
void conv_direct_fp32_nchw44(const float* src, const float* filter,
const float* bias, float*, float* dst,
const int oc, const int ic, const int ih,
const int iw, const int oh, const int oh_block,
const int ow, const Op& op, const int, const int);
template <int stride>
void pack_src_fp32_nchw44(float* sptr_base, const float* sptr_origin, const int,
const int pw, const int pad_right, const int ih,
const int iw, const int iw2, const int pad_top,
const int pad_bottom, const int ic,
const int ic_stride);
void pack_src_fp32_nchw44_stride1(float* sptr_base, const float* sptr_origin,
const int ph, const int pw,
const int pad_right, const int ih,
const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride);
} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn
......@@ -120,7 +120,8 @@ static void pack_weight(const WorkspaceBundle& bundle,
kern_param.filter<dt_float32>(group_id) + oc_idx * fh * fw * ic;
auto packed_weight = reinterpret_cast<float*>(bundle.get(1)) +
group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw;
pack_weight_fp32_nchw_nchw44(fptr, packed_weight, oc_block, fh, fw, ic);
fp32_direct_nchw_nchw44::pack_weight_fp32_nchw_nchw44(fptr, packed_weight,
oc_block, fh, fw, ic);
}
template <size_t filter_size, BiasMode bias_mode, typename Op, size_t stride>
......@@ -180,7 +181,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
kern_param.bias<dt_float32>(batch_id, group_id) + oc_idx;
Op op;
conv_direct_fp32_nchw_nchw44<bias_mode, Op, filter_size, stride>(
fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44<bias_mode, Op,
filter_size, stride>(
sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih_real, iw2,
oh, oh_block_real, ow, op, ph, pw);
}
......
......@@ -20,295 +20,12 @@
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace {
/**
*\brief ShiftCalHelper is core calculate code
*\tparam src_idx is offset for src regs
*\tparam weight_idx is offset for weight regs
*\tparam T is type of output regs
*\tparam T2 is type of src regs
*\tparam T3 is type of weight regs
*/
template <int src_idx, int weight_idx, int c_dim, typename Func, int stride,
typename T, typename T2, typename T3>
struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
};
template <int src_idx, int weight_idx, typename Func, int stride, typename T,
typename T2, typename T3>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, stride, T, T2, T3> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \
c[0][step], weight[0][weight_idx], \
src[(step * stride + src_idx) / 4]); \
c[1][step] = Func::template impl<(step * stride + src_idx) % 4>( \
c[1][step], weight[1][weight_idx], \
src[(step * stride + src_idx) / 4]);
UNROLL_CALL_RAW(8, cb);
#undef cb
}
};
template <int src_idx, int weight_idx, typename Func, int stride, typename T,
typename T2, typename T3>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, stride, T, T2, T3> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \
c[0][step], weight[0][weight_idx], \
src[(step * stride + src_idx) / 4]);
UNROLL_CALL_RAW(8, cb);
#undef cb
}
};
template <int src_idx, int weight_idx, int c_dim, typename FUNC, int stride,
typename T, typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, stride, T, T2, T3>::impl(
c, src, weight);
};
template <int oc>
struct OCHelper {
public:
static const int val = -1;
};
template <>
struct OCHelper<4> {
public:
static const int val = 1;
};
template <>
struct OCHelper<8> {
public:
static const int val = 2;
};
/**
* oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel
**/
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int oc_block, int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32 {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op);
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 7, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 7;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;
constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \
src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<3, 3, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<4, 4, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<5, 5, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<6, 6, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
UNROLL_CALL_RAW(7, KERNEL_CB)
#undef KERNEL_CB
src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 5, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 5;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;
constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \
src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<3, 3, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<4, 4, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
UNROLL_CALL_RAW(5, KERNEL_CB)
#undef KERNEL_CB
src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 3;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;
constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
// row 0
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr,
0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
namespace fp32_direct_nchw_nchw44 {
// row 1
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
// row 2
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + 2 * iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 2, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 2;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;
constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
// row 0
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr,
0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
// row 1
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};
void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, float32_t* dst_ptr,
const int oc, const int kh, const int kw,
const int ic) {
static inline void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr,
float32_t* dst_ptr,
const int oc, const int kh,
const int kw, const int ic) {
constexpr int oc_step = 4;
const int filter_oc_stride = kh * kw * ic;
const int filter_ic_stride = kh * kw * oc_step;
......@@ -327,115 +44,15 @@ void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, float32_t* dst_ptr,
}
}
}
template <BiasMode bias_mode, typename Op, int filter_size, int stride>
static void conv_direct_fp32_nchw_nchw44(
const float32_t* src, const float32_t* filter, const float32_t* bias,
float32_t*, float32_t* dst, const int oc, const int ic, const int ih,
const int iw, const int oh, const int oh_block, const int ow,
const Op& op, const int, const int) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 1;
constexpr int big_oc_step = 8;
constexpr int oc_step = 4;
constexpr int ih_step = 1;
constexpr int oh_step = 1;
constexpr int ow_step = 8;
constexpr int stride_h = stride;
constexpr int stride_w = stride;
constexpr int pack_iw_len = 1;
void conv_direct_fp32_nchw_nchw44(const float32_t* src, const float32_t* filter,
const float32_t* bias, float32_t*,
float32_t* dst, const int oc, const int ic,
const int ih, const int iw, const int oh,
const int oh_block, const int ow,
const Op& op, const int, const int);
} // namespace fp32_direct_nchw_nchw44
const int img_stride = oh * ow;
const int ow_end = ow / ow_step * ow_step;
const int ow_remain = ow - ow_end;
const int oc_end = oc / big_oc_step * big_oc_step;
const int oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;
using remain_fun = std::function<void(
const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \
big_oc_step, stride, ow_step>::impl; \
kern_small_oc_remain = \
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \
oc_step, stride, ow_step>::impl; \
break;
UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}
for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, ow_step, filter_size,
big_oc_step, stride,
ow_step>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
if (oc_remain > 0) {
int oc_idx = oc_end;
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, ow_step, filter_size,
oc_step, stride,
ow_step>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}
}
} // namespace
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace conv_bias {
#define KERN(stride, i, layout) \
template <BiasMode bias_mode, typename Op> \
void conv_direct_##stride##_##i##x##i##_fp32_##layout( \
const float* src, const float* filter, const float* bias, \
float* temp, float* dst, const int oc, const int ic, const int ih, \
const int iw, const int oh, const int oh_block, const int ow, \
const Op& op, const int ph, const int pw);
KERN(stride2, 2, nchw44)
KERN(stride2, 3, nchw44)
KERN(stride2, 5, nchw44)
KERN(stride2, 7, nchw44)
#undef KERN
void pack_src_fp32_nchw44_stride2(float* sptr_base, const float* sptr_origin,
const int ph, const int pw,
const int pad_right, const int ih,
const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride);
} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#ifdef __ARM_FEATURE_DOTPROD
......@@ -17,7 +18,7 @@
#include "src/fallback/conv_bias/common.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h"
namespace megdnn {
namespace arm_common {
namespace direct_dotprod_nchw44 {
......@@ -139,234 +140,9 @@ void copy_packed_src_int8_nchw44<2>(int8_t* dst, const int dst_step,
}
}
template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int filter_size>
void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow,
const int8_t* src, const int ih, const int iw,
const int8_t* filter, const int32_t* bias,
const int oh_size, const int oc, const int ic,
const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int IC_PACK_SIZE = 4;
constexpr int OC_PACK_SIZE = 4;
#if MEGDNN_AARCH64
constexpr int OC_BIG_INTERVAL = 12;
constexpr int OC_MID_INTERVAL = 8;
constexpr int OC_SMA_INTERVAL = 4;
#else
constexpr int OC_BIG_INTERVAL = 4;
constexpr int OC_MID_INTERVAL = 4;
constexpr int OC_SMA_INTERVAL = 4;
#endif
constexpr int OW_INTERVAL = 8;
constexpr int SH = stride;
const int dst_numbers_per_channel = oh * ow;
const int ow_remain = ow % OW_INTERVAL;
const int ow_end_idx = ow - ow_remain;
const int oc_remain =
oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8
const int oc_end_idx = oc - oc_remain;
const int dst_numbers_4channel_packed =
dst_numbers_per_channel * OC_PACK_SIZE;
using remain_fun = std::function<void(
dst_type * dst, const int dst_step, const int8_t* src, const int ih,
const int iw, const int8_t* filter, const int32_t* bias,
const int ic, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_mid_oc_remain = nullptr;
remain_fun kern_sma_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_BIG_INTERVAL, \
OW_INTERVAL>::impl; \
kern_mid_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_MID_INTERVAL, \
OW_INTERVAL>::impl; \
kern_sma_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_SMA_INTERVAL, \
OW_INTERVAL>::impl; \
break;
UNROLL_CALL_RAW(8, cb);
#undef cb
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}
//! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC]
//! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL,
//! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates
//! [OW_INTERVAL, 1, OC_INTERVAL] each time
for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) {
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_BIG_INTERVAL, OW_INTERVAL>::
impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
kern_big_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}
#ifdef MEGDNN_AARCH64
//! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32
if (oc_remain) {
int oc_idx = oc_end_idx;
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_MID_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
} else {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_SMA_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
}
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
kern_mid_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
} else {
kern_sma_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}
}
#endif
}
#define CONSTRUCT_FUNC(filter_size) \
template <typename dst_type, BiasMode bias_mode, typename Op, int stride> \
void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \
dst_type* dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op) { \
conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, Op, \
filter_size>( \
dst, oh, ow, src, ih, iw, weight, bias, oh_size, oc, ic, op); \
}
CONSTRUCT_FUNC(2);
CONSTRUCT_FUNC(3);
CONSTRUCT_FUNC(5);
CONSTRUCT_FUNC(7);
#undef CONSTRUCT_FUNC
#define INSTANTIATION(dst_type, stride, i, bias_mode, Op) \
template void conv_direct_##i##x##i##_int8_nchw44<dst_type, bias_mode, Op, \
stride>( \
dst_type * dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op);
#define FOR_OP(stride, i, bias_mode) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int32, stride, i, bias_mode, \
NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)
#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)
#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)
FOR_FILTER(1)
FOR_FILTER(2)
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn
#endif
//vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#if __ARM_FEATURE_DOTPROD
......@@ -42,20 +43,13 @@ using BiasMode = ConvBiasForward::BiasMode;
* @return none
*/
#define KERN(filter_size) \
template <typename dst_type, BiasMode bias_mode, typename Op, int stride> \
void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \
dst_type* dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op)
KERN(2);
KERN(3);
KERN(5);
KERN(7);
#undef KERN
template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int filter_size>
void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow,
const int8_t* src, const int ih, const int iw,
const int8_t* filter, const int32_t* bias,
const int oh_size, const int oc, const int ic,
const Op& op);
/**
* @brief : copy data from src to dst for direct conv with no side effect
* @param : [output ptr] dst
......@@ -84,4 +78,4 @@ void copy_packed_src_int8_nchw44(int8_t* dst, const int dst_step,
#endif
//vim: syntax=cpp.doxygen
\ No newline at end of file
// vim: syntax=cpp.doxygen
\ No newline at end of file
......@@ -148,14 +148,10 @@ static void conv_kern(const WorkspaceBundle& bundle,
float scale_dst = ncb_param.dst_type.param<dtype::QuantizedS8>().scale;
op = Op(scale_bias, scale_dst);
}
#define KERN1_NCHW44_CONV(filter) \
direct_dotprod_nchw44::conv_direct_##filter##x##filter##_int8_nchw44< \
dst_type, bias_mode, Op, stride>(dst, OH, OW, copy_dst, \
ih_real_size, iw2, weights, bias, \
oh_real_size, OC, IC, op);
DISPATCH_FILTER(filter_size, KERN1_NCHW44_CONV);
#undef KERN1_NCHW44_CONV
direct_dotprod_nchw44::conv_direct_sdot_int8_nchw44<
dst_type, stride, bias_mode, Op, filter_size>(
dst, OH, OW, copy_dst, ih_real_size, iw2, weights, bias,
oh_real_size, OC, IC, op);
}
} // namespace
......@@ -342,4 +338,4 @@ ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns(
#endif
//vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
......@@ -9,8 +11,7 @@
* implied.
*/
#pragma once
#ifdef __ARM_FEATURE_DOTPROD
#if __ARM_FEATURE_DOTPROD
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
......@@ -98,7 +99,6 @@ struct StoreOCxOWx {
template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<1, ow_remain, Op, T> {
static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr,
const int ld_dst_oc) {
MEGDNN_MARK_USED_VAR(ld_dst_oc);
......@@ -206,22 +206,23 @@ MEGDNN_ALWAYS_INLINE void store_ocx_owx_remain_static(int32x4_t res[][8],
}
template <int res_row, int src_row, int src_start_idx, int weight_idx,
typename FUNC, typename T, typename T2, typename T3>
typename T, typename T2, typename T3>
struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& res, T2& src, T3& weight) {
#define cb(step) \
res[res_row][step] = FUNC::template impl<((src_start_idx + step) % 4)>( \
res[res_row][step], weight[weight_idx], \
src[src_row][(src_start_idx + step) / 4]);
#define cb(step) \
res[res_row][step] = \
vdotq_laneq_s32(res[res_row][step], weight[weight_idx], \
src[src_row][(src_start_idx + step) / 4], \
(src_start_idx + step) % 4);
UNROLL_CALL_RAW(8, cb);
#undef cb
}
};
template <int res_row, int src_row, int src_start_idx, int weight_idx,
typename FUNC, typename T, typename T2, typename T3>
typename T, typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& res, T2& src, T3& weight) {
ShiftCalHelper<res_row, src_row, src_start_idx, weight_idx, FUNC, T, T2,
ShiftCalHelper<res_row, src_row, src_start_idx, weight_idx, T, T2,
T3>::impl(res, src, weight);
};
......@@ -237,199 +238,8 @@ struct KernNeonSdotNCHW44 {
const int32_t* bias, const int ic, const Op& op);
};
template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44<dst_type, 1, bias_mode, Op, ow_remain, filter_size,
oc_interval, ow_interval> {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int filter_next_row =
FW * OC_PACK_SIZE *
IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int filter_next_4oc =
FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int src_next_ic = ih * iw;
const int src_next_row = iw * IC_PACK_SIZE;
constexpr int NSRC = (ow_interval + filter_size - 1) / 4 + 1;
constexpr int LOOP = oc_interval / 4;
int32x4_t res[3][ow_interval];
init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE);
for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) {
const int8_t* i_src = src + ic_idx * src_next_ic;
const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE;
for (int fh_idx = 0; fh_idx < FH; ++fh_idx) {
int8x16_t src[1][4];
int8x16_t weight[3];
load_helper<NSRC, 0, SIMD_LEN, 1, Vld1q_s8>(src, i_src, 0);
//! do not use switch order 3,2,1 because it will slow the speed.
#define CALC_PART(step) \
switch (LOOP) { \
case 1: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \
break; \
case 2: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \
break; \
case 3: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \
weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \
filter_next_col * step); \
cal_helper<2, 0, step, 2, Vdotq_laneq_s32>(res, src, weight); \
break; \
default: \
break; \
}
switch (filter_size) {
case 2:
UNROLL_CALL_RAW(2, CALC_PART);
break;
case 3:
UNROLL_CALL_RAW(3, CALC_PART);
break;
case 5:
UNROLL_CALL_RAW(5, CALC_PART);
break;
case 7:
UNROLL_CALL_RAW(7, CALC_PART);
break;
default:
break;
}
#undef CALC_PART
i_filter += filter_next_row;
i_src += src_next_row;
}
}
store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst,
dst_step);
}
};
template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size,
oc_interval, ow_interval> {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int filter_next_row =
FW * OC_PACK_SIZE *
IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int filter_next_4oc =
FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int src_next_ic = ih * iw;
const int src_next_row = iw * IC_PACK_SIZE;
constexpr int NSRC = (ow_interval * 2 + filter_size - 3) / 8 + 1;
constexpr int LOOP = oc_interval / 4;
int32x4_t res[3][ow_interval];
init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE);
for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) {
const int8_t* i_src = src + ic_idx * src_next_ic;
const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE;
for (int fh_idx = 0; fh_idx < FH; ++fh_idx) {
int8x16_t src[2][3];
int8x16_t weight[3];
const int offset = megdnn::div_ceil(iw, 2) * IC_PACK_SIZE;
load_helper<NSRC, 0, SIMD_LEN, 2, Vld1q_s8>(src, i_src, offset);
//! do not use switch order 3,2,1 because it will slow the speed.
#define CALC_PART(step) \
switch (LOOP) { \
case 1: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \
weight); \
break; \
case 2: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \
weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \
weight); \
break; \
case 3: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \
weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \
weight); \
weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \
filter_next_col * step); \
cal_helper<2, step % 2, step / 2, 2, Vdotq_laneq_s32>(res, src, \
weight); \
break; \
default: \
break; \
}
switch (filter_size) {
case 2:
UNROLL_CALL_RAW(2, CALC_PART);
break;
case 3:
UNROLL_CALL_RAW(3, CALC_PART);
break;
case 5:
UNROLL_CALL_RAW(5, CALC_PART);
break;
case 7:
UNROLL_CALL_RAW(7, CALC_PART);
break;
default:
break;
}
#undef CALC_PART
i_filter += filter_next_row;
i_src += src_next_row;
}
}
store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst,
dst_step);
}
};
} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
\ No newline at end of file
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h"
namespace megdnn {
namespace arm_common {
namespace direct_dotprod_nchw44 {
template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44<dst_type, 1, bias_mode, Op, ow_remain, filter_size,
oc_interval, ow_interval> {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int filter_next_row =
FW * OC_PACK_SIZE *
IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int filter_next_4oc =
FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int src_next_ic = ih * iw;
const int src_next_row = iw * IC_PACK_SIZE;
constexpr int NSRC = (ow_interval + filter_size - 1) / 4 + 1;
constexpr int LOOP = oc_interval / 4;
int32x4_t res[3][ow_interval];
init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE);
for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) {
const int8_t* i_src = src + ic_idx * src_next_ic;
const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE;
for (int fh_idx = 0; fh_idx < FH; ++fh_idx) {
int8x16_t src[1][4];
int8x16_t weight[3];
load_helper<NSRC, 0, SIMD_LEN, 1, Vld1q_s8>(src, i_src, 0);
//! do not use switch order 3,2,1 because it will slow the speed.
#define CALC_PART(step) \
switch (LOOP) { \
case 1: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0>(res, src, weight); \
break; \
case 2: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, 0, step, 1>(res, src, weight); \
break; \
case 3: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, 0, step, 1>(res, src, weight); \
weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \
filter_next_col * step); \
cal_helper<2, 0, step, 2>(res, src, weight); \
break; \
default: \
break; \
}
switch (filter_size) {
case 2:
UNROLL_CALL_RAW(2, CALC_PART);
break;
case 3:
UNROLL_CALL_RAW(3, CALC_PART);
break;
case 5:
UNROLL_CALL_RAW(5, CALC_PART);
break;
case 7:
UNROLL_CALL_RAW(7, CALC_PART);
break;
default:
break;
}
#undef CALC_PART
i_filter += filter_next_row;
i_src += src_next_row;
}
}
store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst,
dst_step);
}
};
template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int filter_size>
void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow,
const int8_t* src, const int ih, const int iw,
const int8_t* filter, const int32_t* bias,
const int oh_size, const int oc, const int ic,
const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int IC_PACK_SIZE = 4;
constexpr int OC_PACK_SIZE = 4;
#if MEGDNN_AARCH64
constexpr int OC_BIG_INTERVAL = 12;
constexpr int OC_MID_INTERVAL = 8;
constexpr int OC_SMA_INTERVAL = 4;
#else
constexpr int OC_BIG_INTERVAL = 4;
constexpr int OC_MID_INTERVAL = 4;
constexpr int OC_SMA_INTERVAL = 4;
#endif
constexpr int OW_INTERVAL = 8;
constexpr int SH = stride;
const int dst_numbers_per_channel = oh * ow;
const int ow_remain = ow % OW_INTERVAL;
const int ow_end_idx = ow - ow_remain;
const int oc_remain =
oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8
const int oc_end_idx = oc - oc_remain;
const int dst_numbers_4channel_packed =
dst_numbers_per_channel * OC_PACK_SIZE;
using remain_fun = std::function<void(
dst_type * dst, const int dst_step, const int8_t* src, const int ih,
const int iw, const int8_t* filter, const int32_t* bias,
const int ic, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_mid_oc_remain = nullptr;
remain_fun kern_sma_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_BIG_INTERVAL, \
OW_INTERVAL>::impl; \
kern_mid_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_MID_INTERVAL, \
OW_INTERVAL>::impl; \
kern_sma_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_SMA_INTERVAL, \
OW_INTERVAL>::impl; \
break;
UNROLL_CALL_RAW(8, cb);
#undef cb
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}
//! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC]
//! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL,
//! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates
//! [OW_INTERVAL, 1, OC_INTERVAL] each time
for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) {
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_BIG_INTERVAL, OW_INTERVAL>::
impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
kern_big_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}
#ifdef MEGDNN_AARCH64
//! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32
if (oc_remain) {
int oc_idx = oc_end_idx;
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_MID_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
} else {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_SMA_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
}
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
kern_mid_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
} else {
kern_sma_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}
}
#endif
}
#define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \
template void conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, \
Op, filter_size>( \
dst_type * dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op);
#define FOR_OP(stride, i, bias_mode) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int32, stride, i, bias_mode, \
NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)
#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)
#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)
FOR_FILTER(1)
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
\ No newline at end of file
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h"
namespace megdnn {
namespace arm_common {
namespace direct_dotprod_nchw44 {
template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size,
oc_interval, ow_interval> {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int filter_next_row =
FW * OC_PACK_SIZE *
IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int filter_next_4oc =
FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int src_next_ic = ih * iw;
const int src_next_row = iw * IC_PACK_SIZE;
constexpr int NSRC = (ow_interval * 2 + filter_size - 3) / 8 + 1;
constexpr int LOOP = oc_interval / 4;
int32x4_t res[3][ow_interval];
init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE);
for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) {
const int8_t* i_src = src + ic_idx * src_next_ic;
const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE;
for (int fh_idx = 0; fh_idx < FH; ++fh_idx) {
int8x16_t src[2][3];
int8x16_t weight[3];
const int offset = megdnn::div_ceil(iw, 2) * IC_PACK_SIZE;
load_helper<NSRC, 0, SIMD_LEN, 2, Vld1q_s8>(src, i_src, offset);
//! do not use switch order 3,2,1 because it will slow the speed.
#define CALC_PART(step) \
switch (LOOP) { \
case 1: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \
break; \
case 2: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, step % 2, step / 2, 1>(res, src, weight); \
break; \
case 3: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, step % 2, step / 2, 1>(res, src, weight); \
weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \
filter_next_col * step); \
cal_helper<2, step % 2, step / 2, 2>(res, src, weight); \
break; \
default: \
break; \
}
switch (filter_size) {
case 2:
UNROLL_CALL_RAW(2, CALC_PART);
break;
case 3:
UNROLL_CALL_RAW(3, CALC_PART);
break;
case 5:
UNROLL_CALL_RAW(5, CALC_PART);
break;
case 7:
UNROLL_CALL_RAW(7, CALC_PART);
break;
default:
break;
}
#undef CALC_PART
i_filter += filter_next_row;
i_src += src_next_row;
}
}
store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst,
dst_step);
}
};
template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int filter_size>
void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow,
const int8_t* src, const int ih, const int iw,
const int8_t* filter, const int32_t* bias,
const int oh_size, const int oc, const int ic,
const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int IC_PACK_SIZE = 4;
constexpr int OC_PACK_SIZE = 4;
#if MEGDNN_AARCH64
constexpr int OC_BIG_INTERVAL = 12;
constexpr int OC_MID_INTERVAL = 8;
constexpr int OC_SMA_INTERVAL = 4;
#else
constexpr int OC_BIG_INTERVAL = 4;
constexpr int OC_MID_INTERVAL = 4;
constexpr int OC_SMA_INTERVAL = 4;
#endif
constexpr int OW_INTERVAL = 8;
constexpr int SH = stride;
const int dst_numbers_per_channel = oh * ow;
const int ow_remain = ow % OW_INTERVAL;
const int ow_end_idx = ow - ow_remain;
const int oc_remain =
oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8
const int oc_end_idx = oc - oc_remain;
const int dst_numbers_4channel_packed =
dst_numbers_per_channel * OC_PACK_SIZE;
using remain_fun = std::function<void(
dst_type * dst, const int dst_step, const int8_t* src, const int ih,
const int iw, const int8_t* filter, const int32_t* bias,
const int ic, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_mid_oc_remain = nullptr;
remain_fun kern_sma_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_BIG_INTERVAL, \
OW_INTERVAL>::impl; \
kern_mid_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_MID_INTERVAL, \
OW_INTERVAL>::impl; \
kern_sma_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_SMA_INTERVAL, \
OW_INTERVAL>::impl; \
break;
UNROLL_CALL_RAW(8, cb);
#undef cb
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}
//! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC]
//! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL,
//! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates
//! [OW_INTERVAL, 1, OC_INTERVAL] each time
for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) {
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_BIG_INTERVAL, OW_INTERVAL>::
impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
kern_big_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}
#ifdef MEGDNN_AARCH64
//! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32
if (oc_remain) {
int oc_idx = oc_end_idx;
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_MID_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
} else {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_SMA_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
}
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
kern_mid_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
} else {
kern_sma_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}
}
#endif
}
#define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \
template void conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, \
Op, filter_size>( \
dst_type * dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op);
#define FOR_OP(stride, i, bias_mode) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int32, stride, i, bias_mode, \
NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)
#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)
#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)
FOR_FILTER(2)
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
\ No newline at end of file
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h"
namespace megdnn {
namespace arm_common {
namespace {
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int oc_block, int stride>
struct KerNeonXXs2NchwNchw44 {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op);
};
template <int oc>
struct OCHelper {
public:
static const int val = 0;
};
template <>
struct OCHelper<4> {
public:
static const int val = 1;
};
template <>
struct OCHelper<8> {
public:
static const int val = 2;
};
} // namespace
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
......@@ -114,7 +114,7 @@ static void copy_padding_kern(const WorkspaceBundle& bundle,
rep(ih_idx, IH) {
std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t));
sptr_base += nr_pad_w;
nchw44_pack_src(sptr, sptr_base, IW);
int8_direct_nchw44::nchw44_pack_src(sptr, sptr_base, IW);
sptr_base += IW * pack_ic * expend_element;
sptr += IW * pack_ic;
std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t));
......@@ -125,8 +125,8 @@ static void copy_padding_kern(const WorkspaceBundle& bundle,
}
}
template <size_t filter, BiasMode bias_mode, typename Op, int ow_remain,
typename DstType, int stride>
template <size_t filter, BiasMode bias_mode, typename Op, typename DstType,
int stride>
static void do_conv_kern(const WorkspaceBundle& bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
......@@ -182,8 +182,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx;
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) +
group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW;
nchw44_pack_filter(fptr, packed_weight, oc_block / 4 * IC / 4 * FH * FW);
conv_direct_int8_nchw44<bias_mode, Op, ow_remain, filter, DstType, stride>(
int8_direct_nchw44::nchw44_pack_filter(fptr, packed_weight,
oc_block / 4 * IC / 4 * FH * FW);
int8_direct_nchw44::conv_direct_int8_nchw44<bias_mode, Op, filter, DstType,
stride>(
sptr, packed_weight, bptr, nullptr, static_cast<DstType*>(dst),
oc_block, IC, IH2, IW2, OH, OW, op);
}
......@@ -233,40 +235,38 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns(
size_t N = param.n;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
size_t OW = param.osz[1];
size_t group = fm.group;
size_t fh = fm.spatial[0];
size_t fw = fm.spatial[1];
WorkspaceBundle wbundle = get_bundle(param);
conv_fun do_conv_fun = nullptr;
int ow_remain = OW % 8;
bool need_post_process = param.dst_type.enumv() == DTypeEnum::QuantizedS8;
// NOTE: remain_w is not used to gen hash of midout for compatible with changing
// shape runtime
#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, remain_w, op) \
#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44, \
midout_iv(#stride #dst_type #filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w, dst_type, \
stride>; \
do_conv_fun = do_conv_kern<filter, bias_mode, op, dst_type, stride>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(stride, filter, bias_mode, remain_w) \
#define GET_OP_PARAM(stride, filter, bias_mode) \
if (need_post_process) { \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
remain_w, \
\
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
remain_w, \
\
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
remain_w, \
\
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
default: \
......@@ -277,7 +277,7 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns(
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, \
remain_w, NoneOp<dt_int32>) \
NoneOp<dt_int32>) \
break; \
default: \
megdnn_assert( \
......@@ -287,48 +287,17 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns(
} \
}
#define GET_REMAIN_W_PARAM(stride, filter, bias_mode) \
switch (ow_remain) { \
case 0: \
GET_OP_PARAM(stride, filter, bias_mode, 0); \
break; \
case 1: \
GET_OP_PARAM(stride, filter, bias_mode, 1); \
break; \
case 2: \
GET_OP_PARAM(stride, filter, bias_mode, 2); \
break; \
case 3: \
GET_OP_PARAM(stride, filter, bias_mode, 3); \
break; \
case 4: \
GET_OP_PARAM(stride, filter, bias_mode, 4); \
break; \
case 5: \
GET_OP_PARAM(stride, filter, bias_mode, 5); \
break; \
case 6: \
GET_OP_PARAM(stride, filter, bias_mode, 6); \
break; \
case 7: \
GET_OP_PARAM(stride, filter, bias_mode, 7); \
break; \
default: \
megdnn_assert(0); \
}
#define GET_BIAS_MODE_PARAM(stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_REMAIN_W_PARAM(stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_REMAIN_W_PARAM(stride, filter, \
BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
#define GET_BIAS_MODE_PARAM(stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN(stride) \
......
......@@ -117,11 +117,11 @@ static void copy_padding_kern(const WorkspaceBundle& bundle,
const size_t tmp_size = get_temp_bytes(iw, pw);
int8_t* tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) +
ncb_index.thread_id * tmp_size;
pack_nchw_src_for_nchw44_conv<1>(sptr, sptr_base, 1, ph, ph, pw, pw, ih,
iw, iw2, pw, tmp_ptr);
int8_direct_nchw_nchw44::pack_nchw_src_for_nchw44_conv<1>(
sptr, sptr_base, 1, ph, ph, pw, pw, ih, iw, iw2, pw, tmp_ptr);
} else {
pack_nchw_src_for_nchw44_conv<2>(sptr, sptr_base, 1, ph, ph, pw, pw, ih,
iw, iw2, pw, nullptr);
int8_direct_nchw_nchw44::pack_nchw_src_for_nchw44_conv<2>(
sptr, sptr_base, 1, ph, ph, pw, pw, ih, iw, iw2, pw, nullptr);
}
}
static void pack_weight(const WorkspaceBundle& bundle,
......@@ -142,11 +142,11 @@ static void pack_weight(const WorkspaceBundle& bundle,
group_id * oc * ic * fh * fw2 + oc_idx * ic * fh * fw2;
if (stride_h == 1) {
pack_nchw44_weight_for_nchw_conv<1>(fptr, packed_weight, ic, fh, fw,
oc_block);
int8_direct_nchw_nchw44::pack_nchw44_weight_for_nchw_conv<1>(
fptr, packed_weight, ic, fh, fw, oc_block);
} else {
pack_nchw44_weight_for_nchw_conv<2>(fptr, packed_weight, ic, fh, fw,
oc_block);
int8_direct_nchw_nchw44::pack_nchw44_weight_for_nchw_conv<2>(
fptr, packed_weight, ic, fh, fw, oc_block);
}
}
template <size_t filter, BiasMode bias_mode, typename Op, int stride>
......@@ -208,7 +208,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
int8_t* packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) +
group_id * oc * ic * fh * fw2 +
oc_idx * ic * fh * fw2;
conv_direct_int8_nchw_nchw44<bias_mode, Op, filter, stride>(
int8_direct_nchw_nchw44::conv_direct_int8_nchw_nchw44<bias_mode, Op, filter,
stride>(
sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih2, iw2, oh,
ow, op);
}
......
......@@ -93,8 +93,8 @@ void do_weight_trans(const WorkspaceBundle& bundle,
const int fw2 = round_up(fw, 4);
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1));
auto origin_weight = kern_param.filter<dt_int8>();
pack_weight_int8_nchw_nchw44_dot(packed_weight, origin_weight, oc, ic, fh,
fw, fw2);
dot_direct_nchw_nchw44::pack_weight_int8_nchw_nchw44_dot(
packed_weight, origin_weight, oc, ic, fh, fw, fw2);
}
template <size_t filter, BiasMode bias_mode, typename Op, int stride>
......@@ -147,7 +147,7 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) +
ncb_index.thread_id * tmp_size;
}
pack_src_int8_nchw_nchw44_dot<stride>(
dot_direct_nchw_nchw44::pack_src_int8_nchw_nchw44_dot<stride>(
sptr, origin_sptr, ph, pw, remain_right_pad,
ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad,
src_bottom_pad, ic, ih * iw, tmp_ptr);
......@@ -164,7 +164,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale;
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale;
Op op(scale_bias, scale_dst);
conv_direct_int8_nchw_nchw44_dot<bias_mode, Op, filter, stride>(
dot_direct_nchw_nchw44::conv_direct_int8_nchw_nchw44_dot<bias_mode, Op,
filter, stride>(
sptr, fptr, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh,
oh_block_real, ow, op);
}
......
......@@ -2344,7 +2344,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
#endif
std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args)
if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
gemv_args.emplace_back(arg);
}
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV");
......@@ -2361,7 +2361,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) {
#endif
std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args)
if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
gemv_args.emplace_back(arg);
}
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV");
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/arm_common/fixture.h"
......@@ -30,8 +31,7 @@ TEST_F(ARM_COMMON, MATRIX_MUL_INT8x8x16) {
TEST_F(ARM_COMMON, MATRIX_MUL_QUINT8) {
matrix_mul::check_matrix_mul(dtype::Quantized8Asymm(1.2f, (uint8_t)127),
dtype::Quantized8Asymm(1.3f, (uint8_t)129),
{},
dtype::Quantized8Asymm(1.3f, (uint8_t)129), {},
handle());
}
......@@ -232,8 +232,7 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEVM) {
Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param;
checker.set_before_exec_callback(
AlgoChecker<MatrixMul>("ARM_COMMON_GEVM"));
checker.set_before_exec_callback(AlgoChecker<MatrixMul>("ARM_COMMON_GEVM"));
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127);
checker.set_rng(0, rng.get()).set_rng(1, rng.get());
......@@ -251,7 +250,7 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEVM) {
.set_dtype(2, dtype::QuantizedS32(6.25f))
.execs({A, B, {}});
};
// M = 1
for (size_t N : {1, 10, 16, 33, 64})
for (size_t K : {7, 512, 1024})
......@@ -263,8 +262,7 @@ TEST_F(ARM_COMMON, FP32_GEVM) {
Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param;
checker.set_before_exec_callback(
AlgoChecker<MatrixMul>("ARM_COMMON_GEVM"));
checker.set_before_exec_callback(AlgoChecker<MatrixMul>("ARM_COMMON_GEVM"));
checker.set_epsilon(1e-2);
auto run = [&](size_t M, size_t K, size_t N) {
......@@ -276,7 +274,7 @@ TEST_F(ARM_COMMON, FP32_GEVM) {
B = TensorShape{N, K};
checker.set_param(param).execs({A, B, {}});
};
// M = 1
for (size_t M : {1})
for (size_t K : {1000, 4096, 25088})
......@@ -298,15 +296,15 @@ TEST_F(ARM_COMMON, FP32_GEMV_MK4) {
param.transposeA = false;
param.transposeB = false;
TensorShape A, B;
A = TensorShape{M/4, K/4, 4, 4};
B = TensorShape{K/4, 1, 4};
A = TensorShape{M / 4, K / 4, 4, 4};
B = TensorShape{K / 4, 1, 4};
checker.set_param(param).execs({A, B, {}});
};
// N = 1
for (size_t M : {4, 16, 128, 1024})
for (size_t K : {4, 8, 12, 128, 256, 4096})
run(M, K);
run(M, K);
}
#if MEGDNN_WITH_BENCHMARK
......@@ -343,7 +341,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV) {
for (size_t M : {4, 64, 1024, 4096})
for (size_t K : {128, 256, 1024, 4096})
run(M, K, 1);
run(M, K, 1);
}
TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) {
......@@ -372,7 +370,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) {
.exec({{2, 1024}, {1024, 512}, {}});
benchmarker.set_display(true);
}
// run gemv
run(12, 48, 1);
run(48, 12, 1);
......@@ -396,14 +394,14 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_MK4) {
Benchmarker<MatrixMul> benchmarker(handle());
benchmarker.set_times(exec_times);
benchmarker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_param(param);
.set_dtype(1, dtype::Float32())
.set_param(param);
auto run = [&](size_t M, size_t K) {
printf("SGEMV_MK4: (%zu, %zu, %zu)\n", M, K, N);
printf("SGEMV_MK4: (%zu, %zu)\n", M, K);
TensorShape A, B;
A = TensorShape{M/4, K/4, 4, 4};
B = TensorShape{K/4, 1, 4};
A = TensorShape{M / 4, K / 4, 4, 4};
B = TensorShape{K / 4, 1, 4};
auto time = benchmarker.exec({A, B, {}}) / exec_times;
auto computations = 2.f * M * K * 1e-6;
auto perf = computations / time;
......@@ -422,7 +420,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_MK4) {
// run gemv mk4
for (size_t M : {4, 64, 1024, 4096})
for (size_t K : {128, 1024, 4096})
run(M, K);
run(M, K);
}
TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) {
......@@ -490,7 +488,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMM) {
//////////////////////// gemv //////////////////////////
for (size_t M : {8, 64, 112, 256}) {
for (size_t K : {8, 64, 112, 256}) {
run (M, 1, K);
run(M, 1, K);
}
}
......@@ -502,10 +500,8 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMM) {
}
}
}
}
TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) {
constexpr size_t RUNS = 50;
param::MatrixMul param;
......@@ -514,7 +510,8 @@ TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) {
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int32{})
.set_param(param).set_display(false);
.set_param(param)
.set_display(false);
Benchmarker<MatrixMul> benchmarker_float(handle());
benchmarker_float.set_display(false).set_times(RUNS);
......@@ -533,7 +530,7 @@ TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) {
//////////////////////// gemv //////////////////////////
for (size_t M : {8, 64, 112, 256}) {
for (size_t K : {8, 64, 112, 256}) {
run (M, 1, K);
run(M, 1, K);
}
}
......@@ -618,5 +615,4 @@ TEST_F(ARM_COMMON, BENCHMARK_TRANSPOSED_MATRIX_MUL_QUINT8) {
#endif
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册