提交 95abdf8d 编写于 作者: K Kevin Chen 提交者: TensorFlower Gardener

Rollback of: Update TopkRewriter to handle tensors with rank > 2

Broke argsort on gpu.

PiperOrigin-RevId: 561136138
上级 26612c86
......@@ -237,23 +237,12 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
self._assertOpOutputMatchesExpected(wrap_sort, inputs, expected=inputs)
@parameterized.product(
dtype=[
dtypes.bfloat16.as_numpy_dtype,
np.float16,
np.float32,
np.float64,
np.int32,
np.uint32,
np.int64,
np.uint64,
np.uint8,
np.int8,
],
rank=[1, 2, 3],
)
def testTopK(self, dtype, rank):
if dtype in self.numeric_types:
def testTopK(self):
supported_types = set([
dtypes.bfloat16.as_numpy_dtype, np.float16, np.float32, np.float64,
np.int32, np.uint32, np.int64, np.uint64, np.uint8, np.int8,
])
for dtype in supported_types.intersection(self.numeric_types):
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
# after conversion to bfloat16, so the possible resulting index array is
# no longer unique.
......@@ -266,26 +255,53 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
else:
array_size = 200 * 1000
k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
for x in [np.arange(array_size)]:
np.random.shuffle(x)
for k in k_options:
indices = x.argsort()[::-1][:k]
# Tile array to tensor of specified rank, then shuffle along the last dim
x = np.arange(array_size)
x = np.tile(x, (2,) * (rank - 1) + (1,))
np.apply_along_axis(np.random.shuffle, -1, x)
def topk(v, k=k):
return nn_ops.top_k(v, k=k, sorted=True)
sorted_indices = x.argsort(axis=-1)[..., ::-1]
sorted_values = np.sort(x, axis=-1)[..., ::-1]
for k in k_options:
indices = sorted_indices[..., :k]
expected = sorted_values[..., :k]
self._assertOpOutputMatchesExpected(
topk, [x.astype(dtype)],
expected=[x[indices].astype(dtype), indices])
def topk(v, k=k):
return nn_ops.top_k(v, k=k, sorted=True)
@parameterized.named_parameters(
("HalfPrecision", dtypes.bfloat16.as_numpy_dtype),
("HalfFloatPrecision", np.float16),
("SinglePrecision", np.float32),
("DoublePrecision", np.float64),
("Int32", np.int32),
("UnsignedInt32", np.uint32),
("Int64", np.int64),
("UnsignedInt64", np.uint64),
)
def testTopK2D(self, dtype):
if dtype in self.numeric_types:
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
# after conversion to bfloat16, so the possible resulting index array is
# no longer unique.
if dtype in (dtypes.bfloat16.as_numpy_dtype, np.float16):
array_size = 10
k_options = [0, 1, 2, 10]
else:
array_size = 200 * 1000
k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
batch = 16
for x in [np.arange(batch * array_size)]:
np.random.shuffle(x)
x = np.reshape(x, [batch, array_size])
for k in k_options:
indices = x.argsort(axis=1)[::, -1:-k - 1:-1]
expected = np.sort(x, axis=1)[::, -1:-k - 1:-1]
self._assertOpOutputMatchesExpected(
topk,
[x.astype(dtype)],
expected=[expected.astype(dtype), indices],
)
def topk(v, k=k):
return nn_ops.top_k(v, k=k, sorted=True)
self._assertOpOutputMatchesExpected(
topk, [x.astype(dtype)],
expected=[expected.astype(dtype), indices])
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
......
......@@ -196,6 +196,8 @@ std::optional<int64_t> TopkRewriter::SortIsInTopK(HloInstruction* inst) {
return std::nullopt;
}
const int64_t sort_dim = sort->sort_dimension();
const int64_t batch_dim = sort_dim == 1 ? 0 : 1;
const bool has_batch = data->shape().rank() == 2;
bool supported = true;
std::optional<int64_t> k;
......@@ -220,15 +222,10 @@ std::optional<int64_t> TopkRewriter::SortIsInTopK(HloInstruction* inst) {
supported = false;
break;
}
for (int64_t i = 0; i < slice->slice_limits().size(); ++i) {
if (i != sort_dim &&
slice->slice_limits(i) != slice->operand(0)->shape().dimensions(i)) {
// Slicing along a non-sort dimension isn't supported.
supported = false;
break;
}
}
if (!supported) {
if (has_batch && slice->slice_limits(batch_dim) !=
slice->operand(0)->shape().dimensions(batch_dim)) {
// Slicing along the batch dimension isn't supported.
supported = false;
break;
}
if (k == std::nullopt) {
......@@ -260,57 +257,29 @@ StatusOr<bool> TopkRewriter::TransformToCustomCall(
HloSortInstruction* sort = DynCast<HloSortInstruction>(inst);
HloInstruction* data = sort->mutable_operand(0);
const PrimitiveType element_type = data->shape().element_type();
const Shape data_shape = data->shape();
if (element_type != F32 && element_type != BF16) {
if ((data->shape().rank() != 1 && data->shape().rank() != 2) ||
(element_type != F32 && element_type != BF16)) {
continue;
}
// Sort dimension must be the first or last dimension.
const int64_t sort_dim = sort->sort_dimension();
if (sort_dim != 0 && sort_dim != data_shape.rank() - 1) {
continue;
}
const int64_t batch_dim = sort_dim == 1 ? 0 : 1;
const bool has_batch = data->shape().rank() == 2;
// Profitability check.
if (!is_profitable_to_convert_(sort, *k)) {
continue;
}
HloInstruction* input = data;
const bool has_batch = data_shape.rank() >= 2;
const int64_t input_size = data_shape.dimensions(sort_dim);
int64_t batch_size = 1;
Shape topk_input_shape;
if (has_batch) {
// The TopK custom call expects either a 1d tensor or a 2d tensor with
// the last dimension being the sort dimension. An input with rank > 2
// is reshaped into a 2d tensor by combining non-sort dimensions into a
// single batch dimension. The original non-sort dimensions are
// restored for the outputs with another reshape after the custom call.
batch_size =
ShapeUtil::ElementsIn(data_shape) / data_shape.dimensions(sort_dim);
topk_input_shape =
ShapeUtil::MakeShape(element_type, {batch_size, input_size});
if (data_shape.rank() > 2) {
// Reshape to 2d.
input = comp->AddInstruction(HloInstruction::CreateReshape(
sort_dim == 0
? ShapeUtil::MakeShape(element_type, {input_size, batch_size})
: ShapeUtil::MakeShape(element_type,
{batch_size, input_size}),
input));
}
if (sort_dim == 0) {
// Transpose for the custom call when sorting the first dimension.
input = comp->AddInstruction(
HloInstruction::CreateTranspose(topk_input_shape, input, {1, 0}));
}
} else {
topk_input_shape = data_shape;
const int64_t batch_size =
has_batch ? sort->operand(0)->shape().dimensions(batch_dim) : 1;
const int64_t input_size = sort->operand(0)->shape().dimensions(sort_dim);
HloInstruction* input = sort->mutable_operand(0);
if (has_batch && sort_dim == 0) {
input = comp->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input,
{1, 0}));
}
Shape topk_shape =
......@@ -331,28 +300,13 @@ StatusOr<bool> TopkRewriter::TransformToCustomCall(
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
topk->shape().tuple_shapes(1), topk, 1));
if (has_batch) {
if (sort_dim == 0) {
// Transpose back.
value_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(element_type, {k.value(), batch_size}),
value_gte, {1, 0}));
index_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte,
{1, 0}));
}
if (data_shape.rank() > 2) {
// Reshape back.
Shape value_shape = data_shape;
value_shape.set_dimensions(sort_dim, k.value());
value_gte = comp->AddInstruction(
HloInstruction::CreateReshape(value_shape, value_gte));
Shape index_shape =
ShapeUtil::MakeShape(S32, value_shape.dimensions());
index_gte = comp->AddInstruction(
HloInstruction::CreateReshape(index_shape, index_gte));
}
if (has_batch && sort_dim == 0) {
value_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(element_type, {k.value(), batch_size}),
value_gte, {1, 0}));
index_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte,
{1, 0}));
}
for (HloInstruction* user : sort->users()) {
......
......@@ -326,42 +326,6 @@ ENTRY cluster {
EXPECT_THAT(cc->custom_call_target(), "TopK");
}
TEST_F(TopkRewriterTest, RewriteReshape) {
const std::string hlo_string = R"(
HloModule module
)" + getComparator() + R"(
ENTRY cluster {
%arg_tuple.1 = f32[3,8,1234567] parameter(0)
%iota.4 = s32[3,8,1234567] iota(), iota_dimension=2
%sort.27 = (f32[3,8,1234567], s32[3,8,1234567]) sort(%arg_tuple.1, %iota.4),
dimensions={2}, is_stable=true, to_apply=%compare
%get-tuple-element.28 = f32[3, 8,1234567] get-tuple-element(%sort.27), index=0
%slice.29 = f32[3,8,5] slice(%get-tuple-element.28), slice={[0:3], [0:8], [0:5]}
%get-tuple-element.30 = s32[3,8,1234567] get-tuple-element(%sort.27), index=1
%slice.31 = s32[3,8,5] slice(%get-tuple-element.30), slice={[0:3], [0:8], [0:5]}
ROOT %tuple.32 = (f32[3,8,5], s32[3,8,5]) tuple(%slice.29, %slice.31)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TopkRewriter rewriter(
[](const HloSortInstruction*, int64_t) { return true; });
TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
TF_ASSERT_OK(HloDCE().Run(module.get()).status());
EXPECT_TRUE(changed);
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::Tuple(
m::Reshape(m::GetTupleElement(
m::CustomCall(m::Reshape(m::Parameter(0))), 0)),
m::Reshape(m::GetTupleElement(
m::CustomCall(m::Reshape(m::Parameter(0))), 1)))));
const HloInstruction* cc = module->entry_computation()
->root_instruction()
->operand(0)
->operand(0)
->operand(0);
EXPECT_THAT(cc->custom_call_target(), "TopK");
}
TEST_F(TopkRewriterTest, RewriteNoIota) {
const std::string hlo_string = R"(
HloModule module
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册