提交 b7772333 编写于 作者: Z zhenghuanhuan

fix issue.

上级 ca524d71
......@@ -26,6 +26,7 @@ from mindspore.common import dtype as mstype
from mindarmour.utils._check_param import check_param_type
from mindarmour.utils._check_param import check_value_positive
from mindarmour.utils._check_param import check_param_in_range
from mindarmour.utils._check_param import check_value_non_negative
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
......@@ -204,8 +205,10 @@ class NoiseGaussianRandom(_Mechanisms):
def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.0, seed=0, decay_policy=None):
super(NoiseGaussianRandom, self).__init__()
norm_bound = check_param_type('norm_bound', norm_bound, float)
self._norm_bound = check_value_positive('norm_bound', norm_bound)
self._norm_bound = Tensor(norm_bound, mstype.float32)
initial_noise_multiplier = check_param_type('initial_noise_multiplier', initial_noise_multiplier, float)
self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier',
initial_noise_multiplier)
self._initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32)
......@@ -213,7 +216,8 @@ class NoiseGaussianRandom(_Mechanisms):
if decay_policy is not None:
raise ValueError('decay_policy must be None in GaussianRandom class, but got {}.'.format(decay_policy))
self._decay_policy = decay_policy
self._seed = seed
seed = check_param_type('seed', seed, int)
self._seed = check_value_non_negative('seed', seed)
def construct(self, gradients):
"""
......@@ -400,7 +404,8 @@ class AdaClippingWithGaussianRandom(Cell):
self._sub = P.Sub()
self._mul = P.Mul()
self._exp = P.Exp()
self._seed = seed
seed = check_param_type('seed', seed, int)
self._seed = check_value_non_negative('seed', seed)
def construct(self, empirical_fraction, norm_bound):
"""
......
......@@ -50,8 +50,7 @@ from mindspore import ParameterTuple
from mindarmour.utils.logger import LogUtil
from mindarmour.diff_privacy.mechanisms.mechanisms import \
_MechanismsParamsUpdater
from mindarmour.utils._check_param import check_param_type
from mindarmour.utils._check_param import check_value_positive
from mindarmour.utils._check_param import check_value_positive, check_param_type
from mindarmour.utils._check_param import check_int_positive
LOGGER = LogUtil.get_instance()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册