From a676fa992d66ffdde5c340de2009ed6de4b5603e Mon Sep 17 00:00:00 2001 From: fuyi02 <39895596+fuyi02@users.noreply.github.com> Date: Wed, 13 Nov 2019 10:46:18 +0800 Subject: [PATCH] add warmup strategy (#86) --- pdseg/solver.py | 25 +++++++++++++++++++++++++ pdseg/utils/config.py | 4 ++++ 2 files changed, 29 insertions(+) diff --git a/pdseg/solver.py b/pdseg/solver.py index 8eea7400..48461181 100644 --- a/pdseg/solver.py +++ b/pdseg/solver.py @@ -34,6 +34,25 @@ class Solver(object): self.main_prog = main_prog self.start_prog = start_prog + def lr_warmup(self, learning_rate, warmup_steps, start_lr, end_lr): + linear_step = end_lr - start_lr + lr = fluid.layers.tensor.create_global_var( + shape=[1], + value=0.0, + dtype='float32', + persistable=True, + name="learning_rate_warmup") + + global_step = fluid.layers.learning_rate_scheduler._decay_step_counter() + + with fluid.layers.control_flow.Switch() as switch: + with switch.case(global_step < warmup_steps): + decayed_lr = start_lr + linear_step * (global_step / warmup_steps) + fluid.layers.tensor.assign(decayed_lr, lr) + with switch.default(): + fluid.layers.tensor.assign(learning_rate, lr) + return lr + def piecewise_decay(self): gamma = cfg.SOLVER.GAMMA bd = [self.step_per_epoch * e for e in cfg.SOLVER.DECAY_EPOCH] @@ -63,6 +82,12 @@ class Solver(object): raise Exception( "unsupport learning decay policy! only support poly,piecewise,cosine" ) + + if cfg.SOLVER.LR_WARMUP: + start_lr = 0 + end_lr = cfg.SOLVER.LR + warmup_steps = cfg.SOLVER.LR_WARMUP_STEPS + decayed_lr = self.lr_warmup(decayed_lr, warmup_steps, start_lr, end_lr) return decayed_lr def sgd_optimizer(self, lr_policy, loss): diff --git a/pdseg/utils/config.py b/pdseg/utils/config.py index f8bf7969..4cac671d 100644 --- a/pdseg/utils/config.py +++ b/pdseg/utils/config.py @@ -154,6 +154,10 @@ cfg.SOLVER.BEGIN_EPOCH = 1 cfg.SOLVER.NUM_EPOCHS = 30 # loss的选择,支持softmax_loss, bce_loss, dice_loss cfg.SOLVER.LOSS = ["softmax_loss"] +# 是否开启warmup学习策略 +cfg.SOLVER.LR_WARMUP = False +# warmup的迭代次数 +cfg.SOLVER.LR_WARMUP_STEPS = 2000 ########################## 测试配置 ########################################### # 测试模型路径 -- GitLab