提交 e44dedf9 编写于 作者: D dangqingqing

Fix the warning and unit test.

上级 74af23b6
...@@ -62,7 +62,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> { ...@@ -62,7 +62,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
if (match_indices[j] != -1) { if (match_indices[j] != -1) {
continue; continue;
} }
for (int k = 0; k < row_pool.size(); ++k) { for (size_t k = 0; k < row_pool.size(); ++k) {
int m = row_pool[k]; int m = row_pool[k];
// distance is 0 between m-th row and j-th column // distance is 0 between m-th row and j-th column
if (dist_data[m * col + j] < kEPS) { if (dist_data[m * col + j] < kEPS) {
......
...@@ -69,7 +69,7 @@ class TestBipartiteMatchOpForWithLoD(OpTest): ...@@ -69,7 +69,7 @@ class TestBipartiteMatchOpForWithLoD(OpTest):
dis = np.random.random((23, 217)).astype('float32') dis = np.random.random((23, 217)).astype('float32')
match_indices, match_dis = batch_bipartite_match(dis, lod[0]) match_indices, match_dis = batch_bipartite_match(dis, lod[0])
self.inputs = {'DisMat': (dis, lod)} self.inputs = {'DistMat': (dis, lod)}
self.outputs = { self.outputs = {
'ColToRowMatchIndices': (match_indices), 'ColToRowMatchIndices': (match_indices),
'ColToRowMatchDis': (match_dis), 'ColToRowMatchDis': (match_dis),
...@@ -86,7 +86,7 @@ class TestBipartiteMatchOpWithoutLoD(OpTest): ...@@ -86,7 +86,7 @@ class TestBipartiteMatchOpWithoutLoD(OpTest):
dis = np.random.random((8, 17)).astype('float32') dis = np.random.random((8, 17)).astype('float32')
match_indices, match_dis = batch_bipartite_match(dis, lod[0]) match_indices, match_dis = batch_bipartite_match(dis, lod[0])
self.inputs = {'DisMat': dis} self.inputs = {'DistMat': dis}
self.outputs = { self.outputs = {
'ColToRowMatchIndices': (match_indices), 'ColToRowMatchIndices': (match_indices),
'ColToRowMatchDis': (match_dis), 'ColToRowMatchDis': (match_dis),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册