提交 530df1b2 编写于 作者: D dangqingqing

Fix the naming.

上级 07908686
...@@ -28,11 +28,11 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { ...@@ -28,11 +28,11 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("DisMat"), PADDLE_ENFORCE(ctx->HasInput("DistMat"),
"Input(DisMat) of BipartiteMatch should not be null."); "Input(DistMat) of BipartiteMatch should not be null.");
auto dims = ctx->GetInputDim("DisMat"); auto dims = ctx->GetInputDim("DistMat");
PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DisMat) must be 2."); PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DistMat) must be 2.");
ctx->SetOutputDim("ColToRowMatchIndices", dims); ctx->SetOutputDim("ColToRowMatchIndices", dims);
ctx->SetOutputDim("ColToRowMatchDis", dims); ctx->SetOutputDim("ColToRowMatchDis", dims);
...@@ -90,7 +90,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> { ...@@ -90,7 +90,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
} }
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* dist_mat = context.Input<LoDTensor>("DisMat"); auto* dist_mat = context.Input<LoDTensor>("DistMat");
auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices"); auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices");
auto* match_dist = context.Output<Tensor>("ColToRowMatchDis"); auto* match_dist = context.Output<Tensor>("ColToRowMatchDis");
...@@ -132,12 +132,12 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -132,12 +132,12 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
BipartiteMatchOpMaker(OpProto* proto, OpAttrChecker* op_checker) BipartiteMatchOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"DisMat", "DistMat",
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape " "(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
"[K, M]. It is pair-wise distance matrix between the entities " "[K, M]. It is pair-wise distance matrix between the entities "
"represented by each row and each column. For example, assumed one " "represented by each row and each column. For example, assumed one "
"entity is A with shape [K], another entity is B with shape [M]. The " "entity is A with shape [K], another entity is B with shape [M]. The "
"DisMat[i][j] is the distance between A[i] and B[j]. The bigger " "DistMat[i][j] is the distance between A[i] and B[j]. The bigger "
"the distance is, the better macthing the pairs are. Please note, " "the distance is, the better macthing the pairs are. Please note, "
"This tensor can contain LoD information to represent a batch of " "This tensor can contain LoD information to represent a batch of "
"inputs. One instance of this batch can contain different numbers of " "inputs. One instance of this batch can contain different numbers of "
...@@ -155,7 +155,7 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -155,7 +155,7 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
"ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed " "ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed "
"ColToRowMatchIndices[i][j] = d, and the row offsets of each " "ColToRowMatchIndices[i][j] = d, and the row offsets of each "
"instance are called LoD. Then " "instance are called LoD. Then "
"ColToRowMatchDis[i][j] = DisMat[d+LoD[i]][j]"); "ColToRowMatchDis[i][j] = DistMat[d+LoD[i]][j]");
AddComment(R"DOC( AddComment(R"DOC(
This operator is a greedy bipartite matching algorithm, which is used to This operator is a greedy bipartite matching algorithm, which is used to
obtain the matching with the maximum distance based on the input obtain the matching with the maximum distance based on the input
...@@ -171,7 +171,7 @@ row entity to the column entity and the matched indices are not duplicated ...@@ -171,7 +171,7 @@ row entity to the column entity and the matched indices are not duplicated
in each row of ColToRowMatchIndices. If the column entity is not matched in each row of ColToRowMatchIndices. If the column entity is not matched
any row entity, set -1 in ColToRowMatchIndices. any row entity, set -1 in ColToRowMatchIndices.
Please note that the input DisMat can be LoDTensor (with LoD) or Tensor. Please note that the input DistMat can be LoDTensor (with LoD) or Tensor.
If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size. If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size.
If Tensor, the height of ColToRowMatchIndices is 1. If Tensor, the height of ColToRowMatchIndices is 1.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册