提交 96302d7f 编写于 作者: Z zheng-huanhuan

initial version

上级
<!-- Thanks for sending a pull request! Here are some tips for you:
If this is your first time, please read our contributor guidelines: https://gitee.com/mindspore/mindspore/blob/master/CONTRIBUTING.md
-->
**What type of PR is this?**
> Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line:
>
> /kind bug
> /kind task
> /kind feature
**What this PR does / why we need it**:
**Which issue(s) this PR fixes**:
<!--
*Automatically closes linked issue when PR is merged.
Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`.
-->
Fixes #
**Special notes for your reviewer**:
*.dot
*.ir
*.dat
*.pyc
*.csv
*.gz
*.tar
*.zip
*.rar
*.ipynb
.idea/
build/
dist/
local_script/
example/dataset/
example/mnist_demo/MNIST_unzip/
example/mnist_demo/trained_ckpt_file/
example/mnist_demo/model/
example/cifar_demo/model/
example/dog_cat_demo/model/
mindarmour.egg-info/
*model/
*MNIST/
*out.data/
*defensed_model/
*pre_trained_model/
*__pycache__/
*kernel_meta
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
MindSpore MindArmour
Copyright 2019-2020 Huawei Technologies Co., Ltd
\ No newline at end of file
# MindArmour
- [What is MindArmour](#what-is-mindarmour)
- [Setting up](#setting-up-mindarmour)
- [Docs](#docs)
- [Community](#community)
- [Contributing](#contributing)
- [Release Notes](#release-notes)
- [License](#license)
## What is MindArmour
A tool box for MindSpore users to enhance model security and trustworthiness.
MindArmour is designed for adversarial examples, including four submodule: adversarial examples generation, adversarial example detection, model defense and evaluation. The architecture is shown as follow:
![mindarmour_architecture](docs/mindarmour_architecture.png)
## Setting up MindArmour
### Dependencies
This library uses MindSpore to accelerate graph computations performed by many machine learning models. Therefore, installing MindSpore is a pre-requisite. All other dependencies are included in `setup.py`.
### Installation
#### Installation for development
1. Download source code from Gitee.
```bash
git clone https://gitee.com/mindspore/mindarmour.git
```
2. Compile and install in MindArmour directory.
```bash
$ cd mindarmour
$ python setup.py install
```
#### `Pip` installation
1. Download whl package from [MindSpore website](https://www.mindspore.cn/versions/en), then run the following command:
```
pip install mindarmour-{version}-cp37-cp37m-linux_{arch}.whl
```
2. Successfully installed, if there is no error message such as `No module named 'mindarmour'` when execute the following command:
```bash
python -c 'import mindarmour'
```
## Docs
Guidance on installation, tutorials, API, see our [User Documentation](https://gitee.com/mindspore/docs).
## Community
- [MindSpore Slack](https://join.slack.com/t/mindspore/shared_invite/enQtOTcwMTIxMDI3NjM0LTNkMWM2MzI5NjIyZWU5ZWQ5M2EwMTQ5MWNiYzMxOGM4OWFhZjI4M2E5OGI2YTg3ODU1ODE2Njg1MThiNWI3YmQ) - Ask questions and find answers.
## Contributing
Welcome contributions. See our [Contributor Wiki](https://gitee.com/mindspore/mindspore/blob/master/CONTRIBUTING.md) for more details.
## Release Notes
The release notes, see our [RELEASE](RELEASE.md).
## License
[Apache License 2.0](LICENSE)
# Release 0.1.0-alpha
Initial release of MindArmour.
## Major Features
- Support adversarial attack and defense on the platform of MindSpore.
- Include 13 white-box and 7 black-box attack methods.
- Provide 5 detection algorithms to detect attacking in multiple way.
- Provide adversarial training to enhance model security.
- Provide 6 evaluation metrics for attack methods and 9 evaluation metrics for defense methods.
\ No newline at end of file
# MindArmour Documentation
The MindArmour documentation is in the [MindSpore Docs](https://gitee.com/mindspore/docs) repository.
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.transforms.vision import Inter
import mindspore.common.dtype as mstype
def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1, sparse=True):
"""
create dataset for training or testing
"""
# define dataset
ds1 = ds.MnistDataset(data_path)
# define operation parameters
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
# define map operations
resize_op = CV.Resize((resize_height, resize_width),
interpolation=Inter.LINEAR)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
one_hot_enco = C.OneHot(10)
# apply map operations on images
if not sparse:
ds1 = ds1.map(input_columns="label", operations=one_hot_enco,
num_parallel_workers=num_parallel_workers)
type_cast_op = C.TypeCast(mstype.float32)
ds1 = ds1.map(input_columns="label", operations=type_cast_op,
num_parallel_workers=num_parallel_workers)
ds1 = ds1.map(input_columns="image", operations=resize_op,
num_parallel_workers=num_parallel_workers)
ds1 = ds1.map(input_columns="image", operations=rescale_op,
num_parallel_workers=num_parallel_workers)
ds1 = ds1.map(input_columns="image", operations=hwc2chw_op,
num_parallel_workers=num_parallel_workers)
# apply DatasetOps
buffer_size = 10000
ds1 = ds1.shuffle(buffer_size=buffer_size)
ds1 = ds1.batch(batch_size, drop_remainder=True)
ds1 = ds1.repeat(repeat_size)
return ds1
# mnist demo
## Introduction
The MNIST database of handwritten digits, available from this page, has a training set of 60,000 examples, and a test set of 10,000 examples. It is a subset of a larger set available from MNIST. The digits have been size-normalized and centered in a fixed-size image.
## run demo
### 1. download dataset
```sh
$ cd example/mnist_demo
$ mkdir MNIST_unzip
$ cd MNIST_unzip
$ mkdir train
$ mkdir test
$ cd train
$ wget "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"
$ wget "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"
$ gzip train-images-idx3-ubyte.gz -d
$ gzip train-labels-idx1-ubyte.gz -d
$ cd ../test
$ wget "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"
$ wget "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"
$ gzip t10k-images-idx3-ubyte.gz -d
$ gzip t10k-images-idx3-ubyte.gz -d
$ cd ../../
```
### 1. trian model
```sh
$ python mnist_train.py
```
### 2. run attack test
```sh
$ mkdir out.data
$ python mnist_attack_jsma.py
```
### 3. run defense/detector test
```sh
$ python mnist_defense_nad.py
$ python mnist_similarity_detector.py
```
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore.common.initializer import TruncatedNormal
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
return TruncatedNormal(0.2)
class LeNet5(nn.Cell):
"""
Lenet network
"""
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16*5*5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.reshape(x, (-1, 16*5*5))
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
import numpy as np
import pytest
from scipy.special import softmax
from mindspore import Model
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.attacks.carlini_wagner import CarliniWagnerL2Attack
from mindarmour.utils.logger import LogUtil
from mindarmour.evaluations.attack_evaluation import AttackEvaluate
from lenet5_net import LeNet5
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
TAG = 'CW_Test'
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_carlini_wagner_attack():
"""
CW-Attack test
"""
# upload trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
# get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size=batch_size)
# prediction accuracy before attack
model = Model(net)
batch_num = 3 # the number of batches of attacking samples
test_images = []
test_labels = []
predict_labels = []
i = 0
for data in ds.create_tuple_iterator():
i += 1
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
pred_labels = np.argmax(model.predict(Tensor(images)).asnumpy(),
axis=1)
predict_labels.append(pred_labels)
if i >= batch_num:
break
predict_labels = np.concatenate(predict_labels)
true_labels = np.concatenate(test_labels)
accuracy = np.mean(np.equal(predict_labels, true_labels))
LOGGER.info(TAG, "prediction accuracy before attacking is : %s", accuracy)
# attacking
num_classes = 10
attack = CarliniWagnerL2Attack(net, num_classes, targeted=False)
start_time = time.clock()
adv_data = attack.batch_generate(np.concatenate(test_images),
np.concatenate(test_labels), batch_size=32)
stop_time = time.clock()
pred_logits_adv = model.predict(Tensor(adv_data)).asnumpy()
# rescale predict confidences into (0, 1).
pred_logits_adv = softmax(pred_logits_adv, axis=1)
pred_labels_adv = np.argmax(pred_logits_adv, axis=1)
accuracy_adv = np.mean(np.equal(pred_labels_adv, true_labels))
LOGGER.info(TAG, "prediction accuracy after attacking is : %s",
accuracy_adv)
test_labels = np.eye(10)[np.concatenate(test_labels)]
attack_evaluate = AttackEvaluate(np.concatenate(test_images).transpose(0, 2, 3, 1),
test_labels, adv_data.transpose(0, 2, 3, 1),
pred_logits_adv)
LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s',
attack_evaluate.mis_classification_rate())
LOGGER.info(TAG, 'The average confidence of adversarial class is : %s',
attack_evaluate.avg_conf_adv_class())
LOGGER.info(TAG, 'The average confidence of true class is : %s',
attack_evaluate.avg_conf_true_class())
LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_lp_distance())
LOGGER.info(TAG, 'The average structural similarity between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_ssim())
LOGGER.info(TAG, 'The average costing time is %s',
(stop_time - start_time)/(batch_num*batch_size))
if __name__ == '__main__':
test_carlini_wagner_attack()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
import numpy as np
import pytest
from scipy.special import softmax
from mindspore import Model
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.attacks.deep_fool import DeepFool
from mindarmour.utils.logger import LogUtil
from mindarmour.evaluations.attack_evaluation import AttackEvaluate
from lenet5_net import LeNet5
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
TAG = 'DeepFool_Test'
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_deepfool_attack():
"""
DeepFool-Attack test
"""
# upload trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
# get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size=batch_size)
# prediction accuracy before attack
model = Model(net)
batch_num = 3 # the number of batches of attacking samples
test_images = []
test_labels = []
predict_labels = []
i = 0
for data in ds.create_tuple_iterator():
i += 1
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
pred_labels = np.argmax(model.predict(Tensor(images)).asnumpy(),
axis=1)
predict_labels.append(pred_labels)
if i >= batch_num:
break
predict_labels = np.concatenate(predict_labels)
true_labels = np.concatenate(test_labels)
accuracy = np.mean(np.equal(predict_labels, true_labels))
LOGGER.info(TAG, "prediction accuracy before attacking is : %s", accuracy)
# attacking
classes = 10
attack = DeepFool(net, classes, norm_level=2,
bounds=(0.0, 1.0))
start_time = time.clock()
adv_data = attack.batch_generate(np.concatenate(test_images),
np.concatenate(test_labels), batch_size=32)
stop_time = time.clock()
pred_logits_adv = model.predict(Tensor(adv_data)).asnumpy()
# rescale predict confidences into (0, 1).
pred_logits_adv = softmax(pred_logits_adv, axis=1)
pred_labels_adv = np.argmax(pred_logits_adv, axis=1)
accuracy_adv = np.mean(np.equal(pred_labels_adv, true_labels))
LOGGER.info(TAG, "prediction accuracy after attacking is : %s",
accuracy_adv)
test_labels = np.eye(10)[np.concatenate(test_labels)]
attack_evaluate = AttackEvaluate(np.concatenate(test_images).transpose(0, 2, 3, 1),
test_labels, adv_data.transpose(0, 2, 3, 1),
pred_logits_adv)
LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s',
attack_evaluate.mis_classification_rate())
LOGGER.info(TAG, 'The average confidence of adversarial class is : %s',
attack_evaluate.avg_conf_adv_class())
LOGGER.info(TAG, 'The average confidence of true class is : %s',
attack_evaluate.avg_conf_true_class())
LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_lp_distance())
LOGGER.info(TAG, 'The average structural similarity between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_ssim())
LOGGER.info(TAG, 'The average costing time is %s',
(stop_time - start_time)/(batch_num*batch_size))
if __name__ == '__main__':
test_deepfool_attack()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
import numpy as np
import pytest
from scipy.special import softmax
from mindspore import Model
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.attacks.gradient_method import FastGradientSignMethod
from mindarmour.utils.logger import LogUtil
from mindarmour.evaluations.attack_evaluation import AttackEvaluate
from lenet5_net import LeNet5
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
TAG = 'FGSM_Test'
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_fast_gradient_sign_method():
"""
FGSM-Attack test
"""
# upload trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
# get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size, sparse=False)
# prediction accuracy before attack
model = Model(net)
batch_num = 3 # the number of batches of attacking samples
test_images = []
test_labels = []
predict_labels = []
i = 0
for data in ds.create_tuple_iterator():
i += 1
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
pred_labels = np.argmax(model.predict(Tensor(images)).asnumpy(),
axis=1)
predict_labels.append(pred_labels)
if i >= batch_num:
break
predict_labels = np.concatenate(predict_labels)
true_labels = np.argmax(np.concatenate(test_labels), axis=1)
accuracy = np.mean(np.equal(predict_labels, true_labels))
LOGGER.info(TAG, "prediction accuracy before attacking is : %s", accuracy)
# attacking
attack = FastGradientSignMethod(net, eps=0.3)
start_time = time.clock()
adv_data = attack.batch_generate(np.concatenate(test_images),
np.concatenate(test_labels), batch_size=32)
stop_time = time.clock()
np.save('./adv_data', adv_data)
pred_logits_adv = model.predict(Tensor(adv_data)).asnumpy()
# rescale predict confidences into (0, 1).
pred_logits_adv = softmax(pred_logits_adv, axis=1)
pred_labels_adv = np.argmax(pred_logits_adv, axis=1)
accuracy_adv = np.mean(np.equal(pred_labels_adv, true_labels))
LOGGER.info(TAG, "prediction accuracy after attacking is : %s", accuracy_adv)
attack_evaluate = AttackEvaluate(np.concatenate(test_images).transpose(0, 2, 3, 1),
np.concatenate(test_labels),
adv_data.transpose(0, 2, 3, 1),
pred_logits_adv)
LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s',
attack_evaluate.mis_classification_rate())
LOGGER.info(TAG, 'The average confidence of adversarial class is : %s',
attack_evaluate.avg_conf_adv_class())
LOGGER.info(TAG, 'The average confidence of true class is : %s',
attack_evaluate.avg_conf_true_class())
LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_lp_distance())
LOGGER.info(TAG, 'The average structural similarity between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_ssim())
LOGGER.info(TAG, 'The average costing time is %s',
(stop_time - start_time)/(batch_num*batch_size))
if __name__ == '__main__':
test_fast_gradient_sign_method()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
import numpy as np
import pytest
from scipy.special import softmax
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.attacks.black.genetic_attack import GeneticAttack
from mindarmour.attacks.black.black_model import BlackModel
from mindarmour.utils.logger import LogUtil
from mindarmour.evaluations.attack_evaluation import AttackEvaluate
from lenet5_net import LeNet5
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
TAG = 'Genetic_Attack'
class ModelToBeAttacked(BlackModel):
"""model to be attack"""
def __init__(self, network):
super(ModelToBeAttacked, self).__init__()
self._network = network
def predict(self, inputs):
"""predict"""
result = self._network(Tensor(inputs.astype(np.float32)))
return result.asnumpy()
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_genetic_attack_on_mnist():
"""
Genetic-Attack test
"""
# upload trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
# get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size=batch_size)
# prediction accuracy before attack
model = ModelToBeAttacked(net)
batch_num = 3 # the number of batches of attacking samples
test_images = []
test_labels = []
predict_labels = []
i = 0
for data in ds.create_tuple_iterator():
i += 1
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
pred_labels = np.argmax(model.predict(images), axis=1)
predict_labels.append(pred_labels)
if i >= batch_num:
break
predict_labels = np.concatenate(predict_labels)
true_labels = np.concatenate(test_labels)
accuracy = np.mean(np.equal(predict_labels, true_labels))
LOGGER.info(TAG, "prediction accuracy before attacking is : %g", accuracy)
# attacking
attack = GeneticAttack(model=model, pop_size=6, mutation_rate=0.05,
per_bounds=0.1, step_size=0.25, temp=0.1,
sparse=True)
targeted_labels = np.random.randint(0, 10, size=len(true_labels))
for i in range(len(true_labels)):
if targeted_labels[i] == true_labels[i]:
targeted_labels[i] = (targeted_labels[i] + 1) % 10
start_time = time.clock()
success_list, adv_data, query_list = attack.generate(
np.concatenate(test_images), targeted_labels)
stop_time = time.clock()
LOGGER.info(TAG, 'success_list: %s', success_list)
LOGGER.info(TAG, 'average of query times is : %s', np.mean(query_list))
pred_logits_adv = model.predict(adv_data)
# rescale predict confidences into (0, 1).
pred_logits_adv = softmax(pred_logits_adv, axis=1)
pred_lables_adv = np.argmax(pred_logits_adv, axis=1)
accuracy_adv = np.mean(np.equal(pred_lables_adv, true_labels))
LOGGER.info(TAG, "prediction accuracy after attacking is : %g",
accuracy_adv)
test_labels_onehot = np.eye(10)[true_labels]
attack_evaluate = AttackEvaluate(np.concatenate(test_images),
test_labels_onehot, adv_data,
pred_logits_adv, targeted=True,
target_label=targeted_labels)
LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s',
attack_evaluate.mis_classification_rate())
LOGGER.info(TAG, 'The average confidence of adversarial class is : %s',
attack_evaluate.avg_conf_adv_class())
LOGGER.info(TAG, 'The average confidence of true class is : %s',
attack_evaluate.avg_conf_true_class())
LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_lp_distance())
LOGGER.info(TAG, 'The average structural similarity between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_ssim())
LOGGER.info(TAG, 'The average costing time is %s',
(stop_time - start_time)/(batch_num*batch_size))
if __name__ == '__main__':
test_genetic_attack_on_mnist()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
import pytest
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.attacks.black.hop_skip_jump_attack import HopSkipJumpAttack
from mindarmour.attacks.black.black_model import BlackModel
from mindarmour.utils.logger import LogUtil
from lenet5_net import LeNet5
sys.path.append("..")
from data_processing import generate_mnist_dataset
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
LOGGER = LogUtil.get_instance()
TAG = 'HopSkipJumpAttack'
class ModelToBeAttacked(BlackModel):
"""model to be attack"""
def __init__(self, network):
super(ModelToBeAttacked, self).__init__()
self._network = network
def predict(self, inputs):
"""predict"""
if len(inputs.shape) == 3:
inputs = inputs[np.newaxis, :]
result = self._network(Tensor(inputs.astype(np.float32)))
return result.asnumpy()
def random_target_labels(true_labels):
target_labels = []
for label in true_labels:
while True:
target_label = np.random.randint(0, 10)
if target_label != label:
target_labels.append(target_label)
break
return target_labels
def create_target_images(dataset, data_labels, target_labels):
res = []
for label in target_labels:
for i in range(len(data_labels)):
if data_labels[i] == label:
res.append(dataset[i])
break
return np.array(res)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_hsja_mnist_attack():
"""
hsja-Attack test
"""
# upload trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
net.set_train(False)
# get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size=batch_size)
# prediction accuracy before attack
model = ModelToBeAttacked(net)
batch_num = 5 # the number of batches of attacking samples
test_images = []
test_labels = []
predict_labels = []
i = 0
for data in ds.create_tuple_iterator():
i += 1
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
pred_labels = np.argmax(model.predict(images), axis=1)
predict_labels.append(pred_labels)
if i >= batch_num:
break
predict_labels = np.concatenate(predict_labels)
true_labels = np.concatenate(test_labels)
accuracy = np.mean(np.equal(predict_labels, true_labels))
LOGGER.info(TAG, "prediction accuracy before attacking is : %s",
accuracy)
test_images = np.concatenate(test_images)
# attacking
norm = 'l2'
search = 'grid_search'
target = False
attack = HopSkipJumpAttack(model, constraint=norm, stepsize_search=search)
if target:
target_labels = random_target_labels(true_labels)
target_images = create_target_images(test_images, predict_labels,
target_labels)
attack.set_target_images(target_images)
success_list, adv_data, query_list = attack.generate(test_images, target_labels)
else:
success_list, adv_data, query_list = attack.generate(test_images, None)
adv_datas = []
gts = []
for success, adv, gt in zip(success_list, adv_data, true_labels):
if success:
adv_datas.append(adv)
gts.append(gt)
if len(gts) > 0:
adv_datas = np.concatenate(np.asarray(adv_datas), axis=0)
gts = np.asarray(gts)
pred_logits_adv = model.predict(adv_datas)
pred_lables_adv = np.argmax(pred_logits_adv, axis=1)
accuracy_adv = np.mean(np.equal(pred_lables_adv, gts))
LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s',
accuracy_adv)
if __name__ == '__main__':
test_hsja_mnist_attack()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
import numpy as np
import pytest
from scipy.special import softmax
from mindspore import Model
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.attacks.jsma import JSMAAttack
from mindarmour.utils.logger import LogUtil
from mindarmour.evaluations.attack_evaluation import AttackEvaluate
from lenet5_net import LeNet5
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
TAG = 'JSMA_Test'
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_jsma_attack():
"""
JSMA-Attack test
"""
# upload trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
# get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size=batch_size)
# prediction accuracy before attack
model = Model(net)
batch_num = 3 # the number of batches of attacking samples
test_images = []
test_labels = []
predict_labels = []
i = 0
for data in ds.create_tuple_iterator():
i += 1
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
pred_labels = np.argmax(model.predict(Tensor(images)).asnumpy(),
axis=1)
predict_labels.append(pred_labels)
if i >= batch_num:
break
predict_labels = np.concatenate(predict_labels)
true_labels = np.concatenate(test_labels)
targeted_labels = np.random.randint(0, 10, size=len(true_labels))
for i in range(len(true_labels)):
if targeted_labels[i] == true_labels[i]:
targeted_labels[i] = (targeted_labels[i] + 1) % 10
accuracy = np.mean(np.equal(predict_labels, true_labels))
LOGGER.info(TAG, "prediction accuracy before attacking is : %g", accuracy)
# attacking
classes = 10
attack = JSMAAttack(net, classes)
start_time = time.clock()
adv_data = attack.batch_generate(np.concatenate(test_images),
targeted_labels, batch_size=32)
stop_time = time.clock()
pred_logits_adv = model.predict(Tensor(adv_data)).asnumpy()
# rescale predict confidences into (0, 1).
pred_logits_adv = softmax(pred_logits_adv, axis=1)
pred_lables_adv = np.argmax(pred_logits_adv, axis=1)
accuracy_adv = np.mean(np.equal(pred_lables_adv, true_labels))
LOGGER.info(TAG, "prediction accuracy after attacking is : %g",
accuracy_adv)
test_labels = np.eye(10)[np.concatenate(test_labels)]
attack_evaluate = AttackEvaluate(
np.concatenate(test_images).transpose(0, 2, 3, 1),
test_labels, adv_data.transpose(0, 2, 3, 1),
pred_logits_adv, targeted=True, target_label=targeted_labels)
LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s',
attack_evaluate.mis_classification_rate())
LOGGER.info(TAG, 'The average confidence of adversarial class is : %s',
attack_evaluate.avg_conf_adv_class())
LOGGER.info(TAG, 'The average confidence of true class is : %s',
attack_evaluate.avg_conf_true_class())
LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_lp_distance())
LOGGER.info(TAG, 'The average structural similarity between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_ssim())
LOGGER.info(TAG, 'The average costing time is %s',
(stop_time - start_time) / (batch_num*batch_size))
if __name__ == '__main__':
test_jsma_attack()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
import numpy as np
import pytest
from scipy.special import softmax
from mindspore import Model
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.attacks.lbfgs import LBFGS
from mindarmour.utils.logger import LogUtil
from mindarmour.evaluations.attack_evaluation import AttackEvaluate
from lenet5_net import LeNet5
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
TAG = 'LBFGS_Test'
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_lbfgs_attack():
"""
LBFGS-Attack test
"""
# upload trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
# get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size=batch_size, sparse=False)
# prediction accuracy before attack
model = Model(net)
batch_num = 3 # the number of batches of attacking samples
test_images = []
test_labels = []
predict_labels = []
i = 0
for data in ds.create_tuple_iterator():
i += 1
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
pred_labels = np.argmax(model.predict(Tensor(images)).asnumpy(),
axis=1)
predict_labels.append(pred_labels)
if i >= batch_num:
break
predict_labels = np.concatenate(predict_labels)
true_labels = np.argmax(np.concatenate(test_labels), axis=1)
accuracy = np.mean(np.equal(predict_labels, true_labels))
LOGGER.info(TAG, "prediction accuracy before attacking is : %s", accuracy)
# attacking
is_targeted = True
if is_targeted:
targeted_labels = np.random.randint(0, 10, size=len(true_labels)).astype(np.int32)
for i in range(len(true_labels)):
if targeted_labels[i] == true_labels[i]:
targeted_labels[i] = (targeted_labels[i] + 1) % 10
else:
targeted_labels = true_labels.astype(np.int32)
targeted_labels = np.eye(10)[targeted_labels].astype(np.float32)
attack = LBFGS(net, is_targeted=is_targeted)
start_time = time.clock()
adv_data = attack.batch_generate(np.concatenate(test_images),
targeted_labels,
batch_size=batch_size)
stop_time = time.clock()
pred_logits_adv = model.predict(Tensor(adv_data)).asnumpy()
# rescale predict confidences into (0, 1).
pred_logits_adv = softmax(pred_logits_adv, axis=1)
pred_labels_adv = np.argmax(pred_logits_adv, axis=1)
accuracy_adv = np.mean(np.equal(pred_labels_adv, true_labels))
LOGGER.info(TAG, "prediction accuracy after attacking is : %s",
accuracy_adv)
attack_evaluate = AttackEvaluate(np.concatenate(test_images).transpose(0, 2, 3, 1),
np.concatenate(test_labels),
adv_data.transpose(0, 2, 3, 1),
pred_logits_adv,
targeted=is_targeted,
target_label=np.argmax(targeted_labels,
axis=1))
LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s',
attack_evaluate.mis_classification_rate())
LOGGER.info(TAG, 'The average confidence of adversarial class is : %s',
attack_evaluate.avg_conf_adv_class())
LOGGER.info(TAG, 'The average confidence of true class is : %s',
attack_evaluate.avg_conf_true_class())
LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_lp_distance())
LOGGER.info(TAG, 'The average structural similarity between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_ssim())
LOGGER.info(TAG, 'The average costing time is %s',
(stop_time - start_time)/(batch_num*batch_size))
if __name__ == '__main__':
test_lbfgs_attack()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
import pytest
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.attacks.black.natural_evolutionary_strategy import NES
from mindarmour.attacks.black.black_model import BlackModel
from mindarmour.utils.logger import LogUtil
from lenet5_net import LeNet5
sys.path.append("..")
from data_processing import generate_mnist_dataset
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
LOGGER = LogUtil.get_instance()
TAG = 'HopSkipJumpAttack'
class ModelToBeAttacked(BlackModel):
"""model to be attack"""
def __init__(self, network):
super(ModelToBeAttacked, self).__init__()
self._network = network
def predict(self, inputs):
"""predict"""
if len(inputs.shape) == 3:
inputs = inputs[np.newaxis, :]
result = self._network(Tensor(inputs.astype(np.float32)))
return result.asnumpy()
def random_target_labels(true_labels, labels_list):
target_labels = []
for label in true_labels:
while True:
target_label = np.random.choice(labels_list)
if target_label != label:
target_labels.append(target_label)
break
return target_labels
def _pseudorandom_target(index, total_indices, true_class):
""" pseudo random_target """
rng = np.random.RandomState(index)
target = true_class
while target == true_class:
target = rng.randint(0, total_indices)
return target
def create_target_images(dataset, data_labels, target_labels):
res = []
for label in target_labels:
for i in range(len(data_labels)):
if data_labels[i] == label:
res.append(dataset[i])
break
return np.array(res)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_nes_mnist_attack():
"""
hsja-Attack test
"""
# upload trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
net.set_train(False)
# get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size=batch_size)
# prediction accuracy before attack
model = ModelToBeAttacked(net)
# the number of batches of attacking samples
batch_num = 5
test_images = []
test_labels = []
predict_labels = []
i = 0
for data in ds.create_tuple_iterator():
i += 1
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
pred_labels = np.argmax(model.predict(images), axis=1)
predict_labels.append(pred_labels)
if i >= batch_num:
break
predict_labels = np.concatenate(predict_labels)
true_labels = np.concatenate(test_labels)
accuracy = np.mean(np.equal(predict_labels, true_labels))
LOGGER.info(TAG, "prediction accuracy before attacking is : %s",
accuracy)
test_images = np.concatenate(test_images)
# attacking
scene = 'Query_Limit'
if scene == 'Query_Limit':
top_k = -1
elif scene == 'Partial_Info':
top_k = 5
elif scene == 'Label_Only':
top_k = 5
success = 0
queries_num = 0
nes_instance = NES(model, scene, top_k=top_k)
test_length = 32
advs = []
for img_index in range(test_length):
# Initial image and class selection
initial_img = test_images[img_index]
orig_class = true_labels[img_index]
initial_img = [initial_img]
target_class = random_target_labels([orig_class], true_labels)
target_image = create_target_images(test_images, true_labels,
target_class)
nes_instance.set_target_images(target_image)
tag, adv, queries = nes_instance.generate(initial_img, target_class)
if tag[0]:
success += 1
queries_num += queries[0]
advs.append(adv)
advs = np.reshape(advs, (len(advs), 1, 32, 32))
adv_pred = np.argmax(model.predict(advs), axis=1)
adv_accuracy = np.mean(np.equal(adv_pred, true_labels[:test_length]))
LOGGER.info(TAG, "prediction accuracy after attacking is : %s",
adv_accuracy)
if __name__ == '__main__':
test_nes_mnist_attack()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
import numpy as np
import pytest
from scipy.special import softmax
from mindspore import Model
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.attacks.iterative_gradient_method import ProjectedGradientDescent
from mindarmour.utils.logger import LogUtil
from mindarmour.evaluations.attack_evaluation import AttackEvaluate
from lenet5_net import LeNet5
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
TAG = 'PGD_Test'
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_projected_gradient_descent_method():
"""
PGD-Attack test
"""
# upload trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
# get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size, sparse=False)
# prediction accuracy before attack
model = Model(net)
batch_num = 32 # the number of batches of attacking samples
test_images = []
test_labels = []
predict_labels = []
i = 0
for data in ds.create_tuple_iterator():
i += 1
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
pred_labels = np.argmax(model.predict(Tensor(images)).asnumpy(),
axis=1)
predict_labels.append(pred_labels)
if i >= batch_num:
break
predict_labels = np.concatenate(predict_labels)
true_labels = np.argmax(np.concatenate(test_labels), axis=1)
accuracy = np.mean(np.equal(predict_labels, true_labels))
LOGGER.info(TAG, "prediction accuracy before attacking is : %s", accuracy)
# attacking
attack = ProjectedGradientDescent(net, eps=0.3)
start_time = time.clock()
adv_data = attack.batch_generate(np.concatenate(test_images),
np.concatenate(test_labels), batch_size=32)
stop_time = time.clock()
np.save('./adv_data', adv_data)
pred_logits_adv = model.predict(Tensor(adv_data)).asnumpy()
# rescale predict confidences into (0, 1).
pred_logits_adv = softmax(pred_logits_adv, axis=1)
pred_labels_adv = np.argmax(pred_logits_adv, axis=1)
accuracy_adv = np.mean(np.equal(pred_labels_adv, true_labels))
LOGGER.info(TAG, "prediction accuracy after attacking is : %s", accuracy_adv)
attack_evaluate = AttackEvaluate(np.concatenate(test_images).transpose(0, 2, 3, 1),
np.concatenate(test_labels),
adv_data.transpose(0, 2, 3, 1),
pred_logits_adv)
LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s',
attack_evaluate.mis_classification_rate())
LOGGER.info(TAG, 'The average confidence of adversarial class is : %s',
attack_evaluate.avg_conf_adv_class())
LOGGER.info(TAG, 'The average confidence of true class is : %s',
attack_evaluate.avg_conf_true_class())
LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_lp_distance())
LOGGER.info(TAG, 'The average structural similarity between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_ssim())
LOGGER.info(TAG, 'The average costing time is %s',
(stop_time - start_time)/(batch_num*batch_size))
if __name__ == '__main__':
test_projected_gradient_descent_method()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
import pytest
from scipy.special import softmax
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.attacks.black.pointwise_attack import PointWiseAttack
from mindarmour.attacks.black.black_model import BlackModel
from mindarmour.utils.logger import LogUtil
from mindarmour.evaluations.attack_evaluation import AttackEvaluate
from lenet5_net import LeNet5
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
TAG = 'Pointwise_Attack'
LOGGER.set_level('INFO')
class ModelToBeAttacked(BlackModel):
"""model to be attack"""
def __init__(self, network):
super(ModelToBeAttacked, self).__init__()
self._network = network
def predict(self, inputs):
"""predict"""
if len(inputs.shape) == 3:
inputs = inputs[np.newaxis, :]
result = self._network(Tensor(inputs.astype(np.float32)))
return result.asnumpy()
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_pointwise_attack_on_mnist():
"""
Salt-and-Pepper-Attack test
"""
# upload trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
# get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size=batch_size)
# prediction accuracy before attack
model = ModelToBeAttacked(net)
batch_num = 3 # the number of batches of attacking samples
test_images = []
test_labels = []
predict_labels = []
i = 0
for data in ds.create_tuple_iterator():
i += 1
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
pred_labels = np.argmax(model.predict(images), axis=1)
predict_labels.append(pred_labels)
if i >= batch_num:
break
predict_labels = np.concatenate(predict_labels)
true_labels = np.concatenate(test_labels)
accuracy = np.mean(np.equal(predict_labels, true_labels))
LOGGER.info(TAG, "prediction accuracy before attacking is : %g", accuracy)
# attacking
is_target = False
attack = PointWiseAttack(model=model, is_targeted=is_target)
if is_target:
targeted_labels = np.random.randint(0, 10, size=len(true_labels))
for i in range(len(true_labels)):
if targeted_labels[i] == true_labels[i]:
targeted_labels[i] = (targeted_labels[i] + 1) % 10
else:
targeted_labels = true_labels
success_list, adv_data, query_list = attack.generate(
np.concatenate(test_images), targeted_labels)
success_list = np.arange(success_list.shape[0])[success_list]
LOGGER.info(TAG, 'success_list: %s', success_list)
LOGGER.info(TAG, 'average of query times is : %s', np.mean(query_list))
adv_preds = []
for ite_data in adv_data:
pred_logits_adv = model.predict(ite_data)
# rescale predict confidences into (0, 1).
pred_logits_adv = softmax(pred_logits_adv, axis=1)
adv_preds.extend(pred_logits_adv)
accuracy_adv = np.mean(np.equal(np.max(adv_preds, axis=1), true_labels))
LOGGER.info(TAG, "prediction accuracy after attacking is : %g",
accuracy_adv)
test_labels_onehot = np.eye(10)[true_labels]
attack_evaluate = AttackEvaluate(np.concatenate(test_images),
test_labels_onehot, adv_data,
adv_preds, targeted=is_target,
target_label=targeted_labels)
LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s',
attack_evaluate.mis_classification_rate())
LOGGER.info(TAG, 'The average confidence of adversarial class is : %s',
attack_evaluate.avg_conf_adv_class())
LOGGER.info(TAG, 'The average confidence of true class is : %s',
attack_evaluate.avg_conf_true_class())
LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_lp_distance())
if __name__ == '__main__':
test_pointwise_attack_on_mnist()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
import numpy as np
import pytest
from scipy.special import softmax
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.attacks.black.pso_attack import PSOAttack
from mindarmour.attacks.black.black_model import BlackModel
from mindarmour.utils.logger import LogUtil
from mindarmour.evaluations.attack_evaluation import AttackEvaluate
from lenet5_net import LeNet5
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
TAG = 'PSO_Attack'
class ModelToBeAttacked(BlackModel):
"""model to be attack"""
def __init__(self, network):
super(ModelToBeAttacked, self).__init__()
self._network = network
def predict(self, inputs):
"""predict"""
result = self._network(Tensor(inputs.astype(np.float32)))
return result.asnumpy()
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_pso_attack_on_mnist():
"""
PSO-Attack test
"""
# upload trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
# get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size=batch_size)
# prediction accuracy before attack
model = ModelToBeAttacked(net)
batch_num = 3 # the number of batches of attacking samples
test_images = []
test_labels = []
predict_labels = []
i = 0
for data in ds.create_tuple_iterator():
i += 1
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
pred_labels = np.argmax(model.predict(images), axis=1)
predict_labels.append(pred_labels)
if i >= batch_num:
break
predict_labels = np.concatenate(predict_labels)
true_labels = np.concatenate(test_labels)
accuracy = np.mean(np.equal(predict_labels, true_labels))
LOGGER.info(TAG, "prediction accuracy before attacking is : %s", accuracy)
# attacking
attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=True)
start_time = time.clock()
success_list, adv_data, query_list = attack.generate(
np.concatenate(test_images), np.concatenate(test_labels))
stop_time = time.clock()
LOGGER.info(TAG, 'success_list: %s', success_list)
LOGGER.info(TAG, 'average of query times is : %s', np.mean(query_list))
pred_logits_adv = model.predict(adv_data)
# rescale predict confidences into (0, 1).
pred_logits_adv = softmax(pred_logits_adv, axis=1)
pred_labels_adv = np.argmax(pred_logits_adv, axis=1)
accuracy_adv = np.mean(np.equal(pred_labels_adv, true_labels))
LOGGER.info(TAG, "prediction accuracy after attacking is : %s",
accuracy_adv)
test_labels_onehot = np.eye(10)[np.concatenate(test_labels)]
attack_evaluate = AttackEvaluate(np.concatenate(test_images),
test_labels_onehot, adv_data,
pred_logits_adv)
LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s',
attack_evaluate.mis_classification_rate())
LOGGER.info(TAG, 'The average confidence of adversarial class is : %s',
attack_evaluate.avg_conf_adv_class())
LOGGER.info(TAG, 'The average confidence of true class is : %s',
attack_evaluate.avg_conf_true_class())
LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_lp_distance())
LOGGER.info(TAG, 'The average structural similarity between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_ssim())
LOGGER.info(TAG, 'The average costing time is %s',
(stop_time - start_time)/(batch_num*batch_size))
if __name__ == '__main__':
test_pso_attack_on_mnist()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
import pytest
from scipy.special import softmax
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.attacks.black.salt_and_pepper_attack import SaltAndPepperNoiseAttack
from mindarmour.attacks.black.black_model import BlackModel
from mindarmour.utils.logger import LogUtil
from mindarmour.evaluations.attack_evaluation import AttackEvaluate
from lenet5_net import LeNet5
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
TAG = 'Salt_and_Pepper_Attack'
LOGGER.set_level('DEBUG')
class ModelToBeAttacked(BlackModel):
"""model to be attack"""
def __init__(self, network):
super(ModelToBeAttacked, self).__init__()
self._network = network
def predict(self, inputs):
"""predict"""
if len(inputs.shape) == 3:
inputs = inputs[np.newaxis, :]
result = self._network(Tensor(inputs.astype(np.float32)))
return result.asnumpy()
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_salt_and_pepper_attack_on_mnist():
"""
Salt-and-Pepper-Attack test
"""
# upload trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
# get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size=batch_size)
# prediction accuracy before attack
model = ModelToBeAttacked(net)
batch_num = 3 # the number of batches of attacking samples
test_images = []
test_labels = []
predict_labels = []
i = 0
for data in ds.create_tuple_iterator():
i += 1
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
pred_labels = np.argmax(model.predict(images), axis=1)
predict_labels.append(pred_labels)
if i >= batch_num:
break
LOGGER.debug(TAG, 'model input image shape is: {}'.format(np.array(test_images).shape))
predict_labels = np.concatenate(predict_labels)
true_labels = np.concatenate(test_labels)
accuracy = np.mean(np.equal(predict_labels, true_labels))
LOGGER.info(TAG, "prediction accuracy before attacking is : %g", accuracy)
# attacking
is_target = False
attack = SaltAndPepperNoiseAttack(model=model,
is_targeted=is_target,
sparse=True)
if is_target:
targeted_labels = np.random.randint(0, 10, size=len(true_labels))
for i in range(len(true_labels)):
if targeted_labels[i] == true_labels[i]:
targeted_labels[i] = (targeted_labels[i] + 1) % 10
else:
targeted_labels = true_labels
LOGGER.debug(TAG, 'input shape is: {}'.format(np.concatenate(test_images).shape))
success_list, adv_data, query_list = attack.generate(
np.concatenate(test_images), targeted_labels)
success_list = np.arange(success_list.shape[0])[success_list]
LOGGER.info(TAG, 'success_list: %s', success_list)
LOGGER.info(TAG, 'average of query times is : %s', np.mean(query_list))
adv_preds = []
for ite_data in adv_data:
pred_logits_adv = model.predict(ite_data)
# rescale predict confidences into (0, 1).
pred_logits_adv = softmax(pred_logits_adv, axis=1)
adv_preds.extend(pred_logits_adv)
accuracy_adv = np.mean(np.equal(np.max(adv_preds, axis=1), true_labels))
LOGGER.info(TAG, "prediction accuracy after attacking is : %g",
accuracy_adv)
test_labels_onehot = np.eye(10)[true_labels]
attack_evaluate = AttackEvaluate(np.concatenate(test_images),
test_labels_onehot, adv_data,
adv_preds, targeted=is_target,
target_label=targeted_labels)
LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s',
attack_evaluate.mis_classification_rate())
LOGGER.info(TAG, 'The average confidence of adversarial class is : %s',
attack_evaluate.avg_conf_adv_class())
LOGGER.info(TAG, 'The average confidence of true class is : %s',
attack_evaluate.avg_conf_true_class())
LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original '
'samples and adversarial samples are: %s',
attack_evaluate.avg_lp_distance())
if __name__ == '__main__':
test_salt_and_pepper_attack_on_mnist()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""defense example using nad"""
import sys
import logging
import numpy as np
import pytest
from mindspore import Tensor
from mindspore import context
from mindspore import nn
from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.attacks import FastGradientSignMethod
from mindarmour.defenses import NaturalAdversarialDefense
from mindarmour.utils.logger import LogUtil
from lenet5_net import LeNet5
sys.path.append("..")
from data_processing import generate_mnist_dataset
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
LOGGER = LogUtil.get_instance()
TAG = 'Nad_Example'
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_nad_method():
"""
NAD-Defense test.
"""
# 1. load trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=False)
opt = nn.Momentum(net.trainable_params(), 0.01, 0.09)
nad = NaturalAdversarialDefense(net, loss_fn=loss, optimizer=opt,
bounds=(0.0, 1.0), eps=0.3)
# 2. get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds_test = generate_mnist_dataset(data_list, batch_size=batch_size,
sparse=False)
inputs = []
labels = []
for data in ds_test.create_tuple_iterator():
inputs.append(data[0].astype(np.float32))
labels.append(data[1])
inputs = np.concatenate(inputs)
labels = np.concatenate(labels)
# 3. get accuracy of test data on original model
net.set_train(False)
acc_list = []
batchs = inputs.shape[0] // batch_size
for i in range(batchs):
batch_inputs = inputs[i*batch_size : (i + 1)*batch_size]
batch_labels = np.argmax(labels[i*batch_size : (i + 1)*batch_size], axis=1)
logits = net(Tensor(batch_inputs)).asnumpy()
label_pred = np.argmax(logits, axis=1)
acc_list.append(np.mean(batch_labels == label_pred))
LOGGER.debug(TAG, 'accuracy of TEST data on original model is : %s',
np.mean(acc_list))
# 4. get adv of test data
attack = FastGradientSignMethod(net, eps=0.3)
adv_data = attack.batch_generate(inputs, labels)
LOGGER.debug(TAG, 'adv_data.shape is : %s', adv_data.shape)
# 5. get accuracy of adv data on original model
net.set_train(False)
acc_list = []
batchs = adv_data.shape[0] // batch_size
for i in range(batchs):
batch_inputs = adv_data[i*batch_size : (i + 1)*batch_size]
batch_labels = np.argmax(labels[i*batch_size : (i + 1)*batch_size], axis=1)
logits = net(Tensor(batch_inputs)).asnumpy()
label_pred = np.argmax(logits, axis=1)
acc_list.append(np.mean(batch_labels == label_pred))
LOGGER.debug(TAG, 'accuracy of adv data on original model is : %s',
np.mean(acc_list))
# 6. defense
net.set_train()
nad.batch_defense(inputs, labels, batch_size=32, epochs=10)
# 7. get accuracy of test data on defensed model
net.set_train(False)
acc_list = []
batchs = inputs.shape[0] // batch_size
for i in range(batchs):
batch_inputs = inputs[i*batch_size : (i + 1)*batch_size]
batch_labels = np.argmax(labels[i*batch_size : (i + 1)*batch_size], axis=1)
logits = net(Tensor(batch_inputs)).asnumpy()
label_pred = np.argmax(logits, axis=1)
acc_list.append(np.mean(batch_labels == label_pred))
LOGGER.debug(TAG, 'accuracy of TEST data on defensed model is : %s',
np.mean(acc_list))
# 8. get accuracy of adv data on defensed model
acc_list = []
batchs = adv_data.shape[0] // batch_size
for i in range(batchs):
batch_inputs = adv_data[i*batch_size : (i + 1)*batch_size]
batch_labels = np.argmax(labels[i*batch_size : (i + 1)*batch_size], axis=1)
logits = net(Tensor(batch_inputs)).asnumpy()
label_pred = np.argmax(logits, axis=1)
acc_list.append(np.mean(batch_labels == label_pred))
LOGGER.debug(TAG, 'accuracy of adv data on defensed model is : %s',
np.mean(acc_list))
if __name__ == '__main__':
LOGGER.set_level(logging.DEBUG)
test_nad_method()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""evaluate example"""
import sys
import os
import time
import numpy as np
from scipy.special import softmax
from lenet5_net import LeNet5
from mindspore import Model
from mindspore import Tensor
from mindspore import context
from mindspore import nn
from mindspore.nn import Cell
from mindspore.ops.operations import TensorAdd
from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.attacks import FastGradientSignMethod
from mindarmour.attacks import GeneticAttack
from mindarmour.attacks.black.black_model import BlackModel
from mindarmour.defenses import NaturalAdversarialDefense
from mindarmour.evaluations import BlackDefenseEvaluate
from mindarmour.evaluations import DefenseEvaluate
from mindarmour.utils.logger import LogUtil
from mindarmour.detectors.black.similarity_detector import SimilarityDetector
sys.path.append("..")
from data_processing import generate_mnist_dataset
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
LOGGER = LogUtil.get_instance()
TAG = 'Defense_Evaluate_Example'
def get_detector(train_images):
encoder = Model(EncoderNet(encode_dim=256))
detector = SimilarityDetector(max_k_neighbor=50, trans_model=encoder)
detector.fit(inputs=train_images)
return detector
class EncoderNet(Cell):
"""
Similarity encoder for input data
"""
def __init__(self, encode_dim):
super(EncoderNet, self).__init__()
self._encode_dim = encode_dim
self.add = TensorAdd()
def construct(self, inputs):
"""
construct the neural network
Args:
inputs (Tensor): input data to neural network.
Returns:
Tensor, output of neural network.
"""
return self.add(inputs, inputs)
def get_encode_dim(self):
"""
Get the dimension of encoded inputs
Returns:
int, dimension of encoded inputs.
"""
return self._encode_dim
class ModelToBeAttacked(BlackModel):
"""
model to be attack
"""
def __init__(self, network, defense=False, train_images=None):
super(ModelToBeAttacked, self).__init__()
self._network = network
self._queries = []
self._defense = defense
self._detector = None
self._detected_res = []
if self._defense:
self._detector = get_detector(train_images)
def predict(self, inputs):
"""
predict function
"""
query_num = inputs.shape[0]
results = []
if self._detector:
for i in range(query_num):
query = np.expand_dims(inputs[i].astype(np.float32), axis=0)
result = self._network(Tensor(query)).asnumpy()
det_num = len(self._detector.get_detected_queries())
self._detector.detect([query])
new_det_num = len(self._detector.get_detected_queries())
# If attack query detected, return random predict result
if new_det_num > det_num:
results.append(result + np.random.rand(*result.shape))
self._detected_res.append(True)
else:
results.append(result)
self._detected_res.append(False)
results = np.concatenate(results)
else:
results = self._network(Tensor(inputs.astype(np.float32))).asnumpy()
return results
def get_detected_result(self):
return self._detected_res
def test_black_defense():
# load trained network
current_dir = os.path.dirname(os.path.abspath(__file__))
ckpt_name = os.path.abspath(os.path.join(
current_dir, './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'))
# ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
wb_net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(wb_net, load_dict)
# get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds_test = generate_mnist_dataset(data_list, batch_size=batch_size,
sparse=False)
inputs = []
labels = []
for data in ds_test.create_tuple_iterator():
inputs.append(data[0].astype(np.float32))
labels.append(data[1])
inputs = np.concatenate(inputs).astype(np.float32)
labels = np.concatenate(labels).astype(np.float32)
labels_sparse = np.argmax(labels, axis=1)
target_label = np.random.randint(0, 10, size=labels_sparse.shape[0])
for idx in range(labels_sparse.shape[0]):
while target_label[idx] == labels_sparse[idx]:
target_label[idx] = np.random.randint(0, 10)
target_label = np.eye(10)[target_label].astype(np.float32)
attacked_size = 50
benign_size = 500
attacked_sample = inputs[:attacked_size]
attacked_true_label = labels[:attacked_size]
benign_sample = inputs[attacked_size:attacked_size + benign_size]
wb_model = ModelToBeAttacked(wb_net)
# gen white-box adversarial examples of test data
wb_attack = FastGradientSignMethod(wb_net, eps=0.3)
wb_adv_sample = wb_attack.generate(attacked_sample,
attacked_true_label)
wb_raw_preds = softmax(wb_model.predict(wb_adv_sample), axis=1)
accuracy_test = np.mean(
np.equal(np.argmax(wb_model.predict(attacked_sample), axis=1),
np.argmax(attacked_true_label, axis=1)))
LOGGER.info(TAG, "prediction accuracy before white-box attack is : %s",
accuracy_test)
accuracy_adv = np.mean(np.equal(np.argmax(wb_raw_preds, axis=1),
np.argmax(attacked_true_label, axis=1)))
LOGGER.info(TAG, "prediction accuracy after white-box attack is : %s",
accuracy_adv)
# improve the robustness of model with white-box adversarial examples
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=False)
opt = nn.Momentum(wb_net.trainable_params(), 0.01, 0.09)
nad = NaturalAdversarialDefense(wb_net, loss_fn=loss, optimizer=opt,
bounds=(0.0, 1.0), eps=0.3)
wb_net.set_train(False)
nad.batch_defense(inputs[:5000], labels[:5000], batch_size=32, epochs=10)
wb_def_preds = wb_net(Tensor(wb_adv_sample)).asnumpy()
wb_def_preds = softmax(wb_def_preds, axis=1)
accuracy_def = np.mean(np.equal(np.argmax(wb_def_preds, axis=1),
np.argmax(attacked_true_label, axis=1)))
LOGGER.info(TAG, "prediction accuracy after defense is : %s", accuracy_def)
# calculate defense evaluation metrics for defense against white-box attack
wb_def_evaluate = DefenseEvaluate(wb_raw_preds, wb_def_preds,
np.argmax(attacked_true_label, axis=1))
LOGGER.info(TAG, 'defense evaluation for white-box adversarial attack')
LOGGER.info(TAG,
'classification accuracy variance (CAV) is : {:.2f}'.format(
wb_def_evaluate.cav()))
LOGGER.info(TAG, 'classification rectify ratio (CRR) is : {:.2f}'.format(
wb_def_evaluate.crr()))
LOGGER.info(TAG, 'classification sacrifice ratio (CSR) is : {:.2f}'.format(
wb_def_evaluate.csr()))
LOGGER.info(TAG,
'classification confidence variance (CCV) is : {:.2f}'.format(
wb_def_evaluate.ccv()))
LOGGER.info(TAG, 'classification output stability is : {:.2f}'.format(
wb_def_evaluate.cos()))
# calculate defense evaluation metrics for defense against black-box attack
LOGGER.info(TAG, 'defense evaluation for black-box adversarial attack')
bb_raw_preds = []
bb_def_preds = []
raw_query_counts = []
raw_query_time = []
def_query_counts = []
def_query_time = []
def_detection_counts = []
# gen black-box adversarial examples of test data
bb_net = LeNet5()
load_param_into_net(bb_net, load_dict)
bb_model = ModelToBeAttacked(bb_net, defense=False)
attack_rm = GeneticAttack(model=bb_model, pop_size=6, mutation_rate=0.05,
per_bounds=0.1, step_size=0.25, temp=0.1,
sparse=False)
attack_target_label = target_label[:attacked_size]
true_label = labels_sparse[:attacked_size + benign_size]
# evaluate robustness of original model
# gen black-box adversarial examples of test data
for idx in range(attacked_size):
raw_st = time.time()
raw_sl, raw_a, raw_qc = attack_rm.generate(
np.expand_dims(attacked_sample[idx], axis=0),
np.expand_dims(attack_target_label[idx], axis=0))
raw_t = time.time() - raw_st
bb_raw_preds.extend(softmax(bb_model.predict(raw_a), axis=1))
raw_query_counts.extend(raw_qc)
raw_query_time.append(raw_t)
for idx in range(benign_size):
raw_st = time.time()
bb_raw_pred = softmax(
bb_model.predict(np.expand_dims(benign_sample[idx], axis=0)),
axis=1)
raw_t = time.time() - raw_st
bb_raw_preds.extend(bb_raw_pred)
raw_query_counts.extend([0])
raw_query_time.append(raw_t)
accuracy_test = np.mean(
np.equal(np.argmax(bb_raw_preds[0:len(attack_target_label)], axis=1),
np.argmax(attack_target_label, axis=1)))
LOGGER.info(TAG, "attack success before adv defense is : %s",
accuracy_test)
# improve the robustness of model with similarity-based detector
bb_def_model = ModelToBeAttacked(bb_net, defense=True,
train_images=inputs[0:6000])
# attack defensed model
attack_dm = GeneticAttack(model=bb_def_model, pop_size=6,
mutation_rate=0.05,
per_bounds=0.1, step_size=0.25, temp=0.1,
sparse=False)
for idx in range(attacked_size):
def_st = time.time()
def_sl, def_a, def_qc = attack_dm.generate(
np.expand_dims(attacked_sample[idx], axis=0),
np.expand_dims(attack_target_label[idx], axis=0))
def_t = time.time() - def_st
det_res = bb_def_model.get_detected_result()
def_detection_counts.append(np.sum(det_res[-def_qc[0]:]))
bb_def_preds.extend(softmax(bb_def_model.predict(def_a), axis=1))
def_query_counts.extend(def_qc)
def_query_time.append(def_t)
for idx in range(benign_size):
def_st = time.time()
bb_def_pred = softmax(
bb_def_model.predict(np.expand_dims(benign_sample[idx], axis=0)),
axis=1)
def_t = time.time() - def_st
det_res = bb_def_model.get_detected_result()
def_detection_counts.append(np.sum(det_res[-1]))
bb_def_preds.extend(bb_def_pred)
def_query_counts.extend([0])
def_query_time.append(def_t)
accuracy_adv = np.mean(
np.equal(np.argmax(bb_def_preds[0:len(attack_target_label)], axis=1),
np.argmax(attack_target_label, axis=1)))
LOGGER.info(TAG, "attack success rate after adv defense is : %s",
accuracy_adv)
bb_raw_preds = np.array(bb_raw_preds).astype(np.float32)
bb_def_preds = np.array(bb_def_preds).astype(np.float32)
# check evaluate data
max_queries = 6000
def_evaluate = BlackDefenseEvaluate(bb_raw_preds, bb_def_preds,
np.array(raw_query_counts),
np.array(def_query_counts),
np.array(raw_query_time),
np.array(def_query_time),
np.array(def_detection_counts),
true_label, max_queries)
LOGGER.info(TAG, 'query count variance of adversaries is : {:.2f}'.format(
def_evaluate.qcv()))
LOGGER.info(TAG, 'attack success rate variance of adversaries '
'is : {:.2f}'.format(def_evaluate.asv()))
LOGGER.info(TAG, 'false positive rate (FPR) of the query-based detector '
'is : {:.2f}'.format(def_evaluate.fpr()))
LOGGER.info(TAG, 'the benign query response time variance (QRV) '
'is : {:.2f}'.format(def_evaluate.qrv()))
if __name__ == '__main__':
test_black_defense()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
import pytest
from scipy.special import softmax
from mindspore import Model
from mindspore import context
from mindspore import Tensor
from mindspore.nn import Cell
from mindspore.ops.operations import TensorAdd
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.utils.logger import LogUtil
from mindarmour.attacks.black.pso_attack import PSOAttack
from mindarmour.attacks.black.black_model import BlackModel
from mindarmour.detectors.black.similarity_detector import SimilarityDetector
from lenet5_net import LeNet5
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
TAG = 'Similarity Detector test'
class ModelToBeAttacked(BlackModel):
"""
model to be attack
"""
def __init__(self, network):
super(ModelToBeAttacked, self).__init__()
self._network = network
self._queries = []
def predict(self, inputs):
"""
predict function
"""
query_num = inputs.shape[0]
for i in range(query_num):
self._queries.append(inputs[i].astype(np.float32))
result = self._network(Tensor(inputs.astype(np.float32)))
return result.asnumpy()
def get_queries(self):
return self._queries
class EncoderNet(Cell):
"""
Similarity encoder for input data
"""
def __init__(self, encode_dim):
super(EncoderNet, self).__init__()
self._encode_dim = encode_dim
self.add = TensorAdd()
def construct(self, inputs):
"""
construct the neural network
Args:
inputs (Tensor): input data to neural network.
Returns:
Tensor, output of neural network.
"""
return self.add(inputs, inputs)
def get_encode_dim(self):
"""
Get the dimension of encoded inputs
Returns:
int, dimension of encoded inputs.
"""
return self._encode_dim
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_similarity_detector():
"""
Similarity Detector test.
"""
# load trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
# get mnist data
data_list = "./MNIST_unzip/test"
batch_size = 1000
ds = generate_mnist_dataset(data_list, batch_size=batch_size)
model = ModelToBeAttacked(net)
batch_num = 10 # the number of batches of input samples
all_images = []
true_labels = []
predict_labels = []
i = 0
for data in ds.create_tuple_iterator():
i += 1
images = data[0].astype(np.float32)
labels = data[1]
all_images.append(images)
true_labels.append(labels)
pred_labels = np.argmax(model.predict(images), axis=1)
predict_labels.append(pred_labels)
if i >= batch_num:
break
all_images = np.concatenate(all_images)
true_labels = np.concatenate(true_labels)
predict_labels = np.concatenate(predict_labels)
accuracy = np.mean(np.equal(predict_labels, true_labels))
LOGGER.info(TAG, "prediction accuracy before attacking is : %s", accuracy)
train_images = all_images[0:6000, :, :, :]
attacked_images = all_images[0:10, :, :, :]
attacked_labels = true_labels[0:10]
# generate malicious query sequence of black attack
attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=True,
t_max=1000)
success_list, adv_data, query_list = attack.generate(attacked_images,
attacked_labels)
LOGGER.info(TAG, 'pso attack success_list: %s', success_list)
LOGGER.info(TAG, 'average of query counts is : %s', np.mean(query_list))
pred_logits_adv = model.predict(adv_data)
# rescale predict confidences into (0, 1).
pred_logits_adv = softmax(pred_logits_adv, axis=1)
pred_lables_adv = np.argmax(pred_logits_adv, axis=1)
accuracy_adv = np.mean(np.equal(pred_lables_adv, attacked_labels))
LOGGER.info(TAG, "prediction accuracy after attacking is : %g",
accuracy_adv)
benign_queries = all_images[6000:10000, :, :, :]
suspicious_queries = model.get_queries()
# explicit threshold not provided, calculate threshold for K
encoder = Model(EncoderNet(encode_dim=256))
detector = SimilarityDetector(max_k_neighbor=50, trans_model=encoder)
detector.fit(inputs=train_images)
# test benign queries
detector.detect(benign_queries)
fpr = len(detector.get_detected_queries()) / benign_queries.shape[0]
LOGGER.info(TAG, 'Number of false positive of attack detector is : %s',
len(detector.get_detected_queries()))
LOGGER.info(TAG, 'False positive rate of attack detector is : %s', fpr)
# test attack queries
detector.clear_buffer()
detector.detect(suspicious_queries)
LOGGER.info(TAG, 'Number of detected attack queries is : %s',
len(detector.get_detected_queries()))
LOGGER.info(TAG, 'The detected attack query indexes are : %s',
detector.get_detected_queries())
if __name__ == '__main__':
test_similarity_detector()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import sys
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train import Model
import mindspore.ops.operations as P
from mindspore.nn.metrics import Accuracy
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindarmour.utils.logger import LogUtil
from lenet5_net import LeNet5
sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
TAG = 'Lenet5_train'
class CrossEntropyLoss(nn.Cell):
"""
Define loss for network
"""
def __init__(self):
super(CrossEntropyLoss, self).__init__()
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
def construct(self, logits, label):
label = self.one_hot(label, F.shape(logits)[1], self.on_value, self.off_value)
loss = self.cross_entropy(logits, label)[0]
loss = self.mean(loss, (-1,))
return loss
def mnist_train(epoch_size, batch_size, lr, momentum):
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
enable_mem_reuse=False)
lr = lr
momentum = momentum
epoch_size = epoch_size
mnist_path = "./MNIST_unzip/"
ds = generate_mnist_dataset(os.path.join(mnist_path, "train"),
batch_size=batch_size, repeat_size=1)
network = LeNet5()
network.set_train()
net_loss = CrossEntropyLoss()
net_opt = nn.Momentum(network.trainable_params(), lr, momentum)
config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory='./trained_ckpt_file/', config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
LOGGER.info(TAG, "============== Starting Training ==============")
model.train(epoch_size, ds, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) # train
LOGGER.info(TAG, "============== Starting Testing ==============")
param_dict = load_checkpoint("trained_ckpt_file/checkpoint_lenet-10_1875.ckpt")
load_param_into_net(network, param_dict)
ds_eval = generate_mnist_dataset(os.path.join(mnist_path, "test"), batch_size=batch_size)
acc = model.eval(ds_eval)
LOGGER.info(TAG, "============== Accuracy: %s ==============", acc)
if __name__ == '__main__':
mnist_train(10, 32, 0.001, 0.9)
"""
MindArmour, a tool box of MindSpore to enhance model security and
trustworthiness against adversarial examples.
"""
from .attacks import Attack
from .attacks.black.black_model import BlackModel
from .defenses.defense import Defense
from .detectors.detector import Detector
__all__ = ['Attack',
'BlackModel',
'Detector',
'Defense']
"""
This module includes classical black-box and white-box attack algorithms
in making adversarial examples.
"""
from .gradient_method import *
from .iterative_gradient_method import *
from .deep_fool import DeepFool
from .jsma import JSMAAttack
from .carlini_wagner import CarliniWagnerL2Attack
from .lbfgs import LBFGS
from . import black
from .black.hop_skip_jump_attack import HopSkipJumpAttack
from .black.genetic_attack import GeneticAttack
from .black.natural_evolutionary_strategy import NES
from .black.pointwise_attack import PointWiseAttack
from .black.pso_attack import PSOAttack
from .black.salt_and_pepper_attack import SaltAndPepperNoiseAttack
__all__ = ['FastGradientMethod',
'RandomFastGradientMethod',
'FastGradientSignMethod',
'RandomFastGradientSignMethod',
'LeastLikelyClassMethod',
'RandomLeastLikelyClassMethod',
'IterativeGradientMethod',
'BasicIterativeMethod',
'MomentumIterativeMethod',
'ProjectedGradientDescent',
'DeepFool',
'CarliniWagnerL2Attack',
'JSMAAttack',
'LBFGS',
'GeneticAttack',
'HopSkipJumpAttack',
'NES',
'PointWiseAttack',
'PSOAttack',
'SaltAndPepperNoiseAttack'
]
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base Class of Attack.
"""
from abc import abstractmethod
import numpy as np
from mindarmour.utils._check_param import check_pair_numpy_param, \
check_int_positive
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
TAG = 'Attack'
class Attack:
"""
The abstract base class for all attack classes creating adversarial examples.
"""
def __init__(self):
pass
def batch_generate(self, inputs, labels, batch_size=64):
"""
Generate adversarial examples in batch, based on input samples and
their labels.
Args:
inputs (numpy.ndarray): Samples based on which adversarial
examples are generated.
labels (numpy.ndarray): Labels of samples, whose values determined
by specific attacks.
batch_size (int): The number of samples in one batch.
Returns:
numpy.ndarray, generated adversarial examples
Examples:
>>> inputs = Tensor([[0.2, 0.4, 0.5, 0.2], [0.7, 0.2, 0.4, 0.3]])
>>> labels = [3, 0]
>>> advs = attack.batch_generate(inputs, labels, batch_size=2)
"""
arr_x, arr_y = check_pair_numpy_param('inputs', inputs, 'labels', labels)
len_x = arr_x.shape[0]
batch_size = check_int_positive('batch_size', batch_size)
batchs = int(len_x / batch_size)
rest = len_x - batchs*batch_size
res = []
for i in range(batchs):
x_batch = arr_x[i*batch_size: (i + 1)*batch_size]
y_batch = arr_y[i*batch_size: (i + 1)*batch_size]
adv_x = self.generate(x_batch, y_batch)
# Black-attack methods will return 3 values, just get the second.
res.append(adv_x[1] if isinstance(adv_x, tuple) else adv_x)
if rest != 0:
x_batch = arr_x[batchs*batch_size:]
y_batch = arr_y[batchs*batch_size:]
adv_x = self.generate(x_batch, y_batch)
# Black-attack methods will return 3 values, just get the second.
res.append(adv_x[1] if isinstance(adv_x, tuple) else adv_x)
adv_x = np.concatenate(res, axis=0)
return adv_x
@abstractmethod
def generate(self, inputs, labels):
"""
Generate adversarial examples based on normal samples and their labels.
Args:
inputs (numpy.ndarray): Samples based on which adversarial
examples are generated.
labels (numpy.ndarray): Labels of samples, whose values determined
by specific attacks.
Raises:
NotImplementedError: It is an abstract method.
"""
msg = 'The function generate() is an abstract function in class ' \
'`Attack` and should be implemented in child class.'
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Black model.
"""
from abc import abstractmethod
import numpy as np
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
TAG = 'BlackModel'
class BlackModel:
"""
The abstract class which treats the target model as a black box. The model
should be defined by users.
"""
def __init__(self):
pass
@abstractmethod
def predict(self, inputs):
"""
Predict using the user specified model. The shape of predict results
should be (m, n), where n represents the number of classes this model
classifies.
Args:
inputs (numpy.ndarray): The input samples to be predicted.
Raises:
NotImplementedError: It is an abstract method.
"""
msg = 'The function predict() is an abstract function in class ' \
'`BlackModel` and should be implemented in child class by user.'
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
def is_adversarial(self, data, label, is_targeted):
"""
Check if input sample is adversarial example or not.
Args:
data (numpy.ndarray): The input sample to be check, typically some
maliciously perturbed examples.
label (numpy.ndarray): For targeted attacks, label is intended
label of perturbed example. For untargeted attacks, label is
original label of corresponding unperturbed sample.
is_targeted (bool): For targeted/untargeted attacks, select True/False.
Returns:
bool.
- If True, the input sample is adversarial.
- If False, the input sample is not adversarial.
"""
logits = self.predict(np.expand_dims(data, axis=0))[0]
predicts = np.argmax(logits)
if is_targeted:
return predicts == label
return predicts != label
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Genetic-Attack.
"""
import numpy as np
from scipy.special import softmax
from mindarmour.attacks.attack import Attack
from mindarmour.utils.logger import LogUtil
from mindarmour.attacks.black.black_model import BlackModel
from mindarmour.utils._check_param import check_numpy_param, check_model, \
check_pair_numpy_param, check_param_type, check_value_positive, \
check_int_positive, check_param_multi_types
LOGGER = LogUtil.get_instance()
TAG = 'GeneticAttack'
def _mutation(cur_pop, step_noise=0.01, prob=0.005):
"""
Generate mutation samples in genetic_attack.
Args:
cur_pop (numpy.ndarray): Samples before mutation.
step_noise (float): Noise range. Default: 0.01.
prob (float): Mutation probability. Default: 0.005.
Returns:
numpy.ndarray, samples after mutation operation in genetic_attack.
Examples:
>>> mul_pop = self._mutation_op([0.2, 0.3, 0.4], step_noise=0.03,
>>> prob=0.01)
"""
cur_pop = check_numpy_param('cur_pop', cur_pop)
perturb_noise = np.clip(np.random.random(cur_pop.shape) - 0.5,
-step_noise, step_noise)
mutated_pop = perturb_noise*(
np.random.random(cur_pop.shape) < prob) + cur_pop
return mutated_pop
class GeneticAttack(Attack):
"""
The Genetic Attack represents the black-box attack based on the genetic algorithm,
which belongs to differential evolution algorithms.
This attack was proposed by Moustafa Alzantot et al. (2018).
References: `Moustafa Alzantot, Yash Sharma, Supriyo Chakraborty,
"GeneticAttack: Practical Black-box Attacks with
Gradient-FreeOptimization" <https://arxiv.org/abs/1805.11090>`_
Args:
model (BlackModel): Target model.
pop_size (int): The number of particles, which should be greater than
zero. Default: 6.
mutation_rate (float): The probability of mutations. Default: 0.005.
per_bounds (float): Maximum L_inf distance.
max_steps (int): The maximum round of iteration for each adversarial
example. Default: 1000.
step_size (float): Attack step size. Default: 0.2.
temp (float): Sampling temperature for selection. Default: 0.3.
bounds (tuple): Upper and lower bounds of data. In form of (clip_min,
clip_max). Default: (0, 1.0)
adaptive (bool): If True, turns on dynamic scaling of mutation
parameters. If false, turns on static mutation parameters.
Default: False.
sparse (bool): If True, input labels are sparse-encoded. If False,
input labels are one-hot-encoded. Default: True.
Examples:
>>> attack = GeneticAttack(model)
"""
def __init__(self, model, pop_size=6,
mutation_rate=0.005, per_bounds=0.15, max_steps=1000,
step_size=0.20, temp=0.3, bounds=(0, 1.0), adaptive=False,
sparse=True):
super(GeneticAttack, self).__init__()
self._model = check_model('model', model, BlackModel)
self._per_bounds = check_value_positive('per_bounds', per_bounds)
self._pop_size = check_int_positive('pop_size', pop_size)
self._step_size = check_value_positive('step_size', step_size)
self._temp = check_value_positive('temp', temp)
self._max_steps = check_int_positive('max_steps', max_steps)
self._mutation_rate = check_value_positive('mutation_rate',
mutation_rate)
self._adaptive = check_param_type('adaptive', adaptive, bool)
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple])
for b in self._bounds:
_ = check_param_multi_types('bound', b, [int, float])
# initial global optimum fitness value
self._best_fit = -1
# count times of no progress
self._plateau_times = 0
# count times of changing attack step
self._adap_times = 0
self._sparse = check_param_type('sparse', sparse, bool)
def generate(self, inputs, labels):
"""
Generate adversarial examples based on input data and targeted
labels (or ground_truth labels).
Args:
inputs (numpy.ndarray): Input samples.
labels (numpy.ndarray): Targeted labels.
Returns:
- numpy.ndarray, bool values for each attack result.
- numpy.ndarray, generated adversarial examples.
- numpy.ndarray, query times for each sample.
Examples:
>>> advs = attack.generate([[0.2, 0.3, 0.4],
>>> [0.3, 0.3, 0.2]],
>>> [1, 2])
"""
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
# if input is one-hot encoded, get sparse format value
if not self._sparse:
if labels.ndim != 2:
raise ValueError('labels must be 2 dims, '
'but got {} dims.'.format(labels.ndim))
labels = np.argmax(labels, axis=1)
adv_list = []
success_list = []
query_times_list = []
for i in range(inputs.shape[0]):
is_success = False
target_label = labels[i]
iters = 0
x_ori = inputs[i]
# generate particles
ori_copies = np.repeat(
x_ori[np.newaxis, :], self._pop_size, axis=0)
# initial perturbations
cur_pert = np.clip(np.random.random(ori_copies.shape)*self._step_size,
(0 - self._per_bounds),
self._per_bounds)
query_times = 0
while iters < self._max_steps:
iters += 1
cur_pop = np.clip(
ori_copies + cur_pert, self._bounds[0], self._bounds[1])
pop_preds = self._model.predict(cur_pop)
query_times += cur_pop.shape[0]
all_preds = np.argmax(pop_preds, axis=1)
success_pop = np.equal(target_label, all_preds).astype(np.int32)
success = max(success_pop)
if success == 1:
is_success = True
adv = cur_pop[np.argmax(success_pop)]
break
target_preds = pop_preds[:, target_label]
others_preds_sum = np.sum(pop_preds, axis=1) - target_preds
fit_vals = target_preds - others_preds_sum
best_fit = max(target_preds - np.max(pop_preds))
if best_fit > self._best_fit:
self._best_fit = best_fit
self._plateau_times = 0
else:
self._plateau_times += 1
adap_threshold = (lambda z: 100 if z > -0.4 else 300)(best_fit)
if self._plateau_times > adap_threshold:
self._adap_times += 1
self._plateau_times = 0
if self._adaptive:
step_noise = max(self._step_size, 0.4*(0.9**self._adap_times))
step_p = max(self._step_size, 0.5*(0.9**self._adap_times))
else:
step_noise = self._step_size
step_p = self._mutation_rate
step_temp = self._temp
elite = cur_pert[np.argmax(fit_vals)]
select_probs = softmax(fit_vals/step_temp)
select_args = np.arange(self._pop_size)
parents_arg = np.random.choice(
a=select_args, size=2*(self._pop_size - 1),
replace=True, p=select_probs)
parent1 = cur_pert[parents_arg[:self._pop_size - 1]]
parent2 = cur_pert[parents_arg[self._pop_size - 1:]]
parent1_probs = select_probs[parents_arg[:self._pop_size - 1]]
parent2_probs = select_probs[parents_arg[self._pop_size - 1:]]
parent2_probs = parent2_probs / (parent1_probs + parent2_probs)
# duplicate the probabilities to all features of each particle.
dims = len(x_ori.shape)
for _ in range(dims):
parent2_probs = parent2_probs[:, np.newaxis]
parent2_probs = np.tile(parent2_probs, ((1,) + x_ori.shape))
cross_probs = (np.random.random(parent1.shape) >
parent2_probs).astype(np.int32)
childs = parent1*cross_probs + parent2*(1 - cross_probs)
mutated_childs = _mutation(
childs, step_noise=self._per_bounds*step_noise,
prob=step_p)
cur_pert = np.concatenate((mutated_childs, elite[np.newaxis, :]))
if is_success:
LOGGER.debug(TAG, 'successfully find one adversarial sample '
'and start Reduction process.')
adv_list.append(adv)
else:
LOGGER.debug(TAG, 'fail to find adversarial sample.')
adv_list.append(elite + x_ori)
LOGGER.debug(TAG,
'iteration times is: %d and query times is: %d',
iters,
query_times)
success_list.append(is_success)
query_times_list.append(query_times)
del ori_copies, cur_pert, cur_pop
return np.asarray(success_list), \
np.asarray(adv_list), \
np.asarray(query_times_list)
此差异已折叠。
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Natural-evolutionary-strategy Attack.
"""
import time
import numpy as np
from scipy.special import softmax
from mindarmour.attacks.attack import Attack
from mindarmour.utils.logger import LogUtil
from mindarmour.attacks.black.black_model import BlackModel
from mindarmour.utils._check_param import check_pair_numpy_param, check_model, \
check_numpy_param, check_int_positive, check_value_positive, check_param_type
LOGGER = LogUtil.get_instance()
TAG = 'NES'
def _one_hot(index, total):
arr = np.zeros((total))
arr[index] = 1.0
return arr
def _bound(image, epislon):
lower = np.clip(image - epislon, 0, 1)
upper = np.clip(image + epislon, 0, 1)
return lower, upper
class NES(Attack):
"""
The class is an implementation of the Natural Evolutionary Strategies Attack,
including three settings: Query-Limited setting, Partial-Information setting
and Label-Only setting.
References: `Andrew Ilyas, Logan Engstrom, Anish Athalye, and Jessy Lin.
Black-box adversarial attacks with limited queries and information. In
ICML, July 2018 <https://arxiv.org/abs/1804.08598>`_
Args:
model (BlackModel): Target model.
scene (str): Scene in 'Label_Only', 'Partial_Info' or
'Query_Limit'.
max_queries (int): Maximum query numbers to generate an adversarial
example. Default: 500000.
top_k (int): For Partial-Info or Label-Only setting, indicating how
much (Top-k) information is available for the attacker. For
Query-Limited setting, this input should be set as -1. Default: -1.
num_class (int): Number of classes in dataset. Default: 10.
batch_size (int): Batch size. Default: 96.
epsilon (float): Maximum perturbation allowed in attack. Default: 0.3.
samples_per_draw (int): Number of samples draw in antithetic sampling.
Default: 96.
momentum (float): Momentum. Default: 0.9.
learning_rate (float): Learning rate. Default: 1e-2.
max_lr (float): Max Learning rate. Default: 1e-2.
min_lr (float): Min Learning rate. Default: 5e-5.
sigma (float): Step size of random noise. Default: 1e-3.
plateau_length (int): Length of plateau used in Annealing algorithm.
Default: 20.
plateau_drop (float): Drop of plateau used in Annealing algorithm.
Default: 2.0.
adv_thresh (float): Threshold of adversarial. Default: 0.15.
zero_iters (int): Number of points to use for the proxy score.
Default: 10.
starting_eps (float): Starting epsilon used in Label-Only setting.
Default: 1.0.
starting_delta_eps (float): Delta epsilon used in Label-Only setting.
Default: 0.5.
label_only_sigma (float): Sigma used in Label-Only setting.
Default: 1e-3.
conservative (int): Conservation used in epsilon decay, it will
increase if no convergence. Default: 2.
sparse (bool): If True, input labels are sparse-encoded. If False,
input labels are one-hot-encoded. Default: True.
Examples:
>>> SCENE = 'Label_Only'
>>> TOP_K = 5
>>> num_class = 5
>>> nes_instance = NES(user_model, SCENE, top_k=TOP_K)
>>> initial_img = np.asarray(np.random.random((32, 32)), np.float32)
>>> target_image = np.asarray(np.random.random((32, 32)), np.float32)
>>> orig_class = 0
>>> target_class = 2
>>> nes_instance.set_target_images(target_image)
>>> tag, adv, queries = nes_instance.generate([initial_img], [target_class])
"""
def __init__(self, model, scene, max_queries=10000, top_k=-1, num_class=10,
batch_size=128, epsilon=0.3, samples_per_draw=128,
momentum=0.9, learning_rate=1e-3, max_lr=5e-2, min_lr=5e-4,
sigma=1e-3, plateau_length=20, plateau_drop=2.0,
adv_thresh=0.25, zero_iters=10, starting_eps=1.0,
starting_delta_eps=0.5, label_only_sigma=1e-3, conservative=2,
sparse=True):
super(NES, self).__init__()
self._model = check_model('model', model, BlackModel)
self._scene = scene
self._max_queries = check_int_positive('max_queries', max_queries)
self._num_class = check_int_positive('num_class', num_class)
self._batch_size = check_int_positive('batch_size', batch_size)
self._samples_per_draw = check_int_positive('samples_per_draw',
samples_per_draw)
self._goal_epsilon = check_value_positive('epsilon', epsilon)
self._momentum = check_value_positive('momentum', momentum)
self._learning_rate = check_value_positive('learning_rate',
learning_rate)
self._max_lr = check_value_positive('max_lr', max_lr)
self._min_lr = check_value_positive('min_lr', min_lr)
self._sigma = check_value_positive('sigma', sigma)
self._plateau_length = check_int_positive('plateau_length',
plateau_length)
self._plateau_drop = check_value_positive('plateau_drop', plateau_drop)
# partial information arguments
self._k = top_k
self._adv_thresh = check_value_positive('adv_thresh', adv_thresh)
# label only arguments
self._zero_iters = check_int_positive('zero_iters', zero_iters)
self._starting_eps = check_value_positive('starting_eps', starting_eps)
self._starting_delta_eps = check_value_positive('starting_delta_eps',
starting_delta_eps)
self._label_only_sigma = check_value_positive('label_only_sigma',
label_only_sigma)
self._conservative = check_int_positive('conservative', conservative)
self._sparse = check_param_type('sparse', sparse, bool)
self.target_imgs = None
self.target_img = None
self.target_class = None
def generate(self, inputs, labels):
"""
Main algorithm for NES.
Args:
inputs (numpy.ndarray): Benign input samples.
labels (numpy.ndarray): Target labels.
Returns:
- numpy.ndarray, bool values for each attack result.
- numpy.ndarray, generated adversarial examples.
- numpy.ndarray, query times for each sample.
Raises:
ValueError: If the top_k less than 0 in Label-Only or Partial-Info
setting.
ValueError: If the target_imgs is None in Label-Only or
Partial-Info setting.
ValueError: If scene is not in ['Label_Only', 'Partial_Info',
'Query_Limit']
Examples:
>>> advs = attack.generate([[0.2, 0.3, 0.4], [0.3, 0.3, 0.2]],
>>> [1, 2])
"""
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if not self._sparse:
labels = np.argmax(labels, axis=1)
if self._scene == 'Label_Only' or self._scene == 'Partial_Info':
if self._k < 0:
msg = "In 'Label_Only' or 'Partial_Info' mode, " \
"'top_k' must more than 0."
LOGGER.error(TAG, msg)
raise ValueError(msg)
if self.target_imgs is None:
msg = "In 'Label_Only' or 'Partial_Info' mode, " \
"'target_imgs' must be set."
LOGGER.error(TAG, msg)
raise ValueError(msg)
elif self._scene == 'Query_Limit':
self._k = self._num_class
else:
msg = "scene must be string in 'Label_Only', " \
"'Partial_Info' or 'Query_Limit' "
LOGGER.error(TAG, msg)
raise ValueError(msg)
is_advs = []
advs = []
queries = []
for sample, label, target_img in zip(inputs, labels, self.target_imgs):
is_adv, adv, query = self._generate_one(sample, label, target_img)
is_advs.append(is_adv)
advs.append(adv)
queries.append(query)
return is_advs, advs, queries
def set_target_images(self, target_images):
"""
Set target samples for target attack.
Args:
target_images (numpy.ndarray): Target samples for target attack.
"""
self.target_imgs = check_numpy_param('target_images', target_images)
def _generate_one(self, origin_image, target_label, target_image):
"""
Main algorithm for NES.
Args:
origin_image (numpy.ndarray): Benign input sample.
target_label (int): Target label.
Returns:
- bool.
- If True: successfully make an adversarial example.
- If False: unsuccessfully make an adversarial example.
- numpy.ndarray, an adversarial example.
- int, number of queries.
"""
self.target_class = target_label
origin_image = check_numpy_param('origin_image', origin_image)
self._epsilon = self._starting_eps
lower, upper = _bound(origin_image, self._epsilon)
goal_epsilon = self._goal_epsilon
delta_epsilon = self._starting_delta_eps
if self._scene == 'Label_Only' or self._scene == 'Partial_Info':
adv = target_image
else:
adv = origin_image.copy()
# for backtracking and momentum
num_queries = 0
gradient = 0
last_ls = []
max_iters = int(np.ceil(self._max_queries // self._samples_per_draw))
for i in range(max_iters):
start = time.time()
# early stop
eval_preds = self._model.predict(adv)
eval_preds = np.argmax(eval_preds, axis=1)
padv = np.equal(eval_preds, self.target_class)
if padv and self._epsilon <= goal_epsilon:
LOGGER.debug(TAG, 'early stopping at iteration %d', i)
return True, adv, num_queries
# antithetic sampling noise
noise_pos = np.random.normal(
size=(self._batch_size // 2,) + origin_image.shape)
noise = np.concatenate((noise_pos, -noise_pos), axis=0)
eval_points = adv + self._sigma*noise
prev_g = gradient
loss, gradient = self._get_grad(origin_image, eval_points, noise)
gradient = self._momentum*prev_g + (1.0 - self._momentum)*gradient
# plateau learning rate annealing
last_ls.append(loss)
last_ls = self._plateau_annealing(last_ls)
# search for learning rate and epsilon decay
current_lr = self._max_lr
prop_delta_eps = 0.0
if loss < self._adv_thresh and self._epsilon > goal_epsilon:
prop_delta_eps = delta_epsilon
while current_lr >= self._min_lr:
# in partial information only or label only setting
if self._scene == 'Label_Only' or self._scene == 'Partial_Info':
proposed_epsilon = max(self._epsilon - prop_delta_eps,
goal_epsilon)
lower, upper = _bound(origin_image, proposed_epsilon)
proposed_adv = adv - current_lr*np.sign(gradient)
proposed_adv = np.clip(proposed_adv, lower, upper)
num_queries += 1
if self._preds_in_top_k(self.target_class, proposed_adv):
# The predicted label of proposed adversarial examples is in
# the top k observations.
if prop_delta_eps > 0:
delta_epsilon = max(prop_delta_eps, 0.1)
last_ls = []
adv = proposed_adv
self._epsilon = max(
self._epsilon - prop_delta_eps / self._conservative,
goal_epsilon)
break
elif current_lr >= self._min_lr*2:
current_lr = current_lr / 2
LOGGER.debug(TAG, "backtracking learning rate to %.3f",
current_lr)
else:
prop_delta_eps = prop_delta_eps / 2
if prop_delta_eps < 2e-3:
LOGGER.debug(TAG, "Did not converge.")
return False, adv, num_queries
current_lr = self._max_lr
LOGGER.debug(TAG,
"backtracking epsilon to %.3f",
self._epsilon - prop_delta_eps)
# update the number of queries
if self._scene == 'Label_Only':
num_queries += self._samples_per_draw*self._zero_iters
else:
num_queries += self._samples_per_draw
LOGGER.debug(TAG,
'Step %d: loss %.4f, lr %.2E, eps %.3f, time %.4f.',
i,
loss,
current_lr,
self._epsilon,
time.time() - start)
return False, adv, num_queries
def _plateau_annealing(self, last_loss):
last_loss = last_loss[-self._plateau_length:]
if last_loss[-1] > last_loss[0] and len(
last_loss) == self._plateau_length:
if self._max_lr > self._min_lr:
LOGGER.debug(TAG, "Annealing max learning rate.")
self._max_lr = max(self._max_lr / self._plateau_drop,
self._min_lr)
last_loss = []
return last_loss
def _softmax_cross_entropy_with_logit(self, logit):
logit = softmax(logit, axis=1)
onehot_label = np.zeros(self._num_class)
onehot_label[self.target_class] = 1
onehot_labels = np.tile(onehot_label, (len(logit), 1))
entropy = -onehot_labels*np.log(logit)
loss = np.mean(entropy, axis=1)
return loss
def _query_limit_loss(self, eval_points, noise):
"""
Loss in Query-Limit setting.
"""
LOGGER.debug(TAG, 'enter the function _query_limit_loss().')
loss = self._softmax_cross_entropy_with_logit(
self._model.predict(eval_points))
return loss, noise
def _partial_info_loss(self, eval_points, noise):
"""
Loss in Partial-Info setting.
"""
LOGGER.debug(TAG, 'enter the function _partial_info_loss.')
logit = self._model.predict(eval_points)
loss = np.sort(softmax(logit, axis=1))[:, -self._k:]
inds = np.argsort(logit)[:, -self._k:]
good_loss = np.where(np.equal(inds, self.target_class), loss,
np.zeros(np.shape(inds)))
good_loss = np.max(good_loss, axis=1)
losses = -np.log(good_loss)
return losses, noise
def _label_only_loss(self, origin_image, eval_points, noise):
"""
Loss in Label-Only setting.
"""
LOGGER.debug(TAG, 'enter the function _label_only_loss().')
tiled_points = np.tile(np.expand_dims(eval_points, 0),
[self._zero_iters,
*[1]*len(eval_points.shape)])
noised_eval_im = tiled_points \
+ np.random.randn(self._zero_iters,
self._batch_size,
*origin_image.shape) \
*self._label_only_sigma
noised_eval_im = np.reshape(noised_eval_im, (
self._zero_iters*self._batch_size, *origin_image.shape))
logits = self._model.predict(noised_eval_im)
inds = np.argsort(logits)[:, -self._k:]
real_inds = np.reshape(inds, (self._zero_iters, self._batch_size, -1))
rank_range = np.arange(1, self._k + 1, 1, dtype=np.float32)
tiled_rank_range = np.tile(np.reshape(rank_range, (1, 1, self._k)),
[self._zero_iters, self._batch_size, 1])
batches_in = np.where(np.equal(real_inds, self.target_class),
tiled_rank_range,
np.zeros(np.shape(tiled_rank_range)))
loss = 1 - np.mean(batches_in)
return loss, noise
def _preds_in_top_k(self, target_class, prop_adv_):
# query limit setting
if self._k == self._num_class:
return True
# label only and partial information setting
eval_preds = self._model.predict(prop_adv_)
if not target_class in eval_preds.argsort()[:, -self._k:]:
return False
return True
def _get_grad(self, origin_image, eval_points, noise):
"""Calculate gradient."""
losses = []
grads = []
for _ in range(self._samples_per_draw // self._batch_size):
if self._scene == 'Label_Only':
loss, np_noise = self._label_only_loss(origin_image,
eval_points,
noise)
elif self._scene == 'Partial_Info':
loss, np_noise = self._partial_info_loss(eval_points, noise)
else:
loss, np_noise = self._query_limit_loss(eval_points, noise)
# only support three channel images
losses_tiled = np.tile(np.reshape(loss, (-1, 1, 1, 1)),
(1,) + origin_image.shape)
grad = np.mean(losses_tiled*np_noise, axis=0) / self._sigma
grads.append(grad)
losses.append(np.mean(loss))
return np.array(losses).mean(), np.mean(np.array(grads), axis=0)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Pointwise-Attack.
"""
import numpy as np
from mindarmour.attacks.attack import Attack
from mindarmour.attacks.black.black_model import BlackModel
from mindarmour.attacks.black.salt_and_pepper_attack import \
SaltAndPepperNoiseAttack
from mindarmour.utils._check_param import check_model, check_pair_numpy_param, \
check_int_positive, check_param_type
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
TAG = 'PointWiseAttack'
class PointWiseAttack(Attack):
"""
The Pointwise Attack make sure use the minimum number of changed pixels
to generate adversarial sample for each original sample.Those changed pixels
will use binary seach to make sure the distance between adversarial sample
and original sample is as close as possible.
References: `L. Schott, J. Rauber, M. Bethge, W. Brendel: "Towards the
first adversarially robust neural network model on MNIST", ICLR (2019)
<https://arxiv.org/abs/1805.09190>`_
Args:
model (BlackModel): Target model.
max_iter (int): Max rounds of iteration to generate adversarial image.
search_iter (int): Max rounds of binary search.
is_targeted (bool): If True, targeted attack. If False, untargeted
attack. Default: False.
init_attack (Attack): Attack used to find a starting point. Default:
None.
sparse (bool): If True, input labels are sparse-encoded. If False,
input labels are one-hot-encoded. Default: True.
Examples:
>>> attack = PointWiseAttack(model)
"""
def __init__(self,
model,
max_iter=1000,
search_iter=10,
is_targeted=False,
init_attack=None,
sparse=True):
super(PointWiseAttack, self).__init__()
self._model = check_model('model', model, BlackModel)
self._max_iter = check_int_positive('max_iter', max_iter)
self._search_iter = check_int_positive('search_iter', search_iter)
self._is_targeted = check_param_type('is_targeted', is_targeted, bool)
if init_attack is None:
self._init_attack = SaltAndPepperNoiseAttack(model,
is_targeted=self._is_targeted)
else:
self._init_attack = init_attack
self._sparse = check_param_type('sparse', sparse, bool)
def generate(self, inputs, labels):
"""
Generate adversarial examples based on input samples and targeted labels.
Args:
inputs (numpy.ndarray): Benign input samples used as references to create
adversarial examples.
labels (numpy.ndarray): For targeted attack, labels are adversarial
target labels. For untargeted attack, labels are ground-truth labels.
Returns:
- numpy.ndarray, bool values for each attack result.
- numpy.ndarray, generated adversarial examples.
- numpy.ndarray, query times for each sample.
Examples:
>>> is_adv_list, adv_list, query_times_each_adv = attack.generate(
>>> [[0.1, 0.2, 0.6], [0.3, 0, 0.4]],
>>> [2, 3])
"""
arr_x, arr_y = check_pair_numpy_param('inputs', inputs, 'labels',
labels)
if not self._sparse:
arr_y = np.argmax(arr_y, axis=1)
ini_bool, ini_advs, ini_count = self._initialize_starting_point(arr_x,
arr_y)
is_adv_list = list()
adv_list = list()
query_times_each_adv = list()
for sample, sample_label, start_adv, ite_bool, ite_c in zip(arr_x,
arr_y,
ini_advs,
ini_bool,
ini_count):
if ite_bool:
LOGGER.info(TAG, 'Start optimizing.')
ori_label = np.argmax(
self._model.predict(np.expand_dims(sample, axis=0))[0])
ini_label = np.argmax(self._model.predict(np.expand_dims(start_adv, axis=0))[0])
is_adv, adv_x, query_times = self._decision_optimize(sample,
sample_label,
start_adv)
adv_label = np.argmax(
self._model.predict(np.expand_dims(adv_x, axis=0))[0])
LOGGER.debug(TAG, 'before ini attack label is :{}'.format(ori_label))
LOGGER.debug(TAG, 'after ini attack label is :{}'.format(ini_label))
LOGGER.debug(TAG, 'INPUT optimize label is :{}'.format(sample_label))
LOGGER.debug(TAG, 'after pointwise attack label is :{}'.format(adv_label))
is_adv_list.append(is_adv)
adv_list.append(adv_x)
query_times_each_adv.append(query_times + ite_c)
else:
LOGGER.info(TAG, 'Initial sample is not adversarial, pass.')
is_adv_list.append(False)
adv_list.append(start_adv)
query_times_each_adv.append(ite_c)
is_adv_list = np.array(is_adv_list)
adv_list = np.array(adv_list)
query_times_each_adv = np.array(query_times_each_adv)
LOGGER.debug(TAG, 'ret list is: {}'.format(adv_list))
return is_adv_list, adv_list, query_times_each_adv
def _decision_optimize(self, unperturbed_img, input_label, perturbed_img):
"""
Make the perturbed samples more similar to unperturbed samples,
while maintaining the perturbed_label.
Args:
unperturbed_img (numpy.ndarray): Input sample as reference to create
adversarial example.
input_label (numpy.ndarray): Input label.
perturbed_img (numpy.ndarray): Starting point to optimize.
Returns:
numpy.ndarray, a generated adversarial example.
Raises:
ValueError: if input unperturbed and perturbed samples have different size.
"""
query_count = 0
img_size = unperturbed_img.size
img_shape = unperturbed_img.shape
perturbed_img = perturbed_img.reshape(-1)
unperturbed_img = unperturbed_img.reshape(-1)
recover = np.copy(perturbed_img)
if unperturbed_img.dtype != perturbed_img.dtype:
msg = 'unperturbed sample and perturbed sample must have the same' \
' dtype, but got dtype of unperturbed is: {}, dtype of perturbed ' \
'is: {}'.format(unperturbed_img.dtype, perturbed_img.dtype)
LOGGER.error(TAG, msg)
raise ValueError(msg)
LOGGER.debug(TAG, 'Before optimize, the mse distance between original '
'sample and adversarial sample is: {}'
.format(self._distance(perturbed_img, unperturbed_img)))
# recover pixel if image is adversarial
for _ in range(self._max_iter):
is_improve = False
# at the premise of adversarial feature, recover pixels
pixels_ind = np.arange(img_size)
mask = unperturbed_img != perturbed_img
np.random.shuffle(pixels_ind)
for ite_ind in pixels_ind:
if mask[ite_ind]:
recover[ite_ind] = unperturbed_img[ite_ind]
query_count += 1
is_adv = self._model.is_adversarial(
recover.reshape(img_shape), input_label, self._is_targeted)
if is_adv:
is_improve = True
perturbed_img[ite_ind] = recover[ite_ind]
break
else:
recover[ite_ind] = perturbed_img[ite_ind]
if not is_improve or (self._distance(
perturbed_img, unperturbed_img) <= self._get_threthod()):
break
LOGGER.debug(TAG, 'first round: Query count {}'.format(query_count))
LOGGER.debug(TAG, 'Starting binary searches.')
# tag the optimized pixels.
mask = unperturbed_img != perturbed_img
for _ in range(self._max_iter):
is_improve = False
pixels_ind = np.arange(img_size)
np.random.shuffle(pixels_ind)
for ite_ind in pixels_ind:
if not mask[ite_ind]:
continue
recover[ite_ind] = unperturbed_img[ite_ind]
query_count += 1
is_adv = self._model.is_adversarial(recover.reshape(img_shape),
input_label,
self._is_targeted)
if is_adv:
is_improve = True
mask[ite_ind] = True
perturbed_img[ite_ind] = recover[ite_ind]
LOGGER.debug(TAG,
'Reset {}th pixel value to original, '
'mse distance: {}.'.format(
ite_ind,
self._distance(perturbed_img,
unperturbed_img)))
break
else:
# use binary searches
optimized_value, b_query = self._binary_search(
perturbed_img,
unperturbed_img,
ite_ind,
input_label, img_shape)
query_count += b_query
if optimized_value != perturbed_img[ite_ind]:
is_improve = True
mask[ite_ind] = True
perturbed_img[ite_ind] = optimized_value
LOGGER.debug(TAG,
'Reset {}th pixel value to original, '
'mse distance: {}.'.format(
ite_ind,
self._distance(perturbed_img,
unperturbed_img)))
break
if not is_improve or (self._distance(
perturbed_img, unperturbed_img) <= self._get_threthod()):
LOGGER.debug(TAG, 'second optimized finish.')
break
LOGGER.info(TAG, 'Optimized finished, query count is {}'.format(query_count))
# this method use to optimized the adversarial sample
return True, perturbed_img.reshape(img_shape), query_count
def _binary_search(self, perturbed_img, unperturbed_img, ite_ind,
input_label, img_shape):
"""
For original pixel of inputs, use binary search to get the nearest pixel
value with original value with adversarial feature.
Args:
perturbed_img (numpy.ndarray): Adversarial sample.
unperturbed_img (numpy.ndarray): Input sample.
ite_ind (int): The index of pixel in inputs.
input_label (numpy.ndarray): Input labels.
img_shape (tuple): Shape of the original sample.
Returns:
float, adversarial pixel value.
"""
query_count = 0
adv_value = perturbed_img[ite_ind]
non_adv_value = unperturbed_img[ite_ind]
for _ in range(self._search_iter):
next_value = (adv_value + non_adv_value) / 2
recover = np.copy(perturbed_img)
recover[ite_ind] = next_value
query_count += 1
is_adversarial = self._model.is_adversarial(
recover.reshape(img_shape), input_label, self._is_targeted)
if is_adversarial:
adv_value = next_value
else:
non_adv_value = next_value
return adv_value, query_count
def _initialize_starting_point(self, inputs, labels):
"""
Use init_attack to generate original adversarial inputs.
Args:
inputs (numpy.ndarray): Benign input sample used as references to create
adversarial examples.
labels (numpy.ndarray): If is targeted attack, labels is adversarial
labels, if is untargeted attack, labels is true labels.
Returns:
numpy.ndarray, adversarial image(s) generate by init_attack method.
"""
is_adv, start_adv, query_c = self._init_attack.generate(inputs, labels)
return is_adv, start_adv, query_c
def _distance(self, perturbed_img, unperturbed_img):
"""
Calculate Mean Squared Error (MSE) to evaluate the optimized process.
Args:
perturbed_img (numpy.ndarray): Adversarial sample to be optimized.
unperturbed_img (numpy.ndarray): As a reference benigh sample.
Returns:
float, Calculation of Mean Squared Error (MSE).
"""
return np.square(np.subtract(perturbed_img, unperturbed_img)).mean()
def _get_threthod(self, method='MSE'):
"""
Return a float number, when distance small than this number,
optimize will abort early.
Args:
method: distance method. Default: MSE.
Returns:
float, the optimized level, the smaller of number, the better
of adversarial sample.
"""
predefined_threshold = 0.01
if method == 'MSE':
return predefined_threshold
return predefined_threshold
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PSO-Attack.
"""
import numpy as np
from mindarmour.attacks.attack import Attack
from mindarmour.utils.logger import LogUtil
from mindarmour.attacks.black.black_model import BlackModel
from mindarmour.utils._check_param import check_model, check_pair_numpy_param, \
check_numpy_param, check_value_positive, check_int_positive, \
check_param_type, check_equal_shape, check_param_multi_types
LOGGER = LogUtil.get_instance()
TAG = 'PSOAttack'
class PSOAttack(Attack):
"""
The PSO Attack represents the black-box attack based on Particle Swarm
Optimization algorithm, which belongs to differential evolution algorithms.
This attack was proposed by Rayan Mosli et al. (2019).
References: `Rayan Mosli, Matthew Wright, Bo Yuan, Yin Pan, "They Might NOT
Be Giants: Crafting Black-Box Adversarial Examples with Fewer Queries
Using Particle Swarm Optimization", arxiv: 1909.07490, 2019.
<https://arxiv.org/abs/1909.07490>`_
Args:
model (BlackModel): Target model.
step_size (float): Attack step size. Default: 0.5.
per_bounds (float): Relative variation range of perturbations. Default: 0.6.
c1 (float): Weight coefficient. Default: 2.
c2 (float): Weight coefficient. Default: 2.
c (float): Weight of perturbation loss. Default: 2.
pop_size (int): The number of particles, which should be greater
than zero. Default: 6.
t_max (int): The maximum round of iteration for each adversarial example,
which should be greater than zero. Default: 1000.
pm (float): The probability of mutations. Default: 0.5.
bounds (tuple): Upper and lower bounds of data. In form of (clip_min,
clip_max). Default: None.
targeted (bool): If True, turns on the targeted attack. If False,
turns on untargeted attack. Default: False.
reduction_iters (int): Cycle times in reduction process. Default: 3.
sparse (bool): If True, input labels are sparse-encoded. If False,
input labels are one-hot-encoded. Default: True.
Examples:
>>> attack = PSOAttack(model)
"""
def __init__(self, model, step_size=0.5, per_bounds=0.6, c1=2.0, c2=2.0,
c=2.0, pop_size=6, t_max=1000, pm=0.5, bounds=None,
targeted=False, reduction_iters=3, sparse=True):
super(PSOAttack, self).__init__()
self._model = check_model('model', model, BlackModel)
self._step_size = check_value_positive('step_size', step_size)
self._per_bounds = check_value_positive('per_bounds', per_bounds)
self._c1 = check_value_positive('c1', c1)
self._c2 = check_value_positive('c2', c2)
self._c = check_value_positive('c', c)
self._pop_size = check_int_positive('pop_size', pop_size)
self._pm = check_value_positive('pm', pm)
self._bounds = check_param_multi_types('bounds', bounds, [list, tuple])
for b in self._bounds:
_ = check_param_multi_types('bound', b, [int, float])
self._targeted = check_param_type('targeted', targeted, bool)
self._t_max = check_int_positive('t_max', t_max)
self._reduce_iters = check_int_positive('reduction_iters',
reduction_iters)
self._sparse = check_param_type('sparse', sparse, bool)
def _fitness(self, confi_ori, confi_adv, x_ori, x_adv):
"""
Calculate the fitness value for each particle.
Args:
confi_ori (float): Maximum confidence or target label confidence of
the original benign inputs' prediction confidences.
confi_adv (float): Maximum confidence or target label confidence of
the adversarial samples' prediction confidences.
x_ori (numpy.ndarray): Benign samples.
x_adv (numpy.ndarray): Adversarial samples.
Returns:
- float, fitness values of adversarial particles.
- int, query times after reduction.
Examples:
>>> fitness = self._fitness(2.4, 1.2, [0.2, 0.3, 0.1], [0.21,
>>> 0.34, 0.13])
"""
x_ori = check_numpy_param('x_ori', x_ori)
x_adv = check_numpy_param('x_adv', x_adv)
fit_value = abs(
confi_ori - confi_adv) - self._c / self._pop_size*np.linalg.norm(
(x_adv - x_ori).reshape(x_adv.shape[0], -1), axis=1)
return fit_value
def _mutation_op(self, cur_pop):
"""
Generate mutation samples.
"""
cur_pop = check_numpy_param('cur_pop', cur_pop)
perturb_noise = np.random.random(cur_pop.shape) - 0.5
mutated_pop = perturb_noise*(np.random.random(cur_pop.shape)
< self._pm) + cur_pop
mutated_pop = np.clip(mutated_pop, cur_pop*(1 - self._per_bounds),
cur_pop*(1 + self._per_bounds))
return mutated_pop
def _reduction(self, x_ori, q_times, label, best_position):
"""
Decrease the differences between the original samples and adversarial samples.
Args:
x_ori (numpy.ndarray): Original samples.
q_times (int): Query times.
label (int): Target label ot ground-truth label.
best_position (numpy.ndarray): Adversarial examples.
Returns:
numpy.ndarray, adversarial examples after reduction.
Examples:
>>> adv_reduction = self._reduction(self, [0.1, 0.2, 0.3], 20, 1,
>>> [0.12, 0.15, 0.25])
"""
x_ori = check_numpy_param('x_ori', x_ori)
best_position = check_numpy_param('best_position', best_position)
x_ori, best_position = check_equal_shape('x_ori', x_ori,
'best_position', best_position)
x_ori_fla = x_ori.flatten()
best_position_fla = best_position.flatten()
pixel_deep = self._bounds[1] - self._bounds[0]
nums_pixel = len(x_ori_fla)
for i in range(nums_pixel):
diff = x_ori_fla[i] - best_position_fla[i]
if abs(diff) > pixel_deep*0.1:
old_poi_fla = np.copy(best_position_fla)
best_position_fla[i] = np.clip(
best_position_fla[i] + diff*0.5,
self._bounds[0], self._bounds[1])
cur_label = np.argmax(
self._model.predict(np.expand_dims(
best_position_fla.reshape(x_ori.shape), axis=0))[0])
q_times += 1
if self._targeted:
if cur_label != label:
best_position_fla = old_poi_fla
else:
if cur_label == label:
best_position_fla = old_poi_fla
return best_position_fla.reshape(x_ori.shape), q_times
def generate(self, inputs, labels):
"""
Generate adversarial examples based on input data and targeted
labels (or ground_truth labels).
Args:
inputs (numpy.ndarray): Input samples.
labels (numpy.ndarray): Targeted labels or ground_truth labels.
Returns:
- numpy.ndarray, bool values for each attack result.
- numpy.ndarray, generated adversarial examples.
- numpy.ndarray, query times for each sample.
Examples:
>>> advs = attack.generate([[0.2, 0.3, 0.4], [0.3, 0.3, 0.2]],
>>> [1, 2])
"""
inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels)
if not self._sparse:
labels = np.argmax(labels, axis=1)
# generate one adversarial each time
if self._targeted:
target_labels = labels
adv_list = []
success_list = []
query_times_list = []
pixel_deep = self._bounds[1] - self._bounds[0]
for i in range(inputs.shape[0]):
is_success = False
q_times = 0
x_ori = inputs[i]
confidences = self._model.predict(np.expand_dims(x_ori, axis=0))[0]
q_times += 1
true_label = labels[i]
if self._targeted:
t_label = target_labels[i]
confi_ori = confidences[t_label]
else:
confi_ori = max(confidences)
# step1, initializing
# initial global optimum fitness value, cannot set to be 0
best_fitness = -np.inf
# initial global optimum position
best_position = x_ori
x_copies = np.repeat(x_ori[np.newaxis, :], self._pop_size, axis=0)
cur_noise = np.clip((np.random.random(x_copies.shape) - 0.5)
*self._step_size,
(0 - self._per_bounds)*(x_copies + 0.1),
self._per_bounds*(x_copies + 0.1))
par = np.clip(x_copies + cur_noise,
x_copies*(1 - self._per_bounds),
x_copies*(1 + self._per_bounds))
# initial advs
par_ori = np.copy(par)
# initial optimum positions for particles
par_best_poi = np.copy(par)
# initial optimum fitness values
par_best_fit = -np.inf*np.ones(self._pop_size)
# step2, optimization
# initial velocities for particles
v_particles = np.zeros(par.shape)
is_mutation = False
iters = 0
while iters < self._t_max:
last_best_fit = best_fitness
ran_1 = np.random.random(par.shape)
ran_2 = np.random.random(par.shape)
v_particles = self._step_size*(
v_particles + self._c1*ran_1*(best_position - par)) \
+ self._c2*ran_2*(par_best_poi - par)
par = np.clip(par + v_particles,
(par_ori + 0.1*pixel_deep)*(
1 - self._per_bounds),
(par_ori + 0.1*pixel_deep)*(
1 + self._per_bounds))
if iters > 30 and is_mutation:
par = self._mutation_op(par)
if self._targeted:
confi_adv = self._model.predict(par)[:, t_label]
else:
confi_adv = np.max(self._model.predict(par), axis=1)
q_times += self._pop_size
fit_value = self._fitness(confi_ori, confi_adv, x_ori, par)
for k in range(self._pop_size):
if fit_value[k] > par_best_fit[k]:
par_best_fit[k] = fit_value[k]
par_best_poi[k] = par[k]
if fit_value[k] > best_fitness:
best_fitness = fit_value[k]
best_position = par[k]
iters += 1
cur_pre = self._model.predict(np.expand_dims(best_position,
axis=0))[0]
is_mutation = False
if (best_fitness - last_best_fit) < last_best_fit*0.05:
is_mutation = True
cur_label = np.argmax(cur_pre)
q_times += 1
if self._targeted:
if cur_label == t_label:
is_success = True
else:
if cur_label != true_label:
is_success = True
if is_success:
LOGGER.debug(TAG, 'successfully find one adversarial '
'sample and start Reduction process')
# step3, reduction
if self._targeted:
best_position, q_times = self._reduction(
x_ori, q_times, t_label, best_position)
else:
best_position, q_times = self._reduction(
x_ori, q_times, true_label, best_position)
break
if not is_success:
LOGGER.debug(TAG,
'fail to find adversarial sample, iteration '
'times is: %d and query times is: %d',
iters,
q_times)
adv_list.append(best_position)
success_list.append(is_success)
query_times_list.append(q_times)
del x_copies, cur_noise, par, par_ori, par_best_poi
return np.asarray(success_list), \
np.asarray(adv_list), \
np.asarray(query_times_list)
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
"""
This module includes classical defense algorithms in defencing adversarial
examples and enhancing model security and trustworthy.
"""
from .adversarial_defense import AdversarialDefense
from .adversarial_defense import AdversarialDefenseWithAttacks
from .adversarial_defense import EnsembleAdversarialDefense
from .natural_adversarial_defense import NaturalAdversarialDefense
from .projected_adversarial_defense import ProjectedAdversarialDefense
__all__ = ['AdversarialDefense',
'AdversarialDefenseWithAttacks',
'NaturalAdversarialDefense',
'ProjectedAdversarialDefense',
'EnsembleAdversarialDefense']
此差异已折叠。
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base Class of Defense.
"""
from abc import abstractmethod
from mindarmour.utils.logger import LogUtil
from mindarmour.utils._check_param import check_pair_numpy_param, \
check_int_positive
LOGGER = LogUtil.get_instance()
TAG = 'Defense'
class Defense:
"""
The abstract base class for all defense classes defending adversarial
examples.
Args:
network (Cell): A MindSpore-style deep learning model to be defensed.
"""
def __init__(self, network):
self._network = network
@abstractmethod
def defense(self, inputs, labels):
"""
Defense model with samples.
Args:
inputs (numpy.ndarray): Samples based on which adversarial
examples are generated.
labels (numpy.ndarray): Labels of input samples.
Raises:
NotImplementedError: It is an abstract method.
"""
msg = 'The function defense() is an abstract function in class ' \
'`Defense` and should be implemented in child class.'
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
def batch_defense(self, inputs, labels, batch_size=32, epochs=5):
"""
Defense model with samples in batch.
Args:
inputs (numpy.ndarray): Samples based on which adversarial
examples are generated.
labels (numpy.ndarray): Labels of input samples.
batch_size (int): Number of samples in one batch.
epochs (int): Number of epochs.
Returns:
numpy.ndarray, loss of batch_defense operation.
Raises:
ValueError: If batch_size is 0.
"""
inputs, labels = check_pair_numpy_param('inputs', inputs, 'labels',
labels)
x_len = len(inputs)
batch_size = check_int_positive('batch_size', batch_size)
iters_per_epoch = int(x_len / batch_size)
loss = None
for _ in range(epochs):
for step in range(iters_per_epoch):
x_batch = inputs[step*batch_size:(step + 1)*batch_size]
y_batch = labels[step*batch_size:(step + 1)*batch_size]
loss = self.defense(x_batch, y_batch)
return loss
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
"""
This module includes various metrics to evaluate the result of attacks or
defenses.
"""
from .attack_evaluation import AttackEvaluate
from .defense_evaluation import DefenseEvaluate
from .visual_metrics import RadarMetric
from . import black
from .black.defense_evaluation import BlackDefenseEvaluate
__all__ = ['AttackEvaluate',
'BlackDefenseEvaluate',
'DefenseEvaluate',
'RadarMetric']
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册