From 6e7475d697103a46857f7fd010cffd241e011cd3 Mon Sep 17 00:00:00 2001 From: jin-xiulang Date: Tue, 30 Jun 2020 15:54:44 +0800 Subject: [PATCH] Fix bugs for monitor.py and model_coverage_metrics.py. --- mindarmour/diff_privacy/monitor/monitor.py | 160 ++++++++++--------- mindarmour/fuzzing/model_coverage_metrics.py | 44 ++--- 2 files changed, 108 insertions(+), 96 deletions(-) diff --git a/mindarmour/diff_privacy/monitor/monitor.py b/mindarmour/diff_privacy/monitor/monitor.py index eedf896..f0a5697 100644 --- a/mindarmour/diff_privacy/monitor/monitor.py +++ b/mindarmour/diff_privacy/monitor/monitor.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Monitor module of differential privacy training. """ -import math import numpy as np from scipy import special @@ -40,8 +39,9 @@ class PrivacyMonitorFactory: Create a privacy monitor class. Args: - policy (str): Monitor policy, 'rdp' is supported by now. RDP means R'enyi differential privacy, - which computed based on R'enyi divergence. + policy (str): Monitor policy, 'rdp' is supported by now. RDP + means R'enyi differential privacy, which computed based + on R'enyi divergence. args (Union[int, float, numpy.ndarray, list, str]): Parameters used for creating a privacy monitor. kwargs (Union[int, float, numpy.ndarray, list, str]): Keyword @@ -60,9 +60,14 @@ class PrivacyMonitorFactory: class RDPMonitor(Callback): - """ + r""" Compute the privacy budget of DP training based on Renyi differential - privacy theory. + privacy (RDP) theory. According to the reference below, if a randomized + mechanism is said to have ε'-Renyi differential privacy of order α, it + also satisfies conventional differential privacy (ε, δ) as below: + + .. math:: + (ε'+\frac{log(1/δ)}{α-1}, δ) Reference: `Rényi Differential Privacy of the Sampled Gaussian Mechanism `_ @@ -70,33 +75,43 @@ class RDPMonitor(Callback): Args: num_samples (int): The total number of samples in training data sets. batch_size (int): The number of samples in a batch while training. - initial_noise_multiplier (Union[float, int]): The initial - multiplier of the noise added to training parameters' gradients. Default: 1.5. + initial_noise_multiplier (Union[float, int]): Ratio of the standard + deviation of Gaussian noise divided by the norm_bound, which will + be used to calculate privacy spent. Default: 1.5. max_eps (Union[float, int, None]): The maximum acceptable epsilon - budget for DP training. Default: 10.0. + budget for DP training, which is used for estimating the max + training epochs. Default: 10.0. target_delta (Union[float, int, None]): Target delta budget for DP - training. Default: 1e-3. + training. If target_delta is set to be δ, then the privacy budget + δ would be fixed during the whole training process. Default: 1e-3. max_delta (Union[float, int, None]): The maximum acceptable delta - budget for DP training. Max_delta must be less than 1 and - suggested to be less than 1e-3, otherwise overflow would be - encountered. Default: None. + budget for DP training, which is used for estimating the max + training epochs. Max_delta must be less than 1 and suggested + to be less than 1e-3, otherwise overflow would be encountered. + Default: None. target_eps (Union[float, int, None]): Target epsilon budget for DP - training. Default: None. + training. If target_eps is set to be ε, then the privacy budget + ε would be fixed during the whole training process. Default: None. orders (Union[None, list[int, float]]): Finite orders used for - computing rdp, which must be greater than 1. + computing rdp, which must be greater than 1. The computation result + of privacy budget would be different for various orders. In order + to obtain a tighter (smaller) privacy budget estimation, a list + of orders could be tried. Default: None. noise_decay_mode (str): Decay mode of adding noise while training, which can be 'no_decay', 'Time' or 'Step'. Default: 'Time'. noise_decay_rate (Union[float, None]): Decay rate of noise while training. Default: 6e-4. per_print_times (int): The interval steps of computing and printing the privacy budget. Default: 50. - dataset_sink_mode (bool): If True, all training data would be passed to device(Ascend) at once. If False, - training data would be passed to device after each step training. Default: False. + dataset_sink_mode (bool): If True, all training data would be passed + to device(Ascend) at once. If False, training data would be passed + to device after each step training. Default: False. Examples: >>> rdp = PrivacyMonitorFactory.create(policy='rdp', >>> num_samples=60000, batch_size=256) >>> network = Net() + >>> epochs = 2 >>> net_loss = nn.SoftmaxCrossEntropyWithLogits() >>> net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) >>> model = Model(network, net_loss, net_opt) @@ -158,6 +173,15 @@ class RDPMonitor(Callback): self._noise_decay_rate = noise_decay_rate self._rdp = 0 self._per_print_times = per_print_times + if self._target_eps is None and self._target_delta is None: + msg = 'target eps and target delta cannot both be None' + LOGGER.error(TAG, msg) + raise ValueError(msg) + + if self._target_eps is not None and self._target_delta is not None: + msg = 'One of target eps and target delta must be None' + LOGGER.error(TAG, msg) + raise ValueError(msg) if dataset_sink_mode: self._per_print_times = int(self._num_samples / self._batch_size) @@ -178,7 +202,7 @@ class RDPMonitor(Callback): while epoch < 10000: steps = self._num_samples // self._batch_size eps, delta = self._compute_privacy_steps( - list(np.arange((epoch - 1) * steps, epoch * steps + 1))) + list(np.arange((epoch - 1)*steps, epoch*steps + 1))) if self._max_eps is not None: if eps <= self._max_eps: epoch += 1 @@ -189,6 +213,7 @@ class RDPMonitor(Callback): epoch += 1 else: break + # reset the rdp for model training self._rdp = 0 return epoch @@ -233,25 +258,15 @@ class RDPMonitor(Callback): Returns: float, privacy budget. """ - if self._target_eps is None and self._target_delta is None: - msg = 'target eps and target delta cannot both be None' - LOGGER.error(TAG, msg) - raise ValueError(msg) - - if self._target_eps is not None and self._target_delta is not None: - msg = 'One of target eps and target delta must be None' - LOGGER.error(TAG, msg) - raise ValueError(msg) - if self._orders is None: self._orders = ( [1.005, 1.01, 1.02, 1.08, 1.2, 2, 5, 10, 20, 40, 80]) sampling_rate = self._batch_size / self._num_samples - noise_step = self._initial_noise_multiplier + noise_stddev_step = self._initial_noise_multiplier if self._noise_decay_mode == 'no_decay': - self._rdp += self._compute_rdp(sampling_rate, noise_step) * len( + self._rdp += self._compute_rdp(sampling_rate, noise_stddev_step)*len( steps) else: if self._noise_decay_rate is None: @@ -260,33 +275,33 @@ class RDPMonitor(Callback): raise ValueError(msg) if self._noise_decay_mode == 'Time': - noise_step = [self._initial_noise_multiplier / ( - 1 + self._noise_decay_rate * step) for step in steps] + noise_stddev_step = [self._initial_noise_multiplier / ( + 1 + self._noise_decay_rate*step) for step in steps] elif self._noise_decay_mode == 'Step': - noise_step = [self._initial_noise_multiplier * ( - 1 - self._noise_decay_rate) ** step for step in steps] + noise_stddev_step = [self._initial_noise_multiplier*( + 1 - self._noise_decay_rate)**step for step in steps] self._rdp += sum( [self._compute_rdp(sampling_rate, noise) for noise in - noise_step]) + noise_stddev_step]) eps, delta = self._compute_privacy_budget(self._rdp) return eps, delta - def _compute_rdp(self, q, noise): + def _compute_rdp(self, sample_rate, noise_stddev): """ Compute rdp according to sampling rate, added noise and Renyi divergence orders. Args: - q (float): Sampling rate of each batch of samples. - noise (float): Noise multiplier. + sample_rate (float): Sampling rate of each batch of samples. + noise_stddev (float): Noise multiplier. Returns: float or numpy.ndarray, rdp values. """ rdp = np.array( - [_compute_rdp_order(q, noise, order) for order in self._orders]) + [_compute_rdp_with_order(sample_rate, noise_stddev, order) for order in self._orders]) return rdp def _compute_privacy_budget(self, rdp): @@ -317,14 +332,9 @@ class RDPMonitor(Callback): """ orders = np.atleast_1d(self._orders) rdps = np.atleast_1d(rdp) - if len(orders) != len(rdps): - msg = 'rdp lists and orders list must have the same length.' - LOGGER.error(TAG, msg) - raise ValueError(msg) - - deltas = np.exp((rdps - self._target_eps) * (orders - 1)) - min_delta = min(deltas) - return min(min_delta, 1.) + deltas = np.exp((rdps - self._target_eps)*(orders - 1)) + min_delta = np.min(deltas) + return np.min([min_delta, 1.]) def _compute_eps(self, rdp): """ @@ -338,50 +348,46 @@ class RDPMonitor(Callback): """ orders = np.atleast_1d(self._orders) rdps = np.atleast_1d(rdp) - if len(orders) != len(rdps): - msg = 'rdp lists and orders list must have the same length.' - LOGGER.error(TAG, msg) - raise ValueError(msg) - eps = rdps - math.log(self._target_delta) / (orders - 1) - return min(eps) + eps = rdps - np.log(self._target_delta) / (orders - 1) + return np.min(eps) -def _compute_rdp_order(q, sigma, alpha): +def _compute_rdp_with_order(sample_rate, noise_stddev, order): """ Compute rdp for each order. Args: - q (float): Sampling probability. - sigma (float): Noise multiplier. - alpha: The order used for computing rdp. + sample_rate (float): Sampling probability. + noise_stddev (float): Noise multiplier. + order: The order used for computing rdp. Returns: float, rdp value. """ - if float(alpha).is_integer(): + if float(order).is_integer(): log_integrate = -np.inf - for k in range(alpha + 1): - term_k = (math.log( - special.binom(alpha, k)) + k * math.log(q) + ( - alpha - k) * math.log( - 1 - q)) + (k * k - k) / (2 * (sigma ** 2)) + for k in range(order + 1): + term_k = (np.log( + special.binom(order, k)) + k*np.log(sample_rate) + ( + order - k)*np.log( + 1 - sample_rate)) + (k*k - k) / (2*(noise_stddev**2)) log_integrate = _log_add(log_integrate, term_k) - return float(log_integrate) / (alpha - 1) + return float(log_integrate) / (order - 1) log_part_0, log_part_1 = -np.inf, -np.inf k = 0 - z0 = sigma ** 2 * math.log(1 / q - 1) + 1 / 2 + z0 = noise_stddev**2*np.log(1 / sample_rate - 1) + 1 / 2 while True: - bi_coef = special.binom(alpha, k) - log_coef = math.log(abs(bi_coef)) - j = alpha - k + bi_coef = special.binom(order, k) + log_coef = np.log(abs(bi_coef)) + j = order - k - term_k_part_0 = log_coef + k * math.log(q) + j * math.log(1 - q) + ( - k * k - k) / (2 * (sigma ** 2)) + special.log_ndtr( - (z0 - k) / sigma) + term_k_part_0 = log_coef + k*np.log(sample_rate) + j*np.log(1 - sample_rate) + ( + k*k - k) / (2*(noise_stddev**2)) + special.log_ndtr( + (z0 - k) / noise_stddev) - term_k_part_1 = log_coef + j * math.log(q) + k * math.log(1 - q) + ( - j * j - j) / (2 * (sigma ** 2)) + special.log_ndtr( - (j - z0) / sigma) + term_k_part_1 = log_coef + j*np.log(sample_rate) + k*np.log(1 - sample_rate) + ( + j*j - j) / (2*(noise_stddev**2)) + special.log_ndtr( + (j - z0) / noise_stddev) if bi_coef > 0: log_part_0 = _log_add(log_part_0, term_k_part_0) @@ -391,10 +397,10 @@ def _compute_rdp_order(q, sigma, alpha): log_part_1 = _log_subtract(log_part_1, term_k_part_1) k += 1 - if max(term_k_part_0, term_k_part_1) < -30: + if np.max([term_k_part_0, term_k_part_1]) < -30: break - return _log_add(log_part_0, log_part_1) / (alpha - 1) + return _log_add(log_part_0, log_part_1) / (order - 1) def _log_add(x, y): @@ -405,7 +411,7 @@ def _log_add(x, y): return y if y == -np.inf: return x - return max(x, y) + math.log1p(math.exp(-abs(x - y))) + return np.max([x, y]) + np.log1p(np.exp(-abs(x - y))) def _log_subtract(x, y): @@ -418,4 +424,4 @@ def _log_subtract(x, y): raise ValueError(msg) if y == -np.inf: return x - return math.log1p(math.exp(y - x)) + x + return np.log1p(np.exp(y - x)) + x diff --git a/mindarmour/fuzzing/model_coverage_metrics.py b/mindarmour/fuzzing/model_coverage_metrics.py index 95f2de1..78f0952 100644 --- a/mindarmour/fuzzing/model_coverage_metrics.py +++ b/mindarmour/fuzzing/model_coverage_metrics.py @@ -26,15 +26,21 @@ from mindarmour.utils._check_param import check_model, check_numpy_param, \ class ModelCoverageMetrics: """ - Evaluate the testing adequacy of a model fuzz test. + As we all known, each neuron output of a network will have a output range + after training (we call it original range), and test dataset is used to + estimate the accuracy of the trained network. However, neurons' output + distribution would be different with different test datasets. Therefore, + similar to function fuzz, model fuzz means testing those neurons' outputs + and estimating the proportion of original range that has emerged with test + datasets. Reference: `DeepGauge: Multi-Granularity Testing Criteria for Deep Learning Systems `_ Args: model (Model): The pre-trained model which waiting for testing. - k (int): The number of segmented sections of neurons' output intervals. - n (int): The number of testing neurons. + segmented_num (int): The number of segmented sections of neurons' output intervals. + neuron_num (int): The number of testing neurons. train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries. @@ -49,18 +55,18 @@ class ModelCoverageMetrics: >>> print('SNAC of this test is : %s', model_fuzz_test.get_snac()) """ - def __init__(self, model, k, n, train_dataset): + def __init__(self, model, segmented_num, neuron_num, train_dataset): self._model = check_model('model', model, Model) - self._k = k - self._n = n + self._segmented_num = check_int_positive('segmented_num', segmented_num) + self._neuron_num = check_int_positive('neuron_num', neuron_num) train_dataset = check_numpy_param('train_dataset', train_dataset) - self._lower_bounds = [np.inf]*n - self._upper_bounds = [-np.inf]*n - self._var = [0]*n - self._main_section_hits = [[0 for _ in range(self._k)] for _ in - range(self._n)] - self._lower_corner_hits = [0]*self._n - self._upper_corner_hits = [0]*self._n + self._lower_bounds = [np.inf]*neuron_num + self._upper_bounds = [-np.inf]*neuron_num + self._var = [0]*neuron_num + self._main_section_hits = [[0 for _ in range(self._segmented_num)] for _ in + range(self._neuron_num)] + self._lower_corner_hits = [0]*self._neuron_num + self._upper_corner_hits = [0]*self._neuron_num self._bounds_get(train_dataset) def _bounds_get(self, train_dataset, batch_size=32): @@ -107,10 +113,10 @@ class ModelCoverageMetrics: batch_output = self._model.predict(Tensor(dataset)).asnumpy() batch_section_indexes = (batch_output - self._lower_bounds) // intervals for section_indexes in batch_section_indexes: - for i in range(self._n): + for i in range(self._neuron_num): if section_indexes[i] < 0: self._lower_corner_hits[i] = 1 - elif section_indexes[i] >= self._k: + elif section_indexes[i] >= self._segmented_num: self._upper_corner_hits[i] = 1 else: self._main_section_hits[i][int(section_indexes[i])] = 1 @@ -135,7 +141,7 @@ class ModelCoverageMetrics: batch_size = check_int_positive('batch_size', batch_size) self._lower_bounds -= bias_coefficient*self._var self._upper_bounds += bias_coefficient*self._var - intervals = (self._upper_bounds - self._lower_bounds) / self._k + intervals = (self._upper_bounds - self._lower_bounds) / self._segmented_num batches = dataset.shape[0] // batch_size for i in range(batches): self._sections_hits_count( @@ -151,7 +157,7 @@ class ModelCoverageMetrics: Examples: >>> model_fuzz_test.get_kmnc() """ - kmnc = np.sum(self._main_section_hits) / (self._n*self._k) + kmnc = np.sum(self._main_section_hits) / (self._neuron_num*self._segmented_num) return kmnc def get_nbc(self): @@ -165,7 +171,7 @@ class ModelCoverageMetrics: >>> model_fuzz_test.get_nbc() """ nbc = (np.sum(self._lower_corner_hits) + np.sum( - self._upper_corner_hits)) / (2*self._n) + self._upper_corner_hits)) / (2*self._neuron_num) return nbc def get_snac(self): @@ -178,5 +184,5 @@ class ModelCoverageMetrics: Examples: >>> model_fuzz_test.get_snac() """ - snac = np.sum(self._upper_corner_hits) / self._n + snac = np.sum(self._upper_corner_hits) / self._neuron_num return snac -- GitLab