未验证 提交 24566e95 编写于 作者: Q qingqing01 提交者: GitHub

Support empty bbox in bipartite math op (#26488)

上级 87843beb
...@@ -222,10 +222,12 @@ class BipartiteMatchKernel : public framework::OpKernel<T> { ...@@ -222,10 +222,12 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
} else { } else {
auto lod = dist_mat->lod().back(); auto lod = dist_mat->lod().back();
for (size_t i = 0; i < lod.size() - 1; ++i) { for (size_t i = 0; i < lod.size() - 1; ++i) {
Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]); if (lod[i + 1] > lod[i]) {
BipartiteMatch(one_ins, indices + i * col, dist + i * col); Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]);
if (type == "per_prediction") { BipartiteMatch(one_ins, indices + i * col, dist + i * col);
ArgMaxMatch(one_ins, indices + i * col, dist + i * col, threshold); if (type == "per_prediction") {
ArgMaxMatch(one_ins, indices + i * col, dist + i * col, threshold);
}
} }
} }
} }
......
...@@ -65,7 +65,7 @@ def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None): ...@@ -65,7 +65,7 @@ def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None):
"""Bipartite Matching algorithm for batch input. """Bipartite Matching algorithm for batch input.
Arg: Arg:
distance (numpy.array) : The distance of two entries with shape [M, N]. distance (numpy.array) : The distance of two entries with shape [M, N].
lod (list of int): The offsets of each input in this batch. lod (list of int): The length of each input in this batch.
""" """
n = len(lod) n = len(lod)
m = distance.shape[1] m = distance.shape[1]
...@@ -73,6 +73,7 @@ def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None): ...@@ -73,6 +73,7 @@ def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None):
match_dist = np.zeros((n, m), dtype=np.float32) match_dist = np.zeros((n, m), dtype=np.float32)
cur_offset = 0 cur_offset = 0
for i in range(n): for i in range(n):
if lod[i] == 0: continue
bipartite_match(distance[cur_offset:(cur_offset + lod[i]), :], bipartite_match(distance[cur_offset:(cur_offset + lod[i]), :],
match_indices[i, :], match_dist[i, :]) match_indices[i, :], match_dist[i, :])
if match_type == 'per_prediction': if match_type == 'per_prediction':
...@@ -155,5 +156,22 @@ class TestBipartiteMatchOpWithPerPredictionType(OpTest): ...@@ -155,5 +156,22 @@ class TestBipartiteMatchOpWithPerPredictionType(OpTest):
self.check_output() self.check_output()
class TestBipartiteMatchOpWithEmptyLoD(OpTest):
def setUp(self):
self.op_type = 'bipartite_match'
lod = [[5, 6, 0, 12]]
dist = np.random.random((23, 217)).astype('float32')
match_indices, match_dist = batch_bipartite_match(dist, lod[0])
self.inputs = {'DistMat': (dist, lod)}
self.outputs = {
'ColToRowMatchIndices': match_indices,
'ColToRowMatchDist': match_dist,
}
def test_check_output(self):
self.check_output()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册