未验证 提交 77cb396e 编写于 作者: L Li Fuchen 提交者: GitHub

OP(rank_loss, similarity_focus, squeeze) error message enhancement (#24448) (#24467)

* enhance rank_loss error message, test=develop

* enhance similarity_focus error message, test=develop

* enhance squeeze error message, test=develop
上级 e01e77d7
...@@ -27,55 +27,87 @@ class RankLossOp : public framework::OperatorWithKernel { ...@@ -27,55 +27,87 @@ class RankLossOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true, OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "RankLoss");
"Input(Label) shouldn't be null."); OP_INOUT_CHECK(ctx->HasInput("Left"), "Input", "Left", "RankLoss");
PADDLE_ENFORCE_EQ(ctx->HasInput("Left"), true, OP_INOUT_CHECK(ctx->HasInput("Right"), "Input", "Right", "RankLoss");
"Input(Left) shouldn't be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Right"), true,
"Input(Right) shouldn't be null.");
auto label_dims = ctx->GetInputDim("Label"); auto label_dims = ctx->GetInputDim("Label");
auto left_dims = ctx->GetInputDim("Left"); auto left_dims = ctx->GetInputDim("Left");
auto right_dims = ctx->GetInputDim("Right"); auto right_dims = ctx->GetInputDim("Right");
// check label_dims valid // check label_dims valid
PADDLE_ENFORCE_GE(label_dims.size(), 1, PADDLE_ENFORCE_GE(
"The dimension size of Input(Label) must be greater than " label_dims.size(), 1,
"or equal to 1."); platform::errors::InvalidArgument(
"The dimension size of Input(Label) must be greater than "
"or equal to 1, but received %d.",
label_dims.size()));
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
label_dims.size(), 2, label_dims.size(), 2,
"The dimension size of Input(Label) must be less than or equal to 2."); platform::errors::InvalidArgument("The dimension size of Input(Label) "
"must be less than or equal to 2, "
"but received %d.",
label_dims.size()));
if (label_dims.size() == 2U) { if (label_dims.size() == 2U) {
PADDLE_ENFORCE_EQ(label_dims[1], 1, PADDLE_ENFORCE_EQ(
"The last dimension of Input(Label) must be 1."); label_dims[1], 1,
platform::errors::InvalidArgument(
"The last dimension of Input(Label) must be 1, but received %d.",
label_dims[1]));
} }
// check left_dims valid // check left_dims valid
PADDLE_ENFORCE_GE(left_dims.size(), 1, PADDLE_ENFORCE_GE(
"The dimension size of Input(Left) must be greater than " left_dims.size(), 1,
"or equal to 1."); platform::errors::InvalidArgument(
"The dimension size of Input(Left) must be greater than "
"or equal to 1, but received %d.",
left_dims.size()));
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
left_dims.size(), 2, left_dims.size(), 2,
"The dimension size of Input(Left) must be less than or equal to 2."); platform::errors::InvalidArgument("The dimension size of Input(Left) "
"must be less than or equal to 2, "
"but received %d.",
left_dims.size()));
if (left_dims.size() == 2U) { if (left_dims.size() == 2U) {
PADDLE_ENFORCE_EQ(left_dims[1], 1, PADDLE_ENFORCE_EQ(
"The last dimension of Input(Left) must be 1."); left_dims[1], 1,
platform::errors::InvalidArgument(
"The last dimension of Input(Left) must be 1, but received %d.",
left_dims[1]));
} }
// check right_dims valid // check right_dims valid
PADDLE_ENFORCE_GE(right_dims.size(), 1, PADDLE_ENFORCE_GE(
"The dimension size of Input(Right) must be greater than " right_dims.size(), 1,
"or equal to 1."); platform::errors::InvalidArgument(
"The dimension size of Input(Right) must be greater than "
"or equal to 1, but received %d.",
right_dims.size()));
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
right_dims.size(), 2, right_dims.size(), 2,
"The dimension size of Input(Right) must be less than or equal to 2."); platform::errors::InvalidArgument("The dimension size of Input(Right) "
"must be less than or equal to 2, "
"but received %d.",
right_dims.size()));
if (right_dims.size() == 2U) { if (right_dims.size() == 2U) {
PADDLE_ENFORCE_EQ(right_dims[1], 1, PADDLE_ENFORCE_EQ(
"The last dimension of Input(Right) must be 1."); right_dims[1], 1,
platform::errors::InvalidArgument(
"The last dimension of Input(Right) must be 1, but received %d.",
right_dims[1]));
} }
PADDLE_ENFORCE_EQ(label_dims[0], left_dims[0], PADDLE_ENFORCE_EQ(
"The first dimension of Input(Label) and Input(Left) " label_dims[0], left_dims[0],
"must have the same value."); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(label_dims[0], right_dims[0], "The first dimension of Input(Label) and Input(Left) "
"The first dimension of Input(Label) and Input(Right) " "must have the same value. But received Label.dims[0]=%d, "
"must have the same value."); "Left.dims[0]=%d.",
label_dims[0], left_dims[0]));
PADDLE_ENFORCE_EQ(
label_dims[0], right_dims[0],
platform::errors::InvalidArgument(
"The first dimension of Input(Label) and Input(Right) "
"must have the same value. But received Label.dims[0]=%d, "
"Right.dims[0]=%d.",
label_dims[0], right_dims[0]));
ctx->SetOutputDim("Out", label_dims); ctx->SetOutputDim("Out", label_dims);
} }
}; };
...@@ -133,14 +165,12 @@ class RankLossGradOp : public framework::OperatorWithKernel { ...@@ -133,14 +165,12 @@ class RankLossGradOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true, OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "RankLossGrad");
"Input(Label) shouldn't be null."); OP_INOUT_CHECK(ctx->HasInput("Left"), "Input", "Left", "RankLossGrad");
PADDLE_ENFORCE_EQ(ctx->HasInput("Left"), true, OP_INOUT_CHECK(ctx->HasInput("Right"), "Input", "Right", "RankLossGrad");
"Input(Left) shouldn't be null."); OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
PADDLE_ENFORCE_EQ(ctx->HasInput("Right"), true, framework::GradVarName("Out"), "RankLossGrad");
"Input(Right) shouldn't be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) shouldn't be null.");
auto left_dims = ctx->GetInputDim("Left"); auto left_dims = ctx->GetInputDim("Left");
auto right_dims = ctx->GetInputDim("Right"); auto right_dims = ctx->GetInputDim("Right");
auto left_grad_name = framework::GradVarName("Left"); auto left_grad_name = framework::GradVarName("Left");
......
...@@ -59,10 +59,15 @@ class SimilarityFocusOp : public framework::OperatorWithKernel { ...@@ -59,10 +59,15 @@ class SimilarityFocusOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SimilarityFocus");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SimilarityFocus");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "Input(X)'s rank should be 4."); PADDLE_ENFORCE_EQ(
x_dims.size(), 4,
platform::errors::InvalidArgument(
"The dimension size of Input(X) be 4, but received %d.",
x_dims.size()));
ctx->SetOutputDim("Out", x_dims); ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
......
...@@ -43,13 +43,19 @@ class SimilarityFocusKernel : public framework::OpKernel<T> { ...@@ -43,13 +43,19 @@ class SimilarityFocusKernel : public framework::OpKernel<T> {
dim[i] = x->dims()[i]; dim[i] = x->dims()[i];
} }
if (indexes.size() < 1) { PADDLE_ENFORCE_GT(
PADDLE_THROW("Indexes' size can not be 0."); indexes.size(), 0,
} platform::errors::InvalidArgument("The size of Attr(indexes) must be "
for (auto index : indexes) { "greater than 0, but received %d.",
if (dim[axis] < index) { indexes.size()));
PADDLE_THROW("Index exceeds tensor shape limit.");
} for (size_t i = 0; i < indexes.size(); i++) {
PADDLE_ENFORCE_GT(
dim[axis], indexes[i],
platform::errors::InvalidArgument(
"Each value of Attr(indexes) must be less than X.dim[axis], "
"but indexes[%d] received %d.",
i, indexes[i]));
} }
int64_t array_size = 1; int64_t array_size = 1;
...@@ -72,6 +78,16 @@ class SimilarityFocusKernel : public framework::OpKernel<T> { ...@@ -72,6 +78,16 @@ class SimilarityFocusKernel : public framework::OpKernel<T> {
d3 * dim[3] + d4; d3 * dim[3] + d4;
}; };
PADDLE_ENFORCE_GT(
axis, 0,
platform::errors::InvalidArgument(
"The value of Attr(axis) must be 1 or 2 or 3, but received %d.",
axis));
PADDLE_ENFORCE_LT(
axis, 4,
platform::errors::InvalidArgument(
"The value of Attr(axis) must be 1 or 2 or 3, but received %d.",
axis));
memset(out_data, 0, sizeof(T) * batch_size * dim[1] * dim[2] * dim[3]); memset(out_data, 0, sizeof(T) * batch_size * dim[1] * dim[2] * dim[3]);
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
for (auto index : indexes) { for (auto index : indexes) {
...@@ -156,8 +172,6 @@ class SimilarityFocusKernel : public framework::OpKernel<T> { ...@@ -156,8 +172,6 @@ class SimilarityFocusKernel : public framework::OpKernel<T> {
break; break;
} }
} }
} else {
PADDLE_THROW("Axis must be 1 or 2 or 3");
} }
} }
} }
......
...@@ -27,29 +27,19 @@ class SqueezeOp : public framework::OperatorWithKernel { ...@@ -27,29 +27,19 @@ class SqueezeOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Squeeze");
"Input(X) of Squeeze operator should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Squeeze");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of Squeeze operator should not be null.");
const auto &x_dims = ctx->GetInputDim("X"); const auto &x_dims = ctx->GetInputDim("X");
// Check input tensor dims (<6) Eigen limit. // Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE_LE(x_dims.size(), 6, PADDLE_ENFORCE_LE(x_dims.size(), 6,
"ShapeError: the dimensions of Input(X) " platform::errors::InvalidArgument(
"should be in the range of [1, 6] (Eigen limit)." "The dimensions of Input(X) "
"But received X's dimensions = %d, X's shape=[%s].", "should be in the range of [1, 6] (Eigen limit)."
x_dims.size(), x_dims); "But received X's dimensions = %d, X's shape=[%s].",
x_dims.size(), x_dims));
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes"); const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
for (int a : axes) {
PADDLE_ENFORCE_LT(
a, x_dims.size(),
"ShapeError: The squeeze axis should be less than input "
"tensor's dimensions. But received axis = %d, input "
"tensor's dimensions = %d, input tensor's shape = [%s].",
a, x_dims.size(), x_dims);
}
auto out_dims = GetOutputShape(axes, x_dims); auto out_dims = GetOutputShape(axes, x_dims);
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) { if (x_dims[0] == out_dims[0]) {
...@@ -78,10 +68,18 @@ class SqueezeOp : public framework::OperatorWithKernel { ...@@ -78,10 +68,18 @@ class SqueezeOp : public framework::OperatorWithKernel {
for (size_t idx = 0; idx < num_squeeze_dims; ++idx) { for (size_t idx = 0; idx < num_squeeze_dims; ++idx) {
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size() int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
: squeeze_dims[idx]; : squeeze_dims[idx];
PADDLE_ENFORCE_GE(current, 0, PADDLE_ENFORCE_GE(
"Invalid axis, the axis should >= 0." current, 0,
"Current axis is:%d, input tensor's shape = [%s].", platform::errors::InvalidArgument(
current, in_dims); "Each axis in Attr(axes) should be in the range of [%d, %d]"
"But current axis is:%d, input tensor's shape = [%s].",
-in_dims.size(), in_dims.size() - 1, current, in_dims));
PADDLE_ENFORCE_LT(
current, in_dims.size(),
platform::errors::InvalidArgument(
"Each axis in Attr(axes) should be in the range of [%d, %d]"
"But current axis is:%d, input tensor's shape = [%s].",
-in_dims.size(), in_dims.size() - 1, current, in_dims));
if (!(should_squeeze[current])) { if (!(should_squeeze[current])) {
++cnt_squeezed_dims; ++cnt_squeezed_dims;
...@@ -171,28 +169,19 @@ class Squeeze2Op : public framework::OperatorWithKernel { ...@@ -171,28 +169,19 @@ class Squeeze2Op : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Squeeze2");
"Input(X) of Squeeze operator should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Squeeze2");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of Squeeze operator should not be null.");
const auto &x_dims = ctx->GetInputDim("X"); const auto &x_dims = ctx->GetInputDim("X");
// Check input tensor dims (<6) Eigen limit. // Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE_LE(x_dims.size(), 6, PADDLE_ENFORCE_LE(x_dims.size(), 6,
"ShapeError: the dimensions of Input(X) " platform::errors::InvalidArgument(
"should be in the range of [1, 6] (Eigen limit)." "The dimensions of Input(X) "
"But received X's dimensions = %d, X's shape = [%s].", "should be in the range of [1, 6] (Eigen limit)."
x_dims.size(), x_dims); "But received X's dimensions = %d, X's shape = [%s].",
x_dims.size(), x_dims));
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes"); const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
for (int a : axes) {
PADDLE_ENFORCE_LT(
a, x_dims.size(),
"ShapeError: The squeeze axis should be less than input "
"tensor's dimensions. But received axis = %d, input "
"tensor's dimensions = %d, input tensor's shape = [%s].",
a, x_dims.size(), x_dims);
}
auto out_dims = SqueezeOp::GetOutputShape(axes, x_dims); auto out_dims = SqueezeOp::GetOutputShape(axes, x_dims);
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
...@@ -202,8 +191,8 @@ class Squeeze2Op : public framework::OperatorWithKernel { ...@@ -202,8 +191,8 @@ class Squeeze2Op : public framework::OperatorWithKernel {
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
} }
PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true, OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Squeeze2");
"Output(XShape) of Squeeze operator should not be null.");
std::vector<int64_t> xshape_dims(x_dims.size() + 1); std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0; xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) { for (int i = 0; i < x_dims.size(); ++i) {
...@@ -233,10 +222,10 @@ class Squeeze2GradOp : public framework::OperatorWithKernel { ...@@ -233,10 +222,10 @@ class Squeeze2GradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override { void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(context->HasInput("XShape"), true, OP_INOUT_CHECK(context->HasInput("XShape"), "Input", "XShape",
"Input(XShape) shouldn't be null."); "Squeeze2Grad");
PADDLE_ENFORCE_EQ(context->HasInput(framework::GradVarName("Out")), true, OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input",
"Input(Out@GRAD) shouldn't be null."); framework::GradVarName("Out"), "Squeeze2Grad");
auto xshape_dims = context->GetInputDim("XShape"); auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims); context->SetOutputDim(framework::GradVarName("X"), x_dims);
......
...@@ -62,16 +62,25 @@ class SqueezeKernel : public framework::OpKernel<T> { ...@@ -62,16 +62,25 @@ class SqueezeKernel : public framework::OpKernel<T> {
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size() int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
: squeeze_dims[idx]; : squeeze_dims[idx];
PADDLE_ENFORCE_GE(current, 0, PADDLE_ENFORCE_GE(
"Invalid axis, the axis should >= 0." current, 0,
"Current axis is:%d, input tensor's shape = [%s].", platform::errors::InvalidArgument(
current, in_dims); "Each axis in Attr(axes) should be in the range of [%d, %d]"
"But current axis is:%d, input tensor's shape = [%s].",
-in_dims.size(), in_dims.size() - 1, current, in_dims));
PADDLE_ENFORCE_LT(
current, in_dims.size(),
platform::errors::InvalidArgument(
"Each axis in Attr(axes) should be in the range of [%d, %d]"
"But current axis is:%d, input tensor's shape = [%s].",
-in_dims.size(), in_dims.size() - 1, current, in_dims));
PADDLE_ENFORCE_EQ(in_dims[current], 1, PADDLE_ENFORCE_EQ(in_dims[current], 1,
"Invalid axis index, the axis that will be squeezed " platform::errors::InvalidArgument(
"should be equal to 1. But current axis = %d," "The size of axis that will be squeezed "
"input tensor's shape = [%s].", "should be equal to 1. But current axis = %d,"
in_dims[current], in_dims); "input tensor's shape = [%s].",
in_dims[current], in_dims));
if (!(should_squeeze[current])) { if (!(should_squeeze[current])) {
++cnt_squeezed_dims; ++cnt_squeezed_dims;
......
...@@ -169,7 +169,7 @@ class WarpCTCGradOp : public framework::OperatorWithKernel { ...@@ -169,7 +169,7 @@ class WarpCTCGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("WarpCTCGrad"), "Input", "WarpCTCGrad", OP_INOUT_CHECK(ctx->HasInput("WarpCTCGrad"), "Input", "WarpCTCGrad",
"WarpCTCGrad"); "WarpCTCGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Logits")), "Output", OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Logits")), "Output",
"WarpCTCGrad", "WarpCTCGrad"); framework::GradVarName("Logits"), "WarpCTCGrad");
ctx->SetOutputDim(framework::GradVarName("Logits"), ctx->SetOutputDim(framework::GradVarName("Logits"),
ctx->GetInputDim("Logits")); ctx->GetInputDim("Logits"));
ctx->ShareLoD("Logits", /*->*/ framework::GradVarName("Logits")); ctx->ShareLoD("Logits", /*->*/ framework::GradVarName("Logits"));
......
...@@ -1323,15 +1323,9 @@ def rank_loss(label, left, right, name=None): ...@@ -1323,15 +1323,9 @@ def rank_loss(label, left, right, name=None):
""" """
helper = LayerHelper('rank_loss', **locals()) helper = LayerHelper('rank_loss', **locals())
check_variable_and_dtype(label, 'label', ['float32'], "rank_loss")
if not (isinstance(label, Variable)): check_variable_and_dtype(left, 'left', ['float32'], "rank_loss")
raise ValueError("The label should be a Variable") check_variable_and_dtype(right, 'right', ['float32'], "rank_loss")
if not (isinstance(left, Variable)):
raise ValueError("The left should be a Variable")
if not (isinstance(right, Variable)):
raise ValueError("The right should be a Variable")
out = helper.create_variable_for_type_inference("float32") out = helper.create_variable_for_type_inference("float32")
......
...@@ -13613,10 +13613,10 @@ def similarity_focus(input, axis, indexes, name=None): ...@@ -13613,10 +13613,10 @@ def similarity_focus(input, axis, indexes, name=None):
""" """
helper = LayerHelper('similarity_focus', **locals()) helper = LayerHelper('similarity_focus', **locals())
# check attrs # check attrs
if isinstance(axis, int) is False: check_variable_and_dtype(input, 'input', ['float32', 'float64'],
raise TypeError("axis must be int type.") "similarity_focus")
if isinstance(indexes, list) is False: check_type(axis, 'axis', int, "similarity_focus")
raise TypeError("indexes must be list type.") check_type(indexes, 'indexes', list, "similarity_focus")
if axis != 1 and axis != 2 and axis != 3: if axis != 1 and axis != 2 and axis != 3:
raise ValueError("axis must be 1, 2 or 3.") raise ValueError("axis must be 1, 2 or 3.")
if len(indexes) == 0: if len(indexes) == 0:
......
...@@ -17,6 +17,8 @@ from __future__ import print_function ...@@ -17,6 +17,8 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestRankLossOp(OpTest): class TestRankLossOp(OpTest):
...@@ -84,5 +86,31 @@ class TestRankLossOp5(TestRankLossOp): ...@@ -84,5 +86,31 @@ class TestRankLossOp5(TestRankLossOp):
return (batch_size), (batch_size), (batch_size) return (batch_size), (batch_size), (batch_size)
class TestRankLossOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
label = fluid.data(name="label", shape=[16, 1], dtype="float32")
left = fluid.data(name="left", shape=[16, 1], dtype="float32")
right = fluid.data(name="right", shape=[16, 1], dtype="float32")
def test_label_Variable():
label_data = np.random.rand(16, 1).astype("float32")
out = fluid.layers.rank_loss(label_data, left, right)
self.assertRaises(TypeError, test_label_Variable)
def test_left_Variable():
left_data = np.random.rand(16, 1).astype("float32")
out = fluid.layers.rank_loss(label, left_data, right)
self.assertRaises(TypeError, test_left_Variable)
def test_right_Variable():
right_data = np.random.rand(16, 1).astype("float32")
out = fluid.layers.rank_loss(label, left, right_data)
self.assertRaises(TypeError, test_right_Variable)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -18,6 +18,8 @@ import unittest ...@@ -18,6 +18,8 @@ import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestSimilarityFocusOp(OpTest): class TestSimilarityFocusOp(OpTest):
...@@ -213,5 +215,32 @@ class TestSimilarityFocusOp_axis3(OpTest): ...@@ -213,5 +215,32 @@ class TestSimilarityFocusOp_axis3(OpTest):
self.check_output() self.check_output()
class TestSimilarityFocusOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
data = fluid.data(name='data', shape=[16, 3, 2, 2], dtype='float32')
def test_input_Variable():
input = np.random.rand(16, 3, 2, 2).astype("float32")
out = fluid.layers.similarity_focus(
input=input, axis=1, indexes=[0])
self.assertRaises(TypeError, test_input_Variable)
def test_axis_Int():
axis = 1.0
out = fluid.layers.similarity_focus(
input=data, axis=axis, indexes=[0])
self.assertRaises(TypeError, test_axis_Int)
def test_indexes_List():
indexes = 0
out = fluid.layers.similarity_focus(
input=data, axis=1, indexes=indexes)
self.assertRaises(TypeError, test_indexes_List)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册