提交 4b617781 编写于 作者: A A. Unique TensorFlower 提交者: TF Object Detection Team

INTERNAL CHANGE.

PiperOrigin-RevId: 397358689
上级 b656833a
......@@ -1182,13 +1182,24 @@ def _build_center_net_feature_extractor(feature_extractor_config, is_training):
list(feature_extractor_config.channel_stds),
'bgr_ordering':
feature_extractor_config.bgr_ordering,
'depth_multiplier':
feature_extractor_config.depth_multiplier,
'use_separable_conv':
use_separable_conv,
'upsampling_interpolation':
feature_extractor_config.upsampling_interpolation,
}
if feature_extractor_config.HasField('depth_multiplier'):
kwargs.update({
'depth_multiplier': feature_extractor_config.depth_multiplier,
})
if feature_extractor_config.HasField('use_separable_conv'):
kwargs.update({
'use_separable_conv': use_separable_conv,
})
if feature_extractor_config.HasField('upsampling_interpolation'):
kwargs.update({
'upsampling_interpolation':
feature_extractor_config.upsampling_interpolation,
})
if feature_extractor_config.HasField('use_depthwise'):
kwargs.update({
'use_depthwise': feature_extractor_config.use_depthwise,
})
return CENTER_NET_EXTRACTOR_FUNCTION_MAP[feature_extractor_config.type](
......
......@@ -476,7 +476,7 @@ class ModelBuilderTF2Test(
num_classes: 10
feature_extractor {
type: "mobilenet_v2_fpn"
depth_multiplier: 1.0
depth_multiplier: 2.0
use_separable_conv: true
upsampling_interpolation: "bilinear"
}
......@@ -513,6 +513,21 @@ class ModelBuilderTF2Test(
# Verify that there are up_sampling2d layers.
self.assertGreater(num_up_sampling2d_layers, 0)
# Verify that the FPN ops uses separable conv.
for layer in fpn.layers:
# Convolution layers with kernel size not equal to (1, 1) should be
# separable 2D convolutions.
if 'conv' in layer.name and layer.kernel_size != (1, 1):
self.assertIsInstance(layer, tf.keras.layers.SeparableConv2D)
# Verify that the backbone indeed double the number of channel according to
# the depthmultiplier.
backbone = feature_extractor.get_layer('model')
first_conv = backbone.get_layer('Conv1')
# Note that the first layer typically has 32 filters, but this model has
# a depth multiplier of 2.
self.assertEqual(64, first_conv.filters)
def test_create_center_net_deepmac(self):
"""Test building a CenterNet DeepMAC model."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册