Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wjd2002
Ncnn
比较版本
b8cf8cb73e7be5fc8be4b558d53efefd9b99b8a8...4a78b6d457c82e6d513792d4ab8020369c223d5d
N
Ncnn
项目概览
wjd2002
/
Ncnn
9 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
N
Ncnn
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
源分支
4a78b6d457c82e6d513792d4ab8020369c223d5d
选择Git版本
...
目标分支
b8cf8cb73e7be5fc8be4b558d53efefd9b99b8a8
选择Git版本
比较
Commits (2)
https://gitcode.net/wjd2002/ncnn/-/commit/e112461d3019f8714f91fcba52e6c7ec4bd20172
write shape, fuse sam image encoder attention (#4792)
2023-06-12T11:43:21+08:00
nihui
nihuini@tencent.com
* write shape, fuse sam image encoder attention * set more dynamic shape as static * less warning for constant tensor node
https://gitcode.net/wjd2002/ncnn/-/commit/4a78b6d457c82e6d513792d4ab8020369c223d5d
Update HUAWEI KunPeng 920 platform (#4795)
2023-06-12T16:48:11+08:00
Zhang Geng
mobtgzhang@outlook.com
隐藏空白更改
内联
并排
Showing
9 changed file
with
354 addition
and
109 deletion
+354
-109
benchmark/README.md
benchmark/README.md
+130
-1
tools/pnnx/src/ir.cpp
tools/pnnx/src/ir.cpp
+1
-3
tools/pnnx/src/pass_level2.cpp
tools/pnnx/src/pass_level2.cpp
+44
-0
tools/pnnx/src/pass_level3/fuse_expression.cpp
tools/pnnx/src/pass_level3/fuse_expression.cpp
+2
-1
tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.cpp
...nx/src/pass_level5/eliminate_reshape_shape_expression.cpp
+69
-12
tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp
tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp
+1
-1
tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp
tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp
+11
-11
tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.cpp
...nnx/src/pass_level5/fuse_scaled_dot_product_attention.cpp
+60
-0
tools/pnnx/src/pass_ncnn/fuse_convert_shufflechannel_slice.cpp
.../pnnx/src/pass_ncnn/fuse_convert_shufflechannel_slice.cpp
+36
-80
未找到文件。
benchmark/README.md
浏览文件 @
4a78b6d4
...
...
@@ -4095,6 +4095,135 @@ cooling_down = 0
vision_transformer min = 1617.60 max = 1634.13 avg = 1625.87
FastestDet min = 10.19 max = 10.55 avg = 10.36
```
### HUAWEI KunPeng 920 2251K (x8 cores)
test on UOS 1050
```
mobtgzhang@mobtgzhang-PC:~/ncnn/benchmark$ ./benchncnn 10 1 0 -1 0
loop_count = 10
num_threads = 1
powersave = 0
gpu_device = -1
cooling_down = 0
squeezenet min = 12.11 max = 12.40 avg = 12.25
squeezenet_int8 min = 14.24 max = 14.50 avg = 14.36
mobilenet min = 20.52 max = 21.11 avg = 20.63
mobilenet_int8 min = 18.29 max = 18.63 avg = 18.45
mobilenet_v2 min = 13.73 max = 13.90 avg = 13.79
mobilenet_v3 min = 11.37 max = 11.49 avg = 11.41
shufflenet min = 7.90 max = 7.96 avg = 7.92
shufflenet_v2 min = 8.09 max = 8.13 avg = 8.11
mnasnet min = 13.26 max = 13.44 avg = 13.30
proxylessnasnet min = 16.19 max = 16.39 avg = 16.26
efficientnet_b0 min = 34.92 max = 35.22 avg = 35.04
efficientnetv2_b0 min = 43.82 max = 44.39 avg = 43.94
regnety_400m min = 17.55 max = 18.02 avg = 17.65
blazeface min = 3.05 max = 3.08 avg = 3.07
googlenet min = 58.65 max = 59.26 avg = 58.89
googlenet_int8 min = 60.55 max = 63.00 avg = 61.96
resnet18 min = 34.27 max = 35.43 avg = 34.84
resnet18_int8 min = 60.79 max = 62.15 avg = 61.47
alexnet min = 42.01 max = 44.43 avg = 43.36
vgg16 min = 174.46 max = 177.33 avg = 175.57
vgg16_int8 min = 453.93 max = 457.03 avg = 454.79
resnet50 min = 95.36 max = 96.27 avg = 95.55
resnet50_int8 min = 119.77 max = 121.26 avg = 120.46
squeezenet_ssd min = 39.05 max = 39.69 avg = 39.20
squeezenet_ssd_int8 min = 55.06 max = 56.23 avg = 55.72
mobilenet_ssd min = 45.20 max = 45.96 avg = 45.49
mobilenet_ssd_int8 min = 39.40 max = 40.13 avg = 39.76
mobilenet_yolo min = 98.86 max = 99.85 avg = 99.34
mobilenetv2_yolov3 min = 51.17 max = 52.89 avg = 51.89
yolov4-tiny min = 66.43 max = 67.23 avg = 66.70
nanodet_m min = 20.59 max = 20.79 avg = 20.71
yolo-fastest-1.1 min = 7.90 max = 7.99 avg = 7.93
yolo-fastestv2 min = 7.45 max = 7.49 avg = 7.47
vision_transformer min = 1586.33 max = 1595.34 avg = 1589.76
FastestDet min = 7.45 max = 7.52 avg = 7.47
mobtgzhang@mobtgzhang-PC:~/ncnn/benchmark$ ./benchncnn 10 8 0 -1 0
loop_count = 10
num_threads = 8
powersave = 0
gpu_device = -1
cooling_down = 0
squeezenet min = 2.93 max = 3.10 avg = 3.00
squeezenet_int8 min = 3.47 max = 3.56 avg = 3.52
mobilenet min = 3.89 max = 4.04 avg = 3.94
mobilenet_int8 min = 3.29 max = 3.39 avg = 3.33
mobilenet_v2 min = 3.95 max = 4.08 avg = 3.98
mobilenet_v3 min = 3.45 max = 3.59 avg = 3.49
shufflenet min = 3.42 max = 4.66 avg = 3.62
shufflenet_v2 min = 2.60 max = 2.94 avg = 2.68
mnasnet min = 3.46 max = 3.57 avg = 3.52
proxylessnasnet min = 3.94 max = 12.34 avg = 4.88
efficientnet_b0 min = 7.31 max = 7.60 avg = 7.38
efficientnetv2_b0 min = 9.01 max = 9.22 avg = 9.08
regnety_400m min = 8.56 max = 9.36 avg = 8.70
blazeface min = 1.36 max = 3.52 avg = 1.60
googlenet min = 11.80 max = 12.02 avg = 11.93
googlenet_int8 min = 11.87 max = 23.09 avg = 13.16
resnet18 min = 7.27 max = 7.64 avg = 7.38
resnet18_int8 min = 11.02 max = 11.73 avg = 11.20
alexnet min = 9.05 max = 9.35 avg = 9.17
vgg16 min = 44.13 max = 50.84 avg = 46.89
vgg16_int8 min = 75.15 max = 80.73 avg = 77.52
resnet50 min = 18.72 max = 27.49 avg = 19.96
resnet50_int8 min = 22.72 max = 36.80 avg = 26.78
squeezenet_ssd min = 13.96 max = 27.42 avg = 15.62
squeezenet_ssd_int8 min = 15.01 max = 29.53 avg = 19.51
mobilenet_ssd min = 9.37 max = 13.34 avg = 10.44
mobilenet_ssd_int8 min = 8.07 max = 24.28 avg = 9.83
mobilenet_yolo min = 22.06 max = 24.89 avg = 22.91
mobilenetv2_yolov3 min = 14.41 max = 15.97 avg = 14.78
yolov4-tiny min = 20.71 max = 23.96 avg = 21.42
nanodet_m min = 6.37 max = 6.59 avg = 6.45
yolo-fastest-1.1 min = 4.27 max = 4.52 avg = 4.34
yolo-fastestv2 min = 3.53 max = 3.63 avg = 3.58
vision_transformer min = 435.60 max = 523.43 avg = 479.70
FastestDet min = 3.54 max = 7.95 avg = 5.24
mobtgzhang@mobtgzhang-PC:~/ncnn/benchmark$ ./benchncnn 10 4 2 -1 0
loop_count = 10
num_threads = 4
powersave = 2
gpu_device = -1
cooling_down = 0
squeezenet min = 4.04 max = 4.22 avg = 4.09
squeezenet_int8 min = 4.64 max = 4.76 avg = 4.69
mobilenet min = 6.04 max = 6.06 avg = 6.05
mobilenet_int8 min = 5.23 max = 5.32 avg = 5.25
mobilenet_v2 min = 5.00 max = 5.03 avg = 5.01
mobilenet_v3 min = 4.49 max = 4.69 avg = 4.52
shufflenet min = 3.90 max = 3.94 avg = 3.91
shufflenet_v2 min = 3.27 max = 3.48 avg = 3.33
mnasnet min = 4.80 max = 4.83 avg = 4.82
proxylessnasnet min = 5.20 max = 5.28 avg = 5.23
efficientnet_b0 min = 10.53 max = 11.06 avg = 10.68
efficientnetv2_b0 min = 13.18 max = 13.37 avg = 13.25
regnety_400m min = 9.20 max = 9.25 avg = 9.22
blazeface min = 1.43 max = 1.45 avg = 1.44
googlenet min = 17.63 max = 17.78 avg = 17.71
googlenet_int8 min = 17.63 max = 18.03 avg = 17.85
resnet18 min = 10.34 max = 10.59 avg = 10.40
resnet18_int8 min = 17.93 max = 18.84 avg = 18.25
alexnet min = 13.28 max = 13.37 avg = 13.31
vgg16 min = 55.41 max = 56.60 avg = 55.70
vgg16_int8 min = 123.71 max = 125.34 avg = 124.48
resnet50 min = 27.82 max = 28.22 avg = 27.95
resnet50_int8 min = 34.50 max = 34.89 avg = 34.70
squeezenet_ssd min = 14.67 max = 15.19 avg = 14.85
squeezenet_ssd_int8 min = 19.76 max = 20.32 avg = 19.87
mobilenet_ssd min = 13.15 max = 13.38 avg = 13.21
mobilenet_ssd_int8 min = 11.52 max = 11.70 avg = 11.60
mobilenet_yolo min = 30.95 max = 31.28 avg = 31.05
mobilenetv2_yolov3 min = 20.04 max = 20.36 avg = 20.16
yolov4-tiny min = 25.61 max = 26.73 avg = 25.80
nanodet_m min = 7.93 max = 7.97 avg = 7.95
yolo-fastest-1.1 min = 4.52 max = 4.59 avg = 4.53
yolo-fastestv2 min = 3.74 max = 3.88 avg = 3.77
vision_transformer min = 546.94 max = 726.81 avg = 698.27
FastestDet min = 3.59 max = 3.61 avg = 3.60
```
### Intel Atom x5-Z8350
```
...
...
@@ -5738,4 +5867,4 @@ cooling_down = 0
yolo-fastestv2 min = 5.68 max = 7.20 avg = 5.88
vision_transformer min = 600.83 max = 666.35 avg = 617.33
FastestDet min = 6.05 max = 6.72 avg = 6.23
```
\ No newline at end of file
```
tools/pnnx/src/ir.cpp
浏览文件 @
4a78b6d4
...
...
@@ -269,10 +269,8 @@ Parameter::Parameter(const torch::jit::Node* value_node)
}
else
{
const
int
ndim
=
(
int
)
t
.
dim
();
// constant tensor will become pnnx attribute node later
type
=
8
;
fprintf
(
stderr
,
"unknown Parameter value kind %s of TensorType, t.dim = %d
\n
"
,
value_node
->
kind
().
toDisplayString
(),
ndim
);
}
break
;
...
...
tools/pnnx/src/pass_level2.cpp
浏览文件 @
4a78b6d4
...
...
@@ -100,6 +100,50 @@ void GraphRewriterPass::write(Operator* op, const std::map<std::string, Paramete
op
->
params
[
x
.
first
]
=
Parameter
::
parse_from_string
(
str
);
}
for
(
size_t
i
=
0
;
i
<
op
->
inputs
.
size
();
i
++
)
{
Operand
*
operand
=
op
->
inputs
[
i
];
std
::
vector
<
int
>&
shape
=
operand
->
shape
;
for
(
size_t
j
=
0
;
j
<
shape
.
size
();
j
++
)
{
int
ai
=
shape
[
j
];
if
(
ai
==
-
233
)
{
std
::
string
key
=
operand
->
params
.
at
(
std
::
string
(
"__shape_"
)
+
std
::
to_string
(
j
)).
s
;
if
(
captured_params
.
find
(
key
)
==
captured_params
.
end
())
{
fprintf
(
stderr
,
"replace pattern param %%%s missing captured
\n
"
,
key
.
c_str
());
return
;
}
shape
[
j
]
=
captured_params
.
at
(
key
).
i
;
}
}
}
for
(
size_t
i
=
0
;
i
<
op
->
outputs
.
size
();
i
++
)
{
Operand
*
operand
=
op
->
outputs
[
i
];
std
::
vector
<
int
>&
shape
=
operand
->
shape
;
for
(
size_t
j
=
0
;
j
<
shape
.
size
();
j
++
)
{
int
ai
=
shape
[
j
];
if
(
ai
==
-
233
)
{
std
::
string
key
=
operand
->
params
.
at
(
std
::
string
(
"__shape_"
)
+
std
::
to_string
(
j
)).
s
;
if
(
captured_params
.
find
(
key
)
==
captured_params
.
end
())
{
fprintf
(
stderr
,
"replace pattern param %%%s missing captured
\n
"
,
key
.
c_str
());
return
;
}
shape
[
j
]
=
captured_params
.
at
(
key
).
i
;
}
}
}
}
void
GraphRewriterPass
::
write
(
Operator
*
op
,
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
,
const
std
::
map
<
std
::
string
,
Attribute
>&
captured_attrs
)
const
...
...
tools/pnnx/src/pass_level3/fuse_expression.cpp
浏览文件 @
4a78b6d4
...
...
@@ -100,7 +100,8 @@ static bool operand_maybe_tensor(const Operand* operand)
||
op
->
type
==
"aten::div"
||
op
->
type
==
"aten::floor_divide"
||
op
->
type
==
"aten::mul"
||
op
->
type
==
"aten::pow"
)
||
op
->
type
==
"aten::pow"
||
op
->
type
==
"aten::remainder"
)
{
return
operand_maybe_tensor
(
op
->
inputs
[
0
])
||
operand_maybe_tensor
(
op
->
inputs
[
1
]);
}
...
...
tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.cpp
浏览文件 @
4a78b6d4
...
...
@@ -31,13 +31,12 @@ static bool token_is_interger_literal(const std::string& t)
return
iss
.
eof
()
&&
!
iss
.
fail
();
}
static
std
::
vector
<
int
>
build_shape
(
const
std
::
string
&
expr
)
static
void
build_shape
(
const
std
::
string
&
expr
,
std
::
vector
<
int
>&
shape
,
std
::
vector
<
std
::
string
>&
expr_tokens
)
{
std
::
string
listexpr
=
expr
.
substr
(
1
,
expr
.
size
()
-
2
);
std
::
vector
<
int
>
shape
;
std
::
string
t
;
std
::
string
et
;
int
level
=
0
;
for
(
size_t
i
=
0
;
i
<
listexpr
.
size
();
i
++
)
{
...
...
@@ -47,21 +46,26 @@ static std::vector<int> build_shape(const std::string& expr)
{
level
+=
1
;
t
=
"-1"
;
et
+=
ch
;
}
else
if
(
ch
==
')'
||
ch
==
']'
)
{
level
-=
1
;
t
=
"-1"
;
et
+=
ch
;
}
else
if
(
level
==
0
&&
ch
==
','
)
{
int
dimsize
=
token_is_interger_literal
(
t
)
?
std
::
stoi
(
t
)
:
-
1
;
shape
.
push_back
(
dimsize
);
expr_tokens
.
push_back
(
et
);
t
.
clear
();
et
.
clear
();
}
else
{
t
+=
ch
;
et
+=
ch
;
}
}
...
...
@@ -71,7 +75,26 @@ static std::vector<int> build_shape(const std::string& expr)
shape
.
push_back
(
dimsize
);
}
return
shape
;
if
(
level
==
0
&&
!
et
.
empty
())
{
expr_tokens
.
push_back
(
et
);
}
}
static
std
::
string
build_expr
(
const
std
::
vector
<
std
::
string
>&
expr_tokens
)
{
std
::
string
expr
;
expr
+=
'['
;
for
(
int
i
=
0
;
i
<
(
int
)
expr_tokens
.
size
();
i
++
)
{
expr
+=
expr_tokens
[
i
];
if
(
i
!=
(
int
)
expr_tokens
.
size
()
-
1
)
expr
+=
','
;
}
expr
+=
']'
;
return
expr
;
}
void
eliminate_reshape_shape_expression
(
Graph
&
graph
)
...
...
@@ -98,18 +121,21 @@ void eliminate_reshape_shape_expression(Graph& graph)
if
(
expr
.
empty
()
||
expr
[
0
]
!=
'['
)
continue
;
std
::
vector
<
int
>
shape
=
build_shape
(
expr
);
std
::
vector
<
int
>
outshape
=
op
->
outputs
[
0
]
->
shape
;
if
(
outshape
.
empty
())
continue
;
std
::
vector
<
int
>
shape
;
std
::
vector
<
std
::
string
>
expr_tokens
;
build_shape
(
expr
,
shape
,
expr_tokens
);
// replace -1 with static dim-size
std
::
vector
<
int
>
outshape
=
op
->
outputs
[
0
]
->
shape
;
if
(
!
outshape
.
empty
())
for
(
size_t
j
=
0
;
j
<
outshape
.
size
();
j
++
)
{
for
(
size_t
j
=
0
;
j
<
outshape
.
size
();
j
++
)
if
(
outshape
[
j
]
!=
-
1
)
{
if
(
outshape
[
j
]
!=
-
1
)
{
shape
[
j
]
=
outshape
[
j
];
}
shape
[
j
]
=
outshape
[
j
];
expr_tokens
[
j
]
=
std
::
to_string
(
outshape
[
j
]);
}
}
...
...
@@ -124,7 +150,10 @@ void eliminate_reshape_shape_expression(Graph& graph)
}
if
(
dynamic_dim_count
>
1
)
{
op_expr
->
params
[
"expr"
]
=
build_expr
(
expr_tokens
);
continue
;
}
matched
=
true
;
...
...
@@ -156,6 +185,34 @@ void eliminate_reshape_shape_expression(Graph& graph)
if
(
!
matched
)
break
;
}
for
(
size_t
i
=
0
;
i
<
graph
.
ops
.
size
();
i
++
)
{
Operator
*
op
=
graph
.
ops
[
i
];
if
(
op
->
type
!=
"Tensor.view"
&&
op
->
type
!=
"Tensor.reshape"
)
continue
;
if
(
op
->
inputs
.
size
()
!=
1
)
continue
;
std
::
vector
<
int
>
outshape
=
op
->
outputs
[
0
]
->
shape
;
if
(
outshape
.
empty
())
continue
;
std
::
vector
<
int
>
shape
=
op
->
params
.
at
(
"shape"
).
ai
;
// replace -1 with static dim-size
for
(
size_t
j
=
0
;
j
<
outshape
.
size
();
j
++
)
{
if
(
outshape
[
j
]
!=
-
1
)
{
shape
[
j
]
=
outshape
[
j
];
}
}
op
->
params
[
"shape"
]
=
shape
;
}
}
}
// namespace pnnx
tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp
浏览文件 @
4a78b6d4
...
...
@@ -56,7 +56,7 @@ public:
pnnx.Input input 0 1 input
Tensor.view op_0 1 1 input 13 shape=(%batch,%groups,%channels_per_group,%h,%w)
torch.transpose op_1 1 1 13 14 dim0=1 dim1=2
Tensor.reshape op_2 1 1 14 out shape=(%batch,
-1
,%h,%w)
Tensor.reshape op_2 1 1 14 out shape=(%batch,
%channels
,%h,%w)
pnnx.Output output 1 0 out
)PNNXIR"
;
}
...
...
tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp
浏览文件 @
4a78b6d4
...
...
@@ -1060,14 +1060,14 @@ nn.Linear op_1 1 1 input 4 bias=%kbias in_features=%embed_d
nn.Linear op_2 1 1 input 6 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Expression op_3 1 1 2 3 expr=mul(@0,%inv_sqrt_embed_dim_per_head)
Tensor.view op_4 1 1 3 8 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.view op_5 1 1 4 5 shape=(%batch,
-1
,%num_heads,%feat_per_head)
Tensor.view op_6 1 1 6 7 shape=(%batch,
-1
,%num_heads,%feat_per_head)
Tensor.view op_5 1 1 4 5 shape=(%batch,
%size
,%num_heads,%feat_per_head)
Tensor.view op_6 1 1 6 7 shape=(%batch,
%size
,%num_heads,%feat_per_head)
torch.transpose op_7 1 1 8 9 dim0=1 dim1=2
torch.transpose op_8 1 1 5 10 dim0=1 dim1=2
torch.transpose op_9 1 1 7 11 dim0=1 dim1=2
Tensor.reshape op_10 1 1 9 14 shape=(%num_heads,
-1
,%feat_per_head)
Tensor.reshape op_11 1 1 10 12 shape=(%num_heads,
-1
,%feat_per_head)
Tensor.reshape op_12 1 1 11 17 shape=(%num_heads,
-1
,%feat_per_head)
Tensor.reshape op_10 1 1 9 14 shape=(%num_heads,
%batch_mul_size
,%feat_per_head)
Tensor.reshape op_11 1 1 10 12 shape=(%num_heads,
%batch_mul_size
,%feat_per_head)
Tensor.reshape op_12 1 1 11 17 shape=(%num_heads,
%batch_mul_size
,%feat_per_head)
torch.transpose op_13 1 1 12 13 dim0=1 dim1=2
torch.bmm op_14 2 1 14 13 15
F.softmax op_15 1 1 15 16 dim=-1
...
...
@@ -1094,14 +1094,14 @@ nn.Linear op_1 1 1 input 5 bias=%kbias in_features=%embed_d
nn.Linear op_2 1 1 input 7 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Expression op_3 1 1 3 4 expr=mul(@0,%inv_sqrt_embed_dim_per_head)
Tensor.view op_4 1 1 4 9 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.view op_5 1 1 5 6 shape=(%batch,
-1
,%num_heads,%feat_per_head)
Tensor.view op_6 1 1 7 8 shape=(%batch,
-1
,%num_heads,%feat_per_head)
Tensor.view op_5 1 1 5 6 shape=(%batch,
%size
,%num_heads,%feat_per_head)
Tensor.view op_6 1 1 7 8 shape=(%batch,
%size
,%num_heads,%feat_per_head)
torch.transpose op_7 1 1 9 10 dim0=1 dim1=2
torch.transpose op_8 1 1 6 11 dim0=1 dim1=2
torch.transpose op_9 1 1 8 12 dim0=1 dim1=2
Tensor.reshape op_10 1 1 10 15 shape=(%num_heads,
-1
,%feat_per_head)
Tensor.reshape op_11 1 1 11 13 shape=(%num_heads,
-1
,%feat_per_head)
Tensor.reshape op_12 1 1 12 21 shape=(%num_heads,
-1
,%feat_per_head)
Tensor.reshape op_10 1 1 10 15 shape=(%num_heads,
%batch_mul_size
,%feat_per_head)
Tensor.reshape op_11 1 1 11 13 shape=(%num_heads,
%batch_mul_size
,%feat_per_head)
Tensor.reshape op_12 1 1 12 21 shape=(%num_heads,
%batch_mul_size
,%feat_per_head)
torch.transpose op_13 1 1 13 14 dim0=1 dim1=2
torch.bmm op_14 2 1 15 14 16
Tensor.view op_15 1 1 16 17 shape=(%batch,%num_heads,%size,%size)
...
...
@@ -1301,7 +1301,7 @@ pnnx.Expression op_7 2 1 33 attn_mask 35 expr=add(@0,@1)
Tensor.view op_8 1 1 35 36 shape=(1,%batch,%num_heads,%size,%size)
pnnx.Attribute op_9 0 1 37 @data=(1,%batch,1,%size,%size)f32
pnnx.Expression op_10 2 1 36 37 38 expr=add(@0,@1)
Tensor.view op_11 1 1 38 39 shape=(
-1
,%num_heads,%size,%size)
Tensor.view op_11 1 1 38 39 shape=(
%batch
,%num_heads,%size,%size)
F.softmax op_12 1 1 39 40 dim=-1
torch.matmul op_13 2 1 40 30 41
torch.transpose op_14 1 1 41 42 dim0=1 dim1=2
...
...
tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.cpp
浏览文件 @
4a78b6d4
...
...
@@ -79,13 +79,73 @@ pnnx.Output output 1 0 out
}
};
class
fuse_scaled_dot_product_attention_pass_1
:
public
GraphRewriterPass
{
public:
const
char
*
match_pattern_graph
()
const
{
return
R"PNNXIR(7767517
14 13
pnnx.Input input_0 0 1 query #query=(%batch,%qsize,%feat_per_head)f32
pnnx.Input input_1 0 1 key #key=(%batch,%kvsize,%feat_per_head)f32
pnnx.Input input_2 0 1 value #value=(%batch,%kvsize,%feat_per_head)f32
pnnx.Input input_Rh 0 1 Rh #Rh=(%batch,%h,%w,%h,1)f32
pnnx.Input input_Rw 0 1 Rw #Rw=(%batch,%h,%w,1,%w)f32
pnnx.Expression op_0 1 1 query 17 expr=mul(@0,%inv_sqrt_embed_dim_per_head)
torch.transpose op_1 1 1 key 22 dim0=-2 dim1=-1
torch.matmul op_2 2 1 17 22 23
Tensor.view op_3 1 1 23 24 shape=(%batch,%h,%w,%h,%w)
pnnx.Expression op_4 3 1 24 Rh Rw 28 expr=add(add(@0,@1),@2)
Tensor.view op_5 1 1 28 29 shape=(%batch,%qsize,%qsize)
F.softmax op_6 1 1 29 30 dim=-1
torch.matmul op_7 2 1 30 value out
pnnx.Output output 1 0 out
)PNNXIR"
;
}
const
char
*
replace_pattern_graph
()
const
{
return
R"PNNXIR(7767517
9 8
pnnx.Input input_0 0 1 query
pnnx.Input input_1 0 1 key
pnnx.Input input_2 0 1 value
pnnx.Input input_Rh 0 1 Rh
pnnx.Input input_Rw 0 1 Rw
pnnx.Expression RhRw 2 1 Rh Rw RhRw expr=add(@0,@1) #RhRw=(%batch,%h,%w,%h,%w)f32
Tensor.reshape attn_mask 1 1 RhRw attn_mask shape=(%batch,%qsize,%qsize) #attn_mask=(%batch,%qsize,%qsize)f32
F.scaled_dot_product_attention op_0 4 1 query key value attn_mask out dropout_p=0.0 is_causal=False $attn_mask=attn_mask
pnnx.Output output 1 0 out
)PNNXIR"
;
}
bool
match
(
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
)
const
{
const
int
qsize
=
captured_params
.
at
(
"qsize"
).
i
;
const
int
h
=
captured_params
.
at
(
"h"
).
i
;
const
int
w
=
captured_params
.
at
(
"w"
).
i
;
const
int
feat_per_head
=
captured_params
.
at
(
"feat_per_head"
).
i
;
const
float
inv_sqrt_embed_dim_per_head
=
captured_params
.
at
(
"inv_sqrt_embed_dim_per_head"
).
f
;
if
(
qsize
!=
h
*
w
)
return
false
;
if
(
!
NearlyEqual
(
inv_sqrt_embed_dim_per_head
,
1.
f
/
sqrt
(
feat_per_head
),
0.001
))
return
false
;
return
true
;
}
};
void
fuse_scaled_dot_product_attention
(
Graph
&
graph
)
{
#if TORCH_VERSION_MAJOR >= 2
fuse_scaled_dot_product_attention_pass
a
;
fuse_scaled_dot_product_attention_pass_1
b
;
int
opindex
=
0
;
pnnx_graph_rewrite
(
graph
,
&
a
,
opindex
);
pnnx_graph_rewrite
(
graph
,
&
b
,
opindex
);
#endif
}
...
...
tools/pnnx/src/pass_ncnn/fuse_convert_shufflechannel_slice.cpp
浏览文件 @
4a78b6d4
...
...
@@ -35,51 +35,58 @@ public:
{
return
R"PNNXIR(7767517
6 6
pnnx.Input input 0 1 input
Tensor.reshape op_0 1 1 input a shape=
%shape
torch.permute op_1 1 1 a b dims=
%dims
Tensor.reshape op_2 1 1 b c shape=
%shape2
pnnx.Input input 0 1 input
#input=(%batch,%c,%h,%w)f32
Tensor.reshape op_0 1 1 input a shape=
(%batch_mul_ch_per_group,%groups,%h_mul_w)
torch.permute op_1 1 1 a b dims=
(1,0,2)
Tensor.reshape op_2 1 1 b c shape=
(%groups,%batch,%ch_per_group,%h,%w)
torch.unbind op_3 1 2 c out0 out1 dim=0
pnnx.Output output 2 0 out0 out1
)PNNXIR"
;
}
const
char
*
type_str
()
const
{
return
"ncnn._shufflechannel_slice"
;
}
const
char
*
name_str
()
const
const
char
*
replace_pattern_graph
()
const
{
return
"shufflechannel_slice"
;
return
R"PNNXIR(7767517
4 4
pnnx.Input input 0 1 input
ShuffleChannel shufflechannel 1 1 input a 0=%groups 1=1 #a=(%batch,%c,%h,%w)f32
Slice slice 1 2 a out0 out1 0=(-233,-233) 1=0
pnnx.Output output 2 0 out0 out1
)PNNXIR"
;
}
bool
match
(
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
)
const
{
// (116,2,1024)
// (1,0,2)
// (2,-1,116,32,32)
const
std
::
vector
<
int
>&
shape
=
captured_params
.
at
(
"shape"
).
ai
;
const
std
::
vector
<
int
>&
dims
=
captured_params
.
at
(
"dims"
).
ai
;
const
std
::
vector
<
int
>&
shape2
=
captured_params
.
at
(
"shape2"
).
ai
;
if
(
dims
!=
std
::
vector
<
int
>
{
1
,
0
,
2
})
const
int
groups
=
captured_params
.
at
(
"groups"
).
i
;
const
int
batch
=
captured_params
.
at
(
"batch"
).
i
;
const
int
batch_mul_ch_per_group
=
captured_params
.
at
(
"batch_mul_ch_per_group"
).
i
;
const
int
ch_per_group
=
captured_params
.
at
(
"ch_per_group"
).
i
;
const
int
h_mul_w
=
captured_params
.
at
(
"h_mul_w"
).
i
;
const
int
c
=
captured_params
.
at
(
"c"
).
i
;
const
int
h
=
captured_params
.
at
(
"h"
).
i
;
const
int
w
=
captured_params
.
at
(
"w"
).
i
;
if
(
groups
!=
2
||
groups
*
ch_per_group
!=
c
)
return
false
;
if
(
shape
[
0
]
!=
shape2
[
2
]
||
shape
[
1
]
!=
shape2
[
0
]
||
shape
[
2
]
!=
shape2
[
3
]
*
shape2
[
4
]
||
shape
[
1
]
!=
2
||
shape2
[
1
]
!=
-
1
)
if
(
batch_mul_ch_per_group
!=
batch
*
ch_per_group
)
return
false
;
if
(
h_mul_w
!=
h
*
w
)
return
false
;
return
true
;
}
void
write
(
Operator
*
op
,
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_param
s
)
const
void
write
(
const
std
::
map
<
std
::
string
,
Operator
*>&
ops
,
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
,
const
std
::
map
<
std
::
string
,
Attribute
>&
captured_attr
s
)
const
{
const
std
::
vector
<
int
>&
shape
=
captured_params
.
at
(
"shape"
).
ai
;
GraphRewriterPass
::
write
(
ops
,
captured_params
,
captured_attrs
)
;
int
groups
=
shape
[
1
]
;
const
int
batch_index
=
ops
.
at
(
"shufflechannel"
)
->
inputs
[
0
]
->
params
[
"__batch_index"
].
i
;
op
->
params
[
"0"
]
=
groups
;
op
->
params
[
"1"
]
=
1
;
ops
.
at
(
"slice"
)
->
inputs
[
0
]
->
params
[
"__batch_index"
]
=
batch_index
;
ops
.
at
(
"slice"
)
->
outputs
[
0
]
->
params
[
"__batch_index"
]
=
batch_index
;
ops
.
at
(
"slice"
)
->
outputs
[
1
]
->
params
[
"__batch_index"
]
=
batch_index
;
}
};
...
...
@@ -90,10 +97,10 @@ public:
{
return
R"PNNXIR(7767517
6 6
pnnx.Input input 0 1 input
Tensor.reshape op_0 1 1 input a shape=
%shape
Tensor.permute op_1 1 1 a b dims=
%dims
Tensor.reshape op_2 1 1 b c shape=
%shape2
pnnx.Input input 0 1 input
#input=(%batch,%c,%h,%w)f32
Tensor.reshape op_0 1 1 input a shape=
(%batch_mul_ch_per_group,%groups,%h_mul_w)
Tensor.permute op_1 1 1 a b dims=
(1,0,2)
Tensor.reshape op_2 1 1 b c shape=
(%groups,%batch,%ch_per_group,%h,%w)
torch.unbind op_3 1 2 c out0 out1 dim=0
pnnx.Output output 2 0 out0 out1
)PNNXIR"
;
...
...
@@ -108,57 +115,6 @@ void fuse_convert_shufflechannel_slice(Graph& graph)
pnnx_graph_rewrite
(
graph
,
&
a
,
opindex
);
pnnx_graph_rewrite
(
graph
,
&
b
,
opindex
);
int
op_index
=
0
;
while
(
1
)
{
bool
matched
=
false
;
for
(
Operator
*
op
:
graph
.
ops
)
{
if
(
op
->
type
!=
"ncnn._shufflechannel_slice"
)
continue
;
matched
=
true
;
const
int
batch_index
=
op
->
inputs
[
0
]
->
params
[
"__batch_index"
].
i
;
op
->
type
=
"ShuffleChannel"
;
op
->
name
=
std
::
string
(
"shufflechannel_"
)
+
std
::
to_string
(
op_index
++
);
Operand
*
out0
=
op
->
outputs
[
0
];
Operand
*
out1
=
op
->
outputs
[
1
];
Operator
*
slice
=
graph
.
new_operator_after
(
"Slice"
,
op
->
name
+
"_slice"
,
op
);
Operand
*
slice_in
=
graph
.
new_operand
(
op
->
name
+
"_slice_in"
);
slice_in
->
params
[
"__batch_index"
]
=
batch_index
;
out0
->
params
[
"__batch_index"
]
=
batch_index
;
out1
->
params
[
"__batch_index"
]
=
batch_index
;
slice
->
inputs
.
push_back
(
slice_in
);
slice
->
outputs
.
push_back
(
out0
);
slice
->
outputs
.
push_back
(
out1
);
op
->
outputs
.
clear
();
op
->
outputs
.
push_back
(
slice_in
);
out0
->
producer
=
slice
;
out1
->
producer
=
slice
;
slice_in
->
producer
=
op
;
slice_in
->
consumers
.
push_back
(
slice
);
slice
->
params
[
"0"
]
=
std
::
vector
<
int
>
{
-
233
,
-
233
};
slice
->
params
[
"1"
]
=
0
;
break
;
}
if
(
!
matched
)
break
;
}
}
}
// namespace ncnn
...
...