diff --git a/ppcls/optimizer/learning_rate.py b/ppcls/optimizer/learning_rate.py index 11a5c915243f122f1771194f5c2bcb6271dfb8f3..c88c5ea7625f014311c05a729c41619c5b26156f 100644 --- a/ppcls/optimizer/learning_rate.py +++ b/ppcls/optimizer/learning_rate.py @@ -116,9 +116,9 @@ class Constant(LRBase): last_epoch=-1, by_epoch=False, **kwargs): - super(Linear, self).__init__(epochs, step_each_epoch, learning_rate, - warmup_epoch, warmup_start_lr, last_epoch, - by_epoch) + super(Constant, self).__init__(epochs, step_each_epoch, learning_rate, + warmup_epoch, warmup_start_lr, + last_epoch, by_epoch) def __call__(self): learning_rate = lr.LRScheduler( @@ -220,7 +220,7 @@ class Cosine(LRBase): last_epoch=-1, by_epoch=False, **kwargs): - super(Linear, self).__init__(epochs, step_each_epoch, learning_rate, + super(Cosine, self).__init__(epochs, step_each_epoch, learning_rate, warmup_epoch, warmup_start_lr, last_epoch, by_epoch) self.T_max = (self.epochs - self.warmup_epoch) * self.step_each_epoch @@ -269,9 +269,9 @@ class Step(LRBase): last_epoch=-1, by_epoch=False, **kwargs): - super(Linear, self).__init__(epochs, step_each_epoch, learning_rate, - warmup_epoch, warmup_start_lr, last_epoch, - by_epoch) + super(Step, self).__init__(epochs, step_each_epoch, learning_rate, + warmup_epoch, warmup_start_lr, last_epoch, + by_epoch) self.step_size = step_size * step_each_epoch self.gamma = gamma if self.by_epoch: @@ -315,7 +315,7 @@ class Piecewise(LRBase): last_epoch=-1, by_epoch=False, **kwargs): - super(Linear, + super(Piecewise, self).__init__(epochs, step_each_epoch, values[0], warmup_epoch, warmup_start_lr, last_epoch, by_epoch) self.values = values @@ -362,9 +362,9 @@ class MultiStepDecay(LRBase): last_epoch=-1, by_epoch=False, **kwargs): - super(Linear, self).__init__(epochs, step_each_epoch, learning_rate, - warmup_epoch, warmup_start_lr, last_epoch, - by_epoch) + super(MultiStepDecay, self).__init__( + epochs, step_each_epoch, learning_rate, warmup_epoch, + warmup_start_lr, last_epoch, by_epoch) self.milestones = [x * step_each_epoch for x in milestones] self.gamma = gamma if self.by_epoch: