未验证 提交 24566e95 编写于 作者: Q qingqing01 提交者: GitHub

Support empty bbox in bipartite math op (#26488)

上级 87843beb
......@@ -222,10 +222,12 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
} else {
auto lod = dist_mat->lod().back();
for (size_t i = 0; i < lod.size() - 1; ++i) {
Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]);
BipartiteMatch(one_ins, indices + i * col, dist + i * col);
if (type == "per_prediction") {
ArgMaxMatch(one_ins, indices + i * col, dist + i * col, threshold);
if (lod[i + 1] > lod[i]) {
Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]);
BipartiteMatch(one_ins, indices + i * col, dist + i * col);
if (type == "per_prediction") {
ArgMaxMatch(one_ins, indices + i * col, dist + i * col, threshold);
}
}
}
}
......
......@@ -65,7 +65,7 @@ def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None):
"""Bipartite Matching algorithm for batch input.
Arg:
distance (numpy.array) : The distance of two entries with shape [M, N].
lod (list of int): The offsets of each input in this batch.
lod (list of int): The length of each input in this batch.
"""
n = len(lod)
m = distance.shape[1]
......@@ -73,6 +73,7 @@ def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None):
match_dist = np.zeros((n, m), dtype=np.float32)
cur_offset = 0
for i in range(n):
if lod[i] == 0: continue
bipartite_match(distance[cur_offset:(cur_offset + lod[i]), :],
match_indices[i, :], match_dist[i, :])
if match_type == 'per_prediction':
......@@ -155,5 +156,22 @@ class TestBipartiteMatchOpWithPerPredictionType(OpTest):
self.check_output()
class TestBipartiteMatchOpWithEmptyLoD(OpTest):
def setUp(self):
self.op_type = 'bipartite_match'
lod = [[5, 6, 0, 12]]
dist = np.random.random((23, 217)).astype('float32')
match_indices, match_dist = batch_bipartite_match(dist, lod[0])
self.inputs = {'DistMat': (dist, lod)}
self.outputs = {
'ColToRowMatchIndices': match_indices,
'ColToRowMatchDist': match_dist,
}
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册