diff --git a/paddle/fluid/operators/detection/bipartite_match_op.cc b/paddle/fluid/operators/detection/bipartite_match_op.cc index 16e1699e12c832d54af14f673577dcc32b015d6d..5cd853758926e622d0f87e6f8bbaba2cf3b9f85e 100644 --- a/paddle/fluid/operators/detection/bipartite_match_op.cc +++ b/paddle/fluid/operators/detection/bipartite_match_op.cc @@ -222,10 +222,12 @@ class BipartiteMatchKernel : public framework::OpKernel { } else { auto lod = dist_mat->lod().back(); for (size_t i = 0; i < lod.size() - 1; ++i) { - Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]); - BipartiteMatch(one_ins, indices + i * col, dist + i * col); - if (type == "per_prediction") { - ArgMaxMatch(one_ins, indices + i * col, dist + i * col, threshold); + if (lod[i + 1] > lod[i]) { + Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]); + BipartiteMatch(one_ins, indices + i * col, dist + i * col); + if (type == "per_prediction") { + ArgMaxMatch(one_ins, indices + i * col, dist + i * col, threshold); + } } } } diff --git a/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py b/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py index 5cc8e2ba15d260b988ee66a5711aed42ca04c10b..cc2b1165ec304a63671b48d4702142ea38c9a2c1 100644 --- a/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py +++ b/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py @@ -65,7 +65,7 @@ def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None): """Bipartite Matching algorithm for batch input. Arg: distance (numpy.array) : The distance of two entries with shape [M, N]. - lod (list of int): The offsets of each input in this batch. + lod (list of int): The length of each input in this batch. """ n = len(lod) m = distance.shape[1] @@ -73,6 +73,7 @@ def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None): match_dist = np.zeros((n, m), dtype=np.float32) cur_offset = 0 for i in range(n): + if lod[i] == 0: continue bipartite_match(distance[cur_offset:(cur_offset + lod[i]), :], match_indices[i, :], match_dist[i, :]) if match_type == 'per_prediction': @@ -155,5 +156,22 @@ class TestBipartiteMatchOpWithPerPredictionType(OpTest): self.check_output() +class TestBipartiteMatchOpWithEmptyLoD(OpTest): + def setUp(self): + self.op_type = 'bipartite_match' + lod = [[5, 6, 0, 12]] + dist = np.random.random((23, 217)).astype('float32') + match_indices, match_dist = batch_bipartite_match(dist, lod[0]) + + self.inputs = {'DistMat': (dist, lod)} + self.outputs = { + 'ColToRowMatchIndices': match_indices, + 'ColToRowMatchDist': match_dist, + } + + def test_check_output(self): + self.check_output() + + if __name__ == '__main__': unittest.main()