提交 8d9a16ce 编写于 作者: Y Yeqing Li 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 285844156
上级 913640d4
......@@ -30,10 +30,11 @@ REGULARIZATION_VAR_REGEX = r'.*(kernel|weight):0$'
BASE_CFG = {
'model_dir': '',
'use_tpu': True,
'strategy_type': 'tpu',
'isolate_session_state': False,
'train': {
'iterations_per_loop': 100,
'train_batch_size': 64,
'batch_size': 64,
'total_steps': 22500,
'num_cores_per_replica': None,
'input_partition_dims': None,
......@@ -57,13 +58,13 @@ BASE_CFG = {
'frozen_variable_prefix': RESNET_FROZEN_VAR_PREFIX,
'train_file_pattern': '',
'train_dataset_type': 'tfrecord',
'transpose_input': True,
'transpose_input': False,
'regularization_variable_regex': REGULARIZATION_VAR_REGEX,
'l2_weight_decay': 0.0001,
'gradient_clip_norm': 0.0,
},
'eval': {
'eval_batch_size': 8,
'batch_size': 8,
'eval_samples': 5000,
'min_eval_interval': 180,
'eval_timeout': None,
......
......@@ -34,6 +34,7 @@ MASKRCNN_CFG.override({
'maskrcnn_parser': {
'use_bfloat16': True,
'output_size': [1024, 1024],
'num_channels': 3,
'rpn_match_threshold': 0.7,
'rpn_unmatched_threshold': 0.3,
'rpn_batch_size_per_im': 256,
......
......@@ -275,6 +275,10 @@ class Parser(object):
if self._use_bfloat16:
image = tf.cast(image, dtype=tf.bfloat16)
inputs = {
'image': image,
'image_info': image_info,
}
# Packs labels for model_fn outputs.
labels = {
'anchor_boxes': input_anchor.multilevel_boxes,
......@@ -282,15 +286,16 @@ class Parser(object):
'rpn_score_targets': rpn_score_targets,
'rpn_box_targets': rpn_box_targets,
}
labels['gt_boxes'] = input_utils.pad_to_fixed_size(
boxes, self._max_num_instances, -1)
labels['gt_classes'] = input_utils.pad_to_fixed_size(
inputs['gt_boxes'] = input_utils.pad_to_fixed_size(boxes,
self._max_num_instances,
-1)
inputs['gt_classes'] = input_utils.pad_to_fixed_size(
classes, self._max_num_instances, -1)
if self._include_mask:
labels['gt_masks'] = input_utils.pad_to_fixed_size(
inputs['gt_masks'] = input_utils.pad_to_fixed_size(
masks, self._max_num_instances, -1)
return image, labels
return inputs, labels
def _parse_eval_data(self, data):
"""Parses data for evaluation."""
......@@ -348,11 +353,7 @@ class Parser(object):
self._anchor_size,
(image_height, image_width))
labels = {
'source_id': dataloader_utils.process_source_id(data['source_id']),
'anchor_boxes': input_anchor.multilevel_boxes,
'image_info': image_info,
}
labels = {}
if self._mode == ModeKeys.PREDICT_WITH_GT:
# Converts boxes from normalized coordinates to pixel coordinates.
......@@ -372,6 +373,11 @@ class Parser(object):
groundtruths['source_id'])
groundtruths = dataloader_utils.pad_groundtruths_to_fixed_size(
groundtruths, self._max_num_instances)
# TODO(yeqing): Remove the `groundtrtuh` layer key (no longer needed).
labels['groundtruths'] = groundtruths
inputs = {
'image': image,
'image_info': image_info,
}
return image, labels
return inputs, labels
......@@ -99,6 +99,7 @@ class Model(object):
params.train.learning_rate)
self._frozen_variable_prefix = params.train.frozen_variable_prefix
self._l2_weight_decay = params.train.l2_weight_decay
# Checkpoint restoration.
self._checkpoint = params.train.checkpoint.as_dict()
......
......@@ -147,6 +147,7 @@ class RpnBoxLoss(object):
"""Region Proposal Network box regression loss function."""
def __init__(self, params):
self._delta = params.huber_loss_delta
self._huber_loss = tf.keras.losses.Huber(
delta=params.huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM)
......@@ -212,7 +213,7 @@ class FastrcnnClassLoss(object):
a scalar tensor representing total class loss.
"""
with tf.name_scope('fast_rcnn_loss'):
_, _, _, num_classes = class_outputs.get_shape().as_list()
_, _, num_classes = class_outputs.get_shape().as_list()
class_targets = tf.cast(class_targets, dtype=tf.int32)
class_targets_one_hot = tf.one_hot(class_targets, num_classes)
return self._fast_rcnn_class_loss(class_outputs, class_targets_one_hot)
......@@ -320,9 +321,6 @@ class FastrcnnBoxLoss(object):
class MaskrcnnLoss(object):
"""Mask R-CNN instance segmentation mask loss function."""
def __init__(self):
raise ValueError('Not TF 2.0 ready.')
def __call__(self, mask_outputs, mask_targets, select_class_targets):
"""Computes the mask loss of Mask-RCNN.
......
......@@ -56,7 +56,6 @@ class RetinanetModel(base_model.Model):
self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
params.postprocess)
self._l2_weight_decay = params.train.l2_weight_decay
self._transpose_input = params.train.transpose_input
assert not self._transpose_input, 'Transpose input is not supportted.'
# Input layer.
......@@ -134,6 +133,7 @@ class RetinanetModel(base_model.Model):
return self._keras_model
def post_processing(self, labels, outputs):
# TODO(yeqing): Moves the output related part into build_outputs.
required_output_fields = ['cls_outputs', 'box_outputs']
for field in required_output_fields:
if field not in outputs:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册