From 425cc952835647f1780ccb3792ce1bf05c2447b9 Mon Sep 17 00:00:00 2001 From: pkuliuliu Date: Tue, 18 Aug 2020 20:59:49 +0800 Subject: [PATCH] add comment and param check --- example/mnist_demo/lenet5_mnist_coverage.py | 13 +- example/mnist_demo/lenet5_mnist_fuzzing.py | 6 +- mindarmour/fuzzing/fuzzing.py | 441 +++++++++++--------- tests/ut/python/fuzzing/test_fuzzer.py | 6 +- 4 files changed, 266 insertions(+), 200 deletions(-) diff --git a/example/mnist_demo/lenet5_mnist_coverage.py b/example/mnist_demo/lenet5_mnist_coverage.py index 6a93c30..5f299dd 100644 --- a/example/mnist_demo/lenet5_mnist_coverage.py +++ b/example/mnist_demo/lenet5_mnist_coverage.py @@ -67,7 +67,7 @@ def test_lenet_mnist_coverage(): test_labels.append(labels) test_images = np.concatenate(test_images, axis=0) test_labels = np.concatenate(test_labels, axis=0) - model_fuzz_test.test_adequacy_coverage_calculate(test_images) + model_fuzz_test.calculate_coverage(test_images) LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) @@ -76,14 +76,13 @@ def test_lenet_mnist_coverage(): loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) adv_data = attack.batch_generate(test_images, test_labels, batch_size=32) - model_fuzz_test.test_adequacy_coverage_calculate(adv_data, - bias_coefficient=0.5) - LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) - LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) - LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) + model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5) + LOGGER.info(TAG, 'KMNC of this adv data is : %s', model_fuzz_test.get_kmnc()) + LOGGER.info(TAG, 'NBC of this adv data is : %s', model_fuzz_test.get_nbc()) + LOGGER.info(TAG, 'SNAC of this adv data is : %s', model_fuzz_test.get_snac()) if __name__ == '__main__': # device_target can be "CPU", "GPU" or "Ascend" - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") test_lenet_mnist_coverage() diff --git a/example/mnist_demo/lenet5_mnist_fuzzing.py b/example/mnist_demo/lenet5_mnist_fuzzing.py index 8a372f9..273e534 100644 --- a/example/mnist_demo/lenet5_mnist_fuzzing.py +++ b/example/mnist_demo/lenet5_mnist_fuzzing.py @@ -87,9 +87,9 @@ def test_lenet_mnist_fuzzing(): LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) - model_fuzz_test = Fuzzer(model, train_images, 1000, 10) + model_fuzz_test = Fuzzer(model, train_images, 10, 1000) _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, - eval_metric=True) + eval_metrics='auto') if metrics: for key in metrics: LOGGER.info(TAG, key + ': %s', metrics[key]) @@ -97,5 +97,5 @@ def test_lenet_mnist_fuzzing(): if __name__ == '__main__': # device_target can be "CPU", "GPU" or "Ascend" - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") test_lenet_mnist_fuzzing() diff --git a/mindarmour/fuzzing/fuzzing.py b/mindarmour/fuzzing/fuzzing.py index 4c5a951..d33a0ac 100644 --- a/mindarmour/fuzzing/fuzzing.py +++ b/mindarmour/fuzzing/fuzzing.py @@ -27,6 +27,48 @@ from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, \ Noise, Translate, Scale, Shear, Rotate from mindarmour.attacks import FastGradientSignMethod, \ MomentumDiverseInputIterativeMethod, ProjectedGradientDescent +from mindarmour.utils.logger import LogUtil + +LOGGER = LogUtil.get_instance() +TAG = 'Fuzzer' + + +def _select_next(initial_seeds): + """ Randomly select a seed from `initial_seeds`.""" + seed_num = choice(range(len(initial_seeds))) + seed = initial_seeds[seed_num] + del initial_seeds[seed_num] + return seed, initial_seeds + + +def _coverage_gains(coverages): + """ Calculate the coverage gains of mutated samples. """ + gains = [0] + coverages[:-1] + gains = np.array(coverages) - np.array(gains) + return gains + + +def _is_trans_valid(seed, mutate_sample): + """ Check a mutated sample is valid. If the number of changed pixels in + a seed is less than pixels_change_rate*size(seed), this mutate is valid. + Else check the infinite norm of seed changes, if the value of the + infinite norm less than pixel_value_change_rate*255, this mutate is + valid too. Otherwise the opposite. + """ + is_valid = False + pixels_change_rate = 0.02 + pixel_value_change_rate = 0.2 + diff = np.array(seed - mutate_sample).flatten() + size = np.shape(diff)[0] + l0_norm = np.linalg.norm(diff, ord=0) + linf = np.linalg.norm(diff, ord=np.inf) + if l0_norm > pixels_change_rate*size: + if linf < 256: + is_valid = True + else: + if linf < pixel_value_change_rate*255: + is_valid = True + return is_valid class Fuzzer: @@ -40,71 +82,203 @@ class Fuzzer: target_model (Model): Target fuzz model. train_dataset (numpy.ndarray): Training dataset used for determining the neurons' output boundaries. - segmented_num (int): The number of segmented sections of neurons' - output intervals. neuron_num (int): The number of testing neurons. + segmented_num (int): The number of segmented sections of neurons' + output intervals. Default: 1000. + + Examples: + >>> net = Net() + >>> mutate_config = [{'method': 'Blur', 'params': {'auto_param': True}}, + >>> {'method': 'Contrast','params': {'factor': 2}}, + >>> {'method': 'Translate', 'params': {'x_bias': 0.1, 'y_bias': 0.2}}, + >>> {'method': 'FGSM', 'params': {'eps': 0.1, 'alpha': 0.1}}] + >>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) + >>> model_fuzz_test = Fuzzer(model, train_images, 1000, 10) + >>> samples, labels, preds, strategies, report = model_fuzz_test.fuzzing(mutate_config, initial_seeds) """ - def __init__(self, target_model, train_dataset, segmented_num, neuron_num): - self.target_model = check_model('model', target_model, Model) - self.train_dataset = check_numpy_param('train_dataset', train_dataset) - self.coverage_metrics = ModelCoverageMetrics(target_model, - segmented_num, - neuron_num, train_dataset) + def __init__(self, target_model, train_dataset, neuron_num, segmented_num=1000): + self._target_model = check_model('model', target_model, Model) + train_dataset = check_numpy_param('train_dataset', train_dataset) + self._coverage_metrics = ModelCoverageMetrics(target_model, + segmented_num, + neuron_num, train_dataset) # Allowed mutate strategies so far. - self.strategies = {'Contrast': Contrast, 'Brightness': Brightness, - 'Blur': Blur, 'Noise': Noise, 'Translate': Translate, - 'Scale': Scale, 'Shear': Shear, 'Rotate': Rotate, - 'FGSM': FastGradientSignMethod, - 'PGD': ProjectedGradientDescent, - 'MDIIM': MomentumDiverseInputIterativeMethod} - self.affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate'] - self.pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur', - 'Noise'] - self.attacks_list = ['FGSM', 'PGD', 'MDIIM'] - self.attack_param_checklists = { + self._strategies = {'Contrast': Contrast, 'Brightness': Brightness, + 'Blur': Blur, 'Noise': Noise, 'Translate': Translate, + 'Scale': Scale, 'Shear': Shear, 'Rotate': Rotate, + 'FGSM': FastGradientSignMethod, + 'PGD': ProjectedGradientDescent, + 'MDIIM': MomentumDiverseInputIterativeMethod} + self._affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate'] + self._pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur', + 'Noise'] + self._attacks_list = ['FGSM', 'PGD', 'MDIIM'] + self._attack_param_checklists = { 'FGSM': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, 'alpha': {'dtype': [float, int], 'range': [0, 1]}, 'bounds': {'dtype': [list, tuple], - 'range': None}, - }}, + 'range': None}}}, 'PGD': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, 'eps_iter': {'dtype': [float, int], 'range': [0, 1e5]}, 'nb_iter': {'dtype': [float, int], 'range': [0, 1e5]}, 'bounds': {'dtype': [list, tuple], - 'range': None}, - }}, + 'range': None}}}, 'MDIIM': { 'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, 'norm_level': {'dtype': [str], 'range': None}, 'prob': {'dtype': [float, int], 'range': [0, 1]}, - 'bounds': {'dtype': [list, tuple], 'range': None}, - }}} + 'bounds': {'dtype': [list, tuple], 'range': None}}}} + + def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC', + eval_metrics='auto', max_iters=10000, mutate_num_per_seed=20): + """ + Fuzzing tests for deep neural networks. + + Args: + mutate_config (list): Mutate configs. The format is + [{'method': 'Blur', 'params': {'auto_param': True}}, {'method': 'Contrast', 'params': {'factor': 2}}]. + The support methods list is in `self._strategies`, and the params of each + method must within the range of changeable parameters. + initial_seeds (numpy.ndarray): Initial seeds used to generate + mutated samples. + coverage_metric (str): Model coverage metric of neural networks. + Default: 'KMNC'. + eval_metrics (Union[list, tuple, str]): Evaluation metrics. If the type is 'auto', + it will calculate all the metrics, else if the type is list or tuple, it will + calculate the metrics specified by user. Default: 'auto'. + max_iters (int): Max number of select a seed to mutate. + Default: 10000. + mutate_num_per_seed (int): The number of mutate times for a seed. + Default: 20. + + Returns: + - list, mutated samples in fuzzing. + + - list, ground truth labels of mutated samples. + + - list, preds of mutated samples. + + - list, strategies of mutated samples. + + - dict, metrics report of fuzzer. + + Raises: + TypeError: If the type of `eval_metrics` is not str, list or tuple. + TypeError: If the type of metric in list `eval_metrics` is not str. + ValueError: If `eval_metrics` is not equal to 'auto' when it's type is str. + ValueError: If metric in list `eval_metrics` is not in ['accuracy', 'attack_success_rate', + 'kmnc', 'nbc', 'snac']. + """ + eval_metrics_ = None + if isinstance(eval_metrics, (list, tuple)): + eval_metrics_ = [] + avaliable_metrics = ['accuracy', 'attack_success_rate', 'kmnc', 'nbc', 'snac'] + for elem in eval_metrics: + if not isinstance(elem, str): + msg = 'the type of metric in list `eval_metrics` must be str, but got {}.' \ + .format(type(elem)) + LOGGER.error(TAG, msg) + raise TypeError(msg) + if elem not in avaliable_metrics: + msg = 'metric in list `eval_metrics` must be in {}, but got {}.' \ + .format(avaliable_metrics, elem) + LOGGER.error(TAG, msg) + raise ValueError(msg) + eval_metrics_.append(elem.lower()) + elif isinstance(eval_metrics, str): + if eval_metrics != 'auto': + msg = "the value of `eval_metrics` must be 'auto' if it's type is str, " \ + "but got {}.".format(eval_metrics) + LOGGER.error(TAG, msg) + raise ValueError(msg) + eval_metrics_ = 'auto' + else: + msg = "the type of `eval_metrics` must be str, list or tuple, but got {}." \ + .format(type(eval_metrics)) + LOGGER.error(TAG, msg) + raise TypeError(msg) + + # Check whether the mutate_config meet the specification. + mutates = self._init_mutates(mutate_config) + seed, initial_seeds = _select_next(initial_seeds) + fuzz_samples = [] + gt_labels = [] + fuzz_preds = [] + fuzz_strategies = [] + iter_num = 0 + while initial_seeds and iter_num < max_iters: + # Mutate a seed. + mutate_samples, mutate_strategies = self._metamorphic_mutate(seed, + mutates, + mutate_config, + mutate_num_per_seed) + # Calculate the coverages and predictions of generated samples. + coverages, predicts = self._run(mutate_samples, coverage_metric) + coverage_gains = _coverage_gains(coverages) + for mutate, cov, pred, strategy in zip(mutate_samples, + coverage_gains, + predicts, mutate_strategies): + fuzz_samples.append(mutate[0]) + gt_labels.append(mutate[1]) + fuzz_preds.append(pred) + fuzz_strategies.append(strategy) + # if the mutate samples has coverage gains add this samples in + # the initial seeds to guide new mutates. + if cov > 0: + initial_seeds.append(mutate) + seed, initial_seeds = _select_next(initial_seeds) + iter_num += 1 + metrics_report = None + if eval_metrics_ is not None: + metrics_report = self._evaluate(fuzz_samples, gt_labels, fuzz_preds, + fuzz_strategies, eval_metrics_) + return fuzz_samples, gt_labels, fuzz_preds, fuzz_strategies, metrics_report + + def _run(self, mutate_samples, coverage_metric="KNMC"): + """ Calculate the coverages and predictions of generated samples.""" + samples = [s[0] for s in mutate_samples] + samples = np.array(samples) + coverages = [] + predictions = self._target_model.predict(Tensor(samples.astype(np.float32))) + predictions = predictions.asnumpy() + for index in range(len(samples)): + mutate = samples[:index + 1] + self._coverage_metrics.calculate_coverage(mutate.astype(np.float32)) + if coverage_metric == "KMNC": + coverages.append(self._coverage_metrics.get_kmnc()) + if coverage_metric == 'NBC': + coverages.append(self._coverage_metrics.get_nbc()) + if coverage_metric == 'SNAC': + coverages.append(self._coverage_metrics.get_snac()) + return coverages, predictions def _check_attack_params(self, method, params): """Check input parameters of attack methods.""" - allow_params = self.attack_param_checklists[method]['params'].keys() - for p in params: - if p not in allow_params: + allow_params = self._attack_param_checklists[method]['params'].keys() + for param_name in params: + if param_name not in allow_params: msg = "parameters of {} must in {}".format(method, allow_params) raise ValueError(msg) - if p == 'bounds': - bounds = check_param_multi_types('bounds', params[p], + + param_value = params[param_name] + if param_name == 'bounds': + bounds = check_param_multi_types('bounds', param_value, [list, tuple]) - for b in bounds: - _ = check_param_multi_types('bound', b, [int, float]) - elif p == 'norm_level': - _ = check_norm_level(params[p]) + for bound_value in bounds: + _ = check_param_multi_types('bound', bound_value, [int, float]) + elif param_name == 'norm_level': + _ = check_norm_level(param_value) else: - allow_type = self.attack_param_checklists[method]['params'][p][ + allow_type = self._attack_param_checklists[method]['params'][param_name][ 'dtype'] - allow_range = self.attack_param_checklists[method]['params'][p][ + allow_range = self._attack_param_checklists[method]['params'][param_name][ 'range'] - _ = check_param_multi_types(str(p), params[p], allow_type) - _ = check_param_in_range(str(p), params[p], allow_range[0], + _ = check_param_multi_types(str(param_name), param_value, allow_type) + _ = check_param_in_range(str(param_name), param_value, allow_range[0], allow_range[1]) def _metamorphic_mutate(self, seed, mutates, mutate_config, @@ -117,23 +291,23 @@ class Fuzzer: strage = choice(mutate_config) # Choose a pixel value based transform method if only_pixel_trans: - while strage['method'] not in self.pixel_value_trans_list: + while strage['method'] not in self._pixel_value_trans_list: strage = choice(mutate_config) transform = mutates[strage['method']] params = strage['params'] method = strage['method'] - if method in list(self.pixel_value_trans_list + self.affine_trans_list): + if method in list(self._pixel_value_trans_list + self._affine_trans_list): transform.set_params(**params) mutate_sample = transform.transform(seed[0]) else: - for p in params: - transform.__setattr__('_'+str(p), params[p]) + for param_name in params: + transform.__setattr__('_' + str(param_name), params[param_name]) mutate_sample = transform.generate([seed[0].astype(np.float32)], [seed[1]])[0] - if method not in self.pixel_value_trans_list: + if method not in self._pixel_value_trans_list: only_pixel_trans = 1 mutate_sample = [mutate_sample, seed[1], only_pixel_trans] - if self._is_trans_valid(seed[0], mutate_sample[0]): + if _is_trans_valid(seed[0], mutate_sample[0]): mutate_samples.append(mutate_sample) mutate_strategies.append(method) if not mutate_samples: @@ -145,29 +319,29 @@ class Fuzzer: """ Check whether the mutate_config meet the specification.""" has_pixel_trans = False for mutate in mutate_config: - if mutate['method'] in self.pixel_value_trans_list: + if mutate['method'] in self._pixel_value_trans_list: has_pixel_trans = True break if not has_pixel_trans: msg = "mutate methods in mutate_config at lease have one in {}".format( - self.pixel_value_trans_list) + self._pixel_value_trans_list) raise ValueError(msg) mutates = {} for mutate in mutate_config: method = mutate['method'] params = mutate['params'] - if method not in self.attacks_list: - mutates[method] = self.strategies[method]() + if method not in self._attacks_list: + mutates[method] = self._strategies[method]() else: self._check_attack_params(method, params) - network = self.target_model._network - loss_fn = self.target_model._loss_fn - mutates[method] = self.strategies[method](network, - loss_fn=loss_fn) + network = self._target_model._network + loss_fn = self._target_model._loss_fn + mutates[method] = self._strategies[method](network, + loss_fn=loss_fn) return mutates - def evaluate(self, fuzz_samples, gt_labels, fuzz_preds, - fuzz_strategies): + def _evaluate(self, fuzz_samples, gt_labels, fuzz_preds, + fuzz_strategies, metrics): """ Evaluate generated fuzzing samples in three dimention: accuracy, attack success rate and neural coverage. @@ -177,147 +351,40 @@ class Fuzzer: gt_labels (numpy.ndarray): Ground Truth of seeds. fuzz_preds (numpy.ndarray): Predictions of generated fuzz samples. fuzz_strategies (numpy.ndarray): Mutate strategies of fuzz samples. + metrics (Union[list, tuple, str]): evaluation metrics. Returns: dict, evaluate metrics include accuarcy, attack success rate and neural coverage. """ - - gt_labels = np.asarray(gt_labels) - fuzz_preds = np.asarray(fuzz_preds) temp = np.argmax(gt_labels, axis=1) == np.argmax(fuzz_preds, axis=1) - acc = np.sum(temp) / np.size(temp) - - cond = [elem in self.attacks_list for elem in fuzz_strategies] - temp = temp[cond] - attack_success_rate = 1 - np.sum(temp) / np.size(temp) - - self.coverage_metrics.calculate_coverage( - np.array(fuzz_samples).astype(np.float32)) - kmnc = self.coverage_metrics.get_kmnc() - nbc = self.coverage_metrics.get_nbc() - snac = self.coverage_metrics.get_snac() - - metrics = {} - metrics['Accuracy'] = acc - metrics['Attack_succrss_rate'] = attack_success_rate - metrics['Neural_coverage_KMNC'] = kmnc - metrics['Neural_coverage_NBC'] = nbc - metrics['Neural_coverage_SNAC'] = snac - return metrics + metrics_report = {} + if metrics == 'auto' or 'accuracy' in metrics: + gt_labels = np.asarray(gt_labels) + fuzz_preds = np.asarray(fuzz_preds) + acc = np.sum(temp) / np.size(temp) + metrics_report['Accuracy'] = acc - def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC', - eval_metric=True, max_iters=10000, mutate_num_per_seed=20): - """ - Fuzzing tests for deep neural networks. + if metrics == 'auto' or 'attack_success_rate' in metrics: + cond = [elem in self._attacks_list for elem in fuzz_strategies] + temp = temp[cond] + attack_success_rate = 1 - np.sum(temp) / np.size(temp) + metrics_report['Attack_success_rate'] = attack_success_rate - Args: - mutate_config (list): Mutate configs. The format is - [{'method': 'Blur', - 'params': {'auto_param': True}}, - {'method': 'Contrast', - 'params': {'factor': 2}}, - ...]. The support methods list is in `self.strategies`, - The params of each method must within the range of changeable - parameters. - initial_seeds (numpy.ndarray): Initial seeds used to generate - mutated samples. - coverage_metric (str): Model coverage metric of neural networks. - Default: 'KMNC'. - eval_metric (bool): Whether to evaluate the generated fuzz samples. - Default: True. - max_iters (int): Max number of select a seed to mutate. - Default: 10000. - mutate_num_per_seed (int): The number of mutate times for a seed. - Default: 20. + if metrics == 'auto' or 'kmnc' in metrics or 'nbc' in metrics or 'snac' in metrics: + self._coverage_metrics.calculate_coverage( + np.array(fuzz_samples).astype(np.float32)) - Returns: - list, mutated samples. - """ - # Check whether the mutate_config meet the specification. - mutates = self._init_mutates(mutate_config) - seed, initial_seeds = self._select_next(initial_seeds) - fuzz_samples = [] - gt_labels = [] - fuzz_preds = [] - fuzz_strategies = [] - iter_num = 0 - while initial_seeds and iter_num < max_iters: - # Mutate a seed. - mutate_samples, mutate_strategies = self._metamorphic_mutate(seed, - mutates, - mutate_config, - mutate_num_per_seed) - # Calculate the coverages and predictions of generated samples. - coverages, predicts = self._run(mutate_samples, coverage_metric) - coverage_gains = self._coverage_gains(coverages) - for mutate, cov, pred, strategy in zip(mutate_samples, - coverage_gains, - predicts, mutate_strategies): - fuzz_samples.append(mutate[0]) - gt_labels.append(mutate[1]) - fuzz_preds.append(pred) - fuzz_strategies.append(strategy) - # if the mutate samples has coverage gains add this samples in - # the initial seeds to guide new mutates. - if cov > 0: - initial_seeds.append(mutate) - seed, initial_seeds = self._select_next(initial_seeds) - iter_num += 1 - metrics = None - if eval_metric: - metrics = self.evaluate(fuzz_samples, gt_labels, fuzz_preds, - fuzz_strategies) - return fuzz_samples, gt_labels, fuzz_preds, fuzz_strategies, metrics - - def _coverage_gains(self, coverages): - """ Calculate the coverage gains of mutated samples. """ - gains = [0] + coverages[:-1] - gains = np.array(coverages) - np.array(gains) - return gains + if metrics == 'auto' or 'kmnc' in metrics: + kmnc = self._coverage_metrics.get_kmnc() + metrics_report['Neural_coverage_KMNC'] = kmnc - def _run(self, mutate_samples, coverage_metric="KNMC"): - """ Calculate the coverages and predictions of generated samples.""" - samples = [s[0] for s in mutate_samples] - samples = np.array(samples) - coverages = [] - predictions = self.target_model.predict(Tensor(samples.astype(np.float32))) - predictions = predictions.asnumpy() - for index in range(len(samples)): - mutate = samples[:index + 1] - self.coverage_metrics.calculate_coverage(mutate.astype(np.float32)) - if coverage_metric == "KMNC": - coverages.append(self.coverage_metrics.get_kmnc()) - if coverage_metric == 'NBC': - coverages.append(self.coverage_metrics.get_nbc()) - if coverage_metric == 'SNAC': - coverages.append(self.coverage_metrics.get_snac()) - return coverages, predictions + if metrics == 'auto' or 'nbc' in metrics: + nbc = self._coverage_metrics.get_nbc() + metrics_report['Neural_coverage_NBC'] = nbc - def _select_next(self, initial_seeds): - """Randomly select a seed from `initial_seeds`.""" - seed_num = choice(range(len(initial_seeds))) - seed = initial_seeds[seed_num] - del initial_seeds[seed_num] - return seed, initial_seeds - - def _is_trans_valid(self, seed, mutate_sample): - """ Check a mutated sample is valid. If the number of changed pixels in - a seed is less than pixels_change_rate*size(seed), this mutate is valid. - Else check the infinite norm of seed changes, if the value of the - infinite norm less than pixel_value_change_rate*255, this mutate is - valid too. Otherwise the opposite.""" - is_valid = False - pixels_change_rate = 0.02 - pixel_value_change_rate = 0.2 - diff = np.array(seed - mutate_sample).flatten() - size = np.shape(diff)[0] - l0 = np.linalg.norm(diff, ord=0) - linf = np.linalg.norm(diff, ord=np.inf) - if l0 > pixels_change_rate*size: - if linf < 256: - is_valid = True - else: - if linf < pixel_value_change_rate*255: - is_valid = True - return is_valid + if metrics == 'auto' or 'snac' in metrics: + snac = self._coverage_metrics.get_snac() + metrics_report['Neural_coverage_SNAC'] = snac + + return metrics_report diff --git a/tests/ut/python/fuzzing/test_fuzzer.py b/tests/ut/python/fuzzing/test_fuzzer.py index 92e6f70..60af20d 100644 --- a/tests/ut/python/fuzzing/test_fuzzer.py +++ b/tests/ut/python/fuzzing/test_fuzzer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Model-fuzz coverage test. +Model-fuzzer test. """ import numpy as np import pytest @@ -121,7 +121,7 @@ def test_fuzzing_ascend(): LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) - model_fuzz_test = Fuzzer(model, train_images, 1000, 10) + model_fuzz_test = Fuzzer(model, train_images, 10, 1000) _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds) print(metrics) @@ -167,6 +167,6 @@ def test_fuzzing_cpu(): LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) - model_fuzz_test = Fuzzer(model, train_images, 1000, 10) + model_fuzz_test = Fuzzer(model, train_images, 10, 1000) _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds) print(metrics) -- GitLab