提交 07249465 编写于 作者: P Peter Hawkins 提交者: TensorFlower Gardener

[XLA] Add test case for nested while loops.

Change: 150204362
上级 500277ad
......@@ -369,6 +369,74 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) {
}
}
// Tests nested while loops.
//
// int32 result = 0;
// while (result < 30) {
// int i = 0;
// while (i < 7) {
// result = result + 2;
// i = i + 1;
// }
// }
XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
auto outer_result_shape = ShapeUtil::MakeShape(S32, {});
auto inner_result_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})});
Computation inner_condition;
{
ComputationBuilder builder(client_, "inner_condition");
auto params = builder.Parameter(0, inner_result_shape, "prev");
auto i = builder.GetTupleElement(params, 0);
builder.Lt(i, builder.ConstantR0<int32>(7));
inner_condition = builder.Build().ConsumeValueOrDie();
}
// Creates a computation for the outer loop condition:
// repeat while result < 30.
Computation outer_condition;
{
ComputationBuilder builder(client_, "outer_condition");
auto prev = builder.Parameter(0, outer_result_shape, "prev");
builder.Lt(prev, builder.ConstantR0<int32>(30));
outer_condition = builder.Build().ConsumeValueOrDie();
}
// Creates a computation for the inner loop body: add 1 to `i`, and add 2 to
// `result`.
Computation inner_body;
{
ComputationBuilder builder(client_, "inner_body");
auto params = builder.Parameter(0, inner_result_shape, "prev");
auto i = builder.GetTupleElement(params, 0);
auto result = builder.GetTupleElement(params, 1);
i = builder.Add(builder.ConstantR0<int32>(1), i);
result = builder.Add(builder.ConstantR0<int32>(2), result);
auto output = builder.Tuple({i, result});
inner_body = builder.Build().ConsumeValueOrDie();
}
// Creates a computation for the outer loop: run the inner loop with i = 0.
Computation outer_body;
{
ComputationBuilder builder(client_, "outer_body");
auto prev = builder.Parameter(0, outer_result_shape, "prev");
auto init = builder.Tuple({builder.ConstantR0<int32>(0), prev});
auto result = builder.While(inner_condition, inner_body, init);
auto output = builder.GetTupleElement(result, 1);
outer_body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
ComputationBuilder builder(client_, TestName());
auto init = builder.ConstantR0<int32>(0);
auto result = builder.While(outer_condition, outer_body, init);
auto shape = builder.GetShape(result).ConsumeValueOrDie();
ComputeAndCompareR0<int32>(&builder, 42, {});
}
void BM_WhileLoop(int num_iters) {
// Benchmark a simple kernel to measure while loop overheads.
tensorflow::testing::StopTiming();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册