Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
07249465
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
07249465
编写于
3月 15, 2017
作者:
P
Peter Hawkins
提交者:
TensorFlower Gardener
3月 15, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XLA] Add test case for nested while loops.
Change: 150204362
上级
500277ad
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
68 addition
and
0 deletion
+68
-0
tensorflow/compiler/xla/tests/while_test.cc
tensorflow/compiler/xla/tests/while_test.cc
+68
-0
未找到文件。
tensorflow/compiler/xla/tests/while_test.cc
浏览文件 @
07249465
...
...
@@ -369,6 +369,74 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) {
}
}
// Tests nested while loops.
//
// int32 result = 0;
// while (result < 30) {
// int i = 0;
// while (i < 7) {
// result = result + 2;
// i = i + 1;
// }
// }
XLA_TEST_F
(
WhileTest
,
NestedWhileWithScalarResult
)
{
auto
outer_result_shape
=
ShapeUtil
::
MakeShape
(
S32
,
{});
auto
inner_result_shape
=
ShapeUtil
::
MakeTupleShape
(
{
ShapeUtil
::
MakeShape
(
S32
,
{}),
ShapeUtil
::
MakeShape
(
S32
,
{})});
Computation
inner_condition
;
{
ComputationBuilder
builder
(
client_
,
"inner_condition"
);
auto
params
=
builder
.
Parameter
(
0
,
inner_result_shape
,
"prev"
);
auto
i
=
builder
.
GetTupleElement
(
params
,
0
);
builder
.
Lt
(
i
,
builder
.
ConstantR0
<
int32
>
(
7
));
inner_condition
=
builder
.
Build
().
ConsumeValueOrDie
();
}
// Creates a computation for the outer loop condition:
// repeat while result < 30.
Computation
outer_condition
;
{
ComputationBuilder
builder
(
client_
,
"outer_condition"
);
auto
prev
=
builder
.
Parameter
(
0
,
outer_result_shape
,
"prev"
);
builder
.
Lt
(
prev
,
builder
.
ConstantR0
<
int32
>
(
30
));
outer_condition
=
builder
.
Build
().
ConsumeValueOrDie
();
}
// Creates a computation for the inner loop body: add 1 to `i`, and add 2 to
// `result`.
Computation
inner_body
;
{
ComputationBuilder
builder
(
client_
,
"inner_body"
);
auto
params
=
builder
.
Parameter
(
0
,
inner_result_shape
,
"prev"
);
auto
i
=
builder
.
GetTupleElement
(
params
,
0
);
auto
result
=
builder
.
GetTupleElement
(
params
,
1
);
i
=
builder
.
Add
(
builder
.
ConstantR0
<
int32
>
(
1
),
i
);
result
=
builder
.
Add
(
builder
.
ConstantR0
<
int32
>
(
2
),
result
);
auto
output
=
builder
.
Tuple
({
i
,
result
});
inner_body
=
builder
.
Build
().
ConsumeValueOrDie
();
}
// Creates a computation for the outer loop: run the inner loop with i = 0.
Computation
outer_body
;
{
ComputationBuilder
builder
(
client_
,
"outer_body"
);
auto
prev
=
builder
.
Parameter
(
0
,
outer_result_shape
,
"prev"
);
auto
init
=
builder
.
Tuple
({
builder
.
ConstantR0
<
int32
>
(
0
),
prev
});
auto
result
=
builder
.
While
(
inner_condition
,
inner_body
,
init
);
auto
output
=
builder
.
GetTupleElement
(
result
,
1
);
outer_body
=
builder
.
Build
().
ConsumeValueOrDie
();
}
// Create a While node with computations for the condition and the body.
ComputationBuilder
builder
(
client_
,
TestName
());
auto
init
=
builder
.
ConstantR0
<
int32
>
(
0
);
auto
result
=
builder
.
While
(
outer_condition
,
outer_body
,
init
);
auto
shape
=
builder
.
GetShape
(
result
).
ConsumeValueOrDie
();
ComputeAndCompareR0
<
int32
>
(
&
builder
,
42
,
{});
}
void
BM_WhileLoop
(
int
num_iters
)
{
// Benchmark a simple kernel to measure while loop overheads.
tensorflow
::
testing
::
StopTiming
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录