diff --git a/paddle/operators/margin_rank_loss_op.cc b/paddle/operators/margin_rank_loss_op.cc index 47faaf71630dd98931c7674032b77b156268dca2..8d62dbb4c67c9a4ba83cfcd8180f19decee8e16e 100644 --- a/paddle/operators/margin_rank_loss_op.cc +++ b/paddle/operators/margin_rank_loss_op.cc @@ -25,47 +25,67 @@ class MarginRankLossOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { // input check PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) shouldn't be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X1"), "Input(X1) shouldn't be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X2"), "Input(X2) shouldn't be null"); + "Input(Label) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X1"), "Input(X1) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X2"), "Input(X2) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(X2) shouldn't be null."); auto label_dims = ctx.Input("Label")->dims(); auto x1_dims = ctx.Input("X1")->dims(); auto x2_dims = ctx.Input("X2")->dims(); PADDLE_ENFORCE((label_dims == x1_dims) && (x1_dims == x2_dims) && (label_dims.size() == 2) && (label_dims[1] == 1), - "All inputs must be vector with the same size"); - ctx.Output("Activated")->Resize(label_dims); - ctx.Output("Out")->Resize(label_dims); + "All inputs must be vector with the same size."); + auto act_t = ctx.Output("Activated"); + auto out_t = ctx.Output("Out"); + if (act_t) { + act_t->Resize(label_dims); + } + if (out_t) { + out_t->Resize(label_dims); + } } }; -template +template class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker { public: MarginRankLossOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X1", "The first variable to be ranked, row vector."); - AddInput("X2", "The second variable to be ranked, row vector."); + AddInput("X1", + "(2-D tensor with shape [batch_size x 1]) In pairwise ranking, " + "X1 is the score for one item to be ranked."); + AddInput("X2", + "(2-D tensor with shape [batch_size x 1]) In pairwise ranking, " + "X2 is the score for another item to be ranked."); AddInput("Label", - "The label indicating X1 ranked higher than X2 " - "or not, row vector."); - AddAttr("margin", "Margin for MarginRankLossOp, scalar.") - .SetDefault(0); + "(2-D tensor with shape [batch_size x 1]) " + "The label indicating X1 ranked higher than X2 or not, " + "can only be +1 or -1."); + AddAttr("margin", "(scalar, default 0) Margin for MarginRankLossOp.") + .SetDefault(static_cast(0)); AddOutput("Activated", - "Intermediate tensor to indicate whether each element of " - "Output(Out) is activated.") + "(2-D tensor with shape [batch_size x 1]) Intermediate tensor " + "to indicate whether each element of Output(Out) is activated.") .AsIntermediate(); - AddOutput("Out", "The output loss of MarginRankLoss operator"); + AddOutput("Out", + "(2-D tensor with shape [batch_size x 1])" + "The output loss of MarginRankLoss operator"); AddComment(R"DOC( MarginRankLoss operator measures the loss given a pair of input {`X1`, `X2`} -and the `Label` with attribute `margin`, where `Label = 1` indicating X1 is +and the `Label` with attribute `margin`, where `Label = +1` indicating X1 is ranked higher than `X2`, otherwise `Label = -1`. The loss turns out loss(X1, X2, Label) = max(0, -Label * (X1 - X2) + margin) -For batch input, `X1`, `X2` and `Label` all have the same size batch_size x 1. +The attribute `margin` involved here helps make the predictions more robust. +Only when the difference between `X1` and `X2` is greater than `margin`, it is +possible for these two items contribute to the final loss. + +For batch input with size `batch_size`, `X1`, `X2` and `Label` +all have the same shape [batch_size x 1]. )DOC"); } diff --git a/paddle/operators/margin_rank_loss_op.h b/paddle/operators/margin_rank_loss_op.h index 3d63343a613fe8dea3abe09fc2440743f4b9ac75..ec00643ecd41e09c21dc781c66a42f7bc0dfb907 100644 --- a/paddle/operators/margin_rank_loss_op.h +++ b/paddle/operators/margin_rank_loss_op.h @@ -23,26 +23,18 @@ namespace operators { template struct ReLU { HOSTDEVICE T operator()(const T& val) const { - if (val < 0) { - return static_cast(0); - } else { - return val; - } + return val > 0 ? val : static_cast(0); } }; template struct Heaviside { HOSTDEVICE T operator()(const T& val) const { - if (val > 0) { - return static_cast(1); - } else { - return static_cast(0); - } + return static_cast(val > 0 ? 1 : 0); } }; -template +template class MarginRankLossKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { @@ -56,7 +48,7 @@ class MarginRankLossKernel : public framework::OpKernel { out_t->mutable_data(ctx.GetPlace()); act_t->mutable_data(ctx.GetPlace()); - auto margin = static_cast(ctx.Attr("margin")); + auto margin = static_cast(ctx.Attr("margin")); auto out = framework::EigenVector::Flatten(*out_t); auto act = framework::EigenVector::Flatten(*act_t);