提交 d41a1626 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 510189719
上级 8b35b7ab
......@@ -673,18 +673,12 @@ class MaskrcnnHead(tf.keras.layers.Layer):
])
with tf.name_scope('masks_post_processing'):
# TODO(pengchong): Figure out the way not to use the static inferred
# batch size.
batch_size, num_masks = class_indices.get_shape().as_list()
mask_outputs = tf.transpose(a=mask_outputs, perm=[0, 1, 4, 2, 3])
# Constructs indices for gather.
batch_indices = tf.tile(
tf.expand_dims(tf.range(batch_size), axis=1), [1, num_masks])
mask_indices = tf.tile(
tf.expand_dims(tf.range(num_masks), axis=0), [batch_size, 1])
gather_indices = tf.stack(
[batch_indices, mask_indices, class_indices], axis=2)
mask_outputs = tf.gather_nd(mask_outputs, gather_indices)
mask_outputs = tf.gather(
mask_outputs,
tf.cast(class_indices, tf.int32),
axis=-1,
batch_dims=2,
)
return mask_outputs
......
......@@ -208,12 +208,7 @@ class DeepMaskHead(tf.keras.layers.Layer):
roi_width * upsample_factor], representing the mask predictions.
"""
roi_features, roi_classes = inputs
features_shape = tf.shape(roi_features)
batch_size, num_rois, height, width, filters = (
features_shape[0], features_shape[1], features_shape[2],
features_shape[3], features_shape[4])
if batch_size is None:
batch_size = tf.shape(roi_features)[0]
_, num_rois, height, width, filters = roi_features.get_shape().as_list()
x = tf.reshape(roi_features, [-1, height, width, filters])
......@@ -229,29 +224,15 @@ class DeepMaskHead(tf.keras.layers.Layer):
mask_width = width * self._config_dict['upsample_factor']
if self._config_dict['class_agnostic']:
logits = tf.reshape(logits, [-1, num_rois, mask_height, mask_width, 1])
return tf.reshape(logits, [-1, num_rois, mask_height, mask_width])
else:
logits = tf.reshape(
logits,
[-1, num_rois, mask_height, mask_width,
self._config_dict['num_classes']])
batch_indices = tf.tile(
tf.expand_dims(tf.range(batch_size), axis=1), [1, num_rois])
mask_indices = tf.tile(
tf.expand_dims(tf.range(num_rois), axis=0), [batch_size, 1])
if self._config_dict['class_agnostic']:
class_gather_indices = tf.zeros_like(roi_classes, dtype=tf.int32)
else:
class_gather_indices = tf.cast(roi_classes, dtype=tf.int32)
gather_indices = tf.stack(
[batch_indices, mask_indices, class_gather_indices],
axis=2)
mask_outputs = tf.gather_nd(
tf.transpose(logits, [0, 1, 4, 2, 3]), gather_indices)
return mask_outputs
return tf.gather(
logits, tf.cast(roi_classes, dtype=tf.int32), axis=-1, batch_dims=2
)
def _build_convnet_variant(self):
......
......@@ -399,10 +399,7 @@ class MaskHead(tf.keras.layers.Layer):
roi_width * upsample_factor], representing the mask predictions.
"""
roi_features, roi_classes = inputs
batch_size, num_rois, height, width, filters = (
roi_features.get_shape().as_list())
if batch_size is None:
batch_size = tf.shape(roi_features)[0]
_, num_rois, height, width, filters = roi_features.get_shape().as_list()
x = tf.reshape(roi_features, [-1, height, width, filters])
for conv, bn in zip(self._convs, self._conv_norms):
......@@ -420,29 +417,15 @@ class MaskHead(tf.keras.layers.Layer):
mask_width = width * self._config_dict['upsample_factor']
if self._config_dict['class_agnostic']:
logits = tf.reshape(logits, [-1, num_rois, mask_height, mask_width, 1])
return tf.reshape(logits, [-1, num_rois, mask_height, mask_width])
else:
logits = tf.reshape(
logits,
[-1, num_rois, mask_height, mask_width,
self._config_dict['num_classes']])
batch_indices = tf.tile(
tf.expand_dims(tf.range(batch_size), axis=1), [1, num_rois])
mask_indices = tf.tile(
tf.expand_dims(tf.range(num_rois), axis=0), [batch_size, 1])
if self._config_dict['class_agnostic']:
class_gather_indices = tf.zeros_like(roi_classes, dtype=tf.int32)
else:
class_gather_indices = tf.cast(roi_classes, dtype=tf.int32)
gather_indices = tf.stack(
[batch_indices, mask_indices, class_gather_indices],
axis=2)
mask_outputs = tf.gather_nd(
tf.transpose(logits, [0, 1, 4, 2, 3]), gather_indices)
return mask_outputs
return tf.gather(
logits, tf.cast(roi_classes, dtype=tf.int32), axis=-1, batch_dims=2
)
def get_config(self):
return self._config_dict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册