提交 95e30f35 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!102 Avoid error of graph topological order

Merge pull request !102 from pkuliuliu/master
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""defense example using nad"""
import os
import sys
import numpy as np
......@@ -19,41 +20,43 @@ from mindspore import Tensor
from mindspore import context
from mindspore import nn
from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train import Model
from mindspore.train.callback import LossMonitor
from lenet5_net import LeNet5
from mindarmour.attacks import FastGradientSignMethod
from mindarmour.defenses import NaturalAdversarialDefense
from mindarmour.utils.logger import LogUtil
from lenet5_net import LeNet5
sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
LOGGER.set_level("INFO")
TAG = 'Nad_Example'
def test_nad_method():
"""
NAD-Defense test for CPU device.
NAD-Defense test.
"""
# 1. load trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
mnist_path = "./MNIST_unzip/"
batch_size = 32
# 1. train original model
ds_train = generate_mnist_dataset(os.path.join(mnist_path, "train"),
batch_size=batch_size, repeat_size=1)
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
opt = nn.Momentum(net.trainable_params(), 0.01, 0.09)
nad = NaturalAdversarialDefense(net, loss_fn=loss, optimizer=opt,
bounds=(0.0, 1.0), eps=0.3)
model = Model(net, loss, opt, metrics=None)
model.train(10, ds_train, callbacks=[LossMonitor()],
dataset_sink_mode=False)
# 2. get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds_test = generate_mnist_dataset(data_list, batch_size=batch_size)
ds_test = generate_mnist_dataset(os.path.join(mnist_path, "test"),
batch_size=batch_size, repeat_size=1)
inputs = []
labels = []
for data in ds_test.create_tuple_iterator():
......@@ -73,16 +76,15 @@ def test_nad_method():
label_pred = np.argmax(logits, axis=1)
acc_list.append(np.mean(batch_labels == label_pred))
LOGGER.debug(TAG, 'accuracy of TEST data on original model is : %s',
LOGGER.info(TAG, 'accuracy of TEST data on original model is : %s',
np.mean(acc_list))
# 4. get adv of test data
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss)
adv_data = attack.batch_generate(inputs, labels)
LOGGER.debug(TAG, 'adv_data.shape is : %s', adv_data.shape)
LOGGER.info(TAG, 'adv_data.shape is : %s', adv_data.shape)
# 5. get accuracy of adv data on original model
net.set_train(False)
acc_list = []
batchs = adv_data.shape[0] // batch_size
for i in range(batchs):
......@@ -92,11 +94,13 @@ def test_nad_method():
label_pred = np.argmax(logits, axis=1)
acc_list.append(np.mean(batch_labels == label_pred))
LOGGER.debug(TAG, 'accuracy of adv data on original model is : %s',
LOGGER.info(TAG, 'accuracy of adv data on original model is : %s',
np.mean(acc_list))
# 6. defense
net.set_train()
nad = NaturalAdversarialDefense(net, loss_fn=loss, optimizer=opt,
bounds=(0.0, 1.0), eps=0.3)
nad.batch_defense(inputs, labels, batch_size=32, epochs=10)
# 7. get accuracy of test data on defensed model
......@@ -110,7 +114,7 @@ def test_nad_method():
label_pred = np.argmax(logits, axis=1)
acc_list.append(np.mean(batch_labels == label_pred))
LOGGER.debug(TAG, 'accuracy of TEST data on defensed model is : %s',
LOGGER.info(TAG, 'accuracy of TEST data on defensed model is : %s',
np.mean(acc_list))
# 8. get accuracy of adv data on defensed model
......@@ -123,11 +127,11 @@ def test_nad_method():
label_pred = np.argmax(logits, axis=1)
acc_list.append(np.mean(batch_labels == label_pred))
LOGGER.debug(TAG, 'accuracy of adv data on defensed model is : %s',
LOGGER.info(TAG, 'accuracy of adv data on defensed model is : %s',
np.mean(acc_list))
if __name__ == '__main__':
# device_target can be "CPU", "GPU" or "Ascend"
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_nad_method()
......@@ -136,6 +136,7 @@ class AdversarialDefenseWithAttacks(AdversarialDefense):
self._replace_ratio = check_param_in_range('replace_ratio',
replace_ratio,
0, 1)
self._graph_initialized = False
def defense(self, inputs, labels):
"""
......@@ -150,6 +151,9 @@ class AdversarialDefenseWithAttacks(AdversarialDefense):
"""
inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels',
labels)
if not self._graph_initialized:
self._train_net(Tensor(inputs), Tensor(labels))
self._graph_initialized = True
x_len = inputs.shape[0]
n_adv = int(np.ceil(self._replace_ratio*x_len))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册