提交 834a4071 编写于 作者: L leilei_snow

Add the function of checking nan or inf

上级 46acf238
......@@ -15,6 +15,7 @@
"""Check parameters."""
import re
import inspect
import math
from enum import Enum
from functools import reduce, wraps
from itertools import repeat
......@@ -318,6 +319,16 @@ class Validator:
raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
@staticmethod
def check_float_legal_value(arg_name, arg_value, prim_name):
"""Checks whether a legal value of float type"""
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
if isinstance(arg_value, float):
if math.isinf(arg_value) or math.isnan(arg_value):
raise ValueError(f"{msg_prefix} `{arg_name}` must be legal value, but got {arg_value}.")
return arg_value
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
class ParamValidator:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
......
......@@ -28,7 +28,7 @@ def piecewise_constant_lr(milestone, learning_rates):
`milestone`. Let the output learning rate be `y`.
.. math::
y[i] = x_t for i \in [M_{t-1}, M_t)
y[i] = x_t,\ for\ i \in [M_{t-1}, M_t)
Args:
milestone (list[int]): A list of milestone. This list is a monotone increasing list.
......@@ -52,7 +52,7 @@ def piecewise_constant_lr(milestone, learning_rates):
last_item = 0
for i, item in enumerate(milestone):
validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT, None)
validator.check_value_type(f'learning_rates[{i}]', learning_rates[i], [float], None)
validator.check_float_legal_value(f'learning_rates[{i}]', learning_rates[i], None)
if item < last_item:
raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]')
lr += [learning_rates[i]] * (item - last_item)
......@@ -66,7 +66,9 @@ def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_e
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
validator.check_float_positive('learning_rate', learning_rate, None)
validator.check_float_legal_value('learning_rate', learning_rate, None)
validator.check_float_positive('decay_rate', decay_rate, None)
validator.check_float_legal_value('decay_rate', decay_rate, None)
validator.check_value_type('is_stair', is_stair, [bool], None)
......@@ -229,7 +231,9 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
[0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01]
"""
validator.check_float_positive('min_lr', min_lr, None)
validator.check_float_legal_value('min_lr', min_lr, None)
validator.check_float_positive('max_lr', max_lr, None)
validator.check_float_legal_value('max_lr', max_lr, None)
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
......@@ -280,11 +284,14 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
[0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01]
"""
validator.check_float_positive('learning_rate', learning_rate, None)
validator.check_float_legal_value('learning_rate', learning_rate, None)
validator.check_float_positive('end_learning_rate', end_learning_rate, None)
validator.check_float_legal_value('end_learning_rate', end_learning_rate, None)
validator.check_float_positive('power', power, None)
validator.check_float_legal_value('power', power, None)
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
validator.check_value_type('power', power, [float], None)
validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None)
function = lambda x, y: (x, min(x, y))
......@@ -298,3 +305,13 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
decay_epoch, tmp_epoch = function(decay_epoch, current_epoch)
lr.append(delta * (1 - tmp_epoch / decay_epoch) ** power + end_learning_rate)
return lr
__all__ = [
'piecewise_constant_lr',
'exponential_decay_lr',
'natural_exp_decay_lr',
'inverse_decay_lr',
'cosine_decay_lr',
'polynomial_decay_lr'
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册