提交 b21849d2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!14 Add M-DI2-FGSM method

Merge pull request !14 from zheng-huanhuan/2_master
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
import numpy as np
import pytest
from scipy.special import softmax
from mindspore import Model
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.attacks.iterative_gradient_method import MomentumDiverseInputIterativeMethod
from mindarmour.utils.logger import LogUtil
from mindarmour.evaluations.attack_evaluation import AttackEvaluate
from lenet5_net import LeNet5
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
TAG = 'M_DI2_FGSM_Test'
LOGGER.set_level('INFO')
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_momentum_diverse_input_iterative_method():
"""
M-DI2-FGSM Attack Test
"""
# upload trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
# get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size, sparse=False)
# prediction accuracy before attack
model = Model(net)
batch_num = 32 # the number of batches of attacking samples
test_images = []
test_labels = []
predict_labels = []
i = 0
for data in ds.create_tuple_iterator():
i += 1
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
pred_labels = np.argmax(model.predict(Tensor(images)).asnumpy(),
axis=1)
predict_labels.append(pred_labels)
if i >= batch_num:
break
predict_labels = np.concatenate(predict_labels)
true_labels = np.argmax(np.concatenate(test_labels), axis=1)
accuracy = np.mean(np.equal(predict_labels, true_labels))
LOGGER.info(TAG, "prediction accuracy before attacking is : %s", accuracy)
# attacking
attack = MomentumDiverseInputIterativeMethod(net)
start_time = time.clock()
adv_data = attack.batch_generate(np.concatenate(test_images),
np.concatenate(test_labels), batch_size=32)
stop_time = time.clock()
pred_logits_adv = model.predict(Tensor(adv_data)).asnumpy()
# rescale predict confidences into (0, 1).
pred_logits_adv = softmax(pred_logits_adv, axis=1)
pred_labels_adv = np.argmax(pred_logits_adv, axis=1)
accuracy_adv = np.mean(np.equal(pred_labels_adv, true_labels))
LOGGER.info(TAG, "prediction accuracy after attacking is : %s", accuracy_adv)
attack_evaluate = AttackEvaluate(np.concatenate(test_images).transpose(0, 2, 3, 1),
np.concatenate(test_labels),
adv_data.transpose(0, 2, 3, 1),
pred_logits_adv)
LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s',
attack_evaluate.mis_classification_rate())
LOGGER.info(TAG, 'The average confidence of adversarial class is : %s',
attack_evaluate.avg_conf_adv_class())
LOGGER.info(TAG, 'The average confidence of true class is : %s',
attack_evaluate.avg_conf_true_class())
LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_lp_distance())
LOGGER.info(TAG, 'The average structural similarity between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_ssim())
LOGGER.info(TAG, 'The average costing time is %s',
(stop_time - start_time)/(batch_num*batch_size))
if __name__ == '__main__':
test_momentum_diverse_input_iterative_method()
...@@ -26,6 +26,8 @@ __all__ = ['FastGradientMethod', ...@@ -26,6 +26,8 @@ __all__ = ['FastGradientMethod',
'BasicIterativeMethod', 'BasicIterativeMethod',
'MomentumIterativeMethod', 'MomentumIterativeMethod',
'ProjectedGradientDescent', 'ProjectedGradientDescent',
'DiverseInputIterativeMethod',
'MomentumDiverseInputIterativeMethod',
'DeepFool', 'DeepFool',
'CarliniWagnerL2Attack', 'CarliniWagnerL2Attack',
'JSMAAttack', 'JSMAAttack',
......
...@@ -46,7 +46,7 @@ class GradientMethod(Attack): ...@@ -46,7 +46,7 @@ class GradientMethod(Attack):
Default: None. Default: None.
bounds (tuple): Upper and lower bounds of data, indicating the data range. bounds (tuple): Upper and lower bounds of data, indicating the data range.
In form of (clip_min, clip_max). Default: None. In form of (clip_min, clip_max). Default: None.
loss_fn (Loss): Loss function for optimization. loss_fn (Loss): Loss function for optimization. Default: None.
""" """
def __init__(self, network, eps=0.07, alpha=None, bounds=None, def __init__(self, network, eps=0.07, alpha=None, bounds=None,
...@@ -151,7 +151,7 @@ class FastGradientMethod(GradientMethod): ...@@ -151,7 +151,7 @@ class FastGradientMethod(GradientMethod):
Possible values: np.inf, 1 or 2. Default: 2. Possible values: np.inf, 1 or 2. Default: 2.
is_targeted (bool): If True, targeted attack. If False, untargeted is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False. attack. Default: False.
loss_fn (Loss): Loss function for optimization. loss_fn (Loss): Loss function for optimization. Default: None.
Examples: Examples:
>>> attack = FastGradientMethod(network) >>> attack = FastGradientMethod(network)
...@@ -214,7 +214,7 @@ class RandomFastGradientMethod(FastGradientMethod): ...@@ -214,7 +214,7 @@ class RandomFastGradientMethod(FastGradientMethod):
Possible values: np.inf, 1 or 2. Default: 2. Possible values: np.inf, 1 or 2. Default: 2.
is_targeted (bool): If True, targeted attack. If False, untargeted is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False. attack. Default: False.
loss_fn (Loss): Loss function for optimization. loss_fn (Loss): Loss function for optimization. Default: None.
Raises: Raises:
ValueError: eps is smaller than alpha! ValueError: eps is smaller than alpha!
...@@ -255,7 +255,7 @@ class FastGradientSignMethod(GradientMethod): ...@@ -255,7 +255,7 @@ class FastGradientSignMethod(GradientMethod):
In form of (clip_min, clip_max). Default: (0.0, 1.0). In form of (clip_min, clip_max). Default: (0.0, 1.0).
is_targeted (bool): If True, targeted attack. If False, untargeted is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False. attack. Default: False.
loss_fn (Loss): Loss function for optimization. loss_fn (Loss): Loss function for optimization. Default: None.
Examples: Examples:
>>> attack = FastGradientSignMethod(network) >>> attack = FastGradientSignMethod(network)
...@@ -314,7 +314,7 @@ class RandomFastGradientSignMethod(FastGradientSignMethod): ...@@ -314,7 +314,7 @@ class RandomFastGradientSignMethod(FastGradientSignMethod):
In form of (clip_min, clip_max). Default: (0.0, 1.0). In form of (clip_min, clip_max). Default: (0.0, 1.0).
is_targeted (bool): True: targeted attack. False: untargeted attack. is_targeted (bool): True: targeted attack. False: untargeted attack.
Default: False. Default: False.
loss_fn (Loss): Loss function for optimization. loss_fn (Loss): Loss function for optimization. Default: None.
Raises: Raises:
ValueError: eps is smaller than alpha! ValueError: eps is smaller than alpha!
...@@ -350,7 +350,7 @@ class LeastLikelyClassMethod(FastGradientSignMethod): ...@@ -350,7 +350,7 @@ class LeastLikelyClassMethod(FastGradientSignMethod):
Default: None. Default: None.
bounds (tuple): Upper and lower bounds of data, indicating the data range. bounds (tuple): Upper and lower bounds of data, indicating the data range.
In form of (clip_min, clip_max). Default: (0.0, 1.0). In form of (clip_min, clip_max). Default: (0.0, 1.0).
loss_fn (Loss): Loss function for optimization. loss_fn (Loss): Loss function for optimization. Default: None.
Examples: Examples:
>>> attack = LeastLikelyClassMethod(network) >>> attack = LeastLikelyClassMethod(network)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from abc import abstractmethod from abc import abstractmethod
import numpy as np import numpy as np
from PIL import Image, ImageOps
from mindspore.nn import SoftmaxCrossEntropyWithLogits from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore import Tensor from mindspore import Tensor
...@@ -115,7 +116,7 @@ class IterativeGradientMethod(Attack): ...@@ -115,7 +116,7 @@ class IterativeGradientMethod(Attack):
bounds (tuple): Upper and lower bounds of data, indicating the data range. bounds (tuple): Upper and lower bounds of data, indicating the data range.
In form of (clip_min, clip_max). Default: (0.0, 1.0). In form of (clip_min, clip_max). Default: (0.0, 1.0).
nb_iter (int): Number of iteration. Default: 5. nb_iter (int): Number of iteration. Default: 5.
loss_fn (Loss): Loss function for optimization. loss_fn (Loss): Loss function for optimization. Default: None.
""" """
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), nb_iter=5, def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), nb_iter=5,
loss_fn=None): loss_fn=None):
...@@ -178,14 +179,13 @@ class BasicIterativeMethod(IterativeGradientMethod): ...@@ -178,14 +179,13 @@ class BasicIterativeMethod(IterativeGradientMethod):
is_targeted (bool): If True, targeted attack. If False, untargeted is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False. attack. Default: False.
nb_iter (int): Number of iteration. Default: 5. nb_iter (int): Number of iteration. Default: 5.
loss_fn (Loss): Loss function for optimization. loss_fn (Loss): Loss function for optimization. Default: None.
attack (class): The single step gradient method of each iteration. In attack (class): The single step gradient method of each iteration. In
this class, FGSM is used. this class, FGSM is used.
Examples: Examples:
>>> attack = BasicIterativeMethod(network) >>> attack = BasicIterativeMethod(network)
""" """
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0),
is_targeted=False, nb_iter=5, loss_fn=None): is_targeted=False, nb_iter=5, loss_fn=None):
super(BasicIterativeMethod, self).__init__(network, super(BasicIterativeMethod, self).__init__(network,
...@@ -227,14 +227,22 @@ class BasicIterativeMethod(IterativeGradientMethod): ...@@ -227,14 +227,22 @@ class BasicIterativeMethod(IterativeGradientMethod):
clip_min, clip_max = self._bounds clip_min, clip_max = self._bounds
clip_diff = clip_max - clip_min clip_diff = clip_max - clip_min
for _ in range(self._nb_iter): for _ in range(self._nb_iter):
adv_x = self._attack.generate(inputs, labels) if 'self.prob' in globals():
d_inputs = _transform_inputs(inputs, self.prob)
else:
d_inputs = inputs
adv_x = self._attack.generate(d_inputs, labels)
perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff,
self._eps*clip_diff) self._eps*clip_diff)
adv_x = arr_x + perturs adv_x = arr_x + perturs
inputs = adv_x inputs = adv_x
else: else:
for _ in range(self._nb_iter): for _ in range(self._nb_iter):
adv_x = self._attack.generate(inputs, labels) if 'self.prob' in globals():
d_inputs = _transform_inputs(inputs, self.prob)
else:
d_inputs = inputs
adv_x = self._attack.generate(d_inputs, labels)
adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps)
inputs = adv_x inputs = adv_x
return adv_x return adv_x
...@@ -261,7 +269,7 @@ class MomentumIterativeMethod(IterativeGradientMethod): ...@@ -261,7 +269,7 @@ class MomentumIterativeMethod(IterativeGradientMethod):
decay_factor (float): Decay factor in iterations. Default: 1.0. decay_factor (float): Decay factor in iterations. Default: 1.0.
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: norm_level (Union[int, numpy.inf]): Order of the norm. Possible values:
np.inf, 1 or 2. Default: 'inf'. np.inf, 1 or 2. Default: 'inf'.
loss_fn (Loss): Loss function for optimization. loss_fn (Loss): Loss function for optimization. Default: None.
""" """
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0),
...@@ -303,9 +311,13 @@ class MomentumIterativeMethod(IterativeGradientMethod): ...@@ -303,9 +311,13 @@ class MomentumIterativeMethod(IterativeGradientMethod):
clip_min, clip_max = self._bounds clip_min, clip_max = self._bounds
clip_diff = clip_max - clip_min clip_diff = clip_max - clip_min
for _ in range(self._nb_iter): for _ in range(self._nb_iter):
gradient = self._gradient(inputs, labels) if 'self.prob' in globals():
d_inputs = _transform_inputs(inputs, self.prob)
else:
d_inputs = inputs
gradient = self._gradient(d_inputs, labels)
momentum = self._decay_factor*momentum + gradient momentum = self._decay_factor*momentum + gradient
adv_x = inputs + self._eps_iter*np.sign(momentum) adv_x = d_inputs + self._eps_iter*np.sign(momentum)
perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff,
self._eps*clip_diff) self._eps*clip_diff)
adv_x = arr_x + perturs adv_x = arr_x + perturs
...@@ -313,12 +325,15 @@ class MomentumIterativeMethod(IterativeGradientMethod): ...@@ -313,12 +325,15 @@ class MomentumIterativeMethod(IterativeGradientMethod):
inputs = adv_x inputs = adv_x
else: else:
for _ in range(self._nb_iter): for _ in range(self._nb_iter):
gradient = self._gradient(inputs, labels) if 'self.prob' in globals():
d_inputs = _transform_inputs(inputs, self.prob)
else:
d_inputs = inputs
gradient = self._gradient(d_inputs, labels)
momentum = self._decay_factor*momentum + gradient momentum = self._decay_factor*momentum + gradient
adv_x = inputs + self._eps_iter*np.sign(momentum) adv_x = d_inputs + self._eps_iter*np.sign(momentum)
adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps)
inputs = adv_x inputs = adv_x
return adv_x return adv_x
def _gradient(self, inputs, labels): def _gradient(self, inputs, labels):
...@@ -372,7 +387,7 @@ class ProjectedGradientDescent(BasicIterativeMethod): ...@@ -372,7 +387,7 @@ class ProjectedGradientDescent(BasicIterativeMethod):
nb_iter (int): Number of iteration. Default: 5. nb_iter (int): Number of iteration. Default: 5.
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: norm_level (Union[int, numpy.inf]): Order of the norm. Possible values:
np.inf, 1 or 2. Default: 'inf'. np.inf, 1 or 2. Default: 'inf'.
loss_fn (Loss): Loss function for optimization. loss_fn (Loss): Loss function for optimization. Default: None.
""" """
def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0),
...@@ -430,3 +445,114 @@ class ProjectedGradientDescent(BasicIterativeMethod): ...@@ -430,3 +445,114 @@ class ProjectedGradientDescent(BasicIterativeMethod):
adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps)
inputs = adv_x inputs = adv_x
return adv_x return adv_x
class DiverseInputIterativeMethod(BasicIterativeMethod):
"""
The Diverse Input Iterative Method attack.
References: `Xie, Cihang and Zhang, et al., "Improving Transferability of
Adversarial Examples With Input Diversity," in CVPR, 2019 <https://arxiv.org/abs/1803.06978>`_
Args:
network (Cell): Target model.
eps (float): Proportion of adversarial perturbation generated by the
attack to data range. Default: 0.3.
bounds (tuple): Upper and lower bounds of data, indicating the data range.
In form of (clip_min, clip_max). Default: (0.0, 1.0).
is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False.
prob (float): Transformation probability. Default: 0.5.
loss_fn (Loss): Loss function for optimization. Default: None.
"""
def __init__(self, network, eps=0.3, bounds=(0.0, 1.0),
is_targeted=False, prob=0.5, loss_fn=None):
# reference to paper hyper parameters setting.
eps_iter = 16*2/255
nb_iter = int(min(eps*255 + 4, 1.25*255*eps))
super(DiverseInputIterativeMethod, self).__init__(network,
eps=eps,
eps_iter=eps_iter,
bounds=bounds,
is_targeted=is_targeted,
nb_iter=nb_iter,
loss_fn=loss_fn)
# FGSM default alpha is None equal alpha=1
self.prob = check_param_type('prob', prob, float)
class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod):
"""
The Momentum Diverse Input Iterative Method attack.
References: `Xie, Cihang and Zhang, et al., "Improving Transferability of
Adversarial Examples With Input Diversity," in CVPR, 2019 <https://arxiv.org/abs/1803.06978>`_
Args:
network (Cell): Target model.
eps (float): Proportion of adversarial perturbation generated by the
attack to data range. Default: 0.3.
bounds (tuple): Upper and lower bounds of data, indicating the data range.
In form of (clip_min, clip_max). Default: (0.0, 1.0).
is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False.
norm_level (Union[int, numpy.inf]): Order of the norm. Possible values:
np.inf, 1 or 2. Default: 'l1'.
prob (float): Transformation probability. Default: 0.5.
loss_fn (Loss): Loss function for optimization. Default: None.
"""
def __init__(self, network, eps=0.3, bounds=(0.0, 1.0),
is_targeted=False, norm_level='l1', prob=0.5, loss_fn=None):
eps_iter = 16*2 / 255
nb_iter = int(min(eps*255 + 4, 1.25*255*eps))
super(MomentumDiverseInputIterativeMethod, self).__init__(network=network,
eps=eps,
eps_iter=eps_iter,
bounds=bounds,
nb_iter=nb_iter,
is_targeted=is_targeted,
norm_level=norm_level,
loss_fn=loss_fn)
self.prob = check_param_type('prob', prob, float)
def _transform_inputs(inputs, prob, low=29, high=33, full_aug=False):
"""
Inputs data augmentation.
Args:
inputs (Union[np.int8, np.float]): Inputs.
prob (float): The probability of augmentation.
low (int): Lower bound of resize image width. Default: 29.
high (int): Upper bound of resize image height. Default: 33.
full_aug (bool): type of augmentation method, use interpolation and padding
as default. Default: False.
Returns:
numpy.ndarray, the augmentation data.
"""
raw_shape = inputs[0].shape
tran_mask = np.random.uniform(0, 1, size=inputs.shape[0]) < prob
tran_inputs = inputs[tran_mask]
raw_inputs = inputs[tran_mask == 0]
tran_outputs = []
for sample in tran_inputs:
width = np.random.choice(np.arange(low, high))
# resize
sample = (sample*255).astype(np.uint8)
d_image = Image.fromarray(sample, mode='L').resize((width, width), Image.NEAREST)
# pad
left_pad = (raw_shape[0] - width) // 2
right_pad = raw_shape[0] - width - left_pad
top_pad = (raw_shape[1] - width) // 2
bottom_pad = raw_shape[1] - width - top_pad
p_sample = ImageOps.expand(d_image,
border=(left_pad, top_pad, right_pad, bottom_pad))
tran_outputs.append(np.array(p_sample).astype(np.float) / 255)
if full_aug:
# gaussian noise
tran_outputs = np.random.normal(tran_outputs.shape) + tran_outputs
tran_outputs.extend(raw_inputs)
if not np.any(tran_outputs-raw_inputs):
LOGGER.error(TAG, 'the transform function does not take effect.')
return tran_outputs
...@@ -242,7 +242,7 @@ def normalize_value(value, norm_level): ...@@ -242,7 +242,7 @@ def normalize_value(value, norm_level):
Raises: Raises:
NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2', NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2',
'inf] 'inf', 'l1', 'l2']
""" """
norm_level = check_norm_level(norm_level) norm_level = check_norm_level(norm_level)
ori_shape = value.shape ori_shape = value.shape
......
numpy >= 1.17.0 numpy >= 1.17.0
scipy >= 1.3.3 scipy >= 1.3.3
matplotlib >= 3.1.3 matplotlib >= 3.1.3
Pillow >= 2.0.0
pytest >= 4.3.1 pytest >= 4.3.1
wheel >= 0.32.0 wheel >= 0.32.0
setuptools >= 40.8.0 setuptools >= 40.8.0
...@@ -95,7 +95,8 @@ setup( ...@@ -95,7 +95,8 @@ setup(
install_requires=[ install_requires=[
'scipy >= 1.3.3', 'scipy >= 1.3.3',
'numpy >= 1.17.0', 'numpy >= 1.17.0',
'matplotlib >= 3.1.3' 'matplotlib >= 3.1.3',
'Pillow >= 2.0.0'
], ],
) )
print(find_packages()) print(find_packages())
...@@ -25,6 +25,8 @@ from mindarmour.attacks import BasicIterativeMethod ...@@ -25,6 +25,8 @@ from mindarmour.attacks import BasicIterativeMethod
from mindarmour.attacks import MomentumIterativeMethod from mindarmour.attacks import MomentumIterativeMethod
from mindarmour.attacks import ProjectedGradientDescent from mindarmour.attacks import ProjectedGradientDescent
from mindarmour.attacks import IterativeGradientMethod from mindarmour.attacks import IterativeGradientMethod
from mindarmour.attacks import DiverseInputIterativeMethod
from mindarmour.attacks import MomentumDiverseInputIterativeMethod
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
...@@ -91,7 +93,7 @@ def test_momentum_iterative_method(): ...@@ -91,7 +93,7 @@ def test_momentum_iterative_method():
for i in range(5): for i in range(5):
attack = MomentumIterativeMethod(Net(), nb_iter=i + 1) attack = MomentumIterativeMethod(Net(), nb_iter=i + 1)
ms_adv_x = attack.generate(input_np, label) ms_adv_x = attack.generate(input_np, label)
assert np.any(ms_adv_x != input_np), 'Basic iterative method: generate' \ assert np.any(ms_adv_x != input_np), 'Momentum iterative method: generate' \
' value must not be equal to' \ ' value must not be equal to' \
' original value.' ' original value.'
...@@ -119,6 +121,48 @@ def test_projected_gradient_descent_method(): ...@@ -119,6 +121,48 @@ def test_projected_gradient_descent_method():
' original value.' ' original value.'
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_diverse_input_iterative_method():
"""
Diverse input iterative method unit test.
"""
input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32)
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)
for i in range(5):
attack = DiverseInputIterativeMethod(Net())
ms_adv_x = attack.generate(input_np, label)
assert np.any(ms_adv_x != input_np), 'Diverse input iterative method: generate' \
' value must not be equal to' \
' original value.'
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_momentum_diverse_input_iterative_method():
"""
Momentum diverse input iterative method unit test.
"""
input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32)
label = np.asarray([2], np.int32)
label = np.eye(3)[label].astype(np.float32)
for i in range(5):
attack = MomentumDiverseInputIterativeMethod(Net())
ms_adv_x = attack.generate(input_np, label)
assert np.any(ms_adv_x != input_np), 'Momentum diverse input iterative method: ' \
'generate value must not be equal to' \
' original value.'
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册