提交 2b3ba40e 编写于 作者: G gx_wind

add adversarial sample

上级 643ff03f
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A set of tools for generating adversarial example on paddle platform
"""
"""
The base model of the model.
"""
from abc import ABCMeta
#from advbox.base import Model
import abc
abstractmethod = abc.abstractmethod
class Attack(object):
"""
Abstract base class for adversarial attacks. `Attack` represent an adversarial attack
which search an adversarial example. subclass should implement the _apply() method.
Args:
model(Model): an instance of the class advbox.base.Model.
"""
__metaclass__ = ABCMeta
def __init__(self, model):
self.model = model
def __call__(self, image_batch):
"""
Generate the adversarial sample.
Args:
image_batch(list): The image and label tuple list.
"""
adv_img = self._apply(image_batch)
return adv_img
@abstractmethod
def _apply(self, image_batch):
"""
Search an adversarial example.
Args:
image_batch(list): The image and label tuple list.
"""
raise NotImplementedError
"""
This module provide the attack method for FGSM's implement.
"""
from __future__ import division
import numpy as np
from collections import Iterable
from .base import Attack
class GradientSignAttack(Attack):
"""
This attack was originally implemented by Goodfellow et al. (2015) with the
infinity norm (and is known as the "Fast Gradient Sign Method"). This is therefore called
the Fast Gradient Method.
Paper link: https://arxiv.org/abs/1412.6572
"""
def _apply(self, image_batch, epsilons=1000):
pre_label = np.argmax(self.model.predict(image_batch))
min_, max_ = self.model.bounds()
gradient = self.model.gradient(image_batch)
gradient_sign = np.sign(gradient) * (max_ - min_)
if not isinstance(epsilons, Iterable):
epsilons = np.linspace(0, 1, num = epsilons + 1)
for epsilon in epsilons:
adv_img = image_batch[0][0].reshape(gradient_sign.shape) + epsilon * gradient_sign
adv_img = np.clip(adv_img, min_, max_)
adv_label = np.argmax(self.model.predict([(adv_img, 0)]))
#print("pre_label="+str(pre_label)+ " adv_label="+str(adv_label))
if pre_label != adv_label:
#print(epsilon, pre_label, adv_label)
return adv_img
FGSM = GradientSignAttack
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Paddle model for target of attack
"""
"""
The base model of the model.
"""
from abc import ABCMeta
import abc
abstractmethod = abc.abstractmethod
class Model(object):
"""
Base class of model to provide attack.
Args:
bounds(tuple): The lower and upper bound for the image pixel.
channel_axis(int): The index of the axis that represents the color channel.
preprocess(tuple): Two element tuple used to preprocess the input. First
substract the first element, then divide the second element.
"""
__metaclass__ = ABCMeta
def __init__(self, bounds, channel_axis, preprocess=None):
assert len(bounds) == 2
assert channel_axis in [0, 1, 2, 3]
if preprocess is None:
preprocess = (0, 1)
self._bounds = bounds
self._channel_axis = channel_axis
self._preprocess = preprocess
def bounds(self):
"""
Return the upper and lower bounds of the model.
"""
return self._bounds
def channel_axis(self):
"""
Return the channel axis of the model.
"""
return self._channel_axis
def _process_input(self, input_):
res = input_
sub, div = self._preprocess
if sub != 0:
res = input_ - sub
assert div != 0
if div != 1:
res /= div
return res
@abstractmethod
def predict(self, image_batch):
"""
Calculate the prediction of the image batch.
Args:
image_batch(numpy.ndarray): image batch of shape (batch_size, height, width, channels).
Return:
numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes).
"""
raise NotImplementedError
@abstractmethod
def num_classes(self):
"""
Determine the number of the classes
Return:
int: the number of the classes
"""
raise NotImplementedError
@abstractmethod
def gradient(self, image_batch):
"""
Calculate the gradient of the cross-entropy loss w.r.t the image.
Args:
image(numpy.ndarray): image with shape (height, width, channel)
label(int): image label used to cal gradient.
Return:
numpy.ndarray: gradient of the cross-entropy loss w.r.t the image with
the shape (height, width, channel).
"""
raise NotImplementedError
from __future__ import absolute_import
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
from paddle.v2.fluid.framework import program_guard
from .base import Model
class PaddleModel(Model):
"""
Create a PaddleModel instance.
When you need to generate a adversarial sample, you should construct an instance of PaddleModel.
Args:
program(paddle.v2.fluid.framework.Program): The program of the model which generate the adversarial sample.
input_name(string): The name of the input.
logits_name(string): The name of the logits.
predict_name(string): The name of the predict.
cost_name(string): The name of the loss in the program.
"""
def __init__(self,
program,
input_name,
logits_name,
predict_name,
cost_name,
bounds,
channel_axis=3,
preprocess=None):
super(PaddleModel, self).__init__(
bounds=bounds,
channel_axis=channel_axis,
preprocess=preprocess)
if preprocess is None:
preprocess = (0, 1)
self._program = program
self._place = fluid.CPUPlace()
self._exe = fluid.Executor(self._place)
self._input_name = input_name
self._logits_name = logits_name
self._predict_name = predict_name
self._cost_name = cost_name
# gradient
loss = self._program.block(0).var(self._cost_name)
param_grads = fluid.backward.append_backward(loss, parameter_list=[self._input_name])
self._gradient = param_grads[0][1]
def predict(self, image_batch):
"""
Predict the label of the image_batch.
Args:
image_batch(list): The image and label tuple list.
Return:
numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes).
"""
feeder = fluid.DataFeeder(
feed_list=[self._input_name, self._logits_name],
place=self._place,
program=self._program
)
predict_var = self._program.block(0).var(self._predict_name)
predict = self._exe.run(
self._program,
feed=feeder.feed(image_batch),
fetch_list=[predict_var]
)
return predict
def num_classes(self):
"""
Calculate the number of classes of the output label.
Return:
int: the number of classes
"""
predict_var = self._program.block(0).var(self._predict_name)
assert len(predict_var.shape) == 2
return predict_var.shape[1]
def gradient(self, image_batch):
"""
Calculate the gradient of the loss w.r.t the input.
Args:
image_batch(list): The image and label tuple list.
Return:
list: The list of the gradient of the image.
"""
feeder = fluid.DataFeeder(
feed_list=[self._input_name, self._logits_name],
place=self._place,
program=self._program
)
grad, = self._exe.run(
self._program,
feed=feeder.feed(image_batch),
fetch_list=[self._gradient])
return grad
################################################################################
#
# Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
A pure Paddlepaddle implementation of a neural network.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
from advbox import Model
def main():
"""
example main function
"""
model_dir = "./mnist_model"
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feed_var_names, fetch_vars = fluid.io.load_inferfence_model(model_dir, exe)
print(program)
if __name__ == "__main__":
main()
"""
CNN on mnist data using fluid api of paddlepaddle
"""
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
def mnist_cnn_model(img):
"""
Mnist cnn model
Args:
img(Varaible): the input image to be recognized
Returns:
Variable: the label prediction
"""
#conv1 = fluid.nets.conv2d()
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
num_filters=20,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
num_filters=50,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
logits = fluid.layers.fc(
input=conv_pool_2,
size=10,
act='softmax')
return logits
def main():
"""
Train the cnn model on mnist datasets
"""
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
logits = mnist_cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
optimizer = fluid.optimizer.Adam(learning_rate=0.01)
optimizer.minimize(avg_cost)
accuracy = fluid.evaluator.Accuracy(input=logits, label=label)
BATCH_SIZE = 50
PASS_NUM = 3
ACC_THRESHOLD = 0.98
LOSS_THRESHOLD = 10.0
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
exe.run(fluid.default_startup_program())
for pass_id in range(PASS_NUM):
accuracy.reset(exe)
for data in train_reader():
loss, acc = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost] + accuracy.metrics)
pass_acc = accuracy.eval(exe)
print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc=" +
str(pass_acc))
# print loss, acc
if loss < LOSS_THRESHOLD and pass_acc > ACC_THRESHOLD:
# if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good.
break
# exit(0)
pass_acc = accuracy.eval(exe)
print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc))
fluid.io.save_params(exe, dirname='./mnist', main_program=fluid.default_main_program())
print('train mnist done')
exit(1)
if __name__ == '__main__':
main()
"""
This attack was originally implemented by Goodfellow et al. (2015) with the
infinity norm (and is known as the "Fast Gradient Sign Method"). This is therefore called
the Fast Gradient Method.
Paper link: https://arxiv.org/abs/1412.6572
"""
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
BATCH_SIZE = 50
PASS_NUM = 1
EPS = 0.3
CLIP_MIN = -1
CLIP_MAX = 1
PASS_NUM = 1
def mnist_cnn_model(img):
"""
Mnist cnn model
Args:
img(Varaible): the input image to be recognized
Returns:
Variable: the label prediction
"""
#conv1 = fluid.nets.conv2d()
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
num_filters=20,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
num_filters=50,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
logits = fluid.layers.fc(
input=conv_pool_2,
size=10,
act='softmax')
return logits
def main():
"""
Generate adverserial example and evaluate accuracy on mnist using FGSM
"""
images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype='float32')
# The gradient should flow
images.stop_gradient = False
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
predict = mnist_cnn_model(images)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
# Cal gradient of input
params_grads = fluid.backward.append_backward_ops(avg_cost, parameter_list=['pixel'])
# data batch
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE)
accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
accuracy.reset(exe)
#exe.run(fluid.default_startup_program())
feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
for pass_id in range(PASS_NUM):
fluid.io.load_params(exe, "./mnist/", main_program=fluid.default_main_program())
for data in train_reader():
# cal gradient and eval accuracy
ps, acc = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[params_grads[0][1]]+accuracy.metrics)
labels = []
for idx, _ in enumerate(data):
labels.append(data[idx][1])
# generate adversarial example
batch_num = ps.shape[0]
new_data = []
for i in range(batch_num):
adv_img = np.reshape(data[0][0], (1, 28, 28)) + EPS * np.sign(ps[i])
adv_img = np.clip(adv_img, CLIP_MIN, CLIP_MAX)
#adv_imgs.append(adv_img)
t = (adv_img, data[0][1])
new_data.append(t)
# predict label
predict_label, = exe.run(
fluid.default_main_program(),
feed=feeder.feed(new_data),
fetch_list=[predict])
adv_labels = np.argmax(predict_label, axis=1)
batch_accuracy = np.mean(np.equal(labels, adv_labels))
print "pass_id=" + str(pass_id) + " acc=" + str(acc)+ " adv_acc=" + str(batch_accuracy)
if __name__ == "__main__":
main()
"""
FGSM demos on mnist using advbox tool.
"""
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import matplotlib.pyplot as plt
import numpy as np
from advbox.models.paddle import PaddleModel
from advbox.attacks.gradientsign import GradientSignAttack
def cnn_model(img):
"""
Mnist cnn model
Args:
img(Varaible): the input image to be recognized
Returns:
Variable: the label prediction
"""
#conv1 = fluid.nets.conv2d()
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
num_filters=20,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
num_filters=50,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
logits = fluid.layers.fc(
input=conv_pool_2,
size=10,
act='softmax')
return logits
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE)
feeder = fluid.DataFeeder(
feed_list=[IMG_NAME, LABEL_NAME],
place=place,
program=fluid.default_main_program()
)
fluid.io.load_params(exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(
fluid.default_main_program(),
IMG_NAME,
LABEL_NAME,
logits.name,
avg_cost.name,
(-1, 1)
)
att = GradientSignAttack(m)
for data in train_reader():
# fgsm attack
adv_img = att(data)
plt.imshow(n[0][0], cmap='Greys_r')
plt.show()
#np.save('adv_img', adv_img)
break
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册