提交 986ffac4 编写于 作者: Y Yeqing Li 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 282065024
上级 a9387332
......@@ -24,6 +24,7 @@
# Note that we need to trailing `/` to avoid the incorrect match.
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
RESNET50_FROZEN_VAR_PREFIX = r'(resnet\d+/)conv2d(|_([1-9]|10))\/'
RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
# pylint: disable=line-too-long
......@@ -38,6 +39,7 @@ RETINANET_CFG = {
'optimizer': {
'type': 'momentum',
'momentum': 0.9,
'nesterov': False,
},
'learning_rate': {
'type': 'step',
......@@ -56,6 +58,7 @@ RETINANET_CFG = {
# TODO(b/142174042): Support transpose_input option.
'transpose_input': False,
'l2_weight_decay': 0.0001,
'input_sharding': False,
},
'eval': {
'batch_size': 8,
......@@ -65,6 +68,7 @@ RETINANET_CFG = {
'type': 'box',
'val_json_file': '',
'eval_file_pattern': '',
'input_sharding': True,
},
'predict': {
'predict_batch_size': 8,
......@@ -165,7 +169,8 @@ RETINANET_CFG = {
'num_classes': 91,
'max_total_size': 100,
'nms_iou_threshold': 0.5,
'score_threshold': 0.05
'score_threshold': 0.05,
'pre_nms_num_boxes': 5000,
},
'enable_summary': False,
}
......
......@@ -58,6 +58,15 @@ class InputFn(object):
self._parser_fn = factory.parser_generator(params, mode)
self._dataset_fn = tf.data.TFRecordDataset
self._input_sharding = (not self._is_training)
try:
if self._is_training:
self._input_sharding = params.train.input_sharding
else:
self._input_sharding = params.eval.input_sharding
except KeyError:
pass
def __call__(self, ctx=None, batch_size: int = None):
"""Provides tf.data.Dataset object.
......@@ -74,7 +83,7 @@ class InputFn(object):
dataset = tf.data.Dataset.list_files(
self._file_pattern, shuffle=self._is_training)
if ctx and ctx.num_input_pipelines > 1:
if self._input_sharding and ctx and ctx.num_input_pipelines > 1:
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
if self._is_training:
dataset = dataset.repeat()
......@@ -82,6 +91,7 @@ class InputFn(object):
dataset = dataset.interleave(
map_func=lambda file_name: self._dataset_fn(file_name), cycle_length=32,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.cache()
if self._is_training:
dataset = dataset.shuffle(64)
......
......@@ -58,6 +58,13 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
trainable_variables)
logging.info('Filter trainable variables from %d to %d',
len(model.trainable_variables), len(trainable_variables))
_update_state = lambda labels, outputs: None
if isinstance(metric, tf.keras.metrics.Metric):
_update_state = lambda labels, outputs: metric.update_state(
labels, outputs)
else:
logging.error('Detection: train metric is not an instance of '
'tf.keras.metrics.Metric.')
def _replicated_step(inputs):
"""Replicated training step."""
......@@ -71,11 +78,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
v = tf.reduce_mean(v) / strategy.num_replicas_in_sync
losses[k] = v
loss = losses['total_loss']
if isinstance(metric, tf.keras.metrics.Metric):
metric.update_state(labels, outputs)
else:
logging.error('train metric is not an instance of '
'tf.keras.metrics.Metric.')
_update_state(labels, outputs)
grads = tape.gradient(loss, trainable_variables)
optimizer.apply_gradients(zip(grads, trainable_variables))
......
......@@ -36,8 +36,15 @@ class OptimizerFactory(object):
def __init__(self, params):
"""Creates optimized based on the specified flags."""
if params.type == 'momentum':
nesterov = False
try:
nesterov = params.nesterov
except KeyError:
pass
self._optimizer = functools.partial(
tf.keras.optimizers.SGD, momentum=0.9, nesterov=True)
tf.keras.optimizers.SGD,
momentum=params.momentum,
nesterov=nesterov)
elif params.type == 'adam':
self._optimizer = tf.keras.optimizers.Adam
elif params.type == 'adadelta':
......@@ -133,11 +140,10 @@ class Model(object):
"""
return _make_filter_trainable_variables_fn(self._frozen_variable_prefix)
def weight_decay_loss(self, l2_weight_decay, keras_model):
# TODO(yeqing): Correct the filter according to cr/269707763.
def weight_decay_loss(self, l2_weight_decay, trainable_variables):
return l2_weight_decay * tf.add_n([
tf.nn.l2_loss(v)
for v in self._keras_model.trainable_variables
for v in trainable_variables
if 'batch_normalization' not in v.name and 'bias' not in v.name
])
......
......@@ -40,7 +40,8 @@ def generate_detections_factory(params):
_generate_detections,
max_total_size=params.max_total_size,
nms_iou_threshold=params.nms_iou_threshold,
score_threshold=params.score_threshold)
score_threshold=params.score_threshold,
pre_nms_num_boxes=params.pre_nms_num_boxes)
return func
......
......@@ -120,6 +120,9 @@ class RetinanetModel(base_model.Model):
if self._keras_model is None:
raise ValueError('build_loss_fn() must be called after build_model().')
filter_fn = self.make_filter_trainable_variables_fn()
trainable_variables = filter_fn(self._keras_model.trainable_variables)
def _total_loss_fn(labels, outputs):
cls_loss = self._cls_loss_fn(outputs['cls_outputs'],
labels['cls_targets'],
......@@ -129,7 +132,7 @@ class RetinanetModel(base_model.Model):
labels['num_positives'])
model_loss = cls_loss + self._box_loss_weight * box_loss
l2_regularization_loss = self.weight_decay_loss(self._l2_weight_decay,
self._keras_model)
trainable_variables)
total_loss = model_loss + l2_regularization_loss
return {
'total_loss': total_loss,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册