提交 3bac1426 编写于 作者: V Vighnesh Birodkar 提交者: TF Object Detection Team

Fixes and tests for hourglass variants.

PiperOrigin-RevId: 331166835
上级 643d492b
......@@ -226,7 +226,12 @@ def _make_repeated_residual_blocks(out_channels, num_blocks,
residual_channels = out_channels
for i in range(num_blocks - 1):
# Only use the stride at the first block so we don't repeatedly downsample
# the input
stride = initial_stride if i == 0 else 1
# If the stide is more than 1, we cannot use an identity layer for the
# skip connection and are forced to use a conv for the skip connection.
skip_conv = stride > 1
blocks.append(
......@@ -234,8 +239,18 @@ def _make_repeated_residual_blocks(out_channels, num_blocks,
skip_conv=skip_conv)
)
skip_conv = residual_channels != out_channels
blocks.append(ResidualBlock(out_channels=out_channels, skip_conv=skip_conv))
if num_blocks == 1:
# If there is only 1 block, the for loop above is not run,
# therefore we honor the requested stride in the last residual block
stride = initial_stride
# We are forced to use a conv in the skip connection if stride > 1
skip_conv = stride > 1
else:
stride = 1
skip_conv = residual_channels != out_channels
blocks.append(ResidualBlock(out_channels=out_channels, skip_conv=skip_conv,
stride=stride))
return blocks
......@@ -494,7 +509,7 @@ def hourglass_104():
)
def single_stage_hourglass(blocks_per_stage, num_channels):
def single_stage_hourglass(blocks_per_stage, num_channels, downsample=True):
nc = num_channels
channel_dims = [nc, nc * 2, nc * 2, nc * 3, nc * 3, nc * 3, nc * 4]
num_stages = len(blocks_per_stage) - 1
......@@ -504,20 +519,21 @@ def single_stage_hourglass(blocks_per_stage, num_channels):
num_hourglasses=1,
num_stages=num_stages,
blocks_per_stage=blocks_per_stage,
downsample=downsample
)
def hourglass_10(num_channels):
return single_stage_hourglass([1, 1], num_channels)
def hourglass_10(num_channels, downsample=True):
return single_stage_hourglass([1, 1], num_channels, downsample)
def hourglass_20(num_channels):
return single_stage_hourglass([1, 1, 1, 2], num_channels)
def hourglass_20(num_channels, downsample=True):
return single_stage_hourglass([1, 2, 2], num_channels, downsample)
def hourglass_32(num_channels):
return single_stage_hourglass([1, 1, 2, 2, 2], num_channels)
def hourglass_32(num_channels, downsample=True):
return single_stage_hourglass([2, 2, 2, 2], num_channels, downsample)
def hourglass_52(num_channels):
return single_stage_hourglass([2, 2, 2, 2, 2, 4], num_channels)
def hourglass_52(num_channels, downsample=True):
return single_stage_hourglass([2, 2, 2, 2, 2, 4], num_channels, downsample)
......@@ -111,21 +111,34 @@ class HourglassDepthTest(tf.test.TestCase):
self.assertEqual(hourglass.hourglass_depth(net), 104)
def test_hourglass_10(self):
net = hourglass.hourglass_10(2)
net = hourglass.hourglass_10(2, downsample=False)
self.assertEqual(hourglass.hourglass_depth(net), 10)
outputs = net(tf.zeros((2, 32, 32, 3)))
self.assertEqual(outputs[0].shape, (2, 32, 32, 4))
def test_hourglass_20(self):
net = hourglass.hourglass_20(2)
net = hourglass.hourglass_20(2, downsample=False)
self.assertEqual(hourglass.hourglass_depth(net), 20)
outputs = net(tf.zeros((2, 32, 32, 3)))
self.assertEqual(outputs[0].shape, (2, 32, 32, 4))
def test_hourglass_32(self):
net = hourglass.hourglass_32(2)
net = hourglass.hourglass_32(2, downsample=False)
self.assertEqual(hourglass.hourglass_depth(net), 32)
outputs = net(tf.zeros((2, 32, 32, 3)))
self.assertEqual(outputs[0].shape, (2, 32, 32, 4))
def test_hourglass_52(self):
net = hourglass.hourglass_52(2)
net = hourglass.hourglass_52(2, downsample=False)
self.assertEqual(hourglass.hourglass_depth(net), 52)
outputs = net(tf.zeros((2, 32, 32, 3)))
self.assertEqual(outputs[0].shape, (2, 32, 32, 4))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册