提交 563239d3 编写于 作者: M Megvii Engine Team

feat(dnn): add arm_common nchw44 cwconv3x3s1p1 and cwconv5x5s1p2

GitOrigin-RevId: 9ea411d0e108cdc8bcb6dfa49d9e21fa741cdac1
上级 3344b580
......@@ -19,6 +19,8 @@
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#pragma GCC diagnostic ignored "-Wunused-parameter"
using namespace megdnn;
using namespace arm_common;
using namespace fp16;
......
......@@ -284,7 +284,7 @@ void channel_wise_nchw88::do_conv_kern_stride1_3x3(
const __fp16* src, const __fp16* filter, const __fp16* bias,
__fp16* dst, const size_t IH, const size_t IW, const size_t OH,
const size_t OW, const size_t PH, const size_t PW) {
if (IH == OH && IW == OW && PH == 1 && PW == 1) {
if (IH == OH && IW == OW && IH >= 3 && IW >= 3 && PH == 1 && PW == 1) {
do_conv_kern_3x3_stride1_padding1<bias_mode, Op>(src, dst, filter, bias,
OH, OW);
return;
......
/**
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
#pragma GCC diagnostic ignored "-Wunused-parameter"
using namespace megdnn;
using namespace arm_common;
namespace {
#if defined(__ARM_FEATURE_FMA)
#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m)
#else
#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m)
#endif
template <int shift>
static inline void shift_src(float32x4_t rsrc[3][4]) {
float32x4_t t[4];
t[0] = rsrc[0][(shift + 0) % 4];
t[1] = rsrc[0][(shift + 1) % 4];
t[2] = rsrc[0][(shift + 2) % 4];
t[3] = rsrc[0][(shift + 3) % 4];
rsrc[0][0] = t[0];
rsrc[0][1] = t[1];
rsrc[0][2] = t[2];
rsrc[0][3] = t[3];
t[0] = rsrc[1][(shift + 0) % 4];
t[1] = rsrc[1][(shift + 1) % 4];
t[2] = rsrc[1][(shift + 2) % 4];
t[3] = rsrc[1][(shift + 3) % 4];
rsrc[1][0] = t[0];
rsrc[1][1] = t[1];
rsrc[1][2] = t[2];
rsrc[1][3] = t[3];
t[0] = rsrc[2][(shift + 0) % 4];
t[1] = rsrc[2][(shift + 1) % 4];
t[2] = rsrc[2][(shift + 2) % 4];
t[3] = rsrc[2][(shift + 3) % 4];
rsrc[2][0] = t[0];
rsrc[2][1] = t[1];
rsrc[2][2] = t[2];
rsrc[2][3] = t[3];
}
template <BiasMode bias_mode>
static inline float32x4_t load_bias(const float* bias,
const float32x4_t& init) {
if (bias_mode == BiasMode::BIAS) {
return vld1q_f32(bias);
} else {
return init;
}
}
template <int BW, int bw, bool has_top, bool has_bottom, BiasMode bias_mode>
struct compute_element {
template <typename Op>
static inline void call(const float*& src0, const float*& src1,
const float*& src2, float*& dst, const float*& bias,
const float32x4_t& init, float32x4_t rsrc[3][4],
float32x4_t rfilter[3][3], const Op& op) {
#define RSRC(i, j) rsrc[i][((j) + bw) % 4]
float32x4_t rdst = load_bias<bias_mode>(bias, init);
if (has_top) {
RSRC(0, 3) = vld1q_f32(src0 + 8);
}
{ RSRC(1, 3) = vld1q_f32(src1 + 8); }
if (has_bottom) {
RSRC(2, 3) = vld1q_f32(src2 + 8);
}
if (has_top) {
rdst = Vfmaq_f32(rdst, RSRC(0, 0), rfilter[0][0]);
rdst = Vfmaq_f32(rdst, RSRC(0, 1), rfilter[0][1]);
rdst = Vfmaq_f32(rdst, RSRC(0, 2), rfilter[0][2]);
}
{
rdst = Vfmaq_f32(rdst, RSRC(1, 0), rfilter[1][0]);
rdst = Vfmaq_f32(rdst, RSRC(1, 1), rfilter[1][1]);
rdst = Vfmaq_f32(rdst, RSRC(1, 2), rfilter[1][2]);
}
if (has_bottom) {
rdst = Vfmaq_f32(rdst, RSRC(2, 0), rfilter[2][0]);
rdst = Vfmaq_f32(rdst, RSRC(2, 1), rfilter[2][1]);
rdst = Vfmaq_f32(rdst, RSRC(2, 2), rfilter[2][2]);
}
vst1q_f32(dst, op(rdst));
if (has_top) {
src0 += 4;
}
{ src1 += 4; }
if (has_bottom) {
src2 += 4;
}
dst += 4;
bias += 4;
compute_element<BW, bw + 1, has_top, has_bottom, bias_mode>::call(
src0, src1, src2, dst, bias, init, rsrc, rfilter, op);
#undef RSRC
}
};
template <int BW, bool has_top, bool has_bottom, BiasMode bias_mode>
struct compute_element<BW, BW, has_top, has_bottom, bias_mode> {
template <typename... Types>
static inline void call(Types... args) {}
};
template <bool has_top, bool has_bottom, BiasMode bias_mode>
struct compute_element_right {
template <typename Op>
static inline void call(float*& dst, const float*& bias,
const float32x4_t& init, float32x4_t rsrc[3][4],
float32x4_t rfilter[3][3], const Op& op) {
float32x4_t rdst = load_bias<bias_mode>(bias, init);
if (has_top) {
rdst = Vfmaq_f32(rdst, rsrc[0][0], rfilter[0][0]);
rdst = Vfmaq_f32(rdst, rsrc[0][1], rfilter[0][1]);
rdst = Vfmaq_f32(rdst, rsrc[0][2], rfilter[0][2]);
}
{
rdst = Vfmaq_f32(rdst, rsrc[1][0], rfilter[1][0]);
rdst = Vfmaq_f32(rdst, rsrc[1][1], rfilter[1][1]);
rdst = Vfmaq_f32(rdst, rsrc[1][2], rfilter[1][2]);
}
if (has_bottom) {
rdst = Vfmaq_f32(rdst, rsrc[2][0], rfilter[2][0]);
rdst = Vfmaq_f32(rdst, rsrc[2][1], rfilter[2][1]);
rdst = Vfmaq_f32(rdst, rsrc[2][2], rfilter[2][2]);
}
vst1q_f32(dst, op(rdst));
dst += 4;
bias += 4;
}
};
template <bool has_top, bool has_bottom, BiasMode bias_mode>
struct compute_element_right_pad {
template <typename Op>
static inline void call(float*& dst, const float*& bias,
const float32x4_t& init, float32x4_t rsrc[3][4],
float32x4_t rfilter[3][3], const Op& op) {
float32x4_t rdst = load_bias<bias_mode>(bias, init);
if (has_top) {
rdst = Vfmaq_f32(rdst, rsrc[0][1], rfilter[0][0]);
rdst = Vfmaq_f32(rdst, rsrc[0][2], rfilter[0][1]);
}
{
rdst = Vfmaq_f32(rdst, rsrc[1][1], rfilter[1][0]);
rdst = Vfmaq_f32(rdst, rsrc[1][2], rfilter[1][1]);
}
if (has_bottom) {
rdst = Vfmaq_f32(rdst, rsrc[2][1], rfilter[2][0]);
rdst = Vfmaq_f32(rdst, rsrc[2][2], rfilter[2][1]);
}
vst1q_f32(dst, op(rdst));
dst += 4;
bias += 4;
}
};
template <bool has_top, bool has_bottom, BiasMode bias_mode>
struct compute_row {
template <typename Op>
static inline void call(const float*& src0, const float*& src1,
const float*& src2, float*& dst, const float*& bias,
const float32x4_t& init, float32x4_t rsrc[3][4],
float32x4_t rfilter[3][3], int W, const Op& op) {
if (has_top) {
rsrc[0][0] = vdupq_n_f32(0);
rsrc[0][1] = vld1q_f32(src0 + 0);
rsrc[0][2] = vld1q_f32(src0 + 4);
}
{
rsrc[1][0] = vdupq_n_f32(0);
rsrc[1][1] = vld1q_f32(src1 + 0);
rsrc[1][2] = vld1q_f32(src1 + 4);
}
if (has_bottom) {
rsrc[2][0] = vdupq_n_f32(0);
rsrc[2][1] = vld1q_f32(src2 + 0);
rsrc[2][2] = vld1q_f32(src2 + 4);
}
int w = 0;
const float* src0_ptr = src0;
const float* src1_ptr = src1;
const float* src2_ptr = src2;
float* dst_ptr = dst;
const float* bias_ptr = bias;
for (; w + 3 < W - 2; w += 4) {
compute_element<4, 0, has_top, has_bottom, bias_mode>::call(
src0_ptr, src1_ptr, src2_ptr, dst_ptr, bias_ptr, init, rsrc,
rfilter, op);
}
if (w + 1 < W - 2) {
compute_element<2, 0, has_top, has_bottom, bias_mode>::call(
src0_ptr, src1_ptr, src2_ptr, dst_ptr, bias_ptr, init, rsrc,
rfilter, op);
shift_src<2>(rsrc);
w += 2;
}
if (w < W - 2) {
compute_element<1, 0, has_top, has_bottom, bias_mode>::call(
src0_ptr, src1_ptr, src2_ptr, dst_ptr, bias_ptr, init, rsrc,
rfilter, op);
shift_src<1>(rsrc);
w += 1;
}
// compute rightmost 2 elements seperately
compute_element_right<has_top, has_bottom, bias_mode>::call(
dst_ptr, bias_ptr, init, rsrc, rfilter, op);
compute_element_right_pad<has_top, has_bottom, bias_mode>::call(
dst_ptr, bias_ptr, init, rsrc, rfilter, op);
src0 += W * 4;
src1 += W * 4;
src2 += W * 4;
dst += W * 4;
bias += W * 4;
}
};
} // namespace
template <BiasMode bias_mode, typename Op>
void channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1(
const float* src, float* dst, const float* filter, const float* bias,
int H, int W) {
Op op;
float32x4_t init = vdupq_n_f32(0);
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
init = vld1q_f32(bias);
}
const float* src0 = src - W * 4;
const float* src1 = src;
const float* src2 = src + W * 4;
float32x4_t rfilter[3][3];
rfilter[0][0] = vld1q_f32(filter + 0);
rfilter[0][1] = vld1q_f32(filter + 4);
rfilter[0][2] = vld1q_f32(filter + 8);
rfilter[1][0] = vld1q_f32(filter + 12);
rfilter[1][1] = vld1q_f32(filter + 16);
rfilter[1][2] = vld1q_f32(filter + 20);
rfilter[2][0] = vld1q_f32(filter + 24);
rfilter[2][1] = vld1q_f32(filter + 28);
rfilter[2][2] = vld1q_f32(filter + 32);
float32x4_t rsrc[3][4];
compute_row<false, true, bias_mode>::call(src0, src1, src2, dst, bias, init,
rsrc, rfilter, W, op);
for (int h = 1; h < H - 1; h += 1) {
compute_row<true, true, bias_mode>::call(src0, src1, src2, dst, bias,
init, rsrc, rfilter, W, op);
}
compute_row<true, false, bias_mode>::call(src0, src1, src2, dst, bias, init,
rsrc, rfilter, W, op);
}
#define INSTANTIATION(bias, Op) \
template void \
channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1<bias, Op>( \
const float*, float*, const float*, const float*, int, int);
#define FOR_OP(bias) \
INSTANTIATION(bias, SigmoidOp<dt_float32>) \
INSTANTIATION(bias, ReluOp<dt_float32>) \
INSTANTIATION(bias, HSwishOp<dt_float32>) \
INSTANTIATION(bias, NoneOp<dt_float32>)
#define FOR_BIAS \
FOR_OP(BiasMode::NO_BIAS) \
FOR_OP(BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(BiasMode::BIAS)
FOR_BIAS
#undef FOR_BIAS
#undef FOR_OP
#undef INSTANTIATION
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace channel_wise_nchw44_float {
template <BiasMode bias_mode, typename Op>
void do_conv_kern_3x3_stride1_padding1(const float* src, float* dst,
const float* filter, const float* bias,
int H, int W);
} // namespace channel_wise_nchw44_float
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
#pragma GCC diagnostic ignored "-Wunused-parameter"
using namespace megdnn;
using namespace arm_common;
namespace {
#if defined(__ARM_FEATURE_FMA)
#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m)
#else
#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m)
#endif
template <int shift>
static inline void shift_src(float32x4_t rsrc[6]) {
float32x4_t t[6];
t[0] = rsrc[(shift + 0) % 6];
t[1] = rsrc[(shift + 1) % 6];
t[2] = rsrc[(shift + 2) % 6];
t[3] = rsrc[(shift + 3) % 6];
t[4] = rsrc[(shift + 4) % 6];
t[5] = rsrc[(shift + 5) % 6];
rsrc[0] = t[0];
rsrc[1] = t[1];
rsrc[2] = t[2];
rsrc[3] = t[3];
rsrc[4] = t[4];
rsrc[5] = t[5];
}
static inline void load_filter(const float* filter, float32x4_t rfilter[5]) {
rfilter[0] = vld1q_f32(filter + 0);
rfilter[1] = vld1q_f32(filter + 4);
rfilter[2] = vld1q_f32(filter + 8);
rfilter[3] = vld1q_f32(filter + 12);
rfilter[4] = vld1q_f32(filter + 16);
}
template <BiasMode bias_mode>
static inline float32x4_t load_bias(const float* bias,
const float32x4_t& init) {
if (bias_mode == BiasMode::BIAS) {
return vld1q_f32(bias);
} else {
return init;
}
}
template <int BW, int bw, BiasMode bias_mode, bool need_load_bias,
bool need_do_op>
struct compute_element {
template <typename Op>
static inline void call(const float*& src, float*& dst, const float*& bias,
const float32x4_t& init, float32x4_t rsrc[6],
float32x4_t rfilter[5], const Op& op) {
#define RSRC(i) rsrc[((i) + bw) % 6]
float32x4_t rdst;
if (need_load_bias) {
rdst = load_bias<bias_mode>(bias, init);
} else {
rdst = vld1q_f32(dst);
}
RSRC(5) = vld1q_f32(src + 12);
rdst = Vfmaq_f32(rdst, RSRC(0), rfilter[0]);
rdst = Vfmaq_f32(rdst, RSRC(1), rfilter[1]);
rdst = Vfmaq_f32(rdst, RSRC(2), rfilter[2]);
rdst = Vfmaq_f32(rdst, RSRC(3), rfilter[3]);
rdst = Vfmaq_f32(rdst, RSRC(4), rfilter[4]);
if (need_do_op) {
rdst = op(rdst);
}
vst1q_f32(dst, rdst);
src += 4;
dst += 4;
bias += 4;
compute_element<BW, bw + 1, bias_mode, need_load_bias,
need_do_op>::call(src, dst, bias, init, rsrc, rfilter,
op);
#undef RSRC
}
};
template <int BW, BiasMode bias_mode, bool need_load_bias, bool need_do_op>
struct compute_element<BW, BW, bias_mode, need_load_bias, need_do_op> {
template <typename... Types>
static inline void call(Types... args) {}
};
template <size_t padding, BiasMode bias_mode, bool need_load_bias,
bool need_do_op>
struct compute_element_right {
template <typename Op>
static inline void call(float*& dst, const float*& bias,
const float32x4_t& init, float32x4_t rsrc[6],
float32x4_t rfilter[5], const Op& op) {
float32x4_t rdst;
if (need_load_bias) {
rdst = load_bias<bias_mode>(bias, init);
} else {
rdst = vld1q_f32(dst);
}
rdst = Vfmaq_f32(rdst, rsrc[0 + padding], rfilter[0]);
rdst = Vfmaq_f32(rdst, rsrc[1 + padding], rfilter[1]);
rdst = Vfmaq_f32(rdst, rsrc[2 + padding], rfilter[2]);
if (padding < 2) {
rdst = Vfmaq_f32(rdst, rsrc[3 + padding], rfilter[3]);
}
if (padding < 1) {
rdst = Vfmaq_f32(rdst, rsrc[4 + padding], rfilter[4]);
}
if (need_do_op) {
rdst = op(rdst);
}
vst1q_f32(dst, rdst);
dst += 4;
bias += 4;
}
};
template <BiasMode bias_mode, bool need_load_bias, bool need_do_op>
struct compute_row_src_1x5 {
template <typename Op>
static inline void call(const float* src, float* dst, const float* bias,
const float32x4_t& init, float32x4_t rsrc[6],
float32x4_t rfilter[5], int W, const Op& op) {
rsrc[0] = vdupq_n_f32(0);
rsrc[1] = vdupq_n_f32(0);
rsrc[2] = vld1q_f32(src + 0);
rsrc[3] = vld1q_f32(src + 4);
rsrc[4] = vld1q_f32(src + 8);
int w = 0;
for (; w + 5 < W - 3; w += 6) {
compute_element<6, 0, bias_mode, need_load_bias, need_do_op>::call(
src, dst, bias, init, rsrc, rfilter, op);
}
if (w + 3 < W - 3) {
compute_element<4, 0, bias_mode, need_load_bias, need_do_op>::call(
src, dst, bias, init, rsrc, rfilter, op);
shift_src<4>(rsrc);
w += 4;
}
if (w + 1 < W - 3) {
compute_element<2, 0, bias_mode, need_load_bias, need_do_op>::call(
src, dst, bias, init, rsrc, rfilter, op);
shift_src<2>(rsrc);
w += 2;
}
if (w < W - 3) {
compute_element<1, 0, bias_mode, need_load_bias, need_do_op>::call(
src, dst, bias, init, rsrc, rfilter, op);
shift_src<1>(rsrc);
w += 1;
}
// compute rightmost 3 elements seperately
compute_element_right<0, bias_mode, need_load_bias, need_do_op>::call(
dst, bias, init, rsrc, rfilter, op);
compute_element_right<1, bias_mode, need_load_bias, need_do_op>::call(
dst, bias, init, rsrc, rfilter, op);
compute_element_right<2, bias_mode, need_load_bias, need_do_op>::call(
dst, bias, init, rsrc, rfilter, op);
}
};
template <size_t top_padding, size_t bottom_padding, BiasMode bias_mode>
struct compute_row {
template <typename Op>
static inline void call(const float*& src, float*& dst, const float* filter,
const float*& bias, const float32x4_t& init,
float32x4_t rsrc[6], float32x4_t rfilter[5], int W,
const Op& op) {
if (top_padding < 1) {
load_filter(filter + 0, rfilter);
compute_row_src_1x5<bias_mode, top_padding == 0, false>::call(
src - W * 8, dst, bias, init, rsrc, rfilter, W, op);
}
if (top_padding < 2) {
load_filter(filter + 20, rfilter);
compute_row_src_1x5<bias_mode, top_padding == 1, false>::call(
src - W * 4, dst, bias, init, rsrc, rfilter, W, op);
}
{
load_filter(filter + 40, rfilter);
compute_row_src_1x5<bias_mode, top_padding == 2,
bottom_padding == 2>::call(src, dst, bias, init,
rsrc, rfilter, W,
op);
}
if (bottom_padding < 2) {
load_filter(filter + 60, rfilter);
compute_row_src_1x5<bias_mode, false, bottom_padding == 1>::call(
src + W * 4, dst, bias, init, rsrc, rfilter, W, op);
}
if (bottom_padding < 1) {
load_filter(filter + 80, rfilter);
compute_row_src_1x5<bias_mode, false, bottom_padding == 0>::call(
src + W * 8, dst, bias, init, rsrc, rfilter, W, op);
}
src += W * 4;
dst += W * 4;
bias += W * 4;
}
};
} // namespace
template <BiasMode bias_mode, typename Op>
void channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2(
const float* src, float* dst, const float* filter, const float* bias,
int H, int W) {
Op op;
float32x4_t init = vdupq_n_f32(0);
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
init = vld1q_f32(bias);
}
float32x4_t rsrc[6];
float32x4_t rfilter[5];
compute_row<2, 0, bias_mode>::call(src, dst, filter, bias, init, rsrc,
rfilter, W, op);
compute_row<1, 0, bias_mode>::call(src, dst, filter, bias, init, rsrc,
rfilter, W, op);
for (int h = 2; h < H - 2; h += 1) {
compute_row<0, 0, bias_mode>::call(src, dst, filter, bias, init, rsrc,
rfilter, W, op);
}
compute_row<0, 1, bias_mode>::call(src, dst, filter, bias, init, rsrc,
rfilter, W, op);
compute_row<0, 2, bias_mode>::call(src, dst, filter, bias, init, rsrc,
rfilter, W, op);
}
#define INSTANTIATION(bias, Op) \
template void \
channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2<bias, Op>( \
const float*, float*, const float*, const float*, int, int);
#define FOR_OP(bias) \
INSTANTIATION(bias, SigmoidOp<dt_float32>) \
INSTANTIATION(bias, ReluOp<dt_float32>) \
INSTANTIATION(bias, HSwishOp<dt_float32>) \
INSTANTIATION(bias, NoneOp<dt_float32>)
#define FOR_BIAS \
FOR_OP(BiasMode::NO_BIAS) \
FOR_OP(BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(BiasMode::BIAS)
FOR_BIAS
#undef FOR_BIAS
#undef FOR_OP
#undef INSTANTIATION
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace channel_wise_nchw44_float {
template <BiasMode bias_mode, typename Op>
void do_conv_kern_5x5_stride1_padding2(const float* src, float* dst,
const float* filter, const float* bias,
int H, int W);
} // namespace channel_wise_nchw44_float
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -11,6 +11,8 @@
*/
#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
......@@ -413,6 +415,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3(
const float* src, const float* filter, const float* bias, float* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW,
const size_t PH, const size_t PW) {
if (IH == OH && IW == OW && IH >= 3 && IW >= 3 && PH == 1 && PW == 1) {
channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1<bias_mode,
Op>(
src, dst, filter, bias, OH, OW);
return;
}
float32x4_t kernel[9];
load_vec<9>(kernel, filter);
Op op;
......@@ -424,10 +433,7 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3(
size_t ow_start = PW;
size_t oh_end = IH + PH - 2;
size_t ow_end = IW + PW - 2;
if (PH == 1 && PW == 1) {
PaddingComputeK3P1<bias_mode, Op>::compute(src, bias, dst, 1, IH, IW,
OH, OW, kernel, init);
} else if (PH || PW) {
if (PH || PW) {
PaddingCompute<bias_mode, Op>::compute(src, bias, dst, 3, 1, IH, IW, OH,
OW, PH, PW, kernel, init);
}
......@@ -557,6 +563,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5(
const float* src, const float* filter, const float* bias, float* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW,
const size_t PH, const size_t PW) {
if (IH == OH && IW == OW && IH >= 5 && IW >= 5 && PH == 2 && PW == 2) {
channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2<bias_mode,
Op>(
src, dst, filter, bias, OH, OW);
return;
}
Op op;
float32x4_t init = vdupq_n_f32(0.f);
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册