From 43d6abf0a550faa973fc096acde85a5a2fb23516 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Tue, 9 Mar 2021 14:47:50 +0800 Subject: [PATCH] update conv2d, test=develop (#31480) --- paddle/fluid/operators/conv_op.cc | 11 ++++++++--- .../tests/unittests/test_functional_conv2d.py | 15 +++++++++++++++ python/paddle/nn/layer/conv.py | 6 ++++++ 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index f3dd0dcb46c..85bb4e5baa0 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -175,9 +175,14 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( input_data_type != framework::proto::VarType::UINT8 && input_data_type != framework::proto::VarType::BF16) { auto filter_data_type = ctx.Input("Filter")->type(); - PADDLE_ENFORCE_EQ(input_data_type, filter_data_type, - platform::errors::InvalidArgument( - "input and filter data type should be consistent")); + PADDLE_ENFORCE_EQ( + input_data_type, filter_data_type, + platform::errors::InvalidArgument( + "input and filter data type should be consistent, " + "but received input data type is %s and filter type " + "is %s", + paddle::framework::DataTypeToString(input_data_type), + paddle::framework::DataTypeToString(filter_data_type))); } if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN, diff --git a/python/paddle/fluid/tests/unittests/test_functional_conv2d.py b/python/paddle/fluid/tests/unittests/test_functional_conv2d.py index 68be0bf5d56..766e1bb1d34 100644 --- a/python/paddle/fluid/tests/unittests/test_functional_conv2d.py +++ b/python/paddle/fluid/tests/unittests/test_functional_conv2d.py @@ -442,5 +442,20 @@ class TestFunctionalConv2DErrorCase10(TestFunctionalConv2DError): self.data_format = "NHWC" +class TestFunctionalConv2DErrorCase11(TestFunctionalConv2DError): + def setUp(self): + self.in_channels = 3 + self.out_channels = 5 + self.filter_shape = 3 + self.padding = 0 + self.stride = 1 + self.dilation = 1 + self.groups = 1 + self.no_bias = False + self.act = "sigmoid" + self.use_cudnn = False + self.data_format = "NHCW" + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/layer/conv.py b/python/paddle/nn/layer/conv.py index 2c6308d1129..389920b9238 100644 --- a/python/paddle/nn/layer/conv.py +++ b/python/paddle/nn/layer/conv.py @@ -85,6 +85,12 @@ class _ConvNd(layers.Layer): "when padding_mode in ['reflect', 'replicate', 'circular'], type of padding must be int" ) + valid_format = {'NHWC', 'NCHW', 'NDHWC', 'NCDHW', 'NLC', 'NCL'} + if data_format not in valid_format: + raise ValueError( + "data_format must be one of {}, but got data_format='{}'". + format(valid_format, data_format)) + channel_last = (data_format == "NHWC") or (data_format == "NDHWC") or ( data_format == "NLC") if channel_last: -- GitLab