From 778b71fc93cf1cc541cabfddbd1b229898229506 Mon Sep 17 00:00:00 2001 From: baiyf Date: Wed, 27 Jun 2018 16:51:42 +0800 Subject: [PATCH] Optimize bipartite_match_op in large scale input (#11730) * optimize bipartite_match_op in large scale input --- .../operators/detection/bipartite_match_op.cc | 98 +++++++++++++------ .../unittests/test_bipartite_match_op.py | 17 ++++ 2 files changed, 84 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/operators/detection/bipartite_match_op.cc b/paddle/fluid/operators/detection/bipartite_match_op.cc index d437ad5c198..c23b65fe4de 100644 --- a/paddle/fluid/operators/detection/bipartite_match_op.cc +++ b/paddle/fluid/operators/detection/bipartite_match_op.cc @@ -51,6 +51,12 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { } }; +template +bool DistPairDescend(std::tuple pair1, + std::tuple pair2) { + return std::get<2>(pair1) > std::get<2>(pair2); +} + template class BipartiteMatchKernel : public framework::OpKernel { public: @@ -58,46 +64,76 @@ class BipartiteMatchKernel : public framework::OpKernel { // The match_dist must be initialized to 0 at first. void BipartiteMatch(const Tensor& dist, int* match_indices, T* match_dist) const { - constexpr T kEPS = static_cast(1e-6); PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2."); int64_t row = dist.dims()[0]; int64_t col = dist.dims()[1]; auto* dist_data = dist.data(); - std::vector row_pool; - for (int i = 0; i < row; ++i) { - row_pool.push_back(i); - } - while (row_pool.size() > 0) { - int max_idx = -1; - int max_row_idx = -1; - T max_dist = -1; - for (int64_t j = 0; j < col; ++j) { - if (match_indices[j] != -1) { - continue; + // Test result: When row==130 the speed of these two methods almost the same + if (row >= 130) { + std::vector> match_pair; + + for (int64_t i = 0; i < row; ++i) { + for (int64_t j = 0; j < col; ++j) { + match_pair.push_back(std::make_tuple(i, j, dist_data[i * col + j])); } - for (size_t k = 0; k < row_pool.size(); ++k) { - int m = row_pool[k]; - // distance is 0 between m-th row and j-th column - if (dist_data[m * col + j] < kEPS) { + } + std::sort(match_pair.begin(), match_pair.end(), DistPairDescend); + std::vector row_indices(row, -1); + + int64_t idx = 0; + for (int64_t k = 0; k < row * col; ++k) { + int64_t i = std::get<0>(match_pair[k]); + int64_t j = std::get<1>(match_pair[k]); + T dist = std::get<2>(match_pair[k]); + + if (idx >= row) { + break; + } + if (match_indices[j] == -1 && row_indices[i] == -1 && dist > 0) { + match_indices[j] = i; + row_indices[i] = j; + match_dist[j] = dist; + idx += 1; + } + } + } else { + constexpr T kEPS = static_cast(1e-6); + std::vector row_pool; + for (int i = 0; i < row; ++i) { + row_pool.push_back(i); + } + while (row_pool.size() > 0) { + int max_idx = -1; + int max_row_idx = -1; + T max_dist = -1; + for (int64_t j = 0; j < col; ++j) { + if (match_indices[j] != -1) { continue; } - if (dist_data[m * col + j] > max_dist) { - max_idx = j; - max_row_idx = m; - max_dist = dist_data[m * col + j]; + for (size_t k = 0; k < row_pool.size(); ++k) { + int m = row_pool[k]; + // distance is 0 between m-th row and j-th column + if (dist_data[m * col + j] < kEPS) { + continue; + } + if (dist_data[m * col + j] > max_dist) { + max_idx = j; + max_row_idx = m; + max_dist = dist_data[m * col + j]; + } } } - } - if (max_idx == -1) { - // Cannot find good match. - break; - } else { - PADDLE_ENFORCE_EQ(match_indices[max_idx], -1); - match_indices[max_idx] = max_row_idx; - match_dist[max_idx] = max_dist; - // Erase the row index. - row_pool.erase( - std::find(row_pool.begin(), row_pool.end(), max_row_idx)); + if (max_idx == -1) { + // Cannot find good match. + break; + } else { + PADDLE_ENFORCE_EQ(match_indices[max_idx], -1); + match_indices[max_idx] = max_row_idx; + match_dist[max_idx] = max_dist; + // Erase the row index. + row_pool.erase( + std::find(row_pool.begin(), row_pool.end(), max_row_idx)); + } } } } diff --git a/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py b/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py index 1a245fd756c..d5bd726c4a8 100644 --- a/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py +++ b/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py @@ -114,6 +114,23 @@ class TestBipartiteMatchOpWithoutLoD(OpTest): self.check_output() +class TestBipartiteMatchOpWithoutLoDLargeScaleInput(OpTest): + def setUp(self): + self.op_type = 'bipartite_match' + lod = [[300]] + dist = np.random.random((300, 17)).astype('float32') + match_indices, match_dist = batch_bipartite_match(dist, lod[0]) + + self.inputs = {'DistMat': dist} + self.outputs = { + 'ColToRowMatchIndices': match_indices, + 'ColToRowMatchDist': match_dist, + } + + def test_check_output(self): + self.check_output() + + class TestBipartiteMatchOpWithPerPredictionType(OpTest): def setUp(self): self.op_type = 'bipartite_match' -- GitLab