Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
6e4ed1c3
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
2
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6e4ed1c3
编写于
7月 02, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 02, 2020
浏览文件
操作
浏览文件
下载
差异文件
!43 Fix code specification issues for monitor.py and model_coverage_metrics.py.
Merge pull request !43 from jxlang910/master
上级
e24cf5f6
6e7475d6
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
108 addition
and
96 deletion
+108
-96
mindarmour/diff_privacy/monitor/monitor.py
mindarmour/diff_privacy/monitor/monitor.py
+83
-77
mindarmour/fuzzing/model_coverage_metrics.py
mindarmour/fuzzing/model_coverage_metrics.py
+25
-19
未找到文件。
mindarmour/diff_privacy/monitor/monitor.py
浏览文件 @
6e4ed1c3
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Monitor module of differential privacy training. """
import
math
import
numpy
as
np
from
scipy
import
special
...
...
@@ -40,8 +39,9 @@ class PrivacyMonitorFactory:
Create a privacy monitor class.
Args:
policy (str): Monitor policy, 'rdp' is supported by now. RDP means R'enyi differential privacy,
which computed based on R'enyi divergence.
policy (str): Monitor policy, 'rdp' is supported by now. RDP
means R'enyi differential privacy, which computed based
on R'enyi divergence.
args (Union[int, float, numpy.ndarray, list, str]): Parameters
used for creating a privacy monitor.
kwargs (Union[int, float, numpy.ndarray, list, str]): Keyword
...
...
@@ -60,9 +60,14 @@ class PrivacyMonitorFactory:
class
RDPMonitor
(
Callback
):
"""
r
"""
Compute the privacy budget of DP training based on Renyi differential
privacy theory.
privacy (RDP) theory. According to the reference below, if a randomized
mechanism is said to have ε'-Renyi differential privacy of order α, it
also satisfies conventional differential privacy (ε, δ) as below:
.. math::
(ε'+\frac{log(1/δ)}{α-1}, δ)
Reference: `Rényi Differential Privacy of the Sampled Gaussian Mechanism
<https://arxiv.org/abs/1908.10530>`_
...
...
@@ -70,33 +75,43 @@ class RDPMonitor(Callback):
Args:
num_samples (int): The total number of samples in training data sets.
batch_size (int): The number of samples in a batch while training.
initial_noise_multiplier (Union[float, int]): The initial
multiplier of the noise added to training parameters' gradients. Default: 1.5.
initial_noise_multiplier (Union[float, int]): Ratio of the standard
deviation of Gaussian noise divided by the norm_bound, which will
be used to calculate privacy spent. Default: 1.5.
max_eps (Union[float, int, None]): The maximum acceptable epsilon
budget for DP training. Default: 10.0.
budget for DP training, which is used for estimating the max
training epochs. Default: 10.0.
target_delta (Union[float, int, None]): Target delta budget for DP
training. Default: 1e-3.
training. If target_delta is set to be δ, then the privacy budget
δ would be fixed during the whole training process. Default: 1e-3.
max_delta (Union[float, int, None]): The maximum acceptable delta
budget for DP training. Max_delta must be less than 1 and
suggested to be less than 1e-3, otherwise overflow would be
encountered. Default: None.
budget for DP training, which is used for estimating the max
training epochs. Max_delta must be less than 1 and suggested
to be less than 1e-3, otherwise overflow would be encountered.
Default: None.
target_eps (Union[float, int, None]): Target epsilon budget for DP
training. Default: None.
training. If target_eps is set to be ε, then the privacy budget
ε would be fixed during the whole training process. Default: None.
orders (Union[None, list[int, float]]): Finite orders used for
computing rdp, which must be greater than 1.
computing rdp, which must be greater than 1. The computation result
of privacy budget would be different for various orders. In order
to obtain a tighter (smaller) privacy budget estimation, a list
of orders could be tried. Default: None.
noise_decay_mode (str): Decay mode of adding noise while training,
which can be 'no_decay', 'Time' or 'Step'. Default: 'Time'.
noise_decay_rate (Union[float, None]): Decay rate of noise while
training. Default: 6e-4.
per_print_times (int): The interval steps of computing and printing
the privacy budget. Default: 50.
dataset_sink_mode (bool): If True, all training data would be passed to device(Ascend) at once. If False,
training data would be passed to device after each step training. Default: False.
dataset_sink_mode (bool): If True, all training data would be passed
to device(Ascend) at once. If False, training data would be passed
to device after each step training. Default: False.
Examples:
>>> rdp = PrivacyMonitorFactory.create(policy='rdp',
>>> num_samples=60000, batch_size=256)
>>> network = Net()
>>> epochs = 2
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits()
>>> net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
>>> model = Model(network, net_loss, net_opt)
...
...
@@ -158,6 +173,15 @@ class RDPMonitor(Callback):
self
.
_noise_decay_rate
=
noise_decay_rate
self
.
_rdp
=
0
self
.
_per_print_times
=
per_print_times
if
self
.
_target_eps
is
None
and
self
.
_target_delta
is
None
:
msg
=
'target eps and target delta cannot both be None'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
self
.
_target_eps
is
not
None
and
self
.
_target_delta
is
not
None
:
msg
=
'One of target eps and target delta must be None'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
dataset_sink_mode
:
self
.
_per_print_times
=
int
(
self
.
_num_samples
/
self
.
_batch_size
)
...
...
@@ -178,7 +202,7 @@ class RDPMonitor(Callback):
while
epoch
<
10000
:
steps
=
self
.
_num_samples
//
self
.
_batch_size
eps
,
delta
=
self
.
_compute_privacy_steps
(
list
(
np
.
arange
((
epoch
-
1
)
*
steps
,
epoch
*
steps
+
1
)))
list
(
np
.
arange
((
epoch
-
1
)
*
steps
,
epoch
*
steps
+
1
)))
if
self
.
_max_eps
is
not
None
:
if
eps
<=
self
.
_max_eps
:
epoch
+=
1
...
...
@@ -189,6 +213,7 @@ class RDPMonitor(Callback):
epoch
+=
1
else
:
break
# reset the rdp for model training
self
.
_rdp
=
0
return
epoch
...
...
@@ -233,25 +258,15 @@ class RDPMonitor(Callback):
Returns:
float, privacy budget.
"""
if
self
.
_target_eps
is
None
and
self
.
_target_delta
is
None
:
msg
=
'target eps and target delta cannot both be None'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
self
.
_target_eps
is
not
None
and
self
.
_target_delta
is
not
None
:
msg
=
'One of target eps and target delta must be None'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
self
.
_orders
is
None
:
self
.
_orders
=
(
[
1.005
,
1.01
,
1.02
,
1.08
,
1.2
,
2
,
5
,
10
,
20
,
40
,
80
])
sampling_rate
=
self
.
_batch_size
/
self
.
_num_samples
noise_step
=
self
.
_initial_noise_multiplier
noise_st
ddev_st
ep
=
self
.
_initial_noise_multiplier
if
self
.
_noise_decay_mode
==
'no_decay'
:
self
.
_rdp
+=
self
.
_compute_rdp
(
sampling_rate
,
noise_st
ep
)
*
len
(
self
.
_rdp
+=
self
.
_compute_rdp
(
sampling_rate
,
noise_st
ddev_step
)
*
len
(
steps
)
else
:
if
self
.
_noise_decay_rate
is
None
:
...
...
@@ -260,33 +275,33 @@ class RDPMonitor(Callback):
raise
ValueError
(
msg
)
if
self
.
_noise_decay_mode
==
'Time'
:
noise_step
=
[
self
.
_initial_noise_multiplier
/
(
1
+
self
.
_noise_decay_rate
*
step
)
for
step
in
steps
]
noise_st
ddev_st
ep
=
[
self
.
_initial_noise_multiplier
/
(
1
+
self
.
_noise_decay_rate
*
step
)
for
step
in
steps
]
elif
self
.
_noise_decay_mode
==
'Step'
:
noise_st
ep
=
[
self
.
_initial_noise_multiplier
*
(
1
-
self
.
_noise_decay_rate
)
**
step
for
step
in
steps
]
noise_st
ddev_step
=
[
self
.
_initial_noise_multiplier
*
(
1
-
self
.
_noise_decay_rate
)
**
step
for
step
in
steps
]
self
.
_rdp
+=
sum
(
[
self
.
_compute_rdp
(
sampling_rate
,
noise
)
for
noise
in
noise_step
])
noise_st
ddev_st
ep
])
eps
,
delta
=
self
.
_compute_privacy_budget
(
self
.
_rdp
)
return
eps
,
delta
def
_compute_rdp
(
self
,
q
,
noise
):
def
_compute_rdp
(
self
,
sample_rate
,
noise_stddev
):
"""
Compute rdp according to sampling rate, added noise and Renyi
divergence orders.
Args:
q
(float): Sampling rate of each batch of samples.
noise (float): Noise multiplier.
sample_rate
(float): Sampling rate of each batch of samples.
noise
_stddev
(float): Noise multiplier.
Returns:
float or numpy.ndarray, rdp values.
"""
rdp
=
np
.
array
(
[
_compute_rdp_
order
(
q
,
noise
,
order
)
for
order
in
self
.
_orders
])
[
_compute_rdp_
with_order
(
sample_rate
,
noise_stddev
,
order
)
for
order
in
self
.
_orders
])
return
rdp
def
_compute_privacy_budget
(
self
,
rdp
):
...
...
@@ -317,14 +332,9 @@ class RDPMonitor(Callback):
"""
orders
=
np
.
atleast_1d
(
self
.
_orders
)
rdps
=
np
.
atleast_1d
(
rdp
)
if
len
(
orders
)
!=
len
(
rdps
):
msg
=
'rdp lists and orders list must have the same length.'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
deltas
=
np
.
exp
((
rdps
-
self
.
_target_eps
)
*
(
orders
-
1
))
min_delta
=
min
(
deltas
)
return
min
(
min_delta
,
1.
)
deltas
=
np
.
exp
((
rdps
-
self
.
_target_eps
)
*
(
orders
-
1
))
min_delta
=
np
.
min
(
deltas
)
return
np
.
min
([
min_delta
,
1.
])
def
_compute_eps
(
self
,
rdp
):
"""
...
...
@@ -338,50 +348,46 @@ class RDPMonitor(Callback):
"""
orders
=
np
.
atleast_1d
(
self
.
_orders
)
rdps
=
np
.
atleast_1d
(
rdp
)
if
len
(
orders
)
!=
len
(
rdps
):
msg
=
'rdp lists and orders list must have the same length.'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
eps
=
rdps
-
math
.
log
(
self
.
_target_delta
)
/
(
orders
-
1
)
return
min
(
eps
)
eps
=
rdps
-
np
.
log
(
self
.
_target_delta
)
/
(
orders
-
1
)
return
np
.
min
(
eps
)
def
_compute_rdp_
order
(
q
,
sigma
,
alpha
):
def
_compute_rdp_
with_order
(
sample_rate
,
noise_stddev
,
order
):
"""
Compute rdp for each order.
Args:
q
(float): Sampling probability.
sigma
(float): Noise multiplier.
alpha
: The order used for computing rdp.
sample_rate
(float): Sampling probability.
noise_stddev
(float): Noise multiplier.
order
: The order used for computing rdp.
Returns:
float, rdp value.
"""
if
float
(
alpha
).
is_integer
():
if
float
(
order
).
is_integer
():
log_integrate
=
-
np
.
inf
for
k
in
range
(
alpha
+
1
):
term_k
=
(
math
.
log
(
special
.
binom
(
alpha
,
k
))
+
k
*
math
.
log
(
q
)
+
(
alpha
-
k
)
*
math
.
log
(
1
-
q
))
+
(
k
*
k
-
k
)
/
(
2
*
(
sigma
**
2
))
for
k
in
range
(
order
+
1
):
term_k
=
(
np
.
log
(
special
.
binom
(
order
,
k
))
+
k
*
np
.
log
(
sample_rate
)
+
(
order
-
k
)
*
np
.
log
(
1
-
sample_rate
))
+
(
k
*
k
-
k
)
/
(
2
*
(
noise_stddev
**
2
))
log_integrate
=
_log_add
(
log_integrate
,
term_k
)
return
float
(
log_integrate
)
/
(
alpha
-
1
)
return
float
(
log_integrate
)
/
(
order
-
1
)
log_part_0
,
log_part_1
=
-
np
.
inf
,
-
np
.
inf
k
=
0
z0
=
sigma
**
2
*
math
.
log
(
1
/
q
-
1
)
+
1
/
2
z0
=
noise_stddev
**
2
*
np
.
log
(
1
/
sample_rate
-
1
)
+
1
/
2
while
True
:
bi_coef
=
special
.
binom
(
alpha
,
k
)
log_coef
=
math
.
log
(
abs
(
bi_coef
))
j
=
alpha
-
k
bi_coef
=
special
.
binom
(
order
,
k
)
log_coef
=
np
.
log
(
abs
(
bi_coef
))
j
=
order
-
k
term_k_part_0
=
log_coef
+
k
*
math
.
log
(
q
)
+
j
*
math
.
log
(
1
-
q
)
+
(
k
*
k
-
k
)
/
(
2
*
(
sigma
**
2
))
+
special
.
log_ndtr
(
(
z0
-
k
)
/
sigma
)
term_k_part_0
=
log_coef
+
k
*
np
.
log
(
sample_rate
)
+
j
*
np
.
log
(
1
-
sample_rate
)
+
(
k
*
k
-
k
)
/
(
2
*
(
noise_stddev
**
2
))
+
special
.
log_ndtr
(
(
z0
-
k
)
/
noise_stddev
)
term_k_part_1
=
log_coef
+
j
*
math
.
log
(
q
)
+
k
*
math
.
log
(
1
-
q
)
+
(
j
*
j
-
j
)
/
(
2
*
(
sigma
**
2
))
+
special
.
log_ndtr
(
(
j
-
z0
)
/
sigma
)
term_k_part_1
=
log_coef
+
j
*
np
.
log
(
sample_rate
)
+
k
*
np
.
log
(
1
-
sample_rate
)
+
(
j
*
j
-
j
)
/
(
2
*
(
noise_stddev
**
2
))
+
special
.
log_ndtr
(
(
j
-
z0
)
/
noise_stddev
)
if
bi_coef
>
0
:
log_part_0
=
_log_add
(
log_part_0
,
term_k_part_0
)
...
...
@@ -391,10 +397,10 @@ def _compute_rdp_order(q, sigma, alpha):
log_part_1
=
_log_subtract
(
log_part_1
,
term_k_part_1
)
k
+=
1
if
max
(
term_k_part_0
,
term_k_part_1
)
<
-
30
:
if
np
.
max
([
term_k_part_0
,
term_k_part_1
]
)
<
-
30
:
break
return
_log_add
(
log_part_0
,
log_part_1
)
/
(
alpha
-
1
)
return
_log_add
(
log_part_0
,
log_part_1
)
/
(
order
-
1
)
def
_log_add
(
x
,
y
):
...
...
@@ -405,7 +411,7 @@ def _log_add(x, y):
return
y
if
y
==
-
np
.
inf
:
return
x
return
max
(
x
,
y
)
+
math
.
log1p
(
math
.
exp
(
-
abs
(
x
-
y
)))
return
np
.
max
([
x
,
y
])
+
np
.
log1p
(
np
.
exp
(
-
abs
(
x
-
y
)))
def
_log_subtract
(
x
,
y
):
...
...
@@ -418,4 +424,4 @@ def _log_subtract(x, y):
raise
ValueError
(
msg
)
if
y
==
-
np
.
inf
:
return
x
return
math
.
log1p
(
math
.
exp
(
y
-
x
))
+
x
return
np
.
log1p
(
np
.
exp
(
y
-
x
))
+
x
mindarmour/fuzzing/model_coverage_metrics.py
浏览文件 @
6e4ed1c3
...
...
@@ -26,15 +26,21 @@ from mindarmour.utils._check_param import check_model, check_numpy_param, \
class
ModelCoverageMetrics
:
"""
Evaluate the testing adequacy of a model fuzz test.
As we all known, each neuron output of a network will have a output range
after training (we call it original range), and test dataset is used to
estimate the accuracy of the trained network. However, neurons' output
distribution would be different with different test datasets. Therefore,
similar to function fuzz, model fuzz means testing those neurons' outputs
and estimating the proportion of original range that has emerged with test
datasets.
Reference: `DeepGauge: Multi-Granularity Testing Criteria for Deep
Learning Systems <https://arxiv.org/abs/1803.07519>`_
Args:
model (Model): The pre-trained model which waiting for testing.
k
(int): The number of segmented sections of neurons' output intervals.
n (int): The number of testing neurons.
segmented_num
(int): The number of segmented sections of neurons' output intervals.
n
euron_num
(int): The number of testing neurons.
train_dataset (numpy.ndarray): Training dataset used for determine
the neurons' output boundaries.
...
...
@@ -49,18 +55,18 @@ class ModelCoverageMetrics:
>>> print('SNAC of this test is : %s', model_fuzz_test.get_snac())
"""
def
__init__
(
self
,
model
,
k
,
n
,
train_dataset
):
def
__init__
(
self
,
model
,
segmented_num
,
neuron_num
,
train_dataset
):
self
.
_model
=
check_model
(
'model'
,
model
,
Model
)
self
.
_
k
=
k
self
.
_n
=
n
self
.
_
segmented_num
=
check_int_positive
(
'segmented_num'
,
segmented_num
)
self
.
_n
euron_num
=
check_int_positive
(
'neuron_num'
,
neuron_num
)
train_dataset
=
check_numpy_param
(
'train_dataset'
,
train_dataset
)
self
.
_lower_bounds
=
[
np
.
inf
]
*
n
self
.
_upper_bounds
=
[
-
np
.
inf
]
*
n
self
.
_var
=
[
0
]
*
n
self
.
_main_section_hits
=
[[
0
for
_
in
range
(
self
.
_
k
)]
for
_
in
range
(
self
.
_n
)]
self
.
_lower_corner_hits
=
[
0
]
*
self
.
_n
self
.
_upper_corner_hits
=
[
0
]
*
self
.
_n
self
.
_lower_bounds
=
[
np
.
inf
]
*
n
euron_num
self
.
_upper_bounds
=
[
-
np
.
inf
]
*
n
euron_num
self
.
_var
=
[
0
]
*
n
euron_num
self
.
_main_section_hits
=
[[
0
for
_
in
range
(
self
.
_
segmented_num
)]
for
_
in
range
(
self
.
_n
euron_num
)]
self
.
_lower_corner_hits
=
[
0
]
*
self
.
_n
euron_num
self
.
_upper_corner_hits
=
[
0
]
*
self
.
_n
euron_num
self
.
_bounds_get
(
train_dataset
)
def
_bounds_get
(
self
,
train_dataset
,
batch_size
=
32
):
...
...
@@ -107,10 +113,10 @@ class ModelCoverageMetrics:
batch_output
=
self
.
_model
.
predict
(
Tensor
(
dataset
)).
asnumpy
()
batch_section_indexes
=
(
batch_output
-
self
.
_lower_bounds
)
//
intervals
for
section_indexes
in
batch_section_indexes
:
for
i
in
range
(
self
.
_n
):
for
i
in
range
(
self
.
_n
euron_num
):
if
section_indexes
[
i
]
<
0
:
self
.
_lower_corner_hits
[
i
]
=
1
elif
section_indexes
[
i
]
>=
self
.
_
k
:
elif
section_indexes
[
i
]
>=
self
.
_
segmented_num
:
self
.
_upper_corner_hits
[
i
]
=
1
else
:
self
.
_main_section_hits
[
i
][
int
(
section_indexes
[
i
])]
=
1
...
...
@@ -135,7 +141,7 @@ class ModelCoverageMetrics:
batch_size
=
check_int_positive
(
'batch_size'
,
batch_size
)
self
.
_lower_bounds
-=
bias_coefficient
*
self
.
_var
self
.
_upper_bounds
+=
bias_coefficient
*
self
.
_var
intervals
=
(
self
.
_upper_bounds
-
self
.
_lower_bounds
)
/
self
.
_
k
intervals
=
(
self
.
_upper_bounds
-
self
.
_lower_bounds
)
/
self
.
_
segmented_num
batches
=
dataset
.
shape
[
0
]
//
batch_size
for
i
in
range
(
batches
):
self
.
_sections_hits_count
(
...
...
@@ -151,7 +157,7 @@ class ModelCoverageMetrics:
Examples:
>>> model_fuzz_test.get_kmnc()
"""
kmnc
=
np
.
sum
(
self
.
_main_section_hits
)
/
(
self
.
_n
*
self
.
_k
)
kmnc
=
np
.
sum
(
self
.
_main_section_hits
)
/
(
self
.
_n
euron_num
*
self
.
_segmented_num
)
return
kmnc
def
get_nbc
(
self
):
...
...
@@ -165,7 +171,7 @@ class ModelCoverageMetrics:
>>> model_fuzz_test.get_nbc()
"""
nbc
=
(
np
.
sum
(
self
.
_lower_corner_hits
)
+
np
.
sum
(
self
.
_upper_corner_hits
))
/
(
2
*
self
.
_n
)
self
.
_upper_corner_hits
))
/
(
2
*
self
.
_n
euron_num
)
return
nbc
def
get_snac
(
self
):
...
...
@@ -178,5 +184,5 @@ class ModelCoverageMetrics:
Examples:
>>> model_fuzz_test.get_snac()
"""
snac
=
np
.
sum
(
self
.
_upper_corner_hits
)
/
self
.
_n
snac
=
np
.
sum
(
self
.
_upper_corner_hits
)
/
self
.
_n
euron_num
return
snac
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录