提交 ed4aa219 编写于 作者: X Xin Pan

Small doc fix and clean up of reshape

上级 c90e64e7
...@@ -4263,14 +4263,18 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): ...@@ -4263,14 +4263,18 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
say :attr:`actual_shape` has a higher priority say :attr:`actual_shape` has a higher priority
than :attr:`shape`. than :attr:`shape`.
act (str): The non-linear activation to be applied to output variable. act (str): The non-linear activation to be applied to output variable.
inplace(bool): If this flag is set true, a new output tensor is created inplace(bool): If this flag is set true, the output
whose data is copied from input x, otherwise the output shares data with input without copying, otherwise
shares data with input without copying. a new output tensor is created
whose data is copied from input x.
name (str): The name of this layer. It is optional. name (str): The name of this layer. It is optional.
Returns: Returns:
Variable: The output tensor. Variable: The output tensor.
Raises:
TypeError: if actual_shape is neither Variable nor None.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -4282,6 +4286,11 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): ...@@ -4282,6 +4286,11 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
if not (isinstance(shape, list) or isinstance(shape, tuple)): if not (isinstance(shape, list) or isinstance(shape, tuple)):
raise ValueError("Input shape must be a python lsit or tuple.") raise ValueError("Input shape must be a python lsit or tuple.")
inputs = {"X": x}
if isinstance(actual_shape, Variable):
inputs["Shape"] = actual_shape
elif actual_shape is not None:
raise TypeError("actual_shape should either be Variable or None")
# Validate the shape # Validate the shape
unk_dim_idx = -1 unk_dim_idx = -1
...@@ -4302,9 +4311,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): ...@@ -4302,9 +4311,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
reshaped = helper.create_tmp_variable(dtype=x.dtype) reshaped = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op( helper.append_op(
type="reshape", type="reshape",
inputs={"X": x, inputs=inputs,
"Shape": actual_shape}
if isinstance(actual_shape, Variable) else {"X": x},
attrs={"shape": shape, attrs={"shape": shape,
"inplace": inplace}, "inplace": inplace},
outputs={"Out": reshaped}) outputs={"Out": reshaped})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册