提交 02ff7788 编写于 作者: R Reed Wanderman-Milne 提交者: A. Unique TensorFlower

Allow mask rcnn to be run with mixed precision without NaNs.

Some parts of the forward pass would previously overflow in float16. Such parts are now done in float32.

PiperOrigin-RevId: 380857663
上级 f13895b9
......@@ -198,7 +198,8 @@ def multilevel_crop_and_resize(features,
# Assigns boxes to the right level.
box_width = boxes[:, :, 3] - boxes[:, :, 1]
box_height = boxes[:, :, 2] - boxes[:, :, 0]
areas_sqrt = tf.cast(tf.sqrt(box_height * box_width), tf.float32)
areas_sqrt = tf.sqrt(
tf.cast(box_height, tf.float32) * tf.cast(box_width, tf.float32))
levels = tf.cast(
tf.math.floordiv(
tf.math.log(tf.divide(areas_sqrt, 224.0)),
......@@ -456,6 +457,12 @@ def crop_mask_in_target_box(masks,
[batch_size, num_boxes, output_size, output_size].
"""
with tf.name_scope('crop_mask_in_target_box'):
# Cast to float32, as the y_transform and other transform variables may
# overflow in float16
masks = tf.cast(masks, tf.float32)
boxes = tf.cast(boxes, tf.float32)
target_boxes = tf.cast(target_boxes, tf.float32)
batch_size, num_masks, height, width = masks.get_shape().as_list()
if batch_size is None:
batch_size = tf.shape(masks)[0]
......
......@@ -132,6 +132,9 @@ class IouSimilarity:
Output shape:
[M, N], or [B, M, N]
"""
boxes_1 = tf.cast(boxes_1, tf.float32)
boxes_2 = tf.cast(boxes_2, tf.float32)
boxes_1_rank = len(boxes_1.shape)
boxes_2_rank = len(boxes_2.shape)
if boxes_1_rank < 2 or boxes_1_rank > 3:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册