提交 09ceaaae 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/arm): stride1 support for nchw_nchw44 fp32 conv

GitOrigin-RevId: 744c5db3dc3a867d1577f1c870e47472945234f5
上级 50db9b84
......@@ -293,11 +293,11 @@ public:
const NCBKernSizeParam& param) const override;
};
class ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44 final : public AlgoBase {
class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
public:
AlgoF32DirectStride2NCHWNCHW44() {}
AlgoF32DirectNCHWNCHW44() {}
bool is_reproducible() const override { return true; }
const char* name() const override { return "F32_CONV_NCHW_NCHW44"; }
bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
......
/**
* \file
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -13,7 +13,7 @@
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"
......@@ -26,7 +26,7 @@ using conv_fun = std::function<void(
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw_nchw44_stride2)
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw_nchw44)
namespace {
static inline int block_helper(const int nthread, const int amount,
const int per_unit_bytes) {
......@@ -120,11 +120,10 @@ static void pack_weight(WorkspaceBundle bundle,
kern_param.filter<dt_float32>(group_id) + oc_idx * fh * fw * ic;
auto packed_weight = reinterpret_cast<float*>(bundle.get(1)) +
group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw;
conv_bias::pack_weight_fp32_nchw_nchw44(fptr, packed_weight, oc_block, fh,
fw, ic);
pack_weight_fp32_nchw_nchw44(fptr, packed_weight, oc_block, fh, fw, ic);
}
template <size_t filter, BiasMode bias_mode, typename Op>
template <size_t filter_size, BiasMode bias_mode, typename Op, size_t stride>
static void do_conv_kern(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
......@@ -137,7 +136,7 @@ static void do_conv_kern(WorkspaceBundle bundle,
const int oc = kern_param.filter_meta.ocpg;
const int ih = kern_param.isz[0];
const int iw = kern_param.isz[1];
const int stride_h = kern_param.filter_meta.stride[0];
const int stride_h = stride;
const int ph = kern_param.filter_meta.padding[0];
const int pw = kern_param.filter_meta.padding[1];
int ih2 = 0;
......@@ -181,21 +180,15 @@ static void do_conv_kern(WorkspaceBundle bundle,
const float* bptr =
kern_param.bias<dt_float32>(batch_id, group_id) + oc_idx;
Op op;
#define KERN1_NCHW44_CONV(filter) \
conv_bias::conv_direct_stride2_##filter##x##filter##_fp32_nchw_nchw44< \
\
bias_mode, Op>(sptr, packed_weight, bptr, nullptr, dst, oc_block, \
ic, ih_real, iw2, oh, oh_block_real, ow, op, ph, \
pw)
DISPATCH_FILTER(filter, KERN1_NCHW44_CONV);
#undef KERN1_NCHW44_CONV
conv_direct_fp32_nchw_nchw44<bias_mode, Op, filter_size, stride>(
sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih_real, iw2,
oh, oh_block_real, ow, op, ph, pw);
}
} // namespace
/* ===================== stride2 algo ===================== */
bool ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::usable(
bool ConvBiasImpl::AlgoF32DirectNCHWNCHW44::usable(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
......@@ -209,19 +202,20 @@ bool ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::usable(
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] &&
(fh == 2 || fh == 3 || fh == 5 || fh == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2;
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[0] == 2);
bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS;
bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv;
return avaible;
}
size_t ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::get_workspace(
size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::dispatch_kerns(
ConvBiasImpl::AlgoF32DirectNCHWNCHW44::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
const int batch = param.n;
......@@ -230,61 +224,73 @@ ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::dispatch_kerns(
conv_fun do_conv_fun = nullptr;
// NOTE: remain_w is not used to gen hash of midout for compatible with
// shape runtime
#define DO_CONV_KERN_FUN(filter, bias_mode, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw_nchw44_stride2, \
midout_iv(#filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op>; \
} \
#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw_nchw44, \
midout_iv(#stride #filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(filter, bias_mode, NoneOp<dt_float32>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(filter, bias_mode, ReluOp<dt_float32>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(filter, bias_mode, HSwishOp<dt_float32>) \
break; \
default: \
megdnn_assert(0); \
break; \
#define GET_OP_PARAM(stride, filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, NoneOp<dt_float32>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, ReluOp<dt_float32>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, HSwishOp<dt_float32>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_BIAS_MODE_PARAM(filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
#define GET_BIAS_MODE_PARAM(stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN() \
#define DISPATCH_CONV_KERN(stride) \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(2) \
GET_BIAS_MODE_PARAM(stride, 2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(3) \
GET_BIAS_MODE_PARAM(stride, 3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(5) \
GET_BIAS_MODE_PARAM(stride, 5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(7) \
GET_BIAS_MODE_PARAM(stride, 7) \
break; \
default: \
megdnn_assert(0); \
break; \
}
DISPATCH_CONV_KERN();
switch (param.filter_meta.stride[0]) {
case 1:
DISPATCH_CONV_KERN(1);
break;
case 2:
DISPATCH_CONV_KERN(2);
break;
default:
megdnn_throw(ssprintf("Unsupport stride size %u for the first conv",
param.filter_meta.stride[0])
.c_str());
break;
}
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
......
/**
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace conv_bias {
#define KERN(stride, i, layout) \
template <BiasMode bias_mode, typename Op> \
void conv_direct_##stride##_##i##x##i##_fp32_nchw_##layout( \
const float* src, const float* filter, const float* bias, \
float* temp, float* dst, const int oc, const int ic, const int ih, \
const int iw, const int oh, const int oh_block, const int ow, \
const Op& op, const int ph, const int pw);
KERN(stride2, 2, nchw44)
KERN(stride2, 3, nchw44)
KERN(stride2, 5, nchw44)
KERN(stride2, 7, nchw44)
#undef KERN
void pack_weight_fp32_nchw_nchw44(const float_t* in_ptr, float_t* dst_ptr,
const int oc, const int kh, const int kw,
const int ic);
} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn
\ No newline at end of file
......@@ -66,7 +66,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false};
#endif
AlgoF32DirectStride2NCHWNCHW44 f32_direct_stride2_nchw_nchw44;
AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44;
AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44;
AlgoF32DirectNCHW44 f32_direct_nchw44;
......
......@@ -71,7 +71,7 @@ private:
class AlgoF32Direct;
class AlgoF32DirectStride1;
class AlgoF32DirectStride2;
class AlgoF32DirectStride2NCHWNCHW44;
class AlgoF32DirectNCHWNCHW44;
class AlgoF32ChannelWiseNCHW44;
class AlgoF32DirectNCHW44;
......
......@@ -204,6 +204,11 @@ static void benchmark_convbias(Handle* handle, std::string int_name,
run(1, 3, 32, 224, 224, 3, 2, true);
run(1, 3, 64, 224, 224, 7, 2, true);
run(1, 1, 4, 112, 112, 2, 1, true);
run(1, 3, 32, 224, 224, 3, 1, true);
run(1, 3, 64, 224, 224, 3, 1, true);
run(1, 3, 64, 224, 224, 7, 1, true);
run(1, 64, 128, 56, 56, 3, 2, false);
run(1, 128, 256, 28, 28, 3, 2, false);
run(1, 256, 512, 14, 14, 3, 2, false);
......
......@@ -392,6 +392,9 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) {
check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false,
false, true),
handle(), "F32_CONV_NCHW_NCHW44");
check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false,
false, true),
handle(), "F32_CONV_NCHW_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) {
check_conv_bias(
......@@ -824,13 +827,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_PREPROCESS_NCHW44) {
auto conv_bias_opr = handle->create_operator<ConvBias>();
conv_bias_opr->param() = param;
conv_bias_opr->param().format = param::ConvBias::Format::NCHW44_WINOGRAD;
conv_bias_opr->param().format =
param::ConvBias::Format::NCHW44_WINOGRAD;
conv_bias_opr->param().output_block_size = m;
size_t conv_bias_workspace_in_bytes =
conv_bias_opr->get_workspace_in_bytes(
tensors[0].layout, filter_transform_layout,
tensors[2].layout, tensors[3].layout,
tensors[4].layout, nullptr);
tensors[2].layout, tensors[3].layout, tensors[4].layout,
nullptr);
WorkspaceBundle wb(nullptr, {filter_transform_layout.span().dist_byte(),
conv_bias_workspace_in_bytes,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册