Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wjd2002
Ncnn
提交
05ad0c52
N
Ncnn
项目概览
wjd2002
/
Ncnn
10 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
N
Ncnn
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
05ad0c52
编写于
5月 06, 2023
作者:
N
nihui
提交者:
GitHub
5月 06, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pnnx fuse gelu (#4702)
上级
490816b2
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
132 addition
and
11 deletion
+132
-11
tools/pnnx/src/pass_level2/F_gelu.cpp
tools/pnnx/src/pass_level2/F_gelu.cpp
+111
-0
tools/pnnx/src/pass_level2/F_local_response_norm.cpp
tools/pnnx/src/pass_level2/F_local_response_norm.cpp
+6
-6
tools/pnnx/tests/test_F_gelu.py
tools/pnnx/tests/test_F_gelu.py
+15
-5
未找到文件。
tools/pnnx/src/pass_level2/F_gelu.cpp
浏览文件 @
05ad0c52
...
...
@@ -59,4 +59,115 @@ pnnx.Output output 1 0 out
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS
(
F_gelu_1
,
10
)
class
F_gelu_2
:
public
GraphRewriterPass
{
public:
// x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
const
char
*
match_pattern_graph
()
const
{
return
R"PNNXIR(7767517
11 10
pnnx.Input input_0 0 1 input
prim::Constant op_0 0 1 12 value=%0p5
aten::mul op_1 2 1 input 12 13
prim::Constant op_2 0 1 15 value=%sqrt2
aten::div op_3 2 1 input 15 16
aten::erf op_4 1 1 16 17
prim::Constant op_5 0 1 20 value=%1
prim::Constant op_6 0 1 21 value=1
aten::add op_7 3 1 17 20 21 22
aten::mul op_8 2 1 13 22 out
pnnx.Output output 1 0 out
)PNNXIR"
;
}
bool
match
(
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
)
const
{
if
(
captured_params
.
at
(
"0p5"
).
f
!=
0.5
f
)
return
false
;
if
(
fabs
(
captured_params
.
at
(
"sqrt2"
).
f
-
sqrt
(
2.
f
))
>
0.0001
f
)
return
false
;
if
((
captured_params
.
at
(
"1"
).
type
==
2
&&
captured_params
.
at
(
"1"
).
i
!=
1
)
||
(
captured_params
.
at
(
"1"
).
type
==
3
&&
captured_params
.
at
(
"1"
).
f
!=
1.
f
))
return
false
;
return
true
;
}
const
char
*
type_str
()
const
{
return
"F.gelu"
;
}
void
write
(
Operator
*
/*op*/
,
const
std
::
map
<
std
::
string
,
Parameter
>&
/*captured_params*/
)
const
{
}
};
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS
(
F_gelu_2
,
9
)
class
F_gelu_3
:
public
GraphRewriterPass
{
public:
// 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
const
char
*
match_pattern_graph
()
const
{
return
R"PNNXIR(7767517
17 16
pnnx.Input input_0 0 1 input
prim::Constant op_0 0 1 60 value=%0p5
aten::mul op_1 2 1 input 60 26
prim::Constant op_2 0 1 28 value=%3
aten::pow op_3 2 1 input 28 29
prim::Constant op_4 0 1 30 value=%0p044715
aten::mul op_5 2 1 29 30 31
prim::Constant op_6 0 1 61 value=1
aten::add op_7 3 1 input 31 61 35
prim::Constant op_8 0 1 36 value=%sqrt2dpi
aten::mul op_9 2 1 35 36 37
aten::tanh op_10 1 1 37 39
prim::Constant op_11 0 1 62 value=%1
prim::Constant op_12 0 1 63 value=%1_1
aten::add op_13 3 1 39 62 63 42
aten::mul op_14 2 1 26 42 out
pnnx.Output output 1 0 out
)PNNXIR"
;
}
bool
match
(
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
)
const
{
if
(
captured_params
.
at
(
"0p5"
).
f
!=
0.5
f
)
return
false
;
if
(
fabs
(
captured_params
.
at
(
"0p044715"
).
f
-
0.044715
f
)
>
0.0001
f
)
return
false
;
if
(
fabs
(
captured_params
.
at
(
"sqrt2dpi"
).
f
-
sqrt
(
2.
f
/
M_PI
))
>
0.0001
f
)
return
false
;
if
((
captured_params
.
at
(
"1"
).
type
==
2
&&
captured_params
.
at
(
"1"
).
i
!=
1
)
||
(
captured_params
.
at
(
"1"
).
type
==
3
&&
captured_params
.
at
(
"1"
).
f
!=
1.
f
))
return
false
;
if
((
captured_params
.
at
(
"3"
).
type
==
2
&&
captured_params
.
at
(
"3"
).
i
!=
3
)
||
(
captured_params
.
at
(
"3"
).
type
==
3
&&
captured_params
.
at
(
"3"
).
f
!=
3.
f
))
return
false
;
if
((
captured_params
.
at
(
"1_1"
).
type
==
2
&&
captured_params
.
at
(
"1_1"
).
i
!=
1
)
||
(
captured_params
.
at
(
"1_1"
).
type
==
3
&&
captured_params
.
at
(
"1_1"
).
f
!=
1.
f
))
return
false
;
return
true
;
}
const
char
*
type_str
()
const
{
return
"F.gelu"
;
}
void
write
(
Operator
*
/*op*/
,
const
std
::
map
<
std
::
string
,
Parameter
>&
/*captured_params*/
)
const
{
}
};
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS
(
F_gelu_3
,
9
)
}
// namespace pnnx
tools/pnnx/src/pass_level2/F_local_response_norm.cpp
浏览文件 @
05ad0c52
...
...
@@ -66,7 +66,7 @@ pnnx.Output output 1 0 out
return
"F.local_response_norm"
;
}
bool
match
_captured_params
(
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
)
const
bool
match
(
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
)
const
{
if
(
captured_params
.
at
(
"padzero"
).
type
==
2
)
return
captured_params
.
at
(
"padzero"
).
i
==
0
;
...
...
@@ -168,7 +168,7 @@ pnnx.Output output 1 0 out
return
"F.local_response_norm"
;
}
bool
match
_captured_params
(
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
)
const
bool
match
(
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
)
const
{
if
(
captured_params
.
at
(
"padzero"
).
type
==
2
)
return
captured_params
.
at
(
"padzero"
).
i
==
0
;
...
...
@@ -274,7 +274,7 @@ pnnx.Output output 1 0 out
return
"F.local_response_norm"
;
}
bool
match
_captured_params
(
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
)
const
bool
match
(
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
)
const
{
if
(
captured_params
.
at
(
"padzero"
).
type
==
2
)
return
captured_params
.
at
(
"padzero"
).
i
==
0
;
...
...
@@ -347,7 +347,7 @@ pnnx.Output output 1 0 out
return
"F.local_response_norm"
;
}
bool
match
_captured_params
(
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
)
const
bool
match
(
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
)
const
{
if
(
captured_params
.
at
(
"padzero"
).
type
==
2
)
return
captured_params
.
at
(
"padzero"
).
i
==
0
;
...
...
@@ -450,7 +450,7 @@ pnnx.Output output 1 0 out
return
"F.local_response_norm"
;
}
bool
match
_captured_params
(
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
)
const
bool
match
(
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
)
const
{
if
(
captured_params
.
at
(
"padzero"
).
type
==
2
)
return
captured_params
.
at
(
"padzero"
).
i
==
0
;
...
...
@@ -557,7 +557,7 @@ pnnx.Output output 1 0 out
return
"F.local_response_norm"
;
}
bool
match
_captured_params
(
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
)
const
bool
match
(
const
std
::
map
<
std
::
string
,
Parameter
>&
captured_params
)
const
{
if
(
captured_params
.
at
(
"padzero"
).
type
==
2
)
return
captured_params
.
at
(
"padzero"
).
i
==
0
;
...
...
tools/pnnx/tests/test_F_gelu.py
浏览文件 @
05ad0c52
...
...
@@ -15,6 +15,13 @@
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
math
def
gelu_forward_0
(
x
):
return
x
*
0.5
*
(
1.0
+
torch
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
def
gelu_forward_1
(
x
):
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
math
.
sqrt
(
2.0
/
math
.
pi
)
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3.0
))))
class
Model
(
nn
.
Module
):
def
__init__
(
self
):
...
...
@@ -23,8 +30,8 @@ class Model(nn.Module):
def
forward
(
self
,
x
,
y
,
z
,
w
):
x
=
F
.
gelu
(
x
)
y
=
F
.
gelu
(
y
)
z
=
F
.
gelu
(
z
)
w
=
F
.
gelu
(
w
)
z
=
gelu_forward_0
(
z
)
w
=
gelu_forward_1
(
w
)
return
x
,
y
,
z
,
w
def
test
():
...
...
@@ -37,7 +44,7 @@ def test():
z
=
torch
.
rand
(
1
,
3
,
12
,
16
)
w
=
torch
.
rand
(
1
,
5
,
7
,
9
,
11
)
a
0
,
a1
,
a2
,
a3
=
net
(
x
,
y
,
z
,
w
)
a
=
net
(
x
,
y
,
z
,
w
)
# export torchscript
mod
=
torch
.
jit
.
trace
(
net
,
(
x
,
y
,
z
,
w
))
...
...
@@ -49,9 +56,12 @@ def test():
# pnnx inference
import
test_F_gelu_pnnx
b
0
,
b1
,
b2
,
b3
=
test_F_gelu_pnnx
.
test_inference
()
b
=
test_F_gelu_pnnx
.
test_inference
()
return
torch
.
equal
(
a0
,
b0
)
and
torch
.
equal
(
a1
,
b1
)
and
torch
.
equal
(
a2
,
b2
)
and
torch
.
equal
(
a3
,
b3
)
for
a0
,
b0
in
zip
(
a
,
b
):
if
not
torch
.
allclose
(
a0
,
b0
,
1e-4
,
1e-4
):
return
False
return
True
if
__name__
==
"__main__"
:
if
test
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录