Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
5938a718
M
Models
项目概览
曾经的那一瞬间
/
Models
大约 1 年 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
5938a718
编写于
11月 29, 2022
作者:
A
A. Unique TensorFlower
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Internal change
PiperOrigin-RevId: 491555987
上级
c5662b16
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
519 addition
and
480 deletion
+519
-480
official/projects/edgetpu/vision/modeling/custom_layers.py
official/projects/edgetpu/vision/modeling/custom_layers.py
+1
-309
official/projects/edgetpu/vision/modeling/custom_layers_test.py
...al/projects/edgetpu/vision/modeling/custom_layers_test.py
+0
-169
official/vision/modeling/layers/detection_generator.py
official/vision/modeling/layers/detection_generator.py
+2
-2
official/vision/modeling/layers/edgetpu.py
official/vision/modeling/layers/edgetpu.py
+327
-0
official/vision/modeling/layers/edgetpu_test.py
official/vision/modeling/layers/edgetpu_test.py
+189
-0
未找到文件。
official/projects/edgetpu/vision/modeling/custom_layers.py
浏览文件 @
5938a718
...
...
@@ -14,10 +14,9 @@
"""Customized keras layers used in the EdgeTPU models."""
from
collections.abc
import
Iterable
,
MutableMapping
,
Sequence
from
collections.abc
import
MutableMapping
import
inspect
from
typing
import
Any
,
Optional
,
Union
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
...
...
@@ -479,310 +478,3 @@ class ArgmaxKerasLayer(tf.keras.layers.Layer):
axis
=
self
.
axis
,
output_type
=
self
.
output_type
,
name
=
self
.
name
)
_or
=
tf
.
maximum
_and
=
tf
.
minimum
_reduce_or
=
tf
.
reduce_max
def
_tensor_sum_vectors
(
a
,
b
):
a
=
tf
.
tile
(
tf
.
reshape
(
a
,
[
1
,
-
1
,
1
,
a
.
shape
[
-
1
]]),
[
1
,
1
,
a
.
shape
[
-
1
],
1
])
b
=
tf
.
tile
(
tf
.
reshape
(
b
,
[
1
,
-
1
,
a
.
shape
[
-
1
],
1
]),
[
1
,
1
,
1
,
a
.
shape
[
-
1
]])
return
a
+
b
def
_tensor_product_iou
(
boxes
):
"""Computes pairwise IOU.
Reason to use 4-D tensors is to follow TPU compiler preference.
Args:
boxes: A 2-D float `Tensor` of shape `[num_boxes, 4]`.
Returns:
A 4-D float `Tensor` of shape `[1, 1, num_boxes, num_boxes]` containing
pairwise IOU.
"""
boxes_size
=
boxes
.
shape
[
-
2
]
# Code below will do frequent operands broadcasting.
# TPU compiler has (empirically) less issues broadcasting if
# - batch (first) dimension is 1. (Special consideration sharding)
# - there are 4 dimensions. (Standard traversal mapping)
# - last dimension is not 1. (Structure alignment)
tpu_friendly_shape
=
[
1
,
-
1
,
1
,
boxes_size
]
bottom
,
left
,
top
,
right
=
(
tf
.
reshape
(
side
,
tpu_friendly_shape
)
for
side
in
tf
.
split
(
boxes
,
4
,
-
1
))
height
,
width
=
top
-
bottom
,
right
-
left
area
=
height
*
width
area_sum
=
_tensor_sum_vectors
(
area
,
area
)
bottom_pad
,
left_pad
,
top_pad
,
right_pad
=
(
tf
.
nn
.
relu
(
_tensor_sum_vectors
(
x
,
-
x
))
for
x
in
(
-
bottom
,
-
left
,
top
,
right
))
height_pad
,
width_pad
=
bottom_pad
+
top_pad
,
left_pad
+
right_pad
intersection
=
tf
.
nn
.
relu
(
height
-
height_pad
)
*
tf
.
nn
.
relu
(
width
-
width_pad
)
union
=
area_sum
-
intersection
iou
=
tf
.
math
.
divide
(
intersection
,
union
+
_same
(
union
))
return
iou
def
_greater
(
x
):
"""Avoid non lowerable layers in boolean comparison.
Logical operation results in tensor of boolean type. However in serving such
a tensors cannot be cast to values because of NNAPI specs.
`tf.where` operation result in `select` instruction lowering, which not runs
well on all generations of edge-tpus.
Args:
x: any numeric tensor.
Returns:
tf.where(x > tf.zero_like(x), tf.one_like(x), tf.zero_like(x))
"""
x_clip
=
tf
.
minimum
(
tf
.
nn
.
relu
(
x
),
tf
.
constant
(
1
,
dtype
=
x
.
dtype
))
return
-
tf
.
math
.
floor
(
-
x_clip
)
def
_same
(
x
):
"""Avoid non lowerable layers in boolean equality.
Logical operation results in tensor of boolean type. However in serving such
a tensors cannot be cast to values because of NNAPI specs.
`tf.where` operation result in `select` instruction lowering, which not runs
well on all generations of edge-tpus.
Args:
x: any numeric tensor.
Returns:
tf.where(x == tf.zero_like(x), tf.one_like(x), tf.zero_like(x))
"""
x_clip
=
tf
.
minimum
(
tf
.
abs
(
x
),
tf
.
constant
(
1
,
dtype
=
x
.
dtype
))
return
tf
.
constant
(
1
,
dtype
=
x
.
dtype
)
+
tf
.
math
.
floor
(
-
x_clip
)
def
shard_tensors
(
axis
:
int
,
block_size
:
int
,
*
tensors
:
tf
.
Tensor
)
->
Iterable
[
Sequence
[
tf
.
Tensor
]]:
"""Consistently splits multiple tensors sharding-style.
Args:
axis: axis to be used to split tensors
block_size: block size to split tensors.
*tensors: list of tensors.
Returns:
List of shards, each shard has exactly one peace of each input tesnor.
Raises:
ValueError: if input tensors has different size of sharded dimension.
"""
for
validate_axis
in
range
(
axis
+
1
):
consistent_length
:
int
=
tensors
[
0
].
shape
[
validate_axis
]
for
tensor
in
tensors
:
if
tensor
.
shape
[
validate_axis
]
!=
consistent_length
:
raise
ValueError
(
'Inconsistent shapes in shard_tensors: first is '
f
'
{
tensors
[
0
].
shape
}
and other is
{
tensor
.
shape
}
'
)
batch_size
:
int
=
tensors
[
0
].
shape
[
axis
]
if
block_size
>=
batch_size
:
return
[
tensors
]
else
:
blocks
=
batch_size
//
block_size
remainder
=
batch_size
%
block_size
if
remainder
:
tensor_parts
=
[]
for
tensor
in
tensors
:
shape
:
tf
.
TensorShape
=
tensor
.
shape
body
:
tf
.
Tensor
=
tf
.
slice
(
tensor
,
[
0
]
*
len
(
shape
),
[
size
if
i
!=
axis
else
blocks
*
block_size
for
i
,
size
in
enumerate
(
shape
)
])
tail
:
tf
.
Tensor
=
tf
.
slice
(
tensor
,
[
0
if
i
!=
axis
else
(
blocks
*
block_size
)
for
i
,
_
in
enumerate
(
shape
)
],
[
size
if
i
!=
axis
else
(
size
-
blocks
*
block_size
)
for
i
,
size
in
enumerate
(
shape
)
])
tensor_parts
.
append
(
tf
.
split
(
body
,
blocks
,
axis
)
+
[
tail
])
return
zip
(
*
tensor_parts
)
else
:
return
zip
(
*
[
tf
.
split
(
tensor
,
blocks
,
axis
)
for
tensor
in
tensors
])
# TODO(b/258007436): Number is based on existing compiler limitations while
# running bf16 NMS on edgetpu. Remove manual sharing when compiler issue will be
# fixed.
_RECOMMENDED_NMS_MEMORY
=
360000
def
non_max_suppression_padded
(
boxes
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
output_size
:
int
,
iou_threshold
:
float
=
0.5
)
->
tf
.
Tensor
:
"""Selects a subset of boxes which have highest score among IOU-similar boxes.
Prunes away boxes that have high intersection-over-union (IOU) overlap
with boxes having higher score. Boxes are supplied as `[y1, x1, y2, x2]`,
where `(y1, x1)` and `(y2, x2)` are the coordinates of any diagonal pair of
box corners. Note that this algorithm is agnostic to the coordinate system.
Thus translating or reflections of the coordinate system result in the same
boxes being selected by the algorithm. The output of this operation is a
set of integers indexing into the input collection of bounding boxes
representing the selected boxes.
Set will be returned padded on the right with `-1` values. The bounding
box coordinates corresponding to the selected indices can then be obtained
using the `tf.gather` operation. For example:
```python
selected_indices = vision.modeling.layers.non_max_suppression_padded(
boxes, scores, max_output_size, iou_threshold)
selected_boxes = tf.gather(boxes, selected_indices)
```
See following documetation for implementation details.
third_party/tensorflow_models/official/projects/edgetpu/vision/modeling/g3doc/non_max_suppression.md
Args:
boxes: A 2-D+ float `Tensor` of shape `[...batch_dims, num_boxes, 4]`.
scores: A 1-D+ float `Tensor` of shape `[...batch_dims, num_boxes]`
representing a single score corresponding to each box (each row of boxes).
output_size: A scalar integer `Tensor` representing the maximum number of
boxes to be selected by non-max suppression.
iou_threshold: A 0-D float tensor representing the threshold for deciding
whether boxes overlap too much with respect to IOU.
Returns:
A 1-D+ integer `Tensor` of shape `[...batch_dims, output_size]` representing
the selected indices from the boxes tensor and `-1` values for the padding.
"""
# Does partitioning job to help compiler converge with memory.
batch_shape
=
boxes
.
shape
[:
-
2
]
batch_size
=
tf
.
reduce_prod
(
batch_shape
).
numpy
()
boxes_size
,
struct_size
=
boxes
.
shape
[
-
2
:]
boxes
=
tf
.
reshape
(
boxes
,
[
batch_size
,
boxes_size
,
struct_size
])
scores
=
tf
.
reshape
(
scores
,
[
batch_size
,
boxes_size
])
block
=
max
(
1
,
_RECOMMENDED_NMS_MEMORY
//
(
boxes_size
*
boxes_size
))
indices
=
[]
for
boxes_i
,
scores_i
in
shard_tensors
(
0
,
block
,
boxes
,
scores
):
indices
.
append
(
_non_max_suppression_as_is
(
boxes_i
,
scores_i
,
output_size
,
iou_threshold
))
indices
=
tf
.
concat
(
indices
,
axis
=
0
)
return
tf
.
reshape
(
indices
,
batch_shape
+
[
output_size
])
def
_non_max_suppression_as_is
(
boxes
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
output_size
:
int
,
iou_threshold
:
float
=
0.5
)
->
tf
.
Tensor
:
"""Selects a subset of boxes which have highest score among IOU-similar boxes.
Args:
boxes: A 2-D+ float `Tensor` of shape `[...batch_dims, num_boxes, 4]`.
scores: A 1-D+ float `Tensor` of shape `[...batch_dims, num_boxes]`
representing a single score corresponding to each box (each row of boxes).
output_size: A scalar integer `Tensor` representing the maximum number of
boxes to be selected by non-max suppression.
iou_threshold: A 0-D float tensor representing the threshold for deciding
whether boxes overlap too much with respect to IOU.
Returns:
A 1-D+ integer `Tensor` of shape `[...batch_dims, output_size]` representing
the selected indices from the boxes tensor and `-1` values for the padding.
"""
batch_shape
=
boxes
.
shape
[:
-
2
]
batch_size
=
tf
.
reduce_prod
(
batch_shape
).
numpy
()
boxes_size
=
boxes
.
shape
[
-
2
]
if
boxes
.
shape
[
-
1
]
!=
4
:
raise
ValueError
(
f
'Boxes shape (
{
boxes
.
shape
}
) last dimension must be 4 '
'to represent [y1, x1, y2, x2] boxes coordinates'
)
if
scores
.
shape
!=
boxes
.
shape
[:
-
1
]:
raise
ValueError
(
f
'Boxes shape (
{
boxes
.
shape
}
) and scores shape '
f
'(
{
scores
.
shape
}
) do not match.'
)
order
=
tf
.
range
(
boxes_size
,
dtype
=
tf
.
float32
)
relative_order
=
_tensor_sum_vectors
(
order
,
-
order
)
relative_scores
=
_tensor_sum_vectors
(
scores
,
-
scores
)
similar
=
_greater
(
_tensor_product_iou
(
boxes
)
-
iou_threshold
)
worse
=
_greater
(
relative_scores
)
same_later
=
_and
(
_same
(
relative_scores
),
_greater
(
relative_order
))
similar_worse_or_same_later
=
_and
(
similar
,
_or
(
worse
,
same_later
))
prunable
=
_reduce_or
(
similar_worse_or_same_later
,
axis
=-
1
)
remaining
=
tf
.
constant
(
1.
)
-
prunable
scores
=
tf
.
reshape
(
tf
.
exp
(
scores
),
[
1
,
1
,
batch_size
,
boxes_size
])
remaining
=
tf
.
reshape
(
remaining
,
[
1
,
1
,
batch_size
,
boxes_size
])
# top_k runs on TPU cores, let it happen, TPU tiles implementation is slower.
top_k
=
tf
.
math
.
top_k
(
scores
*
remaining
,
output_size
)
indices
=
(
tf
.
cast
(
top_k
.
indices
,
top_k
.
values
.
dtype
)
*
_greater
(
top_k
.
values
)
-
_same
(
top_k
.
values
))
return
tf
.
reshape
(
indices
,
batch_shape
+
[
output_size
])
def
concat_and_top_k
(
top_k
:
int
,
scores_pair
:
tuple
[
Optional
[
tf
.
Tensor
],
tf
.
Tensor
],
*
other_pairs
:
tuple
[
Optional
[
tf
.
Tensor
],
tf
.
Tensor
]
)
->
tuple
[
tf
.
Tensor
,
...]:
"""Combines shards of top_k operation, when sharded along filtered dimension.
General idea is that sometimes top_k dimension is very large, while top_k is
moderately low. (Keep in mind sample of 15K pre-top_k dimension and 150 top_k)
In that case it is possible to break top_k input into groups significantly
larger than top_k and significatly lower than pre-top_l (Keep in mind 1500).
We do top_k over first 1500 elements, than join 150 remaining with new 1500
elements (1750 in total), repeat top_k. This function provides repeatedly used
method which will concat and top_k in that case.
For example with top_k = 2 and scores_pair = ([10, 6], [9, 8, 7]), output
scores will be [10, 9].
Other pairs are filtered using indexes generated from scores. This is a preaty
common case of filtering structure by its score.
For example with one extra pair of box per score:
top_k = 2
scores_pair = ([10, 6],
[9, 8, 7])
other_pairs = [([[0, 0, 10, 10], [0, 0, 6, 6]],
[[1, 1, 9, 9], [1, 1, 8, 8], [1, 1, 7, 7]])]
Output is:
([10, 9], [[0, 0, 10, 10], [1, 1, 9, 9]])
See also 'test_top_k_sharded_fusion' unit test with end to end example.
Args:
top_k: is top_k argument of sharded tf.math.top_k.
scores_pair: Tuple (<previous shards combination>, <additional shard>)
scores to be aggregated using top_k.
*other_pairs: Tuples (<previous shards combination>, <additional shard>)
other values to be aggregated using indexes of top_k scores.
Returns:
Tuple of scores based top_k aggregations with additional shards.
"""
scores
,
scores_shard
=
scores_pair
if
other_pairs
:
others
,
others_shard
=
zip
(
*
other_pairs
)
else
:
others
=
others_shard
=
[]
# Same as tf.rank, but avoiding tensor form for graph mode execution.
top_k_dim
:
int
=
len
(
scores_shard
.
shape
)
-
1
if
scores
is
None
:
# First shard becomes aggregation
scores
=
scores_shard
others
=
others_shard
else
:
# Merge shard into agregation
scores
=
tf
.
concat
([
scores
,
scores_shard
],
top_k_dim
)
others
=
[
tf
.
concat
([
other
,
other_shard
],
top_k_dim
)
for
other
,
other_shard
in
zip
(
others
,
others_shard
)
]
# When shards are uneven some will be smaller than requested top_k
if
scores
.
shape
[
top_k_dim
]
>
top_k
:
scores
,
indices
=
tf
.
nn
.
top_k
(
scores
,
top_k
)
others
=
[
tf
.
gather
(
other
,
indices
,
axis
=
top_k_dim
,
batch_dims
=
top_k_dim
)
for
other
in
others
]
return
scores
,
*
others
official/projects/edgetpu/vision/modeling/custom_layers_test.py
浏览文件 @
5938a718
...
...
@@ -15,10 +15,8 @@
"""Tests for custom_layers."""
import
itertools
from
typing
import
Optional
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.projects.edgetpu.vision.modeling
import
custom_layers
...
...
@@ -186,170 +184,3 @@ class ArgmaxTest(parameterized.TestCase, tf.test.TestCase):
test_output
=
custom_layers
.
argmax
(
random_inputs
,
axis
=
axis
,
output_type
=
output_type
)
self
.
assertAllEqual
(
control_output
,
test_output
)
def
random_boxes
(
shape
):
a
=
tf
.
random
.
uniform
(
shape
=
shape
+
[
2
])
b
=
tf
.
random
.
uniform
(
shape
=
shape
+
[
2
])
l
=
tf
.
minimum
(
a
,
b
)
u
=
tf
.
maximum
(
a
,
b
)
return
tf
.
concat
([
l
,
u
],
axis
=-
1
)
def
_maximum_activation_size
(
model
):
max_size
=
0
for
layer
in
model
.
layers
:
outputs
=
layer
.
output
if
not
isinstance
(
outputs
,
list
):
outputs
=
[
outputs
]
for
output
in
outputs
:
if
hasattr
(
output
,
'shape'
):
size
=
np
.
prod
(
output
.
shape
)
max_size
=
max
(
max_size
,
size
)
print
(
'Layer'
,
size
,
output
.
shape
,
layer
.
name
)
return
max_size
class
NonMaxSuppressionTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
tf
.
random
.
set_seed
(
42
)
@
parameterized
.
parameters
((
16
,
8
,
200
,
0.009
),
(
31
,
17
,
100
,
0.013
),
(
71
,
41
,
100
,
0.045
),
(
150
,
100
,
100
,
0.129
),
(
300
,
300
,
100
,
0.116
),
(
600
,
600
,
50
,
0.176
))
def
test_reference_match
(
self
,
n
,
top
,
runs
,
max_deviation
):
"""Compares that new optimized method is close to reference method.
Runs two algorithms with same sets of input boxes and scores, and measures
deviation between returned sets of prunned boxes.
Read more about test results at ./g3doc/non_max_suppression.md
(*) Avoid flakiness with safe boundary (go/python-tips/048): deviation
between two sets is a positive number, which may vary from test to test.
Doing multiple runs expected to reduce average deviation variation following
LLN theorem. Therefore by having first test run we know upper deviation
bound which algorithm would not exceed until broken (in any feasible amount
of time in the future). Use of this safe boundary makes test non-flaky.
Args:
n: number of boxes and scores on input of the algorithm.
top: limit of output boxes count.
runs: for the statistical testing number of runs to performs to avoid
tests flakiness.
max_deviation: mean limit on deviation between optimized and reference
algorithms. Please read notes why this number may be set higher to avoid
flaky testing.
"""
deviation_rate
=
0
min_union
=
2
*
n
boxes
=
random_boxes
([
runs
,
n
])
scores
=
tf
.
random
.
uniform
(
shape
=
[
runs
,
n
])
test
=
custom_layers
.
non_max_suppression_padded
(
boxes
,
scores
,
top
)
for
run
in
range
(
runs
):
reference
=
tf
.
image
.
non_max_suppression
(
boxes
[
run
],
scores
[
run
],
top
)
reference
=
{
*
reference
.
numpy
().
tolist
()}
optimized
=
{
*
test
[
run
].
numpy
().
astype
(
int
).
tolist
()}
-
{
-
1
}
union_size
=
len
(
optimized
|
reference
)
deviation_rate
+=
len
(
optimized
^
reference
)
/
union_size
min_union
=
min
(
min_union
,
union_size
)
deviation_rate
=
deviation_rate
/
runs
# six sigma estimate via LLN theorem
safe_margin
=
6
*
(
deviation_rate
/
np
.
sqrt
(
runs
)
+
1
/
(
runs
*
min_union
))
self
.
assertLess
(
deviation_rate
,
max_deviation
,
msg
=
'Deviation rate between optimized and reference implementations is '
'higher than expected. If you are tuning the test, recommended safe '
'deviation rate is '
f
'
{
deviation_rate
}
+
{
safe_margin
}
=
{
deviation_rate
+
safe_margin
}
'
)
@
parameterized
.
parameters
(([
16
],
8
),
([
91
,
150
],
100
),
([
20
,
20
,
200
],
10
))
def
test_sharded_match
(
self
,
shape
:
list
[
int
],
top
:
int
):
boxes
=
random_boxes
(
shape
)
scores
=
tf
.
random
.
uniform
(
shape
=
shape
)
optimized
=
custom_layers
.
non_max_suppression_padded
(
boxes
,
scores
,
top
)
reference
=
custom_layers
.
_non_max_suppression_as_is
(
boxes
,
scores
,
top
)
self
.
assertAllEqual
(
optimized
,
reference
)
_sharded_nms
=
custom_layers
.
non_max_suppression_padded
_stright_nms
=
custom_layers
.
_non_max_suppression_as_is
@
parameterized
.
parameters
(([
16
],
8
,
_sharded_nms
,
True
),
([
16
],
8
,
_stright_nms
,
True
),
([
91
,
150
],
100
,
_sharded_nms
,
True
),
([
91
,
150
],
100
,
_stright_nms
,
False
),
([
20
,
20
,
200
],
10
,
_sharded_nms
,
True
),
([
20
,
20
,
200
],
10
,
_stright_nms
,
False
))
def
test_sharded_size
(
self
,
shape
:
list
[
int
],
top
:
int
,
algorithm
,
fits_as_is
:
bool
):
scores
=
tf
.
keras
.
Input
(
shape
=
shape
,
batch_size
=
1
)
boxes
=
tf
.
keras
.
Input
(
shape
=
shape
+
[
4
],
batch_size
=
1
)
optimized
=
algorithm
(
boxes
,
scores
,
top
)
model
=
tf
.
keras
.
Model
(
inputs
=
[
boxes
,
scores
],
outputs
=
optimized
)
max_size
=
_maximum_activation_size
(
model
)
if
fits_as_is
:
# Sharding done or not needed.
self
.
assertLessEqual
(
max_size
,
custom_layers
.
_RECOMMENDED_NMS_MEMORY
)
else
:
# Sharding needed.
self
.
assertGreater
(
max_size
,
custom_layers
.
_RECOMMENDED_NMS_MEMORY
)
def
test_shard_tensors
(
self
):
a
:
tf
.
Tensor
=
tf
.
constant
([[
0
,
1
,
2
,
3
,
4
]])
b
:
tf
.
Tensor
=
tf
.
constant
([[
[
0
,
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
8
,
9
],
[
10
,
11
,
12
,
13
,
14
],
[
15
,
16
,
17
,
18
,
19
],
[
20
,
21
,
22
,
23
,
24
],
]])
for
i
,
(
a_i
,
b_i
)
in
enumerate
(
custom_layers
.
shard_tensors
(
1
,
3
,
a
,
b
)):
self
.
assertAllEqual
(
a_i
,
a
[:,
i
*
3
:
i
*
3
+
3
])
self
.
assertAllEqual
(
b_i
,
b
[:,
i
*
3
:
i
*
3
+
3
,
:])
def
test_top_k_sharded_fusion_arguments_validation
(
self
):
# Input scores is not pair of agregation and shard.
self
.
assertRaises
(
ValueError
,
custom_layers
.
concat_and_top_k
,
100
,
tf
.
zeros
(
shape
=
[
1000
]))
# Input other values is not pairs of agregation and shard.
self
.
assertRaises
(
TypeError
,
custom_layers
.
concat_and_top_k
,
100
,
(
None
,
tf
.
zeros
(
shape
=
[
1000
])),
None
,
tf
.
zeros
(
shape
=
[
1000
]))
# Insufficient rank to do top_k
self
.
assertRaises
(
IndexError
,
custom_layers
.
concat_and_top_k
,
100
,
(
None
,
tf
.
constant
(
1.
)))
@
parameterized
.
parameters
(
0
,
1
,
2
)
def
test_top_k_sharded_fusion_vs_top_k_unsharded
(
self
,
axis
:
int
):
r
"""Tests `horizontal` sharding using shard_tensors and concat_and_top_k.
Will generate and test graph (on diagram 4 shards, in test 6 shards):
Input
-----
|
+-------+--------------------------------------------
| Split |----------------------- \
+-------+--- \ |
| \ | |
+-------+ +--------+ +-------+ +--------+ +-------+ +--------+ +-------+
| top k |-| concat |-| top k |-| concat |-| top k |-| concat |-| top k |
+-------+ +--------+ +-------+ +--------+ +-------+ +--------+ +-------+
|
Output
------
Args:
axis: test top_k axis (tensor rank will be axis + 1)
"""
sample
:
tf
.
Tensor
=
tf
.
random
.
uniform
(
shape
=
axis
*
[
1
]
+
[
10000
],
dtype
=
tf
.
float32
)
top_1000_direct
:
tf
.
Tensor
=
tf
.
math
.
top_k
(
sample
,
1000
).
values
top_1000_sharded
:
Optional
[
tf
.
Tensor
]
=
None
for
(
piece
,)
in
custom_layers
.
shard_tensors
(
axis
,
1500
,
sample
):
(
top_1000_sharded
,)
=
custom_layers
.
concat_and_top_k
(
1000
,
(
top_1000_sharded
,
piece
))
self
.
assertAllEqual
(
top_1000_direct
,
top_1000_sharded
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/modeling/layers/detection_generator.py
浏览文件 @
5938a718
...
...
@@ -18,7 +18,7 @@ from typing import Any, Dict, List, Optional, Mapping, Sequence, Tuple
# Import libraries
import
tensorflow
as
tf
from
official.
projects.edgetpu.vision.modeling
import
custom_layers
from
official.
vision.modeling.layers
import
edgetpu
from
official.vision.ops
import
box_ops
from
official.vision.ops
import
nms
from
official.vision.ops
import
preprocess_ops
...
...
@@ -428,7 +428,7 @@ def _generate_detections_v3(
boxes
,
scores
,
min_score_threshold
=
pre_nms_score_threshold
)
# EdgeTPU-friendly class-wise NMS, -1 for invalid.
indices
=
custom_layers
.
non_max_suppression_padded
(
indices
=
edgetpu
.
non_max_suppression_padded
(
boxes
,
scores
,
max_num_detections
,
...
...
official/vision/modeling/layers/edgetpu.py
0 → 100644
浏览文件 @
5938a718
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""EdgeTPU oriented layers and tools."""
from
collections.abc
import
Iterable
,
Sequence
from
typing
import
Optional
import
numpy
as
np
import
tensorflow
as
tf
_or
=
tf
.
maximum
_and
=
tf
.
minimum
_reduce_or
=
tf
.
reduce_max
def
_tensor_sum_vectors
(
a
,
b
):
a
=
tf
.
tile
(
tf
.
reshape
(
a
,
[
1
,
-
1
,
1
,
a
.
shape
[
-
1
]]),
[
1
,
1
,
a
.
shape
[
-
1
],
1
])
b
=
tf
.
tile
(
tf
.
reshape
(
b
,
[
1
,
-
1
,
a
.
shape
[
-
1
],
1
]),
[
1
,
1
,
1
,
a
.
shape
[
-
1
]])
return
a
+
b
def
_tensor_product_iou
(
boxes
):
"""Computes pairwise IOU.
Reason to use 4-D tensors is to follow TPU compiler preference.
Args:
boxes: A 2-D float `Tensor` of shape `[num_boxes, 4]`.
Returns:
A 4-D float `Tensor` of shape `[1, 1, num_boxes, num_boxes]` containing
pairwise IOU.
"""
boxes_size
=
boxes
.
shape
[
-
2
]
# Code below will do frequent operands broadcasting.
# TPU compiler has (empirically) less issues broadcasting if
# - batch (first) dimension is 1. (Special consideration sharding)
# - there are 4 dimensions. (Standard traversal mapping)
# - last dimension is not 1. (Structure alignment)
tpu_friendly_shape
=
[
1
,
-
1
,
1
,
boxes_size
]
bottom
,
left
,
top
,
right
=
(
tf
.
reshape
(
side
,
tpu_friendly_shape
)
for
side
in
tf
.
split
(
boxes
,
4
,
-
1
))
height
,
width
=
top
-
bottom
,
right
-
left
area
=
height
*
width
area_sum
=
_tensor_sum_vectors
(
area
,
area
)
bottom_pad
,
left_pad
,
top_pad
,
right_pad
=
(
tf
.
nn
.
relu
(
_tensor_sum_vectors
(
x
,
-
x
))
for
x
in
(
-
bottom
,
-
left
,
top
,
right
))
height_pad
,
width_pad
=
bottom_pad
+
top_pad
,
left_pad
+
right_pad
intersection
=
tf
.
nn
.
relu
(
height
-
height_pad
)
*
tf
.
nn
.
relu
(
width
-
width_pad
)
union
=
area_sum
-
intersection
iou
=
tf
.
math
.
divide
(
intersection
,
union
+
_same
(
union
))
return
iou
def
_greater
(
x
):
"""Avoid non lowerable layers in boolean comparison.
Logical operation results in tensor of boolean type. However in serving such
a tensors cannot be cast to values because of NNAPI specs.
`tf.where` operation result in `select` instruction lowering, which not runs
well on all generations of edge-tpus.
Args:
x: any numeric tensor.
Returns:
tf.where(x > tf.zero_like(x), tf.one_like(x), tf.zero_like(x))
"""
x_clip
=
tf
.
minimum
(
tf
.
nn
.
relu
(
x
),
tf
.
constant
(
1
,
dtype
=
x
.
dtype
))
return
-
tf
.
math
.
floor
(
-
x_clip
)
def
_same
(
x
):
"""Avoid non lowerable layers in boolean equality.
Logical operation results in tensor of boolean type. However in serving such
a tensors cannot be cast to values because of NNAPI specs.
`tf.where` operation result in `select` instruction lowering, which not runs
well on all generations of edge-tpus.
Args:
x: any numeric tensor.
Returns:
tf.where(x == tf.zero_like(x), tf.one_like(x), tf.zero_like(x))
"""
x_clip
=
tf
.
minimum
(
tf
.
abs
(
x
),
tf
.
constant
(
1
,
dtype
=
x
.
dtype
))
return
tf
.
constant
(
1
,
dtype
=
x
.
dtype
)
+
tf
.
math
.
floor
(
-
x_clip
)
def
shard_tensors
(
axis
:
int
,
block_size
:
int
,
*
tensors
:
tf
.
Tensor
)
->
Iterable
[
Sequence
[
tf
.
Tensor
]]:
"""Consistently splits multiple tensors sharding-style.
Args:
axis: axis to be used to split tensors
block_size: block size to split tensors.
*tensors: list of tensors.
Returns:
List of shards, each shard has exactly one peace of each input tesnor.
Raises:
ValueError: if input tensors has different size of sharded dimension.
"""
for
validate_axis
in
range
(
axis
+
1
):
consistent_length
:
int
=
tensors
[
0
].
shape
[
validate_axis
]
for
tensor
in
tensors
:
if
tensor
.
shape
[
validate_axis
]
!=
consistent_length
:
raise
ValueError
(
'Inconsistent shapes in shard_tensors: first is '
f
'
{
tensors
[
0
].
shape
}
and other is
{
tensor
.
shape
}
'
)
batch_size
:
int
=
tensors
[
0
].
shape
[
axis
]
if
block_size
>=
batch_size
:
return
[
tensors
]
else
:
blocks
=
batch_size
//
block_size
remainder
=
batch_size
%
block_size
if
remainder
:
tensor_parts
=
[]
for
tensor
in
tensors
:
shape
:
tf
.
TensorShape
=
tensor
.
shape
body
:
tf
.
Tensor
=
tf
.
slice
(
tensor
,
[
0
]
*
len
(
shape
),
[
size
if
i
!=
axis
else
blocks
*
block_size
for
i
,
size
in
enumerate
(
shape
)
])
tail
:
tf
.
Tensor
=
tf
.
slice
(
tensor
,
[
0
if
i
!=
axis
else
(
blocks
*
block_size
)
for
i
,
_
in
enumerate
(
shape
)
],
[
size
if
i
!=
axis
else
(
size
-
blocks
*
block_size
)
for
i
,
size
in
enumerate
(
shape
)
])
tensor_parts
.
append
(
tf
.
split
(
body
,
blocks
,
axis
)
+
[
tail
])
return
zip
(
*
tensor_parts
)
else
:
return
zip
(
*
[
tf
.
split
(
tensor
,
blocks
,
axis
)
for
tensor
in
tensors
])
# TODO(b/258007436): Number is based on existing compiler limitations while
# running bf16 NMS on edgetpu. Remove manual sharing when compiler issue will be
# fixed.
_RECOMMENDED_NMS_MEMORY
=
360000
def
non_max_suppression_padded
(
boxes
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
output_size
:
int
,
iou_threshold
:
float
=
0.5
)
->
tf
.
Tensor
:
"""Selects a subset of boxes which have highest score among IOU-similar boxes.
Prunes away boxes that have high intersection-over-union (IOU) overlap
with boxes having higher score. Boxes are supplied as `[y1, x1, y2, x2]`,
where `(y1, x1)` and `(y2, x2)` are the coordinates of any diagonal pair of
box corners. Note that this algorithm is agnostic to the coordinate system.
Thus translating or reflections of the coordinate system result in the same
boxes being selected by the algorithm. The output of this operation is a
set of integers indexing into the input collection of bounding boxes
representing the selected boxes.
Set will be returned padded on the right with `-1` values. The bounding
box coordinates corresponding to the selected indices can then be obtained
using the `tf.gather` operation. For example:
```python
selected_indices = vision.modeling.layers.non_max_suppression_padded(
boxes, scores, max_output_size, iou_threshold)
selected_boxes = tf.gather(boxes, selected_indices)
```
See following documetation for implementation details.
third_party/tensorflow_models/official/projects/edgetpu/vision/modeling/g3doc/non_max_suppression.md
Args:
boxes: A 2-D+ float `Tensor` of shape `[...batch_dims, num_boxes, 4]`.
scores: A 1-D+ float `Tensor` of shape `[...batch_dims, num_boxes]`
representing a single score corresponding to each box (each row of boxes).
output_size: A scalar integer `Tensor` representing the maximum number of
boxes to be selected by non-max suppression.
iou_threshold: A 0-D float tensor representing the threshold for deciding
whether boxes overlap too much with respect to IOU.
Returns:
A 1-D+ integer `Tensor` of shape `[...batch_dims, output_size]` representing
the selected indices from the boxes tensor and `-1` values for the padding.
"""
# Does partitioning job to help compiler converge with memory.
batch_shape
=
boxes
.
shape
[:
-
2
]
batch_size
=
np
.
prod
(
batch_shape
,
dtype
=
np
.
int32
)
boxes_size
,
struct_size
=
boxes
.
shape
[
-
2
:]
boxes
=
tf
.
reshape
(
boxes
,
[
batch_size
,
boxes_size
,
struct_size
])
scores
=
tf
.
reshape
(
scores
,
[
batch_size
,
boxes_size
])
block
=
max
(
1
,
_RECOMMENDED_NMS_MEMORY
//
(
boxes_size
*
boxes_size
))
indices
=
[]
for
boxes_i
,
scores_i
in
shard_tensors
(
0
,
block
,
boxes
,
scores
):
indices
.
append
(
_non_max_suppression_as_is
(
boxes_i
,
scores_i
,
output_size
,
iou_threshold
))
indices
=
tf
.
concat
(
indices
,
axis
=
0
)
return
tf
.
reshape
(
indices
,
batch_shape
+
[
output_size
])
def
_non_max_suppression_as_is
(
boxes
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
output_size
:
int
,
iou_threshold
:
float
=
0.5
)
->
tf
.
Tensor
:
"""Selects a subset of boxes which have highest score among IOU-similar boxes.
Args:
boxes: A 2-D+ float `Tensor` of shape `[...batch_dims, num_boxes, 4]`.
scores: A 1-D+ float `Tensor` of shape `[...batch_dims, num_boxes]`
representing a single score corresponding to each box (each row of boxes).
output_size: A scalar integer `Tensor` representing the maximum number of
boxes to be selected by non-max suppression.
iou_threshold: A 0-D float tensor representing the threshold for deciding
whether boxes overlap too much with respect to IOU.
Returns:
A 1-D+ integer `Tensor` of shape `[...batch_dims, output_size]` representing
the selected indices from the boxes tensor and `-1` values for the padding.
"""
batch_shape
=
boxes
.
shape
[:
-
2
]
batch_size
=
np
.
prod
(
batch_shape
,
dtype
=
np
.
int32
)
boxes_size
=
boxes
.
shape
[
-
2
]
if
boxes
.
shape
[
-
1
]
!=
4
:
raise
ValueError
(
f
'Boxes shape (
{
boxes
.
shape
}
) last dimension must be 4 '
'to represent [y1, x1, y2, x2] boxes coordinates'
)
if
scores
.
shape
!=
boxes
.
shape
[:
-
1
]:
raise
ValueError
(
f
'Boxes shape (
{
boxes
.
shape
}
) and scores shape '
f
'(
{
scores
.
shape
}
) do not match.'
)
order
=
tf
.
range
(
boxes_size
,
dtype
=
tf
.
float32
)
relative_order
=
_tensor_sum_vectors
(
order
,
-
order
)
relative_scores
=
_tensor_sum_vectors
(
scores
,
-
scores
)
similar
=
_greater
(
_tensor_product_iou
(
boxes
)
-
iou_threshold
)
worse
=
_greater
(
relative_scores
)
same_later
=
_and
(
_same
(
relative_scores
),
_greater
(
relative_order
))
similar_worse_or_same_later
=
_and
(
similar
,
_or
(
worse
,
same_later
))
prunable
=
_reduce_or
(
similar_worse_or_same_later
,
axis
=-
1
)
remaining
=
tf
.
constant
(
1.
)
-
prunable
scores
=
tf
.
reshape
(
tf
.
exp
(
scores
),
[
1
,
1
,
batch_size
,
boxes_size
])
remaining
=
tf
.
reshape
(
remaining
,
[
1
,
1
,
batch_size
,
boxes_size
])
# top_k runs on TPU cores, let it happen, TPU tiles implementation is slower.
top_k
=
tf
.
math
.
top_k
(
scores
*
remaining
,
output_size
)
indices
=
(
tf
.
cast
(
top_k
.
indices
,
top_k
.
values
.
dtype
)
*
_greater
(
top_k
.
values
)
-
_same
(
top_k
.
values
))
return
tf
.
reshape
(
indices
,
batch_shape
+
[
output_size
])
def
concat_and_top_k
(
top_k
:
int
,
scores_pair
:
tuple
[
Optional
[
tf
.
Tensor
],
tf
.
Tensor
],
*
other_pairs
:
tuple
[
Optional
[
tf
.
Tensor
],
tf
.
Tensor
]
)
->
tuple
[
tf
.
Tensor
,
...]:
"""Combines shards of top_k operation, when sharded along filtered dimension.
General idea is that sometimes top_k dimension is very large, while top_k is
moderately low. (Keep in mind sample of 15K pre-top_k dimension and 150 top_k)
In that case it is possible to break top_k input into groups significantly
larger than top_k and significatly lower than pre-top_l (Keep in mind 1500).
We do top_k over first 1500 elements, than join 150 remaining with new 1500
elements (1750 in total), repeat top_k. This function provides repeatedly used
method which will concat and top_k in that case.
For example with top_k = 2 and scores_pair = ([10, 6], [9, 8, 7]), output
scores will be [10, 9].
Other pairs are filtered using indexes generated from scores. This is a preaty
common case of filtering structure by its score.
For example with one extra pair of box per score:
top_k = 2
scores_pair = ([10, 6],
[9, 8, 7])
other_pairs = [([[0, 0, 10, 10], [0, 0, 6, 6]],
[[1, 1, 9, 9], [1, 1, 8, 8], [1, 1, 7, 7]])]
Output is:
([10, 9], [[0, 0, 10, 10], [1, 1, 9, 9]])
See also 'test_top_k_sharded_fusion' unit test with end to end example.
Args:
top_k: is top_k argument of sharded tf.math.top_k.
scores_pair: Tuple (<previous shards combination>, <additional shard>)
scores to be aggregated using top_k.
*other_pairs: Tuples (<previous shards combination>, <additional shard>)
other values to be aggregated using indexes of top_k scores.
Returns:
Tuple of scores based top_k aggregations with additional shards.
"""
scores
,
scores_shard
=
scores_pair
if
other_pairs
:
others
,
others_shard
=
zip
(
*
other_pairs
)
else
:
others
=
others_shard
=
[]
# Same as tf.rank, but avoiding tensor form for graph mode execution.
top_k_dim
:
int
=
len
(
scores_shard
.
shape
)
-
1
if
scores
is
None
:
# First shard becomes aggregation
scores
=
scores_shard
others
=
others_shard
else
:
# Merge shard into aggregation
scores
=
tf
.
concat
([
scores
,
scores_shard
],
top_k_dim
)
others
=
[
tf
.
concat
([
other
,
other_shard
],
top_k_dim
)
for
other
,
other_shard
in
zip
(
others
,
others_shard
)
]
# When shards are uneven some will be smaller than requested top_k
if
scores
.
shape
[
top_k_dim
]
>
top_k
:
scores
,
indices
=
tf
.
nn
.
top_k
(
scores
,
top_k
)
others
=
[
tf
.
gather
(
other
,
indices
,
axis
=
top_k_dim
,
batch_dims
=
top_k_dim
)
for
other
in
others
]
return
scores
,
*
others
official/vision/modeling/layers/edgetpu_test.py
0 → 100644
浏览文件 @
5938a718
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests EdgeTPU oriented layers and tools."""
from
typing
import
Optional
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.vision.modeling.layers
import
edgetpu
def
random_boxes
(
shape
):
a
=
tf
.
random
.
uniform
(
shape
=
shape
+
[
2
])
b
=
tf
.
random
.
uniform
(
shape
=
shape
+
[
2
])
l
=
tf
.
minimum
(
a
,
b
)
u
=
tf
.
maximum
(
a
,
b
)
return
tf
.
concat
([
l
,
u
],
axis
=-
1
)
def
_maximum_activation_size
(
model
):
max_size
=
0
for
layer
in
model
.
layers
:
outputs
=
layer
.
output
if
not
isinstance
(
outputs
,
list
):
outputs
=
[
outputs
]
for
output
in
outputs
:
if
hasattr
(
output
,
'shape'
):
size
=
np
.
prod
(
output
.
shape
)
max_size
=
max
(
max_size
,
size
)
print
(
'Layer'
,
size
,
output
.
shape
,
layer
.
name
)
return
max_size
class
NonMaxSuppressionTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
tf
.
random
.
set_seed
(
42
)
@
parameterized
.
parameters
((
16
,
8
,
200
,
0.009
),
(
31
,
17
,
100
,
0.013
),
(
71
,
41
,
100
,
0.045
),
(
150
,
100
,
100
,
0.129
),
(
300
,
300
,
100
,
0.116
),
(
600
,
600
,
50
,
0.176
))
def
test_reference_match
(
self
,
n
,
top
,
runs
,
max_deviation
):
"""Compares that new optimized method is close to reference method.
Runs two algorithms with same sets of input boxes and scores, and measures
deviation between returned sets of prunned boxes.
Read more about test results at ./g3doc/non_max_suppression.md
(*) Avoid flakiness with safe boundary (go/python-tips/048): deviation
between two sets is a positive number, which may vary from test to test.
Doing multiple runs expected to reduce average deviation variation following
LLN theorem. Therefore by having first test run we know upper deviation
bound which algorithm would not exceed until broken (in any feasible amount
of time in the future). Use of this safe boundary makes test non-flaky.
Args:
n: number of boxes and scores on input of the algorithm.
top: limit of output boxes count.
runs: for the statistical testing number of runs to performs to avoid
tests flakiness.
max_deviation: mean limit on deviation between optimized and reference
algorithms. Please read notes why this number may be set higher to avoid
flaky testing.
"""
deviation_rate
=
0
min_union
=
2
*
n
boxes
=
random_boxes
([
runs
,
n
])
scores
=
tf
.
random
.
uniform
(
shape
=
[
runs
,
n
])
test
=
edgetpu
.
non_max_suppression_padded
(
boxes
,
scores
,
top
)
for
run
in
range
(
runs
):
reference
=
tf
.
image
.
non_max_suppression
(
boxes
[
run
],
scores
[
run
],
top
)
reference
=
{
*
reference
.
numpy
().
tolist
()}
optimized
=
{
*
test
[
run
].
numpy
().
astype
(
int
).
tolist
()}
-
{
-
1
}
union_size
=
len
(
optimized
|
reference
)
deviation_rate
+=
len
(
optimized
^
reference
)
/
union_size
min_union
=
min
(
min_union
,
union_size
)
deviation_rate
=
deviation_rate
/
runs
# six sigma estimate via LLN theorem
safe_margin
=
6
*
(
deviation_rate
/
np
.
sqrt
(
runs
)
+
1
/
(
runs
*
min_union
))
self
.
assertLess
(
deviation_rate
,
max_deviation
,
msg
=
'Deviation rate between optimized and reference implementations is '
'higher than expected. If you are tuning the test, recommended safe '
'deviation rate is '
f
'
{
deviation_rate
}
+
{
safe_margin
}
=
{
deviation_rate
+
safe_margin
}
'
)
@
parameterized
.
parameters
(([
16
],
8
),
([
91
,
150
],
100
),
([
20
,
20
,
200
],
10
))
def
test_sharded_match
(
self
,
shape
:
list
[
int
],
top
:
int
):
boxes
=
random_boxes
(
shape
)
scores
=
tf
.
random
.
uniform
(
shape
=
shape
)
optimized
=
edgetpu
.
non_max_suppression_padded
(
boxes
,
scores
,
top
)
reference
=
edgetpu
.
_non_max_suppression_as_is
(
boxes
,
scores
,
top
)
self
.
assertAllEqual
(
optimized
,
reference
)
_sharded_nms
=
edgetpu
.
non_max_suppression_padded
_stright_nms
=
edgetpu
.
_non_max_suppression_as_is
@
parameterized
.
parameters
(([
16
],
8
,
_sharded_nms
,
True
),
([
16
],
8
,
_stright_nms
,
True
),
([
91
,
150
],
100
,
_sharded_nms
,
True
),
([
91
,
150
],
100
,
_stright_nms
,
False
),
([
20
,
20
,
200
],
10
,
_sharded_nms
,
True
),
([
20
,
20
,
200
],
10
,
_stright_nms
,
False
))
def
test_sharded_size
(
self
,
shape
:
list
[
int
],
top
:
int
,
algorithm
,
fits_as_is
:
bool
):
scores
=
tf
.
keras
.
Input
(
shape
=
shape
,
batch_size
=
1
)
boxes
=
tf
.
keras
.
Input
(
shape
=
shape
+
[
4
],
batch_size
=
1
)
optimized
=
algorithm
(
boxes
,
scores
,
top
)
model
=
tf
.
keras
.
Model
(
inputs
=
[
boxes
,
scores
],
outputs
=
optimized
)
max_size
=
_maximum_activation_size
(
model
)
if
fits_as_is
:
# Sharding done or not needed.
self
.
assertLessEqual
(
max_size
,
edgetpu
.
_RECOMMENDED_NMS_MEMORY
)
else
:
# Sharding needed.
self
.
assertGreater
(
max_size
,
edgetpu
.
_RECOMMENDED_NMS_MEMORY
)
def
test_shard_tensors
(
self
):
a
:
tf
.
Tensor
=
tf
.
constant
([[
0
,
1
,
2
,
3
,
4
]])
b
:
tf
.
Tensor
=
tf
.
constant
([[
[
0
,
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
8
,
9
],
[
10
,
11
,
12
,
13
,
14
],
[
15
,
16
,
17
,
18
,
19
],
[
20
,
21
,
22
,
23
,
24
],
]])
for
i
,
(
a_i
,
b_i
)
in
enumerate
(
edgetpu
.
shard_tensors
(
1
,
3
,
a
,
b
)):
self
.
assertAllEqual
(
a_i
,
a
[:,
i
*
3
:
i
*
3
+
3
])
self
.
assertAllEqual
(
b_i
,
b
[:,
i
*
3
:
i
*
3
+
3
,
:])
def
test_top_k_sharded_fusion_arguments_validation
(
self
):
# Input scores is not pair of aggregation and shard.
self
.
assertRaises
(
ValueError
,
edgetpu
.
concat_and_top_k
,
100
,
tf
.
zeros
(
shape
=
[
1000
]))
# Input other values is not pairs of aggregation and shard.
self
.
assertRaises
(
TypeError
,
edgetpu
.
concat_and_top_k
,
100
,
(
None
,
tf
.
zeros
(
shape
=
[
1000
])),
None
,
tf
.
zeros
(
shape
=
[
1000
]))
# Insufficient rank to do top_k
self
.
assertRaises
(
IndexError
,
edgetpu
.
concat_and_top_k
,
100
,
(
None
,
tf
.
constant
(
1.
)))
@
parameterized
.
parameters
(
0
,
1
,
2
)
def
test_top_k_sharded_fusion_vs_top_k_unsharded
(
self
,
axis
:
int
):
r
"""Tests `horizontal` sharding using shard_tensors and concat_and_top_k.
Will generate and test graph (on diagram 4 shards, in test 6 shards):
Input
-----
|
+-------+--------------------------------------------
| Split |----------------------- \
+-------+--- \ |
| \ | |
+-------+ +--------+ +-------+ +--------+ +-------+ +--------+ +-------+
| top k |-| concat |-| top k |-| concat |-| top k |-| concat |-| top k |
+-------+ +--------+ +-------+ +--------+ +-------+ +--------+ +-------+
|
Output
------
Args:
axis: test top_k axis (tensor rank will be axis + 1)
"""
sample
:
tf
.
Tensor
=
tf
.
random
.
uniform
(
shape
=
axis
*
[
1
]
+
[
10000
],
dtype
=
tf
.
float32
)
top_1000_direct
:
tf
.
Tensor
=
tf
.
math
.
top_k
(
sample
,
1000
).
values
top_1000_sharded
:
Optional
[
tf
.
Tensor
]
=
None
for
(
piece
,)
in
edgetpu
.
shard_tensors
(
axis
,
1500
,
sample
):
(
top_1000_sharded
,)
=
edgetpu
.
concat_and_top_k
(
1000
,
(
top_1000_sharded
,
piece
))
self
.
assertAllEqual
(
top_1000_direct
,
top_1000_sharded
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录