Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
ad008705
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
2
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ad008705
编写于
7月 08, 2020
作者:
J
jin-xiulang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add ZCDPMonitor and Exponential decay mode.
上级
ac39d193
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
319 addition
and
30 deletion
+319
-30
mindarmour/diff_privacy/monitor/monitor.py
mindarmour/diff_privacy/monitor/monitor.py
+225
-27
tests/ut/python/diff_privacy/test_monitor.py
tests/ut/python/diff_privacy/test_monitor.py
+94
-3
未找到文件。
mindarmour/diff_privacy/monitor/monitor.py
浏览文件 @
ad008705
...
...
@@ -39,9 +39,8 @@ class PrivacyMonitorFactory:
Create a privacy monitor class.
Args:
policy (str): Monitor policy, 'rdp' is supported by now. RDP
means R'enyi differential privacy, which computed based
on R'enyi divergence.
policy (str): Monitor policy, 'rdp' and 'zcdp' are supported
by now.
args (Union[int, float, numpy.ndarray, list, str]): Parameters
used for creating a privacy monitor.
kwargs (Union[int, float, numpy.ndarray, list, str]): Keyword
...
...
@@ -56,7 +55,9 @@ class PrivacyMonitorFactory:
"""
if
policy
==
'rdp'
:
return
RDPMonitor
(
*
args
,
**
kwargs
)
raise
ValueError
(
"Only RDP-policy is supported by now"
)
if
policy
==
'zcdp'
:
return
ZCDPMonitor
(
*
args
,
**
kwargs
)
raise
ValueError
(
"Only RDP-policy or ZCDP-policy is supported by now"
)
class
RDPMonitor
(
Callback
):
...
...
@@ -97,24 +98,28 @@ class RDPMonitor(Callback):
of privacy budget would be different for various orders. In order
to obtain a tighter (smaller) privacy budget estimation, a list
of orders could be tried. Default: None.
noise_decay_mode (str): Decay mode of adding noise while training,
which can be 'no_decay', 'Time' or 'Step'. Default: 'Time'.
noise_decay_rate (Union[float, None]): Decay rate of noise while
training. Default: 6e-4.
noise_decay_mode (Union[None, str]): Decay mode of adding noise while
training, which can be None, 'Time', 'Step' or 'Exp'. Default: 'Time'.
noise_decay_rate (float): Decay rate of noise while training. Default: 6e-4.
per_print_times (int): The interval steps of computing and printing
the privacy budget. Default: 50.
dataset_sink_mode (bool): If True, all training data would be passed
to device(Ascend)
at onc
e. If False, training data would be passed
to device(Ascend)
one-tim
e. If False, training data would be passed
to device after each step training. Default: False.
Examples:
>>> rdp = PrivacyMonitorFactory.create(policy='rdp',
>>> num_samples=60000, batch_size=256)
>>> network = Net()
>>> epochs = 2
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits()
>>> epochs = 2
>>> norm_clip = 1.0
>>> initial_noise_multiplier = 0.01
>>> mech = MechanismsFactory().create('Gaussian',
>>> norm_bound=norm_clip, initial_noise_multiplier=initial_noise_multiplier)
>>> net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
>>> model = Model(network, net_loss, net_opt)
>>> model = DPModel(micro_batches=2, norm_clip=norm_clip,
>>> mech=mech, network=network, loss_fn=loss, optimizer=net_opt, metrics=None)
>>> rdp = PrivacyMonitorFactory.create(policy='rdp',
>>> num_samples=60000, batch_size=256)
>>> model.train(epochs, ds, callbacks=[rdp], dataset_sink_mode=False)
"""
...
...
@@ -150,17 +155,16 @@ class RDPMonitor(Callback):
msg
=
'orders must be greater than 1'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
noise_decay_mode
not
in
(
'no_decay'
,
'Step'
,
'Time'
):
msg
=
"Noise decay mode must be in ('no_decay', 'Step', 'Time')"
if
noise_decay_mode
is
not
None
:
if
noise_decay_mode
not
in
(
'Step'
,
'Time'
,
'Exp'
):
msg
=
"Noise decay mode must be in ('Step', 'Time', 'Exp')"
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
noise_decay_rate
is
not
None
:
noise_decay_rate
=
check_param_type
(
'noise_decay_rate'
,
noise_decay_rate
,
float
)
check_param_in_range
(
'noise_decay_rate'
,
noise_decay_rate
,
0.0
,
1.0
)
check_int_positive
(
'per_print_times'
,
per_print_times
)
check_param_type
(
'dataset_sink_mode'
,
dataset_sink_mode
,
bool
)
self
.
_total_echo_privacy
=
None
self
.
_num_samples
=
num_samples
self
.
_batch_size
=
batch_size
self
.
_initial_noise_multiplier
=
initial_noise_multiplier
...
...
@@ -232,8 +236,7 @@ class RDPMonitor(Callback):
if
cb_params
.
cur_step_num
%
self
.
_per_print_times
==
0
:
steps
=
np
.
arange
(
cur_step
-
self
.
_per_print_times
,
cur_step
+
1
)
eps
,
delta
=
self
.
_compute_privacy_steps
(
list
(
steps
))
if
np
.
isnan
(
eps
)
or
np
.
isinf
(
eps
)
or
np
.
isnan
(
delta
)
or
np
.
isinf
(
delta
):
if
np
.
isnan
(
eps
)
or
np
.
isinf
(
eps
):
msg
=
'epoch: {} step: {}, invalid eps, terminating '
\
'training.'
.
format
(
cb_params
.
cur_epoch_num
,
cur_step_in_epoch
)
...
...
@@ -265,15 +268,10 @@ class RDPMonitor(Callback):
sampling_rate
=
self
.
_batch_size
/
self
.
_num_samples
noise_stddev_step
=
self
.
_initial_noise_multiplier
if
self
.
_noise_decay_mode
==
'no_decay'
:
if
self
.
_noise_decay_mode
is
None
:
self
.
_rdp
+=
self
.
_compute_rdp
(
sampling_rate
,
noise_stddev_step
)
*
len
(
steps
)
else
:
if
self
.
_noise_decay_rate
is
None
:
msg
=
'noise_decay_rate in decay-mode cannot be None'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
self
.
_noise_decay_mode
==
'Time'
:
noise_stddev_step
=
[
self
.
_initial_noise_multiplier
/
(
1
+
self
.
_noise_decay_rate
*
step
)
for
step
in
steps
]
...
...
@@ -281,6 +279,9 @@ class RDPMonitor(Callback):
elif
self
.
_noise_decay_mode
==
'Step'
:
noise_stddev_step
=
[
self
.
_initial_noise_multiplier
*
(
1
-
self
.
_noise_decay_rate
)
**
step
for
step
in
steps
]
elif
self
.
_noise_decay_mode
==
'Exp'
:
noise_stddev_step
=
[
self
.
_initial_noise_multiplier
*
np
.
exp
(
-
step
*
self
.
_noise_decay_rate
)
for
step
in
steps
]
self
.
_rdp
+=
sum
(
[
self
.
_compute_rdp
(
sampling_rate
,
noise
)
for
noise
in
noise_stddev_step
])
...
...
@@ -352,6 +353,203 @@ class RDPMonitor(Callback):
return
np
.
min
(
eps
)
class
ZCDPMonitor
(
Callback
):
r
"""
Compute the privacy budget of DP training based on zero-concentrated
differential privacy theory (zcdp). According to the reference below,
if a randomized mechanism is said to have ρ-zCDP, it also satisfies
conventional differential privacy (ε, δ) as below:
.. math::
(ρ+2\sqrt{ρlog(1/δ)}, δ)
Reference: `Concentrated Differentially Private Gradient Descent with
Adaptive per-Iteration Privacy Budget <https://arxiv.org/abs/1808.09501>`_
Args:
num_samples (int): The total number of samples in training data sets.
batch_size (int): The number of samples in a batch while training.
initial_noise_multiplier (Union[float, int]): Ratio of the standard
deviation of Gaussian noise divided by the norm_bound, which will
be used to calculate privacy spent. Default: 1.5.
max_eps (Union[float, int]): The maximum acceptable epsilon budget for
DP training, which is used for estimating the max training epochs.
Default: 10.0.
target_delta (Union[float, int]): Target delta budget for DP training.
If target_delta is set to be δ, then the privacy budget δ would be
fixed during the whole training process. Default: 1e-3.
noise_decay_mode (Union[None, str]): Decay mode of adding noise while
training, which can be None, 'Time', 'Step' or 'Exp'. Default: 'Time'.
noise_decay_rate (float): Decay rate of noise while training. Default: 6e-4.
per_print_times (int): The interval steps of computing and printing
the privacy budget. Default: 50.
dataset_sink_mode (bool): If True, all training data would be passed
to device(Ascend) one-time. If False, training data would be passed
to device after each step training. Default: False.
Examples:
>>> network = Net()
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits()
>>> epochs = 2
>>> norm_clip = 1.0
>>> initial_noise_multiplier = 0.01
>>> mech = MechanismsFactory().create('Gaussian',
>>> norm_bound=norm_clip, initial_noise_multiplier=initial_noise_multiplier)
>>> net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
>>> model = DPModel(micro_batches=2, norm_clip=norm_clip,
>>> mech=mech, network=network, loss_fn=loss, optimizer=net_opt, metrics=None)
>>> rdp = PrivacyMonitorFactory.create(policy='rdp',
>>> num_samples=60000, batch_size=256)
>>> model.train(epochs, ds, callbacks=[rdp], dataset_sink_mode=False)
"""
def
__init__
(
self
,
num_samples
,
batch_size
,
initial_noise_multiplier
=
1.5
,
max_eps
=
10.0
,
target_delta
=
1e-3
,
noise_decay_mode
=
'Time'
,
noise_decay_rate
=
6e-4
,
per_print_times
=
50
,
dataset_sink_mode
=
False
):
super
(
ZCDPMonitor
,
self
).
__init__
()
check_int_positive
(
'num_samples'
,
num_samples
)
check_int_positive
(
'batch_size'
,
batch_size
)
if
batch_size
>=
num_samples
:
msg
=
'Batch_size must be less than num_samples.'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
check_value_positive
(
'initial_noise_multiplier'
,
initial_noise_multiplier
)
if
noise_decay_mode
is
not
None
:
if
noise_decay_mode
not
in
(
'Step'
,
'Time'
,
'Exp'
):
msg
=
"Noise decay mode must be in ('Step', 'Time', 'Exp')"
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
noise_decay_rate
=
check_param_type
(
'noise_decay_rate'
,
noise_decay_rate
,
float
)
check_param_in_range
(
'noise_decay_rate'
,
noise_decay_rate
,
0.0
,
1.0
)
check_int_positive
(
'per_print_times'
,
per_print_times
)
check_param_type
(
'dataset_sink_mode'
,
dataset_sink_mode
,
bool
)
self
.
_num_samples
=
num_samples
self
.
_batch_size
=
batch_size
self
.
_initial_noise_multiplier
=
initial_noise_multiplier
self
.
_max_eps
=
check_value_positive
(
'max_eps'
,
max_eps
)
self
.
_target_delta
=
check_param_in_range
(
'target_delta'
,
target_delta
,
0.0
,
1.0
)
self
.
_noise_decay_mode
=
noise_decay_mode
self
.
_noise_decay_rate
=
noise_decay_rate
# initialize zcdp
self
.
_zcdp
=
0
self
.
_per_print_times
=
per_print_times
if
dataset_sink_mode
:
self
.
_per_print_times
=
int
(
self
.
_num_samples
/
self
.
_batch_size
)
def
max_epoch_suggest
(
self
):
"""
Estimate the maximum training epochs to satisfy the predefined
privacy budget.
Returns:
int, the recommended maximum training epochs.
Examples:
>>> zcdp = PrivacyMonitorFactory.create(policy='zcdp',
>>> num_samples=60000, batch_size=32)
>>> suggest_epoch = zcdp.max_epoch_suggest()
"""
epoch
=
1
while
epoch
<
10000
:
steps
=
self
.
_num_samples
//
self
.
_batch_size
eps
,
_
=
self
.
_compute_privacy_steps
(
list
(
np
.
arange
((
epoch
-
1
)
*
steps
,
epoch
*
steps
+
1
)))
if
eps
<=
self
.
_max_eps
:
epoch
+=
1
else
:
break
# initialize the zcdp for model training
self
.
_zcdp
=
0
return
epoch
def
step_end
(
self
,
run_context
):
"""
Compute privacy budget after each training step.
Args:
run_context (RunContext): Include some information of the model.
"""
cb_params
=
run_context
.
original_args
()
cur_step
=
cb_params
.
cur_step_num
cur_step_in_epoch
=
(
cb_params
.
cur_step_num
-
1
)
%
\
cb_params
.
batch_num
+
1
if
cb_params
.
cur_step_num
%
self
.
_per_print_times
==
0
:
steps
=
np
.
arange
(
cur_step
-
self
.
_per_print_times
,
cur_step
+
1
)
eps
,
delta
=
self
.
_compute_privacy_steps
(
list
(
steps
))
if
np
.
isnan
(
eps
)
or
np
.
isinf
(
eps
)
or
np
.
isnan
(
delta
)
or
np
.
isinf
(
delta
):
msg
=
'epoch: {} step: {}, invalid eps, terminating '
\
'training.'
.
format
(
cb_params
.
cur_epoch_num
,
cur_step_in_epoch
)
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
print
(
"epoch: %s step: %s, delta is %s, eps is %s"
%
(
cb_params
.
cur_epoch_num
,
cur_step_in_epoch
,
delta
,
eps
))
def
_compute_privacy_steps
(
self
,
steps
):
"""
Compute privacy budget corresponding to steps.
Args:
steps (list): Training steps.
Returns:
float, privacy budget.
"""
noise_stddev_step
=
self
.
_initial_noise_multiplier
if
self
.
_noise_decay_mode
is
None
:
self
.
_zcdp
+=
self
.
_compute_zcdp
(
noise_stddev_step
)
*
len
(
steps
)
else
:
if
self
.
_noise_decay_mode
==
'Time'
:
noise_stddev_step
=
[
self
.
_initial_noise_multiplier
/
(
1
+
self
.
_noise_decay_rate
*
step
)
for
step
in
steps
]
elif
self
.
_noise_decay_mode
==
'Step'
:
noise_stddev_step
=
[
self
.
_initial_noise_multiplier
*
(
1
-
self
.
_noise_decay_rate
)
**
step
for
step
in
steps
]
elif
self
.
_noise_decay_mode
==
'Exp'
:
noise_stddev_step
=
[
self
.
_initial_noise_multiplier
*
np
.
exp
(
-
step
*
self
.
_noise_decay_rate
)
for
step
in
steps
]
self
.
_zcdp
+=
sum
(
[
self
.
_compute_zcdp
(
noise
)
for
noise
in
noise_stddev_step
])
eps
=
self
.
_compute_eps
(
self
.
_zcdp
)
return
eps
,
self
.
_target_delta
def
_compute_zcdp
(
self
,
noise_stddev
):
"""
Compute zcdp according to added noise.
Args:
noise_stddev (float): Noise multiplier.
Returns:
float or numpy.ndarray, zcdp values.
"""
zcdp
=
1
/
(
2
*
noise_stddev
**
2
)
return
zcdp
def
_compute_eps
(
self
,
zcdp
):
"""
Compute eps for given zcdp and delta.
Args:
zcdp (Union[float, numpy.ndarray]): zero-concentrated
differential privacy.
Returns:
float, eps budget.
"""
eps
=
zcdp
+
2
*
np
.
sqrt
(
zcdp
*
np
.
log
(
1
/
self
.
_target_delta
))
return
eps
def
_compute_rdp_with_order
(
sample_rate
,
noise_stddev
,
order
):
"""
Compute rdp for each order.
...
...
tests/ut/python/diff_privacy/test_monitor.py
浏览文件 @
ad008705
...
...
@@ -53,7 +53,7 @@ def test_dp_monitor():
rdp
=
PrivacyMonitorFactory
.
create
(
policy
=
'rdp'
,
num_samples
=
60000
,
batch_size
=
batch_size
,
initial_noise_multiplier
=
0.4
,
noise_decay_rate
=
6e-
5
)
noise_decay_rate
=
6e-
3
)
suggest_epoch
=
rdp
.
max_epoch_suggest
()
LOGGER
.
info
(
TAG
,
'The recommended maximum training epochs is: %s'
,
suggest_epoch
)
...
...
@@ -83,7 +83,7 @@ def test_dp_monitor_gpu():
rdp
=
PrivacyMonitorFactory
.
create
(
policy
=
'rdp'
,
num_samples
=
60000
,
batch_size
=
batch_size
,
initial_noise_multiplier
=
0.4
,
noise_decay_rate
=
6e-
5
)
noise_decay_rate
=
6e-
3
)
suggest_epoch
=
rdp
.
max_epoch_suggest
()
LOGGER
.
info
(
TAG
,
'The recommended maximum training epochs is: %s'
,
suggest_epoch
)
...
...
@@ -113,7 +113,7 @@ def test_dp_monitor_cpu():
rdp
=
PrivacyMonitorFactory
.
create
(
policy
=
'rdp'
,
num_samples
=
60000
,
batch_size
=
batch_size
,
initial_noise_multiplier
=
0.4
,
noise_decay_rate
=
6e-
5
)
noise_decay_rate
=
6e-
3
)
suggest_epoch
=
rdp
.
max_epoch_suggest
()
LOGGER
.
info
(
TAG
,
'The recommended maximum training epochs is: %s'
,
suggest_epoch
)
...
...
@@ -129,3 +129,94 @@ def test_dp_monitor_cpu():
[
"data"
,
"label"
])
ds1
.
set_dataset_size
(
batch_size
*
batches
)
model
.
train
(
epochs
,
ds1
,
callbacks
=
[
rdp
],
dataset_sink_mode
=
False
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_card
@
pytest
.
mark
.
component_mindarmour
def
test_dp_monitor_zcdp
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
batch_size
=
16
batches
=
128
epochs
=
1
zcdp
=
PrivacyMonitorFactory
.
create
(
policy
=
'zcdp'
,
num_samples
=
60000
,
batch_size
=
batch_size
,
initial_noise_multiplier
=
0.4
,
noise_decay_rate
=
6e-3
)
suggest_epoch
=
zcdp
.
max_epoch_suggest
()
LOGGER
.
info
(
TAG
,
'The recommended maximum training epochs is: %s'
,
suggest_epoch
)
network
=
LeNet5
()
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
0.01
,
0.9
)
model
=
Model
(
network
,
net_loss
,
net_opt
)
LOGGER
.
info
(
TAG
,
"============== Starting Training =============="
)
ds1
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
"data"
,
"label"
])
ds1
.
set_dataset_size
(
batch_size
*
batches
)
model
.
train
(
epochs
,
ds1
,
callbacks
=
[
zcdp
],
dataset_sink_mode
=
False
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_inference
@
pytest
.
mark
.
env_card
@
pytest
.
mark
.
component_mindarmour
def
test_dp_monitor_zcdp_gpu
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
batch_size
=
16
batches
=
128
epochs
=
1
zcdp
=
PrivacyMonitorFactory
.
create
(
policy
=
'zcdp'
,
num_samples
=
60000
,
batch_size
=
batch_size
,
initial_noise_multiplier
=
0.4
,
noise_decay_rate
=
6e-3
)
suggest_epoch
=
zcdp
.
max_epoch_suggest
()
LOGGER
.
info
(
TAG
,
'The recommended maximum training epochs is: %s'
,
suggest_epoch
)
network
=
LeNet5
()
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
0.01
,
0.9
)
model
=
Model
(
network
,
net_loss
,
net_opt
)
LOGGER
.
info
(
TAG
,
"============== Starting Training =============="
)
ds1
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
"data"
,
"label"
])
ds1
.
set_dataset_size
(
batch_size
*
batches
)
model
.
train
(
epochs
,
ds1
,
callbacks
=
[
zcdp
],
dataset_sink_mode
=
False
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_card
@
pytest
.
mark
.
component_mindarmour
def
test_dp_monitor_zcdp_cpu
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"CPU"
)
batch_size
=
16
batches
=
128
epochs
=
1
zcdp
=
PrivacyMonitorFactory
.
create
(
policy
=
'zcdp'
,
num_samples
=
60000
,
batch_size
=
batch_size
,
initial_noise_multiplier
=
0.4
,
noise_decay_rate
=
6e-3
)
suggest_epoch
=
zcdp
.
max_epoch_suggest
()
LOGGER
.
info
(
TAG
,
'The recommended maximum training epochs is: %s'
,
suggest_epoch
)
network
=
LeNet5
()
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
0.01
,
0.9
)
model
=
Model
(
network
,
net_loss
,
net_opt
)
LOGGER
.
info
(
TAG
,
"============== Starting Training =============="
)
ds1
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
"data"
,
"label"
])
ds1
.
set_dataset_size
(
batch_size
*
batches
)
model
.
train
(
epochs
,
ds1
,
callbacks
=
[
zcdp
],
dataset_sink_mode
=
False
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录