提交 97ddff4d 编写于 作者: M Megvii Engine Team

feat(dnn): add fp16 hybrid direct conv algo

GitOrigin-RevId: 6192b0ffb40ed812fc386299d0e1354bf4556f11
上级 5705368c
#include "src/common/utils.h"
namespace megdnn {
namespace {
// block_helper is used to calculate oh block size
static inline int l2_block_helper(
const int nthread, const int amount, const int size_per_unit) {
constexpr int l2_cache_size = 256 * 1024;
const int block_per_thread = div_ceil(amount, nthread);
const int best_block =
std::min(amount, (l2_cache_size + size_per_unit / 2) / size_per_unit);
const int max_block_num = div_ceil(block_per_thread, best_block);
const int min_block_num = std::max(max_block_num - 1, 1);
const int max_block = div_ceil(block_per_thread, max_block_num);
const int min_block = div_ceil(block_per_thread, min_block_num);
const int max_loss = std::abs(max_block_num * max_block - block_per_thread);
const int min_loss = std::abs(min_block_num * min_block - block_per_thread);
int block = max_loss > min_loss ? min_block : max_block;
return block;
}
} // namespace
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
......@@ -156,6 +156,23 @@ public:
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW88_FP16)
};
class ConvBiasImpl::AlgoF16DirectNchwNchw88 final : public AlgoBase {
public:
AlgoF16DirectNchwNchw88() = default;
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "DIRECT_CONV_F16_NCHW_NCHW88"; }
bool usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW88_FP16)
};
} // namespace arm_common
} // namespace megdnn
#endif
......
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/f16/algos.h"
#include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h"
......
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include "src/arm_common/conv_bias/f16/algos.h"
#include "src/arm_common/conv_bias/f16/direct_nchw_nchw88_kern.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/gi/block_helper.h"
#include "src/fallback/conv_bias/opr_impl.h"
using namespace megdnn;
using namespace arm_common;
MIDOUT_DECL(megdnn_arm_common_direct_conv_nchw_nchw88_fp16)
namespace {
static void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2,
int& iw2, int& oh2, int& ow2, int& block_oh) {
int ic = param.filter_meta.icpg;
int iw = param.isz[1];
int oh = param.osz[0];
int ow = param.osz[1];
oh2 = oh;
ow2 = ow;
iw2 = iw + 2 * static_cast<int>(param.filter_meta.padding[1]);
const int sh = static_cast<int>(param.filter_meta.stride[0]);
block_oh = l2_block_helper(param.nr_threads, oh, ic * iw2 * sh * sizeof(__fp16));
const int fh = static_cast<int>(param.filter_meta.spatial[0]);
ih2 = (block_oh - 1) * sh + fh;
}
static WorkspaceBundle get_bundle(
const fallback::ConvBiasImpl::NCBKernSizeParam& param) {
const auto& fm = param.filter_meta;
const int group = fm.group;
const int ic = fm.icpg;
const int oc = fm.ocpg;
const int fh = fm.spatial[0];
const int fw = fm.spatial[1];
int ih2, iw2, oh2, ow2, oh_block;
get_rectified_size(param, ih2, iw2, oh2, ow2, oh_block);
megdnn_assert(oh_block != 0, "oh_block == 0");
const size_t src_size = ic * ih2 * iw2 * sizeof(__fp16);
const size_t weight_size = group * oc * ic * fh * fw * sizeof(__fp16);
return {nullptr, {src_size * param.nr_threads, weight_size}};
}
static void pack_weight(
const WorkspaceBundle& bundle,
const fallback::ConvBiasImpl::NCBKernParam& kern_param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
const int group_id = ncb_index.ndrange_id[0];
const auto& fm = kern_param.filter_meta;
const int oc = fm.ocpg;
const int ic = fm.icpg;
const int fh = fm.spatial[0];
const int fw = fm.spatial[1];
const int oc_idx = 0;
const int oc_block = oc;
const __fp16* weight =
reinterpret_cast<const __fp16*>(kern_param.filter<dt_float16>(group_id)) +
oc_idx * ic * fh * fw;
__fp16* packed_weight = static_cast<__fp16*>(bundle.get(1)) +
group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw;
fp16_direct_nchw_nchw88::pack_weight_fp16_nchw_nchw88(
weight, packed_weight, oc_block, fh, fw, ic);
}
/**
* @brief Copy data from sptr_origin to sptr, and padding.
*
*/
static inline void src_copy_pad(
__fp16* sptr, const __fp16* sptr_origin, const int pw, const int pad_right,
const int pad_top, const int pad_buttom, const int ic, const int ih,
const int iw, const int iw2, const int ld_src_ic) {
rep(ic_idx, ic) {
const __fp16* ic_sptr_origin = sptr_origin + ic_idx * ld_src_ic;
memset(sptr, 0, iw2 * pad_top * sizeof(__fp16));
sptr += iw2 * pad_top;
rep(ih_idx, ih) {
memset(sptr, 0, pw * sizeof(__fp16));
sptr += pw;
memcpy(sptr, ic_sptr_origin, iw * sizeof(__fp16));
sptr += iw;
ic_sptr_origin += iw;
memset(sptr, 0, pad_right * sizeof(__fp16));
sptr += pad_right;
}
memset(sptr, 0, iw2 * pad_buttom * sizeof(__fp16));
sptr += iw2 * pad_buttom;
}
}
template <BiasMode bias_mode, typename Op, int filter_size, int stride>
static void do_conv_kern(
const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index) {
const int oc = kern_param.filter_meta.ocpg;
const int oh = kern_param.osz[0];
const int ow = kern_param.osz[1];
const int ic = kern_param.filter_meta.icpg;
const int ih = kern_param.isz[0];
const int iw = kern_param.isz[1];
const int fh = kern_param.filter_meta.spatial[0];
const int fw = kern_param.filter_meta.spatial[1];
const int sh = kern_param.filter_meta.stride[0];
const int ph = kern_param.filter_meta.padding[0];
const int pw = kern_param.filter_meta.padding[1];
int ih2, iw2, oh2, ow2, oh_block;
get_rectified_size(kern_param, ih2, iw2, oh2, ow2, oh_block);
constexpr int pack_oc = 8;
const int batch_id = ncb_index.ndrange_id[0];
const int group_id = ncb_index.ndrange_id[1];
int oc_idx = 0;
int oc_block = oc;
const int oh_idx = ncb_index.ndrange_id[2];
const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block);
const int ih_block_real = (oh_block_real - 1) * sh + fh;
const int src_top_pad = std::max(0, ph - oh_idx * oh_block * sh);
const int src_buttom_pad =
std::max(0, (oh_idx * oh_block + oh_block_real - 1) * sh + fh - ph - ih);
const int src_right_pad = std::max(iw2 - iw - pw, 0);
const int src_offset = std::max(oh_idx * oh_block * sh - ph, 0) * iw;
const __fp16* origin_ptr = reinterpret_cast<const __fp16*>(
kern_param.src<dt_float16>(batch_id, group_id)) +
src_offset;
const size_t src_size = sizeof(__fp16) * ic * ih2 * iw2;
__fp16* sptr = reinterpret_cast<__fp16*>(
reinterpret_cast<int8_t*>(bundle.get(0)) + ncb_index.thread_id * src_size);
src_copy_pad(
sptr, origin_ptr, pw, src_right_pad, src_top_pad, src_buttom_pad, ic,
ih_block_real - src_top_pad - src_buttom_pad, iw, iw2, ih * iw);
//! packed weight
const __fp16* weight = reinterpret_cast<const __fp16*>(bundle.get(1)) +
group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw;
__fp16* dst =
reinterpret_cast<__fp16*>(kern_param.dst<dt_float16>(batch_id, group_id)) +
oh_idx * oh_block * ow * pack_oc;
const __fp16* bias = reinterpret_cast<const __fp16*>(
kern_param.bias<dt_float16>(batch_id, group_id)) +
oc_idx;
Op op;
fp16_direct_nchw_nchw88::fp16_direct_conv_nchw_nchw88<
bias_mode, Op, filter_size, stride>(
sptr, weight, bias, dst, oc_block, ic, ih_block_real, iw2, oh,
oh_block_real, ow2, op);
}
} // namespace
bool ConvBiasImpl::AlgoF16DirectNchwNchw88::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
return nchw_nchwxx_valid<NchwNchwxxType::NCHW88_FP16>(
param.src_type.enumv(), param.filter_type.enumv(), param.dst_type.enumv(),
param.filter_meta, param.bias_mode, param.nonlineMode);
}
size_t ConvBiasImpl::AlgoF16DirectNchwNchw88::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_direct_conv_nchw_nchw88_fp16,
midout_iv("AlgoF16DirectNchwNchw88::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectNchwNchw88::
dispatch_kerns(const NCBKernSizeParam& param) const {
using conv_func_ptr = std::function<void(
const WorkspaceBundle& bundle,
const fallback::ConvBiasImpl::NCBKernParam& kern_param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index)>;
const auto& fm = param.filter_meta;
const int batch = param.n;
const int group = fm.group;
auto bundle = get_bundle(param);
conv_func_ptr conv_func = nullptr;
#define CONV_FUNC(bias_mode, op, filter, stride) \
MIDOUT_BEGIN( \
megdnn_arm_common_direct_conv_nchw_nchw88_fp16, \
midout_iv(#bias_mode #op #filter #stride##_hash)) { \
conv_func = do_conv_kern<bias_mode, op, filter, stride>; \
} \
MIDOUT_END();
#define FOR_OP(bias_mode, filter, stride) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
CONV_FUNC(bias_mode, NoneOp<__fp16>, filter, stride); \
break; \
case param::ConvBias::NonlineMode::RELU: \
CONV_FUNC(bias_mode, ReluOp<__fp16>, filter, stride); \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
CONV_FUNC(bias_mode, HSwishOp<__fp16>, filter, stride); \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define FOR_BIAS_MODE(filter, stride) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
FOR_OP(BiasMode::NO_BIAS, filter, stride); \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
FOR_OP(BiasMode::BROADCAST_CHANNEL_BIAS, filter, stride); \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define FOR_FILTER(stride) \
switch (fm.spatial[0]) { \
case 2: \
FOR_BIAS_MODE(2, stride); \
break; \
case 3: \
FOR_BIAS_MODE(3, stride); \
break; \
case 5: \
FOR_BIAS_MODE(5, stride); \
break; \
case 7: \
FOR_BIAS_MODE(7, stride); \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define FOR_STRIDE() \
switch (fm.stride[0]) { \
case 1: \
FOR_FILTER(1); \
break; \
case 2: \
FOR_FILTER(2); \
break; \
default: \
megdnn_assert(0); \
break; \
}
FOR_STRIDE()
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_BIAS_MODE
#undef FOR_OP
#undef CONV_FUNC
megdnn_assert(conv_func);
SmallVector<NCBKern> ret_kerns;
int oh = param.osz[0];
int ih2, iw2, oh2, ow2, oh_block;
get_rectified_size(param, ih2, iw2, oh2, ow2, oh_block);
auto do_pack_weight = [bundle](
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr);
pack_weight(bundle, kern_param, ncb_index);
};
ret_kerns.push_back({do_pack_weight, {static_cast<size_t>(group)}});
CpuNDRange ncb_range{
static_cast<size_t>(batch), static_cast<size_t>(group),
static_cast<size_t>(div_ceil(oh, oh_block))};
auto do_conv = [bundle, conv_func](
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr);
conv_func(bundle, kern_param, ncb_index);
};
ret_kerns.push_back({do_conv, ncb_range});
return ret_kerns;
}
#endif
\ No newline at end of file
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include "src/arm_common/conv_bias/f16/direct_nchw_nchw88_kern_common.h"
#pragma once
namespace megdnn {
namespace arm_common {
namespace fp16_direct_nchw_nchw88 {
//! (OC/8, FH, FW, IC, 8) --> (OC/8, IC, FH, FW, 8)
static inline void pack_weight_fp16_nchw_nchw88(
const __fp16* in_ptr, __fp16* dst_ptr, const int oc, const int fh, const int fw,
const int ic) {
constexpr int oc_step = 8;
const int ld_ic = fh * fw * oc_step;
const int ld_oc = ic * fh * fw;
for (int oc_idx = 0; oc_idx < oc; oc_idx += oc_step) {
const __fp16* in_ptr_oc = in_ptr + ld_oc * oc_idx;
__fp16* dst_ptr_oc = dst_ptr + ld_oc * oc_idx;
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
for (int fw_idx = 0; fw_idx < fw; ++fw_idx) {
for (int ic_idx = 0; ic_idx < ic; ++ic_idx) {
vst1q_f16(dst_ptr_oc + ic_idx * ld_ic, vld1q_f16(in_ptr_oc));
in_ptr_oc += oc_step;
}
dst_ptr_oc += oc_step;
}
}
}
}
template <BiasMode bias_mode, typename Op, int filter_size, int stride>
static void fp16_direct_conv_nchw_nchw88(
const __fp16* src, const __fp16* filter, const __fp16* bias, __fp16* dst,
const int oc, const int ic, const int ih, const int iw, const int oh,
const int oh_block, const int ow, const Op& op) {
ConvDirectNchwNchw88Fp16<bias_mode, Op, filter_size, stride>::impl(
src, filter, bias, dst, oc, ic, ih, iw, oh, oh_block, ow, op);
}
} // namespace fp16_direct_nchw_nchw88
} // namespace arm_common
} // namespace megdnn
#endif
\ No newline at end of file
#pragma once
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/intrinsic_helper.h"
#include "src/arm_common/neon_struct.h"
#include "src/common/unroll_macro.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace {
/**
* @brief Say src -> (N, IC, IH, IW), weight -> (OC, IC, FH, FW), bias -> (1, OC, 1, 1)
* Calculate (n, ic, ih, iw) * (oc : oc + nr_oc_block * oc_block, ic, fh, fw) + (0,
* oc : oc + nr_oc_block * oc_block, 0, 0)
*
* @tparam src_idx related to ih and iw above
* @tparam weight_idx related to fh and fw above
* @tparam nr_oc_block number of oc block
* @tparam stride
* @tparam nr_ow This function calculates the value of nr_ow positions at a time.
* @tparam T1
* @tparam T2
* @tparam T3
*/
template <
int src_idx, int weight_idx, int nr_oc_block, int stride, int nr_ow,
typename T1, typename T2, typename T3>
struct CalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T1& bias, const T2& src, const T3& weight);
};
template <
int src_idx, int weight_idx, int nr_oc_block, int stride, typename T1,
typename T2, typename T3>
struct CalHelper<src_idx, weight_idx, nr_oc_block, stride, 0, T1, T2, T3> {
static MEGDNN_ALWAYS_INLINE void impl(T1& bias, const T2& src, const T3& weight){};
};
#if defined(__ARM_FEATURE_FMA)
#if MEGDNN_AARCH64
#define fma_lane_f16(a, b, v, lane) vfmaq_laneq_f16((a), (b), (v), (lane))
#else
#define fma_lane_f16(a, b, v, lane) \
vfmaq_f16((a), (b), vdupq_n_f16(vgetq_lane_f16(v, lane)))
#endif
#else
#if MEGDNN_AARCH64
#define fma_lane_f16(a, b, v, lane) vaddq_f16((a), vmulq_laneq_f16((b), (v), (lane)))
#else
#define fma_lane_f16(a, b, v, lane) \
vaddq_f16((a), vmulq_n_f16(b, vgetq_lane_f16(v, lane)))
#endif
#endif
#define cb1(step) \
bias[0][step] = fma_lane_f16( \
bias[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 8], \
(step * stride + src_idx) % 8);
#define cb2(step) \
bias[0][step] = fma_lane_f16( \
bias[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 8], \
(step * stride + src_idx) % 8); \
bias[1][step] = fma_lane_f16( \
bias[1][step], weight[1][weight_idx], src[(step * stride + src_idx) / 8], \
(step * stride + src_idx) % 8);
#define CAL_HELPER(nr_ow) \
template < \
int src_idx, int weight_idx, int stride, typename T1, typename T2, \
typename T3> \
struct CalHelper<src_idx, weight_idx, 1, stride, nr_ow, T1, T2, T3> { \
static MEGDNN_ALWAYS_INLINE void impl( \
T1& bias, const T2& src, const T3& weight) { \
UNROLL_CALL_NOWRAPPER(nr_ow, cb1); \
} \
}; \
template < \
int src_idx, int weight_idx, int stride, typename T1, typename T2, \
typename T3> \
struct CalHelper<src_idx, weight_idx, 2, stride, nr_ow, T1, T2, T3> { \
static MEGDNN_ALWAYS_INLINE void impl( \
T1& bias, const T2& src, const T3& weight) { \
UNROLL_CALL_NOWRAPPER(nr_ow, cb2); \
} \
};
CAL_HELPER(1)
CAL_HELPER(2)
CAL_HELPER(3)
CAL_HELPER(4)
CAL_HELPER(5)
CAL_HELPER(6)
CAL_HELPER(7)
CAL_HELPER(8)
#undef CAL_HELPER
#undef cb2
#undef cb1
#undef fma_lane_f16
template <
int src_idx, int weight_idx, int nr_oc_block, int stride, int nr_ow,
typename T1, typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T1& bias, const T2& src, const T3& weight) {
CalHelper<src_idx, weight_idx, nr_oc_block, stride, nr_ow, T1, T2, T3>::impl(
bias, src, weight);
}
template <int oc>
struct OCHelper {
static constexpr int val = -1;
};
template <>
struct OCHelper<8> {
static constexpr int val = 1;
};
template <>
struct OCHelper<16> {
static constexpr int val = 2;
};
template <
BiasMode bias_mode, typename Op, int nr_ow, int filter, int oc_block,
int stride> //! CHECK
struct KernFilterXStrideXNchwNchw88FP16 {
static void impl(
const __fp16* src_ptr, const __fp16* weight_ptr, const __fp16* bias_ptr,
__fp16* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op);
};
#define KERNEL_CB(i) cal_helper<i, i, nr_oc_block, stride, nr_ow>(bias, src, weight);
#define KERNEL(step, FILTER_SIZE) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f16>(src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, simd_len, nr_oc_block, Vld1q_f16>( \
weight, weight_ptr + step * ld_weight_fh, ld_weight_oc); \
UNROLL_CALL_RAW(FILTER_SIZE, KERNEL_CB);
#define INSTANCE_KERN(FILTER_SIZE) \
template <BiasMode bias_mode, typename Op, int nr_ow, int oc_block, int stride> \
struct KernFilterXStrideXNchwNchw88FP16< \
bias_mode, Op, nr_ow, FILTER_SIZE, oc_block, stride> { \
static void impl( \
const __fp16* src_ptr, const __fp16* weight_ptr, \
const __fp16* bias_ptr, __fp16* dst_ptr, int ic, int ih, int iw, \
int ld_dst_oc, const Op& op) { \
constexpr int filter_size = FILTER_SIZE; \
constexpr int oc_step = 8; \
constexpr int simd_len = 8; \
constexpr int src_reg_size = \
((nr_ow - 1) * stride + filter_size + simd_len - 1) / simd_len; \
\
constexpr int ld_weight_fh = filter_size * oc_step; \
constexpr int ld_weight_ic = ld_weight_fh * filter_size; \
const int ld_weight_oc = ld_weight_ic * ic; \
const int ld_src_ic = ih * iw; \
\
constexpr int nr_oc_block = OCHelper<oc_block>::val; \
float16x8_t bias[nr_oc_block][nr_ow]; \
init_ocx_ow8<nr_oc_block, bias_mode, nr_ow>(bias, bias_ptr, oc_step); \
\
rep(ic_idx, ic) { \
float16x8_t src[src_reg_size], weight[nr_oc_block][filter_size]; \
UNROLL_CALL_ONE_ARG_RAW(FILTER_SIZE, KERNEL, FILTER_SIZE); \
src_ptr += ld_src_ic; \
weight_ptr += ld_weight_ic; \
} \
store_ocx_ow8_remain_static<nr_oc_block, nr_ow, Op, 16>( \
bias, op, dst_ptr, ld_dst_oc); \
} \
};
INSTANCE_KERN(2)
INSTANCE_KERN(3)
INSTANCE_KERN(5)
INSTANCE_KERN(7)
#undef INSTANCE_KERN
#undef KERNEL
#undef KERNEL_CB
template <BiasMode bias_mode, typename Op, int filter_size, int stride>
struct ConvDirectNchwNchw88Fp16 {
static MEGDNN_ALWAYS_INLINE void impl(
const __fp16* src, const __fp16* filter, const __fp16* bias, __fp16* dst,
const int oc, const int ic, const int ih, const int iw, const int oh,
const int oh_block, const int ow, const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
#ifdef MEGDNN_ARMV7
constexpr int big_oc_step = 8;
#else
constexpr int big_oc_step = 16;
#endif
constexpr int oc_step = 8;
constexpr int ow_step = 8;
constexpr int sh = stride;
constexpr int sw = stride;
const int ld_dst_oc = oh * ow * oc_step;
const int ow_end = ow / ow_step * ow_step;
const int ow_remain = ow - ow_end;
const int oc_end = oc / big_oc_step * big_oc_step;
const int oc_remain = oc - oc_end;
using remain_func = std::function<void(
const __fp16* src_ptr, const __fp16* weight_ptr, const __fp16* bias_ptr,
__fp16* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op)>;
remain_func big_oc_remain = nullptr, small_oc_remain = nullptr;
if (ow_remain) {
switch (ow_remain) {
#define cb(i) \
case i + 1: \
big_oc_remain = KernFilterXStrideXNchwNchw88FP16< \
bias_mode, Op, i + 1, filter_size, big_oc_step, stride>::impl; \
small_oc_remain = KernFilterXStrideXNchwNchw88FP16< \
bias_mode, Op, i + 1, filter_size, oc_step, stride>::impl; \
break;
UNROLL_CALL_NOWRAPPER(7, cb);
#undef cb
default:
megdnn_assert(0, "Don't support remain %d for kern", ow_remain);
}
}
for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const __fp16* weight_ptr = filter + oc_idx * ic * fh * fw;
rep(oh_idx, oh_block) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const __fp16* src_ptr = src + oh_idx * sh * iw + ow_idx * sw;
const __fp16* bias_ptr = bias + oc_idx;
__fp16* dst_ptr =
dst + oc_idx * oh * ow + (oh_idx * ow + ow_idx) * oc_step;
KernFilterXStrideXNchwNchw88FP16<
bias_mode, Op, ow_step, filter_size, big_oc_step, stride>::
impl(src_ptr, weight_ptr, bias_ptr, dst_ptr, ic, ih, iw,
ld_dst_oc, op);
}
if (ow_remain > 0) {
const __fp16* src_ptr = src + oh_idx * sh * iw + ow_end * sw;
const __fp16* bias_ptr = bias + oc_idx;
__fp16* dst_ptr =
dst + oc_idx * oh * ow + (oh_idx * ow + ow_end) * oc_step;
big_oc_remain(
src_ptr, weight_ptr, bias_ptr, dst_ptr, ic, ih, iw,
ld_dst_oc, op);
}
}
}
if (oc_remain > 0) {
const __fp16* weight_ptr = filter + oc_end * ic * fh * fw;
rep(oh_idx, oh_block) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const __fp16* src_ptr = src + oh_idx * sh * iw + ow_idx * sw;
const __fp16* bias_ptr = bias + oc_end;
__fp16* dst_ptr =
dst + oc_end * oh * ow + (oh_idx * ow + ow_idx) * oc_step;
KernFilterXStrideXNchwNchw88FP16<
bias_mode, Op, ow_step, filter_size, oc_step, stride>::
impl(src_ptr, weight_ptr, bias_ptr, dst_ptr, ic, ih, iw,
ld_dst_oc, op);
}
if (ow_remain > 0) {
const __fp16* src_ptr = src + oh_idx * sh * iw + ow_end * sw;
const __fp16* bias_ptr = bias + oc_end;
__fp16* dst_ptr =
dst + oc_end * oh * ow + (oh_idx * ow + ow_end) * oc_step;
small_oc_remain(
src_ptr, weight_ptr, bias_ptr, dst_ptr, ic, ih, iw,
ld_dst_oc, op);
}
}
}
}
};
} // namespace
} // namespace arm_common
} // namespace megdnn
#endif
\ No newline at end of file
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/fallback/conv_bias/gi/block_helper.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h"
......
......@@ -12,11 +12,11 @@
*/
#include "megdnn/oprs.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h"
#include "src/fallback/conv_bias/gi/block_helper.h"
#include "midout.h"
......
......@@ -12,12 +12,12 @@
*/
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/int8x8x16/algos.h"
#include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h"
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/gi/block_helper.h"
#include "midout.h"
......
......@@ -174,191 +174,195 @@ __ai void store_ocx_ow4_remain_static(
StoreOcxOw4Remain<c_dim, ow_remain, Op, T>::impl(c, op, dst_ptr, ld_dst_oc);
}
////////////////////Store_OCX_OW8_Remain/////////////////////////
template <int c_dim, int ow_remain, typename Op, typename T, typename T2, typename T3>
template <
int c_dim, int ow_remain, typename Op, typename T, typename T2, typename T3,
size_t simd_lenx2>
struct StoreOcxOw8Remain {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc);
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 0, Op, T, T2, T3> {
template <
int c_dim, typename Op, typename T, typename T2, typename T3, size_t simd_lenx2>
struct StoreOcxOw8Remain<c_dim, 0, Op, T, T2, T3, simd_lenx2> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24));
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op({{c[1][4], c[1][5]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16));
op({{c[1][6], c[1][7]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24));
MEGDNN_MARK_USED_VAR(c);
MEGDNN_MARK_USED_VAR(op);
MEGDNN_MARK_USED_VAR(dst_ptr);
MEGDNN_MARK_USED_VAR(ld_dst_oc);
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 8, Op, T, T2, T3> {
template <typename Op, typename T, typename T2, typename T3, size_t simd_lenx2>
struct StoreOcxOw8Remain<2, 8, Op, T, T2, T3, simd_lenx2> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2 * 2));
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2 * 3));
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op({{c[1][4], c[1][5]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16));
op({{c[1][6], c[1][7]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24));
op({{c[1][2], c[1][3]}},
reinterpret_cast<T3>(dst_ptr + ld_dst_oc + simd_lenx2));
op({{c[1][4], c[1][5]}},
reinterpret_cast<T3>(dst_ptr + ld_dst_oc + simd_lenx2 * 2));
op({{c[1][6], c[1][7]}},
reinterpret_cast<T3>(dst_ptr + ld_dst_oc + simd_lenx2 * 3));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 7, Op, T, T2, T3> {
template <typename Op, typename T, typename T2, typename T3, size_t simd_lenx2>
struct StoreOcxOw8Remain<2, 7, Op, T, T2, T3, simd_lenx2> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op(c[0][6], reinterpret_cast<T3>(dst_ptr + 24));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2 * 2));
op(c[0][6], reinterpret_cast<T3>(dst_ptr + simd_lenx2 * 3));
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op({{c[1][4], c[1][5]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16));
op(c[1][6], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24));
op({{c[1][2], c[1][3]}},
reinterpret_cast<T3>(dst_ptr + ld_dst_oc + simd_lenx2));
op({{c[1][4], c[1][5]}},
reinterpret_cast<T3>(dst_ptr + ld_dst_oc + simd_lenx2 * 2));
op(c[1][6], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + simd_lenx2 * 3));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 6, Op, T, T2, T3> {
template <typename Op, typename T, typename T2, typename T3, size_t simd_lenx2>
struct StoreOcxOw8Remain<2, 6, Op, T, T2, T3, simd_lenx2> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2 * 2));
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op({{c[1][4], c[1][5]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16));
op({{c[1][2], c[1][3]}},
reinterpret_cast<T3>(dst_ptr + ld_dst_oc + simd_lenx2));
op({{c[1][4], c[1][5]}},
reinterpret_cast<T3>(dst_ptr + ld_dst_oc + simd_lenx2 * 2));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 5, Op, T, T2, T3> {
template <typename Op, typename T, typename T2, typename T3, size_t simd_lenx2>
struct StoreOcxOw8Remain<2, 5, Op, T, T2, T3, simd_lenx2> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op(c[0][4], reinterpret_cast<T3>(dst_ptr + 16));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2));
op(c[0][4], reinterpret_cast<T3>(dst_ptr + simd_lenx2 * 2));
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op(c[1][4], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16));
op({{c[1][2], c[1][3]}},
reinterpret_cast<T3>(dst_ptr + ld_dst_oc + simd_lenx2));
op(c[1][4], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + simd_lenx2 * 2));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 4, Op, T, T2, T3> {
template <typename Op, typename T, typename T2, typename T3, size_t simd_lenx2>
struct StoreOcxOw8Remain<2, 4, Op, T, T2, T3, simd_lenx2> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2));
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op({{c[1][2], c[1][3]}},
reinterpret_cast<T3>(dst_ptr + ld_dst_oc + simd_lenx2));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 3, Op, T, T2, T3> {
template <typename Op, typename T, typename T2, typename T3, size_t simd_lenx2>
struct StoreOcxOw8Remain<2, 3, Op, T, T2, T3, simd_lenx2> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op(c[0][2], reinterpret_cast<T3>(dst_ptr + 8));
op(c[0][2], reinterpret_cast<T3>(dst_ptr + simd_lenx2));
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op(c[1][2], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op(c[1][2], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + simd_lenx2));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 2, Op, T, T2, T3> {
template <typename Op, typename T, typename T2, typename T3, size_t simd_lenx2>
struct StoreOcxOw8Remain<2, 2, Op, T, T2, T3, simd_lenx2> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 1, Op, T, T2, T3> {
template <typename Op, typename T, typename T2, typename T3, size_t simd_lenx2>
struct StoreOcxOw8Remain<2, 1, Op, T, T2, T3, simd_lenx2> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op(c[0][0], reinterpret_cast<T3>(dst_ptr));
op(c[1][0], reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 0, Op, T, T2, T3> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 8, Op, T, T2, T3> {
template <typename Op, typename T, typename T2, typename T3, size_t simd_lenx2>
struct StoreOcxOw8Remain<1, 8, Op, T, T2, T3, simd_lenx2> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2 * 2));
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2 * 3));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 7, Op, T, T2, T3> {
template <typename Op, typename T, typename T2, typename T3, size_t simd_lenx2>
struct StoreOcxOw8Remain<1, 7, Op, T, T2, T3, simd_lenx2> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op(c[0][6], reinterpret_cast<T3>(dst_ptr + 24));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2 * 2));
op(c[0][6], reinterpret_cast<T3>(dst_ptr + simd_lenx2 * 3));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 6, Op, T, T2, T3> {
template <typename Op, typename T, typename T2, typename T3, size_t simd_lenx2>
struct StoreOcxOw8Remain<1, 6, Op, T, T2, T3, simd_lenx2> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2 * 2));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 5, Op, T, T2, T3> {
template <typename Op, typename T, typename T2, typename T3, size_t simd_lenx2>
struct StoreOcxOw8Remain<1, 5, Op, T, T2, T3, simd_lenx2> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op(c[0][4], reinterpret_cast<T3>(dst_ptr + 16));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2));
op(c[0][4], reinterpret_cast<T3>(dst_ptr + simd_lenx2 * 2));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 4, Op, T, T2, T3> {
template <typename Op, typename T, typename T2, typename T3, size_t simd_lenx2>
struct StoreOcxOw8Remain<1, 4, Op, T, T2, T3, simd_lenx2> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + simd_lenx2));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 3, Op, T, T2, T3> {
template <typename Op, typename T, typename T2, typename T3, size_t simd_lenx2>
struct StoreOcxOw8Remain<1, 3, Op, T, T2, T3, simd_lenx2> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op(c[0][2], reinterpret_cast<T3>(dst_ptr + 8));
op(c[0][2], reinterpret_cast<T3>(dst_ptr + simd_lenx2));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 2, Op, T, T2, T3> {
template <typename Op, typename T, typename T2, typename T3, size_t simd_lenx2>
struct StoreOcxOw8Remain<1, 2, Op, T, T2, T3, simd_lenx2> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 1, Op, T, T2, T3> {
template <typename Op, typename T, typename T2, typename T3, size_t simd_lenx2>
struct StoreOcxOw8Remain<1, 1, Op, T, T2, T3, simd_lenx2> {
static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) {
op(c[0][0], reinterpret_cast<T3>(dst_ptr));
}
};
template <int c_dim, int ow_remain, typename Op, typename T, typename T2>
template <
int c_dim, int ow_remain, typename Op, size_t simd_lenx2 = 8, typename T,
typename T2>
__ai void store_ocx_ow8_remain_static(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
StoreOcxOw8Remain<c_dim, ow_remain, Op, T, T2, T2>::impl(c, op, dst_ptr, ld_dst_oc);
StoreOcxOw8Remain<c_dim, ow_remain, Op, T, T2, T2, simd_lenx2>::impl(
c, op, dst_ptr, ld_dst_oc);
}
template <int c_dim, int ow_remain, typename Op, typename T3, typename T, typename T2>
template <
int c_dim, int ow_remain, typename Op, typename T3, size_t simd_lenx2 = 8,
typename T, typename T2>
__ai void store_ocx_ow8_remain_static_dt(
T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
StoreOcxOw8Remain<c_dim, ow_remain, Op, T, T2, T3>::impl(c, op, dst_ptr, ld_dst_oc);
StoreOcxOw8Remain<c_dim, ow_remain, Op, T, T2, T3, simd_lenx2>::impl(
c, op, dst_ptr, ld_dst_oc);
}
////////////////////Store_OCX_OW8_Remain/////////////////////////
template <
......@@ -611,6 +615,15 @@ __ai int32x4_t neon_vld1q(const int* ptr) {
__ai int16x8_t neon_vld1q(const int16_t* ptr) {
return vld1q_s16(ptr);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
__ai float16x8_t neon_vdupq_n(__fp16 val) {
return vdupq_n_f16(val);
}
__ai float16x8_t neon_vld1q(const __fp16* ptr) {
return vld1q_f16(ptr);
}
#endif
template <typename T>
struct NeonLdqSimd;
template <>
......
......@@ -67,6 +67,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoF16DirectStride1 f16_direct_stride1;
AlgoF16ChannelWiseNCHW88 f16_channel_wise_nchw88;
AlgoF16DirectNCHW88 f16_direct_nchw88;
AlgoF16DirectNchwNchw88 f16_direct_nchw_nchw88;
#endif
SmallVector<std::unique_ptr<AlgoBase>> refhold;
......@@ -104,6 +105,7 @@ public:
m_direct_algos.emplace_back(&f16_direct);
m_direct_algos.emplace_back(&f16_channel_wise_nchw88);
m_direct_algos.emplace_back(&f16_direct_nchw88);
m_direct_algos.emplace_back(&f16_direct_nchw_nchw88);
#endif
m_direct_algos.emplace_back(&i8x8x16_direct);
m_direct_algos.emplace_back(&i8x8x16_stride2_filter2);
......
......@@ -73,6 +73,7 @@ private:
class AlgoF16DirectStride1;
class AlgoF16ChannelWiseNCHW88;
class AlgoF16DirectNCHW88;
class AlgoF16DirectNchwNchw88;
#endif
class AlgoPack;
......
......@@ -31,6 +31,11 @@ struct Vld1q_f32 {
struct Vld1_s8 {
static __ai int8x8_t impl(const int8_t* ptr) { return vld1_s8(ptr); }
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
struct Vld1q_f16 {
static __ai float16x8_t impl(const __fp16* ptr) { return vld1q_f16(ptr); }
};
#endif
struct Vldq_dup_4s8_8s16 {
static __ai int16x8_t impl(const int8_t* ptr) { return vldq_dup_4s8_8s16(ptr); }
};
......
......@@ -8,12 +8,15 @@ using NchwNchwxxFuncInterface = std::function<bool(
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const ConvBiasForward::BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode)>;
static SmallVector<NchwNchwxxFuncInterface> g_func_vec{
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_FP32>,
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8>,
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_INT8_INT16>,
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_DOT>,
nchw_nchwxx_valid<NchwNchwxxType::NCHW88>,
static SmallVector<NchwNchwxxFuncInterface> g_func_vec {
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_FP32>,
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8>,
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_INT8_INT16>,
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_DOT>,
nchw_nchwxx_valid<NchwNchwxxType::NCHW88>,
#if !MEGDNN_DISABLE_FLOAT16
nchw_nchwxx_valid<NchwNchwxxType::NCHW88_FP16>,
#endif
};
} // namespace
bool ConvBiasForward::is_nchw_nchwxx_optimized(
......
......@@ -9,6 +9,9 @@ enum NchwNchwxxType {
NCHW44_INT8_INT8_INT16,
NCHW44_INT8_DOT,
NCHW88,
#if !MEGDNN_DISABLE_FLOAT16
NCHW88_FP16,
#endif
};
template <NchwNchwxxType T>
static inline bool nchw_nchwxx_valid(
......@@ -139,6 +142,35 @@ inline bool nchw_nchwxx_valid<NCHW88>(
bool avaible = ok_type && ok_src_dst && ok_slide && ok_conv;
return avaible;
}
#if !MEGDNN_DISABLE_FLOAT16
template <>
inline bool nchw_nchwxx_valid<NCHW88_FP16>(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const BiasMode bias_mode, const param::ConvBias::NonlineMode nonline_mode) {
bool ok_type =
((src_dtype == DTypeEnum::Float16 && filter_dtype == DTypeEnum::Float16 &&
(dst_dtype == DTypeEnum::Float16))) &&
(fm.format == param::Convolution::Format::NCHW88);
bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY ||
nonline_mode == param::ConvBias::NonlineMode::RELU ||
nonline_mode == param::ConvBias::NonlineMode::H_SWISH;
bool ok_src_dst =
fm.icpg < 8 && (fm.ocpg % 8 == 0 && fm.ocpg >= 8) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
(fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 ||
fm.spatial[0] == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[1] == 2);
bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS;
bool avaible =
ok_type && ok_nonline && ok_src_dst && ok_filter && ok_slide && ok_conv;
return avaible;
}
#endif
} // namespace
} // namespace megdnn
\ No newline at end of file
......@@ -239,6 +239,23 @@
#define UNROLL_CALL_NOWRAPPER_D2(step, step2, cb) \
UNROLL_CALL_RAW_D2(step, step2, cb)
/////////////////// unroll call with one argument ///////////////////
//! If the arg of cb is related to a macro inside cb, use this.
//! Reason: ## before VA_ARGS removes the ',' when there are no arguments,
//! but with that we can not nest macros.
//! Ref: https://stackoverflow.com/questions/5891221/variadic-macros-with-zero-arguments
#define UNROLL_ONE_ARG_RAW2(cb, arg) cb(0, arg) cb(1, arg)
#define UNROLL_ONE_ARG_RAW3(cb, arg) UNROLL_ONE_ARG_RAW2(cb, arg) cb(2, arg)
#define UNROLL_ONE_ARG_RAW5(cb, arg) \
UNROLL_ONE_ARG_RAW3(cb, arg) \
cb(3, arg) cb(4, arg)
#define UNROLL_ONE_ARG_RAW7(cb, arg) \
UNROLL_ONE_ARG_RAW5(cb, arg) \
cb(5, arg) cb(6, arg)
#define UNROLL_CALL_ONE_ARG_RAW(step, cb, arg) UNROLL_ONE_ARG_RAW##step(cb, arg)
// clang-format on
// vim: syntax=cpp.doxygen
......@@ -22,4 +22,4 @@ static inline int l2_block_helper(
} // namespace
} // namespace megdnn
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
\ No newline at end of file
......@@ -255,27 +255,13 @@ struct StoreOcxOw8Remain {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc);
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 0, Op, T, T2, T3> {
template <int c_dim, typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<c_dim, 0, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
ParamElemFixLenVisitor<typename Op::src_ctype> vis;
op(vis(c[0][0]), reinterpret_cast<T3>(dst_ptr));
op(vis(c[0][1]), reinterpret_cast<T3>(dst_ptr + 4));
op(vis(c[0][2]), reinterpret_cast<T3>(dst_ptr + 8));
op(vis(c[0][3]), reinterpret_cast<T3>(dst_ptr + 12));
op(vis(c[0][4]), reinterpret_cast<T3>(dst_ptr + 16));
op(vis(c[0][5]), reinterpret_cast<T3>(dst_ptr + 20));
op(vis(c[0][6]), reinterpret_cast<T3>(dst_ptr + 24));
op(vis(c[0][7]), reinterpret_cast<T3>(dst_ptr + 28));
op(vis(c[1][0]), reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op(vis(c[1][1]), reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 4));
op(vis(c[1][2]), reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op(vis(c[1][3]), reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 12));
op(vis(c[1][4]), reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16));
op(vis(c[1][5]), reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 20));
op(vis(c[1][6]), reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24));
op(vis(c[1][7]), reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 28));
MEGDNN_MARK_USED_VAR(c);
MEGDNN_MARK_USED_VAR(op);
MEGDNN_MARK_USED_VAR(dst_ptr);
MEGDNN_MARK_USED_VAR(ld_dst_oc);
}
};
template <typename Op, typename T, typename T2, typename T3>
......@@ -406,20 +392,6 @@ struct StoreOcxOw8Remain<2, 1, Op, T, T2, T3> {
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 0, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) {
ParamElemFixLenVisitor<typename Op::src_ctype> vis;
op(vis(c[0][0]), reinterpret_cast<T3>(dst_ptr));
op(vis(c[0][1]), reinterpret_cast<T3>(dst_ptr + 4));
op(vis(c[0][2]), reinterpret_cast<T3>(dst_ptr + 8));
op(vis(c[0][3]), reinterpret_cast<T3>(dst_ptr + 12));
op(vis(c[0][4]), reinterpret_cast<T3>(dst_ptr + 16));
op(vis(c[0][5]), reinterpret_cast<T3>(dst_ptr + 20));
op(vis(c[0][6]), reinterpret_cast<T3>(dst_ptr + 24));
op(vis(c[0][7]), reinterpret_cast<T3>(dst_ptr + 28));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 8, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) {
......
......@@ -335,19 +335,23 @@ bool ConvBiasImpl::AlgoIm2col::usable(
return false;
}
if (format == param::ConvBias::Format::NCHW88) {
if (matmul_desc.packmode != Pack_Mode::DEFAULT) {
//! current NCHW88 im2col only support DEFAULT mode matmul
bool is_packmode_not_default = (matmul_desc.packmode != Pack_Mode::DEFAULT);
//! NCHW88 hybrid mode and channel wise is not support
bool is_hybrid_mode_or_channel_wise =
(param.filter_meta.icpg < 8_z || param.filter_meta.ocpg == 1);
if (is_packmode_not_default || is_hybrid_mode_or_channel_wise) {
return false;
}
}
if (format == param::ConvBias::Format::NCHW44 ||
format == param::ConvBias::Format::NCHW44_DOT) {
//! current NCHW44 im2col only support DEFAULT mode matmul
if (matmul_desc.packmode != Pack_Mode::DEFAULT) {
return false;
//! nchw44 hybird mode and channel wise is not support
} else if (
param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 ||
param.filter_meta.ocpg == 1) {
bool is_packmode_not_default = (matmul_desc.packmode != Pack_Mode::DEFAULT);
//! NCHW44 hybrid mode and channel wise is not support
bool is_hybrid_mode_or_channel_wise =
(param.filter_meta.icpg < 4_z || param.filter_meta.ocpg == 1);
if (is_packmode_not_default || is_hybrid_mode_or_channel_wise) {
return false;
}
}
......@@ -358,12 +362,11 @@ bool ConvBiasImpl::AlgoIm2col::usable(
}
if (format == param::ConvBias::Format::NCHW44) {
//! current NCHW44 im2col only support DEFAULT mode matmul
if (matmul_desc.packmode != Pack_Mode::DEFAULT) {
return false;
//! nchw44 hybird mode and channel wise is not support
} else if (
param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 ||
param.filter_meta.ocpg == 1) {
bool is_packmode_not_default = (matmul_desc.packmode != Pack_Mode::DEFAULT);
//! NCHW44 hybrid mode and channel wise is not support
bool is_hybrid_mode_or_channel_wise =
(param.filter_meta.icpg < 4_z || param.filter_meta.ocpg == 1);
if (is_packmode_not_default || is_hybrid_mode_or_channel_wise) {
return false;
}
}
......@@ -411,15 +414,14 @@ bool ConvBiasImpl::AlgoIm2col::usable(
matmul_desc.innerblocksize.n, m_ohw_tile_size, matmul_desc.packmode);
fallback::MatrixMulImpl::KernSizeParam matmul_param =
get_matmul_kern_param(param, ohw_tile_size, oc_tile_size);
bool matmulusable = m_matmul_algo->usable(matmul_param);
return matmulusable &&
(!(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
return (!(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 1 &&
param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1)) &&
(param.filter_meta.dilation[0] == param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT;
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
m_matmul_algo->usable(matmul_param);
}
MIDOUT_END();
return false;
......
......@@ -255,6 +255,7 @@ public:
ARM_COMMON_DIRECT_STRD1_FP16,
ARM_COMMON_CHWNWISE_NCHW88_F16,
ARM_COMMON_DIRECT_NCHW88_FP16,
ARM_COMMON_DIRECT_NCHW_NCHW88_FP16,
ARM_COMMON_DIRECT_STRD1_S8,
ARM_COMMON_DIRECT_STRD2_S8,
ARM_COMMON_DIRECT_NCHW44,
......
......@@ -11,65 +11,6 @@
namespace megdnn {
namespace test {
std::vector<conv_bias::TestArg> get_conv_bias_args(
std::vector<size_t> kernel, size_t stride) {
using namespace conv_bias;
using Param = param::ConvBias;
using NLMode = param::ConvBias::NonlineMode;
std::vector<TestArg> args;
auto pack = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h, size_t kernel,
size_t stride, NLMode nonline_mode) {
Param param;
param.stride_h = stride;
param.stride_w = stride;
param.pad_h = kernel == 1 ? 0 : kernel / 2;
param.pad_w = kernel == 1 ? 0 : kernel / 2;
param.nonlineMode = nonline_mode;
//! no bias
args.emplace_back(
param, TensorShape{n, ic, h, w}, TensorShape{oc, ic, kernel, kernel},
TensorShape{});
//! bias broadcast channle
args.emplace_back(
param, TensorShape{n, ic, h, w}, TensorShape{oc, ic, kernel, kernel},
TensorShape{1, oc, 1, 1});
//! bias
args.emplace_back(
param, TensorShape{n, ic, h, w}, TensorShape{oc, ic, kernel, kernel},
TensorShape{
n, oc, (h + 2 * param.pad_h - kernel) / stride + 1,
(w + 2 * param.pad_h - kernel) / stride + 1});
};
for (auto nlmode : {NLMode::IDENTITY, NLMode::RELU, NLMode::SIGMOID}) {
for (size_t n : {1, 2}) {
for (size_t ic : {1, 2, 3, 4, 8}) {
for (size_t oc : {1, 2, 3, 4, 8}) {
for (size_t size : {1, 2, 3, 4, 8, 24}) {
for (size_t k : kernel) {
pack(n, oc, ic, size + 24, size + 24, k, stride, nlmode);
}
}
}
}
}
}
return args;
}
void checker_conv_bias(
std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name) {
using namespace conv_bias;
Checker<ConvBias> checker(handle);
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
for (auto&& arg : args) {
checker.set_param(arg.param).execs({arg.src, arg.filter, arg.bias, {}, {}});
}
}
TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) {
check_conv_bias(
conv_bias::get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
......
......@@ -378,6 +378,56 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) {
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), handle(),
rng, "F16DIRECT", 0.03);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_NCHW_NCHW88) {
NormalRNG rng(1);
using namespace conv_bias;
std::vector<conv_bias::TestArg> args;
auto pack = [&](size_t oc, size_t ic, size_t ih, size_t iw, size_t n, size_t filter,
size_t stride, size_t pad, param::ConvBias::NonlineMode nlmode,
megdnn::BiasMode bias_mode) {
if (ih + 2 * pad < filter || iw + 2 * pad < filter)
return;
param::ConvBias param;
constexpr size_t pack_c = 8;
param.format = param::ConvBias::Format::NCHW88;
param.stride_h = stride;
param.stride_w = stride;
param.pad_h = pad;
param.pad_w = pad;
param.nonlineMode = nlmode;
auto src_tensor_shape = TensorShape{n, ic, ih, iw};
auto weight_tensor_shape = TensorShape{oc / pack_c, filter, filter, ic, pack_c};
auto bias_tensor_shape = TensorShape{};
if (bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) {
bias_tensor_shape = {1, oc / pack_c, 1, 1, pack_c};
}
args.emplace_back(
param, src_tensor_shape, weight_tensor_shape, bias_tensor_shape);
};
for (size_t n : {1}) {
for (size_t oc : {8, 16}) {
for (size_t ic : {4}) {
for (size_t ih = 1; ih < 71; ih += 17) {
for (size_t filter : {2, 3, 5, 7}) {
for (size_t stride : {1, 2}) {
for (size_t pad : {0, 1}) {
for (auto nlmode : QUAN_NLMODE) {
for (auto bias_mode : BR_AND_BIAS_BIASMODE) {
pack(oc, ic, ih, ih, n, filter, stride, pad,
nlmode, bias_mode);
}
}
}
}
}
}
}
}
}
checker_conv_bias_f16(args, handle(), rng, "DIRECT_CONV_F16_NCHW_NCHW88", 0.03);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1) {
NormalRNG rng(1);
checker_conv_bias_f16(
......
......@@ -1677,6 +1677,68 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_NCHW44_VS_N
algo_name_nchw44, data_type_fp32, RUNS, {1, {4}});
}
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
BENCHMARK_CONVBIAS_DIRECT_HYBRID_NCHW44_VS_NCHW88) {
constexpr size_t RUNS = 50;
using NLMode = param::ConvBias::NonlineMode;
std::vector<conv_bias::TestArg> args_nchw88, args_nchw44;
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W,
size_t FS) {
param::ConvBias param_nchw88, param_nchw44;
param_nchw88.format = param::ConvBias::Format::NCHW88;
param_nchw44.format = param::ConvBias::Format::NCHW44;
for (size_t pad : {1}) {
for (size_t stride : {1, 2}) {
for (auto nlmode : {NLMode::RELU, NLMode::IDENTITY, NLMode::H_SWISH}) {
param_nchw88.nonlineMode = nlmode;
param_nchw88.pad_h = pad;
param_nchw88.pad_w = pad;
param_nchw88.stride_h = stride;
param_nchw88.stride_w = stride;
param_nchw44.nonlineMode = nlmode;
param_nchw44.pad_h = pad;
param_nchw44.pad_w = pad;
param_nchw44.stride_h = stride;
param_nchw44.stride_w = stride;
args_nchw88.emplace_back(
param_nchw88, TensorShape{N, IC, H, W},
TensorShape{OC / 8, FS, FS, IC, 8},
TensorShape{1, OC / 8, 1, 1, 8});
args_nchw44.emplace_back(
param_nchw44, TensorShape{N, IC, H, W},
TensorShape{OC / 4, FS, FS, IC, 4},
TensorShape{1, OC / 4, 1, 1, 4});
}
}
}
};
std::vector<DType> data_type_fp16 = {
dtype::Float16(), dtype::Float16(), dtype::Float16(), dtype::Float16()};
std::vector<DType> data_type_fp32 = {
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()};
bench_case(1, 3, 32, 400, 400, 7);
bench_case(1, 3, 32, 300, 300, 5);
bench_case(1, 3, 32, 100, 100, 3);
bench_case(1, 3, 32, 80, 80, 2);
bench_case(1, 3, 64, 200, 200, 7);
bench_case(1, 3, 64, 128, 128, 5);
bench_case(1, 3, 64, 100, 100, 3);
bench_case(1, 3, 64, 80, 80, 2);
bench_case(1, 3, 128, 200, 200, 7);
bench_case(1, 3, 128, 128, 128, 5);
bench_case(1, 3, 128, 100, 100, 3);
bench_case(1, 3, 128, 80, 80, 2);
std::string algo_name_nchw88 = "DIRECT_CONV_F16_NCHW_NCHW88";
std::string algo_name_nchw44 = "F32_CONV_NCHW_NCHW44";
benchmark_with_contrast(
args_nchw88, algo_name_nchw88, data_type_fp16, args_nchw44,
algo_name_nchw44, data_type_fp32, RUNS, {1, {4}});
}
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
BENCHMARK_CHANNEL_WISE_INT8_INT8_INT8_STRIDE1) {
constexpr size_t RUNS = 50;
......
......@@ -1163,7 +1163,7 @@ void benchmark_with_contrast(
{arg.src, data_type[0]}, {arg.filter, data_type[1]},
{arg.bias, data_type[2]}, {}, dst_layout);
float computation = (dst_layout.total_nr_elems() * arg.filter[1] *
arg.filter[2] * arg.filter[3] * arg.filter[4] * 2.0) /
arg.filter[2] * arg.filter[3] * 2.0) /
(1024 * 1024 * 1024) * 1e3;
benchmarker.set_param(arg.param);
auto used = benchmarker.exec({arg.src, arg.filter, arg.bias, {}, {}}) / RUNS;
......@@ -1176,8 +1176,7 @@ void benchmark_with_contrast(
{arg_contrast.bias, data_type_contrast[2]}, {}, dst_layout_contrast);
float computation_contrast =
(dst_layout_contrast.total_nr_elems() * arg_contrast.filter[1] *
arg_contrast.filter[2] * arg_contrast.filter[3] *
arg_contrast.filter[4] * 2.0) /
arg_contrast.filter[2] * arg_contrast.filter[3] * 2.0) /
(1024 * 1024 * 1024) * 1e3;
benchmarker_contrast.set_param(arg_contrast.param);
auto used_contrast = benchmarker_contrast.exec(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册