未验证 提交 84d8e49d 编写于 作者: M MRXLT 提交者: GitHub

refine adam/strided_slice && fix doc for rmsprop/unstack (#27740)

* refine parameters order && doc

* update rmsprop doc

* refine adam/transpose/unstack/stride_slice

* fix bug && doc

* fix doc

* bug fix

* bug fix

* fix doc

* fix doc

* fix doc

* fix doc

* depercate old strided_slice

* update doc

* set default value for name

* update doc
上级 e96fc6ab
...@@ -10241,9 +10241,9 @@ def unstack(x, axis=0, num=None): ...@@ -10241,9 +10241,9 @@ def unstack(x, axis=0, num=None):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
x = fluid.data(name='x', shape=[2, 3, 5], dtype='float32') # create a tensor with shape=[2, 3, 5] x = paddle.ones(name='x', shape=[2, 3, 5], dtype='float32') # create a tensor with shape=[2, 3, 5]
y = fluid.layers.unstack(x, axis=1) # unstack with second axis, which results 3 tensors with shape=[2, 5] y = paddle.unstack(x, axis=1) # unstack with second axis, which results 3 tensors with shape=[2, 5]
""" """
helper = LayerHelper('unstack', **locals()) helper = LayerHelper('unstack', **locals())
...@@ -11017,7 +11017,7 @@ def slice(input, axes, starts, ends): ...@@ -11017,7 +11017,7 @@ def slice(input, axes, starts, ends):
return out return out
@templatedoc() @deprecated(since='2.0.0', update_to="paddle.strided_slice")
def strided_slice(input, axes, starts, ends, strides): def strided_slice(input, axes, starts, ends, strides):
""" """
:alias_main: paddle.strided_slice :alias_main: paddle.strided_slice
...@@ -11095,7 +11095,9 @@ def strided_slice(input, axes, starts, ends, strides): ...@@ -11095,7 +11095,9 @@ def strided_slice(input, axes, starts, ends, strides):
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle
paddle.enable_static()
input = fluid.data( input = fluid.data(
name="input", shape=[3, 4, 5, 6], dtype='float32') name="input", shape=[3, 4, 5, 6], dtype='float32')
......
...@@ -16,6 +16,9 @@ from op_test import OpTest ...@@ -16,6 +16,9 @@ from op_test import OpTest
import numpy as np import numpy as np
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle
paddle.enable_static()
def strided_slice_native_forward(input, axes, starts, ends, strides): def strided_slice_native_forward(input, axes, starts, ends, strides):
...@@ -498,6 +501,16 @@ class TestStridedSliceAPI(unittest.TestCase): ...@@ -498,6 +501,16 @@ class TestStridedSliceAPI(unittest.TestCase):
assert np.array_equal(res_6, input[-3:3, 0:100:2, :, -1:2:-1]) assert np.array_equal(res_6, input[-3:3, 0:100:2, :, -1:2:-1])
assert np.array_equal(res_7, input[-1, 0:100:2, :, -1:2:-1]) assert np.array_equal(res_7, input[-1, 0:100:2, :, -1:2:-1])
def test_dygraph_op(self):
x = paddle.zeros(shape=[3, 4, 5, 6], dtype="float32")
axes = [1, 2, 3]
starts = [-3, 0, 2]
ends = [3, 2, 4]
strides_1 = [1, 1, 1]
sliced_1 = paddle.strided_slice(
x, axes=axes, starts=starts, ends=ends, strides=strides_1)
assert sliced_1.shape == (3, 2, 2, 2)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -128,16 +128,13 @@ def transpose(x, perm, name=None): ...@@ -128,16 +128,13 @@ def transpose(x, perm, name=None):
.. code-block:: python .. code-block:: python
import paddle import paddle
import numpy as np
import paddle.fluid.dygraph as dg
with dg.guard(): x = paddle.to_tensor([[1.0 + 1.0j, 2.0 + 1.0j], [3.0+1.0j, 4.0+1.0j], [5.0+1.0j, 6.0+1.0j]])
a = np.array([[1.0 + 1.0j, 2.0 + 1.0j], [3.0+1.0j, 4.0+1.0j]]) x_transposed = paddle.complex.transpose(x, [1, 0])
x = dg.to_variable(a) print(x_transposed.numpy())
y = paddle.complex.transpose(x, [1, 0]) #[[1.+1.j 3.+1.j 5.+1.j]
print(y.numpy()) # [2.+1.j 4.+1.j 6.+1.j]]
# [[1.+1.j 3.+1.j]
# [2.+1.j 4.+1.j]]
""" """
complex_variable_exists([x], "transpose") complex_variable_exists([x], "transpose")
real = layers.transpose(x.real, perm, name) real = layers.transpose(x.real, perm, name)
......
...@@ -72,9 +72,6 @@ class Adam(Optimizer): ...@@ -72,9 +72,6 @@ class Adam(Optimizer):
some derived class of ``GradientClipBase`` . There are three cliping strategies some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
lazy_mode (bool, optional): The official Adam algorithm has two moving-average accumulators. lazy_mode (bool, optional): The official Adam algorithm has two moving-average accumulators.
The accumulators are updated at every step. Every element of the two moving-average The accumulators are updated at every step. Every element of the two moving-average
is updated in both dense mode and sparse mode. If the size of parameter is very large, is updated in both dense mode and sparse mode. If the size of parameter is very large,
...@@ -82,17 +79,17 @@ class Adam(Optimizer): ...@@ -82,17 +79,17 @@ class Adam(Optimizer):
gradient in current mini-batch, so it will be much more faster. But this mode has gradient in current mini-batch, so it will be much more faster. But this mode has
different semantics with the original Adam algorithm and may lead to different result. different semantics with the original Adam algorithm and may lead to different result.
The default value is False. The default value is False.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import numpy as np
paddle.disable_static()
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10) linear = paddle.nn.Linear(10, 10)
inp = paddle.to_tensor(inp) inp = paddle.rand([10,10], dtype="float32")
out = linear(inp) out = linear(inp)
loss = paddle.mean(out) loss = paddle.mean(out)
adam = paddle.optimizer.Adam(learning_rate=0.1, adam = paddle.optimizer.Adam(learning_rate=0.1,
...@@ -105,12 +102,9 @@ class Adam(Optimizer): ...@@ -105,12 +102,9 @@ class Adam(Optimizer):
# Adam with beta1/beta2 as Tensor and weight_decay as float # Adam with beta1/beta2 as Tensor and weight_decay as float
import paddle import paddle
import numpy as np
paddle.disable_static()
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10) linear = paddle.nn.Linear(10, 10)
inp = paddle.to_tensor(inp) inp = paddle.rand([10,10], dtype="float32")
out = linear(inp) out = linear(inp)
loss = paddle.mean(out) loss = paddle.mean(out)
...@@ -140,8 +134,8 @@ class Adam(Optimizer): ...@@ -140,8 +134,8 @@ class Adam(Optimizer):
parameters=None, parameters=None,
weight_decay=None, weight_decay=None,
grad_clip=None, grad_clip=None,
name=None, lazy_mode=False,
lazy_mode=False): name=None):
assert learning_rate is not None assert learning_rate is not None
assert beta1 is not None assert beta1 is not None
assert beta2 is not None assert beta2 is not None
...@@ -266,10 +260,8 @@ class Adam(Optimizer): ...@@ -266,10 +260,8 @@ class Adam(Optimizer):
.. code-block:: python .. code-block:: python
import paddle import paddle
import numpy as np
paddle.disable_static() a = paddle.rand([2,13], dtype="float32")
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value)
linear = paddle.nn.Linear(13, 5) linear = paddle.nn.Linear(13, 5)
# This can be any optimizer supported by dygraph. # This can be any optimizer supported by dygraph.
adam = paddle.optimizer.Adam(learning_rate = 0.01, adam = paddle.optimizer.Adam(learning_rate = 0.01,
......
...@@ -64,9 +64,6 @@ class AdamW(Adam): ...@@ -64,9 +64,6 @@ class AdamW(Adam):
some derived class of ``GradientClipBase`` . There are three cliping strategies some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
lazy_mode (bool, optional): The official Adam algorithm has two moving-average accumulators. lazy_mode (bool, optional): The official Adam algorithm has two moving-average accumulators.
The accumulators are updated at every step. Every element of the two moving-average The accumulators are updated at every step. Every element of the two moving-average
is updated in both dense mode and sparse mode. If the size of parameter is very large, is updated in both dense mode and sparse mode. If the size of parameter is very large,
...@@ -74,18 +71,18 @@ class AdamW(Adam): ...@@ -74,18 +71,18 @@ class AdamW(Adam):
gradient in current mini-batch, so it will be much more faster. But this mode has gradient in current mini-batch, so it will be much more faster. But this mode has
different semantics with the original Adam algorithm and may lead to different result. different semantics with the original Adam algorithm and may lead to different result.
The default value is False. The default value is False.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
**Notes**: **Notes**:
**Currently, AdamW doesn't support sparse parameter optimization.** **Currently, AdamW doesn't support sparse parameter optimization.**
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import numpy as np
paddle.disable_static()
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10) linear = paddle.nn.Linear(10, 10)
inp = paddle.to_tensor(inp) inp = paddle.rand([10,10], dtype="float32")
out = linear(inp) out = linear(inp)
loss = paddle.mean(out) loss = paddle.mean(out)
...@@ -112,8 +109,8 @@ class AdamW(Adam): ...@@ -112,8 +109,8 @@ class AdamW(Adam):
weight_decay=0.01, weight_decay=0.01,
apply_decay_param_fun=None, apply_decay_param_fun=None,
grad_clip=None, grad_clip=None,
name=None, lazy_mode=False,
lazy_mode=False): name=None):
assert learning_rate is not None assert learning_rate is not None
assert beta1 is not None assert beta1 is not None
assert beta2 is not None assert beta2 is not None
......
...@@ -104,24 +104,18 @@ class RMSProp(Optimizer): ...@@ -104,24 +104,18 @@ class RMSProp(Optimizer):
.. code-block:: python .. code-block:: python
import paddle import paddle
import numpy as np
paddle.disable_static() inp = paddle.rand([10,10], dtype="float32")
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10) linear = paddle.nn.Linear(10, 10)
inp = paddle.to_tensor(inp)
out = linear(inp) out = linear(inp)
loss = paddle.mean(out) loss = paddle.mean(out)
beta1 = paddle.to_tensor([0.9], dtype="float32") rmsprop = paddle.optimizer.RMSProp(learning_rate=0.1,
beta2 = paddle.to_tensor([0.99], dtype="float32")
adam = paddle.optimizer.RMSProp(learning_rate=0.1,
parameters=linear.parameters(), parameters=linear.parameters(),
weight_decay=0.01) weight_decay=0.01)
out.backward() out.backward()
adam.step() rmsprop.step()
adam.clear_grad() rmsprop.clear_grad()
""" """
......
...@@ -25,7 +25,6 @@ import six ...@@ -25,7 +25,6 @@ import six
# TODO: define functions to manipulate a tensor # TODO: define functions to manipulate a tensor
from ..fluid.layers import cast #DEFINE_ALIAS from ..fluid.layers import cast #DEFINE_ALIAS
from ..fluid.layers import slice #DEFINE_ALIAS from ..fluid.layers import slice #DEFINE_ALIAS
from ..fluid.layers import strided_slice #DEFINE_ALIAS
from ..fluid.layers import transpose #DEFINE_ALIAS from ..fluid.layers import transpose #DEFINE_ALIAS
from ..fluid.layers import unstack #DEFINE_ALIAS from ..fluid.layers import unstack #DEFINE_ALIAS
...@@ -1461,3 +1460,89 @@ def gather_nd(x, index, name=None): ...@@ -1461,3 +1460,89 @@ def gather_nd(x, index, name=None):
""" """
return paddle.fluid.layers.gather_nd(input=x, index=index, name=name) return paddle.fluid.layers.gather_nd(input=x, index=index, name=name)
def strided_slice(x, axes, starts, ends, strides, name=None):
"""
This operator produces a slice of ``x`` along multiple axes. Similar to numpy:
https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
Slice uses ``axes``, ``starts`` and ``ends`` attributes to specify the start and
end dimension for each axis in the list of axes and Slice uses this information
to slice the input data tensor. If a negative value is passed to
``starts`` or ``ends`` such as :math:`-i`, it represents the reverse position of the
axis :math:`i-1` th(here 0 is the initial position). The ``strides`` represents steps of
slicing and if the ``strides`` is negative, slice operation is in the opposite direction.
If the value passed to ``starts`` or ``ends`` is greater than n
(the number of elements in this dimension), it represents n.
For slicing to the end of a dimension with unknown size, it is recommended
to pass in INT_MAX. The size of ``axes`` must be equal to ``starts`` , ``ends`` and ``strides``.
Following examples will explain how strided_slice works:
.. code-block:: text
Case1:
Given:
data = [ [1, 2, 3, 4], [5, 6, 7, 8], ]
axes = [0, 1]
starts = [1, 0]
ends = [2, 3]
strides = [1, 1]
Then:
result = [ [5, 6, 7], ]
Case2:
Given:
data = [ [1, 2, 3, 4], [5, 6, 7, 8], ]
axes = [0, 1]
starts = [0, 1]
ends = [2, 0]
strides = [1, -1]
Then:
result = [ [8, 7, 6], ]
Case3:
Given:
data = [ [1, 2, 3, 4], [5, 6, 7, 8], ]
axes = [0, 1]
starts = [0, 1]
ends = [-1, 1000]
strides = [1, 3]
Then:
result = [ [2], ]
Args:
x (Tensor): An N-D ``Tensor``. The data type is ``float32``, ``float64``, ``int32`` or ``int64``.
axes (list|tuple): The data type is ``int32`` . Axes that `starts` and `ends` apply to.
It's optional. If it is not provides, it will be treated as :math:`[0,1,...,len(starts)-1]`.
starts (list|tuple|Tensor): The data type is ``int32`` . If ``starts`` is a list or tuple, the elements of it should be integers or Tensors with shape [1]. If ``starts`` is an Tensor, it should be an 1-D Tensor. It represents starting indices of corresponding axis in ``axes``.
ends (list|tuple|Tensor): The data type is ``int32`` . If ``ends`` is a list or tuple, the elements of
it should be integers or Tensors with shape [1]. If ``ends`` is an Tensor, it should be an 1-D Tensor . It represents ending indices of corresponding axis in ``axes``.
strides (list|tuple|Tensor): The data type is ``int32`` . If ``strides`` is a list or tuple, the elements of
it should be integers or Tensors with shape [1]. If ``strides`` is an Tensor, it should be an 1-D Tensor . It represents slice step of corresponding axis in ``axes``.
name(str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Tensor: A ``Tensor`` with the same dimension as ``x``. The data type is same as ``x``.
Examples:
.. code-block:: python
import paddle
x = paddle.zeros(shape=[3,4,5,6], dtype="float32")
# example 1:
# attr starts is a list which doesn't contain Tensor.
axes = [1, 2, 3]
starts = [-3, 0, 2]
ends = [3, 2, 4]
strides_1 = [1, 1, 1]
strides_2 = [1, 1, 2]
sliced_1 = paddle.strided_slice(x, axes=axes, starts=starts, ends=ends, strides=strides_1)
# sliced_1 is x[:, 1:3:1, 0:2:1, 2:4:1].
# example 2:
# attr starts is a list which contain tensor Tensor.
minus_3 = paddle.fill_constant([1], "int32", -3)
sliced_2 = paddle.strided_slice(x, axes=axes, starts=[minus_3, 0, 2], ends=ends, strides=strides_2)
# sliced_2 is x[:, 1:3:1, 0:2:1, 2:4:2].
"""
return paddle.fluid.layers.strided_slice(
input=x, axes=axes, starts=starts, ends=ends, strides=strides)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册