diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py index 5e2059592b5170fdd623f4a20b9fa47612ff2a6a..1320623f8f8422f14677a3ca629735838dc94aa8 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py @@ -26,7 +26,7 @@ import paddle.fluid as fluid import paddle.fluid.dygraph as dygraph from paddle.fluid import core from paddle.fluid.optimizer import SGDOptimizer -from paddle.nn import Conv2D, Pool2D, Linear, SyncBatchNorm +from paddle.nn import Conv2d, Pool2D, Linear, SyncBatchNorm from paddle.fluid.dygraph.base import to_variable from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase @@ -42,26 +42,24 @@ class TestLayer(fluid.dygraph.Layer): act=None): super(TestLayer, self).__init__() - self._conv = Conv2D( - num_channels=num_channels, - num_filters=num_filters, - filter_size=filter_size, + self._conv = Conv2d( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, stride=stride, padding=(filter_size - 1) // 2, groups=groups, - act=None, bias_attr=False) self._sync_batch_norm = SyncBatchNorm(num_filters) - self._conv2 = Conv2D( - num_channels=num_filters, - num_filters=num_filters, - filter_size=filter_size, + self._conv2 = Conv2d( + in_channels=num_filters, + out_channels=num_filters, + kernel_size=filter_size, stride=stride, padding=(filter_size - 1) // 2, groups=groups, - act=None, bias_attr=False) self._sync_batch_norm2 = SyncBatchNorm( diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_layer.py b/python/paddle/fluid/tests/unittests/test_conv2d_layer.py index 64653ce2e7b8630030094b4004ecb17d56d3ff43..6bfe2aca530ddea6b49f12ad34dd9672e2a99ab5 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_layer.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_layer.py @@ -20,6 +20,10 @@ import paddle.fluid.initializer as I import unittest +def _reverse_repeat_list(t, n): + return list(x for x in reversed(t) for _ in range(n)) + + class Conv2DTestCase(unittest.TestCase): def __init__(self, methodName='runTest', @@ -29,12 +33,11 @@ class Conv2DTestCase(unittest.TestCase): num_filters=8, filter_size=3, padding=0, + padding_mode='zeros', stride=1, dilation=1, groups=1, - act=None, no_bias=False, - use_cudnn=True, data_format="NCHW", dtype="float32"): super(Conv2DTestCase, self).__init__(methodName) @@ -45,12 +48,16 @@ class Conv2DTestCase(unittest.TestCase): self.filter_size = filter_size self.padding = padding + if padding_mode in {'reflect', 'replicate', 'circular'}: + _paired_padding = fluid.layers.utils.convert_to_list(padding, 2, + 'padding') + self._reversed_padding_repeated_twice = _reverse_repeat_list( + _paired_padding, 2) + self.padding_mode = padding_mode self.stride = stride self.dilation = dilation self.groups = groups - self.act = act self.no_bias = no_bias - self.use_cudnn = use_cudnn self.data_format = data_format self.dtype = dtype @@ -91,19 +98,27 @@ class Conv2DTestCase(unittest.TestCase): bias_attr = False else: bias_attr = I.NumpyArrayInitializer(self.bias) + if self.padding_mode != 'zeros': + x_var = F.pad(x_var, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + data_format=self.data_format) + padding = 0 + else: + padding = self.padding + y_var = fluid.layers.conv2d( x_var, self.num_filters, self.filter_size, - padding=self.padding, + padding=padding, stride=self.stride, dilation=self.dilation, groups=self.groups, param_attr=weight_attr, bias_attr=bias_attr, - use_cudnn=self.use_cudnn, - act=self.act, data_format=self.data_format) + feed_dict = {"input": self.input} exe = fluid.Executor(place) exe.run(start) @@ -122,16 +137,24 @@ class Conv2DTestCase(unittest.TestCase): "weight", self.weight_shape, dtype=self.dtype) b_var = fluid.data( "bias", (self.num_filters, ), dtype=self.dtype) + + if self.padding_mode != 'zeros': + x_var = F.pad(x_var, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + data_format=self.data_format) + padding = 0 + else: + padding = self.padding + y_var = F.conv2d( x_var, w_var, b_var if not self.no_bias else None, - padding=self.padding, + padding=padding, stride=self.stride, dilation=self.dilation, groups=self.groups, - act=self.act, - use_cudnn=self.use_cudnn, data_format=self.data_format) feed_dict = {"input": self.input, "weight": self.weight} if self.bias is not None: @@ -143,18 +166,16 @@ class Conv2DTestCase(unittest.TestCase): def paddle_nn_layer(self): x_var = dg.to_variable(self.input) - conv = nn.Conv2D( + conv = nn.Conv2d( self.num_channels, self.num_filters, self.filter_size, padding=self.padding, + padding_mode=self.padding_mode, stride=self.stride, dilation=self.dilation, groups=self.groups, - act=self.act, - use_cudnn=self.use_cudnn, - data_format=self.data_format, - dtype=self.dtype) + data_format=self.data_format) conv.weight.set_value(self.weight) if not self.no_bias: conv.bias.set_value(self.bias) @@ -198,7 +219,7 @@ def add_cases(suite): methodName='runTest', stride=2, dilation=(2, 1))) suite.addTest( Conv2DTestCase( - methodName='runTest', padding="same", no_bias=True, act="sigmoid")) + methodName='runTest', padding="same", no_bias=True)) suite.addTest( Conv2DTestCase( methodName='runTest', filter_size=(3, 3), padding='valid')) @@ -222,15 +243,28 @@ def add_cases(suite): num_filters=6, num_channels=3, groups=3, - use_cudnn=False, - act="sigmoid", padding="valid")) + suite.addTest( + Conv2DTestCase( + methodName='runTest', + filter_size=(3, 3), + padding=1, + padding_mode='reflect')) + suite.addTest( + Conv2DTestCase( + methodName='runTest', + filter_size=(3, 3), + padding=1, + padding_mode='replicate')) + suite.addTest( + Conv2DTestCase( + methodName='runTest', + filter_size=(3, 3), + padding=1, + padding_mode='circular')) def add_error_cases(suite): - suite.addTest( - Conv2DErrorTestCase( - methodName='runTest', use_cudnn="not_valid")) suite.addTest( Conv2DErrorTestCase( methodName='runTest', num_channels=5, groups=2)) diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_layer.py b/python/paddle/fluid/tests/unittests/test_conv3d_layer.py index cf582c6210b76c6546de6d09d9219dbf4005bb17..56355a1c95e0396d0dec53cae02c3a99bf874013 100644 --- a/python/paddle/fluid/tests/unittests/test_conv3d_layer.py +++ b/python/paddle/fluid/tests/unittests/test_conv3d_layer.py @@ -32,9 +32,7 @@ class Conv3DTestCase(unittest.TestCase): stride=1, dilation=1, groups=1, - act=None, no_bias=False, - use_cudnn=True, data_format="NCDHW", dtype="float32"): super(Conv3DTestCase, self).__init__(methodName) @@ -48,9 +46,7 @@ class Conv3DTestCase(unittest.TestCase): self.stride = stride self.dilation = dilation self.groups = groups - self.act = act self.no_bias = no_bias - self.use_cudnn = use_cudnn self.data_format = data_format self.dtype = dtype @@ -101,8 +97,6 @@ class Conv3DTestCase(unittest.TestCase): groups=self.groups, param_attr=weight_attr, bias_attr=bias_attr, - use_cudnn=self.use_cudnn, - act=self.act, data_format=self.data_format) feed_dict = {"input": self.input} exe = fluid.Executor(place) @@ -130,8 +124,6 @@ class Conv3DTestCase(unittest.TestCase): stride=self.stride, dilation=self.dilation, groups=self.groups, - act=self.act, - use_cudnn=self.use_cudnn, data_format=self.data_format) feed_dict = {"input": self.input, "weight": self.weight} if self.bias is not None: @@ -143,7 +135,7 @@ class Conv3DTestCase(unittest.TestCase): def paddle_nn_layer(self): x_var = dg.to_variable(self.input) - conv = nn.Conv3D( + conv = nn.Conv3d( self.num_channels, self.num_filters, self.filter_size, @@ -151,10 +143,7 @@ class Conv3DTestCase(unittest.TestCase): stride=self.stride, dilation=self.dilation, groups=self.groups, - act=self.act, - use_cudnn=self.use_cudnn, - data_format=self.data_format, - dtype=self.dtype) + data_format=self.data_format) conv.weight.set_value(self.weight) if not self.no_bias: conv.bias.set_value(self.bias) @@ -225,15 +214,10 @@ def add_cases(suite): num_filters=6, num_channels=3, groups=3, - use_cudnn=False, - act="sigmoid", padding="valid")) def add_error_cases(suite): - suite.addTest( - Conv3DErrorTestCase( - methodName='runTest', use_cudnn="not_valid")) suite.addTest( Conv3DErrorTestCase( methodName='runTest', num_channels=5, groups=2)) diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_weight_norm.py b/python/paddle/fluid/tests/unittests/test_dygraph_weight_norm.py index f33334d536d14970475bf8674b3f38b855b22f88..466226c53fabbd315acd19c6421f210d0ca225c1 100644 --- a/python/paddle/fluid/tests/unittests/test_dygraph_weight_norm.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_weight_norm.py @@ -117,7 +117,7 @@ class TestDygraphWeightNorm(unittest.TestCase): def test_check_output(self): fluid.enable_imperative() - linear = paddle.nn.Conv2D(2, 3, 3) + linear = paddle.nn.Conv2d(2, 3, 3) before_weight = linear.weight.numpy() if self.dim == None: self.dim = -1 @@ -169,7 +169,7 @@ class TestDygraphRemoveWeightNorm(unittest.TestCase): def test_check_output(self): fluid.enable_imperative() - linear = paddle.nn.Conv2D(2, 3, 3) + linear = paddle.nn.Conv2d(2, 3, 3) before_weight = linear.weight wn = weight_norm(linear, dim=self.dim) rwn = remove_weight_norm(linear) diff --git a/python/paddle/fluid/tests/unittests/test_functional_conv2d.py b/python/paddle/fluid/tests/unittests/test_functional_conv2d.py index c43454eaaee9e3b2f9aa371453e58b009c99a52c..68be0bf5d561ef0d8fe92005dd9ddb47c21aca51 100644 --- a/python/paddle/fluid/tests/unittests/test_functional_conv2d.py +++ b/python/paddle/fluid/tests/unittests/test_functional_conv2d.py @@ -37,7 +37,6 @@ class TestFunctionalConv2D(TestCase): self.groups = 1 self.no_bias = False self.act = "sigmoid" - self.use_cudnn = True self.data_format = "NHWC" def prepare(self): @@ -88,7 +87,6 @@ class TestFunctionalConv2D(TestCase): param_attr=I.NumpyArrayInitializer(self.weight), bias_attr=False if self.no_bias else I.NumpyArrayInitializer(self.bias), - use_cudnn=self.use_cudnn, act=self.act, data_format=self.data_format) exe = fluid.Executor(self.place) @@ -121,9 +119,11 @@ class TestFunctionalConv2D(TestCase): stride=self.stride, dilation=self.dilation, groups=self.groups, - act=self.act, - data_format=self.data_format, - use_cudnn=self.use_cudnn) + data_format=self.data_format) + + if self.act == 'sigmoid': + y = F.sigmoid(y) + exe = fluid.Executor(self.place) exe.run(start) feed_dict = {"input": self.input, "weight": self.weight} @@ -144,10 +144,12 @@ class TestFunctionalConv2D(TestCase): padding=self.padding, stride=self.stride, dilation=self.dilation, - act=self.act, groups=self.groups, - data_format=self.data_format, - use_cudnn=self.use_cudnn) + data_format=self.data_format) + + if self.act == 'sigmoid': + y = F.sigmoid(y) + out = y.numpy() return out @@ -185,7 +187,6 @@ class TestFunctionalConv2DError(TestCase): self.groups = 1 self.no_bias = False self.act = "sigmoid" - self.use_cudnn = True self.data_format = "NHWC" def test_exception(self): @@ -228,9 +229,7 @@ class TestFunctionalConv2DError(TestCase): stride=self.stride, dilation=self.dilation, groups=self.groups, - act=self.act, - data_format=self.data_format, - use_cudnn=self.use_cudnn) + data_format=self.data_format) class TestFunctionalConv2DCase2(TestFunctionalConv2D): @@ -383,21 +382,6 @@ class TestFunctionalConv2DErrorCase4(TestFunctionalConv2DError): self.data_format = "NCHW" -class TestFunctionalConv2DErrorCase6(TestFunctionalConv2DError): - def setUp(self): - self.in_channels = 3 - self.out_channels = 5 - self.filter_shape = 3 - self.padding = "same" - self.stride = 1 - self.dilation = 1 - self.groups = 1 - self.no_bias = False - self.act = "sigmoid" - self.use_cudnn = "not_valid" - self.data_format = "NCHW" - - class TestFunctionalConv2DErrorCase7(TestFunctionalConv2DError): def setUp(self): self.in_channels = 3 diff --git a/python/paddle/fluid/tests/unittests/test_functional_conv3d.py b/python/paddle/fluid/tests/unittests/test_functional_conv3d.py index 195e3812f94843f6ccdd05cbc317238765e4c06b..b413a56c07a9ce3afbe15baffbffaf92a3d42129 100644 --- a/python/paddle/fluid/tests/unittests/test_functional_conv3d.py +++ b/python/paddle/fluid/tests/unittests/test_functional_conv3d.py @@ -37,7 +37,6 @@ class TestFunctionalConv3D(TestCase): self.groups = 1 self.no_bias = False self.act = "sigmoid" - self.use_cudnn = True self.data_format = "NDHWC" def prepare(self): @@ -88,7 +87,6 @@ class TestFunctionalConv3D(TestCase): param_attr=I.NumpyArrayInitializer(self.weight), bias_attr=False if self.no_bias else I.NumpyArrayInitializer(self.bias), - use_cudnn=self.use_cudnn, act=self.act, data_format=self.data_format) exe = fluid.Executor(self.place) @@ -121,9 +119,11 @@ class TestFunctionalConv3D(TestCase): stride=self.stride, dilation=self.dilation, groups=self.groups, - act=self.act, - data_format=self.data_format, - use_cudnn=self.use_cudnn) + data_format=self.data_format) + + if self.act == 'sigmoid': + y = F.sigmoid(y) + exe = fluid.Executor(self.place) exe.run(start) feed_dict = {"input": self.input, "weight": self.weight} @@ -144,10 +144,12 @@ class TestFunctionalConv3D(TestCase): padding=self.padding, stride=self.stride, dilation=self.dilation, - act=self.act, groups=self.groups, - data_format=self.data_format, - use_cudnn=self.use_cudnn) + data_format=self.data_format) + + if self.act == 'sigmoid': + y = F.sigmoid(y) + out = y.numpy() return out @@ -185,7 +187,6 @@ class TestFunctionalConv3DError(TestCase): self.groups = 1 self.no_bias = False self.act = "sigmoid" - self.use_cudnn = True self.data_format = "NDHWC" def test_exception(self): @@ -228,9 +229,10 @@ class TestFunctionalConv3DError(TestCase): stride=self.stride, dilation=self.dilation, groups=self.groups, - act=self.act, - data_format=self.data_format, - use_cudnn=self.use_cudnn) + data_format=self.data_format) + + if self.act == 'sigmoid': + y = F.sigmoid(y) class TestFunctionalConv3DCase2(TestFunctionalConv3D): @@ -244,7 +246,6 @@ class TestFunctionalConv3DCase2(TestFunctionalConv3D): self.groups = 1 self.no_bias = False self.act = "sigmoid" - self.use_cudnn = True self.data_format = "NDHWC" @@ -259,7 +260,6 @@ class TestFunctionalConv3DCase3(TestFunctionalConv3D): self.groups = 1 self.no_bias = False self.act = "sigmoid" - self.use_cudnn = True self.data_format = "NDHWC" @@ -274,7 +274,6 @@ class TestFunctionalConv3DCase4(TestFunctionalConv3D): self.groups = 1 self.no_bias = False self.act = "sigmoid" - self.use_cudnn = True self.data_format = "NDHWC" @@ -289,7 +288,6 @@ class TestFunctionalConv3DCase5(TestFunctionalConv3D): self.groups = 1 self.no_bias = False self.act = "sigmoid" - self.use_cudnn = True self.data_format = "NDHWC" @@ -304,7 +302,6 @@ class TestFunctionalConv3DCase6(TestFunctionalConv3D): self.groups = 1 self.no_bias = False self.act = "sigmoid" - self.use_cudnn = True self.data_format = "NCDHW" @@ -319,7 +316,6 @@ class TestFunctionalConv3DCase7(TestFunctionalConv3D): self.groups = 2 self.no_bias = False self.act = "sigmoid" - self.use_cudnn = True self.data_format = "NCDHW" @@ -349,7 +345,6 @@ class TestFunctionalConv3DErrorCase2(TestFunctionalConv3DError): self.groups = 1 self.no_bias = False self.act = "sigmoid" - self.use_cudnn = False self.data_format = "NCDHW" @@ -364,7 +359,6 @@ class TestFunctionalConv3DErrorCase3(TestFunctionalConv3DError): self.groups = 2 self.no_bias = False self.act = "sigmoid" - self.use_cudnn = False self.data_format = "not_valid" @@ -379,22 +373,6 @@ class TestFunctionalConv3DErrorCase4(TestFunctionalConv3DError): self.groups = 2 self.no_bias = False self.act = "sigmoid" - self.use_cudnn = False - self.data_format = "NCDHW" - - -class TestFunctionalConv3DErrorCase6(TestFunctionalConv3DError): - def setUp(self): - self.in_channels = 3 - self.out_channels = 5 - self.filter_shape = 3 - self.padding = "same" - self.stride = 1 - self.dilation = 1 - self.groups = 1 - self.no_bias = False - self.act = "sigmoid" - self.use_cudnn = "not_valid" self.data_format = "NCDHW" @@ -409,7 +387,6 @@ class TestFunctionalConv3DErrorCase7(TestFunctionalConv3DError): self.groups = 1 self.no_bias = False self.act = "sigmoid" - self.use_cudnn = True self.data_format = "not_valid" @@ -424,7 +401,6 @@ class TestFunctionalConv3DErrorCase8(TestFunctionalConv3DError): self.groups = 1 self.no_bias = False self.act = "sigmoid" - self.use_cudnn = True self.data_format = "NCDHW" @@ -439,7 +415,6 @@ class TestFunctionalConv3DErrorCase9(TestFunctionalConv3DError): self.groups = 1 self.no_bias = False self.act = "sigmoid" - self.use_cudnn = False self.data_format = "NCDHW" @@ -454,7 +429,6 @@ class TestFunctionalConv3DErrorCase10(TestFunctionalConv3DError): self.groups = 2 self.no_bias = False self.act = "sigmoid" - self.use_cudnn = False self.data_format = "NDHWC" diff --git a/python/paddle/fluid/tests/unittests/test_imperative_layer_apply.py b/python/paddle/fluid/tests/unittests/test_imperative_layer_apply.py index a391c088a3640c097ff0f4ff714bf50470c575c6..b15ad911ee79d47011be6eaa4bde62ba71c55c0e 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_layer_apply.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_layer_apply.py @@ -28,11 +28,11 @@ class LeNetDygraph(fluid.dygraph.Layer): super(LeNetDygraph, self).__init__() self.num_classes = num_classes self.features = nn.Sequential( - nn.Conv2D( + nn.Conv2d( 1, 6, 3, stride=1, padding=1), nn.ReLU(), nn.Pool2D(2, 'max', 2), - nn.Conv2D( + nn.Conv2d( 6, 16, 5, stride=1, padding=0), nn.ReLU(), nn.Pool2D(2, 'max', 2)) @@ -61,7 +61,7 @@ def init_weights(layer): new_bias = paddle.fill_constant( layer.bias.shape, layer.bias.dtype, value=-0.1) layer.bias.set_value(new_bias) - elif type(layer) == nn.Conv2D: + elif type(layer) == nn.Conv2d: new_weight = paddle.fill_constant( layer.weight.shape, layer.weight.dtype, value=0.7) layer.weight.set_value(new_weight) @@ -81,7 +81,7 @@ class TestLayerApply(unittest.TestCase): if type(layer) == nn.Linear: np.testing.assert_allclose(layer.weight.numpy(), 0.9) np.testing.assert_allclose(layer.bias.numpy(), -0.1) - elif type(layer) == nn.Conv2D: + elif type(layer) == nn.Conv2d: np.testing.assert_allclose(layer.weight.numpy(), 0.7) np.testing.assert_allclose(layer.bias.numpy(), -0.2) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_layer_children.py b/python/paddle/fluid/tests/unittests/test_imperative_layer_children.py index e6d8b052d7f1839466d09e4756b49cc6a38554cc..c7e0902341a59649219cf94ef9741fdf7ae09233 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_layer_children.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_layer_children.py @@ -27,11 +27,11 @@ class LeNetDygraph(fluid.dygraph.Layer): def __init__(self): super(LeNetDygraph, self).__init__() self.features = nn.Sequential( - nn.Conv2D( + nn.Conv2d( 1, 6, 3, stride=1, padding=1), nn.ReLU(), nn.Pool2D(2, 'max', 2), - nn.Conv2D( + nn.Conv2d( 6, 16, 5, stride=1, padding=0), nn.ReLU(), nn.Pool2D(2, 'max', 2)) diff --git a/python/paddle/fluid/tests/unittests/test_retain_graph.py b/python/paddle/fluid/tests/unittests/test_retain_graph.py index 360a2de1df6ed8c97dac9ca4473e0b49a240cb5f..5f3e8ff737a04cd40c715f267d8cd1f391acee1c 100644 --- a/python/paddle/fluid/tests/unittests/test_retain_graph.py +++ b/python/paddle/fluid/tests/unittests/test_retain_graph.py @@ -26,7 +26,7 @@ paddle.manual_seed(SEED) class Generator(fluid.dygraph.Layer): def __init__(self): super(Generator, self).__init__() - self.conv1 = paddle.nn.Conv2D(3, 3, 3, 1) + self.conv1 = paddle.nn.Conv2d(3, 3, 3, padding=1) def forward(self, x): x = self.conv1(x) @@ -37,7 +37,7 @@ class Generator(fluid.dygraph.Layer): class Discriminator(fluid.dygraph.Layer): def __init__(self): super(Discriminator, self).__init__() - self.convd = paddle.nn.Conv2D(6, 3, 1) + self.convd = paddle.nn.Conv2d(6, 3, 1) def forward(self, x): x = self.convd(x) diff --git a/python/paddle/incubate/hapi/tests/test_model.py b/python/paddle/incubate/hapi/tests/test_model.py index 103ab30724ca883ce02358487b005203be805455..25b62667af416e7c05b08d8be458527ccbda9d55 100644 --- a/python/paddle/incubate/hapi/tests/test_model.py +++ b/python/paddle/incubate/hapi/tests/test_model.py @@ -23,7 +23,7 @@ import shutil import tempfile from paddle import fluid -from paddle.nn import Conv2D, Pool2D, Linear, ReLU, Sequential +from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential from paddle.fluid.dygraph.base import to_variable import paddle.incubate.hapi as hapi @@ -40,11 +40,11 @@ class LeNetDygraph(fluid.dygraph.Layer): super(LeNetDygraph, self).__init__() self.num_classes = num_classes self.features = Sequential( - Conv2D( + Conv2d( 1, 6, 3, stride=1, padding=1), ReLU(), Pool2D(2, 'max', 2), - Conv2D( + Conv2d( 6, 16, 5, stride=1, padding=0), ReLU(), Pool2D(2, 'max', 2)) diff --git a/python/paddle/incubate/hapi/tests/test_uncombined_weight2state_dict.py b/python/paddle/incubate/hapi/tests/test_uncombined_weight2state_dict.py index c2035a8b5c5958d54c79d6ee0ff6df654bb35d51..26ec53014b1c3b113a0e1ee82f3b9edfe9f48a3f 100644 --- a/python/paddle/incubate/hapi/tests/test_uncombined_weight2state_dict.py +++ b/python/paddle/incubate/hapi/tests/test_uncombined_weight2state_dict.py @@ -22,7 +22,7 @@ import shutil import tempfile from paddle import fluid -from paddle.nn import Conv2D, Pool2D, Linear, ReLU, Sequential +from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential from paddle.incubate.hapi.utils import uncombined_weight_to_state_dict @@ -32,11 +32,11 @@ class LeNetDygraph(fluid.dygraph.Layer): super(LeNetDygraph, self).__init__() self.num_classes = num_classes self.features = Sequential( - Conv2D( + Conv2d( 1, 6, 3, stride=1, padding=1), ReLU(), Pool2D(2, 'max', 2), - Conv2D( + Conv2d( 6, 16, 5, stride=1, padding=0), ReLU(), Pool2D(2, 'max', 2)) diff --git a/python/paddle/incubate/hapi/vision/models/lenet.py b/python/paddle/incubate/hapi/vision/models/lenet.py index db1d894b4aa5f2535795c6350faad6ee3aee1164..dc7b094de0f26e04b9f07d011d3ce492950df269 100644 --- a/python/paddle/incubate/hapi/vision/models/lenet.py +++ b/python/paddle/incubate/hapi/vision/models/lenet.py @@ -13,7 +13,7 @@ #limitations under the License. import paddle.fluid as fluid -from paddle.nn import Conv2D, Pool2D, Linear, ReLU, Sequential +from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential __all__ = ['LeNet'] @@ -39,11 +39,11 @@ class LeNet(fluid.dygraph.Layer): super(LeNet, self).__init__() self.num_classes = num_classes self.features = Sequential( - Conv2D( + Conv2d( 1, 6, 3, stride=1, padding=1), ReLU(), Pool2D(2, 'max', 2), - Conv2D( + Conv2d( 6, 16, 5, stride=1, padding=0), ReLU(), Pool2D(2, 'max', 2)) diff --git a/python/paddle/incubate/hapi/vision/models/vgg.py b/python/paddle/incubate/hapi/vision/models/vgg.py index 74e7228e5249fe990d037c9f12e75b6d4839c591..30f6e120b2502113045b3583686360f4ed2c32ac 100644 --- a/python/paddle/incubate/hapi/vision/models/vgg.py +++ b/python/paddle/incubate/hapi/vision/models/vgg.py @@ -13,7 +13,7 @@ # limitations under the License. import paddle.fluid as fluid -from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear +from paddle.nn import Conv2d, Pool2D, BatchNorm, Linear, ReLU from paddle.fluid.dygraph.container import Sequential from ...download import get_weights_path_from_url @@ -105,12 +105,11 @@ def make_layers(cfg, batch_norm=False): layers += [Pool2D(pool_size=2, pool_stride=2)] else: if batch_norm: - conv2d = Conv2D(in_channels, v, filter_size=3, padding=1) - layers += [conv2d, BatchNorm(v, act='relu')] + conv2d = Conv2d(in_channels, v, kernel_size=3, padding=1) + layers += [conv2d, BatchNorm(v), ReLU()] else: - conv2d = Conv2D( - in_channels, v, filter_size=3, padding=1, act='relu') - layers += [conv2d] + conv2d = Conv2d(in_channels, v, kernel_size=3, padding=1) + layers += [conv2d, ReLU()] in_channels = v return Sequential(*layers) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 6b449c259b4d4a1fd4c1ab93019b57961ffdf057..75ab3ab5c9bceea3aa3582754736d4adbad648b6 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -93,9 +93,9 @@ from .layer.common import Dropout2D #DEFINE_ALIAS from .layer.common import Dropout3D #DEFINE_ALIAS from .layer.pooling import AdaptiveAvgPool2d #DEFINE_ALIAS from .layer.pooling import AdaptiveAvgPool3d #DEFINE_ALIAS -from .layer.conv import Conv2D #DEFINE_ALIAS +from .layer.conv import Conv2d #DEFINE_ALIAS +from .layer.conv import Conv3d #DEFINE_ALIAS from .layer.conv import ConvTranspose2d #DEFINE_ALIAS -from .layer.conv import Conv3D #DEFINE_ALIAS from .layer.conv import ConvTranspose3d #DEFINE_ALIAS # from .layer.conv import TreeConv #DEFINE_ALIAS # from .layer.conv import Conv1D #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index fc0656c89dba948c4b5e8f40cd634430e9ff72b6..25f39808c523771af82b4c232b848e455d5a25ec 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -88,20 +88,16 @@ def _update_padding_nd(padding, channel_last, num_dims): return padding, padding_algorithm -def conv2d(input, +def conv2d(x, weight, bias=None, - padding=0, stride=1, + padding=0, dilation=1, groups=1, - use_cudnn=True, - act=None, data_format="NCHW", name=None): """ - :alias_main: paddle.nn.functional.conv2d - :alias: paddle.nn.functional.conv2d,paddle.nn.functional.conv.conv2d The convolution2D layer calculates the output based on the input, filter and strides, paddings, dilations, groups parameters. Input and @@ -153,12 +149,15 @@ def conv2d(input, W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1 Args: - input (Variable): The input is 4-D Tensor with shape [N, C, H, W], the data type + x (Tensor): The input is 4-D Tensor with shape [N, C, H, W], the data type of input is float16 or float32 or float64. - weight (Variable): The convolution kernel with shape [M, C/g, kH, kW], where M is + weight (Tensor): The convolution kernel with shape [M, C/g, kH, kW], where M is the number of output channels, g is the number of groups, kH is the filter's height, kW is the filter's width. - bias (Variable, optional): The bias with shape [M,]. + bias (Tensor, optional): The bias with shape [M,]. + stride (int|tuple): The stride size. It means the stride in convolution. + If stride is a tuple, it must contain two integers, (stride_height, stride_width). + Otherwise, stride_height = stride_width = stride. Default: stride = 1. padding (string|int|list|tuple): The padding size. It means the number of zero-paddings on both sides for each dimension.If `padding` is a string, either 'VALID' or 'SAME' which is the padding algorithm. If padding size is a tuple or list, @@ -169,9 +168,6 @@ def conv2d(input, when `data_format` is `"NHWC"`, `pool_padding` can be in the form `[[0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. Default: padding = 0. - stride (int|tuple): The stride size. It means the stride in convolution. - If stride is a tuple, it must contain two integers, (stride_height, stride_width). - Otherwise, stride_height = stride_width = stride. Default: stride = 1. dilation (int|tuple): The dilation size. It means the spacing between the kernel points. If dilation is a tuple, it must contain two integers, (dilation_height, dilation_width). Otherwise, dilation_height = dilation_width = dilation. @@ -181,10 +177,6 @@ def conv2d(input, the first half of the filters is only connected to the first half of the input channels, while the second half of the filters is only connected to the second half of the input channels. Default: groups=1. - use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn - library is installed. Default: True - act (str): Activation type, if it is set to None, activation is not appended. - Default: None data_format (str, optional): Specify the data format of the input, and the data format of the output will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: @@ -194,13 +186,9 @@ def conv2d(input, None by default. Returns: - A Variable holding Tensor representing the conv2d, whose data type is the - same with input. If act is None, the tensor variable storing the convolution - result, and if act is not None, the tensor variable storing convolution - and non-linearity activation result. + A Tensor representing the conv2d result, whose data type is the same with input. Raises: - ValueError: If the type of `use_cudnn` is not bool. ValueError: If `data_format` is not "NCHW" or "NHWC". ValueError: If the channel dimmention of the input is less than or equal to zero. ValueError: If `padding` is a string, but not "SAME" or "VALID". @@ -215,62 +203,65 @@ def conv2d(input, Examples: .. code-block:: python - from paddle import fluid + import paddle import paddle.nn.functional as F - import paddle.fluid.dygraph as dg import numpy as np x = np.random.randn(2, 3, 8, 8).astype(np.float32) w = np.random.randn(6, 3, 3, 3).astype(np.float32) - place = fluid.CPUPlace() - with dg.guard(place): - x_var = dg.to_variable(x) - w_var = dg.to_variable(w) - y_var = F.conv2d(x_var, w_var, act="relu") - y_np = y_var.numpy() + paddle.disable_static() + + x_var = paddle.to_tensor(x) + w_var = paddle.to_tensor(w) + y_var = F.conv2d(x_var, w_var) + y_np = y_var.numpy() + print(y_np.shape) # (2, 6, 6, 6) """ # entry checks - if not isinstance(use_cudnn, bool): - raise ValueError("Attr(use_cudnn) should be True or False. " - "Received Attr(use_cudnn): {}.".format(use_cudnn)) if data_format not in ["NCHW", "NHWC"]: raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'. " "Received Attr(data_format): {}.".format(data_format)) channel_last = (data_format == "NHWC") channel_dim = -1 if channel_last else 1 - num_channels = input.shape[channel_dim] + num_channels = x.shape[channel_dim] num_filters = weight.shape[0] if num_channels < 0: raise ValueError("The channel dimmention of the input({}) " "should be defined. Received: {}.".format( - input.shape, num_channels)) + x.shape, num_channels)) if num_channels % groups != 0: raise ValueError( "the channel of input must be divisible by groups," "received: the channel of input is {}, the shape of input is {}" - ", the groups is {}".format(num_channels, input.shape, groups)) + ", the groups is {}".format(num_channels, x.shape, groups)) if num_filters % groups != 0: raise ValueError( "the number of filters must be divisible by groups," "received: the number of filters is {}, the shape of weight is {}" ", the groups is {}".format(num_filters, weight.shape, groups)) + # use_cudnn = True if core.is_compiled_with_cuda() else False + cudnn_version = get_cudnn_version() + + use_cudnn = True if (core.is_compiled_with_cuda() and + cudnn_version is not None) else False + # update attrs padding, padding_algorithm = _update_padding_nd(padding, channel_last, 2) stride = utils.convert_to_list(stride, 2, 'stride') dilation = utils.convert_to_list(dilation, 2, 'dilation') l_type = "conv2d" - if (num_channels == groups and num_filters % num_channels == 0 and - not use_cudnn): + if (num_channels == groups and num_filters % num_channels == 0): l_type = 'depthwise_conv2d' + use_cudnn = False - inputs = {'Input': [input], 'Filter': [weight]} + inputs = {'Input': [x], 'Filter': [weight]} attrs = { 'strides': stride, 'paddings': padding, @@ -288,15 +279,13 @@ def conv2d(input, 'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', False, 'fuse_relu_before_depthwise_conv', False, "padding_algorithm", padding_algorithm, "data_format", data_format) - pre_bias = getattr(core.ops, l_type)(input, weight, *attrs) + pre_bias = getattr(core.ops, l_type)(x, weight, *attrs) if bias is not None: - pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim) + out = nn.elementwise_add(pre_bias, bias, axis=channel_dim) else: - pre_act = pre_bias - out = dygraph_utils._append_activation_in_dygraph( - pre_act, act, use_cudnn=use_cudnn) + out = pre_bias else: - inputs = {'Input': [input], 'Filter': [weight]} + inputs = {'Input': [x], 'Filter': [weight]} attrs = { 'strides': stride, 'paddings': padding, @@ -308,8 +297,8 @@ def conv2d(input, "padding_algorithm": padding_algorithm, "data_format": data_format } - check_variable_and_dtype(input, 'input', - ['float16', 'float32', 'float64'], 'conv2d') + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'conv2d') helper = LayerHelper(l_type, **locals()) dtype = helper.input_dtype() pre_bias = helper.create_variable_for_type_inference(dtype) @@ -317,10 +306,10 @@ def conv2d(input, helper.append_op( type=l_type, inputs=inputs, outputs=outputs, attrs=attrs) if bias is not None: - pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim) + out = nn.elementwise_add(pre_bias, bias, axis=channel_dim) else: - pre_act = pre_bias - out = helper.append_activation(pre_act) + out = pre_bias + return out @@ -571,20 +560,16 @@ def conv_transpose2d(x, return out -def conv3d(input, +def conv3d(x, weight, bias=None, - padding=0, stride=1, + padding=0, dilation=1, groups=1, - use_cudnn=True, - act=None, data_format="NCDHW", name=None): """ - :alias_main: paddle.nn.functional.conv3d - :alias: paddle.nn.functional.conv3d,paddle.nn.functional.conv.conv3d The convolution3D layer calculates the output based on the input, filter and strides, paddings, dilations, groups parameters. Input(Input) and @@ -630,12 +615,15 @@ def conv3d(input, W_{out}&= \\frac{(W_{in} + 2 * paddings[2] - (dilations[2] * (W_f - 1) + 1))}{strides[2]} + 1 Args: - input (Variable): The input is 5-D Tensor with shape [N, C, D, H, W], the data + x (Tensor): The input is 5-D Tensor with shape [N, C, D, H, W], the data type of input is float16 or float32 or float64. weight (Variable): The convolution kernel, a Tensor with shape [M, C/g, kD, kH, kW], where M is the number of filters(output channels), g is the number of groups, kD, kH, kW are the filter's depth, height and width respectively. - bias (Variable, optional): The bias, a Tensor of shape [M, ]. + bias (Tensor, optional): The bias, a Tensor of shape [M, ]. + stride (int|tuple): The stride size. It means the stride in convolution. If stride is a + tuple, it must contain three integers, (stride_depth, stride_height, stride_width). + Otherwise, stride_depth = stride_height = stride_width = stride. Default: stride = 1. padding (string|int|list|tuple): The padding size. It means the number of zero-paddings on both sides for each dimension. If `padding` is a string, either 'VALID' or 'SAME' which is the padding algorithm. If padding size is a tuple or list, @@ -646,9 +634,6 @@ def conv3d(input, when `data_format` is `"NDHWC"`, `pool_padding` can be in the form `[[0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. Default: padding = 0. - stride (int|tuple): The stride size. It means the stride in convolution. If stride is a - tuple, it must contain three integers, (stride_depth, stride_height, stride_width). - Otherwise, stride_depth = stride_height = stride_width = stride. Default: stride = 1. dilation (int|tuple): The dilation size. It means the spacing between the kernel points. If dilation is a tuple, it must contain three integers, (dilation_depth, dilation_height, dilation_width). Otherwise, dilation_depth = dilation_height = dilation_width = dilation. @@ -658,10 +643,6 @@ def conv3d(input, the first half of the filters is only connected to the first half of the input channels, while the second half of the filters is only connected to the second half of the input channels. Default: groups=1 - use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn - library is installed. Default: True - act (str): Activation type, if it is set to None, activation is not appended. - Default: None. data_format (str, optional): Specify the data format of the input, and the data format of the output will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: @@ -671,13 +652,12 @@ def conv3d(input, None by default. Returns: - A Variable holding Tensor representing the conv3d, whose data type is + A Tensor representing the conv3d, whose data type is the same with input. If act is None, the tensor variable storing the convolution result, and if act is not None, the tensor variable storing convolution and non-linearity activation result. Raises: - ValueError: If the type of `use_cudnn` is not bool. ValueError: If `data_format` is not "NCDHW" or "NDHWC". ValueError: If the channel dimmention of the input is less than or equal to zero. ValueError: If `padding` is a string, but not "SAME" or "VALID". @@ -711,10 +691,6 @@ def conv3d(input, # (2, 6, 6, 6, 6) """ # entry check - if not isinstance(use_cudnn, bool): - raise ValueError("Attr(use_cudnn) should be True or False. Received " - "Attr(use_cudnn): {}. ".format(use_cudnn)) - if data_format not in ["NCDHW", "NDHWC"]: raise ValueError( "Attr(data_format) should be 'NCDHW' or 'NDHWC'. Received " @@ -722,12 +698,12 @@ def conv3d(input, channel_last = (data_format == "NDHWC") channel_dim = -1 if channel_last else 1 - num_channels = input.shape[channel_dim] + num_channels = x.shape[channel_dim] num_filters = weight.shape[0] if num_channels < 0: raise ValueError( "The channel dimmention of the input({}) should be defined. " - "Received: {}.".format(input.shape, num_channels)) + "Received: {}.".format(x.shape, num_channels)) if num_channels % groups != 0: raise ValueError( "The number of input channels must be divisible by Attr(groups). " @@ -739,6 +715,10 @@ def conv3d(input, "Received: number of filters({}), groups({}).".format(num_filters, groups)) + cudnn_version = get_cudnn_version() + use_cudnn = True if (core.is_compiled_with_cuda() and + cudnn_version is not None) else False + padding, padding_algorithm = _update_padding_nd(padding, channel_last, 3) stride = utils.convert_to_list(stride, 3, 'stride') dilation = utils.convert_to_list(dilation, 3, 'dilation') @@ -749,15 +729,13 @@ def conv3d(input, 'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', False, "padding_algorithm", padding_algorithm, "data_format", data_format) - pre_bias = getattr(core.ops, op_type)(input, weight, *attrs) + pre_bias = getattr(core.ops, op_type)(x, weight, *attrs) if bias is not None: - pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim) + out = nn.elementwise_add(pre_bias, bias, axis=channel_dim) else: - pre_act = pre_bias - out = dygraph_utils._append_activation_in_dygraph( - pre_act, act, use_cudnn=use_cudnn) + out = pre_bias else: - inputs = {'Input': [input], 'Filter': [weight]} + inputs = {'Input': [x], 'Filter': [weight]} attrs = { 'strides': stride, 'paddings': padding, @@ -770,8 +748,8 @@ def conv3d(input, } helper = LayerHelper(op_type, **locals()) dtype = helper.input_dtype() - check_variable_and_dtype(input, 'input', - ['float16', 'float32', 'float64'], 'conv3d') + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'conv3d') pre_bias = helper.create_variable_for_type_inference(dtype) outputs = {"Output": [pre_bias]} @@ -779,10 +757,9 @@ def conv3d(input, helper.append_op( type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) if bias is not None: - pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim) + out = nn.elementwise_add(pre_bias, bias, axis=channel_dim) else: - pre_act = pre_bias - out = helper.append_activation(pre_act) + out = pre_bias return out diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 8442bac7a305ed76b8c999783c734b753b205984..7604f1884db46fb2d8461c533a7bb6902970b332 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -57,9 +57,9 @@ from .common import Dropout2D #DEFINE_ALIAS from .common import Dropout3D #DEFINE_ALIAS from .pooling import AdaptiveAvgPool2d #DEFINE_ALIAS from .pooling import AdaptiveAvgPool3d #DEFINE_ALIAS -from .conv import Conv2D #DEFINE_ALIAS +from .conv import Conv2d #DEFINE_ALIAS +from .conv import Conv3d #DEFINE_ALIAS from .conv import ConvTranspose2d #DEFINE_ALIAS -from .conv import Conv3D #DEFINE_ALIAS from .conv import ConvTranspose3d #DEFINE_ALIAS # from .conv import TreeConv #DEFINE_ALIAS # from .conv import Conv1D #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/conv.py b/python/paddle/nn/layer/conv.py index 2e0cfb96244d21ab84b0c6ac1a6a8dcebdfded53..32e15dea523debe5f2cbdcdc25d08a15c7ae1798 100644 --- a/python/paddle/nn/layer/conv.py +++ b/python/paddle/nn/layer/conv.py @@ -15,9 +15,9 @@ # TODO: define classes of convolutional neural network __all__ = [ - 'Conv2D', + 'Conv2d', + 'Conv3d', 'ConvTranspose2d', - 'Conv3D', 'ConvTranspose3d', # 'TreeConv', # 'Conv1D' @@ -38,6 +38,15 @@ def _get_default_param_initializer(num_channels, filter_size): return Normal(0.0, std, 0) +def _reverse_repeat_list(t, n): + """Reverse the order of `t` and repeat each element for `n` times. + + This can be used to translate padding arg used by Conv and Pooling modules + to the ones used by `F.pad`. + """ + return list(x for x in reversed(t) for _ in range(n)) + + class _ConvNd(layers.Layer): def __init__(self, in_channels, @@ -63,17 +72,38 @@ class _ConvNd(layers.Layer): self._out_channels = out_channels self._data_format = data_format + valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'} + if padding_mode not in valid_padding_modes: + raise ValueError( + "padding_mode must be one of {}, but got padding_mode='{}'". + format(valid_padding_modes, padding_mode)) + + if padding_mode in {'reflect', 'replicate', 'circular' + } and not isinstance(padding, np.int): + raise TypeError( + "when padding_mode in ['reflect', 'replicate', 'circular'], type of padding must be int" + ) + self._stride = utils.convert_to_list(stride, dims, 'stride') self._dilation = utils.convert_to_list(dilation, dims, 'dilation') self._kernel_size = utils.convert_to_list(kernel_size, dims, 'kernel_size') self._padding = padding + self._padding_mode = padding_mode self.output_padding = output_padding if transposed: filter_shape = [self._in_channels, out_channels // groups ] + self._kernel_size else: + if in_channels % groups != 0: + raise ValueError("in_channels must be divisible by groups.") + + if padding_mode in {'reflect', 'replicate', 'circular'}: + _paired_padding = utils.convert_to_list(padding, 2, 'padding') + self._reversed_padding_repeated_twice = _reverse_repeat_list( + _paired_padding, 2) + filter_shape = [out_channels, in_channels // groups ] + self._kernel_size @@ -83,12 +113,10 @@ class _ConvNd(layers.Layer): attr=self._bias_attr, shape=[self._out_channels], is_bias=True) -class Conv2D(layers.Layer): +class Conv2d(_ConvNd): """ - :alias_main: paddle.nn.Conv2D - :alias: paddle.nn.Conv2D,paddle.nn.layer.Conv2D,paddle.nn.layer.conv.Conv2D - This interface is used to construct a callable object of the ``Conv2D`` class. + This interface is used to construct a callable object of the ``Conv2d`` class. For more details, refer to code examples. The convolution2D layer calculates the output based on the input, filter and strides, paddings, dilations, groups parameters. Input and @@ -120,32 +148,13 @@ class Conv2D(layers.Layer): * :math:`\\sigma`: Activation function. * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. - Example: - - - Input: - - Input shape: :math:`(N, C_{in}, H_{in}, W_{in})` - - Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)` - - - Output: - - Output shape: :math:`(N, C_{out}, H_{out}, W_{out})` - - Where - - .. math:: - - H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\ - W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1 - Parameters: - num_channels(int): The number of channels in the input image. - num_filters(int): The number of filter. It is as same as the output - feature map. - filter_size (int or tuple): The filter size. If filter_size is a tuple, - it must contain two integers, (filter_size_H, filter_size_W). - Otherwise, the filter will be a square. + in_channels(int): The number of channels in the input image. + out_channels(int): The number of channels produced by convolution. + kernel_size (int|list|tuple): The size of convolution kernel. + stride (int|list|tuple, optional): The stride size. If stride is a tuple, it must + contain two integers, (stride_H, stride_W). Otherwise, the + stride_H = stride_W = stride. Default: 1. padding(int|str|tuple|list, optional): The padding size. Padding coule be in one of the following forms. 1. a string in ['valid', 'same']. 2. an int, which means each spartial dimension(depth, height, width) is zero paded by size of `padding`on both sides @@ -153,10 +162,8 @@ class Conv2D(layers.Layer): 4. a list[int] or tuple[int] whose length is 2 * number of spartial dimensions. It has the form [pad_before, pad_after, pad_before, pad_after, ...] for all spartial dimensions. 5. a list or tuple of pairs of ints. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension are also included. Each pair of integers correspond to the amount of padding for a dimension of the input. Padding in batch dimension and channel dimension should be [0, 0] or (0, 0). The default value is 0. - stride (int or tuple, optional): The stride size. If stride is a tuple, it must - contain two integers, (stride_H, stride_W). Otherwise, the - stride_H = stride_W = stride. Default: 1. - dilation (int or tuple, optional): The dilation size. If dilation is a tuple, it must + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` . + dilation (int|list|tuple, optional): The dilation size. If dilation is a tuple, it must contain two integers, (dilation_H, dilation_W). Otherwise, the dilation_H = dilation_W = dilation. Default: 1. groups (int, optional): The groups number of the Conv2d Layer. According to grouped @@ -164,119 +171,108 @@ class Conv2D(layers.Layer): the first half of the filters is only connected to the first half of the input channels, while the second half of the filters is only connected to the second half of the input channels. Default: 1. - param_attr (ParamAttr, optional): The parameter attribute for learnable weights(Parameter) + weight_attr (ParamAttr, optional): The parameter attribute for learnable weights(Parameter) of conv2d. If it is set to None or one attribute of ParamAttr, conv2d will create ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is initialized with :math:`Normal(0.0, std)`, and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None. - bias_attr (ParamAttr or bool, optional): The attribute for the bias of conv2d. + bias_attr (ParamAttr|bool, optional): The attribute for the bias of conv2d. If it is set to False, no bias will be added to the output units. If it is set to None or one attribute of ParamAttr, conv2d will create ParamAttr as bias_attr. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. - use_cudnn (bool, optional): Use cudnn kernel or not, it is valid only when the cudnn - library is installed. Default: True. - act (str, optional): Activation type, if it is set to None, activation is not appended. - Default: None. data_format (str, optional): Data format that specifies the layout of input. It can be "NCHW" or "NHWC". Default: "NCHW". - dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32". Attribute: **weight** (Parameter): the learnable weights of filter of this layer. **bias** (Parameter or None): the learnable bias of this layer. - Returns: - None - - Raises: - ValueError: if ``use_cudnn`` is not a bool value. + Shape: + + - x: :math:`(N, C_{in}, H_{in}, W_{in})` + + - output: :math:`(N, C_{out}, H_{out}, W_{out})` + + Where + + .. math:: + + H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (kernel_size[0] - 1) + 1))}{strides[0]} + 1 \\\\ + W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (kernel_size[1] - 1) + 1))}{strides[1]} + 1 Examples: .. code-block:: python import numpy as np - from paddle import fluid - import paddle.fluid.dygraph as dg - from paddle import nn + + import paddle + import paddle.nn as nn x = np.random.uniform(-1, 1, (2, 4, 8, 8)).astype('float32') - place = fluid.CPUPlace() - with dg.guard(place): - x_var = dg.to_variable(x) - conv = nn.Conv2D(4, 6, (3, 3)) - y_var = conv(x_var) - y_np = y_var.numpy() - print(y_np.shape) + + paddle.disable_static() + + x_var = paddle.to_tensor(x) + conv = nn.Conv2d(4, 6, (3, 3)) + y_var = conv(x_var) + y_np = y_var.numpy() + print(y_np.shape) # (2, 6, 6, 6) """ def __init__(self, - num_channels, - num_filters, - filter_size, - padding=0, + in_channels, + out_channels, + kernel_size, stride=1, + padding=0, dilation=1, groups=1, - param_attr=None, + padding_mode='zeros', + weight_attr=None, bias_attr=None, - use_cudnn=True, - act=None, - data_format="NCHW", - dtype='float32'): - super(Conv2D, self).__init__() - assert param_attr is not False, "param_attr should not be False here." - self._num_channels = num_channels - self._num_filters = num_filters - self._groups = groups - if num_channels % groups != 0: - raise ValueError("num_channels must be divisible by groups.") - self._act = act - self._data_format = data_format - self._dtype = dtype - if not isinstance(use_cudnn, bool): - raise ValueError("use_cudnn should be True or False") - self._use_cudnn = use_cudnn - - self._filter_size = utils.convert_to_list(filter_size, 2, 'filter_size') - self._stride = utils.convert_to_list(stride, 2, 'stride') - self._dilation = utils.convert_to_list(dilation, 2, 'dilation') - channel_last = (data_format == "NHWC") - self._padding = padding # leave it to F.conv2d - - self._param_attr = param_attr - self._bias_attr = bias_attr - - num_filter_channels = num_channels // groups - filter_shape = [self._num_filters, num_filter_channels - ] + self._filter_size + data_format="NCHW"): + super(Conv2d, self).__init__( + in_channels, + out_channels, + kernel_size, + False, + 2, + stride=stride, + padding=padding, + padding_mode=padding_mode, + dilation=dilation, + groups=groups, + weight_attr=weight_attr, + bias_attr=bias_attr, + data_format=data_format) - self.weight = self.create_parameter( - attr=self._param_attr, - shape=filter_shape, - dtype=self._dtype, - default_initializer=_get_default_param_initializer( - self._num_channels, filter_shape)) - self.bias = self.create_parameter( - attr=self._bias_attr, - shape=[self._num_filters], - dtype=self._dtype, - is_bias=True) + def forward(self, x): + if self._padding_mode != 'zeros': + x = F.pad(x, + self._reversed_padding_repeated_twice, + mode=self._padding_mode, + data_format=self._data_format) + return F.conv2d( + x, + self.weight, + bias=self.bias, + stride=self._stride, + dilation=self._dilation, + groups=self._groups, + data_format=self._data_format) - def forward(self, input): out = F.conv2d( - input, + x, self.weight, bias=self.bias, padding=self._padding, stride=self._stride, dilation=self._dilation, groups=self._groups, - use_cudnn=self._use_cudnn, - act=self._act, data_format=self._data_format) return out @@ -458,14 +454,12 @@ class ConvTranspose2d(_ConvNd): return out -class Conv3D(layers.Layer): +class Conv3d(_ConvNd): """ - :alias_main: paddle.nn.Conv3D - :alias: paddle.nn.Conv3D,paddle.nn.layer.Conv3D,paddle.nn.layer.conv.Conv3D - **Convlution3D Layer** + **Convlution3d Layer** - The convolution3D layer calculates the output based on the input, filter + The convolution3d layer calculates the output based on the input, filter and strides, paddings, dilations, groups parameters. Input(Input) and Output(Output) are multidimensional tensors with a shape of :math:`[N, C, D, H, W]` . Where N is batch size, C is the number of @@ -490,33 +484,11 @@ class Conv3D(layers.Layer): * :math:`\\sigma`: Activation function. * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. - Example: - - - Input: - - Input shape: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` - - Filter shape: :math:`(C_{out}, C_{in}, D_f, H_f, W_f)` - - - Output: - Output shape: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` - - Where - - .. math:: - - D_{out}&= \\frac{(D_{in} + 2 * paddings[0] - (dilations[0] * (D_f - 1) + 1))}{strides[0]} + 1 \\\\ - H_{out}&= \\frac{(H_{in} + 2 * paddings[1] - (dilations[1] * (H_f - 1) + 1))}{strides[1]} + 1 \\\\ - W_{out}&= \\frac{(W_{in} + 2 * paddings[2] - (dilations[2] * (W_f - 1) + 1))}{strides[2]} + 1 - Parameters: - num_channels(int): The number of channels in the input image. - num_filters(int): The number of filter. It is as same as the output image channel. - filter_size (int|tuple, optional): The filter size. If filter_size is a tuple, - it must contain three integers, (filter_size_D, filter_size_H, filter_size_W). - Otherwise, the filter will be a square, filter_size_depth = filter_size_height - = filter_size_width = filter_size. - stride (int|tuple, optional): The stride size. If stride is a tuple, it must + in_channels(int): The number of input channels in the input image. + out_channels(int): The number of output channels produced by the convolution. + kernel_size (int|list|tuple, optional): The size of the convolving kernel. + stride (int|list|tuple, optional): The stride size. If stride is a tuple, it must contain three integers, (stride_D, stride_H, stride_W). Otherwise, the stride_D = stride_H = stride_W = stride. The default value is 1. padding (int|str|tuple|list, optional): The padding size. Padding coule be in one of the following forms. @@ -526,7 +498,7 @@ class Conv3D(layers.Layer): 4. a list[int] or tuple[int] whose length is 2 * number of spartial dimensions. It has the form [pad_before, pad_after, pad_before, pad_after, ...] for all spartial dimensions. 5. a list or tuple of pairs of ints. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension are also included. Each pair of integers correspond to the amount of padding for a dimension of the input. Padding in batch dimension and channel dimension should be [0, 0] or (0, 0). The default value is 0. - dilation (int|tuple, optional): The dilation size. If dilation is a tuple, it must + dilation (int|list|tuple, optional): The dilation size. If dilation is a tuple, it must contain three integers, (dilation_D, dilation_H, dilation_W). Otherwise, the dilation_D = dilation_H = dilation_W = dilation. The default value is 1. groups (int, optional): The groups number of the Conv3d Layer. According to grouped @@ -534,7 +506,8 @@ class Conv3D(layers.Layer): the first half of the filters is only connected to the first half of the input channels, while the second half of the filters is only connected to the second half of the input channels. The default value is 1. - param_attr (ParamAttr, optional): The parameter attribute for learnable parameters/weights + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``. + weight_attr (ParamAttr, optional): The parameter attribute for learnable parameters/weights of conv3d. If it is set to None or one attribute of ParamAttr, conv3d will create ParamAttr as param_attr. If it is set to None, the parameter is initialized with :math:`Normal(0.0, std)`, and the :math:`std` is @@ -544,21 +517,27 @@ class Conv3D(layers.Layer): If it is set to None or one attribute of ParamAttr, conv3d will create ParamAttr as bias_attr. If the Initializer of the bias_attr is not set, the bias is initialized zero. The default value is None. - use_cudnn (bool, optional): Use cudnn kernel or not, it is valid only when the cudnn - library is installed. The default value is True. - act (str, optional): Activation type, if it is set to None, activation is not appended. - The default value is None. data_format (str, optional): Data format that specifies the layout of input. It can be "NCDHW" or "NDHWC". Default: "NCDHW". - dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32". Attribute: **weight** (Parameter): the learnable weights of filters of this layer. **bias** (Parameter): the learnable bias of this layer. - Returns: - None. + Shape: + + - x: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` + + - output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` + + Where + + .. math:: + + D_{out}&= \\frac{(D_{in} + 2 * paddings[0] - (dilations[0] * (D_f - 1) + 1))}{strides[0]} + 1 \\\\ + H_{out}&= \\frac{(H_{in} + 2 * paddings[1] - (dilations[1] * (H_f - 1) + 1))}{strides[1]} + 1 \\\\ + W_{out}&= \\frac{(W_{in} + 2 * paddings[2] - (dilations[2] * (W_f - 1) + 1))}{strides[2]} + 1 Raises: ValueError: If the shapes of input, filter_size, stride, padding and @@ -568,85 +547,73 @@ class Conv3D(layers.Layer): .. code-block:: python import numpy as np - from paddle import fluid - import paddle.fluid.dygraph as dg - from paddle import nn + + import paddle + import paddle.nn as nn x = np.random.uniform(-1, 1, (2, 4, 8, 8, 8)).astype('float32') - place = fluid.CPUPlace() - with dg.guard(place): - x_var = dg.to_variable(x) - conv = nn.Conv3D(4, 6, (3, 3, 3)) - y_var = conv(x_var) - y_np = y_var.numpy() - print(y_np.shape) + + paddle.disable_static() + + x_var = dg.to_variable(x) + conv = nn.Conv3d(4, 6, (3, 3, 3)) + y_var = conv(x_var) + y_np = y_var.numpy() + print(y_np.shape) # (2, 6, 6, 6, 6) """ def __init__(self, - num_channels, - num_filters, - filter_size, + in_channels, + out_channels, + kernel_size, padding=0, stride=1, dilation=1, groups=1, - param_attr=None, + padding_mode='zeros', + weight_attr=None, bias_attr=None, - use_cudnn=True, - act=None, - data_format="NCDHW", - dtype='float32'): - super(Conv3D, self).__init__() - assert param_attr is not False, "param_attr should not be False here." - self._num_channels = num_channels - self._num_filters = num_filters - self._groups = groups - self._act = act - self._use_cudnn = use_cudnn - self._dtype = dtype - self._data_format = data_format - - self._stride = utils.convert_to_list(stride, 3, 'stride') - self._dilation = utils.convert_to_list(dilation, 3, 'dilation') - self._filter_size = utils.convert_to_list(filter_size, 3, 'filter_size') - channel_last = (data_format == "NDHWC") - self._padding = padding - - self._param_attr = param_attr - self._bias_attr = bias_attr - - if num_channels % groups != 0: - raise ValueError("num_channels must be divisible by groups.") - num_filter_channels = num_channels // groups - - filter_shape = [num_filters, num_filter_channels] + self._filter_size - - self.weight = self.create_parameter( - attr=self._param_attr, - shape=filter_shape, - dtype=self._dtype, - default_initializer=_get_default_param_initializer( - self._num_channels, self._filter_size)) + data_format="NCDHW"): + super(Conv3d, self).__init__( + in_channels, + out_channels, + kernel_size, + False, + 3, + stride=stride, + padding=padding, + padding_mode=padding_mode, + dilation=dilation, + groups=groups, + weight_attr=weight_attr, + bias_attr=bias_attr, + data_format=data_format) - self.bias = self.create_parameter( - attr=self._bias_attr, - shape=[self._num_filters], - dtype=self._dtype, - is_bias=True) + def forward(self, x): + if self._padding_mode != 'zeros': + x = F.pad(x, + self._reversed_padding_repeated_twice, + mode=self._padding_mode, + data_format=self._data_format) + return F.conv3d( + x, + self.weight, + bias=self.bias, + stride=self._stride, + dilation=self._dilation, + groups=self._groups, + data_format=self._data_format) - def forward(self, input): out = F.conv3d( - input, + x, self.weight, bias=self.bias, padding=self._padding, stride=self._stride, dilation=self._dilation, groups=self._groups, - use_cudnn=self._use_cudnn, - act=self._act, data_format=self._data_format) return out