提交 9d2a7242 编写于 作者: V Vighnesh Birodkar 提交者: TF Object Detection Team

Support for hourglass-10,20,32 and 52 and function to compute hourglass depth.

PiperOrigin-RevId: 330992374
上级 9b8b13e8
......@@ -372,8 +372,76 @@ class HourglassNetwork(tf.keras.Model):
return self.num_hourglasses
def _layer_depth(layer):
"""Compute depth of Conv/Residual blocks or lists of them."""
if isinstance(layer, list):
return sum([_layer_depth(l) for l in layer])
elif isinstance(layer, ConvolutionalBlock):
return 1
elif isinstance(layer, ResidualBlock):
return 2
else:
raise ValueError('Unknown layer - {}'.format(layer))
def _encoder_decoder_depth(network):
"""Helper function to compute depth of encoder-decoder blocks."""
encoder_block2_layers = _layer_depth(network.encoder_block2)
decoder_block_layers = _layer_depth(network.decoder_block)
if isinstance(network.inner_block[0], EncoderDecoderBlock):
assert len(network.inner_block) == 1, 'Inner block is expected as length 1.'
inner_block_layers = _encoder_decoder_depth(network.inner_block[0])
return inner_block_layers + encoder_block2_layers + decoder_block_layers
elif isinstance(network.inner_block[0], ResidualBlock):
return (encoder_block2_layers + decoder_block_layers +
_layer_depth(network.inner_block))
else:
raise ValueError('Unknown inner block type.')
def hourglass_depth(network):
"""Helper function to verify depth of hourglass backbone."""
input_conv_layers = 3 # 1 ResidualBlock and 1 ConvBlock
# Only intermediate_conv2 and intermediate_residual are applied before
# sending inputs to the later stages.
intermediate_layers = (
_layer_depth(network.intermediate_conv2) +
_layer_depth(network.intermediate_residual)
)
# network.output_conv is applied before sending input to the later stages
output_layers = _layer_depth(network.output_conv)
encoder_decoder_layers = sum(_encoder_decoder_depth(net) for net in
network.hourglass_network)
return (input_conv_layers + encoder_decoder_layers + intermediate_layers
+ output_layers)
def hourglass_104():
"""The Hourglass-104 backbone."""
"""The Hourglass-104 backbone.
The architecture parameters are taken from [1].
Returns:
network: An HourglassNetwork object implementing the Hourglass-104
backbone.
[1]: https://arxiv.org/abs/1904.07850
"""
return HourglassNetwork(
channel_dims=[128, 256, 256, 384, 384, 384, 512],
......@@ -381,3 +449,32 @@ def hourglass_104():
num_stages=5,
blocks_per_stage=[2, 2, 2, 2, 2, 4],
)
def single_stage_hourglass(blocks_per_stage, num_channels):
nc = num_channels
channel_dims = [nc, nc * 2, nc * 2, nc * 3, nc * 3, nc * 3, nc * 4]
num_stages = len(blocks_per_stage) - 1
channel_dims = channel_dims[:num_stages + 2]
return HourglassNetwork(
channel_dims=channel_dims,
num_hourglasses=1,
num_stages=num_stages,
blocks_per_stage=blocks_per_stage,
)
def hourglass_10(num_channels):
return single_stage_hourglass([1, 1], num_channels)
def hourglass_20(num_channels):
return single_stage_hourglass([1, 1, 1, 2], num_channels)
def hourglass_32(num_channels):
return single_stage_hourglass([1, 1, 2, 2, 2], num_channels)
def hourglass_52(num_channels):
return single_stage_hourglass([2, 2, 2, 2, 2, 4], num_channels)
......@@ -96,5 +96,30 @@ class HourglassFeatureExtractorTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(outputs[1].shape, (2, 16, 16, 6))
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class HourglassDepthTest(tf.test.TestCase):
def test_hourglass_104(self):
net = hourglass.hourglass_104()
self.assertEqual(hourglass.hourglass_depth(net), 104)
def test_hourglass_10(self):
net = hourglass.hourglass_10(2)
self.assertEqual(hourglass.hourglass_depth(net), 10)
def test_hourglass_20(self):
net = hourglass.hourglass_20(2)
self.assertEqual(hourglass.hourglass_depth(net), 20)
def test_hourglass_32(self):
net = hourglass.hourglass_32(2)
self.assertEqual(hourglass.hourglass_depth(net), 32)
def test_hourglass_52(self):
net = hourglass.hourglass_52(2)
self.assertEqual(hourglass.hourglass_depth(net), 52)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册