提交 cf4b2db7 编写于 作者: Y Yibing Liu

change the dims of input of rank_loss_op

上级 ece32910
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
...@@ -37,11 +36,10 @@ class RankLossOp : public framework::OperatorWithKernel { ...@@ -37,11 +36,10 @@ class RankLossOp : public framework::OperatorWithKernel {
auto label_dims = ctx.Input<framework::Tensor>("Label")->dims(); auto label_dims = ctx.Input<framework::Tensor>("Label")->dims();
auto left_dims = ctx.Input<framework::Tensor>("Left")->dims(); auto left_dims = ctx.Input<framework::Tensor>("Left")->dims();
auto right_dims = ctx.Input<framework::Tensor>("Right")->dims(); auto right_dims = ctx.Input<framework::Tensor>("Right")->dims();
PADDLE_ENFORCE((label_dims.size() == 1) && (left_dims.size() == 1) &&
(right_dims.size() == 1),
"The rank of all inputs must be 1.");
PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims), PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims),
"All inputs must have the same size"); "All inputs must have the same size");
PADDLE_ENFORCE((label_dims.size() == 2) && (label_dims[1] == 1),
"All inputs must be row vector with size batch_sizex1.");
ctx.Output<framework::LoDTensor>("Out")->Resize(label_dims); ctx.Output<framework::LoDTensor>("Out")->Resize(label_dims);
} }
}; };
...@@ -52,10 +50,10 @@ class RankLossOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -52,10 +50,10 @@ class RankLossOpMaker : public framework::OpProtoAndCheckerMaker {
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Label", AddInput("Label",
"The label indicating A ranked higher than B or not, 1-D tensor."); "The label indicating A ranked higher than B or not, row vector.");
AddInput("Left", "The output of RankNet for doc A, 1-D tensor."); AddInput("Left", "The output of RankNet for doc A, vector.");
AddInput("Right", "The output of RankNet for doc B, 1-D tensor"); AddInput("Right", "The output of RankNet for doc B, vetor");
AddOutput("Out", "The output loss of RankLoss operator, 1-D tensor."); AddOutput("Out", "The output loss of RankLoss operator, vector.");
AddComment(R"DOC(RankLoss operator AddComment(R"DOC(RankLoss operator
Rank loss operator for RankNet[1]. RankNet is a pairwise ranking model with Rank loss operator for RankNet[1]. RankNet is a pairwise ranking model with
......
...@@ -8,9 +8,9 @@ class TestRankLossOp(OpTest): ...@@ -8,9 +8,9 @@ class TestRankLossOp(OpTest):
self.op_type = "rank_loss" self.op_type = "rank_loss"
batch_size = 5 batch_size = 5
# labels_{i} = {0, 1.0} or {0, 0.5, 1.0} # labels_{i} = {0, 1.0} or {0, 0.5, 1.0}
label = np.random.randint(0, 2, size=(batch_size, )).astype("float32") label = np.random.randint(0, 2, size=(batch_size, 1)).astype("float32")
left = np.random.random((batch_size, )).astype("float32") left = np.random.random((batch_size, 1)).astype("float32")
right = np.random.random((batch_size, )).astype("float32") right = np.random.random((batch_size, 1)).astype("float32")
loss = np.log(1.0 + np.exp(left - right)) - label * (left - right) loss = np.log(1.0 + np.exp(left - right)) - label * (left - right)
self.inputs = {'Label': label, 'Left': left, 'Right': right} self.inputs = {'Label': label, 'Left': left, 'Right': right}
self.outputs = {'Out': loss} self.outputs = {'Out': loss}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册