diff --git a/ppcls/optimizer/learning_rate.py b/ppcls/optimizer/learning_rate.py index 0227ae43e0f5c99894dcaeb1b51425fbdc9a8c82..87c08e747b52f91f555a5787390d5d55c617278f 100644 --- a/ppcls/optimizer/learning_rate.py +++ b/ppcls/optimizer/learning_rate.py @@ -112,35 +112,19 @@ class CosineWarmup(object): self.lr = lr self.step_each_epoch = step_each_epoch self.epochs = epochs - self.warmup_epoch = fluid.layers.fill_constant( - shape=[1], - value=float(warmup_epoch), - dtype='float32', - force_cpu=True) + self.warmup_epoch = warmup_epoch def __call__(self): - global_step = _decay_step_counter() - learning_rate = fluid.layers.tensor.create_global_var( - shape=[1], - value=0.0, - dtype='float32', - persistable=True, - name="learning_rate") - epoch = ops.floor(global_step / self.step_each_epoch) - with fluid.layers.control_flow.Switch() as switch: - with switch.case(epoch < self.warmup_epoch): - decayed_lr = self.lr * \ - (global_step / (self.step_each_epoch * self.warmup_epoch)) - fluid.layers.tensor.assign( - input=decayed_lr, output=learning_rate) - with switch.default(): - current_step = global_step - self.warmup_epoch * self.step_each_epoch - total_step = ( - self.epochs - self.warmup_epoch) * self.step_each_epoch - decayed_lr = self.lr * \ - (ops.cos(current_step * math.pi / total_step) + 1) / 2 - fluid.layers.tensor.assign( - input=decayed_lr, output=learning_rate) + learning_rate = fluid.layers.cosine_decay( + learning_rate=self.lr, + step_each_epoch=self.step_each_epoch, + epochs=self.epochs) + + learning_rate = fluid.layers.linear_lr_warmup( + learning_rate, + warmup_steps=self.warmup_epoch * self.step_each_epoch, + start_lr=0.0, + end_lr=self.lr) return learning_rate @@ -169,37 +153,22 @@ class ExponentialWarmup(object): super(ExponentialWarmup, self).__init__() self.lr = lr self.step_each_epoch = step_each_epoch - self.decay_epochs = decay_epochs * self.step_each_epoch + self.decay_epochs = decay_epochs self.decay_rate = decay_rate - self.warmup_epoch = fluid.layers.fill_constant( - shape=[1], - value=float(warmup_epoch), - dtype='float32', - force_cpu=True) + self.warmup_epoch = warmup_epoch def __call__(self): - global_step = _decay_step_counter() - learning_rate = fluid.layers.tensor.create_global_var( - shape=[1], - value=0.0, - dtype='float32', - persistable=True, - name="learning_rate") - - epoch = ops.floor(global_step / self.step_each_epoch) - with fluid.layers.control_flow.Switch() as switch: - with switch.case(epoch < self.warmup_epoch): - decayed_lr = self.lr * \ - (global_step / (self.step_each_epoch * self.warmup_epoch)) - fluid.layers.tensor.assign( - input=decayed_lr, output=learning_rate) - with switch.default(): - rest_step = global_step - self.warmup_epoch * self.step_each_epoch - div_res = ops.floor(rest_step / self.decay_epochs) - - decayed_lr = self.lr * (self.decay_rate**div_res) - fluid.layers.tensor.assign( - input=decayed_lr, output=learning_rate) + learning_rate = fluid.layers.exponential_decay( + learning_rate=self.lr, + decay_steps=self.decay_epochs * self.step_each_epoch, + decay_rate=self.decay_rate, + staircase=False) + + learning_rate = fluid.layers.linear_lr_warmup( + learning_rate, + warmup_steps=self.warmup_epoch * self.step_each_epoch, + start_lr=0.0, + end_lr=self.lr) return learning_rate