未验证 提交 137e6336 编写于 作者: A Aurelius84 提交者: GitHub

Remove constraint that last dimension is forced to be 1 in rank_loss (#19997)

* fix input shape check test=develop

* move PADDLE_ENFORCE test=develop
上级 101a2b61
......@@ -27,20 +27,55 @@ class RankLossOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
// input check
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("Right"), "Input(Right) shouldn't be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true,
"Input(Label) shouldn't be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Left"), true,
"Input(Left) shouldn't be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Right"), true,
"Input(Right) shouldn't be null.");
auto label_dims = ctx->GetInputDim("Label");
auto left_dims = ctx->GetInputDim("Left");
auto right_dims = ctx->GetInputDim("Right");
PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims),
"All inputs must have the same size.");
PADDLE_ENFORCE(
(label_dims.size() == 2) && (label_dims[1] == 1),
"All inputs must be 2-D tensors with shape [batch_size x 1].");
// check label_dims valid
PADDLE_ENFORCE_GE(label_dims.size(), 1,
"The dimension size of Input(Label) must be greater than "
"or equal to 1.");
PADDLE_ENFORCE_LE(
label_dims.size(), 2,
"The dimension size of Input(Label) must be less than or equal to 2.");
if (label_dims.size() == 2U) {
PADDLE_ENFORCE_EQ(label_dims[1], 1,
"The last dimension of Input(Label) must be 1.");
}
// check left_dims valid
PADDLE_ENFORCE_GE(left_dims.size(), 1,
"The dimension size of Input(Left) must be greater than "
"or equal to 1.");
PADDLE_ENFORCE_LE(
left_dims.size(), 2,
"The dimension size of Input(Left) must be less than or equal to 2.");
if (left_dims.size() == 2U) {
PADDLE_ENFORCE_EQ(left_dims[1], 1,
"The last dimension of Input(Left) must be 1.");
}
// check right_dims valid
PADDLE_ENFORCE_GE(right_dims.size(), 1,
"The dimension size of Input(Right) must be greater than "
"or equal to 1.");
PADDLE_ENFORCE_LE(
right_dims.size(), 2,
"The dimension size of Input(Right) must be less than or equal to 2.");
if (right_dims.size() == 2U) {
PADDLE_ENFORCE_EQ(right_dims[1], 1,
"The last dimension of Input(Right) must be 1.");
}
PADDLE_ENFORCE_EQ(label_dims[0], left_dims[0],
"The first dimension of Input(Label) and Input(Left) "
"must have the same value.");
PADDLE_ENFORCE_EQ(label_dims[0], right_dims[0],
"The first dimension of Input(Label) and Input(Right) "
"must have the same value.");
ctx->SetOutputDim("Out", label_dims);
}
};
......@@ -98,21 +133,25 @@ class RankLossGradOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("Right"), "Input(Right) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
auto dims = ctx->GetInputDim("Left");
PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true,
"Input(Label) shouldn't be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Left"), true,
"Input(Left) shouldn't be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Right"), true,
"Input(Right) shouldn't be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) shouldn't be null.");
auto left_dims = ctx->GetInputDim("Left");
auto right_dims = ctx->GetInputDim("Right");
auto left_grad_name = framework::GradVarName("Left");
auto right_grad_name = framework::GradVarName("Right");
if (ctx->HasOutput(left_grad_name)) {
ctx->SetOutputDim(left_grad_name, dims);
ctx->SetOutputDim(left_grad_name, left_dims);
}
if (ctx->HasOutput(right_grad_name)) {
ctx->SetOutputDim(right_grad_name, dims);
ctx->SetOutputDim(right_grad_name, right_dims);
}
}
};
......
......@@ -22,14 +22,24 @@ from op_test import OpTest
class TestRankLossOp(OpTest):
def setUp(self):
self.op_type = "rank_loss"
batch_size = 5
shape = (5, 1)
# labels_{i} = {0, 1.0} or {0, 0.5, 1.0}
label = np.random.randint(0, 2, size=(batch_size, 1)).astype("float32")
left = np.random.random((batch_size, 1)).astype("float32")
right = np.random.random((batch_size, 1)).astype("float32")
label_shape, left_shape, right_shape = self.set_shape()
label = np.random.randint(0, 2, size=shape).astype("float32")
left = np.random.random(shape).astype("float32")
right = np.random.random(shape).astype("float32")
loss = np.log(1.0 + np.exp(left - right)) - label * (left - right)
self.inputs = {'Label': label, 'Left': left, 'Right': right}
self.outputs = {'Out': loss}
loss = np.reshape(loss, label_shape)
self.inputs = {
'Label': label.reshape(label_shape),
'Left': left.reshape(left_shape),
'Right': right.reshape(right_shape)
}
self.outputs = {'Out': loss.reshape(label_shape)}
def set_shape(self):
batch_size = 5
return (batch_size, 1), (batch_size, 1), (batch_size, 1)
def test_check_output(self):
self.check_output()
......@@ -44,5 +54,35 @@ class TestRankLossOp(OpTest):
self.check_grad(["Left"], "Out", no_grad_set=set('Right'))
class TestRankLossOp1(TestRankLossOp):
def set_shape(self):
batch_size = 5
return (batch_size), (batch_size, 1), (batch_size, 1)
class TestRankLossOp2(TestRankLossOp):
def set_shape(self):
batch_size = 5
return (batch_size, 1), (batch_size), (batch_size, 1)
class TestRankLossOp3(TestRankLossOp):
def set_shape(self):
batch_size = 5
return (batch_size, 1), (batch_size, 1), (batch_size)
class TestRankLossOp4(TestRankLossOp):
def set_shape(self):
batch_size = 5
return (batch_size), (batch_size), (batch_size, 1)
class TestRankLossOp5(TestRankLossOp):
def set_shape(self):
batch_size = 5
return (batch_size), (batch_size), (batch_size)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册