提交 3e5951cd 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!77 Add comment and param-check of Fuzzer

Merge pull request !77 from pkuliuliu/master
......@@ -67,7 +67,7 @@ def test_lenet_mnist_coverage():
test_labels.append(labels)
test_images = np.concatenate(test_images, axis=0)
test_labels = np.concatenate(test_labels, axis=0)
model_fuzz_test.test_adequacy_coverage_calculate(test_images)
model_fuzz_test.calculate_coverage(test_images)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())
......@@ -76,14 +76,13 @@ def test_lenet_mnist_coverage():
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss)
adv_data = attack.batch_generate(test_images, test_labels, batch_size=32)
model_fuzz_test.test_adequacy_coverage_calculate(adv_data,
bias_coefficient=0.5)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())
model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5)
LOGGER.info(TAG, 'KMNC of this adv data is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this adv data is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this adv data is : %s', model_fuzz_test.get_snac())
if __name__ == '__main__':
# device_target can be "CPU", "GPU" or "Ascend"
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_lenet_mnist_coverage()
......@@ -87,9 +87,9 @@ def test_lenet_mnist_fuzzing():
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())
model_fuzz_test = Fuzzer(model, train_images, 1000, 10)
model_fuzz_test = Fuzzer(model, train_images, 10, 1000)
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds,
eval_metric=True)
eval_metrics='auto')
if metrics:
for key in metrics:
LOGGER.info(TAG, key + ': %s', metrics[key])
......@@ -97,5 +97,5 @@ def test_lenet_mnist_fuzzing():
if __name__ == '__main__':
# device_target can be "CPU", "GPU" or "Ascend"
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_lenet_mnist_fuzzing()
......@@ -27,6 +27,48 @@ from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, \
Noise, Translate, Scale, Shear, Rotate
from mindarmour.attacks import FastGradientSignMethod, \
MomentumDiverseInputIterativeMethod, ProjectedGradientDescent
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
TAG = 'Fuzzer'
def _select_next(initial_seeds):
""" Randomly select a seed from `initial_seeds`."""
seed_num = choice(range(len(initial_seeds)))
seed = initial_seeds[seed_num]
del initial_seeds[seed_num]
return seed, initial_seeds
def _coverage_gains(coverages):
""" Calculate the coverage gains of mutated samples. """
gains = [0] + coverages[:-1]
gains = np.array(coverages) - np.array(gains)
return gains
def _is_trans_valid(seed, mutate_sample):
""" Check a mutated sample is valid. If the number of changed pixels in
a seed is less than pixels_change_rate*size(seed), this mutate is valid.
Else check the infinite norm of seed changes, if the value of the
infinite norm less than pixel_value_change_rate*255, this mutate is
valid too. Otherwise the opposite.
"""
is_valid = False
pixels_change_rate = 0.02
pixel_value_change_rate = 0.2
diff = np.array(seed - mutate_sample).flatten()
size = np.shape(diff)[0]
l0_norm = np.linalg.norm(diff, ord=0)
linf = np.linalg.norm(diff, ord=np.inf)
if l0_norm > pixels_change_rate*size:
if linf < 256:
is_valid = True
else:
if linf < pixel_value_change_rate*255:
is_valid = True
return is_valid
class Fuzzer:
......@@ -40,71 +82,203 @@ class Fuzzer:
target_model (Model): Target fuzz model.
train_dataset (numpy.ndarray): Training dataset used for determining
the neurons' output boundaries.
segmented_num (int): The number of segmented sections of neurons'
output intervals.
neuron_num (int): The number of testing neurons.
segmented_num (int): The number of segmented sections of neurons'
output intervals. Default: 1000.
Examples:
>>> net = Net()
>>> mutate_config = [{'method': 'Blur', 'params': {'auto_param': True}},
>>> {'method': 'Contrast','params': {'factor': 2}},
>>> {'method': 'Translate', 'params': {'x_bias': 0.1, 'y_bias': 0.2}},
>>> {'method': 'FGSM', 'params': {'eps': 0.1, 'alpha': 0.1}}]
>>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
>>> model_fuzz_test = Fuzzer(model, train_images, 1000, 10)
>>> samples, labels, preds, strategies, report = model_fuzz_test.fuzzing(mutate_config, initial_seeds)
"""
def __init__(self, target_model, train_dataset, segmented_num, neuron_num):
self.target_model = check_model('model', target_model, Model)
self.train_dataset = check_numpy_param('train_dataset', train_dataset)
self.coverage_metrics = ModelCoverageMetrics(target_model,
def __init__(self, target_model, train_dataset, neuron_num, segmented_num=1000):
self._target_model = check_model('model', target_model, Model)
train_dataset = check_numpy_param('train_dataset', train_dataset)
self._coverage_metrics = ModelCoverageMetrics(target_model,
segmented_num,
neuron_num, train_dataset)
# Allowed mutate strategies so far.
self.strategies = {'Contrast': Contrast, 'Brightness': Brightness,
self._strategies = {'Contrast': Contrast, 'Brightness': Brightness,
'Blur': Blur, 'Noise': Noise, 'Translate': Translate,
'Scale': Scale, 'Shear': Shear, 'Rotate': Rotate,
'FGSM': FastGradientSignMethod,
'PGD': ProjectedGradientDescent,
'MDIIM': MomentumDiverseInputIterativeMethod}
self.affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate']
self.pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur',
self._affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate']
self._pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur',
'Noise']
self.attacks_list = ['FGSM', 'PGD', 'MDIIM']
self.attack_param_checklists = {
self._attacks_list = ['FGSM', 'PGD', 'MDIIM']
self._attack_param_checklists = {
'FGSM': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]},
'alpha': {'dtype': [float, int],
'range': [0, 1]},
'bounds': {'dtype': [list, tuple],
'range': None},
}},
'range': None}}},
'PGD': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]},
'eps_iter': {'dtype': [float, int],
'range': [0, 1e5]},
'nb_iter': {'dtype': [float, int],
'range': [0, 1e5]},
'bounds': {'dtype': [list, tuple],
'range': None},
}},
'range': None}}},
'MDIIM': {
'params': {'eps': {'dtype': [float, int], 'range': [0, 1]},
'norm_level': {'dtype': [str], 'range': None},
'prob': {'dtype': [float, int], 'range': [0, 1]},
'bounds': {'dtype': [list, tuple], 'range': None},
}}}
'bounds': {'dtype': [list, tuple], 'range': None}}}}
def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC',
eval_metrics='auto', max_iters=10000, mutate_num_per_seed=20):
"""
Fuzzing tests for deep neural networks.
Args:
mutate_config (list): Mutate configs. The format is
[{'method': 'Blur', 'params': {'auto_param': True}}, {'method': 'Contrast', 'params': {'factor': 2}}].
The support methods list is in `self._strategies`, and the params of each
method must within the range of changeable parameters.
initial_seeds (numpy.ndarray): Initial seeds used to generate
mutated samples.
coverage_metric (str): Model coverage metric of neural networks.
Default: 'KMNC'.
eval_metrics (Union[list, tuple, str]): Evaluation metrics. If the type is 'auto',
it will calculate all the metrics, else if the type is list or tuple, it will
calculate the metrics specified by user. Default: 'auto'.
max_iters (int): Max number of select a seed to mutate.
Default: 10000.
mutate_num_per_seed (int): The number of mutate times for a seed.
Default: 20.
Returns:
- list, mutated samples in fuzzing.
- list, ground truth labels of mutated samples.
- list, preds of mutated samples.
- list, strategies of mutated samples.
- dict, metrics report of fuzzer.
Raises:
TypeError: If the type of `eval_metrics` is not str, list or tuple.
TypeError: If the type of metric in list `eval_metrics` is not str.
ValueError: If `eval_metrics` is not equal to 'auto' when it's type is str.
ValueError: If metric in list `eval_metrics` is not in ['accuracy', 'attack_success_rate',
'kmnc', 'nbc', 'snac'].
"""
eval_metrics_ = None
if isinstance(eval_metrics, (list, tuple)):
eval_metrics_ = []
avaliable_metrics = ['accuracy', 'attack_success_rate', 'kmnc', 'nbc', 'snac']
for elem in eval_metrics:
if not isinstance(elem, str):
msg = 'the type of metric in list `eval_metrics` must be str, but got {}.' \
.format(type(elem))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if elem not in avaliable_metrics:
msg = 'metric in list `eval_metrics` must be in {}, but got {}.' \
.format(avaliable_metrics, elem)
LOGGER.error(TAG, msg)
raise ValueError(msg)
eval_metrics_.append(elem.lower())
elif isinstance(eval_metrics, str):
if eval_metrics != 'auto':
msg = "the value of `eval_metrics` must be 'auto' if it's type is str, " \
"but got {}.".format(eval_metrics)
LOGGER.error(TAG, msg)
raise ValueError(msg)
eval_metrics_ = 'auto'
else:
msg = "the type of `eval_metrics` must be str, list or tuple, but got {}." \
.format(type(eval_metrics))
LOGGER.error(TAG, msg)
raise TypeError(msg)
# Check whether the mutate_config meet the specification.
mutates = self._init_mutates(mutate_config)
seed, initial_seeds = _select_next(initial_seeds)
fuzz_samples = []
gt_labels = []
fuzz_preds = []
fuzz_strategies = []
iter_num = 0
while initial_seeds and iter_num < max_iters:
# Mutate a seed.
mutate_samples, mutate_strategies = self._metamorphic_mutate(seed,
mutates,
mutate_config,
mutate_num_per_seed)
# Calculate the coverages and predictions of generated samples.
coverages, predicts = self._run(mutate_samples, coverage_metric)
coverage_gains = _coverage_gains(coverages)
for mutate, cov, pred, strategy in zip(mutate_samples,
coverage_gains,
predicts, mutate_strategies):
fuzz_samples.append(mutate[0])
gt_labels.append(mutate[1])
fuzz_preds.append(pred)
fuzz_strategies.append(strategy)
# if the mutate samples has coverage gains add this samples in
# the initial seeds to guide new mutates.
if cov > 0:
initial_seeds.append(mutate)
seed, initial_seeds = _select_next(initial_seeds)
iter_num += 1
metrics_report = None
if eval_metrics_ is not None:
metrics_report = self._evaluate(fuzz_samples, gt_labels, fuzz_preds,
fuzz_strategies, eval_metrics_)
return fuzz_samples, gt_labels, fuzz_preds, fuzz_strategies, metrics_report
def _run(self, mutate_samples, coverage_metric="KNMC"):
""" Calculate the coverages and predictions of generated samples."""
samples = [s[0] for s in mutate_samples]
samples = np.array(samples)
coverages = []
predictions = self._target_model.predict(Tensor(samples.astype(np.float32)))
predictions = predictions.asnumpy()
for index in range(len(samples)):
mutate = samples[:index + 1]
self._coverage_metrics.calculate_coverage(mutate.astype(np.float32))
if coverage_metric == "KMNC":
coverages.append(self._coverage_metrics.get_kmnc())
if coverage_metric == 'NBC':
coverages.append(self._coverage_metrics.get_nbc())
if coverage_metric == 'SNAC':
coverages.append(self._coverage_metrics.get_snac())
return coverages, predictions
def _check_attack_params(self, method, params):
"""Check input parameters of attack methods."""
allow_params = self.attack_param_checklists[method]['params'].keys()
for p in params:
if p not in allow_params:
allow_params = self._attack_param_checklists[method]['params'].keys()
for param_name in params:
if param_name not in allow_params:
msg = "parameters of {} must in {}".format(method, allow_params)
raise ValueError(msg)
if p == 'bounds':
bounds = check_param_multi_types('bounds', params[p],
param_value = params[param_name]
if param_name == 'bounds':
bounds = check_param_multi_types('bounds', param_value,
[list, tuple])
for b in bounds:
_ = check_param_multi_types('bound', b, [int, float])
elif p == 'norm_level':
_ = check_norm_level(params[p])
for bound_value in bounds:
_ = check_param_multi_types('bound', bound_value, [int, float])
elif param_name == 'norm_level':
_ = check_norm_level(param_value)
else:
allow_type = self.attack_param_checklists[method]['params'][p][
allow_type = self._attack_param_checklists[method]['params'][param_name][
'dtype']
allow_range = self.attack_param_checklists[method]['params'][p][
allow_range = self._attack_param_checklists[method]['params'][param_name][
'range']
_ = check_param_multi_types(str(p), params[p], allow_type)
_ = check_param_in_range(str(p), params[p], allow_range[0],
_ = check_param_multi_types(str(param_name), param_value, allow_type)
_ = check_param_in_range(str(param_name), param_value, allow_range[0],
allow_range[1])
def _metamorphic_mutate(self, seed, mutates, mutate_config,
......@@ -117,23 +291,23 @@ class Fuzzer:
strage = choice(mutate_config)
# Choose a pixel value based transform method
if only_pixel_trans:
while strage['method'] not in self.pixel_value_trans_list:
while strage['method'] not in self._pixel_value_trans_list:
strage = choice(mutate_config)
transform = mutates[strage['method']]
params = strage['params']
method = strage['method']
if method in list(self.pixel_value_trans_list + self.affine_trans_list):
if method in list(self._pixel_value_trans_list + self._affine_trans_list):
transform.set_params(**params)
mutate_sample = transform.transform(seed[0])
else:
for p in params:
transform.__setattr__('_'+str(p), params[p])
for param_name in params:
transform.__setattr__('_' + str(param_name), params[param_name])
mutate_sample = transform.generate([seed[0].astype(np.float32)],
[seed[1]])[0]
if method not in self.pixel_value_trans_list:
if method not in self._pixel_value_trans_list:
only_pixel_trans = 1
mutate_sample = [mutate_sample, seed[1], only_pixel_trans]
if self._is_trans_valid(seed[0], mutate_sample[0]):
if _is_trans_valid(seed[0], mutate_sample[0]):
mutate_samples.append(mutate_sample)
mutate_strategies.append(method)
if not mutate_samples:
......@@ -145,29 +319,29 @@ class Fuzzer:
""" Check whether the mutate_config meet the specification."""
has_pixel_trans = False
for mutate in mutate_config:
if mutate['method'] in self.pixel_value_trans_list:
if mutate['method'] in self._pixel_value_trans_list:
has_pixel_trans = True
break
if not has_pixel_trans:
msg = "mutate methods in mutate_config at lease have one in {}".format(
self.pixel_value_trans_list)
self._pixel_value_trans_list)
raise ValueError(msg)
mutates = {}
for mutate in mutate_config:
method = mutate['method']
params = mutate['params']
if method not in self.attacks_list:
mutates[method] = self.strategies[method]()
if method not in self._attacks_list:
mutates[method] = self._strategies[method]()
else:
self._check_attack_params(method, params)
network = self.target_model._network
loss_fn = self.target_model._loss_fn
mutates[method] = self.strategies[method](network,
network = self._target_model._network
loss_fn = self._target_model._loss_fn
mutates[method] = self._strategies[method](network,
loss_fn=loss_fn)
return mutates
def evaluate(self, fuzz_samples, gt_labels, fuzz_preds,
fuzz_strategies):
def _evaluate(self, fuzz_samples, gt_labels, fuzz_preds,
fuzz_strategies, metrics):
"""
Evaluate generated fuzzing samples in three dimention: accuracy,
attack success rate and neural coverage.
......@@ -177,147 +351,40 @@ class Fuzzer:
gt_labels (numpy.ndarray): Ground Truth of seeds.
fuzz_preds (numpy.ndarray): Predictions of generated fuzz samples.
fuzz_strategies (numpy.ndarray): Mutate strategies of fuzz samples.
metrics (Union[list, tuple, str]): evaluation metrics.
Returns:
dict, evaluate metrics include accuarcy, attack success rate
and neural coverage.
"""
temp = np.argmax(gt_labels, axis=1) == np.argmax(fuzz_preds, axis=1)
metrics_report = {}
if metrics == 'auto' or 'accuracy' in metrics:
gt_labels = np.asarray(gt_labels)
fuzz_preds = np.asarray(fuzz_preds)
temp = np.argmax(gt_labels, axis=1) == np.argmax(fuzz_preds, axis=1)
acc = np.sum(temp) / np.size(temp)
metrics_report['Accuracy'] = acc
cond = [elem in self.attacks_list for elem in fuzz_strategies]
if metrics == 'auto' or 'attack_success_rate' in metrics:
cond = [elem in self._attacks_list for elem in fuzz_strategies]
temp = temp[cond]
attack_success_rate = 1 - np.sum(temp) / np.size(temp)
metrics_report['Attack_success_rate'] = attack_success_rate
self.coverage_metrics.calculate_coverage(
if metrics == 'auto' or 'kmnc' in metrics or 'nbc' in metrics or 'snac' in metrics:
self._coverage_metrics.calculate_coverage(
np.array(fuzz_samples).astype(np.float32))
kmnc = self.coverage_metrics.get_kmnc()
nbc = self.coverage_metrics.get_nbc()
snac = self.coverage_metrics.get_snac()
metrics = {}
metrics['Accuracy'] = acc
metrics['Attack_succrss_rate'] = attack_success_rate
metrics['Neural_coverage_KMNC'] = kmnc
metrics['Neural_coverage_NBC'] = nbc
metrics['Neural_coverage_SNAC'] = snac
return metrics
def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC',
eval_metric=True, max_iters=10000, mutate_num_per_seed=20):
"""
Fuzzing tests for deep neural networks.
if metrics == 'auto' or 'kmnc' in metrics:
kmnc = self._coverage_metrics.get_kmnc()
metrics_report['Neural_coverage_KMNC'] = kmnc
Args:
mutate_config (list): Mutate configs. The format is
[{'method': 'Blur',
'params': {'auto_param': True}},
{'method': 'Contrast',
'params': {'factor': 2}},
...]. The support methods list is in `self.strategies`,
The params of each method must within the range of changeable
parameters.
initial_seeds (numpy.ndarray): Initial seeds used to generate
mutated samples.
coverage_metric (str): Model coverage metric of neural networks.
Default: 'KMNC'.
eval_metric (bool): Whether to evaluate the generated fuzz samples.
Default: True.
max_iters (int): Max number of select a seed to mutate.
Default: 10000.
mutate_num_per_seed (int): The number of mutate times for a seed.
Default: 20.
if metrics == 'auto' or 'nbc' in metrics:
nbc = self._coverage_metrics.get_nbc()
metrics_report['Neural_coverage_NBC'] = nbc
Returns:
list, mutated samples.
"""
# Check whether the mutate_config meet the specification.
mutates = self._init_mutates(mutate_config)
seed, initial_seeds = self._select_next(initial_seeds)
fuzz_samples = []
gt_labels = []
fuzz_preds = []
fuzz_strategies = []
iter_num = 0
while initial_seeds and iter_num < max_iters:
# Mutate a seed.
mutate_samples, mutate_strategies = self._metamorphic_mutate(seed,
mutates,
mutate_config,
mutate_num_per_seed)
# Calculate the coverages and predictions of generated samples.
coverages, predicts = self._run(mutate_samples, coverage_metric)
coverage_gains = self._coverage_gains(coverages)
for mutate, cov, pred, strategy in zip(mutate_samples,
coverage_gains,
predicts, mutate_strategies):
fuzz_samples.append(mutate[0])
gt_labels.append(mutate[1])
fuzz_preds.append(pred)
fuzz_strategies.append(strategy)
# if the mutate samples has coverage gains add this samples in
# the initial seeds to guide new mutates.
if cov > 0:
initial_seeds.append(mutate)
seed, initial_seeds = self._select_next(initial_seeds)
iter_num += 1
metrics = None
if eval_metric:
metrics = self.evaluate(fuzz_samples, gt_labels, fuzz_preds,
fuzz_strategies)
return fuzz_samples, gt_labels, fuzz_preds, fuzz_strategies, metrics
def _coverage_gains(self, coverages):
""" Calculate the coverage gains of mutated samples. """
gains = [0] + coverages[:-1]
gains = np.array(coverages) - np.array(gains)
return gains
def _run(self, mutate_samples, coverage_metric="KNMC"):
""" Calculate the coverages and predictions of generated samples."""
samples = [s[0] for s in mutate_samples]
samples = np.array(samples)
coverages = []
predictions = self.target_model.predict(Tensor(samples.astype(np.float32)))
predictions = predictions.asnumpy()
for index in range(len(samples)):
mutate = samples[:index + 1]
self.coverage_metrics.calculate_coverage(mutate.astype(np.float32))
if coverage_metric == "KMNC":
coverages.append(self.coverage_metrics.get_kmnc())
if coverage_metric == 'NBC':
coverages.append(self.coverage_metrics.get_nbc())
if coverage_metric == 'SNAC':
coverages.append(self.coverage_metrics.get_snac())
return coverages, predictions
def _select_next(self, initial_seeds):
"""Randomly select a seed from `initial_seeds`."""
seed_num = choice(range(len(initial_seeds)))
seed = initial_seeds[seed_num]
del initial_seeds[seed_num]
return seed, initial_seeds
if metrics == 'auto' or 'snac' in metrics:
snac = self._coverage_metrics.get_snac()
metrics_report['Neural_coverage_SNAC'] = snac
def _is_trans_valid(self, seed, mutate_sample):
""" Check a mutated sample is valid. If the number of changed pixels in
a seed is less than pixels_change_rate*size(seed), this mutate is valid.
Else check the infinite norm of seed changes, if the value of the
infinite norm less than pixel_value_change_rate*255, this mutate is
valid too. Otherwise the opposite."""
is_valid = False
pixels_change_rate = 0.02
pixel_value_change_rate = 0.2
diff = np.array(seed - mutate_sample).flatten()
size = np.shape(diff)[0]
l0 = np.linalg.norm(diff, ord=0)
linf = np.linalg.norm(diff, ord=np.inf)
if l0 > pixels_change_rate*size:
if linf < 256:
is_valid = True
else:
if linf < pixel_value_change_rate*255:
is_valid = True
return is_valid
return metrics_report
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Model-fuzz coverage test.
Model-fuzzer test.
"""
import numpy as np
import pytest
......@@ -121,7 +121,7 @@ def test_fuzzing_ascend():
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())
model_fuzz_test = Fuzzer(model, train_images, 1000, 10)
model_fuzz_test = Fuzzer(model, train_images, 10, 1000)
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds)
print(metrics)
......@@ -167,6 +167,6 @@ def test_fuzzing_cpu():
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())
model_fuzz_test = Fuzzer(model, train_images, 1000, 10)
model_fuzz_test = Fuzzer(model, train_images, 10, 1000)
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds)
print(metrics)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册