提交 ff8a6778 编写于 作者: Y Yibing Liu

Revise comments in rank_loss_op

上级 5a3d1362
...@@ -35,9 +35,10 @@ class RankLossOp : public framework::OperatorWithKernel { ...@@ -35,9 +35,10 @@ class RankLossOp : public framework::OperatorWithKernel {
auto right_dims = ctx->GetInputDim("Right"); auto right_dims = ctx->GetInputDim("Right");
PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims), PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims),
"All inputs must have the same size"); "All inputs must have the same size.");
PADDLE_ENFORCE((label_dims.size() == 2) && (label_dims[1] == 1), PADDLE_ENFORCE(
"All inputs must be row vector with size batch_size x 1."); (label_dims.size() == 2) && (label_dims[1] == 1),
"All inputs must be 2-D tensors with shape [batch_size x 1].");
ctx->SetOutputDim("Out", label_dims); ctx->SetOutputDim("Out", label_dims);
} }
}; };
...@@ -48,10 +49,17 @@ class RankLossOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -48,10 +49,17 @@ class RankLossOpMaker : public framework::OpProtoAndCheckerMaker {
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Label", AddInput("Label",
"The label indicating A ranked higher than B or not, row vector."); "(2-D Tensor with shape [batch_size x 1]) "
AddInput("Left", "The output of RankNet for doc A, vector."); "The label indicating A ranked higher than B or not.");
AddInput("Right", "The output of RankNet for doc B, vetor."); AddInput("Left",
AddOutput("Out", "The output loss of RankLoss operator, vector."); "(2-D Tensor with shape [batch_size x 1]) "
"The output of RankNet for doc A.");
AddInput("Right",
"(2-D Tensor with shape [batch_size x 1]) "
"The output of RankNet for doc B.");
AddOutput("Out",
"(2-D Tensor with shape [batch_size x 1]) "
"The output loss of RankLoss operator.");
AddComment(R"DOC( AddComment(R"DOC(
RankLoss Operator. RankLoss Operator.
...@@ -65,8 +73,9 @@ P = {0, 1} or {0, 0.5, 1}, where 0.5 means no information about the rank of ...@@ -65,8 +73,9 @@ P = {0, 1} or {0, 0.5, 1}, where 0.5 means no information about the rank of
the input pair. the input pair.
The RankLoss operator takes three inputs: Left (o_i), Right (o_j) and Label The RankLoss operator takes three inputs: Left (o_i), Right (o_j) and Label
(P_{i,j}), which represent the output of RankNet for the two docs and the label, (P_{i,j}), which represent the output score of RankNet for the two docs and
respectively, and yields the rank loss C_{i,j} using the following equation: the label respectively, and yields the rank loss C_{i,j} using the following
equation:
\f$$ \f$$
C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + log(1 + e^{o_{i,j}}) \\ C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + log(1 + e^{o_{i,j}}) \\
...@@ -74,7 +83,7 @@ respectively, and yields the rank loss C_{i,j} using the following equation: ...@@ -74,7 +83,7 @@ respectively, and yields the rank loss C_{i,j} using the following equation:
\tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \} \tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \}
\f$$ \f$$
The operator can take inputs of one sample or in batch. The operator can take batch inputs with size batch_size (batch_size >= 1).
)DOC"); )DOC");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册