Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
3bac1426
M
Models
项目概览
曾经的那一瞬间
/
Models
大约 1 年 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
3bac1426
编写于
9月 11, 2020
作者:
V
Vighnesh Birodkar
提交者:
TF Object Detection Team
9月 11, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fixes and tests for hourglass variants.
PiperOrigin-RevId: 331166835
上级
643d492b
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
44 addition
and
15 deletion
+44
-15
research/object_detection/models/keras_models/hourglass_network.py
...object_detection/models/keras_models/hourglass_network.py
+27
-11
research/object_detection/models/keras_models/hourglass_network_tf2_test.py
...tection/models/keras_models/hourglass_network_tf2_test.py
+17
-4
未找到文件。
research/object_detection/models/keras_models/hourglass_network.py
浏览文件 @
3bac1426
...
...
@@ -226,7 +226,12 @@ def _make_repeated_residual_blocks(out_channels, num_blocks,
residual_channels
=
out_channels
for
i
in
range
(
num_blocks
-
1
):
# Only use the stride at the first block so we don't repeatedly downsample
# the input
stride
=
initial_stride
if
i
==
0
else
1
# If the stide is more than 1, we cannot use an identity layer for the
# skip connection and are forced to use a conv for the skip connection.
skip_conv
=
stride
>
1
blocks
.
append
(
...
...
@@ -234,8 +239,18 @@ def _make_repeated_residual_blocks(out_channels, num_blocks,
skip_conv
=
skip_conv
)
)
skip_conv
=
residual_channels
!=
out_channels
blocks
.
append
(
ResidualBlock
(
out_channels
=
out_channels
,
skip_conv
=
skip_conv
))
if
num_blocks
==
1
:
# If there is only 1 block, the for loop above is not run,
# therefore we honor the requested stride in the last residual block
stride
=
initial_stride
# We are forced to use a conv in the skip connection if stride > 1
skip_conv
=
stride
>
1
else
:
stride
=
1
skip_conv
=
residual_channels
!=
out_channels
blocks
.
append
(
ResidualBlock
(
out_channels
=
out_channels
,
skip_conv
=
skip_conv
,
stride
=
stride
))
return
blocks
...
...
@@ -494,7 +509,7 @@ def hourglass_104():
)
def
single_stage_hourglass
(
blocks_per_stage
,
num_channels
):
def
single_stage_hourglass
(
blocks_per_stage
,
num_channels
,
downsample
=
True
):
nc
=
num_channels
channel_dims
=
[
nc
,
nc
*
2
,
nc
*
2
,
nc
*
3
,
nc
*
3
,
nc
*
3
,
nc
*
4
]
num_stages
=
len
(
blocks_per_stage
)
-
1
...
...
@@ -504,20 +519,21 @@ def single_stage_hourglass(blocks_per_stage, num_channels):
num_hourglasses
=
1
,
num_stages
=
num_stages
,
blocks_per_stage
=
blocks_per_stage
,
downsample
=
downsample
)
def
hourglass_10
(
num_channels
):
return
single_stage_hourglass
([
1
,
1
],
num_channels
)
def
hourglass_10
(
num_channels
,
downsample
=
True
):
return
single_stage_hourglass
([
1
,
1
],
num_channels
,
downsample
)
def
hourglass_20
(
num_channels
):
return
single_stage_hourglass
([
1
,
1
,
1
,
2
],
num_channels
)
def
hourglass_20
(
num_channels
,
downsample
=
True
):
return
single_stage_hourglass
([
1
,
2
,
2
],
num_channels
,
downsample
)
def
hourglass_32
(
num_channels
):
return
single_stage_hourglass
([
1
,
1
,
2
,
2
,
2
],
num_channels
)
def
hourglass_32
(
num_channels
,
downsample
=
True
):
return
single_stage_hourglass
([
2
,
2
,
2
,
2
],
num_channels
,
downsample
)
def
hourglass_52
(
num_channels
):
return
single_stage_hourglass
([
2
,
2
,
2
,
2
,
2
,
4
],
num_channels
)
def
hourglass_52
(
num_channels
,
downsample
=
True
):
return
single_stage_hourglass
([
2
,
2
,
2
,
2
,
2
,
4
],
num_channels
,
downsample
)
research/object_detection/models/keras_models/hourglass_network_tf2_test.py
浏览文件 @
3bac1426
...
...
@@ -111,21 +111,34 @@ class HourglassDepthTest(tf.test.TestCase):
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
104
)
def
test_hourglass_10
(
self
):
net
=
hourglass
.
hourglass_10
(
2
)
net
=
hourglass
.
hourglass_10
(
2
,
downsample
=
False
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
10
)
outputs
=
net
(
tf
.
zeros
((
2
,
32
,
32
,
3
)))
self
.
assertEqual
(
outputs
[
0
].
shape
,
(
2
,
32
,
32
,
4
))
def
test_hourglass_20
(
self
):
net
=
hourglass
.
hourglass_20
(
2
)
net
=
hourglass
.
hourglass_20
(
2
,
downsample
=
False
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
20
)
outputs
=
net
(
tf
.
zeros
((
2
,
32
,
32
,
3
)))
self
.
assertEqual
(
outputs
[
0
].
shape
,
(
2
,
32
,
32
,
4
))
def
test_hourglass_32
(
self
):
net
=
hourglass
.
hourglass_32
(
2
)
net
=
hourglass
.
hourglass_32
(
2
,
downsample
=
False
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
32
)
outputs
=
net
(
tf
.
zeros
((
2
,
32
,
32
,
3
)))
self
.
assertEqual
(
outputs
[
0
].
shape
,
(
2
,
32
,
32
,
4
))
def
test_hourglass_52
(
self
):
net
=
hourglass
.
hourglass_52
(
2
)
net
=
hourglass
.
hourglass_52
(
2
,
downsample
=
False
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
52
)
outputs
=
net
(
tf
.
zeros
((
2
,
32
,
32
,
3
)))
self
.
assertEqual
(
outputs
[
0
].
shape
,
(
2
,
32
,
32
,
4
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录