未验证 提交 e29c50c2 编写于 作者: 傅剑寒 提交者: GitHub

remove pad2d in nn.py (#47854)

上级 29782728
......@@ -116,7 +116,6 @@ __all__ = [
'crop_tensor',
'prelu',
'flatten',
'pad2d',
'unique',
'unique_with_counts',
'scale',
......@@ -7579,151 +7578,6 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
return out
def pad2d(
input,
paddings=[0, 0, 0, 0],
mode='constant',
pad_value=0.0,
data_format="NCHW",
name=None,
):
"""
Pad 2-d images according to 'paddings' and 'mode'.
If mode is 'reflect', paddings[0] and paddings[1] must be no greater
than height-1. And the width dimension has the same condition.
Parameters:
input (Tensor): The input image with [N, C, H, W] format or [N, H, W, C] format, which is a 4-D Tensor with data type float32.
paddings (Tensor | List[int32]): The padding size. If padding is a List, it must
contain four integers, (padding_top, padding_bottom, padding_left, padding_right).
Otherwise, it is a 1-D Tensor with shape [4]. Data type is int32.
Default is [0, 0, 0, 0].
mode (str): Three modes: 'constant' (default), 'reflect', 'edge' .
When in 'constant' mode, this op uses a constant value to pad the input tensor.
When in 'reflect' mode, uses reflection of the input boundaries to pad the input tensor.
When in 'edge' mode, uses input boundaries to pad the input tensor.
Default is 'constant'
pad_value (float32): The value to fill the padded areas in 'constant' mode . Default is 0.0
data_format (str): An string from: "NHWC", "NCHW". Specify the data format of
the input data.
Default is "NCHW"
name (str, optional) : The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name` .
Returns:
Tensor, a 4-D Tensor padded according to paddings and mode and data type is same as input.
Examples:
.. code-block:: text
Input = [[[[1., 2., 3.],
[4., 5., 6.]]]]
Case 0:
paddings = [0, 1, 2, 3],
mode = 'constant'
pad_value = 0
Out = [[[[0., 0., 1., 2., 3., 0., 0., 0.],
[0., 0., 4., 5., 6., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]]]]
Case 1:
paddings = [0, 1, 2, 1],
mode = 'reflect'
Out = [[[[3., 2., 1., 2., 3., 2.],
[6., 5., 4., 5., 6., 5.],
[3., 2., 1., 2., 3., 2.]]]]
Case 2:
paddings = [0, 1, 2, 1],
mode = 'edge'
Out = [[[[1., 1., 1., 2., 3., 3.],
[4., 4., 4., 5., 6., 6.],
[4., 4., 4., 5., 6., 6.]]]]
Code Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn.functional as F
# example 1
x_shape = (1, 1, 3, 4)
x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) + 1
tensor_x = paddle.to_tensor(x)
y = paddle.fluid.layers.pad2d(tensor_x, paddings=[1, 2, 2, 1], pad_value=1, mode='constant')
print(y.numpy())
# [[[[ 1. 1. 1. 1. 1. 1. 1.]
# [ 1. 1. 1. 2. 3. 4. 1.]
# [ 1. 1. 5. 6. 7. 8. 1.]
# [ 1. 1. 9. 10. 11. 12. 1.]
# [ 1. 1. 1. 1. 1. 1. 1.]
# [ 1. 1. 1. 1. 1. 1. 1.]]]]
# example 2
x_shape = (1, 1, 2, 3)
x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) + 1
tensor_x = paddle.to_tensor(x)
y = paddle.fluid.layers.pad2d(tensor_x, paddings=[1, 1, 1, 1], mode='reflect')
print(y.numpy())
# [[[[5. 4. 5. 6. 5.]
# [2. 1. 2. 3. 2.]
# [5. 4. 5. 6. 5.]
# [2. 1. 2. 3. 2.]]]]
"""
if _non_static_mode():
_paddings = (
paddings.numpy().tolist()
if isinstance(paddings, Variable)
else paddings
)
return _legacy_C_ops.pad2d(
input,
'mode',
mode,
'pad_value',
pad_value,
'data_format',
data_format,
'paddings',
_paddings,
)
check_variable_and_dtype(
input,
'input',
['float16', 'float32', 'float64', 'int32', 'int64'],
"pad2d",
)
attrs = {'mode': mode, 'pad_value': pad_value, 'data_format': data_format}
inputs = {'X': [input]}
if isinstance(paddings, Variable):
inputs['Paddings'] = [paddings]
attrs['paddings'] = []
else:
attrs['paddings'] = paddings
helper = LayerHelper('pad2d', **locals())
assert mode in [
'reflect',
'edge',
'constant',
], "mode should be one of constant, reflect, edge."
dtype = helper.input_dtype(input_param_name='input')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='pad2d', inputs=inputs, outputs={"Out": out}, attrs=attrs
)
return out
@deprecated(since="2.0.0", update_to="paddle.static.nn.prelu")
def prelu(x, mode, param_attr=None, data_format="NCHW", name=None):
r"""
......
......@@ -179,10 +179,12 @@ class build_resnet_block(fluid.dygraph.Layer):
self.dim = dim
def forward(self, inputs):
out_res = fluid.layers.pad2d(inputs, [1, 1, 1, 1], mode="reflect")
pad1 = paddle.nn.Pad2D([1, 1, 1, 1], mode="reflect")
out_res = pad1(inputs)
out_res = self.conv0(out_res)
out_res = fluid.layers.pad2d(out_res, [1, 1, 1, 1], mode="reflect")
pad2 = paddle.nn.Pad2D([1, 1, 1, 1], mode="reflect")
out_res = pad2(out_res)
out_res = self.conv1(out_res)
return out_res + inputs
......@@ -253,7 +255,8 @@ class build_generator_resnet_9blocks(fluid.dygraph.Layer):
)
def forward(self, inputs):
pad_input = fluid.layers.pad2d(inputs, [3, 3, 3, 3], mode="reflect")
pad1 = paddle.nn.Pad2D([3, 3, 3, 3], mode="reflect")
pad_input = pad1(inputs)
y = self.conv0(pad_input)
y = self.conv1(y)
y = self.conv2(y)
......@@ -261,7 +264,8 @@ class build_generator_resnet_9blocks(fluid.dygraph.Layer):
y = build_resnet_block_i(y)
y = self.deconv0(y)
y = self.deconv1(y)
y = fluid.layers.pad2d(y, [3, 3, 3, 3], mode="reflect")
pad2 = paddle.nn.Pad2D([3, 3, 3, 3], mode="reflect")
y = pad2(y)
y = self.conv3(y)
y = paddle.tanh(y)
return y
......@@ -461,9 +465,10 @@ class DeConv2D(fluid.dygraph.Layer):
def forward(self, inputs):
conv = self._deconv(inputs)
conv = fluid.layers.pad2d(
conv, paddings=self.outpadding, mode='constant', pad_value=0.0
tmp_pad = paddle.nn.Pad2D(
padding=self.outpadding, mode='constant', value=0.0
)
conv = tmp_pad(conv)
if self.norm:
conv = self.bn(conv)
......
......@@ -3502,23 +3502,15 @@ class TestBook(LayerTest):
input = self._get_data(
name="input", shape=[3, 100, 100], dtype="float32"
)
paddings = layers.fill_constant(shape=[4], dtype='int32', value=1)
out = layers.pad2d(
input,
paddings=[1, 2, 3, 4],
mode='reflect',
data_format='NCHW',
name="shape",
)
out_1 = layers.pad2d(
input,
paddings=paddings,
tmp_pad = paddle.nn.Pad2D(
padding=[1, 2, 3, 4],
mode='reflect',
data_format='NCHW',
name="shape",
)
out = tmp_pad(input)
return out
return out_1
def make_prelu(self):
with program_guard(
......
......@@ -12,11 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import unittest
class TestPad2dOp(OpTest):
......@@ -138,21 +136,5 @@ class TestCase7(TestPad2dOp):
self.variable_paddings = True
class TestPad2dOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
input_data = np.random.random((2, 2, 2, 2)).astype("float32")
def test_Variable():
fluid.layers.pad2d(input=input_data, paddings=[1, 1, 1, 1])
self.assertRaises(TypeError, test_Variable)
data = fluid.data(
name='data', shape=[None, 3, 20, 20], dtype='float16'
)
fluid.layers.pad2d(input=data, paddings=[1, 1, 1, 1])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册