提交 a4c4feca 编写于 作者: J jin-xiulang

Develop new module for model fuzz test.

上级 248b479d
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
from mindspore import Model
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindarmour.attacks.gradient_method import FastGradientSignMethod
from mindarmour.utils.logger import LogUtil
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics
from lenet5_net import LeNet5
sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
TAG = 'Neuron coverage test'
LOGGER.set_level('INFO')
def test_lenet_mnist_coverage():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
# upload trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
model = Model(net)
# get training data
data_list = "./MNIST_unzip/train"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size, sparse=True)
train_images = []
for data in ds.create_tuple_iterator():
images = data[0].astype(np.float32)
train_images.append(images)
train_images = np.concatenate(train_images, axis=0)
# initialize fuzz test with training dataset
model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images)
# fuzz test with original test data
# get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size, sparse=True)
test_images = []
test_labels = []
for data in ds.create_tuple_iterator():
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
test_images = np.concatenate(test_images, axis=0)
test_labels = np.concatenate(test_labels, axis=0)
model_fuzz_test.test_adequacy_coverage_calculate(test_images)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())
# generate adv_data
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss)
adv_data = attack.batch_generate(test_images, test_labels, batch_size=32)
model_fuzz_test.test_adequacy_coverage_calculate(adv_data,
bias_coefficient=0.5)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())
if __name__ == '__main__':
test_lenet_mnist_coverage()
from .model_coverage_metrics import ModelCoverageMetrics
__all__ = ['ModelCoverageMetrics']
\ No newline at end of file
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Model-Test Coverage Metrics.
"""
import numpy as np
from mindspore import Tensor
from mindspore import Model
from mindarmour.utils._check_param import check_model, check_numpy_param, \
check_int_positive
class ModelCoverageMetrics:
"""
Evaluate the testing adequacy of a model fuzz test.
Reference: `DeepGauge: Multi-Granularity Testing Criteria for Deep
Learning Systems <https://arxiv.org/abs/1803.07519>`_
Args:
model (Model): The pre-trained model which waiting for testing.
k (int): The number of segmented sections of neurons' output intervals.
n (int): The number of testing neurons.
train_dataset (numpy.ndarray): Training dataset used for determine
the neurons' output boundaries.
"""
def __init__(self, model, k, n, train_dataset):
self._model = check_model('model', model, Model)
self._k = k
self._n = n
train_dataset = check_numpy_param('train_dataset', train_dataset)
self._lower_bounds = [np.inf]*n
self._upper_bounds = [-np.inf]*n
self._var = [0]*n
self._main_section_hits = [[0 for _ in range(self._k)] for _ in
range(self._n)]
self._lower_corner_hits = [0]*self._n
self._upper_corner_hits = [0]*self._n
self._bounds_get(train_dataset)
def _bounds_get(self, train_dataset, batch_size=32):
"""
Update the lower and upper boundaries of neurons' outputs.
Args:
train_dataset (numpy.ndarray): Training dataset used for
determine the neurons' output boundaries.
batch_size (int): The number of samples in a predict batch.
Default: 32.
"""
batch_size = check_int_positive('batch_size', batch_size)
output_mat = []
batches = train_dataset.shape[0] // batch_size
for i in range(batches):
inputs = train_dataset[i*batch_size: (i + 1)*batch_size]
output = self._model.predict(Tensor(inputs)).asnumpy()
output_mat.append(output)
lower_compare_array = np.concatenate(
[output, np.array([self._lower_bounds])], axis=0)
self._lower_bounds = np.min(lower_compare_array, axis=0)
upper_compare_array = np.concatenate(
[output, np.array([self._upper_bounds])], axis=0)
self._upper_bounds = np.max(upper_compare_array, axis=0)
self._var = np.std(np.concatenate(np.array(output_mat), axis=0),
axis=0)
def _sections_hits_count(self, dataset, intervals):
"""
Update the coverage matrix of neurons' output subsections.
Args:
dataset (numpy.ndarray): Testing data.
intervals (list[float]): Segmentation intervals of neurons'
outputs.
"""
dataset = check_numpy_param('dataset', dataset)
batch_output = self._model.predict(Tensor(dataset)).asnumpy()
batch_section_indexes = (batch_output - self._lower_bounds) // intervals
for section_indexes in batch_section_indexes:
for i in range(self._n):
if section_indexes[i] < 0:
self._lower_corner_hits[i] = 1
elif section_indexes[i] >= self._k:
self._upper_corner_hits[i] = 1
else:
self._main_section_hits[i][int(section_indexes[i])] = 1
def test_adequacy_coverage_calculate(self, dataset, bias_coefficient=0,
batch_size=32):
"""
Calculate the testing adequacy of the given dataset.
Args:
dataset (numpy.ndarray): Data for fuzz test.
bias_coefficient (float): The coefficient used for changing the
neurons' output boundaries. Default: 0.
batch_size (int): The number of samples in a predict batch.
Default: 32.
Examples:
>>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images)
>>> model_fuzz_test.test_adequacy_coverage_calculate(test_images)
"""
dataset = check_numpy_param('dataset', dataset)
batch_size = check_int_positive('batch_size', batch_size)
self._lower_bounds -= bias_coefficient*self._var
self._upper_bounds += bias_coefficient*self._var
intervals = (self._upper_bounds - self._lower_bounds) / self._k
batches = dataset.shape[0] // batch_size
for i in range(batches):
self._sections_hits_count(
dataset[i*batch_size: (i + 1)*batch_size], intervals)
def get_kmnc(self):
"""
Get the metric of 'k-multisection neuron coverage'.
Returns:
float, the metric of 'k-multisection neuron coverage'.
Examples:
>>> model_fuzz_test.get_kmnc()
"""
kmnc = np.sum(self._main_section_hits) / (self._n*self._k)
return kmnc
def get_nbc(self):
"""
Get the metric of 'neuron boundary coverage'.
Returns:
float, the metric of 'neuron boundary coverage'.
Examples:
>>> model_fuzz_test.get_nbc()
"""
nbc = (np.sum(self._lower_corner_hits) + np.sum(
self._upper_corner_hits)) / (2*self._n)
return nbc
def get_snac(self):
"""
Get the metric of 'strong neuron activation coverage'.
Returns:
float: the metric of 'strong neuron activation coverage'.
Examples:
>>> model_fuzz_test.get_snac()
"""
snac = np.sum(self._upper_corner_hits) / self._n
return snac
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Model-fuzz coverage test.
"""
import numpy as np
import pytest
from mindspore.train import Model
import mindspore.nn as nn
from mindspore.nn import Cell
from mindspore import context
from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindarmour.attacks.gradient_method import FastGradientSignMethod
from mindarmour.utils.logger import LogUtil
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics
LOGGER = LogUtil.get_instance()
TAG = 'Neuron coverage test'
LOGGER.set_level('INFO')
# for user
class Net(Cell):
"""
Construct the network of target model.
Examples:
>>> net = Net()
"""
def __init__(self):
"""
Introduce the layers used for network construction.
"""
super(Net, self).__init__()
self._relu = nn.ReLU()
def construct(self, inputs):
"""
Construct network.
Args:
inputs (Tensor): Input data.
"""
out = self._relu(inputs)
return out
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_lenet_mnist_coverage_cpu():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
# load network
net = Net()
model = Model(net)
# initialize fuzz test with training dataset
training_data = (np.random.random((10000, 10))*20).astype(np.float32)
model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, training_data)
# fuzz test with original test data
# get test data
test_data = (np.random.random((2000, 10))*20).astype(np.float32)
test_labels = np.random.randint(0, 10, 2000).astype(np.int32)
model_fuzz_test.test_adequacy_coverage_calculate(test_data)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())
# generate adv_data
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss)
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32)
model_fuzz_test.test_adequacy_coverage_calculate(adv_data,
bias_coefficient=0.5)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_lenet_mnist_coverage_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
# load network
net = Net()
model = Model(net)
# initialize fuzz test with training dataset
training_data = (np.random.random((10000, 10))*20).astype(np.float32)
model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, training_data)
# fuzz test with original test data
# get test data
test_data = (np.random.random((2000, 10))*20).astype(np.float32)
test_labels = np.random.randint(0, 10, 2000)
test_labels = (np.eye(10)[test_labels]).astype(np.float32)
model_fuzz_test.test_adequacy_coverage_calculate(test_data)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())
# generate adv_data
attack = FastGradientSignMethod(net, eps=0.3)
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32)
model_fuzz_test.test_adequacy_coverage_calculate(adv_data,
bias_coefficient=0.5)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册