From a4c4feca4e82c484cbe0a7a5056b6e3c415f042e Mon Sep 17 00:00:00 2001 From: jin-xiulang Date: Wed, 22 Apr 2020 16:07:24 +0800 Subject: [PATCH] Develop new module for model fuzz test. --- example/mnist_demo/lenet5_mnist_coverage.py | 89 ++++++++++ mindarmour/fuzzing/__init__.py | 3 + mindarmour/fuzzing/model_coverage_metrics.py | 167 ++++++++++++++++++ .../python/fuzzing/test_coverage_metrics.py | 128 ++++++++++++++ 4 files changed, 387 insertions(+) create mode 100644 example/mnist_demo/lenet5_mnist_coverage.py create mode 100644 mindarmour/fuzzing/__init__.py create mode 100644 mindarmour/fuzzing/model_coverage_metrics.py create mode 100644 tests/ut/python/fuzzing/test_coverage_metrics.py diff --git a/example/mnist_demo/lenet5_mnist_coverage.py b/example/mnist_demo/lenet5_mnist_coverage.py new file mode 100644 index 0000000..b5181e8 --- /dev/null +++ b/example/mnist_demo/lenet5_mnist_coverage.py @@ -0,0 +1,89 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import numpy as np + +from mindspore import Model +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.nn import SoftmaxCrossEntropyWithLogits + +from mindarmour.attacks.gradient_method import FastGradientSignMethod +from mindarmour.utils.logger import LogUtil +from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics + +from lenet5_net import LeNet5 + +sys.path.append("..") +from data_processing import generate_mnist_dataset + +LOGGER = LogUtil.get_instance() +TAG = 'Neuron coverage test' +LOGGER.set_level('INFO') + + +def test_lenet_mnist_coverage(): + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + # upload trained network + ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + net = LeNet5() + load_dict = load_checkpoint(ckpt_name) + load_param_into_net(net, load_dict) + model = Model(net) + + # get training data + data_list = "./MNIST_unzip/train" + batch_size = 32 + ds = generate_mnist_dataset(data_list, batch_size, sparse=True) + train_images = [] + for data in ds.create_tuple_iterator(): + images = data[0].astype(np.float32) + train_images.append(images) + train_images = np.concatenate(train_images, axis=0) + + # initialize fuzz test with training dataset + model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) + + # fuzz test with original test data + # get test data + data_list = "./MNIST_unzip/test" + batch_size = 32 + ds = generate_mnist_dataset(data_list, batch_size, sparse=True) + test_images = [] + test_labels = [] + for data in ds.create_tuple_iterator(): + images = data[0].astype(np.float32) + labels = data[1] + test_images.append(images) + test_labels.append(labels) + test_images = np.concatenate(test_images, axis=0) + test_labels = np.concatenate(test_labels, axis=0) + model_fuzz_test.test_adequacy_coverage_calculate(test_images) + LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) + LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) + LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) + + # generate adv_data + loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) + adv_data = attack.batch_generate(test_images, test_labels, batch_size=32) + model_fuzz_test.test_adequacy_coverage_calculate(adv_data, + bias_coefficient=0.5) + LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) + LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) + LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) + + +if __name__ == '__main__': + test_lenet_mnist_coverage() diff --git a/mindarmour/fuzzing/__init__.py b/mindarmour/fuzzing/__init__.py new file mode 100644 index 0000000..c591d2b --- /dev/null +++ b/mindarmour/fuzzing/__init__.py @@ -0,0 +1,3 @@ +from .model_coverage_metrics import ModelCoverageMetrics + +__all__ = ['ModelCoverageMetrics'] \ No newline at end of file diff --git a/mindarmour/fuzzing/model_coverage_metrics.py b/mindarmour/fuzzing/model_coverage_metrics.py new file mode 100644 index 0000000..bc8f562 --- /dev/null +++ b/mindarmour/fuzzing/model_coverage_metrics.py @@ -0,0 +1,167 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Model-Test Coverage Metrics. +""" + +import numpy as np + +from mindspore import Tensor +from mindspore import Model + +from mindarmour.utils._check_param import check_model, check_numpy_param, \ + check_int_positive + + +class ModelCoverageMetrics: + """ + Evaluate the testing adequacy of a model fuzz test. + + Reference: `DeepGauge: Multi-Granularity Testing Criteria for Deep + Learning Systems `_ + + Args: + model (Model): The pre-trained model which waiting for testing. + k (int): The number of segmented sections of neurons' output intervals. + n (int): The number of testing neurons. + train_dataset (numpy.ndarray): Training dataset used for determine + the neurons' output boundaries. + """ + + def __init__(self, model, k, n, train_dataset): + self._model = check_model('model', model, Model) + self._k = k + self._n = n + train_dataset = check_numpy_param('train_dataset', train_dataset) + self._lower_bounds = [np.inf]*n + self._upper_bounds = [-np.inf]*n + self._var = [0]*n + self._main_section_hits = [[0 for _ in range(self._k)] for _ in + range(self._n)] + self._lower_corner_hits = [0]*self._n + self._upper_corner_hits = [0]*self._n + self._bounds_get(train_dataset) + + def _bounds_get(self, train_dataset, batch_size=32): + """ + Update the lower and upper boundaries of neurons' outputs. + + Args: + train_dataset (numpy.ndarray): Training dataset used for + determine the neurons' output boundaries. + batch_size (int): The number of samples in a predict batch. + Default: 32. + """ + batch_size = check_int_positive('batch_size', batch_size) + output_mat = [] + batches = train_dataset.shape[0] // batch_size + for i in range(batches): + inputs = train_dataset[i*batch_size: (i + 1)*batch_size] + output = self._model.predict(Tensor(inputs)).asnumpy() + output_mat.append(output) + lower_compare_array = np.concatenate( + [output, np.array([self._lower_bounds])], axis=0) + self._lower_bounds = np.min(lower_compare_array, axis=0) + upper_compare_array = np.concatenate( + [output, np.array([self._upper_bounds])], axis=0) + self._upper_bounds = np.max(upper_compare_array, axis=0) + self._var = np.std(np.concatenate(np.array(output_mat), axis=0), + axis=0) + + def _sections_hits_count(self, dataset, intervals): + """ + Update the coverage matrix of neurons' output subsections. + + Args: + dataset (numpy.ndarray): Testing data. + intervals (list[float]): Segmentation intervals of neurons' + outputs. + """ + dataset = check_numpy_param('dataset', dataset) + batch_output = self._model.predict(Tensor(dataset)).asnumpy() + batch_section_indexes = (batch_output - self._lower_bounds) // intervals + for section_indexes in batch_section_indexes: + for i in range(self._n): + if section_indexes[i] < 0: + self._lower_corner_hits[i] = 1 + elif section_indexes[i] >= self._k: + self._upper_corner_hits[i] = 1 + else: + self._main_section_hits[i][int(section_indexes[i])] = 1 + + def test_adequacy_coverage_calculate(self, dataset, bias_coefficient=0, + batch_size=32): + """ + Calculate the testing adequacy of the given dataset. + + Args: + dataset (numpy.ndarray): Data for fuzz test. + bias_coefficient (float): The coefficient used for changing the + neurons' output boundaries. Default: 0. + batch_size (int): The number of samples in a predict batch. + Default: 32. + + Examples: + >>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) + >>> model_fuzz_test.test_adequacy_coverage_calculate(test_images) + """ + dataset = check_numpy_param('dataset', dataset) + batch_size = check_int_positive('batch_size', batch_size) + self._lower_bounds -= bias_coefficient*self._var + self._upper_bounds += bias_coefficient*self._var + intervals = (self._upper_bounds - self._lower_bounds) / self._k + batches = dataset.shape[0] // batch_size + for i in range(batches): + self._sections_hits_count( + dataset[i*batch_size: (i + 1)*batch_size], intervals) + + def get_kmnc(self): + """ + Get the metric of 'k-multisection neuron coverage'. + + Returns: + float, the metric of 'k-multisection neuron coverage'. + + Examples: + >>> model_fuzz_test.get_kmnc() + """ + kmnc = np.sum(self._main_section_hits) / (self._n*self._k) + return kmnc + + def get_nbc(self): + """ + Get the metric of 'neuron boundary coverage'. + + Returns: + float, the metric of 'neuron boundary coverage'. + + Examples: + >>> model_fuzz_test.get_nbc() + """ + nbc = (np.sum(self._lower_corner_hits) + np.sum( + self._upper_corner_hits)) / (2*self._n) + return nbc + + def get_snac(self): + """ + Get the metric of 'strong neuron activation coverage'. + + Returns: + float: the metric of 'strong neuron activation coverage'. + + Examples: + >>> model_fuzz_test.get_snac() + """ + snac = np.sum(self._upper_corner_hits) / self._n + return snac diff --git a/tests/ut/python/fuzzing/test_coverage_metrics.py b/tests/ut/python/fuzzing/test_coverage_metrics.py new file mode 100644 index 0000000..dd98507 --- /dev/null +++ b/tests/ut/python/fuzzing/test_coverage_metrics.py @@ -0,0 +1,128 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Model-fuzz coverage test. +""" +import numpy as np +import pytest + +from mindspore.train import Model +import mindspore.nn as nn +from mindspore.nn import Cell +from mindspore import context +from mindspore.nn import SoftmaxCrossEntropyWithLogits + +from mindarmour.attacks.gradient_method import FastGradientSignMethod +from mindarmour.utils.logger import LogUtil +from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics + +LOGGER = LogUtil.get_instance() +TAG = 'Neuron coverage test' +LOGGER.set_level('INFO') + + +# for user +class Net(Cell): + """ + Construct the network of target model. + + Examples: + >>> net = Net() + """ + + def __init__(self): + """ + Introduce the layers used for network construction. + """ + super(Net, self).__init__() + self._relu = nn.ReLU() + + def construct(self, inputs): + """ + Construct network. + + Args: + inputs (Tensor): Input data. + """ + out = self._relu(inputs) + return out + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_lenet_mnist_coverage_cpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + # load network + net = Net() + model = Model(net) + + # initialize fuzz test with training dataset + training_data = (np.random.random((10000, 10))*20).astype(np.float32) + model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, training_data) + + # fuzz test with original test data + # get test data + test_data = (np.random.random((2000, 10))*20).astype(np.float32) + test_labels = np.random.randint(0, 10, 2000).astype(np.int32) + model_fuzz_test.test_adequacy_coverage_calculate(test_data) + LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) + LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) + LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) + + # generate adv_data + loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) + adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) + model_fuzz_test.test_adequacy_coverage_calculate(adv_data, + bias_coefficient=0.5) + LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) + LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) + LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_lenet_mnist_coverage_ascend(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + # load network + net = Net() + model = Model(net) + + # initialize fuzz test with training dataset + training_data = (np.random.random((10000, 10))*20).astype(np.float32) + model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, training_data) + + # fuzz test with original test data + # get test data + test_data = (np.random.random((2000, 10))*20).astype(np.float32) + test_labels = np.random.randint(0, 10, 2000) + test_labels = (np.eye(10)[test_labels]).astype(np.float32) + model_fuzz_test.test_adequacy_coverage_calculate(test_data) + LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) + LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) + LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) + + # generate adv_data + attack = FastGradientSignMethod(net, eps=0.3) + adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) + model_fuzz_test.test_adequacy_coverage_calculate(adv_data, + bias_coefficient=0.5) + LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) + LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) + LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) \ No newline at end of file -- GitLab