提交 ca131878 编写于 作者: wgzqz's avatar wgzqz

Add targeted attack methods

上级 d824116b
"""
A set of tools for generating adversarial example on paddle platform
A set of tools for generating adversarial example on paddle platform
"""
from . import attacks # type: ignore # noqa: F401
from . import models # type: ignore # noqa: F401
from .adversary import Adversary # noqa: F401
"""
The base model of the model.
"""
from abc import ABCMeta, abstractmethod
import logging
from abc import ABCMeta
from abc import abstractmethod
import numpy as np
class Attack(object):
"""
Abstract base class for adversarial attacks. `Attack` represent an adversarial attack
which search an adversarial example. subclass should implement the _apply() method.
Abstract base class for adversarial attacks. `Attack` represent an
adversarial attack which search an adversarial example. subclass should
implement the _apply() method.
Args:
model(Model): an instance of the class advbox.base.Model.
......@@ -18,22 +23,48 @@ class Attack(object):
def __init__(self, model):
self.model = model
def __call__(self, image_label):
def __call__(self, adversary, **kwargs):
"""
Generate the adversarial sample.
Args:
image_label(list): The image and label tuple list with one element.
adversary(object): The adversary object.
**kwargs: Other params.
"""
adv_img = self._apply(image_label)
return adv_img
self._preprocess(adversary)
return self._apply(adversary, **kwargs)
@abstractmethod
def _apply(self, image_label):
def _apply(self, adversary):
"""
Search an adversarial example.
Args:
image_batch(list): The image and label tuple list with one element.
adversary(object): The adversary object.
"""
raise NotImplementedError
def _preprocess(self, adversary):
"""
Preprocess the adversary object.
:param adversary: adversary
:return: None
"""
if adversary.original_label is None:
adversary.original_label = np.argmax(
self.model.predict([(adversary.original, 0)]))
if adversary.is_targeted_attack and adversary.target_label is None:
if adversary.target is None:
raise ValueError(
'When adversary.is_targeted_attack is True, '
'adversary.target_label or adversary.target must be set.')
else:
adversary.target_label_label = np.argmax(
self.model.predict([(adversary.target_label, 0)]))
logging.info('adversary:\noriginal_label: {}'
'\n target_lable: {}'
'\n is_targeted_attack: {}'.format(
adversary.original_label, adversary.target_label,
adversary.is_targeted_attack))
......@@ -2,37 +2,50 @@
This module provide the attack method for FGSM's implement.
"""
from __future__ import division
import numpy as np
import logging
from collections import Iterable
import numpy as np
from .base import Attack
class GradientSignAttack(Attack):
"""
This attack was originally implemented by Goodfellow et al. (2015) with the
infinity norm (and is known as the "Fast Gradient Sign Method"). This is therefore called
the Fast Gradient Method.
infinity norm (and is known as the "Fast Gradient Sign Method").
This is therefore called the Fast Gradient Method.
Paper link: https://arxiv.org/abs/1412.6572
"""
def _apply(self, image_label, epsilons=1000):
assert len(image_label) == 1
pre_label = np.argmax(self.model.predict(image_label))
min_, max_ = self.model.bounds()
gradient = self.model.gradient(image_label)
gradient_sign = np.sign(gradient) * (max_ - min_)
def _apply(self, adversary, epsilons=1000):
assert adversary is not None
if not isinstance(epsilons, Iterable):
epsilons = np.linspace(0, 1, num=epsilons + 1)
epsilons = np.linspace(0, 1, num=epsilons + 1)[1:]
pre_label = adversary.original_label
min_, max_ = self.model.bounds()
if adversary.is_targeted_attack:
gradient = self.model.gradient([(adversary.original,
adversary.target_label)])
gradient_sign = -np.sign(gradient) * (max_ - min_)
else:
gradient = self.model.gradient([(adversary.original,
adversary.original_label)])
gradient_sign = np.sign(gradient) * (max_ - min_)
for epsilon in epsilons:
adv_img = image_label[0][0].reshape(
gradient_sign.shape) + epsilon * gradient_sign
adv_img = adversary.original + epsilon * gradient_sign
adv_img = np.clip(adv_img, min_, max_)
adv_label = np.argmax(self.model.predict([(adv_img, 0)]))
if pre_label != adv_label:
return adv_img
logging.info('epsilon = {:.3f}, pre_label = {}, adv_label={}'.
format(epsilon, pre_label, adv_label))
if adversary.try_accept_the_example(adv_img, adv_label):
return adversary
return adversary
FGSM = GradientSignAttack
......@@ -2,8 +2,12 @@
This module provide the attack method for Iterator FGSM's implement.
"""
from __future__ import division
import numpy as np
import logging
from collections import Iterable
import numpy as np
from .base import Attack
......@@ -13,31 +17,43 @@ class IteratorGradientSignAttack(Attack):
Paper link: https://arxiv.org/pdf/1607.02533.pdf
"""
def _apply(self, image_label, epsilons=100, steps=10):
def _apply(self, adversary, epsilons=100, steps=10):
"""
Apply the iterative gradient sign attack.
Args:
image_label(list): The image and label tuple list of one element.
adversary(object): The image and label tuple list of one element.
epsilons(list|tuple|int): The epsilon (input variation parameter).
steps(int): The number of iterator steps.
Return:
numpy.ndarray: The adversarail sample generated by the algorithm.
"""
assert len(image_label) == 1
pre_label = np.argmax(self.model.predict(image_label))
gradient = self.model.gradient(image_label)
min_, max_ = self.model.bounds()
if not isinstance(epsilons, Iterable):
epsilons = np.linspace(0, 1, num=epsilons + 1)
epsilons = np.linspace(0, 1 / steps, num=epsilons + 1)[1:]
pre_label = adversary.original_label
min_, max_ = self.model.bounds()
for epsilon in epsilons:
adv_img = image_label[0][0].reshape(gradient.shape)
adv_img = adversary.original
for _ in range(steps):
gradient = self.model.gradient([(adv_img, image_label[0][1])])
gradient_sign = np.sign(gradient) * (max_ - min_)
adv_img = adv_img + epsilon * gradient_sign
if adversary.is_targeted_attack:
gradient = self.model.gradient([(adversary.original,
adversary.target_label)])
gradient_sign = -np.sign(gradient) * (max_ - min_)
else:
gradient = self.model.gradient([(adversary.original,
adversary.original_label)])
gradient_sign = np.sign(gradient) * (max_ - min_)
adv_img = adv_img + gradient_sign * epsilon
adv_img = np.clip(adv_img, min_, max_)
adv_label = np.argmax(self.model.predict([(adv_img, 0)]))
if pre_label != adv_label:
return adv_img
logging.info('epsilon = {:.3f}, pre_label = {}, adv_label={}'.
format(epsilon, pre_label, adv_label))
if adversary.try_accept_the_example(adv_img, adv_label):
return adversary
return adversary
IFGSM = IteratorGradientSignAttack
"""
Paddle model for target of attack
Paddle model for target of attack
"""
from .base import Model # noqa: F401
from .paddle import PaddleModel # noqa: F401
......@@ -2,21 +2,21 @@
The base model of the model.
"""
from abc import ABCMeta
import abc
from abc import abstractmethod
abstractmethod = abc.abstractmethod
import numpy as np
class Model(object):
"""
Base class of model to provide attack.
Args:
bounds(tuple): The lower and upper bound for the image pixel.
channel_axis(int): The index of the axis that represents the color channel.
preprocess(tuple): Two element tuple used to preprocess the input. First
substract the first element, then divide the second element.
channel_axis(int): The index of the axis that represents the color
channel.
preprocess(tuple): Two element tuple used to preprocess the input.
First substract the first element, then divide the second element.
"""
__metaclass__ = ABCMeta
......@@ -45,10 +45,10 @@ class Model(object):
def _process_input(self, input_):
res = input_
sub, div = self._preprocess
if sub != 0:
if np.any(sub != 0):
res = input_ - sub
assert div != 0
if div != 1:
assert np.any(div != 0)
if np.any(div != 1):
res /= div
return res
......@@ -58,10 +58,12 @@ class Model(object):
Calculate the prediction of the image batch.
Args:
image_batch(numpy.ndarray): image batch of shape (batch_size, height, width, channels).
image_batch(numpy.ndarray): image batch of shape (batch_size,
height, width, channels).
Return:
numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes).
numpy.ndarray: predictions of the images with shape (batch_size,
num_of_classes).
"""
raise NotImplementedError
......@@ -84,7 +86,7 @@ class Model(object):
image_batch(list): The image and label tuple list.
Return:
numpy.ndarray: gradient of the cross-entropy loss w.r.t the image with
the shape (height, width, channel).
numpy.ndarray: gradient of the cross-entropy loss w.r.t the image
with the shape (height, width, channel).
"""
raise NotImplementedError
from __future__ import absolute_import
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
from paddle.v2.fluid.framework import program_guard
from .base import Model
......@@ -11,10 +8,12 @@ from .base import Model
class PaddleModel(Model):
"""
Create a PaddleModel instance.
When you need to generate a adversarial sample, you should construct an instance of PaddleModel.
When you need to generate a adversarial sample, you should construct an
instance of PaddleModel.
Args:
program(paddle.v2.fluid.framework.Program): The program of the model which generate the adversarial sample.
program(paddle.v2.fluid.framework.Program): The program of the model
which generate the adversarial sample.
input_name(string): The name of the input.
logits_name(string): The name of the logits.
predict_name(string): The name of the predict.
......@@ -30,12 +29,12 @@ class PaddleModel(Model):
bounds,
channel_axis=3,
preprocess=None):
super(PaddleModel, self).__init__(
bounds=bounds, channel_axis=channel_axis, preprocess=preprocess)
if preprocess is None:
preprocess = (0, 1)
super(PaddleModel, self).__init__(
bounds=bounds, channel_axis=channel_axis, preprocess=preprocess)
self._program = program
self._place = fluid.CPUPlace()
self._exe = fluid.Executor(self._place)
......@@ -58,7 +57,8 @@ class PaddleModel(Model):
Args:
image_batch(list): The image and label tuple list.
Return:
numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes).
numpy.ndarray: predictions of the images with shape (batch_size,
num_of_classes).
"""
feeder = fluid.DataFeeder(
feed_list=[self._input_name, self._logits_name],
......@@ -72,7 +72,7 @@ class PaddleModel(Model):
def num_classes(self):
"""
Calculate the number of classes of the output label.
Calculate the number of classes of the output label.
Return:
int: the number of classes
......
"""
FGSM demos on mnist using advbox tool.
"""
import matplotlib.pyplot as plt
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import matplotlib.pyplot as plt
import numpy as np
from advbox.models.paddle import PaddleModel
from advbox.attacks.gradientsign import GradientSignAttack
from .advbox import Adversary
from .advbox.attacks.gradientsign import GradientSignAttack
from .advbox.models.paddle import PaddleModel
def cnn_model(img):
......@@ -18,7 +18,7 @@ def cnn_model(img):
Returns:
Variable: the label prediction
"""
#conv1 = fluid.nets.conv2d()
# conv1 = fluid.nets.conv2d()
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
num_filters=20,
......@@ -76,10 +76,11 @@ def main():
att = GradientSignAttack(m)
for data in train_reader():
# fgsm attack
adv_img = att(data)
plt.imshow(n[0][0], cmap='Greys_r')
plt.show()
#np.save('adv_img', adv_img)
adversary = att(Adversary(data))
if adversary.is_successful():
plt.imshow(adversary.target, cmap='Greys_r')
plt.show()
# np.save('adv_img', adversary.target)
break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册