未验证 提交 61d98f27 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #6058 from kuke/refine_rank_loss_op

Revise comments in rank_loss_op
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
...@@ -35,9 +35,10 @@ class RankLossOp : public framework::OperatorWithKernel { ...@@ -35,9 +35,10 @@ class RankLossOp : public framework::OperatorWithKernel {
auto right_dims = ctx->GetInputDim("Right"); auto right_dims = ctx->GetInputDim("Right");
PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims), PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims),
"All inputs must have the same size"); "All inputs must have the same size.");
PADDLE_ENFORCE((label_dims.size() == 2) && (label_dims[1] == 1), PADDLE_ENFORCE(
"All inputs must be row vector with size batch_size x 1."); (label_dims.size() == 2) && (label_dims[1] == 1),
"All inputs must be 2-D tensors with shape [batch_size x 1].");
ctx->SetOutputDim("Out", label_dims); ctx->SetOutputDim("Out", label_dims);
} }
}; };
...@@ -48,10 +49,17 @@ class RankLossOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -48,10 +49,17 @@ class RankLossOpMaker : public framework::OpProtoAndCheckerMaker {
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Label", AddInput("Label",
"The label indicating A ranked higher than B or not, row vector."); "(2-D Tensor with shape [batch_size x 1]) "
AddInput("Left", "The output of RankNet for doc A, vector."); "The label indicating A ranked higher than B or not.");
AddInput("Right", "The output of RankNet for doc B, vetor."); AddInput("Left",
AddOutput("Out", "The output loss of RankLoss operator, vector."); "(2-D Tensor with shape [batch_size x 1]) "
"The output of RankNet for doc A.");
AddInput("Right",
"(2-D Tensor with shape [batch_size x 1]) "
"The output of RankNet for doc B.");
AddOutput("Out",
"(2-D Tensor with shape [batch_size x 1]) "
"The output loss of RankLoss operator.");
AddComment(R"DOC( AddComment(R"DOC(
RankLoss Operator. RankLoss Operator.
...@@ -65,16 +73,17 @@ P = {0, 1} or {0, 0.5, 1}, where 0.5 means no information about the rank of ...@@ -65,16 +73,17 @@ P = {0, 1} or {0, 0.5, 1}, where 0.5 means no information about the rank of
the input pair. the input pair.
The RankLoss operator takes three inputs: Left (o_i), Right (o_j) and Label The RankLoss operator takes three inputs: Left (o_i), Right (o_j) and Label
(P_{i,j}), which represent the output of RankNet for the two docs and the label, (P_{i,j}), which represent the output score of RankNet for the two docs and
respectively, and yields the rank loss C_{i,j} using the following equation: the label respectively, and yields the rank loss C_{i,j} using the following
equation:
\f$$ $$
C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + log(1 + e^{o_{i,j}}) \\ C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + \log(1 + e^{o_{i,j}}) \\
o_{i,j} = o_i - o_j \\ o_{i,j} = o_i - o_j \\
\tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \} \tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \}
\f$$ $$
The operator can take inputs of one sample or in batch. The operator can take batch inputs with size batch_size (batch_size >= 1).
)DOC"); )DOC");
} }
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册