diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 7c50a53a620ca0724c78a4abb321cb0d9fd3ce4e..bbadb955356e0f9c7736b4667628d56892e1f684 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -237,23 +237,12 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase): self._assertOpOutputMatchesExpected(wrap_sort, inputs, expected=inputs) - @parameterized.product( - dtype=[ - dtypes.bfloat16.as_numpy_dtype, - np.float16, - np.float32, - np.float64, - np.int32, - np.uint32, - np.int64, - np.uint64, - np.uint8, - np.int8, - ], - rank=[1, 2, 3], - ) - def testTopK(self, dtype, rank): - if dtype in self.numeric_types: + def testTopK(self): + supported_types = set([ + dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64, + np.int32, np.uint32, np.int64, np.uint64, np.uint8, np.int8, + ]) + for dtype in supported_types.intersection(self.numeric_types): # Use small input size for bfloat16. Otherwise, we'll get duplicate values # after conversion to bfloat16, so the possible resulting index array is # no longer unique. @@ -266,26 +255,53 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase): else: array_size = 200 * 1000 k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000] + for x in [np.arange(array_size)]: + np.random.shuffle(x) + for k in k_options: + indices = x.argsort()[::-1][:k] - # Tile array to tensor of specified rank, then shuffle along the last dim - x = np.arange(array_size) - x = np.tile(x, (2,) * (rank - 1) + (1,)) - np.apply_along_axis(np.random.shuffle, -1, x) + def topk(v, k=k): + return nn_ops.top_k(v, k=k, sorted=True) - sorted_indices = x.argsort(axis=-1)[..., ::-1] - sorted_values = np.sort(x, axis=-1)[..., ::-1] - for k in k_options: - indices = sorted_indices[..., :k] - expected = sorted_values[..., :k] + self._assertOpOutputMatchesExpected( + topk, [x.astype(dtype)], + expected=[x[indices].astype(dtype), indices]) - def topk(v, k=k): - return nn_ops.top_k(v, k=k, sorted=True) + @parameterized.named_parameters( + ("HalfPrecision", dtypes.bfloat16.as_numpy_dtype), + ("HalfFloatPrecision", np.float16), + ("SinglePrecision", np.float32), + ("DoublePrecision", np.float64), + ("Int32", np.int32), + ("UnsignedInt32", np.uint32), + ("Int64", np.int64), + ("UnsignedInt64", np.uint64), + ) + def testTopK2D(self, dtype): + if dtype in self.numeric_types: + # Use small input size for bfloat16. Otherwise, we'll get duplicate values + # after conversion to bfloat16, so the possible resulting index array is + # no longer unique. + if dtype in (dtypes.bfloat16.as_numpy_dtype, np.float16): + array_size = 10 + k_options = [0, 1, 2, 10] + else: + array_size = 200 * 1000 + k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000] + batch = 16 + for x in [np.arange(batch * array_size)]: + np.random.shuffle(x) + x = np.reshape(x, [batch, array_size]) + for k in k_options: + indices = x.argsort(axis=1)[::, -1:-k - 1:-1] + expected = np.sort(x, axis=1)[::, -1:-k - 1:-1] - self._assertOpOutputMatchesExpected( - topk, - [x.astype(dtype)], - expected=[expected.astype(dtype), indices], - ) + def topk(v, k=k): + return nn_ops.top_k(v, k=k, sorted=True) + + self._assertOpOutputMatchesExpected( + topk, [x.astype(dtype)], + expected=[expected.astype(dtype), indices]) def testTopKZeros(self): """Tests that positive and negative zeros sort correctly.""" diff --git a/tensorflow/compiler/xla/service/topk_rewriter.cc b/tensorflow/compiler/xla/service/topk_rewriter.cc index 8117ce4599c92c552320db1334be0c7f64a62e87..37e5bbca234e57f071ef97d4b70bc4e43eaeaade 100644 --- a/tensorflow/compiler/xla/service/topk_rewriter.cc +++ b/tensorflow/compiler/xla/service/topk_rewriter.cc @@ -196,6 +196,8 @@ std::optional TopkRewriter::SortIsInTopK(HloInstruction* inst) { return std::nullopt; } const int64_t sort_dim = sort->sort_dimension(); + const int64_t batch_dim = sort_dim == 1 ? 0 : 1; + const bool has_batch = data->shape().rank() == 2; bool supported = true; std::optional k; @@ -220,15 +222,10 @@ std::optional TopkRewriter::SortIsInTopK(HloInstruction* inst) { supported = false; break; } - for (int64_t i = 0; i < slice->slice_limits().size(); ++i) { - if (i != sort_dim && - slice->slice_limits(i) != slice->operand(0)->shape().dimensions(i)) { - // Slicing along a non-sort dimension isn't supported. - supported = false; - break; - } - } - if (!supported) { + if (has_batch && slice->slice_limits(batch_dim) != + slice->operand(0)->shape().dimensions(batch_dim)) { + // Slicing along the batch dimension isn't supported. + supported = false; break; } if (k == std::nullopt) { @@ -260,57 +257,29 @@ StatusOr TopkRewriter::TransformToCustomCall( HloSortInstruction* sort = DynCast(inst); HloInstruction* data = sort->mutable_operand(0); const PrimitiveType element_type = data->shape().element_type(); - const Shape data_shape = data->shape(); - if (element_type != F32 && element_type != BF16) { + if ((data->shape().rank() != 1 && data->shape().rank() != 2) || + (element_type != F32 && element_type != BF16)) { continue; } - // Sort dimension must be the first or last dimension. const int64_t sort_dim = sort->sort_dimension(); - if (sort_dim != 0 && sort_dim != data_shape.rank() - 1) { - continue; - } + const int64_t batch_dim = sort_dim == 1 ? 0 : 1; + const bool has_batch = data->shape().rank() == 2; // Profitability check. if (!is_profitable_to_convert_(sort, *k)) { continue; } - HloInstruction* input = data; - const bool has_batch = data_shape.rank() >= 2; - const int64_t input_size = data_shape.dimensions(sort_dim); - int64_t batch_size = 1; - Shape topk_input_shape; - - if (has_batch) { - // The TopK custom call expects either a 1d tensor or a 2d tensor with - // the last dimension being the sort dimension. An input with rank > 2 - // is reshaped into a 2d tensor by combining non-sort dimensions into a - // single batch dimension. The original non-sort dimensions are - // restored for the outputs with another reshape after the custom call. - batch_size = - ShapeUtil::ElementsIn(data_shape) / data_shape.dimensions(sort_dim); - topk_input_shape = - ShapeUtil::MakeShape(element_type, {batch_size, input_size}); - - if (data_shape.rank() > 2) { - // Reshape to 2d. - input = comp->AddInstruction(HloInstruction::CreateReshape( - sort_dim == 0 - ? ShapeUtil::MakeShape(element_type, {input_size, batch_size}) - : ShapeUtil::MakeShape(element_type, - {batch_size, input_size}), - input)); - } - - if (sort_dim == 0) { - // Transpose for the custom call when sorting the first dimension. - input = comp->AddInstruction( - HloInstruction::CreateTranspose(topk_input_shape, input, {1, 0})); - } - } else { - topk_input_shape = data_shape; + const int64_t batch_size = + has_batch ? sort->operand(0)->shape().dimensions(batch_dim) : 1; + const int64_t input_size = sort->operand(0)->shape().dimensions(sort_dim); + HloInstruction* input = sort->mutable_operand(0); + if (has_batch && sort_dim == 0) { + input = comp->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input, + {1, 0})); } Shape topk_shape = @@ -331,28 +300,13 @@ StatusOr TopkRewriter::TransformToCustomCall( comp->AddInstruction(HloInstruction::CreateGetTupleElement( topk->shape().tuple_shapes(1), topk, 1)); - if (has_batch) { - if (sort_dim == 0) { - // Transpose back. - value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), - value_gte, {1, 0})); - index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, - {1, 0})); - } - if (data_shape.rank() > 2) { - // Reshape back. - Shape value_shape = data_shape; - value_shape.set_dimensions(sort_dim, k.value()); - value_gte = comp->AddInstruction( - HloInstruction::CreateReshape(value_shape, value_gte)); - - Shape index_shape = - ShapeUtil::MakeShape(S32, value_shape.dimensions()); - index_gte = comp->AddInstruction( - HloInstruction::CreateReshape(index_shape, index_gte)); - } + if (has_batch && sort_dim == 0) { + value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), + value_gte, {1, 0})); + index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, + {1, 0})); } for (HloInstruction* user : sort->users()) { diff --git a/tensorflow/compiler/xla/service/topk_rewriter_test.cc b/tensorflow/compiler/xla/service/topk_rewriter_test.cc index cfd1c7bb650090f6a2a3e01df240c0bec504d6e4..cc31870e462e4d1a3b0c787004d60bcfbb1e1d2c 100644 --- a/tensorflow/compiler/xla/service/topk_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/topk_rewriter_test.cc @@ -326,42 +326,6 @@ ENTRY cluster { EXPECT_THAT(cc->custom_call_target(), "TopK"); } -TEST_F(TopkRewriterTest, RewriteReshape) { - const std::string hlo_string = R"( -HloModule module -)" + getComparator() + R"( -ENTRY cluster { - %arg_tuple.1 = f32[3,8,1234567] parameter(0) - %iota.4 = s32[3,8,1234567] iota(), iota_dimension=2 - %sort.27 = (f32[3,8,1234567], s32[3,8,1234567]) sort(%arg_tuple.1, %iota.4), - dimensions={2}, is_stable=true, to_apply=%compare - %get-tuple-element.28 = f32[3, 8,1234567] get-tuple-element(%sort.27), index=0 - %slice.29 = f32[3,8,5] slice(%get-tuple-element.28), slice={[0:3], [0:8], [0:5]} - %get-tuple-element.30 = s32[3,8,1234567] get-tuple-element(%sort.27), index=1 - %slice.31 = s32[3,8,5] slice(%get-tuple-element.30), slice={[0:3], [0:8], [0:5]} - ROOT %tuple.32 = (f32[3,8,5], s32[3,8,5]) tuple(%slice.29, %slice.31) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TopkRewriter rewriter( - [](const HloSortInstruction*, int64_t) { return true; }); - TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get())); - TF_ASSERT_OK(HloDCE().Run(module.get()).status()); - EXPECT_TRUE(changed); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Tuple( - m::Reshape(m::GetTupleElement( - m::CustomCall(m::Reshape(m::Parameter(0))), 0)), - m::Reshape(m::GetTupleElement( - m::CustomCall(m::Reshape(m::Parameter(0))), 1))))); - const HloInstruction* cc = module->entry_computation() - ->root_instruction() - ->operand(0) - ->operand(0) - ->operand(0); - EXPECT_THAT(cc->custom_call_target(), "TopK"); -} - TEST_F(TopkRewriterTest, RewriteNoIota) { const std::string hlo_string = R"( HloModule module