提交 a6a2646c 编写于 作者: M Megvii Engine Team

feat(arm): add AlgoFP32Winograd F43, and add filter size into name of winograd-related algorithms

GitOrigin-RevId: 909503a90dd729f8c8ac8f570602ee92a481facb
上级 b8821edb
......@@ -453,19 +453,21 @@ public:
};
//! param for winograd algos.
struct WinogradParam {
uint32_t channel_block_size;
uint32_t output_block_size;
uint32_t tile_size;
uint32_t filter_size;
bool operator==(const WinogradParam& rhs) const {
return channel_block_size == rhs.channel_block_size &&
output_block_size == rhs.output_block_size &&
tile_size == rhs.tile_size;
tile_size == rhs.tile_size && filter_size == rhs.filter_size;
}
std::string to_string() const;
};
static constexpr WinogradParam INVALID_WINOGRAD_PARAM = {0, 0, 0};
static constexpr WinogradParam INVALID_WINOGRAD_PARAM = {0, 0, 0, 0};
struct DirectParam {
std::string to_string() const { return ""; }
......
......@@ -14,7 +14,7 @@ public:
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {1, 2, m_tile_size});
m_matmul_algo->name(), {1, 2, m_tile_size, 3});
}
return m_name.c_str();
}
......@@ -33,7 +33,7 @@ public:
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {1, 4, m_tile_size});
m_matmul_algo->name(), {1, 4, m_tile_size, 5});
}
return m_name.c_str();
}
......@@ -51,7 +51,7 @@ public:
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {1, 6, m_tile_size});
m_matmul_algo->name(), {1, 6, m_tile_size, 3});
}
return m_name.c_str();
}
......@@ -69,7 +69,7 @@ public:
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {8, 2, m_tile_size});
m_matmul_algo->name(), {8, 2, m_tile_size, 3});
}
return m_name.c_str();
}
......
......@@ -221,7 +221,7 @@ public:
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {8, 2, m_tile_size});
m_matmul_algo->name(), {8, 2, m_tile_size, 3});
}
return m_name.c_str();
}
......@@ -239,7 +239,7 @@ public:
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {4, 2, m_tile_size},
m_matmul_algo->name(), {4, 2, m_tile_size, 3},
param::ConvBias::Format::NCHW44);
}
return m_name.c_str();
......@@ -258,7 +258,7 @@ public:
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {8, 2, m_tile_size},
m_matmul_algo->name(), {8, 2, m_tile_size, 3},
param::ConvBias::Format::NCHW44);
}
return m_name.c_str();
......
......@@ -176,7 +176,9 @@ template <typename T>
struct NCHW44ParamTrait;
std::string ConvBias::WinogradParam::to_string() const {
return ssprintf("%u:%u:%u", channel_block_size, output_block_size, tile_size);
return ssprintf(
"%u:%u:%u:%u", channel_block_size, output_block_size, tile_size,
filter_size);
}
template <typename T>
......
......@@ -165,6 +165,18 @@
cb(8, 0, ##a) cb(8, 1, ##a) cb(8, 2, ##a) cb(8, 3, ##a) \
cb(8, 4, ##a) cb(8, 5, ##a) cb(8, 6, ##a) cb(8, 7, ##a) cb(8, 8, ##a)
#define UNROLL_RAW_4x2(cb, v0, a...) \
cb(0, 0, ##a) cb(0, 1, ##a) cb(1, 0, ##a) cb(1, 1, ##a) \
cb(2, 0, ##a) cb(2, 1, ##a) cb(3, 0, ##a) cb(3, 1, ##a)
#define UNROLL_RAW_5x2(cb, v0, a...) \
UNROLL_RAW_4x2(cb, v0, ##a) \
cb(4, 0, ##a) cb(4, 1, ##a)
#define UNROLL_RAW_6x2(cb, v0, a...) \
UNROLL_RAW_5x2(cb, v0, ##a) \
cb(5, 0, ##a) cb(5, 1, ##a)
#define UNROLL_CALL0_D2(step, step2, cb, v...) \
UNROLL_RAW_##step##x##step2(cb, 0, ##v)
#define UNROLL_CALL1_D2(step, step2, cb, v...) \
......
......@@ -42,7 +42,7 @@ public:
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
ssprintf("FALLBACK_WINOGRAD_F32-%s", m_matmul_algo->name()),
{1, 2, UNIT_TILE_SIZE});
{1, 2, UNIT_TILE_SIZE, 3});
}
return m_name.c_str();
}
......@@ -74,7 +74,7 @@ public:
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
ssprintf("FALLBACK_WINOGRAD_F32-%s", m_matmul_algo->name()),
{4, 2, UNIT_TILE_SIZE});
{4, 2, UNIT_TILE_SIZE, 3});
}
return m_name.c_str();
}
......@@ -106,7 +106,7 @@ public:
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
ssprintf("FALLBACK_WINOGRAD_QS8-%s", m_matmul_algo->name()),
{1, 2, UNIT_TILE_SIZE});
{1, 2, UNIT_TILE_SIZE, 3});
}
return m_name.c_str();
}
......@@ -138,7 +138,7 @@ public:
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
ssprintf("FALLBACK_WINOGRAD_QS8-%s", m_matmul_algo->name()),
{8, 2, UNIT_TILE_SIZE});
{8, 2, UNIT_TILE_SIZE, 3});
}
return m_name.c_str();
}
......
......@@ -84,6 +84,38 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(
AlgoFP32WinogradF63, winograd::winograd_6x3_1x1_f,
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT);
/* ======================= AlgoFP32WinogradF43 ======================== */
bool ConvBiasImpl::AlgoFP32WinogradF43::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 5, 0) {
using Strategy = winograd::winograd_4x3_1x1_f;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy>(strategy, m_tile_size, param)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
param.filter_meta.format == param::ConvBias::Format::NCHW &&
!param.filter_meta.should_flip &&
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 3) &&
(param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1) &&
(param.filter_meta.dilation[0] == param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
param.src_type.enumv() == DTypeEnum::Float32;
}
MIDOUT_END();
return false;
}
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(
AlgoFP32WinogradF43, winograd::winograd_4x3_1x1_f,
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT);
/* ======================= AlgoFP32WinogradF54 ======================== */
bool ConvBiasImpl::AlgoFP32WinogradF54::usable(
......
......@@ -14,7 +14,7 @@ public:
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {4, 2, m_tile_size});
m_matmul_algo->name(), {4, 2, m_tile_size, 3});
}
return m_name.c_str();
}
......@@ -31,7 +31,7 @@ public:
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {1, 6, m_tile_size});
m_matmul_algo->name(), {1, 6, m_tile_size, 3});
}
return m_name.c_str();
}
......@@ -42,6 +42,28 @@ public:
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_FP32)
};
class ConvBiasImpl::AlgoFP32WinogradF43 final : public AlgoBase {
public:
AlgoFP32WinogradF43(
fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size)
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {1, 4, m_tile_size, 3});
}
return m_name.c_str();
}
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F43_FP32);
};
class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase {
public:
AlgoFP32WinogradF63_4x4(
......@@ -50,7 +72,7 @@ public:
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {4, 6, m_tile_size});
m_matmul_algo->name(), {4, 6, m_tile_size, 3});
}
return m_name.c_str();
}
......@@ -67,7 +89,7 @@ public:
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {1, 5, m_tile_size});
m_matmul_algo->name(), {1, 5, m_tile_size, 4});
}
return m_name.c_str();
}
......@@ -86,7 +108,7 @@ public:
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {1, 4, m_tile_size});
m_matmul_algo->name(), {1, 4, m_tile_size, 5});
}
return m_name.c_str();
}
......@@ -106,7 +128,7 @@ public:
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {4, 2, m_tile_size},
m_matmul_algo->name(), {4, 2, m_tile_size, 3},
param::ConvBias::Format::NCHW44);
}
return m_name.c_str();
......@@ -124,7 +146,7 @@ public:
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {4, 6, m_tile_size},
m_matmul_algo->name(), {4, 6, m_tile_size, 3},
param::ConvBias::Format::NCHW44);
}
return m_name.c_str();
......@@ -142,7 +164,7 @@ public:
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {4, 7, m_tile_size},
m_matmul_algo->name(), {4, 7, m_tile_size, 3},
param::ConvBias::Format::NCHW44);
}
return m_name.c_str();
......
......@@ -155,6 +155,124 @@ struct FilterTransform6X3 {
#undef FILTER_TRANSFORM
#undef GET_VECTOR_ELEM
template <param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT>
struct FilterTransform4X3 {
#define FILTER_TRANSFORM(d, wd, ADDC, SUBC, MULC) \
do { \
wd##0 = MULC(d##0, 0.25f); \
auto tmp0 = MULC(ADDC(d##0, d##2), -0.1666667f); \
auto tmp1 = MULC(d##1, -0.1666667f); \
wd##1 = ADDC(tmp0, tmp1); \
wd##2 = SUBC(tmp0, tmp1); \
tmp0 = ADDC(MULC(d##0, 0.0416667f), MULC(d##2, 0.1666667f)); \
tmp1 = MULC(d##1, 0.0833333f); \
wd##3 = ADDC(tmp0, tmp1); \
wd##4 = SUBC(tmp0, tmp1); \
wd##5 = d##2; \
} while (0);
static void transform(
const float* filter, float* filter_transform_buf, float* transform_mid_buf,
size_t OC, size_t IC, size_t oc_start, size_t oc_end) {
constexpr size_t alpha = 4 + 3 - 1;
size_t OCB = OC / 4;
size_t ICB = IC / 4;
for (size_t oc = oc_start; oc < oc_end; oc++) {
rep(ic, IC) {
const float* fptr = filter + (oc * IC + ic) * 3 * 3;
GI_FLOAT32_t g0 = GiLoadFloat32(fptr);
GI_FLOAT32_t g1 = GiLoadFloat32(fptr + 3);
GI_FLOAT32_t g2 = GiLoadFloat32(fptr + 6 - 1);
GI_FLOAT32_t zeros = GiZeroFloat32();
g2 = GiExtqFloat32(g2, zeros, 1);
#define cb(i) GI_FLOAT32_t wd##i = GiZeroFloat32();
#if MEGDNN_AARCH64
UNROLL_CALL_NOWRAPPER(8, cb);
#else
UNROLL_CALL_NOWRAPPER(6, cb);
#endif
#undef cb
FILTER_TRANSFORM(g, wd, ADDF, SUBF, MULSF);
size_t ocb = oc / 4;
size_t oc4 = oc % 4;
size_t icb = ic / 4;
size_t ic4 = ic % 4;
#if MEGDNN_AARCH64
#define cb(i) GI_FLOAT32_V2_t wdt##i;
UNROLL_CALL_NOWRAPPER(3, cb);
#undef cb
#define cb(i) GI_FLOAT32_V2_t ret##i;
UNROLL_CALL_NOWRAPPER(6, cb);
#undef cb
TRANSPOSE_8x3(wd, wdt);
FILTER_TRANSFORM(wdt, ret, ADDFV2, SUBFV2, MULSFV2);
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i);
UNROLL_CALL_NOWRAPPER(6, cb);
#undef cb
rep(i, alpha) rep(j, alpha) {
if (format == param::MatrixMul::Format::DEFAULT) {
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + oc] =
transform_mid_buf[j * alpha + i];
} else {
filter_transform_buf
[(i * alpha + j) * OCB * ICB * 4 * 4 +
ocb * ICB * 4 * 4 + icb * 4 * 4 + ic4 * 4 + oc4] =
transform_mid_buf[j * alpha + i];
}
}
#else
#define cb(i) \
do { \
mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0) * 0.25f; \
auto tmp0 = \
(GET_VECTOR_ELEM(wd, i, 0) + GET_VECTOR_ELEM(wd, i, 2)) * -0.1666667f; \
auto tmp1 = GET_VECTOR_ELEM(wd, i, 1) * -0.1666667f; \
mid_buf1[1] = tmp0 + tmp1; \
mid_buf1[2] = tmp0 - tmp1; \
tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.0416667f + \
GET_VECTOR_ELEM(wd, i, 2) * 0.1666667f; \
tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.0833333f; \
mid_buf1[3] = tmp0 + tmp1; \
mid_buf1[4] = tmp0 - tmp1; \
mid_buf1[5] = GET_VECTOR_ELEM(wd, i, 2); \
mid_buf1 += 6; \
} while (0);
#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i))
float* mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(6, cb);
mid_buf1 = transform_mid_buf;
#undef cb
rep(i, alpha) rep(j, alpha) {
if (format == param::MatrixMul::Format::DEFAULT) {
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + oc] =
transform_mid_buf[i * alpha + j];
} else {
filter_transform_buf
[(i * alpha + j) * OCB * ICB * 4 * 4 +
ocb * ICB * 4 * 4 + icb * 4 * 4 + ic4 * 4 + oc4] =
transform_mid_buf[i * alpha + j];
}
}
#endif
}
}
}
};
#undef FILTER_TRANSFORM
#undef GET_VECTOR_ELEM
} // namespace fallback
} // namespace megdnn
......
......@@ -116,6 +116,46 @@ inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) {
GiReinterpretqFloat32ToS64(b7.val[1]))); \
} while (0);
#define TRANSPOSE_6x6(a, ret) \
do { \
auto b0 = GiZipqFloat32(CONCAT(a, 00), CONCAT(a, 10)); \
auto b1 = GiZipqFloat32(CONCAT(a, 01), CONCAT(a, 11)); \
auto b2 = GiZipqFloat32(CONCAT(a, 20), CONCAT(a, 30)); \
auto b3 = GiZipqFloat32(CONCAT(a, 21), CONCAT(a, 31)); \
auto b4 = GiZipqFloat32(CONCAT(a, 40), CONCAT(a, 50)); \
auto b5 = GiZipqFloat32(CONCAT(a, 41), CONCAT(a, 51)); \
CONCAT(ret, 00) = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b2.val[0]))); \
CONCAT(ret, 01) = b4.val[0]; \
CONCAT(ret, 10) = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b2.val[0]))); \
CONCAT(ret, 11) = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b4.val[0]), \
GiReinterpretqFloat32ToS64(b5.val[0]))); \
CONCAT(ret, 20) = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[1]), \
GiReinterpretqFloat32ToS64(b2.val[1]))); \
CONCAT(ret, 21) = b4.val[1]; \
CONCAT(ret, 30) = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b0.val[1]), \
GiReinterpretqFloat32ToS64(b2.val[1]))); \
CONCAT(ret, 31) = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b4.val[1]), \
GiReinterpretqFloat32ToS64(b5.val[1]))); \
CONCAT(ret, 40) = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b1.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 41) = b5.val[0]; \
CONCAT(ret, 50) = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b1.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 51) = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b5.val[0]), \
GiReinterpretqFloat32ToS64(b4.val[0]))); \
} while (0);
#define TRANSPOSE_8x3(a, ret) \
auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1)); \
auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3)); \
......
......@@ -12,6 +12,8 @@ MEGDNN_REG_WINOGRAD_STRATEGY(
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 1, 1, winograd_6x3_1x1_f)
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 4, 3, 1, 1, winograd_4x3_1x1_f)
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 4, 4, winograd_6x3_4x4_f)
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 5, 4, 1, 1, winograd_5x4_1x1_f)
......
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h"
#include "src/fallback/conv_bias/gi/fp32/helper.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/fallback/elemwise_helper/op_unary.h"
#include "src/naive/matrix_mul/matrix_mul_helper.h"
#include "midout.h"
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F43)
using namespace megdnn;
using namespace fallback;
namespace {
/**
* input transform
*
* wd0 = 4 * (d0 - d2) - (d2 - d4)
* wd1 = -4 * (d1 + d2) + (d3 + d4)
* wd2 = 4 * (d1 - d2) + (d4 - d3)
* wd3 = 2 * (d3 - d1) - (d2 - d4)
* wd4 = -2 * (d3 - d1) - (d2 - d4)
* wd5 = -4 * (d3 - d1) + (d5 - d3)
*/
#define INPUT_TRANSFORM(d, wd, i) \
do { \
auto tmp0 = SUBF(d##2##i, d##4##i); \
auto tmp1 = SUBF(d##3##i, d##1##i); \
wd##0##i = SUBF(MULSF(SUBF(d##0##i, d##2##i), 4.0f), tmp0); \
wd##1##i = SUBF(ADDF(d##3##i, d##4##i), MULSF(ADDF(d##1##i, d##2##i), 4.0f)); \
wd##2##i = ADDF(MULSF(SUBF(d##1##i, d##2##i), 4.0f), SUBF(d##4##i, d##3##i)); \
wd##3##i = SUBF(MULSF(tmp1, 2.0f), tmp0); \
wd##4##i = SUBF(MULSF(tmp1, -2.0f), tmp0); \
wd##5##i = SUBF(SUBF(d##5##i, d##3##i), MULSF(tmp1, 4.0f)); \
} while (0);
#define INPUT_TRANSFORM_V2(d, wd) \
INPUT_TRANSFORM(d, wd, 0); \
INPUT_TRANSFORM(d, wd, 1);
#define GET_VECTOR_HIGH_ELEM(s, i, idx) GiExtractLane##idx##Float32(s##i##1)
#define GET_VECTOR_LOW_ELEM(s, i, idx) GiExtractLane##idx##Float32(s##i##0)
struct InputTransform4X3 {
template <bool inner>
static void transform(
const float* input, float* input_transform_buf, float* transform_mid_buf,
int ih_start, int iw_start, size_t ic, size_t IH, size_t IW, size_t IC,
size_t unit_idx, size_t nr_units_in_tile) {
constexpr size_t alpha = 4 + 3 - 1;
if (!inner) {
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha);
}
#define cb(i, j) GI_FLOAT32_t d##i##j;
UNROLL_CALL_NOWRAPPER_D2(6, 2, cb);
#undef cb
if (inner) {
const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start;
#define cb(i, j) d##i##j = GiLoadFloat32(input_ptr + IW * i + 4 * j);
UNROLL_CALL_NOWRAPPER_D2(5, 2, cb);
#undef cb
d50 = GiLoadFloat32(input_ptr + IW * 5);
d51 = GiLoadFloat32LowHalf(input_ptr + IW * 5 + 4);
} else {
int ih0_act = std::max<int>(ih_start, 0),
ih1_act = std::min<int>(ih_start + alpha, IH),
iw0_act = std::max<int>(iw_start, 0),
iw1_act = std::min<int>(iw_start + alpha, IW);
for (int ih = ih0_act; ih < ih1_act; ++ih) {
for (int iw = iw0_act; iw < iw1_act; ++iw) {
size_t iho = ih - ih_start, iwo = iw - iw_start;
transform_mid_buf[iho * alpha + iwo] =
input[ic * IH * IW + ih * IW + iw];
}
}
#define cb(i, j) d##i##j = GiLoadFloat32(transform_mid_buf + alpha * i + 4 * j);
UNROLL_CALL_NOWRAPPER_D2(5, 2, cb);
#undef cb
d50 = GiLoadFloat32(transform_mid_buf + alpha * 5);
d51 = GiLoadFloat32LowHalf(transform_mid_buf + alpha * 5 + 4);
}
#define cb(i, j) GI_FLOAT32_t wd##i##j;
UNROLL_CALL_NOWRAPPER_D2(6, 2, cb);
#undef cb
INPUT_TRANSFORM_V2(d, wd);
#if MEGDNN_AARCH64
#define cb(i, j) GI_FLOAT32_t ret##i##j;
UNROLL_CALL_NOWRAPPER_D2(6, 2, cb);
#undef cb
TRANSPOSE_6x6(wd, d);
INPUT_TRANSFORM_V2(d, ret);
#define cb(i, j) GiStoreFloat32(transform_mid_buf + i * alpha + j * 4, ret##i##j);
UNROLL_CALL_NOWRAPPER_D2(5, 2, cb);
#undef cb
GiStoreFloat32(transform_mid_buf + 5 * alpha, ret50);
float tmp[4];
GiStoreFloat32(tmp, ret51);
memcpy(transform_mid_buf + 5 * alpha + 4, tmp, sizeof(float) * 2);
rep(i, alpha) rep(j, alpha) {
input_transform_buf
[(i * alpha + j) * nr_units_in_tile * IC + unit_idx * IC + ic] =
transform_mid_buf[j * alpha + i];
}
#else
//! 4 0 0 0 0 0
//! 0 -4 4 -2 2 4
//! -5 -4 -4 -1 -1 0
//! 0 1 -1 2 -2 -5
//! 1 1 1 1 1 0
//! 0 0 0 0 0 1
#define cb(i) \
do { \
auto tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) - GET_VECTOR_HIGH_ELEM(wd, i, 0); \
auto tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 3) - GET_VECTOR_LOW_ELEM(wd, i, 1); \
mid_buf1[0] = \
(GET_VECTOR_LOW_ELEM(wd, i, 0) - GET_VECTOR_LOW_ELEM(wd, i, 2)) * \
4.0f - \
tmp0; \
mid_buf1[1] = \
(GET_VECTOR_LOW_ELEM(wd, i, 1) + GET_VECTOR_LOW_ELEM(wd, i, 2)) * \
-4.0f + \
(GET_VECTOR_LOW_ELEM(wd, i, 3) + GET_VECTOR_HIGH_ELEM(wd, i, 0)); \
mid_buf1[2] = \
(GET_VECTOR_LOW_ELEM(wd, i, 1) - GET_VECTOR_LOW_ELEM(wd, i, 2)) * \
4.0f + \
(GET_VECTOR_HIGH_ELEM(wd, i, 0) - GET_VECTOR_LOW_ELEM(wd, i, 3)); \
mid_buf1[3] = 2.0f * tmp1 - tmp0; \
mid_buf1[4] = -2.0f * tmp1 - tmp0; \
mid_buf1[5] = -4.0f * tmp1 + (GET_VECTOR_HIGH_ELEM(wd, i, 1) - \
GET_VECTOR_LOW_ELEM(wd, i, 3)); \
mid_buf1 += 6; \
} while (0);
float* mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(6, cb);
mid_buf1 = transform_mid_buf;
#undef cb
rep(i, alpha) rep(j, alpha) {
input_transform_buf
[(i * alpha + j) * nr_units_in_tile * IC + unit_idx * IC + ic] =
transform_mid_buf[i * alpha + j];
}
#endif
}
};
#undef INPUT_TRANSFORM_V2
#undef INPUT_TRANSFORM
/**
* Output Transform: use fma
*
* s0 = m0 + (m1 + m2) + (m3 + m4)
* s1 = (m1 - m2) + 2 * (m3 - m4)
* s2 = (m1 + m2) + 4 * (m3 + m4)
* s3 = (m1 - m2) + 8 * (m3 - m4) + m5
*/
#define OUTPUT_TRANSFORM(m, s, i) \
do { \
auto m1addm2 = ADDF(m##1##i, m##2##i); \
auto m1subm2 = SUBF(m##1##i, m##2##i); \
auto m3addm4 = ADDF(m##3##i, m##4##i); \
auto m3subm4 = SUBF(m##3##i, m##4##i); \
s##0##i = m##0##i; \
s##0##i = ADDF(s##0##i, m1addm2); \
s##0##i = ADDF(s##0##i, m3addm4); \
s##1##i = m1subm2; \
s##1##i = GiMultiplyAddScalarFloat32(s##1##i, m3subm4, 2.0f); \
s##2##i = m1addm2; \
s##2##i = GiMultiplyAddScalarFloat32(s##2##i, m3addm4, 4.0f); \
s##3##i = m1subm2; \
s##3##i = GiMultiplyAddScalarFloat32(s##3##i, m3subm4, 8.0f); \
s##3##i = ADDF(s##3##i, m##5##i); \
} while (0);
#define OUTPUT_TRANSFORM_V2(m, s) \
OUTPUT_TRANSFORM(m, s, 0); \
OUTPUT_TRANSFORM(m, s, 1);
template <BiasMode bmode, typename Op>
struct OutputTransform4X3 {
static void transform(
const float* output_transform_buf, const float* bias, float* output,
float* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH,
size_t OW, size_t oc_start, size_t oc_end, size_t oc_index, size_t unit_idx,
size_t nr_units_in_tile, const DType& src_dtype, const DType& dst_dtype) {
constexpr size_t alpha = 4 + 3 - 1;
Op op(src_dtype, dst_dtype);
float* mid_buf1 = transform_mid_buf;
//! AT * m * A
size_t OC = oc_end - oc_start;
size_t oc = oc_start + oc_index;
#define cb(m, n) \
transform_mid_buf[m * alpha + n] = output_transform_buf \
[(m * alpha + n) * nr_units_in_tile * OC + unit_idx * OC + oc_index];
UNROLL_CALL_NOWRAPPER_D2(6, 6, cb);
#undef cb
#define cb(i, j) auto m##i##j = GiLoadFloat32(transform_mid_buf + alpha * i + 4 * j);
UNROLL_CALL_NOWRAPPER_D2(5, 2, cb);
#undef cb
GI_FLOAT32_t m50, m51;
m50 = GiLoadFloat32(transform_mid_buf + alpha * 5);
m51 = GiLoadFloat32LowHalf(transform_mid_buf + alpha * 5 + 4);
#define cb(i, j) GI_FLOAT32_t s##i##j;
UNROLL_CALL_NOWRAPPER_D2(4, 2, cb);
#undef cb
OUTPUT_TRANSFORM_V2(m, s);
/**
* Output transform: s * A
*
* 1 0 0 0
* 1 1 1 1
* 1 -1 1 -1
* 1 2 4 8
* 1 -2 4 -8
* 0 0 0 1
*/
#define cb(i) \
do { \
auto m1addm2 = GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \
auto m1subm2 = GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \
auto m3addm4 = GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \
auto m3subm4 = GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \
mid_buf1[0] = GET_VECTOR_LOW_ELEM(s, i, 0) + m1addm2 + m3addm4; \
mid_buf1[1] = m1subm2 + 2.f * m3subm4; \
mid_buf1[2] = m1addm2 + 4.f * m3addm4; \
mid_buf1[3] = m1subm2 + 8.f * m3subm4 + GET_VECTOR_HIGH_ELEM(s, i, 1); \
mid_buf1 += 4; \
} while (0);
mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(4, cb);
mid_buf1 = transform_mid_buf;
#undef cb
if (oh_start + 4 <= OH && ow_start + 4 <= OW) {
GI_FLOAT32_t bias0;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias0 = GiBroadcastFloat32(bias[oc]);
}
rep(i, 4) {
size_t oh = oh_start + i;
GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1);
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
item0 = GiAddFloat32(item0, bias0);
} else if (bmode == BiasMode::BIAS) {
bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start);
item0 = GiAddFloat32(item0, bias0);
}
item0 = op(item0);
GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0);
mid_buf1 += 4;
}
} else {
for (size_t oho = 0; oho < 4 && oh_start + oho < OH; ++oho) {
for (size_t owo = 0; owo < 4 && ow_start + owo < OW; ++owo) {
size_t oh = oh_start + oho;
size_t ow = ow_start + owo;
float res = mid_buf1[oho * 4 + owo];
if (bmode == BiasMode::BIAS) {
res += bias[oc * OH * OW + oh * OW + ow];
} else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
res += bias[oc];
}
res = op(res);
output[oc * OH * OW + oh * OW + ow] = res;
}
}
}
}
};
#undef GET_VECTOR_HIGH_ELEM
#undef GET_VECTOR_LOW_ELEM
#undef OUTPUT_TRANSFORM_V2
#undef OUTPUT_TRANSFORM
} // namespace
namespace megdnn {
namespace fallback {
namespace winograd {
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x3_1x1_f)
void winograd_4x3_1x1_f::filter(
const float* filter, float* filter_transform_buf, float* transform_mid_buf,
size_t OC, size_t IC, size_t oc_start, size_t oc_end) {
FilterTransform4X3<param::MatrixMul::Format::DEFAULT>::transform(
filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, oc_end);
}
void winograd_4x3_1x1_f::input(
const float* input, float* input_transform_buf, float* transform_mid_buf,
size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx,
size_t nr_units_in_tile) {
constexpr int alpha = 3 + 4 - 1;
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
rep(ic, IC) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) {
InputTransform4X3::transform<true>(
input, input_transform_buf, transform_mid_buf, ih_start,
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile);
} else {
InputTransform4X3::transform<false>(
input, input_transform_buf, transform_mid_buf, ih_start,
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile);
}
}
}
}
void winograd_4x3_1x1_f::output(
const float* output_transform_buf, const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH,
size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
OutputTransform4X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__);
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
for (size_t oc = oc_start; oc < oc_end; oc++) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
GI_DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_fallback_winograd_fp32_F43, cb, float, float, bmode,
nonline_mode, output_transform_buf, bias, output, transform_mid_buf,
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx,
nr_units_in_tile, src_dtype, dst_dtype);
}
}
#undef cb
}
} // namespace winograd
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -93,7 +93,6 @@ struct InputTransform6X3 {
#undef cb
INPUT_TRANSFORM(d, wd);
#if MEGDNN_AARCH64
#define cb(i) GI_FLOAT32_V2_t ret##i;
UNROLL_CALL_NOWRAPPER(8, cb);
......
......@@ -159,6 +159,10 @@ public:
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_gi_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF43(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_gi_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF54(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
......
......@@ -217,6 +217,7 @@ public:
FB_IM2COL,
GI_COMMON_WINOGRAD_F23_4X4_FP32,
GI_COMMON_WINOGRAD_F63_FP32,
GI_COMMON_WINOGRAD_F43_FP32,
GI_COMMON_WINOGRAD_F63_4X4_FP32,
GI_COMMON_WINOGRAD_F54_FP32,
GI_COMMON_WINOGRAD_F45_FP32,
......@@ -379,6 +380,7 @@ private:
class AlgoFP32WinogradF23_4x4;
class AlgoFP32WinogradF63;
class AlgoFP32WinogradF43;
class AlgoFP32WinogradF63_4x4;
class AlgoFP32WinogradF54;
class AlgoFP32WinogradF45;
......
......@@ -83,7 +83,7 @@ public:
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {8, 6, m_tile_size});
m_matmul_algo->name(), {8, 6, m_tile_size, 3});
}
return m_name.c_str();
}
......@@ -100,7 +100,7 @@ public:
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {8, 2, m_tile_size});
m_matmul_algo->name(), {8, 2, m_tile_size, 3});
}
return m_name.c_str();
}
......
......@@ -47,6 +47,25 @@ TEST_F(ARM_COMMON, CONV_BIAS_MATMUL) {
}
}
TEST_F(ARM_COMMON, CONV_BIAS_WINOGRAD) {
using namespace conv_bias;
std::vector<TestArg> args = get_quantized_args();
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
handle());
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>("WINOGRAD:.*:1:4:.*:3"));
ConvBiasForward::Param param;
param.pad_h = 1;
param.pad_w = 1;
checker.set_param(param);
checker.execs(
{{1, 3, 351, 257},
{5, 3, 3, 3},
{},
{},
{}}); // Input, weight, bias, ..., Output
}
TEST_F(ARM_COMMON, CONV_BIAS_RECORD) {
using namespace conv_bias;
std::vector<TestArg> args = get_quantized_args();
......@@ -987,6 +1006,13 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23) {
#endif
}
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F43_F63) {
#if MEGDNN_AARCH64
benchmark_winograd_compare(
"WINOGRAD:AARCH64_F32K8X12X1:1:4:.*:3", "WINOGRAD:AARCH64_F32K8X12X1:1:6",
handle(), 3);
#endif
}
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63) {
#if MEGDNN_AARCH64
benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:6", handle(), 3);
......@@ -1005,9 +1031,9 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F54) {
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F45) {
#if MEGDNN_AARCH64
benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:4", handle(), 5);
benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:4:.*:5", handle(), 5);
#else
benchmark_winograd("WINOGRAD:ARMV7_F32:1:4", handle(), 5);
benchmark_winograd("WINOGRAD:ARMV7_F32:1:4:.*:5", handle(), 5);
#endif
}
......@@ -1026,11 +1052,12 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F23) {
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F45) {
#if MEGDNN_AARCH64
benchmark_winograd_fp16(
"WINOGRAD:AARCH64_F32K8X12X1:1:4", "WINOGRAD:AARCH64_F16_K8X24X1:1:4",
handle(), 5);
"WINOGRAD:AARCH64_F32K8X12X1:1:4:.*:5",
"WINOGRAD:AARCH64_F16_K8X24X1:1:4:.*:5", handle(), 5);
#else
benchmark_winograd_fp16(
"WINOGRAD:ARMV7_F32:1:4", "WINOGRAD:AARCH32_F16_K4X16X1:1:4", handle(), 5);
"WINOGRAD:ARMV7_F32:1:4:.*:5", "WINOGRAD:AARCH32_F16_K4X16X1:1:4:.*:5",
handle(), 5);
#endif
}
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F63) {
......
......@@ -800,6 +800,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) {
#endif
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F32_F43) {
using namespace conv_bias;
std::vector<TestArg> args = get_winograd_args(3);
Checker<ConvBiasForward> checker(handle());
check_winograd("1:4:32", checker, args);
}
//! uncomment it when low precision mode is ok
#if 0
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) {
......
......@@ -921,6 +921,7 @@ std::vector<conv_bias::TestArg> get_winograd_benchmark_args(
TensorShape{oc, ic, kernel, kernel},
{1, oc, 1, 1}});
};
for (size_t ic : {8, 16, 32, 64}) {
for (size_t oc : {8, 16, 32, 64}) {
pack(oc, ic, 56, 56, kernel, kernel / 2);
......@@ -1041,6 +1042,60 @@ void benchmark_winograd_weight_preprocess(
computations / used_winograd);
}
}
void benchmark_winograd_compare(
const char* algoA_name, const char* algoB_name, megdnn::Handle* handle,
size_t kernel, size_t pack_size) {
auto&& args = get_winograd_benchmark_args(kernel, pack_size);
using namespace conv_bias;
constexpr size_t RUN = 10;
Benchmarker<ConvBias, Timer, OprWeightPreprocessBenchmarkProxy<ConvBias>>
benchmark_winograd(handle);
benchmark_winograd.set_display(false);
benchmark_winograd.set_times(RUN);
for (auto&& arg : args) {
TensorLayout dst_layout;
auto opr = handle->create_operator<ConvBias>();
opr->param() = arg.param;
opr->deduce_layout(
{arg.src, dtype::Float32()}, {arg.filter, dtype::Float32()},
{arg.bias, dtype::Float32()}, {}, dst_layout);
//! dst.nr_elems * IC * FH * FW * 2
float computations = dst_layout.total_nr_elems() * arg.filter[1] *
arg.filter[2] * arg.filter[3] * 2.0 /
(1024 * 1024 * 1024) * 1e3;
param::Convolution conv_param;
conv_param.pad_h = arg.param.pad_h;
conv_param.pad_w = arg.param.pad_w;
conv_param.stride_h = arg.param.stride_h;
conv_param.stride_w = arg.param.stride_w;
benchmark_winograd.set_param(arg.param);
auto used_winograd1 =
algo_benchmark<
ConvBias, OprWeightPreprocessBenchmarkProxy<ConvBias>, Timer>(
benchmark_winograd, {arg.src, arg.filter, {}, {}, {}},
algoA_name) /
RUN;
auto used_winograd2 =
algo_benchmark<
ConvBias, OprWeightPreprocessBenchmarkProxy<ConvBias>, Timer>(
benchmark_winograd, {arg.src, arg.filter, {}, {}, {}},
algoB_name) /
RUN;
printf("%s %s: %s: %f ms %f Gflops %s: %f ms %f GFlops "
"speedup: "
"%f\n",
arg.src.to_string().c_str(), arg.filter.to_string().c_str(), algoA_name,
used_winograd1, computations / used_winograd1, algoB_name,
used_winograd2, computations / used_winograd2,
used_winograd2 / used_winograd1);
}
}
#endif // MEGDNN_WITH_BENCHMARK
template <class Checker>
......
......@@ -69,6 +69,9 @@ void benchmark_winograd(
void benchmark_winograd_weight_preprocess(
const char* algo_name, megdnn::Handle* handle, size_t kernel,
size_t pack_size = 1);
void benchmark_winograd_compare(
const char* algoA_name, const char* algoB_name, megdnn::Handle* handle,
size_t kernel, size_t pack_size = 1);
#endif // MEGDNN_WITH_BENCHMARK
template <class Checker>
void check_winograd(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册