提交 682c74df 编写于 作者: M Megvii Engine Team

feat(dnn): add direct nchw88 fp16 conv

GitOrigin-RevId: 44719e8b6436ca0aa22eabc89a7a8eb39b5f857a
上级 fca19535
......@@ -153,6 +153,27 @@ public:
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_NCHW88_F16)
};
class ConvBiasImpl::AlgoF16DirectNCHW88 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
public:
AlgoF16DirectNCHW88() {}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override { return "F16_CONV_NCHW88_DIRECT"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW88_FP16)
};
} // namespace arm_common
} // namespace megdnn
#endif
......
/**
* \file dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/f16/algos.h"
#include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "midout.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
using namespace megdnn;
using namespace arm_common;
using conv_fun =
std::function<void(const WorkspaceBundle& bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_nchw88)
namespace {
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
size_t nr_threads = param.nr_threads;
size_t IC = fm.icpg / 8;
size_t PH = fm.padding[0];
size_t PW = fm.padding[1];
size_t IH2 = param.isz[0] + 2 * PH;
size_t IW2 = param.isz[1] + 2 * PW;
if (PH == 0 && PW == 0) {
return {nullptr, {}};
}
size_t s = (nr_threads * IC * IH2 * IW2 * 8) * sizeof(dt_float16);
return {nullptr, {s}};
}
void copy_padding_kern(const WorkspaceBundle& bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids) {
auto fm = kern_param.filter_meta;
size_t group = fm.group;
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t IC = fm.icpg / 8;
size_t PH = fm.padding[0];
size_t PW = fm.padding[1];
size_t IH2 = IH + 2 * PH;
size_t IW2 = IW + 2 * PW;
if (PH == 0 && PW == 0) {
return;
}
//! Used for get the workspace offset
size_t workspace_group_id = workspace_ids[0];
size_t workspace_batch_id = workspace_ids[1];
size_t channel_id = workspace_ids[2];
size_t group_id = ncb_index.ndrange_id[0];
size_t batch_id = ncb_index.ndrange_id[1];
const dt_float16* sptr =
kern_param.src<dt_float16>(batch_id, group_id, channel_id, 1, 8);
//! copy to sptr_base to eliminate padding effect
dt_float16* sptr_base = static_cast<dt_float16*>(bundle.get(0)) +
workspace_batch_id * group * IC * IH2 * IW2 * 8 +
workspace_group_id * IC * IH2 * IW2 * 8 +
channel_id * IH2 * IW2 * 8;
std::memset(sptr_base, 0, IH2 * IW2 * 8 * sizeof(dt_float16));
rep(ih, IH) {
std::memcpy(sptr_base + (ih + PH) * IW2 * 8 + PW * 8,
sptr + ih * IW * 8, IW * 8 * sizeof(dt_float16));
}
};
template <size_t FH, size_t SH, BiasMode bias_mode, typename Op>
static void do_conv_kern(const WorkspaceBundle& bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids) {
auto fm = kern_param.filter_meta;
size_t group = fm.group;
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t FW = FH;
size_t IC = fm.icpg / 8;
size_t PH = fm.padding[0];
size_t PW = fm.padding[1];
size_t IH2 = kern_param.isz[0] + 2 * PH;
size_t IW2 = kern_param.isz[1] + 2 * PW;
size_t group_id = ncb_index.ndrange_id[0];
size_t batch_id = ncb_index.ndrange_id[1];
size_t channel_id = workspace_ids[2];
//! Used for get the workspace offset
size_t workspace_batch_id = workspace_ids[1];
size_t workspace_group_id = workspace_ids[0];
const __fp16* sptr = nullptr;
if (PH == 0 && PW == 0) {
sptr = reinterpret_cast<const __fp16*>(
kern_param.src<dt_float16>(batch_id, group_id));
} else {
sptr = reinterpret_cast<const __fp16*>(
static_cast<const dt_float16*>(bundle.get(0))) +
workspace_batch_id * group * IC * IH2 * IW2 * 8 +
workspace_group_id * IC * IH2 * IW2 * 8;
}
const __fp16* filter = reinterpret_cast<const __fp16*>(
kern_param.filter<dt_float16>(group_id, 1)) +
channel_id * IC * FH * FW * 8 * 8;
const __fp16* bias_ptr = reinterpret_cast<const __fp16*>(
kern_param.bias<dt_float16>(batch_id, group_id, channel_id, 1, 8));
__fp16* dptr = reinterpret_cast<__fp16*>(
kern_param.dst<dt_float16>(batch_id, group_id, channel_id, 1, 8));
conv_bias::conv_direct_fp16_nchw88<FH, SH, bias_mode, Op>(
sptr, filter, bias_ptr, dptr, IC, IH2, IW2, OH, OW);
}
} // namespace
/* ===================== stride1 algo ===================== */
bool ConvBiasImpl::AlgoF16DirectNCHW88::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
auto fh = fm.spatial[0];
int oc = fm.ocpg;
int ic = fm.icpg;
bool ok_type = ((param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_type.enumv() == DTypeEnum::Float16 &&
(param.dst_type.enumv() == DTypeEnum::Float16))) &&
(fm.format == param::Convolution::Format::NCHW88);
bool ok_src_dst = (oc % 8 == 0 && oc >= 8 && ic % 8 == 0 && ic >= 8);
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] &&
(fh == 1 || fh == 2 || fh == 3 || fh == 5 || fh == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
((fm.stride[0] == 1 && fm.stride[1] == 1) ||
(fm.stride[0] == 2 && fm.stride[1] == 2));
bool ok_conv = !fm.should_flip;
bool ok_comp = param.compute_mode == Param::ComputeMode::DEFAULT;
return ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv && ok_comp;
}
size_t ConvBiasImpl::AlgoF16DirectNCHW88::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_nchw88_stride1,
midout_iv("AlgoF16DirectNCHW88::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF16DirectNCHW88::dispatch_kerns(
const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
size_t batch = param.n;
size_t group = fm.group;
WorkspaceBundle wbundle = get_bundle(param);
conv_fun do_conv_fun = nullptr;
// NOTE: remain_w is not used to gen hash of midout for compatible with
// shape runtime
#define DO_CONV_KERN_FUN(filter, bias_mode, op, stride) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_nchw88, \
midout_iv(#filter #bias_mode #stride #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, stride, bias_mode, op>; \
} \
MIDOUT_END();
#define GET_STRIDE_PARAM(filter, bias_mode, op) \
switch (fm.stride[0]) { \
case 1: \
DO_CONV_KERN_FUN(filter, bias_mode, op, 1); \
break; \
case 2: \
DO_CONV_KERN_FUN(filter, bias_mode, op, 2); \
break; \
\
default: \
megdnn_assert(0, "stride not supported"); \
}
#define GET_OP_PARAM(filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
GET_STRIDE_PARAM(filter, bias_mode, NoneOp<__fp16>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
GET_STRIDE_PARAM(filter, bias_mode, ReluOp<__fp16>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
GET_STRIDE_PARAM(filter, bias_mode, HSwishOp<__fp16>) \
break; \
case param::ConvBias::NonlineMode::SIGMOID: \
GET_STRIDE_PARAM(filter, bias_mode, SigmoidOp<__fp16>) \
break; \
default: \
megdnn_assert(0, "nonline not supported"); \
break; \
}
#define GET_BIAS_MODE_PARAM(filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
case BiasMode::BIAS: \
GET_OP_PARAM(filter, BiasMode::BIAS) \
break; \
default: \
megdnn_assert(0, "bias_mode not supported"); \
break; \
}
#define DISPATCH_CONV_KERN() \
switch (param.filter_meta.spatial[0]) { \
case 1: \
GET_BIAS_MODE_PARAM(1) \
break; \
case 2: \
GET_BIAS_MODE_PARAM(2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(7) \
break; \
default: \
megdnn_assert(0, "filter not supported"); \
break; \
}
DISPATCH_CONV_KERN();
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
#undef GET_OP_PARAM
#undef GET_BIAS_MODE_PARAM
#undef DISPATCH_CONV_KERN
megdnn_assert(do_conv_fun);
WorkspaceBundle bundle = get_bundle(param);
SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
auto exec_one_group = [bundle, do_conv_fun](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
auto fm = kern_param.filter_meta;
size_t IC = fm.icpg / 8;
size_t OC = fm.ocpg / 8;
bundle.set(kern_param.workspace_ptr);
for (size_t ic = 0; ic < IC; ic++) {
copy_padding_kern(bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
}
for (size_t oc = 0; oc < OC; oc++) {
do_conv_fun(bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, oc});
}
};
// TODO: large group only, further multithread optimization required
ret_kerns.push_back({exec_one_group, {group, batch, 1_z}});
return ret_kerns;
}
#endif
// vim: syntax=cpp.doxygen
/**
* \file
* dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/fallback/conv_bias/common.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
using namespace megdnn;
using namespace arm_common;
template <int PC, int BW, int pc, int bw>
struct compute_fma {
static inline void call(const float16x8_t* ri, const float16x8_t* rf,
float16x8_t* rdst) {
#if defined(__aarch64__)
rdst[bw] = vfmaq_laneq_f16(rdst[bw], rf[pc], ri[bw], pc);
#else
rdst[bw] = vfmaq_f16(rdst[bw], rf[pc],
vdupq_n_f16(vgetq_lane_f16(ri[bw], pc)));
#endif
compute_fma<PC, BW, pc, bw + 1>::call(ri, rf, rdst);
}
};
template <int PC, int BW, int pc>
struct compute_fma<PC, BW, pc, BW> {
static inline void call(const float16x8_t* ri, const float16x8_t* rf,
float16x8_t* rdst) {
compute_fma<PC, BW, pc + 1, 0>::call(ri, rf, rdst);
}
};
template <int PC, int BW>
struct compute_fma<PC, BW, PC, 0> {
static inline void call(const float16x8_t* ri, const float16x8_t* rf,
float16x8_t* rdst) {}
};
template <int PC, int BW, int bw>
struct load_dst {
static inline void call(float16x8_t* rdst, const float16_t* dst_ptr) {
rdst[bw] = vld1q_f16(dst_ptr + bw * PC);
load_dst<PC, BW, bw + 1>::call(rdst, dst_ptr);
}
};
template <int PC, int BW>
struct load_dst<PC, BW, BW> {
static inline void call(float16x8_t* rdst, const float16_t* dst_ptr) {}
};
template <int PC, int SW, int BW, int bw>
struct load_src {
static inline void call(float16x8_t* ri, const float16_t* src_ptr) {
ri[bw] = vld1q_f16(src_ptr + bw * SW * PC);
load_src<PC, SW, BW, bw + 1>::call(ri, src_ptr);
}
};
template <int PC, int SW, int BW>
struct load_src<PC, SW, BW, BW> {
static inline void call(float16x8_t* ri, const float16_t* src_ptr) {}
};
template <int PC, int pc>
struct load_filter {
static inline void call(float16x8_t* rf, const float16_t* filter_ptr) {
rf[pc] = vld1q_f16(filter_ptr + pc * PC);
load_filter<PC, pc + 1>::call(rf, filter_ptr);
}
};
template <int PC>
struct load_filter<PC, PC> {
static inline void call(float16x8_t* rf, const float16_t* filter_ptr) {}
};
template <int PC, int BW, int bw>
struct store_dst {
static inline void call(const float16x8_t* rdst, float16_t* dst_ptr) {
vst1q_f16(dst_ptr + bw * PC, rdst[bw]);
store_dst<PC, BW, bw + 1>::call(rdst, dst_ptr);
}
};
template <int PC, int BW>
struct store_dst<PC, BW, BW> {
static inline void call(const float16x8_t* rdst, float16_t* dst_ptr) {}
};
template <int FH, int SH, int BW>
static inline void do_conv_kern_1xBW(const float16_t*& src, float16_t*& dst,
const float16_t* filter, int IW, int OW,
int& ow) {
constexpr int PC = 8;
constexpr int FW = FH;
constexpr int SW = SH;
float16x8_t rf[PC];
if (FH == 1 && FW == 1) {
load_filter<PC, 0>::call(rf, filter);
}
for (; ow + BW - 1 < OW; ow += BW) {
float16x8_t rdst[BW];
load_dst<PC, BW, 0>::call(rdst, dst);
for (int fh = 0; fh < FH; ++fh) {
for (int fw = 0; fw < FW; ++fw) {
float16x8_t ri[BW];
load_src<PC, SW, BW, 0>::call(ri, src + (fh * IW + fw) * PC);
if (FH > 1 || FW > 1) {
load_filter<PC, 0>::call(rf,
filter + (fh * FW + fw) * PC * PC);
}
compute_fma<PC, BW, 0, 0>::call(ri, rf, rdst);
}
}
store_dst<PC, BW, 0>::call(rdst, dst);
src += SW * BW * PC;
dst += BW * PC;
}
}
template <BiasMode bias_mode>
static void do_load_bias_kern(float16_t* dst, const float16_t* bias, int OH,
int OW) {
constexpr int PC = 8;
if (bias_mode == BiasMode::NO_BIAS) {
memset(dst, 0, OH * OW * PC * sizeof(float16_t));
} else if (bias_mode == BiasMode::BIAS) {
memcpy(dst, bias, OH * OW * PC * sizeof(float16_t));
} else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
float16x8_t bias_v = vld1q_f16(bias);
int i = 0;
for (; i + 3 < OH * OW; i += 4) {
vst1q_f16(dst + PC * 0, bias_v);
vst1q_f16(dst + PC * 1, bias_v);
vst1q_f16(dst + PC * 2, bias_v);
vst1q_f16(dst + PC * 3, bias_v);
dst += PC * 4;
}
for (; i < OH * OW; i += 1) {
vst1q_f16(dst, bias_v);
dst += PC;
}
}
}
template <typename Op>
static void do_op_kern(float16_t* dst, int OH, int OW) {
constexpr int PC = 8;
Op op;
int i = 0;
for (; i + 3 < OH * OW; i += 4) {
float16x8_t dst0 = vld1q_f16(dst + PC * 0);
float16x8_t dst1 = vld1q_f16(dst + PC * 1);
float16x8_t dst2 = vld1q_f16(dst + PC * 2);
float16x8_t dst3 = vld1q_f16(dst + PC * 3);
dst0 = op(dst0);
dst1 = op(dst1);
dst2 = op(dst2);
dst3 = op(dst3);
vst1q_f16(dst + PC * 0, dst0);
vst1q_f16(dst + PC * 1, dst1);
vst1q_f16(dst + PC * 2, dst2);
vst1q_f16(dst + PC * 3, dst3);
dst += PC * 4;
}
for (; i < OH * OW; i += 1) {
vst1q_f16(dst, op(vld1q_f16(dst)));
dst += PC;
}
}
template <int FH, int SH>
static void do_conv_kern(const float16_t* src, float16_t* dst,
const float16_t* filter, int IC, int IH, int IW,
int OH, int OW) {
constexpr int PC = 8;
constexpr int FW = FH;
for (int ic = 0; ic < IC; ic += 1) {
const float16_t* src_ptr_h = src;
float16_t* dst_ptr_h = dst;
for (int oh = 0; oh < OH; oh += 1) {
const float16_t* src_ptr_w = src_ptr_h;
float16_t* dst_ptr_w = dst_ptr_h;
int ow = 0;
do_conv_kern_1xBW<FH, SH, 4>(src_ptr_w, dst_ptr_w, filter, IW, OW,
ow);
if (OW & 3) {
do_conv_kern_1xBW<FH, SH, 2>(src_ptr_w, dst_ptr_w, filter, IW,
OW, ow);
do_conv_kern_1xBW<FH, SH, 1>(src_ptr_w, dst_ptr_w, filter, IW,
OW, ow);
}
src_ptr_h += SH * IW * PC;
dst_ptr_h += OW * PC;
}
src += IH * IW * PC;
filter += FH * FW * PC * PC;
}
}
static void do_conv_kern_1x1(const float16_t* src, float16_t* dst,
const float16_t* filter, int IC, int OH, int OW) {
constexpr int PC = 8;
const int IH = OH;
const int IW = OW;
const int IHW = IH * IW;
const int OHW = OH * OW;
for (int ic = 0; ic < IC; ic += 1) {
const float16_t* src_ptr_hw = src;
float16_t* dst_ptr_hw = dst;
int ohw = 0;
do_conv_kern_1xBW<1, 1, 8>(src_ptr_hw, dst_ptr_hw, filter, IHW, OHW,
ohw);
do_conv_kern_1xBW<1, 1, 4>(src_ptr_hw, dst_ptr_hw, filter, IHW, OHW,
ohw);
do_conv_kern_1xBW<1, 1, 1>(src_ptr_hw, dst_ptr_hw, filter, IHW, OHW,
ohw);
src += IHW * PC;
filter += PC * PC;
}
}
template <size_t FH, size_t SH, BiasMode bias_mode, typename Op>
void conv_bias::conv_direct_fp16_nchw88(const __fp16* src, const __fp16* filter,
const __fp16* bias, __fp16* dst, int IC,
int IH, int IW, int OH, int OW) {
do_load_bias_kern<bias_mode>(dst, bias, OH, OW);
if (FH == 1 && SH == 1 && IH == OH && IW == OW) {
do_conv_kern_1x1(src, dst, filter, IC, OH, OW);
} else {
do_conv_kern<FH, SH>(src, dst, filter, IC, IH, IW, OH, OW);
}
do_op_kern<Op>(dst, OH, OW);
}
#define INSTANTIATION(stride, filter, bias, Op) \
template void \
conv_bias::conv_direct_fp16_nchw88<filter, stride, bias, Op>( \
const __fp16*, const __fp16*, const __fp16*, __fp16*, int, int, \
int, int, int);
#define FOR_OP(stride, filter, bias) \
INSTANTIATION(stride, filter, bias, SigmoidOp<__fp16>) \
INSTANTIATION(stride, filter, bias, ReluOp<__fp16>) \
INSTANTIATION(stride, filter, bias, HSwishOp<__fp16>) \
INSTANTIATION(stride, filter, bias, NoneOp<__fp16>)
#define FOR_BIAS(stride, filter) \
FOR_OP(stride, filter, BiasMode::NO_BIAS) \
FOR_OP(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(stride, filter, BiasMode::BIAS)
#define FOR_FILTER(stride) \
FOR_BIAS(stride, 1) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)
#define FOR_STRIDE \
FOR_FILTER(1) \
FOR_FILTER(2)
FOR_STRIDE
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_BIAS
#undef FOR_OP
#undef INSTANTIATION
#endif
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
namespace megdnn {
namespace arm_common {
namespace conv_bias {
template <size_t FH, size_t SH, BiasMode bias_mode, typename Op>
void conv_direct_fp16_nchw88(const __fp16* src, const __fp16* filter,
const __fp16* bias, __fp16* dst, int IC, int IH,
int IW, int OH, int OW);
} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn
#endif
......@@ -86,6 +86,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoF16Direct f16_direct;
AlgoF16DirectStride1 f16_direct_stride1;
AlgoF16ChannelWiseNCHW88 f16_channel_wise_nchw88;
AlgoF16DirectNCHW88 f16_direct_nchw88;
#endif
SmallVector<std::unique_ptr<AlgoBase>> refhold;
......@@ -121,6 +122,7 @@ public:
m_direct_algos.emplace_back(&f16_direct_stride1);
m_direct_algos.emplace_back(&f16_direct);
m_direct_algos.emplace_back(&f16_channel_wise_nchw88);
m_direct_algos.emplace_back(&f16_direct_nchw88);
#endif
m_direct_algos.emplace_back(&i8x8x16_direct);
m_direct_algos.emplace_back(&i8x8x16_stride2_filter2);
......@@ -252,7 +254,6 @@ public:
}
}
for (auto&& algo : m_direct_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
......@@ -261,8 +262,7 @@ public:
}
}
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos()
const {
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() const {
return m_direct_algos;
}
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& winograd_algos()
......
......@@ -10,9 +10,9 @@
* implied.
*/
#pragma once
#include "src/common/algo_base.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/common/algo_base.h"
namespace megdnn {
namespace arm_common {
......@@ -28,7 +28,8 @@ public:
}
};
SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() override;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo()
override;
bool is_matmul_quantized_prefer(
const fallback::ConvBiasImpl::NCBKernSizeParam& ncb_param)
......@@ -97,6 +98,7 @@ private:
class AlgoF16Direct;
class AlgoF16DirectStride1;
class AlgoF16ChannelWiseNCHW88;
class AlgoF16DirectNCHW88;
#endif
class AlgoPack;
......
......@@ -56,8 +56,7 @@ public:
bool is_thread_safe() const override { return true; }
void exec_preprocess(const TensorLayout& src_layout,
_megdnn_tensor_in filter,
_megdnn_tensor_in bias,
_megdnn_tensor_in filter, _megdnn_tensor_in bias,
const TensorLayout& z_layout,
const TensorLayout& dst_layout,
PreprocessedFilter* preprocessed_filter,
......@@ -243,6 +242,7 @@ public:
ARM_COMMON_DIRECT_FP16,
ARM_COMMON_DIRECT_STRD1_FP16,
ARM_COMMON_CHWNWISE_NCHW88_F16,
ARM_COMMON_DIRECT_NCHW88_FP16,
ARM_COMMON_WINOGRAD_F23_4X4_FP32,
ARM_COMMON_WINOGRAD_F63_FP32,
ARM_COMMON_WINOGRAD_F63_4X4_FP32,
......@@ -288,7 +288,7 @@ public:
#else
ARMV7_MATMUL_S8,
ARMV7_MATMUL_QU8,
#endif // MEGDNN_AARCH64
#endif // MEGDNN_AARCH64
#endif
};
......
......@@ -124,8 +124,8 @@ std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args(
for (size_t n : {1, 2}) {
for (auto nlmode : nonlinemode) {
for (bool pad : {true}) {
for (size_t group : {1, 2, 4, 7, 128}) {
for (size_t size : {4, 6, 7, 9, 15, 40}) {
for (size_t group : {1, 2, 4, 7, 16}) {
for (size_t size : {4, 6, 7, 9, 20}) {
for (size_t kern : kernel) {
pack(n, group, size, size, kern, stride, nlmode,
pad);
......@@ -134,8 +134,8 @@ std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args(
}
}
for (bool pad : {false}) {
for (size_t group : {1, 2, 7, 128}) {
for (size_t size : {7, 9, 15, 40}) {
for (size_t group : {1, 2, 7, 16}) {
for (size_t size : {7, 9, 20}) {
for (size_t kern : kernel) {
pack(n, group, size, size, kern, stride, nlmode,
pad);
......@@ -199,8 +199,8 @@ std::vector<conv_bias::TestArg> get_nchw88_channel_wise_args(
for (size_t n : {1, 2}) {
for (auto nlmode : nonlinemode) {
for (bool pad : {true}) {
for (size_t group : {1, 2, 4, 7, 8, 128}) {
for (size_t size : {4, 6, 7, 9, 15, 40}) {
for (size_t group : {1, 2, 4, 7, 8, 16}) {
for (size_t size : {4, 6, 7, 9, 20}) {
for (size_t kern : kernel) {
pack(n, group, size, size, kern, stride, nlmode,
pad);
......@@ -209,8 +209,8 @@ std::vector<conv_bias::TestArg> get_nchw88_channel_wise_args(
}
}
for (bool pad : {false}) {
for (size_t group : {1, 2, 7, 128}) {
for (size_t size : {7, 9, 15, 40}) {
for (size_t group : {1, 2, 7, 16}) {
for (size_t size : {7, 9, 20}) {
for (size_t kern : kernel) {
pack(n, group, size, size, kern, stride, nlmode,
pad);
......@@ -412,6 +412,23 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP16_NCHW88) {
get_nchw88_channel_wise_args({2, 3, 5}, 2, false, false, false),
handle(), rng, "F16_CHANNEL_WISE_NCHW88", 0.03);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_NCHW88_S1) {
NormalRNG rng(1);
checker_conv_bias_f16(
get_nchw88_conv_bias_args({1, 2, 3, 5, 7}, FULL_NLMODE,
ALL_BIASMODE, 1),
handle(), rng, "F16_CONV_NCHW88_DIRECT", 0.03);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_NCHW88_S2) {
NormalRNG rng(1);
checker_conv_bias_f16(
get_nchw88_conv_bias_args({1, 2, 3, 5, 7}, FULL_NLMODE,
ALL_BIASMODE, 2),
handle(), rng, "F16_CONV_NCHW88_DIRECT", 0.03);
}
#endif
/**********************************algo 8816 direct************************/
......@@ -794,8 +811,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) {
check_winograd("1:6:32", checker, args);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) {
using namespace conv_bias;
std::vector<TestArg> args = get_winograd_mk_packed_args();
......@@ -804,19 +819,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) {
check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) {
using namespace conv_bias;
std::vector<TestArg> args =
get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1);
Checker<ConvBiasForward> checker(handle());
check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4,
param::ConvBias::Format::NCHW44);
}
//! uncomment it when low precision mode is ok
#if 0
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) {
......@@ -847,8 +858,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) {
check_winograd("1:5:32", checker, args);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) {
using namespace conv_bias;
std::vector<TestArg> args = get_winograd_args(5);
......@@ -971,18 +980,17 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
using namespace conv_bias;
Checker<ConvBiasForward> checker(handle());
auto run = [&checker](const std::vector<TestArg>& args,
DType A_dtype,
auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype,
DType B_dtype, DType C_dtype, DType D_dtype,
float eps) {
for (auto&& arg : args) {
checker.set_dtype(0, A_dtype)
.set_dtype(1, B_dtype)
.set_dtype(2, C_dtype)
.set_dtype(4, D_dtype)
.set_epsilon(eps)
.set_param(arg.param)
.execs({arg.src, arg.filter, arg.bias, {}, {}});
checker.set_dtype(0, A_dtype)
.set_dtype(1, B_dtype)
.set_dtype(2, C_dtype)
.set_dtype(4, D_dtype)
.set_epsilon(eps)
.set_param(arg.param)
.execs({arg.src, arg.filter, arg.bias, {}, {}});
}
};
......@@ -997,9 +1005,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4);
UniformIntRNG int_rng{-50, 50};
checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
run(quantized_args, dtype::QuantizedS8(2.5f),
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f),
dtype::QuantizedS8(60.25f),1e-3);
run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3);
}
TEST_F(ARM_COMMON_MULTI_THREADS,
......
......@@ -400,7 +400,8 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16_STR1) {
benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}},
{1, {4}}, data_type);
}
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_F16_NCHW88) {
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CHANNEL_WISE_FP16_NCHW88) {
constexpr size_t RUNS = 50;
std::string algo_name = "F16_CHANNEL_WISE_NCHW88";
......@@ -462,6 +463,64 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_F16_NCHW88) {
bench_case(1, 64, 28, 28, 2, 0, 2);
}
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_FP16_NCHW88) {
constexpr size_t RUNS = 40;
std::vector<DType> data_type = {dtype::Float16(), dtype::Float16(),
dtype::Float16(), dtype::Float16()};
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W,
size_t FS, size_t group, size_t P, size_t S) {
param::ConvBias param;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
param.pad_h = P;
param.pad_w = P;
param.stride_h = S;
param.stride_w = S;
param.sparse = param::ConvBias::Sparse::DENSE;
param.format = param::ConvBias::Format::NCHW88;
auto OH = (H + 2 * P - FS) / static_cast<size_t>(S) + 1;
auto OW = (W + 2 * P - FS) / static_cast<size_t>(S) + 1;
TensorShape src = {N, IC / 8, H, W, 8};
TensorShape filter = {OC / 8, IC / 8, FS, FS, 8, 8};
if (group > 1) {
filter = {group, OC / group / 8, IC / group / 8, FS, FS, 8, 8};
param.sparse = param::ConvBias::Sparse::GROUP;
}
TensorShape bias = {1, OC / 8, 1, 1, 8};
TensorShape dst = {N, OC / 8, OH, OW, 8};
SmallVector<TensorShape> shapes{src, filter, bias, {}, dst};
float computations =
(((IC / group) * FS * FS + 1) * dst.total_nr_elems() * 2 +
dst.total_nr_elems()) *
1e-6;
std::vector<std::pair<SmallVector<TensorShape>, float>> shape_arg = {
std::make_pair(shapes, computations)};
benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}},
{1, {7}}, data_type);
};
bench_case(1, 64, 64, 28, 28, 3, 1, 1, 1);
bench_case(1, 64, 64, 28, 28, 5, 1, 2, 1);
bench_case(1, 64, 64, 28, 28, 7, 1, 3, 1);
bench_case(1, 64, 64, 28, 28, 3, 1, 1, 2);
bench_case(1, 64, 64, 28, 28, 5, 1, 2, 2);
bench_case(1, 64, 64, 28, 28, 7, 1, 3, 2);
bench_case(1, 64, 64, 28, 28, 3, 2, 1, 1);
bench_case(1, 64, 64, 28, 28, 3, 4, 1, 1);
bench_case(1, 64, 64, 28, 28, 3, 8, 1, 1);
bench_case(1, 16, 16, 28, 28, 3, 1, 1, 1);
bench_case(1, 32, 32, 28, 28, 3, 1, 1, 1);
bench_case(1, 128, 128, 28, 28, 3, 1, 1, 1);
bench_case(1, 256, 256, 28, 28, 3, 1, 1, 1);
bench_case(1, 64, 64, 7, 7, 3, 1, 1, 1);
bench_case(1, 64, 64, 14, 14, 3, 1, 1, 1);
bench_case(1, 64, 64, 56, 56, 3, 1, 1, 1);
bench_case(1, 64, 64, 112, 112, 3, 1, 1, 1);
}
#endif
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
BENCHMARK_CONVBIAS_DIRECT_INT8x8x16) {
......@@ -769,10 +828,10 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) {
bench_case(1, 128, 128, 28, 28, 3, 4, 1, 1);
bench_case(1, 256, 256, 14, 14, 3, 4, 1, 1);
bench_case(1, 512, 512, 7, 7, 3, 4, 1, 1);
}
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT_S2) {
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
BENCHMARK_CONVBIAS_INT8_NCHW44_DOT_S2) {
constexpr size_t RUNS = 40;
std::vector<DType> data_type = {
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
......@@ -825,16 +884,13 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT_S2
bench_case(1, 128, 128, 28, 28, 3, 4, 1, 2);
bench_case(1, 256, 256, 14, 14, 3, 4, 1, 2);
bench_case(1, 512, 512, 7, 7, 3, 4, 1, 2);
}
#endif
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_FLOAT_NCHW44) {
constexpr size_t RUNS = 40;
std::vector<DType> data_type = {
dtype::Float32(), dtype::Float32(),
dtype::Float32(), dtype::Float32()};
std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(),
dtype::Float32(), dtype::Float32()};
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W,
size_t FS, size_t group, size_t P, size_t S,
bool is_nchw = false) {
......@@ -880,15 +936,12 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_FLOAT_NCHW44) {
bench_case(1, 128, 128, 28, 28, 3, 4, 1, 2);
bench_case(1, 256, 256, 14, 14, 3, 4, 1, 2);
bench_case(1, 512, 512, 7, 7, 3, 4, 1, 2);
bench_case(1, 64, 64, 56*2, 56*2, 3, 4, 1, 2);
bench_case(1, 128, 128, 28*2, 28*2, 3, 4, 1, 2);
bench_case(1, 256, 256, 14*2, 14*2, 3, 4, 1, 2);
bench_case(1, 512, 512, 7*2, 7*2, 3, 4, 1, 2);
}
bench_case(1, 64, 64, 56 * 2, 56 * 2, 3, 4, 1, 2);
bench_case(1, 128, 128, 28 * 2, 28 * 2, 3, 4, 1, 2);
bench_case(1, 256, 256, 14 * 2, 14 * 2, 3, 4, 1, 2);
bench_case(1, 512, 512, 7 * 2, 7 * 2, 3, 4, 1, 2);
}
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE2) {
......@@ -1473,9 +1526,9 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_WINOGRAD_INT8) {
algo_name = "WINOGRAD:ARMV7_INT16X16X32_MK8_4X8:8:2:32";
#endif
std::vector<DType> data_type = {dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
dtype::QuantizedS32(6.25f) ,dtype::QuantizedS8(60.25f) };
std::vector<DType> data_type = {
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f)};
printf("Benchmark WINOGRAD_IN8_MK8 algo\n");
benchmark_impl(param, shapes_and_computation, algo_name, RUNS,
{4, {4, 5, 6, 7}}, {1, {4}}, data_type);
......@@ -1839,7 +1892,6 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
{1, {4}}, data_type);
}
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
BENCHMARK_IM2COL_NCHW44_INT8x8x32_STRIDE1) {
constexpr size_t RUNS = 50;
......@@ -1852,18 +1904,17 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
param.stride_w = 1;
param.sparse = param::ConvBias::Sparse::DENSE;
param.format = param::ConvBias::Format::NCHW44;
std::vector<std::pair<SmallVector<TensorShape>, float>>
shapes_and_computation;
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W,
size_t FS, size_t group=1) {
SmallVector<TensorShape> shapes{{N, IC, H, W,4},
{OC, IC / group, FS, FS,4,4},
size_t FS, size_t group = 1) {
SmallVector<TensorShape> shapes{{N, IC, H, W, 4},
{OC, IC / group, FS, FS, 4, 4},
{/*1, OC, 1, 1*/},
{},
{N, OC, H, W,4}};
TensorShape dst{N, OC, H, W,4};
{N, OC, H, W, 4}};
TensorShape dst{N, OC, H, W, 4};
float computations =
((4 * IC / group) * FS * FS * dst.total_nr_elems() * 2 +
dst.total_nr_elems()) *
......@@ -1907,9 +1958,10 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
#endif
std::string algo_name = "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96";
printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96 algo\n");
std::vector<DType> data_type = {
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
dtype::QuantizedS32(6.25f), {}};
std::vector<DType> data_type = {dtype::QuantizedS8(2.5f),
dtype::QuantizedS8(2.5f),
dtype::QuantizedS32(6.25f),
{}};
benchmark_impl(param, shapes_and_computation, algo_name, RUNS,
{4, {4, 5, 6, 7}}, {1, {4}}, data_type);
benchmark_impl(param, shapes_and_computation, algo_name, RUNS,
......@@ -1917,10 +1969,9 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}},
{1, {4}}, data_type);
algo_name = "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:192";
printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:192 algo\n");
printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:192 "
"algo\n");
benchmark_impl(param, shapes_and_computation, algo_name, RUNS,
{4, {4, 5, 6, 7}}, {1, {4}}, data_type);
benchmark_impl(param, shapes_and_computation, algo_name, RUNS,
......@@ -1929,14 +1980,14 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
{1, {4}}, data_type);
algo_name = "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:384";
printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:384 algo\n");
printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:384 "
"algo\n");
benchmark_impl(param, shapes_and_computation, algo_name, RUNS,
{4, {4, 5, 6, 7}}, {1, {4}}, data_type);
benchmark_impl(param, shapes_and_computation, algo_name, RUNS,
{4, {4, 5, 6, 7}}, {1, {7}}, data_type);
benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}},
{1, {4}}, data_type);
}
#endif
......
......@@ -1185,9 +1185,10 @@ void check_conv_bias_preprocess(std::vector<conv_bias::TestArg> args,
}
}
void checker_conv_bias_common(std::vector<conv_bias::TestArg> args, Handle* handle,
RNG* rng, float epsilon, DType type0, DType type1,
DType type2, DType type3, const char* algo_name) {
void checker_conv_bias_common(std::vector<conv_bias::TestArg> args,
Handle* handle, RNG* rng, float epsilon,
DType type0, DType type1, DType type2,
DType type3, const char* algo_name) {
using namespace conv_bias;
Checker<ConvBias> checker(handle);
......@@ -1377,6 +1378,88 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
}
return args;
}
std::vector<conv_bias::TestArg> get_nchw88_conv_bias_args(
std::vector<size_t> kernel_vec,
std::vector<param::ConvBias::NonlineMode> nlmode_vec,
std::vector<megdnn::BiasMode> biasmode_vec, size_t stride) {
using namespace conv_bias;
using NLMode = param::ConvBias::NonlineMode;
std::vector<TestArg> args;
auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w,
size_t kernel, size_t stride, size_t group, NLMode nlmode,
megdnn::BiasMode bias_mode) {
constexpr int pack_c = 8;
const size_t pad = kernel / 2;
auto oc_per_group = oc / group;
auto ic_per_group = ic / group;
megdnn_assert(oc_per_group % pack_c == 0 && ic_per_group % pack_c == 0,
"ocpg/icpg not divided by 8");
size_t kernel_h = kernel;
size_t kernel_w = kernel;
param::ConvBias param;
param.format = param::ConvBias::Format::NCHW88;
param.stride_h = stride;
param.stride_w = stride;
param.pad_h = pad;
param.pad_w = pad;
param.nonlineMode = nlmode;
auto src_tensor_shape = TensorShape{n, ic / pack_c, h, w, pack_c};
auto weight_tensor_shape = TensorShape{
oc / pack_c, ic / pack_c, kernel_h, kernel_w, pack_c, pack_c};
auto bias_tensor_shape = TensorShape{};
if (bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) {
bias_tensor_shape = {1, oc / pack_c, 1, 1, pack_c};
} else if (bias_mode == megdnn::BiasMode::BIAS) {
bias_tensor_shape = {n, oc / pack_c,
(h + 2 * pad - kernel) / stride + 1,
(w + 2 * pad - kernel) / stride + 1, pack_c};
}
if (group == 1) {
param.sparse = param::ConvBias::Sparse::DENSE;
} else {
param.sparse = param::ConvBias::Sparse::GROUP;
weight_tensor_shape = TensorShape{group,
oc_per_group / pack_c,
ic_per_group / pack_c,
kernel_h,
kernel_w,
pack_c,
pack_c};
}
args.emplace_back(param, src_tensor_shape, weight_tensor_shape,
bias_tensor_shape);
};
for (auto bias : biasmode_vec)
for (auto nlmode : nlmode_vec)
for (size_t n : {1, 2})
for (size_t kernel : kernel_vec)
for (size_t oc : {8, 16})
for (size_t ic : {8, 16, 24})
for (size_t h : {1, 3, 12})
for (size_t w : {1, 8, 13}) {
for (size_t group = 1; group < oc / 8;
++group) {
if (ic % (group * 8) ||
oc % (group * 8)) {
continue;
}
if (kernel < h || kernel < w) {
continue;
}
pack(n, oc, ic, h, w, kernel, stride,
group, nlmode, bias);
}
}
return args;
}
} // namespace conv_bias
} // namespace test
} // namespace megdnn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册