Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
c36bf197
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c36bf197
编写于
3月 27, 2017
作者:
Q
qijun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add optimizer doc
上级
929090ed
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
131 addition
and
2 deletion
+131
-2
doc/api/v2/config/optimizer.rst
doc/api/v2/config/optimizer.rst
+0
-2
python/paddle/v2/optimizer.py
python/paddle/v2/optimizer.py
+131
-0
未找到文件。
doc/api/v2/config/optimizer.rst
浏览文件 @
c36bf197
.. _api_v2.optimizer:
==========
Optimizer
==========
...
...
python/paddle/v2/optimizer.py
浏览文件 @
c36bf197
...
...
@@ -47,6 +47,35 @@ class Optimizer(object):
class
Momentum
(
Optimizer
):
"""
SGD Optimizer.
SGD is an optimization method, trying to find a neural network that
minimize the "cost/error" of it by iteration. In paddle's implementation
SGD Optimizer is synchronized, which means all gradients will be wait to
calculate and reduced into one gradient, then do optimize operation.
The neural network consider the learning problem of minimizing an objective
function, that has the form of a sum
.. math::
Q(w) =
\\
sum_{i}^{n} Q_i(w)
The value of function Q sometimes is the cost of neural network (Mean
Square Error between prediction and label for example). The function Q is
parametrised by w, the weight/bias of neural network. And weights is what to
be learned. The i is the i-th observation in (trainning) data.
So, the SGD method will optimize the weight by
.. math::
w = w -
\\
eta
\\
nabla Q(w) = w -
\\
eta
\\
sum_{i}^{n}
\\
nabla Q_i(w)
where :math:`
\\
eta` is learning rate. And :math:`n` is batch size.
"""
def
__init__
(
self
,
momentum
=
None
,
sparse
=
False
,
**
kwargs
):
learning_method
=
v1_optimizers
.
MomentumOptimizer
(
momentum
=
momentum
,
sparse
=
sparse
)
...
...
@@ -55,6 +84,26 @@ class Momentum(Optimizer):
class
Adam
(
Optimizer
):
"""
Adam optimizer.
The details of please refer `Adam: A Method for Stochastic Optimization
<https://arxiv.org/abs/1412.6980>`_
.. math::
m(w, t) & =
\\
beta_1 m(w, t-1) + (1 -
\\
beta_1)
\\
nabla Q_i(w)
\\\\
v(w, t) & =
\\
beta_2 v(w, t-1) + (1 -
\\
beta_2)(
\\
nabla Q_i(w)) ^2
\\\\
w & = w -
\\
frac{
\\
eta}{
\\
sqrt{v(w,t) +
\\
epsilon}}
:param beta1: the :math:`
\\
beta_1` in equation.
:type beta1: float
:param beta2: the :math:`
\\
beta_2` in equation.
:type beta2: float
:param epsilon: the :math:`
\\
epsilon` in equation. It is used to prevent
divided by zero.
:type epsilon: float
"""
def
__init__
(
self
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-8
,
**
kwargs
):
learning_method
=
v1_optimizers
.
AdamOptimizer
(
beta1
=
beta1
,
beta2
=
beta2
,
epsilon
=
epsilon
)
...
...
@@ -62,6 +111,24 @@ class Adam(Optimizer):
class
Adamax
(
Optimizer
):
"""
Adamax optimizer.
The details of please refer this `Adam: A Method for Stochastic Optimization
<https://arxiv.org/abs/1412.6980>`_
.. math::
m_t & =
\\
beta_1 * m_{t-1} + (1-
\\
beta_1)*
\\
nabla Q_i(w)
\\\\
u_t & = max(
\\
beta_2*u_{t-1}, abs(
\\
nabla Q_i(w)))
\\\\
w_t & = w_{t-1} - (
\\
eta/(1-
\\
beta_1^t))*m_t/u_t
:param beta1: the :math:`
\\
beta_1` in the equation.
:type beta1: float
:param beta2: the :math:`
\\
beta_2` in the equation.
:type beta2: float
"""
def
__init__
(
self
,
beta1
=
0.9
,
beta2
=
0.999
,
**
kwargs
):
learning_method
=
v1_optimizers
.
AdamaxOptimizer
(
beta1
=
beta1
,
beta2
=
beta2
)
...
...
@@ -69,12 +136,40 @@ class Adamax(Optimizer):
class
AdaGrad
(
Optimizer
):
"""
Adagrad(for ADAptive GRAdient algorithm) optimizer.
For details please refer this `Adaptive Subgradient Methods for
Online Learning and Stochastic Optimization
<http://www.magicbroom.info/Papers/DuchiHaSi10.pdf>`_.
.. math::
G &=
\\
sum_{
\\
tau=1}^{t} g_{
\\
tau} g_{
\\
tau}^T
\\\\
w & = w -
\\
eta diag(G)^{-
\\
frac{1}{2}}
\\
circ g
"""
def
__init__
(
self
,
**
kwargs
):
learning_method
=
v1_optimizers
.
AdaGradOptimizer
()
super
(
AdaGrad
,
self
).
__init__
(
learning_method
=
learning_method
,
**
kwargs
)
class
DecayedAdaGrad
(
Optimizer
):
"""
AdaGrad method with decayed sum gradients. The equations of this method
show as follow.
.. math::
E(g_t^2) &=
\\
rho * E(g_{t-1}^2) + (1-
\\
rho) * g^2
\\\\
learning
\\
_rate &= 1/sqrt( ( E(g_t^2) +
\\
epsilon )
:param rho: The :math:`
\\
rho` parameter in that equation
:type rho: float
:param epsilon: The :math:`
\\
epsilon` parameter in that equation.
:type epsilon: float
"""
def
__init__
(
self
,
rho
=
0.95
,
epsilon
=
1e-06
,
**
kwargs
):
learning_method
=
v1_optimizers
.
DecayedAdaGradOptimizer
(
rho
=
rho
,
epsilon
=
epsilon
)
...
...
@@ -83,6 +178,24 @@ class DecayedAdaGrad(Optimizer):
class
AdaDelta
(
Optimizer
):
"""
AdaDelta method. The details of adadelta please refer to this
`ADADELTA: AN ADAPTIVE LEARNING RATE METHOD
<http://www.matthewzeiler.com/pubs/googleTR2012/googleTR2012.pdf>`_.
.. math::
E(g_t^2) &=
\\
rho * E(g_{t-1}^2) + (1-
\\
rho) * g^2
\\\\
learning
\\
_rate &= sqrt( ( E(dx_{t-1}^2) +
\\
epsilon ) / (
\\
E(g_t^2) +
\\
epsilon ) )
\\\\
E(dx_t^2) &=
\\
rho * E(dx_{t-1}^2) + (1-
\\
rho) * (-g*learning
\\
_rate)^2
:param rho: :math:`
\\
rho` in equation
:type rho: float
:param epsilon: :math:`
\\
rho` in equation
:type epsilon: float
"""
def
__init__
(
self
,
rho
=
0.95
,
epsilon
=
1e-06
,
**
kwargs
):
learning_method
=
v1_optimizers
.
AdaDeltaOptimizer
(
rho
=
rho
,
epsilon
=
epsilon
)
...
...
@@ -91,6 +204,24 @@ class AdaDelta(Optimizer):
class
RMSProp
(
Optimizer
):
"""
RMSProp(for Root Mean Square Propagation) optimizer. For details please
refer this `slide <http://www.cs.toronto.edu/~tijmen/csc321/slides/
lecture_slides_lec6.pdf>`_.
The equations of this method as follows:
.. math::
v(w, t) & =
\\
rho v(w, t-1) + (1 -
\\
rho)(
\\
nabla Q_{i}(w))^2
\\\\
w & = w -
\\
frac{
\\
eta} {
\\
sqrt{v(w,t) +
\\
epsilon}}
\\
nabla Q_{i}(w)
:param rho: the :math:`
\\
rho` in the equation. The forgetting factor.
:type rho: float
:param epsilon: the :math:`
\\
epsilon` in the equation.
:type epsilon: float
"""
def
__init__
(
self
,
rho
=
0.95
,
epsilon
=
1e-6
,
**
kwargs
):
learning_method
=
v1_optimizers
.
RMSPropOptimizer
(
rho
=
rho
,
epsilon
=
epsilon
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录