提交 6e7475d6 编写于 作者: J jin-xiulang

Fix bugs for monitor.py and model_coverage_metrics.py.

上级 92165efc
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Monitor module of differential privacy training. """
import math
import numpy as np
from scipy import special
......@@ -40,8 +39,9 @@ class PrivacyMonitorFactory:
Create a privacy monitor class.
Args:
policy (str): Monitor policy, 'rdp' is supported by now. RDP means R'enyi differential privacy,
which computed based on R'enyi divergence.
policy (str): Monitor policy, 'rdp' is supported by now. RDP
means R'enyi differential privacy, which computed based
on R'enyi divergence.
args (Union[int, float, numpy.ndarray, list, str]): Parameters
used for creating a privacy monitor.
kwargs (Union[int, float, numpy.ndarray, list, str]): Keyword
......@@ -60,9 +60,14 @@ class PrivacyMonitorFactory:
class RDPMonitor(Callback):
"""
r"""
Compute the privacy budget of DP training based on Renyi differential
privacy theory.
privacy (RDP) theory. According to the reference below, if a randomized
mechanism is said to have ε'-Renyi differential privacy of order α, it
also satisfies conventional differential privacy (ε, δ) as below:
.. math::
(ε'+\frac{log(1/δ)}{α-1}, δ)
Reference: `Rényi Differential Privacy of the Sampled Gaussian Mechanism
<https://arxiv.org/abs/1908.10530>`_
......@@ -70,33 +75,43 @@ class RDPMonitor(Callback):
Args:
num_samples (int): The total number of samples in training data sets.
batch_size (int): The number of samples in a batch while training.
initial_noise_multiplier (Union[float, int]): The initial
multiplier of the noise added to training parameters' gradients. Default: 1.5.
initial_noise_multiplier (Union[float, int]): Ratio of the standard
deviation of Gaussian noise divided by the norm_bound, which will
be used to calculate privacy spent. Default: 1.5.
max_eps (Union[float, int, None]): The maximum acceptable epsilon
budget for DP training. Default: 10.0.
budget for DP training, which is used for estimating the max
training epochs. Default: 10.0.
target_delta (Union[float, int, None]): Target delta budget for DP
training. Default: 1e-3.
training. If target_delta is set to be δ, then the privacy budget
δ would be fixed during the whole training process. Default: 1e-3.
max_delta (Union[float, int, None]): The maximum acceptable delta
budget for DP training. Max_delta must be less than 1 and
suggested to be less than 1e-3, otherwise overflow would be
encountered. Default: None.
budget for DP training, which is used for estimating the max
training epochs. Max_delta must be less than 1 and suggested
to be less than 1e-3, otherwise overflow would be encountered.
Default: None.
target_eps (Union[float, int, None]): Target epsilon budget for DP
training. Default: None.
training. If target_eps is set to be ε, then the privacy budget
ε would be fixed during the whole training process. Default: None.
orders (Union[None, list[int, float]]): Finite orders used for
computing rdp, which must be greater than 1.
computing rdp, which must be greater than 1. The computation result
of privacy budget would be different for various orders. In order
to obtain a tighter (smaller) privacy budget estimation, a list
of orders could be tried. Default: None.
noise_decay_mode (str): Decay mode of adding noise while training,
which can be 'no_decay', 'Time' or 'Step'. Default: 'Time'.
noise_decay_rate (Union[float, None]): Decay rate of noise while
training. Default: 6e-4.
per_print_times (int): The interval steps of computing and printing
the privacy budget. Default: 50.
dataset_sink_mode (bool): If True, all training data would be passed to device(Ascend) at once. If False,
training data would be passed to device after each step training. Default: False.
dataset_sink_mode (bool): If True, all training data would be passed
to device(Ascend) at once. If False, training data would be passed
to device after each step training. Default: False.
Examples:
>>> rdp = PrivacyMonitorFactory.create(policy='rdp',
>>> num_samples=60000, batch_size=256)
>>> network = Net()
>>> epochs = 2
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits()
>>> net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
>>> model = Model(network, net_loss, net_opt)
......@@ -158,6 +173,15 @@ class RDPMonitor(Callback):
self._noise_decay_rate = noise_decay_rate
self._rdp = 0
self._per_print_times = per_print_times
if self._target_eps is None and self._target_delta is None:
msg = 'target eps and target delta cannot both be None'
LOGGER.error(TAG, msg)
raise ValueError(msg)
if self._target_eps is not None and self._target_delta is not None:
msg = 'One of target eps and target delta must be None'
LOGGER.error(TAG, msg)
raise ValueError(msg)
if dataset_sink_mode:
self._per_print_times = int(self._num_samples / self._batch_size)
......@@ -178,7 +202,7 @@ class RDPMonitor(Callback):
while epoch < 10000:
steps = self._num_samples // self._batch_size
eps, delta = self._compute_privacy_steps(
list(np.arange((epoch - 1) * steps, epoch * steps + 1)))
list(np.arange((epoch - 1)*steps, epoch*steps + 1)))
if self._max_eps is not None:
if eps <= self._max_eps:
epoch += 1
......@@ -189,6 +213,7 @@ class RDPMonitor(Callback):
epoch += 1
else:
break
# reset the rdp for model training
self._rdp = 0
return epoch
......@@ -233,25 +258,15 @@ class RDPMonitor(Callback):
Returns:
float, privacy budget.
"""
if self._target_eps is None and self._target_delta is None:
msg = 'target eps and target delta cannot both be None'
LOGGER.error(TAG, msg)
raise ValueError(msg)
if self._target_eps is not None and self._target_delta is not None:
msg = 'One of target eps and target delta must be None'
LOGGER.error(TAG, msg)
raise ValueError(msg)
if self._orders is None:
self._orders = (
[1.005, 1.01, 1.02, 1.08, 1.2, 2, 5, 10, 20, 40, 80])
sampling_rate = self._batch_size / self._num_samples
noise_step = self._initial_noise_multiplier
noise_stddev_step = self._initial_noise_multiplier
if self._noise_decay_mode == 'no_decay':
self._rdp += self._compute_rdp(sampling_rate, noise_step) * len(
self._rdp += self._compute_rdp(sampling_rate, noise_stddev_step)*len(
steps)
else:
if self._noise_decay_rate is None:
......@@ -260,33 +275,33 @@ class RDPMonitor(Callback):
raise ValueError(msg)
if self._noise_decay_mode == 'Time':
noise_step = [self._initial_noise_multiplier / (
1 + self._noise_decay_rate * step) for step in steps]
noise_stddev_step = [self._initial_noise_multiplier / (
1 + self._noise_decay_rate*step) for step in steps]
elif self._noise_decay_mode == 'Step':
noise_step = [self._initial_noise_multiplier * (
1 - self._noise_decay_rate) ** step for step in steps]
noise_stddev_step = [self._initial_noise_multiplier*(
1 - self._noise_decay_rate)**step for step in steps]
self._rdp += sum(
[self._compute_rdp(sampling_rate, noise) for noise in
noise_step])
noise_stddev_step])
eps, delta = self._compute_privacy_budget(self._rdp)
return eps, delta
def _compute_rdp(self, q, noise):
def _compute_rdp(self, sample_rate, noise_stddev):
"""
Compute rdp according to sampling rate, added noise and Renyi
divergence orders.
Args:
q (float): Sampling rate of each batch of samples.
noise (float): Noise multiplier.
sample_rate (float): Sampling rate of each batch of samples.
noise_stddev (float): Noise multiplier.
Returns:
float or numpy.ndarray, rdp values.
"""
rdp = np.array(
[_compute_rdp_order(q, noise, order) for order in self._orders])
[_compute_rdp_with_order(sample_rate, noise_stddev, order) for order in self._orders])
return rdp
def _compute_privacy_budget(self, rdp):
......@@ -317,14 +332,9 @@ class RDPMonitor(Callback):
"""
orders = np.atleast_1d(self._orders)
rdps = np.atleast_1d(rdp)
if len(orders) != len(rdps):
msg = 'rdp lists and orders list must have the same length.'
LOGGER.error(TAG, msg)
raise ValueError(msg)
deltas = np.exp((rdps - self._target_eps) * (orders - 1))
min_delta = min(deltas)
return min(min_delta, 1.)
deltas = np.exp((rdps - self._target_eps)*(orders - 1))
min_delta = np.min(deltas)
return np.min([min_delta, 1.])
def _compute_eps(self, rdp):
"""
......@@ -338,50 +348,46 @@ class RDPMonitor(Callback):
"""
orders = np.atleast_1d(self._orders)
rdps = np.atleast_1d(rdp)
if len(orders) != len(rdps):
msg = 'rdp lists and orders list must have the same length.'
LOGGER.error(TAG, msg)
raise ValueError(msg)
eps = rdps - math.log(self._target_delta) / (orders - 1)
return min(eps)
eps = rdps - np.log(self._target_delta) / (orders - 1)
return np.min(eps)
def _compute_rdp_order(q, sigma, alpha):
def _compute_rdp_with_order(sample_rate, noise_stddev, order):
"""
Compute rdp for each order.
Args:
q (float): Sampling probability.
sigma (float): Noise multiplier.
alpha: The order used for computing rdp.
sample_rate (float): Sampling probability.
noise_stddev (float): Noise multiplier.
order: The order used for computing rdp.
Returns:
float, rdp value.
"""
if float(alpha).is_integer():
if float(order).is_integer():
log_integrate = -np.inf
for k in range(alpha + 1):
term_k = (math.log(
special.binom(alpha, k)) + k * math.log(q) + (
alpha - k) * math.log(
1 - q)) + (k * k - k) / (2 * (sigma ** 2))
for k in range(order + 1):
term_k = (np.log(
special.binom(order, k)) + k*np.log(sample_rate) + (
order - k)*np.log(
1 - sample_rate)) + (k*k - k) / (2*(noise_stddev**2))
log_integrate = _log_add(log_integrate, term_k)
return float(log_integrate) / (alpha - 1)
return float(log_integrate) / (order - 1)
log_part_0, log_part_1 = -np.inf, -np.inf
k = 0
z0 = sigma ** 2 * math.log(1 / q - 1) + 1 / 2
z0 = noise_stddev**2*np.log(1 / sample_rate - 1) + 1 / 2
while True:
bi_coef = special.binom(alpha, k)
log_coef = math.log(abs(bi_coef))
j = alpha - k
bi_coef = special.binom(order, k)
log_coef = np.log(abs(bi_coef))
j = order - k
term_k_part_0 = log_coef + k * math.log(q) + j * math.log(1 - q) + (
k * k - k) / (2 * (sigma ** 2)) + special.log_ndtr(
(z0 - k) / sigma)
term_k_part_0 = log_coef + k*np.log(sample_rate) + j*np.log(1 - sample_rate) + (
k*k - k) / (2*(noise_stddev**2)) + special.log_ndtr(
(z0 - k) / noise_stddev)
term_k_part_1 = log_coef + j * math.log(q) + k * math.log(1 - q) + (
j * j - j) / (2 * (sigma ** 2)) + special.log_ndtr(
(j - z0) / sigma)
term_k_part_1 = log_coef + j*np.log(sample_rate) + k*np.log(1 - sample_rate) + (
j*j - j) / (2*(noise_stddev**2)) + special.log_ndtr(
(j - z0) / noise_stddev)
if bi_coef > 0:
log_part_0 = _log_add(log_part_0, term_k_part_0)
......@@ -391,10 +397,10 @@ def _compute_rdp_order(q, sigma, alpha):
log_part_1 = _log_subtract(log_part_1, term_k_part_1)
k += 1
if max(term_k_part_0, term_k_part_1) < -30:
if np.max([term_k_part_0, term_k_part_1]) < -30:
break
return _log_add(log_part_0, log_part_1) / (alpha - 1)
return _log_add(log_part_0, log_part_1) / (order - 1)
def _log_add(x, y):
......@@ -405,7 +411,7 @@ def _log_add(x, y):
return y
if y == -np.inf:
return x
return max(x, y) + math.log1p(math.exp(-abs(x - y)))
return np.max([x, y]) + np.log1p(np.exp(-abs(x - y)))
def _log_subtract(x, y):
......@@ -418,4 +424,4 @@ def _log_subtract(x, y):
raise ValueError(msg)
if y == -np.inf:
return x
return math.log1p(math.exp(y - x)) + x
return np.log1p(np.exp(y - x)) + x
......@@ -26,15 +26,21 @@ from mindarmour.utils._check_param import check_model, check_numpy_param, \
class ModelCoverageMetrics:
"""
Evaluate the testing adequacy of a model fuzz test.
As we all known, each neuron output of a network will have a output range
after training (we call it original range), and test dataset is used to
estimate the accuracy of the trained network. However, neurons' output
distribution would be different with different test datasets. Therefore,
similar to function fuzz, model fuzz means testing those neurons' outputs
and estimating the proportion of original range that has emerged with test
datasets.
Reference: `DeepGauge: Multi-Granularity Testing Criteria for Deep
Learning Systems <https://arxiv.org/abs/1803.07519>`_
Args:
model (Model): The pre-trained model which waiting for testing.
k (int): The number of segmented sections of neurons' output intervals.
n (int): The number of testing neurons.
segmented_num (int): The number of segmented sections of neurons' output intervals.
neuron_num (int): The number of testing neurons.
train_dataset (numpy.ndarray): Training dataset used for determine
the neurons' output boundaries.
......@@ -49,18 +55,18 @@ class ModelCoverageMetrics:
>>> print('SNAC of this test is : %s', model_fuzz_test.get_snac())
"""
def __init__(self, model, k, n, train_dataset):
def __init__(self, model, segmented_num, neuron_num, train_dataset):
self._model = check_model('model', model, Model)
self._k = k
self._n = n
self._segmented_num = check_int_positive('segmented_num', segmented_num)
self._neuron_num = check_int_positive('neuron_num', neuron_num)
train_dataset = check_numpy_param('train_dataset', train_dataset)
self._lower_bounds = [np.inf]*n
self._upper_bounds = [-np.inf]*n
self._var = [0]*n
self._main_section_hits = [[0 for _ in range(self._k)] for _ in
range(self._n)]
self._lower_corner_hits = [0]*self._n
self._upper_corner_hits = [0]*self._n
self._lower_bounds = [np.inf]*neuron_num
self._upper_bounds = [-np.inf]*neuron_num
self._var = [0]*neuron_num
self._main_section_hits = [[0 for _ in range(self._segmented_num)] for _ in
range(self._neuron_num)]
self._lower_corner_hits = [0]*self._neuron_num
self._upper_corner_hits = [0]*self._neuron_num
self._bounds_get(train_dataset)
def _bounds_get(self, train_dataset, batch_size=32):
......@@ -107,10 +113,10 @@ class ModelCoverageMetrics:
batch_output = self._model.predict(Tensor(dataset)).asnumpy()
batch_section_indexes = (batch_output - self._lower_bounds) // intervals
for section_indexes in batch_section_indexes:
for i in range(self._n):
for i in range(self._neuron_num):
if section_indexes[i] < 0:
self._lower_corner_hits[i] = 1
elif section_indexes[i] >= self._k:
elif section_indexes[i] >= self._segmented_num:
self._upper_corner_hits[i] = 1
else:
self._main_section_hits[i][int(section_indexes[i])] = 1
......@@ -135,7 +141,7 @@ class ModelCoverageMetrics:
batch_size = check_int_positive('batch_size', batch_size)
self._lower_bounds -= bias_coefficient*self._var
self._upper_bounds += bias_coefficient*self._var
intervals = (self._upper_bounds - self._lower_bounds) / self._k
intervals = (self._upper_bounds - self._lower_bounds) / self._segmented_num
batches = dataset.shape[0] // batch_size
for i in range(batches):
self._sections_hits_count(
......@@ -151,7 +157,7 @@ class ModelCoverageMetrics:
Examples:
>>> model_fuzz_test.get_kmnc()
"""
kmnc = np.sum(self._main_section_hits) / (self._n*self._k)
kmnc = np.sum(self._main_section_hits) / (self._neuron_num*self._segmented_num)
return kmnc
def get_nbc(self):
......@@ -165,7 +171,7 @@ class ModelCoverageMetrics:
>>> model_fuzz_test.get_nbc()
"""
nbc = (np.sum(self._lower_corner_hits) + np.sum(
self._upper_corner_hits)) / (2*self._n)
self._upper_corner_hits)) / (2*self._neuron_num)
return nbc
def get_snac(self):
......@@ -178,5 +184,5 @@ class ModelCoverageMetrics:
Examples:
>>> model_fuzz_test.get_snac()
"""
snac = np.sum(self._upper_corner_hits) / self._n
snac = np.sum(self._upper_corner_hits) / self._neuron_num
return snac
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册