提交 36c25d9f 编写于 作者: P pkuliuliu

Avoid error of graph topological order

上级 135e7a82
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""defense example using nad""" """defense example using nad"""
import os
import sys import sys
import numpy as np import numpy as np
...@@ -19,41 +20,43 @@ from mindspore import Tensor ...@@ -19,41 +20,43 @@ from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore import nn from mindspore import nn
from mindspore.nn import SoftmaxCrossEntropyWithLogits from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train import Model
from mindspore.train.callback import LossMonitor
from lenet5_net import LeNet5
from mindarmour.attacks import FastGradientSignMethod from mindarmour.attacks import FastGradientSignMethod
from mindarmour.defenses import NaturalAdversarialDefense from mindarmour.defenses import NaturalAdversarialDefense
from mindarmour.utils.logger import LogUtil from mindarmour.utils.logger import LogUtil
from lenet5_net import LeNet5
sys.path.append("..") sys.path.append("..")
from data_processing import generate_mnist_dataset from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance() LOGGER = LogUtil.get_instance()
LOGGER.set_level("INFO")
TAG = 'Nad_Example' TAG = 'Nad_Example'
def test_nad_method(): def test_nad_method():
""" """
NAD-Defense test for CPU device. NAD-Defense test.
""" """
# 1. load trained network mnist_path = "./MNIST_unzip/"
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' batch_size = 32
# 1. train original model
ds_train = generate_mnist_dataset(os.path.join(mnist_path, "train"),
batch_size=batch_size, repeat_size=1)
net = LeNet5() net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
opt = nn.Momentum(net.trainable_params(), 0.01, 0.09) opt = nn.Momentum(net.trainable_params(), 0.01, 0.09)
model = Model(net, loss, opt, metrics=None)
nad = NaturalAdversarialDefense(net, loss_fn=loss, optimizer=opt, model.train(10, ds_train, callbacks=[LossMonitor()],
bounds=(0.0, 1.0), eps=0.3) dataset_sink_mode=False)
# 2. get test data # 2. get test data
data_list = "./MNIST_unzip/test" ds_test = generate_mnist_dataset(os.path.join(mnist_path, "test"),
batch_size = 32 batch_size=batch_size, repeat_size=1)
ds_test = generate_mnist_dataset(data_list, batch_size=batch_size)
inputs = [] inputs = []
labels = [] labels = []
for data in ds_test.create_tuple_iterator(): for data in ds_test.create_tuple_iterator():
...@@ -73,16 +76,15 @@ def test_nad_method(): ...@@ -73,16 +76,15 @@ def test_nad_method():
label_pred = np.argmax(logits, axis=1) label_pred = np.argmax(logits, axis=1)
acc_list.append(np.mean(batch_labels == label_pred)) acc_list.append(np.mean(batch_labels == label_pred))
LOGGER.debug(TAG, 'accuracy of TEST data on original model is : %s', LOGGER.info(TAG, 'accuracy of TEST data on original model is : %s',
np.mean(acc_list)) np.mean(acc_list))
# 4. get adv of test data # 4. get adv of test data
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss)
adv_data = attack.batch_generate(inputs, labels) adv_data = attack.batch_generate(inputs, labels)
LOGGER.debug(TAG, 'adv_data.shape is : %s', adv_data.shape) LOGGER.info(TAG, 'adv_data.shape is : %s', adv_data.shape)
# 5. get accuracy of adv data on original model # 5. get accuracy of adv data on original model
net.set_train(False)
acc_list = [] acc_list = []
batchs = adv_data.shape[0] // batch_size batchs = adv_data.shape[0] // batch_size
for i in range(batchs): for i in range(batchs):
...@@ -92,11 +94,13 @@ def test_nad_method(): ...@@ -92,11 +94,13 @@ def test_nad_method():
label_pred = np.argmax(logits, axis=1) label_pred = np.argmax(logits, axis=1)
acc_list.append(np.mean(batch_labels == label_pred)) acc_list.append(np.mean(batch_labels == label_pred))
LOGGER.debug(TAG, 'accuracy of adv data on original model is : %s', LOGGER.info(TAG, 'accuracy of adv data on original model is : %s',
np.mean(acc_list)) np.mean(acc_list))
# 6. defense # 6. defense
net.set_train() net.set_train()
nad = NaturalAdversarialDefense(net, loss_fn=loss, optimizer=opt,
bounds=(0.0, 1.0), eps=0.3)
nad.batch_defense(inputs, labels, batch_size=32, epochs=10) nad.batch_defense(inputs, labels, batch_size=32, epochs=10)
# 7. get accuracy of test data on defensed model # 7. get accuracy of test data on defensed model
...@@ -110,8 +114,8 @@ def test_nad_method(): ...@@ -110,8 +114,8 @@ def test_nad_method():
label_pred = np.argmax(logits, axis=1) label_pred = np.argmax(logits, axis=1)
acc_list.append(np.mean(batch_labels == label_pred)) acc_list.append(np.mean(batch_labels == label_pred))
LOGGER.debug(TAG, 'accuracy of TEST data on defensed model is : %s', LOGGER.info(TAG, 'accuracy of TEST data on defensed model is : %s',
np.mean(acc_list)) np.mean(acc_list))
# 8. get accuracy of adv data on defensed model # 8. get accuracy of adv data on defensed model
acc_list = [] acc_list = []
...@@ -123,11 +127,11 @@ def test_nad_method(): ...@@ -123,11 +127,11 @@ def test_nad_method():
label_pred = np.argmax(logits, axis=1) label_pred = np.argmax(logits, axis=1)
acc_list.append(np.mean(batch_labels == label_pred)) acc_list.append(np.mean(batch_labels == label_pred))
LOGGER.debug(TAG, 'accuracy of adv data on defensed model is : %s', LOGGER.info(TAG, 'accuracy of adv data on defensed model is : %s',
np.mean(acc_list)) np.mean(acc_list))
if __name__ == '__main__': if __name__ == '__main__':
# device_target can be "CPU", "GPU" or "Ascend" # device_target can be "CPU", "GPU" or "Ascend"
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_nad_method() test_nad_method()
...@@ -136,6 +136,7 @@ class AdversarialDefenseWithAttacks(AdversarialDefense): ...@@ -136,6 +136,7 @@ class AdversarialDefenseWithAttacks(AdversarialDefense):
self._replace_ratio = check_param_in_range('replace_ratio', self._replace_ratio = check_param_in_range('replace_ratio',
replace_ratio, replace_ratio,
0, 1) 0, 1)
self._graph_initialized = False
def defense(self, inputs, labels): def defense(self, inputs, labels):
""" """
...@@ -150,6 +151,9 @@ class AdversarialDefenseWithAttacks(AdversarialDefense): ...@@ -150,6 +151,9 @@ class AdversarialDefenseWithAttacks(AdversarialDefense):
""" """
inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels', inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels',
labels) labels)
if not self._graph_initialized:
self._train_net(Tensor(inputs), Tensor(labels))
self._graph_initialized = True
x_len = inputs.shape[0] x_len = inputs.shape[0]
n_adv = int(np.ceil(self._replace_ratio*x_len)) n_adv = int(np.ceil(self._replace_ratio*x_len))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册