未验证 提交 d6ee0868 编写于 作者: L Leo Chen 提交者: GitHub

fix unsqueeze in dygraph (#27107)

* fix unsqueeze in dygraph

* add test

* add test
上级 58f3ef98
...@@ -6306,6 +6306,15 @@ def unsqueeze(input, axes, name=None): ...@@ -6306,6 +6306,15 @@ def unsqueeze(input, axes, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
if isinstance(axes, int):
axes = [axes]
elif isinstance(axes, Variable):
axes = [axes.numpy().item(0)]
elif isinstance(axes, (list, tuple)):
axes = [
item.numpy().item(0) if isinstance(item, Variable) else item
for item in axes
]
out, _ = core.ops.unsqueeze2(input, 'axes', axes) out, _ = core.ops.unsqueeze2(input, 'axes', axes)
return out return out
......
...@@ -134,29 +134,60 @@ class API_TestUnsqueeze3(unittest.TestCase): ...@@ -134,29 +134,60 @@ class API_TestUnsqueeze3(unittest.TestCase):
result1, = exe.run(feed={"data1": input, result1, = exe.run(feed={"data1": input,
"data2": input2}, "data2": input2},
fetch_list=[result_squeeze]) fetch_list=[result_squeeze])
self.assertTrue(np.allclose(input1, result1)) self.assertTrue(np.array_equal(input1, result1))
self.assertEqual(input1.shape, result1.shape)
class API_TestDyUnsqueeze(unittest.TestCase): class API_TestDyUnsqueeze(unittest.TestCase):
def test_out(self): def test_out(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
input_1 = np.random.random([5, 1, 10]).astype("int32") input_1 = np.random.random([5, 1, 10]).astype("int32")
input1 = np.squeeze(input_1, axis=1) input1 = np.expand_dims(input_1, axis=1)
input = fluid.dygraph.to_variable(input_1) input = fluid.dygraph.to_variable(input_1)
output = paddle.unsqueeze(input, axis=[1]) output = paddle.unsqueeze(input, axis=[1])
out_np = output.numpy() out_np = output.numpy()
self.assertTrue(np.allclose(input1, out_np)) self.assertTrue(np.array_equal(input1, out_np))
self.assertEqual(input1.shape, out_np.shape)
class API_TestDyUnsqueeze2(unittest.TestCase): class API_TestDyUnsqueeze2(unittest.TestCase):
def test_out(self): def test_out(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
input_1 = np.random.random([5, 1, 10]).astype("int32") input1 = np.random.random([5, 10]).astype("int32")
input1 = np.squeeze(input_1, axis=1) out1 = np.expand_dims(input1, axis=1)
input = fluid.dygraph.to_variable(input_1) input = fluid.dygraph.to_variable(input1)
output = paddle.unsqueeze(input, axis=1) output = paddle.unsqueeze(input, axis=1)
out_np = output.numpy() out_np = output.numpy()
self.assertTrue(np.allclose(input1, out_np)) self.assertTrue(np.array_equal(out1, out_np))
self.assertEqual(out1.shape, out_np.shape)
class API_TestDyUnsqueezeAxisTensor(unittest.TestCase):
def test_out(self):
with fluid.dygraph.guard():
input1 = np.random.random([5, 10]).astype("int32")
out1 = np.expand_dims(input1, axis=1)
input = fluid.dygraph.to_variable(input1)
output = paddle.unsqueeze(input, axis=paddle.to_tensor([1]))
out_np = output.numpy()
self.assertTrue(np.array_equal(out1, out_np))
self.assertEqual(out1.shape, out_np.shape)
class API_TestDyUnsqueezeAxisTensorList(unittest.TestCase):
def test_out(self):
with fluid.dygraph.guard():
input1 = np.random.random([5, 10]).astype("int32")
# Actually, expand_dims supports tuple since version 1.18.0
out1 = np.expand_dims(input1, axis=1)
out1 = np.expand_dims(out1, axis=2)
input = fluid.dygraph.to_variable(input1)
output = paddle.unsqueeze(
fluid.dygraph.to_variable(input1),
axis=[paddle.to_tensor([1]), paddle.to_tensor([2])])
out_np = output.numpy()
self.assertTrue(np.array_equal(out1, out_np))
self.assertEqual(out1.shape, out_np.shape)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -746,8 +746,6 @@ def unsqueeze(x, axis, name=None): ...@@ -746,8 +746,6 @@ def unsqueeze(x, axis, name=None):
print(out3.shape) # [1, 1, 1, 5, 10] print(out3.shape) # [1, 1, 1, 5, 10]
""" """
if isinstance(axis, int):
axis = [axis]
return layers.unsqueeze(x, axis, name) return layers.unsqueeze(x, axis, name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册