未验证 提交 14b2d686 编写于 作者: S Sachin Joglekar 提交者: GitHub

Merge pull request #41140 from srjoglekar246/cherrypicks_5EM5L

r2.3 cherry-pick request: Shift padded NMS compat window forward to fix TFLite conversion
...@@ -30,6 +30,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin ...@@ -30,6 +30,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import image_ops from tensorflow.python.ops import image_ops
...@@ -774,6 +775,7 @@ class ResizeBilinearNonAlignCornersTest(xla_test.XLATestCase): ...@@ -774,6 +775,7 @@ class ResizeBilinearNonAlignCornersTest(xla_test.XLATestCase):
class NonMaxSuppressionTest(xla_test.XLATestCase): class NonMaxSuppressionTest(xla_test.XLATestCase):
@test_util.disable_mlir_bridge("%1")
def testNMS128From1024(self): def testNMS128From1024(self):
num_boxes = 1024 num_boxes = 1024
boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4")
...@@ -808,6 +810,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): ...@@ -808,6 +810,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
self.assertEqual(indices_tf.size, max_output_size) self.assertEqual(indices_tf.size, max_output_size)
@test_util.disable_mlir_bridge("%1")
def testNMS3From6Boxes(self): def testNMS3From6Boxes(self):
# Three boxes are selected based on IOU. # Three boxes are selected based on IOU.
boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
...@@ -849,6 +852,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): ...@@ -849,6 +852,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
self.assertEqual(num_valid, 3) self.assertEqual(num_valid, 3)
self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) self.assertAllClose(indices_tf[:num_valid], [3, 0, 5])
@test_util.disable_mlir_bridge("%1")
def testNMS3Then2WithScoreThresh(self): def testNMS3Then2WithScoreThresh(self):
# Three boxes are selected based on IOU. # Three boxes are selected based on IOU.
# One is filtered out by score threshold. # One is filtered out by score threshold.
...@@ -891,6 +895,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): ...@@ -891,6 +895,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
self.assertEqual(num_valid, 2) self.assertEqual(num_valid, 2)
self.assertAllClose(indices_tf[:num_valid], [3, 0]) self.assertAllClose(indices_tf[:num_valid], [3, 0])
@test_util.disable_mlir_bridge("%1")
def testNMS3Then1WithScoreMaxThresh(self): def testNMS3Then1WithScoreMaxThresh(self):
# Three boxes are selected based on IOU. # Three boxes are selected based on IOU.
# One is filtered out by score threshold. # One is filtered out by score threshold.
...@@ -934,6 +939,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): ...@@ -934,6 +939,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
self.assertEqual(num_valid, 1) self.assertEqual(num_valid, 1)
self.assertAllClose(indices_tf[:num_valid], [3]) self.assertAllClose(indices_tf[:num_valid], [3])
@test_util.disable_mlir_bridge("%1")
def testSelectFromContinuousOverLap(self): def testSelectFromContinuousOverLap(self):
# Tests that a suppressed box does not itself suppress other boxes. # Tests that a suppressed box does not itself suppress other boxes.
...@@ -978,6 +984,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): ...@@ -978,6 +984,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSFrom6(self): def testBatchedNMSFrom6(self):
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
...@@ -1015,6 +1022,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): ...@@ -1015,6 +1022,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
indices_output) indices_output)
self.assertAllEqual([5, 4], num_valid_output) self.assertAllEqual([5, 4], num_valid_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSFrom6Max3(self): def testBatchedNMSFrom6Max3(self):
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
...@@ -1048,6 +1056,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): ...@@ -1048,6 +1056,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output) self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output)
self.assertAllEqual([3, 3], num_valid_output) self.assertAllEqual([3, 3], num_valid_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSSingleFrom6Max3(self): def testBatchedNMSSingleFrom6Max3(self):
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
...@@ -1078,6 +1087,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): ...@@ -1078,6 +1087,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
self.assertAllEqual([0, 1, 2], indices_output) self.assertAllEqual([0, 1, 2], indices_output)
self.assertAllEqual(3, num_valid_output) self.assertAllEqual(3, num_valid_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSSingleFrom6NoPad(self): def testBatchedNMSSingleFrom6NoPad(self):
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
...@@ -1107,6 +1117,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): ...@@ -1107,6 +1117,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
self.assertAllEqual([0, 1, 2, 4, 5], indices_output) self.assertAllEqual([0, 1, 2, 4, 5], indices_output)
self.assertAllEqual(5, num_valid_output) self.assertAllEqual(5, num_valid_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSBatchDimsFrom6Max3(self): def testBatchedNMSBatchDimsFrom6Max3(self):
boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
...@@ -1140,6 +1151,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): ...@@ -1140,6 +1151,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output) self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output)
self.assertAllEqual([[3, 3]], num_valid_output) self.assertAllEqual([[3, 3]], num_valid_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSScoreThresholdFrom6Max3(self): def testBatchedNMSScoreThresholdFrom6Max3(self):
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
...@@ -1175,6 +1187,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): ...@@ -1175,6 +1187,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
self.assertAllEqual([3, 2], num_valid_output) self.assertAllEqual([3, 2], num_valid_output)
self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSUnsortedInputFrom6(self): def testBatchedNMSUnsortedInputFrom6(self):
boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1], boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1],
[0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]],
...@@ -1211,6 +1224,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): ...@@ -1211,6 +1224,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
indices_output) indices_output)
self.assertAllEqual([5, 4], num_valid_output) self.assertAllEqual([5, 4], num_valid_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSNoncanonicalizedInputFrom6(self): def testBatchedNMSNoncanonicalizedInputFrom6(self):
boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4], boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4],
[1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]], [1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]],
...@@ -1248,6 +1262,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): ...@@ -1248,6 +1262,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
indices_output) indices_output)
self.assertAllEqual([5, 4], num_valid_output) self.assertAllEqual([5, 4], num_valid_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self): def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self):
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
...@@ -1283,6 +1298,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): ...@@ -1283,6 +1298,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
self.assertAllEqual([3, 2], num_valid_output) self.assertAllEqual([3, 2], num_valid_output)
self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSFrom6DynamicInput(self): def testBatchedNMSFrom6DynamicInput(self):
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
......
...@@ -4579,11 +4579,11 @@ def non_max_suppression_padded(boxes, ...@@ -4579,11 +4579,11 @@ def non_max_suppression_padded(boxes,
Raises: Raises:
ValueError: When set pad_to_max_output_size to False for batched input. ValueError: When set pad_to_max_output_size to False for batched input.
""" """
# if no new arguments are used and no later than 2020/4/20, use the old # if no new arguments are used and no later than 2020/6/23, use the old
# version to give us time to fix TFLite conversion # version to give us time to fix TFLite conversion after the TF 2.3 release.
if (not sorted_input) and \ if (not sorted_input) and \
(not canonicalized_coordinates) and \ (not canonicalized_coordinates) and \
tile_size == 512 and not compat.forward_compatible(2020, 4, 20): tile_size == 512 and not compat.forward_compatible(2020, 6, 23):
return non_max_suppression_padded_v1( return non_max_suppression_padded_v1(
boxes, scores, max_output_size, iou_threshold, score_threshold, boxes, scores, max_output_size, iou_threshold, score_threshold,
pad_to_max_output_size, name) pad_to_max_output_size, name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册