提交 3afd339f 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 485693087
上级 69bbdc1c
......@@ -477,3 +477,130 @@ class ArgmaxKerasLayer(tf.keras.layers.Layer):
axis=self.axis,
output_type=self.output_type,
name=self.name)
_or = tf.maximum
_and = tf.minimum
_reduce_or = tf.reduce_max
def _tensor_sum_vectors(a, b):
return tf.reshape(a, [1, 1, 1, -1]) + tf.reshape(b, [1, 1, -1, 1])
def _tensor_product_iou(boxes):
"""Computes pairwise IOU.
Reason to use 4-D tensors is to follow TPU compiler preference.
Args:
boxes: A 2-D float `Tensor` of shape `[num_boxes, 4]`.
Returns:
A 4-D float `Tensor` of shape `[1, 1, num_boxes, num_boxes]` containing
pairwise IOU.
"""
boxes = tf.reshape(boxes, [-1, 4])
boxes = tf.transpose(boxes, [1, 0])
bottom, left, top, right = tf.split(boxes, 4, 0)
height, width = top - bottom, right - left
area = height * width
area_sum = _tensor_sum_vectors(area, area)
bottom_pad, left_pad, top_pad, right_pad = (
tf.nn.relu(_tensor_sum_vectors(x, -x))
for x in (-bottom, -left, top, right))
height_pad, width_pad = bottom_pad + top_pad, left_pad + right_pad
intersection = tf.nn.relu(height - height_pad) * tf.nn.relu(width - width_pad)
union = area_sum - intersection
iou = tf.math.divide(intersection, union + _same(union))
return iou
def _greater(x):
"""Avoid non lowerable layers in boolean comparison.
Logical operation results in tensor of boolean type. However in serving such
a tensors cannot be cast to values because of NNAPI specs.
`tf.where` operation result in `select` instruction lowering, which not runs
well on all generations of edge-tpus.
Args:
x: any numeric tensor.
Returns:
tf.where(x > tf.zero_like(x), tf.one_like(x), tf.zero_like(x))
"""
x_clip = tf.minimum(tf.nn.relu(x), tf.constant(1, dtype=x.dtype))
return -tf.math.floor(-x_clip)
def _same(x):
"""Avoid non lowerable layers in boolean equality.
Logical operation results in tensor of boolean type. However in serving such
a tensors cannot be cast to values because of NNAPI specs.
`tf.where` operation result in `select` instruction lowering, which not runs
well on all generations of edge-tpus.
Args:
x: any numeric tensor.
Returns:
tf.where(x == tf.zero_like(x), tf.one_like(x), tf.zero_like(x))
"""
x_clip = tf.minimum(tf.abs(x), tf.constant(1, dtype=x.dtype))
return tf.constant(1, dtype=x.dtype) + tf.math.floor(-x_clip)
def non_max_suppression_padded(boxes: tf.Tensor,
scores: tf.Tensor,
output_size: int,
iou_threshold: float = 0.5) -> tf.Tensor:
"""Selects a subset of boxes which have highest score among IOU-similar boxes.
Prunes away boxes that have high intersection-over-union (IOU) overlap
with boxes having higher score. Boxes are supplied as `[y1, x1, y2, x2]`,
where `(y1, x1)` and `(y2, x2)` are the coordinates of any diagonal pair of
box corners. Note that this algorithm is agnostic to the coordinate system.
Thus translating or reflections of the coordinate system result in the same
boxes being selected by the algorithm. The output of this operation is a
set of integers indexing into the input collection of bounding boxes
representing the selected boxes.
Set will be returned padded on the right with `-1` values. The bounding
box coordinates corresponding to the selected indices can then be obtained
using the `tf.gather` operation. For example:
```python
selected_indices = vision.modeling.layers.non_max_suppression_padded(
boxes, scores, max_output_size, iou_threshold)
selected_boxes = tf.gather(boxes, selected_indices)
```
Args:
boxes: A 2-D float `Tensor` of shape `[num_boxes, 4]`.
scores: A 1-D float `Tensor` of shape `[num_boxes]` representing a single
score corresponding to each box (each row of boxes).
output_size: A scalar integer `Tensor` representing the maximum number of
boxes to be selected by non-max suppression.
iou_threshold: A 0-D float tensor representing the threshold for deciding
whether boxes overlap too much with respect to IOU.
Returns:
selected_indices: A 1-D integer `Tensor` of shape `[output_size]`
representing the selected indices from the boxes tensor and `-1` values
for the padding.
"""
order = tf.range(tf.size(scores), dtype=tf.float32)
relative_order = _tensor_sum_vectors(order, -order)
relative_scores = _tensor_sum_vectors(scores, -scores)
similar = _greater(_tensor_product_iou(boxes) - iou_threshold)
worse = _greater(relative_scores)
same_later = _and(_same(relative_scores), _greater(relative_order))
similar_worse_or_same_later = _and(similar, _or(worse, same_later))
prunable = _reduce_or(similar_worse_or_same_later, axis=-1)
remaining = (tf.constant(1.) - prunable)
# top_k runs on TPU cores, let it happen, TPU tiles implementation is slower.
top_k = tf.math.top_k(remaining * tf.exp(scores), output_size)
return tf.squeeze(
tf.cast(top_k.indices, top_k.values.dtype) * _greater(top_k.values) -
_same(top_k.values))
......@@ -17,6 +17,7 @@
import itertools
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.projects.edgetpu.vision.modeling import custom_layers
......@@ -186,5 +187,70 @@ class ArgmaxTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(control_output, test_output)
def random_boxes(n):
a = tf.random.uniform(shape=[n, 2])
b = tf.random.uniform(shape=[n, 2])
l = tf.minimum(a, b)
u = tf.maximum(a, b)
return tf.concat([l, u], axis=-1)
class NonMaxSuppressionTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters((16, 8, 500, 0.016), (31, 17, 300, 0.033),
(71, 41, 300, 0.065), (150, 100, 250, 0.137),
(300, 300, 250, 0.126), (600, 600, 100, 0.213))
def test_reference_match(self, n, top, runs, max_deviation):
"""Compares that new optimized method is close to reference method.
Runs two algorithms with same sets of input boxes and scores, and measures
deviation between returned sets of prunned boxes.
(*) Avoid flakiness with safe boundary (go/python-tips/048): deviation
between two sets is a positive number, which may vary from test to test.
Doing multiple runs expected to reduce average deviation variation following
LLN theorem. Therefore by having first test run we know upper deviation
bound which algorithm would not exceed until broken (in any feasible amount
of time in the future). Use of this safe boundary makes test non-flaky.
(**) Parametrized inputs description. See safe deviation choice is higher
than absolute deviation to avoid flaky tesing.
in # | out # | deflake # | test time | deviation | safe threshold
---- | ----- | --------- | --------- | --------- | --------------
18 | 8 | 500 | 6 sec | 0.4% | 1.6%
31 | 17 | 300 | 7 sec | 1.0% | 3.3%
71 | 41 | 300 | 7 sec | 3.4% | 6.5%
150 | 100 | 250 | 7 sec | 8.2% | 13.7%
300 | 300 | 250 | 10 sec | 7.4% | 12.6%
600 | 600 | 100 | 9 sec | 9.6% | 21.3%
Args:
n: number of boxes and scores on input of the algorithm.
top: limit of output boxes count.
runs: for the statistical testing number of runs to performs to avoid
tests flakiness.
max_deviation: mean limit on deviation between optimized and reference
algorithms. Please read notes why this number may be set higher to avoid
flaky testing.
"""
deviation_rate = 0
for _ in range(runs):
boxes = random_boxes(n)
scores = tf.random.uniform(shape=[n])
optimized = custom_layers.non_max_suppression_padded(boxes, scores, top)
optimized = {*optimized.numpy().astype(int).tolist()} - {-1}
reference = tf.image.non_max_suppression(boxes, scores, top)
reference = {*reference.numpy().tolist()}
deviation_rate += len(optimized ^ reference) / len(optimized | reference)
deviation_rate = deviation_rate / runs
# six sigma estimate via LLN theorem
safe_margin = 6 * (deviation_rate / np.sqrt(runs) + 1 / runs)
self.assertLess(
deviation_rate,
max_deviation,
msg='Deviation rate between optimized and reference implementations is '
'higher than expected. If you are tuning the test, recommended safe '
'deviation rate is '
f'{deviation_rate} + {safe_margin} = {deviation_rate + safe_margin}')
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册