diff --git a/paddle/operators/margin_rank_loss_op.cc b/paddle/operators/margin_rank_loss_op.cc index 3f94f73fe6ed456098a6d1bf77662158e2fac7f4..16c9b20a265833d98e0100dadb12ea2938ea0275 100644 --- a/paddle/operators/margin_rank_loss_op.cc +++ b/paddle/operators/margin_rank_loss_op.cc @@ -22,7 +22,7 @@ class MarginRankLossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { // input check PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput("X1"), "Input(X1) shouldn't be null."); @@ -47,11 +47,11 @@ class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker { framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X1", - "(2-D tensor with shape [batch_size x 1]) In pairwise ranking, " - "X1 is the score for one item to be ranked."); + "(2-D tensor with shape [batch_size x 1]) The score for " + "one item X1 to be ranked, from pairwise ranking model."); AddInput("X2", - "(2-D tensor with shape [batch_size x 1]) In pairwise ranking, " - "X2 is the score for another item to be ranked."); + "(2-D tensor with shape [batch_size x 1]) The score for " + "another item X2 to be ranked, from pairwise ranking model."); AddInput("Label", "(2-D tensor with shape [batch_size x 1]) " "The label indicating X1 ranked higher than X2 or not, " @@ -63,19 +63,25 @@ class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker { "to indicate whether each element of Output(Out) is activated.") .AsIntermediate(); AddOutput("Out", - "(2-D tensor with shape [batch_size x 1])" + "(2-D tensor with shape [batch_size x 1]) " "The output loss of MarginRankLoss operator."); AddComment(R"DOC( -MarginRankLoss operator measures the loss given a pair of input {`X1`, `X2`} -and the `Label` with attribute `margin`, where `Label = +1` indicating X1 is -ranked higher than `X2`, otherwise `Label = -1`. The loss turns out +MarginRankLoss operator measures the loss given a pair of training sample +{`X1`, `X2`} and the `Label` with attribute `margin`, where `Label = +1` +indicating X1 is ranked higher than `X2`, otherwise `Label = -1`. The loss +turns out -loss(X1, X2, Label) = max(0, -Label * (X1 - X2) + margin) +loss(X1, X2, Label) = max(0, -Label * (X1 - X2) + margin). The attribute `margin` involved here helps make the predictions more robust. -Only when the difference between `X1` and `X2` is greater than `margin`, it is -possible for these two items contribute to the final loss. +Denote the item ranked higher as the positive sample, otherwise negative +sample. If the score of the two samples statisfies + +positive sample - negative sample < margin, + +the pair of samples will contribute to the loss, which will backpropogate and +train the ranking model to enlarge the difference of the two score. For batch input with size `batch_size`, `X1`, `X2` and `Label` all have the same shape [batch_size x 1]. @@ -89,7 +95,7 @@ class MarginRankLossGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput("X1"), "Input(X1) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput("X2"), "Input(X2) shouldn't be null."); diff --git a/paddle/operators/margin_rank_loss_op.h b/paddle/operators/margin_rank_loss_op.h index ec00643ecd41e09c21dc781c66a42f7bc0dfb907..8d0830147ecc465909e8988e90125929829f6f34 100644 --- a/paddle/operators/margin_rank_loss_op.h +++ b/paddle/operators/margin_rank_loss_op.h @@ -35,7 +35,7 @@ struct Heaviside { }; template -class MarginRankLossKernel : public framework::OpKernel { +class MarginRankLossKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* out_t = ctx.Output("Out"); @@ -63,7 +63,7 @@ class MarginRankLossKernel : public framework::OpKernel { }; template -class MarginRankLossGradKernel : public framework::OpKernel { +class MarginRankLossGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* d_x1_t =