未验证 提交 137e6336 编写于 作者: A Aurelius84 提交者: GitHub

Remove constraint that last dimension is forced to be 1 in rank_loss (#19997)

* fix input shape check test=develop

* move PADDLE_ENFORCE test=develop
上级 101a2b61
...@@ -27,20 +27,55 @@ class RankLossOp : public framework::OperatorWithKernel { ...@@ -27,20 +27,55 @@ class RankLossOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
// input check PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true,
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null."); "Input(Label) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("Left"), true,
PADDLE_ENFORCE(ctx->HasInput("Right"), "Input(Right) shouldn't be null."); "Input(Left) shouldn't be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Right"), true,
"Input(Right) shouldn't be null.");
auto label_dims = ctx->GetInputDim("Label"); auto label_dims = ctx->GetInputDim("Label");
auto left_dims = ctx->GetInputDim("Left"); auto left_dims = ctx->GetInputDim("Left");
auto right_dims = ctx->GetInputDim("Right"); auto right_dims = ctx->GetInputDim("Right");
// check label_dims valid
PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims), PADDLE_ENFORCE_GE(label_dims.size(), 1,
"All inputs must have the same size."); "The dimension size of Input(Label) must be greater than "
PADDLE_ENFORCE( "or equal to 1.");
(label_dims.size() == 2) && (label_dims[1] == 1), PADDLE_ENFORCE_LE(
"All inputs must be 2-D tensors with shape [batch_size x 1]."); label_dims.size(), 2,
"The dimension size of Input(Label) must be less than or equal to 2.");
if (label_dims.size() == 2U) {
PADDLE_ENFORCE_EQ(label_dims[1], 1,
"The last dimension of Input(Label) must be 1.");
}
// check left_dims valid
PADDLE_ENFORCE_GE(left_dims.size(), 1,
"The dimension size of Input(Left) must be greater than "
"or equal to 1.");
PADDLE_ENFORCE_LE(
left_dims.size(), 2,
"The dimension size of Input(Left) must be less than or equal to 2.");
if (left_dims.size() == 2U) {
PADDLE_ENFORCE_EQ(left_dims[1], 1,
"The last dimension of Input(Left) must be 1.");
}
// check right_dims valid
PADDLE_ENFORCE_GE(right_dims.size(), 1,
"The dimension size of Input(Right) must be greater than "
"or equal to 1.");
PADDLE_ENFORCE_LE(
right_dims.size(), 2,
"The dimension size of Input(Right) must be less than or equal to 2.");
if (right_dims.size() == 2U) {
PADDLE_ENFORCE_EQ(right_dims[1], 1,
"The last dimension of Input(Right) must be 1.");
}
PADDLE_ENFORCE_EQ(label_dims[0], left_dims[0],
"The first dimension of Input(Label) and Input(Left) "
"must have the same value.");
PADDLE_ENFORCE_EQ(label_dims[0], right_dims[0],
"The first dimension of Input(Label) and Input(Right) "
"must have the same value.");
ctx->SetOutputDim("Out", label_dims); ctx->SetOutputDim("Out", label_dims);
} }
}; };
...@@ -98,21 +133,25 @@ class RankLossGradOp : public framework::OperatorWithKernel { ...@@ -98,21 +133,25 @@ class RankLossGradOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true,
PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null."); "Input(Label) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("Right"), "Input(Right) shouldn't be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("Left"), true,
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Left) shouldn't be null.");
"Input(Out@GRAD) shouldn't be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("Right"), true,
auto dims = ctx->GetInputDim("Left"); "Input(Right) shouldn't be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) shouldn't be null.");
auto left_dims = ctx->GetInputDim("Left");
auto right_dims = ctx->GetInputDim("Right");
auto left_grad_name = framework::GradVarName("Left"); auto left_grad_name = framework::GradVarName("Left");
auto right_grad_name = framework::GradVarName("Right"); auto right_grad_name = framework::GradVarName("Right");
if (ctx->HasOutput(left_grad_name)) { if (ctx->HasOutput(left_grad_name)) {
ctx->SetOutputDim(left_grad_name, dims); ctx->SetOutputDim(left_grad_name, left_dims);
} }
if (ctx->HasOutput(right_grad_name)) { if (ctx->HasOutput(right_grad_name)) {
ctx->SetOutputDim(right_grad_name, dims); ctx->SetOutputDim(right_grad_name, right_dims);
} }
} }
}; };
......
...@@ -22,14 +22,24 @@ from op_test import OpTest ...@@ -22,14 +22,24 @@ from op_test import OpTest
class TestRankLossOp(OpTest): class TestRankLossOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "rank_loss" self.op_type = "rank_loss"
batch_size = 5 shape = (5, 1)
# labels_{i} = {0, 1.0} or {0, 0.5, 1.0} # labels_{i} = {0, 1.0} or {0, 0.5, 1.0}
label = np.random.randint(0, 2, size=(batch_size, 1)).astype("float32") label_shape, left_shape, right_shape = self.set_shape()
left = np.random.random((batch_size, 1)).astype("float32") label = np.random.randint(0, 2, size=shape).astype("float32")
right = np.random.random((batch_size, 1)).astype("float32") left = np.random.random(shape).astype("float32")
right = np.random.random(shape).astype("float32")
loss = np.log(1.0 + np.exp(left - right)) - label * (left - right) loss = np.log(1.0 + np.exp(left - right)) - label * (left - right)
self.inputs = {'Label': label, 'Left': left, 'Right': right} loss = np.reshape(loss, label_shape)
self.outputs = {'Out': loss} self.inputs = {
'Label': label.reshape(label_shape),
'Left': left.reshape(left_shape),
'Right': right.reshape(right_shape)
}
self.outputs = {'Out': loss.reshape(label_shape)}
def set_shape(self):
batch_size = 5
return (batch_size, 1), (batch_size, 1), (batch_size, 1)
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -44,5 +54,35 @@ class TestRankLossOp(OpTest): ...@@ -44,5 +54,35 @@ class TestRankLossOp(OpTest):
self.check_grad(["Left"], "Out", no_grad_set=set('Right')) self.check_grad(["Left"], "Out", no_grad_set=set('Right'))
class TestRankLossOp1(TestRankLossOp):
def set_shape(self):
batch_size = 5
return (batch_size), (batch_size, 1), (batch_size, 1)
class TestRankLossOp2(TestRankLossOp):
def set_shape(self):
batch_size = 5
return (batch_size, 1), (batch_size), (batch_size, 1)
class TestRankLossOp3(TestRankLossOp):
def set_shape(self):
batch_size = 5
return (batch_size, 1), (batch_size, 1), (batch_size)
class TestRankLossOp4(TestRankLossOp):
def set_shape(self):
batch_size = 5
return (batch_size), (batch_size), (batch_size, 1)
class TestRankLossOp5(TestRankLossOp):
def set_shape(self):
batch_size = 5
return (batch_size), (batch_size), (batch_size)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册