From c48d6bfb1e664818e79920f80217d2c9f7851e38 Mon Sep 17 00:00:00 2001 From: Tingquan Gao <35441050@qq.com> Date: Thu, 24 Aug 2023 14:49:36 +0800 Subject: [PATCH] support to specify priority about steps and epochs in LinearWarmup (#8564) for uapi --- ppdet/optimizer/optimizer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ppdet/optimizer/optimizer.py b/ppdet/optimizer/optimizer.py index 2d0714078..3c528fcfe 100644 --- a/ppdet/optimizer/optimizer.py +++ b/ppdet/optimizer/optimizer.py @@ -165,17 +165,20 @@ class LinearWarmup(object): of `epochs` is higher than `steps`. Default: None. """ - def __init__(self, steps=500, start_factor=1. / 3, epochs=None): + def __init__(self, steps=500, start_factor=1. / 3, epochs=None, epochs_first=True): super(LinearWarmup, self).__init__() self.steps = steps self.start_factor = start_factor self.epochs = epochs + self.epochs_first = epochs_first def __call__(self, base_lr, step_per_epoch): boundary = [] value = [] - warmup_steps = self.epochs * step_per_epoch \ - if self.epochs is not None else self.steps + if self.epochs_first and self.epochs is not None: + warmup_steps = self.epochs * step_per_epoch + else: + warmup_steps = self.steps warmup_steps = max(warmup_steps, 1) for i in range(warmup_steps + 1): if warmup_steps > 0: -- GitLab