未验证 提交 66d3b668 编写于 作者: G gx-wind 提交者: GitHub

Merge pull request #841 from buaawht/new_method

Add MomentumIteratorAttack to Advbox
......@@ -54,6 +54,7 @@ The structure of Advbox module are as follows:
| ├── mnist_tutorial_fgsm.py
| ├── mnist_tutorial_bim.py
| ├── mnist_tutorial_ilcm.py
| ├── mnist_tutorial_mifgsm.py
| ├── mnist_tutorial_jsma.py
| └── mnist_tutorial_deepfool.py
└── README.md
......@@ -77,6 +78,7 @@ The `./tutorials/` folder provides some tutorials to generate adversarial exampl
* [FGSM](https://arxiv.org/abs/1412.6572)
* [BIM](https://arxiv.org/abs/1607.02533)
* [ILCM](https://arxiv.org/abs/1607.02533)
* [MI-FGSM](https://arxiv.org/pdf/1710.06081.pdf)
* [JSMA](https://arxiv.org/pdf/1511.07528)
* [DeepFool](https://arxiv.org/abs/1511.04599)
......@@ -91,6 +93,7 @@ Benchmarks on a vanilla CNN model.
|FGSM| 57.8% | 26.55% | 0.3 | One shot| *** |
|BIM| 97.4% | --- | 0.1 | 100 | **** |
|ILCM| --- | 100.0% | 0.1 | 100 | **** |
|MI-FGSM| 94.4% | 100.0% | 0.1 | 100 | **** |
|JSMA| 96.8% | 90.4%| 0.1 | 2000 | *** |
|DeepFool| 97.7% | 51.3% | --- | 100 | **** |
......@@ -101,8 +104,9 @@ Benchmarks on a vanilla CNN model.
* [Intriguing properties of neural networks](https://arxiv.org/abs/1312.6199), C. Szegedy et al., arxiv 2014
* [Explaining and Harnessing Adversarial Examples](https://arxiv.org/abs/1412.6572), I. Goodfellow et al., ICLR 2015
* [Adversarial Examples In The Physical World](https://arxiv.org/pdf/1607.02533v3.pdf), A. Kurakin et al., ICLR workshop 2017
* [Boosting Adversarial Attacks with Momentum](https://arxiv.org/abs/1710.06081), Yinpeng Dong et al., arxiv 2018
* [The Limitations of Deep Learning in Adversarial Settings](https://arxiv.org/abs/1511.07528), N. Papernot et al., ESSP 2016
* [DeepFool: a simple and accurate method to fool deep neural networks](https://arxiv.org/abs/1511.04599), S. Moosavi-Dezfooli et al., CVPR 2016
* [Foolbox: A Python toolbox to benchmark the robustness of machine learning models] (https://arxiv.org/abs/1707.04131), Jonas Rauber et al., arxiv 2018
* [Foolbox: A Python toolbox to benchmark the robustness of machine learning models](https://arxiv.org/abs/1707.04131), Jonas Rauber et al., arxiv 2018
* [CleverHans: An adversarial example library for constructing attacks, building defenses, and benchmarking both](https://github.com/tensorflow/cleverhans#setting-up-cleverhans)
* [Threat of Adversarial Attacks on Deep Learning in Computer Vision: A Survey](https://arxiv.org/abs/1801.00553), Naveed Akhtar, Ajmal Mian, arxiv 2018
......@@ -14,7 +14,8 @@ __all__ = [
'GradientMethodAttack', 'FastGradientSignMethodAttack', 'FGSM',
'FastGradientSignMethodTargetedAttack', 'FGSMT',
'BasicIterativeMethodAttack', 'BIM',
'IterativeLeastLikelyClassMethodAttack', 'ILCM'
'IterativeLeastLikelyClassMethodAttack', 'ILCM', 'MomentumIteratorAttack',
'MIFGSM'
]
......@@ -76,9 +77,9 @@ class GradientMethodAttack(Attack):
for epsilon in epsilons[:]:
step = 1
adv_img = adversary.original
if epsilon == 0.0:
continue
for i in range(steps):
if epsilon == 0.0:
continue
if adversary.is_targeted_attack:
gradient = -self.model.gradient(adv_img,
adversary.target_label)
......@@ -175,7 +176,103 @@ class BasicIterativeMethodAttack(IterativeLeastLikelyClassMethodAttack):
super(BasicIterativeMethodAttack, self).__init__(model, False)
class MomentumIteratorAttack(GradientMethodAttack):
"""
The Momentum Iterative Fast Gradient Sign Method (Dong et al. 2017).
This method won the first places in NIPS 2017 Non-targeted Adversarial
Attacks and Targeted Adversarial Attacks. The original paper used
hard labels for this attack; no label smoothing. inf norm.
Paper link: https://arxiv.org/pdf/1710.06081.pdf
"""
def __init__(self, model, support_targeted=True):
"""
:param model(model): The model to be attacked.
:param support_targeted(bool): Does this attack method support targeted.
"""
super(MomentumIteratorAttack, self).__init__(model)
self.support_targeted = support_targeted
def _apply(self,
adversary,
norm_ord=np.inf,
epsilons=0.1,
steps=100,
epsilon_steps=100,
decay_factor=1):
"""
Apply the momentum iterative gradient attack method.
:param adversary(Adversary):
The Adversary object.
:param norm_ord(int):
Order of the norm, such as np.inf, 1, 2, etc. It can't be 0.
:param epsilons(list|tuple|float):
Attack step size (input variation).
Largest step size if epsilons is not iterable.
:param epsilon_steps:
The number of Epsilons' iteration for each attack iteration.
:param steps:
The number of attack iteration.
:param decay_factor:
The decay factor for the momentum term.
:return:
adversary(Adversary): The Adversary object.
"""
if norm_ord == 0:
raise ValueError("L0 norm is not supported!")
if not self.support_targeted:
if adversary.is_targeted_attack:
raise ValueError(
"This attack method doesn't support targeted attack!")
assert self.model.channel_axis() == adversary.original.ndim
assert (self.model.channel_axis() == 1 or
self.model.channel_axis() == adversary.original.shape[0] or
self.model.channel_axis() == adversary.original.shape[-1])
if not isinstance(epsilons, Iterable):
epsilons = np.linspace(0, epsilons, num=epsilon_steps)
min_, max_ = self.model.bounds()
pre_label = adversary.original_label
for epsilon in epsilons[:]:
if epsilon == 0.0:
continue
step = 1
adv_img = adversary.original
momentum = 0
for i in range(steps):
if adversary.is_targeted_attack:
gradient = -self.model.gradient(adv_img,
adversary.target_label)
else:
gradient = self.model.gradient(adv_img, pre_label)
# normalize gradient
velocity = gradient / self._norm(gradient, ord=1)
momentum = decay_factor * momentum + velocity
if norm_ord == np.inf:
normalized_grad = np.sign(momentum)
else:
normalized_grad = self._norm(momentum, ord=norm_ord)
perturbation = epsilon * normalized_grad
adv_img = adv_img + perturbation
adv_img = np.clip(adv_img, min_, max_)
adv_label = np.argmax(self.model.predict(adv_img))
logging.info(
'step={}, epsilon = {:.5f}, pre_label = {}, adv_label={}'
.format(step, epsilon, pre_label, adv_label))
if adversary.try_accept_the_example(adv_img, adv_label):
return adversary
step += 1
return adversary
FGSM = FastGradientSignMethodAttack
FGSMT = FastGradientSignMethodTargetedAttack
BIM = BasicIterativeMethodAttack
ILCM = IterativeLeastLikelyClassMethodAttack
MIFGSM = MomentumIteratorAttack
"""
MIFGSM tutorial on mnist using advbox tool.
MIFGSM is a broad class of momentum iterative gradient-based methods based on FSGM.
It supports non-targeted attack and targeted attack.
"""
import sys
sys.path.append("..")
import matplotlib.pyplot as plt
import numpy as np
import paddle.fluid as fluid
import paddle.v2 as paddle
from advbox.adversary import Adversary
from advbox.attacks.gradient_method import MIFGSM
from advbox.models.paddle import PaddleModel
from tutorials.mnist_model import mnist_cnn_model
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
TOTAL_NUM = 500
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = mnist_cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
# use CPU
place = fluid.CPUPlace()
# use GPU
# place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.test(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(
fluid.default_main_program(),
IMG_NAME,
LABEL_NAME,
logits.name,
avg_cost.name, (-1, 1),
channel_axis=1)
attack = MIFGSM(m)
attack_config = {
"norm_ord": np.inf,
"epsilons": 0.1,
"steps": 100,
"decay_factor": 1
}
# use train data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in train_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# MIFGSM non-targeted attack
adversary = attack(adversary, **attack_config)
# MIFGSM targeted attack
# tlabel = 0
# adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TRAIN_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
# use test data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in test_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# MIFGSM non-targeted attack
adversary = attack(adversary, **attack_config)
# MIFGSM targeted attack
# tlabel = 0
# adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TEST_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
print("mifgsm attack done")
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册