提交 e7faae01 编写于 作者: F fengjiayi

Refine assign layer

Give assign layer's second parameter 'output' a default value: None. If
it is None, the output variable will be created inside the layer.
上级 732eef57
...@@ -155,7 +155,7 @@ def cast(x, dtype): ...@@ -155,7 +155,7 @@ def cast(x, dtype):
Examples: Examples:
.. code-block:: python .. code-block:: python
data = fluid.layers.data(name='x', shape=[13], dtype='float32') data = fluid.layers.data(name='x', shape=[13], dtype='float32')
result = fluid.layers.cast(x=data, dtype='float64') result = fluid.layers.cast(x=data, dtype='float64')
""" """
...@@ -188,7 +188,7 @@ def concat(input, axis=0, name=None): ...@@ -188,7 +188,7 @@ def concat(input, axis=0, name=None):
Examples: Examples:
.. code-block:: python .. code-block:: python
out = fluid.layers.concat(input=[Efirst, Esecond, Ethird, Efourth]) out = fluid.layers.concat(input=[Efirst, Esecond, Ethird, Efourth])
""" """
helper = LayerHelper('concat', **locals()) helper = LayerHelper('concat', **locals())
...@@ -234,7 +234,7 @@ def sums(input, out=None): ...@@ -234,7 +234,7 @@ def sums(input, out=None):
return out return out
def assign(input, output): def assign(input, output=None):
""" """
**Assign** **Assign**
...@@ -242,7 +242,7 @@ def assign(input, output): ...@@ -242,7 +242,7 @@ def assign(input, output):
Args: Args:
input(Variable|numpy.ndarray): The source variable input(Variable|numpy.ndarray): The source variable
output(Variable): The destination variable output(Variable|None): The destination variable
Returns: Returns:
Variable: The destination variable that was supplied as the *output*. Variable: The destination variable that was supplied as the *output*.
...@@ -255,6 +255,8 @@ def assign(input, output): ...@@ -255,6 +255,8 @@ def assign(input, output):
fluid.layers.assign(hidden, out) fluid.layers.assign(hidden, out)
""" """
helper = LayerHelper('assign', **locals()) helper = LayerHelper('assign', **locals())
if output is None:
output = helper.create_tmp_variable(dtype=input.dtype)
if isinstance(input, Variable): if isinstance(input, Variable):
helper.append_op( helper.append_op(
type='assign', inputs={'X': [input]}, outputs={'Out': [output]}) type='assign', inputs={'X': [input]}, outputs={'Out': [output]})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册