提交 ec9d46a7 编写于 作者: Z ZhidanLiu

fix parameter check

上级 f3baf9db
......@@ -159,7 +159,10 @@ class AdaGaussianRandom(Mechanisms):
alpha = check_param_type('alpha', alpha, float)
self._alpha = Tensor(np.array(alpha, np.float32))
self._decay_policy = check_param_type('decay_policy', decay_policy, str)
if decay_policy not in ['Time', 'Step']:
raise NameError("The decay_policy must be in ['Time', 'Step'], but "
"get {}".format(decay_policy))
self._decay_policy = decay_policy
self._mean = 0.0
self._sub = P.Sub()
self._mul = P.Mul()
......
......@@ -43,7 +43,7 @@ def check_param_type(arg_name, arg_value, valid_type):
valid_type,
type(arg_value).__name__)
LOGGER.error(TAG, msg)
raise ValueError(msg)
raise TypeError(msg)
return arg_value
......@@ -54,7 +54,7 @@ def check_param_multi_types(arg_name, arg_value, valid_types):
msg = 'type of {} must be in {}, but got {}' \
.format(arg_name, valid_types, type(arg_value).__name__)
LOGGER.error(TAG, msg)
raise ValueError(msg)
raise TypeError(msg)
return arg_value
......@@ -157,7 +157,7 @@ def check_numpy_param(arg_name, arg_value):
msg = 'type of {} must be in (list, tuple, numpy.ndarray)'.format(
arg_name)
LOGGER.error(TAG, msg)
raise ValueError(msg)
raise TypeError(msg)
return arg_value
......
......@@ -167,7 +167,7 @@ def test_momentum_diverse_input_iterative_method():
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_error():
with pytest.raises(ValueError):
with pytest.raises(TypeError):
# check_param_multi_types
assert IterativeGradientMethod(Net(), bounds=None)
attack = IterativeGradientMethod(Net(), bounds=(0.0, 1.0))
......
......@@ -100,16 +100,16 @@ def test_value_error():
with pytest.raises(ValueError):
assert RegionBasedDetector(model, search_step=0)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
assert RegionBasedDetector(model, sparse='False')
detector = RegionBasedDetector(model)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
# radius must not empty
assert detector.detect(adv)
radius = detector.fit(ori, labels)
detector.set_radius(radius)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
# adv type should be in (list, tuple, numpy.ndarray)
assert detector.detect(adv.tostring())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册