Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
79c6403d
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
2
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
79c6403d
编写于
7月 14, 2020
作者:
Z
ZhidanLiu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add new feature: adaptive clipping
上级
ac39d193
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
676 addition
and
218 deletion
+676
-218
example/mnist_demo/lenet5_config.py
example/mnist_demo/lenet5_config.py
+8
-3
example/mnist_demo/lenet5_dp.py
example/mnist_demo/lenet5_dp.py
+42
-19
mindarmour/diff_privacy/__init__.py
mindarmour/diff_privacy/__init__.py
+8
-4
mindarmour/diff_privacy/mechanisms/mechanisms.py
mindarmour/diff_privacy/mechanisms/mechanisms.py
+194
-48
mindarmour/diff_privacy/optimizer/optimizer.py
mindarmour/diff_privacy/optimizer/optimizer.py
+2
-2
mindarmour/diff_privacy/train/model.py
mindarmour/diff_privacy/train/model.py
+257
-110
tests/ut/python/diff_privacy/test_mechanisms.py
tests/ut/python/diff_privacy/test_mechanisms.py
+116
-10
tests/ut/python/diff_privacy/test_model_train.py
tests/ut/python/diff_privacy/test_model_train.py
+49
-22
未找到文件。
example/mnist_demo/lenet5_config.py
浏览文件 @
79c6403d
...
...
@@ -20,7 +20,7 @@ from easydict import EasyDict as edict
mnist_cfg
=
edict
({
'num_classes'
:
10
,
# the number of classes of model's output
'lr'
:
0.1
,
# the learning rate of model's optimizer
'lr'
:
0.
0
1
,
# the learning rate of model's optimizer
'momentum'
:
0.9
,
# the momentum value of model's optimizer
'epoch_size'
:
10
,
# training epochs
'batch_size'
:
256
,
# batch size for training
...
...
@@ -33,8 +33,13 @@ mnist_cfg = edict({
'dataset_sink_mode'
:
False
,
# whether deliver all training data to device one time
'micro_batches'
:
16
,
# the number of small batches split from an original batch
'norm_clip'
:
1.0
,
# the clip bound of the gradients of model's training parameters
'initial_noise_multiplier'
:
1
.5
,
# the initial multiplication coefficient of the noise added to training
'initial_noise_multiplier'
:
0
.5
,
# the initial multiplication coefficient of the noise added to training
# parameters' gradients
'mechanisms'
:
'AdaGaussian'
,
# the method of adding noise in gradients while training
'noise_mechanisms'
:
'AdaGaussian'
,
# the method of adding noise in gradients while training
'clip_mechanisms'
:
'Gaussian'
,
# the method of adaptive clipping gradients while training
'clip_decay_policy'
:
'Linear'
,
# Decay policy of adaptive clipping, decay_policy must be in ['Linear', 'Geometric'].
'clip_learning_rate'
:
0.001
,
# Learning rate of update norm clip.
'target_unclipped_quantile'
:
0.9
,
# Target quantile of norm clip.
'fraction_stddev'
:
0.01
,
# The stddev of Gaussian normal which used in empirical_fraction.
'optimizer'
:
'Momentum'
# the base optimizer used for Differential privacy training
})
example/mnist_demo/lenet5_dp.py
浏览文件 @
79c6403d
...
...
@@ -31,7 +31,8 @@ import mindspore.common.dtype as mstype
from
mindarmour.diff_privacy
import
DPModel
from
mindarmour.diff_privacy
import
PrivacyMonitorFactory
from
mindarmour.diff_privacy
import
MechanismsFactory
from
mindarmour.diff_privacy
import
NoiseMechanismsFactory
from
mindarmour.diff_privacy
import
ClipMechanismsFactory
from
mindarmour.utils.logger
import
LogUtil
from
lenet5_net
import
LeNet5
from
lenet5_config
import
mnist_cfg
as
cfg
...
...
@@ -87,11 +88,14 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1,
if
__name__
==
"__main__"
:
# This configure can run both in pynative mode and graph mode
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
cfg
.
device_target
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
cfg
.
device_target
)
network
=
LeNet5
()
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
save_checkpoint_steps
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
save_checkpoint_steps
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
directory
=
'./trained_ckpt_file/'
,
config
=
config_ck
)
...
...
@@ -102,17 +106,33 @@ if __name__ == "__main__":
cfg
.
epoch_size
)
if
cfg
.
micro_batches
and
cfg
.
batch_size
%
cfg
.
micro_batches
!=
0
:
raise
ValueError
(
"Number of micro_batches should divide evenly batch_size"
)
# Create a factory class of DP mechanisms, this method is adding noise in gradients while training.
# Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which
# means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism.
mech
=
MechanismsFactory
().
create
(
cfg
.
mechanisms
,
norm_bound
=
cfg
.
norm_clip
,
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
)
net_opt
=
nn
.
Momentum
(
params
=
network
.
trainable_params
(),
learning_rate
=
cfg
.
lr
,
momentum
=
cfg
.
momentum
)
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps
# and delta) while training.
raise
ValueError
(
"Number of micro_batches should divide evenly batch_size"
)
# Create a factory class of DP noise mechanisms, this method is adding noise
# in gradients while training. Initial_noise_multiplier is suggested to be
# greater than 1.0, otherwise the privacy budget would be huge, which means
# that the privacy protection effect is weak. Mechanisms can be 'Gaussian'
# or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian'
# mechanism while be constant with 'Gaussian' mechanism.
noise_mech
=
NoiseMechanismsFactory
().
create
(
cfg
.
noise_mechanisms
,
norm_bound
=
cfg
.
norm_clip
,
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
)
# Create a factory class of clip mechanisms, this method is to adaptive clip
# gradients while training, decay_policy support 'Linear' and 'Geometric',
# learning_rate is the learning rate to update clip_norm,
# target_unclipped_quantile is the target quantile of norm clip,
# fraction_stddev is the stddev of Gaussian normal which used in
# empirical_fraction, the formula is
# $empirical_fraction + N(0, fraction_stddev)$.
clip_mech
=
ClipMechanismsFactory
().
create
(
cfg
.
clip_mechanisms
,
decay_policy
=
cfg
.
clip_decay_policy
,
learning_rate
=
cfg
.
clip_learning_rate
,
target_unclipped_quantile
=
cfg
.
target_unclipped_quantile
,
fraction_stddev
=
cfg
.
fraction_stddev
)
net_opt
=
nn
.
Momentum
(
params
=
network
.
trainable_params
(),
learning_rate
=
cfg
.
lr
,
momentum
=
cfg
.
momentum
)
# Create a monitor for DP training. The function of the monitor is to
# compute and print the privacy budget(eps and delta) while training.
rdp_monitor
=
PrivacyMonitorFactory
.
create
(
'rdp'
,
num_samples
=
60000
,
batch_size
=
cfg
.
batch_size
,
...
...
@@ -121,20 +141,23 @@ if __name__ == "__main__":
# Create the DP model for training.
model
=
DPModel
(
micro_batches
=
cfg
.
micro_batches
,
norm_clip
=
cfg
.
norm_clip
,
mech
=
mech
,
noise_mech
=
noise_mech
,
clip_mech
=
clip_mech
,
network
=
network
,
loss_fn
=
net_loss
,
optimizer
=
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
LOGGER
.
info
(
TAG
,
"============== Starting Training =============="
)
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
ckpoint_cb
,
LossMonitor
(),
rdp_monitor
],
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
ckpoint_cb
,
LossMonitor
(),
rdp_monitor
],
dataset_sink_mode
=
cfg
.
dataset_sink_mode
)
LOGGER
.
info
(
TAG
,
"============== Starting Testing =============="
)
ckpt_file_name
=
'trained_ckpt_file/checkpoint_lenet-10_234.ckpt'
param_dict
=
load_checkpoint
(
ckpt_file_name
)
load_param_into_net
(
network
,
param_dict
)
ds_eval
=
generate_mnist_dataset
(
os
.
path
.
join
(
cfg
.
data_path
,
'test'
),
batch_size
=
cfg
.
batch_size
)
ds_eval
=
generate_mnist_dataset
(
os
.
path
.
join
(
cfg
.
data_path
,
'test'
),
batch_size
=
cfg
.
batch_size
)
acc
=
model
.
eval
(
ds_eval
,
dataset_sink_mode
=
False
)
LOGGER
.
info
(
TAG
,
"============== Accuracy: %s =============="
,
acc
)
mindarmour/diff_privacy/__init__.py
浏览文件 @
79c6403d
"""
This module provide Differential Privacy feature to protect user privacy.
"""
from
.mechanisms.mechanisms
import
GaussianRandom
from
.mechanisms.mechanisms
import
Noise
GaussianRandom
from
.mechanisms.mechanisms
import
AdaGaussianRandom
from
.mechanisms.mechanisms
import
MechanismsFactory
from
.mechanisms.mechanisms
import
AdaClippingWithGaussianRandom
from
.mechanisms.mechanisms
import
NoiseMechanismsFactory
from
.mechanisms.mechanisms
import
ClipMechanismsFactory
from
.monitor.monitor
import
PrivacyMonitorFactory
from
.optimizer.optimizer
import
DPOptimizerClassFactory
from
.train.model
import
DPModel
__all__
=
[
'GaussianRandom'
,
__all__
=
[
'
Noise
GaussianRandom'
,
'AdaGaussianRandom'
,
'MechanismsFactory'
,
'AdaClippingWithGaussianRandom'
,
'NoiseMechanismsFactory'
,
'ClipMechanismsFactory'
,
'PrivacyMonitorFactory'
,
'DPOptimizerClassFactory'
,
'DPModel'
]
mindarmour/diff_privacy/mechanisms/mechanisms.py
浏览文件 @
79c6403d
...
...
@@ -28,11 +28,54 @@ from mindarmour.utils._check_param import check_param_in_range
from
mindarmour.utils.logger
import
LogUtil
LOGGER
=
LogUtil
.
get_instance
()
TAG
=
'
Defense
'
TAG
=
'
NoiseMechanism
'
class
MechanismsFactory
:
""" Factory class of mechanisms"""
class
ClipMechanismsFactory
:
""" Factory class of clip mechanisms"""
def
__init__
(
self
):
pass
@
staticmethod
def
create
(
mech_name
,
*
args
,
**
kwargs
):
"""
Args:
mech_name(str): Clip noise generated strategy, support 'Gaussian' now.
args(Union[float, str]): Parameters used for creating clip mechanisms.
kwargs(Union[float, str]): Parameters used for creating clip
mechanisms.
Raises:
NameError: `mech_name` must be in ['Gaussian'].
Returns:
Mechanisms, class of noise generated Mechanism.
Examples:
>>> decay_policy = 'Linear'
>>> beta = Tensor(0.5, mstype.float32)
>>> norm_clip = Tensor(1.0, mstype.float32)
>>> beta_stddev = 0.1
>>> learning_rate = 0.1
>>> target_unclipped_quantile = 0.3
>>> clip_mechanism = ClipMechanismsFactory()
>>> ada_clip = clip_mechanism.create('Gaussian',
>>> decay_policy=decay_policy,
>>> learning_rate=learning_rate,
>>> target_unclipped_quantile=target_unclipped_quantile,
>>> fraction_stddev=beta_stddev)
>>> next_norm_clip = ada_clip(beta, norm_clip)
"""
if
mech_name
==
'Gaussian'
:
return
AdaClippingWithGaussianRandom
(
*
args
,
**
kwargs
)
raise
NameError
(
"The {} is not implement, please choose "
"['Gaussian']"
.
format
(
mech_name
))
class
NoiseMechanismsFactory
:
""" Factory class of noise mechanisms"""
def
__init__
(
self
):
pass
...
...
@@ -56,42 +99,38 @@ class MechanismsFactory:
Mechanisms, class of noise generated Mechanism.
Examples:
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
>>> self.bn = nn.BatchNorm2d(64)
>>> self.relu = nn.ReLU()
>>> self.flatten = nn.Flatten()
>>> self.fc = nn.Dense(64*224*224, 12) # padding=0
>>>
>>> def construct(self, x):
>>> x = self.conv(x)
>>> x = self.bn(x)
>>> x = self.relu(x)
>>> x = self.flatten(x)
>>> out = self.fc(x)
>>> return out
>>> norm_clip = 1.0
>>> initial_noise_multiplier = 1.5
>>> net = Net()
>>> initial_noise_multiplier = 0.01
>>> network = LeNet5()
>>> batch_size = 32
>>> batches = 128
>>> epochs = 1
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> net_opt = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9)
>>> mech = MechanismsFactory().create('Gaussian',
>>> norm_bound=norm_clip,
>>> initial_noise_multiplier=initial_noise_multiplier)
>>> noise_mech = NoiseMechanismsFactory().create('Gaussian',
>>> norm_bound=norm_clip,
>>> initial_noise_multiplier=initial_noise_multiplier)
>>> clip_mech = ClipMechanismsFactory().create('Gaussian',
>>> decay_policy='Linear',
>>> learning_rate=0.01,
>>> target_unclipped_quantile=0.9,
>>> fraction_stddev=0.01)
>>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1,
>>> momentum=0.9)
>>> model = DPModel(micro_batches=2,
>>> norm_clip=1.0,
>>> mech=mech,
>>> network=net,
>>> clip_mech=clip_mech,
>>> norm_clip=norm_clip,
>>> noise_mech=noise_mech,
>>> network=network,
>>> loss_fn=loss,
>>> optimizer=net_opt,
>>> metrics=None)
>>> dataset = get_dataset()
>>> model.train(2, dataset)
>>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches),
>>> ['data', 'label'])
>>> ms_ds.set_dataset_size(batch_size * batches)
>>> model.train(epochs, ms_ds, dataset_sink_mode=False)
"""
if
policy
==
'Gaussian'
:
return
GaussianRandom
(
*
args
,
**
kwargs
)
return
Noise
GaussianRandom
(
*
args
,
**
kwargs
)
if
policy
==
'AdaGaussian'
:
return
AdaGaussianRandom
(
*
args
,
**
kwargs
)
raise
NameError
(
"The {} is not implement, please choose "
...
...
@@ -110,7 +149,7 @@ class Mechanisms(Cell):
"""
class
GaussianRandom
(
Mechanisms
):
class
Noise
GaussianRandom
(
Mechanisms
):
"""
Gaussian noise generated mechanism.
...
...
@@ -133,18 +172,21 @@ class GaussianRandom(Mechanisms):
>>> gradients = Tensor([0.2, 0.9], mstype.float32)
>>> norm_bound = 0.5
>>> initial_noise_multiplier = 1.5
>>> net = GaussianRandom(norm_bound, initial_noise_multiplier)
>>> net =
Noise
GaussianRandom(norm_bound, initial_noise_multiplier)
>>> res = net(gradients)
>>> print(res)
"""
def
__init__
(
self
,
norm_bound
=
0.5
,
initial_noise_multiplier
=
1.5
,
seed
=
0
,
policy
=
None
):
super
(
GaussianRandom
,
self
).
__init__
()
def
__init__
(
self
,
norm_bound
=
0.5
,
initial_noise_multiplier
=
1.5
,
seed
=
0
,
policy
=
None
):
super
(
NoiseGaussianRandom
,
self
).
__init__
()
self
.
_norm_bound
=
check_value_positive
(
'norm_bound'
,
norm_bound
)
self
.
_norm_bound
=
Tensor
(
norm_bound
,
mstype
.
float32
)
self
.
_initial_noise_multiplier
=
check_value_positive
(
'initial_noise_multiplier'
,
initial_noise_multiplier
)
self
.
_initial_noise_multiplier
=
Tensor
(
initial_noise_multiplier
,
mstype
.
float32
)
self
.
_initial_noise_multiplier
=
check_value_positive
(
'initial_noise_multiplier'
,
initial_noise_multiplier
)
self
.
_initial_noise_multiplier
=
Tensor
(
initial_noise_multiplier
,
mstype
.
float32
)
self
.
_mean
=
Tensor
(
0
,
mstype
.
float32
)
self
.
_normal
=
P
.
Normal
(
seed
=
seed
)
self
.
_decay_policy
=
policy
...
...
@@ -201,17 +243,20 @@ class AdaGaussianRandom(Mechanisms):
noise_decay_rate
=
6e-4
,
decay_policy
=
'Time'
,
seed
=
0
):
super
(
AdaGaussianRandom
,
self
).
__init__
()
norm_bound
=
check_value_positive
(
'norm_bound'
,
norm_bound
)
initial_noise_multiplier
=
check_value_positive
(
'initial_noise_multiplier'
,
initial_noise_multiplier
)
initial_noise_multiplier
=
check_value_positive
(
'initial_noise_multiplier'
,
initial_noise_multiplier
)
self
.
_norm_bound
=
Tensor
(
norm_bound
,
mstype
.
float32
)
initial_noise_multiplier
=
Tensor
(
initial_noise_multiplier
,
mstype
.
float32
)
initial_noise_multiplier
=
Tensor
(
initial_noise_multiplier
,
mstype
.
float32
)
self
.
_initial_noise_multiplier
=
Parameter
(
initial_noise_multiplier
,
name
=
'initial_noise_multiplier'
)
self
.
_noise_multiplier
=
Parameter
(
initial_noise_multiplier
,
name
=
'noise_multiplier'
)
self
.
_mean
=
Tensor
(
0
,
mstype
.
float32
)
noise_decay_rate
=
check_param_type
(
'noise_decay_rate'
,
noise_decay_rate
,
float
)
noise_decay_rate
=
check_param_type
(
'noise_decay_rate'
,
noise_decay_rate
,
float
)
check_param_in_range
(
'noise_decay_rate'
,
noise_decay_rate
,
0.0
,
1.0
)
self
.
_noise_decay_rate
=
Tensor
(
noise_decay_rate
,
mstype
.
float32
)
if
decay_policy
not
in
[
'Time'
,
'Step'
,
'Exp'
]:
...
...
@@ -232,7 +277,9 @@ class AdaGaussianRandom(Mechanisms):
Tensor, generated noise with shape like given gradients.
"""
shape
=
P
.
Shape
()(
gradients
)
noise
=
self
.
_normal
(
shape
,
self
.
_mean
,
self
.
_mul
(
self
.
_noise_multiplier
,
self
.
_norm_bound
))
noise
=
self
.
_normal
(
shape
,
self
.
_mean
,
self
.
_mul
(
self
.
_noise_multiplier
,
self
.
_norm_bound
))
return
noise
...
...
@@ -241,10 +288,14 @@ class _MechanismsParamsUpdater(Cell):
Update mechanisms parameters, the parameters will refresh in train period.
Args:
policy(str): Pass in by the mechanisms class, mechanisms parameters update policy.
decay_rate(Tensor): Pass in by the mechanisms class, hyper parameter for controlling the decay size.
cur_noise_multiplier(Parameter): Pass in by the mechanisms class, current params value in this time.
init_noise_multiplier(Parameter):Pass in by the mechanisms class, initial params value to be updated.
policy(str): Pass in by the mechanisms class, mechanisms parameters
update policy.
decay_rate(Tensor): Pass in by the mechanisms class, hyper parameter for
controlling the decay size.
cur_noise_multiplier(Parameter): Pass in by the mechanisms class,
current params value in this time.
init_noise_multiplier(Parameter):Pass in by the mechanisms class,
initial params value to be updated.
Returns:
Tuple, next params value.
...
...
@@ -281,5 +332,100 @@ class _MechanismsParamsUpdater(Cell):
next_noise_multiplier
=
self
.
_assign
(
self
.
_cur_noise_multiplier
,
self
.
_mul
(
temp
,
self
.
_cur_noise_multiplier
))
else
:
next_noise_multiplier
=
self
.
_assign
(
self
.
_cur_noise_multiplier
,
self
.
_div
(
self
.
_one
,
self
.
_exp
(
self
.
_one
)))
next_noise_multiplier
=
self
.
_assign
(
self
.
_cur_noise_multiplier
,
self
.
_div
(
self
.
_one
,
self
.
_exp
(
self
.
_one
)))
return
next_noise_multiplier
class
AdaClippingWithGaussianRandom
(
Cell
):
"""
Adaptive clipping. If `decay_policy` is 'Linear', the update formula is
$ norm_clip = norm_clip - learning_rate*(beta-target_unclipped_quantile)$.
`decay_policy` is 'Geometric', the update formula is
$ norm_clip = norm_clip*exp(-learning_rate*(empirical_fraction-target_unclipped_quantile))$.
where beta is the empirical fraction of samples with the value at most
`target_unclipped_quantile`.
Args:
decay_policy(str): Decay policy of adaptive clipping, decay_policy must
be in ['Linear', 'Geometric']. Default: Linear.
learning_rate(float): Learning rate of update norm clip. Default: 0.01.
target_unclipped_quantile(float): Target quantile of norm clip. Default: 0.9.
fraction_stddev(float): The stddev of Gaussian normal which used in
empirical_fraction, the formula is $empirical_fraction + N(0, fraction_stddev)$.
seed(int): Original random seed, if seed=0 random normal will use secure
random number. IF seed!=0 random normal will generate values using
given seed. Default: 0.
Returns:
Tensor, undated norm clip .
Examples:
>>> decay_policy = 'Linear'
>>> beta = Tensor(0.5, mstype.float32)
>>> norm_clip = Tensor(1.0, mstype.float32)
>>> beta_stddev = 0.01
>>> learning_rate = 0.001
>>> target_unclipped_quantile = 0.9
>>> ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy,
>>> learning_rate=learning_rate,
>>> target_unclipped_quantile=target_unclipped_quantile,
>>> fraction_stddev=beta_stddev)
>>> next_norm_clip = ada_clip(beta, norm_clip)
"""
def
__init__
(
self
,
decay_policy
=
'Linear'
,
learning_rate
=
0.001
,
target_unclipped_quantile
=
0.9
,
fraction_stddev
=
0.01
,
seed
=
0
):
super
(
AdaClippingWithGaussianRandom
,
self
).
__init__
()
if
decay_policy
not
in
[
'Linear'
,
'Geometric'
]:
msg
=
"decay policy of adaptive clip must be in ['Linear', 'Geometric'],
\
but got: {}"
.
format
(
decay_policy
)
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
self
.
_decay_policy
=
decay_policy
learning_rate
=
check_param_type
(
'learning_rate'
,
learning_rate
,
float
)
learning_rate
=
check_value_positive
(
'learning_rate'
,
learning_rate
)
self
.
_learning_rate
=
Tensor
(
learning_rate
,
mstype
.
float32
)
fraction_stddev
=
check_param_type
(
'fraction_stddev'
,
fraction_stddev
,
float
)
self
.
_fraction_stddev
=
Tensor
(
fraction_stddev
,
mstype
.
float32
)
target_unclipped_quantile
=
check_param_type
(
'target_unclipped_quantile'
,
target_unclipped_quantile
,
float
)
self
.
_target_unclipped_quantile
=
Tensor
(
target_unclipped_quantile
,
mstype
.
float32
)
self
.
_zero
=
Tensor
(
0
,
mstype
.
float32
)
self
.
_add
=
P
.
TensorAdd
()
self
.
_sub
=
P
.
Sub
()
self
.
_mul
=
P
.
Mul
()
self
.
_exp
=
P
.
Exp
()
self
.
_normal
=
P
.
Normal
(
seed
=
seed
)
def
construct
(
self
,
empirical_fraction
,
norm_clip
):
"""
Update value of norm_clip.
Args:
empirical_fraction(Tensor): empirical fraction of samples with the
value at most `target_unclipped_quantile`.
norm_clip(Tensor): Clipping bound for the l2 norm of the gradients.
Returns:
Tensor, generated noise with shape like given gradients.
"""
fraction_noise
=
self
.
_normal
((
1
,),
self
.
_zero
,
self
.
_fraction_stddev
)
empirical_fraction
=
self
.
_add
(
empirical_fraction
,
fraction_noise
)
if
self
.
_decay_policy
==
'Linear'
:
grad_clip
=
self
.
_sub
(
empirical_fraction
,
self
.
_target_unclipped_quantile
)
next_norm_clip
=
self
.
_sub
(
norm_clip
,
self
.
_mul
(
self
.
_learning_rate
,
grad_clip
))
# decay_policy == 'Geometric'
else
:
grad_clip
=
self
.
_sub
(
empirical_fraction
,
self
.
_target_unclipped_quantile
)
grad_clip
=
self
.
_exp
(
self
.
_mul
(
-
self
.
_learning_rate
,
grad_clip
))
next_norm_clip
=
self
.
_mul
(
norm_clip
,
grad_clip
)
return
next_norm_clip
mindarmour/diff_privacy/optimizer/optimizer.py
浏览文件 @
79c6403d
...
...
@@ -22,7 +22,7 @@ from mindspore.ops import functional as F
from
mindspore.common
import
dtype
as
mstype
from
mindarmour.utils.logger
import
LogUtil
from
mindarmour.diff_privacy
import
MechanismsFactory
from
mindarmour.diff_privacy
import
Noise
MechanismsFactory
from
mindarmour.diff_privacy.mechanisms.mechanisms
import
_MechanismsParamsUpdater
from
mindarmour.utils._check_param
import
check_int_positive
...
...
@@ -70,7 +70,7 @@ class DPOptimizerClassFactory:
"""
def
__init__
(
self
,
micro_batches
=
2
):
self
.
_mech_factory
=
MechanismsFactory
()
self
.
_mech_factory
=
Noise
MechanismsFactory
()
self
.
mech
=
None
self
.
_micro_batches
=
check_int_positive
(
'micro_batches'
,
micro_batches
)
...
...
mindarmour/diff_privacy/train/model.py
浏览文件 @
79c6403d
此差异已折叠。
点击以展开。
tests/ut/python/diff_privacy/test_mechanisms.py
浏览文件 @
79c6403d
...
...
@@ -19,9 +19,11 @@ import pytest
from
mindspore
import
context
from
mindspore
import
Tensor
from
mindspore.common
import
dtype
as
mstype
from
mindarmour.diff_privacy
import
GaussianRandom
from
mindarmour.diff_privacy
import
Noise
GaussianRandom
from
mindarmour.diff_privacy
import
AdaGaussianRandom
from
mindarmour.diff_privacy
import
MechanismsFactory
from
mindarmour.diff_privacy
import
AdaClippingWithGaussianRandom
from
mindarmour.diff_privacy
import
NoiseMechanismsFactory
from
mindarmour.diff_privacy
import
ClipMechanismsFactory
@
pytest
.
mark
.
level0
...
...
@@ -33,7 +35,7 @@ def test_graph_gaussian():
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
net
=
GaussianRandom
(
norm_bound
,
initial_noise_multiplier
)
net
=
Noise
GaussianRandom
(
norm_bound
,
initial_noise_multiplier
)
res
=
net
(
grad
)
print
(
res
)
...
...
@@ -47,7 +49,7 @@ def test_pynative_gaussian():
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
net
=
GaussianRandom
(
norm_bound
,
initial_noise_multiplier
)
net
=
Noise
GaussianRandom
(
norm_bound
,
initial_noise_multiplier
)
res
=
net
(
grad
)
print
(
res
)
...
...
@@ -80,13 +82,13 @@ def test_graph_factory():
initial_noise_multiplier
=
0.1
alpha
=
0.5
decay_policy
=
'Step'
noise_mechanism
=
MechanismsFactory
()
noise_mechanism
=
Noise
MechanismsFactory
()
noise_construct
=
noise_mechanism
.
create
(
'Gaussian'
,
norm_bound
,
initial_noise_multiplier
)
noise
=
noise_construct
(
grad
)
print
(
'Gaussian noise: '
,
noise
)
ada_mechanism
=
MechanismsFactory
()
ada_mechanism
=
Noise
MechanismsFactory
()
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
norm_bound
,
initial_noise_multiplier
,
...
...
@@ -124,13 +126,13 @@ def test_pynative_factory():
initial_noise_multiplier
=
0.1
alpha
=
0.5
decay_policy
=
'Step'
noise_mechanism
=
MechanismsFactory
()
noise_mechanism
=
Noise
MechanismsFactory
()
noise_construct
=
noise_mechanism
.
create
(
'Gaussian'
,
norm_bound
,
initial_noise_multiplier
)
noise
=
noise_construct
(
grad
)
print
(
'Gaussian noise: '
,
noise
)
ada_mechanism
=
MechanismsFactory
()
ada_mechanism
=
Noise
MechanismsFactory
()
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
norm_bound
,
initial_noise_multiplier
,
...
...
@@ -151,7 +153,7 @@ def test_pynative_exponential():
initial_noise_multiplier
=
0.1
alpha
=
0.5
decay_policy
=
'Exp'
ada_mechanism
=
MechanismsFactory
()
ada_mechanism
=
Noise
MechanismsFactory
()
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
norm_bound
,
initial_noise_multiplier
,
...
...
@@ -172,7 +174,7 @@ def test_graph_exponential():
initial_noise_multiplier
=
0.1
alpha
=
0.5
decay_policy
=
'Exp'
ada_mechanism
=
MechanismsFactory
()
ada_mechanism
=
Noise
MechanismsFactory
()
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
norm_bound
,
initial_noise_multiplier
,
...
...
@@ -180,3 +182,107 @@ def test_graph_exponential():
decay_policy
=
decay_policy
)
ada_noise
=
ada_noise_construct
(
grad
)
print
(
'ada noise: '
,
ada_noise
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
def
test_ada_clip_gaussian_random_pynative
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
decay_policy
=
'Linear'
beta
=
Tensor
(
0.5
,
mstype
.
float32
)
norm_clip
=
Tensor
(
1.0
,
mstype
.
float32
)
beta_stddev
=
0.1
learning_rate
=
0.1
target_unclipped_quantile
=
0.3
ada_clip
=
AdaClippingWithGaussianRandom
(
decay_policy
=
decay_policy
,
learning_rate
=
learning_rate
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
,
seed
=
1
)
next_norm_clip
=
ada_clip
(
beta
,
norm_clip
)
print
(
'Liner next norm clip:'
,
next_norm_clip
)
decay_policy
=
'Geometric'
ada_clip
=
AdaClippingWithGaussianRandom
(
decay_policy
=
decay_policy
,
learning_rate
=
learning_rate
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
,
seed
=
1
)
next_norm_clip
=
ada_clip
(
beta
,
norm_clip
)
print
(
'Geometric next norm clip:'
,
next_norm_clip
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
def
test_ada_clip_gaussian_random_graph
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
decay_policy
=
'Linear'
beta
=
Tensor
(
0.5
,
mstype
.
float32
)
norm_clip
=
Tensor
(
1.0
,
mstype
.
float32
)
beta_stddev
=
0.1
learning_rate
=
0.1
target_unclipped_quantile
=
0.3
ada_clip
=
AdaClippingWithGaussianRandom
(
decay_policy
=
decay_policy
,
learning_rate
=
learning_rate
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
,
seed
=
1
)
next_norm_clip
=
ada_clip
(
beta
,
norm_clip
)
print
(
'Liner next norm clip:'
,
next_norm_clip
)
decay_policy
=
'Geometric'
ada_clip
=
AdaClippingWithGaussianRandom
(
decay_policy
=
decay_policy
,
learning_rate
=
learning_rate
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
,
seed
=
1
)
next_norm_clip
=
ada_clip
(
beta
,
norm_clip
)
print
(
'Geometric next norm clip:'
,
next_norm_clip
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
def
test_pynative_clip_mech_factory
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
decay_policy
=
'Linear'
beta
=
Tensor
(
0.5
,
mstype
.
float32
)
norm_clip
=
Tensor
(
1.0
,
mstype
.
float32
)
beta_stddev
=
0.1
learning_rate
=
0.1
target_unclipped_quantile
=
0.3
clip_mechanism
=
ClipMechanismsFactory
()
ada_clip
=
clip_mechanism
.
create
(
'Gaussian'
,
decay_policy
=
decay_policy
,
learning_rate
=
learning_rate
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
)
next_norm_clip
=
ada_clip
(
beta
,
norm_clip
)
print
(
'next_norm_clip: '
,
next_norm_clip
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
def
test_graph_clip_mech_factory
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
decay_policy
=
'Linear'
beta
=
Tensor
(
0.5
,
mstype
.
float32
)
norm_clip
=
Tensor
(
1.0
,
mstype
.
float32
)
beta_stddev
=
0.1
learning_rate
=
0.1
target_unclipped_quantile
=
0.3
clip_mechanism
=
ClipMechanismsFactory
()
ada_clip
=
clip_mechanism
.
create
(
'Gaussian'
,
decay_policy
=
decay_policy
,
learning_rate
=
learning_rate
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
)
next_norm_clip
=
ada_clip
(
beta
,
norm_clip
)
print
(
'next_norm_clip: '
,
next_norm_clip
)
tests/ut/python/diff_privacy/test_model_train.py
浏览文件 @
79c6403d
...
...
@@ -22,7 +22,8 @@ from mindspore import context
import
mindspore.dataset
as
ds
from
mindarmour.diff_privacy
import
DPModel
from
mindarmour.diff_privacy
import
MechanismsFactory
from
mindarmour.diff_privacy
import
NoiseMechanismsFactory
from
mindarmour.diff_privacy
import
ClipMechanismsFactory
from
mindarmour.diff_privacy
import
DPOptimizerClassFactory
from
test_network
import
LeNet5
...
...
@@ -30,10 +31,12 @@ from test_network import LeNet5
def
dataset_generator
(
batch_size
,
batches
):
"""mock training data."""
data
=
np
.
random
.
random
((
batches
*
batch_size
,
1
,
32
,
32
)).
astype
(
np
.
float32
)
label
=
np
.
random
.
randint
(
0
,
10
,
batches
*
batch_size
).
astype
(
np
.
int32
)
data
=
np
.
random
.
random
((
batches
*
batch_size
,
1
,
32
,
32
)).
astype
(
np
.
float32
)
label
=
np
.
random
.
randint
(
0
,
10
,
batches
*
batch_size
).
astype
(
np
.
int32
)
for
i
in
range
(
batches
):
yield
data
[
i
*
batch_size
:(
i
+
1
)
*
batch_size
],
label
[
i
*
batch_size
:(
i
+
1
)
*
batch_size
]
yield
data
[
i
*
batch_size
:(
i
+
1
)
*
batch_size
],
\
label
[
i
*
batch_size
:(
i
+
1
)
*
batch_size
]
@
pytest
.
mark
.
level0
...
...
@@ -55,16 +58,24 @@ def test_dp_model_with_pynative_mode():
factory_opt
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
norm_clip
,
initial_noise_multiplier
=
initial_noise_multiplier
)
net_opt
=
factory_opt
.
create
(
'Momentum'
)(
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
net_opt
=
factory_opt
.
create
(
'Momentum'
)(
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
clip_mech
=
ClipMechanismsFactory
().
create
(
'Gaussian'
,
decay_policy
=
'Linear'
,
learning_rate
=
0.01
,
target_unclipped_quantile
=
0.9
,
fraction_stddev
=
0.01
)
model
=
DPModel
(
micro_batches
=
micro_batches
,
norm_clip
=
norm_clip
,
mech
=
None
,
clip_mech
=
clip_mech
,
noise_mech
=
None
,
network
=
network
,
loss_fn
=
loss
,
optimizer
=
net_opt
,
metrics
=
None
)
ms_ds
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
'data'
,
'label'
])
ms_ds
.
set_dataset_size
(
batch_size
*
batches
)
ms_ds
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
'data'
,
'label'
])
ms_ds
.
set_dataset_size
(
batch_size
*
batches
)
model
.
train
(
epochs
,
ms_ds
,
dataset_sink_mode
=
False
)
...
...
@@ -82,19 +93,27 @@ def test_dp_model_with_graph_mode():
batches
=
128
epochs
=
1
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
mech
=
MechanismsFactory
().
create
(
'Gaussian'
,
norm_bound
=
norm_clip
,
initial_noise_multiplier
=
initial_noise_multiplier
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
noise_mech
=
NoiseMechanismsFactory
().
create
(
'Gaussian'
,
norm_bound
=
norm_clip
,
initial_noise_multiplier
=
initial_noise_multiplier
)
clip_mech
=
ClipMechanismsFactory
().
create
(
'Gaussian'
,
decay_policy
=
'Linear'
,
learning_rate
=
0.01
,
target_unclipped_quantile
=
0.9
,
fraction_stddev
=
0.01
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
model
=
DPModel
(
micro_batches
=
2
,
clip_mech
=
clip_mech
,
norm_clip
=
norm_clip
,
mech
=
mech
,
noise_mech
=
noise_
mech
,
network
=
network
,
loss_fn
=
loss
,
optimizer
=
net_opt
,
metrics
=
None
)
ms_ds
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
'data'
,
'label'
])
ms_ds
.
set_dataset_size
(
batch_size
*
batches
)
ms_ds
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
'data'
,
'label'
])
ms_ds
.
set_dataset_size
(
batch_size
*
batches
)
model
.
train
(
epochs
,
ms_ds
,
dataset_sink_mode
=
False
)
...
...
@@ -112,17 +131,25 @@ def test_dp_model_with_graph_mode_ada_gaussian():
batches
=
128
epochs
=
1
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
mech
=
MechanismsFactory
().
create
(
'AdaGaussian'
,
norm_bound
=
norm_clip
,
initial_noise_multiplier
=
initial_noise_multiplier
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
noise_mech
=
NoiseMechanismsFactory
().
create
(
'AdaGaussian'
,
norm_bound
=
norm_clip
,
initial_noise_multiplier
=
initial_noise_multiplier
)
clip_mech
=
ClipMechanismsFactory
().
create
(
'Gaussian'
,
decay_policy
=
'Linear'
,
learning_rate
=
0.01
,
target_unclipped_quantile
=
0.9
,
fraction_stddev
=
0.01
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
model
=
DPModel
(
micro_batches
=
2
,
clip_mech
=
clip_mech
,
norm_clip
=
norm_clip
,
mech
=
mech
,
noise_mech
=
noise_
mech
,
network
=
network
,
loss_fn
=
loss
,
optimizer
=
net_opt
,
metrics
=
None
)
ms_ds
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
'data'
,
'label'
])
ms_ds
.
set_dataset_size
(
batch_size
*
batches
)
ms_ds
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
'data'
,
'label'
])
ms_ds
.
set_dataset_size
(
batch_size
*
batches
)
model
.
train
(
epochs
,
ms_ds
,
dataset_sink_mode
=
False
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录