Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
b24c68e2
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b24c68e2
编写于
1月 16, 2018
作者:
G
gx_wind
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move advbox to paddle model directory
上级
97c27ba8
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
561 addition
and
0 deletion
+561
-0
fluid/adversarial/README.md
fluid/adversarial/README.md
+9
-0
fluid/adversarial/advbox/__init__.py
fluid/adversarial/advbox/__init__.py
+16
-0
fluid/adversarial/advbox/attacks/base.py
fluid/adversarial/advbox/attacks/base.py
+52
-0
fluid/adversarial/advbox/attacks/gradientsign.py
fluid/adversarial/advbox/attacks/gradientsign.py
+51
-0
fluid/adversarial/advbox/models/__init__.py
fluid/adversarial/advbox/models/__init__.py
+16
-0
fluid/adversarial/advbox/models/base.py
fluid/adversarial/advbox/models/base.py
+103
-0
fluid/adversarial/advbox/models/paddle.py
fluid/adversarial/advbox/models/paddle.py
+116
-0
fluid/adversarial/fluid_mnist.py
fluid/adversarial/fluid_mnist.py
+99
-0
fluid/adversarial/mnist_tutorial_fgsm.py
fluid/adversarial/mnist_tutorial_fgsm.py
+99
-0
未找到文件。
fluid/adversarial/README.md
0 → 100644
浏览文件 @
b24c68e2
# Advbox
Advbox is a Python toolbox to create adversarial examples that fool neural networks. It requires Python and paddle.
## How to use
1.
train a model and save it's parameters. (like fluid_mnist.py)
2.
load the parameters which is trained in step1, then reconstruct the model.(like mnist_tutorial_fgsm.py)
3.
use advbox to generate the adversarial sample.
fluid/adversarial/advbox/__init__.py
0 → 100644
浏览文件 @
b24c68e2
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A set of tools for generating adversarial example on paddle platform
"""
fluid/adversarial/advbox/attacks/base.py
0 → 100644
浏览文件 @
b24c68e2
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
The base model of the model.
"""
from
abc
import
ABCMeta
,
abstractmethod
class
Attack
(
object
):
"""
Abstract base class for adversarial attacks. `Attack` represent an adversarial attack
which search an adversarial example. subclass should implement the _apply() method.
Args:
model(Model): an instance of the class advbox.base.Model.
"""
__metaclass__
=
ABCMeta
def
__init__
(
self
,
model
):
self
.
model
=
model
def
__call__
(
self
,
image_label
):
"""
Generate the adversarial sample.
Args:
image_label(list): The image and label tuple list with one element.
"""
adv_img
=
self
.
_apply
(
image_label
)
return
adv_img
@
abstractmethod
def
_apply
(
self
,
image_label
):
"""
Search an adversarial example.
Args:
image_batch(list): The image and label tuple list with one element.
"""
raise
NotImplementedError
fluid/adversarial/advbox/attacks/gradientsign.py
0 → 100644
浏览文件 @
b24c68e2
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
This module provide the attack method for FGSM's implement.
"""
from
__future__
import
division
import
numpy
as
np
from
collections
import
Iterable
from
.base
import
Attack
class
GradientSignAttack
(
Attack
):
"""
This attack was originally implemented by Goodfellow et al. (2015) with the
infinity norm (and is known as the "Fast Gradient Sign Method"). This is therefore called
the Fast Gradient Method.
Paper link: https://arxiv.org/abs/1412.6572
"""
def
_apply
(
self
,
image_label
,
epsilons
=
1000
):
assert
len
(
image_label
)
==
1
pre_label
=
np
.
argmax
(
self
.
model
.
predict
(
image_label
))
min_
,
max_
=
self
.
model
.
bounds
()
gradient
=
self
.
model
.
gradient
(
image_label
)
gradient_sign
=
np
.
sign
(
gradient
)
*
(
max_
-
min_
)
if
not
isinstance
(
epsilons
,
Iterable
):
epsilons
=
np
.
linspace
(
0
,
1
,
num
=
epsilons
+
1
)
for
epsilon
in
epsilons
:
adv_img
=
image_label
[
0
][
0
].
reshape
(
gradient_sign
.
shape
)
+
epsilon
*
gradient_sign
adv_img
=
np
.
clip
(
adv_img
,
min_
,
max_
)
adv_label
=
np
.
argmax
(
self
.
model
.
predict
([(
adv_img
,
0
)]))
if
pre_label
!=
adv_label
:
return
adv_img
FGSM
=
GradientSignAttack
fluid/adversarial/advbox/models/__init__.py
0 → 100644
浏览文件 @
b24c68e2
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Paddle model for target of attack
"""
fluid/adversarial/advbox/models/base.py
0 → 100644
浏览文件 @
b24c68e2
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
The base model of the model.
"""
from
abc
import
ABCMeta
import
abc
abstractmethod
=
abc
.
abstractmethod
class
Model
(
object
):
"""
Base class of model to provide attack.
Args:
bounds(tuple): The lower and upper bound for the image pixel.
channel_axis(int): The index of the axis that represents the color channel.
preprocess(tuple): Two element tuple used to preprocess the input. First
substract the first element, then divide the second element.
"""
__metaclass__
=
ABCMeta
def
__init__
(
self
,
bounds
,
channel_axis
,
preprocess
=
None
):
assert
len
(
bounds
)
==
2
assert
channel_axis
in
[
0
,
1
,
2
,
3
]
if
preprocess
is
None
:
preprocess
=
(
0
,
1
)
self
.
_bounds
=
bounds
self
.
_channel_axis
=
channel_axis
self
.
_preprocess
=
preprocess
def
bounds
(
self
):
"""
Return the upper and lower bounds of the model.
"""
return
self
.
_bounds
def
channel_axis
(
self
):
"""
Return the channel axis of the model.
"""
return
self
.
_channel_axis
def
_process_input
(
self
,
input_
):
res
=
input_
sub
,
div
=
self
.
_preprocess
if
sub
!=
0
:
res
=
input_
-
sub
assert
div
!=
0
if
div
!=
1
:
res
/=
div
return
res
@
abstractmethod
def
predict
(
self
,
image_batch
):
"""
Calculate the prediction of the image batch.
Args:
image_batch(numpy.ndarray): image batch of shape (batch_size, height, width, channels).
Return:
numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes).
"""
raise
NotImplementedError
@
abstractmethod
def
num_classes
(
self
):
"""
Determine the number of the classes
Return:
int: the number of the classes
"""
raise
NotImplementedError
@
abstractmethod
def
gradient
(
self
,
image_batch
):
"""
Calculate the gradient of the cross-entropy loss w.r.t the image.
Args:
image_batch(list): The image and label tuple list.
Return:
numpy.ndarray: gradient of the cross-entropy loss w.r.t the image with
the shape (height, width, channel).
"""
raise
NotImplementedError
fluid/adversarial/advbox/models/paddle.py
0 → 100644
浏览文件 @
b24c68e2
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
import
numpy
as
np
import
paddle.v2
as
paddle
import
paddle.v2.fluid
as
fluid
from
paddle.v2.fluid.framework
import
program_guard
from
.base
import
Model
class
PaddleModel
(
Model
):
"""
Create a PaddleModel instance.
When you need to generate a adversarial sample, you should construct an instance of PaddleModel.
Args:
program(paddle.v2.fluid.framework.Program): The program of the model which generate the adversarial sample.
input_name(string): The name of the input.
logits_name(string): The name of the logits.
predict_name(string): The name of the predict.
cost_name(string): The name of the loss in the program.
"""
def
__init__
(
self
,
program
,
input_name
,
logits_name
,
predict_name
,
cost_name
,
bounds
,
channel_axis
=
3
,
preprocess
=
None
):
super
(
PaddleModel
,
self
).
__init__
(
bounds
=
bounds
,
channel_axis
=
channel_axis
,
preprocess
=
preprocess
)
if
preprocess
is
None
:
preprocess
=
(
0
,
1
)
self
.
_program
=
program
self
.
_place
=
fluid
.
CPUPlace
()
self
.
_exe
=
fluid
.
Executor
(
self
.
_place
)
self
.
_input_name
=
input_name
self
.
_logits_name
=
logits_name
self
.
_predict_name
=
predict_name
self
.
_cost_name
=
cost_name
# gradient
loss
=
self
.
_program
.
block
(
0
).
var
(
self
.
_cost_name
)
param_grads
=
fluid
.
backward
.
append_backward
(
loss
,
parameter_list
=
[
self
.
_input_name
])
self
.
_gradient
=
dict
(
param_grads
)[
self
.
_input_name
]
def
predict
(
self
,
image_batch
):
"""
Predict the label of the image_batch.
Args:
image_batch(list): The image and label tuple list.
Return:
numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes).
"""
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
self
.
_input_name
,
self
.
_logits_name
],
place
=
self
.
_place
,
program
=
self
.
_program
)
predict_var
=
self
.
_program
.
block
(
0
).
var
(
self
.
_predict_name
)
predict
=
self
.
_exe
.
run
(
self
.
_program
,
feed
=
feeder
.
feed
(
image_batch
),
fetch_list
=
[
predict_var
])
return
predict
def
num_classes
(
self
):
"""
Calculate the number of classes of the output label.
Return:
int: the number of classes
"""
predict_var
=
self
.
_program
.
block
(
0
).
var
(
self
.
_predict_name
)
assert
len
(
predict_var
.
shape
)
==
2
return
predict_var
.
shape
[
1
]
def
gradient
(
self
,
image_batch
):
"""
Calculate the gradient of the loss w.r.t the input.
Args:
image_batch(list): The image and label tuple list.
Return:
list: The list of the gradient of the image.
"""
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
self
.
_input_name
,
self
.
_logits_name
],
place
=
self
.
_place
,
program
=
self
.
_program
)
grad
,
=
self
.
_exe
.
run
(
self
.
_program
,
feed
=
feeder
.
feed
(
image_batch
),
fetch_list
=
[
self
.
_gradient
])
return
grad
fluid/adversarial/fluid_mnist.py
0 → 100644
浏览文件 @
b24c68e2
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
CNN on mnist data using fluid api of paddlepaddle
"""
import
paddle.v2
as
paddle
import
paddle.v2.fluid
as
fluid
def
mnist_cnn_model
(
img
):
"""
Mnist cnn model
Args:
img(Varaible): the input image to be recognized
Returns:
Variable: the label prediction
"""
conv_pool_1
=
fluid
.
nets
.
simple_img_conv_pool
(
input
=
img
,
num_filters
=
20
,
filter_size
=
5
,
pool_size
=
2
,
pool_stride
=
2
,
act
=
'relu'
)
conv_pool_2
=
fluid
.
nets
.
simple_img_conv_pool
(
input
=
conv_pool_1
,
num_filters
=
50
,
filter_size
=
5
,
pool_size
=
2
,
pool_stride
=
2
,
act
=
'relu'
)
logits
=
fluid
.
layers
.
fc
(
input
=
conv_pool_2
,
size
=
10
,
act
=
'softmax'
)
return
logits
def
main
():
"""
Train the cnn model on mnist datasets
"""
img
=
fluid
.
layers
.
data
(
name
=
'img'
,
shape
=
[
1
,
28
,
28
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
logits
=
mnist_cnn_model
(
img
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
logits
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.01
)
optimizer
.
minimize
(
avg_cost
)
accuracy
=
fluid
.
evaluator
.
Accuracy
(
input
=
logits
,
label
=
label
)
BATCH_SIZE
=
50
PASS_NUM
=
3
ACC_THRESHOLD
=
0.98
LOSS_THRESHOLD
=
10.0
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
mnist
.
train
(),
buf_size
=
500
),
batch_size
=
BATCH_SIZE
)
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
img
,
label
],
place
=
place
)
exe
.
run
(
fluid
.
default_startup_program
())
for
pass_id
in
range
(
PASS_NUM
):
accuracy
.
reset
(
exe
)
for
data
in
train_reader
():
loss
,
acc
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
avg_cost
]
+
accuracy
.
metrics
)
pass_acc
=
accuracy
.
eval
(
exe
)
print
(
"pass_id="
+
str
(
pass_id
)
+
" acc="
+
str
(
acc
)
+
" pass_acc="
+
str
(
pass_acc
))
if
loss
<
LOSS_THRESHOLD
and
pass_acc
>
ACC_THRESHOLD
:
break
pass_acc
=
accuracy
.
eval
(
exe
)
print
(
"pass_id="
+
str
(
pass_id
)
+
" pass_acc="
+
str
(
pass_acc
))
fluid
.
io
.
save_params
(
exe
,
dirname
=
'./mnist'
,
main_program
=
fluid
.
default_main_program
())
print
(
'train mnist done'
)
if
__name__
==
'__main__'
:
main
()
fluid/adversarial/mnist_tutorial_fgsm.py
0 → 100644
浏览文件 @
b24c68e2
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
FGSM demos on mnist using advbox tool.
"""
import
paddle.v2
as
paddle
import
paddle.v2.fluid
as
fluid
import
matplotlib.pyplot
as
plt
import
numpy
as
np
from
advbox.models.paddle
import
PaddleModel
from
advbox.attacks.gradientsign
import
GradientSignAttack
def
cnn_model
(
img
):
"""
Mnist cnn model
Args:
img(Varaible): the input image to be recognized
Returns:
Variable: the label prediction
"""
#conv1 = fluid.nets.conv2d()
conv_pool_1
=
fluid
.
nets
.
simple_img_conv_pool
(
input
=
img
,
num_filters
=
20
,
filter_size
=
5
,
pool_size
=
2
,
pool_stride
=
2
,
act
=
'relu'
)
conv_pool_2
=
fluid
.
nets
.
simple_img_conv_pool
(
input
=
conv_pool_1
,
num_filters
=
50
,
filter_size
=
5
,
pool_size
=
2
,
pool_stride
=
2
,
act
=
'relu'
)
logits
=
fluid
.
layers
.
fc
(
input
=
conv_pool_2
,
size
=
10
,
act
=
'softmax'
)
return
logits
def
main
():
"""
Advbox demo which demonstrate how to use advbox.
"""
IMG_NAME
=
'img'
LABEL_NAME
=
'label'
img
=
fluid
.
layers
.
data
(
name
=
IMG_NAME
,
shape
=
[
1
,
28
,
28
],
dtype
=
'float32'
)
# gradient should flow
img
.
stop_gradient
=
False
label
=
fluid
.
layers
.
data
(
name
=
LABEL_NAME
,
shape
=
[
1
],
dtype
=
'int64'
)
logits
=
cnn_model
(
img
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
logits
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
BATCH_SIZE
=
1
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
mnist
.
train
(),
buf_size
=
500
),
batch_size
=
BATCH_SIZE
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
IMG_NAME
,
LABEL_NAME
],
place
=
place
,
program
=
fluid
.
default_main_program
())
fluid
.
io
.
load_params
(
exe
,
"./mnist/"
,
main_program
=
fluid
.
default_main_program
())
# advbox demo
m
=
PaddleModel
(
fluid
.
default_main_program
(),
IMG_NAME
,
LABEL_NAME
,
logits
.
name
,
avg_cost
.
name
,
(
-
1
,
1
))
att
=
GradientSignAttack
(
m
)
for
data
in
train_reader
():
# fgsm attack
adv_img
=
att
(
data
)
plt
.
imshow
(
n
[
0
][
0
],
cmap
=
'Greys_r'
)
plt
.
show
()
#np.save('adv_img', adv_img)
break
if
__name__
==
'__main__'
:
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录