Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
fe97f43f
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
2
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fe97f43f
编写于
6月 17, 2020
作者:
Z
zhenghuanhuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add differential privacy in GRAPH MODE context mode.
上级
bfd8d88d
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
449 addition
and
172 deletion
+449
-172
example/mnist_demo/lenet5_config.py
example/mnist_demo/lenet5_config.py
+2
-2
example/mnist_demo/lenet5_dp.py
example/mnist_demo/lenet5_dp.py
+140
-0
example/mnist_demo/lenet5_dp_pynative_mode.py
example/mnist_demo/lenet5_dp_pynative_mode.py
+15
-23
mindarmour/diff_privacy/mechanisms/mechanisms.py
mindarmour/diff_privacy/mechanisms/mechanisms.py
+63
-50
mindarmour/diff_privacy/optimizer/optimizer.py
mindarmour/diff_privacy/optimizer/optimizer.py
+38
-13
mindarmour/diff_privacy/train/model.py
mindarmour/diff_privacy/train/model.py
+67
-51
setup.py
setup.py
+1
-1
tests/ut/python/diff_privacy/test_mechanisms.py
tests/ut/python/diff_privacy/test_mechanisms.py
+80
-21
tests/ut/python/diff_privacy/test_model_train.py
tests/ut/python/diff_privacy/test_model_train.py
+42
-11
tests/ut/python/diff_privacy/test_optimizer.py
tests/ut/python/diff_privacy/test_optimizer.py
+1
-0
未找到文件。
example/mnist_demo/lenet5_config.py
浏览文件 @
fe97f43f
...
...
@@ -20,7 +20,7 @@ from easydict import EasyDict as edict
mnist_cfg
=
edict
({
'num_classes'
:
10
,
# the number of classes of model's output
'lr'
:
0.
0
1
,
# the learning rate of model's optimizer
'lr'
:
0.1
,
# the learning rate of model's optimizer
'momentum'
:
0.9
,
# the momentum value of model's optimizer
'epoch_size'
:
10
,
# training epochs
'batch_size'
:
256
,
# batch size for training
...
...
@@ -33,7 +33,7 @@ mnist_cfg = edict({
'dataset_sink_mode'
:
False
,
# whether deliver all training data to device one time
'micro_batches'
:
16
,
# the number of small batches split from an original batch
'l2_norm_bound'
:
1.0
,
# the clip bound of the gradients of model's training parameters
'initial_noise_multiplier'
:
1.5
,
# the initial multiplication coefficient of the noise added to training
'initial_noise_multiplier'
:
0.2
,
# the initial multiplication coefficient of the noise added to training
# parameters' gradients
'mechanisms'
:
'AdaGaussian'
,
# the method of adding noise in gradients while training
'optimizer'
:
'Momentum'
# the base optimizer used for Differential privacy training
...
...
example/mnist_demo/lenet5_dp.py
0 → 100644
浏览文件 @
fe97f43f
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python lenet5_dp.py --data_path /YourDataPath --micro_batches=2
"""
import
os
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore.train.callback
import
ModelCheckpoint
from
mindspore.train.callback
import
CheckpointConfig
from
mindspore.train.callback
import
LossMonitor
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
CV
import
mindspore.dataset.transforms.c_transforms
as
C
from
mindspore.dataset.transforms.vision
import
Inter
import
mindspore.common.dtype
as
mstype
from
mindarmour.diff_privacy
import
DPModel
from
mindarmour.diff_privacy
import
PrivacyMonitorFactory
from
mindarmour.diff_privacy
import
MechanismsFactory
from
mindarmour.utils.logger
import
LogUtil
from
lenet5_net
import
LeNet5
from
lenet5_config
import
mnist_cfg
as
cfg
LOGGER
=
LogUtil
.
get_instance
()
LOGGER
.
set_level
(
'INFO'
)
TAG
=
'Lenet5_train'
def
generate_mnist_dataset
(
data_path
,
batch_size
=
32
,
repeat_size
=
1
,
num_parallel_workers
=
1
,
sparse
=
True
):
"""
create dataset for training or testing
"""
# define dataset
ds1
=
ds
.
MnistDataset
(
data_path
)
# define operation parameters
resize_height
,
resize_width
=
32
,
32
rescale
=
1.0
/
255.0
shift
=
0.0
# define map operations
resize_op
=
CV
.
Resize
((
resize_height
,
resize_width
),
interpolation
=
Inter
.
LINEAR
)
rescale_op
=
CV
.
Rescale
(
rescale
,
shift
)
hwc2chw_op
=
CV
.
HWC2CHW
()
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
# apply map operations on images
if
not
sparse
:
one_hot_enco
=
C
.
OneHot
(
10
)
ds1
=
ds1
.
map
(
input_columns
=
"label"
,
operations
=
one_hot_enco
,
num_parallel_workers
=
num_parallel_workers
)
type_cast_op
=
C
.
TypeCast
(
mstype
.
float32
)
ds1
=
ds1
.
map
(
input_columns
=
"label"
,
operations
=
type_cast_op
,
num_parallel_workers
=
num_parallel_workers
)
ds1
=
ds1
.
map
(
input_columns
=
"image"
,
operations
=
resize_op
,
num_parallel_workers
=
num_parallel_workers
)
ds1
=
ds1
.
map
(
input_columns
=
"image"
,
operations
=
rescale_op
,
num_parallel_workers
=
num_parallel_workers
)
ds1
=
ds1
.
map
(
input_columns
=
"image"
,
operations
=
hwc2chw_op
,
num_parallel_workers
=
num_parallel_workers
)
# apply DatasetOps
buffer_size
=
10000
ds1
=
ds1
.
shuffle
(
buffer_size
=
buffer_size
)
ds1
=
ds1
.
batch
(
batch_size
,
drop_remainder
=
True
)
ds1
=
ds1
.
repeat
(
repeat_size
)
return
ds1
if
__name__
==
"__main__"
:
# This configure can run both in pynative mode and graph mode
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
cfg
.
device_target
)
network
=
LeNet5
()
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
save_checkpoint_steps
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
directory
=
'./trained_ckpt_file/'
,
config
=
config_ck
)
# get training dataset
ds_train
=
generate_mnist_dataset
(
os
.
path
.
join
(
cfg
.
data_path
,
"train"
),
cfg
.
batch_size
,
cfg
.
epoch_size
)
if
cfg
.
micro_batches
and
cfg
.
batch_size
%
cfg
.
micro_batches
!=
0
:
raise
ValueError
(
"Number of micro_batches should divide evenly batch_size"
)
# Create a factory class of DP mechanisms, this method is adding noise in gradients while training.
# Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which
# means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism.
mech
=
MechanismsFactory
().
create
(
cfg
.
mechanisms
,
norm_bound
=
cfg
.
l2_norm_bound
,
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
)
net_opt
=
nn
.
Momentum
(
params
=
network
.
trainable_params
(),
learning_rate
=
cfg
.
lr
,
momentum
=
cfg
.
momentum
)
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps
# and delta) while training.
rdp_monitor
=
PrivacyMonitorFactory
.
create
(
'rdp'
,
num_samples
=
60000
,
batch_size
=
cfg
.
batch_size
,
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
*
cfg
.
l2_norm_bound
,
per_print_times
=
10
)
# Create the DP model for training.
model
=
DPModel
(
micro_batches
=
cfg
.
micro_batches
,
norm_clip
=
cfg
.
l2_norm_bound
,
mech
=
mech
,
network
=
network
,
loss_fn
=
net_loss
,
optimizer
=
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
LOGGER
.
info
(
TAG
,
"============== Starting Training =============="
)
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
ckpoint_cb
,
LossMonitor
(),
rdp_monitor
],
dataset_sink_mode
=
cfg
.
dataset_sink_mode
)
LOGGER
.
info
(
TAG
,
"============== Starting Testing =============="
)
ckpt_file_name
=
'trained_ckpt_file/checkpoint_lenet-10_234.ckpt'
param_dict
=
load_checkpoint
(
ckpt_file_name
)
load_param_into_net
(
network
,
param_dict
)
ds_eval
=
generate_mnist_dataset
(
os
.
path
.
join
(
cfg
.
data_path
,
'test'
),
batch_size
=
cfg
.
batch_size
)
acc
=
model
.
eval
(
ds_eval
,
dataset_sink_mode
=
False
)
LOGGER
.
info
(
TAG
,
"============== Accuracy: %s =============="
,
acc
)
example/mnist_demo/lenet5_dp_
model_train
.py
→
example/mnist_demo/lenet5_dp_
pynative_mode
.py
浏览文件 @
fe97f43f
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python lenet5_dp_
model_train
.py --data_path /YourDataPath --micro_batches=2
python lenet5_dp_
pynative_mode
.py --data_path /YourDataPath --micro_batches=2
"""
import
os
...
...
@@ -30,8 +30,8 @@ from mindspore.dataset.transforms.vision import Inter
import
mindspore.common.dtype
as
mstype
from
mindarmour.diff_privacy
import
DPModel
from
mindarmour.diff_privacy
import
DPOptimizerClassFactory
from
mindarmour.diff_privacy
import
PrivacyMonitorFactory
from
mindarmour.diff_privacy
import
DPOptimizerClassFactory
from
mindarmour.utils.logger
import
LogUtil
from
lenet5_net
import
LeNet5
from
lenet5_config
import
mnist_cfg
as
cfg
...
...
@@ -86,8 +86,8 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1,
if
__name__
==
"__main__"
:
# This configure just can run in pynative mode.
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
cfg
.
device_target
)
network
=
LeNet5
()
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
save_checkpoint_steps
,
...
...
@@ -103,34 +103,26 @@ if __name__ == "__main__":
if
cfg
.
micro_batches
and
cfg
.
batch_size
%
cfg
.
micro_batches
!=
0
:
raise
ValueError
(
"Number of micro_batches should divide evenly batch_size"
)
# Create a factory class of DP optimizer
gaussian_mech
=
DPOptimizerClassFactory
(
cfg
.
micro_batches
)
# Set the method of adding noise in gradients while training. Initial_noise_multiplier is suggested to be greater
# than 1.0, otherwise the privacy budget would be huge, which means that the privacy protection effect is weak.
# mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' mechanism while
# be constant with 'Gaussian' mechanism.
gaussian_mech
.
set_mechanisms
(
cfg
.
mechanisms
,
norm_bound
=
cfg
.
l2_norm_bound
,
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
)
# Wrap the base optimizer for DP training. Momentum optimizer is suggested for LenNet5.
net_opt
=
gaussian_mech
.
create
(
cfg
.
optimizer
)(
params
=
network
.
trainable_params
(),
learning_rate
=
cfg
.
lr
,
momentum
=
cfg
.
momentum
)
# Create a factory class of DP mechanisms, this method is adding noise in gradients while training.
# Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which
# means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism.
dp_opt
=
DPOptimizerClassFactory
(
micro_batches
=
cfg
.
micro_batches
)
dp_opt
.
set_mechanisms
(
cfg
.
mechanisms
,
norm_bound
=
cfg
.
l2_norm_bound
,
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
)
net_opt
=
dp_opt
.
create
(
'Momentum'
)(
params
=
network
.
trainable_params
(),
learning_rate
=
cfg
.
lr
,
momentum
=
cfg
.
momentum
)
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps
# and delta) while training.
rdp_monitor
=
PrivacyMonitorFactory
.
create
(
'rdp'
,
num_samples
=
60000
,
batch_size
=
cfg
.
batch_size
,
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
,
per_print_times
=
50
)
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
*
cfg
.
l2_norm_bound
,
per_print_times
=
10
)
# Create the DP model for training.
model
=
DPModel
(
micro_batches
=
cfg
.
micro_batches
,
norm_clip
=
cfg
.
l2_norm_bound
,
dp_mech
=
gaussian_mech
.
mech
,
mech
=
None
,
network
=
network
,
loss_fn
=
net_loss
,
optimizer
=
net_opt
,
...
...
mindarmour/diff_privacy/mechanisms/mechanisms.py
浏览文件 @
fe97f43f
...
...
@@ -14,8 +14,6 @@
"""
Noise Mechanisms.
"""
import
numpy
as
np
from
mindspore
import
Tensor
from
mindspore.nn
import
Cell
from
mindspore.ops
import
operations
as
P
...
...
@@ -24,6 +22,7 @@ from mindspore.common import dtype as mstype
from
mindarmour.utils._check_param
import
check_param_type
from
mindarmour.utils._check_param
import
check_value_positive
from
mindarmour.utils._check_param
import
check_value_non_negative
from
mindarmour.utils._check_param
import
check_param_in_range
...
...
@@ -62,7 +61,8 @@ class Mechanisms(Cell):
"""
Basic class of noise generated mechanism.
"""
def
construct
(
self
,
shape
):
def
construct
(
self
,
gradients
):
"""
Construct function.
"""
...
...
@@ -78,41 +78,47 @@ class GaussianRandom(Mechanisms):
initial_noise_multiplier(float): Ratio of the standard deviation of
Gaussian noise divided by the norm_bound, which will be used to
calculate privacy spent. Default: 1.5.
mean(float): Average value of random noise. Default: 0.0.
seed(int): Original random seed. Default: 0.
Returns:
Tensor, generated noise.
Tensor, generated noise
with shape like given gradients
.
Examples:
>>>
shape = (3, 2, 4
)
>>>
gradients = Tensor([0.2, 0.9], mstype.float32
)
>>> norm_bound = 1.0
>>> initial_noise_multiplier =
1.5
>>> initial_noise_multiplier =
0.1
>>> net = GaussianRandom(norm_bound, initial_noise_multiplier)
>>> res = net(
shape
)
>>> res = net(
gradients
)
>>> print(res)
"""
def
__init__
(
self
,
norm_bound
=
1.0
,
initial_noise_multiplier
=
1.5
):
def
__init__
(
self
,
norm_bound
=
1.0
,
initial_noise_multiplier
=
1.5
,
mean
=
0.0
,
seed
=
0
):
super
(
GaussianRandom
,
self
).
__init__
()
self
.
_norm_bound
=
check_value_positive
(
'norm_bound'
,
norm_bound
)
self
.
_norm_bound
=
Tensor
(
norm_bound
,
mstype
.
float32
)
self
.
_initial_noise_multiplier
=
check_value_positive
(
'initial_noise_multiplier'
,
initial_noise_multiplier
,)
stddev
=
self
.
_norm_bound
*
self
.
_initial_noise_multiplier
self
.
_stddev
=
stddev
self
.
_mean
=
0
def
construct
(
self
,
shape
):
initial_noise_multiplier
)
self
.
_initial_noise_multiplier
=
Tensor
(
initial_noise_multiplier
,
mstype
.
float32
)
mean
=
check_param_type
(
'mean'
,
mean
,
float
)
mean
=
check_value_non_negative
(
'mean'
,
mean
)
self
.
_mean
=
Tensor
(
mean
,
mstype
.
float32
)
self
.
_normal
=
P
.
Normal
(
seed
=
seed
)
def
construct
(
self
,
gradients
):
"""
Generated Gaussian noise.
Args:
shape(tuple): The shape of
gradients.
gradients(Tensor): The
gradients.
Returns:
Tensor, generated noise.
Tensor, generated noise
with shape like given gradients
.
"""
shape
=
check_param_type
(
'shape'
,
shape
,
tuple
)
noise
=
np
.
random
.
normal
(
self
.
_mean
,
self
.
_stddev
,
shape
)
return
Tensor
(
noise
,
mstype
.
float32
)
shape
=
P
.
Shape
()(
gradients
)
stddev
=
P
.
Mul
()(
self
.
_norm_bound
,
self
.
_initial_noise_multiplier
)
noise
=
self
.
_normal
(
shape
,
self
.
_mean
,
stddev
)
return
noise
class
AdaGaussianRandom
(
Mechanisms
):
...
...
@@ -126,54 +132,60 @@ class AdaGaussianRandom(Mechanisms):
initial_noise_multiplier(float): Ratio of the standard deviation of
Gaussian noise divided by the norm_bound, which will be used to
calculate privacy spent. Default: 5.0.
noise_decay_rate(float): Hyperparameter for controlling the noise decay.
mean(float): Average value of random noise. Default: 0.0
noise_decay_rate(float): Hyper parameter for controlling the noise decay.
Default: 6e-4.
decay_policy(str): Noise decay strategy include 'Step' and 'Time'.
Default: 'Time'.
seed(int): Original random seed. Default: 0.
Returns:
Tensor, generated noise.
Tensor, generated noise
with shape like given gradients
.
Examples:
>>>
shape = (3, 2, 4
)
>>>
gradients = Tensor([0.2, 0.9], mstype.float32
)
>>> norm_bound = 1.0
>>> initial_noise_multiplier = 0.1
>>> noise_decay_rate = 0.5
>>> initial_noise_multiplier = 5.0
>>> mean = 0.0
>>> noise_decay_rate = 6e-4
>>> decay_policy = "Time"
>>> net = AdaGaussianRandom(norm_bound, initial_noise_multiplier,
>>> net = AdaGaussianRandom(norm_bound, initial_noise_multiplier,
mean
>>> noise_decay_rate, decay_policy)
>>> res = net(
shape
)
>>> res = net(
gradients
)
>>> print(res)
"""
def
__init__
(
self
,
norm_bound
=
1.5
,
initial_noise_multiplier
=
5.0
,
noise_decay_rate
=
6e-4
,
decay_policy
=
'Time'
):
def
__init__
(
self
,
norm_bound
=
1.5
,
initial_noise_multiplier
=
5.0
,
mean
=
0.0
,
noise_decay_rate
=
6e-4
,
decay_policy
=
'Time'
,
seed
=
0
):
super
(
AdaGaussianRandom
,
self
).
__init__
()
norm_bound
=
check_value_positive
(
'norm_bound'
,
norm_bound
)
initial_noise_multiplier
=
check_value_positive
(
'initial_noise_multiplier'
,
initial_noise_multiplier
)
initial_noise_multiplier
=
Tensor
(
np
.
array
(
initial_noise_multiplier
,
np
.
float32
))
self
.
_norm_bound
=
Tensor
(
norm_bound
,
mstype
.
float32
)
initial_noise_multiplier
=
Tensor
(
initial_noise_multiplier
,
mstype
.
float32
)
self
.
_initial_noise_multiplier
=
Parameter
(
initial_noise_multiplier
,
name
=
'initial_noise_multiplier'
)
self
.
_stddev
=
P
.
Mul
()(
self
.
_norm_bound
,
self
.
_initial_noise_multiplier
)
self
.
_noise_multiplier
=
Parameter
(
initial_noise_multiplier
,
name
=
'noise_multiplier'
)
norm_bound
=
check_value_positive
(
'norm_bound'
,
norm_bound
)
self
.
_norm_bound
=
Tensor
(
np
.
array
(
norm_bound
,
np
.
float32
)
)
mean
=
check_param_type
(
'mean'
,
mean
,
float
)
mean
=
check_value_non_negative
(
'mean'
,
mean
)
self
.
_mean
=
Tensor
(
mean
,
mstype
.
float32
)
noise_decay_rate
=
check_param_type
(
'noise_decay_rate'
,
noise_decay_rate
,
float
)
check_param_in_range
(
'noise_decay_rate'
,
noise_decay_rate
,
0.0
,
1.0
)
self
.
_noise_decay_rate
=
Tensor
(
np
.
array
(
noise_decay_rate
,
np
.
float32
))
self
.
_noise_decay_rate
=
Tensor
(
noise_decay_rate
,
mstype
.
float32
)
if
decay_policy
not
in
[
'Time'
,
'Step'
]:
raise
NameError
(
"The decay_policy must be in ['Time', 'Step'], but "
"get {}"
.
format
(
decay_policy
))
self
.
_decay_policy
=
decay_policy
self
.
_mean
=
0.0
self
.
_sub
=
P
.
Sub
()
self
.
_mul
=
P
.
Mul
()
self
.
_add
=
P
.
TensorAdd
()
self
.
_div
=
P
.
Div
()
self
.
_stddev
=
self
.
_update_stddev
()
self
.
_dtype
=
mstype
.
float32
self
.
_normal
=
P
.
Normal
(
seed
=
seed
)
self
.
_assign
=
P
.
Assign
()
def
_update_multiplier
(
self
):
""" Update multiplier. """
...
...
@@ -181,31 +193,32 @@ class AdaGaussianRandom(Mechanisms):
temp
=
self
.
_div
(
self
.
_initial_noise_multiplier
,
self
.
_noise_multiplier
)
temp
=
self
.
_add
(
temp
,
self
.
_noise_decay_rate
)
temp
=
self
.
_div
(
self
.
_initial_noise_multiplier
,
temp
)
self
.
_noise_multiplier
=
Parameter
(
temp
,
name
=
'noise_multiplier'
)
self
.
_noise_multiplier
=
self
.
_assign
(
self
.
_noise_multiplier
,
self
.
_div
(
self
.
_initial_noise_multiplier
,
temp
)
)
else
:
one
=
Tensor
(
1
,
self
.
_dtype
)
temp
=
self
.
_sub
(
one
,
self
.
_noise_decay_rate
)
temp
=
self
.
_mul
(
temp
,
self
.
_noise_multiplier
)
self
.
_noise_multiplier
=
Parameter
(
temp
,
name
=
'noise_multiplier'
)
self
.
_noise_multiplier
=
self
.
_assign
(
self
.
_noise_multiplier
,
self
.
_mul
(
temp
,
self
.
_noise_multiplier
)
)
return
self
.
_noise_multiplier
def
_update_stddev
(
self
):
self
.
_stddev
=
self
.
_
mul
(
self
.
_noise_multiplier
,
self
.
_norm_bound
)
self
.
_stddev
=
self
.
_
assign
(
self
.
_stddev
,
self
.
_mul
(
self
.
_noise_multiplier
,
self
.
_norm_bound
)
)
return
self
.
_stddev
def
construct
(
self
,
shape
):
def
construct
(
self
,
gradients
):
"""
Generate adaptive Gaussian noise.
Args:
shape(tuple): The shape of
gradients.
gradients(Tensor): The
gradients.
Returns:
Tensor, generated noise.
Tensor, generated noise
with shape like given gradients
.
"""
shape
=
check_param_type
(
'shape'
,
shape
,
tuple
)
noise
=
np
.
random
.
normal
(
self
.
_mean
,
self
.
_stddev
.
asnumpy
(),
shape
)
self
.
_update_multiplier
()
self
.
_update_stddev
()
return
Tensor
(
noise
,
mstype
.
float32
)
shape
=
P
.
Shape
()(
gradients
)
noise
=
self
.
_normal
(
shape
,
self
.
_mean
,
self
.
_stddev
)
# pylint: disable=unused-variable
mt
=
self
.
_update_multiplier
()
# pylint: disable=unused-variable
std
=
self
.
_update_stddev
()
return
noise
mindarmour/diff_privacy/optimizer/optimizer.py
浏览文件 @
fe97f43f
...
...
@@ -14,13 +14,37 @@
"""
Differential privacy optimizer.
"""
import
mindspore
as
ms
from
mindspore
import
nn
from
mindspore
import
Tensor
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
from
mindspore.common
import
dtype
as
mstype
from
mindarmour.diff_privacy.mechanisms.mechanisms
import
MechanismsFactory
from
mindarmour.utils._check_param
import
check_int_positive
_grad_scale
=
C
.
MultitypeFuncGraph
(
"grad_scale"
)
_reciprocal
=
P
.
Reciprocal
()
@
_grad_scale
.
register
(
"Tensor"
,
"Tensor"
)
def
tensor_grad_scale
(
scale
,
grad
):
""" grad scaling """
return
grad
*
_reciprocal
(
scale
)
class
_TupleAdd
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
_TupleAdd
,
self
).
__init__
()
self
.
add
=
P
.
TensorAdd
()
self
.
hyper_map
=
C
.
HyperMap
()
def
construct
(
self
,
input1
,
input2
):
"""Add two tuple of data."""
out
=
self
.
hyper_map
(
self
.
add
,
input1
,
input2
)
return
out
class
DPOptimizerClassFactory
:
"""
...
...
@@ -36,9 +60,10 @@ class DPOptimizerClassFactory:
>>> GaussianSGD = DPOptimizerClassFactory(micro_batches=2)
>>> GaussianSGD.set_mechanisms('Gaussian', norm_bound=1.0, initial_noise_multiplier=1.5)
>>> net_opt = GaussianSGD.create('Momentum')(params=network.trainable_params(),
>>> learning_rate=cfg.lr,
>>> momentum=cfg.momentum)
>>>
learning_rate=cfg.lr,
>>>
momentum=cfg.momentum)
"""
def
__init__
(
self
,
micro_batches
=
2
):
self
.
_mech_factory
=
MechanismsFactory
()
self
.
mech
=
None
...
...
@@ -78,6 +103,7 @@ class DPOptimizerClassFactory:
"""
Wrap original mindspore optimizer with `self._mech`.
"""
class
DPOptimizer
(
cls
):
"""
Initialize the DPOptimizerClass.
...
...
@@ -85,23 +111,22 @@ class DPOptimizerClassFactory:
Returns:
Optimizer, Optimizer class.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
DPOptimizer
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
_mech
=
mech
self
.
_tuple_add
=
_TupleAdd
()
self
.
_hyper_map
=
C
.
HyperMap
()
self
.
_micro_float
=
Tensor
(
micro_batches
,
mstype
.
float32
)
def
construct
(
self
,
gradients
):
"""
construct a compute flow.
"""
g_len
=
len
(
gradients
)
gradient_noise
=
list
(
gradients
)
for
i
in
range
(
g_len
):
gradient_noise
[
i
]
=
gradient_noise
[
i
].
asnumpy
()
gradient_noise
[
i
]
=
self
.
_mech
(
gradient_noise
[
i
].
shape
).
asnumpy
()
+
gradient_noise
[
i
]
gradient_noise
[
i
]
=
gradient_noise
[
i
]
/
micro_batches
gradient_noise
[
i
]
=
Tensor
(
gradient_noise
[
i
],
ms
.
float32
)
gradients
=
tuple
(
gradient_noise
)
gradients
=
super
(
DPOptimizer
,
self
).
construct
(
gradients
)
grad_noise
=
self
.
_hyper_map
(
self
.
_mech
,
gradients
)
grads
=
self
.
_tuple_add
(
gradients
,
grad_noise
)
grads
=
self
.
_hyper_map
(
F
.
partial
(
_grad_scale
,
self
.
_micro_float
),
grads
)
gradients
=
super
(
DPOptimizer
,
self
).
construct
(
grads
)
return
gradients
return
DPOptimizer
mindarmour/diff_privacy/train/model.py
浏览文件 @
fe97f43f
...
...
@@ -16,7 +16,6 @@ Differential privacy model.
"""
from
easydict
import
EasyDict
as
edict
import
mindspore
as
ms
from
mindspore.train.model
import
Model
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
...
...
@@ -48,21 +47,19 @@ from mindspore.nn.wrap.loss_scale import _grad_overflow
from
mindspore.nn
import
Cell
from
mindspore
import
ParameterTuple
from
mindarmour.diff_privacy.mechanisms
import
mechanisms
from
mindarmour.utils._check_param
import
check_param_type
from
mindarmour.utils._check_param
import
check_value_positive
from
mindarmour.utils._check_param
import
check_int_positive
GRADIENT_CLIP_TYPE
=
1
grad_scale
=
C
.
MultitypeFuncGraph
(
"grad_scale"
)
reciprocal
=
P
.
Reciprocal
()
_
grad_scale
=
C
.
MultitypeFuncGraph
(
"grad_scale"
)
_
reciprocal
=
P
.
Reciprocal
()
@
grad_scale
.
register
(
"Tensor"
,
"Tensor"
)
@
_
grad_scale
.
register
(
"Tensor"
,
"Tensor"
)
def
tensor_grad_scale
(
scale
,
grad
):
""" grad scaling """
return
grad
*
reciprocal
(
scale
)
return
grad
*
F
.
cast
(
_reciprocal
(
scale
),
F
.
dtype
(
grad
)
)
class
DPModel
(
Model
):
...
...
@@ -72,7 +69,7 @@ class DPModel(Model):
Args:
micro_batches (int): The number of small batches split from an original batch. Default: 2.
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0.
dp_
mech (Mechanisms): The object can generate the different type of noise. Default: None.
mech (Mechanisms): The object can generate the different type of noise. Default: None.
Examples:
>>> class Net(nn.Cell):
...
...
@@ -94,32 +91,37 @@ class DPModel(Model):
>>>
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9)
>>> gaussian_mech = DPOptimizerClassFactory()
>>> gaussian_mech.set_mechanisms('Gaussian',
>>> norm_bound=args.l2_norm_bound,
>>> initial_noise_multiplier=args.initial_noise_multiplier)
>>> net_opt = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9)
>>> mech = MechanismsFactory().create('Gaussian',
>>> norm_bound=args.l2_norm_bound,
>>> initial_noise_multiplier=args.initial_noise_multiplier)
>>> model = DPModel(micro_batches=2,
>>> norm_clip=1.0,
>>>
dp_mech=gaussian_mech.
mech,
>>>
mech=
mech,
>>> network=net,
>>> loss_fn=loss,
>>> optimizer=
optim
,
>>> optimizer=
net_opt
,
>>> metrics=None)
>>> dataset = get_dataset()
>>> model.train(2, dataset)
"""
def
__init__
(
self
,
micro_batches
=
2
,
norm_clip
=
1.0
,
dp_mech
=
None
,
**
kwargs
):
def
__init__
(
self
,
micro_batches
=
2
,
norm_clip
=
1.0
,
mech
=
None
,
**
kwargs
):
if
micro_batches
:
self
.
_micro_batches
=
check_int_positive
(
'micro_batches'
,
micro_batches
)
else
:
self
.
_micro_batches
=
None
norm_clip
=
check_param_type
(
'norm_clip'
,
norm_clip
,
float
)
self
.
_norm_clip
=
check_value_positive
(
'norm_clip'
,
norm_clip
)
if
isinstance
(
dp_mech
,
mechanisms
.
Mechanisms
):
self
.
_dp_mech
=
dp_mech
else
:
raise
TypeError
(
'dp mechanisms should be instance of class Mechansms, but got {}'
.
format
(
type
(
dp_mech
)))
float_norm_clip
=
check_param_type
(
'l2_norm_clip'
,
norm_clip
,
float
)
self
.
_norm_clip
=
check_value_positive
(
'l2_norm_clip'
,
float_norm_clip
)
if
mech
is
not
None
and
"DPOptimizer"
in
kwargs
[
'optimizer'
].
__class__
.
__name__
:
raise
ValueError
(
'DPOptimizer is not supported while mech is not None'
)
if
mech
is
None
:
if
"DPOptimizer"
in
kwargs
[
'optimizer'
].
__class__
.
__name__
:
if
context
.
get_context
(
'mode'
)
!=
context
.
PYNATIVE_MODE
:
raise
ValueError
(
'DPOptimizer just support pynative mode currently.'
)
else
:
raise
ValueError
(
'DPModel should set mech or DPOptimizer configure, please refer to example.'
)
self
.
_mech
=
mech
super
(
DPModel
,
self
).
__init__
(
**
kwargs
)
def
_amp_build_train_network
(
self
,
network
,
optimizer
,
loss_fn
=
None
,
level
=
'O0'
,
**
kwargs
):
...
...
@@ -179,14 +181,14 @@ class DPModel(Model):
scale_update_cell
=
update_cell
,
micro_batches
=
self
.
_micro_batches
,
l2_norm_clip
=
self
.
_norm_clip
,
mech
=
self
.
_
dp_
mech
).
set_train
()
mech
=
self
.
_mech
).
set_train
()
return
network
network
=
_TrainOneStepCell
(
network
,
optimizer
,
loss_scale
,
micro_batches
=
self
.
_micro_batches
,
l2_norm_clip
=
self
.
_norm_clip
,
mech
=
self
.
_
dp_
mech
).
set_train
()
mech
=
self
.
_mech
).
set_train
()
return
network
def
_build_train_network
(
self
):
...
...
@@ -244,6 +246,7 @@ class _ClipGradients(nn.Cell):
Outputs:
tuple[Tensor], clipped gradients.
"""
def
__init__
(
self
):
super
(
_ClipGradients
,
self
).
__init__
()
self
.
clip_by_norm
=
nn
.
ClipByNorm
()
...
...
@@ -253,7 +256,8 @@ class _ClipGradients(nn.Cell):
"""
construct a compute flow.
"""
if
clip_type
not
in
(
0
,
1
):
# pylint: disable=consider-using-in
if
clip_type
!=
0
and
clip_type
!=
1
:
return
grads
new_grads
=
()
...
...
@@ -268,6 +272,18 @@ class _ClipGradients(nn.Cell):
return
new_grads
class
_TupleAdd
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
_TupleAdd
,
self
).
__init__
()
self
.
add
=
P
.
TensorAdd
()
self
.
hyper_map
=
C
.
HyperMap
()
def
construct
(
self
,
input1
,
input2
):
"""Add two tuple of data."""
out
=
self
.
hyper_map
(
self
.
add
,
input1
,
input2
)
return
out
class
_TrainOneStepWithLossScaleCell
(
Cell
):
r
"""
Network training with loss scaling.
...
...
@@ -347,6 +363,9 @@ class _TrainOneStepWithLossScaleCell(Cell):
self
.
_split
=
P
.
Split
(
0
,
self
.
_micro_batches
)
self
.
_clip_by_global_norm
=
_ClipGradients
()
self
.
_mech
=
mech
self
.
_tuple_add
=
_TupleAdd
()
self
.
_hyper_map
=
C
.
HyperMap
()
self
.
_micro_float
=
Tensor
(
micro_batches
,
mstype
.
float32
)
def
construct
(
self
,
data
,
label
,
sens
=
None
):
"""
...
...
@@ -368,32 +387,28 @@ class _TrainOneStepWithLossScaleCell(Cell):
weights
=
self
.
weights
record_datas
=
self
.
_split
(
data
)
record_labels
=
self
.
_split
(
label
)
grads
=
()
# first index
loss
=
self
.
network
(
record_datas
[
0
],
record_labels
[
0
])
scaling_sens_filled
=
C
.
ones_like
(
loss
)
*
F
.
cast
(
scaling_sens
,
F
.
dtype
(
loss
))
record_grad
=
self
.
grad
(
self
.
network
,
weights
)(
record_datas
[
0
],
record_labels
[
0
],
scaling_sens_filled
)
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
self
.
_l2_norm
)
grad_sum
=
list
(
record_grad
)
grad_len
=
len
(
record_grad
)
for
i
in
range
(
grad_len
):
grad_sum
[
i
]
=
grad_sum
[
i
].
asnumpy
()
grads
=
record_grad
total_loss
=
loss
for
i
in
range
(
1
,
self
.
_micro_batches
):
loss
=
self
.
network
(
record_datas
[
i
],
record_labels
[
i
])
scaling_sens_filled
=
C
.
ones_like
(
loss
)
*
F
.
cast
(
scaling_sens
,
F
.
dtype
(
loss
))
record_grad
=
self
.
grad
(
self
.
network
,
weights
)(
record_datas
[
i
],
record_labels
[
i
],
scaling_sens_filled
)
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
self
.
_l2_norm
)
for
j
in
range
(
grad_len
):
grad_sum
[
j
]
=
grad_sum
[
j
]
+
record_grad
[
j
].
asnumpy
()
grads
=
self
.
_tuple_add
(
grads
,
record_grad
)
total_loss
=
P
.
TensorAdd
()(
total_loss
,
loss
)
loss
=
P
.
Div
()(
total_loss
,
self
.
_micro_float
)
for
i
in
range
(
grad_len
)
:
grad_
sum
[
i
]
=
Tensor
(
grad_sum
[
i
],
ms
.
float32
)
grads
=
tuple
(
grad_sum
)
loss
=
self
.
network
(
data
,
label
)
if
self
.
_mech
is
not
None
:
grad_
noise
=
self
.
_hyper_map
(
self
.
_mech
,
grads
)
grads
=
self
.
_tuple_add
(
grads
,
grad_noise
)
grads
=
self
.
_hyper_map
(
F
.
partial
(
_grad_scale
,
self
.
_micro_float
),
grads
)
grads
=
self
.
hyper_map
(
F
.
partial
(
grad_scale
,
scaling_sens
),
grads
)
grads
=
self
.
hyper_map
(
F
.
partial
(
_
grad_scale
,
scaling_sens
),
grads
)
# apply grad reducer on grads
grads
=
self
.
grad_reducer
(
grads
)
# get the overflow buffer
...
...
@@ -474,6 +489,9 @@ class _TrainOneStepCell(Cell):
self
.
_split
=
P
.
Split
(
0
,
self
.
_micro_batches
)
self
.
_clip_by_global_norm
=
_ClipGradients
()
self
.
_mech
=
mech
self
.
_tuple_add
=
_TupleAdd
()
self
.
_hyper_map
=
C
.
HyperMap
()
self
.
_micro_float
=
Tensor
(
micro_batches
,
mstype
.
float32
)
def
construct
(
self
,
data
,
label
):
"""
...
...
@@ -486,23 +504,21 @@ class _TrainOneStepCell(Cell):
sens
=
P
.
Fill
()(
P
.
DType
()(
loss
),
P
.
Shape
()(
loss
),
self
.
sens
)
record_grad
=
self
.
grad
(
self
.
network
,
weights
)(
record_datas
[
0
],
record_labels
[
0
],
sens
)
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
self
.
_l2_norm
)
grad_sum
=
list
(
record_grad
)
grad_len
=
len
(
record_grad
)
for
i
in
range
(
grad_len
):
grad_sum
[
i
]
=
grad_sum
[
i
].
asnumpy
()
grads
=
record_grad
total_loss
=
loss
for
i
in
range
(
1
,
self
.
_micro_batches
):
loss
=
self
.
network
(
record_datas
[
i
],
record_labels
[
i
])
sens
=
P
.
Fill
()(
P
.
DType
()(
loss
),
P
.
Shape
()(
loss
),
self
.
sens
)
record_grad
=
self
.
grad
(
self
.
network
,
weights
)(
record_datas
[
i
],
record_labels
[
i
],
sens
)
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
self
.
_l2_norm
)
for
j
in
range
(
grad_len
):
grad_sum
[
j
]
=
grad_sum
[
j
]
+
record_grad
[
j
].
asnumpy
()
for
i
in
range
(
grad_len
):
grad_sum
[
i
]
=
Tensor
(
grad_sum
[
i
],
ms
.
float32
)
grads
=
tuple
(
grad_sum
)
loss
=
self
.
network
(
data
,
label
)
grads
=
self
.
_tuple_add
(
grads
,
record_grad
)
total_loss
=
P
.
TensorAdd
()(
total_loss
,
loss
)
loss
=
P
.
Div
()(
total_loss
,
self
.
_micro_float
)
if
self
.
_mech
is
not
None
:
grad_noise
=
self
.
_hyper_map
(
self
.
_mech
,
grads
)
grads
=
self
.
_tuple_add
(
grads
,
grad_noise
)
grads
=
self
.
_hyper_map
(
F
.
partial
(
_grad_scale
,
self
.
_micro_float
),
grads
)
if
self
.
reducer_flag
:
# apply grad reducer on grads
...
...
setup.py
浏览文件 @
fe97f43f
...
...
@@ -18,7 +18,7 @@ from setuptools import setup
from
setuptools.command.egg_info
import
egg_info
from
setuptools.command.build_py
import
build_py
version
=
'0.
3
.0'
version
=
'0.
5
.0'
cur_dir
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
pkg_dir
=
os
.
path
.
join
(
cur_dir
,
'build'
)
...
...
tests/ut/python/diff_privacy/test_mechanisms.py
浏览文件 @
fe97f43f
...
...
@@ -17,6 +17,8 @@ different Privacy test.
import
pytest
from
mindspore
import
context
from
mindspore
import
Tensor
from
mindspore.common
import
dtype
as
mstype
from
mindarmour.diff_privacy
import
GaussianRandom
from
mindarmour.diff_privacy
import
AdaGaussianRandom
from
mindarmour.diff_privacy
import
MechanismsFactory
...
...
@@ -26,13 +28,13 @@ from mindarmour.diff_privacy import MechanismsFactory
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
def
test_gaussian
():
context
.
set_context
(
mode
=
context
.
PYNATIVE
_MODE
,
device_target
=
"Ascend"
)
shape
=
(
3
,
2
,
4
)
def
test_g
raph_g
aussian
():
context
.
set_context
(
mode
=
context
.
GRAPH
_MODE
,
device_target
=
"Ascend"
)
grad
=
Tensor
([
3
,
2
,
4
],
mstype
.
float32
)
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
net
=
GaussianRandom
(
norm_bound
,
initial_noise_multiplier
)
res
=
net
(
shape
)
res
=
net
(
grad
)
print
(
res
)
...
...
@@ -40,42 +42,99 @@ def test_gaussian():
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
def
test_
ada
_gaussian
():
def
test_
pynative
_gaussian
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
shape
=
(
3
,
2
,
4
)
grad
=
Tensor
([
3
,
2
,
4
],
mstype
.
float32
)
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
net
=
GaussianRandom
(
norm_bound
,
initial_noise_multiplier
)
res
=
net
(
grad
)
print
(
res
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
def
test_graph_ada_gaussian
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
grad
=
Tensor
([
3
,
2
,
4
],
mstype
.
float32
)
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
noise_decay_rate
=
0.5
decay_policy
=
"Step"
alpha
=
0.5
decay_policy
=
'Step'
net
=
AdaGaussianRandom
(
norm_bound
,
initial_noise_multiplier
,
noise_decay_rate
,
decay_policy
)
res
=
net
(
shape
)
noise_decay_rate
=
alpha
,
decay_policy
=
decay_policy
)
res
=
net
(
grad
)
print
(
res
)
def
test_factory
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
shape
=
(
3
,
2
,
4
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
def
test_graph_factory
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
grad
=
Tensor
([
3
,
2
,
4
],
mstype
.
float32
)
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
noise_decay_rate
=
0.5
decay_policy
=
"Step"
alpha
=
0.5
decay_policy
=
'Step'
noise_mechanism
=
MechanismsFactory
()
noise_construct
=
noise_mechanism
.
create
(
'Gaussian'
,
norm_bound
,
initial_noise_multiplier
)
noise
=
noise_construct
(
shape
)
noise
=
noise_construct
(
grad
)
print
(
'Gaussian noise: '
,
noise
)
ada_mechanism
=
MechanismsFactory
()
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
norm_bound
,
initial_noise_multiplier
,
noise_decay_rate
,
decay_policy
)
ada_noise
=
ada_noise_construct
(
shape
)
noise_decay_rate
=
alpha
,
decay_policy
=
decay_policy
)
ada_noise
=
ada_noise_construct
(
grad
)
print
(
'ada noise: '
,
ada_noise
)
if
__name__
==
'__main__'
:
# device_target can be "CPU", "GPU" or "Ascend"
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
def
test_pynative_ada_gaussian
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
grad
=
Tensor
([
3
,
2
,
4
],
mstype
.
float32
)
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
alpha
=
0.5
decay_policy
=
'Step'
net
=
AdaGaussianRandom
(
norm_bound
,
initial_noise_multiplier
,
noise_decay_rate
=
alpha
,
decay_policy
=
decay_policy
)
res
=
net
(
grad
)
print
(
res
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
def
test_pynative_factory
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
grad
=
Tensor
([
3
,
2
,
4
],
mstype
.
float32
)
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
alpha
=
0.5
decay_policy
=
'Step'
noise_mechanism
=
MechanismsFactory
()
noise_construct
=
noise_mechanism
.
create
(
'Gaussian'
,
norm_bound
,
initial_noise_multiplier
)
noise
=
noise_construct
(
grad
)
print
(
'Gaussian noise: '
,
noise
)
ada_mechanism
=
MechanismsFactory
()
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
norm_bound
,
initial_noise_multiplier
,
noise_decay_rate
=
alpha
,
decay_policy
=
decay_policy
)
ada_noise
=
ada_noise_construct
(
grad
)
print
(
'ada noise: '
,
ada_noise
)
tests/ut/python/diff_privacy/test_model_train.py
浏览文件 @
fe97f43f
...
...
@@ -21,13 +21,15 @@ from mindspore import nn
from
mindspore
import
context
import
mindspore.dataset
as
ds
from
mindarmour.diff_privacy
import
DPOptimizerClassFactory
from
mindarmour.diff_privacy
import
DPModel
from
mindarmour.diff_privacy
import
MechanismsFactory
from
mindarmour.diff_privacy
import
DPOptimizerClassFactory
from
test_network
import
LeNet5
def
dataset_generator
(
batch_size
,
batches
):
"""mock training data."""
data
=
np
.
random
.
random
((
batches
*
batch_size
,
1
,
32
,
32
)).
astype
(
np
.
float32
)
label
=
np
.
random
.
randint
(
0
,
10
,
batches
*
batch_size
).
astype
(
np
.
int32
)
for
i
in
range
(
batches
):
...
...
@@ -39,7 +41,7 @@ def dataset_generator(batch_size, batches):
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_card
@
pytest
.
mark
.
component_mindarmour
def
test_dp_model
():
def
test_dp_model
_pynative_mode
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
l2_norm_bound
=
1.0
initial_noise_multiplier
=
0.01
...
...
@@ -47,21 +49,50 @@ def test_dp_model():
batch_size
=
32
batches
=
128
epochs
=
1
micro_batches
=
2
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
factory_opt
=
DPOptimizerClassFactory
(
micro_batches
=
micro_batches
)
factory_opt
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
l2_norm_bound
,
initial_noise_multiplier
=
initial_noise_multiplier
)
net_opt
=
factory_opt
.
create
(
'Momentum'
)(
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
model
=
DPModel
(
micro_batches
=
micro_batches
,
norm_clip
=
l2_norm_bound
,
mech
=
None
,
network
=
network
,
loss_fn
=
loss
,
optimizer
=
net_opt
,
metrics
=
None
)
ms_ds
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
'data'
,
'label'
])
ms_ds
.
set_dataset_size
(
batch_size
*
batches
)
model
.
train
(
epochs
,
ms_ds
,
dataset_sink_mode
=
False
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_card
@
pytest
.
mark
.
component_mindarmour
def
test_dp_model_with_graph_mode
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
l2_norm_bound
=
1.0
initial_noise_multiplier
=
0.01
network
=
LeNet5
()
batch_size
=
32
batches
=
128
epochs
=
1
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
gaussian_mech
=
DPOptimizerClassFactory
(
micro_batches
=
2
)
gaussian_mech
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
l2_norm_bound
,
initial_noise_multiplier
=
initial_noise_multiplier
)
net_opt
=
gaussian_mech
.
create
(
'SGD'
)(
params
=
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
mech
=
MechanismsFactory
().
create
(
'Gaussian'
,
norm_bound
=
l2_norm_bound
,
initial_noise_multiplier
=
initial_noise_multiplier
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
model
=
DPModel
(
micro_batches
=
2
,
norm_clip
=
l2_norm_bound
,
dp_mech
=
gaussian_mech
.
mech
,
mech
=
mech
,
network
=
network
,
loss_fn
=
loss
,
optimizer
=
net_opt
,
metrics
=
None
)
ms_ds
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
'data'
,
'label'
])
ms_ds
.
set_dataset_size
(
batch_size
*
batches
)
model
.
train
(
epochs
,
ms_ds
)
model
.
train
(
epochs
,
ms_ds
,
dataset_sink_mode
=
False
)
tests/ut/python/diff_privacy/test_optimizer.py
浏览文件 @
fe97f43f
...
...
@@ -21,6 +21,7 @@ from mindarmour.diff_privacy import DPOptimizerClassFactory
from
test_network
import
LeNet5
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录