提交 a676fa99 编写于 作者: F fuyi02 提交者: wuzewu

add warmup strategy (#86)

上级 43f56a57
...@@ -34,6 +34,25 @@ class Solver(object): ...@@ -34,6 +34,25 @@ class Solver(object):
self.main_prog = main_prog self.main_prog = main_prog
self.start_prog = start_prog self.start_prog = start_prog
def lr_warmup(self, learning_rate, warmup_steps, start_lr, end_lr):
linear_step = end_lr - start_lr
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate_warmup")
global_step = fluid.layers.learning_rate_scheduler._decay_step_counter()
with fluid.layers.control_flow.Switch() as switch:
with switch.case(global_step < warmup_steps):
decayed_lr = start_lr + linear_step * (global_step / warmup_steps)
fluid.layers.tensor.assign(decayed_lr, lr)
with switch.default():
fluid.layers.tensor.assign(learning_rate, lr)
return lr
def piecewise_decay(self): def piecewise_decay(self):
gamma = cfg.SOLVER.GAMMA gamma = cfg.SOLVER.GAMMA
bd = [self.step_per_epoch * e for e in cfg.SOLVER.DECAY_EPOCH] bd = [self.step_per_epoch * e for e in cfg.SOLVER.DECAY_EPOCH]
...@@ -63,6 +82,12 @@ class Solver(object): ...@@ -63,6 +82,12 @@ class Solver(object):
raise Exception( raise Exception(
"unsupport learning decay policy! only support poly,piecewise,cosine" "unsupport learning decay policy! only support poly,piecewise,cosine"
) )
if cfg.SOLVER.LR_WARMUP:
start_lr = 0
end_lr = cfg.SOLVER.LR
warmup_steps = cfg.SOLVER.LR_WARMUP_STEPS
decayed_lr = self.lr_warmup(decayed_lr, warmup_steps, start_lr, end_lr)
return decayed_lr return decayed_lr
def sgd_optimizer(self, lr_policy, loss): def sgd_optimizer(self, lr_policy, loss):
......
...@@ -154,6 +154,10 @@ cfg.SOLVER.BEGIN_EPOCH = 1 ...@@ -154,6 +154,10 @@ cfg.SOLVER.BEGIN_EPOCH = 1
cfg.SOLVER.NUM_EPOCHS = 30 cfg.SOLVER.NUM_EPOCHS = 30
# loss的选择,支持softmax_loss, bce_loss, dice_loss # loss的选择,支持softmax_loss, bce_loss, dice_loss
cfg.SOLVER.LOSS = ["softmax_loss"] cfg.SOLVER.LOSS = ["softmax_loss"]
# 是否开启warmup学习策略
cfg.SOLVER.LR_WARMUP = False
# warmup的迭代次数
cfg.SOLVER.LR_WARMUP_STEPS = 2000
########################## 测试配置 ########################################### ########################## 测试配置 ###########################################
# 测试模型路径 # 测试模型路径
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册