From bf33b191d0cbb950d50f003f08ed3f16f0e2b92e Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 18 Jan 2018 18:41:08 +0800 Subject: [PATCH] Add bipartite matching operator and unit testing. --- paddle/operators/bipartite_match_op.cc | 178 ++++++++++++++++++ .../v2/fluid/tests/test_bipartite_match_op.py | 100 ++++++++++ 2 files changed, 278 insertions(+) create mode 100644 paddle/operators/bipartite_match_op.cc create mode 100644 python/paddle/v2/fluid/tests/test_bipartite_match_op.py diff --git a/paddle/operators/bipartite_match_op.cc b/paddle/operators/bipartite_match_op.cc new file mode 100644 index 00000000000..8dbade65a5b --- /dev/null +++ b/paddle/operators/bipartite_match_op.cc @@ -0,0 +1,178 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +class BipartiteMatchOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("DisMat"), + "Input(DisMat) of BipartiteMatch should not be null."); + + auto dims = ctx->GetInputDim("DisMat"); + PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DisMat) must be 2."); + + ctx->SetOutputDim("ColToRowMatchIndices", dims); + ctx->SetOutputDim("ColToRowMatchDis", dims); + } +}; + +template +class BipartiteMatchKernel : public framework::OpKernel { + public: + // The match_indices must be initialized to -1 at first. + // The match_dis must be initialized to 0 at first. + void BipartiteMatch(const Tensor& dis, int* match_indices, + T* match_dis) const { + int64_t row = dis.dims()[0]; + int64_t col = dis.dims()[1]; + auto* dis_data = dis.data(); + std::vector row_pool; + for (int i = 0; i < row; ++i) { + row_pool.push_back(i); + } + while (row_pool.size() > 0) { + int max_idx = -1; + int max_row_idx = -1; + T max_dis = -1; + for (int64_t j = 0; j < col; ++j) { + if (match_indices[j] != -1) { + continue; + } + for (int k = 0; k < row_pool.size(); ++k) { + int m = row_pool[k]; + // distance is 0 between m-th row and j-th column + if (dis_data[m * col + j] < 1e-6) { + continue; + } + if (dis_data[m * col + j] > max_dis) { + max_idx = j; + max_row_idx = m; + max_dis = dis_data[m * col + j]; + } + } + } + if (max_idx == -1) { + // Cannot find good match. + break; + } else { + PADDLE_ENFORCE_EQ(match_indices[max_idx], -1); + match_indices[max_idx] = max_row_idx; + match_dis[max_idx] = max_dis; + // Erase the row index. + row_pool.erase( + std::find(row_pool.begin(), row_pool.end(), max_row_idx)); + } + } + } + + void Compute(const framework::ExecutionContext& context) const override { + auto* dis_mat = context.Input("DisMat"); + auto* match_indices = context.Output("ColToRowMatchIndices"); + auto* match_dis = context.Output("ColToRowMatchDis"); + + auto& dev_ctx = context.device_context(); + + auto col = dis_mat->dims()[1]; + + int64_t n = dis_mat->lod().size() == 0 + ? 1 + : static_cast(dis_mat->lod().back().size() - 1); + match_indices->mutable_data({n, col}, context.GetPlace()); + match_dis->mutable_data({n, col}, context.GetPlace()); + + math::SetConstant iset; + iset(dev_ctx, match_indices, static_cast(-1)); + math::SetConstant tset; + tset(dev_ctx, match_dis, static_cast(0)); + + int* indices = match_indices->data(); + T* dis = match_dis->data(); + if (n == 1) { + BipartiteMatch(*dis_mat, indices, dis); + } else { + auto lod = dis_mat->lod().back(); + for (size_t i = 0; i < lod.size() - 1; ++i) { + Tensor one_ins = dis_mat->Slice(lod[i], lod[i + 1]); + BipartiteMatch(one_ins, indices + i * col, dis + i * col); + } + } + } +}; + +class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker { + public: + BipartiteMatchOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "DisMat", + "(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape " + "[K, M]. It is pair-wise distance matrix between the entities " + "represented by each row and each column. For example, assumed one " + "entity is A with shape [K], another entity is B with shape [M]. The " + "DisMat[i][j] is the distance between A[i] and B[j]. The bigger " + "the distance is, the more similar the pairs are. Please note, " + "This tensor can contain LoD information to represent a batch of " + "inputs. One instance of this batch can contain different numbers of " + "entities."); + AddOutput("ColToRowMatchIndices", + "(Tensor) A 2-D Tensor with shape [N, M] in int type. " + "N is the batch size. If ColToRowMatchIndices[i][j] is -1, it " + "means B[j] does not match any entity in i-th instance. " + "Otherwise, it means B[j] is matched to row " + "RowToColMatchIndices[i][j] in i-th instance. The row number of " + "i-th instance is saved in RowToColMatchIndices[i][j]."); + AddOutput("ColToRowMatchDis", + "(Tensor) A 2-D Tensor with shape [N, M] in float type. " + "N is batch size. If ColToRowMatchIndices[i][j] is -1, " + "ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed " + "RowToColMatchIndices[i][j] = d, and the row offsets of each " + "instance are called LoD. Then " + "ColToRowMatchDis[i][j] = DisMat[d+LoD[i]][j]"); + AddComment(R"DOC( +This operator is a greedy bipartite matching algorithm, which is used to +obtain the matching with the (greedy) maximum distance based on the input +distance matrix. There are two outputs to save matched indices and distance. +And this operator only calculate matched indices from column to row. +A simple description, this algothrim matched the best (maximum distance) +row entity to the column entity and the matched indices are not duplicated +in each row of ColToRowMatchIndices. If the column entity is not matched +any row entity, set -1 in ColToRowMatchIndices. + +Please note that the input DisMat can be LoDTensor (with LoD) or Tensor. +If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size. +If Tensor, the height of ColToRowMatchIndices is 1. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(bipartite_match, ops::BipartiteMatchOp, + ops::BipartiteMatchOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(bipartite_match, ops::BipartiteMatchKernel, + ops::BipartiteMatchKernel); diff --git a/python/paddle/v2/fluid/tests/test_bipartite_match_op.py b/python/paddle/v2/fluid/tests/test_bipartite_match_op.py new file mode 100644 index 00000000000..8f1db35d3c5 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_bipartite_match_op.py @@ -0,0 +1,100 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +import unittest +import numpy as np +from op_test import OpTest + + +def bipartite_match(distance, match_indices, match_dis): + """Bipartite Matching algorithm. + Arg: + distance (numpy.array) : The distance of two entries with shape [M, N]. + match_indices (numpy.array): the matched indices from column to row + with shape [1, N], it must be initialized to -1. + match_dis (numpy.array): The matched distance from column to row + with shape [1, N], it must be initialized to 0. + """ + match_pair = [] + row, col = distance.shape + for i in range(row): + for j in range(col): + match_pair.append((i, j, distance[i][j])) + + match_sorted = sorted(match_pair, key=lambda tup: tup[2], reverse=True) + + row_indices = -1 * np.ones((row, ), dtype=np.int) + + idx = 0 + for i, j, dis in match_sorted: + if idx >= row: + break + if match_indices[j] == -1 and row_indices[i] == -1 and dis > 0: + match_indices[j] = i + row_indices[i] = j + match_dis[j] = dis + idx += 1 + + +def batch_bipartite_match(distance, lod): + """Bipartite Matching algorithm for batch input. + Arg: + distance (numpy.array) : The distance of two entries with shape [M, N]. + lod (list of int): The offsets of each input in this batch. + """ + n = len(lod) - 1 + m = distance.shape[1] + match_indices = -1 * np.ones((n, m), dtype=np.int) + match_dis = np.zeros((n, m), dtype=np.float32) + for i in range(len(lod) - 1): + bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :], + match_dis[i, :]) + return match_indices, match_dis + + +class TestBipartiteMatchOpForWithLoD(OpTest): + def setUp(self): + self.op_type = 'bipartite_match' + lod = [[0, 5, 11, 23]] + dis = np.random.random((23, 217)).astype('float32') + match_indices, match_dis = batch_bipartite_match(dis, lod[0]) + + self.inputs = {'DisMat': (dis, lod)} + self.outputs = { + 'ColToRowMatchIndices': (match_indices), + 'ColToRowMatchDis': (match_dis), + } + + def test_check_output(self): + self.check_output() + + +class TestBipartiteMatchOpWithoutLoD(OpTest): + def setUp(self): + self.op_type = 'bipartite_match' + lod = [[0, 8]] + dis = np.random.random((8, 17)).astype('float32') + match_indices, match_dis = batch_bipartite_match(dis, lod[0]) + + self.inputs = {'DisMat': dis} + self.outputs = { + 'ColToRowMatchIndices': (match_indices), + 'ColToRowMatchDis': (match_dis), + } + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() -- GitLab