From d6ee0868a45a80b6560f1e70c3a41ed4064683d0 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Tue, 8 Sep 2020 10:29:23 +0800 Subject: [PATCH] fix unsqueeze in dygraph (#27107) * fix unsqueeze in dygraph * add test * add test --- python/paddle/fluid/layers/nn.py | 9 ++++ .../tests/unittests/test_unsqueeze_op.py | 45 ++++++++++++++++--- python/paddle/tensor/manipulation.py | 2 - 3 files changed, 47 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 9313de8c64..70f48e82fd 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6306,6 +6306,15 @@ def unsqueeze(input, axes, name=None): """ if in_dygraph_mode(): + if isinstance(axes, int): + axes = [axes] + elif isinstance(axes, Variable): + axes = [axes.numpy().item(0)] + elif isinstance(axes, (list, tuple)): + axes = [ + item.numpy().item(0) if isinstance(item, Variable) else item + for item in axes + ] out, _ = core.ops.unsqueeze2(input, 'axes', axes) return out diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py index 9382d53e7f..6f713172f1 100644 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -134,29 +134,60 @@ class API_TestUnsqueeze3(unittest.TestCase): result1, = exe.run(feed={"data1": input, "data2": input2}, fetch_list=[result_squeeze]) - self.assertTrue(np.allclose(input1, result1)) + self.assertTrue(np.array_equal(input1, result1)) + self.assertEqual(input1.shape, result1.shape) class API_TestDyUnsqueeze(unittest.TestCase): def test_out(self): with fluid.dygraph.guard(): input_1 = np.random.random([5, 1, 10]).astype("int32") - input1 = np.squeeze(input_1, axis=1) + input1 = np.expand_dims(input_1, axis=1) input = fluid.dygraph.to_variable(input_1) output = paddle.unsqueeze(input, axis=[1]) out_np = output.numpy() - self.assertTrue(np.allclose(input1, out_np)) + self.assertTrue(np.array_equal(input1, out_np)) + self.assertEqual(input1.shape, out_np.shape) class API_TestDyUnsqueeze2(unittest.TestCase): def test_out(self): with fluid.dygraph.guard(): - input_1 = np.random.random([5, 1, 10]).astype("int32") - input1 = np.squeeze(input_1, axis=1) - input = fluid.dygraph.to_variable(input_1) + input1 = np.random.random([5, 10]).astype("int32") + out1 = np.expand_dims(input1, axis=1) + input = fluid.dygraph.to_variable(input1) output = paddle.unsqueeze(input, axis=1) out_np = output.numpy() - self.assertTrue(np.allclose(input1, out_np)) + self.assertTrue(np.array_equal(out1, out_np)) + self.assertEqual(out1.shape, out_np.shape) + + +class API_TestDyUnsqueezeAxisTensor(unittest.TestCase): + def test_out(self): + with fluid.dygraph.guard(): + input1 = np.random.random([5, 10]).astype("int32") + out1 = np.expand_dims(input1, axis=1) + input = fluid.dygraph.to_variable(input1) + output = paddle.unsqueeze(input, axis=paddle.to_tensor([1])) + out_np = output.numpy() + self.assertTrue(np.array_equal(out1, out_np)) + self.assertEqual(out1.shape, out_np.shape) + + +class API_TestDyUnsqueezeAxisTensorList(unittest.TestCase): + def test_out(self): + with fluid.dygraph.guard(): + input1 = np.random.random([5, 10]).astype("int32") + # Actually, expand_dims supports tuple since version 1.18.0 + out1 = np.expand_dims(input1, axis=1) + out1 = np.expand_dims(out1, axis=2) + input = fluid.dygraph.to_variable(input1) + output = paddle.unsqueeze( + fluid.dygraph.to_variable(input1), + axis=[paddle.to_tensor([1]), paddle.to_tensor([2])]) + out_np = output.numpy() + self.assertTrue(np.array_equal(out1, out_np)) + self.assertEqual(out1.shape, out_np.shape) if __name__ == "__main__": diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 71ac809ddf..db1222fa42 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -746,8 +746,6 @@ def unsqueeze(x, axis, name=None): print(out3.shape) # [1, 1, 1, 5, 10] """ - if isinstance(axis, int): - axis = [axis] return layers.unsqueeze(x, axis, name) -- GitLab