未验证 提交 3a9417f4 编写于 作者: L LielinJiang 提交者: GitHub

Update 2.0 convolution api (#26491)

* update Conv2d Conv3d conv2d conv3d api
上级 7c42f056
......@@ -26,7 +26,7 @@ import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
from paddle.nn import Conv2D, Pool2D, Linear, SyncBatchNorm
from paddle.nn import Conv2d, Pool2D, Linear, SyncBatchNorm
from paddle.fluid.dygraph.base import to_variable
from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase
......@@ -42,26 +42,24 @@ class TestLayer(fluid.dygraph.Layer):
act=None):
super(TestLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
self._conv = Conv2d(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False)
self._sync_batch_norm = SyncBatchNorm(num_filters)
self._conv2 = Conv2D(
num_channels=num_filters,
num_filters=num_filters,
filter_size=filter_size,
self._conv2 = Conv2d(
in_channels=num_filters,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False)
self._sync_batch_norm2 = SyncBatchNorm(
......
......@@ -20,6 +20,10 @@ import paddle.fluid.initializer as I
import unittest
def _reverse_repeat_list(t, n):
return list(x for x in reversed(t) for _ in range(n))
class Conv2DTestCase(unittest.TestCase):
def __init__(self,
methodName='runTest',
......@@ -29,12 +33,11 @@ class Conv2DTestCase(unittest.TestCase):
num_filters=8,
filter_size=3,
padding=0,
padding_mode='zeros',
stride=1,
dilation=1,
groups=1,
act=None,
no_bias=False,
use_cudnn=True,
data_format="NCHW",
dtype="float32"):
super(Conv2DTestCase, self).__init__(methodName)
......@@ -45,12 +48,16 @@ class Conv2DTestCase(unittest.TestCase):
self.filter_size = filter_size
self.padding = padding
if padding_mode in {'reflect', 'replicate', 'circular'}:
_paired_padding = fluid.layers.utils.convert_to_list(padding, 2,
'padding')
self._reversed_padding_repeated_twice = _reverse_repeat_list(
_paired_padding, 2)
self.padding_mode = padding_mode
self.stride = stride
self.dilation = dilation
self.groups = groups
self.act = act
self.no_bias = no_bias
self.use_cudnn = use_cudnn
self.data_format = data_format
self.dtype = dtype
......@@ -91,19 +98,27 @@ class Conv2DTestCase(unittest.TestCase):
bias_attr = False
else:
bias_attr = I.NumpyArrayInitializer(self.bias)
if self.padding_mode != 'zeros':
x_var = F.pad(x_var,
self._reversed_padding_repeated_twice,
mode=self.padding_mode,
data_format=self.data_format)
padding = 0
else:
padding = self.padding
y_var = fluid.layers.conv2d(
x_var,
self.num_filters,
self.filter_size,
padding=self.padding,
padding=padding,
stride=self.stride,
dilation=self.dilation,
groups=self.groups,
param_attr=weight_attr,
bias_attr=bias_attr,
use_cudnn=self.use_cudnn,
act=self.act,
data_format=self.data_format)
feed_dict = {"input": self.input}
exe = fluid.Executor(place)
exe.run(start)
......@@ -122,16 +137,24 @@ class Conv2DTestCase(unittest.TestCase):
"weight", self.weight_shape, dtype=self.dtype)
b_var = fluid.data(
"bias", (self.num_filters, ), dtype=self.dtype)
if self.padding_mode != 'zeros':
x_var = F.pad(x_var,
self._reversed_padding_repeated_twice,
mode=self.padding_mode,
data_format=self.data_format)
padding = 0
else:
padding = self.padding
y_var = F.conv2d(
x_var,
w_var,
b_var if not self.no_bias else None,
padding=self.padding,
padding=padding,
stride=self.stride,
dilation=self.dilation,
groups=self.groups,
act=self.act,
use_cudnn=self.use_cudnn,
data_format=self.data_format)
feed_dict = {"input": self.input, "weight": self.weight}
if self.bias is not None:
......@@ -143,18 +166,16 @@ class Conv2DTestCase(unittest.TestCase):
def paddle_nn_layer(self):
x_var = dg.to_variable(self.input)
conv = nn.Conv2D(
conv = nn.Conv2d(
self.num_channels,
self.num_filters,
self.filter_size,
padding=self.padding,
padding_mode=self.padding_mode,
stride=self.stride,
dilation=self.dilation,
groups=self.groups,
act=self.act,
use_cudnn=self.use_cudnn,
data_format=self.data_format,
dtype=self.dtype)
data_format=self.data_format)
conv.weight.set_value(self.weight)
if not self.no_bias:
conv.bias.set_value(self.bias)
......@@ -198,7 +219,7 @@ def add_cases(suite):
methodName='runTest', stride=2, dilation=(2, 1)))
suite.addTest(
Conv2DTestCase(
methodName='runTest', padding="same", no_bias=True, act="sigmoid"))
methodName='runTest', padding="same", no_bias=True))
suite.addTest(
Conv2DTestCase(
methodName='runTest', filter_size=(3, 3), padding='valid'))
......@@ -222,15 +243,28 @@ def add_cases(suite):
num_filters=6,
num_channels=3,
groups=3,
use_cudnn=False,
act="sigmoid",
padding="valid"))
suite.addTest(
Conv2DTestCase(
methodName='runTest',
filter_size=(3, 3),
padding=1,
padding_mode='reflect'))
suite.addTest(
Conv2DTestCase(
methodName='runTest',
filter_size=(3, 3),
padding=1,
padding_mode='replicate'))
suite.addTest(
Conv2DTestCase(
methodName='runTest',
filter_size=(3, 3),
padding=1,
padding_mode='circular'))
def add_error_cases(suite):
suite.addTest(
Conv2DErrorTestCase(
methodName='runTest', use_cudnn="not_valid"))
suite.addTest(
Conv2DErrorTestCase(
methodName='runTest', num_channels=5, groups=2))
......
......@@ -32,9 +32,7 @@ class Conv3DTestCase(unittest.TestCase):
stride=1,
dilation=1,
groups=1,
act=None,
no_bias=False,
use_cudnn=True,
data_format="NCDHW",
dtype="float32"):
super(Conv3DTestCase, self).__init__(methodName)
......@@ -48,9 +46,7 @@ class Conv3DTestCase(unittest.TestCase):
self.stride = stride
self.dilation = dilation
self.groups = groups
self.act = act
self.no_bias = no_bias
self.use_cudnn = use_cudnn
self.data_format = data_format
self.dtype = dtype
......@@ -101,8 +97,6 @@ class Conv3DTestCase(unittest.TestCase):
groups=self.groups,
param_attr=weight_attr,
bias_attr=bias_attr,
use_cudnn=self.use_cudnn,
act=self.act,
data_format=self.data_format)
feed_dict = {"input": self.input}
exe = fluid.Executor(place)
......@@ -130,8 +124,6 @@ class Conv3DTestCase(unittest.TestCase):
stride=self.stride,
dilation=self.dilation,
groups=self.groups,
act=self.act,
use_cudnn=self.use_cudnn,
data_format=self.data_format)
feed_dict = {"input": self.input, "weight": self.weight}
if self.bias is not None:
......@@ -143,7 +135,7 @@ class Conv3DTestCase(unittest.TestCase):
def paddle_nn_layer(self):
x_var = dg.to_variable(self.input)
conv = nn.Conv3D(
conv = nn.Conv3d(
self.num_channels,
self.num_filters,
self.filter_size,
......@@ -151,10 +143,7 @@ class Conv3DTestCase(unittest.TestCase):
stride=self.stride,
dilation=self.dilation,
groups=self.groups,
act=self.act,
use_cudnn=self.use_cudnn,
data_format=self.data_format,
dtype=self.dtype)
data_format=self.data_format)
conv.weight.set_value(self.weight)
if not self.no_bias:
conv.bias.set_value(self.bias)
......@@ -225,15 +214,10 @@ def add_cases(suite):
num_filters=6,
num_channels=3,
groups=3,
use_cudnn=False,
act="sigmoid",
padding="valid"))
def add_error_cases(suite):
suite.addTest(
Conv3DErrorTestCase(
methodName='runTest', use_cudnn="not_valid"))
suite.addTest(
Conv3DErrorTestCase(
methodName='runTest', num_channels=5, groups=2))
......
......@@ -117,7 +117,7 @@ class TestDygraphWeightNorm(unittest.TestCase):
def test_check_output(self):
fluid.enable_imperative()
linear = paddle.nn.Conv2D(2, 3, 3)
linear = paddle.nn.Conv2d(2, 3, 3)
before_weight = linear.weight.numpy()
if self.dim == None:
self.dim = -1
......@@ -169,7 +169,7 @@ class TestDygraphRemoveWeightNorm(unittest.TestCase):
def test_check_output(self):
fluid.enable_imperative()
linear = paddle.nn.Conv2D(2, 3, 3)
linear = paddle.nn.Conv2d(2, 3, 3)
before_weight = linear.weight
wn = weight_norm(linear, dim=self.dim)
rwn = remove_weight_norm(linear)
......
......@@ -37,7 +37,6 @@ class TestFunctionalConv2D(TestCase):
self.groups = 1
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = True
self.data_format = "NHWC"
def prepare(self):
......@@ -88,7 +87,6 @@ class TestFunctionalConv2D(TestCase):
param_attr=I.NumpyArrayInitializer(self.weight),
bias_attr=False
if self.no_bias else I.NumpyArrayInitializer(self.bias),
use_cudnn=self.use_cudnn,
act=self.act,
data_format=self.data_format)
exe = fluid.Executor(self.place)
......@@ -121,9 +119,11 @@ class TestFunctionalConv2D(TestCase):
stride=self.stride,
dilation=self.dilation,
groups=self.groups,
act=self.act,
data_format=self.data_format,
use_cudnn=self.use_cudnn)
data_format=self.data_format)
if self.act == 'sigmoid':
y = F.sigmoid(y)
exe = fluid.Executor(self.place)
exe.run(start)
feed_dict = {"input": self.input, "weight": self.weight}
......@@ -144,10 +144,12 @@ class TestFunctionalConv2D(TestCase):
padding=self.padding,
stride=self.stride,
dilation=self.dilation,
act=self.act,
groups=self.groups,
data_format=self.data_format,
use_cudnn=self.use_cudnn)
data_format=self.data_format)
if self.act == 'sigmoid':
y = F.sigmoid(y)
out = y.numpy()
return out
......@@ -185,7 +187,6 @@ class TestFunctionalConv2DError(TestCase):
self.groups = 1
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = True
self.data_format = "NHWC"
def test_exception(self):
......@@ -228,9 +229,7 @@ class TestFunctionalConv2DError(TestCase):
stride=self.stride,
dilation=self.dilation,
groups=self.groups,
act=self.act,
data_format=self.data_format,
use_cudnn=self.use_cudnn)
data_format=self.data_format)
class TestFunctionalConv2DCase2(TestFunctionalConv2D):
......@@ -383,21 +382,6 @@ class TestFunctionalConv2DErrorCase4(TestFunctionalConv2DError):
self.data_format = "NCHW"
class TestFunctionalConv2DErrorCase6(TestFunctionalConv2DError):
def setUp(self):
self.in_channels = 3
self.out_channels = 5
self.filter_shape = 3
self.padding = "same"
self.stride = 1
self.dilation = 1
self.groups = 1
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = "not_valid"
self.data_format = "NCHW"
class TestFunctionalConv2DErrorCase7(TestFunctionalConv2DError):
def setUp(self):
self.in_channels = 3
......
......@@ -37,7 +37,6 @@ class TestFunctionalConv3D(TestCase):
self.groups = 1
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = True
self.data_format = "NDHWC"
def prepare(self):
......@@ -88,7 +87,6 @@ class TestFunctionalConv3D(TestCase):
param_attr=I.NumpyArrayInitializer(self.weight),
bias_attr=False
if self.no_bias else I.NumpyArrayInitializer(self.bias),
use_cudnn=self.use_cudnn,
act=self.act,
data_format=self.data_format)
exe = fluid.Executor(self.place)
......@@ -121,9 +119,11 @@ class TestFunctionalConv3D(TestCase):
stride=self.stride,
dilation=self.dilation,
groups=self.groups,
act=self.act,
data_format=self.data_format,
use_cudnn=self.use_cudnn)
data_format=self.data_format)
if self.act == 'sigmoid':
y = F.sigmoid(y)
exe = fluid.Executor(self.place)
exe.run(start)
feed_dict = {"input": self.input, "weight": self.weight}
......@@ -144,10 +144,12 @@ class TestFunctionalConv3D(TestCase):
padding=self.padding,
stride=self.stride,
dilation=self.dilation,
act=self.act,
groups=self.groups,
data_format=self.data_format,
use_cudnn=self.use_cudnn)
data_format=self.data_format)
if self.act == 'sigmoid':
y = F.sigmoid(y)
out = y.numpy()
return out
......@@ -185,7 +187,6 @@ class TestFunctionalConv3DError(TestCase):
self.groups = 1
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = True
self.data_format = "NDHWC"
def test_exception(self):
......@@ -228,9 +229,10 @@ class TestFunctionalConv3DError(TestCase):
stride=self.stride,
dilation=self.dilation,
groups=self.groups,
act=self.act,
data_format=self.data_format,
use_cudnn=self.use_cudnn)
data_format=self.data_format)
if self.act == 'sigmoid':
y = F.sigmoid(y)
class TestFunctionalConv3DCase2(TestFunctionalConv3D):
......@@ -244,7 +246,6 @@ class TestFunctionalConv3DCase2(TestFunctionalConv3D):
self.groups = 1
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = True
self.data_format = "NDHWC"
......@@ -259,7 +260,6 @@ class TestFunctionalConv3DCase3(TestFunctionalConv3D):
self.groups = 1
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = True
self.data_format = "NDHWC"
......@@ -274,7 +274,6 @@ class TestFunctionalConv3DCase4(TestFunctionalConv3D):
self.groups = 1
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = True
self.data_format = "NDHWC"
......@@ -289,7 +288,6 @@ class TestFunctionalConv3DCase5(TestFunctionalConv3D):
self.groups = 1
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = True
self.data_format = "NDHWC"
......@@ -304,7 +302,6 @@ class TestFunctionalConv3DCase6(TestFunctionalConv3D):
self.groups = 1
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = True
self.data_format = "NCDHW"
......@@ -319,7 +316,6 @@ class TestFunctionalConv3DCase7(TestFunctionalConv3D):
self.groups = 2
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = True
self.data_format = "NCDHW"
......@@ -349,7 +345,6 @@ class TestFunctionalConv3DErrorCase2(TestFunctionalConv3DError):
self.groups = 1
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = False
self.data_format = "NCDHW"
......@@ -364,7 +359,6 @@ class TestFunctionalConv3DErrorCase3(TestFunctionalConv3DError):
self.groups = 2
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = False
self.data_format = "not_valid"
......@@ -379,22 +373,6 @@ class TestFunctionalConv3DErrorCase4(TestFunctionalConv3DError):
self.groups = 2
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = False
self.data_format = "NCDHW"
class TestFunctionalConv3DErrorCase6(TestFunctionalConv3DError):
def setUp(self):
self.in_channels = 3
self.out_channels = 5
self.filter_shape = 3
self.padding = "same"
self.stride = 1
self.dilation = 1
self.groups = 1
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = "not_valid"
self.data_format = "NCDHW"
......@@ -409,7 +387,6 @@ class TestFunctionalConv3DErrorCase7(TestFunctionalConv3DError):
self.groups = 1
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = True
self.data_format = "not_valid"
......@@ -424,7 +401,6 @@ class TestFunctionalConv3DErrorCase8(TestFunctionalConv3DError):
self.groups = 1
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = True
self.data_format = "NCDHW"
......@@ -439,7 +415,6 @@ class TestFunctionalConv3DErrorCase9(TestFunctionalConv3DError):
self.groups = 1
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = False
self.data_format = "NCDHW"
......@@ -454,7 +429,6 @@ class TestFunctionalConv3DErrorCase10(TestFunctionalConv3DError):
self.groups = 2
self.no_bias = False
self.act = "sigmoid"
self.use_cudnn = False
self.data_format = "NDHWC"
......
......@@ -28,11 +28,11 @@ class LeNetDygraph(fluid.dygraph.Layer):
super(LeNetDygraph, self).__init__()
self.num_classes = num_classes
self.features = nn.Sequential(
nn.Conv2D(
nn.Conv2d(
1, 6, 3, stride=1, padding=1),
nn.ReLU(),
nn.Pool2D(2, 'max', 2),
nn.Conv2D(
nn.Conv2d(
6, 16, 5, stride=1, padding=0),
nn.ReLU(),
nn.Pool2D(2, 'max', 2))
......@@ -61,7 +61,7 @@ def init_weights(layer):
new_bias = paddle.fill_constant(
layer.bias.shape, layer.bias.dtype, value=-0.1)
layer.bias.set_value(new_bias)
elif type(layer) == nn.Conv2D:
elif type(layer) == nn.Conv2d:
new_weight = paddle.fill_constant(
layer.weight.shape, layer.weight.dtype, value=0.7)
layer.weight.set_value(new_weight)
......@@ -81,7 +81,7 @@ class TestLayerApply(unittest.TestCase):
if type(layer) == nn.Linear:
np.testing.assert_allclose(layer.weight.numpy(), 0.9)
np.testing.assert_allclose(layer.bias.numpy(), -0.1)
elif type(layer) == nn.Conv2D:
elif type(layer) == nn.Conv2d:
np.testing.assert_allclose(layer.weight.numpy(), 0.7)
np.testing.assert_allclose(layer.bias.numpy(), -0.2)
......
......@@ -27,11 +27,11 @@ class LeNetDygraph(fluid.dygraph.Layer):
def __init__(self):
super(LeNetDygraph, self).__init__()
self.features = nn.Sequential(
nn.Conv2D(
nn.Conv2d(
1, 6, 3, stride=1, padding=1),
nn.ReLU(),
nn.Pool2D(2, 'max', 2),
nn.Conv2D(
nn.Conv2d(
6, 16, 5, stride=1, padding=0),
nn.ReLU(),
nn.Pool2D(2, 'max', 2))
......
......@@ -26,7 +26,7 @@ paddle.manual_seed(SEED)
class Generator(fluid.dygraph.Layer):
def __init__(self):
super(Generator, self).__init__()
self.conv1 = paddle.nn.Conv2D(3, 3, 3, 1)
self.conv1 = paddle.nn.Conv2d(3, 3, 3, padding=1)
def forward(self, x):
x = self.conv1(x)
......@@ -37,7 +37,7 @@ class Generator(fluid.dygraph.Layer):
class Discriminator(fluid.dygraph.Layer):
def __init__(self):
super(Discriminator, self).__init__()
self.convd = paddle.nn.Conv2D(6, 3, 1)
self.convd = paddle.nn.Conv2d(6, 3, 1)
def forward(self, x):
x = self.convd(x)
......
......@@ -23,7 +23,7 @@ import shutil
import tempfile
from paddle import fluid
from paddle.nn import Conv2D, Pool2D, Linear, ReLU, Sequential
from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential
from paddle.fluid.dygraph.base import to_variable
import paddle.incubate.hapi as hapi
......@@ -40,11 +40,11 @@ class LeNetDygraph(fluid.dygraph.Layer):
super(LeNetDygraph, self).__init__()
self.num_classes = num_classes
self.features = Sequential(
Conv2D(
Conv2d(
1, 6, 3, stride=1, padding=1),
ReLU(),
Pool2D(2, 'max', 2),
Conv2D(
Conv2d(
6, 16, 5, stride=1, padding=0),
ReLU(),
Pool2D(2, 'max', 2))
......
......@@ -22,7 +22,7 @@ import shutil
import tempfile
from paddle import fluid
from paddle.nn import Conv2D, Pool2D, Linear, ReLU, Sequential
from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential
from paddle.incubate.hapi.utils import uncombined_weight_to_state_dict
......@@ -32,11 +32,11 @@ class LeNetDygraph(fluid.dygraph.Layer):
super(LeNetDygraph, self).__init__()
self.num_classes = num_classes
self.features = Sequential(
Conv2D(
Conv2d(
1, 6, 3, stride=1, padding=1),
ReLU(),
Pool2D(2, 'max', 2),
Conv2D(
Conv2d(
6, 16, 5, stride=1, padding=0),
ReLU(),
Pool2D(2, 'max', 2))
......
......@@ -13,7 +13,7 @@
#limitations under the License.
import paddle.fluid as fluid
from paddle.nn import Conv2D, Pool2D, Linear, ReLU, Sequential
from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential
__all__ = ['LeNet']
......@@ -39,11 +39,11 @@ class LeNet(fluid.dygraph.Layer):
super(LeNet, self).__init__()
self.num_classes = num_classes
self.features = Sequential(
Conv2D(
Conv2d(
1, 6, 3, stride=1, padding=1),
ReLU(),
Pool2D(2, 'max', 2),
Conv2D(
Conv2d(
6, 16, 5, stride=1, padding=0),
ReLU(),
Pool2D(2, 'max', 2))
......
......@@ -13,7 +13,7 @@
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.nn import Conv2d, Pool2D, BatchNorm, Linear, ReLU
from paddle.fluid.dygraph.container import Sequential
from ...download import get_weights_path_from_url
......@@ -105,12 +105,11 @@ def make_layers(cfg, batch_norm=False):
layers += [Pool2D(pool_size=2, pool_stride=2)]
else:
if batch_norm:
conv2d = Conv2D(in_channels, v, filter_size=3, padding=1)
layers += [conv2d, BatchNorm(v, act='relu')]
conv2d = Conv2d(in_channels, v, kernel_size=3, padding=1)
layers += [conv2d, BatchNorm(v), ReLU()]
else:
conv2d = Conv2D(
in_channels, v, filter_size=3, padding=1, act='relu')
layers += [conv2d]
conv2d = Conv2d(in_channels, v, kernel_size=3, padding=1)
layers += [conv2d, ReLU()]
in_channels = v
return Sequential(*layers)
......
......@@ -93,9 +93,9 @@ from .layer.common import Dropout2D #DEFINE_ALIAS
from .layer.common import Dropout3D #DEFINE_ALIAS
from .layer.pooling import AdaptiveAvgPool2d #DEFINE_ALIAS
from .layer.pooling import AdaptiveAvgPool3d #DEFINE_ALIAS
from .layer.conv import Conv2D #DEFINE_ALIAS
from .layer.conv import Conv2d #DEFINE_ALIAS
from .layer.conv import Conv3d #DEFINE_ALIAS
from .layer.conv import ConvTranspose2d #DEFINE_ALIAS
from .layer.conv import Conv3D #DEFINE_ALIAS
from .layer.conv import ConvTranspose3d #DEFINE_ALIAS
# from .layer.conv import TreeConv #DEFINE_ALIAS
# from .layer.conv import Conv1D #DEFINE_ALIAS
......
......@@ -88,20 +88,16 @@ def _update_padding_nd(padding, channel_last, num_dims):
return padding, padding_algorithm
def conv2d(input,
def conv2d(x,
weight,
bias=None,
padding=0,
stride=1,
padding=0,
dilation=1,
groups=1,
use_cudnn=True,
act=None,
data_format="NCHW",
name=None):
"""
:alias_main: paddle.nn.functional.conv2d
:alias: paddle.nn.functional.conv2d,paddle.nn.functional.conv.conv2d
The convolution2D layer calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. Input and
......@@ -153,12 +149,15 @@ def conv2d(input,
W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
Args:
input (Variable): The input is 4-D Tensor with shape [N, C, H, W], the data type
x (Tensor): The input is 4-D Tensor with shape [N, C, H, W], the data type
of input is float16 or float32 or float64.
weight (Variable): The convolution kernel with shape [M, C/g, kH, kW], where M is
weight (Tensor): The convolution kernel with shape [M, C/g, kH, kW], where M is
the number of output channels, g is the number of groups, kH is the filter's
height, kW is the filter's width.
bias (Variable, optional): The bias with shape [M,].
bias (Tensor, optional): The bias with shape [M,].
stride (int|tuple): The stride size. It means the stride in convolution.
If stride is a tuple, it must contain two integers, (stride_height, stride_width).
Otherwise, stride_height = stride_width = stride. Default: stride = 1.
padding (string|int|list|tuple): The padding size. It means the number of zero-paddings
on both sides for each dimension.If `padding` is a string, either 'VALID' or
'SAME' which is the padding algorithm. If padding size is a tuple or list,
......@@ -169,9 +168,6 @@ def conv2d(input,
when `data_format` is `"NHWC"`, `pool_padding` can be in the form
`[[0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`.
Default: padding = 0.
stride (int|tuple): The stride size. It means the stride in convolution.
If stride is a tuple, it must contain two integers, (stride_height, stride_width).
Otherwise, stride_height = stride_width = stride. Default: stride = 1.
dilation (int|tuple): The dilation size. It means the spacing between the kernel
points. If dilation is a tuple, it must contain two integers, (dilation_height,
dilation_width). Otherwise, dilation_height = dilation_width = dilation.
......@@ -181,10 +177,6 @@ def conv2d(input,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: groups=1.
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
act (str): Activation type, if it is set to None, activation is not appended.
Default: None
data_format (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`.
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
......@@ -194,13 +186,9 @@ def conv2d(input,
None by default.
Returns:
A Variable holding Tensor representing the conv2d, whose data type is the
same with input. If act is None, the tensor variable storing the convolution
result, and if act is not None, the tensor variable storing convolution
and non-linearity activation result.
A Tensor representing the conv2d result, whose data type is the same with input.
Raises:
ValueError: If the type of `use_cudnn` is not bool.
ValueError: If `data_format` is not "NCHW" or "NHWC".
ValueError: If the channel dimmention of the input is less than or equal to zero.
ValueError: If `padding` is a string, but not "SAME" or "VALID".
......@@ -215,62 +203,65 @@ def conv2d(input,
Examples:
.. code-block:: python
from paddle import fluid
import paddle
import paddle.nn.functional as F
import paddle.fluid.dygraph as dg
import numpy as np
x = np.random.randn(2, 3, 8, 8).astype(np.float32)
w = np.random.randn(6, 3, 3, 3).astype(np.float32)
place = fluid.CPUPlace()
with dg.guard(place):
x_var = dg.to_variable(x)
w_var = dg.to_variable(w)
y_var = F.conv2d(x_var, w_var, act="relu")
y_np = y_var.numpy()
paddle.disable_static()
x_var = paddle.to_tensor(x)
w_var = paddle.to_tensor(w)
y_var = F.conv2d(x_var, w_var)
y_np = y_var.numpy()
print(y_np.shape)
# (2, 6, 6, 6)
"""
# entry checks
if not isinstance(use_cudnn, bool):
raise ValueError("Attr(use_cudnn) should be True or False. "
"Received Attr(use_cudnn): {}.".format(use_cudnn))
if data_format not in ["NCHW", "NHWC"]:
raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'. "
"Received Attr(data_format): {}.".format(data_format))
channel_last = (data_format == "NHWC")
channel_dim = -1 if channel_last else 1
num_channels = input.shape[channel_dim]
num_channels = x.shape[channel_dim]
num_filters = weight.shape[0]
if num_channels < 0:
raise ValueError("The channel dimmention of the input({}) "
"should be defined. Received: {}.".format(
input.shape, num_channels))
x.shape, num_channels))
if num_channels % groups != 0:
raise ValueError(
"the channel of input must be divisible by groups,"
"received: the channel of input is {}, the shape of input is {}"
", the groups is {}".format(num_channels, input.shape, groups))
", the groups is {}".format(num_channels, x.shape, groups))
if num_filters % groups != 0:
raise ValueError(
"the number of filters must be divisible by groups,"
"received: the number of filters is {}, the shape of weight is {}"
", the groups is {}".format(num_filters, weight.shape, groups))
# use_cudnn = True if core.is_compiled_with_cuda() else False
cudnn_version = get_cudnn_version()
use_cudnn = True if (core.is_compiled_with_cuda() and
cudnn_version is not None) else False
# update attrs
padding, padding_algorithm = _update_padding_nd(padding, channel_last, 2)
stride = utils.convert_to_list(stride, 2, 'stride')
dilation = utils.convert_to_list(dilation, 2, 'dilation')
l_type = "conv2d"
if (num_channels == groups and num_filters % num_channels == 0 and
not use_cudnn):
if (num_channels == groups and num_filters % num_channels == 0):
l_type = 'depthwise_conv2d'
use_cudnn = False
inputs = {'Input': [input], 'Filter': [weight]}
inputs = {'Input': [x], 'Filter': [weight]}
attrs = {
'strides': stride,
'paddings': padding,
......@@ -288,15 +279,13 @@ def conv2d(input,
'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', False,
'fuse_relu_before_depthwise_conv', False, "padding_algorithm",
padding_algorithm, "data_format", data_format)
pre_bias = getattr(core.ops, l_type)(input, weight, *attrs)
pre_bias = getattr(core.ops, l_type)(x, weight, *attrs)
if bias is not None:
pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
out = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
else:
pre_act = pre_bias
out = dygraph_utils._append_activation_in_dygraph(
pre_act, act, use_cudnn=use_cudnn)
out = pre_bias
else:
inputs = {'Input': [input], 'Filter': [weight]}
inputs = {'Input': [x], 'Filter': [weight]}
attrs = {
'strides': stride,
'paddings': padding,
......@@ -308,8 +297,8 @@ def conv2d(input,
"padding_algorithm": padding_algorithm,
"data_format": data_format
}
check_variable_and_dtype(input, 'input',
['float16', 'float32', 'float64'], 'conv2d')
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'conv2d')
helper = LayerHelper(l_type, **locals())
dtype = helper.input_dtype()
pre_bias = helper.create_variable_for_type_inference(dtype)
......@@ -317,10 +306,10 @@ def conv2d(input,
helper.append_op(
type=l_type, inputs=inputs, outputs=outputs, attrs=attrs)
if bias is not None:
pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
out = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
else:
pre_act = pre_bias
out = helper.append_activation(pre_act)
out = pre_bias
return out
......@@ -571,20 +560,16 @@ def conv_transpose2d(x,
return out
def conv3d(input,
def conv3d(x,
weight,
bias=None,
padding=0,
stride=1,
padding=0,
dilation=1,
groups=1,
use_cudnn=True,
act=None,
data_format="NCDHW",
name=None):
"""
:alias_main: paddle.nn.functional.conv3d
:alias: paddle.nn.functional.conv3d,paddle.nn.functional.conv.conv3d
The convolution3D layer calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. Input(Input) and
......@@ -630,12 +615,15 @@ def conv3d(input,
W_{out}&= \\frac{(W_{in} + 2 * paddings[2] - (dilations[2] * (W_f - 1) + 1))}{strides[2]} + 1
Args:
input (Variable): The input is 5-D Tensor with shape [N, C, D, H, W], the data
x (Tensor): The input is 5-D Tensor with shape [N, C, D, H, W], the data
type of input is float16 or float32 or float64.
weight (Variable): The convolution kernel, a Tensor with shape [M, C/g, kD, kH, kW],
where M is the number of filters(output channels), g is the number of groups,
kD, kH, kW are the filter's depth, height and width respectively.
bias (Variable, optional): The bias, a Tensor of shape [M, ].
bias (Tensor, optional): The bias, a Tensor of shape [M, ].
stride (int|tuple): The stride size. It means the stride in convolution. If stride is a
tuple, it must contain three integers, (stride_depth, stride_height, stride_width).
Otherwise, stride_depth = stride_height = stride_width = stride. Default: stride = 1.
padding (string|int|list|tuple): The padding size. It means the number of zero-paddings
on both sides for each dimension. If `padding` is a string, either 'VALID' or
'SAME' which is the padding algorithm. If padding size is a tuple or list,
......@@ -646,9 +634,6 @@ def conv3d(input,
when `data_format` is `"NDHWC"`, `pool_padding` can be in the form
`[[0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`.
Default: padding = 0.
stride (int|tuple): The stride size. It means the stride in convolution. If stride is a
tuple, it must contain three integers, (stride_depth, stride_height, stride_width).
Otherwise, stride_depth = stride_height = stride_width = stride. Default: stride = 1.
dilation (int|tuple): The dilation size. It means the spacing between the kernel points.
If dilation is a tuple, it must contain three integers, (dilation_depth, dilation_height,
dilation_width). Otherwise, dilation_depth = dilation_height = dilation_width = dilation.
......@@ -658,10 +643,6 @@ def conv3d(input,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: groups=1
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
act (str): Activation type, if it is set to None, activation is not appended.
Default: None.
data_format (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`.
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
......@@ -671,13 +652,12 @@ def conv3d(input,
None by default.
Returns:
A Variable holding Tensor representing the conv3d, whose data type is
A Tensor representing the conv3d, whose data type is
the same with input. If act is None, the tensor variable storing the
convolution result, and if act is not None, the tensor variable storing
convolution and non-linearity activation result.
Raises:
ValueError: If the type of `use_cudnn` is not bool.
ValueError: If `data_format` is not "NCDHW" or "NDHWC".
ValueError: If the channel dimmention of the input is less than or equal to zero.
ValueError: If `padding` is a string, but not "SAME" or "VALID".
......@@ -711,10 +691,6 @@ def conv3d(input,
# (2, 6, 6, 6, 6)
"""
# entry check
if not isinstance(use_cudnn, bool):
raise ValueError("Attr(use_cudnn) should be True or False. Received "
"Attr(use_cudnn): {}. ".format(use_cudnn))
if data_format not in ["NCDHW", "NDHWC"]:
raise ValueError(
"Attr(data_format) should be 'NCDHW' or 'NDHWC'. Received "
......@@ -722,12 +698,12 @@ def conv3d(input,
channel_last = (data_format == "NDHWC")
channel_dim = -1 if channel_last else 1
num_channels = input.shape[channel_dim]
num_channels = x.shape[channel_dim]
num_filters = weight.shape[0]
if num_channels < 0:
raise ValueError(
"The channel dimmention of the input({}) should be defined. "
"Received: {}.".format(input.shape, num_channels))
"Received: {}.".format(x.shape, num_channels))
if num_channels % groups != 0:
raise ValueError(
"The number of input channels must be divisible by Attr(groups). "
......@@ -739,6 +715,10 @@ def conv3d(input,
"Received: number of filters({}), groups({}).".format(num_filters,
groups))
cudnn_version = get_cudnn_version()
use_cudnn = True if (core.is_compiled_with_cuda() and
cudnn_version is not None) else False
padding, padding_algorithm = _update_padding_nd(padding, channel_last, 3)
stride = utils.convert_to_list(stride, 3, 'stride')
dilation = utils.convert_to_list(dilation, 3, 'dilation')
......@@ -749,15 +729,13 @@ def conv3d(input,
'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', False,
"padding_algorithm", padding_algorithm, "data_format",
data_format)
pre_bias = getattr(core.ops, op_type)(input, weight, *attrs)
pre_bias = getattr(core.ops, op_type)(x, weight, *attrs)
if bias is not None:
pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
out = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
else:
pre_act = pre_bias
out = dygraph_utils._append_activation_in_dygraph(
pre_act, act, use_cudnn=use_cudnn)
out = pre_bias
else:
inputs = {'Input': [input], 'Filter': [weight]}
inputs = {'Input': [x], 'Filter': [weight]}
attrs = {
'strides': stride,
'paddings': padding,
......@@ -770,8 +748,8 @@ def conv3d(input,
}
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype()
check_variable_and_dtype(input, 'input',
['float16', 'float32', 'float64'], 'conv3d')
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'conv3d')
pre_bias = helper.create_variable_for_type_inference(dtype)
outputs = {"Output": [pre_bias]}
......@@ -779,10 +757,9 @@ def conv3d(input,
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
if bias is not None:
pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
out = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
else:
pre_act = pre_bias
out = helper.append_activation(pre_act)
out = pre_bias
return out
......
......@@ -57,9 +57,9 @@ from .common import Dropout2D #DEFINE_ALIAS
from .common import Dropout3D #DEFINE_ALIAS
from .pooling import AdaptiveAvgPool2d #DEFINE_ALIAS
from .pooling import AdaptiveAvgPool3d #DEFINE_ALIAS
from .conv import Conv2D #DEFINE_ALIAS
from .conv import Conv2d #DEFINE_ALIAS
from .conv import Conv3d #DEFINE_ALIAS
from .conv import ConvTranspose2d #DEFINE_ALIAS
from .conv import Conv3D #DEFINE_ALIAS
from .conv import ConvTranspose3d #DEFINE_ALIAS
# from .conv import TreeConv #DEFINE_ALIAS
# from .conv import Conv1D #DEFINE_ALIAS
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册