Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
04e8759e
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
04e8759e
编写于
12月 11, 2018
作者:
A
Andy Ly
提交者:
TensorFlower Gardener
12月 11, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Grappler] Add helper functions to GraphView.
PiperOrigin-RevId: 225109110
上级
ae244e6d
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
92 addition
and
20 deletion
+92
-20
tensorflow/core/grappler/graph_view.h
tensorflow/core/grappler/graph_view.h
+41
-19
tensorflow/core/grappler/graph_view_test.cc
tensorflow/core/grappler/graph_view_test.cc
+34
-0
tensorflow/core/grappler/utils.cc
tensorflow/core/grappler/utils.cc
+6
-1
tensorflow/core/grappler/utils.h
tensorflow/core/grappler/utils.h
+4
-0
tensorflow/core/grappler/utils_test.cc
tensorflow/core/grappler/utils_test.cc
+7
-0
未找到文件。
tensorflow/core/grappler/graph_view.h
浏览文件 @
04e8759e
...
...
@@ -111,32 +111,37 @@ class GraphViewInternal {
GraphDefT
*
graph
()
const
{
return
graph_
;
}
// Find
a node by name or return `nullptr` if it's not in a
graph view.
// Find
s a node by name or return `nullptr` if it's not in the
graph view.
NodeDefT
*
GetNode
(
absl
::
string_view
node_name
)
const
{
return
gtl
::
FindWithDefault
(
nodes_
,
node_name
,
nullptr
);
}
// Get the specified input port. Note that the special '-1' port_id can be
// Checks if a node by name is in the graph view.
bool
HasNode
(
absl
::
string_view
node_name
)
const
{
return
GetNode
(
node_name
)
!=
nullptr
;
}
// Gets the specified input port. Note that the special '-1' port_id can be
// used to access the controlling nodes (i.e. the nodes connected to node_name
// through an incoming control dependency).
InputPort
GetInputPort
(
absl
::
string_view
node_name
,
int
port_id
)
const
{
return
InputPort
(
GetNode
(
node_name
),
port_id
);
}
// Get the specified output port. Note that the special '-1' port_id can be
// Get
s
the specified output port. Note that the special '-1' port_id can be
// used to access the controlled nodes (i.e. the nodes connected to node_name
// through an outgoing control dependency).
OutputPort
GetOutputPort
(
absl
::
string_view
node_name
,
int
port_id
)
const
{
return
OutputPort
(
GetNode
(
node_name
),
port_id
);
}
// Get the input (resp. output) port(s) in the immediate fanout (resp. fanin)
// of an output (resp. input) port.
// Gets the input port(s) in the immediate fanout of an output port.
const
absl
::
flat_hash_set
<
InputPort
>&
GetFanout
(
const
OutputPort
&
port
)
const
{
return
gtl
::
FindWithDefault
(
fanouts_
,
port
,
fanout_not_found_value_
);
}
// Gets the output port(s) in the immediate fanin of an input port.
absl
::
flat_hash_set
<
OutputPort
>
GetFanin
(
const
InputPort
&
port
)
const
{
if
(
port
.
port_id
>=
0
)
return
{
GetRegularFanin
(
port
)};
...
...
@@ -162,9 +167,22 @@ class GraphViewInternal {
return
GetOutputPort
(
tensor_id
.
node
(),
tensor_id
.
index
());
}
// Get all the input (resp. output) ports in the immediate fanout (resp
// fanin) of a node. Include the controlling nodes iff
// include_controlling_nodes is true.
// Checks if a tensor id is a fanin of the node.
bool
HasFanin
(
const
NodeDef
&
node
,
const
TensorId
&
fanin
)
const
{
if
(
fanin
.
index
()
<
-
1
)
{
return
false
;
}
string
fanin_string
=
TensorIdToString
(
fanin
);
for
(
int
i
=
0
;
i
<
node
.
input_size
();
++
i
)
{
if
(
node
.
input
(
i
)
==
fanin_string
)
{
return
true
;
}
}
return
false
;
}
// Gets all the input ports in the immediate fanout of a node. Include the
// controlled nodes iff include_controlled_nodes is true.
absl
::
flat_hash_set
<
InputPort
>
GetFanouts
(
const
NodeDef
&
node
,
bool
include_controlled_nodes
)
const
{
absl
::
flat_hash_set
<
InputPort
>
result
;
...
...
@@ -185,6 +203,8 @@ class GraphViewInternal {
return
result
;
}
// Gets all the output ports in the immediate fanin of a node. Include the
// controlling nodes iff include_controlling_nodes is true.
absl
::
flat_hash_set
<
OutputPort
>
GetFanins
(
const
NodeDef
&
node
,
bool
include_controlling_nodes
)
const
{
absl
::
flat_hash_set
<
OutputPort
>
result
;
...
...
@@ -198,7 +218,7 @@ class GraphViewInternal {
return
result
;
}
// Get the number of ports in the immediate fanin of a node. Count the
// Get
s
the number of ports in the immediate fanin of a node. Count the
// controlling nodes iff include_controlling_nodes is true.
int
NumFanins
(
const
NodeDef
&
node
,
bool
include_controlling_nodes
)
const
{
int
count
=
0
;
...
...
@@ -211,14 +231,14 @@ class GraphViewInternal {
return
count
;
}
// Get the number of ports in the immediate fanout of a node. Count the
// controll
ing nodes iff include_controlling
_nodes is true.
int
NumFanouts
(
const
NodeDef
&
node
,
bool
include_controll
ing
_nodes
)
const
{
// Get
s
the number of ports in the immediate fanout of a node. Count the
// controll
ed nodes iff include_controlled
_nodes is true.
int
NumFanouts
(
const
NodeDef
&
node
,
bool
include_controll
ed
_nodes
)
const
{
int
count
=
0
;
OutputPort
port
;
port
.
node
=
const_cast
<
NodeDefT
*>
(
&
node
);
const
int
first_port_id
=
include_controll
ing
_nodes
?
-
1
:
0
;
const
int
first_port_id
=
include_controll
ed
_nodes
?
-
1
:
0
;
const
int
last_port_id
=
gtl
::
FindWithDefault
(
max_regular_output_port_
,
port
.
node
,
-
1
);
...
...
@@ -231,8 +251,8 @@ class GraphViewInternal {
return
count
;
}
// Get
all the edges in the immediate fanout (resp fanin) of a node.
//
Include the control edges iff include_controlling
_edges is true.
// Get
s all the edges in the immediate fanout of a node. Include the
//
controlled edges iff include_controlled
_edges is true.
absl
::
flat_hash_set
<
Edge
>
GetFanoutEdges
(
const
NodeDef
&
node
,
bool
include_controlled_edges
)
const
{
absl
::
flat_hash_set
<
Edge
>
result
;
...
...
@@ -248,14 +268,16 @@ class GraphViewInternal {
auto
it
=
fanouts_
.
find
(
port
);
if
(
it
!=
fanouts_
.
end
())
{
for
(
auto
itr
=
it
->
second
.
begin
();
itr
!=
it
->
second
.
end
();
++
itr
)
{
result
.
emplace
(
/*src
*/
OutputPort
(
const_cast
<
NodeDefT
*>
(
&
node
),
i
),
/*dst
*/
*
itr
);
result
.
emplace
(
/*src
=*/
OutputPort
(
const_cast
<
NodeDefT
*>
(
&
node
),
i
),
/*dst
=*/
*
itr
);
}
}
}
return
result
;
}
// Gets all the edges in the immediate fanin of a node. Include the
// controlling edges iff include_controlling_edges is true.
absl
::
flat_hash_set
<
Edge
>
GetFaninEdges
(
const
NodeDef
&
node
,
bool
include_controlling_edges
)
const
{
absl
::
flat_hash_set
<
Edge
>
result
;
...
...
@@ -265,8 +287,8 @@ class GraphViewInternal {
auto
it
=
nodes_
.
find
(
tensor_id
.
node
());
if
(
it
!=
nodes_
.
end
())
{
result
.
emplace
(
/*src
*/
OutputPort
(
it
->
second
,
tensor_id
.
index
()),
/*dst
*/
InputPort
(
const_cast
<
NodeDefT
*>
(
&
node
),
i
));
result
.
emplace
(
/*src
=*/
OutputPort
(
it
->
second
,
tensor_id
.
index
()),
/*dst
=*/
InputPort
(
const_cast
<
NodeDefT
*>
(
&
node
),
i
));
}
}
return
result
;
...
...
tensorflow/core/grappler/graph_view_test.cc
浏览文件 @
04e8759e
...
...
@@ -230,6 +230,40 @@ TEST_F(GraphViewTest, ControlDependencies) {
EXPECT_EQ
(
0
,
(
*
fanin
.
begin
()).
port_id
);
}
TEST_F
(
GraphViewTest
,
HasNode
)
{
tensorflow
::
Scope
s
=
tensorflow
::
Scope
::
NewRootScope
();
Output
a
=
ops
::
Const
(
s
.
WithOpName
(
"a"
),
0.0
f
,
{
10
,
10
});
GrapplerItem
item
;
TF_CHECK_OK
(
s
.
ToGraphDef
(
&
item
.
graph
));
GraphView
graph
(
&
item
.
graph
);
EXPECT_EQ
(
true
,
graph
.
HasNode
(
"a"
));
EXPECT_EQ
(
false
,
graph
.
HasNode
(
"b"
));
}
TEST_F
(
GraphViewTest
,
HasFanin
)
{
tensorflow
::
Scope
s
=
tensorflow
::
Scope
::
NewRootScope
();
Output
a
=
ops
::
Const
(
s
.
WithOpName
(
"a"
),
0.0
f
,
{
10
,
10
});
Output
b
=
ops
::
Square
(
s
.
WithOpName
(
"b"
),
{
a
});
Output
c
=
ops
::
Sqrt
(
s
.
WithOpName
(
"c"
),
{
b
});
Output
d
=
ops
::
AddN
(
s
.
WithOpName
(
"d"
).
WithControlDependencies
(
a
),
{
b
,
c
});
GrapplerItem
item
;
TF_CHECK_OK
(
s
.
ToGraphDef
(
&
item
.
graph
));
GraphView
graph
(
&
item
.
graph
);
const
NodeDef
*
d_node
=
graph
.
GetNode
(
"d"
);
EXPECT_NE
(
nullptr
,
d_node
);
EXPECT_EQ
(
true
,
graph
.
HasFanin
(
*
d_node
,
{
"a"
,
Graph
::
kControlSlot
}));
EXPECT_EQ
(
false
,
graph
.
HasFanin
(
*
d_node
,
{
"a"
,
0
}));
EXPECT_EQ
(
true
,
graph
.
HasFanin
(
*
d_node
,
{
"b"
,
0
}));
EXPECT_EQ
(
false
,
graph
.
HasFanin
(
*
d_node
,
{
"b"
,
Graph
::
kControlSlot
}));
EXPECT_EQ
(
true
,
graph
.
HasFanin
(
*
d_node
,
{
"c"
,
0
}));
EXPECT_EQ
(
false
,
graph
.
HasFanin
(
*
d_node
,
{
"c"
,
Graph
::
kControlSlot
}));
}
}
// namespace
}
// namespace grappler
}
// namespace tensorflow
tensorflow/core/grappler/utils.cc
浏览文件 @
04e8759e
...
...
@@ -144,11 +144,16 @@ void NodeMap::UpdateOutput(const string& node_name,
outputs
.
insert
(
nodes_
[
NodeName
(
new_output_name
)]);
}
string
TensorIdToString
(
const
TensorId
&
tensor_id
)
{
return
tensor_id
.
index
()
==
0
?
string
(
tensor_id
.
node
())
:
tensor_id
.
ToString
();
}
bool
IsSameInput
(
const
string
&
name1
,
const
string
&
name2
)
{
if
(
name1
==
name2
)
return
true
;
TensorId
tensor1
=
ParseTensorName
(
name1
);
TensorId
tensor2
=
ParseTensorName
(
name2
);
return
tensor1
.
node
()
==
tensor2
.
node
()
&&
tensor1
.
index
()
==
tensor2
.
index
()
;
return
tensor1
==
tensor2
;
}
bool
IsControlInput
(
const
string
&
name
)
{
...
...
tensorflow/core/grappler/utils.h
浏览文件 @
04e8759e
...
...
@@ -100,6 +100,10 @@ class SetVector {
std
::
vector
<
T
>
vector_
;
};
// Returns formatted string from TensorId specific to grappler. Specifically,
// for the 0 port (first output), only the node name is returned.
string
TensorIdToString
(
const
TensorId
&
tensor_id
);
// True iff 'name' refers to a control inputs, i.e. a node name prefixed with
// the ^ character.
bool
IsControlInput
(
const
string
&
name
);
...
...
tensorflow/core/grappler/utils_test.cc
浏览文件 @
04e8759e
...
...
@@ -464,6 +464,13 @@ TEST_F(UtilsTest, SetTensorValueBFloat16IntMin) {
Tensor
(
bfloat16
(
std
::
numeric_limits
<
int
>::
min
())),
t
);
}
TEST_F
(
UtilsTest
,
TensorIdToString
)
{
EXPECT_EQ
(
"^foo"
,
TensorIdToString
({
"foo"
,
-
1
}));
EXPECT_EQ
(
"foo"
,
TensorIdToString
({
"foo"
,
0
}));
EXPECT_EQ
(
"foo:1"
,
TensorIdToString
({
"foo"
,
1
}));
EXPECT_EQ
(
"foo:2"
,
TensorIdToString
({
"foo"
,
2
}));
}
}
// namespace
}
// namespace grappler
}
// namespace tensorflow
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录