提交 b7772333 编写于 作者: Z zhenghuanhuan

fix issue.

上级 ca524d71
...@@ -26,6 +26,7 @@ from mindspore.common import dtype as mstype ...@@ -26,6 +26,7 @@ from mindspore.common import dtype as mstype
from mindarmour.utils._check_param import check_param_type from mindarmour.utils._check_param import check_param_type
from mindarmour.utils._check_param import check_value_positive from mindarmour.utils._check_param import check_value_positive
from mindarmour.utils._check_param import check_param_in_range from mindarmour.utils._check_param import check_param_in_range
from mindarmour.utils._check_param import check_value_non_negative
from mindarmour.utils.logger import LogUtil from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance() LOGGER = LogUtil.get_instance()
...@@ -204,8 +205,10 @@ class NoiseGaussianRandom(_Mechanisms): ...@@ -204,8 +205,10 @@ class NoiseGaussianRandom(_Mechanisms):
def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.0, seed=0, decay_policy=None): def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.0, seed=0, decay_policy=None):
super(NoiseGaussianRandom, self).__init__() super(NoiseGaussianRandom, self).__init__()
norm_bound = check_param_type('norm_bound', norm_bound, float)
self._norm_bound = check_value_positive('norm_bound', norm_bound) self._norm_bound = check_value_positive('norm_bound', norm_bound)
self._norm_bound = Tensor(norm_bound, mstype.float32) self._norm_bound = Tensor(norm_bound, mstype.float32)
initial_noise_multiplier = check_param_type('initial_noise_multiplier', initial_noise_multiplier, float)
self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier', self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier',
initial_noise_multiplier) initial_noise_multiplier)
self._initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32) self._initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32)
...@@ -213,7 +216,8 @@ class NoiseGaussianRandom(_Mechanisms): ...@@ -213,7 +216,8 @@ class NoiseGaussianRandom(_Mechanisms):
if decay_policy is not None: if decay_policy is not None:
raise ValueError('decay_policy must be None in GaussianRandom class, but got {}.'.format(decay_policy)) raise ValueError('decay_policy must be None in GaussianRandom class, but got {}.'.format(decay_policy))
self._decay_policy = decay_policy self._decay_policy = decay_policy
self._seed = seed seed = check_param_type('seed', seed, int)
self._seed = check_value_non_negative('seed', seed)
def construct(self, gradients): def construct(self, gradients):
""" """
...@@ -400,7 +404,8 @@ class AdaClippingWithGaussianRandom(Cell): ...@@ -400,7 +404,8 @@ class AdaClippingWithGaussianRandom(Cell):
self._sub = P.Sub() self._sub = P.Sub()
self._mul = P.Mul() self._mul = P.Mul()
self._exp = P.Exp() self._exp = P.Exp()
self._seed = seed seed = check_param_type('seed', seed, int)
self._seed = check_value_non_negative('seed', seed)
def construct(self, empirical_fraction, norm_bound): def construct(self, empirical_fraction, norm_bound):
""" """
......
...@@ -50,8 +50,7 @@ from mindspore import ParameterTuple ...@@ -50,8 +50,7 @@ from mindspore import ParameterTuple
from mindarmour.utils.logger import LogUtil from mindarmour.utils.logger import LogUtil
from mindarmour.diff_privacy.mechanisms.mechanisms import \ from mindarmour.diff_privacy.mechanisms.mechanisms import \
_MechanismsParamsUpdater _MechanismsParamsUpdater
from mindarmour.utils._check_param import check_param_type from mindarmour.utils._check_param import check_value_positive, check_param_type
from mindarmour.utils._check_param import check_value_positive
from mindarmour.utils._check_param import check_int_positive from mindarmour.utils._check_param import check_int_positive
LOGGER = LogUtil.get_instance() LOGGER = LogUtil.get_instance()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册