提交 986ffac4 编写于 作者: Y Yeqing Li 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 282065024
上级 a9387332
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
# Note that we need to trailing `/` to avoid the incorrect match. # Note that we need to trailing `/` to avoid the incorrect match.
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198 # [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
RESNET50_FROZEN_VAR_PREFIX = r'(resnet\d+/)conv2d(|_([1-9]|10))\/' RESNET50_FROZEN_VAR_PREFIX = r'(resnet\d+/)conv2d(|_([1-9]|10))\/'
RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
# pylint: disable=line-too-long # pylint: disable=line-too-long
...@@ -38,6 +39,7 @@ RETINANET_CFG = { ...@@ -38,6 +39,7 @@ RETINANET_CFG = {
'optimizer': { 'optimizer': {
'type': 'momentum', 'type': 'momentum',
'momentum': 0.9, 'momentum': 0.9,
'nesterov': False,
}, },
'learning_rate': { 'learning_rate': {
'type': 'step', 'type': 'step',
...@@ -56,6 +58,7 @@ RETINANET_CFG = { ...@@ -56,6 +58,7 @@ RETINANET_CFG = {
# TODO(b/142174042): Support transpose_input option. # TODO(b/142174042): Support transpose_input option.
'transpose_input': False, 'transpose_input': False,
'l2_weight_decay': 0.0001, 'l2_weight_decay': 0.0001,
'input_sharding': False,
}, },
'eval': { 'eval': {
'batch_size': 8, 'batch_size': 8,
...@@ -65,6 +68,7 @@ RETINANET_CFG = { ...@@ -65,6 +68,7 @@ RETINANET_CFG = {
'type': 'box', 'type': 'box',
'val_json_file': '', 'val_json_file': '',
'eval_file_pattern': '', 'eval_file_pattern': '',
'input_sharding': True,
}, },
'predict': { 'predict': {
'predict_batch_size': 8, 'predict_batch_size': 8,
...@@ -165,7 +169,8 @@ RETINANET_CFG = { ...@@ -165,7 +169,8 @@ RETINANET_CFG = {
'num_classes': 91, 'num_classes': 91,
'max_total_size': 100, 'max_total_size': 100,
'nms_iou_threshold': 0.5, 'nms_iou_threshold': 0.5,
'score_threshold': 0.05 'score_threshold': 0.05,
'pre_nms_num_boxes': 5000,
}, },
'enable_summary': False, 'enable_summary': False,
} }
......
...@@ -58,6 +58,15 @@ class InputFn(object): ...@@ -58,6 +58,15 @@ class InputFn(object):
self._parser_fn = factory.parser_generator(params, mode) self._parser_fn = factory.parser_generator(params, mode)
self._dataset_fn = tf.data.TFRecordDataset self._dataset_fn = tf.data.TFRecordDataset
self._input_sharding = (not self._is_training)
try:
if self._is_training:
self._input_sharding = params.train.input_sharding
else:
self._input_sharding = params.eval.input_sharding
except KeyError:
pass
def __call__(self, ctx=None, batch_size: int = None): def __call__(self, ctx=None, batch_size: int = None):
"""Provides tf.data.Dataset object. """Provides tf.data.Dataset object.
...@@ -74,7 +83,7 @@ class InputFn(object): ...@@ -74,7 +83,7 @@ class InputFn(object):
dataset = tf.data.Dataset.list_files( dataset = tf.data.Dataset.list_files(
self._file_pattern, shuffle=self._is_training) self._file_pattern, shuffle=self._is_training)
if ctx and ctx.num_input_pipelines > 1: if self._input_sharding and ctx and ctx.num_input_pipelines > 1:
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
if self._is_training: if self._is_training:
dataset = dataset.repeat() dataset = dataset.repeat()
...@@ -82,6 +91,7 @@ class InputFn(object): ...@@ -82,6 +91,7 @@ class InputFn(object):
dataset = dataset.interleave( dataset = dataset.interleave(
map_func=lambda file_name: self._dataset_fn(file_name), cycle_length=32, map_func=lambda file_name: self._dataset_fn(file_name), cycle_length=32,
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.cache()
if self._is_training: if self._is_training:
dataset = dataset.shuffle(64) dataset = dataset.shuffle(64)
......
...@@ -58,6 +58,13 @@ class DetectionDistributedExecutor(executor.DistributedExecutor): ...@@ -58,6 +58,13 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
trainable_variables) trainable_variables)
logging.info('Filter trainable variables from %d to %d', logging.info('Filter trainable variables from %d to %d',
len(model.trainable_variables), len(trainable_variables)) len(model.trainable_variables), len(trainable_variables))
_update_state = lambda labels, outputs: None
if isinstance(metric, tf.keras.metrics.Metric):
_update_state = lambda labels, outputs: metric.update_state(
labels, outputs)
else:
logging.error('Detection: train metric is not an instance of '
'tf.keras.metrics.Metric.')
def _replicated_step(inputs): def _replicated_step(inputs):
"""Replicated training step.""" """Replicated training step."""
...@@ -71,11 +78,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor): ...@@ -71,11 +78,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
v = tf.reduce_mean(v) / strategy.num_replicas_in_sync v = tf.reduce_mean(v) / strategy.num_replicas_in_sync
losses[k] = v losses[k] = v
loss = losses['total_loss'] loss = losses['total_loss']
if isinstance(metric, tf.keras.metrics.Metric): _update_state(labels, outputs)
metric.update_state(labels, outputs)
else:
logging.error('train metric is not an instance of '
'tf.keras.metrics.Metric.')
grads = tape.gradient(loss, trainable_variables) grads = tape.gradient(loss, trainable_variables)
optimizer.apply_gradients(zip(grads, trainable_variables)) optimizer.apply_gradients(zip(grads, trainable_variables))
......
...@@ -36,8 +36,15 @@ class OptimizerFactory(object): ...@@ -36,8 +36,15 @@ class OptimizerFactory(object):
def __init__(self, params): def __init__(self, params):
"""Creates optimized based on the specified flags.""" """Creates optimized based on the specified flags."""
if params.type == 'momentum': if params.type == 'momentum':
nesterov = False
try:
nesterov = params.nesterov
except KeyError:
pass
self._optimizer = functools.partial( self._optimizer = functools.partial(
tf.keras.optimizers.SGD, momentum=0.9, nesterov=True) tf.keras.optimizers.SGD,
momentum=params.momentum,
nesterov=nesterov)
elif params.type == 'adam': elif params.type == 'adam':
self._optimizer = tf.keras.optimizers.Adam self._optimizer = tf.keras.optimizers.Adam
elif params.type == 'adadelta': elif params.type == 'adadelta':
...@@ -133,11 +140,10 @@ class Model(object): ...@@ -133,11 +140,10 @@ class Model(object):
""" """
return _make_filter_trainable_variables_fn(self._frozen_variable_prefix) return _make_filter_trainable_variables_fn(self._frozen_variable_prefix)
def weight_decay_loss(self, l2_weight_decay, keras_model): def weight_decay_loss(self, l2_weight_decay, trainable_variables):
# TODO(yeqing): Correct the filter according to cr/269707763.
return l2_weight_decay * tf.add_n([ return l2_weight_decay * tf.add_n([
tf.nn.l2_loss(v) tf.nn.l2_loss(v)
for v in self._keras_model.trainable_variables for v in trainable_variables
if 'batch_normalization' not in v.name and 'bias' not in v.name if 'batch_normalization' not in v.name and 'bias' not in v.name
]) ])
......
...@@ -40,7 +40,8 @@ def generate_detections_factory(params): ...@@ -40,7 +40,8 @@ def generate_detections_factory(params):
_generate_detections, _generate_detections,
max_total_size=params.max_total_size, max_total_size=params.max_total_size,
nms_iou_threshold=params.nms_iou_threshold, nms_iou_threshold=params.nms_iou_threshold,
score_threshold=params.score_threshold) score_threshold=params.score_threshold,
pre_nms_num_boxes=params.pre_nms_num_boxes)
return func return func
......
...@@ -120,6 +120,9 @@ class RetinanetModel(base_model.Model): ...@@ -120,6 +120,9 @@ class RetinanetModel(base_model.Model):
if self._keras_model is None: if self._keras_model is None:
raise ValueError('build_loss_fn() must be called after build_model().') raise ValueError('build_loss_fn() must be called after build_model().')
filter_fn = self.make_filter_trainable_variables_fn()
trainable_variables = filter_fn(self._keras_model.trainable_variables)
def _total_loss_fn(labels, outputs): def _total_loss_fn(labels, outputs):
cls_loss = self._cls_loss_fn(outputs['cls_outputs'], cls_loss = self._cls_loss_fn(outputs['cls_outputs'],
labels['cls_targets'], labels['cls_targets'],
...@@ -129,7 +132,7 @@ class RetinanetModel(base_model.Model): ...@@ -129,7 +132,7 @@ class RetinanetModel(base_model.Model):
labels['num_positives']) labels['num_positives'])
model_loss = cls_loss + self._box_loss_weight * box_loss model_loss = cls_loss + self._box_loss_weight * box_loss
l2_regularization_loss = self.weight_decay_loss(self._l2_weight_decay, l2_regularization_loss = self.weight_decay_loss(self._l2_weight_decay,
self._keras_model) trainable_variables)
total_loss = model_loss + l2_regularization_loss total_loss = model_loss + l2_regularization_loss
return { return {
'total_loss': total_loss, 'total_loss': total_loss,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册