未验证 提交 d0f1be51 编写于 作者: 张春乔 提交者: GitHub

[API Enhancement] No.6 support single `int` input in UpsamplingNearest2D and...

[API Enhancement] No.6 support single `int` input in UpsamplingNearest2D and UpsamplingBilinear2D (#56470)

* enhance single int input in UpsamplingNearest2D and UpsamplingBilinear2D

* add unittest

* add unittest
上级 3568a99c
...@@ -489,6 +489,8 @@ class UpsamplingNearest2D(Layer): ...@@ -489,6 +489,8 @@ class UpsamplingNearest2D(Layer):
self, size=None, scale_factor=None, data_format='NCHW', name=None self, size=None, scale_factor=None, data_format='NCHW', name=None
): ):
super().__init__() super().__init__()
if isinstance(size, int):
size = [size, size]
self.size = size self.size = size
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.data_format = data_format self.data_format = data_format
...@@ -575,6 +577,8 @@ class UpsamplingBilinear2D(Layer): ...@@ -575,6 +577,8 @@ class UpsamplingBilinear2D(Layer):
self, size=None, scale_factor=None, data_format='NCHW', name=None self, size=None, scale_factor=None, data_format='NCHW', name=None
): ):
super().__init__() super().__init__()
if isinstance(size, int):
size = [size, size]
self.size = size self.size = size
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.data_format = data_format self.data_format = data_format
......
...@@ -116,11 +116,21 @@ class TestLayerPrint(unittest.TestCase): ...@@ -116,11 +116,21 @@ class TestLayerPrint(unittest.TestCase):
str(module), 'UpsamplingNearest2D(size=[12, 12], data_format=NCHW)' str(module), 'UpsamplingNearest2D(size=[12, 12], data_format=NCHW)'
) )
module = nn.UpsamplingNearest2D(size=12)
self.assertEqual(
str(module), 'UpsamplingNearest2D(size=[12, 12], data_format=NCHW)'
)
module = nn.UpsamplingBilinear2D(size=[12, 12]) module = nn.UpsamplingBilinear2D(size=[12, 12])
self.assertEqual( self.assertEqual(
str(module), 'UpsamplingBilinear2D(size=[12, 12], data_format=NCHW)' str(module), 'UpsamplingBilinear2D(size=[12, 12], data_format=NCHW)'
) )
module = nn.UpsamplingBilinear2D(size=12)
self.assertEqual(
str(module), 'UpsamplingBilinear2D(size=[12, 12], data_format=NCHW)'
)
module = nn.Bilinear(in1_features=5, in2_features=4, out_features=1000) module = nn.Bilinear(in1_features=5, in2_features=4, out_features=1000)
self.assertEqual( self.assertEqual(
str(module), str(module),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册