From bc2e26ee1b05b6be442cdcd014a1fdaa3b611ec9 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Thu, 28 Sep 2017 12:17:48 +0800 Subject: [PATCH] refine comments and clean code in marigin_rank_loss_op --- paddle/operators/margin_rank_loss_op.cc | 56 +++++++++++++++++-------- paddle/operators/margin_rank_loss_op.h | 16 ++----- 2 files changed, 42 insertions(+), 30 deletions(-) diff --git a/paddle/operators/margin_rank_loss_op.cc b/paddle/operators/margin_rank_loss_op.cc index 47faaf7163..8d62dbb4c6 100644 --- a/paddle/operators/margin_rank_loss_op.cc +++ b/paddle/operators/margin_rank_loss_op.cc @@ -25,47 +25,67 @@ class MarginRankLossOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { // input check PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) shouldn't be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X1"), "Input(X1) shouldn't be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X2"), "Input(X2) shouldn't be null"); + "Input(Label) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X1"), "Input(X1) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X2"), "Input(X2) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(X2) shouldn't be null."); auto label_dims = ctx.Input("Label")->dims(); auto x1_dims = ctx.Input("X1")->dims(); auto x2_dims = ctx.Input("X2")->dims(); PADDLE_ENFORCE((label_dims == x1_dims) && (x1_dims == x2_dims) && (label_dims.size() == 2) && (label_dims[1] == 1), - "All inputs must be vector with the same size"); - ctx.Output("Activated")->Resize(label_dims); - ctx.Output("Out")->Resize(label_dims); + "All inputs must be vector with the same size."); + auto act_t = ctx.Output("Activated"); + auto out_t = ctx.Output("Out"); + if (act_t) { + act_t->Resize(label_dims); + } + if (out_t) { + out_t->Resize(label_dims); + } } }; -template +template class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker { public: MarginRankLossOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X1", "The first variable to be ranked, row vector."); - AddInput("X2", "The second variable to be ranked, row vector."); + AddInput("X1", + "(2-D tensor with shape [batch_size x 1]) In pairwise ranking, " + "X1 is the score for one item to be ranked."); + AddInput("X2", + "(2-D tensor with shape [batch_size x 1]) In pairwise ranking, " + "X2 is the score for another item to be ranked."); AddInput("Label", - "The label indicating X1 ranked higher than X2 " - "or not, row vector."); - AddAttr("margin", "Margin for MarginRankLossOp, scalar.") - .SetDefault(0); + "(2-D tensor with shape [batch_size x 1]) " + "The label indicating X1 ranked higher than X2 or not, " + "can only be +1 or -1."); + AddAttr("margin", "(scalar, default 0) Margin for MarginRankLossOp.") + .SetDefault(static_cast(0)); AddOutput("Activated", - "Intermediate tensor to indicate whether each element of " - "Output(Out) is activated.") + "(2-D tensor with shape [batch_size x 1]) Intermediate tensor " + "to indicate whether each element of Output(Out) is activated.") .AsIntermediate(); - AddOutput("Out", "The output loss of MarginRankLoss operator"); + AddOutput("Out", + "(2-D tensor with shape [batch_size x 1])" + "The output loss of MarginRankLoss operator"); AddComment(R"DOC( MarginRankLoss operator measures the loss given a pair of input {`X1`, `X2`} -and the `Label` with attribute `margin`, where `Label = 1` indicating X1 is +and the `Label` with attribute `margin`, where `Label = +1` indicating X1 is ranked higher than `X2`, otherwise `Label = -1`. The loss turns out loss(X1, X2, Label) = max(0, -Label * (X1 - X2) + margin) -For batch input, `X1`, `X2` and `Label` all have the same size batch_size x 1. +The attribute `margin` involved here helps make the predictions more robust. +Only when the difference between `X1` and `X2` is greater than `margin`, it is +possible for these two items contribute to the final loss. + +For batch input with size `batch_size`, `X1`, `X2` and `Label` +all have the same shape [batch_size x 1]. )DOC"); } diff --git a/paddle/operators/margin_rank_loss_op.h b/paddle/operators/margin_rank_loss_op.h index 3d63343a61..ec00643ecd 100644 --- a/paddle/operators/margin_rank_loss_op.h +++ b/paddle/operators/margin_rank_loss_op.h @@ -23,26 +23,18 @@ namespace operators { template struct ReLU { HOSTDEVICE T operator()(const T& val) const { - if (val < 0) { - return static_cast(0); - } else { - return val; - } + return val > 0 ? val : static_cast(0); } }; template struct Heaviside { HOSTDEVICE T operator()(const T& val) const { - if (val > 0) { - return static_cast(1); - } else { - return static_cast(0); - } + return static_cast(val > 0 ? 1 : 0); } }; -template +template class MarginRankLossKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { @@ -56,7 +48,7 @@ class MarginRankLossKernel : public framework::OpKernel { out_t->mutable_data(ctx.GetPlace()); act_t->mutable_data(ctx.GetPlace()); - auto margin = static_cast(ctx.Attr("margin")); + auto margin = static_cast(ctx.Attr("margin")); auto out = framework::EigenVector::Flatten(*out_t); auto act = framework::EigenVector::Flatten(*act_t); -- GitLab