Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
f2622e42
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
2
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f2622e42
编写于
5月 30, 2020
作者:
J
jin-xiulang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add value-range check for parameter 'alpha' in mechanisms.py.
上级
d3b06562
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
75 addition
and
66 deletion
+75
-66
example/mnist_demo/lenet5_config.py
example/mnist_demo/lenet5_config.py
+18
-10
example/mnist_demo/lenet5_dp_model_train.py
example/mnist_demo/lenet5_dp_model_train.py
+31
-32
mindarmour/diff_privacy/mechanisms/mechanisms.py
mindarmour/diff_privacy/mechanisms/mechanisms.py
+14
-10
mindarmour/diff_privacy/monitor/monitor.py
mindarmour/diff_privacy/monitor/monitor.py
+6
-8
mindarmour/diff_privacy/optimizer/optimizer.py
mindarmour/diff_privacy/optimizer/optimizer.py
+1
-1
mindarmour/diff_privacy/train/model.py
mindarmour/diff_privacy/train/model.py
+1
-1
tests/ut/python/diff_privacy/test_mechanisms.py
tests/ut/python/diff_privacy/test_mechanisms.py
+4
-4
未找到文件。
example/mnist_demo/lenet5_config.py
浏览文件 @
f2622e42
...
...
@@ -19,14 +19,22 @@ network config setting, will be used in train.py
from
easydict
import
EasyDict
as
edict
mnist_cfg
=
edict
({
'num_classes'
:
10
,
'lr'
:
0.01
,
'momentum'
:
0.9
,
'epoch_size'
:
10
,
'batch_size'
:
256
,
'buffer_size'
:
1000
,
'image_height'
:
32
,
'image_width'
:
32
,
'save_checkpoint_steps'
:
234
,
'keep_checkpoint_max'
:
10
,
'num_classes'
:
10
,
# the number of classes of model's output
'lr'
:
0.01
,
# the learning rate of model's optimizer
'momentum'
:
0.9
,
# the momentum value of model's optimizer
'epoch_size'
:
10
,
# training epochs
'batch_size'
:
256
,
# batch size for training
'image_height'
:
32
,
# the height of training samples
'image_width'
:
32
,
# the width of training samples
'save_checkpoint_steps'
:
234
,
# the interval steps for saving checkpoint file of the model
'keep_checkpoint_max'
:
10
,
# the maximum number of checkpoint files would be saved
'device_target'
:
'Ascend'
,
# device used
'data_path'
:
'./MNIST_unzip'
,
# the path of training and testing data set
'dataset_sink_mode'
:
False
,
# whether deliver all training data to device one time
'micro_batches'
:
32
,
# the number of small batches split from an original batch
'l2_norm_bound'
:
1.0
,
# the clip bound of the gradients of model's training parameters
'initial_noise_multiplier'
:
1.5
,
# the initial multiplication coefficient of the noise added to training
# parameters' gradients
'mechanisms'
:
'AdaGaussian'
,
# the method of adding noise in gradients while training
'optimizer'
:
'Momentum'
# the base optimizer used for Differential privacy training
})
example/mnist_demo/lenet5_dp_model_train.py
浏览文件 @
f2622e42
...
...
@@ -15,7 +15,6 @@
python lenet5_dp_model_train.py --data_path /YourDataPath --micro_batches=2
"""
import
os
import
argparse
import
mindspore.nn
as
nn
from
mindspore
import
context
...
...
@@ -87,21 +86,7 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
'MindSpore MNIST Example'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
"Ascend"
,
choices
=
[
'Ascend'
,
'GPU'
,
'CPU'
],
help
=
'device where the code will be implemented (default: Ascend)'
)
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
"./MNIST_unzip"
,
help
=
'path where the dataset is saved'
)
parser
.
add_argument
(
'--dataset_sink_mode'
,
type
=
bool
,
default
=
False
,
help
=
'dataset_sink_mode is False or True'
)
parser
.
add_argument
(
'--micro_batches'
,
type
=
int
,
default
=
32
,
help
=
'optional, if use differential privacy, need to set micro_batches'
)
parser
.
add_argument
(
'--l2_norm_bound'
,
type
=
float
,
default
=
1.0
,
help
=
'optional, if use differential privacy, need to set l2_norm_bound'
)
parser
.
add_argument
(
'--initial_noise_multiplier'
,
type
=
float
,
default
=
1.5
,
help
=
'optional, if use differential privacy, need to set initial_noise_multiplier'
)
args
=
parser
.
parse_args
()
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
args
.
device_target
)
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
cfg
.
device_target
)
network
=
LeNet5
()
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
...
...
@@ -111,27 +96,41 @@ if __name__ == "__main__":
directory
=
'./trained_ckpt_file/'
,
config
=
config_ck
)
ds_train
=
generate_mnist_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"train"
),
# get training dataset
ds_train
=
generate_mnist_dataset
(
os
.
path
.
join
(
cfg
.
data_path
,
"train"
),
cfg
.
batch_size
,
cfg
.
epoch_size
)
if
args
.
micro_batches
and
cfg
.
batch_size
%
args
.
micro_batches
!=
0
:
if
cfg
.
micro_batches
and
cfg
.
batch_size
%
cfg
.
micro_batches
!=
0
:
raise
ValueError
(
"Number of micro_batches should divide evenly batch_size"
)
gaussian_mech
=
DPOptimizerClassFactory
(
args
.
micro_batches
)
gaussian_mech
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
args
.
l2_norm_bound
,
initial_noise_multiplier
=
args
.
initial_noise_multiplier
)
net_opt
=
gaussian_mech
.
create
(
'Momentum'
)(
params
=
network
.
trainable_params
(),
learning_rate
=
cfg
.
lr
,
momentum
=
cfg
.
momentum
)
# Create a factory class of DP optimizer
gaussian_mech
=
DPOptimizerClassFactory
(
cfg
.
micro_batches
)
# Set the method of adding noise in gradients while training. Initial_noise_multiplier is suggested to be greater
# than 1.0, otherwise the privacy budget would be huge, which means that the privacy protection effect is weak.
# mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' mechanism while
# be constant with 'Gaussian' mechanism.
gaussian_mech
.
set_mechanisms
(
cfg
.
mechanisms
,
norm_bound
=
cfg
.
l2_norm_bound
,
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
)
# Wrap the base optimizer for DP training. Momentum optimizer is suggested for LenNet5.
net_opt
=
gaussian_mech
.
create
(
cfg
.
optimizer
)(
params
=
network
.
trainable_params
(),
learning_rate
=
cfg
.
lr
,
momentum
=
cfg
.
momentum
)
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps
# and delta) while training.
rdp_monitor
=
PrivacyMonitorFactory
.
create
(
'rdp'
,
num_samples
=
60000
,
batch_size
=
cfg
.
batch_size
,
initial_noise_multiplier
=
args
.
initial_noise_multiplier
*
args
.
l2_norm_bound
,
per_print_times
=
10
)
model
=
DPModel
(
micro_batches
=
args
.
micro_batches
,
norm_clip
=
args
.
l2_norm_bound
,
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
*
cfg
.
l2_norm_bound
,
per_print_times
=
50
)
# Create the DP model for training.
model
=
DPModel
(
micro_batches
=
cfg
.
micro_batches
,
norm_clip
=
cfg
.
l2_norm_bound
,
dp_mech
=
gaussian_mech
.
mech
,
network
=
network
,
loss_fn
=
net_loss
,
...
...
@@ -140,12 +139,12 @@ if __name__ == "__main__":
LOGGER
.
info
(
TAG
,
"============== Starting Training =============="
)
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
ckpoint_cb
,
LossMonitor
(),
rdp_monitor
],
dataset_sink_mode
=
args
.
dataset_sink_mode
)
dataset_sink_mode
=
cfg
.
dataset_sink_mode
)
LOGGER
.
info
(
TAG
,
"============== Starting Testing =============="
)
ckpt_file_name
=
'trained_ckpt_file/checkpoint_lenet-10_234.ckpt'
param_dict
=
load_checkpoint
(
ckpt_file_name
)
load_param_into_net
(
network
,
param_dict
)
ds_eval
=
generate_mnist_dataset
(
os
.
path
.
join
(
args
.
data_path
,
'test'
),
batch_size
=
cfg
.
batch_size
)
ds_eval
=
generate_mnist_dataset
(
os
.
path
.
join
(
cfg
.
data_path
,
'test'
),
batch_size
=
cfg
.
batch_size
)
acc
=
model
.
eval
(
ds_eval
,
dataset_sink_mode
=
False
)
LOGGER
.
info
(
TAG
,
"============== Accuracy: %s =============="
,
acc
)
mindarmour/diff_privacy/mechanisms/mechanisms.py
浏览文件 @
f2622e42
...
...
@@ -24,6 +24,7 @@ from mindspore.common import dtype as mstype
from
mindarmour.utils._check_param
import
check_param_type
from
mindarmour.utils._check_param
import
check_value_positive
from
mindarmour.utils._check_param
import
check_param_in_range
class
MechanismsFactory
:
...
...
@@ -37,7 +38,8 @@ class MechanismsFactory:
"""
Args:
policy(str): Noise generated strategy, could be 'Gaussian' or
'AdaGaussian'. Default: 'AdaGaussian'.
'AdaGaussian'. Noise would be decayed with 'AdaGaussian' mechanism while
be constant with 'Gaussian' mechanism. Default: 'AdaGaussian'.
args(Union[float, str]): Parameters used for creating noise
mechanisms.
kwargs(Union[float, str]): Parameters used for creating noise
...
...
@@ -115,7 +117,8 @@ class GaussianRandom(Mechanisms):
class
AdaGaussianRandom
(
Mechanisms
):
"""
Adaptive Gaussian noise generated mechanism.
Adaptive Gaussian noise generated mechanism. Noise would be decayed with training. Decay mode could be 'Time'
mode or 'Step' mode.
Args:
norm_bound(float): Clipping bound for the l2 norm of the gradients.
...
...
@@ -123,7 +126,7 @@ class AdaGaussianRandom(Mechanisms):
initial_noise_multiplier(float): Ratio of the standard deviation of
Gaussian noise divided by the norm_bound, which will be used to
calculate privacy spent. Default: 5.0.
alpha
(float): Hyperparameter for controlling the noise decay.
noise_decay_rate
(float): Hyperparameter for controlling the noise decay.
Default: 6e-4.
decay_policy(str): Noise decay strategy include 'Step' and 'Time'.
Default: 'Time'.
...
...
@@ -135,16 +138,16 @@ class AdaGaussianRandom(Mechanisms):
>>> shape = (3, 2, 4)
>>> norm_bound = 1.0
>>> initial_noise_multiplier = 0.1
>>>
alpha
= 0.5
>>>
noise_decay_rate
= 0.5
>>> decay_policy = "Time"
>>> net = AdaGaussianRandom(norm_bound, initial_noise_multiplier,
>>>
alpha
, decay_policy)
>>>
noise_decay_rate
, decay_policy)
>>> res = net(shape)
>>> print(res)
"""
def
__init__
(
self
,
norm_bound
=
1.5
,
initial_noise_multiplier
=
5.0
,
alpha
=
6e-4
,
decay_policy
=
'Time'
):
noise_decay_rate
=
6e-4
,
decay_policy
=
'Time'
):
super
(
AdaGaussianRandom
,
self
).
__init__
()
initial_noise_multiplier
=
check_value_positive
(
'initial_noise_multiplier'
,
initial_noise_multiplier
)
...
...
@@ -156,8 +159,9 @@ class AdaGaussianRandom(Mechanisms):
norm_bound
=
check_value_positive
(
'norm_bound'
,
norm_bound
)
self
.
_norm_bound
=
Tensor
(
np
.
array
(
norm_bound
,
np
.
float32
))
alpha
=
check_param_type
(
'alpha'
,
alpha
,
float
)
self
.
_alpha
=
Tensor
(
np
.
array
(
alpha
,
np
.
float32
))
noise_decay_rate
=
check_param_type
(
'noise_decay_rate'
,
noise_decay_rate
,
float
)
check_param_in_range
(
'noise_decay_rate'
,
noise_decay_rate
,
0.0
,
1.0
)
self
.
_noise_decay_rate
=
Tensor
(
np
.
array
(
noise_decay_rate
,
np
.
float32
))
if
decay_policy
not
in
[
'Time'
,
'Step'
]:
raise
NameError
(
"The decay_policy must be in ['Time', 'Step'], but "
...
...
@@ -176,12 +180,12 @@ class AdaGaussianRandom(Mechanisms):
if
self
.
_decay_policy
==
'Time'
:
temp
=
self
.
_div
(
self
.
_initial_noise_multiplier
,
self
.
_noise_multiplier
)
temp
=
self
.
_add
(
temp
,
self
.
_
alpha
)
temp
=
self
.
_add
(
temp
,
self
.
_
noise_decay_rate
)
temp
=
self
.
_div
(
self
.
_initial_noise_multiplier
,
temp
)
self
.
_noise_multiplier
=
Parameter
(
temp
,
name
=
'noise_multiplier'
)
else
:
one
=
Tensor
(
1
,
self
.
_dtype
)
temp
=
self
.
_sub
(
one
,
self
.
_
alpha
)
temp
=
self
.
_sub
(
one
,
self
.
_
noise_decay_rate
)
temp
=
self
.
_mul
(
temp
,
self
.
_noise_multiplier
)
self
.
_noise_multiplier
=
Parameter
(
temp
,
name
=
'noise_multiplier'
)
...
...
mindarmour/diff_privacy/monitor/monitor.py
浏览文件 @
f2622e42
...
...
@@ -20,7 +20,7 @@ from mindspore.train.callback import Callback
from
mindarmour.utils.logger
import
LogUtil
from
mindarmour.utils._check_param
import
check_int_positive
,
\
check_value_positive
check_value_positive
,
check_param_in_range
,
check_param_type
LOGGER
=
LogUtil
.
get_instance
()
TAG
=
'DP monitor'
...
...
@@ -40,7 +40,8 @@ class PrivacyMonitorFactory:
Create a privacy monitor class.
Args:
policy (str): Monitor policy, 'rdp' is supported by now.
policy (str): Monitor policy, 'rdp' is supported by now. RDP means R'enyi differential privacy,
which computed based on R'enyi divergence.
args (Union[int, float, numpy.ndarray, list, str]): Parameters
used for creating a privacy monitor.
kwargs (Union[int, float, numpy.ndarray, list, str]): Keyword
...
...
@@ -70,7 +71,7 @@ class RDPMonitor(Callback):
num_samples (int): The total number of samples in training data sets.
batch_size (int): The number of samples in a batch while training.
initial_noise_multiplier (Union[float, int]): The initial
multiplier of
added noise
. Default: 1.5.
multiplier of
the noise added to training parameters' gradients
. Default: 1.5.
max_eps (Union[float, int, None]): The maximum acceptable epsilon
budget for DP training. Default: 10.0.
target_delta (Union[float, int, None]): Target delta budget for DP
...
...
@@ -137,11 +138,8 @@ class RDPMonitor(Callback):
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
noise_decay_rate
is
not
None
:
check_value_positive
(
'noise_decay_rate'
,
noise_decay_rate
)
if
noise_decay_rate
>=
1
:
msg
=
'Noise decay rate must be less than 1'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
noise_decay_rate
=
check_param_type
(
'noise_decay_rate'
,
noise_decay_rate
,
float
)
check_param_in_range
(
'noise_decay_rate'
,
noise_decay_rate
,
0.0
,
1.0
)
check_int_positive
(
'per_print_times'
,
per_print_times
)
self
.
_total_echo_privacy
=
None
...
...
mindarmour/diff_privacy/optimizer/optimizer.py
浏览文件 @
f2622e42
...
...
@@ -27,7 +27,7 @@ class DPOptimizerClassFactory:
Factory class of Optimizer.
Args:
micro_batches (int): The number of small batches split from an origi
an
l batch. Default: 2.
micro_batches (int): The number of small batches split from an origi
na
l batch. Default: 2.
Returns:
Optimizer, Optimizer class
...
...
mindarmour/diff_privacy/train/model.py
浏览文件 @
f2622e42
...
...
@@ -70,7 +70,7 @@ class DPModel(Model):
This class is overload mindspore.train.model.Model.
Args:
micro_batches (int): The number of small batches split from an origi
an
l batch. Default: 2.
micro_batches (int): The number of small batches split from an origi
na
l batch. Default: 2.
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0.
dp_mech (Mechanisms): The object can generate the different type of noise. Default: None.
...
...
tests/ut/python/diff_privacy/test_mechanisms.py
浏览文件 @
f2622e42
...
...
@@ -45,10 +45,10 @@ def test_ada_gaussian():
shape
=
(
3
,
2
,
4
)
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
alpha
=
0.5
noise_decay_rate
=
0.5
decay_policy
=
"Step"
net
=
AdaGaussianRandom
(
norm_bound
,
initial_noise_multiplier
,
alpha
,
decay_policy
)
noise_decay_rate
,
decay_policy
)
res
=
net
(
shape
)
print
(
res
)
...
...
@@ -58,7 +58,7 @@ def test_factory():
shape
=
(
3
,
2
,
4
)
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
alpha
=
0.5
noise_decay_rate
=
0.5
decay_policy
=
"Step"
noise_mechanism
=
MechanismsFactory
()
noise_construct
=
noise_mechanism
.
create
(
'Gaussian'
,
...
...
@@ -70,7 +70,7 @@ def test_factory():
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
norm_bound
,
initial_noise_multiplier
,
alpha
,
noise_decay_rate
,
decay_policy
)
ada_noise
=
ada_noise_construct
(
shape
)
print
(
'ada noise: '
,
ada_noise
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录