Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
63032170
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
63032170
编写于
2月 08, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/fallback): add gi fp16 nchw88 winograd F63 algo
GitOrigin-RevId: d986e1cbebd0f9ad89c27a62bd4e951459165d3d
上级
10503cfa
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
693 addition
and
1 deletion
+693
-1
dnn/src/fallback/conv_bias/gi/fp16/algos.cpp
dnn/src/fallback/conv_bias/gi/fp16/algos.cpp
+38
-0
dnn/src/fallback/conv_bias/gi/fp16/algos.h
dnn/src/fallback/conv_bias/gi/fp16/algos.h
+18
-0
dnn/src/fallback/conv_bias/gi/fp16/helper.h
dnn/src/fallback/conv_bias/gi/fp16/helper.h
+1
-0
dnn/src/fallback/conv_bias/gi/fp16/strategy.h
dnn/src/fallback/conv_bias/gi/fp16/strategy.h
+3
-0
dnn/src/fallback/conv_bias/gi/fp16/strategy_f63_mk8_nchw88.cpp
...rc/fallback/conv_bias/gi/fp16/strategy_f63_mk8_nchw88.cpp
+574
-0
dnn/src/fallback/conv_bias/opr_impl.cpp
dnn/src/fallback/conv_bias/opr_impl.cpp
+5
-0
dnn/src/fallback/conv_bias/opr_impl.h
dnn/src/fallback/conv_bias/opr_impl.h
+2
-0
dnn/test/fallback/conv_bias.cpp
dnn/test/fallback/conv_bias.cpp
+52
-1
未找到文件。
dnn/src/fallback/conv_bias/gi/fp16/algos.cpp
浏览文件 @
63032170
...
...
@@ -91,5 +91,43 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(
AlgoFP16WinogradF23_8x8_NCHW88
,
winograd
::
winograd_F23_mk8_f16_nchw88
,
megdnn_fallback_winograd_fp16_nchw88
,
param
::
MatrixMul
::
Format
::
MK8
);
/* =================== AlgoFP16WinogradF63_8x8_NCHW88 ===================== */
bool
ConvBiasImpl
::
AlgoFP16WinogradF63_8x8_NCHW88
::
usable
(
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
/*algo_selection_strategy*/
)
const
{
MIDOUT_BEGIN
(
megdnn_fallback_winograd_fp16_nchw88
,
midout_iv
(
"AlgoFP16WinogradF63_8x8_NCHW88"
_hash
))
{
if
(
param
.
filter_meta
.
icpg
%
8
!=
0
||
param
.
filter_meta
.
ocpg
%
8
!=
0
)
return
false
;
using
Strategy
=
winograd
::
winograd_F63_mk8_f16_nchw88
;
Strategy
strategy
(
param
.
src_type
,
param
.
filter_type
,
param
.
dst_type
);
auto
&&
matmul_param
=
megdnn
::
winograd
::
ConvBias
<
Strategy
,
param
::
MatrixMul
::
Format
::
MK8
>
(
strategy
,
m_tile_size
,
param
)
.
get_matmul_kern_param
(
param
);
return
m_matmul_algo
->
usable
(
matmul_param
)
&&
m_matmul_algo
->
packmode
()
==
fallback
::
MatrixMulImpl
::
AlgoBase
::
PackMode
::
NO_PACK
&&
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
&&
!
param
.
filter_meta
.
should_flip
&&
(
param
.
filter_meta
.
spatial
[
0
]
==
param
.
filter_meta
.
spatial
[
1
]
&&
param
.
filter_meta
.
spatial
[
0
]
==
3
)
&&
(
param
.
filter_meta
.
stride
[
0
]
==
param
.
filter_meta
.
stride
[
1
]
&&
param
.
filter_meta
.
stride
[
0
]
==
1
)
&&
(
param
.
filter_meta
.
dilation
[
0
]
==
param
.
filter_meta
.
dilation
[
1
]
&&
param
.
filter_meta
.
dilation
[
0
]
==
1
)
&&
param
.
compute_mode
==
param
::
ConvBias
::
ComputeMode
::
DEFAULT
&&
param
.
src_type
.
enumv
()
==
DTypeEnum
::
Float16
;
}
MIDOUT_END
();
return
false
;
}
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL
(
AlgoFP16WinogradF63_8x8_NCHW88
,
winograd
::
winograd_F63_mk8_f16_nchw88
,
megdnn_fallback_winograd_fp16_nchw88
,
param
::
MatrixMul
::
Format
::
MK8
);
#endif
// vim: syntax=cpp.doxygen
dnn/src/fallback/conv_bias/gi/fp16/algos.h
浏览文件 @
63032170
...
...
@@ -46,6 +46,24 @@ public:
MEGDNN_DECL_ALGO_TYPE
(
GI_COMMON_WINOGRAD_F23_8X8_NCHW88_F16
)
};
class
ConvBiasImpl
::
AlgoFP16WinogradF63_8x8_NCHW88
final
:
public
AlgoBase
{
public:
AlgoFP16WinogradF63_8x8_NCHW88
(
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
uint32_t
tile_size
)
:
m_matmul_algo
{
matmul_algo
},
m_tile_size
{
tile_size
}
{}
const
char
*
name
()
const
override
{
if
(
m_name
.
empty
())
{
m_name
=
ConvBiasImpl
::
algo_name
<
ConvBias
::
WinogradParam
>
(
m_matmul_algo
->
name
(),
{
8
,
6
,
m_tile_size
,
3
},
param
::
ConvBias
::
Format
::
NCHW88
);
}
return
m_name
.
c_str
();
}
AlgoAttribute
attribute
()
const
override
{
return
AlgoAttribute
::
REPRODUCIBLE
;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
FLOAT16
);
MEGDNN_DECL_ALGO_TYPE
(
GI_COMMON_WINOGRAD_F63_8X8_NCHW88_F16
)
};
}
// namespace fallback
}
// namespace megdnn
...
...
dnn/src/fallback/conv_bias/gi/fp16/helper.h
浏览文件 @
63032170
...
...
@@ -7,4 +7,5 @@
#define MULSF16 GiMultiplyScalerFloat16
#endif
#define CONCAT(a, idx) a##idx
// vim: syntax=cpp.doxygen
dnn/src/fallback/conv_bias/gi/fp16/strategy.h
浏览文件 @
63032170
...
...
@@ -17,6 +17,9 @@ MEGDNN_REG_WINOGRAD_STRATEGY(
MEGDNN_REG_WINOGRAD_STRATEGY
(
dt_float16
,
dt_float16
,
dt_float16
,
dt_float16
,
2
,
3
,
8
,
8
,
winograd_F23_mk8_f16_nchw88
)
MEGDNN_REG_WINOGRAD_STRATEGY
(
dt_float16
,
dt_float16
,
dt_float16
,
dt_float16
,
6
,
3
,
8
,
8
,
winograd_F63_mk8_f16_nchw88
)
}
// namespace winograd
}
// namespace fallback
}
// namespace megdnn
...
...
dnn/src/fallback/conv_bias/gi/fp16/strategy_f63_mk8_nchw88.cpp
0 → 100644
浏览文件 @
63032170
#include "src/fallback/conv_bias/gi/fp16/strategy.h"
#if defined(GI_SUPPORT_F16)
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/common/winograd/winograd_helper.h"
#include "src/fallback/conv_bias/gi/fp16/helper.h"
#include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/fallback/elemwise_helper/op_unary.h"
#include "midout.h"
MIDOUT_DECL
(
megdnn_fallback_winograd_fp16_F63_mk8
)
using
namespace
megdnn
;
using
namespace
fallback
;
namespace
{
constexpr
size_t
alpha
=
6
+
3
-
1
;
constexpr
size_t
pack_size
=
8
;
constexpr
gi_float16_t
input_parameters
[
16
]
=
{
5.25
f
,
4.25
f
,
0.5
f
,
0.25
f
,
2.5
f
,
1.25
f
,
2.0
f
,
4.0
f
,
5.0
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
};
struct
InputTransformF63_NCHW88
{
template
<
bool
inner
>
static
void
prepare
(
const
gi_float16_t
*
input
,
gi_float16_t
*
patch
,
gi_float16_t
*
patchT
,
int
ih_start
,
int
iw_start
,
size_t
IH
,
size_t
IW
,
size_t
ic
,
size_t
IC
)
{
MEGDNN_MARK_USED_VAR
(
patch
);
size_t
IW8
=
IW
*
pack_size
;
size_t
iw8_start
=
iw_start
*
pack_size
;
size_t
icb
=
ic
/
pack_size
;
if
(
!
(
inner
&&
ic
+
pack_size
<
IC
))
{
memset
(
patchT
,
0
,
sizeof
(
gi_float16_t
)
*
pack_size
*
alpha
*
alpha
);
}
if
(
inner
)
{
const
gi_float16_t
*
input_ptr
=
input
+
icb
*
IH
*
IW8
+
ih_start
*
IW8
+
iw8_start
;
for
(
size_t
ih
=
0
;
ih
<
alpha
;
ih
++
)
{
#define cb(i) auto v##i = GiLoadFloat16(input_ptr + pack_size * i);
UNROLL_CALL_NOWRAPPER
(
8
,
cb
);
#undef cb
#define cb(i) GiStoreFloat16(patchT + ih * pack_size * alpha + i * pack_size, v##i);
UNROLL_CALL_NOWRAPPER
(
8
,
cb
);
#undef cb
input_ptr
+=
IW8
;
}
}
else
{
int
ih0_act
=
std
::
max
<
int
>
(
ih_start
,
0
),
ih1_act
=
std
::
min
<
int
>
(
ih_start
+
alpha
,
IH
),
iw0_act
=
std
::
max
<
int
>
(
iw_start
,
0
),
iw1_act
=
std
::
min
<
int
>
(
iw_start
+
alpha
,
IW
);
const
gi_float16_t
*
input_ptr
=
input
+
icb
*
IH
*
IW8
;
// partial copy
for
(
int
ih
=
ih0_act
;
ih
<
ih1_act
;
++
ih
)
{
for
(
int
iw
=
iw0_act
;
iw
<
iw1_act
;
++
iw
)
{
size_t
iho
=
ih
-
ih_start
,
iwo
=
iw
-
iw_start
;
auto
src
=
GiLoadFloat16
(
input_ptr
+
ih
*
IW8
+
iw
*
pack_size
);
GiStoreFloat16
(
patchT
+
iho
*
pack_size
*
alpha
+
iwo
*
pack_size
,
src
);
}
}
}
}
static
void
transform
(
const
gi_float16_t
*
patchT
,
gi_float16_t
*
input_transform_buf
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
ic
,
size_t
IC
)
{
// BT * d * B
size_t
ICB
=
IC
/
pack_size
;
size_t
icb
=
ic
/
pack_size
;
GI_FLOAT16_t
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
;
#if defined(GI_TARGET_X86) || defined(GI_RVV_INTRINSICS)
//! x86 and rvv GiSimdFmaLane API is slowly, as an alternate, use
//! GiMultiplyAddScalarFloat32
#define MADD(a, b, c, d) GiMultiplyAddScalarFloat16(a, b, *(c + d))
#define MSUB(a, b, c, d) GiMultiplySubScalarFloat16(a, b, *(c + d))
const
gi_float16_t
*
v0
=
input_parameters
+
0
;
const
gi_float16_t
*
v1
=
input_parameters
+
8
;
// const float* v2 = input_parameters + 8;
#else
#define MADD(a, b, c, d) GiSimdFmaLaneFloat16(a, b, c, d)
#define MSUB(a, b, c, d) GiFmsqLaneQFloat16(a, b, c, d)
GI_FLOAT16_t
v0
=
GiLoadFloat16
(
input_parameters
+
0
);
GI_FLOAT16_t
v1
=
GiLoadFloat16
(
input_parameters
+
8
);
// GI_FLOAT32_t v2 = GiLoadFloat32(input_parameters + 8);
#endif
//! B
//! 1 0 0 0 0 0 0 0
//! 0 1 -1 0.5 -0.5 2 -2 -1
//! -5.25 1 1 0.25 0.25 4 4 0
//! 0 -4.25 4.25 -2.5 2.5 -2.5 2.5 5.25
//! 5.25 -4.25 -4.25 -1.25 -1.25 -5 -5 0
//! 0 1 -1 2 -2 0.5 -0.5 -5.25
//! -1 1 1 1 1 1 1 0
//! 0 0 0 0 0 0 0 1
#define cb(i) \
d1 = GiLoadFloat16(patchT + i * alpha * pack_size + 1 * pack_size); \
d2 = GiLoadFloat16(patchT + i * alpha * pack_size + 2 * pack_size); \
d3 = GiLoadFloat16(patchT + i * alpha * pack_size + 3 * pack_size); \
d4 = GiLoadFloat16(patchT + i * alpha * pack_size + 4 * pack_size); \
d5 = GiLoadFloat16(patchT + i * alpha * pack_size + 5 * pack_size); \
d6 = GiLoadFloat16(patchT + i * alpha * pack_size + 6 * pack_size); \
auto t##i##0 = GiLoadFloat16(patchT + i * alpha * pack_size + 0 * pack_size); \
auto t##i##7 = GiLoadFloat16(patchT + i * alpha * pack_size + 7 * pack_size); \
auto t##i##1 = d6; \
auto t##i##2 = d6; \
auto t##i##3 = d6; \
auto t##i##4 = d6; \
auto t##i##5 = d6; \
auto t##i##6 = d6; \
t##i##0 = SUBF16(t##i##0, d6); \
t##i##1 = ADDF16(t##i##1, d1); \
t##i##2 = SUBF16(t##i##2, d1); \
t##i##3 = MADD(t##i##3, d1, v0, 2); \
t##i##4 = MSUB(t##i##4, d1, v0, 2); \
t##i##5 = MADD(t##i##5, d1, v0, 6); \
t##i##6 = MSUB(t##i##6, d1, v0, 6); \
t##i##7 = SUBF16(t##i##7, d1); \
t##i##0 = MSUB(t##i##0, d2, v0, 0); \
t##i##1 = ADDF16(t##i##1, d2); \
t##i##2 = ADDF16(t##i##2, d2); \
t##i##3 = MADD(t##i##3, d2, v0, 3); \
t##i##4 = MADD(t##i##4, d2, v0, 3); \
t##i##5 = MADD(t##i##5, d2, v0, 7); \
t##i##6 = MADD(t##i##6, d2, v0, 7); \
t##i##1 = MSUB(t##i##1, d3, v0, 1); \
t##i##2 = MADD(t##i##2, d3, v0, 1); \
t##i##3 = MSUB(t##i##3, d3, v0, 4); \
t##i##4 = MADD(t##i##4, d3, v0, 4); \
t##i##5 = MSUB(t##i##5, d3, v0, 4); \
t##i##6 = MADD(t##i##6, d3, v0, 4); \
t##i##7 = MADD(t##i##7, d3, v0, 0); \
t##i##0 = MADD(t##i##0, d4, v0, 0); \
t##i##1 = MSUB(t##i##1, d4, v0, 1); \
t##i##2 = MSUB(t##i##2, d4, v0, 1); \
t##i##3 = MSUB(t##i##3, d4, v0, 5); \
t##i##4 = MSUB(t##i##4, d4, v0, 5); \
t##i##5 = MSUB(t##i##5, d4, v1, 0); \
t##i##6 = MSUB(t##i##6, d4, v1, 0); \
t##i##1 = ADDF16(t##i##1, d5); \
t##i##2 = SUBF16(t##i##2, d5); \
t##i##3 = MADD(t##i##3, d5, v0, 6); \
t##i##4 = MSUB(t##i##4, d5, v0, 6); \
t##i##5 = MADD(t##i##5, d5, v0, 2); \
t##i##6 = MSUB(t##i##6, d5, v0, 2); \
t##i##7 = MSUB(t##i##7, d5, v0, 0);
UNROLL_CALL_RAW
(
8
,
cb
);
#undef cb
#define cb(i) \
d0 = t0##i; \
d1 = t6##i; \
d2 = t6##i; \
d3 = t6##i; \
d4 = t6##i; \
d5 = t6##i; \
d6 = t6##i; \
d7 = t7##i; \
d0 = SUBF16(d0, t6##i); \
d1 = ADDF16(d1, t1##i); \
d2 = SUBF16(d2, t1##i); \
d3 = MADD(d3, t1##i, v0, 2); \
d4 = MSUB(d4, t1##i, v0, 2); \
d5 = MADD(d5, t1##i, v0, 6); \
d6 = MSUB(d6, t1##i, v0, 6); \
d7 = SUBF16(d7, t1##i); \
d0 = MSUB(d0, t2##i, v0, 0); \
d1 = ADDF16(d1, t2##i); \
d2 = ADDF16(d2, t2##i); \
d3 = MADD(d3, t2##i, v0, 3); \
d4 = MADD(d4, t2##i, v0, 3); \
d5 = MADD(d5, t2##i, v0, 7); \
d6 = MADD(d6, t2##i, v0, 7); \
d1 = MSUB(d1, t3##i, v0, 1); \
d2 = MADD(d2, t3##i, v0, 1); \
d3 = MSUB(d3, t3##i, v0, 4); \
d4 = MADD(d4, t3##i, v0, 4); \
d5 = MSUB(d5, t3##i, v0, 4); \
d6 = MADD(d6, t3##i, v0, 4); \
d7 = MADD(d7, t3##i, v0, 0); \
d0 = MADD(d0, t4##i, v0, 0); \
d1 = MSUB(d1, t4##i, v0, 1); \
d2 = MSUB(d2, t4##i, v0, 1); \
d3 = MSUB(d3, t4##i, v0, 5); \
d4 = MSUB(d4, t4##i, v0, 5); \
d5 = MSUB(d5, t4##i, v1, 0); \
d6 = MSUB(d6, t4##i, v1, 0); \
d1 = ADDF16(d1, t5##i); \
d2 = SUBF16(d2, t5##i); \
d3 = MADD(d3, t5##i, v0, 6); \
d4 = MSUB(d4, t5##i, v0, 6); \
d5 = MADD(d5, t5##i, v0, 2); \
d6 = MSUB(d6, t5##i, v0, 2); \
d7 = MSUB(d7, t5##i, v0, 0); \
GiStoreFloat16( \
input_transform_buf + \
(0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d0); \
GiStoreFloat16( \
input_transform_buf + \
(1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d1); \
GiStoreFloat16( \
input_transform_buf + \
(2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d2); \
GiStoreFloat16( \
input_transform_buf + \
(3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d3); \
GiStoreFloat16( \
input_transform_buf + \
(4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d4); \
GiStoreFloat16( \
input_transform_buf + \
(5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d5); \
GiStoreFloat16( \
input_transform_buf + \
(6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d6); \
GiStoreFloat16( \
input_transform_buf + \
(7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d7);
UNROLL_CALL_RAW
(
8
,
cb
);
#undef cb
#undef MADD
#undef MSUB
}
};
template
<
BiasMode
bmode
,
typename
Op
>
struct
OutputTransformF63_NCHW88
{
static
void
transform
(
const
gi_float16_t
*
output_transform_buf
,
const
gi_float16_t
*
bias
,
gi_float16_t
*
output
,
gi_float16_t
*
transform_mid_buf
,
size_t
oh_start
,
size_t
ow_start
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
size_t
oc_end
,
size_t
oc_index
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
const
DType
&
src_dtype
,
const
DType
&
dst_dtype
)
{
MEGDNN_MARK_USED_VAR
(
transform_mid_buf
);
Op
op
(
src_dtype
,
dst_dtype
);
//! AT * m * A
size_t
oc
=
oc_start
+
oc_index
;
size_t
OCB
=
(
oc_end
-
oc_start
)
/
pack_size
;
size_t
ocb
=
oc_index
/
pack_size
;
#define cb(m, n) \
auto v##m##n = GiLoadFloat16( \
output_transform_buf + \
(m * alpha + n) * OCB * nr_units_in_tile * pack_size + \
ocb * nr_units_in_tile * pack_size + unit_idx * pack_size);
UNROLL_CALL_NOWRAPPER_D2
(
8
,
8
,
cb
);
#undef cb
/**
* A
*
* 1 0 0 0 0 0
* 1 1 1 1 1 1
* 1 -1 1 -1 1 -1
* 1 2 4 8 16 32
* 1 -2 4 -8 16 -32
* 1 0.5 0.25 0.125 0.0625 0.03125
* 1 -0.5 0.25 -0.125 0.0625 -0.03125
* 0 0 0 0 0 1
*/
/*
* v1addv2 = v1##m + v2##m;
* v1subv2 = v1##m - v2##m;
* v3addv4 = v3##m + v4##m;
* v3subv4 = v3##m - v4##m;
* v5addv6 = v5##m + v6##m;
* v5subv6 = v5##m - v6##m;
* auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6;
* auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f;
* auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f;
* auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f;
* auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f;
* auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m;
*/
GI_FLOAT16_t
v1addv2
,
v1subv2
,
v3addv4
,
v3subv4
,
v5addv6
,
v5subv6
;
#define cb(m) \
v1addv2 = ADDF16(v1##m, v2##m); \
v1subv2 = SUBF16(v1##m, v2##m); \
v3addv4 = ADDF16(v3##m, v4##m); \
v3subv4 = SUBF16(v3##m, v4##m); \
v5addv6 = ADDF16(v5##m, v6##m); \
v5subv6 = SUBF16(v5##m, v6##m); \
auto t0##m = ADDF16(ADDF16(ADDF16(v0##m, v1addv2), v3addv4), v5addv6); \
auto t1##m = \
ADDF16(ADDF16(v1subv2, MULSF16(v3subv4, 2.f)), MULSF16(v5subv6, 0.5f)); \
auto t2##m = \
ADDF16(ADDF16(v1addv2, MULSF16(v3addv4, 4.f)), MULSF16(v5addv6, 0.25f)); \
auto t3##m = \
ADDF16(ADDF16(v1subv2, MULSF16(v3subv4, 8.f)), MULSF16(v5subv6, 0.125f)); \
auto t4##m = ADDF16( \
ADDF16(v1addv2, MULSF16(v3addv4, 16.f)), MULSF16(v5addv6, 0.0625f)); \
auto t5##m = \
ADDF16(ADDF16(ADDF16(v1subv2, MULSF16(v3subv4, 32.f)), \
MULSF16(v5subv6, 0.03125f)), \
v7##m);
UNROLL_CALL_NOWRAPPER
(
8
,
cb
);
#undef cb
/*
* v1addv2 = t##m##1 + t##m##2;
* v1subv2 = t##m##1 - t##m##2;
* v3addv4 = t##m##3 + t##m##4;
* v3subv4 = t##m##3 - t##m##4;
* v5addv6 = t##m##5 + t##m##6;
* v5subv6 = t##m##5 - t##m##6;
* v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6;
* v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f;
* v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f;
* v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f;
* v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f;
* v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7;
*/
#define cb(m) \
v1addv2 = ADDF16(t##m##1, t##m##2); \
v1subv2 = SUBF16(t##m##1, t##m##2); \
v3addv4 = ADDF16(t##m##3, t##m##4); \
v3subv4 = SUBF16(t##m##3, t##m##4); \
v5addv6 = ADDF16(t##m##5, t##m##6); \
v5subv6 = SUBF16(t##m##5, t##m##6); \
v##m##0 = ADDF16(ADDF16(ADDF16(t##m##0, v1addv2), v3addv4), v5addv6); \
v##m##1 = ADDF16(ADDF16(v1subv2, MULSF16(v3subv4, 2.f)), MULSF16(v5subv6, 0.5f)); \
v##m##2 = ADDF16(ADDF16(v1addv2, MULSF16(v3addv4, 4.f)), MULSF16(v5addv6, 0.25f)); \
v##m##3 = \
ADDF16(ADDF16(v1subv2, MULSF16(v3subv4, 8.f)), MULSF16(v5subv6, 0.125f)); \
v##m##4 = ADDF16( \
ADDF16(v1addv2, MULSF16(v3addv4, 16.f)), MULSF16(v5addv6, 0.0625f)); \
v##m##5 = \
ADDF16(ADDF16(ADDF16(v1subv2, MULSF16(v3subv4, 32.f)), \
MULSF16(v5subv6, 0.03125f)), \
t##m##7);
UNROLL_CALL_NOWRAPPER
(
6
,
cb
);
#undef cb
GI_FLOAT16_t
vbias
;
if
(
bmode
==
BiasMode
::
BROADCAST_CHANNEL_BIAS
)
{
vbias
=
GiLoadFloat16
(
bias
+
oc
);
#define cb(m, n) v##m##n = ADDF16(v##m##n, vbias);
UNROLL_CALL_RAW_D2
(
6
,
6
,
cb
);
#undef cb
}
if
(
bmode
!=
BiasMode
::
BIAS
)
{
#define cb(m, n) v##m##n = op(CONCAT(v##m, n));
UNROLL_CALL_RAW_D2
(
6
,
6
,
cb
);
#undef cb
}
#define out_save(oho, owo) \
do { \
size_t oh = oh_start + oho; \
size_t ow = ow_start + owo; \
if (oh < OH && ow < OW) { \
if (bmode == BiasMode::BIAS) { \
v##oho##owo = ADDF16( \
v##oho##owo, GiLoadFloat16( \
bias + oc * OH * OW + \
oh * OW * pack_size + ow * pack_size)); \
v##oho##owo = op(v##oho##owo); \
} \
GiStoreFloat16( \
output + oc * OH * OW + oh * OW * pack_size + ow * pack_size, \
v##oho##owo); \
} \
} while (0);
UNROLL_CALL_RAW_D2
(
6
,
6
,
out_save
);
}
#undef out_save
};
}
// namespace
namespace
megdnn
{
namespace
fallback
{
namespace
winograd
{
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL
(
winograd_F63_mk8_f16_nchw88
)
void
winograd_F63_mk8_f16_nchw88
::
filter
(
const
dt_float16
*
filter
,
dt_float16
*
filter_transform_buf
,
dt_float16
*
transform_mid_buf
,
size_t
OC
,
size_t
IC
,
size_t
oc_start
,
size_t
oc_end
)
{
constexpr
size_t
pack_size
=
8
;
// Gg * GT
// G
// 1.0000000 0.0000000 0.0000000
// -0.2222222 -0.2222222 -0.2222222
// -0.2222222 0.2222222 -0.2222222
// 0.0111111 0.0222222 0.0444444
// 0.0111111 -0.0222222 0.0444444
// 0.7111111 0.3555556 0.1777778
// 0.7111111 -0.3555556 0.1777778
// 0.0000000 0.0000000 1.0000000
MEGDNN_MARK_USED_VAR
(
transform_mid_buf
);
megdnn_assert
(
(
oc_end
-
oc_start
)
%
pack_size
==
0
&&
oc_start
%
pack_size
==
0
&&
oc_end
%
pack_size
==
0
&&
IC
%
pack_size
==
0
&&
OC
%
pack_size
==
0
,
"NCHW88 Winograd filter transform requires both OC and IC "
"are times of 8"
);
size_t
ICB
=
IC
/
pack_size
;
for
(
size_t
ocb
=
oc_start
/
pack_size
;
ocb
<
oc_end
/
pack_size
;
ocb
++
)
{
for
(
size_t
icb
=
0
;
icb
<
ICB
;
icb
++
)
{
for
(
size_t
ic_inner
=
0
;
ic_inner
<
pack_size
;
ic_inner
++
)
{
const
gi_float16_t
*
fptr
=
reinterpret_cast
<
const
gi_float16_t
*>
(
filter
)
+
(
ocb
*
ICB
+
icb
)
*
KERNEL_SIZE
*
KERNEL_SIZE
*
pack_size
*
pack_size
+
ic_inner
*
pack_size
;
#define cb(m, n) \
GI_FLOAT16_t g##m##n = \
GiLoadFloat16(fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size);
UNROLL_CALL_NOWRAPPER_D2
(
3
,
3
,
cb
)
#undef cb
/*
* auto wd##n##0 = g##0##n;
* tmp0 = (g##0##n + g##2##n) * -0.2222222f;
* tmp1 = g##1##n * -0.2222222f;
* auto wd##n##1 = tmp0 + tmp1;
* auto wd##n##2 = tmp0 - tmp1;
* tmp0 = g##0##n * 0.0111111f + g##2##n * 0.0444444f;
* tmp1 = g##1##n * 0.0222222f;
* auto wd##n##3 = tmp0 + tmp1;
* auto wd##n##4 = tmp0 - tmp1;
* tmp0 = g##0##n * 0.7111111f + g##2##n * 0.1777778f;
* tmp1 = g##1##n * 0.3555556f;
* auto wd##n##5 = tmp0 + tmp1;
* auto wd##n##6 = tmp0 - tmp1;
* auto wd##n##7 = g##2##n;
*/
#define FILTER_TRANSFORM(n, wd, g) \
auto wd##n##0 = g##0##n; \
tmp0 = MULSF16(ADDF16(g##0##n, g##2##n), -2.0f / 9); \
tmp1 = MULSF16(g##1##n, -2.0f / 9); \
auto wd##n##1 = ADDF16(tmp0, tmp1); \
auto wd##n##2 = SUBF16(tmp0, tmp1); \
tmp0 = ADDF16(MULSF16(g##0##n, 1.0f / 90), MULSF16(g##2##n, 2.0f / 45)); \
tmp1 = MULSF16(g##1##n, 1.0f / 45); \
auto wd##n##3 = ADDF16(tmp0, tmp1); \
auto wd##n##4 = SUBF16(tmp0, tmp1); \
tmp0 = ADDF16(MULSF16(g##0##n, 0.7111111f), MULSF16(g##2##n, 0.1777778f)); \
tmp1 = MULSF16(g##1##n, 0.3555556f); \
auto wd##n##5 = ADDF16(tmp0, tmp1); \
auto wd##n##6 = SUBF16(tmp0, tmp1); \
auto wd##n##7 = g##2##n;
GI_FLOAT16_t
tmp0
,
tmp1
;
UNROLL_CALL_RAW
(
3
,
FILTER_TRANSFORM
,
wd
,
g
);
UNROLL_CALL_RAW
(
8
,
FILTER_TRANSFORM
,
ret
,
wd
);
#undef FILTER_TRANSFORM
#define cb_save(m, n) \
GiStoreFloat16( \
reinterpret_cast<gi_float16_t*>(filter_transform_buf) + \
(m * alpha + n) * OC * IC + ocb * IC * pack_size + \
icb * pack_size * pack_size + ic_inner * pack_size, \
ret##m##n);
UNROLL_CALL_NOWRAPPER_D2
(
8
,
8
,
cb_save
)
#undef cb_save
}
}
}
}
void
winograd_F63_mk8_f16_nchw88
::
input
(
const
dt_float16
*
input
,
dt_float16
*
input_transform_buf
,
dt_float16
*
transform_mid_buf
,
size_t
IH
,
size_t
IW
,
size_t
IC
,
size_t
PH
,
size_t
PW
,
size_t
unit_start_idx
,
size_t
nr_units_in_tile
)
{
constexpr
size_t
pack_size
=
8
;
megdnn_assert
(
IC
%
pack_size
==
0
);
constexpr
int
alpha
=
3
+
6
-
1
;
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto
units_w
=
div_ceil
<
size_t
>
(
IW
+
2
*
PW
-
KERNEL_SIZE
+
1
,
OUTPUT_BLOCK_SIZE
);
gi_float16_t
*
patch
=
reinterpret_cast
<
gi_float16_t
*>
(
transform_mid_buf
);
gi_float16_t
*
patchT
=
reinterpret_cast
<
gi_float16_t
*>
(
transform_mid_buf
)
+
pack_size
*
alpha
*
alpha
;
for
(
size_t
ic
=
0
;
ic
<
IC
;
ic
+=
pack_size
)
{
rep
(
unit_idx
,
nr_units_in_tile
)
{
size_t
index
=
unit_start_idx
+
unit_idx
;
size_t
nh
=
index
/
units_w
;
size_t
nw
=
index
%
units_w
;
int
ih_start
=
nh
*
OUTPUT_BLOCK_SIZE
-
PH
;
int
iw_start
=
nw
*
OUTPUT_BLOCK_SIZE
-
PW
;
if
(
ih_start
>=
0
&&
ih_start
+
alpha
<=
static_cast
<
int
>
(
IH
)
&&
iw_start
>=
0
&&
iw_start
+
alpha
<=
static_cast
<
int
>
(
IW
))
{
InputTransformF63_NCHW88
::
prepare
<
true
>
(
reinterpret_cast
<
const
gi_float16_t
*>
(
input
),
patch
,
patchT
,
ih_start
,
iw_start
,
IH
,
IW
,
ic
,
IC
);
InputTransformF63_NCHW88
::
transform
(
patchT
,
reinterpret_cast
<
gi_float16_t
*>
(
input_transform_buf
),
unit_idx
,
nr_units_in_tile
,
ic
,
IC
);
}
else
{
InputTransformF63_NCHW88
::
prepare
<
false
>
(
reinterpret_cast
<
const
gi_float16_t
*>
(
input
),
patch
,
patchT
,
ih_start
,
iw_start
,
IH
,
IW
,
ic
,
IC
);
InputTransformF63_NCHW88
::
transform
(
patchT
,
reinterpret_cast
<
gi_float16_t
*>
(
input_transform_buf
),
unit_idx
,
nr_units_in_tile
,
ic
,
IC
);
}
}
}
}
void
winograd_F63_mk8_f16_nchw88
::
output
(
const
dt_float16
*
output_transform_buf
,
const
dt_float16
*
bias
,
dt_float16
*
output
,
dt_float16
*
transform_mid_buf
,
BiasMode
bmode
,
NonlineMode
nonline_mode
,
size_t
OH
,
size_t
OW
,
size_t
oc_start
,
size_t
oc_end
,
size_t
unit_start_idx
,
size_t
nr_units_in_tile
)
{
#define cb(_bmode, _nonline_op, ...) \
for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { \
size_t oc_index = oc - oc_start; \
rep(unit_idx, nr_units_in_tile) { \
size_t index = unit_start_idx + unit_idx; \
auto nh = index / units_w; \
auto nw = index % units_w; \
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \
OutputTransformF63_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \
reinterpret_cast<const gi_float16_t*>(output_transform_buf), \
reinterpret_cast<const gi_float16_t*>(bias), \
reinterpret_cast<gi_float16_t*>(output), \
reinterpret_cast<gi_float16_t*>(transform_mid_buf), oh_start, \
ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, \
nr_units_in_tile, src_dtype, dst_dtype); \
} \
}
auto
units_w
=
div_ceil
<
size_t
>
(
OW
,
OUTPUT_BLOCK_SIZE
);
constexpr
size_t
pack_size
=
8
;
size_t
OC
=
oc_end
-
oc_start
;
megdnn_assert
(
OC
%
pack_size
==
0
&&
oc_start
%
pack_size
==
0
&&
oc_end
%
pack_size
==
0
,
"NCHW88 Winograd filter transform requires OC is times of 8"
);
GI_DISPATCH_CONV_WINOGRAD_BIAS
(
megdnn_fallback_winograd_fp16_F63_mk8
,
cb
,
gi_float16_t
,
gi_float16_t
,
bmode
,
nonline_mode
);
#undef cb
}
}
// namespace winograd
}
// namespace fallback
}
// namespace megdnn
#endif
// vim: syntax=cpp.doxygen
dnn/src/fallback/conv_bias/opr_impl.cpp
浏览文件 @
63032170
...
...
@@ -204,6 +204,11 @@ public:
static_cast
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
(
algo
),
tile_size
));
m_gi_winograd_algos
.
emplace_back
(
refhold
.
back
().
get
());
refhold
.
emplace_back
(
new
AlgoFP16WinogradF63_8x8_NCHW88
(
static_cast
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
(
algo
),
tile_size
));
m_gi_winograd_algos
.
emplace_back
(
refhold
.
back
().
get
());
}
}
#endif
...
...
dnn/src/fallback/conv_bias/opr_impl.h
浏览文件 @
63032170
...
...
@@ -228,6 +228,7 @@ public:
GI_COMMON_WINOGRAD_F73_4X4_NCHW44_F32
,
GI_COMMON_WINOGRAD_F23_8X8_NCHW88_F16
,
GI_COMMON_WINOGRAD_F43_8X8_NCHW88_F16
,
GI_COMMON_WINOGRAD_F63_8X8_NCHW88_F16
,
GI_COMMON_DIRECT_FP32
,
GI_COMMON_DIRECT_STRD1_FP32
,
GI_COMMON_DIRECT_STRD2_FP32
,
...
...
@@ -397,6 +398,7 @@ private:
class
AlgoFP16WinogradF23_8x8_NCHW88
;
class
AlgoFP16WinogradF43_8x8_NCHW88
;
class
AlgoFP16WinogradF63_8x8_NCHW88
;
class
AlgoF32Direct
;
class
AlgoF32DirectStride1
;
...
...
dnn/test/fallback/conv_bias.cpp
浏览文件 @
63032170
...
...
@@ -634,6 +634,19 @@ TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_F43_8_NCHW88_FP16) {
"8:4:"
,
checker
,
args
,
&
rng
,
0.006
,
param
::
MatrixMul
::
Format
::
MK8
,
"WINOGRAD_NCHW88"
);
}
TEST_F
(
FALLBACK_MULTI_THREADS
,
CONVBIAS_GI_WINOGRAD_F63_8_NCHW88_FP16
)
{
using
namespace
conv_bias
;
std
::
vector
<
TestArg
>
args
=
get_nchw88_conv_bias_args
({
3
},
FULL_NLMODE
,
BR_AND_NO_BIASMODE
,
1
);
Checker
<
ConvBiasForward
,
OprWeightPreprocessProxy
<
ConvBiasForward
>>
checker
(
handle
());
Float16PeriodicalRNG
rng
(
0x3c00
);
check_winograd_fp16
(
"8:6:"
,
checker
,
args
,
&
rng
,
0.019
,
param
::
MatrixMul
::
Format
::
MK8
,
"WINOGRAD_NCHW88"
);
}
#endif
TEST_F
(
FALLBACK_MULTI_THREADS
,
CONVBIAS_GI_WINOGRAD_F43_4_WEIGHT_PREPROCESS
)
{
...
...
@@ -1407,7 +1420,7 @@ TEST_F(FALLBACK, BENCHMARK_GI_WINOGRAD_F23_FP32_NCHW44_VS_FP16_NCHW88) {
std
::
string
algo_name_fp32
=
"WINOGRAD_NCHW44:FB_GI_F32_MK4_4x8:4:2"
;
benchmark_with_contrast
(
args_with_computation_fp16
,
algo_name_fp16
,
data_type_fp16
,
args_with_computation_fp32
,
algo_name_fp32
,
data_type_fp32
,
10
,
{
1
,
{
0
}});
args_with_computation_fp32
,
algo_name_fp32
,
data_type_fp32
,
10
,
{
1
,
{
4
}});
}
TEST_F
(
FALLBACK
,
BENCHMARK_GI_WINOGRAD_F43_FP32_NCHW44_VS_FP16_NCHW88
)
{
...
...
@@ -1447,6 +1460,44 @@ TEST_F(FALLBACK, BENCHMARK_GI_WINOGRAD_F43_FP32_NCHW44_VS_FP16_NCHW88) {
args_with_computation_fp16
,
algo_name_fp16
,
data_type_fp16
,
args_with_computation_fp32
,
algo_name_fp32
,
data_type_fp32
,
10
,
{
1
,
{
0
}});
}
TEST_F
(
FALLBACK
,
BENCHMARK_GI_WINOGRAD_F63_FP32_NCHW44_VS_FP16_NCHW88
)
{
auto
&&
args_fp16
=
conv_bias
::
get_winograd_benchmark_args
(
3
,
8
,
8
);
auto
&&
args_fp32
=
conv_bias
::
get_winograd_benchmark_args
(
3
,
4
,
4
);
auto
cal_computation
=
[](
const
conv_bias
::
TestArg
&
arg
)
{
TensorShape
dst_shape
{
arg
.
src
[
0
],
arg
.
filter
[
0
],
(
arg
.
src
[
2
]
+
arg
.
param
.
pad_h
*
2
-
arg
.
filter
[
2
])
/
arg
.
param
.
stride_h
+
1
,
(
arg
.
src
[
3
]
+
arg
.
param
.
pad_w
*
2
-
arg
.
filter
[
3
])
/
arg
.
param
.
stride_w
+
1
,
arg
.
filter
[
5
]};
return
dst_shape
.
total_nr_elems
()
*
arg
.
filter
[
1
]
*
arg
.
filter
[
2
]
*
arg
.
filter
[
3
]
*
arg
.
filter
[
4
]
*
2.0
/
(
1024
*
1024
*
1024
)
*
1e3
;
};
std
::
vector
<
std
::
pair
<
conv_bias
::
TestArg
,
float
>>
args_with_computation_fp16
,
args_with_computation_fp32
;
for
(
const
auto
&
arg
:
args_fp16
)
{
args_with_computation_fp16
.
emplace_back
(
arg
,
cal_computation
(
arg
));
}
for
(
const
auto
&
arg
:
args_fp32
)
{
args_with_computation_fp32
.
emplace_back
(
arg
,
cal_computation
(
arg
));
}
std
::
vector
<
DType
>
data_type_fp16
=
{
dtype
::
Float16
(),
dtype
::
Float16
(),
dtype
::
Float16
(),
dtype
::
Float16
()};
std
::
vector
<
DType
>
data_type_fp32
=
{
dtype
::
Float32
(),
dtype
::
Float32
(),
dtype
::
Float32
(),
dtype
::
Float32
()};
std
::
string
algo_name_fp16
=
"WINOGRAD_NCHW88:FB_GI_F16_MK8_8x8:8:6"
;
std
::
string
algo_name_fp32
=
"WINOGRAD_NCHW44:FB_GI_F32_MK4_4x8:4:6"
;
benchmark_with_contrast
(
args_with_computation_fp16
,
algo_name_fp16
,
data_type_fp16
,
args_with_computation_fp32
,
algo_name_fp32
,
data_type_fp32
,
10
,
{
1
,
{
4
}});
}
#endif
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录