diff --git a/paddle/operators/bipartite_match_op.cc b/paddle/operators/bipartite_match_op.cc index c2d30c7d926b66703b5a4d8e4ea0a91d6f77c1c3..0fcff6e26d7e9a84ee326267181e849d9eac3e9c 100644 --- a/paddle/operators/bipartite_match_op.cc +++ b/paddle/operators/bipartite_match_op.cc @@ -28,11 +28,11 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("DisMat"), - "Input(DisMat) of BipartiteMatch should not be null."); + PADDLE_ENFORCE(ctx->HasInput("DistMat"), + "Input(DistMat) of BipartiteMatch should not be null."); - auto dims = ctx->GetInputDim("DisMat"); - PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DisMat) must be 2."); + auto dims = ctx->GetInputDim("DistMat"); + PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DistMat) must be 2."); ctx->SetOutputDim("ColToRowMatchIndices", dims); ctx->SetOutputDim("ColToRowMatchDis", dims); @@ -90,7 +90,7 @@ class BipartiteMatchKernel : public framework::OpKernel { } void Compute(const framework::ExecutionContext& context) const override { - auto* dist_mat = context.Input("DisMat"); + auto* dist_mat = context.Input("DistMat"); auto* match_indices = context.Output("ColToRowMatchIndices"); auto* match_dist = context.Output("ColToRowMatchDis"); @@ -132,12 +132,12 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker { BipartiteMatchOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( - "DisMat", + "DistMat", "(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape " "[K, M]. It is pair-wise distance matrix between the entities " "represented by each row and each column. For example, assumed one " "entity is A with shape [K], another entity is B with shape [M]. The " - "DisMat[i][j] is the distance between A[i] and B[j]. The bigger " + "DistMat[i][j] is the distance between A[i] and B[j]. The bigger " "the distance is, the better macthing the pairs are. Please note, " "This tensor can contain LoD information to represent a batch of " "inputs. One instance of this batch can contain different numbers of " @@ -155,7 +155,7 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker { "ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed " "ColToRowMatchIndices[i][j] = d, and the row offsets of each " "instance are called LoD. Then " - "ColToRowMatchDis[i][j] = DisMat[d+LoD[i]][j]"); + "ColToRowMatchDis[i][j] = DistMat[d+LoD[i]][j]"); AddComment(R"DOC( This operator is a greedy bipartite matching algorithm, which is used to obtain the matching with the maximum distance based on the input @@ -171,7 +171,7 @@ row entity to the column entity and the matched indices are not duplicated in each row of ColToRowMatchIndices. If the column entity is not matched any row entity, set -1 in ColToRowMatchIndices. -Please note that the input DisMat can be LoDTensor (with LoD) or Tensor. +Please note that the input DistMat can be LoDTensor (with LoD) or Tensor. If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size. If Tensor, the height of ColToRowMatchIndices is 1.