提交 57528920 编写于 作者: Q qingqing01

Fix bug and unit test in bipartite_match_op.

上级 9b1a17a8
...@@ -21,8 +21,6 @@ namespace operators { ...@@ -21,8 +21,6 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
constexpr char kEPS = 1e-6;
class BipartiteMatchOp : public framework::OperatorWithKernel { class BipartiteMatchOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -46,6 +44,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> { ...@@ -46,6 +44,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
// The match_dist must be initialized to 0 at first. // The match_dist must be initialized to 0 at first.
void BipartiteMatch(const Tensor& dist, int* match_indices, void BipartiteMatch(const Tensor& dist, int* match_indices,
T* match_dist) const { T* match_dist) const {
constexpr T kEPS = static_cast<T>(1e-6);
PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2."); PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2.");
int64_t row = dist.dims()[0]; int64_t row = dist.dims()[0];
int64_t col = dist.dims()[1]; int64_t col = dist.dims()[1];
......
...@@ -16,13 +16,13 @@ import numpy as np ...@@ -16,13 +16,13 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
def bipartite_match(distance, match_indices, match_dis): def bipartite_match(distance, match_indices, match_dist):
"""Bipartite Matching algorithm. """Bipartite Matching algorithm.
Arg: Arg:
distance (numpy.array) : The distance of two entries with shape [M, N]. distance (numpy.array) : The distance of two entries with shape [M, N].
match_indices (numpy.array): the matched indices from column to row match_indices (numpy.array): the matched indices from column to row
with shape [1, N], it must be initialized to -1. with shape [1, N], it must be initialized to -1.
match_dis (numpy.array): The matched distance from column to row match_dist (numpy.array): The matched distance from column to row
with shape [1, N], it must be initialized to 0. with shape [1, N], it must be initialized to 0.
""" """
match_pair = [] match_pair = []
...@@ -36,13 +36,13 @@ def bipartite_match(distance, match_indices, match_dis): ...@@ -36,13 +36,13 @@ def bipartite_match(distance, match_indices, match_dis):
row_indices = -1 * np.ones((row, ), dtype=np.int) row_indices = -1 * np.ones((row, ), dtype=np.int)
idx = 0 idx = 0
for i, j, dis in match_sorted: for i, j, dist in match_sorted:
if idx >= row: if idx >= row:
break break
if match_indices[j] == -1 and row_indices[i] == -1 and dis > 0: if match_indices[j] == -1 and row_indices[i] == -1 and dist > 0:
match_indices[j] = i match_indices[j] = i
row_indices[i] = j row_indices[i] = j
match_dis[j] = dis match_dist[j] = dist
idx += 1 idx += 1
...@@ -55,24 +55,24 @@ def batch_bipartite_match(distance, lod): ...@@ -55,24 +55,24 @@ def batch_bipartite_match(distance, lod):
n = len(lod) - 1 n = len(lod) - 1
m = distance.shape[1] m = distance.shape[1]
match_indices = -1 * np.ones((n, m), dtype=np.int) match_indices = -1 * np.ones((n, m), dtype=np.int)
match_dis = np.zeros((n, m), dtype=np.float32) match_dist = np.zeros((n, m), dtype=np.float32)
for i in range(len(lod) - 1): for i in range(len(lod) - 1):
bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :], bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :],
match_dis[i, :]) match_dist[i, :])
return match_indices, match_dis return match_indices, match_dist
class TestBipartiteMatchOpForWithLoD(OpTest): class TestBipartiteMatchOpForWithLoD(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'bipartite_match' self.op_type = 'bipartite_match'
lod = [[0, 5, 11, 23]] lod = [[0, 5, 11, 23]]
dis = np.random.random((23, 217)).astype('float32') dist = np.random.random((23, 217)).astype('float32')
match_indices, match_dis = batch_bipartite_match(dis, lod[0]) match_indices, match_dist = batch_bipartite_match(dist, lod[0])
self.inputs = {'DistMat': (dis, lod)} self.inputs = {'DistMat': (dist, lod)}
self.outputs = { self.outputs = {
'ColToRowMatchIndices': (match_indices), 'ColToRowMatchIndices': (match_indices),
'ColToRowMatchDis': (match_dis), 'ColToRowMatchDis': (match_dist),
} }
def test_check_output(self): def test_check_output(self):
...@@ -83,13 +83,13 @@ class TestBipartiteMatchOpWithoutLoD(OpTest): ...@@ -83,13 +83,13 @@ class TestBipartiteMatchOpWithoutLoD(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'bipartite_match' self.op_type = 'bipartite_match'
lod = [[0, 8]] lod = [[0, 8]]
dis = np.random.random((8, 17)).astype('float32') dist = np.random.random((8, 17)).astype('float32')
match_indices, match_dis = batch_bipartite_match(dis, lod[0]) match_indices, match_dist = batch_bipartite_match(dist, lod[0])
self.inputs = {'DistMat': dis} self.inputs = {'DistMat': dist}
self.outputs = { self.outputs = {
'ColToRowMatchIndices': (match_indices), 'ColToRowMatchIndices': match_indices,
'ColToRowMatchDis': (match_dis), 'ColToRowMatchDis': match_dist,
} }
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册