Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ac6ef06f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ac6ef06f
编写于
3月 07, 2019
作者:
Z
Zhen Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add the Clone method in Graph. test=develop
上级
01eddf12
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
49 addition
and
9 deletion
+49
-9
paddle/fluid/framework/ir/graph.cc
paddle/fluid/framework/ir/graph.cc
+33
-0
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+4
-0
paddle/fluid/framework/ir/node.h
paddle/fluid/framework/ir/node.h
+1
-0
paddle/fluid/pybind/ir.cc
paddle/fluid/pybind/ir.cc
+2
-0
python/paddle/fluid/contrib/slim/tests/test_graph.py
python/paddle/fluid/contrib/slim/tests/test_graph.py
+8
-8
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+1
-1
未找到文件。
paddle/fluid/framework/ir/graph.cc
浏览文件 @
ac6ef06f
...
@@ -152,6 +152,39 @@ void Graph::ResolveHazard(
...
@@ -152,6 +152,39 @@ void Graph::ResolveHazard(
}
}
}
}
std
::
shared_ptr
<
Graph
>
Graph
::
Clone
()
{
auto
cloned_graph
=
std
::
make_shared
<
Graph
>
(
this
->
program_
);
cloned_graph
->
ReleaseNodes
();
cloned_graph
->
num_node_created_
=
0
;
std
::
unordered_map
<
ir
::
Node
*
,
ir
::
Node
*>
origin_to_cloned
;
for
(
auto
*
n
:
this
->
node_set_
)
{
ir
::
Node
*
cloned_node
=
nullptr
;
if
(
n
->
IsCtrlVar
())
{
cloned_node
=
cloned_graph
->
CreateControlDepVar
();
}
else
if
(
!
n
->
var_desc_
&&
!
n
->
op_desc_
)
{
// empty node
cloned_node
=
cloned_graph
->
CreateEmptyNode
(
n
->
Name
(),
n
->
NodeType
());
}
else
if
(
n
->
IsVar
())
{
cloned_node
=
cloned_graph
->
CreateVarNode
(
n
->
Var
());
}
else
if
(
n
->
IsOp
())
{
cloned_node
=
cloned_graph
->
CreateOpNode
(
n
->
Op
());
}
if
(
cloned_node
)
{
origin_to_cloned
[
n
]
=
cloned_node
;
}
else
{
PADDLE_THROW
(
"The cloned node's type is not supported!"
);
}
}
for
(
auto
*
n
:
this
->
node_set_
)
{
for
(
auto
it
=
n
->
inputs
.
begin
();
it
!=
n
->
inputs
.
end
();
it
++
)
{
origin_to_cloned
[
n
]
->
inputs
.
push_back
(
origin_to_cloned
[
*
it
]);
}
for
(
auto
it
=
n
->
outputs
.
begin
();
it
!=
n
->
outputs
.
end
();
it
++
)
{
origin_to_cloned
[
n
]
->
outputs
.
push_back
(
origin_to_cloned
[
*
it
]);
}
}
return
cloned_graph
;
}
bool
IsControlDepVar
(
const
ir
::
Node
&
var
)
{
bool
IsControlDepVar
(
const
ir
::
Node
&
var
)
{
return
var
.
Name
().
find
(
ir
::
Node
::
kControlDepVarName
)
!=
std
::
string
::
npos
;
return
var
.
Name
().
find
(
ir
::
Node
::
kControlDepVarName
)
!=
std
::
string
::
npos
;
}
}
...
...
paddle/fluid/framework/ir/graph.h
浏览文件 @
ac6ef06f
...
@@ -213,6 +213,10 @@ class Graph {
...
@@ -213,6 +213,10 @@ class Graph {
void
ResolveHazard
(
void
ResolveHazard
(
const
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
&
var_nodes
);
const
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
&
var_nodes
);
// Create a new and duplicated graph.
// WARN: The method only clones the graph structure, not its attributes.
std
::
shared_ptr
<
Graph
>
Clone
();
private:
private:
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
InitFromProgram
(
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
InitFromProgram
(
const
ProgramDesc
&
program
);
const
ProgramDesc
&
program
);
...
...
paddle/fluid/framework/ir/node.h
浏览文件 @
ac6ef06f
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <memory>
#include <string>
#include <string>
#include <typeindex>
#include <typeindex>
#include <typeinfo>
#include <typeinfo>
...
...
paddle/fluid/pybind/ir.cc
浏览文件 @
ac6ef06f
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
...
@@ -54,6 +55,7 @@ void BindGraph(py::module *m) {
...
@@ -54,6 +55,7 @@ void BindGraph(py::module *m) {
"The graph is a Directed Acyclic Single Static Assignment Graph, see "
"The graph is a Directed Acyclic Single Static Assignment Graph, see "
"`paddle::ir::Graph` for details."
)
"`paddle::ir::Graph` for details."
)
.
def
(
py
::
init
<
const
ProgramDesc
&>
())
.
def
(
py
::
init
<
const
ProgramDesc
&>
())
.
def
(
"clone"
,
&
Graph
::
Clone
)
.
def
(
"has"
,
&
Graph
::
Has
)
.
def
(
"has"
,
&
Graph
::
Has
)
.
def
(
"get_int"
,
&
Graph
::
Get
<
int
>
)
.
def
(
"get_int"
,
&
Graph
::
Get
<
int
>
)
.
def
(
"get_float"
,
&
Graph
::
Get
<
float
>
)
.
def
(
"get_float"
,
&
Graph
::
Get
<
float
>
)
...
...
python/paddle/fluid/contrib/slim/tests/test_graph.py
浏览文件 @
ac6ef06f
...
@@ -60,20 +60,12 @@ class TestGraph(unittest.TestCase):
...
@@ -60,20 +60,12 @@ class TestGraph(unittest.TestCase):
opt
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
)
opt
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
)
opt
.
minimize
(
loss
)
opt
.
minimize
(
loss
)
graph
=
IrGraph
(
core
.
Graph
(
main
.
desc
),
for_test
=
False
)
graph
=
IrGraph
(
core
.
Graph
(
main
.
desc
),
for_test
=
False
)
backup_graph
=
graph
.
clone
()
self
.
assertEqual
(
len
(
graph
.
all_nodes
()),
len
(
backup_graph
.
all_nodes
()))
marked_nodes
=
set
()
marked_nodes
=
set
()
for
op
in
graph
.
all_op_nodes
():
for
op
in
graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'conv2d'
)
>
-
1
:
if
op
.
name
().
find
(
'conv2d'
)
>
-
1
:
marked_nodes
.
add
(
op
)
marked_nodes
.
add
(
op
)
if
not
for_ci
:
if
not
for_ci
:
graph
.
draw
(
'.'
,
'residual'
,
marked_nodes
)
graph
.
draw
(
'.'
,
'residual'
,
marked_nodes
)
backup_marked_nodes
=
set
()
for
op
in
backup_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'conv2d'
)
>
-
1
:
backup_marked_nodes
.
add
(
op
)
backup_graph
.
draw
(
'.'
,
'backup'
,
backup_marked_nodes
)
self
.
assertFalse
(
graph
.
has_circle
())
self
.
assertFalse
(
graph
.
has_circle
())
self
.
assertEqual
(
graph
.
graph_num
(),
1
)
self
.
assertEqual
(
graph
.
graph_num
(),
1
)
nodes
=
graph
.
topology_sort
()
nodes
=
graph
.
topology_sort
()
...
@@ -83,6 +75,14 @@ class TestGraph(unittest.TestCase):
...
@@ -83,6 +75,14 @@ class TestGraph(unittest.TestCase):
nodes_num
=
len
(
graph
.
all_nodes
())
nodes_num
=
len
(
graph
.
all_nodes
())
graph
.
safe_remove_nodes
(
marked_nodes
)
graph
.
safe_remove_nodes
(
marked_nodes
)
self
.
assertEqual
(
len
(
graph
.
all_nodes
()),
nodes_num
-
len
(
marked_nodes
))
self
.
assertEqual
(
len
(
graph
.
all_nodes
()),
nodes_num
-
len
(
marked_nodes
))
backup_graph
=
graph
.
clone
()
self
.
assertEqual
(
len
(
graph
.
all_nodes
()),
len
(
backup_graph
.
all_nodes
()))
if
not
for_ci
:
backup_marked_nodes
=
set
()
for
op
in
backup_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'conv2d'
)
>
-
1
:
backup_marked_nodes
.
add
(
op
)
backup_graph
.
draw
(
'.'
,
'backup'
,
backup_marked_nodes
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/framework.py
浏览文件 @
ac6ef06f
...
@@ -2012,7 +2012,7 @@ class IrGraph(object):
...
@@ -2012,7 +2012,7 @@ class IrGraph(object):
Returns:
Returns:
IrGraph: A new and duplicated graph.
IrGraph: A new and duplicated graph.
"""
"""
g
=
core
.
Graph
(
self
.
graph
.
origin_program_desc
()
)
g
=
self
.
graph
.
clone
(
)
return
IrGraph
(
g
,
self
.
_for_test
)
return
IrGraph
(
g
,
self
.
_for_test
)
def
is_test
(
self
):
def
is_test
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录