diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 7b8c89351cc8b32bbc1450eaa65a9ab751a5a316..ae4274137115e080b8d8479e3b51ef633e3c7fb2 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -15,6 +15,7 @@ """Check parameters.""" import re import inspect +import math from enum import Enum from functools import reduce, wraps from itertools import repeat @@ -318,6 +319,16 @@ class Validator: raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' f' but got {get_typename(arg_type)}.') + @staticmethod + def check_float_legal_value(arg_name, arg_value, prim_name): + """Checks whether a legal value of float type""" + msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" + if isinstance(arg_value, float): + if math.isinf(arg_value) or math.isnan(arg_value): + raise ValueError(f"{msg_prefix} `{arg_name}` must be legal value, but got {arg_value}.") + return arg_value + raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") + class ParamValidator: """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" diff --git a/mindspore/nn/dynamic_lr.py b/mindspore/nn/dynamic_lr.py index 266587c5c38dd9ce9627a39049385f6e27a5d8b3..1811114680d1fa8969922d838e9d41acadba5421 100644 --- a/mindspore/nn/dynamic_lr.py +++ b/mindspore/nn/dynamic_lr.py @@ -28,7 +28,7 @@ def piecewise_constant_lr(milestone, learning_rates): `milestone`. Let the output learning rate be `y`. .. math:: - y[i] = x_t for i \in [M_{t-1}, M_t) + y[i] = x_t,\ for\ i \in [M_{t-1}, M_t) Args: milestone (Union[list[int], tuple[int]]): A list of milestone. This list is a monotone increasing list. @@ -52,7 +52,7 @@ def piecewise_constant_lr(milestone, learning_rates): last_item = 0 for i, item in enumerate(milestone): validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT, None) - validator.check_value_type(f'learning_rates[{i}]', learning_rates[i], [float], None) + validator.check_float_legal_value(f'learning_rates[{i}]', learning_rates[i], None) if item < last_item: raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]') lr += [learning_rates[i]] * (item - last_item) @@ -66,7 +66,9 @@ def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_e validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) validator.check_float_positive('learning_rate', learning_rate, None) + validator.check_float_legal_value('learning_rate', learning_rate, None) validator.check_float_positive('decay_rate', decay_rate, None) + validator.check_float_legal_value('decay_rate', decay_rate, None) validator.check_value_type('is_stair', is_stair, [bool], None) @@ -229,7 +231,9 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): [0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01] """ validator.check_float_positive('min_lr', min_lr, None) + validator.check_float_legal_value('min_lr', min_lr, None) validator.check_float_positive('max_lr', max_lr, None) + validator.check_float_legal_value('max_lr', max_lr, None) validator.check_integer('total_step', total_step, 0, Rel.GT, None) validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) @@ -280,11 +284,14 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01] """ validator.check_float_positive('learning_rate', learning_rate, None) + validator.check_float_legal_value('learning_rate', learning_rate, None) validator.check_float_positive('end_learning_rate', end_learning_rate, None) + validator.check_float_legal_value('end_learning_rate', end_learning_rate, None) + validator.check_float_positive('power', power, None) + validator.check_float_legal_value('power', power, None) validator.check_integer('total_step', total_step, 0, Rel.GT, None) validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) - validator.check_value_type('power', power, [float], None) validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None) function = lambda x, y: (x, min(x, y)) @@ -298,3 +305,13 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e decay_epoch, tmp_epoch = function(decay_epoch, current_epoch) lr.append(delta * (1 - tmp_epoch / decay_epoch) ** power + end_learning_rate) return lr + + +__all__ = [ + 'piecewise_constant_lr', + 'exponential_decay_lr', + 'natural_exp_decay_lr', + 'inverse_decay_lr', + 'cosine_decay_lr', + 'polynomial_decay_lr' +]