From f91bc7ba0be1f0425c352a122953f8acbebc6757 Mon Sep 17 00:00:00 2001 From: Tingquan Gao Date: Wed, 22 Sep 2021 14:35:37 +0800 Subject: [PATCH] perf: add parameter validation (#1249) When using warm up, the total epoch num must be greater than warm up epoch num. Otherwise, there will be raising warning and warm up epoch num will be set to total epoch num. --- ppcls/optimizer/learning_rate.py | 38 +++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/ppcls/optimizer/learning_rate.py b/ppcls/optimizer/learning_rate.py index ea938b12..b59387dd 100644 --- a/ppcls/optimizer/learning_rate.py +++ b/ppcls/optimizer/learning_rate.py @@ -11,12 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from __future__ import (absolute_import, division, print_function, unicode_literals) from paddle.optimizer import lr from paddle.optimizer.lr import LRScheduler +from ppcls.utils import logger + class Linear(object): """ @@ -41,7 +44,11 @@ class Linear(object): warmup_start_lr=0.0, last_epoch=-1, **kwargs): - super(Linear, self).__init__() + super().__init__() + if warmup_epoch >= epochs: + msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}." + logger.warning(msg) + warmup_epoch = epochs self.learning_rate = learning_rate self.steps = (epochs - warmup_epoch) * step_each_epoch self.end_lr = end_lr @@ -56,7 +63,8 @@ class Linear(object): decay_steps=self.steps, end_lr=self.end_lr, power=self.power, - last_epoch=self.last_epoch) + last_epoch=self. + last_epoch) if self.steps > 0 else self.learning_rate if self.warmup_steps > 0: learning_rate = lr.LinearWarmup( learning_rate=learning_rate, @@ -90,7 +98,11 @@ class Cosine(object): warmup_start_lr=0.0, last_epoch=-1, **kwargs): - super(Cosine, self).__init__() + super().__init__() + if warmup_epoch >= epochs: + msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}." + logger.warning(msg) + warmup_epoch = epochs self.learning_rate = learning_rate self.T_max = (epochs - warmup_epoch) * step_each_epoch self.eta_min = eta_min @@ -103,7 +115,8 @@ class Cosine(object): learning_rate=self.learning_rate, T_max=self.T_max, eta_min=self.eta_min, - last_epoch=self.last_epoch) + last_epoch=self. + last_epoch) if self.T_max > 0 else self.learning_rate if self.warmup_steps > 0: learning_rate = lr.LinearWarmup( learning_rate=learning_rate, @@ -132,12 +145,17 @@ class Step(object): learning_rate, step_size, step_each_epoch, + epochs, gamma, warmup_epoch=0, warmup_start_lr=0.0, last_epoch=-1, **kwargs): - super(Step, self).__init__() + super().__init__() + if warmup_epoch >= epochs: + msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}." + logger.warning(msg) + warmup_epoch = epochs self.step_size = step_each_epoch * step_size self.learning_rate = learning_rate self.gamma = gamma @@ -177,11 +195,16 @@ class Piecewise(object): step_each_epoch, decay_epochs, values, + epochs, warmup_epoch=0, warmup_start_lr=0.0, last_epoch=-1, **kwargs): - super(Piecewise, self).__init__() + super().__init__() + if warmup_epoch >= epochs: + msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}." + logger.warning(msg) + warmup_epoch = epochs self.boundaries = [step_each_epoch * e for e in decay_epochs] self.values = values self.last_epoch = last_epoch @@ -294,8 +317,7 @@ class MultiStepDecay(LRScheduler): raise ValueError('gamma should be < 1.0.') self.milestones = [x * step_each_epoch for x in milestones] self.gamma = gamma - super(MultiStepDecay, self).__init__(learning_rate, last_epoch, - verbose) + super().__init__(learning_rate, last_epoch, verbose) def get_lr(self): for i in range(len(self.milestones)): -- GitLab