diff --git a/research/object_detection/core/post_processing.py b/research/object_detection/core/post_processing.py index 79b1d098be640dee47116d86846e3a98b928feed..fea777640d5326458d71f4332aefc02d25c770ba 100644 --- a/research/object_detection/core/post_processing.py +++ b/research/object_detection/core/post_processing.py @@ -402,10 +402,10 @@ def _clip_boxes(boxes, clip_window): window. """ ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=-1) - clipped_ymin = tf.maximum(ymin, clip_window[:, 0]) - clipped_xmin = tf.maximum(xmin, clip_window[:, 1]) - clipped_ymax = tf.minimum(ymax, clip_window[:, 2]) - clipped_xmax = tf.minimum(xmax, clip_window[:, 3]) + clipped_ymin = tf.maximum(ymin, clip_window[:, 0, tf.newaxis]) + clipped_xmin = tf.maximum(xmin, clip_window[:, 1, tf.newaxis]) + clipped_ymax = tf.minimum(ymax, clip_window[:, 2, tf.newaxis]) + clipped_xmax = tf.minimum(xmax, clip_window[:, 3, tf.newaxis]) return tf.stack([clipped_ymin, clipped_xmin, clipped_ymax, clipped_xmax], axis=-1)