提交 343b32a0 编写于 作者: G gx_wind

fix coding standard

上级 8e8e5a89
...@@ -4,6 +4,7 @@ CNN on mnist data using fluid api of paddlepaddle ...@@ -4,6 +4,7 @@ CNN on mnist data using fluid api of paddlepaddle
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.fluid as fluid import paddle.v2.fluid as fluid
def mnist_cnn_model(img): def mnist_cnn_model(img):
""" """
Mnist cnn model Mnist cnn model
...@@ -31,10 +32,7 @@ def mnist_cnn_model(img): ...@@ -31,10 +32,7 @@ def mnist_cnn_model(img):
pool_stride=2, pool_stride=2,
act='relu') act='relu')
logits = fluid.layers.fc( logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
input=conv_pool_2,
size=10,
act='softmax')
return logits return logits
...@@ -73,17 +71,19 @@ def main(): ...@@ -73,17 +71,19 @@ def main():
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[avg_cost] + accuracy.metrics) fetch_list=[avg_cost] + accuracy.metrics)
pass_acc = accuracy.eval(exe) pass_acc = accuracy.eval(exe)
print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc=" + print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc="
str(pass_acc)) + str(pass_acc))
# print loss, acc # print loss, acc
if loss < LOSS_THRESHOLD and pass_acc > ACC_THRESHOLD: if loss < LOSS_THRESHOLD and pass_acc > ACC_THRESHOLD:
# if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good. # if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good.
break break
# exit(0) # exit(0)
pass_acc = accuracy.eval(exe) pass_acc = accuracy.eval(exe)
print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc)) print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc))
fluid.io.save_params(exe, dirname='./mnist', main_program=fluid.default_main_program()) fluid.io.save_params(
exe, dirname='./mnist', main_program=fluid.default_main_program())
print('train mnist done') print('train mnist done')
exit(1) exit(1)
......
...@@ -9,6 +9,7 @@ import numpy as np ...@@ -9,6 +9,7 @@ import numpy as np
from advbox.models.paddle import PaddleModel from advbox.models.paddle import PaddleModel
from advbox.attacks.gradientsign import GradientSignAttack from advbox.attacks.gradientsign import GradientSignAttack
def cnn_model(img): def cnn_model(img):
""" """
Mnist cnn model Mnist cnn model
...@@ -19,25 +20,22 @@ def cnn_model(img): ...@@ -19,25 +20,22 @@ def cnn_model(img):
""" """
#conv1 = fluid.nets.conv2d() #conv1 = fluid.nets.conv2d()
conv_pool_1 = fluid.nets.simple_img_conv_pool( conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img, input=img,
num_filters=20, num_filters=20,
filter_size=5, filter_size=5,
pool_size=2, pool_size=2,
pool_stride=2, pool_stride=2,
act='relu') act='relu')
conv_pool_2 = fluid.nets.simple_img_conv_pool( conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1, input=conv_pool_1,
num_filters=50, num_filters=50,
filter_size=5, filter_size=5,
pool_size=2, pool_size=2,
pool_stride=2, pool_stride=2,
act='relu') act='relu')
logits = fluid.layers.fc( logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
input=conv_pool_2,
size=10,
act='softmax')
return logits return logits
...@@ -65,22 +63,16 @@ def main(): ...@@ -65,22 +63,16 @@ def main():
paddle.dataset.mnist.train(), buf_size=500), paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
feeder = fluid.DataFeeder( feeder = fluid.DataFeeder(
feed_list=[IMG_NAME, LABEL_NAME], feed_list=[IMG_NAME, LABEL_NAME],
place=place, place=place,
program=fluid.default_main_program() program=fluid.default_main_program())
)
fluid.io.load_params(exe, "./mnist/", main_program=fluid.default_main_program()) fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo # advbox demo
m = PaddleModel( m = PaddleModel(fluid.default_main_program(), IMG_NAME, LABEL_NAME,
fluid.default_main_program(), logits.name, avg_cost.name, (-1, 1))
IMG_NAME,
LABEL_NAME,
logits.name,
avg_cost.name,
(-1, 1)
)
att = GradientSignAttack(m) att = GradientSignAttack(m)
for data in train_reader(): for data in train_reader():
# fgsm attack # fgsm attack
...@@ -89,6 +81,7 @@ def main(): ...@@ -89,6 +81,7 @@ def main():
plt.show() plt.show()
#np.save('adv_img', adv_img) #np.save('adv_img', adv_img)
break break
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册