提交 b586cc2a 编写于 作者: Y Yibing Liu

Fix typos in unsqueeze & unsequeeze wrapper

上级 1443d762
......@@ -4485,6 +4485,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
return helper.append_activation(out)
def squeeze(x, axes, inplace=False, name=None):
"""
Remove single-dimensional entries from the shape of a tensor. Takes a
......@@ -4511,7 +4512,7 @@ def squeeze(x, axes, inplace=False, name=None):
Args:
x (Variable): The input variable to be squeezed.
axes (list): List of integers, indicating the dimensions to be squeezed.
name (str): Name for this layers.
name (str|None): Name for this layer.
Returns:
Variable: Output squeezed variable.
......@@ -4530,8 +4531,9 @@ def squeeze(x, axes, inplace=False, name=None):
attrs={"axes": axes},
outputs={"Out": out})
return out
return out
def unsqueeze(x, axes, inplace=False, name=None):
"""
Insert single-dimensional entries to the shape of a tensor. Takes one
......@@ -4545,7 +4547,7 @@ def unsqueeze(x, axes, inplace=False, name=None):
Args:
x (Variable): The input variable to be unsqueezed.
axes (list): List of integers, indicating the dimensions to be inserted.
name (str): Name for this layers.
name (str|None): Name for this layer.
Returns:
Variable: Output unsqueezed variable.
......@@ -4564,7 +4566,8 @@ def unsqueeze(x, axes, inplace=False, name=None):
attrs={"axes": axes},
outputs={"Out": out})
return out
return out
def lod_reset(x, y=None, target_lod=None):
"""
......
......@@ -243,16 +243,16 @@ class TestBook(unittest.TestCase):
def test_sequence_unsqueeze(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[8,2], dtype='float32')
x = layers.data(name='x', shape=[8, 2], dtype='float32')
out = layers.unsqueeze(x=x, axes=[1])
self.assertIsNotNone(out)
print(str(program))
def test_squeeze(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[1, 1, 4], dtype='float32')
out = layers.squeeze(x=x, axes=[0])
out = layers.squeeze(x=x, axes=[2])
self.assertIsNotNone(out)
print(str(program))
......@@ -277,7 +277,6 @@ class TestBook(unittest.TestCase):
out = layers.sequence_reshape(input=x, new_dim=16)
self.assertIsNotNone(out)
print(str(program))
def test_im2sequence(self):
program = Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册