diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index 205c5aac8bb75a3244c37101320235caccb1b0d9..f81aae5600036145cac4c59d8a1f45161da7a0f8 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -453,19 +453,21 @@ public: }; //! param for winograd algos. + struct WinogradParam { uint32_t channel_block_size; uint32_t output_block_size; uint32_t tile_size; + uint32_t filter_size; bool operator==(const WinogradParam& rhs) const { return channel_block_size == rhs.channel_block_size && output_block_size == rhs.output_block_size && - tile_size == rhs.tile_size; + tile_size == rhs.tile_size && filter_size == rhs.filter_size; } std::string to_string() const; }; - static constexpr WinogradParam INVALID_WINOGRAD_PARAM = {0, 0, 0}; + static constexpr WinogradParam INVALID_WINOGRAD_PARAM = {0, 0, 0, 0}; struct DirectParam { std::string to_string() const { return ""; } diff --git a/dnn/src/arm_common/conv_bias/f16/algos.h b/dnn/src/arm_common/conv_bias/f16/algos.h index a9aa87916da5d33b3858f3bea52c2f0bf857f964..ee25b5a1abc415e9ba328cdda5fe80400a5730ba 100644 --- a/dnn/src/arm_common/conv_bias/f16/algos.h +++ b/dnn/src/arm_common/conv_bias/f16/algos.h @@ -14,7 +14,7 @@ public: const char* name() const override { if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( - m_matmul_algo->name(), {1, 2, m_tile_size}); + m_matmul_algo->name(), {1, 2, m_tile_size, 3}); } return m_name.c_str(); } @@ -33,7 +33,7 @@ public: const char* name() const override { if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( - m_matmul_algo->name(), {1, 4, m_tile_size}); + m_matmul_algo->name(), {1, 4, m_tile_size, 5}); } return m_name.c_str(); } @@ -51,7 +51,7 @@ public: const char* name() const override { if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( - m_matmul_algo->name(), {1, 6, m_tile_size}); + m_matmul_algo->name(), {1, 6, m_tile_size, 3}); } return m_name.c_str(); } @@ -69,7 +69,7 @@ public: const char* name() const override { if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( - m_matmul_algo->name(), {8, 2, m_tile_size}); + m_matmul_algo->name(), {8, 2, m_tile_size, 3}); } return m_name.c_str(); } diff --git a/dnn/src/arm_common/conv_bias/int8/algos.h b/dnn/src/arm_common/conv_bias/int8/algos.h index 5c505953a4d5253218fb1b0e9bfe1bee5a40d4a8..7d8c19534bb3ce58251144e3d777bc16e2cceff0 100644 --- a/dnn/src/arm_common/conv_bias/int8/algos.h +++ b/dnn/src/arm_common/conv_bias/int8/algos.h @@ -221,7 +221,7 @@ public: const char* name() const override { if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( - m_matmul_algo->name(), {8, 2, m_tile_size}); + m_matmul_algo->name(), {8, 2, m_tile_size, 3}); } return m_name.c_str(); } @@ -239,7 +239,7 @@ public: const char* name() const override { if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( - m_matmul_algo->name(), {4, 2, m_tile_size}, + m_matmul_algo->name(), {4, 2, m_tile_size, 3}, param::ConvBias::Format::NCHW44); } return m_name.c_str(); @@ -258,7 +258,7 @@ public: const char* name() const override { if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( - m_matmul_algo->name(), {8, 2, m_tile_size}, + m_matmul_algo->name(), {8, 2, m_tile_size, 3}, param::ConvBias::Format::NCHW44); } return m_name.c_str(); diff --git a/dnn/src/common/conv_bias.cpp b/dnn/src/common/conv_bias.cpp index a8b51a3a355f06fb60087e134cb0db144ce90d75..96031113e7ff634f46dcce7dab0bd4e14cd6dc30 100644 --- a/dnn/src/common/conv_bias.cpp +++ b/dnn/src/common/conv_bias.cpp @@ -176,7 +176,9 @@ template struct NCHW44ParamTrait; std::string ConvBias::WinogradParam::to_string() const { - return ssprintf("%u:%u:%u", channel_block_size, output_block_size, tile_size); + return ssprintf( + "%u:%u:%u:%u", channel_block_size, output_block_size, tile_size, + filter_size); } template diff --git a/dnn/src/common/unroll_macro.h b/dnn/src/common/unroll_macro.h index da0f95bb37504a46796586df6f502a1f43b1da2b..41f88ff100b5a4de6b68b9e2eb12e4621f472b51 100644 --- a/dnn/src/common/unroll_macro.h +++ b/dnn/src/common/unroll_macro.h @@ -165,6 +165,18 @@ cb(8, 0, ##a) cb(8, 1, ##a) cb(8, 2, ##a) cb(8, 3, ##a) \ cb(8, 4, ##a) cb(8, 5, ##a) cb(8, 6, ##a) cb(8, 7, ##a) cb(8, 8, ##a) +#define UNROLL_RAW_4x2(cb, v0, a...) \ + cb(0, 0, ##a) cb(0, 1, ##a) cb(1, 0, ##a) cb(1, 1, ##a) \ + cb(2, 0, ##a) cb(2, 1, ##a) cb(3, 0, ##a) cb(3, 1, ##a) + +#define UNROLL_RAW_5x2(cb, v0, a...) \ + UNROLL_RAW_4x2(cb, v0, ##a) \ + cb(4, 0, ##a) cb(4, 1, ##a) + +#define UNROLL_RAW_6x2(cb, v0, a...) \ + UNROLL_RAW_5x2(cb, v0, ##a) \ + cb(5, 0, ##a) cb(5, 1, ##a) + #define UNROLL_CALL0_D2(step, step2, cb, v...) \ UNROLL_RAW_##step##x##step2(cb, 0, ##v) #define UNROLL_CALL1_D2(step, step2, cb, v...) \ diff --git a/dnn/src/fallback/conv_bias/algos.h b/dnn/src/fallback/conv_bias/algos.h index c98f2dec608250007c7bf031243cb25a3aa869d9..221afc6a912d13821f4145224f252ca3ed6a485f 100644 --- a/dnn/src/fallback/conv_bias/algos.h +++ b/dnn/src/fallback/conv_bias/algos.h @@ -42,7 +42,7 @@ public: if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( ssprintf("FALLBACK_WINOGRAD_F32-%s", m_matmul_algo->name()), - {1, 2, UNIT_TILE_SIZE}); + {1, 2, UNIT_TILE_SIZE, 3}); } return m_name.c_str(); } @@ -74,7 +74,7 @@ public: if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( ssprintf("FALLBACK_WINOGRAD_F32-%s", m_matmul_algo->name()), - {4, 2, UNIT_TILE_SIZE}); + {4, 2, UNIT_TILE_SIZE, 3}); } return m_name.c_str(); } @@ -106,7 +106,7 @@ public: if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( ssprintf("FALLBACK_WINOGRAD_QS8-%s", m_matmul_algo->name()), - {1, 2, UNIT_TILE_SIZE}); + {1, 2, UNIT_TILE_SIZE, 3}); } return m_name.c_str(); } @@ -138,7 +138,7 @@ public: if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( ssprintf("FALLBACK_WINOGRAD_QS8-%s", m_matmul_algo->name()), - {8, 2, UNIT_TILE_SIZE}); + {8, 2, UNIT_TILE_SIZE, 3}); } return m_name.c_str(); } diff --git a/dnn/src/fallback/conv_bias/gi/fp32/algos.cpp b/dnn/src/fallback/conv_bias/gi/fp32/algos.cpp index 1b87e35526befa9f883cab6585e382c55be42007..5ae18365faf338d54d6608b45529011bbeef3358 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/algos.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/algos.cpp @@ -84,6 +84,38 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( AlgoFP32WinogradF63, winograd::winograd_6x3_1x1_f, megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT); +/* ======================= AlgoFP32WinogradF43 ======================== */ + +bool ConvBiasImpl::AlgoFP32WinogradF43::usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 5, 0) { + using Strategy = winograd::winograd_4x3_1x1_f; + Strategy strategy(param.src_type, param.filter_type, param.dst_type); + auto&& matmul_param = + megdnn::winograd::ConvBias(strategy, m_tile_size, param) + .get_matmul_kern_param(param); + return m_matmul_algo->usable(matmul_param) && + param.filter_meta.format == param::ConvBias::Format::NCHW && + !param.filter_meta.should_flip && + (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + param.filter_meta.spatial[0] == 3) && + (param.filter_meta.stride[0] == param.filter_meta.stride[1] && + param.filter_meta.stride[0] == 1) && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + param.src_type.enumv() == DTypeEnum::Float32; + } + MIDOUT_END(); + return false; +} + +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoFP32WinogradF43, winograd::winograd_4x3_1x1_f, + megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT); + /* ======================= AlgoFP32WinogradF54 ======================== */ bool ConvBiasImpl::AlgoFP32WinogradF54::usable( diff --git a/dnn/src/fallback/conv_bias/gi/fp32/algos.h b/dnn/src/fallback/conv_bias/gi/fp32/algos.h index 3f2ee46ca25710d5f2636b72ccadcff09b2530bb..cbba0fad2f9e8868f4a0813025e0345b8aace5dc 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/algos.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/algos.h @@ -14,7 +14,7 @@ public: const char* name() const override { if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( - m_matmul_algo->name(), {4, 2, m_tile_size}); + m_matmul_algo->name(), {4, 2, m_tile_size, 3}); } return m_name.c_str(); } @@ -31,7 +31,7 @@ public: const char* name() const override { if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( - m_matmul_algo->name(), {1, 6, m_tile_size}); + m_matmul_algo->name(), {1, 6, m_tile_size, 3}); } return m_name.c_str(); } @@ -42,6 +42,28 @@ public: MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_FP32) }; +class ConvBiasImpl::AlgoFP32WinogradF43 final : public AlgoBase { +public: + AlgoFP32WinogradF43( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) + : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} + + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasImpl::algo_name( + m_matmul_algo->name(), {1, 4, m_tile_size, 3}); + } + return m_name.c_str(); + } + + AlgoAttribute attribute() const override { + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; + } + + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F43_FP32); +}; + class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { public: AlgoFP32WinogradF63_4x4( @@ -50,7 +72,7 @@ public: const char* name() const override { if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( - m_matmul_algo->name(), {4, 6, m_tile_size}); + m_matmul_algo->name(), {4, 6, m_tile_size, 3}); } return m_name.c_str(); } @@ -67,7 +89,7 @@ public: const char* name() const override { if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( - m_matmul_algo->name(), {1, 5, m_tile_size}); + m_matmul_algo->name(), {1, 5, m_tile_size, 4}); } return m_name.c_str(); } @@ -86,7 +108,7 @@ public: const char* name() const override { if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( - m_matmul_algo->name(), {1, 4, m_tile_size}); + m_matmul_algo->name(), {1, 4, m_tile_size, 5}); } return m_name.c_str(); } @@ -106,7 +128,7 @@ public: const char* name() const override { if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( - m_matmul_algo->name(), {4, 2, m_tile_size}, + m_matmul_algo->name(), {4, 2, m_tile_size, 3}, param::ConvBias::Format::NCHW44); } return m_name.c_str(); @@ -124,7 +146,7 @@ public: const char* name() const override { if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( - m_matmul_algo->name(), {4, 6, m_tile_size}, + m_matmul_algo->name(), {4, 6, m_tile_size, 3}, param::ConvBias::Format::NCHW44); } return m_name.c_str(); @@ -142,7 +164,7 @@ public: const char* name() const override { if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( - m_matmul_algo->name(), {4, 7, m_tile_size}, + m_matmul_algo->name(), {4, 7, m_tile_size, 3}, param::ConvBias::Format::NCHW44); } return m_name.c_str(); diff --git a/dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h b/dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h index 748718946f025dd2a2fdef6d6981964060a63413..501eabfe7bf6f8a9470dbc97c721ba52f522a5b4 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h @@ -155,6 +155,124 @@ struct FilterTransform6X3 { #undef FILTER_TRANSFORM #undef GET_VECTOR_ELEM +template +struct FilterTransform4X3 { +#define FILTER_TRANSFORM(d, wd, ADDC, SUBC, MULC) \ + do { \ + wd##0 = MULC(d##0, 0.25f); \ + auto tmp0 = MULC(ADDC(d##0, d##2), -0.1666667f); \ + auto tmp1 = MULC(d##1, -0.1666667f); \ + wd##1 = ADDC(tmp0, tmp1); \ + wd##2 = SUBC(tmp0, tmp1); \ + tmp0 = ADDC(MULC(d##0, 0.0416667f), MULC(d##2, 0.1666667f)); \ + tmp1 = MULC(d##1, 0.0833333f); \ + wd##3 = ADDC(tmp0, tmp1); \ + wd##4 = SUBC(tmp0, tmp1); \ + wd##5 = d##2; \ + } while (0); + + static void transform( + const float* filter, float* filter_transform_buf, float* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { + constexpr size_t alpha = 4 + 3 - 1; + size_t OCB = OC / 4; + size_t ICB = IC / 4; + for (size_t oc = oc_start; oc < oc_end; oc++) { + rep(ic, IC) { + const float* fptr = filter + (oc * IC + ic) * 3 * 3; + + GI_FLOAT32_t g0 = GiLoadFloat32(fptr); + GI_FLOAT32_t g1 = GiLoadFloat32(fptr + 3); + + GI_FLOAT32_t g2 = GiLoadFloat32(fptr + 6 - 1); + GI_FLOAT32_t zeros = GiZeroFloat32(); + g2 = GiExtqFloat32(g2, zeros, 1); + +#define cb(i) GI_FLOAT32_t wd##i = GiZeroFloat32(); +#if MEGDNN_AARCH64 + UNROLL_CALL_NOWRAPPER(8, cb); +#else + UNROLL_CALL_NOWRAPPER(6, cb); +#endif +#undef cb + + FILTER_TRANSFORM(g, wd, ADDF, SUBF, MULSF); + + size_t ocb = oc / 4; + size_t oc4 = oc % 4; + size_t icb = ic / 4; + size_t ic4 = ic % 4; +#if MEGDNN_AARCH64 + +#define cb(i) GI_FLOAT32_V2_t wdt##i; + UNROLL_CALL_NOWRAPPER(3, cb); +#undef cb + +#define cb(i) GI_FLOAT32_V2_t ret##i; + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + TRANSPOSE_8x3(wd, wdt); + FILTER_TRANSFORM(wdt, ret, ADDFV2, SUBFV2, MULSFV2); + +#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i); + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + rep(i, alpha) rep(j, alpha) { + if (format == param::MatrixMul::Format::DEFAULT) { + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + oc] = + transform_mid_buf[j * alpha + i]; + } else { + filter_transform_buf + [(i * alpha + j) * OCB * ICB * 4 * 4 + + ocb * ICB * 4 * 4 + icb * 4 * 4 + ic4 * 4 + oc4] = + transform_mid_buf[j * alpha + i]; + } + } + +#else +#define cb(i) \ + do { \ + mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0) * 0.25f; \ + auto tmp0 = \ + (GET_VECTOR_ELEM(wd, i, 0) + GET_VECTOR_ELEM(wd, i, 2)) * -0.1666667f; \ + auto tmp1 = GET_VECTOR_ELEM(wd, i, 1) * -0.1666667f; \ + mid_buf1[1] = tmp0 + tmp1; \ + mid_buf1[2] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.0416667f + \ + GET_VECTOR_ELEM(wd, i, 2) * 0.1666667f; \ + tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.0833333f; \ + mid_buf1[3] = tmp0 + tmp1; \ + mid_buf1[4] = tmp0 - tmp1; \ + mid_buf1[5] = GET_VECTOR_ELEM(wd, i, 2); \ + mid_buf1 += 6; \ + } while (0); +#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i)) + + float* mid_buf1 = transform_mid_buf; + UNROLL_CALL_NOWRAPPER(6, cb); + mid_buf1 = transform_mid_buf; +#undef cb + + rep(i, alpha) rep(j, alpha) { + if (format == param::MatrixMul::Format::DEFAULT) { + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + oc] = + transform_mid_buf[i * alpha + j]; + } else { + filter_transform_buf + [(i * alpha + j) * OCB * ICB * 4 * 4 + + ocb * ICB * 4 * 4 + icb * 4 * 4 + ic4 * 4 + oc4] = + transform_mid_buf[i * alpha + j]; + } + } +#endif + } + } + } +}; +#undef FILTER_TRANSFORM +#undef GET_VECTOR_ELEM + } // namespace fallback } // namespace megdnn diff --git a/dnn/src/fallback/conv_bias/gi/fp32/helper.h b/dnn/src/fallback/conv_bias/gi/fp32/helper.h index 00ecce48110d3ad1a2d8dd24ce5a8aab6a74eb1a..979411f3554bdace5cf9bb19e57f8bb67adbcdeb 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/helper.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/helper.h @@ -116,6 +116,46 @@ inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) { GiReinterpretqFloat32ToS64(b7.val[1]))); \ } while (0); +#define TRANSPOSE_6x6(a, ret) \ + do { \ + auto b0 = GiZipqFloat32(CONCAT(a, 00), CONCAT(a, 10)); \ + auto b1 = GiZipqFloat32(CONCAT(a, 01), CONCAT(a, 11)); \ + auto b2 = GiZipqFloat32(CONCAT(a, 20), CONCAT(a, 30)); \ + auto b3 = GiZipqFloat32(CONCAT(a, 21), CONCAT(a, 31)); \ + auto b4 = GiZipqFloat32(CONCAT(a, 40), CONCAT(a, 50)); \ + auto b5 = GiZipqFloat32(CONCAT(a, 41), CONCAT(a, 51)); \ + CONCAT(ret, 00) = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b0.val[0]), \ + GiReinterpretqFloat32ToS64(b2.val[0]))); \ + CONCAT(ret, 01) = b4.val[0]; \ + CONCAT(ret, 10) = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b0.val[0]), \ + GiReinterpretqFloat32ToS64(b2.val[0]))); \ + CONCAT(ret, 11) = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b4.val[0]), \ + GiReinterpretqFloat32ToS64(b5.val[0]))); \ + CONCAT(ret, 20) = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b0.val[1]), \ + GiReinterpretqFloat32ToS64(b2.val[1]))); \ + CONCAT(ret, 21) = b4.val[1]; \ + CONCAT(ret, 30) = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b0.val[1]), \ + GiReinterpretqFloat32ToS64(b2.val[1]))); \ + CONCAT(ret, 31) = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b4.val[1]), \ + GiReinterpretqFloat32ToS64(b5.val[1]))); \ + CONCAT(ret, 40) = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b1.val[0]), \ + GiReinterpretqFloat32ToS64(b3.val[0]))); \ + CONCAT(ret, 41) = b5.val[0]; \ + CONCAT(ret, 50) = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b1.val[0]), \ + GiReinterpretqFloat32ToS64(b3.val[0]))); \ + CONCAT(ret, 51) = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b5.val[0]), \ + GiReinterpretqFloat32ToS64(b4.val[0]))); \ + } while (0); + #define TRANSPOSE_8x3(a, ret) \ auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1)); \ auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3)); \ diff --git a/dnn/src/fallback/conv_bias/gi/fp32/strategy.h b/dnn/src/fallback/conv_bias/gi/fp32/strategy.h index dc48c6f4ad6f5475bd0b7ad9fc8e068158d02e65..fe09b269d935bf333abac3735e448b7802d2031b 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/strategy.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy.h @@ -12,6 +12,8 @@ MEGDNN_REG_WINOGRAD_STRATEGY( MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 1, 1, winograd_6x3_1x1_f) +MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 4, 3, 1, 1, winograd_4x3_1x1_f) + MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 4, 4, winograd_6x3_4x4_f) MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 5, 4, 1, 1, winograd_5x4_1x1_f) diff --git a/dnn/src/fallback/conv_bias/gi/fp32/strategy_4x3.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_4x3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a05197231d88b3877b8bcd83b2cffe5786d7d0c4 --- /dev/null +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_4x3.cpp @@ -0,0 +1,372 @@ +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" +#include "src/fallback/conv_bias/gi/fp32/helper.h" +#include "src/fallback/conv_bias/gi/fp32/strategy.h" +#include "src/fallback/conv_bias/winograd/winograd.h" +#include "src/fallback/elemwise_helper/op_unary.h" +#include "src/naive/matrix_mul/matrix_mul_helper.h" + +#include "midout.h" +MIDOUT_DECL(megdnn_fallback_winograd_fp32_F43) + +using namespace megdnn; +using namespace fallback; +namespace { + +/** + * input transform + * + * wd0 = 4 * (d0 - d2) - (d2 - d4) + * wd1 = -4 * (d1 + d2) + (d3 + d4) + * wd2 = 4 * (d1 - d2) + (d4 - d3) + * wd3 = 2 * (d3 - d1) - (d2 - d4) + * wd4 = -2 * (d3 - d1) - (d2 - d4) + * wd5 = -4 * (d3 - d1) + (d5 - d3) + */ + +#define INPUT_TRANSFORM(d, wd, i) \ + do { \ + auto tmp0 = SUBF(d##2##i, d##4##i); \ + auto tmp1 = SUBF(d##3##i, d##1##i); \ + wd##0##i = SUBF(MULSF(SUBF(d##0##i, d##2##i), 4.0f), tmp0); \ + wd##1##i = SUBF(ADDF(d##3##i, d##4##i), MULSF(ADDF(d##1##i, d##2##i), 4.0f)); \ + wd##2##i = ADDF(MULSF(SUBF(d##1##i, d##2##i), 4.0f), SUBF(d##4##i, d##3##i)); \ + wd##3##i = SUBF(MULSF(tmp1, 2.0f), tmp0); \ + wd##4##i = SUBF(MULSF(tmp1, -2.0f), tmp0); \ + wd##5##i = SUBF(SUBF(d##5##i, d##3##i), MULSF(tmp1, 4.0f)); \ + } while (0); + +#define INPUT_TRANSFORM_V2(d, wd) \ + INPUT_TRANSFORM(d, wd, 0); \ + INPUT_TRANSFORM(d, wd, 1); + +#define GET_VECTOR_HIGH_ELEM(s, i, idx) GiExtractLane##idx##Float32(s##i##1) +#define GET_VECTOR_LOW_ELEM(s, i, idx) GiExtractLane##idx##Float32(s##i##0) + +struct InputTransform4X3 { + template + static void transform( + const float* input, float* input_transform_buf, float* transform_mid_buf, + int ih_start, int iw_start, size_t ic, size_t IH, size_t IW, size_t IC, + size_t unit_idx, size_t nr_units_in_tile) { + constexpr size_t alpha = 4 + 3 - 1; + if (!inner) { + memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); + } + +#define cb(i, j) GI_FLOAT32_t d##i##j; + UNROLL_CALL_NOWRAPPER_D2(6, 2, cb); +#undef cb + if (inner) { + const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; +#define cb(i, j) d##i##j = GiLoadFloat32(input_ptr + IW * i + 4 * j); + UNROLL_CALL_NOWRAPPER_D2(5, 2, cb); +#undef cb + d50 = GiLoadFloat32(input_ptr + IW * 5); + d51 = GiLoadFloat32LowHalf(input_ptr + IW * 5 + 4); + } else { + int ih0_act = std::max(ih_start, 0), + ih1_act = std::min(ih_start + alpha, IH), + iw0_act = std::max(iw_start, 0), + iw1_act = std::min(iw_start + alpha, IW); + for (int ih = ih0_act; ih < ih1_act; ++ih) { + for (int iw = iw0_act; iw < iw1_act; ++iw) { + size_t iho = ih - ih_start, iwo = iw - iw_start; + transform_mid_buf[iho * alpha + iwo] = + input[ic * IH * IW + ih * IW + iw]; + } + } +#define cb(i, j) d##i##j = GiLoadFloat32(transform_mid_buf + alpha * i + 4 * j); + UNROLL_CALL_NOWRAPPER_D2(5, 2, cb); +#undef cb + d50 = GiLoadFloat32(transform_mid_buf + alpha * 5); + d51 = GiLoadFloat32LowHalf(transform_mid_buf + alpha * 5 + 4); + } + +#define cb(i, j) GI_FLOAT32_t wd##i##j; + UNROLL_CALL_NOWRAPPER_D2(6, 2, cb); +#undef cb + + INPUT_TRANSFORM_V2(d, wd); + +#if MEGDNN_AARCH64 +#define cb(i, j) GI_FLOAT32_t ret##i##j; + UNROLL_CALL_NOWRAPPER_D2(6, 2, cb); +#undef cb + TRANSPOSE_6x6(wd, d); + INPUT_TRANSFORM_V2(d, ret); + +#define cb(i, j) GiStoreFloat32(transform_mid_buf + i * alpha + j * 4, ret##i##j); + UNROLL_CALL_NOWRAPPER_D2(5, 2, cb); +#undef cb + + GiStoreFloat32(transform_mid_buf + 5 * alpha, ret50); + float tmp[4]; + GiStoreFloat32(tmp, ret51); + memcpy(transform_mid_buf + 5 * alpha + 4, tmp, sizeof(float) * 2); + + rep(i, alpha) rep(j, alpha) { + input_transform_buf + [(i * alpha + j) * nr_units_in_tile * IC + unit_idx * IC + ic] = + transform_mid_buf[j * alpha + i]; + } +#else + //! 4 0 0 0 0 0 + //! 0 -4 4 -2 2 4 + //! -5 -4 -4 -1 -1 0 + //! 0 1 -1 2 -2 -5 + //! 1 1 1 1 1 0 + //! 0 0 0 0 0 1 +#define cb(i) \ + do { \ + auto tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) - GET_VECTOR_HIGH_ELEM(wd, i, 0); \ + auto tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 3) - GET_VECTOR_LOW_ELEM(wd, i, 1); \ + mid_buf1[0] = \ + (GET_VECTOR_LOW_ELEM(wd, i, 0) - GET_VECTOR_LOW_ELEM(wd, i, 2)) * \ + 4.0f - \ + tmp0; \ + mid_buf1[1] = \ + (GET_VECTOR_LOW_ELEM(wd, i, 1) + GET_VECTOR_LOW_ELEM(wd, i, 2)) * \ + -4.0f + \ + (GET_VECTOR_LOW_ELEM(wd, i, 3) + GET_VECTOR_HIGH_ELEM(wd, i, 0)); \ + mid_buf1[2] = \ + (GET_VECTOR_LOW_ELEM(wd, i, 1) - GET_VECTOR_LOW_ELEM(wd, i, 2)) * \ + 4.0f + \ + (GET_VECTOR_HIGH_ELEM(wd, i, 0) - GET_VECTOR_LOW_ELEM(wd, i, 3)); \ + mid_buf1[3] = 2.0f * tmp1 - tmp0; \ + mid_buf1[4] = -2.0f * tmp1 - tmp0; \ + mid_buf1[5] = -4.0f * tmp1 + (GET_VECTOR_HIGH_ELEM(wd, i, 1) - \ + GET_VECTOR_LOW_ELEM(wd, i, 3)); \ + mid_buf1 += 6; \ + } while (0); + + float* mid_buf1 = transform_mid_buf; + UNROLL_CALL_NOWRAPPER(6, cb); + mid_buf1 = transform_mid_buf; + +#undef cb + rep(i, alpha) rep(j, alpha) { + input_transform_buf + [(i * alpha + j) * nr_units_in_tile * IC + unit_idx * IC + ic] = + transform_mid_buf[i * alpha + j]; + } +#endif + } +}; + +#undef INPUT_TRANSFORM_V2 +#undef INPUT_TRANSFORM + +/** + * Output Transform: use fma + * + * s0 = m0 + (m1 + m2) + (m3 + m4) + * s1 = (m1 - m2) + 2 * (m3 - m4) + * s2 = (m1 + m2) + 4 * (m3 + m4) + * s3 = (m1 - m2) + 8 * (m3 - m4) + m5 + */ +#define OUTPUT_TRANSFORM(m, s, i) \ + do { \ + auto m1addm2 = ADDF(m##1##i, m##2##i); \ + auto m1subm2 = SUBF(m##1##i, m##2##i); \ + auto m3addm4 = ADDF(m##3##i, m##4##i); \ + auto m3subm4 = SUBF(m##3##i, m##4##i); \ + s##0##i = m##0##i; \ + s##0##i = ADDF(s##0##i, m1addm2); \ + s##0##i = ADDF(s##0##i, m3addm4); \ + s##1##i = m1subm2; \ + s##1##i = GiMultiplyAddScalarFloat32(s##1##i, m3subm4, 2.0f); \ + s##2##i = m1addm2; \ + s##2##i = GiMultiplyAddScalarFloat32(s##2##i, m3addm4, 4.0f); \ + s##3##i = m1subm2; \ + s##3##i = GiMultiplyAddScalarFloat32(s##3##i, m3subm4, 8.0f); \ + s##3##i = ADDF(s##3##i, m##5##i); \ + } while (0); + +#define OUTPUT_TRANSFORM_V2(m, s) \ + OUTPUT_TRANSFORM(m, s, 0); \ + OUTPUT_TRANSFORM(m, s, 1); + +template +struct OutputTransform4X3 { + static void transform( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, const DType& dst_dtype) { + constexpr size_t alpha = 4 + 3 - 1; + Op op(src_dtype, dst_dtype); + float* mid_buf1 = transform_mid_buf; + + //! AT * m * A + size_t OC = oc_end - oc_start; + size_t oc = oc_start + oc_index; + +#define cb(m, n) \ + transform_mid_buf[m * alpha + n] = output_transform_buf \ + [(m * alpha + n) * nr_units_in_tile * OC + unit_idx * OC + oc_index]; + UNROLL_CALL_NOWRAPPER_D2(6, 6, cb); +#undef cb + +#define cb(i, j) auto m##i##j = GiLoadFloat32(transform_mid_buf + alpha * i + 4 * j); + UNROLL_CALL_NOWRAPPER_D2(5, 2, cb); +#undef cb + GI_FLOAT32_t m50, m51; + m50 = GiLoadFloat32(transform_mid_buf + alpha * 5); + m51 = GiLoadFloat32LowHalf(transform_mid_buf + alpha * 5 + 4); +#define cb(i, j) GI_FLOAT32_t s##i##j; + UNROLL_CALL_NOWRAPPER_D2(4, 2, cb); +#undef cb + + OUTPUT_TRANSFORM_V2(m, s); + /** + * Output transform: s * A + * + * 1 0 0 0 + * 1 1 1 1 + * 1 -1 1 -1 + * 1 2 4 8 + * 1 -2 4 -8 + * 0 0 0 1 + */ +#define cb(i) \ + do { \ + auto m1addm2 = GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ + auto m1subm2 = GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \ + auto m3addm4 = GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \ + auto m3subm4 = GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \ + mid_buf1[0] = GET_VECTOR_LOW_ELEM(s, i, 0) + m1addm2 + m3addm4; \ + mid_buf1[1] = m1subm2 + 2.f * m3subm4; \ + mid_buf1[2] = m1addm2 + 4.f * m3addm4; \ + mid_buf1[3] = m1subm2 + 8.f * m3subm4 + GET_VECTOR_HIGH_ELEM(s, i, 1); \ + mid_buf1 += 4; \ + } while (0); + + mid_buf1 = transform_mid_buf; + UNROLL_CALL_NOWRAPPER(4, cb); + mid_buf1 = transform_mid_buf; +#undef cb + + if (oh_start + 4 <= OH && ow_start + 4 <= OW) { + GI_FLOAT32_t bias0; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias0 = GiBroadcastFloat32(bias[oc]); + } + rep(i, 4) { + size_t oh = oh_start + i; + GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1); + + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + item0 = GiAddFloat32(item0, bias0); + } else if (bmode == BiasMode::BIAS) { + bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start); + item0 = GiAddFloat32(item0, bias0); + } + item0 = op(item0); + GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0); + + mid_buf1 += 4; + } + } else { + for (size_t oho = 0; oho < 4 && oh_start + oho < OH; ++oho) { + for (size_t owo = 0; owo < 4 && ow_start + owo < OW; ++owo) { + size_t oh = oh_start + oho; + size_t ow = ow_start + owo; + float res = mid_buf1[oho * 4 + owo]; + if (bmode == BiasMode::BIAS) { + res += bias[oc * OH * OW + oh * OW + ow]; + } else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + res += bias[oc]; + } + res = op(res); + output[oc * OH * OW + oh * OW + ow] = res; + } + } + } + } +}; + +#undef GET_VECTOR_HIGH_ELEM +#undef GET_VECTOR_LOW_ELEM +#undef OUTPUT_TRANSFORM_V2 +#undef OUTPUT_TRANSFORM + +} // namespace + +namespace megdnn { +namespace fallback { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x3_1x1_f) + +void winograd_4x3_1x1_f::filter( + const float* filter, float* filter_transform_buf, float* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { + FilterTransform4X3::transform( + filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, oc_end); +} + +void winograd_4x3_1x1_f::input( + const float* input, float* input_transform_buf, float* transform_mid_buf, + size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, + size_t nr_units_in_tile) { + constexpr int alpha = 3 + 4 - 1; + + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + rep(ic, IC) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransform4X3::transform( + input, input_transform_buf, transform_mid_buf, ih_start, + iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); + + } else { + InputTransform4X3::transform( + input, input_transform_buf, transform_mid_buf, ih_start, + iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); + } + } + } +} + +void winograd_4x3_1x1_f::output( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + OutputTransform4X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); + + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + + for (size_t oc = oc_start; oc < oc_end; oc++) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + GI_DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_fallback_winograd_fp32_F43, cb, float, float, bmode, + nonline_mode, output_transform_buf, bias, output, transform_mid_buf, + oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, + nr_units_in_tile, src_dtype, dst_dtype); + } + } +#undef cb +} + +} // namespace winograd +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp index 4adf2af0642b35f268c30f91442835af12fbdb5b..d1e4161c613f03a6efcbe9174b0605db18f9f98f 100644 --- a/dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp @@ -93,7 +93,6 @@ struct InputTransform6X3 { #undef cb INPUT_TRANSFORM(d, wd); - #if MEGDNN_AARCH64 #define cb(i) GI_FLOAT32_V2_t ret##i; UNROLL_CALL_NOWRAPPER(8, cb); diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index 897dbc95a792ef0d5233f958bf82e608035dba17..18db4d0eb8127a65535c43397ea9d09502d8acd9 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -159,6 +159,10 @@ public: static_cast(algo), tile_size)); m_gi_winograd_algos.emplace_back(refhold.back().get()); + refhold.emplace_back(new AlgoFP32WinogradF43( + static_cast(algo), + tile_size)); + m_gi_winograd_algos.emplace_back(refhold.back().get()); refhold.emplace_back(new AlgoFP32WinogradF54( static_cast(algo), tile_size)); diff --git a/dnn/src/fallback/conv_bias/opr_impl.h b/dnn/src/fallback/conv_bias/opr_impl.h index d39ed2d36f1b3946c6e02ebf15636f464584ea1a..5a0fe75ae3131a9e4ebbbf8a210b93a830901b18 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.h +++ b/dnn/src/fallback/conv_bias/opr_impl.h @@ -217,6 +217,7 @@ public: FB_IM2COL, GI_COMMON_WINOGRAD_F23_4X4_FP32, GI_COMMON_WINOGRAD_F63_FP32, + GI_COMMON_WINOGRAD_F43_FP32, GI_COMMON_WINOGRAD_F63_4X4_FP32, GI_COMMON_WINOGRAD_F54_FP32, GI_COMMON_WINOGRAD_F45_FP32, @@ -379,6 +380,7 @@ private: class AlgoFP32WinogradF23_4x4; class AlgoFP32WinogradF63; + class AlgoFP32WinogradF43; class AlgoFP32WinogradF63_4x4; class AlgoFP32WinogradF54; class AlgoFP32WinogradF45; diff --git a/dnn/src/x86/conv_bias/f32/algos.h b/dnn/src/x86/conv_bias/f32/algos.h index ce12eab90227107ce00861e462fcbe6666bb622d..23dc27baa6c651ee711fc42b78f5f64ca6542002 100644 --- a/dnn/src/x86/conv_bias/f32/algos.h +++ b/dnn/src/x86/conv_bias/f32/algos.h @@ -83,7 +83,7 @@ public: const char* name() const override { if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( - m_matmul_algo->name(), {8, 6, m_tile_size}); + m_matmul_algo->name(), {8, 6, m_tile_size, 3}); } return m_name.c_str(); } @@ -100,7 +100,7 @@ public: const char* name() const override { if (m_name.empty()) { m_name = ConvBiasImpl::algo_name( - m_matmul_algo->name(), {8, 2, m_tile_size}); + m_matmul_algo->name(), {8, 2, m_tile_size, 3}); } return m_name.c_str(); } diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index 5a6413b60a97275b2ece1b511f6977f416bb11fa..3bdfa2c7e5b88a68a3e3966f9b2b123f58d45ad1 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -47,6 +47,25 @@ TEST_F(ARM_COMMON, CONV_BIAS_MATMUL) { } } +TEST_F(ARM_COMMON, CONV_BIAS_WINOGRAD) { + using namespace conv_bias; + std::vector args = get_quantized_args(); + Checker> checker( + handle()); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("WINOGRAD:.*:1:4:.*:3")); + ConvBiasForward::Param param; + param.pad_h = 1; + param.pad_w = 1; + checker.set_param(param); + checker.execs( + {{1, 3, 351, 257}, + {5, 3, 3, 3}, + {}, + {}, + {}}); // Input, weight, bias, ..., Output +} + TEST_F(ARM_COMMON, CONV_BIAS_RECORD) { using namespace conv_bias; std::vector args = get_quantized_args(); @@ -987,6 +1006,13 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23) { #endif } +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F43_F63) { +#if MEGDNN_AARCH64 + benchmark_winograd_compare( + "WINOGRAD:AARCH64_F32K8X12X1:1:4:.*:3", "WINOGRAD:AARCH64_F32K8X12X1:1:6", + handle(), 3); +#endif +} TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63) { #if MEGDNN_AARCH64 benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:6", handle(), 3); @@ -1005,9 +1031,9 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F54) { TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F45) { #if MEGDNN_AARCH64 - benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:4", handle(), 5); + benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:4:.*:5", handle(), 5); #else - benchmark_winograd("WINOGRAD:ARMV7_F32:1:4", handle(), 5); + benchmark_winograd("WINOGRAD:ARMV7_F32:1:4:.*:5", handle(), 5); #endif } @@ -1026,11 +1052,12 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F23) { TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F45) { #if MEGDNN_AARCH64 benchmark_winograd_fp16( - "WINOGRAD:AARCH64_F32K8X12X1:1:4", "WINOGRAD:AARCH64_F16_K8X24X1:1:4", - handle(), 5); + "WINOGRAD:AARCH64_F32K8X12X1:1:4:.*:5", + "WINOGRAD:AARCH64_F16_K8X24X1:1:4:.*:5", handle(), 5); #else benchmark_winograd_fp16( - "WINOGRAD:ARMV7_F32:1:4", "WINOGRAD:AARCH32_F16_K4X16X1:1:4", handle(), 5); + "WINOGRAD:ARMV7_F32:1:4:.*:5", "WINOGRAD:AARCH32_F16_K4X16X1:1:4:.*:5", + handle(), 5); #endif } TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F63) { diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 6ca56acf483d7f00f3ffae97be156ea5205374ea..d6043f379a3a02a516555476c332609bbc27fdc0 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -800,6 +800,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) { #endif } +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F32_F43) { + using namespace conv_bias; + std::vector args = get_winograd_args(3); + Checker checker(handle()); + check_winograd("1:4:32", checker, args); +} + //! uncomment it when low precision mode is ok #if 0 TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) { diff --git a/dnn/test/common/conv_bias.cpp b/dnn/test/common/conv_bias.cpp index 09cbcaff6d5351df12936984f8c19d4bdeb8c84f..6476176fc8944f450a9f67753f5ad574d4a22a76 100644 --- a/dnn/test/common/conv_bias.cpp +++ b/dnn/test/common/conv_bias.cpp @@ -921,6 +921,7 @@ std::vector get_winograd_benchmark_args( TensorShape{oc, ic, kernel, kernel}, {1, oc, 1, 1}}); }; + for (size_t ic : {8, 16, 32, 64}) { for (size_t oc : {8, 16, 32, 64}) { pack(oc, ic, 56, 56, kernel, kernel / 2); @@ -1041,6 +1042,60 @@ void benchmark_winograd_weight_preprocess( computations / used_winograd); } } + +void benchmark_winograd_compare( + const char* algoA_name, const char* algoB_name, megdnn::Handle* handle, + size_t kernel, size_t pack_size) { + auto&& args = get_winograd_benchmark_args(kernel, pack_size); + using namespace conv_bias; + constexpr size_t RUN = 10; + + Benchmarker> + benchmark_winograd(handle); + benchmark_winograd.set_display(false); + benchmark_winograd.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle->create_operator(); + opr->param() = arg.param; + opr->deduce_layout( + {arg.src, dtype::Float32()}, {arg.filter, dtype::Float32()}, + {arg.bias, dtype::Float32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + param::Convolution conv_param; + conv_param.pad_h = arg.param.pad_h; + conv_param.pad_w = arg.param.pad_w; + conv_param.stride_h = arg.param.stride_h; + conv_param.stride_w = arg.param.stride_w; + + benchmark_winograd.set_param(arg.param); + auto used_winograd1 = + algo_benchmark< + ConvBias, OprWeightPreprocessBenchmarkProxy, Timer>( + benchmark_winograd, {arg.src, arg.filter, {}, {}, {}}, + algoA_name) / + RUN; + auto used_winograd2 = + algo_benchmark< + ConvBias, OprWeightPreprocessBenchmarkProxy, Timer>( + benchmark_winograd, {arg.src, arg.filter, {}, {}, {}}, + algoB_name) / + RUN; + + printf("%s %s: %s: %f ms %f Gflops %s: %f ms %f GFlops " + "speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), algoA_name, + used_winograd1, computations / used_winograd1, algoB_name, + used_winograd2, computations / used_winograd2, + used_winograd2 / used_winograd1); + } +} #endif // MEGDNN_WITH_BENCHMARK template diff --git a/dnn/test/common/conv_bias.h b/dnn/test/common/conv_bias.h index d389f0da96561b8b053d547b1ef5f7a5359d4f6c..0d0961ee721aa8d3c4cb6ea379b04c6ac9273265 100644 --- a/dnn/test/common/conv_bias.h +++ b/dnn/test/common/conv_bias.h @@ -69,6 +69,9 @@ void benchmark_winograd( void benchmark_winograd_weight_preprocess( const char* algo_name, megdnn::Handle* handle, size_t kernel, size_t pack_size = 1); +void benchmark_winograd_compare( + const char* algoA_name, const char* algoB_name, megdnn::Handle* handle, + size_t kernel, size_t pack_size = 1); #endif // MEGDNN_WITH_BENCHMARK template void check_winograd(