diff --git a/paddle/operators/bipartite_match_op.cc b/paddle/operators/bipartite_match_op.cc index 8dbade65a5b2d181b3a531f25c7004d6fe1d3f4d..c2d30c7d926b66703b5a4d8e4ea0a91d6f77c1c3 100644 --- a/paddle/operators/bipartite_match_op.cc +++ b/paddle/operators/bipartite_match_op.cc @@ -21,6 +21,8 @@ namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; +constexpr char kEPS = 1e-6; + class BipartiteMatchOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -41,12 +43,13 @@ template class BipartiteMatchKernel : public framework::OpKernel { public: // The match_indices must be initialized to -1 at first. - // The match_dis must be initialized to 0 at first. - void BipartiteMatch(const Tensor& dis, int* match_indices, - T* match_dis) const { - int64_t row = dis.dims()[0]; - int64_t col = dis.dims()[1]; - auto* dis_data = dis.data(); + // The match_dist must be initialized to 0 at first. + void BipartiteMatch(const Tensor& dist, int* match_indices, + T* match_dist) const { + PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2."); + int64_t row = dist.dims()[0]; + int64_t col = dist.dims()[1]; + auto* dist_data = dist.data(); std::vector row_pool; for (int i = 0; i < row; ++i) { row_pool.push_back(i); @@ -54,7 +57,7 @@ class BipartiteMatchKernel : public framework::OpKernel { while (row_pool.size() > 0) { int max_idx = -1; int max_row_idx = -1; - T max_dis = -1; + T max_dist = -1; for (int64_t j = 0; j < col; ++j) { if (match_indices[j] != -1) { continue; @@ -62,13 +65,13 @@ class BipartiteMatchKernel : public framework::OpKernel { for (int k = 0; k < row_pool.size(); ++k) { int m = row_pool[k]; // distance is 0 between m-th row and j-th column - if (dis_data[m * col + j] < 1e-6) { + if (dist_data[m * col + j] < kEPS) { continue; } - if (dis_data[m * col + j] > max_dis) { + if (dist_data[m * col + j] > max_dist) { max_idx = j; max_row_idx = m; - max_dis = dis_data[m * col + j]; + max_dist = dist_data[m * col + j]; } } } @@ -78,7 +81,7 @@ class BipartiteMatchKernel : public framework::OpKernel { } else { PADDLE_ENFORCE_EQ(match_indices[max_idx], -1); match_indices[max_idx] = max_row_idx; - match_dis[max_idx] = max_dis; + match_dist[max_idx] = max_dist; // Erase the row index. row_pool.erase( std::find(row_pool.begin(), row_pool.end(), max_row_idx)); @@ -87,34 +90,38 @@ class BipartiteMatchKernel : public framework::OpKernel { } void Compute(const framework::ExecutionContext& context) const override { - auto* dis_mat = context.Input("DisMat"); + auto* dist_mat = context.Input("DisMat"); auto* match_indices = context.Output("ColToRowMatchIndices"); - auto* match_dis = context.Output("ColToRowMatchDis"); + auto* match_dist = context.Output("ColToRowMatchDis"); auto& dev_ctx = context.device_context(); - auto col = dis_mat->dims()[1]; + auto col = dist_mat->dims()[1]; - int64_t n = dis_mat->lod().size() == 0 + int64_t n = dist_mat->lod().size() == 0UL ? 1 - : static_cast(dis_mat->lod().back().size() - 1); + : static_cast(dist_mat->lod().back().size() - 1); + if (dist_mat->lod().size()) { + PADDLE_ENFORCE_EQ(dist_mat->lod().size(), 1UL, + "Only support 1 level of LoD."); + } match_indices->mutable_data({n, col}, context.GetPlace()); - match_dis->mutable_data({n, col}, context.GetPlace()); + match_dist->mutable_data({n, col}, context.GetPlace()); math::SetConstant iset; iset(dev_ctx, match_indices, static_cast(-1)); math::SetConstant tset; - tset(dev_ctx, match_dis, static_cast(0)); + tset(dev_ctx, match_dist, static_cast(0)); int* indices = match_indices->data(); - T* dis = match_dis->data(); + T* dist = match_dist->data(); if (n == 1) { - BipartiteMatch(*dis_mat, indices, dis); + BipartiteMatch(*dist_mat, indices, dist); } else { - auto lod = dis_mat->lod().back(); + auto lod = dist_mat->lod().back(); for (size_t i = 0; i < lod.size() - 1; ++i) { - Tensor one_ins = dis_mat->Slice(lod[i], lod[i + 1]); - BipartiteMatch(one_ins, indices + i * col, dis + i * col); + Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]); + BipartiteMatch(one_ins, indices + i * col, dist + i * col); } } } @@ -131,7 +138,7 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker { "represented by each row and each column. For example, assumed one " "entity is A with shape [K], another entity is B with shape [M]. The " "DisMat[i][j] is the distance between A[i] and B[j]. The bigger " - "the distance is, the more similar the pairs are. Please note, " + "the distance is, the better macthing the pairs are. Please note, " "This tensor can contain LoD information to represent a batch of " "inputs. One instance of this batch can contain different numbers of " "entities."); @@ -140,20 +147,25 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker { "N is the batch size. If ColToRowMatchIndices[i][j] is -1, it " "means B[j] does not match any entity in i-th instance. " "Otherwise, it means B[j] is matched to row " - "RowToColMatchIndices[i][j] in i-th instance. The row number of " - "i-th instance is saved in RowToColMatchIndices[i][j]."); + "ColToRowMatchIndices[i][j] in i-th instance. The row number of " + "i-th instance is saved in ColToRowMatchIndices[i][j]."); AddOutput("ColToRowMatchDis", "(Tensor) A 2-D Tensor with shape [N, M] in float type. " "N is batch size. If ColToRowMatchIndices[i][j] is -1, " "ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed " - "RowToColMatchIndices[i][j] = d, and the row offsets of each " + "ColToRowMatchIndices[i][j] = d, and the row offsets of each " "instance are called LoD. Then " "ColToRowMatchDis[i][j] = DisMat[d+LoD[i]][j]"); AddComment(R"DOC( This operator is a greedy bipartite matching algorithm, which is used to -obtain the matching with the (greedy) maximum distance based on the input -distance matrix. There are two outputs to save matched indices and distance. -And this operator only calculate matched indices from column to row. +obtain the matching with the maximum distance based on the input +distance matrix. For input 2D matrix, the bipartite matching algorithm can +find the matched column for each row, also can find the matched row for +each column. And this operator only calculate matched indices from column +to row. For each instance, the number of matched indices is the number of +of columns of the input ditance matrix. + +There are two outputs to save matched indices and distance. A simple description, this algothrim matched the best (maximum distance) row entity to the column entity and the matched indices are not duplicated in each row of ColToRowMatchIndices. If the column entity is not matched