From 575289209f38fb7b0342cd75d1742217723f1de0 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Tue, 23 Jan 2018 11:28:51 +0800 Subject: [PATCH] Fix bug and unit test in bipartite_match_op. --- paddle/operators/bipartite_match_op.cc | 3 +- .../v2/fluid/tests/test_bipartite_match_op.py | 34 +++++++++---------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/paddle/operators/bipartite_match_op.cc b/paddle/operators/bipartite_match_op.cc index b0f7376d27..83c8778fe4 100644 --- a/paddle/operators/bipartite_match_op.cc +++ b/paddle/operators/bipartite_match_op.cc @@ -21,8 +21,6 @@ namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -constexpr char kEPS = 1e-6; - class BipartiteMatchOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -46,6 +44,7 @@ class BipartiteMatchKernel : public framework::OpKernel { // The match_dist must be initialized to 0 at first. void BipartiteMatch(const Tensor& dist, int* match_indices, T* match_dist) const { + constexpr T kEPS = static_cast(1e-6); PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2."); int64_t row = dist.dims()[0]; int64_t col = dist.dims()[1]; diff --git a/python/paddle/v2/fluid/tests/test_bipartite_match_op.py b/python/paddle/v2/fluid/tests/test_bipartite_match_op.py index 34101b1da4..7413829897 100644 --- a/python/paddle/v2/fluid/tests/test_bipartite_match_op.py +++ b/python/paddle/v2/fluid/tests/test_bipartite_match_op.py @@ -16,13 +16,13 @@ import numpy as np from op_test import OpTest -def bipartite_match(distance, match_indices, match_dis): +def bipartite_match(distance, match_indices, match_dist): """Bipartite Matching algorithm. Arg: distance (numpy.array) : The distance of two entries with shape [M, N]. match_indices (numpy.array): the matched indices from column to row with shape [1, N], it must be initialized to -1. - match_dis (numpy.array): The matched distance from column to row + match_dist (numpy.array): The matched distance from column to row with shape [1, N], it must be initialized to 0. """ match_pair = [] @@ -36,13 +36,13 @@ def bipartite_match(distance, match_indices, match_dis): row_indices = -1 * np.ones((row, ), dtype=np.int) idx = 0 - for i, j, dis in match_sorted: + for i, j, dist in match_sorted: if idx >= row: break - if match_indices[j] == -1 and row_indices[i] == -1 and dis > 0: + if match_indices[j] == -1 and row_indices[i] == -1 and dist > 0: match_indices[j] = i row_indices[i] = j - match_dis[j] = dis + match_dist[j] = dist idx += 1 @@ -55,24 +55,24 @@ def batch_bipartite_match(distance, lod): n = len(lod) - 1 m = distance.shape[1] match_indices = -1 * np.ones((n, m), dtype=np.int) - match_dis = np.zeros((n, m), dtype=np.float32) + match_dist = np.zeros((n, m), dtype=np.float32) for i in range(len(lod) - 1): bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :], - match_dis[i, :]) - return match_indices, match_dis + match_dist[i, :]) + return match_indices, match_dist class TestBipartiteMatchOpForWithLoD(OpTest): def setUp(self): self.op_type = 'bipartite_match' lod = [[0, 5, 11, 23]] - dis = np.random.random((23, 217)).astype('float32') - match_indices, match_dis = batch_bipartite_match(dis, lod[0]) + dist = np.random.random((23, 217)).astype('float32') + match_indices, match_dist = batch_bipartite_match(dist, lod[0]) - self.inputs = {'DistMat': (dis, lod)} + self.inputs = {'DistMat': (dist, lod)} self.outputs = { 'ColToRowMatchIndices': (match_indices), - 'ColToRowMatchDis': (match_dis), + 'ColToRowMatchDis': (match_dist), } def test_check_output(self): @@ -83,13 +83,13 @@ class TestBipartiteMatchOpWithoutLoD(OpTest): def setUp(self): self.op_type = 'bipartite_match' lod = [[0, 8]] - dis = np.random.random((8, 17)).astype('float32') - match_indices, match_dis = batch_bipartite_match(dis, lod[0]) + dist = np.random.random((8, 17)).astype('float32') + match_indices, match_dist = batch_bipartite_match(dist, lod[0]) - self.inputs = {'DistMat': dis} + self.inputs = {'DistMat': dist} self.outputs = { - 'ColToRowMatchIndices': (match_indices), - 'ColToRowMatchDis': (match_dis), + 'ColToRowMatchIndices': match_indices, + 'ColToRowMatchDis': match_dist, } def test_check_output(self): -- GitLab