From 072494650d52f7f0f0385e33e1d2593843a788f6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 15 Mar 2017 08:58:57 -0800 Subject: [PATCH] [XLA] Add test case for nested while loops. Change: 150204362 --- tensorflow/compiler/xla/tests/while_test.cc | 68 +++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 086e1adccd8..4cff1990865 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -369,6 +369,74 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) { } } +// Tests nested while loops. +// +// int32 result = 0; +// while (result < 30) { +// int i = 0; +// while (i < 7) { +// result = result + 2; +// i = i + 1; +// } +// } +XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { + auto outer_result_shape = ShapeUtil::MakeShape(S32, {}); + auto inner_result_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); + + Computation inner_condition; + { + ComputationBuilder builder(client_, "inner_condition"); + auto params = builder.Parameter(0, inner_result_shape, "prev"); + auto i = builder.GetTupleElement(params, 0); + builder.Lt(i, builder.ConstantR0(7)); + inner_condition = builder.Build().ConsumeValueOrDie(); + } + + // Creates a computation for the outer loop condition: + // repeat while result < 30. + Computation outer_condition; + { + ComputationBuilder builder(client_, "outer_condition"); + auto prev = builder.Parameter(0, outer_result_shape, "prev"); + builder.Lt(prev, builder.ConstantR0(30)); + outer_condition = builder.Build().ConsumeValueOrDie(); + } + + // Creates a computation for the inner loop body: add 1 to `i`, and add 2 to + // `result`. + Computation inner_body; + { + ComputationBuilder builder(client_, "inner_body"); + auto params = builder.Parameter(0, inner_result_shape, "prev"); + auto i = builder.GetTupleElement(params, 0); + auto result = builder.GetTupleElement(params, 1); + i = builder.Add(builder.ConstantR0(1), i); + result = builder.Add(builder.ConstantR0(2), result); + auto output = builder.Tuple({i, result}); + inner_body = builder.Build().ConsumeValueOrDie(); + } + + // Creates a computation for the outer loop: run the inner loop with i = 0. + Computation outer_body; + { + ComputationBuilder builder(client_, "outer_body"); + auto prev = builder.Parameter(0, outer_result_shape, "prev"); + auto init = builder.Tuple({builder.ConstantR0(0), prev}); + auto result = builder.While(inner_condition, inner_body, init); + auto output = builder.GetTupleElement(result, 1); + outer_body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, TestName()); + auto init = builder.ConstantR0(0); + auto result = builder.While(outer_condition, outer_body, init); + auto shape = builder.GetShape(result).ConsumeValueOrDie(); + + ComputeAndCompareR0(&builder, 42, {}); +} + void BM_WhileLoop(int num_iters) { // Benchmark a simple kernel to measure while loop overheads. tensorflow::testing::StopTiming(); -- GitLab