提交 2e7494d9 编写于 作者: G guosheng

Merge branch 'develop' of https://github.com/PaddlePaddle/models into fix-transformer-batchsize-dev

......@@ -17,7 +17,7 @@ addons:
- python-pip
- python2.7-dev
- clang-format-3.8
ssh_known_hosts: 52.76.173.135
ssh_known_hosts: 13.229.163.131
before_install:
- if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
- sudo pip install -U virtualenv pre-commit pip
......
......@@ -168,7 +168,7 @@ def profile(args):
start_time = time.time()
frames_seen = 0
# load_data
(features, labels, lod) = batch_data
(features, labels, lod, _) = batch_data
feature_t.set(features, place)
feature_t.set_lod([lod])
label_t.set(labels, place)
......
......@@ -192,7 +192,7 @@ def train(args):
test_data_reader.batch_iterator(args.batch_size,
args.minimum_batch_size)):
# load_data
(features, labels, lod) = batch_data
(features, labels, lod, _) = batch_data
feature_t.set(features, place)
feature_t.set_lod([lod])
label_t.set(labels, place)
......
......@@ -4,10 +4,105 @@ The minimum PaddlePaddle version needed for the code sample in this directory is
# Advbox
Advbox is a Python toolbox to create adversarial examples that fool neural networks. It requires Python and paddle.
Advbox is a toolbox to generate adversarial examples that fool neural networks and Advbox can benchmark the robustness of machine learning models.
## How to use
The Advbox is based on [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) Fluid and is under continual development, always welcoming contributions of the latest method of adversarial attacks and defenses.
1. train a model and save it's parameters. (like fluid_mnist.py)
2. load the parameters which is trained in step1, then reconstruct the model.(like mnist_tutorial_fgsm.py)
3. use advbox to generate the adversarial sample.
## Overview
[Szegedy et al.](https://arxiv.org/abs/1312.6199) discovered an intriguing properties of deep neural networks in the context of image classification for the first time. They showed that despite the state-of-the-art deep networks are surprisingly susceptible to adversarial attacks in the form of small perturbations to images that remain (almost) imperceptible to human vision system. These perturbations are found by optimizing the input to maximize the prediction error and the images modified by these perturbations are called as `adversarial examples`. The profound implications of these results triggered a wide interest of researchers in adversarial attacks and their defenses for deep learning in general.
Advbox is similar to [Foolbox](https://github.com/bethgelab/foolbox) and [CleverHans](https://github.com/tensorflow/cleverhans). CleverHans only supports TensorFlow framework while foolbox interfaces with many popular machine learning frameworks such as PyTorch, Keras, TensorFlow, Theano, Lasagne and MXNet. However, these two great libraries don't support PaddlePaddle, an easy-to-use, efficient, flexible and scalable deep learning platform which is originally developed by Baidu scientists and engineers for the purpose of applying deep learning to many products at Baidu.
## Usage
Advbox provides many stable reference implementations of modern methods to generate adversarial examples such as FGSM, DeepFool, JSMA. When you want to benchmark the robustness of your neural networks , you can use the advbox to generate some adversarial examples and benchmark the networks. Some tips of using Advbox:
1. Train a model and save the parameters.
2. Load the parameters which has been trained,then reconstruct the model.
3. Use advbox to generate the adversarial samples.
#### Dependencies
* PaddlePaddle: [the lastest develop branch](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html)
* Python 2.x
#### Structure
Network models, attack method's implements and the criterion that defines adversarial examples are three essential elements to generate adversarial examples. Misclassification is adopted as the adversarial criterion for briefness in Advbox.
The structure of Advbox module are as follows:
.
├── advbox
| ├── __init__.py
| ├── attack
| ├── __init__.py
| ├── base.py
| ├── deepfool.py
| ├── gradient_method.py
| ├── lbfgs.py
| └── saliency.py
| ├── models
| ├── __init__.py
| ├── base.py
| └── paddle.py
| └── adversary.py
├── tutorials
| ├── __init__.py
| ├── mnist_model.py
| ├── mnist_tutorial_lbfgs.py
| ├── mnist_tutorial_fgsm.py
| ├── mnist_tutorial_bim.py
| ├── mnist_tutorial_ilcm.py
| ├── mnist_tutorial_jsma.py
| └── mnist_tutorial_deepfool.py
└── README.md
**advbox.attack**
Advbox implements several popular adversarial attacks which search adversarial examples. Each attack method uses a distance measure(L1, L2, etc.) to quantify the size of adversarial perturbations. Advbox is easy to craft adversarial example as some attack methods could perform internal hyperparameter tuning to find the minimum perturbation.
**advbox.model**
Advbox implements interfaces to PaddlePaddle. Additionally, other deep learning framworks such as TensorFlow can also be defined and employed. The module is use to compute predictions and gradients for given inputs in a specific framework.
**advbox.adversary**
Adversary contains the original object, the target and the adversarial examples. It provides the misclassification as the criterion to accept a adversarial example.
## Tutorials
The `./tutorials/` folder provides some tutorials to generate adversarial examples on the MNIST dataset. You can slightly modify the code to apply to other dataset. These attack methods are supported in Advbox:
* [L-BFGS](https://arxiv.org/abs/1312.6199)
* [FGSM](https://arxiv.org/abs/1412.6572)
* [BIM](https://arxiv.org/abs/1607.02533)
* [ILCM](https://arxiv.org/abs/1607.02533)
* [JSMA](https://arxiv.org/pdf/1511.07528)
* [DeepFool](https://arxiv.org/abs/1511.04599)
## Testing
Benchmarks on a vanilla CNN model.
> MNIST
| adversarial attacks | fooling rate (non-targeted) | fooling rate (targeted) | max_epsilon | iterations | Strength |
|:-----:| :----: | :---: | :----: | :----: | :----: |
|L-BFGS| --- | 89.2% | --- | One shot | *** |
|FGSM| 57.8% | 26.55% | 0.3 | One shot| *** |
|BIM| 97.4% | --- | 0.1 | 100 | **** |
|ILCM| --- | 100.0% | 0.1 | 100 | **** |
|JSMA| 96.8% | 90.4%| 0.1 | 2000 | *** |
|DeepFool| 97.7% | 51.3% | --- | 100 | **** |
* The strength (higher for more asterisks) is based on the impression from the reviewed literature.
---
## References
* [Intriguing properties of neural networks](https://arxiv.org/abs/1312.6199), C. Szegedy et al., arxiv 2014
* [Explaining and Harnessing Adversarial Examples](https://arxiv.org/abs/1412.6572), I. Goodfellow et al., ICLR 2015
* [Adversarial Examples In The Physical World](https://arxiv.org/pdf/1607.02533v3.pdf), A. Kurakin et al., ICLR workshop 2017
* [The Limitations of Deep Learning in Adversarial Settings](https://arxiv.org/abs/1511.07528), N. Papernot et al., ESSP 2016
* [DeepFool: a simple and accurate method to fool deep neural networks](https://arxiv.org/abs/1511.04599), S. Moosavi-Dezfooli et al., CVPR 2016
* [Foolbox: A Python toolbox to benchmark the robustness of machine learning models] (https://arxiv.org/abs/1707.04131), Jonas Rauber et al., arxiv 2018
* [CleverHans: An adversarial example library for constructing attacks, building defenses, and benchmarking both](https://github.com/tensorflow/cleverhans#setting-up-cleverhans)
* [Threat of Adversarial Attacks on Deep Learning in Computer Vision: A Survey](https://arxiv.org/abs/1801.00553), Naveed Akhtar, Ajmal Mian, arxiv 2018
......@@ -32,7 +32,12 @@ class GradientMethodAttack(Attack):
super(GradientMethodAttack, self).__init__(model)
self.support_targeted = support_targeted
def _apply(self, adversary, norm_ord=np.inf, epsilons=0.01, steps=100):
def _apply(self,
adversary,
norm_ord=np.inf,
epsilons=0.01,
steps=1,
epsilon_steps=100):
"""
Apply the gradient attack method.
:param adversary(Adversary):
......@@ -41,8 +46,11 @@ class GradientMethodAttack(Attack):
Order of the norm, such as np.inf, 1, 2, etc. It can't be 0.
:param epsilons(list|tuple|int):
Attack step size (input variation).
Largest step size if epsilons is not iterable.
:param steps:
The number of iterator steps.
The number of attack iteration.
:param epsilon_steps:
The number of Epsilons' iteration for each attack iteration.
:return:
adversary(Adversary): The Adversary object.
"""
......@@ -55,7 +63,7 @@ class GradientMethodAttack(Attack):
"This attack method doesn't support targeted attack!")
if not isinstance(epsilons, Iterable):
epsilons = np.linspace(epsilons, epsilons + 1e-10, num=steps)
epsilons = np.linspace(0, epsilons, num=epsilon_steps)
pre_label = adversary.original_label
min_, max_ = self.model.bounds()
......@@ -65,30 +73,33 @@ class GradientMethodAttack(Attack):
self.model.channel_axis() == adversary.original.shape[0] or
self.model.channel_axis() == adversary.original.shape[-1])
step = 1
adv_img = adversary.original
for epsilon in epsilons[:steps]:
if epsilon == 0.0:
continue
if adversary.is_targeted_attack:
gradient = -self.model.gradient(adv_img, adversary.target_label)
else:
gradient = self.model.gradient(adv_img,
adversary.original_label)
if norm_ord == np.inf:
gradient_norm = np.sign(gradient)
else:
gradient_norm = gradient / self._norm(gradient, ord=norm_ord)
adv_img = adv_img + epsilon * gradient_norm * (max_ - min_)
adv_img = np.clip(adv_img, min_, max_)
adv_label = np.argmax(self.model.predict(adv_img))
logging.info('step={}, epsilon = {:.5f}, pre_label = {}, '
'adv_label={}'.format(step, epsilon, pre_label,
adv_label))
if adversary.try_accept_the_example(adv_img, adv_label):
return adversary
step += 1
for epsilon in epsilons[:]:
step = 1
adv_img = adversary.original
for i in range(steps):
if epsilon == 0.0:
continue
if adversary.is_targeted_attack:
gradient = -self.model.gradient(adv_img,
adversary.target_label)
else:
gradient = self.model.gradient(adv_img,
adversary.original_label)
if norm_ord == np.inf:
gradient_norm = np.sign(gradient)
else:
gradient_norm = gradient / self._norm(
gradient, ord=norm_ord)
adv_img = adv_img + epsilon * gradient_norm * (max_ - min_)
adv_img = np.clip(adv_img, min_, max_)
adv_label = np.argmax(self.model.predict(adv_img))
logging.info('step={}, epsilon = {:.5f}, pre_label = {}, '
'adv_label={}'.format(step, epsilon, pre_label,
adv_label))
if adversary.try_accept_the_example(adv_img, adv_label):
return adversary
step += 1
return adversary
@staticmethod
......@@ -113,7 +124,7 @@ class FastGradientSignMethodTargetedAttack(GradientMethodAttack):
Paper link: https://arxiv.org/abs/1412.6572
"""
def _apply(self, adversary, epsilons=0.03):
def _apply(self, adversary, epsilons=0.01):
return GradientMethodAttack._apply(
self,
adversary=adversary,
......@@ -144,7 +155,7 @@ class IterativeLeastLikelyClassMethodAttack(GradientMethodAttack):
Paper link: https://arxiv.org/abs/1607.02533
"""
def _apply(self, adversary, epsilons=0.001, steps=1000):
def _apply(self, adversary, epsilons=0.01, steps=1000):
return GradientMethodAttack._apply(
self,
adversary=adversary,
......
"""
FGSM demos on mnist using advbox tool.
"""
import matplotlib.pyplot as plt
import paddle.v2 as paddle
import paddle.fluid as fluid
from advbox.adversary import Adversary
from advbox.attacks.gradient_method import FGSM
from advbox.models.paddle import PaddleModel
def cnn_model(img):
"""
Mnist cnn model
Args:
img(Varaible): the input image to be recognized
Returns:
Variable: the label prediction
"""
# conv1 = fluid.nets.conv2d()
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
num_filters=20,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
num_filters=50,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
return logits
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE)
feeder = fluid.DataFeeder(
feed_list=[IMG_NAME, LABEL_NAME],
place=place,
program=fluid.default_main_program())
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(fluid.default_main_program(), IMG_NAME, LABEL_NAME,
logits.name, avg_cost.name, (-1, 1))
att = FGSM(m)
for data in train_reader():
# fgsm attack
adversary = att(Adversary(data[0][0], data[0][1]))
if adversary.is_successful():
plt.imshow(adversary.target, cmap='Greys_r')
plt.show()
# np.save('adv_img', adversary.target)
break
if __name__ == '__main__':
main()
"""
FGSM demos on mnist using advbox tool.
"""
import matplotlib.pyplot as plt
import paddle.v2 as paddle
import paddle.fluid as fluid
import numpy as np
from advbox import Adversary
from advbox.attacks.saliency import SaliencyMapAttack
from advbox.models.paddle import PaddleModel
def cnn_model(img):
"""
Mnist cnn model
Args:
img(Varaible): the input image to be recognized
Returns:
Variable: the label prediction
"""
# conv1 = fluid.nets.conv2d()
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
num_filters=20,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
num_filters=50,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
return logits
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE)
feeder = fluid.DataFeeder(
feed_list=[IMG_NAME, LABEL_NAME],
place=place,
program=fluid.default_main_program())
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(fluid.default_main_program(), IMG_NAME, LABEL_NAME,
logits.name, avg_cost.name, (-1, 1))
attack = SaliencyMapAttack(m)
total_num = 0
success_num = 0
for data in train_reader():
total_num += 1
# adversary.set_target(True, target_label=target_label)
jsma_attack = attack(Adversary(data[0][0], data[0][1]))
if jsma_attack is not None and jsma_attack.is_successful():
# plt.imshow(jsma_attack.target, cmap='Greys_r')
# plt.show()
success_num += 1
print('original_label=%d, adversary examples label =%d' %
(data[0][1], jsma_attack.adversarial_label))
# np.save('adv_img', jsma_attack.adversarial_example)
print('total num = %d, success num = %d ' % (total_num, success_num))
if total_num == 100:
break
if __name__ == '__main__':
main()
"""
A set of tutorials for generating adversarial examples with advbox.
"""
\ No newline at end of file
......@@ -30,8 +30,9 @@ def mnist_cnn_model(img):
pool_size=2,
pool_stride=2,
act='relu')
fc = fluid.layers.fc(input=conv_pool_2, size=50, act='relu')
logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
logits = fluid.layers.fc(input=fc, size=10, act='softmax')
return logits
......@@ -60,7 +61,10 @@ def main():
paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE)
# use CPU
place = fluid.CPUPlace()
# use GPU
# place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
exe.run(fluid.default_startup_program())
......@@ -74,9 +78,11 @@ def main():
feed=feeder.feed(data),
fetch_list=[avg_cost, batch_acc, batch_size])
pass_acc.add(value=acc, weight=b_size)
pass_acc_val = pass_acc.eval()[0]
print("pass_id=" + str(pass_id) + " acc=" + str(acc[0]) +
" pass_acc=" + str(pass_acc.eval()[0]))
if loss < LOSS_THRESHOLD and pass_acc > ACC_THRESHOLD:
" pass_acc=" + str(pass_acc_val))
if loss < LOSS_THRESHOLD and pass_acc_val > ACC_THRESHOLD:
# early stop
break
print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc.eval()[
......
"""
BIM tutorial on mnist using advbox tool.
BIM method iteratively take multiple small steps while adjusting the direction after each step.
It only supports non-targeted attack.
"""
import sys
sys.path.append("..")
import matplotlib.pyplot as plt
import paddle.fluid as fluid
import paddle.v2 as paddle
from advbox.adversary import Adversary
from advbox.attacks.gradient_method import BIM
from advbox.models.paddle import PaddleModel
from tutorials.mnist_model import mnist_cnn_model
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
TOTAL_NUM = 500
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = mnist_cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
# use CPU
place = fluid.CPUPlace()
# use GPU
# place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.test(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(
fluid.default_main_program(),
IMG_NAME,
LABEL_NAME,
logits.name,
avg_cost.name, (-1, 1),
channel_axis=1)
attack = BIM(m)
attack_config = {"epsilons": 0.1, "steps": 100}
# use train data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in train_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# BIM non-targeted attack
adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TRAIN_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
# use test data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in test_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# BIM non-targeted attack
adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TEST_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
print("bim attack done")
if __name__ == '__main__':
main()
"""
DeepFool tutorial on mnist using advbox tool.
Deepfool is a simple and accurate adversarial attack method.
It supports both targeted attack and non-targeted attack.
"""
import sys
sys.path.append("..")
import matplotlib.pyplot as plt
import paddle.fluid as fluid
import paddle.v2 as paddle
from advbox.adversary import Adversary
from advbox.attacks.deepfool import DeepFoolAttack
from advbox.models.paddle import PaddleModel
from tutorials.mnist_model import mnist_cnn_model
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
TOTAL_NUM = 500
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = mnist_cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
# use CPU
place = fluid.CPUPlace()
# use GPU
# place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.test(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(
fluid.default_main_program(),
IMG_NAME,
LABEL_NAME,
logits.name,
avg_cost.name, (-1, 1),
channel_axis=1)
attack = DeepFoolAttack(m)
attack_config = {"iterations": 100, "overshoot": 9}
# use train data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in train_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# DeepFool non-targeted attack
adversary = attack(adversary, **attack_config)
# DeepFool targeted attack
# tlabel = 0
# adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TRAIN_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
# use test data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in test_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# DeepFool non-targeted attack
adversary = attack(adversary, **attack_config)
# DeepFool targeted attack
# tlabel = 0
# adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TEST_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
print("deelfool attack done")
if __name__ == '__main__':
main()
"""
FGSM tutorial on mnist using advbox tool.
FGSM method is non-targeted attack while FGSMT is targeted attack.
"""
import sys
sys.path.append("..")
import matplotlib.pyplot as plt
import numpy as np
import paddle.fluid as fluid
import paddle.v2 as paddle
from advbox.adversary import Adversary
from advbox.attacks.gradient_method import FGSM
from advbox.attacks.gradient_method import FGSMT
from advbox.models.paddle import PaddleModel
from tutorials.mnist_model import mnist_cnn_model
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
TOTAL_NUM = 500
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = mnist_cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
# use CPU
place = fluid.CPUPlace()
# use GPU
# place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.test(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(
fluid.default_main_program(),
IMG_NAME,
LABEL_NAME,
logits.name,
avg_cost.name, (-1, 1),
channel_axis=1)
attack = FGSM(m)
# attack = FGSMT(m)
attack_config = {"epsilons": 0.3}
# use train data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in train_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# FGSM non-targeted attack
adversary = attack(adversary, **attack_config)
# FGSMT targeted attack
# tlabel = 0
# adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TRAIN_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
# use test data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in test_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# FGSM non-targeted attack
adversary = attack(adversary, **attack_config)
# FGSMT targeted attack
# tlabel = 0
# adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TEST_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
print("fgsm attack done")
if __name__ == '__main__':
main()
"""
ILCM tutorial on mnist using advbox tool.
ILCM method extends "BIM" to support targeted attack.
"""
import sys
sys.path.append("..")
import matplotlib.pyplot as plt
import paddle.fluid as fluid
import paddle.v2 as paddle
from advbox.adversary import Adversary
from advbox.attacks.gradient_method import ILCM
from advbox.models.paddle import PaddleModel
from tutorials.mnist_model import mnist_cnn_model
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
TOTAL_NUM = 500
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = mnist_cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
# use CPU
place = fluid.CPUPlace()
# use GPU
# place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.test(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(
fluid.default_main_program(),
IMG_NAME,
LABEL_NAME,
logits.name,
avg_cost.name, (-1, 1),
channel_axis=1)
attack = ILCM(m)
attack_config = {"epsilons": 0.1, "steps": 100}
# use train data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in train_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
tlabel = 0
adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# ILCM targeted attack
adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TRAIN_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
# use test data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in test_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
tlabel = 0
adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# ILCM targeted attack
adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TEST_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
print("ilcm attack done")
if __name__ == '__main__':
main()
"""
JSMA tutorial on mnist using advbox tool.
JSMA method supports both targeted attack and non-targeted attack.
"""
import sys
sys.path.append("..")
import matplotlib.pyplot as plt
import paddle.fluid as fluid
import paddle.v2 as paddle
from advbox.adversary import Adversary
from advbox.attacks.saliency import JSMA
from advbox.models.paddle import PaddleModel
from tutorials.mnist_model import mnist_cnn_model
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
TOTAL_NUM = 500
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = mnist_cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
# use CPU
place = fluid.CPUPlace()
# use GPU
# place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.test(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(
fluid.default_main_program(),
IMG_NAME,
LABEL_NAME,
logits.name,
avg_cost.name, (-1, 1),
channel_axis=1)
attack = JSMA(m)
attack_config = {
"max_iter": 2000,
"theta": 0.1,
"max_perturbations_per_pixel": 7
}
# use train data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in train_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# JSMA non-targeted attack
adversary = attack(adversary, **attack_config)
# JSMA targeted attack
# tlabel = 0
# adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# adversary = attack(adversary, **attack_config)
# JSMA may return None
if adversary is not None and adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TRAIN_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
# use test data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in test_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# JSMA non-targeted attack
adversary = attack(adversary, **attack_config)
# JSMA targeted attack
# tlabel = 0
# adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# adversary = attack(adversary, **attack_config)
# JSMA may return None
if adversary is not None and adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TEST_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
print("jsma attack done")
if __name__ == '__main__':
main()
"""
LBFGS tutorial on mnist using advbox tool.
LBFGS method only supports targeted attack.
"""
import sys
sys.path.append("..")
import matplotlib.pyplot as plt
import paddle.fluid as fluid
import paddle.v2 as paddle
from advbox.adversary import Adversary
from advbox.attacks.lbfgs import LBFGS
from advbox.models.paddle import PaddleModel
from tutorials.mnist_model import mnist_cnn_model
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
TOTAL_NUM = 500
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = mnist_cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
# use CPU
place = fluid.CPUPlace()
# use GPU
# place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.test(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(
fluid.default_main_program(),
IMG_NAME,
LABEL_NAME,
logits.name,
avg_cost.name, (-1, 1),
channel_axis=1)
attack = LBFGS(m)
attack_config = {"epsilon": 0.001, }
# use train data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in train_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# LBFGS targeted attack
tlabel = 0
adversary.set_target(is_targeted_attack=True, target_label=tlabel)
adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TRAIN_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
# use test data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in test_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# LBFGS targeted attack
tlabel = 0
adversary.set_target(is_targeted_attack=True, target_label=tlabel)
adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TEST_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
print("lbfgs attack done")
if __name__ == '__main__':
main()
......@@ -30,6 +30,11 @@ class InferTaskConfig(object):
# the number of decoded sentences to output.
n_best = 1
# the flags indicating whether to output the special tokens.
output_bos = False
output_eos = False
output_unk = False
# the directory for loading the trained model.
model_path = "trained_models/pass_1.infer.model"
......@@ -55,6 +60,8 @@ class ModelHyperParams(object):
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2
# position value corresponding to the <pad> token.
pos_pad_idx = 0
......
......@@ -11,10 +11,26 @@ from config import InferTaskConfig, ModelHyperParams, \
from train import pad_batch_data
def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
decoder, dec_in_names, dec_out_names, beam_size, max_length,
n_best, batch_size, n_head, d_model, src_pad_idx,
trg_pad_idx, bos_idx, eos_idx):
def translate_batch(exe,
src_words,
encoder,
enc_in_names,
enc_out_names,
decoder,
dec_in_names,
dec_out_names,
beam_size,
max_length,
n_best,
batch_size,
n_head,
d_model,
src_pad_idx,
trg_pad_idx,
bos_idx,
eos_idx,
unk_idx,
output_unk=True):
"""
Run the encoder program once and run the decoder program multiple times to
implement beam search externally.
......@@ -58,7 +74,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
# Use active_beams to recode the alive.
active_beams = range(batch_size)
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True):
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size):
"""
Decode and select n_best sequences for one instance by backtrace.
"""
......@@ -70,7 +86,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
seq.append(next_ids[j][k])
k = prev_branchs[j][k]
seq = seq[::-1]
seq = [bos_idx] + seq if add_bos else seq
# Add the <bos>, since next_ids don't include the <bos>.
seq = [bos_idx] + seq
seqs.append(seq)
return seqs
......@@ -132,8 +149,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_cur_len = trg_slf_attn_bias.shape[-1] + 1
trg_words = np.array(
[
beam_backtrace(
prev_branchs[beam_idx], next_ids[beam_idx], add_bos=True)
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx])
for beam_idx in active_beams
],
dtype="int64")
......@@ -187,9 +203,11 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
predict_all = np.log(
predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1])
[:, -1, :])
predict_all = (predict_all + scores[active_beams].reshape(
predict_all = (predict_all + scores[beam_inst_map].reshape(
[len(beam_inst_map) * beam_size, -1])).reshape(
[len(beam_inst_map), beam_size, -1])
if not output_unk: # To exclude the <unk> token.
predict_all[:, :, unk_idx] = -1e9
active_beams = []
for beam_idx in range(batch_size):
if not beam_inst_map.has_key(beam_idx):
......@@ -283,17 +301,49 @@ def main():
trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = seq[:eos_pos + 1]
return filter(
lambda idx: (output_bos or idx != bos_idx) and \
(output_eos or idx != eos_idx),
seq)
for batch_id, data in enumerate(test_data()):
batch_seqs, batch_scores = translate_batch(
exe, [item[0] for item in data], encoder_program,
encoder_input_data_names, [enc_output.name], decoder_program,
decoder_input_data_names, [predict.name], InferTaskConfig.beam_size,
InferTaskConfig.max_length, InferTaskConfig.n_best,
len(data), ModelHyperParams.n_head, ModelHyperParams.d_model,
ModelHyperParams.src_pad_idx, ModelHyperParams.trg_pad_idx,
ModelHyperParams.bos_idx, ModelHyperParams.eos_idx)
exe, [item[0] for item in data],
encoder_program,
encoder_input_data_names, [enc_output.name],
decoder_program,
decoder_input_data_names, [predict.name],
InferTaskConfig.beam_size,
InferTaskConfig.max_length,
InferTaskConfig.n_best,
len(data),
ModelHyperParams.n_head,
ModelHyperParams.d_model,
ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx,
ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx,
ModelHyperParams.unk_idx,
output_unk=InferTaskConfig.output_unk)
for i in range(len(batch_seqs)):
seqs = batch_seqs[i]
# Post-process the beam-search decoded sequences.
seqs = map(post_process_seq, batch_seqs[i])
scores = batch_scores[i]
for seq in seqs:
print(" ".join([trg_idx2word[idx] for idx in seq]))
......
......@@ -45,7 +45,7 @@ class PolicyGradient:
label=acts) # this is negative log of chosen action
neg_log_prob_weight = fluid.layers.elementwise_mul(x=neg_log_prob, y=vt)
loss = fluid.layers.reduce_mean(
x=neg_log_prob_weight) # reward guided loss
neg_log_prob_weight) # reward guided loss
sgd_optimizer = fluid.optimizer.SGD(self.lr)
sgd_optimizer.minimize(loss)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册