Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
8d4a79e5
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
接近 2 年 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8d4a79e5
编写于
2月 07, 2023
作者:
T
tianyi1997
提交者:
HydrogenSulfate
2月 28, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add cyclic learning rate
上级
3c21282d
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
83 addition
and
0 deletion
+83
-0
ppcls/optimizer/learning_rate.py
ppcls/optimizer/learning_rate.py
+83
-0
未找到文件。
ppcls/optimizer/learning_rate.py
浏览文件 @
8d4a79e5
...
@@ -253,6 +253,87 @@ class Cosine(LRBase):
...
@@ -253,6 +253,87 @@ class Cosine(LRBase):
return
learning_rate
return
learning_rate
class
Cyclic
(
LRBase
):
"""Cyclic learning rate decay
Args:
Args:
epochs (int): total epoch(s)
step_each_epoch (int): number of iterations within an epoch
base_learning_rate (float): Initial learning rate, which is the lower boundary in the cycle. The paper recommends
that set the base_learning_rate to 1/3 or 1/4 of max_learning_rate.
max_learning_rate (float): Maximum learning rate in the cycle. It defines the cycle amplitude as above.
Since there is some scaling operation during process of learning rate adjustment,
max_learning_rate may not actually be reached.
warmup_epoch (int): number of warmup epoch(s)
warmup_start_lr (float): start learning rate within warmup
step_size_up (int): Number of training steps, which is used to increase learning rate in a cycle.
The step size of one cycle will be defined by step_size_up + step_size_down. According to the paper, step
size should be set as at least 3 or 4 times steps in one epoch.
step_size_down (int, optional): Number of training steps, which is used to decrease learning rate in a cycle.
If not specified, it's value will initialize to `` step_size_up `` . Default: None
mode (str, optional): one of 'triangular', 'triangular2' or 'exp_range'.
If scale_fn is specified, this argument will be ignored. Default: 'triangular'
exp_gamma (float): Constant in 'exp_range' scaling function: exp_gamma**iterations. Used only when mode = 'exp_range'. Default: 1.0
scale_fn (function, optional): A custom scaling function, which is used to replace three build-in methods.
It should only have one argument. For all x >= 0, 0 <= scale_fn(x) <= 1.
If specified, then 'mode' will be ignored. Default: None
scale_mode (str, optional): One of 'cycle' or 'iterations'. Defines whether scale_fn is evaluated on cycle
number or cycle iterations (total iterations since start of training). Default: 'cycle'
last_epoch (int, optional): The index of last epoch. Can be set to restart training.Default: -1, means initial learning rate.
by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter
verbose: (bool, optional): If True, prints a message to stdout for each update. Defaults to False
"""
def
__init__
(
self
,
epochs
,
step_each_epoch
,
base_learning_rate
,
max_learning_rate
,
warmup_epoch
,
warmup_start_lr
,
step_size_up
,
step_size_down
=
None
,
mode
=
'triangular'
,
exp_gamma
=
1.0
,
scale_fn
=
None
,
scale_mode
=
'cycle'
,
by_epoch
=
False
,
last_epoch
=-
1
,
verbose
=
False
):
super
(
Cyclic
,
self
).
__init__
(
epochs
,
step_each_epoch
,
base_learning_rate
,
warmup_epoch
,
warmup_start_lr
,
last_epoch
,
by_epoch
,
verbose
)
self
.
base_learning_rate
=
base_learning_rate
self
.
max_learning_rate
=
max_learning_rate
self
.
step_size_up
=
step_size_up
self
.
step_size_down
=
step_size_down
self
.
mode
=
mode
self
.
exp_gamma
=
exp_gamma
self
.
scale_fn
=
scale_fn
self
.
scale_mode
=
scale_mode
def
__call__
(
self
):
learning_rate
=
lr
.
CyclicLR
(
base_learning_rate
=
self
.
base_learning_rate
,
max_learning_rate
=
self
.
max_learning_rate
,
step_size_up
=
self
.
step_size_up
,
step_size_down
=
self
.
step_size_down
,
mode
=
self
.
mode
,
exp_gamma
=
self
.
exp_gamma
,
scale_fn
=
self
.
scale_fn
,
scale_mode
=
self
.
scale_mode
,
last_epoch
=
self
.
last_epoch
,
verbose
=
self
.
verbose
)
if
self
.
warmup_steps
>
0
:
learning_rate
=
self
.
linear_warmup
(
learning_rate
)
setattr
(
learning_rate
,
"by_epoch"
,
self
.
by_epoch
)
return
learning_rate
class
Step
(
LRBase
):
class
Step
(
LRBase
):
"""Step learning rate decay
"""Step learning rate decay
...
@@ -421,6 +502,7 @@ class ReduceOnPlateau(LRBase):
...
@@ -421,6 +502,7 @@ class ReduceOnPlateau(LRBase):
last_epoch (int, optional): last epoch. Defaults to -1.
last_epoch (int, optional): last epoch. Defaults to -1.
by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
epochs
,
epochs
,
step_each_epoch
,
step_each_epoch
,
...
@@ -488,6 +570,7 @@ class CosineFixmatch(LRBase):
...
@@ -488,6 +570,7 @@ class CosineFixmatch(LRBase):
last_epoch (int, optional): last epoch. Defaults to -1.
last_epoch (int, optional): last epoch. Defaults to -1.
by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
epochs
,
epochs
,
step_each_epoch
,
step_each_epoch
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录