Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
c5662b16
M
Models
项目概览
曾经的那一瞬间
/
Models
大约 1 年 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c5662b16
编写于
11月 28, 2022
作者:
A
A. Unique TensorFlower
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Internal change
PiperOrigin-RevId: 491544318
上级
f1e746c2
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
127 addition
and
9 deletion
+127
-9
official/projects/edgetpu/vision/modeling/custom_layers.py
official/projects/edgetpu/vision/modeling/custom_layers.py
+77
-6
official/projects/edgetpu/vision/modeling/custom_layers_test.py
...al/projects/edgetpu/vision/modeling/custom_layers_test.py
+50
-3
未找到文件。
official/projects/edgetpu/vision/modeling/custom_layers.py
浏览文件 @
c5662b16
...
...
@@ -14,8 +14,10 @@
"""Customized keras layers used in the EdgeTPU models."""
from
collections.abc
import
Iterable
,
MutableMapping
,
Sequence
import
inspect
from
typing
import
Any
,
Iterable
,
MutableMapping
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Any
,
Optional
,
Union
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
...
...
@@ -26,12 +28,12 @@ class GroupConv2D(tf.keras.layers.Conv2D):
def
__init__
(
self
,
filters
:
int
,
kernel_size
:
Union
[
int
,
T
uple
[
int
,
int
]],
kernel_size
:
Union
[
int
,
t
uple
[
int
,
int
]],
groups
:
int
,
strides
:
T
uple
[
int
,
int
]
=
(
1
,
1
),
strides
:
t
uple
[
int
,
int
]
=
(
1
,
1
),
padding
:
str
=
'valid'
,
data_format
:
str
=
'channels_last'
,
dilation_rate
:
T
uple
[
int
,
int
]
=
(
1
,
1
),
dilation_rate
:
t
uple
[
int
,
int
]
=
(
1
,
1
),
activation
:
Any
=
None
,
use_bias
:
bool
=
True
,
kernel_initializer
:
Any
=
'glorot_uniform'
,
...
...
@@ -149,7 +151,7 @@ class GroupConv2D(tf.keras.layers.Conv2D):
groups
=
1
,
**
kwargs
)
# pytype: disable=bad-return-type # typed-keras
def
build
(
self
,
input_shape
:
T
uple
[
int
,
...])
->
None
:
def
build
(
self
,
input_shape
:
t
uple
[
int
,
...])
->
None
:
"""Builds GroupConv2D layer as a collection of smaller Conv2D layers."""
input_shape
=
tf
.
TensorShape
(
input_shape
)
input_channel
=
self
.
_get_input_channel
(
input_shape
)
...
...
@@ -271,7 +273,7 @@ class GroupConv2DKerasModel(tf.keras.Model):
def
__init__
(
self
,
filters
:
int
,
kernel_size
:
T
uple
[
int
,
int
],
kernel_size
:
t
uple
[
int
,
int
],
groups
:
int
,
batch_norm_layer
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
bn_epsilon
:
float
=
1e-3
,
...
...
@@ -715,3 +717,72 @@ def _non_max_suppression_as_is(boxes: tf.Tensor,
tf
.
cast
(
top_k
.
indices
,
top_k
.
values
.
dtype
)
*
_greater
(
top_k
.
values
)
-
_same
(
top_k
.
values
))
return
tf
.
reshape
(
indices
,
batch_shape
+
[
output_size
])
def
concat_and_top_k
(
top_k
:
int
,
scores_pair
:
tuple
[
Optional
[
tf
.
Tensor
],
tf
.
Tensor
],
*
other_pairs
:
tuple
[
Optional
[
tf
.
Tensor
],
tf
.
Tensor
]
)
->
tuple
[
tf
.
Tensor
,
...]:
"""Combines shards of top_k operation, when sharded along filtered dimension.
General idea is that sometimes top_k dimension is very large, while top_k is
moderately low. (Keep in mind sample of 15K pre-top_k dimension and 150 top_k)
In that case it is possible to break top_k input into groups significantly
larger than top_k and significatly lower than pre-top_l (Keep in mind 1500).
We do top_k over first 1500 elements, than join 150 remaining with new 1500
elements (1750 in total), repeat top_k. This function provides repeatedly used
method which will concat and top_k in that case.
For example with top_k = 2 and scores_pair = ([10, 6], [9, 8, 7]), output
scores will be [10, 9].
Other pairs are filtered using indexes generated from scores. This is a preaty
common case of filtering structure by its score.
For example with one extra pair of box per score:
top_k = 2
scores_pair = ([10, 6],
[9, 8, 7])
other_pairs = [([[0, 0, 10, 10], [0, 0, 6, 6]],
[[1, 1, 9, 9], [1, 1, 8, 8], [1, 1, 7, 7]])]
Output is:
([10, 9], [[0, 0, 10, 10], [1, 1, 9, 9]])
See also 'test_top_k_sharded_fusion' unit test with end to end example.
Args:
top_k: is top_k argument of sharded tf.math.top_k.
scores_pair: Tuple (<previous shards combination>, <additional shard>)
scores to be aggregated using top_k.
*other_pairs: Tuples (<previous shards combination>, <additional shard>)
other values to be aggregated using indexes of top_k scores.
Returns:
Tuple of scores based top_k aggregations with additional shards.
"""
scores
,
scores_shard
=
scores_pair
if
other_pairs
:
others
,
others_shard
=
zip
(
*
other_pairs
)
else
:
others
=
others_shard
=
[]
# Same as tf.rank, but avoiding tensor form for graph mode execution.
top_k_dim
:
int
=
len
(
scores_shard
.
shape
)
-
1
if
scores
is
None
:
# First shard becomes aggregation
scores
=
scores_shard
others
=
others_shard
else
:
# Merge shard into agregation
scores
=
tf
.
concat
([
scores
,
scores_shard
],
top_k_dim
)
others
=
[
tf
.
concat
([
other
,
other_shard
],
top_k_dim
)
for
other
,
other_shard
in
zip
(
others
,
others_shard
)
]
# When shards are uneven some will be smaller than requested top_k
if
scores
.
shape
[
top_k_dim
]
>
top_k
:
scores
,
indices
=
tf
.
nn
.
top_k
(
scores
,
top_k
)
others
=
[
tf
.
gather
(
other
,
indices
,
axis
=
top_k_dim
,
batch_dims
=
top_k_dim
)
for
other
in
others
]
return
scores
,
*
others
official/projects/edgetpu/vision/modeling/custom_layers_test.py
浏览文件 @
c5662b16
...
...
@@ -15,7 +15,7 @@
"""Tests for custom_layers."""
import
itertools
from
typing
import
List
from
typing
import
Optional
from
absl.testing
import
parameterized
import
numpy
as
np
...
...
@@ -212,6 +212,10 @@ def _maximum_activation_size(model):
class
NonMaxSuppressionTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
tf
.
random
.
set_seed
(
42
)
@
parameterized
.
parameters
((
16
,
8
,
200
,
0.009
),
(
31
,
17
,
100
,
0.013
),
(
71
,
41
,
100
,
0.045
),
(
150
,
100
,
100
,
0.129
),
(
300
,
300
,
100
,
0.116
),
(
600
,
600
,
50
,
0.176
))
...
...
@@ -261,7 +265,7 @@ class NonMaxSuppressionTest(parameterized.TestCase, tf.test.TestCase):
f
'
{
deviation_rate
}
+
{
safe_margin
}
=
{
deviation_rate
+
safe_margin
}
'
)
@
parameterized
.
parameters
(([
16
],
8
),
([
91
,
150
],
100
),
([
20
,
20
,
200
],
10
))
def
test_sharded_match
(
self
,
shape
:
L
ist
[
int
],
top
:
int
):
def
test_sharded_match
(
self
,
shape
:
l
ist
[
int
],
top
:
int
):
boxes
=
random_boxes
(
shape
)
scores
=
tf
.
random
.
uniform
(
shape
=
shape
)
optimized
=
custom_layers
.
non_max_suppression_padded
(
boxes
,
scores
,
top
)
...
...
@@ -277,7 +281,7 @@ class NonMaxSuppressionTest(parameterized.TestCase, tf.test.TestCase):
([
91
,
150
],
100
,
_stright_nms
,
False
),
([
20
,
20
,
200
],
10
,
_sharded_nms
,
True
),
([
20
,
20
,
200
],
10
,
_stright_nms
,
False
))
def
test_sharded_size
(
self
,
shape
:
L
ist
[
int
],
top
:
int
,
algorithm
,
def
test_sharded_size
(
self
,
shape
:
l
ist
[
int
],
top
:
int
,
algorithm
,
fits_as_is
:
bool
):
scores
=
tf
.
keras
.
Input
(
shape
=
shape
,
batch_size
=
1
)
boxes
=
tf
.
keras
.
Input
(
shape
=
shape
+
[
4
],
batch_size
=
1
)
...
...
@@ -304,5 +308,48 @@ class NonMaxSuppressionTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertAllEqual
(
a_i
,
a
[:,
i
*
3
:
i
*
3
+
3
])
self
.
assertAllEqual
(
b_i
,
b
[:,
i
*
3
:
i
*
3
+
3
,
:])
def
test_top_k_sharded_fusion_arguments_validation
(
self
):
# Input scores is not pair of agregation and shard.
self
.
assertRaises
(
ValueError
,
custom_layers
.
concat_and_top_k
,
100
,
tf
.
zeros
(
shape
=
[
1000
]))
# Input other values is not pairs of agregation and shard.
self
.
assertRaises
(
TypeError
,
custom_layers
.
concat_and_top_k
,
100
,
(
None
,
tf
.
zeros
(
shape
=
[
1000
])),
None
,
tf
.
zeros
(
shape
=
[
1000
]))
# Insufficient rank to do top_k
self
.
assertRaises
(
IndexError
,
custom_layers
.
concat_and_top_k
,
100
,
(
None
,
tf
.
constant
(
1.
)))
@
parameterized
.
parameters
(
0
,
1
,
2
)
def
test_top_k_sharded_fusion_vs_top_k_unsharded
(
self
,
axis
:
int
):
r
"""Tests `horizontal` sharding using shard_tensors and concat_and_top_k.
Will generate and test graph (on diagram 4 shards, in test 6 shards):
Input
-----
|
+-------+--------------------------------------------
| Split |----------------------- \
+-------+--- \ |
| \ | |
+-------+ +--------+ +-------+ +--------+ +-------+ +--------+ +-------+
| top k |-| concat |-| top k |-| concat |-| top k |-| concat |-| top k |
+-------+ +--------+ +-------+ +--------+ +-------+ +--------+ +-------+
|
Output
------
Args:
axis: test top_k axis (tensor rank will be axis + 1)
"""
sample
:
tf
.
Tensor
=
tf
.
random
.
uniform
(
shape
=
axis
*
[
1
]
+
[
10000
],
dtype
=
tf
.
float32
)
top_1000_direct
:
tf
.
Tensor
=
tf
.
math
.
top_k
(
sample
,
1000
).
values
top_1000_sharded
:
Optional
[
tf
.
Tensor
]
=
None
for
(
piece
,)
in
custom_layers
.
shard_tensors
(
axis
,
1500
,
sample
):
(
top_1000_sharded
,)
=
custom_layers
.
concat_and_top_k
(
1000
,
(
top_1000_sharded
,
piece
))
self
.
assertAllEqual
(
top_1000_direct
,
top_1000_sharded
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录