test_bipartite_match_op.py 5.9 KB
Newer Older
1
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15
import unittest
16

17
import numpy as np
W
wanghuancoder 已提交
18
from eager_op_test import OpTest
19 20


21
def bipartite_match(distance, match_indices, match_dist):
22 23 24 25 26
    """Bipartite Matching algorithm.
    Arg:
        distance (numpy.array) : The distance of two entries with shape [M, N].
        match_indices (numpy.array): the matched indices from column to row
            with shape [1, N], it must be initialized to -1.
27
        match_dist (numpy.array): The matched distance from column to row
28 29 30 31 32 33 34 35 36 37
            with shape [1, N], it must be initialized to 0.
    """
    match_pair = []
    row, col = distance.shape
    for i in range(row):
        for j in range(col):
            match_pair.append((i, j, distance[i][j]))

    match_sorted = sorted(match_pair, key=lambda tup: tup[2], reverse=True)

38
    row_indices = -1 * np.ones((row,), dtype=np.int_)
39 40

    idx = 0
41
    for i, j, dist in match_sorted:
42 43
        if idx >= row:
            break
44
        if match_indices[j] == -1 and row_indices[i] == -1 and dist > 0:
45 46
            match_indices[j] = i
            row_indices[i] = j
47
            match_dist[j] = dist
48 49 50
            idx += 1


51 52
def argmax_match(distance, match_indices, match_dist, threshold):
    r, c = distance.shape
53
    for j in range(c):
54 55 56 57 58 59 60 61 62 63 64
        if match_indices[j] != -1:
            continue
        col_dist = distance[:, j]
        indices = np.argwhere(col_dist >= threshold).flatten()
        if len(indices) < 1:
            continue
        match_indices[j] = indices[np.argmax(col_dist[indices])]
        match_dist[j] = col_dist[match_indices[j]]


def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None):
65 66 67
    """Bipartite Matching algorithm for batch input.
    Arg:
        distance (numpy.array) : The distance of two entries with shape [M, N].
68
        lod (list of int): The length of each input in this batch.
69
    """
70
    n = len(lod)
71
    m = distance.shape[1]
72
    match_indices = -1 * np.ones((n, m), dtype=np.int_)
73
    match_dist = np.zeros((n, m), dtype=np.float32)
74 75
    cur_offset = 0
    for i in range(n):
76 77 78 79 80 81 82
        if lod[i] == 0:
            continue
        bipartite_match(
            distance[cur_offset : (cur_offset + lod[i]), :],
            match_indices[i, :],
            match_dist[i, :],
        )
83
        if match_type == 'per_prediction':
84 85 86 87 88 89
            argmax_match(
                distance[cur_offset : (cur_offset + lod[i]), :],
                match_indices[i, :],
                match_dist[i, :],
                dist_threshold,
            )
90
        cur_offset += lod[i]
91
    return match_indices, match_dist
92 93


94
class TestBipartiteMatchOpWithLoD(OpTest):
95 96
    def setUp(self):
        self.op_type = 'bipartite_match'
97
        lod = [[5, 6, 12]]
98 99
        dist = np.random.random((23, 217)).astype('float32')
        match_indices, match_dist = batch_bipartite_match(dist, lod[0])
100

101
        self.inputs = {'DistMat': (dist, lod)}
102
        self.outputs = {
103 104
            'ColToRowMatchIndices': match_indices,
            'ColToRowMatchDist': match_dist,
105 106 107
        }

    def test_check_output(self):
108
        self.check_output(check_dygraph=False)
109 110 111 112 113


class TestBipartiteMatchOpWithoutLoD(OpTest):
    def setUp(self):
        self.op_type = 'bipartite_match'
114
        lod = [[8]]
115 116
        dist = np.random.random((8, 17)).astype('float32')
        match_indices, match_dist = batch_bipartite_match(dist, lod[0])
117 118 119 120 121 122 123 124

        self.inputs = {'DistMat': dist}
        self.outputs = {
            'ColToRowMatchIndices': match_indices,
            'ColToRowMatchDist': match_dist,
        }

    def test_check_output(self):
125
        self.check_output(check_dygraph=False)
126 127 128 129 130 131 132 133


class TestBipartiteMatchOpWithoutLoDLargeScaleInput(OpTest):
    def setUp(self):
        self.op_type = 'bipartite_match'
        lod = [[300]]
        dist = np.random.random((300, 17)).astype('float32')
        match_indices, match_dist = batch_bipartite_match(dist, lod[0])
134

135
        self.inputs = {'DistMat': dist}
136
        self.outputs = {
137
            'ColToRowMatchIndices': match_indices,
D
dangqingqing 已提交
138
            'ColToRowMatchDist': match_dist,
139 140 141
        }

    def test_check_output(self):
142
        self.check_output(check_dygraph=False)
143 144


145 146 147
class TestBipartiteMatchOpWithPerPredictionType(OpTest):
    def setUp(self):
        self.op_type = 'bipartite_match'
148
        lod = [[5, 6, 12]]
149
        dist = np.random.random((23, 237)).astype('float32')
150
        match_indices, match_dist = batch_bipartite_match(
151 152
            dist, lod[0], 'per_prediction', 0.5
        )
153 154 155 156 157 158 159 160 161 162 163 164

        self.inputs = {'DistMat': (dist, lod)}
        self.outputs = {
            'ColToRowMatchIndices': match_indices,
            'ColToRowMatchDist': match_dist,
        }
        self.attrs = {
            'match_type': 'per_prediction',
            'dist_threshold': 0.5,
        }

    def test_check_output(self):
165
        self.check_output(check_dygraph=False)
166 167


168 169 170 171 172 173 174 175 176 177 178 179 180 181
class TestBipartiteMatchOpWithEmptyLoD(OpTest):
    def setUp(self):
        self.op_type = 'bipartite_match'
        lod = [[5, 6, 0, 12]]
        dist = np.random.random((23, 217)).astype('float32')
        match_indices, match_dist = batch_bipartite_match(dist, lod[0])

        self.inputs = {'DistMat': (dist, lod)}
        self.outputs = {
            'ColToRowMatchIndices': match_indices,
            'ColToRowMatchDist': match_dist,
        }

    def test_check_output(self):
182
        self.check_output(check_dygraph=False)
183 184


185 186
if __name__ == '__main__':
    unittest.main()