提交 1e83ab63 编写于 作者: M Megvii Engine Team

feat(dnn): add channelwise conv for fp16 nchw88

GitOrigin-RevId: 1bb64f82c5cc4e512e9141a2bebd35ee8033fc5d
上级 28c066ee
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
...@@ -107,7 +108,7 @@ public: ...@@ -107,7 +108,7 @@ public:
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override{ ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP16) MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP16)
...@@ -132,6 +133,26 @@ public: ...@@ -132,6 +133,26 @@ public:
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP16) MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP16)
}; };
class ConvBiasImpl::AlgoF16ChannelWiseNCHW88 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "F16_CHANNEL_WISE_NCHW88"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_NCHW88_F16)
};
} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn
#endif #endif
......
/**
* \file dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
using namespace megdnn;
using namespace arm_common;
using namespace fp16;
namespace {
template <int shift>
static inline void shift_src(float16x8_t rsrc[3][4]) {
float16x8_t t[4];
t[0] = rsrc[0][(shift + 0) % 4];
t[1] = rsrc[0][(shift + 1) % 4];
t[2] = rsrc[0][(shift + 2) % 4];
t[3] = rsrc[0][(shift + 3) % 4];
rsrc[0][0] = t[0];
rsrc[0][1] = t[1];
rsrc[0][2] = t[2];
rsrc[0][3] = t[3];
t[0] = rsrc[1][(shift + 0) % 4];
t[1] = rsrc[1][(shift + 1) % 4];
t[2] = rsrc[1][(shift + 2) % 4];
t[3] = rsrc[1][(shift + 3) % 4];
rsrc[1][0] = t[0];
rsrc[1][1] = t[1];
rsrc[1][2] = t[2];
rsrc[1][3] = t[3];
t[0] = rsrc[2][(shift + 0) % 4];
t[1] = rsrc[2][(shift + 1) % 4];
t[2] = rsrc[2][(shift + 2) % 4];
t[3] = rsrc[2][(shift + 3) % 4];
rsrc[2][0] = t[0];
rsrc[2][1] = t[1];
rsrc[2][2] = t[2];
rsrc[2][3] = t[3];
}
template <BiasMode bias_mode>
static inline float16x8_t load_bias(const float16_t* bias,
const float16x8_t& init) {
if (bias_mode == BiasMode::BIAS) {
return vld1q_f16(bias);
} else {
return init;
}
}
template <int BW, int bw, bool has_top, bool has_bottom, BiasMode bias_mode>
struct compute_element {
template <typename Op>
static inline void call(const float16_t*& src0, const float16_t*& src1,
const float16_t*& src2, float16_t*& dst,
const float16_t*& bias, const float16x8_t& init,
float16x8_t rsrc[3][4], float16x8_t rfilter[3][3],
const Op& op) {
#define RSRC(i, j) rsrc[i][((j) + bw) % 4]
float16x8_t rdst = load_bias<bias_mode>(bias, init);
if (has_top) {
RSRC(0, 3) = vld1q_f16(src0 + 16);
}
{ RSRC(1, 3) = vld1q_f16(src1 + 16); }
if (has_bottom) {
RSRC(2, 3) = vld1q_f16(src2 + 16);
}
if (has_top) {
rdst = vfmaq_f16(rdst, RSRC(0, 0), rfilter[0][0]);
rdst = vfmaq_f16(rdst, RSRC(0, 1), rfilter[0][1]);
rdst = vfmaq_f16(rdst, RSRC(0, 2), rfilter[0][2]);
}
{
rdst = vfmaq_f16(rdst, RSRC(1, 0), rfilter[1][0]);
rdst = vfmaq_f16(rdst, RSRC(1, 1), rfilter[1][1]);
rdst = vfmaq_f16(rdst, RSRC(1, 2), rfilter[1][2]);
}
if (has_bottom) {
rdst = vfmaq_f16(rdst, RSRC(2, 0), rfilter[2][0]);
rdst = vfmaq_f16(rdst, RSRC(2, 1), rfilter[2][1]);
rdst = vfmaq_f16(rdst, RSRC(2, 2), rfilter[2][2]);
}
vst1q_f16(dst, op(rdst));
if (has_top) {
src0 += 8;
}
{ src1 += 8; }
if (has_bottom) {
src2 += 8;
}
dst += 8;
bias += 8;
compute_element<BW, bw + 1, has_top, has_bottom, bias_mode>::call(
src0, src1, src2, dst, bias, init, rsrc, rfilter, op);
#undef RSRC
}
};
template <int BW, bool has_top, bool has_bottom, BiasMode bias_mode>
struct compute_element<BW, BW, has_top, has_bottom, bias_mode> {
template <typename Op>
static inline void call(const float16_t*& src0, const float16_t*& src1,
const float16_t*& src2, float16_t*& dst,
const float16_t*& bias, const float16x8_t& init,
float16x8_t rsrc[3][4], float16x8_t rfilter[3][3],
const Op& op) {}
};
template <bool has_top, bool has_bottom, BiasMode bias_mode>
struct compute_element_right {
template <typename Op>
static inline void call(float16_t*& dst, const float16_t*& bias,
const float16x8_t& init, float16x8_t rsrc[3][4],
float16x8_t rfilter[3][3], const Op& op) {
float16x8_t rdst = load_bias<bias_mode>(bias, init);
if (has_top) {
rdst = vfmaq_f16(rdst, rsrc[0][0], rfilter[0][0]);
rdst = vfmaq_f16(rdst, rsrc[0][1], rfilter[0][1]);
rdst = vfmaq_f16(rdst, rsrc[0][2], rfilter[0][2]);
}
{
rdst = vfmaq_f16(rdst, rsrc[1][0], rfilter[1][0]);
rdst = vfmaq_f16(rdst, rsrc[1][1], rfilter[1][1]);
rdst = vfmaq_f16(rdst, rsrc[1][2], rfilter[1][2]);
}
if (has_bottom) {
rdst = vfmaq_f16(rdst, rsrc[2][0], rfilter[2][0]);
rdst = vfmaq_f16(rdst, rsrc[2][1], rfilter[2][1]);
rdst = vfmaq_f16(rdst, rsrc[2][2], rfilter[2][2]);
}
vst1q_f16(dst, op(rdst));
dst += 8;
bias += 8;
}
};
template <bool has_top, bool has_bottom, BiasMode bias_mode>
struct compute_element_right_pad {
template <typename Op>
static inline void call(float16_t*& dst, const float16_t*& bias,
const float16x8_t& init, float16x8_t rsrc[3][4],
float16x8_t rfilter[3][3], const Op& op) {
float16x8_t rdst = load_bias<bias_mode>(bias, init);
if (has_top) {
rdst = vfmaq_f16(rdst, rsrc[0][1], rfilter[0][0]);
rdst = vfmaq_f16(rdst, rsrc[0][2], rfilter[0][1]);
}
{
rdst = vfmaq_f16(rdst, rsrc[1][1], rfilter[1][0]);
rdst = vfmaq_f16(rdst, rsrc[1][2], rfilter[1][1]);
}
if (has_bottom) {
rdst = vfmaq_f16(rdst, rsrc[2][1], rfilter[2][0]);
rdst = vfmaq_f16(rdst, rsrc[2][2], rfilter[2][1]);
}
vst1q_f16(dst, op(rdst));
dst += 8;
bias += 8;
}
};
template <bool has_top, bool has_bottom, BiasMode bias_mode>
struct compute_row {
template <typename Op>
static inline void call(const float16_t*& src0, const float16_t*& src1,
const float16_t*& src2, float16_t*& dst,
const float16_t*& bias, const float16x8_t& init,
float16x8_t rsrc[3][4], float16x8_t rfilter[3][3],
int W, const Op& op) {
if (has_top) {
rsrc[0][0] = vdupq_n_f16(0);
rsrc[0][1] = vld1q_f16(src0 + 0);
rsrc[0][2] = vld1q_f16(src0 + 8);
}
{
rsrc[1][0] = vdupq_n_f16(0);
rsrc[1][1] = vld1q_f16(src1 + 0);
rsrc[1][2] = vld1q_f16(src1 + 8);
}
if (has_bottom) {
rsrc[2][0] = vdupq_n_f16(0);
rsrc[2][1] = vld1q_f16(src2 + 0);
rsrc[2][2] = vld1q_f16(src2 + 8);
}
int w = 0;
const float16_t* src0_ptr = src0;
const float16_t* src1_ptr = src1;
const float16_t* src2_ptr = src2;
float16_t* dst_ptr = dst;
const float16_t* bias_ptr = bias;
for (; w + 3 < W - 2; w += 4) {
compute_element<4, 0, has_top, has_bottom, bias_mode>::call(
src0_ptr, src1_ptr, src2_ptr, dst_ptr, bias_ptr, init, rsrc,
rfilter, op);
}
if (w + 1 < W - 2) {
compute_element<2, 0, has_top, has_bottom, bias_mode>::call(
src0_ptr, src1_ptr, src2_ptr, dst_ptr, bias_ptr, init, rsrc,
rfilter, op);
shift_src<2>(rsrc);
w += 2;
}
if (w < W - 2) {
compute_element<1, 0, has_top, has_bottom, bias_mode>::call(
src0_ptr, src1_ptr, src2_ptr, dst_ptr, bias_ptr, init, rsrc,
rfilter, op);
shift_src<1>(rsrc);
w += 1;
}
compute_element_right<has_top, has_bottom, bias_mode>::call(
dst_ptr, bias_ptr, init, rsrc, rfilter, op);
compute_element_right_pad<has_top, has_bottom, bias_mode>::call(
dst_ptr, bias_ptr, init, rsrc, rfilter, op);
src0 += W * 8;
src1 += W * 8;
src2 += W * 8;
dst += W * 8;
bias += W * 8;
}
};
} // namespace
template <BiasMode bias_mode, typename Op>
void channel_wise_nchw88::do_conv_kern_3x3_stride1_padding1(
const float16_t* src, float16_t* dst, const float16_t* filter,
const float16_t* bias, int H, int W) {
Op op;
float16x8_t init;
if (bias_mode == BiasMode::NO_BIAS) {
init = vdupq_n_f16(__fp16(0.f));
} else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
init = vld1q_f16(bias);
}
const float16_t* src0 = src - W * 8;
const float16_t* src1 = src;
const float16_t* src2 = src + W * 8;
float16x8_t rfilter[3][3];
rfilter[0][0] = vld1q_f16(filter + 0);
rfilter[0][1] = vld1q_f16(filter + 8);
rfilter[0][2] = vld1q_f16(filter + 16);
rfilter[1][0] = vld1q_f16(filter + 24);
rfilter[1][1] = vld1q_f16(filter + 32);
rfilter[1][2] = vld1q_f16(filter + 40);
rfilter[2][0] = vld1q_f16(filter + 48);
rfilter[2][1] = vld1q_f16(filter + 56);
rfilter[2][2] = vld1q_f16(filter + 64);
float16x8_t rsrc[3][4];
compute_row<false, true, bias_mode>::call(src0, src1, src2, dst, bias, init,
rsrc, rfilter, W, op);
for (int h = 1; h < H - 1; h += 1) {
compute_row<true, true, bias_mode>::call(src0, src1, src2, dst, bias,
init, rsrc, rfilter, W, op);
}
compute_row<true, false, bias_mode>::call(src0, src1, src2, dst, bias, init,
rsrc, rfilter, W, op);
}
#define INSTANTIATION(bias, Op) \
template void \
channel_wise_nchw88::do_conv_kern_3x3_stride1_padding1<bias, Op>( \
const float16_t*, float16_t*, const float16_t*, const float16_t*, \
int, int);
#define FOR_OP(bias) \
INSTANTIATION(bias, SigmoidOp<__fp16>) \
INSTANTIATION(bias, ReluOp<__fp16>) \
INSTANTIATION(bias, HSwishOp<__fp16>) \
INSTANTIATION(bias, NoneOp<__fp16>)
#define FOR_BIAS \
FOR_OP(BiasMode::NO_BIAS) \
FOR_OP(BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(BiasMode::BIAS)
FOR_BIAS
#undef FOR_BIAS
#undef FOR_OP
#undef INSTANTIATION
#endif
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/fp16/channel_wise_3x3_s1p1_nchw88_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
namespace megdnn {
namespace arm_common {
namespace fp16 {
namespace channel_wise_nchw88 {
template <BiasMode bias_mode, typename Op>
void do_conv_kern_3x3_stride1_padding1(const __fp16* src, __fp16* dst,
const __fp16* filter, const __fp16* bias,
int H, int W);
} // namespace channel_wise_nchw88
} // namespace fp16
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/f16/algos.h"
#include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "midout.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
using namespace megdnn;
using namespace arm_common;
using namespace fp16;
using conv_fun = std::function<void(
const __fp16* src, const __fp16* filter, const __fp16* bias,
__fp16* dst, const size_t IH, const size_t IW, const size_t OH,
const size_t OW, const size_t PH, size_t PW)>;
MIDOUT_DECL(conv_bias_fp16_channel_wise_nchw88)
bool ConvBiasImpl::AlgoF16ChannelWiseNCHW88::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
size_t OC = fm.ocpg;
size_t IC = fm.icpg;
size_t GROUP = fm.group;
bool ok_type = (param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_type.enumv() == DTypeEnum::Float16 &&
param.bias_type.enumv() == DTypeEnum::Float16 &&
param.dst_type.enumv() == DTypeEnum::Float16);
bool ok_format = OC == 1 && IC == 1 && GROUP % 8 == 0 &&
fm.format == param::Convolution::Format::NCHW88;
bool ok_filter = fm.spatial_ndim == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[0] == 2);
bool ok_conv = !fm.should_flip;
bool ok_comp = param.compute_mode == Param::ComputeMode::DEFAULT;
return ok_type && ok_format && ok_filter && ok_slide && ok_conv && ok_comp;
}
size_t ConvBiasImpl::AlgoF16ChannelWiseNCHW88::get_workspace(
const NCBKernSizeParam&) const {
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF16ChannelWiseNCHW88::dispatch_kerns(
const NCBKernSizeParam& param) const {
const constexpr size_t pack_group_size = 8_z;
auto fm = param.filter_meta;
const int batch = param.n;
const int group = fm.group;
const int stride = fm.stride[0];
conv_fun do_conv_fun = nullptr;
// NOTE: remain_w is not used to gen hash of midout for compatible with
// shape runtime
#define DO_CONV_KERN_FUN(_stride, filter, bias_mode, op) \
MIDOUT_BEGIN(conv_bias_fp16_channel_wise_nchw88, \
midout_iv(#_stride #filter #bias_mode #op##_hash)) { \
do_conv_fun = channel_wise_nchw88:: \
do_conv_kern_##_stride##_##filter##x##filter<bias_mode, op>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(_stride, filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(_stride, filter, bias_mode, NoneOp<__fp16>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(_stride, filter, bias_mode, ReluOp<__fp16>) \
break; \
case param::ConvBias::NonlineMode::SIGMOID: \
DO_CONV_KERN_FUN(_stride, filter, bias_mode, SigmoidOp<__fp16>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(_stride, filter, bias_mode, HSwishOp<__fp16>) \
break; \
default: \
megdnn_assert(0, "not supported nonline mode"); \
break; \
}
#define GET_BIAS_MODE_PARAM(_stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(_stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(_stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
case BiasMode::BIAS: \
GET_OP_PARAM(_stride, filter, BiasMode::BIAS) \
break; \
default: \
megdnn_assert(0, "not supported bias mode"); \
break; \
}
#define DISPATCH_CONV_KERN(_stride) \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(_stride, 2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(_stride, 3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(_stride, 5) \
break; \
default: \
megdnn_assert(0, "not supported stride"); \
break; \
}
#define DISPATCH_STRIDE() \
if (1 == stride) { \
DISPATCH_CONV_KERN(stride1); \
} else { \
DISPATCH_CONV_KERN(stride2); \
}
DISPATCH_STRIDE();
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
#undef GET_OP_PARAM
#undef GET_BIAS_MODE_PARAM
#undef DISPATCH_CONV_KERN
#undef DISPATCH_STRIDE
megdnn_assert(do_conv_fun, "conv filter not supported");
SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
CpuNDRange ncb_range = {static_cast<size_t>(batch),
static_cast<size_t>(group / pack_group_size)};
auto do_conv = [do_conv_fun](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
const __fp16* sptr =
reinterpret_cast<const __fp16*>(kern_param.src<dt_float16>(
batch_id, group_id, 0, pack_group_size));
const __fp16* fptr = reinterpret_cast<const __fp16*>(
kern_param.filter<dt_float16>(group_id, pack_group_size));
__fp16* dst = reinterpret_cast<__fp16*>(kern_param.dst<dt_float16>(
batch_id, group_id, 0, pack_group_size));
const __fp16* bptr =
reinterpret_cast<const __fp16*>(kern_param.bias<dt_float16>(
batch_id, group_id, 0, pack_group_size));
do_conv_fun(sptr, fptr, bptr, dst, IH, IW, OH, OW, PH, PW);
};
ret_kerns.push_back({do_conv, ncb_range});
return ret_kerns;
}
#endif
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/fp16/channel_wise_nchw88_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
namespace megdnn {
namespace arm_common {
namespace fp16 {
namespace channel_wise_nchw88 {
#define KERN(stride, i) \
template <BiasMode bias_mode, typename Op> \
void do_conv_kern_##stride##_##i##x##i( \
const __fp16* src, const __fp16* filter, const __fp16* bias, \
__fp16* dst, const size_t IH, const size_t IW, const size_t OH, \
const size_t OW, const size_t PH, const size_t PW);
KERN(stride1, 2)
KERN(stride1, 3)
KERN(stride1, 5)
KERN(stride2, 2)
KERN(stride2, 3)
KERN(stride2, 5)
#undef KERN
} // namespace channel_wise_nchw88
} // namespace fp16
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
...@@ -85,6 +85,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { ...@@ -85,6 +85,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoF16Direct f16_direct; AlgoF16Direct f16_direct;
AlgoF16DirectStride1 f16_direct_stride1; AlgoF16DirectStride1 f16_direct_stride1;
AlgoF16ChannelWiseNCHW88 f16_channel_wise_nchw88;
#endif #endif
SmallVector<std::unique_ptr<AlgoBase>> refhold; SmallVector<std::unique_ptr<AlgoBase>> refhold;
...@@ -119,6 +120,7 @@ public: ...@@ -119,6 +120,7 @@ public:
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
m_direct_algos.emplace_back(&f16_direct_stride1); m_direct_algos.emplace_back(&f16_direct_stride1);
m_direct_algos.emplace_back(&f16_direct); m_direct_algos.emplace_back(&f16_direct);
m_direct_algos.emplace_back(&f16_channel_wise_nchw88);
#endif #endif
m_direct_algos.emplace_back(&i8x8x16_direct); m_direct_algos.emplace_back(&i8x8x16_direct);
m_direct_algos.emplace_back(&i8x8x16_stride2_filter2); m_direct_algos.emplace_back(&i8x8x16_stride2_filter2);
......
...@@ -96,6 +96,7 @@ private: ...@@ -96,6 +96,7 @@ private:
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class AlgoF16Direct; class AlgoF16Direct;
class AlgoF16DirectStride1; class AlgoF16DirectStride1;
class AlgoF16ChannelWiseNCHW88;
#endif #endif
class AlgoPack; class AlgoPack;
......
...@@ -238,6 +238,7 @@ public: ...@@ -238,6 +238,7 @@ public:
ARM_COMMON_WINOGRAD_F23_8X8_FP16, ARM_COMMON_WINOGRAD_F23_8X8_FP16,
ARM_COMMON_DIRECT_FP16, ARM_COMMON_DIRECT_FP16,
ARM_COMMON_DIRECT_STRD1_FP16, ARM_COMMON_DIRECT_STRD1_FP16,
ARM_COMMON_CHWNWISE_NCHW88_F16,
ARM_COMMON_WINOGRAD_F23_4X4_FP32, ARM_COMMON_WINOGRAD_F23_4X4_FP32,
ARM_COMMON_WINOGRAD_F63_FP32, ARM_COMMON_WINOGRAD_F63_FP32,
ARM_COMMON_WINOGRAD_F63_4X4_FP32, ARM_COMMON_WINOGRAD_F63_4X4_FP32,
......
...@@ -148,6 +148,81 @@ std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args( ...@@ -148,6 +148,81 @@ std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args(
return args; return args;
} }
std::vector<conv_bias::TestArg> get_nchw88_channel_wise_args(
std::vector<size_t> kernel, size_t stride, bool no_bias,
bool no_nonlinemode, bool no_full_bias) {
using namespace conv_bias;
using Param = param::ConvBias;
using NLMode = param::ConvBias::NonlineMode;
std::vector<TestArg> args;
auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
size_t stride, NLMode nlmode, bool pad) {
Param param;
param.stride_h = stride;
param.stride_w = stride;
if (pad) {
param.pad_h = kernel / 2;
param.pad_w = kernel / 2;
} else {
param.pad_h = 0;
param.pad_w = 0;
}
param.nonlineMode = nlmode;
param.format = param::ConvBias::Format::NCHW88;
param.sparse = param::ConvBias::Sparse::GROUP;
args.emplace_back(param, TensorShape{n, group, h, w, 8},
TensorShape{group, 1, 1, kernel, kernel, 8},
TensorShape{});
if (!no_bias) {
args.emplace_back(param, TensorShape{n, group, h, w, 8},
TensorShape{group, 1, 1, kernel, kernel, 8},
TensorShape{1, group, 1, 1, 8});
}
if (!no_full_bias) {
args.emplace_back(
param, TensorShape{n, group, h, w, 8},
TensorShape{group, 1, 1, kernel, kernel, 8},
TensorShape{n, group,
(h + 2 * param.pad_w - kernel) / stride + 1,
(w + 2 * param.pad_w - kernel) / stride + 1,
8});
}
};
std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
if (!no_nonlinemode) {
nonlinemode.emplace_back(NLMode::RELU);
nonlinemode.emplace_back(NLMode::H_SWISH);
}
for (size_t n : {1, 2}) {
for (auto nlmode : nonlinemode) {
for (bool pad : {true}) {
for (size_t group : {1, 2, 4, 7, 8, 128}) {
for (size_t size : {4, 6, 7, 9, 15, 40}) {
for (size_t kern : kernel) {
pack(n, group, size, size, kern, stride, nlmode,
pad);
}
}
}
}
for (bool pad : {false}) {
for (size_t group : {1, 2, 7, 128}) {
for (size_t size : {7, 9, 15, 40}) {
for (size_t kern : kernel) {
pack(n, group, size, size, kern, stride, nlmode,
pad);
}
}
}
}
}
}
return args;
}
void checker_conv_bias_qint8x8x8(std::vector<conv_bias::TestArg> args, void checker_conv_bias_qint8x8x8(std::vector<conv_bias::TestArg> args,
Handle* handle, const char* algo_name) { Handle* handle, const char* algo_name) {
Checker<ConvBias> checker(handle); Checker<ConvBias> checker(handle);
...@@ -317,6 +392,26 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1) { ...@@ -317,6 +392,26 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1) {
checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false), checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
handle(), rng, "F16STRD1", 0.03); handle(), rng, "F16STRD1", 0.03);
} }
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP16_NCHW88_1) {
NormalRNG rng(1);
checker_conv_bias_f16(
get_nchw88_channel_wise_args({2, 3}, 1, false, false, false),
handle(), rng, "F16_CHANNEL_WISE_NCHW88", 0.03);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP16_NCHW88_2) {
NormalRNG rng(1);
checker_conv_bias_f16(
get_nchw88_channel_wise_args({5}, 1, false, false, false), handle(),
rng, "F16_CHANNEL_WISE_NCHW88", 0.03);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP16_NCHW88) {
NormalRNG rng(1);
checker_conv_bias_f16(
get_nchw88_channel_wise_args({2, 3, 5}, 2, false, false, false),
handle(), rng, "F16_CHANNEL_WISE_NCHW88", 0.03);
}
#endif #endif
/**********************************algo 8816 direct************************/ /**********************************algo 8816 direct************************/
......
...@@ -400,6 +400,68 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16_STR1) { ...@@ -400,6 +400,68 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16_STR1) {
benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}},
{1, {4}}, data_type); {1, {4}}, data_type);
} }
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_F16_NCHW88) {
constexpr size_t RUNS = 50;
std::string algo_name = "F16_CHANNEL_WISE_NCHW88";
printf("Benchmarker F16_CHANNEL_WISE_NCHW88 algo\n");
std::vector<DType> data_type = {dtype::Float16(), dtype::Float16(),
dtype::Float16(), dtype::Float16()};
auto bench_case = [&](size_t N, size_t IC, size_t H, size_t W, size_t FS,
size_t P, size_t S) {
param::ConvBias param;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
param.pad_h = P;
param.pad_w = P;
param.stride_h = S;
param.stride_w = S;
param.sparse = param::ConvBias::Sparse::GROUP;
param.format = param::ConvBias::Format::NCHW88;
size_t group = IC;
size_t OC = IC;
SmallVector<TensorShape> shapes{
{N, IC, H, W, 8},
{group, 1, 1, FS, FS, 8},
{1, OC, 1, 1, 8},
{},
{N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1, 8}};
TensorShape dst{N, OC, (H + 2 * P - FS) / S + 1,
(W + 2 * P - FS) / S + 1, 8};
float computations =
((IC / group) * FS * FS * dst.total_nr_elems() * 2 +
dst.total_nr_elems()) *
1e-6;
std::vector<std::pair<SmallVector<TensorShape>, float>> shape_arg = {
std::make_pair(shapes, computations)};
benchmark_impl(param, shape_arg, algo_name, RUNS, {4, {4, 5, 6, 7}},
{1, {7}}, data_type);
};
bench_case(1, 64, 100, 100, 5, 2, 1);
bench_case(1, 64, 56, 56, 5, 2, 1);
bench_case(1, 64, 28, 28, 5, 2, 1);
bench_case(1, 64, 100, 100, 5, 2, 2);
bench_case(1, 64, 56, 56, 5, 2, 2);
bench_case(1, 64, 28, 28, 5, 2, 2);
bench_case(1, 64, 100, 100, 3, 1, 1);
bench_case(1, 64, 56, 56, 3, 1, 1);
bench_case(1, 64, 28, 28, 3, 1, 1);
bench_case(1, 64, 100, 100, 3, 1, 2);
bench_case(1, 64, 56, 56, 3, 1, 2);
bench_case(1, 64, 28, 28, 3, 1, 2);
bench_case(1, 64, 100, 100, 2, 0, 1);
bench_case(1, 64, 56, 56, 2, 0, 1);
bench_case(1, 64, 28, 28, 2, 0, 1);
bench_case(1, 64, 100, 100, 2, 0, 2);
bench_case(1, 64, 56, 56, 2, 0, 2);
bench_case(1, 64, 28, 28, 2, 0, 2);
}
#endif #endif
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
BENCHMARK_CONVBIAS_DIRECT_INT8x8x16) { BENCHMARK_CONVBIAS_DIRECT_INT8x8x16) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册