From ad91f35e2a3fb354ed4f5c93ee62e1493a2fc899 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 6 Feb 2023 15:10:12 +0800 Subject: [PATCH] feat(dnn/fallback): add gi fp16 nchw88 winograd f23 algo GitOrigin-RevId: ac15495535c78005b5fef729a40d6f8a7dcf9980 --- dnn/src/fallback/conv_bias/gi/fp16/algos.cpp | 38 ++ dnn/src/fallback/conv_bias/gi/fp16/algos.h | 18 + dnn/src/fallback/conv_bias/gi/fp16/strategy.h | 3 + .../gi/fp16/strategy_f23_mk8_nchw88.cpp | 334 ++++++++++++++++++ dnn/src/fallback/conv_bias/opr_impl.cpp | 5 + dnn/src/fallback/conv_bias/opr_impl.h | 2 + dnn/test/fallback/conv_bias.cpp | 50 +++ 7 files changed, 450 insertions(+) create mode 100644 dnn/src/fallback/conv_bias/gi/fp16/strategy_f23_mk8_nchw88.cpp diff --git a/dnn/src/fallback/conv_bias/gi/fp16/algos.cpp b/dnn/src/fallback/conv_bias/gi/fp16/algos.cpp index bcaaec51b..8d51bff6b 100644 --- a/dnn/src/fallback/conv_bias/gi/fp16/algos.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp16/algos.cpp @@ -53,5 +53,43 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( AlgoFP16WinogradF43_8x8_NCHW88, winograd::winograd_F43_mk8_f16_nchw88, megdnn_fallback_winograd_fp16_nchw88, param::MatrixMul::Format::MK8); +/* =================== AlgoFP16WinogradF23_8x8_NCHW88 ===================== */ + +bool ConvBiasImpl::AlgoFP16WinogradF23_8x8_NCHW88::usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MIDOUT_BEGIN( + megdnn_fallback_winograd_fp16_nchw88, + midout_iv("AlgoFP16WinogradF23_8x8_NCHW88"_hash)) { + if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0) + return false; + using Strategy = winograd::winograd_F23_mk8_f16_nchw88; + Strategy strategy(param.src_type, param.filter_type, param.dst_type); + auto&& matmul_param = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param) + .get_matmul_kern_param(param); + return m_matmul_algo->usable(matmul_param) && + m_matmul_algo->packmode() == + fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK && + param.filter_meta.format == param::ConvBias::Format::NCHW88 && + !param.filter_meta.should_flip && + (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + param.filter_meta.spatial[0] == 3) && + (param.filter_meta.stride[0] == param.filter_meta.stride[1] && + param.filter_meta.stride[0] == 1) && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + param.src_type.enumv() == DTypeEnum::Float16; + } + MIDOUT_END(); + return false; +} + +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoFP16WinogradF23_8x8_NCHW88, winograd::winograd_F23_mk8_f16_nchw88, + megdnn_fallback_winograd_fp16_nchw88, param::MatrixMul::Format::MK8); + #endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/gi/fp16/algos.h b/dnn/src/fallback/conv_bias/gi/fp16/algos.h index 5eddadf5b..0045148f8 100644 --- a/dnn/src/fallback/conv_bias/gi/fp16/algos.h +++ b/dnn/src/fallback/conv_bias/gi/fp16/algos.h @@ -28,6 +28,24 @@ public: MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F43_8X8_NCHW88_F16) }; +class ConvBiasImpl::AlgoFP16WinogradF23_8x8_NCHW88 final : public AlgoBase { +public: + AlgoFP16WinogradF23_8x8_NCHW88( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) + : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasImpl::algo_name( + m_matmul_algo->name(), {8, 2, m_tile_size, 3}, + param::ConvBias::Format::NCHW88); + } + return m_name.c_str(); + } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F23_8X8_NCHW88_F16) +}; + } // namespace fallback } // namespace megdnn diff --git a/dnn/src/fallback/conv_bias/gi/fp16/strategy.h b/dnn/src/fallback/conv_bias/gi/fp16/strategy.h index 36f245a2a..20203bb91 100644 --- a/dnn/src/fallback/conv_bias/gi/fp16/strategy.h +++ b/dnn/src/fallback/conv_bias/gi/fp16/strategy.h @@ -14,6 +14,9 @@ namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY( dt_float16, dt_float16, dt_float16, dt_float16, 4, 3, 8, 8, winograd_F43_mk8_f16_nchw88) +MEGDNN_REG_WINOGRAD_STRATEGY( + dt_float16, dt_float16, dt_float16, dt_float16, 2, 3, 8, 8, + winograd_F23_mk8_f16_nchw88) } // namespace winograd } // namespace fallback } // namespace megdnn diff --git a/dnn/src/fallback/conv_bias/gi/fp16/strategy_f23_mk8_nchw88.cpp b/dnn/src/fallback/conv_bias/gi/fp16/strategy_f23_mk8_nchw88.cpp new file mode 100644 index 000000000..0da616ffb --- /dev/null +++ b/dnn/src/fallback/conv_bias/gi/fp16/strategy_f23_mk8_nchw88.cpp @@ -0,0 +1,334 @@ +#include "src/fallback/conv_bias/gi/fp16/strategy.h" + +#if defined(GI_SUPPORT_F16) + +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/gi/fp16/helper.h" +#include "src/fallback/conv_bias/winograd/winograd.h" + +#include "src/fallback/elemwise_helper/op_unary.h" +#include "src/naive/matrix_mul/matrix_mul_helper.h" + +#include "midout.h" +MIDOUT_DECL(megdnn_fallback_winograd_nchw88_fp16_F23_mk8) + +using namespace megdnn; +using namespace fallback; +namespace { + +constexpr size_t alpha = 2 + 3 - 1; +constexpr size_t pack_size = 8; + +struct InputTransformF23_NCHW88 { + template + static void transform( + gi_float16_t* patchT, const gi_float16_t* input, + gi_float16_t* input_transform_buf, size_t ih_start, size_t iw_start, + size_t IH, size_t IW, size_t unit_idx, size_t nr_units_in_tile, size_t ic, + size_t IC) { + size_t IW8 = IW * pack_size; + size_t icb = ic / pack_size; + size_t iw8_start = iw_start * pack_size; + size_t ICB = IC / pack_size; + +#define cb(m, n) GI_FLOAT16_t d##m##n; + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#undef cb + + if (!(inner && ic + pack_size < IC)) { + memset(patchT, 0, sizeof(gi_float16_t) * pack_size * alpha * alpha); + } + if (inner) { + MEGDNN_MARK_USED_VAR(patchT); + const gi_float16_t* input_ptr = + input + icb * IH * IW8 + ih_start * IW8 + iw8_start; +#define cb(n, m) d##m##n = GiLoadFloat16(input_ptr + pack_size * n); + + UNROLL_CALL_RAW(4, cb, 0); + input_ptr += IW8; + UNROLL_CALL_RAW(4, cb, 1); + input_ptr += IW8; + UNROLL_CALL_RAW(4, cb, 2); + input_ptr += IW8; + UNROLL_CALL_RAW(4, cb, 3); +#undef cb + } else { + int ih0_act = std::max(ih_start, 0), + ih1_act = std::min(ih_start + alpha, IH), + iw0_act = std::max(iw_start, 0), + iw1_act = std::min(iw_start + alpha, IW); + const gi_float16_t* input_ptr = input + icb * IH * IW8; + // partial copy + for (int ih = ih0_act; ih < ih1_act; ++ih) { + for (int iw = iw0_act; iw < iw1_act; ++iw) { + size_t iho = ih - ih_start, iwo = iw - iw_start; + auto src = GiLoadFloat16(input_ptr + ih * IW8 + iw * pack_size); + GiStoreFloat16( + patchT + iho * alpha * pack_size + iwo * pack_size, src); + } + } +#define cb(m, n) \ + d##m##n = GiLoadFloat16(patchT + m * alpha * pack_size + n * pack_size); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#undef cb + } + //! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0 + //! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 + //! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 + //! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 +#define cb(m) \ + auto t0##m = SUBF16(d0##m, d2##m); \ + auto t1##m = ADDF16(d1##m, d2##m); \ + auto t2##m = SUBF16(d2##m, d1##m); \ + auto t3##m = SUBF16(d3##m, d1##m); + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(m) \ + d##m##0 = SUBF16(t##m##0, t##m##2); \ + d##m##1 = ADDF16(t##m##1, t##m##2); \ + d##m##2 = SUBF16(t##m##2, t##m##1); \ + d##m##3 = SUBF16(t##m##3, t##m##1); + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(m, n) \ + GiStoreFloat16( \ + input_transform_buf + \ + (m * alpha + n) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d##m##n); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) +#undef cb + } +}; + +#define CONCAT(a, idx) a##idx +template +struct OutputTransformF23_NCHW88 { + static void transform( + const gi_float16_t* output_transform_buf, const gi_float16_t* bias, + gi_float16_t* output, gi_float16_t* transform_mid_buf, size_t oh_start, + size_t ow_start, size_t OH, size_t OW, size_t oc_start, size_t oc_end, + size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, + const DType& src_dtype, const DType& dst_dtype) { + MEGDNN_MARK_USED_VAR(transform_mid_buf); + Op op(src_dtype, dst_dtype); + //! AT * m * A + size_t OCB = (oc_end - oc_start) / pack_size; + size_t oc = oc_start + oc_index; + size_t ocb = oc_index / pack_size; + +#define cb(m, n) \ + auto v##m##n = GiLoadFloat16( \ + output_transform_buf + \ + (m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ + ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#undef cb + + //! 1 1 1 0 v00 v01 v02 v03 1 0 + //! 0 1 -1 1 v10 v11 v12 v13 1 1 + //! v20 v21 v22 v23 1 -1 + //! v30 v31 v32 v33 0 1 + +#define cb(m) \ + auto t0##m = ADDF16(ADDF16(v0##m, v1##m), v2##m); \ + auto t1##m = ADDF16(SUBF16(v1##m, v2##m), v3##m); + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(m) \ + v##m##0 = ADDF16(ADDF16(t##m##0, t##m##1), t##m##2); \ + v##m##1 = ADDF16(SUBF16(t##m##1, t##m##2), t##m##3); + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + + GI_FLOAT16_t vbias; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias = GiLoadFloat16(bias + oc); + +#define cb(m, n) v##m##n = ADDF16(v##m##n, vbias); + UNROLL_CALL_RAW_D2(2, 2, cb); +#undef cb + } + if (bmode != BiasMode::BIAS) { +#define cb(m, n) v##m##n = op(CONCAT(v##m, n)); + UNROLL_CALL_RAW_D2(2, 2, cb); +#undef cb + } +#define out_save(oho, owo) \ + do { \ + size_t oh = oh_start + oho; \ + size_t ow = ow_start + owo; \ + if (oh < OH && ow < OW) { \ + if (bmode == BiasMode::BIAS) { \ + v##oho##owo = ADDF16( \ + v##oho##owo, GiLoadFloat16( \ + bias + oc * OH * OW + \ + oh * OW * pack_size + ow * pack_size)); \ + v##oho##owo = op(v##oho##owo); \ + } \ + GiStoreFloat16( \ + output + oc * OH * OW + oh * OW * pack_size + ow * pack_size, \ + v##oho##owo); \ + } \ + } while (0); + UNROLL_CALL_RAW_D2(2, 2, out_save); +#undef out_save + } +}; +#undef CONCAT +} // namespace + +namespace megdnn { +namespace fallback { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F23_mk8_f16_nchw88) +void winograd_F23_mk8_f16_nchw88::filter( + const dt_float16* filter, dt_float16* filter_transform_buf, + dt_float16* transform_mid_buf, size_t OC, size_t IC, size_t oc_start, + size_t oc_end) { + //! 1 0 0 v00 v01 v02 1 0.5 0.5 0 + //! 0.5 0.5 0.5 v10 v11 v12 0 0.5 -0.5 0 + //! 0.5 -0.5 0.5 v20 v21 v22 0 0.5 0.5 1 + //! 0 0 1 + + constexpr size_t pack_size = 8; + + MEGDNN_MARK_USED_VAR(transform_mid_buf); + megdnn_assert( + (oc_end - oc_start) % pack_size == 0 && oc_start % pack_size == 0 && + oc_end % pack_size == 0 && IC % pack_size == 0 && + OC % pack_size == 0, + "NCHW88 Winograd filter transform requires both OC and IC " + "are times of 8"); + size_t OCB = OC / pack_size; + size_t ICB = IC / pack_size; + + for (size_t ocb = oc_start / pack_size; ocb < oc_end / pack_size; ocb++) { + for (size_t icb = 0; icb < ICB; icb++) { + for (size_t ic_inner = 0; ic_inner < pack_size; ic_inner++) { + const gi_float16_t* fptr = + reinterpret_cast(filter) + + (ocb * ICB + icb) * KERNEL_SIZE * KERNEL_SIZE * pack_size * + pack_size + + ic_inner * pack_size; + +#define cb(m, n) \ + GI_FLOAT16_t g##m##n = \ + GiLoadFloat16(fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); + UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) +#undef cb + +#define FILTER_TRANSFORM(n, wd, g) \ + auto wd##n##0 = g##0##n; \ + tmp00 = ADDF16(g##0##n, g##2##n); \ + tmp0 = MULSF16(tmp00, 0.5); \ + tmp1 = MULSF16(g##1##n, 0.5); \ + auto wd##n##1 = ADDF16(tmp0, tmp1); \ + auto wd##n##2 = SUBF16(tmp0, tmp1); \ + auto wd##n##3 = g##2##n; + GI_FLOAT16_t tmp0, tmp1, tmp00; + UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); + UNROLL_CALL_RAW(4, FILTER_TRANSFORM, ret, wd); +#undef FILTER_TRANSFORM +#define cb_save(m, n) \ + GiStoreFloat16( \ + reinterpret_cast(filter_transform_buf) + \ + (m * ALPHA + n) * OCB * ICB * pack_size * pack_size + \ + ocb * ICB * pack_size * pack_size + icb * pack_size * pack_size + \ + ic_inner * pack_size, \ + ret##m##n); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb_save) +#undef cb_save + } + } + } +} + +void winograd_F23_mk8_f16_nchw88::input( + const dt_float16* input, dt_float16* input_transform_buf, + dt_float16* transform_mid_buf, size_t IH, size_t IW, size_t IC, size_t PH, + size_t PW, size_t unit_start_idx, size_t nr_units_in_tile) { + megdnn_assert(IC % IC_BLOCK_SIZE == 0); + constexpr int alpha = 3 + 2 - 1; + + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + gi_float16_t* patchT = reinterpret_cast(transform_mid_buf) + + IC_BLOCK_SIZE * alpha * alpha; + + for (size_t ic = 0; ic < IC; ic += IC_BLOCK_SIZE) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransformF23_NCHW88::transform( + patchT, reinterpret_cast(input), + reinterpret_cast(input_transform_buf), ih_start, + iw_start, IH, IW, unit_idx, nr_units_in_tile, ic, IC); + } else { + InputTransformF23_NCHW88::transform( + patchT, reinterpret_cast(input), + reinterpret_cast(input_transform_buf), ih_start, + iw_start, IH, IW, unit_idx, nr_units_in_tile, ic, IC); + } + } + } +} + +void winograd_F23_mk8_f16_nchw88::output( + const dt_float16* output_transform_buf, const dt_float16* bias, + dt_float16* output, dt_float16* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, size_t OW, size_t oc_start, size_t oc_end, + size_t unit_start_idx, size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + for (size_t oc = oc_start; oc < oc_end; oc += OC_BLOCK_SIZE) { \ + size_t oc_index = oc - oc_start; \ + rep(unit_idx, nr_units_in_tile) { \ + size_t index = unit_start_idx + unit_idx; \ + auto nh = index / units_w; \ + auto nw = index % units_w; \ + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \ + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \ + OutputTransformF23_NCHW88<_bmode, _nonline_op>::transform( \ + reinterpret_cast(output_transform_buf), \ + reinterpret_cast(bias), \ + reinterpret_cast(output), \ + reinterpret_cast(transform_mid_buf), oh_start, \ + ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, \ + nr_units_in_tile, src_dtype, dst_dtype); \ + } \ + } + + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + constexpr size_t pack_size = 8; + + size_t OC = oc_end - oc_start; + megdnn_assert( + OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0, + "NCHW88 Winograd filter transform requires OC is times of 8"); + + GI_DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_fallback_winograd_nchw88_fp16_F23_mk8, cb, gi_float16_t, + gi_float16_t, bmode, nonline_mode); +#undef cb +} + +} // namespace winograd +} // namespace fallback +} // namespace megdnn + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index 99c0239e3..edc2af637 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -199,6 +199,11 @@ public: static_cast(algo), tile_size)); m_gi_winograd_algos.emplace_back(refhold.back().get()); + + refhold.emplace_back(new AlgoFP16WinogradF23_8x8_NCHW88( + static_cast(algo), + tile_size)); + m_gi_winograd_algos.emplace_back(refhold.back().get()); } } #endif diff --git a/dnn/src/fallback/conv_bias/opr_impl.h b/dnn/src/fallback/conv_bias/opr_impl.h index c12adb602..5c7bc5634 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.h +++ b/dnn/src/fallback/conv_bias/opr_impl.h @@ -226,6 +226,7 @@ public: GI_COMMON_WINOGRAD_F43_4X4_NCHW44_F32, GI_COMMON_WINOGRAD_F63_4X4_NCHW44_F32, GI_COMMON_WINOGRAD_F73_4X4_NCHW44_F32, + GI_COMMON_WINOGRAD_F23_8X8_NCHW88_F16, GI_COMMON_WINOGRAD_F43_8X8_NCHW88_F16, GI_COMMON_DIRECT_FP32, GI_COMMON_DIRECT_STRD1_FP32, @@ -394,6 +395,7 @@ private: class AlgoFP32WinogradF63_4x4_NCHW44; class AlgoFP32WinogradF73_4x4_NCHW44; + class AlgoFP16WinogradF23_8x8_NCHW88; class AlgoFP16WinogradF43_8x8_NCHW88; class AlgoF32Direct; diff --git a/dnn/test/fallback/conv_bias.cpp b/dnn/test/fallback/conv_bias.cpp index ec7b60e21..5ca707a27 100644 --- a/dnn/test/fallback/conv_bias.cpp +++ b/dnn/test/fallback/conv_bias.cpp @@ -609,6 +609,18 @@ TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_F43_4_NCHW44) { } #if defined(GI_SUPPORT_F16) +TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_F23_8_NCHW88_FP16) { + using namespace conv_bias; + std::vector args = + get_nchw88_conv_bias_args({3}, FULL_NLMODE, BR_AND_NO_BIASMODE, 1); + + Checker> checker( + handle()); + Float16PeriodicalRNG rng(0x3c00); + check_winograd_fp16( + "8:2:", checker, args, &rng, 0.003, param::MatrixMul::Format::MK8, + "WINOGRAD_NCHW88"); +} TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_F43_8_NCHW88_FP16) { using namespace conv_bias; @@ -1360,6 +1372,44 @@ TEST_F(FALLBACK, BENCHMARK_GI_CONVBIAS_WINOGRAD_F63_MK4_NCHW_VS_NCHW44) { } #if defined(GI_SUPPORT_F16) +TEST_F(FALLBACK, BENCHMARK_GI_WINOGRAD_F23_FP32_NCHW44_VS_FP16_NCHW88) { + auto&& args_fp16 = conv_bias::get_winograd_benchmark_args(3, 8, 8); + auto&& args_fp32 = conv_bias::get_winograd_benchmark_args(3, 4, 4); + + auto cal_computation = [](const conv_bias::TestArg& arg) { + TensorShape dst_shape{ + arg.src[0], arg.filter[0], + (arg.src[2] + arg.param.pad_h * 2 - arg.filter[2]) / + arg.param.stride_h + + 1, + (arg.src[3] + arg.param.pad_w * 2 - arg.filter[3]) / + arg.param.stride_w + + 1, + arg.filter[5]}; + return dst_shape.total_nr_elems() * arg.filter[1] * arg.filter[2] * + arg.filter[3] * arg.filter[4] * 2.0 / (1024 * 1024 * 1024) * 1e3; + }; + + std::vector> args_with_computation_fp16, + args_with_computation_fp32; + for (const auto& arg : args_fp16) { + args_with_computation_fp16.emplace_back(arg, cal_computation(arg)); + } + for (const auto& arg : args_fp32) { + args_with_computation_fp32.emplace_back(arg, cal_computation(arg)); + } + + std::vector data_type_fp16 = { + dtype::Float16(), dtype::Float16(), dtype::Float16(), dtype::Float16()}; + std::vector data_type_fp32 = { + dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; + std::string algo_name_fp16 = "WINOGRAD_NCHW88:FB_GI_F16_MK8_8x8:8:2"; + std::string algo_name_fp32 = "WINOGRAD_NCHW44:FB_GI_F32_MK4_4x8:4:2"; + benchmark_with_contrast( + args_with_computation_fp16, algo_name_fp16, data_type_fp16, + args_with_computation_fp32, algo_name_fp32, data_type_fp32, 10, {1, {0}}); +} + TEST_F(FALLBACK, BENCHMARK_GI_WINOGRAD_F43_FP32_NCHW44_VS_FP16_NCHW88) { auto&& args_fp16 = conv_bias::get_winograd_benchmark_args(3, 8, 8); auto&& args_fp32 = conv_bias::get_winograd_benchmark_args(3, 4, 4); -- GitLab