diff --git a/ppdet/optimizer/optimizer.py b/ppdet/optimizer/optimizer.py index 2d0714078eec14dadd57f5689ae6a41039562202..3c528fcfe85888d7126d75220a07fbfd94480ad3 100644 --- a/ppdet/optimizer/optimizer.py +++ b/ppdet/optimizer/optimizer.py @@ -165,17 +165,20 @@ class LinearWarmup(object): of `epochs` is higher than `steps`. Default: None. """ - def __init__(self, steps=500, start_factor=1. / 3, epochs=None): + def __init__(self, steps=500, start_factor=1. / 3, epochs=None, epochs_first=True): super(LinearWarmup, self).__init__() self.steps = steps self.start_factor = start_factor self.epochs = epochs + self.epochs_first = epochs_first def __call__(self, base_lr, step_per_epoch): boundary = [] value = [] - warmup_steps = self.epochs * step_per_epoch \ - if self.epochs is not None else self.steps + if self.epochs_first and self.epochs is not None: + warmup_steps = self.epochs * step_per_epoch + else: + warmup_steps = self.steps warmup_steps = max(warmup_steps, 1) for i in range(warmup_steps + 1): if warmup_steps > 0: