提交 07908686 编写于 作者: D dangqingqing

Update some comments and add more check.

上级 c2edcde1
...@@ -21,6 +21,8 @@ namespace operators { ...@@ -21,6 +21,8 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
constexpr char kEPS = 1e-6;
class BipartiteMatchOp : public framework::OperatorWithKernel { class BipartiteMatchOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -41,12 +43,13 @@ template <typename T> ...@@ -41,12 +43,13 @@ template <typename T>
class BipartiteMatchKernel : public framework::OpKernel<T> { class BipartiteMatchKernel : public framework::OpKernel<T> {
public: public:
// The match_indices must be initialized to -1 at first. // The match_indices must be initialized to -1 at first.
// The match_dis must be initialized to 0 at first. // The match_dist must be initialized to 0 at first.
void BipartiteMatch(const Tensor& dis, int* match_indices, void BipartiteMatch(const Tensor& dist, int* match_indices,
T* match_dis) const { T* match_dist) const {
int64_t row = dis.dims()[0]; PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2.");
int64_t col = dis.dims()[1]; int64_t row = dist.dims()[0];
auto* dis_data = dis.data<T>(); int64_t col = dist.dims()[1];
auto* dist_data = dist.data<T>();
std::vector<int> row_pool; std::vector<int> row_pool;
for (int i = 0; i < row; ++i) { for (int i = 0; i < row; ++i) {
row_pool.push_back(i); row_pool.push_back(i);
...@@ -54,7 +57,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> { ...@@ -54,7 +57,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
while (row_pool.size() > 0) { while (row_pool.size() > 0) {
int max_idx = -1; int max_idx = -1;
int max_row_idx = -1; int max_row_idx = -1;
T max_dis = -1; T max_dist = -1;
for (int64_t j = 0; j < col; ++j) { for (int64_t j = 0; j < col; ++j) {
if (match_indices[j] != -1) { if (match_indices[j] != -1) {
continue; continue;
...@@ -62,13 +65,13 @@ class BipartiteMatchKernel : public framework::OpKernel<T> { ...@@ -62,13 +65,13 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
for (int k = 0; k < row_pool.size(); ++k) { for (int k = 0; k < row_pool.size(); ++k) {
int m = row_pool[k]; int m = row_pool[k];
// distance is 0 between m-th row and j-th column // distance is 0 between m-th row and j-th column
if (dis_data[m * col + j] < 1e-6) { if (dist_data[m * col + j] < kEPS) {
continue; continue;
} }
if (dis_data[m * col + j] > max_dis) { if (dist_data[m * col + j] > max_dist) {
max_idx = j; max_idx = j;
max_row_idx = m; max_row_idx = m;
max_dis = dis_data[m * col + j]; max_dist = dist_data[m * col + j];
} }
} }
} }
...@@ -78,7 +81,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> { ...@@ -78,7 +81,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
} else { } else {
PADDLE_ENFORCE_EQ(match_indices[max_idx], -1); PADDLE_ENFORCE_EQ(match_indices[max_idx], -1);
match_indices[max_idx] = max_row_idx; match_indices[max_idx] = max_row_idx;
match_dis[max_idx] = max_dis; match_dist[max_idx] = max_dist;
// Erase the row index. // Erase the row index.
row_pool.erase( row_pool.erase(
std::find(row_pool.begin(), row_pool.end(), max_row_idx)); std::find(row_pool.begin(), row_pool.end(), max_row_idx));
...@@ -87,34 +90,38 @@ class BipartiteMatchKernel : public framework::OpKernel<T> { ...@@ -87,34 +90,38 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
} }
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* dis_mat = context.Input<LoDTensor>("DisMat"); auto* dist_mat = context.Input<LoDTensor>("DisMat");
auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices"); auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices");
auto* match_dis = context.Output<Tensor>("ColToRowMatchDis"); auto* match_dist = context.Output<Tensor>("ColToRowMatchDis");
auto& dev_ctx = context.device_context<platform::CPUDeviceContext>(); auto& dev_ctx = context.device_context<platform::CPUDeviceContext>();
auto col = dis_mat->dims()[1]; auto col = dist_mat->dims()[1];
int64_t n = dis_mat->lod().size() == 0 int64_t n = dist_mat->lod().size() == 0UL
? 1 ? 1
: static_cast<int64_t>(dis_mat->lod().back().size() - 1); : static_cast<int64_t>(dist_mat->lod().back().size() - 1);
if (dist_mat->lod().size()) {
PADDLE_ENFORCE_EQ(dist_mat->lod().size(), 1UL,
"Only support 1 level of LoD.");
}
match_indices->mutable_data<int>({n, col}, context.GetPlace()); match_indices->mutable_data<int>({n, col}, context.GetPlace());
match_dis->mutable_data<T>({n, col}, context.GetPlace()); match_dist->mutable_data<T>({n, col}, context.GetPlace());
math::SetConstant<platform::CPUDeviceContext, int> iset; math::SetConstant<platform::CPUDeviceContext, int> iset;
iset(dev_ctx, match_indices, static_cast<int>(-1)); iset(dev_ctx, match_indices, static_cast<int>(-1));
math::SetConstant<platform::CPUDeviceContext, T> tset; math::SetConstant<platform::CPUDeviceContext, T> tset;
tset(dev_ctx, match_dis, static_cast<T>(0)); tset(dev_ctx, match_dist, static_cast<T>(0));
int* indices = match_indices->data<int>(); int* indices = match_indices->data<int>();
T* dis = match_dis->data<T>(); T* dist = match_dist->data<T>();
if (n == 1) { if (n == 1) {
BipartiteMatch(*dis_mat, indices, dis); BipartiteMatch(*dist_mat, indices, dist);
} else { } else {
auto lod = dis_mat->lod().back(); auto lod = dist_mat->lod().back();
for (size_t i = 0; i < lod.size() - 1; ++i) { for (size_t i = 0; i < lod.size() - 1; ++i) {
Tensor one_ins = dis_mat->Slice(lod[i], lod[i + 1]); Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]);
BipartiteMatch(one_ins, indices + i * col, dis + i * col); BipartiteMatch(one_ins, indices + i * col, dist + i * col);
} }
} }
} }
...@@ -131,7 +138,7 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -131,7 +138,7 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
"represented by each row and each column. For example, assumed one " "represented by each row and each column. For example, assumed one "
"entity is A with shape [K], another entity is B with shape [M]. The " "entity is A with shape [K], another entity is B with shape [M]. The "
"DisMat[i][j] is the distance between A[i] and B[j]. The bigger " "DisMat[i][j] is the distance between A[i] and B[j]. The bigger "
"the distance is, the more similar the pairs are. Please note, " "the distance is, the better macthing the pairs are. Please note, "
"This tensor can contain LoD information to represent a batch of " "This tensor can contain LoD information to represent a batch of "
"inputs. One instance of this batch can contain different numbers of " "inputs. One instance of this batch can contain different numbers of "
"entities."); "entities.");
...@@ -140,20 +147,25 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -140,20 +147,25 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
"N is the batch size. If ColToRowMatchIndices[i][j] is -1, it " "N is the batch size. If ColToRowMatchIndices[i][j] is -1, it "
"means B[j] does not match any entity in i-th instance. " "means B[j] does not match any entity in i-th instance. "
"Otherwise, it means B[j] is matched to row " "Otherwise, it means B[j] is matched to row "
"RowToColMatchIndices[i][j] in i-th instance. The row number of " "ColToRowMatchIndices[i][j] in i-th instance. The row number of "
"i-th instance is saved in RowToColMatchIndices[i][j]."); "i-th instance is saved in ColToRowMatchIndices[i][j].");
AddOutput("ColToRowMatchDis", AddOutput("ColToRowMatchDis",
"(Tensor) A 2-D Tensor with shape [N, M] in float type. " "(Tensor) A 2-D Tensor with shape [N, M] in float type. "
"N is batch size. If ColToRowMatchIndices[i][j] is -1, " "N is batch size. If ColToRowMatchIndices[i][j] is -1, "
"ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed " "ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed "
"RowToColMatchIndices[i][j] = d, and the row offsets of each " "ColToRowMatchIndices[i][j] = d, and the row offsets of each "
"instance are called LoD. Then " "instance are called LoD. Then "
"ColToRowMatchDis[i][j] = DisMat[d+LoD[i]][j]"); "ColToRowMatchDis[i][j] = DisMat[d+LoD[i]][j]");
AddComment(R"DOC( AddComment(R"DOC(
This operator is a greedy bipartite matching algorithm, which is used to This operator is a greedy bipartite matching algorithm, which is used to
obtain the matching with the (greedy) maximum distance based on the input obtain the matching with the maximum distance based on the input
distance matrix. There are two outputs to save matched indices and distance. distance matrix. For input 2D matrix, the bipartite matching algorithm can
And this operator only calculate matched indices from column to row. find the matched column for each row, also can find the matched row for
each column. And this operator only calculate matched indices from column
to row. For each instance, the number of matched indices is the number of
of columns of the input ditance matrix.
There are two outputs to save matched indices and distance.
A simple description, this algothrim matched the best (maximum distance) A simple description, this algothrim matched the best (maximum distance)
row entity to the column entity and the matched indices are not duplicated row entity to the column entity and the matched indices are not duplicated
in each row of ColToRowMatchIndices. If the column entity is not matched in each row of ColToRowMatchIndices. If the column entity is not matched
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册