提交 63032170 编写于 作者: M Megvii Engine Team

feat(dnn/fallback): add gi fp16 nchw88 winograd F63 algo

GitOrigin-RevId: d986e1cbebd0f9ad89c27a62bd4e951459165d3d
上级 10503cfa
......@@ -91,5 +91,43 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(
AlgoFP16WinogradF23_8x8_NCHW88, winograd::winograd_F23_mk8_f16_nchw88,
megdnn_fallback_winograd_fp16_nchw88, param::MatrixMul::Format::MK8);
/* =================== AlgoFP16WinogradF63_8x8_NCHW88 ===================== */
bool ConvBiasImpl::AlgoFP16WinogradF63_8x8_NCHW88::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MIDOUT_BEGIN(
megdnn_fallback_winograd_fp16_nchw88,
midout_iv("AlgoFP16WinogradF63_8x8_NCHW88"_hash)) {
if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0)
return false;
using Strategy = winograd::winograd_F63_mk8_f16_nchw88;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy, param::MatrixMul::Format::MK8>(
strategy, m_tile_size, param)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
m_matmul_algo->packmode() ==
fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK &&
param.filter_meta.format == param::ConvBias::Format::NCHW88 &&
!param.filter_meta.should_flip &&
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 3) &&
(param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1) &&
(param.filter_meta.dilation[0] == param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
param.src_type.enumv() == DTypeEnum::Float16;
}
MIDOUT_END();
return false;
}
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(
AlgoFP16WinogradF63_8x8_NCHW88, winograd::winograd_F63_mk8_f16_nchw88,
megdnn_fallback_winograd_fp16_nchw88, param::MatrixMul::Format::MK8);
#endif
// vim: syntax=cpp.doxygen
......@@ -46,6 +46,24 @@ public:
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F23_8X8_NCHW88_F16)
};
class ConvBiasImpl::AlgoFP16WinogradF63_8x8_NCHW88 final : public AlgoBase {
public:
AlgoFP16WinogradF63_8x8_NCHW88(
fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size)
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {8, 6, m_tile_size, 3},
param::ConvBias::Format::NCHW88);
}
return m_name.c_str();
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_8X8_NCHW88_F16)
};
} // namespace fallback
} // namespace megdnn
......
......@@ -7,4 +7,5 @@
#define MULSF16 GiMultiplyScalerFloat16
#endif
#define CONCAT(a, idx) a##idx
// vim: syntax=cpp.doxygen
......@@ -17,6 +17,9 @@ MEGDNN_REG_WINOGRAD_STRATEGY(
MEGDNN_REG_WINOGRAD_STRATEGY(
dt_float16, dt_float16, dt_float16, dt_float16, 2, 3, 8, 8,
winograd_F23_mk8_f16_nchw88)
MEGDNN_REG_WINOGRAD_STRATEGY(
dt_float16, dt_float16, dt_float16, dt_float16, 6, 3, 8, 8,
winograd_F63_mk8_f16_nchw88)
} // namespace winograd
} // namespace fallback
} // namespace megdnn
......
#include "src/fallback/conv_bias/gi/fp16/strategy.h"
#if defined(GI_SUPPORT_F16)
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/common/winograd/winograd_helper.h"
#include "src/fallback/conv_bias/gi/fp16/helper.h"
#include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/fallback/elemwise_helper/op_unary.h"
#include "midout.h"
MIDOUT_DECL(megdnn_fallback_winograd_fp16_F63_mk8)
using namespace megdnn;
using namespace fallback;
namespace {
constexpr size_t alpha = 6 + 3 - 1;
constexpr size_t pack_size = 8;
constexpr gi_float16_t input_parameters[16] = {5.25f, 4.25f, 0.5f, 0.25f, 2.5f, 1.25f,
2.0f, 4.0f, 5.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f};
struct InputTransformF63_NCHW88 {
template <bool inner>
static void prepare(
const gi_float16_t* input, gi_float16_t* patch, gi_float16_t* patchT,
int ih_start, int iw_start, size_t IH, size_t IW, size_t ic, size_t IC) {
MEGDNN_MARK_USED_VAR(patch);
size_t IW8 = IW * pack_size;
size_t iw8_start = iw_start * pack_size;
size_t icb = ic / pack_size;
if (!(inner && ic + pack_size < IC)) {
memset(patchT, 0, sizeof(gi_float16_t) * pack_size * alpha * alpha);
}
if (inner) {
const gi_float16_t* input_ptr =
input + icb * IH * IW8 + ih_start * IW8 + iw8_start;
for (size_t ih = 0; ih < alpha; ih++) {
#define cb(i) auto v##i = GiLoadFloat16(input_ptr + pack_size * i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
#define cb(i) GiStoreFloat16(patchT + ih * pack_size * alpha + i * pack_size, v##i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
input_ptr += IW8;
}
} else {
int ih0_act = std::max<int>(ih_start, 0),
ih1_act = std::min<int>(ih_start + alpha, IH),
iw0_act = std::max<int>(iw_start, 0),
iw1_act = std::min<int>(iw_start + alpha, IW);
const gi_float16_t* input_ptr = input + icb * IH * IW8;
// partial copy
for (int ih = ih0_act; ih < ih1_act; ++ih) {
for (int iw = iw0_act; iw < iw1_act; ++iw) {
size_t iho = ih - ih_start, iwo = iw - iw_start;
auto src = GiLoadFloat16(input_ptr + ih * IW8 + iw * pack_size);
GiStoreFloat16(
patchT + iho * pack_size * alpha + iwo * pack_size, src);
}
}
}
}
static void transform(
const gi_float16_t* patchT, gi_float16_t* input_transform_buf,
size_t unit_idx, size_t nr_units_in_tile, size_t ic, size_t IC) {
// BT * d * B
size_t ICB = IC / pack_size;
size_t icb = ic / pack_size;
GI_FLOAT16_t d0, d1, d2, d3, d4, d5, d6, d7;
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS)
//! x86 and rvv GiSimdFmaLane API is slowly, as an alternate, use
//! GiMultiplyAddScalarFloat32
#define MADD(a, b, c, d) GiMultiplyAddScalarFloat16(a, b, *(c + d))
#define MSUB(a, b, c, d) GiMultiplySubScalarFloat16(a, b, *(c + d))
const gi_float16_t* v0 = input_parameters + 0;
const gi_float16_t* v1 = input_parameters + 8;
// const float* v2 = input_parameters + 8;
#else
#define MADD(a, b, c, d) GiSimdFmaLaneFloat16(a, b, c, d)
#define MSUB(a, b, c, d) GiFmsqLaneQFloat16(a, b, c, d)
GI_FLOAT16_t v0 = GiLoadFloat16(input_parameters + 0);
GI_FLOAT16_t v1 = GiLoadFloat16(input_parameters + 8);
// GI_FLOAT32_t v2 = GiLoadFloat32(input_parameters + 8);
#endif
//! B
//! 1 0 0 0 0 0 0 0
//! 0 1 -1 0.5 -0.5 2 -2 -1
//! -5.25 1 1 0.25 0.25 4 4 0
//! 0 -4.25 4.25 -2.5 2.5 -2.5 2.5 5.25
//! 5.25 -4.25 -4.25 -1.25 -1.25 -5 -5 0
//! 0 1 -1 2 -2 0.5 -0.5 -5.25
//! -1 1 1 1 1 1 1 0
//! 0 0 0 0 0 0 0 1
#define cb(i) \
d1 = GiLoadFloat16(patchT + i * alpha * pack_size + 1 * pack_size); \
d2 = GiLoadFloat16(patchT + i * alpha * pack_size + 2 * pack_size); \
d3 = GiLoadFloat16(patchT + i * alpha * pack_size + 3 * pack_size); \
d4 = GiLoadFloat16(patchT + i * alpha * pack_size + 4 * pack_size); \
d5 = GiLoadFloat16(patchT + i * alpha * pack_size + 5 * pack_size); \
d6 = GiLoadFloat16(patchT + i * alpha * pack_size + 6 * pack_size); \
auto t##i##0 = GiLoadFloat16(patchT + i * alpha * pack_size + 0 * pack_size); \
auto t##i##7 = GiLoadFloat16(patchT + i * alpha * pack_size + 7 * pack_size); \
auto t##i##1 = d6; \
auto t##i##2 = d6; \
auto t##i##3 = d6; \
auto t##i##4 = d6; \
auto t##i##5 = d6; \
auto t##i##6 = d6; \
t##i##0 = SUBF16(t##i##0, d6); \
t##i##1 = ADDF16(t##i##1, d1); \
t##i##2 = SUBF16(t##i##2, d1); \
t##i##3 = MADD(t##i##3, d1, v0, 2); \
t##i##4 = MSUB(t##i##4, d1, v0, 2); \
t##i##5 = MADD(t##i##5, d1, v0, 6); \
t##i##6 = MSUB(t##i##6, d1, v0, 6); \
t##i##7 = SUBF16(t##i##7, d1); \
t##i##0 = MSUB(t##i##0, d2, v0, 0); \
t##i##1 = ADDF16(t##i##1, d2); \
t##i##2 = ADDF16(t##i##2, d2); \
t##i##3 = MADD(t##i##3, d2, v0, 3); \
t##i##4 = MADD(t##i##4, d2, v0, 3); \
t##i##5 = MADD(t##i##5, d2, v0, 7); \
t##i##6 = MADD(t##i##6, d2, v0, 7); \
t##i##1 = MSUB(t##i##1, d3, v0, 1); \
t##i##2 = MADD(t##i##2, d3, v0, 1); \
t##i##3 = MSUB(t##i##3, d3, v0, 4); \
t##i##4 = MADD(t##i##4, d3, v0, 4); \
t##i##5 = MSUB(t##i##5, d3, v0, 4); \
t##i##6 = MADD(t##i##6, d3, v0, 4); \
t##i##7 = MADD(t##i##7, d3, v0, 0); \
t##i##0 = MADD(t##i##0, d4, v0, 0); \
t##i##1 = MSUB(t##i##1, d4, v0, 1); \
t##i##2 = MSUB(t##i##2, d4, v0, 1); \
t##i##3 = MSUB(t##i##3, d4, v0, 5); \
t##i##4 = MSUB(t##i##4, d4, v0, 5); \
t##i##5 = MSUB(t##i##5, d4, v1, 0); \
t##i##6 = MSUB(t##i##6, d4, v1, 0); \
t##i##1 = ADDF16(t##i##1, d5); \
t##i##2 = SUBF16(t##i##2, d5); \
t##i##3 = MADD(t##i##3, d5, v0, 6); \
t##i##4 = MSUB(t##i##4, d5, v0, 6); \
t##i##5 = MADD(t##i##5, d5, v0, 2); \
t##i##6 = MSUB(t##i##6, d5, v0, 2); \
t##i##7 = MSUB(t##i##7, d5, v0, 0);
UNROLL_CALL_RAW(8, cb);
#undef cb
#define cb(i) \
d0 = t0##i; \
d1 = t6##i; \
d2 = t6##i; \
d3 = t6##i; \
d4 = t6##i; \
d5 = t6##i; \
d6 = t6##i; \
d7 = t7##i; \
d0 = SUBF16(d0, t6##i); \
d1 = ADDF16(d1, t1##i); \
d2 = SUBF16(d2, t1##i); \
d3 = MADD(d3, t1##i, v0, 2); \
d4 = MSUB(d4, t1##i, v0, 2); \
d5 = MADD(d5, t1##i, v0, 6); \
d6 = MSUB(d6, t1##i, v0, 6); \
d7 = SUBF16(d7, t1##i); \
d0 = MSUB(d0, t2##i, v0, 0); \
d1 = ADDF16(d1, t2##i); \
d2 = ADDF16(d2, t2##i); \
d3 = MADD(d3, t2##i, v0, 3); \
d4 = MADD(d4, t2##i, v0, 3); \
d5 = MADD(d5, t2##i, v0, 7); \
d6 = MADD(d6, t2##i, v0, 7); \
d1 = MSUB(d1, t3##i, v0, 1); \
d2 = MADD(d2, t3##i, v0, 1); \
d3 = MSUB(d3, t3##i, v0, 4); \
d4 = MADD(d4, t3##i, v0, 4); \
d5 = MSUB(d5, t3##i, v0, 4); \
d6 = MADD(d6, t3##i, v0, 4); \
d7 = MADD(d7, t3##i, v0, 0); \
d0 = MADD(d0, t4##i, v0, 0); \
d1 = MSUB(d1, t4##i, v0, 1); \
d2 = MSUB(d2, t4##i, v0, 1); \
d3 = MSUB(d3, t4##i, v0, 5); \
d4 = MSUB(d4, t4##i, v0, 5); \
d5 = MSUB(d5, t4##i, v1, 0); \
d6 = MSUB(d6, t4##i, v1, 0); \
d1 = ADDF16(d1, t5##i); \
d2 = SUBF16(d2, t5##i); \
d3 = MADD(d3, t5##i, v0, 6); \
d4 = MSUB(d4, t5##i, v0, 6); \
d5 = MADD(d5, t5##i, v0, 2); \
d6 = MSUB(d6, t5##i, v0, 2); \
d7 = MSUB(d7, t5##i, v0, 0); \
GiStoreFloat16( \
input_transform_buf + \
(0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d0); \
GiStoreFloat16( \
input_transform_buf + \
(1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d1); \
GiStoreFloat16( \
input_transform_buf + \
(2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d2); \
GiStoreFloat16( \
input_transform_buf + \
(3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d3); \
GiStoreFloat16( \
input_transform_buf + \
(4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d4); \
GiStoreFloat16( \
input_transform_buf + \
(5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d5); \
GiStoreFloat16( \
input_transform_buf + \
(6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d6); \
GiStoreFloat16( \
input_transform_buf + \
(7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d7);
UNROLL_CALL_RAW(8, cb);
#undef cb
#undef MADD
#undef MSUB
}
};
template <BiasMode bmode, typename Op>
struct OutputTransformF63_NCHW88 {
static void transform(
const gi_float16_t* output_transform_buf, const gi_float16_t* bias,
gi_float16_t* output, gi_float16_t* transform_mid_buf, size_t oh_start,
size_t ow_start, size_t OH, size_t OW, size_t oc_start, size_t oc_end,
size_t oc_index, size_t unit_idx, size_t nr_units_in_tile,
const DType& src_dtype, const DType& dst_dtype) {
MEGDNN_MARK_USED_VAR(transform_mid_buf);
Op op(src_dtype, dst_dtype);
//! AT * m * A
size_t oc = oc_start + oc_index;
size_t OCB = (oc_end - oc_start) / pack_size;
size_t ocb = oc_index / pack_size;
#define cb(m, n) \
auto v##m##n = GiLoadFloat16( \
output_transform_buf + \
(m * alpha + n) * OCB * nr_units_in_tile * pack_size + \
ocb * nr_units_in_tile * pack_size + unit_idx * pack_size);
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
#undef cb
/**
* A
*
* 1 0 0 0 0 0
* 1 1 1 1 1 1
* 1 -1 1 -1 1 -1
* 1 2 4 8 16 32
* 1 -2 4 -8 16 -32
* 1 0.5 0.25 0.125 0.0625 0.03125
* 1 -0.5 0.25 -0.125 0.0625 -0.03125
* 0 0 0 0 0 1
*/
/*
* v1addv2 = v1##m + v2##m;
* v1subv2 = v1##m - v2##m;
* v3addv4 = v3##m + v4##m;
* v3subv4 = v3##m - v4##m;
* v5addv6 = v5##m + v6##m;
* v5subv6 = v5##m - v6##m;
* auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6;
* auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f;
* auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f;
* auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f;
* auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f;
* auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m;
*/
GI_FLOAT16_t v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6;
#define cb(m) \
v1addv2 = ADDF16(v1##m, v2##m); \
v1subv2 = SUBF16(v1##m, v2##m); \
v3addv4 = ADDF16(v3##m, v4##m); \
v3subv4 = SUBF16(v3##m, v4##m); \
v5addv6 = ADDF16(v5##m, v6##m); \
v5subv6 = SUBF16(v5##m, v6##m); \
auto t0##m = ADDF16(ADDF16(ADDF16(v0##m, v1addv2), v3addv4), v5addv6); \
auto t1##m = \
ADDF16(ADDF16(v1subv2, MULSF16(v3subv4, 2.f)), MULSF16(v5subv6, 0.5f)); \
auto t2##m = \
ADDF16(ADDF16(v1addv2, MULSF16(v3addv4, 4.f)), MULSF16(v5addv6, 0.25f)); \
auto t3##m = \
ADDF16(ADDF16(v1subv2, MULSF16(v3subv4, 8.f)), MULSF16(v5subv6, 0.125f)); \
auto t4##m = ADDF16( \
ADDF16(v1addv2, MULSF16(v3addv4, 16.f)), MULSF16(v5addv6, 0.0625f)); \
auto t5##m = \
ADDF16(ADDF16(ADDF16(v1subv2, MULSF16(v3subv4, 32.f)), \
MULSF16(v5subv6, 0.03125f)), \
v7##m);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
/*
* v1addv2 = t##m##1 + t##m##2;
* v1subv2 = t##m##1 - t##m##2;
* v3addv4 = t##m##3 + t##m##4;
* v3subv4 = t##m##3 - t##m##4;
* v5addv6 = t##m##5 + t##m##6;
* v5subv6 = t##m##5 - t##m##6;
* v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6;
* v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f;
* v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f;
* v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f;
* v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f;
* v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7;
*/
#define cb(m) \
v1addv2 = ADDF16(t##m##1, t##m##2); \
v1subv2 = SUBF16(t##m##1, t##m##2); \
v3addv4 = ADDF16(t##m##3, t##m##4); \
v3subv4 = SUBF16(t##m##3, t##m##4); \
v5addv6 = ADDF16(t##m##5, t##m##6); \
v5subv6 = SUBF16(t##m##5, t##m##6); \
v##m##0 = ADDF16(ADDF16(ADDF16(t##m##0, v1addv2), v3addv4), v5addv6); \
v##m##1 = ADDF16(ADDF16(v1subv2, MULSF16(v3subv4, 2.f)), MULSF16(v5subv6, 0.5f)); \
v##m##2 = ADDF16(ADDF16(v1addv2, MULSF16(v3addv4, 4.f)), MULSF16(v5addv6, 0.25f)); \
v##m##3 = \
ADDF16(ADDF16(v1subv2, MULSF16(v3subv4, 8.f)), MULSF16(v5subv6, 0.125f)); \
v##m##4 = ADDF16( \
ADDF16(v1addv2, MULSF16(v3addv4, 16.f)), MULSF16(v5addv6, 0.0625f)); \
v##m##5 = \
ADDF16(ADDF16(ADDF16(v1subv2, MULSF16(v3subv4, 32.f)), \
MULSF16(v5subv6, 0.03125f)), \
t##m##7);
UNROLL_CALL_NOWRAPPER(6, cb);
#undef cb
GI_FLOAT16_t vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = GiLoadFloat16(bias + oc);
#define cb(m, n) v##m##n = ADDF16(v##m##n, vbias);
UNROLL_CALL_RAW_D2(6, 6, cb);
#undef cb
}
if (bmode != BiasMode::BIAS) {
#define cb(m, n) v##m##n = op(CONCAT(v##m, n));
UNROLL_CALL_RAW_D2(6, 6, cb);
#undef cb
}
#define out_save(oho, owo) \
do { \
size_t oh = oh_start + oho; \
size_t ow = ow_start + owo; \
if (oh < OH && ow < OW) { \
if (bmode == BiasMode::BIAS) { \
v##oho##owo = ADDF16( \
v##oho##owo, GiLoadFloat16( \
bias + oc * OH * OW + \
oh * OW * pack_size + ow * pack_size)); \
v##oho##owo = op(v##oho##owo); \
} \
GiStoreFloat16( \
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size, \
v##oho##owo); \
} \
} while (0);
UNROLL_CALL_RAW_D2(6, 6, out_save);
}
#undef out_save
};
} // namespace
namespace megdnn {
namespace fallback {
namespace winograd {
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F63_mk8_f16_nchw88)
void winograd_F63_mk8_f16_nchw88::filter(
const dt_float16* filter, dt_float16* filter_transform_buf,
dt_float16* transform_mid_buf, size_t OC, size_t IC, size_t oc_start,
size_t oc_end) {
constexpr size_t pack_size = 8;
// Gg * GT
// G
// 1.0000000 0.0000000 0.0000000
// -0.2222222 -0.2222222 -0.2222222
// -0.2222222 0.2222222 -0.2222222
// 0.0111111 0.0222222 0.0444444
// 0.0111111 -0.0222222 0.0444444
// 0.7111111 0.3555556 0.1777778
// 0.7111111 -0.3555556 0.1777778
// 0.0000000 0.0000000 1.0000000
MEGDNN_MARK_USED_VAR(transform_mid_buf);
megdnn_assert(
(oc_end - oc_start) % pack_size == 0 && oc_start % pack_size == 0 &&
oc_end % pack_size == 0 && IC % pack_size == 0 &&
OC % pack_size == 0,
"NCHW88 Winograd filter transform requires both OC and IC "
"are times of 8");
size_t ICB = IC / pack_size;
for (size_t ocb = oc_start / pack_size; ocb < oc_end / pack_size; ocb++) {
for (size_t icb = 0; icb < ICB; icb++) {
for (size_t ic_inner = 0; ic_inner < pack_size; ic_inner++) {
const gi_float16_t* fptr =
reinterpret_cast<const gi_float16_t*>(filter) +
(ocb * ICB + icb) * KERNEL_SIZE * KERNEL_SIZE * pack_size *
pack_size +
ic_inner * pack_size;
#define cb(m, n) \
GI_FLOAT16_t g##m##n = \
GiLoadFloat16(fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size);
UNROLL_CALL_NOWRAPPER_D2(3, 3, cb)
#undef cb
/*
* auto wd##n##0 = g##0##n;
* tmp0 = (g##0##n + g##2##n) * -0.2222222f;
* tmp1 = g##1##n * -0.2222222f;
* auto wd##n##1 = tmp0 + tmp1;
* auto wd##n##2 = tmp0 - tmp1;
* tmp0 = g##0##n * 0.0111111f + g##2##n * 0.0444444f;
* tmp1 = g##1##n * 0.0222222f;
* auto wd##n##3 = tmp0 + tmp1;
* auto wd##n##4 = tmp0 - tmp1;
* tmp0 = g##0##n * 0.7111111f + g##2##n * 0.1777778f;
* tmp1 = g##1##n * 0.3555556f;
* auto wd##n##5 = tmp0 + tmp1;
* auto wd##n##6 = tmp0 - tmp1;
* auto wd##n##7 = g##2##n;
*/
#define FILTER_TRANSFORM(n, wd, g) \
auto wd##n##0 = g##0##n; \
tmp0 = MULSF16(ADDF16(g##0##n, g##2##n), -2.0f / 9); \
tmp1 = MULSF16(g##1##n, -2.0f / 9); \
auto wd##n##1 = ADDF16(tmp0, tmp1); \
auto wd##n##2 = SUBF16(tmp0, tmp1); \
tmp0 = ADDF16(MULSF16(g##0##n, 1.0f / 90), MULSF16(g##2##n, 2.0f / 45)); \
tmp1 = MULSF16(g##1##n, 1.0f / 45); \
auto wd##n##3 = ADDF16(tmp0, tmp1); \
auto wd##n##4 = SUBF16(tmp0, tmp1); \
tmp0 = ADDF16(MULSF16(g##0##n, 0.7111111f), MULSF16(g##2##n, 0.1777778f)); \
tmp1 = MULSF16(g##1##n, 0.3555556f); \
auto wd##n##5 = ADDF16(tmp0, tmp1); \
auto wd##n##6 = SUBF16(tmp0, tmp1); \
auto wd##n##7 = g##2##n;
GI_FLOAT16_t tmp0, tmp1;
UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g);
UNROLL_CALL_RAW(8, FILTER_TRANSFORM, ret, wd);
#undef FILTER_TRANSFORM
#define cb_save(m, n) \
GiStoreFloat16( \
reinterpret_cast<gi_float16_t*>(filter_transform_buf) + \
(m * alpha + n) * OC * IC + ocb * IC * pack_size + \
icb * pack_size * pack_size + ic_inner * pack_size, \
ret##m##n);
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb_save)
#undef cb_save
}
}
}
}
void winograd_F63_mk8_f16_nchw88::input(
const dt_float16* input, dt_float16* input_transform_buf,
dt_float16* transform_mid_buf, size_t IH, size_t IW, size_t IC, size_t PH,
size_t PW, size_t unit_start_idx, size_t nr_units_in_tile) {
constexpr size_t pack_size = 8;
megdnn_assert(IC % pack_size == 0);
constexpr int alpha = 3 + 6 - 1;
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
gi_float16_t* patch = reinterpret_cast<gi_float16_t*>(transform_mid_buf);
gi_float16_t* patchT = reinterpret_cast<gi_float16_t*>(transform_mid_buf) +
pack_size * alpha * alpha;
for (size_t ic = 0; ic < IC; ic += pack_size) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) {
InputTransformF63_NCHW88::prepare<true>(
reinterpret_cast<const gi_float16_t*>(input), patch, patchT,
ih_start, iw_start, IH, IW, ic, IC);
InputTransformF63_NCHW88::transform(
patchT, reinterpret_cast<gi_float16_t*>(input_transform_buf),
unit_idx, nr_units_in_tile, ic, IC);
} else {
InputTransformF63_NCHW88::prepare<false>(
reinterpret_cast<const gi_float16_t*>(input), patch, patchT,
ih_start, iw_start, IH, IW, ic, IC);
InputTransformF63_NCHW88::transform(
patchT, reinterpret_cast<gi_float16_t*>(input_transform_buf),
unit_idx, nr_units_in_tile, ic, IC);
}
}
}
}
void winograd_F63_mk8_f16_nchw88::output(
const dt_float16* output_transform_buf, const dt_float16* bias,
dt_float16* output, dt_float16* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t OH, size_t OW, size_t oc_start, size_t oc_end,
size_t unit_start_idx, size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { \
size_t oc_index = oc - oc_start; \
rep(unit_idx, nr_units_in_tile) { \
size_t index = unit_start_idx + unit_idx; \
auto nh = index / units_w; \
auto nw = index % units_w; \
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \
OutputTransformF63_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \
reinterpret_cast<const gi_float16_t*>(output_transform_buf), \
reinterpret_cast<const gi_float16_t*>(bias), \
reinterpret_cast<gi_float16_t*>(output), \
reinterpret_cast<gi_float16_t*>(transform_mid_buf), oh_start, \
ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, \
nr_units_in_tile, src_dtype, dst_dtype); \
} \
}
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
constexpr size_t pack_size = 8;
size_t OC = oc_end - oc_start;
megdnn_assert(
OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0,
"NCHW88 Winograd filter transform requires OC is times of 8");
GI_DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_fallback_winograd_fp16_F63_mk8, cb, gi_float16_t, gi_float16_t,
bmode, nonline_mode);
#undef cb
}
} // namespace winograd
} // namespace fallback
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
......@@ -204,6 +204,11 @@ public:
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_gi_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP16WinogradF63_8x8_NCHW88(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_gi_winograd_algos.emplace_back(refhold.back().get());
}
}
#endif
......
......@@ -228,6 +228,7 @@ public:
GI_COMMON_WINOGRAD_F73_4X4_NCHW44_F32,
GI_COMMON_WINOGRAD_F23_8X8_NCHW88_F16,
GI_COMMON_WINOGRAD_F43_8X8_NCHW88_F16,
GI_COMMON_WINOGRAD_F63_8X8_NCHW88_F16,
GI_COMMON_DIRECT_FP32,
GI_COMMON_DIRECT_STRD1_FP32,
GI_COMMON_DIRECT_STRD2_FP32,
......@@ -397,6 +398,7 @@ private:
class AlgoFP16WinogradF23_8x8_NCHW88;
class AlgoFP16WinogradF43_8x8_NCHW88;
class AlgoFP16WinogradF63_8x8_NCHW88;
class AlgoF32Direct;
class AlgoF32DirectStride1;
......
......@@ -634,6 +634,19 @@ TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_F43_8_NCHW88_FP16) {
"8:4:", checker, args, &rng, 0.006, param::MatrixMul::Format::MK8,
"WINOGRAD_NCHW88");
}
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_F63_8_NCHW88_FP16) {
using namespace conv_bias;
std::vector<TestArg> args =
get_nchw88_conv_bias_args({3}, FULL_NLMODE, BR_AND_NO_BIASMODE, 1);
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
handle());
Float16PeriodicalRNG rng(0x3c00);
check_winograd_fp16(
"8:6:", checker, args, &rng, 0.019, param::MatrixMul::Format::MK8,
"WINOGRAD_NCHW88");
}
#endif
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_F43_4_WEIGHT_PREPROCESS) {
......@@ -1407,7 +1420,7 @@ TEST_F(FALLBACK, BENCHMARK_GI_WINOGRAD_F23_FP32_NCHW44_VS_FP16_NCHW88) {
std::string algo_name_fp32 = "WINOGRAD_NCHW44:FB_GI_F32_MK4_4x8:4:2";
benchmark_with_contrast(
args_with_computation_fp16, algo_name_fp16, data_type_fp16,
args_with_computation_fp32, algo_name_fp32, data_type_fp32, 10, {1, {0}});
args_with_computation_fp32, algo_name_fp32, data_type_fp32, 10, {1, {4}});
}
TEST_F(FALLBACK, BENCHMARK_GI_WINOGRAD_F43_FP32_NCHW44_VS_FP16_NCHW88) {
......@@ -1447,6 +1460,44 @@ TEST_F(FALLBACK, BENCHMARK_GI_WINOGRAD_F43_FP32_NCHW44_VS_FP16_NCHW88) {
args_with_computation_fp16, algo_name_fp16, data_type_fp16,
args_with_computation_fp32, algo_name_fp32, data_type_fp32, 10, {1, {0}});
}
TEST_F(FALLBACK, BENCHMARK_GI_WINOGRAD_F63_FP32_NCHW44_VS_FP16_NCHW88) {
auto&& args_fp16 = conv_bias::get_winograd_benchmark_args(3, 8, 8);
auto&& args_fp32 = conv_bias::get_winograd_benchmark_args(3, 4, 4);
auto cal_computation = [](const conv_bias::TestArg& arg) {
TensorShape dst_shape{
arg.src[0], arg.filter[0],
(arg.src[2] + arg.param.pad_h * 2 - arg.filter[2]) /
arg.param.stride_h +
1,
(arg.src[3] + arg.param.pad_w * 2 - arg.filter[3]) /
arg.param.stride_w +
1,
arg.filter[5]};
return dst_shape.total_nr_elems() * arg.filter[1] * arg.filter[2] *
arg.filter[3] * arg.filter[4] * 2.0 / (1024 * 1024 * 1024) * 1e3;
};
std::vector<std::pair<conv_bias::TestArg, float>> args_with_computation_fp16,
args_with_computation_fp32;
for (const auto& arg : args_fp16) {
args_with_computation_fp16.emplace_back(arg, cal_computation(arg));
}
for (const auto& arg : args_fp32) {
args_with_computation_fp32.emplace_back(arg, cal_computation(arg));
}
std::vector<DType> data_type_fp16 = {
dtype::Float16(), dtype::Float16(), dtype::Float16(), dtype::Float16()};
std::vector<DType> data_type_fp32 = {
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()};
std::string algo_name_fp16 = "WINOGRAD_NCHW88:FB_GI_F16_MK8_8x8:8:6";
std::string algo_name_fp32 = "WINOGRAD_NCHW44:FB_GI_F32_MK4_4x8:4:6";
benchmark_with_contrast(
args_with_computation_fp16, algo_name_fp16, data_type_fp16,
args_with_computation_fp32, algo_name_fp32, data_type_fp32, 10, {1, {4}});
}
#endif
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册