提交 d4bb2ca7 编写于 作者: G guosheng

Follow comments and refine the python wrapper of reshape_op

上级 454b0a96
......@@ -3361,7 +3361,9 @@ def reshape(x, shape, act=None, inplace=True, name=None):
Examples:
.. code-block:: python
data = fluid.layers.data(name='data', shape=[2, 4, 6], dtype='float32')
data = fluid.layers.data(
name='data', shape=[2, 4, 6], dtype='float32'
)
reshaped = fluid.layers.reshape(
x=data, shape=[-1, 0, 3, 2], act='tanh', inplace=True
)
......@@ -3371,6 +3373,21 @@ def reshape(x, shape, act=None, inplace=True, name=None):
if not (isinstance(shape, list) or isinstance(shape, tuple)):
raise ValueError("Input shape must be a python lsit or tuple.")
# Validate the shape
unk_dim_idx = -1
for dim_idx, dim_size in enumerate(shape):
if dim_size == -1:
assert unk_dim_idx == -1, (
"Only one dimension in shape can be unknown.")
unk_dim_idx = dim_idx
elif dim_size == 0:
assert dim_idx < len(x.shape), (
"The indice of 0s in shape can not exceed Rank(X).")
else:
assert dim_size > 0, (
"Each dimension size given in shape must not be negtive "
"except one unknown dimension.")
helper = LayerHelper("reshape", **locals())
reshaped = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册