提交 bdf4a7bb 编写于 作者: F Fan Yang 提交者: A. Unique TensorFlower

Internal change.

PiperOrigin-RevId: 492381485
上级 462459cf
......@@ -101,14 +101,15 @@ def _same(x):
return tf.constant(1, dtype=x.dtype) + tf.math.floor(-x_clip)
def shard_tensors(axis: int, block_size: int,
*tensors: tf.Tensor) -> Iterable[Sequence[tf.Tensor]]:
def shard_tensors(
axis: int, block_size: int,
tensors: Sequence[tf.Tensor]) -> Iterable[Sequence[tf.Tensor]]:
"""Consistently splits multiple tensors sharding-style.
Args:
axis: axis to be used to split tensors
block_size: block size to split tensors.
*tensors: list of tensors.
tensors: list of tensors.
Returns:
List of shards, each shard has exactly one peace of each input tesnor.
......@@ -206,7 +207,7 @@ def non_max_suppression_padded(boxes: tf.Tensor,
scores = tf.reshape(scores, [batch_size, boxes_size])
block = max(1, _RECOMMENDED_NMS_MEMORY // (boxes_size * boxes_size))
indices = []
for boxes_i, scores_i in shard_tensors(0, block, boxes, scores):
for boxes_i, scores_i in shard_tensors(0, block, (boxes, scores)):
indices.append(
_non_max_suppression_as_is(boxes_i, scores_i, output_size,
iou_threshold, refinements))
......
......@@ -191,7 +191,7 @@ class NonMaxSuppressionTest(parameterized.TestCase, tf.test.TestCase):
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
]])
for i, (a_i, b_i) in enumerate(edgetpu.shard_tensors(1, 3, a, b)):
for i, (a_i, b_i) in enumerate(edgetpu.shard_tensors(1, 3, (a, b))):
self.assertAllEqual(a_i, a[:, i * 3:i * 3 + 3])
self.assertAllEqual(b_i, b[:, i * 3:i * 3 + 3, :])
......@@ -233,7 +233,7 @@ class NonMaxSuppressionTest(parameterized.TestCase, tf.test.TestCase):
shape=axis * [1] + [10000], dtype=tf.float32)
top_1000_direct: tf.Tensor = tf.math.top_k(sample, 1000).values
top_1000_sharded: Optional[tf.Tensor] = None
for (piece,) in edgetpu.shard_tensors(axis, 1500, sample):
for (piece,) in edgetpu.shard_tensors(axis, 1500, (sample,)):
(top_1000_sharded,) = edgetpu.concat_and_top_k(
1000, (top_1000_sharded, piece))
self.assertAllEqual(top_1000_direct, top_1000_sharded)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册