未验证 提交 59bcb589 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #7759 from qingqing01/bipartite_match_op_fix

Fix bug and unit test in bipartite_match_op.
......@@ -21,8 +21,6 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
constexpr char kEPS = 1e-6;
class BipartiteMatchOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -46,6 +44,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
// The match_dist must be initialized to 0 at first.
void BipartiteMatch(const Tensor& dist, int* match_indices,
T* match_dist) const {
constexpr T kEPS = static_cast<T>(1e-6);
PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2.");
int64_t row = dist.dims()[0];
int64_t col = dist.dims()[1];
......
......@@ -16,13 +16,13 @@ import numpy as np
from op_test import OpTest
def bipartite_match(distance, match_indices, match_dis):
def bipartite_match(distance, match_indices, match_dist):
"""Bipartite Matching algorithm.
Arg:
distance (numpy.array) : The distance of two entries with shape [M, N].
match_indices (numpy.array): the matched indices from column to row
with shape [1, N], it must be initialized to -1.
match_dis (numpy.array): The matched distance from column to row
match_dist (numpy.array): The matched distance from column to row
with shape [1, N], it must be initialized to 0.
"""
match_pair = []
......@@ -36,13 +36,13 @@ def bipartite_match(distance, match_indices, match_dis):
row_indices = -1 * np.ones((row, ), dtype=np.int)
idx = 0
for i, j, dis in match_sorted:
for i, j, dist in match_sorted:
if idx >= row:
break
if match_indices[j] == -1 and row_indices[i] == -1 and dis > 0:
if match_indices[j] == -1 and row_indices[i] == -1 and dist > 0:
match_indices[j] = i
row_indices[i] = j
match_dis[j] = dis
match_dist[j] = dist
idx += 1
......@@ -55,24 +55,24 @@ def batch_bipartite_match(distance, lod):
n = len(lod) - 1
m = distance.shape[1]
match_indices = -1 * np.ones((n, m), dtype=np.int)
match_dis = np.zeros((n, m), dtype=np.float32)
match_dist = np.zeros((n, m), dtype=np.float32)
for i in range(len(lod) - 1):
bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :],
match_dis[i, :])
return match_indices, match_dis
match_dist[i, :])
return match_indices, match_dist
class TestBipartiteMatchOpForWithLoD(OpTest):
def setUp(self):
self.op_type = 'bipartite_match'
lod = [[0, 5, 11, 23]]
dis = np.random.random((23, 217)).astype('float32')
match_indices, match_dis = batch_bipartite_match(dis, lod[0])
dist = np.random.random((23, 217)).astype('float32')
match_indices, match_dist = batch_bipartite_match(dist, lod[0])
self.inputs = {'DistMat': (dis, lod)}
self.inputs = {'DistMat': (dist, lod)}
self.outputs = {
'ColToRowMatchIndices': (match_indices),
'ColToRowMatchDis': (match_dis),
'ColToRowMatchDis': (match_dist),
}
def test_check_output(self):
......@@ -83,13 +83,13 @@ class TestBipartiteMatchOpWithoutLoD(OpTest):
def setUp(self):
self.op_type = 'bipartite_match'
lod = [[0, 8]]
dis = np.random.random((8, 17)).astype('float32')
match_indices, match_dis = batch_bipartite_match(dis, lod[0])
dist = np.random.random((8, 17)).astype('float32')
match_indices, match_dist = batch_bipartite_match(dist, lod[0])
self.inputs = {'DistMat': dis}
self.inputs = {'DistMat': dist}
self.outputs = {
'ColToRowMatchIndices': (match_indices),
'ColToRowMatchDis': (match_dis),
'ColToRowMatchIndices': match_indices,
'ColToRowMatchDis': match_dist,
}
def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册