Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f66d08c2
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
f66d08c2
编写于
9月 20, 2018
作者:
X
Xin Pan
提交者:
GitHub
9月 20, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13493 from panyx0718/doc
convert **kwargs to explicit arguments
上级
943c46c7
88ae3f16
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
116 addition
and
72 deletion
+116
-72
benchmark/fluid/args.py
benchmark/fluid/args.py
+0
-4
benchmark/fluid/models/resnet.py
benchmark/fluid/models/resnet.py
+0
-5
benchmark/fluid/models/resnet_with_preprocess.py
benchmark/fluid/models/resnet_with_preprocess.py
+0
-5
benchmark/fluid/models/se_resnext.py
benchmark/fluid/models/se_resnext.py
+1
-7
paddle/fluid/API.spec
paddle/fluid/API.spec
+10
-10
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+99
-32
python/paddle/fluid/regularizer.py
python/paddle/fluid/regularizer.py
+5
-8
python/paddle/fluid/tests/book/test_recognize_digits.py
python/paddle/fluid/tests/book/test_recognize_digits.py
+1
-1
未找到文件。
benchmark/fluid/args.py
浏览文件 @
f66d08c2
...
...
@@ -136,10 +136,6 @@ def parse_args():
'--no_random'
,
action
=
'store_true'
,
help
=
'If set, keep the random seed and do not shuffle the data.'
)
parser
.
add_argument
(
'--use_lars'
,
action
=
'store_true'
,
help
=
'If set, use lars for optimizers, ONLY support resnet module.'
)
parser
.
add_argument
(
'--reduce_strategy'
,
type
=
str
,
...
...
benchmark/fluid/models/resnet.py
浏览文件 @
f66d08c2
...
...
@@ -200,11 +200,6 @@ def get_model(args, is_train, main_prog, startup_prog):
# configure optimize
optimizer
=
None
if
is_train
:
if
args
.
use_lars
:
lars_decay
=
1.0
else
:
lars_decay
=
0.0
total_images
=
1281167
/
trainer_count
step
=
int
(
total_images
/
(
args
.
batch_size
*
args
.
gpus
)
+
1
)
...
...
benchmark/fluid/models/resnet_with_preprocess.py
浏览文件 @
f66d08c2
...
...
@@ -224,11 +224,6 @@ def get_model(args, is_train, main_prog, startup_prog):
# configure optimize
optimizer
=
None
if
is_train
:
if
args
.
use_lars
:
lars_decay
=
1.0
else
:
lars_decay
=
0.0
total_images
=
1281167
/
trainer_count
step
=
int
(
total_images
/
args
.
batch_size
+
1
)
...
...
benchmark/fluid/models/se_resnext.py
浏览文件 @
f66d08c2
...
...
@@ -244,11 +244,6 @@ def get_model(args, is_train, main_prog, startup_prog):
optimizer
=
None
if
is_train
:
if
args
.
use_lars
:
lars_decay
=
1.0
else
:
lars_decay
=
0.0
total_images
=
1281167
/
trainer_count
step
=
int
(
total_images
/
args
.
batch_size
+
1
)
...
...
@@ -262,8 +257,7 @@ def get_model(args, is_train, main_prog, startup_prog):
learning_rate
=
fluid
.
layers
.
piecewise_decay
(
boundaries
=
bd
,
values
=
lr
),
momentum
=
0.9
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
1e-4
),
LARS_weight_decay
=
lars_decay
)
regularization
=
fluid
.
regularizer
.
L2Decay
(
1e-4
))
optimizer
.
minimize
(
avg_cost
)
if
args
.
memory_optimize
:
...
...
paddle/fluid/API.spec
浏览文件 @
f66d08c2
...
...
@@ -350,25 +350,25 @@ paddle.fluid.nets.simple_img_conv_pool ArgSpec(args=['input', 'num_filters', 'fi
paddle.fluid.nets.sequence_conv_pool ArgSpec(args=['input', 'num_filters', 'filter_size', 'param_attr', 'act', 'pool_type'], varargs=None, keywords=None, defaults=(None, 'sigmoid', 'max'))
paddle.fluid.nets.glu ArgSpec(args=['input', 'dim'], varargs=None, keywords=None, defaults=(-1,))
paddle.fluid.nets.scaled_dot_product_attention ArgSpec(args=['queries', 'keys', 'values', 'num_heads', 'dropout_rate'], varargs=None, keywords=None, defaults=(1, 0.0))
paddle.fluid.optimizer.SGDOptimizer.__init__ ArgSpec(args=['self', 'learning_rate'
], varargs=None, keywords='kwargs', defaults=None
)
paddle.fluid.optimizer.SGDOptimizer.__init__ ArgSpec(args=['self', 'learning_rate'
, 'regularization', 'name'], varargs=None, keywords=None, defaults=(None, None)
)
paddle.fluid.optimizer.SGDOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.optimizer.MomentumOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'momentum', 'use_nesterov'
], varargs=None, keywords='kwargs', defaults=(False,
))
paddle.fluid.optimizer.MomentumOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'momentum', 'use_nesterov'
, 'regularization', 'name'], varargs=None, keywords=None, defaults=(False, None, None
))
paddle.fluid.optimizer.MomentumOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.optimizer.AdagradOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'epsilon'
], varargs=None, keywords='kwargs', defaults=(1e-06,
))
paddle.fluid.optimizer.AdagradOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'epsilon'
, 'regularization', 'name'], varargs=None, keywords=None, defaults=(1e-06, None, None
))
paddle.fluid.optimizer.AdagradOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.optimizer.AdamOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'beta1', 'beta2', 'epsilon'
], varargs=None, keywords='kwargs', defaults=(0.001, 0.9, 0.999, 1e-08
))
paddle.fluid.optimizer.AdamOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'beta1', 'beta2', 'epsilon'
, 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.9, 0.999, 1e-08, None, None
))
paddle.fluid.optimizer.AdamOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.optimizer.AdamaxOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'beta1', 'beta2', 'epsilon'
], varargs=None, keywords='kwargs', defaults=(0.001, 0.9, 0.999, 1e-08
))
paddle.fluid.optimizer.AdamaxOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'beta1', 'beta2', 'epsilon'
, 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.9, 0.999, 1e-08, None, None
))
paddle.fluid.optimizer.AdamaxOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.optimizer.DecayedAdagradOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'decay', 'epsilon'
], varargs=None, keywords='kwargs', defaults=(0.95, 1e-06
))
paddle.fluid.optimizer.DecayedAdagradOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'decay', 'epsilon'
, 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.95, 1e-06, None, None
))
paddle.fluid.optimizer.DecayedAdagradOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.optimizer.FtrlOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'l1', 'l2', 'lr_power'
], varargs=None, keywords='kwargs', defaults=(0.0, 0.0, -0.5
))
paddle.fluid.optimizer.FtrlOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'l1', 'l2', 'lr_power'
, 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.0, 0.0, -0.5, None, None
))
paddle.fluid.optimizer.FtrlOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.optimizer.RMSPropOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'rho', 'epsilon', 'momentum', 'centered'
], varargs=None, keywords='kwargs', defaults=(0.95, 1e-06, 0.0, Fals
e))
paddle.fluid.optimizer.RMSPropOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'rho', 'epsilon', 'momentum', 'centered'
, 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.95, 1e-06, 0.0, False, None, Non
e))
paddle.fluid.optimizer.RMSPropOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.optimizer.AdadeltaOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'epsilon', 'rho'
], varargs=None, keywords='kwargs', defaults=(1e-06, 0.95
))
paddle.fluid.optimizer.AdadeltaOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'epsilon', 'rho'
, 'regularization', 'name'], varargs=None, keywords=None, defaults=(1e-06, 0.95, None, None
))
paddle.fluid.optimizer.AdadeltaOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.optimizer.ModelAverage.__init__ ArgSpec(args=['self', 'average_window_rate', 'min_average_window', 'max_average_window'
], varargs=None, keywords='kwargs', defaults=(10000, 10000
))
paddle.fluid.optimizer.ModelAverage.__init__ ArgSpec(args=['self', 'average_window_rate', 'min_average_window', 'max_average_window'
, 'regularization', 'name'], varargs=None, keywords=None, defaults=(10000, 10000, None, None
))
paddle.fluid.optimizer.ModelAverage.apply ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
paddle.fluid.optimizer.ModelAverage.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.optimizer.ModelAverage.restore ArgSpec(args=['self', 'executor'], varargs=None, keywords=None, defaults=None)
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
f66d08c2
...
...
@@ -43,11 +43,7 @@ class Optimizer(object):
but need to use one of it's implementation.
"""
def
__init__
(
self
,
learning_rate
,
regularization
=
None
,
LARS_weight_decay
=
0.0
,
name
=
None
):
def
__init__
(
self
,
learning_rate
,
regularization
=
None
,
name
=
None
):
if
not
isinstance
(
learning_rate
,
float
)
and
\
not
isinstance
(
learning_rate
,
framework
.
Variable
):
raise
TypeError
(
"learning rate should be float or Variable"
)
...
...
@@ -68,7 +64,6 @@ class Optimizer(object):
# {accum_name : { paramter_name : accumulator_for_parameter, ...}, ...}
self
.
_accumulators
=
defaultdict
(
lambda
:
dict
())
self
.
helper
=
None
self
.
_LARS_weight_decay
=
LARS_weight_decay
def
_create_global_learning_rate
(
self
):
lr
=
self
.
_global_learning_rate
()
...
...
@@ -109,7 +104,6 @@ class Optimizer(object):
param
=
param_and_grad
[
0
]
param_lr
=
param
.
optimize_attr
[
'learning_rate'
]
if
type
(
param_lr
)
==
Variable
:
# param learning rate has been updated (LARS)
print
(
"returns updated param lr "
,
param_lr
)
return
param_lr
else
:
...
...
@@ -227,10 +221,6 @@ class Optimizer(object):
self
.
_create_accumulators
(
loss
.
block
,
[
p
[
0
]
for
p
in
parameters_and_grads
])
self
.
_create_global_learning_rate
()
if
self
.
_LARS_weight_decay
>
0.0
:
layers
.
append_LARS
(
parameters_and_grads
,
self
.
_global_learning_rate
(),
self
.
_LARS_weight_decay
)
optimize_ops
=
[]
for
param_and_grad
in
parameters_and_grads
:
...
...
@@ -287,6 +277,9 @@ class SGDOptimizer(Optimizer):
Args:
learning_rate (float|Variable): the learning rate used to update parameters.
\
Can be a float value or a Variable with one float value as data element.
regularization: A Regularizer, such as
fluid.regularizer.L2DecayRegularizer.
name: A optional name prefix.
Examples:
.. code-block:: python
...
...
@@ -295,10 +288,12 @@ class SGDOptimizer(Optimizer):
sgd_optimizer.minimize(cost)
"""
def
__init__
(
self
,
learning_rate
,
**
kwargs
):
def
__init__
(
self
,
learning_rate
,
regularization
=
None
,
name
=
None
):
assert
learning_rate
is
not
None
super
(
SGDOptimizer
,
self
).
__init__
(
learning_rate
=
learning_rate
,
**
kwargs
)
learning_rate
=
learning_rate
,
regularization
=
regularization
,
name
=
name
)
self
.
type
=
"sgd"
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
...
...
@@ -343,6 +338,9 @@ class MomentumOptimizer(Optimizer):
Can be a float value or a Variable with one float value as data element.
momentum (float): momentum factor
use_nesterov (bool): enables Nesterov momentum
regularization: A Regularizer, such as
fluid.regularizer.L2DecayRegularizer.
name: A optional name prefix.
Examples:
.. code-block:: python
...
...
@@ -352,11 +350,18 @@ class MomentumOptimizer(Optimizer):
"""
_velocity_acc_str
=
"velocity"
def
__init__
(
self
,
learning_rate
,
momentum
,
use_nesterov
=
False
,
**
kwargs
):
def
__init__
(
self
,
learning_rate
,
momentum
,
use_nesterov
=
False
,
regularization
=
None
,
name
=
None
):
assert
learning_rate
is
not
None
assert
momentum
is
not
None
super
(
MomentumOptimizer
,
self
).
__init__
(
learning_rate
=
learning_rate
,
**
kwargs
)
learning_rate
=
learning_rate
,
regularization
=
regularization
,
name
=
name
)
self
.
type
=
"momentum"
self
.
_momentum
=
momentum
self
.
_use_nesterov
=
bool
(
use_nesterov
)
...
...
@@ -412,6 +417,9 @@ class AdagradOptimizer(Optimizer):
learning_rate (float|Variable): the learning rate used to update parameters.
\
Can be a float value or a Variable with one float value as data element.
epsilon (float): a small float value for numerical stability.
regularization: A Regularizer, such as
fluid.regularizer.L2DecayRegularizer.
name: A optional name prefix.
Examples:
.. code-block:: python
...
...
@@ -421,11 +429,17 @@ class AdagradOptimizer(Optimizer):
"""
_moment_acc_str
=
"moment"
def
__init__
(
self
,
learning_rate
,
epsilon
=
1.0e-6
,
**
kwargs
):
def
__init__
(
self
,
learning_rate
,
epsilon
=
1.0e-6
,
regularization
=
None
,
name
=
None
):
assert
learning_rate
is
not
None
assert
epsilon
is
not
None
super
(
AdagradOptimizer
,
self
).
__init__
(
learning_rate
=
learning_rate
,
**
kwargs
)
learning_rate
=
learning_rate
,
regularization
=
regularization
,
name
=
name
)
self
.
type
=
"adagrad"
self
.
_epsilon
=
epsilon
...
...
@@ -485,6 +499,9 @@ class AdamOptimizer(Optimizer):
beta1 (float): The exponential decay rate for the 1st moment estimates.
beta2 (float): The exponential decay rate for the 2nd moment estimates.
epsilon (float): a small float value for numerical stability.
regularization: A Regularizer, such as
fluid.regularizer.L2DecayRegularizer.
name: A optional name prefix.
Examples:
.. code-block:: python
...
...
@@ -503,13 +520,16 @@ class AdamOptimizer(Optimizer):
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-8
,
**
kwargs
):
regularization
=
None
,
name
=
None
):
assert
learning_rate
is
not
None
assert
beta1
is
not
None
assert
beta2
is
not
None
assert
epsilon
is
not
None
super
(
AdamOptimizer
,
self
).
__init__
(
learning_rate
=
learning_rate
,
**
kwargs
)
learning_rate
=
learning_rate
,
regularization
=
regularization
,
name
=
name
)
self
.
type
=
"adam"
self
.
_beta1
=
beta1
self
.
_beta2
=
beta2
...
...
@@ -629,6 +649,9 @@ class AdamaxOptimizer(Optimizer):
beta1 (float): The exponential decay rate for the 1st moment estimates.
beta2 (float): The exponential decay rate for the 2nd moment estimates.
epsilon (float): a small float value for numerical stability.
regularization: A Regularizer, such as
fluid.regularizer.L2DecayRegularizer.
name: A optional name prefix.
Examples:
.. code-block:: python
...
...
@@ -645,13 +668,16 @@ class AdamaxOptimizer(Optimizer):
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-8
,
**
kwargs
):
regularization
=
None
,
name
=
None
):
assert
learning_rate
is
not
None
assert
beta1
is
not
None
assert
beta2
is
not
None
assert
epsilon
is
not
None
super
(
AdamaxOptimizer
,
self
).
__init__
(
learning_rate
=
learning_rate
,
**
kwargs
)
learning_rate
=
learning_rate
,
regularization
=
regularization
,
name
=
name
)
self
.
type
=
"adamax"
self
.
_beta1
=
beta1
self
.
_beta2
=
beta2
...
...
@@ -742,6 +768,9 @@ class DecayedAdagradOptimizer(Optimizer):
Can be a float value or a Variable with one float value as data element.
decay (float): decay rate.
epsilon (float): a small float value for numerical stability.
regularization: A Regularizer, such as
fluid.regularizer.L2DecayRegularizer.
name: A optional name prefix.
Examples:
.. code-block:: python
...
...
@@ -751,13 +780,20 @@ class DecayedAdagradOptimizer(Optimizer):
"""
_moment_acc_str
=
"moment"
def
__init__
(
self
,
learning_rate
,
decay
=
0.95
,
epsilon
=
1.0e-6
,
**
kwargs
):
def
__init__
(
self
,
learning_rate
,
decay
=
0.95
,
epsilon
=
1.0e-6
,
regularization
=
None
,
name
=
None
):
assert
learning_rate
is
not
None
assert
decay
is
not
None
assert
epsilon
is
not
None
super
(
DecayedAdagradOptimizer
,
self
).
__init__
(
learning_rate
=
learning_rate
,
**
kwargs
)
learning_rate
=
learning_rate
,
regularization
=
regularization
,
name
=
name
)
self
.
type
=
"decayed_adagrad"
self
.
_decay
=
decay
self
.
_epsilon
=
epsilon
...
...
@@ -811,6 +847,9 @@ class AdadeltaOptimizer(Optimizer):
learning_rate(float): global learning rate
rho(float): rho in equation
epsilon(float): epsilon in equation
regularization: A Regularizer, such as
fluid.regularizer.L2DecayRegularizer.
name: A optional name prefix.
Examples:
.. code-block:: python
...
...
@@ -823,7 +862,12 @@ class AdadeltaOptimizer(Optimizer):
_avg_squared_grad_acc_str
=
"_avg_squared_grad"
_avg_squared_update_acc_str
=
"_avg_squared_update"
def
__init__
(
self
,
learning_rate
,
epsilon
=
1.0e-6
,
rho
=
0.95
,
**
kwargs
):
def
__init__
(
self
,
learning_rate
,
epsilon
=
1.0e-6
,
rho
=
0.95
,
regularization
=
None
,
name
=
None
):
if
learning_rate
is
None
:
raise
ValueError
(
"learning_rate is not set."
)
if
epsilon
is
None
:
...
...
@@ -831,7 +875,9 @@ class AdadeltaOptimizer(Optimizer):
if
rho
is
None
:
raise
ValueError
(
"rho is not set."
)
super
(
AdadeltaOptimizer
,
self
).
__init__
(
learning_rate
=
learning_rate
,
**
kwargs
)
learning_rate
=
learning_rate
,
regularization
=
regularization
,
name
=
name
)
self
.
type
=
"adadelta"
self
.
_epsilon
=
epsilon
self
.
_rho
=
rho
...
...
@@ -932,6 +978,9 @@ class RMSPropOptimizer(Optimizer):
the gradient; if False, by the uncentered second moment. Setting this to
True may help with training, but is slightly more expensive in terms of
computation and memory. Defaults to False.
regularization: A Regularizer, such as
fluid.regularizer.L2DecayRegularizer.
name: A optional name prefix.
Raises:
ValueError: If learning_rate, rho, epsilon, momentum are None.
...
...
@@ -953,9 +1002,12 @@ class RMSPropOptimizer(Optimizer):
epsilon
=
1.0e-6
,
momentum
=
0.0
,
centered
=
False
,
**
kwargs
):
regularization
=
None
,
name
=
None
):
super
(
RMSPropOptimizer
,
self
).
__init__
(
learning_rate
=
learning_rate
,
**
kwargs
)
learning_rate
=
learning_rate
,
regularization
=
regularization
,
name
=
name
)
if
learning_rate
is
None
:
raise
ValueError
(
"learning_rate is not set."
)
if
rho
is
None
:
...
...
@@ -1061,6 +1113,9 @@ class FtrlOptimizer(Optimizer):
l1 (float):
l2 (float):
lr_power (float):
regularization: A Regularizer, such as
fluid.regularizer.L2DecayRegularizer.
name: A optional name prefix.
Raises:
ValueError: If learning_rate, rho, epsilon, momentum are None.
...
...
@@ -1075,9 +1130,17 @@ class FtrlOptimizer(Optimizer):
_squared_acc_str
=
"squared"
_linear_acc_str
=
"linear"
def
__init__
(
self
,
learning_rate
,
l1
=
0.0
,
l2
=
0.0
,
lr_power
=-
0.5
,
**
kwargs
):
def
__init__
(
self
,
learning_rate
,
l1
=
0.0
,
l2
=
0.0
,
lr_power
=-
0.5
,
regularization
=
None
,
name
=
None
):
super
(
FtrlOptimizer
,
self
).
__init__
(
learning_rate
=
learning_rate
,
**
kwargs
)
learning_rate
=
learning_rate
,
regularization
=
regularization
,
name
=
name
)
if
learning_rate
is
None
:
raise
ValueError
(
"learning_rate is not set."
)
...
...
@@ -1155,7 +1218,9 @@ class ModelAverage(Optimizer):
average_window_rate: The rate of average window.
min_average_window: The minimum size of average window.
max_average_window: The maximum size of average window.
regularization: A Regularizer, such as
fluid.regularizer.L2DecayRegularizer.
name: A optional name prefix.
Examples:
.. code-block:: python
...
...
@@ -1178,8 +1243,10 @@ class ModelAverage(Optimizer):
average_window_rate
,
min_average_window
=
10000
,
max_average_window
=
10000
,
**
kwargs
):
super
(
ModelAverage
,
self
).
__init__
(
0.0
,
**
kwargs
)
regularization
=
None
,
name
=
None
):
super
(
ModelAverage
,
self
).
__init__
(
0.0
,
regularization
=
regularization
,
name
=
name
)
self
.
average_window
=
average_window_rate
self
.
min_average_window
=
min_average_window
self
.
max_average_window
=
max_average_window
...
...
python/paddle/fluid/regularizer.py
浏览文件 @
f66d08c2
...
...
@@ -190,14 +190,11 @@ class L1DecayRegularizer(WeightDecayRegularizer):
Examples:
.. code-block:: python
program = fluid.framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="mul.x",
regularizer=fluid.regularizer.L1DecayRegularizer(0.5))
optimizer = fluid.optimizer.Adagrad(
learning_rate=1e-4,
regularization=fluid.regularizer.L1DecayRegularizer(
regularization_coeff=0.1))
optimizer.minimize(avg_cost)
"""
def
__init__
(
self
,
regularization_coeff
=
0.0
):
...
...
python/paddle/fluid/tests/book/test_recognize_digits.py
浏览文件 @
f66d08c2
...
...
@@ -99,7 +99,7 @@ def train(nn_type,
test_program
=
fluid
.
default_main_program
().
clone
(
for_test
=
True
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
,
LARS_weight_decay
=
0.3
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
)
optimizer
.
minimize
(
avg_loss
)
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录