提交 014c9383 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 447823991
上级 a7ba08aa
......@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Box matcher implementation."""
from typing import List, Tuple
import tensorflow as tf
......@@ -43,15 +43,19 @@ class BoxMatcher:
assigned positive_value.
"""
def __init__(self, thresholds, indicators, force_match_for_each_col=False):
def __init__(self,
thresholds: List[float],
indicators: List[int],
force_match_for_each_col: bool = False):
"""Construct BoxMatcher.
Args:
thresholds: A list of thresholds to classify boxes into
different buckets. The list needs to be sorted, and will be prepended
with -Inf and appended with +Inf.
indicators: A list of values to assign for each bucket. len(`indicators`)
must equal to len(`thresholds`) + 1.
thresholds: A list of thresholds to classify the matches into different
types (e.g. positive or negative or ignored match). The list needs to be
sorted, and will be prepended with -Inf and appended with +Inf.
indicators: A list of values representing match types (e.g. positive or
negative or ignored match). len(`indicators`) must equal to
len(`thresholds`) + 1.
force_match_for_each_col: If True, ensures that each column is matched to
at least one row (which is not guaranteed otherwise if the
positive_threshold is high). Defaults to False. If True, all force
......@@ -74,19 +78,20 @@ class BoxMatcher:
self.thresholds = thresholds
self._force_match_for_each_col = force_match_for_each_col
def __call__(self, similarity_matrix):
def __call__(self,
similarity_matrix: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Tries to match each column of the similarity matrix to a row.
Args:
similarity_matrix: A float tensor of shape [N, M] representing any
similarity metric.
similarity_matrix: A float tensor of shape [num_rows, num_cols] or
[batch_size, num_rows, num_cols] representing any similarity metric.
Returns:
A integer tensor of shape [N] with corresponding match indices for each
of M columns, for positive match, the match result will be the
corresponding row index, for negative match, the match will be
`negative_value`, for ignored match, the match result will be
`ignore_value`.
matched_columns: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the index of the matched column for each row.
match_indicators: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the match type indicator (e.g. positive or negative or
ignored match).
"""
squeeze_result = False
if len(similarity_matrix.shape) == 2:
......@@ -101,29 +106,37 @@ class BoxMatcher:
"""Performs matching when the rows of similarity matrix are empty.
When the rows are empty, all detections are false positives. So we return
a tensor of -1's to indicate that the columns do not match to any rows.
a tensor of -1's to indicate that the rows do not match to any columns.
Returns:
matches: int32 tensor indicating the row each column matches to.
matched_columns: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the index of the matched column for each row.
match_indicators: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the match type indicator (e.g. positive or negative
or ignored match).
"""
with tf.name_scope('empty_gt_boxes'):
matches = tf.zeros([batch_size, num_rows], dtype=tf.int32)
match_labels = -tf.ones([batch_size, num_rows], dtype=tf.int32)
return matches, match_labels
matched_columns = tf.zeros([batch_size, num_rows], dtype=tf.int32)
match_indicators = -tf.ones([batch_size, num_rows], dtype=tf.int32)
return matched_columns, match_indicators
def _match_when_rows_are_non_empty():
"""Performs matching when the rows of similarity matrix are non empty.
Returns:
matches: int32 tensor indicating the row each column matches to.
matched_columns: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the index of the matched column for each row.
match_indicators: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the match type indicator (e.g. positive or negative
or ignored match).
"""
# Matches for each column
with tf.name_scope('non_empty_gt_boxes'):
matches = tf.argmax(similarity_matrix, axis=-1, output_type=tf.int32)
matched_columns = tf.argmax(
similarity_matrix, axis=-1, output_type=tf.int32)
# Get logical indices of ignored and unmatched columns as tf.int64
matched_vals = tf.reduce_max(similarity_matrix, axis=-1)
matched_indicators = tf.zeros([batch_size, num_rows], tf.int32)
match_indicators = tf.zeros([batch_size, num_rows], tf.int32)
match_dtype = matched_vals.dtype
for (ind, low, high) in zip(self.indicators, self.thresholds[:-1],
......@@ -133,48 +146,46 @@ class BoxMatcher:
mask = tf.logical_and(
tf.greater_equal(matched_vals, low_threshold),
tf.less(matched_vals, high_threshold))
matched_indicators = self._set_values_using_indicator(
matched_indicators, mask, ind)
match_indicators = self._set_values_using_indicator(
match_indicators, mask, ind)
if self._force_match_for_each_col:
# [batch_size, M], for each col (groundtruth_box), find the best
# matching row (anchor).
force_match_column_ids = tf.argmax(
# [batch_size, num_cols], for each column (groundtruth_box), find the
# best matching row (anchor).
matching_rows = tf.argmax(
input=similarity_matrix, axis=1, output_type=tf.int32)
# [batch_size, M, N]
force_match_column_indicators = tf.one_hot(
force_match_column_ids, depth=num_rows)
# [batch_size, N], for each row (anchor), find the largest column
# index for groundtruth box
force_match_row_ids = tf.argmax(
input=force_match_column_indicators, axis=1, output_type=tf.int32)
# [batch_size, N]
force_match_column_mask = tf.cast(
tf.reduce_max(force_match_column_indicators, axis=1),
tf.bool)
# [batch_size, N]
final_matches = tf.where(force_match_column_mask, force_match_row_ids,
matches)
final_matched_indicators = tf.where(
force_match_column_mask, self.indicators[-1] *
tf.ones([batch_size, num_rows], dtype=tf.int32),
matched_indicators)
return final_matches, final_matched_indicators
else:
return matches, matched_indicators
# [batch_size, num_cols, num_rows], a transposed 0-1 mapping matrix M,
# where M[j, i] = 1 means column j is matched to row i.
column_to_row_match_mapping = tf.one_hot(
matching_rows, depth=num_rows)
# [batch_size, num_rows], for each row (anchor), find the matched
# column (groundtruth_box).
force_matched_columns = tf.argmax(
input=column_to_row_match_mapping, axis=1, output_type=tf.int32)
# [batch_size, num_rows]
force_matched_column_mask = tf.cast(
tf.reduce_max(column_to_row_match_mapping, axis=1), tf.bool)
# [batch_size, num_rows]
matched_columns = tf.where(force_matched_column_mask,
force_matched_columns, matched_columns)
match_indicators = tf.where(
force_matched_column_mask, self.indicators[-1] *
tf.ones([batch_size, num_rows], dtype=tf.int32), match_indicators)
return matched_columns, match_indicators
num_gt_boxes = similarity_matrix.shape.as_list()[-1] or tf.shape(
similarity_matrix)[-1]
result_match, result_matched_indicators = tf.cond(
matched_columns, match_indicators = tf.cond(
pred=tf.greater(num_gt_boxes, 0),
true_fn=_match_when_rows_are_non_empty,
false_fn=_match_when_rows_are_empty)
if squeeze_result:
result_match = tf.squeeze(result_match, axis=0)
result_matched_indicators = tf.squeeze(result_matched_indicators, axis=0)
matched_columns = tf.squeeze(matched_columns, axis=0)
match_indicators = tf.squeeze(match_indicators, axis=0)
return result_match, result_matched_indicators
return matched_columns, match_indicators
def _set_values_using_indicator(self, x, indicator, val):
"""Set the indicated fields of x to val.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册