未验证 提交 f91bc7ba 编写于 作者: T Tingquan Gao 提交者: GitHub

perf: add parameter validation (#1249)

When using warm up, the total epoch num must be greater than warm up epoch num. Otherwise,
there will be raising warning and warm up epoch num will be set to total epoch num.
上级 283f9fa9
......@@ -11,12 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from paddle.optimizer import lr
from paddle.optimizer.lr import LRScheduler
from ppcls.utils import logger
class Linear(object):
"""
......@@ -41,7 +44,11 @@ class Linear(object):
warmup_start_lr=0.0,
last_epoch=-1,
**kwargs):
super(Linear, self).__init__()
super().__init__()
if warmup_epoch >= epochs:
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
logger.warning(msg)
warmup_epoch = epochs
self.learning_rate = learning_rate
self.steps = (epochs - warmup_epoch) * step_each_epoch
self.end_lr = end_lr
......@@ -56,7 +63,8 @@ class Linear(object):
decay_steps=self.steps,
end_lr=self.end_lr,
power=self.power,
last_epoch=self.last_epoch)
last_epoch=self.
last_epoch) if self.steps > 0 else self.learning_rate
if self.warmup_steps > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
......@@ -90,7 +98,11 @@ class Cosine(object):
warmup_start_lr=0.0,
last_epoch=-1,
**kwargs):
super(Cosine, self).__init__()
super().__init__()
if warmup_epoch >= epochs:
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
logger.warning(msg)
warmup_epoch = epochs
self.learning_rate = learning_rate
self.T_max = (epochs - warmup_epoch) * step_each_epoch
self.eta_min = eta_min
......@@ -103,7 +115,8 @@ class Cosine(object):
learning_rate=self.learning_rate,
T_max=self.T_max,
eta_min=self.eta_min,
last_epoch=self.last_epoch)
last_epoch=self.
last_epoch) if self.T_max > 0 else self.learning_rate
if self.warmup_steps > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
......@@ -132,12 +145,17 @@ class Step(object):
learning_rate,
step_size,
step_each_epoch,
epochs,
gamma,
warmup_epoch=0,
warmup_start_lr=0.0,
last_epoch=-1,
**kwargs):
super(Step, self).__init__()
super().__init__()
if warmup_epoch >= epochs:
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
logger.warning(msg)
warmup_epoch = epochs
self.step_size = step_each_epoch * step_size
self.learning_rate = learning_rate
self.gamma = gamma
......@@ -177,11 +195,16 @@ class Piecewise(object):
step_each_epoch,
decay_epochs,
values,
epochs,
warmup_epoch=0,
warmup_start_lr=0.0,
last_epoch=-1,
**kwargs):
super(Piecewise, self).__init__()
super().__init__()
if warmup_epoch >= epochs:
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
logger.warning(msg)
warmup_epoch = epochs
self.boundaries = [step_each_epoch * e for e in decay_epochs]
self.values = values
self.last_epoch = last_epoch
......@@ -294,8 +317,7 @@ class MultiStepDecay(LRScheduler):
raise ValueError('gamma should be < 1.0.')
self.milestones = [x * step_each_epoch for x in milestones]
self.gamma = gamma
super(MultiStepDecay, self).__init__(learning_rate, last_epoch,
verbose)
super().__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
for i in range(len(self.milestones)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册