Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
10cc040d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
10cc040d
编写于
11月 15, 2021
作者:
J
jiangcheng
提交者:
GitHub
11月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add fetch op for cinn graph output node of build_cinn_pass (#37172)
上级
444a7358
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
34 addition
and
14 deletion
+34
-14
paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
+9
-0
paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc
paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc
+19
-10
python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py
...n/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py
+6
-4
未找到文件。
paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
浏览文件 @
10cc040d
...
...
@@ -193,9 +193,18 @@ void AddOutputVar(const GraphNodeSet& output_vars, const GraphNodeSet& cluster,
const
GraphNodeMap
&
old_op2new_op
,
const
GraphNodeMap
&
old_var2new_var
,
Graph
*
graph
)
{
for
(
auto
*
old_var
:
output_vars
)
{
// create fetch op
OpDesc
desc
;
desc
.
SetType
(
"fetch"
);
desc
.
SetInput
(
"X"
,
{
old_var
->
Name
()});
auto
op
=
graph
->
CreateOpNode
(
&
desc
);
auto
*
var
=
old_var2new_var
.
at
(
old_var
);
VLOG
(
4
)
<<
"Add Output Var Node: "
<<
var
->
Name
();
// link fetch op and fetch var
IR_NODE_LINK_TO
(
var
,
op
);
for
(
auto
*
old_op
:
old_var
->
inputs
)
{
if
(
cluster
.
count
(
old_op
))
{
IR_NODE_LINK_TO
(
old_op2new_op
.
at
(
old_op
),
var
);
...
...
paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc
浏览文件 @
10cc040d
...
...
@@ -264,7 +264,7 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
// After search, there should has just one cinn subgraph
// feed --> v1 --
// | --> mul --> v3 --
// v2 -- | --> add --> v5 --> relu --> v6
// v2 -- | --> add --> v5 --> relu --> v6
--> fetch
// feed --> v4 --
auto
compilation_keys
=
GetCompilationKeys
(
*
g
);
ASSERT_EQ
(
compilation_keys
.
size
(),
static_cast
<
size_t
>
(
1
));
...
...
@@ -272,13 +272,14 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
const
auto
&
subgraph
=
cinn_compiler
->
FindGraph
(
compilation_keys
[
0
]);
const
auto
&
subnodes
=
subgraph
.
Nodes
();
ASSERT_EQ
(
subnodes
.
size
(),
static_cast
<
size_t
>
(
1
1
));
ASSERT_EQ
(
subnodes
.
size
(),
static_cast
<
size_t
>
(
1
2
));
ASSERT_TRUE
(
CheckGraphIndependence
(
subnodes
));
ASSERT_TRUE
(
CheckNodeExisted
(
subnodes
,
"mul"
));
ASSERT_TRUE
(
CheckNodeExisted
(
subnodes
,
"add"
));
ASSERT_TRUE
(
CheckNodeExisted
(
subnodes
,
"relu"
));
ASSERT_EQ
(
CountNode
(
subnodes
,
"feed"
),
2
);
ASSERT_EQ
(
CountNode
(
subnodes
,
"fetch"
),
1
);
// No-parameter input should has feed op
auto
new_v1
=
GetNode
(
subnodes
,
"var1"
);
...
...
@@ -292,6 +293,13 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
ASSERT_TRUE
(
new_v2
->
inputs
.
empty
());
ASSERT_EQ
(
new_v2
->
outputs
.
size
(),
static_cast
<
size_t
>
(
1
));
ASSERT_EQ
(
new_v2
->
outputs
[
0
]
->
Name
(),
"mul"
);
// output should has fetch op
auto
new_v6
=
GetNode
(
subnodes
,
"var6"
);
ASSERT_EQ
(
new_v6
->
inputs
.
size
(),
static_cast
<
size_t
>
(
1
));
ASSERT_EQ
(
new_v6
->
outputs
.
size
(),
static_cast
<
size_t
>
(
1
));
ASSERT_EQ
(
new_v6
->
inputs
[
0
]
->
Name
(),
"relu"
);
ASSERT_EQ
(
new_v6
->
outputs
[
0
]
->
Name
(),
"fetch"
);
}
std
::
unique_ptr
<
Graph
>
BuildGraphWithOneCinnSubgraph
()
{
...
...
@@ -379,7 +387,7 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) {
// After search, there should has just one cinn subgraph
// feed --> v1 --
// | --> mul --> v3 --> relu --> v4
// | --> mul --> v3 --> relu --> v4
--> fetch
// v2 --
auto
compilation_keys
=
GetCompilationKeys
(
*
g
);
ASSERT_EQ
(
compilation_keys
.
size
(),
static_cast
<
size_t
>
(
1
));
...
...
@@ -387,12 +395,13 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) {
const
auto
&
subgraph
=
cinn_compiler
->
FindGraph
(
compilation_keys
[
0
]);
const
auto
&
subnodes
=
subgraph
.
Nodes
();
ASSERT_EQ
(
subnodes
.
size
(),
static_cast
<
size_t
>
(
7
));
ASSERT_EQ
(
subnodes
.
size
(),
static_cast
<
size_t
>
(
8
));
ASSERT_TRUE
(
CheckGraphIndependence
(
subnodes
));
ASSERT_TRUE
(
CheckNodeExisted
(
subnodes
,
"mul"
));
ASSERT_TRUE
(
CheckNodeExisted
(
subnodes
,
"relu"
));
ASSERT_EQ
(
CountNode
(
subnodes
,
"feed"
),
1
);
ASSERT_EQ
(
CountNode
(
subnodes
,
"fetch"
),
1
);
}
std
::
unique_ptr
<
Graph
>
BuildGraphWithMultiCinnSubgraph
()
{
...
...
@@ -496,10 +505,10 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) {
ASSERT_EQ
(
compilation_keys
.
size
(),
static_cast
<
size_t
>
(
2
));
// subgraph1:
// feed --> v4 --> relu --> v5
// feed --> v4 --> relu --> v5
--> fetch
// subgraph2:
// feed --> v1 --
// | --> mul --> v3
// | --> mul --> v3
--> fetch
// v2 --
auto
*
cinn_compiler
=
CinnCompiler
::
GetInstance
();
const
auto
&
subgraph1
=
cinn_compiler
->
FindGraph
(
compilation_keys
[
0
]);
...
...
@@ -511,11 +520,11 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) {
ASSERT_TRUE
(
CheckGraphIndependence
(
subnodes2
));
if
(
CheckNodeExisted
(
subnodes1
,
"relu"
))
{
ASSERT_EQ
(
subnodes1
.
size
(),
static_cast
<
size_t
>
(
4
));
ASSERT_EQ
(
subnodes2
.
size
(),
static_cast
<
size_t
>
(
5
));
}
else
{
ASSERT_EQ
(
subnodes2
.
size
(),
static_cast
<
size_t
>
(
4
));
ASSERT_EQ
(
subnodes1
.
size
(),
static_cast
<
size_t
>
(
5
));
ASSERT_EQ
(
subnodes2
.
size
(),
static_cast
<
size_t
>
(
6
));
}
else
{
ASSERT_EQ
(
subnodes2
.
size
(),
static_cast
<
size_t
>
(
5
));
ASSERT_EQ
(
subnodes1
.
size
(),
static_cast
<
size_t
>
(
6
));
}
}
...
...
python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py
浏览文件 @
10cc040d
...
...
@@ -40,15 +40,14 @@ def set_cinn_flag(val):
class
TestResnet50Accuracy
(
unittest
.
TestCase
):
def
reader
(
self
,
limit
):
for
_
in
range
(
limit
):
yield
np
.
random
.
randint
(
0
,
256
,
size
=
[
32
,
3
,
224
,
224
]).
astype
(
'float32'
),
\
np
.
random
.
randint
(
0
,
1000
,
size
=
[
32
]).
astype
(
'int64'
)
yield
{
'image'
:
np
.
random
.
randint
(
0
,
256
,
size
=
[
32
,
3
,
224
,
224
]).
astype
(
'float32'
),
\
'label'
:
np
.
random
.
randint
(
0
,
1000
,
size
=
[
32
]).
astype
(
'int64'
)}
def
generate_random_data
(
self
,
loop_num
=
10
):
feed
=
[]
data
=
self
.
reader
(
loop_num
)
for
_
in
range
(
loop_num
):
x
,
y
=
next
(
data
)
feed
.
append
({
'image'
:
x
,
'label'
:
y
})
feed
.
append
(
next
(
data
))
return
feed
def
build_program
(
self
,
main_program
,
startup_program
):
...
...
@@ -57,6 +56,9 @@ class TestResnet50Accuracy(unittest.TestCase):
name
=
'image'
,
shape
=
[
32
,
3
,
224
,
224
],
dtype
=
'float32'
)
label
=
paddle
.
static
.
data
(
name
=
'label'
,
shape
=
[
32
],
dtype
=
'int64'
)
# TODO: stop_gradient slower training speed, need fix
image
.
stop_gradient
=
False
model
=
paddle
.
vision
.
models
.
resnet50
()
prediction
=
model
(
image
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录