未验证 提交 7a0013c5 编写于 作者: wgzqz's avatar wgzqz 提交者: GitHub

Merge pull request #602 from guangzhuwu/develop

Add deepfool.
......@@ -61,7 +61,7 @@ class Adversary(object):
def _is_successful(self, adversarial_label):
"""
Is the adversarial_label is the expected adversarial label.
:param adversarial_label: adversarial label.
:return: bool
"""
......@@ -87,10 +87,12 @@ class Adversary(object):
:return: bool
"""
assert adversarial_example is not None
assert self.__original.shape == adversarial_example.shape
ok = self._is_successful(adversarial_label)
if ok:
self.__adversarial_example = adversarial_example.reshape(
self.__original.shape)
self.__adversarial_example = adversarial_example
self.__adversarial_label = adversarial_label
return ok
......
......@@ -3,6 +3,7 @@ Attack methods
"""
from .base import Attack
from .deepfool import DeepFoolAttack
from .gradientsign import FGSM
from .gradientsign import GradientSignAttack
from .iterator_gradientsign import IFGSM
......
......@@ -54,7 +54,7 @@ class Attack(object):
"""
if adversary.original_label is None:
adversary.original_label = np.argmax(
self.model.predict([(adversary.original, 0)]))
self.model.predict(adversary.original))
if adversary.is_targeted_attack and adversary.target_label is None:
if adversary.target is None:
raise ValueError(
......@@ -62,7 +62,8 @@ class Attack(object):
'adversary.target_label or adversary.target must be set.')
else:
adversary.target_label_label = np.argmax(
self.model.predict([(adversary.target_label, 0)]))
self.model.predict(
self.model.scale_input(adversary.target)))
logging.info('adversary:\noriginal_label: {}'
'\n target_lable: {}'
......
"""
This module provide the attack method for deepfool. Deepfool is a simple and
accurate adversarial attack.
"""
from __future__ import division
import logging
import numpy as np
from .base import Attack
class DeepFoolAttack(Attack):
"""
DeepFool: a simple and accurate method to fool deep neural networks",
Seyed-Mohsen Moosavi-Dezfooli, Alhussein Fawzi, Pascal Frossard,
https://arxiv.org/abs/1511.04599
"""
def _apply(self, adversary, iterations=100, overshoot=0.02):
"""
Apply the deep fool attack.
Args:
adversary(Adversary): The Adversary object.
iterations(int): The iterations.
overshoot(float): We add (1+overshoot)*pert every iteration.
Return:
adversary: The Adversary object.
"""
assert adversary is not None
pre_label = adversary.original_label
min_, max_ = self.model.bounds()
f = self.model.predict(adversary.original)
if adversary.is_targeted_attack:
labels = [adversary.target_label]
else:
max_class_count = 10
class_count = self.model.num_classes()
if class_count > max_class_count:
labels = np.argsort(f)[-(max_class_count + 1):-1]
else:
labels = np.arange(class_count)
gradient = self.model.gradient(adversary.original, pre_label)
x = adversary.original
for iteration in xrange(iterations):
w = np.inf
w_norm = np.inf
pert = np.inf
for k in labels:
if k == pre_label:
continue
gradient_k = self.model.gradient(x, k)
w_k = gradient_k - gradient
f_k = f[k] - f[pre_label]
w_k_norm = np.linalg.norm(w_k) + 1e-8
pert_k = (np.abs(f_k) + 1e-8) / w_k_norm
if pert_k < pert:
pert = pert_k
w = w_k
w_norm = w_k_norm
r_i = -w * pert / w_norm # The gradient is -gradient in the paper.
x = x + (1 + overshoot) * r_i
x = np.clip(x, min_, max_)
f = self.model.predict(x)
gradient = self.model.gradient(x, pre_label)
adv_label = np.argmax(f)
logging.info('iteration = {}, f = {}, pre_label = {}'
', adv_label={}'.format(iteration, f[pre_label],
pre_label, adv_label))
if adversary.try_accept_the_example(x, adv_label):
return adversary
return adversary
......@@ -37,19 +37,18 @@ class GradientSignAttack(Attack):
min_, max_ = self.model.bounds()
if adversary.is_targeted_attack:
gradient = self.model.gradient([(adversary.original,
adversary.target_label)])
gradient = self.model.gradient(adversary.original,
adversary.target_label)
gradient_sign = -np.sign(gradient) * (max_ - min_)
else:
gradient = self.model.gradient([(adversary.original,
adversary.original_label)])
gradient = self.model.gradient(adversary.original,
adversary.original_label)
gradient_sign = np.sign(gradient) * (max_ - min_)
original = adversary.original.reshape(gradient_sign.shape)
for epsilon in epsilons:
adv_img = original + epsilon * gradient_sign
adv_img = adversary.original + epsilon * gradient_sign
adv_img = np.clip(adv_img, min_, max_)
adv_label = np.argmax(self.model.predict([(adv_img, 0)]))
adv_label = np.argmax(self.model.predict(adv_img))
logging.info('epsilon = {:.3f}, pre_label = {}, adv_label={}'.
format(epsilon, pre_label, adv_label))
if adversary.try_accept_the_example(adv_img, adv_label):
......
......@@ -35,21 +35,19 @@ class IteratorGradientSignAttack(Attack):
min_, max_ = self.model.bounds()
for epsilon in epsilons:
adv_img = None
adv_img = adversary.original
for _ in range(steps):
if adversary.is_targeted_attack:
gradient = self.model.gradient([(adversary.original,
adversary.target_label)])
gradient = self.model.gradient(adversary.original,
adversary.target_label)
gradient_sign = -np.sign(gradient) * (max_ - min_)
else:
gradient = self.model.gradient([(adversary.original,
adversary.original_label)])
gradient = self.model.gradient(adversary.original,
adversary.original_label)
gradient_sign = np.sign(gradient) * (max_ - min_)
if adv_img is None:
adv_img = adversary.original.reshape(gradient_sign.shape)
adv_img = adv_img + gradient_sign * epsilon
adv_img = np.clip(adv_img, min_, max_)
adv_label = np.argmax(self.model.predict([(adv_img, 0)]))
adv_label = np.argmax(self.model.predict(adv_img))
logging.info('epsilon = {:.3f}, pre_label = {}, adv_label={}'.
format(epsilon, pre_label, adv_label))
if adversary.try_accept_the_example(adv_img, adv_label):
......
......@@ -43,26 +43,31 @@ class Model(object):
return self._channel_axis
def _process_input(self, input_):
res = input_
res = None
sub, div = self._preprocess
if np.any(sub != 0):
res = input_ - sub
assert np.any(div != 0)
if np.any(div != 1):
res /= div
if res is None: # "res = input_ - sub" is not executed!
res = input_ / div
else:
res /= div
if res is None: # "res = (input_ - sub)/ div" is not executed!
return input_
return res
@abstractmethod
def predict(self, image_batch):
def predict(self, data):
"""
Calculate the prediction of the image batch.
Calculate the prediction of the data.
Args:
image_batch(numpy.ndarray): image batch of shape (batch_size,
data(numpy.ndarray): input data with shape (size,
height, width, channels).
Return:
numpy.ndarray: predictions of the images with shape (batch_size,
numpy.ndarray: predictions of the data with shape (batch_size,
num_of_classes).
"""
raise NotImplementedError
......@@ -78,12 +83,14 @@ class Model(object):
raise NotImplementedError
@abstractmethod
def gradient(self, image_batch):
def gradient(self, data, label):
"""
Calculate the gradient of the cross-entropy loss w.r.t the image.
Args:
image_batch(list): The image and label tuple list.
data(numpy.ndarray): input data with shape (size, height, width,
channels).
label(int): Label used to calculate the gradient.
Return:
numpy.ndarray: gradient of the cross-entropy loss w.r.t the image
......
......@@ -3,6 +3,7 @@ Paddle model
"""
from __future__ import absolute_import
import numpy as np
import paddle.v2.fluid as fluid
from .base import Model
......@@ -54,24 +55,28 @@ class PaddleModel(Model):
self._gradient = filter(lambda p: p[0].name == self._input_name,
param_grads)[0][1]
def predict(self, image_batch):
def predict(self, data):
"""
Predict the label of the image_batch.
Calculate the prediction of the data.
Args:
image_batch(list): The image and label tuple list.
Return:
numpy.ndarray: predictions of the images with shape (batch_size,
num_of_classes).
Args:
data(numpy.ndarray): input data with shape (size,
height, width, channels).
Return:
numpy.ndarray: predictions of the data with shape (batch_size,
num_of_classes).
"""
scaled_data = self._process_input(data)
feeder = fluid.DataFeeder(
feed_list=[self._input_name, self._logits_name],
place=self._place,
program=self._program)
predict_var = self._program.block(0).var(self._predict_name)
predict = self._exe.run(self._program,
feed=feeder.feed(image_batch),
feed=feeder.feed([(scaled_data, 0)]),
fetch_list=[predict_var])
predict = np.squeeze(predict, axis=0)
return predict
def num_classes(self):
......@@ -85,21 +90,27 @@ class PaddleModel(Model):
assert len(predict_var.shape) == 2
return predict_var.shape[1]
def gradient(self, image_batch):
def gradient(self, data, label):
"""
Calculate the gradient of the loss w.r.t the input.
Calculate the gradient of the cross-entropy loss w.r.t the image.
Args:
image_batch(list): The image and label tuple list.
data(numpy.ndarray): input data with shape (size, height, width,
channels).
label(int): Label used to calculate the gradient.
Return:
list: The list of the gradient of the image.
numpy.ndarray: gradient of the cross-entropy loss w.r.t the image
with the shape (height, width, channel).
"""
scaled_data = self._process_input(data)
feeder = fluid.DataFeeder(
feed_list=[self._input_name, self._logits_name],
place=self._place,
program=self._program)
grad, = self._exe.run(self._program,
feed=feeder.feed(image_batch),
feed=feeder.feed([(scaled_data, label)]),
fetch_list=[self._gradient])
return grad
return grad.reshape(data.shape)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册