提交 bc2e26ee 编写于 作者: Y Yibing Liu

refine comments and clean code in marigin_rank_loss_op

上级 dc186af7
...@@ -25,47 +25,67 @@ class MarginRankLossOp : public framework::OperatorWithKernel { ...@@ -25,47 +25,67 @@ class MarginRankLossOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
// input check // input check
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) shouldn't be null"); "Input(Label) shouldn't be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X1"), "Input(X1) shouldn't be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X1"), "Input(X1) shouldn't be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X2"), "Input(X2) shouldn't be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X2"), "Input(X2) shouldn't be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(X2) shouldn't be null.");
auto label_dims = ctx.Input<framework::Tensor>("Label")->dims(); auto label_dims = ctx.Input<framework::Tensor>("Label")->dims();
auto x1_dims = ctx.Input<framework::Tensor>("X1")->dims(); auto x1_dims = ctx.Input<framework::Tensor>("X1")->dims();
auto x2_dims = ctx.Input<framework::Tensor>("X2")->dims(); auto x2_dims = ctx.Input<framework::Tensor>("X2")->dims();
PADDLE_ENFORCE((label_dims == x1_dims) && (x1_dims == x2_dims) && PADDLE_ENFORCE((label_dims == x1_dims) && (x1_dims == x2_dims) &&
(label_dims.size() == 2) && (label_dims[1] == 1), (label_dims.size() == 2) && (label_dims[1] == 1),
"All inputs must be vector with the same size"); "All inputs must be vector with the same size.");
ctx.Output<framework::LoDTensor>("Activated")->Resize(label_dims); auto act_t = ctx.Output<framework::LoDTensor>("Activated");
ctx.Output<framework::LoDTensor>("Out")->Resize(label_dims); auto out_t = ctx.Output<framework::LoDTensor>("Out");
if (act_t) {
act_t->Resize(label_dims);
}
if (out_t) {
out_t->Resize(label_dims);
}
} }
}; };
template <typename AttrType> template <typename T>
class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker { class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MarginRankLossOpMaker(framework::OpProto *proto, MarginRankLossOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X1", "The first variable to be ranked, row vector."); AddInput("X1",
AddInput("X2", "The second variable to be ranked, row vector."); "(2-D tensor with shape [batch_size x 1]) In pairwise ranking, "
"X1 is the score for one item to be ranked.");
AddInput("X2",
"(2-D tensor with shape [batch_size x 1]) In pairwise ranking, "
"X2 is the score for another item to be ranked.");
AddInput("Label", AddInput("Label",
"The label indicating X1 ranked higher than X2 " "(2-D tensor with shape [batch_size x 1]) "
"or not, row vector."); "The label indicating X1 ranked higher than X2 or not, "
AddAttr<AttrType>("margin", "Margin for MarginRankLossOp, scalar.") "can only be +1 or -1.");
.SetDefault(0); AddAttr<T>("margin", "(scalar, default 0) Margin for MarginRankLossOp.")
.SetDefault(static_cast<T>(0));
AddOutput("Activated", AddOutput("Activated",
"Intermediate tensor to indicate whether each element of " "(2-D tensor with shape [batch_size x 1]) Intermediate tensor "
"Output(Out) is activated.") "to indicate whether each element of Output(Out) is activated.")
.AsIntermediate(); .AsIntermediate();
AddOutput("Out", "The output loss of MarginRankLoss operator"); AddOutput("Out",
"(2-D tensor with shape [batch_size x 1])"
"The output loss of MarginRankLoss operator");
AddComment(R"DOC( AddComment(R"DOC(
MarginRankLoss operator measures the loss given a pair of input {`X1`, `X2`} MarginRankLoss operator measures the loss given a pair of input {`X1`, `X2`}
and the `Label` with attribute `margin`, where `Label = 1` indicating X1 is and the `Label` with attribute `margin`, where `Label = +1` indicating X1 is
ranked higher than `X2`, otherwise `Label = -1`. The loss turns out ranked higher than `X2`, otherwise `Label = -1`. The loss turns out
loss(X1, X2, Label) = max(0, -Label * (X1 - X2) + margin) loss(X1, X2, Label) = max(0, -Label * (X1 - X2) + margin)
For batch input, `X1`, `X2` and `Label` all have the same size batch_size x 1. The attribute `margin` involved here helps make the predictions more robust.
Only when the difference between `X1` and `X2` is greater than `margin`, it is
possible for these two items contribute to the final loss.
For batch input with size `batch_size`, `X1`, `X2` and `Label`
all have the same shape [batch_size x 1].
)DOC"); )DOC");
} }
......
...@@ -23,26 +23,18 @@ namespace operators { ...@@ -23,26 +23,18 @@ namespace operators {
template <typename T> template <typename T>
struct ReLU { struct ReLU {
HOSTDEVICE T operator()(const T& val) const { HOSTDEVICE T operator()(const T& val) const {
if (val < 0) { return val > 0 ? val : static_cast<T>(0);
return static_cast<T>(0);
} else {
return val;
}
} }
}; };
template <typename T> template <typename T>
struct Heaviside { struct Heaviside {
HOSTDEVICE T operator()(const T& val) const { HOSTDEVICE T operator()(const T& val) const {
if (val > 0) { return static_cast<T>(val > 0 ? 1 : 0);
return static_cast<T>(1);
} else {
return static_cast<T>(0);
}
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T>
class MarginRankLossKernel : public framework::OpKernel { class MarginRankLossKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
...@@ -56,7 +48,7 @@ class MarginRankLossKernel : public framework::OpKernel { ...@@ -56,7 +48,7 @@ class MarginRankLossKernel : public framework::OpKernel {
out_t->mutable_data<T>(ctx.GetPlace()); out_t->mutable_data<T>(ctx.GetPlace());
act_t->mutable_data<T>(ctx.GetPlace()); act_t->mutable_data<T>(ctx.GetPlace());
auto margin = static_cast<T>(ctx.Attr<AttrType>("margin")); auto margin = static_cast<T>(ctx.Attr<T>("margin"));
auto out = framework::EigenVector<T>::Flatten(*out_t); auto out = framework::EigenVector<T>::Flatten(*out_t);
auto act = framework::EigenVector<T>::Flatten(*act_t); auto act = framework::EigenVector<T>::Flatten(*act_t);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册