未验证 提交 2787944c 编写于 作者: Z zhupengyang 提交者: GitHub

Ops(relu6/selu/soft_relu/softshrink/stanh/swish/thresholded_relu/hard_shrink/h...

Ops(relu6/selu/soft_relu/softshrink/stanh/swish/thresholded_relu/hard_shrink/hard_sigmoid/hard_swish/hsigmoid/maxout) error message enhancement (#23718)
上级 0b6f09e7
......@@ -61,30 +61,15 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of HierarchicalSigmoidOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Label"), true,
platform::errors::NotFound(
"Input(Label) of HierarchicalSigmoidOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
platform::errors::NotFound(
"Input(W) of HierarchicalSigmoidOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of HierarchicalSigmoidOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("PreOut"), true,
platform::errors::NotFound(
"Output(PreOut) of HierarchicalSigmoidOp is not found."));
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "hsigmoid");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "hsigmoid");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "hsigmoid");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "hsigmoid");
OP_INOUT_CHECK(ctx->HasOutput("PreOut"), "Output", "PreOut", "hsigmoid");
auto with_prefetch = ctx->Attrs().Get<bool>("remote_prefetch");
if (with_prefetch) {
PADDLE_ENFORCE_EQ(
ctx->HasOutput("W_Out"), true,
platform::errors::NotFound(
"Output(W_Out) of HierarchicalSigmoidOp is not found."));
OP_INOUT_CHECK(ctx->HasOutput("W_Out"), "Output", "W_Out", "hsigmoid");
}
const int64_t batch_size = ctx->GetInputDim("X")[0];
std::vector<int64_t> output_shape({batch_size, 1});
......@@ -213,30 +198,15 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("W"), true,
platform::errors::NotFound(
"Input(W) of HierarchicalSigmoidGradOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Label"), true,
platform::errors::NotFound(
"Input(Label) of HierarchicalSigmoidGradOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::NotFound(
"Input(Out@Grad) of HierarchicalSigmoidGradOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("PreOut"), true,
platform::errors::NotFound(
"Input(Preout) of HierarchicalSigmoidGradOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("W")), true,
platform::errors::NotFound(
"Output(W@Grad of HierarchicalSigmoidGradOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::NotFound(
"Output(X@Grad of HierarchicalSigmoidGradOp is not found."));
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "hsigmoid_grad");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "hsigmoid_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@Grad", "hsigmoid_grad");
OP_INOUT_CHECK(ctx->HasInput("PreOut"), "Input", "PreOut", "hsigmoid_grad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("W")), "Output",
"W@Grad", "hsigmoid_grad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
"X@Grad", "hsigmoid_grad");
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"),
......
......@@ -203,8 +203,9 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
zero(dev_ctx, w_grad, static_cast<T>(0.0));
bit_code->MulGradWeight(pre_out_grad, w_grad, in);
} else {
PADDLE_ENFORCE(path != nullptr,
"Sparse mode should not be used without custom tree!");
PADDLE_ENFORCE_NOT_NULL(path,
platform::errors::NotFound(
"Custom tree must be set for sparse mode!"));
framework::Vector<int64_t> real_rows = PathToRows(*path);
auto* w_grad =
ctx.Output<framework::SelectedRows>(framework::GradVarName("W"));
......
......@@ -72,24 +72,26 @@ class MaxOutOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of MaxoutOpshould not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of MaxoutOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "maxout");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "maxout");
auto in_x_dims = ctx->GetInputDim("X");
int groups = ctx->Attrs().Get<int>("groups");
int axis = ctx->Attrs().Get<int>("axis");
// check groups > 1
PADDLE_ENFORCE_GT(groups, 1,
"Attr(groups) of Op(maxout) should be larger than 1.");
PADDLE_ENFORCE_GT(groups, 1, platform::errors::InvalidArgument(
"Attr(groups) of Op(maxout) should be "
"larger than 1. But received %d.",
groups));
PADDLE_ENFORCE_EQ(
in_x_dims[axis] % groups, 0,
"ValueError: The number of input channels for Op(maxout) "
"should be divisible by Attr(groups). But received: the "
"input's channels is [%d], the shape of input is [%s], "
"the Attr(groups) is [%d], the Attr(axis) is [%d]. The "
"error may come from wrong Attr(groups) or Attr(axis) setting.",
in_x_dims[axis], in_x_dims, groups, axis);
platform::errors::InvalidArgument(
"The number of input channels for Op(maxout) "
"should be divisible by Attr(groups). But received: the "
"input's channels is [%d], the shape of input is [%s], "
"the Attr(groups) is [%d], the Attr(axis) is [%d]. The "
"error may come from wrong Attr(groups) or Attr(axis) setting.",
in_x_dims[axis], in_x_dims, groups, axis));
std::vector<int64_t> output_shape(
{in_x_dims[0], in_x_dims[1], in_x_dims[2], in_x_dims[3]});
output_shape[axis] = in_x_dims[axis] / groups;
......@@ -101,10 +103,9 @@ class MaxOutOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MaxOutOpGrad must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(Grad@X) of MaxOutOpGrad should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "maxout_grad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
"X@Grad", "maxout_grad");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
......
......@@ -28,10 +28,8 @@ class SeluOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SeluOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SeluOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "selu");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "selu");
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
......@@ -105,9 +103,9 @@ class SeluGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "selu_grad");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "selu_grad");
auto x_grad_name = framework::GradVarName("X");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("Out"));
}
......
......@@ -923,6 +923,8 @@ def hsigmoid(input,
value=0.05), bias_attr=fluid.initializer.Constant(value=.0))
# out = [[0.62792355], [0.62792355], [0.62792355], [0.62792355]]
"""
check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'hsigmoid')
check_variable_and_dtype(label, 'label', ['int64'], 'hsigmoid')
helper = LayerHelper('hierarchical_sigmoid', **locals())
dtype = helper.input_dtype()
......
......@@ -8280,6 +8280,8 @@ def selu(x, scale=None, alpha=None, name=None):
res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output])
print(res) # [array([[0. , 1.050701],[2.101402, 3.152103]], dtype=float32)]
"""
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'selu')
helper = LayerHelper('selu', **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
......@@ -8888,6 +8890,8 @@ def relu6(x, threshold=6.0, name=None):
# [[0. 0. ]
# [2.5 6. ]]
"""
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'relu6')
helper = LayerHelper('relu6', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
......@@ -8980,6 +8984,8 @@ def stanh(x, scale_a=0.67, scale_b=1.7159, name=None):
# [0.62705994, 0.23110689, 0.56902856]], dtype=float32)]
"""
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'stanh')
helper = LayerHelper('stanh', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
......@@ -9014,6 +9020,9 @@ def hard_sigmoid(x, slope=0.2, offset=0.5, name=None):
data = fluid.layers.fill_constant(shape=[3, 2], value=0.5, dtype='float32') # [[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]]
result = fluid.layers.hard_sigmoid(data) # [[0.6, 0.6], [0.6, 0.6], [0.6, 0.6]]
"""
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'hard_sigmoid')
helper = LayerHelper('hard_sigmoid', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
......@@ -9094,6 +9103,8 @@ def swish(x, beta=1.0, name=None):
# array([[-0.03916847, 0.8835007 , -0.25835553],
# [ 0.51126915, 0.82324016, 0.06915068]], dtype=float32)
"""
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'swish')
helper = LayerHelper('swish', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
......@@ -9293,6 +9304,9 @@ def soft_relu(x, threshold=40.0, name=None):
res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output])
print(res) # [array([[0.6931472, 1.3132616], [2.126928 , 3.0485873]], dtype=float32)]
"""
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'soft_relu')
helper = LayerHelper('soft_relu', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
......@@ -11786,6 +11800,8 @@ def maxout(x, groups, name=None, axis=1):
dtype='float32')
out = fluid.layers.maxout(input, groups=2)
"""
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'maxout')
helper = LayerHelper("maxout", **locals())
if axis not in [1, -1, 3]:
raise ValueError(
......@@ -14005,6 +14021,9 @@ def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None):
out, = exe.run(feed={'x':x_data}, fetch_list=[y.name])
print(out) # [[0.66666667, 1.66666667,3., 4.]]
"""
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'hard_swish')
helper = LayerHelper('hard_swish', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
......
......@@ -17,6 +17,7 @@ import os
from .layer_function_generator import generate_layer_fn, generate_activation_fn
from .. import core
from ..framework import convert_np_dtype_to_dtype_
from ..data_feeder import check_variable_and_dtype
__activations_noattr__ = [
'sigmoid',
......@@ -64,6 +65,9 @@ _softshrink_ = generate_layer_fn('softshrink')
def softshrink(x, alpha=None):
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'softshrink')
locals_var = locals().copy()
kwargs = dict()
for name, val in locals_var.items():
......@@ -107,6 +111,9 @@ _hard_shrink_ = generate_layer_fn('hard_shrink')
def hard_shrink(x, threshold=None):
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'hard_shrink')
locals_var = locals().copy()
kwargs = dict()
for name, val in locals_var.items():
......@@ -163,6 +170,9 @@ _thresholded_relu_ = generate_layer_fn('thresholded_relu')
def thresholded_relu(x, threshold=None):
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'thresholded_relu')
locals_var = locals().copy()
kwargs = dict()
for name, val in locals_var.items():
......
......@@ -220,6 +220,19 @@ class TestHardShrink(TestActivation):
self.check_grad(['X'], 'Out')
class TestHardShrinkOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.hard_shrink, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.hard_shrink, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.hard_shrink(x_fp16)
class TestSoftShrink(TestActivation):
def setUp(self):
self.op_type = "softshrink"
......@@ -241,6 +254,19 @@ class TestSoftShrink(TestActivation):
self.check_grad(['X'], 'Out')
class TestSoftShrinkOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.softshrink, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.softshrink, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.softshrink(x_fp16)
class TestSqrt(TestActivation, TestParameter):
def setUp(self):
self.op_type = "sqrt"
......@@ -586,6 +612,19 @@ class TestRelu6(TestActivation):
self.check_grad(['X'], 'Out')
class TestRelu6OpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.relu6, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.relu6, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.relu6(x_fp16)
class TestHardSwish(TestActivation):
def setUp(self):
self.op_type = 'hard_swish'
......@@ -610,6 +649,19 @@ class TestHardSwish(TestActivation):
self.check_grad(['X'], 'Out')
class TestHardSwishOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.hard_swish, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.hard_swish, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.hard_swish(x_fp16)
class TestSoftRelu(TestActivation):
def setUp(self):
self.op_type = "soft_relu"
......@@ -635,6 +687,19 @@ class TestSoftRelu(TestActivation):
self.check_grad(['X'], 'Out', max_relative_error=0.02)
class TestSoftReluOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.soft_relu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.soft_relu, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.soft_relu(x_fp16)
class TestELU(TestActivation):
def setUp(self):
self.op_type = "elu"
......@@ -812,6 +877,19 @@ class TestSTanh(TestActivation):
self.check_grad(['X'], 'Out')
class TestSTanhOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.stanh, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.stanh, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.stanh(x_fp16)
class TestSoftplus(TestActivation):
def setUp(self):
self.op_type = "softplus"
......@@ -870,6 +948,19 @@ class TestThresholdedRelu(TestActivation):
self.check_grad(['X'], 'Out')
class TestThresholdedReluOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.thresholded_relu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.thresholded_relu, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.thresholded_relu(x_fp16)
class TestHardSigmoid(TestActivation):
def setUp(self):
self.op_type = "hard_sigmoid"
......@@ -899,6 +990,19 @@ class TestHardSigmoid(TestActivation):
self.check_grad(['X'], 'Out')
class TestHardSigmoidOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.hard_sigmoid, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.hard_sigmoid, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.hard_sigmoid(x_fp16)
class TestSwish(TestActivation):
def setUp(self):
self.op_type = "swish"
......@@ -918,6 +1022,19 @@ class TestSwish(TestActivation):
self.check_grad(['X'], 'Out', max_relative_error=0.008)
class TestSwishOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.swish, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.swish, x_int32)
# support the input dtype is float16
x_fp16 = fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.swish(x_fp16)
#------------------ Test Cudnn Activation----------------------
def create_test_act_cudnn_class(parent, atol=1e-3, grad_atol=1e-3):
@unittest.skipIf(not core.is_compiled_with_cuda(),
......
......@@ -18,6 +18,7 @@ import unittest
import numpy as np
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import math
from op_test import OpTest, skip_check_grad_ci
......@@ -378,5 +379,27 @@ class TestHSigmoidOpWithCostumTreeWithoutBias(OpTest):
self.check_grad(['X', 'W'], ['Out'], no_grad_set=set('Label'))
class TestHSigmoidOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
label = fluid.data('label', [4, 1], 'int64')
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.hsigmoid, 1, label, 2)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[4, 3], dtype='int32')
self.assertRaises(TypeError, fluid.layers.hsigmoid, x_int32, label,
2)
# support the input dtype is float32
x_fp32 = fluid.data(name='x_fp32', shape=[4, 3], dtype='float32')
fluid.layers.hsigmoid(x_fp32, label, 2)
# The label type must be Variable.
self.assertRaises(TypeError, fluid.layers.hsigmoid, x_fp32, 1, 2)
# The label dtype must be int64.
label_int32 = fluid.data('label_int32', [4, 1], 'int32')
self.assertRaises(TypeError, fluid.layers.hsigmoid, x_fp32,
label_int32, 2)
if __name__ == '__main__':
unittest.main()
......@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import paddle.fluid.core as core
from op_test import OpTest
......@@ -96,5 +97,18 @@ class TestMaxOutOpAxisAPI(unittest.TestCase):
self.assertRaises(ValueError, _attr_axis)
class TestMaxOutOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.maxout, 1, 2)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.maxout, x_int32, 2)
# support the input dtype is float32
x_fp32 = fluid.data(name='x_fp32', shape=[12, 10], dtype='float32')
fluid.layers.maxout(x_fp32, 2)
if __name__ == '__main__':
unittest.main()
......@@ -18,6 +18,8 @@ import unittest
import numpy as np
import six
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class SeluTest(OpTest):
......@@ -67,5 +69,18 @@ class SeluTest(OpTest):
self.check_grad(['X'], 'Out')
class TestSeluOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.selu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.selu, x_int32)
# support the input dtype is float32
x_fp32 = fluid.data(name='x_fp32', shape=[12, 10], dtype='float32')
fluid.layers.selu(x_fp32)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册