未验证 提交 0b74428d 编写于 作者: L LielinJiang 提交者: GitHub

Fix Conv2DTanspose bug when padding='same' (#29915)

* fix conv_transpose bug when padding=same
上级 11de384c
...@@ -219,6 +219,7 @@ def add_cases(suite): ...@@ -219,6 +219,7 @@ def add_cases(suite):
suite.addTest( suite.addTest(
Conv2DTransposeTestCase( Conv2DTransposeTestCase(
methodName='runTest', padding="valid")) methodName='runTest', padding="valid"))
suite.addTest(Conv2DTransposeTestCase(methodName='runTest', padding="same"))
suite.addTest( suite.addTest(
Conv2DTransposeTestCase( Conv2DTransposeTestCase(
methodName='runTest', filter_size=1, padding=(2, 3))) methodName='runTest', filter_size=1, padding=(2, 3)))
......
...@@ -100,14 +100,12 @@ class _ConvNd(layers.Layer): ...@@ -100,14 +100,12 @@ class _ConvNd(layers.Layer):
self._padding_mode = padding_mode self._padding_mode = padding_mode
self.output_padding = output_padding self.output_padding = output_padding
if dims != 1: if dims != 1:
self._padding, self._padding_algorithm = _update_padding_nd( self._updated_padding, self._padding_algorithm = _update_padding_nd(
padding, channel_last, dims) padding, channel_last, dims)
if transposed: if transposed:
filter_shape = [self._in_channels, out_channels // groups filter_shape = [self._in_channels, out_channels // groups
] + self._kernel_size ] + self._kernel_size
self._padding, self._padding_algorithm = _update_padding_nd(
padding, channel_last, dims)
else: else:
if in_channels % groups != 0: if in_channels % groups != 0:
raise ValueError("in_channels must be divisible by groups.") raise ValueError("in_channels must be divisible by groups.")
...@@ -118,7 +116,8 @@ class _ConvNd(layers.Layer): ...@@ -118,7 +116,8 @@ class _ConvNd(layers.Layer):
self._reversed_padding_repeated_twice = _reverse_repeat_list( self._reversed_padding_repeated_twice = _reverse_repeat_list(
_paired_padding, 2) _paired_padding, 2)
self._padding, _ = _update_padding_nd(0, channel_last, dims) self._updated_padding, self._padding_algorithm = _update_padding_nd(
0, channel_last, dims)
filter_shape = [out_channels, in_channels // groups filter_shape = [out_channels, in_channels // groups
] + self._kernel_size ] + self._kernel_size
...@@ -634,7 +633,7 @@ class Conv2D(_ConvNd): ...@@ -634,7 +633,7 @@ class Conv2D(_ConvNd):
self.weight, self.weight,
bias=self.bias, bias=self.bias,
stride=self._stride, stride=self._stride,
padding=self._padding, padding=self._updated_padding,
padding_algorithm=self._padding_algorithm, padding_algorithm=self._padding_algorithm,
dilation=self._dilation, dilation=self._dilation,
groups=self._groups, groups=self._groups,
...@@ -951,7 +950,7 @@ class Conv3D(_ConvNd): ...@@ -951,7 +950,7 @@ class Conv3D(_ConvNd):
self.weight, self.weight,
bias=self.bias, bias=self.bias,
stride=self._stride, stride=self._stride,
padding=self._padding, padding=self._updated_padding,
padding_algorithm=self._padding_algorithm, padding_algorithm=self._padding_algorithm,
dilation=self._dilation, dilation=self._dilation,
groups=self._groups, groups=self._groups,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册