提交 94ff3ad5 编写于 作者: Z ZhidanLiu

Reconstruct Fuzzer

上级 db93de3e
......@@ -19,7 +19,7 @@ from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from lenet5_net import LeNet5
from mindarmour.fuzzing.fuzzing import Fuzzing
from mindarmour.fuzzing.fuzzing import Fuzzer
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics
from mindarmour.utils.logger import LogUtil
......@@ -38,11 +38,20 @@ def test_lenet_mnist_fuzzing():
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
model = Model(net)
mutate_config = [{'method': 'Blur',
'params': {'auto_param': True}},
{'method': 'Contrast',
'params': {'factor': 2}},
{'method': 'Translate',
'params': {'x_bias': 0.1, 'y_bias': 0.2}},
{'method': 'FGSM',
'params': {'eps': 0.1, 'alpha': 0.1}}
]
# get training data
data_list = "./MNIST_unzip/train"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size, sparse=True)
ds = generate_mnist_dataset(data_list, batch_size, sparse=False)
train_images = []
for data in ds.create_tuple_iterator():
images = data[0].astype(np.float32)
......@@ -56,7 +65,7 @@ def test_lenet_mnist_fuzzing():
# get test data
data_list = "./MNIST_unzip/test"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size, sparse=True)
ds = generate_mnist_dataset(data_list, batch_size, sparse=False)
test_images = []
test_labels = []
for data in ds.create_tuple_iterator():
......@@ -70,19 +79,20 @@ def test_lenet_mnist_fuzzing():
# make initial seeds
for img, label in zip(test_images, test_labels):
initial_seeds.append([img, label])
initial_seeds.append([img, label, 0])
initial_seeds = initial_seeds[:100]
model_coverage_test.test_adequacy_coverage_calculate(np.array(test_images[:100]).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc())
model_coverage_test.calculate_coverage(
np.array(test_images[:100]).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())
model_fuzz_test = Fuzzing(initial_seeds, model, train_images, 20)
failed_tests = model_fuzz_test.fuzzing()
if failed_tests:
model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc())
else:
LOGGER.info(TAG, 'Fuzzing test identifies none failed test')
model_fuzz_test = Fuzzer(model, train_images, 1000, 10)
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds,
eval_metric=True)
if metrics:
for key in metrics:
LOGGER.info(TAG, key + ': %s', metrics[key])
if __name__ == '__main__':
......
......@@ -227,8 +227,8 @@ class BasicIterativeMethod(IterativeGradientMethod):
clip_min, clip_max = self._bounds
clip_diff = clip_max - clip_min
for _ in range(self._nb_iter):
if 'self.prob' in globals():
d_inputs = _transform_inputs(inputs, self.prob)
if 'self._prob' in globals():
d_inputs = _transform_inputs(inputs, self._prob)
else:
d_inputs = inputs
adv_x = self._attack.generate(d_inputs, labels)
......@@ -238,8 +238,8 @@ class BasicIterativeMethod(IterativeGradientMethod):
inputs = adv_x
else:
for _ in range(self._nb_iter):
if 'self.prob' in globals():
d_inputs = _transform_inputs(inputs, self.prob)
if 'self._prob' in globals():
d_inputs = _transform_inputs(inputs, self._prob)
else:
d_inputs = inputs
adv_x = self._attack.generate(d_inputs, labels)
......@@ -311,8 +311,8 @@ class MomentumIterativeMethod(IterativeGradientMethod):
clip_min, clip_max = self._bounds
clip_diff = clip_max - clip_min
for _ in range(self._nb_iter):
if 'self.prob' in globals():
d_inputs = _transform_inputs(inputs, self.prob)
if 'self._prob' in globals():
d_inputs = _transform_inputs(inputs, self._prob)
else:
d_inputs = inputs
gradient = self._gradient(d_inputs, labels)
......@@ -325,8 +325,8 @@ class MomentumIterativeMethod(IterativeGradientMethod):
inputs = adv_x
else:
for _ in range(self._nb_iter):
if 'self.prob' in globals():
d_inputs = _transform_inputs(inputs, self.prob)
if 'self._prob' in globals():
d_inputs = _transform_inputs(inputs, self._prob)
else:
d_inputs = inputs
gradient = self._gradient(d_inputs, labels)
......@@ -476,7 +476,7 @@ class DiverseInputIterativeMethod(BasicIterativeMethod):
is_targeted=is_targeted,
nb_iter=nb_iter,
loss_fn=loss_fn)
self.prob = check_param_type('prob', prob, float)
self._prob = check_param_type('prob', prob, float)
class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod):
......@@ -511,7 +511,7 @@ class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod):
is_targeted=is_targeted,
norm_level=norm_level,
loss_fn=loss_fn)
self.prob = check_param_type('prob', prob, float)
self._prob = check_param_type('prob', prob, float)
def _transform_inputs(inputs, prob, low=29, high=33, full_aug=False):
......
......@@ -22,9 +22,11 @@ from mindspore import Tensor
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics
from mindarmour.utils._check_param import check_model, check_numpy_param, \
check_int_positive
from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, Noise, \
Translate, Scale, Shear, Rotate
check_param_multi_types, check_norm_level, check_param_in_range
from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, \
Noise, Translate, Scale, Shear, Rotate
from mindarmour.attacks import FastGradientSignMethod, \
MomentumDiverseInputIterativeMethod, ProjectedGradientDescent
class Fuzzer:
......@@ -35,129 +37,280 @@ class Fuzzer:
Neural Networks <https://dl.acm.org/doi/10.1145/3293882.3330579>`_
Args:
initial_seeds (list): Initial fuzzing seed, format: [[image, label],
[image, label], ...].
target_model (Model): Target fuzz model.
train_dataset (numpy.ndarray): Training dataset used for determining
the neurons' output boundaries.
const_k (int): The number of mutate tests for a seed.
mode (str): Image mode used in image transform, 'L' means grey graph.
Default: 'L'.
max_seed_num (int): The initial seeds max value. Default: 1000
segmented_num (int): The number of segmented sections of neurons'
output intervals.
neuron_num (int): The number of testing neurons.
"""
def __init__(self, initial_seeds, target_model, train_dataset, const_K,
mode='L', max_seed_num=1000):
self.initial_seeds = initial_seeds
def __init__(self, target_model, train_dataset, segmented_num, neuron_num):
self.target_model = check_model('model', target_model, Model)
self.train_dataset = check_numpy_param('train_dataset', train_dataset)
self.const_k = check_int_positive('const_k', const_K)
self.mode = mode
self.max_seed_num = check_int_positive('max_seed_num', max_seed_num)
self.coverage_metrics = ModelCoverageMetrics(target_model, 1000, 10,
train_dataset)
def _image_value_expand(self, image):
return image*255
def _image_value_compress(self, image):
return image / 255
def _metamorphic_mutate(self, seed, try_num=50):
if self.mode == 'L':
seed = seed[0]
info = [seed, seed]
mutate_tests = []
pixel_value_trans = ['Contrast', 'Brightness', 'Blur', 'Noise']
affine_trans = ['Translate', 'Scale', 'Shear', 'Rotate']
strages = {'Contrast': Contrast, 'Brightness': Brightness, 'Blur': Blur,
'Noise': Noise,
'Translate': Translate, 'Scale': Scale, 'Shear': Shear,
'Rotate': Rotate}
for _ in range(self.const_k):
for _ in range(try_num):
if (info[0] == info[1]).all():
trans_strage = self._random_pick_mutate(affine_trans,
pixel_value_trans)
else:
trans_strage = self._random_pick_mutate(pixel_value_trans,
[])
transform = strages[trans_strage](
self._image_value_expand(seed), self.mode)
transform.set_params(auto_param=True)
mutate_test = transform.transform()
mutate_test = np.expand_dims(
self._image_value_compress(mutate_test), 0)
if self._is_trans_valid(seed, mutate_test):
if trans_strage in affine_trans:
info[1] = mutate_test
mutate_tests.append(mutate_test)
if not mutate_tests:
mutate_tests.append(seed)
return np.array(mutate_tests)
def fuzzing(self, coverage_metric='KMNC'):
self.coverage_metrics = ModelCoverageMetrics(target_model,
segmented_num,
neuron_num, train_dataset)
# Allowed mutate strategies so far.
self.strategies = {'Contrast': Contrast, 'Brightness': Brightness,
'Blur': Blur, 'Noise': Noise, 'Translate': Translate,
'Scale': Scale, 'Shear': Shear, 'Rotate': Rotate,
'FGSM': FastGradientSignMethod,
'PGD': ProjectedGradientDescent,
'MDIIM': MomentumDiverseInputIterativeMethod}
self.affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate']
self.pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur',
'Noise']
self.attacks_list = ['FGSM', 'PGD', 'MDIIM']
self.attack_param_checklists = {
'FGSM': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]},
'alpha': {'dtype': [float, int],
'range': [0, 1]},
'bounds': {'dtype': [list, tuple],
'range': None},
}},
'PGD': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]},
'eps_iter': {'dtype': [float, int],
'range': [0, 1e5]},
'nb_iter': {'dtype': [float, int],
'range': [0, 1e5]},
'bounds': {'dtype': [list, tuple],
'range': None},
}},
'MDIIM': {
'params': {'eps': {'dtype': [float, int], 'range': [0, 1]},
'norm_level': {'dtype': [str], 'range': None},
'prob': {'dtype': [float, int], 'range': [0, 1]},
'bounds': {'dtype': [list, tuple], 'range': None},
}}}
def _check_attack_params(self, method, params):
"""Check input parameters of attack methods."""
allow_params = self.attack_param_checklists[method]['params'].keys()
for p in params:
if p not in allow_params:
msg = "parameters of {} must in {}".format(method, allow_params)
raise ValueError(msg)
if p == 'bounds':
bounds = check_param_multi_types('bounds', params[p],
[list, tuple])
for b in bounds:
_ = check_param_multi_types('bound', b, [int, float])
elif p == 'norm_level':
_ = check_norm_level(params[p])
else:
allow_type = self.attack_param_checklists[method]['params'][p][
'dtype']
allow_range = self.attack_param_checklists[method]['params'][p][
'range']
_ = check_param_multi_types(str(p), params[p], allow_type)
_ = check_param_in_range(str(p), params[p], allow_range[0],
allow_range[1])
def _metamorphic_mutate(self, seed, mutates, mutate_config,
mutate_num_per_seed):
"""Mutate a seed using strategies random selected from mutate_config."""
mutate_samples = []
mutate_strategies = []
only_pixel_trans = seed[2]
for _ in range(mutate_num_per_seed):
strage = choice(mutate_config)
# Choose a pixel value based transform method
if only_pixel_trans:
while strage['method'] not in self.pixel_value_trans_list:
strage = choice(mutate_config)
transform = mutates[strage['method']]
params = strage['params']
method = strage['method']
if method in list(self.pixel_value_trans_list + self.affine_trans_list):
transform.set_params(**params)
mutate_sample = transform.transform(seed[0])
else:
for p in params:
transform.__setattr__('_'+str(p), params[p])
mutate_sample = transform.generate([seed[0].astype(np.float32)],
[seed[1]])[0]
if method not in self.pixel_value_trans_list:
only_pixel_trans = 1
mutate_sample = [mutate_sample, seed[1], only_pixel_trans]
if self._is_trans_valid(seed[0], mutate_sample[0]):
mutate_samples.append(mutate_sample)
mutate_strategies.append(method)
if not mutate_samples:
mutate_samples.append(seed)
mutate_strategies.append(None)
return np.array(mutate_samples), mutate_strategies
def _init_mutates(self, mutate_config):
""" Check whether the mutate_config meet the specification."""
has_pixel_trans = False
for mutate in mutate_config:
if mutate['method'] in self.pixel_value_trans_list:
has_pixel_trans = True
break
if not has_pixel_trans:
msg = "mutate methods in mutate_config at lease have one in {}".format(
self.pixel_value_trans_list)
raise ValueError(msg)
mutates = {}
for mutate in mutate_config:
method = mutate['method']
params = mutate['params']
if method not in self.attacks_list:
mutates[method] = self.strategies[method]()
else:
self._check_attack_params(method, params)
network = self.target_model._network
loss_fn = self.target_model._loss_fn
mutates[method] = self.strategies[method](network,
loss_fn=loss_fn)
return mutates
def evaluate(self, fuzz_samples, gt_labels, fuzz_preds,
fuzz_strategies):
"""
Evaluate generated fuzzing samples in three dimention: accuracy,
attack success rate and neural coverage.
Args:
fuzz_samples (numpy.ndarray): Generated fuzzing samples according to seeds.
gt_labels (numpy.ndarray): Ground Truth of seeds.
fuzz_preds (numpy.ndarray): Predictions of generated fuzz samples.
fuzz_strategies (numpy.ndarray): Mutate strategies of fuzz samples.
Returns:
dict, evaluate metrics include accuarcy, attack success rate
and neural coverage.
"""
gt_labels = np.asarray(gt_labels)
fuzz_preds = np.asarray(fuzz_preds)
temp = np.argmax(gt_labels, axis=1) == np.argmax(fuzz_preds, axis=1)
acc = np.sum(temp) / np.size(temp)
cond = [elem in self.attacks_list for elem in fuzz_strategies]
temp = temp[cond]
attack_success_rate = 1 - np.sum(temp) / np.size(temp)
self.coverage_metrics.calculate_coverage(
np.array(fuzz_samples).astype(np.float32))
kmnc = self.coverage_metrics.get_kmnc()
nbc = self.coverage_metrics.get_nbc()
snac = self.coverage_metrics.get_snac()
metrics = {}
metrics['Accuracy'] = acc
metrics['Attack_succrss_rate'] = attack_success_rate
metrics['Neural_coverage_KMNC'] = kmnc
metrics['Neural_coverage_NBC'] = nbc
metrics['Neural_coverage_SNAC'] = snac
return metrics
def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC',
eval_metric=True, max_iters=10000, mutate_num_per_seed=20):
"""
Fuzzing tests for deep neural networks.
Args:
mutate_config (list): Mutate configs. The format is
[{'method': 'Blur',
'params': {'auto_param': True}},
{'method': 'Contrast',
'params': {'factor': 2}},
...]. The support methods list is in `self.strategies`,
The params of each method must within the range of changeable
parameters.
initial_seeds (numpy.ndarray): Initial seeds used to generate
mutated samples.
coverage_metric (str): Model coverage metric of neural networks.
Default: 'KMNC'.
eval_metric (bool): Whether to evaluate the generated fuzz samples.
Default: True.
max_iters (int): Max number of select a seed to mutate.
Default: 10000.
mutate_num_per_seed (int): The number of mutate times for a seed.
Default: 20.
Returns:
list, mutated tests mis-predicted by target DNN model.
list, mutated samples.
"""
seed = self._select_next()
failed_tests = []
seed_num = 0
while seed and seed_num < self.max_seed_num:
mutate_tests = self._metamorphic_mutate(seed[0])
coverages, predicts = self._run(mutate_tests, coverage_metric)
# Check whether the mutate_config meet the specification.
mutates = self._init_mutates(mutate_config)
seed, initial_seeds = self._select_next(initial_seeds)
fuzz_samples = []
gt_labels = []
fuzz_preds = []
fuzz_strategies = []
iter_num = 0
while initial_seeds and iter_num < max_iters:
# Mutate a seed.
mutate_samples, mutate_strategies = self._metamorphic_mutate(seed,
mutates,
mutate_config,
mutate_num_per_seed)
# Calculate the coverages and predictions of generated samples.
coverages, predicts = self._run(mutate_samples, coverage_metric)
coverage_gains = self._coverage_gains(coverages)
for mutate, cov, res in zip(mutate_tests, coverage_gains, predicts):
if np.argmax(seed[1]) != np.argmax(res):
failed_tests.append(mutate)
continue
for mutate, cov, pred, strategy in zip(mutate_samples,
coverage_gains,
predicts, mutate_strategies):
fuzz_samples.append(mutate[0])
gt_labels.append(mutate[1])
fuzz_preds.append(pred)
fuzz_strategies.append(strategy)
# if the mutate samples has coverage gains add this samples in
# the initial seeds to guide new mutates.
if cov > 0:
self.initial_seeds.append([mutate, seed[1]])
seed = self._select_next()
seed_num += 1
return failed_tests
initial_seeds.append(mutate)
seed, initial_seeds = self._select_next(initial_seeds)
iter_num += 1
metrics = None
if eval_metric:
metrics = self.evaluate(fuzz_samples, gt_labels, fuzz_preds,
fuzz_strategies)
return fuzz_samples, gt_labels, fuzz_preds, fuzz_strategies, metrics
def _coverage_gains(self, coverages):
""" Calculate the coverage gains of mutated samples. """
gains = [0] + coverages[:-1]
gains = np.array(coverages) - np.array(gains)
return gains
def _run(self, mutate_tests, coverage_metric="KNMC"):
def _run(self, mutate_samples, coverage_metric="KNMC"):
""" Calculate the coverages and predictions of generated samples."""
samples = [s[0] for s in mutate_samples]
samples = np.array(samples)
coverages = []
result = self.target_model.predict(
Tensor(mutate_tests.astype(np.float32)))
result = result.asnumpy()
for index in range(len(mutate_tests)):
mutate = np.expand_dims(mutate_tests[index], 0)
self.coverage_metrics.model_coverage_test(
mutate.astype(np.float32), batch_size=1)
predictions = self.target_model.predict(Tensor(samples.astype(np.float32)))
predictions = predictions.asnumpy()
for index in range(len(samples)):
mutate = samples[:index + 1]
self.coverage_metrics.calculate_coverage(mutate.astype(np.float32))
if coverage_metric == "KMNC":
coverages.append(self.coverage_metrics.get_kmnc())
if coverage_metric == 'NBC':
coverages.append(self.coverage_metrics.get_nbc())
if coverage_metric == 'SNAC':
coverages.append(self.coverage_metrics.get_snac())
return coverages, predictions
return coverages, result
def _select_next(self):
seed = choice(self.initial_seeds)
return seed
def _select_next(self, initial_seeds):
"""Randomly select a seed from `initial_seeds`."""
seed_num = choice(range(len(initial_seeds)))
seed = initial_seeds[seed_num]
del initial_seeds[seed_num]
return seed, initial_seeds
def _random_pick_mutate(self, affine_trans_list, pixel_value_trans_list):
strage = choice(affine_trans_list + pixel_value_trans_list)
return strage
def _is_trans_valid(self, seed, mutate_test):
def _is_trans_valid(self, seed, mutate_sample):
""" Check a mutated sample is valid. If the number of changed pixels in
a seed is less than pixels_change_rate*size(seed), this mutate is valid.
Else check the infinite norm of seed changes, if the value of the
infinite norm less than pixel_value_change_rate*255, this mutate is
valid too. Otherwise the opposite."""
is_valid = False
pixels_change_rate = 0.02
pixel_value_change_rate = 0.2
diff = np.array(seed - mutate_test).flatten()
diff = np.array(seed - mutate_sample).flatten()
size = np.shape(diff)[0]
l0 = np.linalg.norm(diff, ord=0)
linf = np.linalg.norm(diff, ord=np.inf)
......@@ -167,5 +320,4 @@ class Fuzzer:
else:
if linf < pixel_value_change_rate*255:
is_valid = True
return is_valid
......@@ -88,7 +88,8 @@ def is_rgb(img):
Bool, True if input is RGB.
"""
if is_numpy(img):
if len(np.shape(img)) == 3:
img_shape = np.shape(img)
if len(np.shape(img)) == 3 and (img_shape[0] == 3 or img_shape[2] == 3):
return True
return False
raise TypeError('img should be Numpy array. Got {}'.format(type(img)))
......@@ -127,6 +128,7 @@ class ImageTransform:
of the image is not normalized , it will be normalized between 0 to 1."""
rgb = is_rgb(image)
chw = False
gray3dim = False
normalized = is_normalized(image)
if rgb:
chw = is_chw(image)
......@@ -135,12 +137,16 @@ class ImageTransform:
else:
image = image
else:
image = image
if len(np.shape(image)) == 3:
gray3dim = True
image = image[0]
else:
image = image
if normalized:
image = np.uint8(image*255)
return rgb, chw, normalized, image
return rgb, chw, normalized, gray3dim, image
def _original_format(self, image, chw, normalized):
def _original_format(self, image, chw, normalized, gray3dim):
""" Return transformed image with original format. """
if not is_numpy(image):
image = np.array(image)
......@@ -148,6 +154,8 @@ class ImageTransform:
image = hwc_to_chw(image)
if normalized:
image = image / 255
if gray3dim:
image = np.expand_dims(image, 0)
return image
def transform(self, image):
......@@ -191,11 +199,12 @@ class Contrast(ImageTransform):
Returns:
numpy.ndarray, transformed image.
"""
_, chw, normalized, image = self._check(image)
_, chw, normalized, gray3dim, image = self._check(image)
image = to_pil(image)
img_contrast = ImageEnhance.Contrast(image)
trans_image = img_contrast.enhance(self.factor)
trans_image = self._original_format(trans_image, chw, normalized)
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image
......@@ -237,11 +246,12 @@ class Brightness(ImageTransform):
Returns:
numpy.ndarray, transformed image.
"""
_, chw, normalized, image = self._check(image)
_, chw, normalized, gray3dim, image = self._check(image)
image = to_pil(image)
img_contrast = ImageEnhance.Brightness(image)
trans_image = img_contrast.enhance(self.factor)
trans_image = self._original_format(trans_image, chw, normalized)
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image
......@@ -280,10 +290,11 @@ class Blur(ImageTransform):
Returns:
numpy.ndarray, transformed image.
"""
_, chw, normalized, image = self._check(image)
_, chw, normalized, gray3dim, image = self._check(image)
image = to_pil(image)
trans_image = image.filter(ImageFilter.GaussianBlur(radius=self.radius))
trans_image = self._original_format(trans_image, chw, normalized)
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image
......@@ -324,12 +335,13 @@ class Noise(ImageTransform):
Returns:
numpy.ndarray, transformed image.
"""
_, chw, normalized, image = self._check(image)
_, chw, normalized, gray3dim, image = self._check(image)
noise = np.random.uniform(low=-1, high=1, size=np.shape(image))
trans_image = np.copy(image)
trans_image[noise < -self.factor] = 0
trans_image[noise > self.factor] = 1
trans_image = self._original_format(trans_image, chw, normalized)
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image
......@@ -375,7 +387,7 @@ class Translate(ImageTransform):
Returns:
numpy.ndarray, transformed image.
"""
_, chw, normalized, image = self._check(image)
_, chw, normalized, gray3dim, image = self._check(image)
img = to_pil(image)
if self.auto_param:
image_shape = np.shape(image)
......@@ -383,7 +395,8 @@ class Translate(ImageTransform):
self.y_bias = image_shape[1]*self.y_bias
trans_image = img.transform(img.size, Image.AFFINE,
(1, 0, self.x_bias, 0, 1, self.y_bias))
trans_image = self._original_format(trans_image, chw, normalized)
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image
......@@ -431,7 +444,7 @@ class Scale(ImageTransform):
Returns:
numpy.ndarray, transformed image.
"""
rgb, chw, normalized, image = self._check(image)
rgb, chw, normalized, gray3dim, image = self._check(image)
if rgb:
h, w, _ = np.shape(image)
else:
......@@ -442,7 +455,8 @@ class Scale(ImageTransform):
trans_image = img.transform(img.size, Image.AFFINE,
(self.factor_x, 0, move_x_centor,
0, self.factor_y, move_y_centor))
trans_image = self._original_format(trans_image, chw, normalized)
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image
......@@ -500,7 +514,7 @@ class Shear(ImageTransform):
Returns:
numpy.ndarray, transformed image.
"""
rgb, chw, normalized, image = self._check(image)
rgb, chw, normalized, gray3dim, image = self._check(image)
img = to_pil(image)
if rgb:
h, w, _ = np.shape(image)
......@@ -523,7 +537,8 @@ class Shear(ImageTransform):
trans_image = img.transform(img.size, Image.AFFINE,
(scale, scale*self.factor_x, move_x_cen,
scale*self.factor_y, scale, move_y_cen))
trans_image = self._original_format(trans_image, chw, normalized)
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image
......@@ -562,8 +577,9 @@ class Rotate(ImageTransform):
Returns:
numpy.ndarray, transformed image.
"""
_, chw, normalized, image = self._check(image)
_, chw, normalized, gray3dim, image = self._check(image)
img = to_pil(image)
trans_image = img.rotate(self.angle, expand=True)
trans_image = self._original_format(trans_image, chw, normalized)
trans_image = self._original_format(trans_image, chw, normalized,
gray3dim)
return trans_image
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Model-fuzz coverage test.
"""
import numpy as np
import pytest
from mindspore import context
from mindspore import nn
from mindspore.common.initializer import TruncatedNormal
from mindspore.ops import operations as P
from mindspore.train import Model
from mindarmour.fuzzing.fuzzing import Fuzzer
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
TAG = 'Fuzzing test'
LOGGER.set_level('INFO')
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
return TruncatedNormal(0.02)
class Net(nn.Cell):
"""
Lenet network
"""
def __init__(self):
super(Net, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16*5*5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.reshape(x, (-1, 16*5*5))
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_fuzzing_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
# load network
net = Net()
model = Model(net)
batch_size = 8
num_classe = 10
mutate_config = [{'method': 'Blur',
'params': {'auto_param': True}},
{'method': 'Contrast',
'params': {'factor': 2}},
{'method': 'Translate',
'params': {'x_bias': 0.1, 'y_bias': 0.2}},
{'method': 'FGSM',
'params': {'eps': 0.1, 'alpha': 0.1}}
]
# initialize fuzz test with training dataset
train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images)
# fuzz test with original test data
# get test data
test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32)
test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32)
initial_seeds = []
# make initial seeds
for img, label in zip(test_images, test_labels):
initial_seeds.append([img, label, 0])
initial_seeds = initial_seeds[:100]
model_coverage_test.calculate_coverage(
np.array(test_images[:100]).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())
model_fuzz_test = Fuzzer(model, train_images, 1000, 10)
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds)
print(metrics)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_fuzzing_cpu():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
# load network
net = Net()
model = Model(net)
batch_size = 8
num_classe = 10
mutate_config = [{'method': 'Blur',
'params': {'auto_param': True}},
{'method': 'Contrast',
'params': {'factor': 2}},
{'method': 'Translate',
'params': {'x_bias': 0.1, 'y_bias': 0.2}},
{'method': 'FGSM',
'params': {'eps': 0.1, 'alpha': 0.1}}
]
# initialize fuzz test with training dataset
train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images)
# fuzz test with original test data
# get test data
test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32)
test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32)
initial_seeds = []
# make initial seeds
for img, label in zip(test_images, test_labels):
initial_seeds.append([img, label, 0])
initial_seeds = initial_seeds[:100]
model_coverage_test.calculate_coverage(
np.array(test_images[:100]).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())
model_fuzz_test = Fuzzer(model, train_images, 1000, 10)
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds)
print(metrics)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册