diff --git a/mindarmour/diff_privacy/mechanisms/mechanisms.py b/mindarmour/diff_privacy/mechanisms/mechanisms.py index e819355d8430540a430531e58c06b9a8c85c918c..02ab8a948cac77a53c8bc5afa60488927e6cc53c 100644 --- a/mindarmour/diff_privacy/mechanisms/mechanisms.py +++ b/mindarmour/diff_privacy/mechanisms/mechanisms.py @@ -86,7 +86,7 @@ class GaussianRandom(Mechanisms): >>> shape = (3, 2, 4) >>> norm_bound = 1.0 >>> initial_noise_multiplier = 1.5 - >>> net = GaussianRandom(shape, norm_bound, initial_noise_multiplier) + >>> net = GaussianRandom(norm_bound, initial_noise_multiplier) >>> res = net(shape) >>> print(res) """ diff --git a/mindarmour/diff_privacy/monitor/monitor.py b/mindarmour/diff_privacy/monitor/monitor.py index e92e2a2c8cb8bf60d77c1eb4390d21c5e284a3bc..eedf896a2d42066c496f0481fb15e7e0e8f6293c 100644 --- a/mindarmour/diff_privacy/monitor/monitor.py +++ b/mindarmour/diff_privacy/monitor/monitor.py @@ -90,6 +90,8 @@ class RDPMonitor(Callback): training. Default: 6e-4. per_print_times (int): The interval steps of computing and printing the privacy budget. Default: 50. + dataset_sink_mode (bool): If True, all training data would be passed to device(Ascend) at once. If False, + training data would be passed to device after each step training. Default: False. Examples: >>> rdp = PrivacyMonitorFactory.create(policy='rdp', @@ -104,7 +106,7 @@ class RDPMonitor(Callback): def __init__(self, num_samples, batch_size, initial_noise_multiplier=1.5, max_eps=10.0, target_delta=1e-3, max_delta=None, target_eps=None, orders=None, noise_decay_mode='Time', - noise_decay_rate=6e-4, per_print_times=50): + noise_decay_rate=6e-4, per_print_times=50, dataset_sink_mode=False): super(RDPMonitor, self).__init__() check_int_positive('num_samples', num_samples) check_int_positive('batch_size', batch_size) @@ -141,6 +143,7 @@ class RDPMonitor(Callback): noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float) check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0) check_int_positive('per_print_times', per_print_times) + check_param_type('dataset_sink_mode', dataset_sink_mode, bool) self._total_echo_privacy = None self._num_samples = num_samples @@ -155,6 +158,8 @@ class RDPMonitor(Callback): self._noise_decay_rate = noise_decay_rate self._rdp = 0 self._per_print_times = per_print_times + if dataset_sink_mode: + self._per_print_times = int(self._num_samples / self._batch_size) def max_epoch_suggest(self): """