From 4948f7b3fe20ff1aa87bd23d84e4fdba42a88e73 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Tue, 27 Feb 2018 10:33:56 +0800 Subject: [PATCH] Enhance bipartite_match_op to support argmax matching after bipartite matching. (#8580) * Enhance bipartite_match_op to support argmax matching after bipartite matching. * Fix typo error. --- paddle/fluid/operators/bipartite_match_op.cc | 57 ++++++++++++++++++- python/paddle/fluid/layers/detection.py | 19 ++++++- .../unittests/test_bipartite_match_op.py | 44 +++++++++++++- 3 files changed, 112 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/bipartite_match_op.cc b/paddle/fluid/operators/bipartite_match_op.cc index c536cf6b6b..2b3f26c0a8 100644 --- a/paddle/fluid/operators/bipartite_match_op.cc +++ b/paddle/fluid/operators/bipartite_match_op.cc @@ -94,6 +94,38 @@ class BipartiteMatchKernel : public framework::OpKernel { } } + void ArgMaxMatch(const Tensor& dist, int* match_indices, T* match_dist, + T overlap_threshold) const { + constexpr T kEPS = static_cast(1e-6); + int64_t row = dist.dims()[0]; + int64_t col = dist.dims()[1]; + auto* dist_data = dist.data(); + for (int64_t j = 0; j < col; ++j) { + if (match_indices[j] != -1) { + // the j-th column has been matched to one entity. + continue; + } + int max_row_idx = -1; + T max_dist = -1; + for (int i = 0; i < row; ++i) { + T dist = dist_data[i * col + j]; + if (dist < kEPS) { + // distance is 0 between m-th row and j-th column + continue; + } + if (dist >= overlap_threshold && dist > max_dist) { + max_row_idx = i; + max_dist = dist; + } + } + if (max_row_idx != -1) { + PADDLE_ENFORCE_EQ(match_indices[j], -1); + match_indices[j] = max_row_idx; + match_dist[j] = max_dist; + } + } + } + void Compute(const framework::ExecutionContext& context) const override { auto* dist_mat = context.Input("DistMat"); auto* match_indices = context.Output("ColToRowMatchIndices"); @@ -120,13 +152,21 @@ class BipartiteMatchKernel : public framework::OpKernel { int* indices = match_indices->data(); T* dist = match_dist->data(); + auto type = context.Attr("match_type"); + auto threshold = context.Attr("dist_threshold"); if (n == 1) { BipartiteMatch(*dist_mat, indices, dist); + if (type == "per_prediction") { + ArgMaxMatch(*dist_mat, indices, dist, threshold); + } } else { auto lod = dist_mat->lod().back(); for (size_t i = 0; i < lod.size() - 1; ++i) { Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]); BipartiteMatch(one_ins, indices + i * col, dist + i * col); + if (type == "per_prediction") { + ArgMaxMatch(one_ins, indices + i * col, dist + i * col, threshold); + } } } } @@ -147,6 +187,19 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker { "This tensor can contain LoD information to represent a batch of " "inputs. One instance of this batch can contain different numbers of " "entities."); + AddAttr( + "match_type", + "(string, defalut: per_prediction) " + "The type of matching method, should be 'bipartite' or " + "'per_prediction', 'bipartite' by defalut.") + .SetDefault("bipartite") + .InEnum({"bipartite", "per_prediction"}); + AddAttr( + "dist_threshold", + "(float, defalut: 0.5) " + "If `match_type` is 'per_prediction', this threshold is to determine " + "the extra matching bboxes based on the maximum distance.") + .SetDefault(0.5); AddOutput("ColToRowMatchIndices", "(Tensor) A 2-D Tensor with shape [N, M] in int type. " "N is the batch size. If ColToRowMatchIndices[i][j] is -1, it " @@ -168,10 +221,10 @@ distance matrix. For input 2D matrix, the bipartite matching algorithm can find the matched column for each row, also can find the matched row for each column. And this operator only calculate matched indices from column to row. For each instance, the number of matched indices is the number of -of columns of the input ditance matrix. +of columns of the input distance matrix. There are two outputs to save matched indices and distance. -A simple description, this algothrim matched the best (maximum distance) +A simple description, this algorithm matched the best (maximum distance) row entity to the column entity and the matched indices are not duplicated in each row of ColToRowMatchIndices. If the column entity is not matched any row entity, set -1 in ColToRowMatchIndices. diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 5ae4da1ea3..25522249c8 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -132,7 +132,10 @@ def detection_output(scores, return nmsed_outs -def bipartite_match(dist_matrix, name=None): +def bipartite_match(dist_matrix, + match_type=None, + dist_threshold=None, + name=None): """ **Bipartite matchint operator** @@ -164,6 +167,11 @@ def bipartite_match(dist_matrix, name=None): This tensor can contain LoD information to represent a batch of inputs. One instance of this batch can contain different numbers of entities. + match_type(string|None): The type of matching method, should be + 'bipartite' or 'per_prediction', 'bipartite' by defalut. + dist_threshold(float|None): If `match_type` is 'per_prediction', + this threshold is to determine the extra matching bboxes based + on the maximum distance, 0.5 by defalut. Returns: match_indices(Variable): A 2-D Tensor with shape [N, M] in int type. N is the batch size. If match_indices[i][j] is -1, it @@ -183,6 +191,10 @@ def bipartite_match(dist_matrix, name=None): helper.append_op( type='bipartite_match', inputs={'DistMat': dist_matrix}, + attrs={ + 'match_type': match_type, + 'dist_threshold': dist_threshold, + }, outputs={ 'ColToRowMatchIndices': match_indices, 'ColToRowMatchDist': match_distance @@ -333,7 +345,7 @@ def ssd_loss(location, loc_loss_weight (float): Weight for localization loss, 1.0 by default. conf_loss_weight (float): Weight for confidence loss, 1.0 by default. match_type (str): The type of matching method during training, should - be 'bipartite' or 'per_prediction'. + be 'bipartite' or 'per_prediction', 'per_prediction' by defalut. mining_type (str): The hard example mining type, should be 'hard_example' or 'max_negative', now only support `max_negative`. @@ -381,7 +393,8 @@ def ssd_loss(location, # 1.1 Compute IOU similarity between ground-truth boxes and prior boxes. iou = iou_similarity(x=gt_box, y=prior_box) # 1.2 Compute matched boundding box by bipartite matching algorithm. - matched_indices, matched_dist = bipartite_match(iou) + matched_indices, matched_dist = bipartite_match(iou, match_type, + overlap_threshold) # 2. Compute confidence for mining hard examples # 2.1. Get the target label based on matched indices diff --git a/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py b/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py index 9f9af2f55e..f7461ee6da 100644 --- a/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py +++ b/python/paddle/fluid/tests/unittests/test_bipartite_match_op.py @@ -46,7 +46,20 @@ def bipartite_match(distance, match_indices, match_dist): idx += 1 -def batch_bipartite_match(distance, lod): +def argmax_match(distance, match_indices, match_dist, threshold): + r, c = distance.shape + for j in xrange(c): + if match_indices[j] != -1: + continue + col_dist = distance[:, j] + indices = np.argwhere(col_dist >= threshold).flatten() + if len(indices) < 1: + continue + match_indices[j] = indices[np.argmax(col_dist[indices])] + match_dist[j] = col_dist[match_indices[j]] + + +def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None): """Bipartite Matching algorithm for batch input. Arg: distance (numpy.array) : The distance of two entries with shape [M, N]. @@ -59,6 +72,9 @@ def batch_bipartite_match(distance, lod): for i in range(len(lod) - 1): bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :], match_dist[i, :]) + if match_type == 'per_prediction': + argmax_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :], + match_dist[i, :], dist_threshold) return match_indices, match_dist @@ -71,8 +87,8 @@ class TestBipartiteMatchOpWithLoD(OpTest): self.inputs = {'DistMat': (dist, lod)} self.outputs = { - 'ColToRowMatchIndices': (match_indices), - 'ColToRowMatchDist': (match_dist), + 'ColToRowMatchIndices': match_indices, + 'ColToRowMatchDist': match_dist, } def test_check_output(self): @@ -96,5 +112,27 @@ class TestBipartiteMatchOpWithoutLoD(OpTest): self.check_output() +class TestBipartiteMatchOpWithPerPredictionType(OpTest): + def setUp(self): + self.op_type = 'bipartite_match' + lod = [[0, 5, 11, 23]] + dist = np.random.random((23, 237)).astype('float32') + match_indices, match_dist = batch_bipartite_match(dist, lod[0], + 'per_prediction', 0.5) + + self.inputs = {'DistMat': (dist, lod)} + self.outputs = { + 'ColToRowMatchIndices': match_indices, + 'ColToRowMatchDist': match_dist, + } + self.attrs = { + 'match_type': 'per_prediction', + 'dist_threshold': 0.5, + } + + def test_check_output(self): + self.check_output() + + if __name__ == '__main__': unittest.main() -- GitLab