diff --git a/ppgan/solver/lr_scheduler.py b/ppgan/solver/lr_scheduler.py index 3c17e3da0f06848ec446a5fbf141762eeae0b918..8e3f87baa3a76d31bfdd06c1646131d47dedeffa 100644 --- a/ppgan/solver/lr_scheduler.py +++ b/ppgan/solver/lr_scheduler.py @@ -12,25 +12,8 @@ def build_lr_scheduler(cfg): 0, epoch + 1 - cfg.start_epoch) / float(cfg.decay_epochs + 1) return lr_l - scheduler = paddle.optimizer.lr_scheduler.LambdaLR( - cfg.learning_rate, lr_lambda=lambda_rule) + scheduler = paddle.optimizer.lr.LambdaLR(cfg.learning_rate, + lr_lambda=lambda_rule) return scheduler else: raise NotImplementedError - - -# paddle.optimizer.lr_scheduler -class LinearDecay(paddle.optimizer.lr_scheduler._LRScheduler): - def __init__(self, learning_rate, step_per_epoch, start_epoch, - decay_epochs): - super(LinearDecay, self).__init__() - self.learning_rate = learning_rate - self.start_epoch = start_epoch - self.decay_epochs = decay_epochs - self.step_per_epoch = step_per_epoch - - def step(self): - cur_epoch = int(self.step_num // self.step_per_epoch) - decay_rate = 1.0 - max( - 0, cur_epoch + 1 - self.start_epoch) / float(self.decay_epochs + 1) - return self.create_lr_var(decay_rate * self.learning_rate)