Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
a676fa99
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a676fa99
编写于
11月 13, 2019
作者:
F
fuyi02
提交者:
wuzewu
11月 13, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add warmup strategy (#86)
上级
43f56a57
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
29 addition
and
0 deletion
+29
-0
pdseg/solver.py
pdseg/solver.py
+25
-0
pdseg/utils/config.py
pdseg/utils/config.py
+4
-0
未找到文件。
pdseg/solver.py
浏览文件 @
a676fa99
...
...
@@ -34,6 +34,25 @@ class Solver(object):
self
.
main_prog
=
main_prog
self
.
start_prog
=
start_prog
def
lr_warmup
(
self
,
learning_rate
,
warmup_steps
,
start_lr
,
end_lr
):
linear_step
=
end_lr
-
start_lr
lr
=
fluid
.
layers
.
tensor
.
create_global_var
(
shape
=
[
1
],
value
=
0.0
,
dtype
=
'float32'
,
persistable
=
True
,
name
=
"learning_rate_warmup"
)
global_step
=
fluid
.
layers
.
learning_rate_scheduler
.
_decay_step_counter
()
with
fluid
.
layers
.
control_flow
.
Switch
()
as
switch
:
with
switch
.
case
(
global_step
<
warmup_steps
):
decayed_lr
=
start_lr
+
linear_step
*
(
global_step
/
warmup_steps
)
fluid
.
layers
.
tensor
.
assign
(
decayed_lr
,
lr
)
with
switch
.
default
():
fluid
.
layers
.
tensor
.
assign
(
learning_rate
,
lr
)
return
lr
def
piecewise_decay
(
self
):
gamma
=
cfg
.
SOLVER
.
GAMMA
bd
=
[
self
.
step_per_epoch
*
e
for
e
in
cfg
.
SOLVER
.
DECAY_EPOCH
]
...
...
@@ -63,6 +82,12 @@ class Solver(object):
raise
Exception
(
"unsupport learning decay policy! only support poly,piecewise,cosine"
)
if
cfg
.
SOLVER
.
LR_WARMUP
:
start_lr
=
0
end_lr
=
cfg
.
SOLVER
.
LR
warmup_steps
=
cfg
.
SOLVER
.
LR_WARMUP_STEPS
decayed_lr
=
self
.
lr_warmup
(
decayed_lr
,
warmup_steps
,
start_lr
,
end_lr
)
return
decayed_lr
def
sgd_optimizer
(
self
,
lr_policy
,
loss
):
...
...
pdseg/utils/config.py
浏览文件 @
a676fa99
...
...
@@ -154,6 +154,10 @@ cfg.SOLVER.BEGIN_EPOCH = 1
cfg
.
SOLVER
.
NUM_EPOCHS
=
30
# loss的选择,支持softmax_loss, bce_loss, dice_loss
cfg
.
SOLVER
.
LOSS
=
[
"softmax_loss"
]
# 是否开启warmup学习策略
cfg
.
SOLVER
.
LR_WARMUP
=
False
# warmup的迭代次数
cfg
.
SOLVER
.
LR_WARMUP_STEPS
=
2000
########################## 测试配置 ###########################################
# 测试模型路径
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录