diff --git a/mindarmour/diff_privacy/mechanisms/mechanisms.py b/mindarmour/diff_privacy/mechanisms/mechanisms.py index 68c5d170b6ea13dee9e921d43d5177936bf17828..a1103373e80aa26d2b93cabe044fac48f569c49f 100644 --- a/mindarmour/diff_privacy/mechanisms/mechanisms.py +++ b/mindarmour/diff_privacy/mechanisms/mechanisms.py @@ -26,6 +26,7 @@ from mindspore.common import dtype as mstype from mindarmour.utils._check_param import check_param_type from mindarmour.utils._check_param import check_value_positive from mindarmour.utils._check_param import check_param_in_range +from mindarmour.utils._check_param import check_value_non_negative from mindarmour.utils.logger import LogUtil LOGGER = LogUtil.get_instance() @@ -204,8 +205,10 @@ class NoiseGaussianRandom(_Mechanisms): def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.0, seed=0, decay_policy=None): super(NoiseGaussianRandom, self).__init__() + norm_bound = check_param_type('norm_bound', norm_bound, float) self._norm_bound = check_value_positive('norm_bound', norm_bound) self._norm_bound = Tensor(norm_bound, mstype.float32) + initial_noise_multiplier = check_param_type('initial_noise_multiplier', initial_noise_multiplier, float) self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier', initial_noise_multiplier) self._initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32) @@ -213,7 +216,8 @@ class NoiseGaussianRandom(_Mechanisms): if decay_policy is not None: raise ValueError('decay_policy must be None in GaussianRandom class, but got {}.'.format(decay_policy)) self._decay_policy = decay_policy - self._seed = seed + seed = check_param_type('seed', seed, int) + self._seed = check_value_non_negative('seed', seed) def construct(self, gradients): """ @@ -400,7 +404,8 @@ class AdaClippingWithGaussianRandom(Cell): self._sub = P.Sub() self._mul = P.Mul() self._exp = P.Exp() - self._seed = seed + seed = check_param_type('seed', seed, int) + self._seed = check_value_non_negative('seed', seed) def construct(self, empirical_fraction, norm_bound): """ diff --git a/mindarmour/diff_privacy/train/model.py b/mindarmour/diff_privacy/train/model.py index f235cb6b802884bb87aa65c32c2d41aeeaeae966..875e9d9494f4f74b5a8f1d848b970bdc9b4f922a 100644 --- a/mindarmour/diff_privacy/train/model.py +++ b/mindarmour/diff_privacy/train/model.py @@ -50,8 +50,7 @@ from mindspore import ParameterTuple from mindarmour.utils.logger import LogUtil from mindarmour.diff_privacy.mechanisms.mechanisms import \ _MechanismsParamsUpdater -from mindarmour.utils._check_param import check_param_type -from mindarmour.utils._check_param import check_value_positive +from mindarmour.utils._check_param import check_value_positive, check_param_type from mindarmour.utils._check_param import check_int_positive LOGGER = LogUtil.get_instance()