提交 a4030d8b 编写于 作者: M Mihai Maruseac

[tflite] Validate segment ids for segment_sum.

Segment identifiers in segment_sum should be in a 1-D tensor of same size as the first dimension of the input. The values of the tensor should be integers from {0, 1, 2, ... k-1}, where k is the first dimension of the input. The segment identifiers must not contain jumps and must be increasing.

See https://www.tensorflow.org/api_docs/python/tf/math#Segmentation as the source for these constraints.

PiperOrigin-RevId: 332510942
Change-Id: I898beaba00642c918bcd4b4d4ce893ebb190d869
上级 5e303955
......@@ -32,11 +32,24 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
const TfLiteTensor* data,
const TfLiteTensor* segment_ids,
TfLiteTensor* output) {
int max_index = -1;
// Segment ids should be of same cardinality as first input dimension and they
// should be increasing by at most 1, from 0 (e.g., [0, 0, 1, 2, 3] is valid)
const int segment_id_size = segment_ids->dims->data[0];
if (segment_id_size > 0) {
max_index = segment_ids->data.i32[segment_id_size - 1];
TF_LITE_ENSURE_EQ(context, segment_id_size, data->dims->data[0]);
int previous_segment_id = -1;
for (int i = 0; i < segment_id_size; i++) {
const int current_segment_id = GetTensorData<int32_t>(segment_ids)[i];
if (i == 0) {
TF_LITE_ENSURE_EQ(context, current_segment_id, 0);
} else {
int delta = current_segment_id - previous_segment_id;
TF_LITE_ENSURE(context, delta == 0 || delta == 1);
}
previous_segment_id = current_segment_id;
}
const int max_index = previous_segment_id;
const int data_rank = NumDimensions(data);
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(data));
output_shape->data[0] = max_index + 1;
......
......@@ -108,5 +108,37 @@ TEST(SegmentSumOpModelTest, Float32Test_ThreeDimensions) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 1}));
}
TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNotSorted) {
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2}},
{TensorType_INT32, {3}});
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 3, 1});
ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError);
}
TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNotConsecutive) {
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2}},
{TensorType_INT32, {3}});
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 3, 5});
ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError);
}
TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNegative) {
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2}},
{TensorType_INT32, {3}});
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
model.PopulateTensor<int32_t>(model.segment_ids(), {-1, 0, 1});
ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError);
}
TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNotTheRightCardinality) {
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2}},
{TensorType_INT32, {2}});
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1});
ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError);
}
} // namespace
} // namespace tflite
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册