提交 e2f4ed2c 编写于 作者: Z ZhidanLiu

fix review bugs in fuzzing and mechanism

上级 aaa9f89f
...@@ -70,7 +70,7 @@ def test_lenet_mnist_fuzzing(): ...@@ -70,7 +70,7 @@ def test_lenet_mnist_fuzzing():
# make initial seeds # make initial seeds
for img, label in zip(test_images, test_labels): for img, label in zip(test_images, test_labels):
initial_seeds.append([img, label, 0]) initial_seeds.append([img, label])
initial_seeds = initial_seeds[:100] initial_seeds = initial_seeds[:100]
model_coverage_test.test_adequacy_coverage_calculate(np.array(test_images[:100]).astype(np.float32)) model_coverage_test.test_adequacy_coverage_calculate(np.array(test_images[:100]).astype(np.float32))
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
""" """
Noise Mechanisms. Noise Mechanisms.
""" """
from abc import abstractmethod
from mindspore import Tensor from mindspore import Tensor
from mindspore.nn import Cell from mindspore.nn import Cell
from mindspore.ops import operations as P from mindspore.ops import operations as P
...@@ -22,8 +24,11 @@ from mindspore.common import dtype as mstype ...@@ -22,8 +24,11 @@ from mindspore.common import dtype as mstype
from mindarmour.utils._check_param import check_param_type from mindarmour.utils._check_param import check_param_type
from mindarmour.utils._check_param import check_value_positive from mindarmour.utils._check_param import check_value_positive
from mindarmour.utils._check_param import check_value_non_negative
from mindarmour.utils._check_param import check_param_in_range from mindarmour.utils._check_param import check_param_in_range
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
TAG = 'Defense'
class MechanismsFactory: class MechanismsFactory:
...@@ -98,6 +103,7 @@ class Mechanisms(Cell): ...@@ -98,6 +103,7 @@ class Mechanisms(Cell):
Basic class of noise generated mechanism. Basic class of noise generated mechanism.
""" """
@abstractmethod
def construct(self, gradients): def construct(self, gradients):
""" """
Construct function. Construct function.
...@@ -114,8 +120,9 @@ class GaussianRandom(Mechanisms): ...@@ -114,8 +120,9 @@ class GaussianRandom(Mechanisms):
initial_noise_multiplier(float): Ratio of the standard deviation of initial_noise_multiplier(float): Ratio of the standard deviation of
Gaussian noise divided by the norm_bound, which will be used to Gaussian noise divided by the norm_bound, which will be used to
calculate privacy spent. Default: 1.5. calculate privacy spent. Default: 1.5.
mean(float): Average value of random noise. Default: 0.0. seed(int): Original random seed, if seed=0 random normal will use secure
seed(int): Original random seed. Default: 0. random number. IF seed!=0 random normal will generate values using
given seed. Default: 0.
Returns: Returns:
Tensor, generated noise with shape like given gradients. Tensor, generated noise with shape like given gradients.
...@@ -129,16 +136,14 @@ class GaussianRandom(Mechanisms): ...@@ -129,16 +136,14 @@ class GaussianRandom(Mechanisms):
>>> print(res) >>> print(res)
""" """
def __init__(self, norm_bound=0.5, initial_noise_multiplier=1.5, mean=0.0, seed=0): def __init__(self, norm_bound=0.5, initial_noise_multiplier=1.5, seed=0):
super(GaussianRandom, self).__init__() super(GaussianRandom, self).__init__()
self._norm_bound = check_value_positive('norm_bound', norm_bound) self._norm_bound = check_value_positive('norm_bound', norm_bound)
self._norm_bound = Tensor(norm_bound, mstype.float32) self._norm_bound = Tensor(norm_bound, mstype.float32)
self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier', self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier',
initial_noise_multiplier) initial_noise_multiplier)
self._initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32) self._initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32)
mean = check_param_type('mean', mean, float) self._mean = Tensor(0, mstype.float32)
mean = check_value_non_negative('mean', mean)
self._mean = Tensor(mean, mstype.float32)
self._normal = P.Normal(seed=seed) self._normal = P.Normal(seed=seed)
def construct(self, gradients): def construct(self, gradients):
...@@ -159,8 +164,8 @@ class GaussianRandom(Mechanisms): ...@@ -159,8 +164,8 @@ class GaussianRandom(Mechanisms):
class AdaGaussianRandom(Mechanisms): class AdaGaussianRandom(Mechanisms):
""" """
Adaptive Gaussian noise generated mechanism. Noise would be decayed with training. Decay mode could be 'Time' Adaptive Gaussian noise generated mechanism. Noise would be decayed with
mode or 'Step' mode. training. Decay mode could be 'Time' mode or 'Step' mode.
Args: Args:
norm_bound(float): Clipping bound for the l2 norm of the gradients. norm_bound(float): Clipping bound for the l2 norm of the gradients.
...@@ -191,7 +196,7 @@ class AdaGaussianRandom(Mechanisms): ...@@ -191,7 +196,7 @@ class AdaGaussianRandom(Mechanisms):
>>> print(res) >>> print(res)
""" """
def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.5, mean=0.0, def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.5,
noise_decay_rate=6e-4, decay_policy='Time', seed=0): noise_decay_rate=6e-4, decay_policy='Time', seed=0):
super(AdaGaussianRandom, self).__init__() super(AdaGaussianRandom, self).__init__()
norm_bound = check_value_positive('norm_bound', norm_bound) norm_bound = check_value_positive('norm_bound', norm_bound)
...@@ -205,9 +210,7 @@ class AdaGaussianRandom(Mechanisms): ...@@ -205,9 +210,7 @@ class AdaGaussianRandom(Mechanisms):
self._stddev = P.Mul()(self._norm_bound, self._initial_noise_multiplier) self._stddev = P.Mul()(self._norm_bound, self._initial_noise_multiplier)
self._noise_multiplier = Parameter(initial_noise_multiplier, self._noise_multiplier = Parameter(initial_noise_multiplier,
name='noise_multiplier') name='noise_multiplier')
mean = check_param_type('mean', mean, float) self._mean = Tensor(0, mstype.float32)
mean = check_value_non_negative('mean', mean)
self._mean = Tensor(mean, mstype.float32)
noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float) noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float)
check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0) check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0)
self._noise_decay_rate = Tensor(noise_decay_rate, mstype.float32) self._noise_decay_rate = Tensor(noise_decay_rate, mstype.float32)
......
...@@ -35,10 +35,10 @@ class Fuzzing: ...@@ -35,10 +35,10 @@ class Fuzzing:
Neural Networks <https://dl.acm.org/doi/10.1145/3293882.3330579>`_ Neural Networks <https://dl.acm.org/doi/10.1145/3293882.3330579>`_
Args: Args:
initial_seeds (list): Initial fuzzing seed, format: [[image, label, 0], initial_seeds (list): Initial fuzzing seed, format: [[image, label],
[image, label, 0], ...]. [image, label], ...].
target_model (Model): Target fuzz model. target_model (Model): Target fuzz model.
train_dataset (numpy.ndarray): Training dataset used for determine train_dataset (numpy.ndarray): Training dataset used for determining
the neurons' output boundaries. the neurons' output boundaries.
const_k (int): The number of mutate tests for a seed. const_k (int): The number of mutate tests for a seed.
mode (str): Image mode used in image transform, 'L' means grey graph. mode (str): Image mode used in image transform, 'L' means grey graph.
...@@ -68,8 +68,8 @@ class Fuzzing: ...@@ -68,8 +68,8 @@ class Fuzzing:
seed = seed[0] seed = seed[0]
info = [seed, seed] info = [seed, seed]
mutate_tests = [] mutate_tests = []
affine_trans = ['Contrast', 'Brightness', 'Blur', 'Noise'] pixel_value_trans = ['Contrast', 'Brightness', 'Blur', 'Noise']
pixel_value_trans = ['Translate', 'Scale', 'Shear', 'Rotate'] affine_trans = ['Translate', 'Scale', 'Shear', 'Rotate']
strages = {'Contrast': Contrast, 'Brightness': Brightness, 'Blur': Blur, strages = {'Contrast': Contrast, 'Brightness': Brightness, 'Blur': Blur,
'Noise': Noise, 'Noise': Noise,
'Translate': Translate, 'Scale': Scale, 'Shear': Shear, 'Translate': Translate, 'Scale': Scale, 'Shear': Shear,
...@@ -80,7 +80,8 @@ class Fuzzing: ...@@ -80,7 +80,8 @@ class Fuzzing:
trans_strage = self._random_pick_mutate(affine_trans, trans_strage = self._random_pick_mutate(affine_trans,
pixel_value_trans) pixel_value_trans)
else: else:
trans_strage = self._random_pick_mutate(affine_trans, []) trans_strage = self._random_pick_mutate(pixel_value_trans,
[])
transform = strages[trans_strage]( transform = strages[trans_strage](
self._image_value_expand(seed), self.mode) self._image_value_expand(seed), self.mode)
transform.random_param() transform.random_param()
...@@ -105,21 +106,21 @@ class Fuzzing: ...@@ -105,21 +106,21 @@ class Fuzzing:
Default: 'KMNC'. Default: 'KMNC'.
Returns: Returns:
list, mutated tests mis-predicted by target dnn model. list, mutated tests mis-predicted by target DNN model.
""" """
seed = self._select_next() seed = self._select_next()
failed_tests = [] failed_tests = []
seed_num = 0 seed_num = 0
while seed and seed_num < self.max_seed_num: while seed and seed_num < self.max_seed_num:
mutate_tests = self._metamorphic_mutate(seed[0]) mutate_tests = self._metamorphic_mutate(seed[0])
coverages, results = self._run(mutate_tests, coverage_metric) coverages, predicts = self._run(mutate_tests, coverage_metric)
coverage_gains = self._coverage_gains(coverages) coverage_gains = self._coverage_gains(coverages)
for mutate, cov, res in zip(mutate_tests, coverage_gains, results): for mutate, cov, res in zip(mutate_tests, coverage_gains, predicts):
if np.argmax(seed[1]) != np.argmax(res): if np.argmax(seed[1]) != np.argmax(res):
failed_tests.append(mutate) failed_tests.append(mutate)
continue continue
if cov > 0: if cov > 0:
self.initial_seeds.append([mutate, seed[1], 0]) self.initial_seeds.append([mutate, seed[1]])
seed = self._select_next() seed = self._select_next()
seed_num += 1 seed_num += 1
...@@ -154,17 +155,17 @@ class Fuzzing: ...@@ -154,17 +155,17 @@ class Fuzzing:
def _is_trans_valid(self, seed, mutate_test): def _is_trans_valid(self, seed, mutate_test):
is_valid = False is_valid = False
alpha = 0.02 pixels_change_rate = 0.02
beta = 0.2 pixel_value_change_rate = 0.2
diff = np.array(seed - mutate_test).flatten() diff = np.array(seed - mutate_test).flatten()
size = np.shape(diff)[0] size = np.shape(diff)[0]
l0 = np.linalg.norm(diff, ord=0) l0 = np.linalg.norm(diff, ord=0)
linf = np.linalg.norm(diff, ord=np.inf) linf = np.linalg.norm(diff, ord=np.inf)
if l0 > alpha*size: if l0 > pixels_change_rate*size:
if linf < 256: if linf < 256:
is_valid = True is_valid = True
else: else:
if linf < beta*255: if linf < pixel_value_change_rate*255:
is_valid = True is_valid = True
return is_valid return is_valid
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册