未验证 提交 d0f1be51 编写于 作者: 张春乔 提交者: GitHub

[API Enhancement] No.6 support single `int` input in UpsamplingNearest2D and...

[API Enhancement] No.6 support single `int` input in UpsamplingNearest2D and UpsamplingBilinear2D (#56470)

* enhance single int input in UpsamplingNearest2D and UpsamplingBilinear2D

* add unittest

* add unittest
上级 3568a99c
......@@ -489,6 +489,8 @@ class UpsamplingNearest2D(Layer):
self, size=None, scale_factor=None, data_format='NCHW', name=None
):
super().__init__()
if isinstance(size, int):
size = [size, size]
self.size = size
self.scale_factor = scale_factor
self.data_format = data_format
......@@ -575,6 +577,8 @@ class UpsamplingBilinear2D(Layer):
self, size=None, scale_factor=None, data_format='NCHW', name=None
):
super().__init__()
if isinstance(size, int):
size = [size, size]
self.size = size
self.scale_factor = scale_factor
self.data_format = data_format
......
......@@ -116,11 +116,21 @@ class TestLayerPrint(unittest.TestCase):
str(module), 'UpsamplingNearest2D(size=[12, 12], data_format=NCHW)'
)
module = nn.UpsamplingNearest2D(size=12)
self.assertEqual(
str(module), 'UpsamplingNearest2D(size=[12, 12], data_format=NCHW)'
)
module = nn.UpsamplingBilinear2D(size=[12, 12])
self.assertEqual(
str(module), 'UpsamplingBilinear2D(size=[12, 12], data_format=NCHW)'
)
module = nn.UpsamplingBilinear2D(size=12)
self.assertEqual(
str(module), 'UpsamplingBilinear2D(size=[12, 12], data_format=NCHW)'
)
module = nn.Bilinear(in1_features=5, in2_features=4, out_features=1000)
self.assertEqual(
str(module),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册