diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 9a1d724c73962e37f71102afd65c49bbc14088cb..72c023dd9924351543a496a70645e5aa876cc639 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -58,8 +58,18 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { const DataLayout data_layout = framework::StringToDataLayout( ctx->Attrs().Get("data_layout")); - PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, - "Input X must have 2 to 5 dimensions."); + PADDLE_ENFORCE_GE( + x_dims.size(), 2, + "ShapeError: the dimension of input X must greater than or equal to 2." + "But received: the shape of input X = [%s], the dimension of input X =" + "[%d]", + x_dims, x_dims.size()); + PADDLE_ENFORCE_LE( + x_dims.size(), 5, + "ShapeError: the dimension of input X must smaller than or equal to 5." + "But received: the shape of input X = [%s], the dimension of input X =" + "[%d]", + x_dims, x_dims.size()); const int64_t C = (data_layout == DataLayout::kNCHW ? x_dims[1] @@ -68,8 +78,16 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { auto scale_dim = ctx->GetInputDim("Scale"); auto bias_dim = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL); - PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL); + PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL, + "ShapeError: the dimension of scale must equal to 1." + "But received: the shape of scale is [%s], the dimension " + "of scale is [%d]", + scale_dim, scale_dim.size()); + PADDLE_ENFORCE_EQ( + bias_dim.size(), 1UL, + "ShapeError: the dimension of bias must equal to 1." + "But received: the shape of bias is [%s],the dimension of bias is [%d]", + bias_dim, bias_dim.size()); bool check = true; if ((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 || @@ -78,8 +96,14 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { } if (check) { - PADDLE_ENFORCE_EQ(scale_dim[0], C); - PADDLE_ENFORCE_EQ(scale_dim[0], C); + PADDLE_ENFORCE_EQ(scale_dim[0], C, + "ShapeError: the shape of scale must equal to [%d]" + "But received: the shape of scale is [%d]", + C, scale_dim[0]); + PADDLE_ENFORCE_EQ(bias_dim[0], C, + "ShapeError: the shape of bias must equal to [%d]" + "But received: the shape of bias is [%d]", + C, bias_dim[0]); } ctx->SetOutputDim("Y", x_dims); ctx->SetOutputDim("MeanOut", {C}); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 1230b848fddd6b951b52b187d8febe27819478a0..7e8a94c9e506dc72a4439868f6e9ec63f4a17bf0 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -51,25 +51,49 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true, - "Conv intput should be 4-D or 5-D tensor, get %u", - in_dims.size()); + "ShapeError: Conv input should be 4-D or 5-D tensor. But " + "received: %u-D Tensor," + "the shape of Conv input is [%s]", + in_dims.size(), in_dims); PADDLE_ENFORCE_EQ( in_dims.size(), filter_dims.size(), - "Conv input dimension and filter dimension should be the same."); - PADDLE_ENFORCE_EQ( - in_dims.size() - strides.size() == 2U, true, - "Conv input dimension and strides dimension should be consistent."); + "ShapeError: Conv input dimension and filter dimension should be the " + "equal." + "But received: the shape of Conv input is [%s], input dimension of Conv " + "input is [%d]," + "the shape of filter is [%s], the filter dimension of Conv is [%d]", + in_dims, in_dims.size(), filter_dims, filter_dims.size()); + + int in_sub_stride_size = in_dims.size() - strides.size(); + PADDLE_ENFORCE_EQ(in_dims.size() - strides.size() == 2U, true, + "ShapeError: the dimension of input minus the dimension of " + "stride must be euqal to 2." + "But received: the dimension of input minus the dimension " + "of stride is [%d], the" + "input dimension of Conv is [%d], the shape of Conv input " + "is [%s], the stride" + "dimension of Conv is [%d]", + in_sub_stride_size, in_dims.size(), in_dims, + strides.size()); const auto input_channels = channel_last ? in_dims[in_dims.size() - 1] : in_dims[1]; - PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, - "The number of input channels should be equal to filter " - "channels * groups."); + PADDLE_ENFORCE_EQ( + input_channels, filter_dims[1] * groups, + "ShapeError: The number of input channels should be equal to filter " + "channels * groups. But received: the input channels is [%d], the shape" + "of input is [%s], the filter channel is [%d], the shape of filter is " + "[%s]," + "the groups is [%d]", + in_dims[1], in_dims, filter_dims[1], filter_dims, groups); PADDLE_ENFORCE_EQ( filter_dims[0] % groups, 0, - "The number of output channels should be divided by groups."); + "ShapeError: The number of output channels should be divided by groups." + "But received: the output channels is [%d], the shape of filter is [%s]" + "(the first dimension of filter is output channel), the groups is [%d]", + filter_dims[0], filter_dims, groups); framework::DDim in_data_dims; if (channel_last) { diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index 624b2b9c00de1e6812496a9164a4189c27e87146..66986e14d8f4a25c57c2c90c422171564874239f 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -41,30 +41,56 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel { bool check = ctx->IsRuntime() || !contain_unknown_dim; if (check) { - PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), - framework::slice_ddim(label_dims, 0, rank - 1), - "Input(X) and Input(Label) shall have the same shape " - "except the last dimension."); + PADDLE_ENFORCE_EQ( + framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "ShapeError: Input(X) and Input(Label) shall have the same shape " + "except the last dimension. But received: the shape of Input(X) is " + "[%s]," + "the shape of Input(Label) is [%s].", + x_dims, label_dims); } if (IsSoftLabel(ctx)) { PADDLE_ENFORCE_EQ( rank, label_dims.size(), - "If Attr(soft_label) == true, Input(X) and Input(Label) " - "shall have the same rank."); + "ShapeError: If Attr(soft_label) == true, Input(X) and Input(Label) " + "shall have the same dimensions. But received: the dimensions of " + "Input(X) is [%d]," + "the shape of Input(X) is [%s], the dimensions of Input(Label) is " + "[%d], the shape of" + "Input(Label) is [%s]", + rank, x_dims, label_dims.size(), label_dims); + if (check) { - PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], - "If Attr(soft_label) == true, the last dimension of " - "Input(X) and Input(Label) should be equal."); + PADDLE_ENFORCE_EQ( + x_dims[rank - 1], label_dims[rank - 1], + "ShapeError: If Attr(soft_label) == true, the last dimension of " + "Input(X) and Input(Label) should be equal. But received: the" + "last dimension of Input(X) is [%d], the shape of Input(X) is [%s]," + "the last dimension of Input(Label) is [%d], the shape of " + "Input(Label)" + "is [%s], the last dimension is [%d].", + x_dims[rank - 1], x_dims, label_dims[rank - 1], label_dims, + rank - 1); } } else { if (rank == label_dims.size()) { - PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL, - "the last dimension of Input(Label) should be 1."); - } else { PADDLE_ENFORCE_EQ( - rank, label_dims.size() + 1, - "The rank of Input(X) should be equal to Input(Label) plus 1."); + label_dims[rank - 1], 1UL, + "ShapeError: the last dimension of Input(Label) should be 1." + "But received: the last dimension of Input(Label) is [%d]," + "the last dimension is [%d]", + label_dims[rank - 1], rank - 1); + } else { + PADDLE_ENFORCE_EQ(rank, label_dims.size() + 1, + "ShapeError: The rank of Input(X) should be equal to " + "Input(Label) plus 1." + "But received: The dimension of Input(X) is [%d], " + "the shape of Input(X) is [%s]," + "the dimension of Input(Label) is [%d], the shape of " + "Input(Label) is [%s]", + rank, x_dims, label_dims.size(), label_dims); } } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 261ebd9fb30907bee3ca3f94f12af70a390d9ee5..c7bc28091df3c97b20be081899bfa8020ffe8c6f 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1677,6 +1677,20 @@ def dropout(x, """ helper = LayerHelper('dropout', **locals()) + + if not isinstance(x, Variable): + raise TypeError( + "The type of 'input' in dropout must be Variable, but received %s" % + (type(x))) + if convert_dtype(x.dtype) in ['float16']: + warnings.warn( + "The data type of 'input' in dropout only support float16 on GPU now." + ) + if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']: + raise TypeError( + "The data type of 'input' in dropout must be float16 or float32 or float64, but received %s." + % (convert_dtype(x.dtype))) + out = helper.create_variable_for_type_inference(dtype=x.dtype) mask = helper.create_variable_for_type_inference( dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) @@ -1749,6 +1763,19 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex): predict = fluid.layers.fc(input=x, size=class_num, act='softmax') cost = fluid.layers.cross_entropy(input=predict, label=label) """ + if not isinstance(input, Variable): + raise TypeError( + "The type of 'input' in cross_entropy must be Variable, but received %s" + % (type(input))) + if convert_dtype(input.dtype) in ['float16']: + warnings.warn( + "The data type of 'input' in cross_entropy only support float16 on GPU now." + ) + if convert_dtype(input.dtype) not in ['float16', 'float32', 'float64']: + raise TypeError( + "The data type of 'input' in cross_entropy must be float16 or float32 or float64, but received %s." + % (convert_dtype(input.dtype))) + if not soft_label: return cross_entropy2(input, label, ignore_index) helper = LayerHelper('cross_entropy', **locals()) @@ -2397,6 +2424,20 @@ def conv2d(input, conv2d = fluid.layers.conv2d(input=data, num_filters=2, filter_size=3, act="relu") """ + if not isinstance(input, Variable): + raise TypeError( + "The type of 'input' in conv2d must be Variable, but received %s" % + (type(input))) + if convert_dtype(input.dtype) in ['float16']: + warnings.warn( + "The data type of 'input' in conv2d only support float16 on GPU now." + ) + if convert_dtype(input.dtype) not in ['float16', 'float32', 'float64']: + raise TypeError( + "The data type of 'input' in conv2d must be float16 or float32 or float64, but received %s." + % (convert_dtype(input.dtype))) + + num_channels = input.shape[1] if not isinstance(use_cudnn, bool): raise ValueError("Attr(use_cudnn) should be True or False. Received " "Attr(use_cudnn): %s. " % str(use_cudnn)) @@ -2427,9 +2468,9 @@ def conv2d(input, else: if num_channels % groups != 0: raise ValueError( - "The number of input channels must be divisible by Attr(groups). " - "Received: number of channels(%s), groups(%s)." % - (str(num_channels), str(groups))) + "the channel of input must be divisible by groups," + "received: the channel of input is {}, the shape of input is {}" + ", the groups is {}".format(num_channels, input.shape, groups)) num_filter_channels = num_channels // groups filter_size = utils.convert_to_list(filter_size, 2, 'filter_size') @@ -3740,8 +3781,21 @@ def batch_norm(input, """ assert bias_attr is not False, "bias_attr should not be False in batch_norm." helper = LayerHelper('batch_norm', **locals()) - dtype = helper.input_dtype() + if not isinstance(input, Variable): + raise TypeError( + "The type of 'input' in batch_norm must be Variable, but received %s" + % (type(input))) + if convert_dtype(input.dtype) in ['float16']: + warnings.warn( + "The data type of 'input' in batch_norm only support float16 on GPU now." + ) + if convert_dtype(input.dtype) not in ['float16', 'float32', 'float64']: + raise TypeError( + "The data type of 'input' in batch_norm must be float16 or float32 or float64, but received %s." + % (convert_dtype(input.dtype))) + + dtype = helper.input_dtype() # use fp32 for bn parameter if dtype == core.VarDesc.VarType.FP16: dtype = core.VarDesc.VarType.FP32 diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index ec96e5f79ca39998dad8d2222cecd573d477ce5b..f26fc47f430e2995041b3ef6f50b81f05437e8ed 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -21,6 +21,8 @@ from paddle.fluid.op import Operator import paddle.fluid as fluid from op_test import OpTest from paddle.fluid.framework import grad_var_name +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard def _reference_testing(x, scale, offset, mean, var, epsilon, data_format): @@ -530,5 +532,19 @@ class TestBatchNormOpFreezeStatsAndScaleBiasTraining( self.fetch_list = ['y', 'mean', 'variance', 'x@GRAD'] +class TestBatchNormOpError(OpTest): + def test_errors(self): + with program_guard(Program(), Program()): + # the input of batch_norm must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + self.assertRaises(TypeError, fluid.layers.batch_norm, x1) + + # the input dtype of batch_norm must be float16 or float32 or float64 + # float16 only can be set on GPU place + x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="int32") + self.assertRaises(TypeError, fluid.layers.batch_norm, x2) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index c9dd714f2a5562c083d039d23de85dd057ba262b..e7388f415aa7a5c87f7f025cdec727323c61dd53 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -18,8 +18,9 @@ import unittest import numpy as np import paddle.fluid.core as core -from op_test import OpTest import paddle.fluid as fluid +from op_test import OpTest +from paddle.fluid import Program, program_guard def conv2d_forward_naive(input, @@ -647,6 +648,28 @@ class TestCUDNNExhaustiveSearch(TestConv2dOp): self.exhaustive_search = True +class TestConv2dOpError(OpTest): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_Variable(): + # the input of conv2d must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + fluid.layers.conv2d(x1, 1, 1) + + self.assertRaises(TypeError, test_Variable) + + def test_dtype(): + # the input dtype of conv2d must be float16 or float32 or float64 + # float16 only can be set on GPU place + x2 = fluid.layers.data( + name='x2', shape=[3, 4, 5, 6], dtype="int32") + fluid.layers.conv2d(x2, 1, 1) + + self.assertRaises(TypeError, test_dtype) + + # Please Don't remove the following code. # Currently, CI use cudnn V5.0 which not support dilation conv. # class TestCUDNNWithDilation(TestWithDilation): diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py index fc8484df2d5e219a6ecc335cd00c735119de7f32..613f074b4a8a40c1ee595df877a44e89e87ec36b 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py @@ -18,6 +18,8 @@ import unittest import numpy as np import paddle.fluid.core as core from op_test import OpTest, randomize_probability +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard class TestCrossEntropyOp(OpTest): @@ -356,5 +358,32 @@ create_test_class(TestCrossEntropyOp7, "TestCrossEntropyF16Op7") create_test_class(TestCrossEntropyOp7RemoveLastDim, "TestCrossEntropyF16Op7RemoveLastDim") + +class TestCrossEntropyOpError(OpTest): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_Variable(): + # the input of cross_entropy must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + lab1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + fluid.layers.cross_entropy(x1, lab1) + + self.assertRaises(TypeError, test_Variable) + + def test_dtype(): + # the input dtype of cross_entropy must be float16 or float32 or float64 + # float16 only can be set on GPU place + x2 = fluid.layers.data( + name='x2', shape=[3, 4, 5, 6], dtype="int32") + lab2 = fluid.layers.data( + name='lab2', shape=[3, 4, 5, 6], dtype="int32") + fluid.layers.cross_entropy(x2, lab2) + + self.assertRaises(TypeError, test_dtype) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 59918a7bb21c42359f7d6c4f6109ca4b1cdc4449..08ec1fce8d3dad06c10686914fd1f834076cef8e 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -18,6 +18,8 @@ import unittest import numpy as np import paddle.fluid.core as core from op_test import OpTest +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard class TestDropoutOp(OpTest): @@ -180,5 +182,27 @@ class TestFP16DropoutOp2(TestFP16DropoutOp): self.fix_seed = False +class TestDropoutOpError(OpTest): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_Variable(): + # the input of dropout must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + fluid.layers.dropout(x1, dropout_prob=0.5) + + self.assertRaises(TypeError, test_Variable) + + def test_dtype(): + # the input dtype of dropout must be float16 or float32 or float64 + # float16 only can be set on GPU place + x2 = fluid.layers.data( + name='x2', shape=[3, 4, 5, 6], dtype="int32") + fluid.layers.dropout(x2, dropout_prob=0.5) + + self.assertRaises(TypeError, test_dtype) + + if __name__ == '__main__': unittest.main()