From 810df70db68509b9ddab1ac6ab96539fac93bad1 Mon Sep 17 00:00:00 2001 From: zheng-huanhuan Date: Sat, 18 Apr 2020 14:55:22 +0800 Subject: [PATCH] 1. Add DI2-FGSM method in iterative_gradient_method.py 2. Add M-DI2-FGSM method in iterative_gradient_method.py 3. Add corresponding ut test. 4. Add mnist example of M_DI2_FGSM method. --- example/mnist_demo/mnist_attack_mdi2fgsm.py | 118 ++++++++++++++ mindarmour/attacks/__init__.py | 2 + mindarmour/attacks/gradient_method.py | 12 +- .../attacks/iterative_gradient_method.py | 150 ++++++++++++++++-- mindarmour/utils/_check_param.py | 2 +- requirements.txt | 1 + setup.py | 3 +- .../attacks/test_iterative_gradient_method.py | 46 +++++- 8 files changed, 313 insertions(+), 21 deletions(-) create mode 100644 example/mnist_demo/mnist_attack_mdi2fgsm.py diff --git a/example/mnist_demo/mnist_attack_mdi2fgsm.py b/example/mnist_demo/mnist_attack_mdi2fgsm.py new file mode 100644 index 0000000..eb983b5 --- /dev/null +++ b/example/mnist_demo/mnist_attack_mdi2fgsm.py @@ -0,0 +1,118 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import time +import numpy as np +import pytest +from scipy.special import softmax + +from mindspore import Model +from mindspore import Tensor +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from mindarmour.attacks.iterative_gradient_method import MomentumDiverseInputIterativeMethod + +from mindarmour.utils.logger import LogUtil +from mindarmour.evaluations.attack_evaluation import AttackEvaluate + +from lenet5_net import LeNet5 + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +sys.path.append("..") +from data_processing import generate_mnist_dataset + +LOGGER = LogUtil.get_instance() +TAG = 'M_DI2_FGSM_Test' +LOGGER.set_level('INFO') + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_momentum_diverse_input_iterative_method(): + """ + M-DI2-FGSM Attack Test + """ + # upload trained network + ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + net = LeNet5() + load_dict = load_checkpoint(ckpt_name) + load_param_into_net(net, load_dict) + + # get test data + data_list = "./MNIST_unzip/test" + batch_size = 32 + ds = generate_mnist_dataset(data_list, batch_size, sparse=False) + + # prediction accuracy before attack + model = Model(net) + batch_num = 32 # the number of batches of attacking samples + test_images = [] + test_labels = [] + predict_labels = [] + i = 0 + for data in ds.create_tuple_iterator(): + i += 1 + images = data[0].astype(np.float32) + labels = data[1] + test_images.append(images) + test_labels.append(labels) + pred_labels = np.argmax(model.predict(Tensor(images)).asnumpy(), + axis=1) + predict_labels.append(pred_labels) + if i >= batch_num: + break + predict_labels = np.concatenate(predict_labels) + true_labels = np.argmax(np.concatenate(test_labels), axis=1) + accuracy = np.mean(np.equal(predict_labels, true_labels)) + LOGGER.info(TAG, "prediction accuracy before attacking is : %s", accuracy) + + # attacking + attack = MomentumDiverseInputIterativeMethod(net) + start_time = time.clock() + adv_data = attack.batch_generate(np.concatenate(test_images), + np.concatenate(test_labels), batch_size=32) + stop_time = time.clock() + pred_logits_adv = model.predict(Tensor(adv_data)).asnumpy() + # rescale predict confidences into (0, 1). + pred_logits_adv = softmax(pred_logits_adv, axis=1) + pred_labels_adv = np.argmax(pred_logits_adv, axis=1) + accuracy_adv = np.mean(np.equal(pred_labels_adv, true_labels)) + LOGGER.info(TAG, "prediction accuracy after attacking is : %s", accuracy_adv) + attack_evaluate = AttackEvaluate(np.concatenate(test_images).transpose(0, 2, 3, 1), + np.concatenate(test_labels), + adv_data.transpose(0, 2, 3, 1), + pred_logits_adv) + LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s', + attack_evaluate.mis_classification_rate()) + LOGGER.info(TAG, 'The average confidence of adversarial class is : %s', + attack_evaluate.avg_conf_adv_class()) + LOGGER.info(TAG, 'The average confidence of true class is : %s', + attack_evaluate.avg_conf_true_class()) + LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original ' + 'samples and adversarial samples are: %s', + attack_evaluate.avg_lp_distance()) + LOGGER.info(TAG, 'The average structural similarity between original ' + 'samples and adversarial samples are: %s', + attack_evaluate.avg_ssim()) + LOGGER.info(TAG, 'The average costing time is %s', + (stop_time - start_time)/(batch_num*batch_size)) + + +if __name__ == '__main__': + test_momentum_diverse_input_iterative_method() diff --git a/mindarmour/attacks/__init__.py b/mindarmour/attacks/__init__.py index 11072ea..2e66469 100644 --- a/mindarmour/attacks/__init__.py +++ b/mindarmour/attacks/__init__.py @@ -26,6 +26,8 @@ __all__ = ['FastGradientMethod', 'BasicIterativeMethod', 'MomentumIterativeMethod', 'ProjectedGradientDescent', + 'DiverseInputIterativeMethod', + 'MomentumDiverseInputIterativeMethod', 'DeepFool', 'CarliniWagnerL2Attack', 'JSMAAttack', diff --git a/mindarmour/attacks/gradient_method.py b/mindarmour/attacks/gradient_method.py index e70290e..66cab6f 100644 --- a/mindarmour/attacks/gradient_method.py +++ b/mindarmour/attacks/gradient_method.py @@ -46,7 +46,7 @@ class GradientMethod(Attack): Default: None. bounds (tuple): Upper and lower bounds of data, indicating the data range. In form of (clip_min, clip_max). Default: None. - loss_fn (Loss): Loss function for optimization. + loss_fn (Loss): Loss function for optimization. Default: None. """ def __init__(self, network, eps=0.07, alpha=None, bounds=None, @@ -151,7 +151,7 @@ class FastGradientMethod(GradientMethod): Possible values: np.inf, 1 or 2. Default: 2. is_targeted (bool): If True, targeted attack. If False, untargeted attack. Default: False. - loss_fn (Loss): Loss function for optimization. + loss_fn (Loss): Loss function for optimization. Default: None. Examples: >>> attack = FastGradientMethod(network) @@ -214,7 +214,7 @@ class RandomFastGradientMethod(FastGradientMethod): Possible values: np.inf, 1 or 2. Default: 2. is_targeted (bool): If True, targeted attack. If False, untargeted attack. Default: False. - loss_fn (Loss): Loss function for optimization. + loss_fn (Loss): Loss function for optimization. Default: None. Raises: ValueError: eps is smaller than alpha! @@ -255,7 +255,7 @@ class FastGradientSignMethod(GradientMethod): In form of (clip_min, clip_max). Default: (0.0, 1.0). is_targeted (bool): If True, targeted attack. If False, untargeted attack. Default: False. - loss_fn (Loss): Loss function for optimization. + loss_fn (Loss): Loss function for optimization. Default: None. Examples: >>> attack = FastGradientSignMethod(network) @@ -314,7 +314,7 @@ class RandomFastGradientSignMethod(FastGradientSignMethod): In form of (clip_min, clip_max). Default: (0.0, 1.0). is_targeted (bool): True: targeted attack. False: untargeted attack. Default: False. - loss_fn (Loss): Loss function for optimization. + loss_fn (Loss): Loss function for optimization. Default: None. Raises: ValueError: eps is smaller than alpha! @@ -350,7 +350,7 @@ class LeastLikelyClassMethod(FastGradientSignMethod): Default: None. bounds (tuple): Upper and lower bounds of data, indicating the data range. In form of (clip_min, clip_max). Default: (0.0, 1.0). - loss_fn (Loss): Loss function for optimization. + loss_fn (Loss): Loss function for optimization. Default: None. Examples: >>> attack = LeastLikelyClassMethod(network) diff --git a/mindarmour/attacks/iterative_gradient_method.py b/mindarmour/attacks/iterative_gradient_method.py index 135ac6d..337fec8 100644 --- a/mindarmour/attacks/iterative_gradient_method.py +++ b/mindarmour/attacks/iterative_gradient_method.py @@ -15,6 +15,7 @@ from abc import abstractmethod import numpy as np +from PIL import Image, ImageOps from mindspore.nn import SoftmaxCrossEntropyWithLogits from mindspore import Tensor @@ -115,7 +116,7 @@ class IterativeGradientMethod(Attack): bounds (tuple): Upper and lower bounds of data, indicating the data range. In form of (clip_min, clip_max). Default: (0.0, 1.0). nb_iter (int): Number of iteration. Default: 5. - loss_fn (Loss): Loss function for optimization. + loss_fn (Loss): Loss function for optimization. Default: None. """ def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), nb_iter=5, loss_fn=None): @@ -178,14 +179,13 @@ class BasicIterativeMethod(IterativeGradientMethod): is_targeted (bool): If True, targeted attack. If False, untargeted attack. Default: False. nb_iter (int): Number of iteration. Default: 5. - loss_fn (Loss): Loss function for optimization. + loss_fn (Loss): Loss function for optimization. Default: None. attack (class): The single step gradient method of each iteration. In this class, FGSM is used. Examples: >>> attack = BasicIterativeMethod(network) """ - def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), is_targeted=False, nb_iter=5, loss_fn=None): super(BasicIterativeMethod, self).__init__(network, @@ -227,14 +227,22 @@ class BasicIterativeMethod(IterativeGradientMethod): clip_min, clip_max = self._bounds clip_diff = clip_max - clip_min for _ in range(self._nb_iter): - adv_x = self._attack.generate(inputs, labels) + if 'self.prob' in globals(): + d_inputs = _transform_inputs(inputs, self.prob) + else: + d_inputs = inputs + adv_x = self._attack.generate(d_inputs, labels) perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, self._eps*clip_diff) adv_x = arr_x + perturs inputs = adv_x else: for _ in range(self._nb_iter): - adv_x = self._attack.generate(inputs, labels) + if 'self.prob' in globals(): + d_inputs = _transform_inputs(inputs, self.prob) + else: + d_inputs = inputs + adv_x = self._attack.generate(d_inputs, labels) adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) inputs = adv_x return adv_x @@ -261,7 +269,7 @@ class MomentumIterativeMethod(IterativeGradientMethod): decay_factor (float): Decay factor in iterations. Default: 1.0. norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: np.inf, 1 or 2. Default: 'inf'. - loss_fn (Loss): Loss function for optimization. + loss_fn (Loss): Loss function for optimization. Default: None. """ def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), @@ -303,9 +311,13 @@ class MomentumIterativeMethod(IterativeGradientMethod): clip_min, clip_max = self._bounds clip_diff = clip_max - clip_min for _ in range(self._nb_iter): - gradient = self._gradient(inputs, labels) + if 'self.prob' in globals(): + d_inputs = _transform_inputs(inputs, self.prob) + else: + d_inputs = inputs + gradient = self._gradient(d_inputs, labels) momentum = self._decay_factor*momentum + gradient - adv_x = inputs + self._eps_iter*np.sign(momentum) + adv_x = d_inputs + self._eps_iter*np.sign(momentum) perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, self._eps*clip_diff) adv_x = arr_x + perturs @@ -313,12 +325,15 @@ class MomentumIterativeMethod(IterativeGradientMethod): inputs = adv_x else: for _ in range(self._nb_iter): - gradient = self._gradient(inputs, labels) + if 'self.prob' in globals(): + d_inputs = _transform_inputs(inputs, self.prob) + else: + d_inputs = inputs + gradient = self._gradient(d_inputs, labels) momentum = self._decay_factor*momentum + gradient - adv_x = inputs + self._eps_iter*np.sign(momentum) + adv_x = d_inputs + self._eps_iter*np.sign(momentum) adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) inputs = adv_x - return adv_x def _gradient(self, inputs, labels): @@ -372,7 +387,7 @@ class ProjectedGradientDescent(BasicIterativeMethod): nb_iter (int): Number of iteration. Default: 5. norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: np.inf, 1 or 2. Default: 'inf'. - loss_fn (Loss): Loss function for optimization. + loss_fn (Loss): Loss function for optimization. Default: None. """ def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), @@ -430,3 +445,114 @@ class ProjectedGradientDescent(BasicIterativeMethod): adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) inputs = adv_x return adv_x + + +class DiverseInputIterativeMethod(BasicIterativeMethod): + """ + The Diverse Input Iterative Method attack. + + References: `Xie, Cihang and Zhang, et al., "Improving Transferability of + Adversarial Examples With Input Diversity," in CVPR, 2019 `_ + + Args: + network (Cell): Target model. + eps (float): Proportion of adversarial perturbation generated by the + attack to data range. Default: 0.3. + bounds (tuple): Upper and lower bounds of data, indicating the data range. + In form of (clip_min, clip_max). Default: (0.0, 1.0). + is_targeted (bool): If True, targeted attack. If False, untargeted + attack. Default: False. + prob (float): Transformation probability. Default: 0.5. + loss_fn (Loss): Loss function for optimization. Default: None. + """ + def __init__(self, network, eps=0.3, bounds=(0.0, 1.0), + is_targeted=False, prob=0.5, loss_fn=None): + # reference to paper hyper parameters setting. + eps_iter = 16*2/255 + nb_iter = int(min(eps*255 + 4, 1.25*255*eps)) + super(DiverseInputIterativeMethod, self).__init__(network, + eps=eps, + eps_iter=eps_iter, + bounds=bounds, + is_targeted=is_targeted, + nb_iter=nb_iter, + loss_fn=loss_fn) + # FGSM default alpha is None equal alpha=1 + self.prob = check_param_type('prob', prob, float) + + +class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod): + """ + The Momentum Diverse Input Iterative Method attack. + + References: `Xie, Cihang and Zhang, et al., "Improving Transferability of + Adversarial Examples With Input Diversity," in CVPR, 2019 `_ + + Args: + network (Cell): Target model. + eps (float): Proportion of adversarial perturbation generated by the + attack to data range. Default: 0.3. + bounds (tuple): Upper and lower bounds of data, indicating the data range. + In form of (clip_min, clip_max). Default: (0.0, 1.0). + is_targeted (bool): If True, targeted attack. If False, untargeted + attack. Default: False. + norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: + np.inf, 1 or 2. Default: 'l1'. + prob (float): Transformation probability. Default: 0.5. + loss_fn (Loss): Loss function for optimization. Default: None. + """ + def __init__(self, network, eps=0.3, bounds=(0.0, 1.0), + is_targeted=False, norm_level='l1', prob=0.5, loss_fn=None): + eps_iter = 16*2 / 255 + nb_iter = int(min(eps*255 + 4, 1.25*255*eps)) + super(MomentumDiverseInputIterativeMethod, self).__init__(network=network, + eps=eps, + eps_iter=eps_iter, + bounds=bounds, + nb_iter=nb_iter, + is_targeted=is_targeted, + norm_level=norm_level, + loss_fn=loss_fn) + self.prob = check_param_type('prob', prob, float) + + +def _transform_inputs(inputs, prob, low=29, high=33, full_aug=False): + """ + Inputs data augmentation. + + Args: + inputs (Union[np.int8, np.float]): Inputs. + prob (float): The probability of augmentation. + low (int): Lower bound of resize image width. Default: 29. + high (int): Upper bound of resize image height. Default: 33. + full_aug (bool): type of augmentation method, use interpolation and padding + as default. Default: False. + + Returns: + numpy.ndarray, the augmentation data. + """ + raw_shape = inputs[0].shape + tran_mask = np.random.uniform(0, 1, size=inputs.shape[0]) < prob + tran_inputs = inputs[tran_mask] + raw_inputs = inputs[tran_mask == 0] + tran_outputs = [] + for sample in tran_inputs: + width = np.random.choice(np.arange(low, high)) + # resize + sample = (sample*255).astype(np.uint8) + d_image = Image.fromarray(sample, mode='L').resize((width, width), Image.NEAREST) + # pad + left_pad = (raw_shape[0] - width) // 2 + right_pad = raw_shape[0] - width - left_pad + top_pad = (raw_shape[1] - width) // 2 + bottom_pad = raw_shape[1] - width - top_pad + p_sample = ImageOps.expand(d_image, + border=(left_pad, top_pad, right_pad, bottom_pad)) + tran_outputs.append(np.array(p_sample).astype(np.float) / 255) + if full_aug: + # gaussian noise + tran_outputs = np.random.normal(tran_outputs.shape) + tran_outputs + tran_outputs.extend(raw_inputs) + if not np.any(tran_outputs-raw_inputs): + LOGGER.error(TAG, 'the transform function does not take effect.') + return tran_outputs diff --git a/mindarmour/utils/_check_param.py b/mindarmour/utils/_check_param.py index 3ac0eba..06a37c6 100644 --- a/mindarmour/utils/_check_param.py +++ b/mindarmour/utils/_check_param.py @@ -242,7 +242,7 @@ def normalize_value(value, norm_level): Raises: NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2', - 'inf] + 'inf', 'l1', 'l2'] """ norm_level = check_norm_level(norm_level) ori_shape = value.shape diff --git a/requirements.txt b/requirements.txt index 30e1b74..39afce4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ numpy >= 1.17.0 scipy >= 1.3.3 matplotlib >= 3.1.3 +Pillow >= 2.0.0 pytest >= 4.3.1 wheel >= 0.32.0 setuptools >= 40.8.0 diff --git a/setup.py b/setup.py index 7173529..c425c2c 100644 --- a/setup.py +++ b/setup.py @@ -95,7 +95,8 @@ setup( install_requires=[ 'scipy >= 1.3.3', 'numpy >= 1.17.0', - 'matplotlib >= 3.1.3' + 'matplotlib >= 3.1.3', + 'Pillow >= 2.0.0' ], ) print(find_packages()) diff --git a/tests/ut/python/attacks/test_iterative_gradient_method.py b/tests/ut/python/attacks/test_iterative_gradient_method.py index 8a0b580..9a766e2 100644 --- a/tests/ut/python/attacks/test_iterative_gradient_method.py +++ b/tests/ut/python/attacks/test_iterative_gradient_method.py @@ -25,6 +25,8 @@ from mindarmour.attacks import BasicIterativeMethod from mindarmour.attacks import MomentumIterativeMethod from mindarmour.attacks import ProjectedGradientDescent from mindarmour.attacks import IterativeGradientMethod +from mindarmour.attacks import DiverseInputIterativeMethod +from mindarmour.attacks import MomentumDiverseInputIterativeMethod context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") @@ -91,7 +93,7 @@ def test_momentum_iterative_method(): for i in range(5): attack = MomentumIterativeMethod(Net(), nb_iter=i + 1) ms_adv_x = attack.generate(input_np, label) - assert np.any(ms_adv_x != input_np), 'Basic iterative method: generate' \ + assert np.any(ms_adv_x != input_np), 'Momentum iterative method: generate' \ ' value must not be equal to' \ ' original value.' @@ -119,6 +121,48 @@ def test_projected_gradient_descent_method(): ' original value.' +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_diverse_input_iterative_method(): + """ + Diverse input iterative method unit test. + """ + input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) + label = np.asarray([2], np.int32) + label = np.eye(3)[label].astype(np.float32) + + for i in range(5): + attack = DiverseInputIterativeMethod(Net()) + ms_adv_x = attack.generate(input_np, label) + assert np.any(ms_adv_x != input_np), 'Diverse input iterative method: generate' \ + ' value must not be equal to' \ + ' original value.' + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_momentum_diverse_input_iterative_method(): + """ + Momentum diverse input iterative method unit test. + """ + input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) + label = np.asarray([2], np.int32) + label = np.eye(3)[label].astype(np.float32) + + for i in range(5): + attack = MomentumDiverseInputIterativeMethod(Net()) + ms_adv_x = attack.generate(input_np, label) + assert np.any(ms_adv_x != input_np), 'Momentum diverse input iterative method: ' \ + 'generate value must not be equal to' \ + ' original value.' + + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training -- GitLab