提交 a8518117 编写于 作者: V Vighnesh Birodkar 提交者: TF Object Detection Team

Make downsampling optional in hourglass.

PiperOrigin-RevId: 331013782
上级 9d2a7242
......@@ -174,6 +174,36 @@ class InputDownsampleBlock(tf.keras.layers.Layer):
return self.residual_block(self.conv_block(inputs))
class InputConvBlock(tf.keras.layers.Layer):
"""Block for the initial feature convolution.
This block is used in the hourglass network when we don't want to downsample
the input.
"""
def __init__(self, out_channels_initial_conv, out_channels_residual_block):
"""Initializes the downsample block.
Args:
out_channels_initial_conv: int, the desired number of output channels
in the initial conv layer.
out_channels_residual_block: int, the desired number of output channels
in the underlying residual block.
"""
super(InputConvBlock, self).__init__()
# TODO(vighneshb) explore if 3x3 works here.
self.conv_block = ConvolutionalBlock(
kernel_size=7, out_channels=out_channels_initial_conv, stride=1,
padding='valid')
self.residual_block = ResidualBlock(
out_channels=out_channels_residual_block, stride=1, skip_conv=True)
def call(self, inputs):
return self.residual_block(self.conv_block(inputs))
def _make_repeated_residual_blocks(out_channels, num_blocks,
initial_stride=1, residual_channels=None):
"""Stack Residual blocks one after the other.
......@@ -285,7 +315,7 @@ class HourglassNetwork(tf.keras.Model):
"""The hourglass network."""
def __init__(self, num_stages, channel_dims, blocks_per_stage,
num_hourglasses):
num_hourglasses, downsample=True):
"""Intializes the feature extractor.
Args:
......@@ -300,15 +330,24 @@ class HourglassNetwork(tf.keras.Model):
stage in the hourglass network
num_hourglasses: int, number of hourglas networks to stack
sequentially.
downsample: bool, if set, downsamples the input by a factor of 4 before
applying the rest of the network.
"""
super(HourglassNetwork, self).__init__()
self.num_hourglasses = num_hourglasses
self.downsample_input = InputDownsampleBlock(
out_channels_initial_conv=channel_dims[0],
out_channels_residual_block=channel_dims[1]
)
self.downsample = downsample
if downsample:
self.downsample_input = InputDownsampleBlock(
out_channels_initial_conv=channel_dims[0],
out_channels_residual_block=channel_dims[1]
)
else:
self.conv_input = InputConvBlock(
out_channels_initial_conv=channel_dims[0],
out_channels_residual_block=channel_dims[1]
)
self.hourglass_network = []
self.output_conv = []
......@@ -343,7 +382,11 @@ class HourglassNetwork(tf.keras.Model):
def call(self, inputs):
inputs = self.downsample_input(inputs)
if self.downsample:
inputs = self.downsample_input(inputs)
else:
inputs = self.conv_input(inputs)
outputs = []
for i in range(self.num_hourglasses):
......
......@@ -78,6 +78,12 @@ class HourglassFeatureExtractorTest(tf.test.TestCase, parameterized.TestCase):
output = layer(np.zeros((2, 32, 32, 8), dtype=np.float32))
self.assertEqual(output.shape, (2, 8, 8, 8))
def test_input_conv_block(self):
layer = hourglass.InputConvBlock(
out_channels_initial_conv=4, out_channels_residual_block=8)
output = layer(np.zeros((2, 32, 32, 8), dtype=np.float32))
self.assertEqual(output.shape, (2, 32, 32, 8))
def test_encoder_decoder_block(self):
layer = hourglass.EncoderDecoderBlock(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册