diff --git a/fluid/adversarial/README.md b/fluid/adversarial/README.md new file mode 100644 index 0000000000000000000000000000000000000000..51da21918a9d6e2192a2e03eabef4fde97896bc5 --- /dev/null +++ b/fluid/adversarial/README.md @@ -0,0 +1,9 @@ +# Advbox + +Advbox is a Python toolbox to create adversarial examples that fool neural networks. It requires Python and paddle. + +## How to use + +1. train a model and save it's parameters. (like fluid_mnist.py) +2. load the parameters which is trained in step1, then reconstruct the model.(like mnist_tutorial_fgsm.py) +3. use advbox to generate the adversarial sample. diff --git a/fluid/adversarial/advbox/__init__.py b/fluid/adversarial/advbox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b50f18f7704de4ef52e47ce9f9f80643ed1032f9 --- /dev/null +++ b/fluid/adversarial/advbox/__init__.py @@ -0,0 +1,3 @@ +""" + A set of tools for generating adversarial example on paddle platform +""" diff --git a/fluid/adversarial/advbox/attacks/base.py b/fluid/adversarial/advbox/attacks/base.py new file mode 100644 index 0000000000000000000000000000000000000000..98a65f2fddff999ac6fa98a5733128a63a60f916 --- /dev/null +++ b/fluid/adversarial/advbox/attacks/base.py @@ -0,0 +1,39 @@ +""" +The base model of the model. +""" +from abc import ABCMeta, abstractmethod + + +class Attack(object): + """ + Abstract base class for adversarial attacks. `Attack` represent an adversarial attack + which search an adversarial example. subclass should implement the _apply() method. + + Args: + model(Model): an instance of the class advbox.base.Model. + + """ + __metaclass__ = ABCMeta + + def __init__(self, model): + self.model = model + + def __call__(self, image_label): + """ + Generate the adversarial sample. + + Args: + image_label(list): The image and label tuple list with one element. + """ + adv_img = self._apply(image_label) + return adv_img + + @abstractmethod + def _apply(self, image_label): + """ + Search an adversarial example. + + Args: + image_batch(list): The image and label tuple list with one element. + """ + raise NotImplementedError diff --git a/fluid/adversarial/advbox/attacks/gradientsign.py b/fluid/adversarial/advbox/attacks/gradientsign.py new file mode 100644 index 0000000000000000000000000000000000000000..15b1d176cb11330ac290d73aec1419a3d8f3cc4c --- /dev/null +++ b/fluid/adversarial/advbox/attacks/gradientsign.py @@ -0,0 +1,38 @@ +""" +This module provide the attack method for FGSM's implement. +""" +from __future__ import division +import numpy as np +from collections import Iterable +from .base import Attack + + +class GradientSignAttack(Attack): + """ + This attack was originally implemented by Goodfellow et al. (2015) with the + infinity norm (and is known as the "Fast Gradient Sign Method"). This is therefore called + the Fast Gradient Method. + Paper link: https://arxiv.org/abs/1412.6572 + """ + + def _apply(self, image_label, epsilons=1000): + assert len(image_label) == 1 + pre_label = np.argmax(self.model.predict(image_label)) + + min_, max_ = self.model.bounds() + gradient = self.model.gradient(image_label) + gradient_sign = np.sign(gradient) * (max_ - min_) + + if not isinstance(epsilons, Iterable): + epsilons = np.linspace(0, 1, num=epsilons + 1) + + for epsilon in epsilons: + adv_img = image_label[0][0].reshape( + gradient_sign.shape) + epsilon * gradient_sign + adv_img = np.clip(adv_img, min_, max_) + adv_label = np.argmax(self.model.predict([(adv_img, 0)])) + if pre_label != adv_label: + return adv_img + + +FGSM = GradientSignAttack diff --git a/fluid/adversarial/advbox/models/__init__.py b/fluid/adversarial/advbox/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec0ef13a79474ca5f3d9a7ec86836c0dc8555282 --- /dev/null +++ b/fluid/adversarial/advbox/models/__init__.py @@ -0,0 +1,3 @@ +""" +Paddle model for target of attack +""" diff --git a/fluid/adversarial/advbox/models/base.py b/fluid/adversarial/advbox/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..74e1045def7648b4a8df30e89312d73c0d4fe7e1 --- /dev/null +++ b/fluid/adversarial/advbox/models/base.py @@ -0,0 +1,90 @@ +""" +The base model of the model. +""" +from abc import ABCMeta +import abc + +abstractmethod = abc.abstractmethod + + +class Model(object): + """ + Base class of model to provide attack. + + + Args: + bounds(tuple): The lower and upper bound for the image pixel. + channel_axis(int): The index of the axis that represents the color channel. + preprocess(tuple): Two element tuple used to preprocess the input. First + substract the first element, then divide the second element. + """ + __metaclass__ = ABCMeta + + def __init__(self, bounds, channel_axis, preprocess=None): + assert len(bounds) == 2 + assert channel_axis in [0, 1, 2, 3] + + if preprocess is None: + preprocess = (0, 1) + self._bounds = bounds + self._channel_axis = channel_axis + self._preprocess = preprocess + + def bounds(self): + """ + Return the upper and lower bounds of the model. + """ + return self._bounds + + def channel_axis(self): + """ + Return the channel axis of the model. + """ + return self._channel_axis + + def _process_input(self, input_): + res = input_ + sub, div = self._preprocess + if sub != 0: + res = input_ - sub + assert div != 0 + if div != 1: + res /= div + return res + + @abstractmethod + def predict(self, image_batch): + """ + Calculate the prediction of the image batch. + + Args: + image_batch(numpy.ndarray): image batch of shape (batch_size, height, width, channels). + + Return: + numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes). + """ + raise NotImplementedError + + @abstractmethod + def num_classes(self): + """ + Determine the number of the classes + + Return: + int: the number of the classes + """ + raise NotImplementedError + + @abstractmethod + def gradient(self, image_batch): + """ + Calculate the gradient of the cross-entropy loss w.r.t the image. + + Args: + image_batch(list): The image and label tuple list. + + Return: + numpy.ndarray: gradient of the cross-entropy loss w.r.t the image with + the shape (height, width, channel). + """ + raise NotImplementedError diff --git a/fluid/adversarial/advbox/models/paddle.py b/fluid/adversarial/advbox/models/paddle.py new file mode 100644 index 0000000000000000000000000000000000000000..976a525ca3bd38413280f8c2470b46a8914a86e1 --- /dev/null +++ b/fluid/adversarial/advbox/models/paddle.py @@ -0,0 +1,103 @@ +from __future__ import absolute_import + +import numpy as np +import paddle.v2 as paddle +import paddle.v2.fluid as fluid +from paddle.v2.fluid.framework import program_guard + +from .base import Model + + +class PaddleModel(Model): + """ + Create a PaddleModel instance. + When you need to generate a adversarial sample, you should construct an instance of PaddleModel. + + Args: + program(paddle.v2.fluid.framework.Program): The program of the model which generate the adversarial sample. + input_name(string): The name of the input. + logits_name(string): The name of the logits. + predict_name(string): The name of the predict. + cost_name(string): The name of the loss in the program. + """ + + def __init__(self, + program, + input_name, + logits_name, + predict_name, + cost_name, + bounds, + channel_axis=3, + preprocess=None): + super(PaddleModel, self).__init__( + bounds=bounds, channel_axis=channel_axis, preprocess=preprocess) + + if preprocess is None: + preprocess = (0, 1) + + self._program = program + self._place = fluid.CPUPlace() + self._exe = fluid.Executor(self._place) + + self._input_name = input_name + self._logits_name = logits_name + self._predict_name = predict_name + self._cost_name = cost_name + + # gradient + loss = self._program.block(0).var(self._cost_name) + param_grads = fluid.backward.append_backward( + loss, parameter_list=[self._input_name]) + self._gradient = dict(param_grads)[self._input_name] + + def predict(self, image_batch): + """ + Predict the label of the image_batch. + + Args: + image_batch(list): The image and label tuple list. + Return: + numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes). + """ + feeder = fluid.DataFeeder( + feed_list=[self._input_name, self._logits_name], + place=self._place, + program=self._program) + predict_var = self._program.block(0).var(self._predict_name) + predict = self._exe.run( + self._program, + feed=feeder.feed(image_batch), + fetch_list=[predict_var]) + return predict + + def num_classes(self): + """ + Calculate the number of classes of the output label. + + Return: + int: the number of classes + """ + predict_var = self._program.block(0).var(self._predict_name) + assert len(predict_var.shape) == 2 + return predict_var.shape[1] + + def gradient(self, image_batch): + """ + Calculate the gradient of the loss w.r.t the input. + + Args: + image_batch(list): The image and label tuple list. + Return: + list: The list of the gradient of the image. + """ + feeder = fluid.DataFeeder( + feed_list=[self._input_name, self._logits_name], + place=self._place, + program=self._program) + + grad, = self._exe.run( + self._program, + feed=feeder.feed(image_batch), + fetch_list=[self._gradient]) + return grad diff --git a/fluid/adversarial/fluid_mnist.py b/fluid/adversarial/fluid_mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..e2bfeeaf356bd3d8da638be4ac79df1f9ef316f0 --- /dev/null +++ b/fluid/adversarial/fluid_mnist.py @@ -0,0 +1,86 @@ +""" +CNN on mnist data using fluid api of paddlepaddle +""" +import paddle.v2 as paddle +import paddle.v2.fluid as fluid + + +def mnist_cnn_model(img): + """ + Mnist cnn model + + Args: + img(Varaible): the input image to be recognized + + Returns: + Variable: the label prediction + """ + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=img, + num_filters=20, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') + + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + num_filters=50, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') + + logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax') + return logits + + +def main(): + """ + Train the cnn model on mnist datasets + """ + img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + logits = mnist_cnn_model(img) + cost = fluid.layers.cross_entropy(input=logits, label=label) + avg_cost = fluid.layers.mean(x=cost) + optimizer = fluid.optimizer.Adam(learning_rate=0.01) + optimizer.minimize(avg_cost) + + accuracy = fluid.evaluator.Accuracy(input=logits, label=label) + + BATCH_SIZE = 50 + PASS_NUM = 3 + ACC_THRESHOLD = 0.98 + LOSS_THRESHOLD = 10.0 + train_reader = paddle.batch( + paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500), + batch_size=BATCH_SIZE) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + feeder = fluid.DataFeeder(feed_list=[img, label], place=place) + exe.run(fluid.default_startup_program()) + + for pass_id in range(PASS_NUM): + accuracy.reset(exe) + for data in train_reader(): + loss, acc = exe.run( + fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[avg_cost] + accuracy.metrics) + pass_acc = accuracy.eval(exe) + print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc=" + + str(pass_acc)) + if loss < LOSS_THRESHOLD and pass_acc > ACC_THRESHOLD: + break + + pass_acc = accuracy.eval(exe) + print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc)) + fluid.io.save_params( + exe, dirname='./mnist', main_program=fluid.default_main_program()) + print('train mnist done') + + +if __name__ == '__main__': + main() diff --git a/fluid/adversarial/mnist_tutorial_fgsm.py b/fluid/adversarial/mnist_tutorial_fgsm.py new file mode 100644 index 0000000000000000000000000000000000000000..4b9a1de8c21cb69899d850b881886265696bbcca --- /dev/null +++ b/fluid/adversarial/mnist_tutorial_fgsm.py @@ -0,0 +1,86 @@ +""" +FGSM demos on mnist using advbox tool. +""" +import paddle.v2 as paddle +import paddle.v2.fluid as fluid +import matplotlib.pyplot as plt +import numpy as np + +from advbox.models.paddle import PaddleModel +from advbox.attacks.gradientsign import GradientSignAttack + + +def cnn_model(img): + """ + Mnist cnn model + Args: + img(Varaible): the input image to be recognized + Returns: + Variable: the label prediction + """ + #conv1 = fluid.nets.conv2d() + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=img, + num_filters=20, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') + + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + num_filters=50, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') + + logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax') + return logits + + +def main(): + """ + Advbox demo which demonstrate how to use advbox. + """ + IMG_NAME = 'img' + LABEL_NAME = 'label' + + img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32') + # gradient should flow + img.stop_gradient = False + label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64') + logits = cnn_model(img) + cost = fluid.layers.cross_entropy(input=logits, label=label) + avg_cost = fluid.layers.mean(x=cost) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + BATCH_SIZE = 1 + train_reader = paddle.batch( + paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500), + batch_size=BATCH_SIZE) + feeder = fluid.DataFeeder( + feed_list=[IMG_NAME, LABEL_NAME], + place=place, + program=fluid.default_main_program()) + + fluid.io.load_params( + exe, "./mnist/", main_program=fluid.default_main_program()) + + # advbox demo + m = PaddleModel(fluid.default_main_program(), IMG_NAME, LABEL_NAME, + logits.name, avg_cost.name, (-1, 1)) + att = GradientSignAttack(m) + for data in train_reader(): + # fgsm attack + adv_img = att(data) + plt.imshow(n[0][0], cmap='Greys_r') + plt.show() + #np.save('adv_img', adv_img) + break + + +if __name__ == '__main__': + main()