From d4bb2ca71f72e31b78231e1bc0907330392ef759 Mon Sep 17 00:00:00 2001 From: guosheng Date: Thu, 22 Mar 2018 13:36:58 +0800 Subject: [PATCH] Follow comments and refine the python wrapper of reshape_op --- python/paddle/fluid/layers/nn.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index b4e3e83e3ab..d98e1bdfcaf 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3361,7 +3361,9 @@ def reshape(x, shape, act=None, inplace=True, name=None): Examples: .. code-block:: python - data = fluid.layers.data(name='data', shape=[2, 4, 6], dtype='float32') + data = fluid.layers.data( + name='data', shape=[2, 4, 6], dtype='float32' + ) reshaped = fluid.layers.reshape( x=data, shape=[-1, 0, 3, 2], act='tanh', inplace=True ) @@ -3371,6 +3373,21 @@ def reshape(x, shape, act=None, inplace=True, name=None): if not (isinstance(shape, list) or isinstance(shape, tuple)): raise ValueError("Input shape must be a python lsit or tuple.") + # Validate the shape + unk_dim_idx = -1 + for dim_idx, dim_size in enumerate(shape): + if dim_size == -1: + assert unk_dim_idx == -1, ( + "Only one dimension in shape can be unknown.") + unk_dim_idx = dim_idx + elif dim_size == 0: + assert dim_idx < len(x.shape), ( + "The indice of 0s in shape can not exceed Rank(X).") + else: + assert dim_size > 0, ( + "Each dimension size given in shape must not be negtive " + "except one unknown dimension.") + helper = LayerHelper("reshape", **locals()) reshaped = helper.create_tmp_variable(dtype=x.dtype) helper.append_op( -- GitLab