Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
a8518117
M
Models
项目概览
曾经的那一瞬间
/
Models
大约 1 年 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a8518117
编写于
9月 10, 2020
作者:
V
Vighnesh Birodkar
提交者:
TF Object Detection Team
9月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make downsampling optional in hourglass.
PiperOrigin-RevId: 331013782
上级
9d2a7242
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
55 addition
and
6 deletion
+55
-6
research/object_detection/models/keras_models/hourglass_network.py
...object_detection/models/keras_models/hourglass_network.py
+49
-6
research/object_detection/models/keras_models/hourglass_network_tf2_test.py
...tection/models/keras_models/hourglass_network_tf2_test.py
+6
-0
未找到文件。
research/object_detection/models/keras_models/hourglass_network.py
浏览文件 @
a8518117
...
...
@@ -174,6 +174,36 @@ class InputDownsampleBlock(tf.keras.layers.Layer):
return
self
.
residual_block
(
self
.
conv_block
(
inputs
))
class
InputConvBlock
(
tf
.
keras
.
layers
.
Layer
):
"""Block for the initial feature convolution.
This block is used in the hourglass network when we don't want to downsample
the input.
"""
def
__init__
(
self
,
out_channels_initial_conv
,
out_channels_residual_block
):
"""Initializes the downsample block.
Args:
out_channels_initial_conv: int, the desired number of output channels
in the initial conv layer.
out_channels_residual_block: int, the desired number of output channels
in the underlying residual block.
"""
super
(
InputConvBlock
,
self
).
__init__
()
# TODO(vighneshb) explore if 3x3 works here.
self
.
conv_block
=
ConvolutionalBlock
(
kernel_size
=
7
,
out_channels
=
out_channels_initial_conv
,
stride
=
1
,
padding
=
'valid'
)
self
.
residual_block
=
ResidualBlock
(
out_channels
=
out_channels_residual_block
,
stride
=
1
,
skip_conv
=
True
)
def
call
(
self
,
inputs
):
return
self
.
residual_block
(
self
.
conv_block
(
inputs
))
def
_make_repeated_residual_blocks
(
out_channels
,
num_blocks
,
initial_stride
=
1
,
residual_channels
=
None
):
"""Stack Residual blocks one after the other.
...
...
@@ -285,7 +315,7 @@ class HourglassNetwork(tf.keras.Model):
"""The hourglass network."""
def
__init__
(
self
,
num_stages
,
channel_dims
,
blocks_per_stage
,
num_hourglasses
):
num_hourglasses
,
downsample
=
True
):
"""Intializes the feature extractor.
Args:
...
...
@@ -300,15 +330,24 @@ class HourglassNetwork(tf.keras.Model):
stage in the hourglass network
num_hourglasses: int, number of hourglas networks to stack
sequentially.
downsample: bool, if set, downsamples the input by a factor of 4 before
applying the rest of the network.
"""
super
(
HourglassNetwork
,
self
).
__init__
()
self
.
num_hourglasses
=
num_hourglasses
self
.
downsample_input
=
InputDownsampleBlock
(
out_channels_initial_conv
=
channel_dims
[
0
],
out_channels_residual_block
=
channel_dims
[
1
]
)
self
.
downsample
=
downsample
if
downsample
:
self
.
downsample_input
=
InputDownsampleBlock
(
out_channels_initial_conv
=
channel_dims
[
0
],
out_channels_residual_block
=
channel_dims
[
1
]
)
else
:
self
.
conv_input
=
InputConvBlock
(
out_channels_initial_conv
=
channel_dims
[
0
],
out_channels_residual_block
=
channel_dims
[
1
]
)
self
.
hourglass_network
=
[]
self
.
output_conv
=
[]
...
...
@@ -343,7 +382,11 @@ class HourglassNetwork(tf.keras.Model):
def
call
(
self
,
inputs
):
inputs
=
self
.
downsample_input
(
inputs
)
if
self
.
downsample
:
inputs
=
self
.
downsample_input
(
inputs
)
else
:
inputs
=
self
.
conv_input
(
inputs
)
outputs
=
[]
for
i
in
range
(
self
.
num_hourglasses
):
...
...
research/object_detection/models/keras_models/hourglass_network_tf2_test.py
浏览文件 @
a8518117
...
...
@@ -78,6 +78,12 @@ class HourglassFeatureExtractorTest(tf.test.TestCase, parameterized.TestCase):
output
=
layer
(
np
.
zeros
((
2
,
32
,
32
,
8
),
dtype
=
np
.
float32
))
self
.
assertEqual
(
output
.
shape
,
(
2
,
8
,
8
,
8
))
def
test_input_conv_block
(
self
):
layer
=
hourglass
.
InputConvBlock
(
out_channels_initial_conv
=
4
,
out_channels_residual_block
=
8
)
output
=
layer
(
np
.
zeros
((
2
,
32
,
32
,
8
),
dtype
=
np
.
float32
))
self
.
assertEqual
(
output
.
shape
,
(
2
,
32
,
32
,
8
))
def
test_encoder_decoder_block
(
self
):
layer
=
hourglass
.
EncoderDecoderBlock
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录