diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index f3dd0dcb46c36fb666c0f4766446098f939a9b6c..85bb4e5baa058a4cc5e6e4b9e1aec9ac75b3c5ea 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -175,9 +175,14 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( input_data_type != framework::proto::VarType::UINT8 && input_data_type != framework::proto::VarType::BF16) { auto filter_data_type = ctx.Input("Filter")->type(); - PADDLE_ENFORCE_EQ(input_data_type, filter_data_type, - platform::errors::InvalidArgument( - "input and filter data type should be consistent")); + PADDLE_ENFORCE_EQ( + input_data_type, filter_data_type, + platform::errors::InvalidArgument( + "input and filter data type should be consistent, " + "but received input data type is %s and filter type " + "is %s", + paddle::framework::DataTypeToString(input_data_type), + paddle::framework::DataTypeToString(filter_data_type))); } if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN, diff --git a/python/paddle/fluid/tests/unittests/test_functional_conv2d.py b/python/paddle/fluid/tests/unittests/test_functional_conv2d.py index 68be0bf5d561ef0d8fe92005dd9ddb47c21aca51..766e1bb1d34af21c073ef5cbd83d3feb27efe8cb 100644 --- a/python/paddle/fluid/tests/unittests/test_functional_conv2d.py +++ b/python/paddle/fluid/tests/unittests/test_functional_conv2d.py @@ -442,5 +442,20 @@ class TestFunctionalConv2DErrorCase10(TestFunctionalConv2DError): self.data_format = "NHWC" +class TestFunctionalConv2DErrorCase11(TestFunctionalConv2DError): + def setUp(self): + self.in_channels = 3 + self.out_channels = 5 + self.filter_shape = 3 + self.padding = 0 + self.stride = 1 + self.dilation = 1 + self.groups = 1 + self.no_bias = False + self.act = "sigmoid" + self.use_cudnn = False + self.data_format = "NHCW" + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/layer/conv.py b/python/paddle/nn/layer/conv.py index 2c6308d11292563bcc512a6cee23b1e0a33d6a3c..389920b923876dff9d6c663a607e7b8752efd7f1 100644 --- a/python/paddle/nn/layer/conv.py +++ b/python/paddle/nn/layer/conv.py @@ -85,6 +85,12 @@ class _ConvNd(layers.Layer): "when padding_mode in ['reflect', 'replicate', 'circular'], type of padding must be int" ) + valid_format = {'NHWC', 'NCHW', 'NDHWC', 'NCDHW', 'NLC', 'NCL'} + if data_format not in valid_format: + raise ValueError( + "data_format must be one of {}, but got data_format='{}'". + format(valid_format, data_format)) + channel_last = (data_format == "NHWC") or (data_format == "NDHWC") or ( data_format == "NLC") if channel_last: