未验证 提交 a398464e 编写于 作者: L lijianshe02 提交者: GitHub

API/OP (affine_channel, group_norm, layer_norm, random_crop, unpool, … (#24118)

* API/OP (affine_channel, group_norm, layer_norm, random_crop, unpool, log_loss) error message enhancement test=develop
上级 ab4d3140
......@@ -61,14 +61,10 @@ class AffineChannelOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of AffineChannelOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Scale"),
"Input(Scale) of AffineChannelOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of AffineChannelOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of AffineChannelOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AffineChannel");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "AffineChannel");
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "AffineChannel");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "AffineChannel");
auto x_dims = ctx->GetInputDim("X");
auto scale_dims = ctx->GetInputDim("Scale");
......@@ -80,13 +76,32 @@ class AffineChannelOp : public framework::OperatorWithKernel {
? x_dims[1]
: x_dims[x_dims.size() - 1]);
PADDLE_ENFORCE_EQ(scale_dims.size(), 1UL);
PADDLE_ENFORCE_EQ(b_dims.size(), 1UL);
PADDLE_ENFORCE_EQ(
scale_dims.size(), 1UL,
platform::errors::InvalidArgument(
"The dimensions of Input(Scale) must be 1,"
"But received the dimensions of Input(Scale) is [%d] ",
scale_dims.size()));
PADDLE_ENFORCE_EQ(b_dims.size(), 1UL,
platform::errors::InvalidArgument(
"The dimensions of Input(Bias) must be 1,"
"But received the dimensions of Input(Bias) is [%d] ",
scale_dims.size()));
if (ctx->IsRuntime() || scale_dims[0] > 0) {
PADDLE_ENFORCE_EQ(scale_dims[0], C);
PADDLE_ENFORCE_EQ(
scale_dims[0], C,
platform::errors::InvalidArgument(
"The first dimension value of Input(Scale) must be [%d],"
"But received [%d].",
C, scale_dims[0]));
}
if (ctx->IsRuntime() || b_dims[0] > 0) {
PADDLE_ENFORCE_EQ(b_dims[0], C);
PADDLE_ENFORCE_EQ(
b_dims[0], C,
platform::errors::InvalidArgument(
"The first dimension value of Input(Bias) must be [%d],"
"But received [%d].",
C, b_dims[0]));
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
......@@ -98,19 +113,19 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "AffineChannelGrad");
if (ctx->HasOutput(framework::GradVarName("X"))) {
PADDLE_ENFORCE(ctx->HasInput("Scale"),
"Input(Scale) should not be null.");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale",
"AffineChannelGrad");
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
// Scale@GRAD and Bias@GRAD must exist at the same time.
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")),
"Output(Scale@GRAD) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Bias")), "Output",
framework::GradVarName("Bias"), "AffineChannelGrad");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AffineChannelGrad");
ctx->SetOutputDim(framework::GradVarName("Scale"),
ctx->GetInputDim("Scale"));
ctx->SetOutputDim(framework::GradVarName("Bias"),
......
......@@ -27,36 +27,62 @@ class LayerNormOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"),
"Output(Y) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Mean"),
"Output(Mean) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Variance"),
"Output(Variance) of LayerNormOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LayerNorm");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "LayerNorm");
OP_INOUT_CHECK(ctx->HasOutput("Mean"), "Output", "Mean", "LayerNorm");
OP_INOUT_CHECK(ctx->HasOutput("Variance"), "Output", "Variance",
"LayerNorm");
auto x_dim = ctx->GetInputDim("X");
auto begin_norm_axis = ctx->Attrs().Get<int>("begin_norm_axis");
PADDLE_ENFORCE_LT(begin_norm_axis, x_dim.size(),
"'begin_norm_axis' must be less than the rank of X.");
PADDLE_ENFORCE_LT(
begin_norm_axis, x_dim.size(),
platform::errors::InvalidArgument(
"'begin_norm_axis' must be less than the dimensions of X,"
"But received 'begin_norm_axis' is [%d],"
"received the dimensions of X is [%d].",
begin_norm_axis, x_dim.size()));
auto matrix_dim = framework::flatten_to_2d(x_dim, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
if (ctx->HasInput("Scale")) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1,
platform::errors::InvalidArgument(
"The dimensions of Input(Scale) must be 1, but "
"received dimensions of"
"Input(Scale) is [%d]",
ctx->GetInputDim("Scale").size()));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], right,
"scale should with right");
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Scale")[0], right,
platform::errors::InvalidArgument(
"The first dimension value of Input(Scale) must equal to be the"
"second dimension value of the flattened 2D matrix of Input(X),"
"But received the first dimension value of Input(Scale) is"
"[%d], the second dimension value of the flattened 2D matrix of"
" Input(Scale) is [%d].",
ctx->GetInputDim("Scale")[0], right));
}
}
if (ctx->HasInput("Bias")) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1,
platform::errors::InvalidArgument(
"The dimensions of Input(Bias) must be 1, but "
"received dimensions of"
"Input(Bias) is [%d]",
ctx->GetInputDim("Bias").size()));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], right,
"bias should with right");
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Bias")[0], right,
platform::errors::InvalidArgument(
"The first dimension value of Input(Bias) must equal to be the"
"second dimension value of the flattened 2D matrix of Input(X),"
"But received the first dimension value of Input(Bias) is"
"[%d], the second dimension value of the flattened 2D matrix of"
" Input(Bias) is [%d].",
ctx->GetInputDim("Scale")[0], right));
}
}
......@@ -90,8 +116,11 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
"Constant for numerical stability [default 1e-5].")
.SetDefault(1e-5)
.AddCustomChecker([](const float &epsilon) {
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f,
"'epsilon' should be between 0.0 and 0.001.");
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true,
platform::errors::InvalidArgument(
"'epsilon' in Op(LayerNorm) should be between"
"0.0 and 0.001, But received [%s].",
epsilon));
});
AddAttr<int>("begin_norm_axis",
"the axis of `begin_norm_axis ... Rank(X) - 1` will be "
......@@ -100,7 +129,10 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(1)
.AddCustomChecker([](const int &begin_norm_axis) {
PADDLE_ENFORCE_GT(begin_norm_axis, 0,
"'begin_norm_axis' should be greater than zero.");
platform::errors::InvalidArgument(
"'begin_norm_axis' in Op(LayerNorm) should be"
"greater than zero. But received [%d].",
begin_norm_axis));
});
AddComment(R"DOC(
......@@ -122,14 +154,12 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
// check input
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mean"),
"Input(Mean) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Variance"),
"Input(Variance) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) of LayerNormOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LayerNormGrad");
OP_INOUT_CHECK(ctx->HasInput("Mean"), "Input", "Mean", "LayerNormGrad");
OP_INOUT_CHECK(ctx->HasInput("Variance"), "Input", "Variance",
"LayerNormGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input",
framework::GradVarName("Y"), "LayerNormGrad");
// check output
if (ctx->HasOutput(framework::GradVarName("X"))) {
......
......@@ -23,25 +23,37 @@ class LogLossOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Predicted"),
"Input(Predicted) must be initialized.");
PADDLE_ENFORCE(ctx->HasInput("Labels"),
"Input(Labels) must be initialized.");
OP_INOUT_CHECK(ctx->HasInput("Predicted"), "Input", "Predicted", "LogLoss");
OP_INOUT_CHECK(ctx->HasInput("Labels"), "Input", "Labels", "LogLoss");
auto pred_dims = ctx->GetInputDim("Predicted");
auto label_dims = ctx->GetInputDim("Labels");
if (ctx->IsRuntime() || (framework::product(pred_dims) > 0 &&
framework::product(label_dims) > 0)) {
PADDLE_ENFORCE_EQ(pred_dims, label_dims);
PADDLE_ENFORCE_EQ(
pred_dims, label_dims,
platform::errors::InvalidArgument(
"The dimensions of Input(Predicted) must be equal to the"
"dimensions of Input(Labels), but received dimensions of "
"Input(Predicted)"
"is [%s], received dimensions of Input(Labels) is [%s].",
pred_dims, label_dims));
}
PADDLE_ENFORCE_EQ(pred_dims.size(), 2,
"The rank of Input(Predicted) must be 2 and the shape is "
"[batch_size, 1].");
platform::errors::InvalidArgument(
"The dimensions of Input(Predicted) must be 2,"
"But received dimensions of Input(Predicted)"
"is [%d]",
pred_dims.size()));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(pred_dims[1], 1,
"Each row of Input(Predicted) contains a real value, "
"so the 2nd dimension of Input(X) must be 1.");
PADDLE_ENFORCE_EQ(
pred_dims[1], 1,
platform::errors::InvalidArgument(
"Each row of Input(Predicted) contains a real value, "
"so the 2nd dimension of Input(X) must be 1,"
"But got [%d]",
pred_dims[1]));
}
ctx->SetOutputDim("Loss", {pred_dims[0], 1});
ctx->ShareLoD("Predicted", "Loss");
......@@ -87,18 +99,25 @@ class LogLossGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Predicted"),
"Input(Predicted) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Labels"),
"Input(Labels) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
"Input(Loss@GRAD) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Predicted")),
"Output(Predicted@GRAD) should not be null.");
OP_INOUT_CHECK(ctx->HasInput("Predicted"), "Input", "Predicted",
"LogLossGrad");
OP_INOUT_CHECK(ctx->HasInput("Labels"), "Input", "Labels", "LogLossGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Loss")), "Input",
framework::GradVarName("Loss"), "LogLossGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Predicted")),
"Output", framework::GradVarName("Predicted"),
"LogLossGrad");
auto pred_dims = ctx->GetInputDim("Predicted");
auto loss_grad_dims = ctx->GetInputDim(framework::GradVarName("Loss"));
PADDLE_ENFORCE_EQ(loss_grad_dims, pred_dims);
PADDLE_ENFORCE_EQ(loss_grad_dims, pred_dims,
platform::errors::InvalidArgument(
"The dimensions of loss_grad must be equal to the "
"dimensions of Predicted,"
"But received dimensions of loss_grad is [%s], "
"received Predicted is "
"[%s]",
loss_grad_dims, pred_dims));
auto pred_grad_name = framework::GradVarName("Predicted");
ctx->SetOutputDim(pred_grad_name, pred_dims);
......
......@@ -27,7 +27,11 @@ class RandomCropOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GT(
x_dim.size(), static_cast<int64_t>(shape.size()),
platform::errors::InvalidArgument(
"Rank of Input(X) must be equal to length of Attr(shape)"));
"The dimensions of Input(X) must be greater than the length of "
"Attr(shape),"
"But received dimensions of Input(X) is [%d], receivecd length"
"of Attr(shape) is [%d].",
x_dim.size(), static_cast<int64_t>(shape.size())));
auto out_dim = framework::vectorize<int>(x_dim);
for (size_t i = 1; i <= shape.size(); ++i) {
size_t x_i = x_dim.size() - i;
......@@ -36,7 +40,10 @@ class RandomCropOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GE(
x_dim[x_i], shape[shape_i],
platform::errors::InvalidArgument(
"Size of Input(X) must be larger than Attr(shape)"));
"The dimensions of Input(X) must be larger than Attr(shape),"
"But received dimensions of Input(X) is [%d], received"
"size of Attr(shape) is [%d].",
x_dim[x_i], shape[shape_i]));
}
out_dim[x_i] = shape[shape_i];
}
......
......@@ -83,15 +83,9 @@ class UnpoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of UnpoolOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Indices"), true,
platform::errors::NotFound("Input(Indices) of UnpoolOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of UnpoolOp is not found."));
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Unpool");
OP_INOUT_CHECK(ctx->HasInput("Indices"), "Input", "Indices", "Unpool");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Unpool");
auto in_x_dims = ctx->GetInputDim("X");
auto in_y_dims = ctx->GetInputDim("Indices");
std::string unpooling_type =
......@@ -101,8 +95,8 @@ class UnpoolOp : public framework::OperatorWithKernel {
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE_EQ(in_x_dims.size() == 4, true,
platform::errors::InvalidArgument(
"Unpooling intput(X) must be of 4-dimensional, but "
"received X's dimension is %d.",
"Unpooling Intput(X) must be of 4-dimensional, but "
"received Input(X)'s dimension is %d.",
in_x_dims.size()));
PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims);
......@@ -146,12 +140,9 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of UnpoolOpGradOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::NotFound(
"Input(X@GRAD) of UnpoolOpGradOp is not found."));
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "UnpoolGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "UnpoolGrad");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
......
......@@ -3406,6 +3406,8 @@ def layer_norm(input,
assert in_dygraph_mode(
) is not True, "please use LayerNorm instead of layer_norm in dygraph mode!"
helper = LayerHelper('layer_norm', **locals())
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'layer_norm')
dtype = helper.input_dtype()
# create intput and parameters
......@@ -3510,7 +3512,8 @@ def group_norm(input,
"""
helper = LayerHelper('group_norm', **locals())
dtype = helper.input_dtype()
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'group_norm')
# create intput and parameters
inputs = {'X': input}
input_shape = input.shape
......@@ -8182,6 +8185,10 @@ def random_crop(x, shape, seed=None):
"""
helper = LayerHelper("random_crop", **locals())
check_variable_and_dtype(x, 'x',
['float32', 'float64', 'uint8', 'int16', 'int32'],
'random_crop')
check_type(shape, 'shape', (list, Variable), 'random_crop')
dtype = x.dtype
out = helper.create_variable_for_type_inference(dtype)
if seed is None:
......@@ -12072,6 +12079,9 @@ def affine_channel(x,
"""
helper = LayerHelper("affine_channel", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'affine_channel')
check_type(scale, 'scale', (Variable, type(None)), 'affine_channel')
check_type(bias, 'bias', (Variable, type(None)), 'affine_channel')
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
......@@ -12390,6 +12400,8 @@ def log_loss(input, label, epsilon=1e-4, name=None):
cost = fluid.layers.log_loss(input=prob, label=label)
"""
helper = LayerHelper('log_loss', **locals())
check_variable_and_dtype(input, 'input', ['float32'], 'log_loss')
check_variable_and_dtype(label, 'label', ['float32'], 'log_loss')
loss = helper.create_variable_for_type_inference(dtype=input.dtype)
......
......@@ -21,6 +21,7 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
def affine_channel(x, scale, bias, layout):
......@@ -67,6 +68,38 @@ class TestAffineChannelOp(OpTest):
self.layout = 'NCHW'
class TestAffineChannelOpError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program()):
def test_x_type():
input_data = np.random.random(2, 1, 2, 2).astype("float32")
fluid.layers.affine_channel(input_data)
self.assertRaises(TypeError, test_x_type)
def test_x_dtype():
x2 = fluid.layers.data(
name='x2', shape=[None, 1, 2, 2], dtype='int32')
fluid.layers.affine_channel(x2)
self.assertRaises(TypeError, test_x_dtype)
def test_scale_type():
x3 = fluid.layers.data(
name='x3', shape=[None, 1, 2, 2], dtype='float32')
fluid.layers.affine_channel(x3, scale=1)
self.assertRaises(TypeError, test_scale_type)
def test_bias_type():
x4 = fluid.layers.data(
name='x4', shape=[None, 1, 2, 2], dtype='float32')
fluid.layers.affine_channel(x4, bias=1)
self.assertRaises(TypeError, test_bias_type)
class TestAffineChannelNHWC(TestAffineChannelOp):
def init_test_case(self):
self.shape = [2, 12, 12, 100]
......
......@@ -40,6 +40,26 @@ def group_norm_naive(x, scale, bias, epsilon, groups, data_layout):
return output, mean.reshape((N, G)), var.reshape((N, G))
class TestGroupNormOpError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
def test_x_type():
input = np.random.random(2, 100, 3, 5).astype('float32')
goups = 2
fluid.layers.group_norm(input, groups)
self.assertRaises(TypeError, test_x_type)
def test_x_dtype():
x2 = fluid.layers.data(
name='x2', shape=[2, 100, 3, 5], dtype='int32')
groups = 2
fluid.layers.group_norm(x2, groups)
self.assertRaises(TypeError, test_x_dtype)
class TestGroupNormOp(OpTest):
def setUp(self):
self.op_type = "group_norm"
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
def sigmoid_array(x):
......@@ -49,5 +50,34 @@ class TestLogLossOp(OpTest):
self.check_grad(['Predicted'], 'Loss', max_relative_error=0.03)
class TestLogLossOpError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program()):
def test_x_type():
input_data = np.random.random(100, 1).astype("float32")
fluid.layers.log_loss(input_data)
self.assertRaises(TypeError, test_x_type)
def test_x_dtype():
x2 = fluid.layers.data(name='x2', shape=[100, 1], dtype='int32')
fluid.layers.log_loss(x2)
self.assertRaises(TypeError, test_x_dtype)
def test_label_type():
input_data = np.random.random(100, 1).astype("float32")
fluid.layers.log_loss(input_data)
self.assertRaises(TypeError, test_label_type)
def test_label_dtype():
x2 = fluid.layers.data(name='x2', shape=[100, 1], dtype='int32')
fluid.layers.log_loss(x2)
self.assertRaises(TypeError, test_label_dtype)
if __name__ == '__main__':
unittest.main()
......@@ -18,6 +18,7 @@ import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import paddle.fluid as fluid
class TestRandomCropOp(OpTest):
......@@ -45,5 +46,30 @@ class TestRandomCropOp(OpTest):
self.assertIn(True, is_equal)
class TestRandomCropOpError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program()):
def test_x_type():
input_data = np.random.random(2, 3, 256, 256).astype("float32")
fluid.layers.random_crop(input_data)
self.assertRaises(TypeError, test_x_type)
def test_x_dtype():
x2 = fluid.layers.data(
name='x2', shape=[None, 3, 256, 256], dtype='float16')
fluid.layers.random_crop(x2)
self.assertRaises(TypeError, test_x_dtype)
def test_shape_type():
x3 = fluid.layers.data(
name='x3', shape=[None, 3, 256, 256], dtype='float32')
fluid.layers.random_crop(x3, shape=1)
self.assertRaises(TypeError, test_shape_type)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册