test_bipartite_match_op.py 3.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import unittest
import numpy as np
from op_test import OpTest


def bipartite_match(distance, match_indices, match_dis):
    """Bipartite Matching algorithm.
    Arg:
        distance (numpy.array) : The distance of two entries with shape [M, N].
        match_indices (numpy.array): the matched indices from column to row
            with shape [1, N], it must be initialized to -1.
        match_dis (numpy.array): The matched distance from column to row
            with shape [1, N], it must be initialized to 0.
    """
    match_pair = []
    row, col = distance.shape
    for i in range(row):
        for j in range(col):
            match_pair.append((i, j, distance[i][j]))

    match_sorted = sorted(match_pair, key=lambda tup: tup[2], reverse=True)

    row_indices = -1 * np.ones((row, ), dtype=np.int)

    idx = 0
    for i, j, dis in match_sorted:
        if idx >= row:
            break
        if match_indices[j] == -1 and row_indices[i] == -1 and dis > 0:
            match_indices[j] = i
            row_indices[i] = j
            match_dis[j] = dis
            idx += 1


def batch_bipartite_match(distance, lod):
    """Bipartite Matching algorithm for batch input.
    Arg:
        distance (numpy.array) : The distance of two entries with shape [M, N].
        lod (list of int): The offsets of each input in this batch.
    """
    n = len(lod) - 1
    m = distance.shape[1]
    match_indices = -1 * np.ones((n, m), dtype=np.int)
    match_dis = np.zeros((n, m), dtype=np.float32)
    for i in range(len(lod) - 1):
        bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :],
                        match_dis[i, :])
    return match_indices, match_dis


class TestBipartiteMatchOpForWithLoD(OpTest):
    def setUp(self):
        self.op_type = 'bipartite_match'
        lod = [[0, 5, 11, 23]]
        dis = np.random.random((23, 217)).astype('float32')
        match_indices, match_dis = batch_bipartite_match(dis, lod[0])

D
dangqingqing 已提交
72
        self.inputs = {'DistMat': (dis, lod)}
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
        self.outputs = {
            'ColToRowMatchIndices': (match_indices),
            'ColToRowMatchDis': (match_dis),
        }

    def test_check_output(self):
        self.check_output()


class TestBipartiteMatchOpWithoutLoD(OpTest):
    def setUp(self):
        self.op_type = 'bipartite_match'
        lod = [[0, 8]]
        dis = np.random.random((8, 17)).astype('float32')
        match_indices, match_dis = batch_bipartite_match(dis, lod[0])

D
dangqingqing 已提交
89
        self.inputs = {'DistMat': dis}
90 91 92 93 94 95 96 97 98 99 100
        self.outputs = {
            'ColToRowMatchIndices': (match_indices),
            'ColToRowMatchDis': (match_dis),
        }

    def test_check_output(self):
        self.check_output()


if __name__ == '__main__':
    unittest.main()