Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wux_labs
Tensorflow
提交
d15c612f
T
Tensorflow
项目概览
wux_labs
/
Tensorflow
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d15c612f
编写于
6月 05, 2019
作者:
A
Andy Ly
提交者:
TensorFlower Gardener
6月 05, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Grappler] Migrate FrameView to use utils::GraphView/utils::MutableGraphView.
PiperOrigin-RevId: 251659253
上级
cd09510f
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
385 addition
and
187 deletion
+385
-187
tensorflow/core/grappler/optimizers/BUILD
tensorflow/core/grappler/optimizers/BUILD
+4
-0
tensorflow/core/grappler/optimizers/layout_optimizer.cc
tensorflow/core/grappler/optimizers/layout_optimizer.cc
+3
-1
tensorflow/core/grappler/optimizers/loop_optimizer.h
tensorflow/core/grappler/optimizers/loop_optimizer.h
+2
-0
tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
+177
-63
tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h
...low/core/grappler/optimizers/scoped_allocator_optimizer.h
+1
-0
tensorflow/core/grappler/utils/BUILD
tensorflow/core/grappler/utils/BUILD
+3
-1
tensorflow/core/grappler/utils/frame.cc
tensorflow/core/grappler/utils/frame.cc
+95
-61
tensorflow/core/grappler/utils/frame.h
tensorflow/core/grappler/utils/frame.h
+8
-3
tensorflow/core/grappler/utils/frame_test.cc
tensorflow/core/grappler/utils/frame_test.cc
+92
-58
未找到文件。
tensorflow/core/grappler/optimizers/BUILD
浏览文件 @
d15c612f
...
...
@@ -498,6 +498,7 @@ cc_library(
"//tensorflow/core/grappler:devices"
,
"//tensorflow/core/grappler:grappler_item"
,
"//tensorflow/core/grappler:op_types"
,
"//tensorflow/core/grappler:utils"
,
"//tensorflow/core/grappler/clusters:cluster"
,
"//tensorflow/core/grappler/costs:graph_properties"
,
"//tensorflow/core/grappler/costs:virtual_placer"
,
...
...
@@ -703,6 +704,7 @@ cc_library(
"//tensorflow/core/grappler:grappler_item"
,
"//tensorflow/core/grappler:mutable_graph_view"
,
"//tensorflow/core/grappler:op_types"
,
"//tensorflow/core/grappler:utils"
,
"//tensorflow/core/grappler/costs:graph_properties"
,
"//tensorflow/core/grappler/utils:frame"
,
"//tensorflow/core/grappler/utils:traversal"
,
...
...
@@ -724,6 +726,7 @@ tf_cuda_cc_test(
"//tensorflow/core/grappler:grappler_item"
,
"//tensorflow/core/grappler:utils"
,
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder"
,
"//tensorflow/core/grappler/utils:graph_view"
,
"//tensorflow/core/grappler/utils:grappler_test"
,
],
)
...
...
@@ -883,6 +886,7 @@ cc_library(
"//tensorflow/core:protos_all_cc"
,
"//tensorflow/core/grappler:grappler_item"
,
"//tensorflow/core/grappler:op_types"
,
"//tensorflow/core/grappler:utils"
,
"//tensorflow/core/grappler/costs:graph_properties"
,
"//tensorflow/core/grappler/utils:frame"
,
],
...
...
tensorflow/core/grappler/optimizers/layout_optimizer.cc
浏览文件 @
d15c612f
...
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
#include <deque>
#include <unordered_set>
...
...
@@ -28,7 +30,7 @@ limitations under the License.
#include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/
optimizers/layout_optimizer
.h"
#include "tensorflow/core/grappler/
utils
.h"
#include "tensorflow/core/grappler/utils/frame.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
...
...
tensorflow/core/grappler/optimizers/loop_optimizer.h
浏览文件 @
d15c612f
...
...
@@ -17,8 +17,10 @@ limitations under the License.
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_
#include <unordered_set>
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/frame.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
...
...
tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
浏览文件 @
d15c612f
...
...
@@ -14,12 +14,14 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/graph_view.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
...
...
@@ -104,26 +106,42 @@ TEST_F(LoopOptimizerTest, Basic) {
TF_EXPECT_OK
(
optimizer
.
Optimize
(
nullptr
,
item
,
&
output
));
{
// Original graph.
GraphView
view
(
&
graph
);
Status
status
;
utils
::
GraphView
view
(
&
graph
,
&
status
);
TF_ASSERT_OK
(
status
);
FrameView
frames
;
TF_EXPECT_OK
(
frames
.
InferFromGraphView
(
view
));
EXPECT_EQ
(
frames
.
num_frames
(),
1
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd"
)).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd"
)).
back
(),
0
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"VariantAdd"
)).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"VariantAdd"
)).
back
(),
0
);
const
auto
*
invariant_add_node
=
view
.
GetNode
(
"InvariantAdd"
);
ASSERT_NE
(
invariant_add_node
,
nullptr
);
const
auto
*
invariant_add_node_def
=
invariant_add_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
invariant_add_node_def
).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
invariant_add_node_def
).
back
(),
0
);
const
auto
*
variant_add_node
=
view
.
GetNode
(
"VariantAdd"
);
ASSERT_NE
(
variant_add_node
,
nullptr
);
const
auto
*
variant_add_node_def
=
variant_add_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
variant_add_node_def
).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
variant_add_node_def
).
back
(),
0
);
}
{
// Optimized graph.
GraphView
view
(
&
output
);
Status
status
;
utils
::
GraphView
view
(
&
output
,
&
status
);
TF_ASSERT_OK
(
status
);
FrameView
frames
;
TF_EXPECT_OK
(
frames
.
InferFromGraphView
(
view
));
EXPECT_EQ
(
frames
.
num_frames
(),
1
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd"
)).
size
(),
0
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"VariantAdd"
)).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"VariantAdd"
)).
back
(),
0
);
const
auto
*
invariant_add_node
=
view
.
GetNode
(
"InvariantAdd"
);
ASSERT_NE
(
invariant_add_node
,
nullptr
);
const
auto
*
invariant_add_node_def
=
invariant_add_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
invariant_add_node_def
).
size
(),
0
);
const
auto
*
variant_add_node
=
view
.
GetNode
(
"VariantAdd"
);
ASSERT_NE
(
variant_add_node
,
nullptr
);
const
auto
*
variant_add_node_def
=
variant_add_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
variant_add_node_def
).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
variant_add_node_def
).
back
(),
0
);
}
}
...
...
@@ -155,25 +173,41 @@ TEST_F(LoopOptimizerTest, Const) {
TF_EXPECT_OK
(
optimizer
.
Optimize
(
nullptr
,
item
,
&
output
));
{
// Original graph.
GraphView
view
(
&
graph
);
Status
status
;
utils
::
GraphView
view
(
&
graph
,
&
status
);
TF_ASSERT_OK
(
status
);
FrameView
frames
;
TF_EXPECT_OK
(
frames
.
InferFromGraphView
(
view
));
EXPECT_EQ
(
frames
.
num_frames
(),
1
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd"
)).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd"
)).
back
(),
0
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"Const"
)).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"Const"
)).
back
(),
0
);
const
auto
*
invariant_add_node
=
view
.
GetNode
(
"InvariantAdd"
);
ASSERT_NE
(
invariant_add_node
,
nullptr
);
const
auto
*
invariant_add_node_def
=
invariant_add_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
invariant_add_node_def
).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
invariant_add_node_def
).
back
(),
0
);
const
auto
*
const_node
=
view
.
GetNode
(
"Const"
);
ASSERT_NE
(
const_node
,
nullptr
);
const
auto
*
const_node_node_def
=
const_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
const_node_node_def
).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
const_node_node_def
).
back
(),
0
);
}
{
// Optimized graph.
GraphView
view
(
&
output
);
Status
status
;
utils
::
GraphView
view
(
&
output
,
&
status
);
TF_ASSERT_OK
(
status
);
FrameView
frames
;
TF_EXPECT_OK
(
frames
.
InferFromGraphView
(
view
));
EXPECT_EQ
(
frames
.
num_frames
(),
1
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd"
)).
size
(),
0
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"Const"
)).
size
(),
0
);
const
auto
*
invariant_add_node
=
view
.
GetNode
(
"InvariantAdd"
);
ASSERT_NE
(
invariant_add_node
,
nullptr
);
const
auto
*
invariant_add_node_def
=
invariant_add_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
invariant_add_node_def
).
size
(),
0
);
const
auto
*
const_node
=
view
.
GetNode
(
"Const"
);
ASSERT_NE
(
const_node
,
nullptr
);
const
auto
*
const_node_node_def
=
const_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
const_node_node_def
).
size
(),
0
);
}
}
...
...
@@ -206,23 +240,33 @@ TEST_F(LoopOptimizerTest, ControlOutput) {
TF_EXPECT_OK
(
optimizer
.
Optimize
(
nullptr
,
item
,
&
output
));
{
// Original graph.
GraphView
view
(
&
graph
);
Status
status
;
utils
::
GraphView
view
(
&
graph
,
&
status
);
TF_ASSERT_OK
(
status
);
FrameView
frames
;
TF_EXPECT_OK
(
frames
.
InferFromGraphView
(
view
));
EXPECT_EQ
(
frames
.
num_frames
(),
1
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd"
)).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd"
)).
back
(),
0
);
const
auto
*
invariant_add_node
=
view
.
GetNode
(
"InvariantAdd"
);
ASSERT_NE
(
invariant_add_node
,
nullptr
);
const
auto
*
invariant_add_node_def
=
invariant_add_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
invariant_add_node_def
).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
invariant_add_node_def
).
back
(),
0
);
}
{
// Optimized graph.
GraphView
view
(
&
output
);
Status
status
;
utils
::
GraphView
view
(
&
output
,
&
status
);
TF_ASSERT_OK
(
status
);
FrameView
frames
;
TF_EXPECT_OK
(
frames
.
InferFromGraphView
(
view
));
EXPECT_EQ
(
frames
.
num_frames
(),
1
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd"
)).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd"
)).
back
(),
0
);
const
auto
*
invariant_add_node
=
view
.
GetNode
(
"InvariantAdd"
);
ASSERT_NE
(
invariant_add_node
,
nullptr
);
const
auto
*
invariant_add_node_def
=
invariant_add_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
invariant_add_node_def
).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
invariant_add_node_def
).
back
(),
0
);
}
}
...
...
@@ -270,30 +314,52 @@ TEST_F(LoopOptimizerTest, NestedLoop1) {
TF_EXPECT_OK
(
optimizer
.
Optimize
(
nullptr
,
item
,
&
output
));
{
// Original graph.
GraphView
view
(
&
graph
);
Status
status
;
utils
::
GraphView
view
(
&
graph
,
&
status
);
TF_ASSERT_OK
(
status
);
FrameView
frames
;
TF_EXPECT_OK
(
frames
.
InferFromGraphView
(
view
));
EXPECT_EQ
(
frames
.
num_frames
(),
2
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd2"
)).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd2"
)).
back
(),
1
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"VariantAdd2"
)).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"VariantAdd2"
)).
back
(),
1
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd"
)).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd"
)).
back
(),
0
);
const
auto
*
invariant_add_2_node
=
view
.
GetNode
(
"InvariantAdd2"
);
ASSERT_NE
(
invariant_add_2_node
,
nullptr
);
const
auto
*
invariant_add_2_node_def
=
invariant_add_2_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
invariant_add_2_node_def
).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
invariant_add_2_node_def
).
back
(),
1
);
const
auto
*
variant_add_2_node
=
view
.
GetNode
(
"VariantAdd2"
);
ASSERT_NE
(
variant_add_2_node
,
nullptr
);
const
auto
*
variant_add_2_node_def
=
variant_add_2_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
variant_add_2_node_def
).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
variant_add_2_node_def
).
back
(),
1
);
const
auto
*
invariant_add_node
=
view
.
GetNode
(
"InvariantAdd"
);
ASSERT_NE
(
invariant_add_node
,
nullptr
);
const
auto
*
invariant_add_node_def
=
invariant_add_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
invariant_add_node_def
).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
invariant_add_node_def
).
back
(),
0
);
}
{
// Optimized graph.
GraphView
view
(
&
output
);
Status
status
;
utils
::
GraphView
view
(
&
output
,
&
status
);
TF_ASSERT_OK
(
status
);
FrameView
frames
;
TF_EXPECT_OK
(
frames
.
InferFromGraphView
(
view
));
EXPECT_EQ
(
frames
.
num_frames
(),
2
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd2"
)).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd2"
)).
back
(),
0
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"VariantAdd2"
)).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"VariantAdd2"
)).
back
(),
1
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd"
)).
size
(),
0
);
const
auto
*
invariant_add_2_node
=
view
.
GetNode
(
"InvariantAdd2"
);
ASSERT_NE
(
invariant_add_2_node
,
nullptr
);
const
auto
*
invariant_add_2_node_def
=
invariant_add_2_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
invariant_add_2_node_def
).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
invariant_add_2_node_def
).
back
(),
0
);
const
auto
*
variant_add_2_node
=
view
.
GetNode
(
"VariantAdd2"
);
ASSERT_NE
(
variant_add_2_node
,
nullptr
);
const
auto
*
variant_add_2_node_def
=
variant_add_2_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
variant_add_2_node_def
).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
variant_add_2_node_def
).
back
(),
1
);
const
auto
*
invariant_add_node
=
view
.
GetNode
(
"InvariantAdd"
);
ASSERT_NE
(
invariant_add_node
,
nullptr
);
const
auto
*
invariant_add_node_def
=
invariant_add_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
invariant_add_node_def
).
size
(),
0
);
}
}
...
...
@@ -341,26 +407,42 @@ TEST_F(LoopOptimizerTest, NestedLoop2) {
TF_EXPECT_OK
(
optimizer
.
Optimize
(
nullptr
,
item
,
&
output
));
{
// Original graph.
GraphView
view
(
&
graph
);
Status
status
;
utils
::
GraphView
view
(
&
graph
,
&
status
);
TF_ASSERT_OK
(
status
);
FrameView
frames
;
TF_EXPECT_OK
(
frames
.
InferFromGraphView
(
view
));
EXPECT_EQ
(
frames
.
num_frames
(),
2
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd2"
)).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd2"
)).
back
(),
1
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"VariantAdd2"
)).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"VariantAdd2"
)).
back
(),
1
);
const
auto
*
invariant_add_2_node
=
view
.
GetNode
(
"InvariantAdd2"
);
ASSERT_NE
(
invariant_add_2_node
,
nullptr
);
const
auto
*
invariant_add_2_node_def
=
invariant_add_2_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
invariant_add_2_node_def
).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
invariant_add_2_node_def
).
back
(),
1
);
const
auto
*
variant_add_2_node
=
view
.
GetNode
(
"VariantAdd2"
);
ASSERT_NE
(
variant_add_2_node
,
nullptr
);
const
auto
*
variant_add_2_node_def
=
variant_add_2_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
variant_add_2_node_def
).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
variant_add_2_node_def
).
back
(),
1
);
}
{
// Optimized graph.
GraphView
view
(
&
output
);
Status
status
;
utils
::
GraphView
view
(
&
output
,
&
status
);
TF_ASSERT_OK
(
status
);
FrameView
frames
;
TF_EXPECT_OK
(
frames
.
InferFromGraphView
(
view
));
EXPECT_EQ
(
frames
.
num_frames
(),
2
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd2"
)).
size
(),
0
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"VariantAdd2"
)).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"VariantAdd2"
)).
back
(),
1
);
const
auto
*
invariant_add_2_node
=
view
.
GetNode
(
"InvariantAdd2"
);
ASSERT_NE
(
invariant_add_2_node
,
nullptr
);
const
auto
*
invariant_add_2_node_def
=
invariant_add_2_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
invariant_add_2_node_def
).
size
(),
0
);
const
auto
*
variant_add_2_node
=
view
.
GetNode
(
"VariantAdd2"
);
ASSERT_NE
(
variant_add_2_node
,
nullptr
);
const
auto
*
variant_add_2_node_def
=
variant_add_2_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
variant_add_2_node_def
).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
variant_add_2_node_def
).
back
(),
1
);
}
}
...
...
@@ -408,27 +490,43 @@ TEST_F(LoopOptimizerTest, NestedLoopConst1) {
TF_EXPECT_OK
(
optimizer
.
Optimize
(
nullptr
,
item
,
&
output
));
{
// Original graph.
GraphView
view
(
&
graph
);
Status
status
;
utils
::
GraphView
view
(
&
graph
,
&
status
);
TF_ASSERT_OK
(
status
);
FrameView
frames
;
TF_EXPECT_OK
(
frames
.
InferFromGraphView
(
view
));
EXPECT_EQ
(
frames
.
num_frames
(),
2
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd2"
)).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd2"
)).
back
(),
1
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"Const2"
)).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"Const2"
)).
back
(),
1
);
const
auto
*
invariant_add_2_node
=
view
.
GetNode
(
"InvariantAdd2"
);
ASSERT_NE
(
invariant_add_2_node
,
nullptr
);
const
auto
*
invariant_add_2_node_def
=
invariant_add_2_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
invariant_add_2_node_def
).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
invariant_add_2_node_def
).
back
(),
1
);
const
auto
*
const_2_node
=
view
.
GetNode
(
"Const2"
);
ASSERT_NE
(
const_2_node
,
nullptr
);
const
auto
*
const_2_node_def
=
const_2_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
const_2_node_def
).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
const_2_node_def
).
back
(),
1
);
}
{
// Optimized graph.
GraphView
view
(
&
output
);
Status
status
;
utils
::
GraphView
view
(
&
output
,
&
status
);
TF_ASSERT_OK
(
status
);
FrameView
frames
;
TF_EXPECT_OK
(
frames
.
InferFromGraphView
(
view
));
EXPECT_EQ
(
frames
.
num_frames
(),
2
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd2"
)).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd2"
)).
back
(),
0
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"Const2"
)).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"Const2"
)).
back
(),
0
);
const
auto
*
invariant_add_2_node
=
view
.
GetNode
(
"InvariantAdd2"
);
ASSERT_NE
(
invariant_add_2_node
,
nullptr
);
const
auto
*
invariant_add_2_node_def
=
invariant_add_2_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
invariant_add_2_node_def
).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
invariant_add_2_node_def
).
back
(),
0
);
const
auto
*
const_2_node
=
view
.
GetNode
(
"Const2"
);
ASSERT_NE
(
const_2_node
,
nullptr
);
const
auto
*
const_2_node_def
=
const_2_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
const_2_node_def
).
size
(),
1
);
EXPECT_EQ
(
frames
.
Frames
(
*
const_2_node_def
).
back
(),
0
);
}
}
...
...
@@ -476,25 +574,41 @@ TEST_F(LoopOptimizerTest, NestedLoopConst2) {
TF_EXPECT_OK
(
optimizer
.
Optimize
(
nullptr
,
item
,
&
output
));
{
// Original graph.
GraphView
view
(
&
graph
);
Status
status
;
utils
::
GraphView
view
(
&
graph
,
&
status
);
TF_ASSERT_OK
(
status
);
FrameView
frames
;
TF_EXPECT_OK
(
frames
.
InferFromGraphView
(
view
));
EXPECT_EQ
(
frames
.
num_frames
(),
2
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd2"
)).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd2"
)).
back
(),
1
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"Const2"
)).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"Const2"
)).
back
(),
1
);
const
auto
*
invariant_add_2_node
=
view
.
GetNode
(
"InvariantAdd2"
);
ASSERT_NE
(
invariant_add_2_node
,
nullptr
);
const
auto
*
invariant_add_2_node_def
=
invariant_add_2_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
invariant_add_2_node_def
).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
invariant_add_2_node_def
).
back
(),
1
);
const
auto
*
const_2_node
=
view
.
GetNode
(
"Const2"
);
ASSERT_NE
(
const_2_node
,
nullptr
);
const
auto
*
const_2_node_def
=
const_2_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
const_2_node_def
).
size
(),
2
);
EXPECT_EQ
(
frames
.
Frames
(
*
const_2_node_def
).
back
(),
1
);
}
{
// Optimized graph.
GraphView
view
(
&
output
);
Status
status
;
utils
::
GraphView
view
(
&
output
,
&
status
);
TF_ASSERT_OK
(
status
);
FrameView
frames
;
TF_EXPECT_OK
(
frames
.
InferFromGraphView
(
view
));
EXPECT_EQ
(
frames
.
num_frames
(),
2
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"InvariantAdd2"
)).
size
(),
0
);
ASSERT_EQ
(
frames
.
Frames
(
*
view
.
GetNode
(
"Const2"
)).
size
(),
0
);
const
auto
*
invariant_add_2_node
=
view
.
GetNode
(
"InvariantAdd2"
);
ASSERT_NE
(
invariant_add_2_node
,
nullptr
);
const
auto
*
invariant_add_2_node_def
=
invariant_add_2_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
invariant_add_2_node_def
).
size
(),
0
);
const
auto
*
const_2_node
=
view
.
GetNode
(
"Const2"
);
ASSERT_NE
(
const_2_node
,
nullptr
);
const
auto
*
const_2_node_def
=
const_2_node
->
node
();
ASSERT_EQ
(
frames
.
Frames
(
*
const_2_node_def
).
size
(),
0
);
}
}
...
...
tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h
浏览文件 @
d15c612f
...
...
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace
tensorflow
{
...
...
tensorflow/core/grappler/utils/BUILD
浏览文件 @
d15c612f
...
...
@@ -78,10 +78,10 @@ cc_library(
hdrs
=
[
"frame.h"
],
visibility
=
[
"//visibility:public"
],
deps
=
[
":graph_view"
,
"//tensorflow/core:lib"
,
"//tensorflow/core:lib_internal"
,
"//tensorflow/core:protos_all_cc"
,
"//tensorflow/core/grappler:graph_view"
,
"//tensorflow/core/grappler:op_types"
,
"@com_google_absl//absl/container:flat_hash_map"
,
],
...
...
@@ -93,6 +93,8 @@ tf_cc_test(
srcs
=
[
"frame_test.cc"
],
deps
=
[
":frame"
,
":graph_view"
,
"//tensorflow/core:lib"
,
"//tensorflow/core:lib_proto_parsing"
,
"//tensorflow/core:protos_all_cc"
,
"//tensorflow/core:test"
,
...
...
tensorflow/core/grappler/utils/frame.cc
浏览文件 @
d15c612f
...
...
@@ -14,10 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/frame.h"
#include <deque>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/lib/core/errors.h"
...
...
@@ -26,101 +27,134 @@ namespace grappler {
namespace
{}
// namespace
Status
FrameView
::
InferFromGraphView
(
const
GraphView
&
graph_view
)
{
template
<
typename
GraphViewT
>
inline
Status
FrameView
::
InferFromGraphViewT
(
const
GraphViewT
&
graph_view
)
{
if
(
is_inferred_
)
{
return
errors
::
Internal
(
"FrameView was already inferred from the graph"
);
}
is_inferred_
=
true
;
std
::
deque
<
const
NodeDef
*>
ready_nod
es
;
std
::
deque
<
int
>
ready_node_indic
es
;
// All nodes without inputs are automatically added to the ready queue.
for
(
const
NodeDef
&
node
:
graph_view
.
graph
()
->
node
())
{
if
(
node
.
input_size
()
==
0
)
{
ready_node
s
.
push_back
(
&
node
);
node_to_frames_
[
&
node
]
=
node_has_no_frames_
;
for
(
const
auto
&
node
:
graph_view
.
GetNodes
())
{
if
(
node
.
NumRegularFanins
()
+
node
.
NumControllingFanins
()
==
0
)
{
ready_node
_indices
.
push_back
(
node
.
node_index
()
);
node_to_frames_
[
node
.
node
()
]
=
node_has_no_frames_
;
}
}
const
auto
*
graph
=
graph_view
.
graph
();
// We assign unique int id to each frame, and use this map to track what
// frames we've already seen in the graph.
absl
::
flat_hash_map
<
string
,
int
>
frame_name_to_id
;
while
(
!
ready_nodes
.
empty
())
{
const
NodeDef
*
ready_node
=
ready_nodes
.
front
();
absl
::
flat_hash_set
<
GraphView
::
InputPort
>
fanouts
=
graph_view
.
GetFanouts
(
*
ready_node
,
/*include_controlled_nodes=*/
true
);
auto
process_fanout
=
[
this
,
graph
](
absl
::
flat_hash_map
<
string
,
int
>*
frame_name_to_id
,
std
::
deque
<
int
>*
ready_node_indices
,
const
NodeDef
*
ready_node
,
int
fanout_node_index
)
{
const
NodeDef
*
fanout_node
=
&
graph
->
node
(
fanout_node_index
);
if
(
!
node_to_frames_
.
contains
(
fanout_node
))
{
// If we have never seen this node before, we add all frames from the
// incoming node (and pop/push frames if coming from Exit/Enter nodes).
std
::
vector
<
int
>
frame_ids
=
node_to_frames_
[
ready_node
];
if
(
IsExit
(
*
ready_node
))
{
frame_ids
.
pop_back
();
}
for
(
const
GraphView
::
InputPort
&
fanout
:
fanouts
)
{
if
(
node_to_frames_
.
count
(
fanout
.
node
)
<
1
)
{
// If we have never seen this node before, we add all frames from the
// incoming node (and pop/push frames if coming from Exit/Enter nodes).
std
::
vector
<
int
>
frame_ids
=
node_to_frames_
[
ready_node
];
if
(
IsEnter
(
*
fanout_node
))
{
const
AttrValue
*
frame_name_attr
=
AttrSlice
(
*
fanout_node
).
Find
(
"frame_name"
);
if
(
IsExit
(
*
ready_node
))
{
frame_ids
.
pop_back
();
if
(
!
frame_name_attr
)
{
return
errors
::
InvalidArgument
(
"Missing frame name for the Enter node: "
,
SummarizeNodeDef
(
*
fanout_node
));
}
if
(
IsEnter
(
*
fanout
.
node
))
{
const
AttrValue
*
frame_name_attr
=
AttrSlice
(
*
fanout
.
node
).
Find
(
"frame_name"
);
const
string
&
frame_name
=
frame_name_attr
->
s
();
int
frame_id
;
if
(
!
frame_name_attr
)
{
return
errors
::
InvalidArgument
(
"Missing frame name for the Enter node: "
,
SummarizeNodeDef
(
*
fanout
.
node
));
}
absl
::
string_view
frame_name
=
frame_name_attr
->
s
();
int
frame_id
;
if
(
frame_name_to_id
.
count
(
frame_name
))
{
frame_id
=
frame_name_to_id
[
frame_name
];
}
else
{
frame_id
=
static_cast
<
int
>
(
frame_name_to_id
.
size
());
frame_name_to_id
[
frame_name
]
=
frame_id
;
}
frame_ids
.
push_back
(
frame_id
);
if
(
frame_name_to_id
->
contains
(
frame_name
))
{
frame_id
=
(
*
frame_name_to_id
)[
frame_name
];
}
else
{
frame_id
=
static_cast
<
int
>
(
frame_name_to_id
->
size
());
(
*
frame_name_to_id
)[
frame_name
]
=
frame_id
;
}
ready_nodes
.
push_back
(
fanout
.
node
);
node_to_frames_
[
fanout
.
node
]
=
std
::
move
(
frame_ids
);
frame_ids
.
push_back
(
frame_id
);
}
}
else
{
// If we've already seen this node before, we need to make sure that
// graph is correct and same nodes doesn't have incoming edges with
// conflicting frames (all inputs must be produces in the same frame).
ready_node_indices
->
push_back
(
fanout_node_index
);
node_to_frames_
[
fanout_node
]
=
std
::
move
(
frame_ids
);
std
::
vector
<
int
>
frame_ids_fanout
=
node_to_frames_
[
fanout
.
node
];
std
::
vector
<
int
>
frame_ids_node
=
node_to_frames_
[
ready_node
];
}
else
{
// If we've already seen this node before, we need to make sure that graph
// is correct and same nodes doesn't have incoming edges with conflicting
// frames (all inputs must be produces in the same frame).
if
(
IsEnter
(
*
fanout
.
node
))
{
frame_ids_fanout
.
pop_back
();
}
if
(
IsExit
(
*
ready_node
))
{
frame_ids_node
.
pop_back
();
}
std
::
vector
<
int
>
frame_ids_fanout
=
node_to_frames_
[
fanout_node
];
std
::
vector
<
int
>
frame_ids_node
=
node_to_frames_
[
ready_node
];
if
(
frame_ids_node
!=
frame_ids_fanout
)
{
return
errors
::
InvalidArgument
(
"Invalid graph: Frame ids for node "
,
ready_node
->
name
(),
" does not match frame ids for it's fanout "
,
fanout
.
node
->
name
());
}
if
(
IsEnter
(
*
fanout_node
))
{
frame_ids_fanout
.
pop_back
();
}
if
(
IsExit
(
*
ready_node
))
{
frame_ids_node
.
pop_back
();
}
if
(
frame_ids_node
!=
frame_ids_fanout
)
{
return
errors
::
InvalidArgument
(
"Invalid graph: Frame ids for node "
,
ready_node
->
name
(),
" does not match frame ids for it's fanout "
,
fanout_node
->
name
());
}
}
return
Status
::
OK
();
};
while
(
!
ready_node_indices
.
empty
())
{
const
int
ready_node_index
=
ready_node_indices
.
front
();
ready_node_indices
.
pop_front
();
const
auto
*
ready_node_view
=
graph_view
.
GetNode
(
ready_node_index
);
const
NodeDef
*
ready_node_def
=
ready_node_view
->
node
();
for
(
const
auto
&
regular_fanouts_port_i
:
ready_node_view
->
GetRegularFanouts
())
{
for
(
const
auto
&
regular_fanout
:
regular_fanouts_port_i
)
{
TF_RETURN_IF_ERROR
(
process_fanout
(
&
frame_name_to_id
,
&
ready_node_indices
,
ready_node_def
,
regular_fanout
.
node_index
()));
}
}
ready_nodes
.
pop_front
();
for
(
const
auto
&
controlled_fanout
:
ready_node_view
->
GetControlledFanouts
())
{
TF_RETURN_IF_ERROR
(
process_fanout
(
&
frame_name_to_id
,
&
ready_node_indices
,
ready_node_def
,
controlled_fanout
.
node_index
()));
}
}
num_frames_
=
static_cast
<
int
>
(
frame_name_to_id
.
size
());
return
Status
::
OK
();
}
Status
FrameView
::
InferFromGraphView
(
const
utils
::
GraphView
&
graph_view
)
{
return
InferFromGraphViewT
(
graph_view
);
}
Status
FrameView
::
InferFromGraphView
(
const
utils
::
MutableGraphView
&
graph_view
)
{
return
InferFromGraphViewT
(
graph_view
);
}
Status
FrameView
::
InferFromGraph
(
const
GraphDef
&
graph
)
{
return
InferFromGraphView
(
GraphView
(
&
graph
));
Status
status
;
utils
::
GraphView
graph_view
(
&
graph
,
&
status
);
TF_RETURN_IF_ERROR
(
status
);
return
InferFromGraphViewT
(
graph_view
);
}
const
std
::
vector
<
int
>&
FrameView
::
Frames
(
const
NodeDef
&
node
)
const
{
...
...
tensorflow/core/grappler/utils/frame.h
浏览文件 @
d15c612f
...
...
@@ -16,10 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_
#define TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_
#include <unordered_map>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/
utils/
graph_view.h"
#include "tensorflow/core/lib/core/status.h"
namespace
tensorflow
{
...
...
@@ -40,7 +39,10 @@ class FrameView {
// Infers nodes execution frames from the GraphView. Returns an error if
// called multiple times.
Status
InferFromGraphView
(
const
GraphView
&
graph_view
);
Status
InferFromGraphView
(
const
utils
::
GraphView
&
graph_view
);
// Infers nodes execution frames from the MutableGraphView. Returns an error
// if called multiple times.
Status
InferFromGraphView
(
const
utils
::
MutableGraphView
&
graph_view
);
// Infers nodes execution by constructing temporary GraphView and passing it
// to InferFromGraphView.
Status
InferFromGraph
(
const
GraphDef
&
graph
);
...
...
@@ -56,6 +58,9 @@ class FrameView {
bool
is_inferred
()
const
{
return
is_inferred_
;
}
private:
template
<
typename
GraphViewT
>
inline
Status
InferFromGraphViewT
(
const
GraphViewT
&
graph_view
);
bool
is_inferred_
;
// true if it was inferred from the graph
int
num_frames_
;
// number of frames present in a graph
absl
::
flat_hash_map
<
const
NodeDef
*
,
std
::
vector
<
int
>>
node_to_frames_
;
...
...
tensorflow/core/grappler/utils/frame_test.cc
浏览文件 @
d15c612f
...
...
@@ -14,8 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/frame.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/utils/graph_view.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
...
...
@@ -23,19 +26,23 @@ namespace tensorflow {
namespace
grappler
{
namespace
{
using
GraphTypes
=
::
testing
::
Types
<
GraphDef
,
utils
::
GraphView
,
utils
::
MutableGraphView
>
;
template
<
typename
T
>
class
FrameViewTest
:
public
::
testing
::
Test
{
protected:
static
NodeDef
CreateNode
(
const
string
&
name
,
const
std
::
vector
<
string
>&
inputs
)
{
NodeDef
CreateNode
(
const
string
&
name
,
const
std
::
vector
<
string
>&
inputs
)
{
return
CreateNode
(
name
,
""
,
""
,
inputs
);
}
static
NodeDef
CreateNode
(
const
string
&
name
,
const
string
&
op
,
const
std
::
vector
<
string
>&
inputs
)
{
NodeDef
CreateNode
(
const
string
&
name
,
const
string
&
op
,
const
std
::
vector
<
string
>&
inputs
)
{
return
CreateNode
(
name
,
op
,
""
,
inputs
);
}
static
NodeDef
CreateNode
(
const
string
&
name
,
const
string
&
op
,
const
string
&
frame
,
const
std
::
vector
<
string
>&
inputs
)
{
NodeDef
CreateNode
(
const
string
&
name
,
const
string
&
op
,
const
string
&
frame
,
const
std
::
vector
<
string
>&
inputs
)
{
NodeDef
node
;
node
.
set_name
(
name
);
if
(
!
op
.
empty
())
{
...
...
@@ -53,30 +60,56 @@ class FrameViewTest : public ::testing::Test {
}
};
TEST_F
(
FrameViewTest
,
NestedLoop
)
{
TYPED_TEST_SUITE
(
FrameViewTest
,
GraphTypes
);
template
<
typename
T
>
void
InferFromGraph
(
FrameView
*
frame_view
,
GraphDef
*
graph
,
bool
valid
)
{
Status
status
;
T
graph_view
(
graph
,
&
status
);
TF_ASSERT_OK
(
status
);
status
=
frame_view
->
InferFromGraphView
(
graph_view
);
if
(
valid
)
{
TF_ASSERT_OK
(
status
);
}
else
{
ASSERT_FALSE
(
status
.
ok
());
}
}
template
<
>
void
InferFromGraph
<
GraphDef
>
(
FrameView
*
frame_view
,
GraphDef
*
graph
,
bool
valid
)
{
Status
status
=
frame_view
->
InferFromGraph
(
*
graph
);
if
(
valid
)
{
TF_ASSERT_OK
(
status
);
}
else
{
ASSERT_FALSE
(
status
.
ok
());
}
}
TYPED_TEST
(
FrameViewTest
,
NestedLoop
)
{
GraphDef
graph
;
// Create a two-level nested loop
*
graph
.
add_node
()
=
CreateNode
(
"0"
,
{});
*
graph
.
add_node
()
=
CreateNode
(
"1"
,
"Enter"
,
"while/context1"
,
{
"0"
});
*
graph
.
add_node
()
=
CreateNode
(
"2"
,
{
"1"
});
*
graph
.
add_node
()
=
CreateNode
(
"3"
,
"Merge"
,
{
"2"
,
"14"
});
*
graph
.
add_node
()
=
CreateNode
(
"4"
,
{
"3"
});
*
graph
.
add_node
()
=
CreateNode
(
"5"
,
"Switch"
,
{
"4"
});
*
graph
.
add_node
()
=
CreateNode
(
"6"
,
{
"5"
});
*
graph
.
add_node
()
=
CreateNode
(
"7"
,
"Enter"
,
"while/context2"
,
{
"6"
});
*
graph
.
add_node
()
=
CreateNode
(
"8"
,
{
"7"
});
*
graph
.
add_node
()
=
CreateNode
(
"9"
,
"Merge"
,
{
"8"
,
"12"
});
*
graph
.
add_node
()
=
CreateNode
(
"10"
,
{
"9"
});
*
graph
.
add_node
()
=
CreateNode
(
"11"
,
"Switch"
,
{
"10"
});
*
graph
.
add_node
()
=
CreateNode
(
"12"
,
"NextIteration"
,
{
"11"
});
*
graph
.
add_node
()
=
CreateNode
(
"13"
,
"Exit"
,
{
"11"
});
*
graph
.
add_node
()
=
CreateNode
(
"14"
,
"NextIteration"
,
{
"13"
});
*
graph
.
add_node
()
=
CreateNode
(
"15"
,
{
"5"
});
*
graph
.
add_node
()
=
CreateNode
(
"16"
,
"Exit"
,
{
"15"
});
*
graph
.
add_node
()
=
CreateNode
(
"17"
,
{
"16"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"0"
,
{});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"1"
,
"Enter"
,
"while/context1"
,
{
"0"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"2"
,
{
"1"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"3"
,
"Merge"
,
{
"2"
,
"14"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"4"
,
{
"3"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"5"
,
"Switch"
,
{
"4"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"6"
,
{
"5"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"7"
,
"Enter"
,
"while/context2"
,
{
"6"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"8"
,
{
"7"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"9"
,
"Merge"
,
{
"8"
,
"12"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"10"
,
{
"9"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"11"
,
"Switch"
,
{
"10"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"12"
,
"NextIteration"
,
{
"11"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"13"
,
"Exit"
,
{
"11"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"14"
,
"NextIteration"
,
{
"13"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"15"
,
{
"5"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"16"
,
"Exit"
,
{
"15"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"17"
,
{
"16"
});
FrameView
frame_view
;
ASSERT_TRUE
(
frame_view
.
InferFromGraph
(
graph
).
ok
()
);
InferFromGraph
<
TypeParam
>
(
&
frame_view
,
&
graph
,
/*valid=*/
true
);
std
::
unordered_map
<
string
,
std
::
vector
<
int
>>
expected
=
{
{
"0"
,
{}},
{
"1"
,
{
0
}},
{
"2"
,
{
0
}},
{
"3"
,
{
0
}},
...
...
@@ -93,15 +126,16 @@ TEST_F(FrameViewTest, NestedLoop) {
}
}
T
EST_F
(
FrameViewTest
,
MultipleInputsToEnter
)
{
T
YPED_TEST
(
FrameViewTest
,
MultipleInputsToEnter
)
{
GraphDef
graph
;
*
graph
.
add_node
()
=
CreateNode
(
"0"
,
{});
*
graph
.
add_node
()
=
CreateNode
(
"1"
,
{});
*
graph
.
add_node
()
=
CreateNode
(
"2"
,
"Enter"
,
"while/context"
,
{
"0"
,
"1"
});
*
graph
.
add_node
()
=
CreateNode
(
"3"
,
"Exit"
,
{
"2"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"0"
,
{});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"1"
,
{});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"2"
,
"Enter"
,
"while/context"
,
{
"0"
,
"1"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"3"
,
"Exit"
,
{
"2"
});
FrameView
frame_view
;
ASSERT_TRUE
(
frame_view
.
InferFromGraph
(
graph
).
ok
()
);
InferFromGraph
<
TypeParam
>
(
&
frame_view
,
&
graph
,
/*valid=*/
true
);
std
::
unordered_map
<
string
,
std
::
vector
<
int
>>
expected
=
{
{
"0"
,
{}},
{
"1"
,
{}},
{
"2"
,
{
0
}},
{
"3"
,
{
0
}}};
...
...
@@ -114,16 +148,16 @@ TEST_F(FrameViewTest, MultipleInputsToEnter) {
}
}
T
EST_F
(
FrameViewTest
,
ExitOutput
)
{
T
YPED_TEST
(
FrameViewTest
,
ExitOutput
)
{
GraphDef
graph
;
*
graph
.
add_node
()
=
CreateNode
(
"0"
,
{});
*
graph
.
add_node
()
=
CreateNode
(
"1"
,
"Enter"
,
"while/context"
,
{
"0"
});
*
graph
.
add_node
()
=
CreateNode
(
"2"
,
"Exit"
,
{
"1"
});
*
graph
.
add_node
()
=
CreateNode
(
"3"
,
{});
*
graph
.
add_node
()
=
CreateNode
(
"4"
,
{
"2"
,
"3"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"0"
,
{});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"1"
,
"Enter"
,
"while/context"
,
{
"0"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"2"
,
"Exit"
,
{
"1"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"3"
,
{});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"4"
,
{
"2"
,
"3"
});
FrameView
frame_view
;
ASSERT_TRUE
(
frame_view
.
InferFromGraph
(
graph
).
ok
()
);
InferFromGraph
<
TypeParam
>
(
&
frame_view
,
&
graph
,
/*valid=*/
true
);
std
::
unordered_map
<
string
,
std
::
vector
<
int
>>
expected
=
{
{
"0"
,
{}},
{
"1"
,
{
0
}},
{
"2"
,
{
0
}},
{
"3"
,
{}},
{
"4"
,
{}}};
...
...
@@ -136,21 +170,21 @@ TEST_F(FrameViewTest, ExitOutput) {
}
}
T
EST_F
(
FrameViewTest
,
MultipleEnterNodes
)
{
T
YPED_TEST
(
FrameViewTest
,
MultipleEnterNodes
)
{
GraphDef
graph
;
*
graph
.
add_node
()
=
CreateNode
(
"0"
,
{});
*
graph
.
add_node
()
=
CreateNode
(
"1"
,
"Enter"
,
"while/context"
,
{
"0"
});
*
graph
.
add_node
()
=
CreateNode
(
"2"
,
{
"1"
});
*
graph
.
add_node
()
=
CreateNode
(
"5"
,
{});
*
graph
.
add_node
()
=
CreateNode
(
"4"
,
"Enter"
,
"while/context"
,
{
"5"
});
*
graph
.
add_node
()
=
CreateNode
(
"3"
,
{
"4"
,
"2"
});
*
graph
.
add_node
()
=
CreateNode
(
"6"
,
"Merge"
,
{
"3"
,
"8"
});
*
graph
.
add_node
()
=
CreateNode
(
"7"
,
"Switch"
,
{
"6"
});
*
graph
.
add_node
()
=
CreateNode
(
"8"
,
"NextIteration"
,
{
"7"
});
*
graph
.
add_node
()
=
CreateNode
(
"9"
,
"Exit"
,
{
"7"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"0"
,
{});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"1"
,
"Enter"
,
"while/context"
,
{
"0"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"2"
,
{
"1"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"5"
,
{});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"4"
,
"Enter"
,
"while/context"
,
{
"5"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"3"
,
{
"4"
,
"2"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"6"
,
"Merge"
,
{
"3"
,
"8"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"7"
,
"Switch"
,
{
"6"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"8"
,
"NextIteration"
,
{
"7"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"9"
,
"Exit"
,
{
"7"
});
FrameView
frame_view
;
ASSERT_TRUE
(
frame_view
.
InferFromGraph
(
graph
).
ok
()
);
InferFromGraph
<
TypeParam
>
(
&
frame_view
,
&
graph
,
/*valid=*/
true
);
std
::
unordered_map
<
string
,
std
::
vector
<
int
>>
expected
=
{
{
"0"
,
{}},
{
"1"
,
{
0
}},
{
"2"
,
{
0
}},
{
"3"
,
{
0
}},
{
"4"
,
{
0
}},
...
...
@@ -164,15 +198,15 @@ TEST_F(FrameViewTest, MultipleEnterNodes) {
}
}
T
EST_F
(
FrameViewTest
,
ConflictingFrames
)
{
T
YPED_TEST
(
FrameViewTest
,
ConflictingFrames
)
{
GraphDef
graph
;
*
graph
.
add_node
()
=
CreateNode
(
"0"
,
{});
*
graph
.
add_node
()
=
CreateNode
(
"1"
,
"Enter"
,
"while/context1"
,
{
"0"
});
*
graph
.
add_node
()
=
CreateNode
(
"2"
,
"Enter"
,
"while/context2"
,
{
"1"
});
*
graph
.
add_node
()
=
CreateNode
(
"3"
,
{
"1"
,
"2"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"0"
,
{});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"1"
,
"Enter"
,
"while/context1"
,
{
"0"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"2"
,
"Enter"
,
"while/context2"
,
{
"1"
});
*
graph
.
add_node
()
=
this
->
CreateNode
(
"3"
,
{
"1"
,
"2"
});
FrameView
frame_view
;
ASSERT_FALSE
(
frame_view
.
InferFromGraph
(
graph
).
ok
()
);
InferFromGraph
<
TypeParam
>
(
&
frame_view
,
&
graph
,
/*valid=*/
false
);
}
}
// namespace
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录