diff --git a/mindarmour/diff_privacy/mechanisms/mechanisms.py b/mindarmour/diff_privacy/mechanisms/mechanisms.py index dd8f5da4ceef44e9884aa22a85f4b41a0886664a..44aaea49e17e3addc3b118e87a32fe86e6bd7fac 100644 --- a/mindarmour/diff_privacy/mechanisms/mechanisms.py +++ b/mindarmour/diff_privacy/mechanisms/mechanisms.py @@ -159,7 +159,10 @@ class AdaGaussianRandom(Mechanisms): alpha = check_param_type('alpha', alpha, float) self._alpha = Tensor(np.array(alpha, np.float32)) - self._decay_policy = check_param_type('decay_policy', decay_policy, str) + if decay_policy not in ['Time', 'Step']: + raise NameError("The decay_policy must be in ['Time', 'Step'], but " + "get {}".format(decay_policy)) + self._decay_policy = decay_policy self._mean = 0.0 self._sub = P.Sub() self._mul = P.Mul() diff --git a/mindarmour/utils/_check_param.py b/mindarmour/utils/_check_param.py index 06a37c66ca3c977f59a843ae38568439111a7426..36ebc8724d667f71dc7232faf044f908ab7a5e5b 100644 --- a/mindarmour/utils/_check_param.py +++ b/mindarmour/utils/_check_param.py @@ -43,7 +43,7 @@ def check_param_type(arg_name, arg_value, valid_type): valid_type, type(arg_value).__name__) LOGGER.error(TAG, msg) - raise ValueError(msg) + raise TypeError(msg) return arg_value @@ -54,7 +54,7 @@ def check_param_multi_types(arg_name, arg_value, valid_types): msg = 'type of {} must be in {}, but got {}' \ .format(arg_name, valid_types, type(arg_value).__name__) LOGGER.error(TAG, msg) - raise ValueError(msg) + raise TypeError(msg) return arg_value @@ -157,7 +157,7 @@ def check_numpy_param(arg_name, arg_value): msg = 'type of {} must be in (list, tuple, numpy.ndarray)'.format( arg_name) LOGGER.error(TAG, msg) - raise ValueError(msg) + raise TypeError(msg) return arg_value diff --git a/tests/ut/python/attacks/test_iterative_gradient_method.py b/tests/ut/python/attacks/test_iterative_gradient_method.py index 3a9fcb024a94634c08770c2be0ce0c0b524f113c..34cc2e384be7b02bbdefe7bfed2dc54901797711 100644 --- a/tests/ut/python/attacks/test_iterative_gradient_method.py +++ b/tests/ut/python/attacks/test_iterative_gradient_method.py @@ -167,7 +167,7 @@ def test_momentum_diverse_input_iterative_method(): @pytest.mark.env_card @pytest.mark.component_mindarmour def test_error(): - with pytest.raises(ValueError): + with pytest.raises(TypeError): # check_param_multi_types assert IterativeGradientMethod(Net(), bounds=None) attack = IterativeGradientMethod(Net(), bounds=(0.0, 1.0)) diff --git a/tests/ut/python/detectors/test_region_based_detector.py b/tests/ut/python/detectors/test_region_based_detector.py index c9587498cf0448c3ba1e43da5c415a6f354ee963..f4b891acbdc1261a664392041fa5e19055b8c10d 100644 --- a/tests/ut/python/detectors/test_region_based_detector.py +++ b/tests/ut/python/detectors/test_region_based_detector.py @@ -100,16 +100,16 @@ def test_value_error(): with pytest.raises(ValueError): assert RegionBasedDetector(model, search_step=0) - with pytest.raises(ValueError): + with pytest.raises(TypeError): assert RegionBasedDetector(model, sparse='False') detector = RegionBasedDetector(model) - with pytest.raises(ValueError): + with pytest.raises(TypeError): # radius must not empty assert detector.detect(adv) radius = detector.fit(ori, labels) detector.set_radius(radius) - with pytest.raises(ValueError): + with pytest.raises(TypeError): # adv type should be in (list, tuple, numpy.ndarray) assert detector.detect(adv.tostring())