未验证 提交 43d6abf0 编写于 作者: W wangguanzhong 提交者: GitHub

update conv2d, test=develop (#31480)

上级 50af0c2c
......@@ -175,9 +175,14 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
input_data_type != framework::proto::VarType::UINT8 &&
input_data_type != framework::proto::VarType::BF16) {
auto filter_data_type = ctx.Input<Tensor>("Filter")->type();
PADDLE_ENFORCE_EQ(input_data_type, filter_data_type,
PADDLE_ENFORCE_EQ(
input_data_type, filter_data_type,
platform::errors::InvalidArgument(
"input and filter data type should be consistent"));
"input and filter data type should be consistent, "
"but received input data type is %s and filter type "
"is %s",
paddle::framework::DataTypeToString(input_data_type),
paddle::framework::DataTypeToString(filter_data_type)));
}
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN,
......
......@@ -442,5 +442,20 @@ class TestFunctionalConv2DErrorCase10(TestFunctionalConv2DError):
self.data_format = "NHWC"
class TestFunctionalConv2DErrorCase11(TestFunctionalConv2DError):
def setUp(self):
self.in_channels = 3
self.out_channels = 5
self.filter_shape = 3
self.padding = 0
self.stride = 1
self.dilation = 1
self.groups = 1
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = False
self.data_format = "NHCW"
if __name__ == "__main__":
unittest.main()
......@@ -85,6 +85,12 @@ class _ConvNd(layers.Layer):
"when padding_mode in ['reflect', 'replicate', 'circular'], type of padding must be int"
)
valid_format = {'NHWC', 'NCHW', 'NDHWC', 'NCDHW', 'NLC', 'NCL'}
if data_format not in valid_format:
raise ValueError(
"data_format must be one of {}, but got data_format='{}'".
format(valid_format, data_format))
channel_last = (data_format == "NHWC") or (data_format == "NDHWC") or (
data_format == "NLC")
if channel_last:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册