Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
2074d369
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2074d369
编写于
6月 19, 2018
作者:
Q
Qiao Longfei
提交者:
GitHub
6月 19, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #11532 from jacquesqiao/add-none-layers-api-doc
Add none layers api doc
上级
9c90dc97
706f3839
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
528 addition
and
130 deletion
+528
-130
python/paddle/fluid/clip.py
python/paddle/fluid/clip.py
+123
-11
python/paddle/fluid/inferencer.py
python/paddle/fluid/inferencer.py
+37
-9
python/paddle/fluid/initializer.py
python/paddle/fluid/initializer.py
+141
-91
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+184
-16
python/paddle/fluid/regularizer.py
python/paddle/fluid/regularizer.py
+43
-3
未找到文件。
python/paddle/fluid/clip.py
浏览文件 @
2074d369
...
@@ -24,8 +24,6 @@ __all__ = [
...
@@ -24,8 +24,6 @@ __all__ = [
'GradientClipByValue'
,
'GradientClipByValue'
,
'GradientClipByNorm'
,
'GradientClipByNorm'
,
'GradientClipByGlobalNorm'
,
'GradientClipByGlobalNorm'
,
'append_gradient_clip_ops'
,
'error_clip_callback'
,
]
]
...
@@ -38,6 +36,25 @@ class BaseErrorClipAttr(object):
...
@@ -38,6 +36,25 @@ class BaseErrorClipAttr(object):
class
ErrorClipByValue
(
BaseErrorClipAttr
):
class
ErrorClipByValue
(
BaseErrorClipAttr
):
"""
Clips tensor values to the range [min, max].
Given a tensor t, this operation clips its value to min and max inplace.
- Any values less than min are set to min.
- Any values greater than max are set to max.
Args:
max (float): The maximum value to clip by.
min (float, optional): The minimum value to clip by. if not set by user,
\
will be set to -max by framework.
Examples:
.. code-block:: python
var = fluid.framework.Variable(..., error_clip=ErrorClipByValue(max=5.0), ...)
"""
def
__init__
(
self
,
max
,
min
=
None
):
def
__init__
(
self
,
max
,
min
=
None
):
max
=
float
(
max
)
max
=
float
(
max
)
if
min
is
None
:
if
min
is
None
:
...
@@ -99,6 +116,31 @@ class NullGradientClipAttr(BaseGradientClipAttr):
...
@@ -99,6 +116,31 @@ class NullGradientClipAttr(BaseGradientClipAttr):
class
GradientClipByValue
(
BaseGradientClipAttr
):
class
GradientClipByValue
(
BaseGradientClipAttr
):
"""
Clips gradient values to the range [min, max].
Given a tensor t, this operation clips its value to min and max inplace.
- Any values less than min are set to min.
- Any values greater than max are set to max.
Args:
max (float): The maximum value to clip by.
min (float, optional): The minimum value to clip by. if not set by user,
\
will be set to -max by framework.
Examples:
.. code-block:: python
w_param_attrs = ParamAttr(name=None,
initializer=UniformInitializer(low=-1.0, high=1.0, seed=0),
learning_rate=1.0,
regularizer=L1Decay(1.0),
trainable=True,
clip=GradientClipByValue(-1.0, 1.0))
y_predict = fluid.layers.fc(input=x, size=1, param_attr=w_param_attrs)
"""
def
__init__
(
self
,
max
,
min
=
None
):
def
__init__
(
self
,
max
,
min
=
None
):
max
=
float
(
max
)
max
=
float
(
max
)
if
min
is
None
:
if
min
is
None
:
...
@@ -120,6 +162,37 @@ class GradientClipByValue(BaseGradientClipAttr):
...
@@ -120,6 +162,37 @@ class GradientClipByValue(BaseGradientClipAttr):
class
GradientClipByNorm
(
BaseGradientClipAttr
):
class
GradientClipByNorm
(
BaseGradientClipAttr
):
"""
Clips tensor values to a maximum L2-norm.
This operator limits the L2 norm of the input :math:`X` within :math:`max\_norm`.
If the L2 norm of :math:`X` is less than or equal to :math:`max\_norm`, :math:`Out`
will be the same as :math:`X`. If the L2 norm of :math:`X` is greater than
:math:`max\_norm`, :math:`X` will be linearly scaled to make the L2 norm of
:math:`Out` equal to :math:`max\_norm`, as shown in the following formula:
.. math::
Out =
\\
frac{max\_norm * X}{norm(X)},
where :math:`norm(X)` represents the L2 norm of :math:`X`.
Args:
clip_norm (float): The maximum norm value
Examples:
.. code-block:: python
w_param_attrs = ParamAttr(name=None,
initializer=UniformInitializer(low=-1.0, high=1.0, seed=0),
learning_rate=1.0,
regularizer=L1Decay(1.0),
trainable=True,
clip=GradientClipByNorm(clip_norm=2.0))
y_predict = fluid.layers.fc(input=x, size=1, param_attr=w_param_attrs)
"""
def
__init__
(
self
,
clip_norm
):
def
__init__
(
self
,
clip_norm
):
self
.
clip_norm
=
clip_norm
self
.
clip_norm
=
clip_norm
...
@@ -135,6 +208,44 @@ class GradientClipByNorm(BaseGradientClipAttr):
...
@@ -135,6 +208,44 @@ class GradientClipByNorm(BaseGradientClipAttr):
class
GradientClipByGlobalNorm
(
BaseGradientClipAttr
):
class
GradientClipByGlobalNorm
(
BaseGradientClipAttr
):
"""
Clips values of multiple tensors by the ratio of the sum of their norms.
Given a list of tensors t_list, and a clipping ratio clip_norm, this
operation returns a list of clipped tensors list_clipped and the global
norm (global_norm) of all tensors in t_list.
To perform the clipping, the values :math:`t\_list[i]` are set to:
.. math::
t\_list[i] = t\_list[i] *
\\
frac{clip\_norm}{\max(global\_norm, clip\_norm)}
where:
.. math::
global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2}
If :math:`clip\_norm > global\_norm` then the entries in t_list remain as they are,
otherwise they're all shrunk by the global ratio.
Args:
clip_norm (float): The maximum norm value
group_name (str, optional): The group name for this clip.
Examples:
.. code-block:: python
p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip)
with fluid.program_guard(main_program=prog_clip):
fluid.clip.set_gradient_clip(
fluid.clip.GradientClipByGlobalNorm(clip_norm=2.0))
p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip)
"""
def
__init__
(
self
,
clip_norm
,
group_name
=
"default_group"
):
def
__init__
(
self
,
clip_norm
,
group_name
=
"default_group"
):
if
not
isinstance
(
group_name
,
basestring
):
if
not
isinstance
(
group_name
,
basestring
):
raise
TypeError
(
"'group_name' must be a basestring."
)
raise
TypeError
(
"'group_name' must be a basestring."
)
...
@@ -183,15 +294,16 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
...
@@ -183,15 +294,16 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
def
set_gradient_clip
(
clip
,
param_list
=
None
,
program
=
None
):
def
set_gradient_clip
(
clip
,
param_list
=
None
,
program
=
None
):
"""
"""
To specify parameters that require gradient clip.
To specify parameters that require gradient clip.
Args:
clip(BaseGradientClipAttr): An instance of some derived class of BaseGradientClipAttr,
Args:
which describes the type and detailed attributes of required gradient clip.
clip(BaseGradientClipAttr): An instance of some derived class of BaseGradientClipAttr,
param_list(list, None by default): Parameters that require gradient clip.
which describes the type and detailed attributes of required gradient clip.
It can be a list of parameter or a list of parameter's name.
param_list(list(Variable)): Parameters that require gradient clip.
When it's None, all parameters in the program will be included.
It can be a list of parameter or a list of parameter's name.
program(Program, None by default): The program where parameters are.
When it's None, all parameters in the program will be included.
Will be the default main program when assigned with None.
program(Program): The program where parameters are.
Will be the default main program when assigned with None.
"""
"""
if
not
isinstance
(
clip
,
BaseGradientClipAttr
):
if
not
isinstance
(
clip
,
BaseGradientClipAttr
):
raise
TypeError
(
raise
TypeError
(
...
...
python/paddle/fluid/inferencer.py
浏览文件 @
2074d369
...
@@ -27,13 +27,30 @@ __all__ = ['Inferencer', ]
...
@@ -27,13 +27,30 @@ __all__ = ['Inferencer', ]
class
Inferencer
(
object
):
class
Inferencer
(
object
):
"""
Inferencer High Level API.
Args:
infer_func (Python func): Infer function that will return predict Variable
param_path (str): The path where the inference model is saved by fluid.io.save_params
place (Place): place to do the inference
parallel (bool): use parallel_executor to run the inference, it will use multi CPU/GPU.
Examples:
.. code-block:: python
def inference_program():
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
return y_predict
place = fluid.CPUPlace()
inferencer = fluid.Inferencer(
infer_func=inference_program, param_path="/tmp/model", place=place)
"""
def
__init__
(
self
,
infer_func
,
param_path
,
place
=
None
,
parallel
=
False
):
def
__init__
(
self
,
infer_func
,
param_path
,
place
=
None
,
parallel
=
False
):
"""
:param infer_func: a function that will return predict Variable
:param param_path: the path where the inference model is saved by fluid.io.save_params
:param place: place to do the inference
:param parallel: use parallel_executor to run the inference, it will use multi CPU/GPU.
"""
self
.
param_path
=
param_path
self
.
param_path
=
param_path
self
.
scope
=
core
.
Scope
()
self
.
scope
=
core
.
Scope
()
self
.
parallel
=
parallel
self
.
parallel
=
parallel
...
@@ -60,9 +77,20 @@ class Inferencer(object):
...
@@ -60,9 +77,20 @@ class Inferencer(object):
def
infer
(
self
,
inputs
,
return_numpy
=
True
):
def
infer
(
self
,
inputs
,
return_numpy
=
True
):
"""
"""
:param inputs: a map of {"input_name": input_var} that will be feed into the inference program
Do Inference for Inputs
to get the predict value
:return: the predict value of the inference model
Args:
inputs (map): a map of {"input_name": input_var} that will be feed into the inference program
return_numpy (bool): transform return value into numpy or not
Returns:
Tensor or Numpy: the predict value of the inference model for the inputs
Examples:
.. code-block:: python
tensor_x = numpy.random.uniform(0, 10, [batch_size, 13]).astype("float32")
results = inferencer.infer({'x': tensor_x})
"""
"""
if
not
isinstance
(
inputs
,
dict
):
if
not
isinstance
(
inputs
,
dict
):
raise
ValueError
(
raise
ValueError
(
...
...
python/paddle/fluid/initializer.py
浏览文件 @
2074d369
...
@@ -19,26 +19,39 @@ from framework import convert_np_dtype_to_dtype_
...
@@ -19,26 +19,39 @@ from framework import convert_np_dtype_to_dtype_
from
core
import
VarDesc
from
core
import
VarDesc
__all__
=
[
__all__
=
[
'Constant'
,
'Uniform'
,
'Normal'
,
'Xavier'
,
'Bilinear'
,
'force_init_on_cpu'
,
'Constant'
,
'Uniform'
,
'Normal'
,
'Xavier'
,
'Bilinear'
,
'MSRA'
,
'init_on_cpu'
,
'ConstantInitializer'
,
'UniformInitializer'
,
'force_init_on_cpu'
,
'init_on_cpu'
,
'ConstantInitializer'
,
'NormalInitializer'
,
'XavierInitializer'
,
'BilinearInitializer'
'UniformInitializer'
,
'NormalInitializer'
,
'XavierInitializer'
,
'BilinearInitializer'
,
'MSRAInitializer'
]
]
_force_init_on_cpu_
=
False
_force_init_on_cpu_
=
False
def
force_init_on_cpu
():
def
force_init_on_cpu
():
"""
The flag of whether force to init variables on CPU.
Examples:
.. code-block:: python
if force_init_on_cpu():
pass
"""
return
_force_init_on_cpu_
return
_force_init_on_cpu_
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
init_on_cpu
():
def
init_on_cpu
():
"""
"""
Switch program with `with` statement
Force the variable to be inited on CPU.
Examples:
Examples:
>>> with init_on_cpu():
.. code-block:: python
>>> step = layers.create_global_var()
with init_on_cpu():
step = layers.create_global_var()
"""
"""
global
_force_init_on_cpu_
global
_force_init_on_cpu_
...
@@ -104,14 +117,18 @@ class Initializer(object):
...
@@ -104,14 +117,18 @@ class Initializer(object):
class
ConstantInitializer
(
Initializer
):
class
ConstantInitializer
(
Initializer
):
"""Implements the constant initializer
"""Implements the constant initializer
Args:
value (float): constant value to initialize the variable
Examples:
.. code-block:: python
fc = fluid.layers.fc(input=x, size=10,
param_attr=fluid.initializer.Constant(value=2.0))
"""
"""
def
__init__
(
self
,
value
=
0.0
,
force_cpu
=
False
):
def
__init__
(
self
,
value
=
0.0
,
force_cpu
=
False
):
"""Constructor for ConstantInitializer
Args:
value: constant value to initialize the variable
"""
assert
value
is
not
None
assert
value
is
not
None
super
(
ConstantInitializer
,
self
).
__init__
()
super
(
ConstantInitializer
,
self
).
__init__
()
self
.
_value
=
value
self
.
_value
=
value
...
@@ -146,16 +163,20 @@ class ConstantInitializer(Initializer):
...
@@ -146,16 +163,20 @@ class ConstantInitializer(Initializer):
class
UniformInitializer
(
Initializer
):
class
UniformInitializer
(
Initializer
):
"""Implements the random uniform distribution initializer
"""Implements the random uniform distribution initializer
Args:
low (float): lower boundary of the uniform distribution
high (float): upper boundary of the uniform distribution
seed (int): random seed
Examples:
.. code-block:: python
fc = fluid.layers.fc(input=x, size=10,
param_attr=fluid.initializer.Uniform(low=-0.5, high=0.5))
"""
"""
def
__init__
(
self
,
low
=-
1.0
,
high
=
1.0
,
seed
=
0
):
def
__init__
(
self
,
low
=-
1.0
,
high
=
1.0
,
seed
=
0
):
"""Constructor for UniformInitializer
Args:
low: lower boundary of the uniform distribution
high: upper boundary of the uniform distribution
seed: random seed
"""
assert
low
is
not
None
assert
low
is
not
None
assert
high
is
not
None
assert
high
is
not
None
assert
high
>=
low
assert
high
>=
low
...
@@ -196,17 +217,21 @@ class UniformInitializer(Initializer):
...
@@ -196,17 +217,21 @@ class UniformInitializer(Initializer):
class
NormalInitializer
(
Initializer
):
class
NormalInitializer
(
Initializer
):
"""Implements the random Normal(Gaussian) distribution initializer
"""Implements the Random Normal(Gaussian) distribution initializer
Args:
loc (float): mean of the normal distribution
scale (float): standard deviation of the normal distribution
seed (int): random seed
Examples:
.. code-block:: python
fc = fluid.layers.fc(input=x, size=10,
param_attr=fluid.initializer.Normal(loc=0.0, scale=2.0))
"""
"""
def
__init__
(
self
,
loc
=
0.0
,
scale
=
1.0
,
seed
=
0
):
def
__init__
(
self
,
loc
=
0.0
,
scale
=
1.0
,
seed
=
0
):
"""Constructor for NormalInitializer
Args:
loc: mean of the normal distribution
scale: standard deviation of the normal distribution
seed: random seed
"""
assert
loc
is
not
None
assert
loc
is
not
None
assert
scale
is
not
None
assert
scale
is
not
None
assert
seed
is
not
None
assert
seed
is
not
None
...
@@ -246,39 +271,49 @@ class NormalInitializer(Initializer):
...
@@ -246,39 +271,49 @@ class NormalInitializer(Initializer):
class
XavierInitializer
(
Initializer
):
class
XavierInitializer
(
Initializer
):
"""Implements the Xavier initializer
"""
This class implements the Xavier weight initializer from the paper
This class implements the Xavier weight initializer from the paper
Understanding the difficulty of training deep feedforward neural
`Understanding the difficulty of training deep feedforward neural
networks[1] by Xavier Glorot and Yoshua Bengio.
networks <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_
by Xavier Glorot and Yoshua Bengio.
This initializer is designed to keep the scale of the gradients
This initializer is designed to keep the scale of the gradients
approximately same in all the layers. In case of Uniform distribution,
approximately same in all the layers. In case of Uniform distribution,
the range is [-x, x], where x = sqrt(6 / (fan_in + fan_out)).
the range is [-x, x], where
.. math::
x = \sqrt{
\\
frac{6.0}{fan\_in + fan\_out}}
In case of Normal distribution, the mean is 0 and the standard deviation
In case of Normal distribution, the mean is 0 and the standard deviation
is sqrt(2/ (fan_in + fan_out)).
is
.. math::
\sqrt{
\\
frac{2.0}{fan\_in + fan\_out}}
Args:
uniform (bool): whether to use uniform or normal distribution
fan_in (float): fan_in for Xavier initialization. If None, it is
inferred from the variable.
fan_out (float): fan_out for Xavier initialization. If None, it is
inferred from the variable.
seed (int): random seed
Note:
It is recommended to set fan_in and fan_out to None for most cases.
Examples:
.. code-block:: python
fc = fluid.layers.fc(
input=queries, size=10,
param_attr=fluid.initializer.Xavier(uniform=False))
References:
[1] Understanding the difficulty of training deep feedforward neural
networks. International conference on artificial intelligence and
statistics.
(http://proceedings.mlr.press/v9/glorot10a.html)
"""
"""
def
__init__
(
self
,
uniform
=
True
,
fan_in
=
None
,
fan_out
=
None
,
seed
=
0
):
def
__init__
(
self
,
uniform
=
True
,
fan_in
=
None
,
fan_out
=
None
,
seed
=
0
):
"""Constructor for XavierInitializer
Args:
uniform: whether to use uniform or normal distribution
fan_in: fan_in for Xavier initialization. If None, it is
inferred from the variable.
fan_out: fan_out for Xavier initialization. If None, it is
inferred from the variable.
seed: random seed
Note: It is recommended to set fan_in and fan_out to None for
most cases.
"""
assert
uniform
is
not
None
assert
uniform
is
not
None
assert
seed
is
not
None
assert
seed
is
not
None
super
(
XavierInitializer
,
self
).
__init__
()
super
(
XavierInitializer
,
self
).
__init__
()
...
@@ -342,30 +377,42 @@ class MSRAInitializer(Initializer):
...
@@ -342,30 +377,42 @@ class MSRAInitializer(Initializer):
"""Implements the MSRA initializer a.k.a. Kaiming Initializer
"""Implements the MSRA initializer a.k.a. Kaiming Initializer
This class implements the weight initialization from the paper
This class implements the weight initialization from the paper
Delving Deep into Rectifiers: Surpassing Human-Level Performance on
`Delving Deep into Rectifiers: Surpassing Human-Level Performance on
ImageNet Classification[1] by Kaiming He, Xiangyu Zhang, Shaoqing Ren
ImageNet Classification <https://arxiv.org/abs/1502.01852>`_
and Jian Sun. This is a robust initialization method that particularly
by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. This is a
considers the rectifier nonlinearities. In case of Uniform distribution,
robust initialization method that particularly considers the rectifier
the range is [-x, x], where x = sqrt(6 / fan_in). In case of Normal
nonlinearities. In case of Uniform distribution, the range is [-x, x], where
distribution, the mean is 0 and the standard deviation
is sqrt(2/ fan_in).
.. math::
References:
x = \sqrt{
\\
frac{6.0}{fan\_in}}
[1] Delving Deep into Rectifiers: Surpassing Human-Level Performance
on ImageNet Classification
In case of Normal distribution, the mean is 0 and the standard deviation
(https://arxiv.org/abs/1502.01852)
is
.. math::
\sqrt{
\\
frac{2.0}{fan\_in}}
Args:
uniform (bool): whether to use uniform or normal distribution
fan_in (float): fan_in for MSRAInitializer. If None, it is
\
inferred from the variable.
seed (int): random seed
Note:
It is recommended to set fan_in to None for most cases.
Examples:
.. code-block:: python
fc = fluid.layers.fc(
input=queries, size=10,
param_attr=fluid.initializer.MSRA(uniform=False))
"""
"""
def
__init__
(
self
,
uniform
=
True
,
fan_in
=
None
,
seed
=
0
):
def
__init__
(
self
,
uniform
=
True
,
fan_in
=
None
,
seed
=
0
):
"""Constructor for MSRAInitializer
"""Constructor for MSRAInitializer
Args:
uniform: whether to use uniform or normal distribution
fan_in: fan_in for MSRAInitializer. If None, it is
inferred from the variable.
seed: random seed
Note: It is recommended to set fan_in to None for most cases.
"""
"""
assert
uniform
is
not
None
assert
uniform
is
not
None
assert
seed
is
not
None
assert
seed
is
not
None
...
@@ -425,34 +472,37 @@ class MSRAInitializer(Initializer):
...
@@ -425,34 +472,37 @@ class MSRAInitializer(Initializer):
class
BilinearInitializer
(
Initializer
):
class
BilinearInitializer
(
Initializer
):
"""Implements the bilinear initializer.
"""
This initializer can be used in transposed convolution operator to
This initializer can be used in transposed convolution operator to
act as upsampling. Users can upsample a feature map with shape of
act as upsampling. Users can upsample a feature map with shape of
(B, C, H, W) by any integer factor. The usage is:
(B, C, H, W) by any integer factor. The usage is:
>>> factor = 2
Examples:
>>> w_attr = ParamAttr(learning_rate=0., regularizer=L2Decay(0.),
>>> initializer=Bilinear())
.. code-block:: python
>>> conv_up = fluid.layers.conv2d_transpose(
>>> input,
factor = 2
>>> num_filters=C,
w_attr = ParamAttr(learning_rate=0., regularizer=L2Decay(0.),
>>> output_size=None,
initializer=Bilinear())
>>> filter_size=2 * factor - factor % 2,
conv_up = fluid.layers.conv2d_transpose(
>>> padding=ceil((factor - 1) / 2.),
input,
>>> stride=factor,
num_filters=C,
>>> groups=C,
output_size=None,
>>> param_attr=w_attr,
filter_size=2 * factor - factor % 2,
>>> bias_attr=False)
padding=ceil((factor - 1) / 2.),
stride=factor,
groups=C,
Where, `num_filters=C` and `groups=C` means this is channel-wise tranposed
param_attr=w_attr,
bias_attr=False)
Where, `num_filters=C` and `groups=C` means this is channel-wise transposed
convolution. The filter shape will be (C, 1, K, K) where K is `filer_size`,
convolution. The filter shape will be (C, 1, K, K) where K is `filer_size`,
This initializer will set a (K, K) interpolation kernel for every channel
This initializer will set a (K, K) interpolation kernel for every channel
of the filter identically. The resulting shape of the output feature map
of the filter identically. The resulting shape of the output feature map
will be (B, C, factor * H, factor * W). Note that the learning rate and the
will be (B, C, factor * H, factor * W). Note that the learning rate and the
weight decay are set to 0 in order to keep coefficient values of bilinear
weight decay are set to 0 in order to keep coefficient values of bilinear
interpolation unchanged during training.
interpolation unchanged during training.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -469,7 +519,7 @@ class BilinearInitializer(Initializer):
...
@@ -469,7 +519,7 @@ class BilinearInitializer(Initializer):
be added.
be added.
Returns:
Returns:
the initialization op
Operator:
the initialization op
Raises:
Raises:
ValueError: If type of `var` and `block` is not right.
ValueError: If type of `var` and `block` is not right.
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
2074d369
...
@@ -29,7 +29,7 @@ __all__ = [
...
@@ -29,7 +29,7 @@ __all__ = [
'SGD'
,
'Momentum'
,
'Adagrad'
,
'Adam'
,
'Adamax'
,
'DecayedAdagrad'
,
'Ftrl'
,
'SGD'
,
'Momentum'
,
'Adagrad'
,
'Adam'
,
'Adamax'
,
'DecayedAdagrad'
,
'Ftrl'
,
'SGDOptimizer'
,
'MomentumOptimizer'
,
'AdagradOptimizer'
,
'AdamOptimizer'
,
'SGDOptimizer'
,
'MomentumOptimizer'
,
'AdagradOptimizer'
,
'AdamOptimizer'
,
'AdamaxOptimizer'
,
'DecayedAdagradOptimizer'
,
'RMSPropOptimizer'
,
'AdamaxOptimizer'
,
'DecayedAdagradOptimizer'
,
'RMSPropOptimizer'
,
'FtrlOptimizer'
,
'Adadelta'
,
'ModelAverage'
,
'Optimizer'
'FtrlOptimizer'
,
'Adadelta'
,
'ModelAverage'
,
'Optimizer'
,
'RMSPropOptimizer'
]
]
...
@@ -192,15 +192,15 @@ class Optimizer(object):
...
@@ -192,15 +192,15 @@ class Optimizer(object):
"""Add optimization operators to update gradients to variables.
"""Add optimization operators to update gradients to variables.
Args:
Args:
loss: the target that this optimization is for.
loss(Variable): the target that this optimization is for.
parameters_and_grads: a list of (variable, gradient) pair to update.
parameters_and_grads(list(tuple(Variable, Variable))):
a list of (variable, gradient) pair to update.
Returns:
Returns:
return_op_list: a list of operators that will complete one step of
return_op_list: a list of operators that will complete one step of
optimization. This will include parameter update ops, global step
optimization. This will include parameter update ops, global step
update ops and any other custom ops required by subclasses to manage
update ops and any other custom ops required by subclasses to manage
their internal state.
their internal state.
:param startup_program:
"""
"""
# This is a default implementation of create_optimization_pass that
# This is a default implementation of create_optimization_pass that
# can be shared by most optimizers. This implementation assumes that
# can be shared by most optimizers. This implementation assumes that
...
@@ -268,7 +268,22 @@ class Optimizer(object):
...
@@ -268,7 +268,22 @@ class Optimizer(object):
class
SGDOptimizer
(
Optimizer
):
class
SGDOptimizer
(
Optimizer
):
""" Simple SGD optimizer without any state.
"""
Optimizer of the stochastic gradient descent algorithm.
.. math::
param\_out = param - learning\_rate * grad
Args:
learning_rate (float|Variable): the learning rate used to update parameters.
\
Can be a float value or a Variable with one float value as data element.
Examples:
.. code-block:: python
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.2)
sgd_optimizer.minimize(cost)
"""
"""
def
__init__
(
self
,
learning_rate
,
**
kwargs
):
def
__init__
(
self
,
learning_rate
,
**
kwargs
):
...
@@ -294,7 +309,37 @@ class SGDOptimizer(Optimizer):
...
@@ -294,7 +309,37 @@ class SGDOptimizer(Optimizer):
class
MomentumOptimizer
(
Optimizer
):
class
MomentumOptimizer
(
Optimizer
):
"""Simple Momentum optimizer with velocity state
"""
Simple Momentum optimizer with velocity state
This optimizer has a flag for Nestrov Momentum.
The update equations are as follows:
.. math::
& velocity = mu * velocity + gradient
& if (use\_nesterov):
&\quad param = param - gradient * learning\_rate + mu * velocity * learning\_rate
& else:
&\quad param = param - learning\_rate * velocity
Args:
learning_rate (float|Variable): the learning rate used to update parameters.
\
Can be a float value or a Variable with one float value as data element.
momentum (float): momentum factor
use_nesterov (bool): enables Nesterov momentum
Examples:
.. code-block:: python
optimizer = fluid.optimizer.Momentum(learning_rate=0.2, momentum=0.1)
optimizer.minimize(cost)
"""
"""
_velocity_acc_str
=
"velocity"
_velocity_acc_str
=
"velocity"
...
@@ -338,7 +383,32 @@ class MomentumOptimizer(Optimizer):
...
@@ -338,7 +383,32 @@ class MomentumOptimizer(Optimizer):
class
AdagradOptimizer
(
Optimizer
):
class
AdagradOptimizer
(
Optimizer
):
"""Simple Adagrad optimizer with moment state
"""
**Adaptive Gradient Algorithm (Adagrad)**
The update is done as follows:
.. math::
moment\_out &= moment + grad * grad
param\_out &= param -
\\
frac{learning\_rate * grad}{\sqrt{moment\_out} + \epsilon}
The original paper(http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
does not have the epsilon attribute. It is added here in our implementation
as also proposed here: http://cs231n.github.io/neural-networks-3/#ada
for numerical stability to avoid the division by zero error.
Args:
learning_rate (float|Variable): the learning rate used to update parameters.
\
Can be a float value or a Variable with one float value as data element.
epsilon (float): a small float value for numerical stability.
Examples:
.. code-block:: python
optimizer = fluid.optimizer.Adagrad(learning_rate=0.2)
optimizer.minimize(cost)
"""
"""
_moment_acc_str
=
"moment"
_moment_acc_str
=
"moment"
...
@@ -379,7 +449,40 @@ class AdagradOptimizer(Optimizer):
...
@@ -379,7 +449,40 @@ class AdagradOptimizer(Optimizer):
class
AdamOptimizer
(
Optimizer
):
class
AdamOptimizer
(
Optimizer
):
"""Implements the Adam Optimizer
"""
This implements the Adam optimizer from Section 2 of the Adam
paper : https://arxiv.org/abs/1412.6980.
Adam is a first-order gradient-based optimization method based on
adaptive estimates of lower-order moments.
Adam updates:
.. math::
t & = t + 1
moment\_1\_out & = {
\\
beta}_1 * moment\_1 + (1 - {
\\
beta}_1) * grad
moment\_2\_out & = {
\\
beta}_2 * moment\_2 + (1 - {
\\
beta}_2) * grad * grad
learning\_rate & = learning\_rate *
\\
\\
frac{\sqrt{1 - {
\\
beta}_2^t}}{1 - {
\\
beta}_1^t}
param\_out & = param - learning\_rate *
\\
frac{moment\_1}{\sqrt{moment\_2} + \epsilon}
Args:
learning_rate (float|Variable): the learning rate used to update parameters.
\
Can be a float value or a Variable with one float value as data element.
beta1 (float): The exponential decay rate for the 1st moment estimates.
beta2 (float): The exponential decay rate for the 2nd moment estimates.
epsilon (float): a small float value for numerical stability.
Examples:
.. code-block:: python
optimizer = fluid.optimizer.Adam(learning_rate=0.2)
optimizer.minimize(cost)
"""
"""
_moment1_acc_str
=
"moment1"
_moment1_acc_str
=
"moment1"
_moment2_acc_str
=
"moment2"
_moment2_acc_str
=
"moment2"
...
@@ -484,7 +587,42 @@ class AdamOptimizer(Optimizer):
...
@@ -484,7 +587,42 @@ class AdamOptimizer(Optimizer):
class
AdamaxOptimizer
(
Optimizer
):
class
AdamaxOptimizer
(
Optimizer
):
"""Implements the Adamax Optimizer
"""
We implement the Adamax optimizer from Section 7 of the Adam
paper: https://arxiv.org/abs/1412.6980. Adamax is a variant of the
Adam algorithm based on the infinity norm.
Adamax updates:
.. math::
t & = t + 1
moment\_out & = {
\\
beta}_1 * moment + (1 - {
\\
beta}_1) * grad
inf\_norm\_out & = max({
\\
beta}_2 * inf\_norm + \epsilon, |grad|)
learning\_rate & =
\\
frac{learning\_rate}{1 - {
\\
beta}_1^t}
param\_out & = param - learning\_rate *
\\
frac{moment\_out}{inf\_norm\_out}
The original paper does not have an epsilon attribute.
However, it is added here for numerical stability to prevent the
division by 0 error.
Args:
learning_rate (float|Variable): the learning rate used to update parameters.
\
Can be a float value or a Variable with one float value as data element.
beta1 (float): The exponential decay rate for the 1st moment estimates.
beta2 (float): The exponential decay rate for the 2nd moment estimates.
epsilon (float): a small float value for numerical stability.
Examples:
.. code-block:: python
optimizer = fluid.optimizer.Adamax(learning_rate=0.2)
optimizer.minimize(cost)
"""
"""
_moment_acc_str
=
"moment"
_moment_acc_str
=
"moment"
_inf_norm_acc_str
=
"inf_norm"
_inf_norm_acc_str
=
"inf_norm"
...
@@ -568,7 +706,34 @@ class AdamaxOptimizer(Optimizer):
...
@@ -568,7 +706,34 @@ class AdamaxOptimizer(Optimizer):
class
DecayedAdagradOptimizer
(
Optimizer
):
class
DecayedAdagradOptimizer
(
Optimizer
):
"""Simple Decayed Adagrad optimizer with moment state
"""
**Decayed Adagrad Optimizer**
The original paper(http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
The update is done as follows:
.. math::
moment\_out & = decay * moment + (1 - decay) * grad * grad
param\_out & = param -
\\
frac{learning\_rate * grad}{\sqrt{moment\_out} + \epsilon}
The original paper(http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
does not have an epsilon attribute. It is added here for numerical
stability to avoid the division by zero error.
Args:
learning_rate (float|Variable): the learning rate used to update parameters.
\
Can be a float value or a Variable with one float value as data element.
decay (float): decay rate.
epsilon (float): a small float value for numerical stability.
Examples:
.. code-block:: python
optimizer = fluid.optimizer.DecayedAdagrad(learning_rate=0.2)
optimizer.minimize(cost)
"""
"""
_moment_acc_str
=
"moment"
_moment_acc_str
=
"moment"
...
@@ -614,6 +779,7 @@ class DecayedAdagradOptimizer(Optimizer):
...
@@ -614,6 +779,7 @@ class DecayedAdagradOptimizer(Optimizer):
class
AdadeltaOptimizer
(
Optimizer
):
class
AdadeltaOptimizer
(
Optimizer
):
"""
"""
**Adadelta Optimizer**
**Adadelta Optimizer**
Simple Adadelta optimizer with average squared grad state and
Simple Adadelta optimizer with average squared grad state and
average squared update state.
average squared update state.
The details of adadelta please refer to this
The details of adadelta please refer to this
...
@@ -703,26 +869,26 @@ class RMSPropOptimizer(Optimizer):
...
@@ -703,26 +869,26 @@ class RMSPropOptimizer(Optimizer):
.. math::
.. math::
r(w, t) & =
\\
rho r(w, t-1) + (1 -
\\
rho)(
\\
nabla Q_{i}(w))^2
\\\\
r(w, t) & =
\\
rho r(w, t-1) + (1 -
\\
rho)(
\\
nabla Q_{i}(w))^2
w & = w -
\\
frac{
\\
eta} {
\\
sqrt{r(w,t) +
\\
epsilon}}
\\
nabla Q_{i}(w)
w & = w -
\\
frac{
\\
eta} {
\\
sqrt{r(w,t) +
\\
epsilon}}
\\
nabla Q_{i}(w)
The first equation calculates moving average of the squared gradient for
The first equation calculates moving average of the squared gradient for
each weight. Then dividing the gradient by :math:
`sqrt{v(w,t)}`.
each weight. Then dividing the gradient by :math:`sqrt{v(w,t)}`.
In some cases, adding a momentum term :math: `
\\
beta` is beneficial.
In some cases, adding a momentum term :math: `
\\
beta` is beneficial.
In our implementation, Nesterov momentum is used:
In our implementation, Nesterov momentum is used:
.. math::
.. math::
r(w, t) & =
\\
rho r(w, t-1) + (1 -
\\
rho)(
\\
nabla Q_{i}(w))^2
\\\\
r(w, t) & =
\\
rho r(w, t-1) + (1 -
\\
rho)(
\\
nabla Q_{i}(w))^2
v(w, t) & =
\\
beta v(w, t-1) +
\\
frac{
\\
eta} {
\\
sqrt{v(w,t) +
v(w, t) & =
\\
beta v(w, t-1) +
\\
frac{
\\
eta} {
\\
sqrt{v(w,t) +
\\
epsilon}}
\\
nabla Q_{i}(w)
\\
epsilon}}
\\
nabla Q_{i}(w)
w & = w - v(w, t)
w & = w - v(w, t)
where, :math:
`
\\
rho` is a hyperparameter and typical values are 0.9, 0.95
where, :math:`
\\
rho` is a hyperparameter and typical values are 0.9, 0.95
and so on. :math: `beta` is the momentum term. :math: `
\\
epsilon` is a
and so on. :math: `beta` is the momentum term. :math: `
\\
epsilon` is a
smoothing term to avoid division by zero, usually set somewhere in range
smoothing term to avoid division by zero, usually set somewhere in range
from 1e-4 to 1e-8.
from 1e-4 to 1e-8.
...
@@ -733,7 +899,7 @@ class RMSPropOptimizer(Optimizer):
...
@@ -733,7 +899,7 @@ class RMSPropOptimizer(Optimizer):
rho(float): rho is :math: `
\\
rho` in equation, set 0.95 by default.
rho(float): rho is :math: `
\\
rho` in equation, set 0.95 by default.
epsilon(float): :math: `
\\
epsilon` in equation is smoothing term to
epsilon(float): :math: `
\\
epsilon` in equation is smoothing term to
avoid division by zero, set 1e-6 by default.
avoid division by zero, set 1e-6 by default.
momentum(float): :math:
`
\\
beta` in equation is the momentum term,
momentum(float): :math:`
\\
beta` in equation is the momentum term,
set 0.0 by default.
set 0.0 by default.
Raises:
Raises:
...
@@ -952,7 +1118,9 @@ class ModelAverage(Optimizer):
...
@@ -952,7 +1118,9 @@ class ModelAverage(Optimizer):
max_average_window: The maximum size of average window.
max_average_window: The maximum size of average window.
Examples:
Examples:
...
.. code-block:: python
optimizer = fluid.optimizer.Momentum()
optimizer = fluid.optimizer.Momentum()
_, params_grads = optimizer.minimize(cost)
_, params_grads = optimizer.minimize(cost)
model_average = fluid.optimizer.ModelAverage(params_grads, 0.15,
model_average = fluid.optimizer.ModelAverage(params_grads, 0.15,
...
...
python/paddle/fluid/regularizer.py
浏览文件 @
2074d369
...
@@ -16,8 +16,8 @@ import framework
...
@@ -16,8 +16,8 @@ import framework
from
.
import
core
from
.
import
core
__all__
=
[
__all__
=
[
'append_regularization_ops'
,
'
WeightDecayRegularizer'
,
'L1Decay'
,
'L2Decay
'
,
'append_regularization_ops'
,
'
L1Decay'
,
'L2Decay'
,
'L1DecayRegularizer
'
,
'L
1DecayRegularizer'
,
'L
2DecayRegularizer'
'L2DecayRegularizer'
]
]
...
@@ -36,7 +36,8 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
...
@@ -36,7 +36,8 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
set. It will be applied with regularizer.
set. It will be applied with regularizer.
Returns:
Returns:
list of (parameters, gradients) pair with the regularized gradient
list[(Variable, Variable)]: list of (parameters, gradients)
\
pair with the regularized gradient
Raises:
Raises:
Exception: Unknown regularization type
Exception: Unknown regularization type
...
@@ -100,6 +101,24 @@ class WeightDecayRegularizer(object):
...
@@ -100,6 +101,24 @@ class WeightDecayRegularizer(object):
class
L2DecayRegularizer
(
WeightDecayRegularizer
):
class
L2DecayRegularizer
(
WeightDecayRegularizer
):
"""Implements the L2 Weight Decay Regularization
"""Implements the L2 Weight Decay Regularization
Small values of L2 can help prevent over fitting the training data.
.. math::
L2WeightDecay = reg\_coeff * parameter
Args:
regularization_coeff(float): regularization coeff
Examples:
.. code-block:: python
optimizer = fluid.optimizer.Adagrad(
learning_rate=1e-4,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.1))
optimizer.minimize(avg_cost)
"""
"""
def
__init__
(
self
,
regularization_coeff
=
0.0
):
def
__init__
(
self
,
regularization_coeff
=
0.0
):
...
@@ -154,6 +173,27 @@ class L2DecayRegularizer(WeightDecayRegularizer):
...
@@ -154,6 +173,27 @@ class L2DecayRegularizer(WeightDecayRegularizer):
class
L1DecayRegularizer
(
WeightDecayRegularizer
):
class
L1DecayRegularizer
(
WeightDecayRegularizer
):
"""Implements the L1 Weight Decay Regularization
"""Implements the L1 Weight Decay Regularization
L1 regularization encourages sparsity.
.. math::
L1WeightDecay = reg\_coeff * sign(parameter)
Args:
regularization_coeff(float): regularization coeff
Examples:
.. code-block:: python
program = fluid.framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="mul.x",
regularizer=fluid.regularizer.L1DecayRegularizer(0.5))
"""
"""
def
__init__
(
self
,
regularization_coeff
=
0.0
):
def
__init__
(
self
,
regularization_coeff
=
0.0
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录