Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
107a316f
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
107a316f
编写于
1月 16, 2022
作者:
B
bupt906
提交者:
GitHub
1月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add onecycle (#5252)
上级
8bae1e40
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
164 addition
and
1 deletion
+164
-1
ppocr/optimizer/learning_rate.py
ppocr/optimizer/learning_rate.py
+51
-1
ppocr/optimizer/lr_scheduler.py
ppocr/optimizer/lr_scheduler.py
+113
-0
未找到文件。
ppocr/optimizer/learning_rate.py
浏览文件 @
107a316f
...
...
@@ -18,7 +18,7 @@ from __future__ import print_function
from
__future__
import
unicode_literals
from
paddle.optimizer
import
lr
from
.lr_scheduler
import
CyclicalCosineDecay
from
.lr_scheduler
import
CyclicalCosineDecay
,
OneCycleDecay
class
Linear
(
object
):
...
...
@@ -226,3 +226,53 @@ class CyclicalCosine(object):
end_lr
=
self
.
learning_rate
,
last_epoch
=
self
.
last_epoch
)
return
learning_rate
class
OneCycle
(
object
):
"""
One Cycle learning rate decay
Args:
max_lr(float): Upper learning rate boundaries
epochs(int): total training epochs
step_each_epoch(int): steps each epoch
anneal_strategy(str): {‘cos’, ‘linear’} Specifies the annealing strategy: “cos” for cosine annealing, “linear” for linear annealing.
Default: ‘cos’
three_phase(bool): If True, use a third phase of the schedule to annihilate the learning rate according to ‘final_div_factor’
instead of modifying the second phase (the first two phases will be symmetrical about the step indicated by ‘pct_start’).
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def
__init__
(
self
,
max_lr
,
epochs
,
step_each_epoch
,
anneal_strategy
=
'cos'
,
three_phase
=
False
,
warmup_epoch
=
0
,
last_epoch
=-
1
,
**
kwargs
):
super
(
OneCycle
,
self
).
__init__
()
self
.
max_lr
=
max_lr
self
.
epochs
=
epochs
self
.
steps_per_epoch
=
step_each_epoch
self
.
anneal_strategy
=
anneal_strategy
self
.
three_phase
=
three_phase
self
.
last_epoch
=
last_epoch
self
.
warmup_epoch
=
round
(
warmup_epoch
*
step_each_epoch
)
def
__call__
(
self
):
learning_rate
=
OneCycleDecay
(
max_lr
=
self
.
max_lr
,
epochs
=
self
.
epochs
,
steps_per_epoch
=
self
.
steps_per_epoch
,
anneal_strategy
=
self
.
anneal_strategy
,
three_phase
=
self
.
three_phase
,
last_epoch
=
self
.
last_epoch
)
if
self
.
warmup_epoch
>
0
:
learning_rate
=
lr
.
LinearWarmup
(
learning_rate
=
learning_rate
,
warmup_steps
=
self
.
warmup_epoch
,
start_lr
=
0.0
,
end_lr
=
self
.
max_lr
,
last_epoch
=
self
.
last_epoch
)
return
learning_rate
\ No newline at end of file
ppocr/optimizer/lr_scheduler.py
浏览文件 @
107a316f
...
...
@@ -47,3 +47,116 @@ class CyclicalCosineDecay(LRScheduler):
lr
=
self
.
eta_min
+
0.5
*
(
self
.
base_lr
-
self
.
eta_min
)
*
\
(
1
+
math
.
cos
(
math
.
pi
*
reletive_epoch
/
self
.
cycle
))
return
lr
class
OneCycleDecay
(
LRScheduler
):
"""
One Cycle learning rate decay
A learning rate which can be referred in https://arxiv.org/abs/1708.07120
Code refered in https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
"""
def
__init__
(
self
,
max_lr
,
epochs
=
None
,
steps_per_epoch
=
None
,
pct_start
=
0.3
,
anneal_strategy
=
'cos'
,
div_factor
=
25.
,
final_div_factor
=
1e4
,
three_phase
=
False
,
last_epoch
=-
1
,
verbose
=
False
):
# Validate total_steps
if
epochs
<=
0
or
not
isinstance
(
epochs
,
int
):
raise
ValueError
(
"Expected positive integer epochs, but got {}"
.
format
(
epochs
))
if
steps_per_epoch
<=
0
or
not
isinstance
(
steps_per_epoch
,
int
):
raise
ValueError
(
"Expected positive integer steps_per_epoch, but got {}"
.
format
(
steps_per_epoch
))
self
.
total_steps
=
epochs
*
steps_per_epoch
self
.
max_lr
=
max_lr
self
.
initial_lr
=
self
.
max_lr
/
div_factor
self
.
min_lr
=
self
.
initial_lr
/
final_div_factor
if
three_phase
:
self
.
_schedule_phases
=
[
{
'end_step'
:
float
(
pct_start
*
self
.
total_steps
)
-
1
,
'start_lr'
:
self
.
initial_lr
,
'end_lr'
:
self
.
max_lr
,
},
{
'end_step'
:
float
(
2
*
pct_start
*
self
.
total_steps
)
-
2
,
'start_lr'
:
self
.
max_lr
,
'end_lr'
:
self
.
initial_lr
,
},
{
'end_step'
:
self
.
total_steps
-
1
,
'start_lr'
:
self
.
initial_lr
,
'end_lr'
:
self
.
min_lr
,
},
]
else
:
self
.
_schedule_phases
=
[
{
'end_step'
:
float
(
pct_start
*
self
.
total_steps
)
-
1
,
'start_lr'
:
self
.
initial_lr
,
'end_lr'
:
self
.
max_lr
,
},
{
'end_step'
:
self
.
total_steps
-
1
,
'start_lr'
:
self
.
max_lr
,
'end_lr'
:
self
.
min_lr
,
},
]
# Validate pct_start
if
pct_start
<
0
or
pct_start
>
1
or
not
isinstance
(
pct_start
,
float
):
raise
ValueError
(
"Expected float between 0 and 1 pct_start, but got {}"
.
format
(
pct_start
))
# Validate anneal_strategy
if
anneal_strategy
not
in
[
'cos'
,
'linear'
]:
raise
ValueError
(
"anneal_strategy must by one of 'cos' or 'linear', instead got {}"
.
format
(
anneal_strategy
))
elif
anneal_strategy
==
'cos'
:
self
.
anneal_func
=
self
.
_annealing_cos
elif
anneal_strategy
==
'linear'
:
self
.
anneal_func
=
self
.
_annealing_linear
super
(
OneCycleDecay
,
self
).
__init__
(
max_lr
,
last_epoch
,
verbose
)
def
_annealing_cos
(
self
,
start
,
end
,
pct
):
"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
cos_out
=
math
.
cos
(
math
.
pi
*
pct
)
+
1
return
end
+
(
start
-
end
)
/
2.0
*
cos_out
def
_annealing_linear
(
self
,
start
,
end
,
pct
):
"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
return
(
end
-
start
)
*
pct
+
start
def
get_lr
(
self
):
computed_lr
=
0.0
step_num
=
self
.
last_epoch
if
step_num
>
self
.
total_steps
:
raise
ValueError
(
"Tried to step {} times. The specified number of total steps is {}"
.
format
(
step_num
+
1
,
self
.
total_steps
))
start_step
=
0
for
i
,
phase
in
enumerate
(
self
.
_schedule_phases
):
end_step
=
phase
[
'end_step'
]
if
step_num
<=
end_step
or
i
==
len
(
self
.
_schedule_phases
)
-
1
:
pct
=
(
step_num
-
start_step
)
/
(
end_step
-
start_step
)
computed_lr
=
self
.
anneal_func
(
phase
[
'start_lr'
],
phase
[
'end_lr'
],
pct
)
break
start_step
=
phase
[
'end_step'
]
return
computed_lr
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录