Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
79c6403d
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
3
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
79c6403d
编写于
7月 14, 2020
作者:
Z
ZhidanLiu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add new feature: adaptive clipping
上级
ac39d193
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
676 addition
and
218 deletion
+676
-218
example/mnist_demo/lenet5_config.py
example/mnist_demo/lenet5_config.py
+8
-3
example/mnist_demo/lenet5_dp.py
example/mnist_demo/lenet5_dp.py
+42
-19
mindarmour/diff_privacy/__init__.py
mindarmour/diff_privacy/__init__.py
+8
-4
mindarmour/diff_privacy/mechanisms/mechanisms.py
mindarmour/diff_privacy/mechanisms/mechanisms.py
+194
-48
mindarmour/diff_privacy/optimizer/optimizer.py
mindarmour/diff_privacy/optimizer/optimizer.py
+2
-2
mindarmour/diff_privacy/train/model.py
mindarmour/diff_privacy/train/model.py
+257
-110
tests/ut/python/diff_privacy/test_mechanisms.py
tests/ut/python/diff_privacy/test_mechanisms.py
+116
-10
tests/ut/python/diff_privacy/test_model_train.py
tests/ut/python/diff_privacy/test_model_train.py
+49
-22
未找到文件。
example/mnist_demo/lenet5_config.py
浏览文件 @
79c6403d
...
@@ -20,7 +20,7 @@ from easydict import EasyDict as edict
...
@@ -20,7 +20,7 @@ from easydict import EasyDict as edict
mnist_cfg
=
edict
({
mnist_cfg
=
edict
({
'num_classes'
:
10
,
# the number of classes of model's output
'num_classes'
:
10
,
# the number of classes of model's output
'lr'
:
0.1
,
# the learning rate of model's optimizer
'lr'
:
0.
0
1
,
# the learning rate of model's optimizer
'momentum'
:
0.9
,
# the momentum value of model's optimizer
'momentum'
:
0.9
,
# the momentum value of model's optimizer
'epoch_size'
:
10
,
# training epochs
'epoch_size'
:
10
,
# training epochs
'batch_size'
:
256
,
# batch size for training
'batch_size'
:
256
,
# batch size for training
...
@@ -33,8 +33,13 @@ mnist_cfg = edict({
...
@@ -33,8 +33,13 @@ mnist_cfg = edict({
'dataset_sink_mode'
:
False
,
# whether deliver all training data to device one time
'dataset_sink_mode'
:
False
,
# whether deliver all training data to device one time
'micro_batches'
:
16
,
# the number of small batches split from an original batch
'micro_batches'
:
16
,
# the number of small batches split from an original batch
'norm_clip'
:
1.0
,
# the clip bound of the gradients of model's training parameters
'norm_clip'
:
1.0
,
# the clip bound of the gradients of model's training parameters
'initial_noise_multiplier'
:
1
.5
,
# the initial multiplication coefficient of the noise added to training
'initial_noise_multiplier'
:
0
.5
,
# the initial multiplication coefficient of the noise added to training
# parameters' gradients
# parameters' gradients
'mechanisms'
:
'AdaGaussian'
,
# the method of adding noise in gradients while training
'noise_mechanisms'
:
'AdaGaussian'
,
# the method of adding noise in gradients while training
'clip_mechanisms'
:
'Gaussian'
,
# the method of adaptive clipping gradients while training
'clip_decay_policy'
:
'Linear'
,
# Decay policy of adaptive clipping, decay_policy must be in ['Linear', 'Geometric'].
'clip_learning_rate'
:
0.001
,
# Learning rate of update norm clip.
'target_unclipped_quantile'
:
0.9
,
# Target quantile of norm clip.
'fraction_stddev'
:
0.01
,
# The stddev of Gaussian normal which used in empirical_fraction.
'optimizer'
:
'Momentum'
# the base optimizer used for Differential privacy training
'optimizer'
:
'Momentum'
# the base optimizer used for Differential privacy training
})
})
example/mnist_demo/lenet5_dp.py
浏览文件 @
79c6403d
...
@@ -31,7 +31,8 @@ import mindspore.common.dtype as mstype
...
@@ -31,7 +31,8 @@ import mindspore.common.dtype as mstype
from
mindarmour.diff_privacy
import
DPModel
from
mindarmour.diff_privacy
import
DPModel
from
mindarmour.diff_privacy
import
PrivacyMonitorFactory
from
mindarmour.diff_privacy
import
PrivacyMonitorFactory
from
mindarmour.diff_privacy
import
MechanismsFactory
from
mindarmour.diff_privacy
import
NoiseMechanismsFactory
from
mindarmour.diff_privacy
import
ClipMechanismsFactory
from
mindarmour.utils.logger
import
LogUtil
from
mindarmour.utils.logger
import
LogUtil
from
lenet5_net
import
LeNet5
from
lenet5_net
import
LeNet5
from
lenet5_config
import
mnist_cfg
as
cfg
from
lenet5_config
import
mnist_cfg
as
cfg
...
@@ -87,11 +88,14 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1,
...
@@ -87,11 +88,14 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
# This configure can run both in pynative mode and graph mode
# This configure can run both in pynative mode and graph mode
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
cfg
.
device_target
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
cfg
.
device_target
)
network
=
LeNet5
()
network
=
LeNet5
()
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
save_checkpoint_steps
,
reduction
=
"mean"
)
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
save_checkpoint_steps
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
directory
=
'./trained_ckpt_file/'
,
directory
=
'./trained_ckpt_file/'
,
config
=
config_ck
)
config
=
config_ck
)
...
@@ -102,17 +106,33 @@ if __name__ == "__main__":
...
@@ -102,17 +106,33 @@ if __name__ == "__main__":
cfg
.
epoch_size
)
cfg
.
epoch_size
)
if
cfg
.
micro_batches
and
cfg
.
batch_size
%
cfg
.
micro_batches
!=
0
:
if
cfg
.
micro_batches
and
cfg
.
batch_size
%
cfg
.
micro_batches
!=
0
:
raise
ValueError
(
"Number of micro_batches should divide evenly batch_size"
)
raise
ValueError
(
# Create a factory class of DP mechanisms, this method is adding noise in gradients while training.
"Number of micro_batches should divide evenly batch_size"
)
# Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which
# Create a factory class of DP noise mechanisms, this method is adding noise
# means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise
# in gradients while training. Initial_noise_multiplier is suggested to be
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism.
# greater than 1.0, otherwise the privacy budget would be huge, which means
mech
=
MechanismsFactory
().
create
(
cfg
.
mechanisms
,
# that the privacy protection effect is weak. Mechanisms can be 'Gaussian'
norm_bound
=
cfg
.
norm_clip
,
# or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian'
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
)
# mechanism while be constant with 'Gaussian' mechanism.
net_opt
=
nn
.
Momentum
(
params
=
network
.
trainable_params
(),
learning_rate
=
cfg
.
lr
,
momentum
=
cfg
.
momentum
)
noise_mech
=
NoiseMechanismsFactory
().
create
(
cfg
.
noise_mechanisms
,
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps
norm_bound
=
cfg
.
norm_clip
,
# and delta) while training.
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
)
# Create a factory class of clip mechanisms, this method is to adaptive clip
# gradients while training, decay_policy support 'Linear' and 'Geometric',
# learning_rate is the learning rate to update clip_norm,
# target_unclipped_quantile is the target quantile of norm clip,
# fraction_stddev is the stddev of Gaussian normal which used in
# empirical_fraction, the formula is
# $empirical_fraction + N(0, fraction_stddev)$.
clip_mech
=
ClipMechanismsFactory
().
create
(
cfg
.
clip_mechanisms
,
decay_policy
=
cfg
.
clip_decay_policy
,
learning_rate
=
cfg
.
clip_learning_rate
,
target_unclipped_quantile
=
cfg
.
target_unclipped_quantile
,
fraction_stddev
=
cfg
.
fraction_stddev
)
net_opt
=
nn
.
Momentum
(
params
=
network
.
trainable_params
(),
learning_rate
=
cfg
.
lr
,
momentum
=
cfg
.
momentum
)
# Create a monitor for DP training. The function of the monitor is to
# compute and print the privacy budget(eps and delta) while training.
rdp_monitor
=
PrivacyMonitorFactory
.
create
(
'rdp'
,
rdp_monitor
=
PrivacyMonitorFactory
.
create
(
'rdp'
,
num_samples
=
60000
,
num_samples
=
60000
,
batch_size
=
cfg
.
batch_size
,
batch_size
=
cfg
.
batch_size
,
...
@@ -121,20 +141,23 @@ if __name__ == "__main__":
...
@@ -121,20 +141,23 @@ if __name__ == "__main__":
# Create the DP model for training.
# Create the DP model for training.
model
=
DPModel
(
micro_batches
=
cfg
.
micro_batches
,
model
=
DPModel
(
micro_batches
=
cfg
.
micro_batches
,
norm_clip
=
cfg
.
norm_clip
,
norm_clip
=
cfg
.
norm_clip
,
mech
=
mech
,
noise_mech
=
noise_mech
,
clip_mech
=
clip_mech
,
network
=
network
,
network
=
network
,
loss_fn
=
net_loss
,
loss_fn
=
net_loss
,
optimizer
=
net_opt
,
optimizer
=
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
metrics
=
{
"Accuracy"
:
Accuracy
()})
LOGGER
.
info
(
TAG
,
"============== Starting Training =============="
)
LOGGER
.
info
(
TAG
,
"============== Starting Training =============="
)
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
ckpoint_cb
,
LossMonitor
(),
rdp_monitor
],
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
ckpoint_cb
,
LossMonitor
(),
rdp_monitor
],
dataset_sink_mode
=
cfg
.
dataset_sink_mode
)
dataset_sink_mode
=
cfg
.
dataset_sink_mode
)
LOGGER
.
info
(
TAG
,
"============== Starting Testing =============="
)
LOGGER
.
info
(
TAG
,
"============== Starting Testing =============="
)
ckpt_file_name
=
'trained_ckpt_file/checkpoint_lenet-10_234.ckpt'
ckpt_file_name
=
'trained_ckpt_file/checkpoint_lenet-10_234.ckpt'
param_dict
=
load_checkpoint
(
ckpt_file_name
)
param_dict
=
load_checkpoint
(
ckpt_file_name
)
load_param_into_net
(
network
,
param_dict
)
load_param_into_net
(
network
,
param_dict
)
ds_eval
=
generate_mnist_dataset
(
os
.
path
.
join
(
cfg
.
data_path
,
'test'
),
batch_size
=
cfg
.
batch_size
)
ds_eval
=
generate_mnist_dataset
(
os
.
path
.
join
(
cfg
.
data_path
,
'test'
),
batch_size
=
cfg
.
batch_size
)
acc
=
model
.
eval
(
ds_eval
,
dataset_sink_mode
=
False
)
acc
=
model
.
eval
(
ds_eval
,
dataset_sink_mode
=
False
)
LOGGER
.
info
(
TAG
,
"============== Accuracy: %s =============="
,
acc
)
LOGGER
.
info
(
TAG
,
"============== Accuracy: %s =============="
,
acc
)
mindarmour/diff_privacy/__init__.py
浏览文件 @
79c6403d
"""
"""
This module provide Differential Privacy feature to protect user privacy.
This module provide Differential Privacy feature to protect user privacy.
"""
"""
from
.mechanisms.mechanisms
import
GaussianRandom
from
.mechanisms.mechanisms
import
Noise
GaussianRandom
from
.mechanisms.mechanisms
import
AdaGaussianRandom
from
.mechanisms.mechanisms
import
AdaGaussianRandom
from
.mechanisms.mechanisms
import
MechanismsFactory
from
.mechanisms.mechanisms
import
AdaClippingWithGaussianRandom
from
.mechanisms.mechanisms
import
NoiseMechanismsFactory
from
.mechanisms.mechanisms
import
ClipMechanismsFactory
from
.monitor.monitor
import
PrivacyMonitorFactory
from
.monitor.monitor
import
PrivacyMonitorFactory
from
.optimizer.optimizer
import
DPOptimizerClassFactory
from
.optimizer.optimizer
import
DPOptimizerClassFactory
from
.train.model
import
DPModel
from
.train.model
import
DPModel
__all__
=
[
'GaussianRandom'
,
__all__
=
[
'
Noise
GaussianRandom'
,
'AdaGaussianRandom'
,
'AdaGaussianRandom'
,
'MechanismsFactory'
,
'AdaClippingWithGaussianRandom'
,
'NoiseMechanismsFactory'
,
'ClipMechanismsFactory'
,
'PrivacyMonitorFactory'
,
'PrivacyMonitorFactory'
,
'DPOptimizerClassFactory'
,
'DPOptimizerClassFactory'
,
'DPModel'
]
'DPModel'
]
mindarmour/diff_privacy/mechanisms/mechanisms.py
浏览文件 @
79c6403d
...
@@ -28,11 +28,54 @@ from mindarmour.utils._check_param import check_param_in_range
...
@@ -28,11 +28,54 @@ from mindarmour.utils._check_param import check_param_in_range
from
mindarmour.utils.logger
import
LogUtil
from
mindarmour.utils.logger
import
LogUtil
LOGGER
=
LogUtil
.
get_instance
()
LOGGER
=
LogUtil
.
get_instance
()
TAG
=
'
Defense
'
TAG
=
'
NoiseMechanism
'
class
MechanismsFactory
:
class
ClipMechanismsFactory
:
""" Factory class of mechanisms"""
""" Factory class of clip mechanisms"""
def
__init__
(
self
):
pass
@
staticmethod
def
create
(
mech_name
,
*
args
,
**
kwargs
):
"""
Args:
mech_name(str): Clip noise generated strategy, support 'Gaussian' now.
args(Union[float, str]): Parameters used for creating clip mechanisms.
kwargs(Union[float, str]): Parameters used for creating clip
mechanisms.
Raises:
NameError: `mech_name` must be in ['Gaussian'].
Returns:
Mechanisms, class of noise generated Mechanism.
Examples:
>>> decay_policy = 'Linear'
>>> beta = Tensor(0.5, mstype.float32)
>>> norm_clip = Tensor(1.0, mstype.float32)
>>> beta_stddev = 0.1
>>> learning_rate = 0.1
>>> target_unclipped_quantile = 0.3
>>> clip_mechanism = ClipMechanismsFactory()
>>> ada_clip = clip_mechanism.create('Gaussian',
>>> decay_policy=decay_policy,
>>> learning_rate=learning_rate,
>>> target_unclipped_quantile=target_unclipped_quantile,
>>> fraction_stddev=beta_stddev)
>>> next_norm_clip = ada_clip(beta, norm_clip)
"""
if
mech_name
==
'Gaussian'
:
return
AdaClippingWithGaussianRandom
(
*
args
,
**
kwargs
)
raise
NameError
(
"The {} is not implement, please choose "
"['Gaussian']"
.
format
(
mech_name
))
class
NoiseMechanismsFactory
:
""" Factory class of noise mechanisms"""
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
...
@@ -56,42 +99,38 @@ class MechanismsFactory:
...
@@ -56,42 +99,38 @@ class MechanismsFactory:
Mechanisms, class of noise generated Mechanism.
Mechanisms, class of noise generated Mechanism.
Examples:
Examples:
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
>>> self.bn = nn.BatchNorm2d(64)
>>> self.relu = nn.ReLU()
>>> self.flatten = nn.Flatten()
>>> self.fc = nn.Dense(64*224*224, 12) # padding=0
>>>
>>> def construct(self, x):
>>> x = self.conv(x)
>>> x = self.bn(x)
>>> x = self.relu(x)
>>> x = self.flatten(x)
>>> out = self.fc(x)
>>> return out
>>> norm_clip = 1.0
>>> norm_clip = 1.0
>>> initial_noise_multiplier = 1.5
>>> initial_noise_multiplier = 0.01
>>> net = Net()
>>> network = LeNet5()
>>> batch_size = 32
>>> batches = 128
>>> epochs = 1
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> net_opt = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9)
>>> noise_mech = NoiseMechanismsFactory().create('Gaussian',
>>> mech = MechanismsFactory().create('Gaussian',
>>> norm_bound=norm_clip,
>>> norm_bound=norm_clip,
>>> initial_noise_multiplier=initial_noise_multiplier)
>>> initial_noise_multiplier=initial_noise_multiplier)
>>> clip_mech = ClipMechanismsFactory().create('Gaussian',
>>> decay_policy='Linear',
>>> learning_rate=0.01,
>>> target_unclipped_quantile=0.9,
>>> fraction_stddev=0.01)
>>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1,
>>> momentum=0.9)
>>> model = DPModel(micro_batches=2,
>>> model = DPModel(micro_batches=2,
>>> norm_clip=1.0,
>>> clip_mech=clip_mech,
>>> mech=mech,
>>> norm_clip=norm_clip,
>>> network=net,
>>> noise_mech=noise_mech,
>>> network=network,
>>> loss_fn=loss,
>>> loss_fn=loss,
>>> optimizer=net_opt,
>>> optimizer=net_opt,
>>> metrics=None)
>>> metrics=None)
>>> dataset = get_dataset()
>>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches),
>>> model.train(2, dataset)
>>> ['data', 'label'])
>>> ms_ds.set_dataset_size(batch_size * batches)
>>> model.train(epochs, ms_ds, dataset_sink_mode=False)
"""
"""
if
policy
==
'Gaussian'
:
if
policy
==
'Gaussian'
:
return
GaussianRandom
(
*
args
,
**
kwargs
)
return
Noise
GaussianRandom
(
*
args
,
**
kwargs
)
if
policy
==
'AdaGaussian'
:
if
policy
==
'AdaGaussian'
:
return
AdaGaussianRandom
(
*
args
,
**
kwargs
)
return
AdaGaussianRandom
(
*
args
,
**
kwargs
)
raise
NameError
(
"The {} is not implement, please choose "
raise
NameError
(
"The {} is not implement, please choose "
...
@@ -110,7 +149,7 @@ class Mechanisms(Cell):
...
@@ -110,7 +149,7 @@ class Mechanisms(Cell):
"""
"""
class
GaussianRandom
(
Mechanisms
):
class
Noise
GaussianRandom
(
Mechanisms
):
"""
"""
Gaussian noise generated mechanism.
Gaussian noise generated mechanism.
...
@@ -133,18 +172,21 @@ class GaussianRandom(Mechanisms):
...
@@ -133,18 +172,21 @@ class GaussianRandom(Mechanisms):
>>> gradients = Tensor([0.2, 0.9], mstype.float32)
>>> gradients = Tensor([0.2, 0.9], mstype.float32)
>>> norm_bound = 0.5
>>> norm_bound = 0.5
>>> initial_noise_multiplier = 1.5
>>> initial_noise_multiplier = 1.5
>>> net = GaussianRandom(norm_bound, initial_noise_multiplier)
>>> net =
Noise
GaussianRandom(norm_bound, initial_noise_multiplier)
>>> res = net(gradients)
>>> res = net(gradients)
>>> print(res)
>>> print(res)
"""
"""
def
__init__
(
self
,
norm_bound
=
0.5
,
initial_noise_multiplier
=
1.5
,
seed
=
0
,
policy
=
None
):
def
__init__
(
self
,
norm_bound
=
0.5
,
initial_noise_multiplier
=
1.5
,
seed
=
0
,
super
(
GaussianRandom
,
self
).
__init__
()
policy
=
None
):
super
(
NoiseGaussianRandom
,
self
).
__init__
()
self
.
_norm_bound
=
check_value_positive
(
'norm_bound'
,
norm_bound
)
self
.
_norm_bound
=
check_value_positive
(
'norm_bound'
,
norm_bound
)
self
.
_norm_bound
=
Tensor
(
norm_bound
,
mstype
.
float32
)
self
.
_norm_bound
=
Tensor
(
norm_bound
,
mstype
.
float32
)
self
.
_initial_noise_multiplier
=
check_value_positive
(
'initial_noise_multiplier'
,
self
.
_initial_noise_multiplier
=
check_value_positive
(
initial_noise_multiplier
)
'initial_noise_multiplier'
,
self
.
_initial_noise_multiplier
=
Tensor
(
initial_noise_multiplier
,
mstype
.
float32
)
initial_noise_multiplier
)
self
.
_initial_noise_multiplier
=
Tensor
(
initial_noise_multiplier
,
mstype
.
float32
)
self
.
_mean
=
Tensor
(
0
,
mstype
.
float32
)
self
.
_mean
=
Tensor
(
0
,
mstype
.
float32
)
self
.
_normal
=
P
.
Normal
(
seed
=
seed
)
self
.
_normal
=
P
.
Normal
(
seed
=
seed
)
self
.
_decay_policy
=
policy
self
.
_decay_policy
=
policy
...
@@ -201,17 +243,20 @@ class AdaGaussianRandom(Mechanisms):
...
@@ -201,17 +243,20 @@ class AdaGaussianRandom(Mechanisms):
noise_decay_rate
=
6e-4
,
decay_policy
=
'Time'
,
seed
=
0
):
noise_decay_rate
=
6e-4
,
decay_policy
=
'Time'
,
seed
=
0
):
super
(
AdaGaussianRandom
,
self
).
__init__
()
super
(
AdaGaussianRandom
,
self
).
__init__
()
norm_bound
=
check_value_positive
(
'norm_bound'
,
norm_bound
)
norm_bound
=
check_value_positive
(
'norm_bound'
,
norm_bound
)
initial_noise_multiplier
=
check_value_positive
(
'initial_noise_multiplier'
,
initial_noise_multiplier
=
check_value_positive
(
initial_noise_multiplier
)
'initial_noise_multiplier'
,
initial_noise_multiplier
)
self
.
_norm_bound
=
Tensor
(
norm_bound
,
mstype
.
float32
)
self
.
_norm_bound
=
Tensor
(
norm_bound
,
mstype
.
float32
)
initial_noise_multiplier
=
Tensor
(
initial_noise_multiplier
,
mstype
.
float32
)
initial_noise_multiplier
=
Tensor
(
initial_noise_multiplier
,
mstype
.
float32
)
self
.
_initial_noise_multiplier
=
Parameter
(
initial_noise_multiplier
,
self
.
_initial_noise_multiplier
=
Parameter
(
initial_noise_multiplier
,
name
=
'initial_noise_multiplier'
)
name
=
'initial_noise_multiplier'
)
self
.
_noise_multiplier
=
Parameter
(
initial_noise_multiplier
,
self
.
_noise_multiplier
=
Parameter
(
initial_noise_multiplier
,
name
=
'noise_multiplier'
)
name
=
'noise_multiplier'
)
self
.
_mean
=
Tensor
(
0
,
mstype
.
float32
)
self
.
_mean
=
Tensor
(
0
,
mstype
.
float32
)
noise_decay_rate
=
check_param_type
(
'noise_decay_rate'
,
noise_decay_rate
,
float
)
noise_decay_rate
=
check_param_type
(
'noise_decay_rate'
,
noise_decay_rate
,
float
)
check_param_in_range
(
'noise_decay_rate'
,
noise_decay_rate
,
0.0
,
1.0
)
check_param_in_range
(
'noise_decay_rate'
,
noise_decay_rate
,
0.0
,
1.0
)
self
.
_noise_decay_rate
=
Tensor
(
noise_decay_rate
,
mstype
.
float32
)
self
.
_noise_decay_rate
=
Tensor
(
noise_decay_rate
,
mstype
.
float32
)
if
decay_policy
not
in
[
'Time'
,
'Step'
,
'Exp'
]:
if
decay_policy
not
in
[
'Time'
,
'Step'
,
'Exp'
]:
...
@@ -232,7 +277,9 @@ class AdaGaussianRandom(Mechanisms):
...
@@ -232,7 +277,9 @@ class AdaGaussianRandom(Mechanisms):
Tensor, generated noise with shape like given gradients.
Tensor, generated noise with shape like given gradients.
"""
"""
shape
=
P
.
Shape
()(
gradients
)
shape
=
P
.
Shape
()(
gradients
)
noise
=
self
.
_normal
(
shape
,
self
.
_mean
,
self
.
_mul
(
self
.
_noise_multiplier
,
self
.
_norm_bound
))
noise
=
self
.
_normal
(
shape
,
self
.
_mean
,
self
.
_mul
(
self
.
_noise_multiplier
,
self
.
_norm_bound
))
return
noise
return
noise
...
@@ -241,10 +288,14 @@ class _MechanismsParamsUpdater(Cell):
...
@@ -241,10 +288,14 @@ class _MechanismsParamsUpdater(Cell):
Update mechanisms parameters, the parameters will refresh in train period.
Update mechanisms parameters, the parameters will refresh in train period.
Args:
Args:
policy(str): Pass in by the mechanisms class, mechanisms parameters update policy.
policy(str): Pass in by the mechanisms class, mechanisms parameters
decay_rate(Tensor): Pass in by the mechanisms class, hyper parameter for controlling the decay size.
update policy.
cur_noise_multiplier(Parameter): Pass in by the mechanisms class, current params value in this time.
decay_rate(Tensor): Pass in by the mechanisms class, hyper parameter for
init_noise_multiplier(Parameter):Pass in by the mechanisms class, initial params value to be updated.
controlling the decay size.
cur_noise_multiplier(Parameter): Pass in by the mechanisms class,
current params value in this time.
init_noise_multiplier(Parameter):Pass in by the mechanisms class,
initial params value to be updated.
Returns:
Returns:
Tuple, next params value.
Tuple, next params value.
...
@@ -281,5 +332,100 @@ class _MechanismsParamsUpdater(Cell):
...
@@ -281,5 +332,100 @@ class _MechanismsParamsUpdater(Cell):
next_noise_multiplier
=
self
.
_assign
(
self
.
_cur_noise_multiplier
,
next_noise_multiplier
=
self
.
_assign
(
self
.
_cur_noise_multiplier
,
self
.
_mul
(
temp
,
self
.
_cur_noise_multiplier
))
self
.
_mul
(
temp
,
self
.
_cur_noise_multiplier
))
else
:
else
:
next_noise_multiplier
=
self
.
_assign
(
self
.
_cur_noise_multiplier
,
self
.
_div
(
self
.
_one
,
self
.
_exp
(
self
.
_one
)))
next_noise_multiplier
=
self
.
_assign
(
self
.
_cur_noise_multiplier
,
self
.
_div
(
self
.
_one
,
self
.
_exp
(
self
.
_one
)))
return
next_noise_multiplier
return
next_noise_multiplier
class
AdaClippingWithGaussianRandom
(
Cell
):
"""
Adaptive clipping. If `decay_policy` is 'Linear', the update formula is
$ norm_clip = norm_clip - learning_rate*(beta-target_unclipped_quantile)$.
`decay_policy` is 'Geometric', the update formula is
$ norm_clip = norm_clip*exp(-learning_rate*(empirical_fraction-target_unclipped_quantile))$.
where beta is the empirical fraction of samples with the value at most
`target_unclipped_quantile`.
Args:
decay_policy(str): Decay policy of adaptive clipping, decay_policy must
be in ['Linear', 'Geometric']. Default: Linear.
learning_rate(float): Learning rate of update norm clip. Default: 0.01.
target_unclipped_quantile(float): Target quantile of norm clip. Default: 0.9.
fraction_stddev(float): The stddev of Gaussian normal which used in
empirical_fraction, the formula is $empirical_fraction + N(0, fraction_stddev)$.
seed(int): Original random seed, if seed=0 random normal will use secure
random number. IF seed!=0 random normal will generate values using
given seed. Default: 0.
Returns:
Tensor, undated norm clip .
Examples:
>>> decay_policy = 'Linear'
>>> beta = Tensor(0.5, mstype.float32)
>>> norm_clip = Tensor(1.0, mstype.float32)
>>> beta_stddev = 0.01
>>> learning_rate = 0.001
>>> target_unclipped_quantile = 0.9
>>> ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy,
>>> learning_rate=learning_rate,
>>> target_unclipped_quantile=target_unclipped_quantile,
>>> fraction_stddev=beta_stddev)
>>> next_norm_clip = ada_clip(beta, norm_clip)
"""
def
__init__
(
self
,
decay_policy
=
'Linear'
,
learning_rate
=
0.001
,
target_unclipped_quantile
=
0.9
,
fraction_stddev
=
0.01
,
seed
=
0
):
super
(
AdaClippingWithGaussianRandom
,
self
).
__init__
()
if
decay_policy
not
in
[
'Linear'
,
'Geometric'
]:
msg
=
"decay policy of adaptive clip must be in ['Linear', 'Geometric'],
\
but got: {}"
.
format
(
decay_policy
)
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
self
.
_decay_policy
=
decay_policy
learning_rate
=
check_param_type
(
'learning_rate'
,
learning_rate
,
float
)
learning_rate
=
check_value_positive
(
'learning_rate'
,
learning_rate
)
self
.
_learning_rate
=
Tensor
(
learning_rate
,
mstype
.
float32
)
fraction_stddev
=
check_param_type
(
'fraction_stddev'
,
fraction_stddev
,
float
)
self
.
_fraction_stddev
=
Tensor
(
fraction_stddev
,
mstype
.
float32
)
target_unclipped_quantile
=
check_param_type
(
'target_unclipped_quantile'
,
target_unclipped_quantile
,
float
)
self
.
_target_unclipped_quantile
=
Tensor
(
target_unclipped_quantile
,
mstype
.
float32
)
self
.
_zero
=
Tensor
(
0
,
mstype
.
float32
)
self
.
_add
=
P
.
TensorAdd
()
self
.
_sub
=
P
.
Sub
()
self
.
_mul
=
P
.
Mul
()
self
.
_exp
=
P
.
Exp
()
self
.
_normal
=
P
.
Normal
(
seed
=
seed
)
def
construct
(
self
,
empirical_fraction
,
norm_clip
):
"""
Update value of norm_clip.
Args:
empirical_fraction(Tensor): empirical fraction of samples with the
value at most `target_unclipped_quantile`.
norm_clip(Tensor): Clipping bound for the l2 norm of the gradients.
Returns:
Tensor, generated noise with shape like given gradients.
"""
fraction_noise
=
self
.
_normal
((
1
,),
self
.
_zero
,
self
.
_fraction_stddev
)
empirical_fraction
=
self
.
_add
(
empirical_fraction
,
fraction_noise
)
if
self
.
_decay_policy
==
'Linear'
:
grad_clip
=
self
.
_sub
(
empirical_fraction
,
self
.
_target_unclipped_quantile
)
next_norm_clip
=
self
.
_sub
(
norm_clip
,
self
.
_mul
(
self
.
_learning_rate
,
grad_clip
))
# decay_policy == 'Geometric'
else
:
grad_clip
=
self
.
_sub
(
empirical_fraction
,
self
.
_target_unclipped_quantile
)
grad_clip
=
self
.
_exp
(
self
.
_mul
(
-
self
.
_learning_rate
,
grad_clip
))
next_norm_clip
=
self
.
_mul
(
norm_clip
,
grad_clip
)
return
next_norm_clip
mindarmour/diff_privacy/optimizer/optimizer.py
浏览文件 @
79c6403d
...
@@ -22,7 +22,7 @@ from mindspore.ops import functional as F
...
@@ -22,7 +22,7 @@ from mindspore.ops import functional as F
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
dtype
as
mstype
from
mindarmour.utils.logger
import
LogUtil
from
mindarmour.utils.logger
import
LogUtil
from
mindarmour.diff_privacy
import
MechanismsFactory
from
mindarmour.diff_privacy
import
Noise
MechanismsFactory
from
mindarmour.diff_privacy.mechanisms.mechanisms
import
_MechanismsParamsUpdater
from
mindarmour.diff_privacy.mechanisms.mechanisms
import
_MechanismsParamsUpdater
from
mindarmour.utils._check_param
import
check_int_positive
from
mindarmour.utils._check_param
import
check_int_positive
...
@@ -70,7 +70,7 @@ class DPOptimizerClassFactory:
...
@@ -70,7 +70,7 @@ class DPOptimizerClassFactory:
"""
"""
def
__init__
(
self
,
micro_batches
=
2
):
def
__init__
(
self
,
micro_batches
=
2
):
self
.
_mech_factory
=
MechanismsFactory
()
self
.
_mech_factory
=
Noise
MechanismsFactory
()
self
.
mech
=
None
self
.
mech
=
None
self
.
_micro_batches
=
check_int_positive
(
'micro_batches'
,
micro_batches
)
self
.
_micro_batches
=
check_int_positive
(
'micro_batches'
,
micro_batches
)
...
...
mindarmour/diff_privacy/train/model.py
浏览文件 @
79c6403d
...
@@ -48,7 +48,8 @@ from mindspore.nn import Cell
...
@@ -48,7 +48,8 @@ from mindspore.nn import Cell
from
mindspore
import
ParameterTuple
from
mindspore
import
ParameterTuple
from
mindarmour.utils.logger
import
LogUtil
from
mindarmour.utils.logger
import
LogUtil
from
mindarmour.diff_privacy.mechanisms.mechanisms
import
_MechanismsParamsUpdater
from
mindarmour.diff_privacy.mechanisms.mechanisms
import
\
_MechanismsParamsUpdater
from
mindarmour.utils._check_param
import
check_param_type
from
mindarmour.utils._check_param
import
check_param_type
from
mindarmour.utils._check_param
import
check_value_positive
from
mindarmour.utils._check_param
import
check_value_positive
from
mindarmour.utils._check_param
import
check_int_positive
from
mindarmour.utils._check_param
import
check_int_positive
...
@@ -64,7 +65,7 @@ _reciprocal = P.Reciprocal()
...
@@ -64,7 +65,7 @@ _reciprocal = P.Reciprocal()
@
_grad_scale
.
register
(
"Tensor"
,
"Tensor"
)
@
_grad_scale
.
register
(
"Tensor"
,
"Tensor"
)
def
tensor_grad_scale
(
scale
,
grad
):
def
tensor_grad_scale
(
scale
,
grad
):
""" grad scaling """
""" grad scaling """
return
grad
*
F
.
cast
(
_reciprocal
(
scale
),
F
.
dtype
(
grad
))
return
grad
*
F
.
cast
(
_reciprocal
(
scale
),
F
.
dtype
(
grad
))
class
DPModel
(
Model
):
class
DPModel
(
Model
):
...
@@ -72,9 +73,14 @@ class DPModel(Model):
...
@@ -72,9 +73,14 @@ class DPModel(Model):
This class is overload mindspore.train.model.Model.
This class is overload mindspore.train.model.Model.
Args:
Args:
micro_batches (int): The number of small batches split from an original batch. Default: 2.
micro_batches (int): The number of small batches split from an original
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0.
batch. Default: 2.
mech (Mechanisms): The object can generate the different type of noise. Default: None.
norm_clip (float): Use to clip the bound, if set 1, will retun the
original data. Default: 1.0.
noise_mech (Mechanisms): The object can generate the different type of
noise. Default: None.
clip_mech (Mechanisms): The object is used to update the adaptive clip .
Default: None.
Examples:
Examples:
>>> norm_clip = 1.0
>>> norm_clip = 1.0
...
@@ -89,63 +95,82 @@ class DPModel(Model):
...
@@ -89,63 +95,82 @@ class DPModel(Model):
>>> factory_opt.set_mechanisms('Gaussian',
>>> factory_opt.set_mechanisms('Gaussian',
>>> norm_bound=norm_clip,
>>> norm_bound=norm_clip,
>>> initial_noise_multiplier=initial_noise_multiplier)
>>> initial_noise_multiplier=initial_noise_multiplier)
>>> net_opt = factory_opt.create('Momentum')(network.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_opt = factory_opt.create('Momentum')(network.trainable_params(),
>>> learning_rate=0.1, momentum=0.9)
>>> clip_mech = ClipMechanismsFactory().create('Gaussian',
>>> decay_policy='Linear',
>>> learning_rate=0.01,
>>> target_unclipped_quantile=0.9,
>>> fraction_stddev=0.01)
>>> model = DPModel(micro_batches=micro_batches,
>>> model = DPModel(micro_batches=micro_batches,
>>> norm_clip=norm_clip,
>>> norm_clip=norm_clip,
>>> mech=None,
>>> clip_mech=clip_mech,
>>> noise_mech=None,
>>> network=network,
>>> network=network,
>>> loss_fn=loss,
>>> loss_fn=loss,
>>> optimizer=net_opt,
>>> optimizer=net_opt,
>>> metrics=None)
>>> metrics=None)
>>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label'])
>>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches),
>>> ms_ds.set_dataset_size(batch_size * batches)
>>> ['data', 'label'])
>>> ms_ds.set_dataset_size(batch_size*batches)
>>> model.train(epochs, ms_ds, dataset_sink_mode=False)
>>> model.train(epochs, ms_ds, dataset_sink_mode=False)
"""
"""
def
__init__
(
self
,
micro_batches
=
2
,
norm_clip
=
1.0
,
mech
=
None
,
**
kwargs
):
def
__init__
(
self
,
micro_batches
=
2
,
norm_clip
=
1.0
,
noise_mech
=
None
,
clip_mech
=
None
,
**
kwargs
):
if
micro_batches
:
if
micro_batches
:
self
.
_micro_batches
=
check_int_positive
(
'micro_batches'
,
micro_batches
)
self
.
_micro_batches
=
check_int_positive
(
'micro_batches'
,
micro_batches
)
else
:
else
:
self
.
_micro_batches
=
None
self
.
_micro_batches
=
None
norm_clip
=
check_param_type
(
'norm_clip'
,
norm_clip
,
float
)
norm_clip
=
check_param_type
(
'norm_clip'
,
norm_clip
,
float
)
self
.
_norm_clip
=
check_value_positive
(
'norm_clip'
,
norm_clip
)
norm_clip
=
check_value_positive
(
'norm_clip'
,
norm_clip
)
if
mech
is
not
None
and
"DPOptimizer"
in
kwargs
[
'optimizer'
].
__class__
.
__name__
:
norm_clip
=
Tensor
(
norm_clip
,
mstype
.
float32
)
msg
=
'DPOptimizer is not supported while mech is not None'
self
.
_norm_clip
=
Parameter
(
norm_clip
,
'norm_clip'
)
if
noise_mech
is
not
None
and
"DPOptimizer"
in
kwargs
[
'optimizer'
].
__class__
.
__name__
:
msg
=
'DPOptimizer is not supported while noise_mech is not None'
LOGGER
.
error
(
TAG
,
msg
)
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
if
mech
is
None
:
if
noise_
mech
is
None
:
if
"DPOptimizer"
in
kwargs
[
'optimizer'
].
__class__
.
__name__
:
if
"DPOptimizer"
in
kwargs
[
'optimizer'
].
__class__
.
__name__
:
if
context
.
get_context
(
'mode'
)
!=
context
.
PYNATIVE_MODE
:
if
context
.
get_context
(
'mode'
)
!=
context
.
PYNATIVE_MODE
:
msg
=
'DPOptimizer just support pynative mode currently.'
msg
=
'DPOptimizer just support pynative mode currently.'
LOGGER
.
error
(
TAG
,
msg
)
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
else
:
else
:
msg
=
'DPModel should set mech or DPOptimizer configure, please refer to example.'
msg
=
'DPModel should set noise_mech or DPOptimizer configure, '
\
'please refer to example.'
LOGGER
.
error
(
TAG
,
msg
)
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
self
.
_mech
=
mech
self
.
_noise_mech
=
noise_mech
if
clip_mech
is
not
None
:
self
.
_clip_mech
=
clip_mech
super
(
DPModel
,
self
).
__init__
(
**
kwargs
)
super
(
DPModel
,
self
).
__init__
(
**
kwargs
)
def
_amp_build_train_network
(
self
,
network
,
optimizer
,
loss_fn
=
None
,
level
=
'O0'
,
**
kwargs
):
def
_amp_build_train_network
(
self
,
network
,
optimizer
,
loss_fn
=
None
,
level
=
'O0'
,
**
kwargs
):
"""
"""
Build the mixed precision training cell automatically.
Build the mixed precision training cell automatically.
Args:
Args:
network (Cell): Definition of the network.
network (Cell): Definition of the network.
loss_fn (Union[None, Cell]): Definition of the loss_fn. If None,
the `network` should have the loss inside.
loss_fn (Union[None, Cell]): Definition of the loss_fn. If None,
Default: None.
the `network` should have the loss inside.
Default: None.
optimizer (Optimizer): Optimizer to update the Parameter.
optimizer (Optimizer): Optimizer to update the Parameter.
level (str): Supports [O0, O2]. Default: "O0".
level (str): Supports [O0, O2]. Default: "O0".
- O0: Do not change.
- O0: Do not change.
- O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
- O2: Cast network to float16, keep batchnorm and `loss_fn`
using dynamic loss scale.
(if set) run in float32, using dynamic loss scale.
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`.
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16`
If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting.
or `mstype.float32`. If set to `mstype.float16`, use `float16`
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting.
mode to train. If set, overwrite the level setting.
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set,
scale the loss by LossScaleManager. If set, overwrite the level setting.
overwrite the level setting.
loss_scale_manager (Union[None, LossScaleManager]): If None, not
scale the loss, or else scale the loss by LossScaleManager.
If set, overwrite the level setting.
"""
"""
validator
.
check_value_type
(
'network'
,
network
,
nn
.
Cell
,
None
)
validator
.
check_value_type
(
'network'
,
network
,
nn
.
Cell
,
None
)
validator
.
check_value_type
(
'optimizer'
,
optimizer
,
nn
.
Optimizer
,
None
)
validator
.
check_value_type
(
'optimizer'
,
optimizer
,
nn
.
Optimizer
,
None
)
...
@@ -161,9 +186,11 @@ class DPModel(Model):
...
@@ -161,9 +186,11 @@ class DPModel(Model):
_do_keep_batchnorm_fp32
(
network
)
_do_keep_batchnorm_fp32
(
network
)
if
loss_fn
:
if
loss_fn
:
network
=
_add_loss_network
(
network
,
loss_fn
,
config
.
cast_model_type
)
network
=
_add_loss_network
(
network
,
loss_fn
,
config
.
cast_model_type
)
if
_get_parallel_mode
()
in
(
ParallelMode
.
SEMI_AUTO_PARALLEL
,
ParallelMode
.
AUTO_PARALLEL
):
if
_get_parallel_mode
()
in
(
ParallelMode
.
SEMI_AUTO_PARALLEL
,
ParallelMode
.
AUTO_PARALLEL
):
network
=
_VirtualDatasetCell
(
network
)
network
=
_VirtualDatasetCell
(
network
)
loss_scale
=
1.0
loss_scale
=
1.0
...
@@ -173,9 +200,12 @@ class DPModel(Model):
...
@@ -173,9 +200,12 @@ class DPModel(Model):
update_cell
=
loss_scale_manager
.
get_update_cell
()
update_cell
=
loss_scale_manager
.
get_update_cell
()
if
update_cell
is
not
None
:
if
update_cell
is
not
None
:
# only cpu not support `TrainOneStepWithLossScaleCell` for control flow.
# only cpu not support `TrainOneStepWithLossScaleCell` for control flow.
if
not
context
.
get_context
(
"enable_ge"
)
and
context
.
get_context
(
"device_target"
)
==
"CPU"
:
if
not
context
.
get_context
(
"enable_ge"
)
and
context
.
get_context
(
msg
=
"Only `loss_scale_manager=None` and `loss_scale_manager=FixedLossScaleManager(drop_overflow"
\
"device_target"
)
==
"CPU"
:
"_update=False)` are supported in current version. If you use `O2` option, please use "
\
msg
=
"Only `loss_scale_manager=None` and "
\
"`loss_scale_manager=FixedLossScaleManager(drop_overflow"
\
"_update=False)` are supported in current version. "
\
"If you use `O2` option, please use "
\
"`loss_scale_manager=None` or `FixedLossScaleManager`"
"`loss_scale_manager=None` or `FixedLossScaleManager`"
LOGGER
.
error
(
TAG
,
msg
)
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
...
@@ -184,15 +214,17 @@ class DPModel(Model):
...
@@ -184,15 +214,17 @@ class DPModel(Model):
scale_update_cell
=
update_cell
,
scale_update_cell
=
update_cell
,
micro_batches
=
self
.
_micro_batches
,
micro_batches
=
self
.
_micro_batches
,
norm_clip
=
self
.
_norm_clip
,
norm_clip
=
self
.
_norm_clip
,
mech
=
self
.
_mech
).
set_train
()
clip_mech
=
self
.
_clip_mech
,
noise_mech
=
self
.
_noise_mech
).
set_train
()
return
network
return
network
network
=
_TrainOneStepCell
(
network
,
network
=
_TrainOneStepCell
(
network
,
optimizer
,
optimizer
,
self
.
_norm_clip
,
loss_scale
,
loss_scale
,
micro_batches
=
self
.
_micro_batches
,
micro_batches
=
self
.
_micro_batches
,
norm_clip
=
self
.
_norm_clip
,
clip_mech
=
self
.
_clip_mech
,
mech
=
self
.
_mech
).
set_train
()
noise_mech
=
self
.
_noise
_mech
).
set_train
()
return
network
return
network
def
_build_train_network
(
self
):
def
_build_train_network
(
self
):
...
@@ -233,7 +265,8 @@ class DPModel(Model):
...
@@ -233,7 +265,8 @@ class DPModel(Model):
elif
self
.
_loss_fn
:
elif
self
.
_loss_fn
:
network
=
nn
.
WithLossCell
(
network
,
self
.
_loss_fn
)
network
=
nn
.
WithLossCell
(
network
,
self
.
_loss_fn
)
if
self
.
_parallel_mode
in
(
ParallelMode
.
SEMI_AUTO_PARALLEL
,
ParallelMode
.
AUTO_PARALLEL
):
if
self
.
_parallel_mode
in
(
ParallelMode
.
SEMI_AUTO_PARALLEL
,
ParallelMode
.
AUTO_PARALLEL
):
network
.
set_auto_parallel
()
network
.
set_auto_parallel
()
return
network
return
network
...
@@ -267,11 +300,10 @@ class _ClipGradients(nn.Cell):
...
@@ -267,11 +300,10 @@ class _ClipGradients(nn.Cell):
new_grads
=
()
new_grads
=
()
for
grad
in
grads
:
for
grad
in
grads
:
if
clip_type
==
0
:
if
clip_type
==
0
:
t
=
C
.
clip_by_value
(
grad
,
F
.
tuple_to_array
((
-
clip_value
,)),
norm
=
C
.
clip_by_value
(
grad
,
-
clip_value
,
clip_value
)
F
.
tuple_to_array
((
clip_value
,)))
else
:
else
:
t
=
self
.
clip_by_norm
(
grad
,
F
.
tuple_to_array
((
clip_value
,))
)
norm
=
self
.
clip_by_norm
(
grad
,
clip_value
)
new_grads
=
new_grads
+
(
t
,)
new_grads
=
new_grads
+
(
norm
,)
return
new_grads
return
new_grads
...
@@ -292,20 +324,27 @@ class _TrainOneStepWithLossScaleCell(Cell):
...
@@ -292,20 +324,27 @@ class _TrainOneStepWithLossScaleCell(Cell):
r
"""
r
"""
Network training with loss scaling.
Network training with loss scaling.
This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update
This is a training step with loss scaling. It takes a network, an optimizer
Cell as args. The loss scale value can be updated in both host side or device side. The
and possibly a scale update Cell as args. The loss scale value can be
TrainOneStepWithLossScaleCell will be compiled to be graph which takes `data`, `label`, `sens` as input
updated in both host side or device side. The TrainOneStepWithLossScaleCell
data. The `sens` is acting as loss scaling value. If you want to update it on host side, the value should
will be compiled to be graph which takes `data`, `label`, `sens` as input
be provided. If `sens` is not given, the loss scale update logic should be provied by `scale_update_cell`.
data. The `sens` is acting as loss scaling value. If you want to update it
If `scale_update_cell` is not None and `sens` is provided, the `scale_update_cell` will be ignored.
on host side, the value should be provided. If `sens` is not given, the loss
scale update logic should be provied by `scale_update_cell`. If
`scale_update_cell` is not None and `sens` is provided, the
`scale_update_cell` will be ignored.
Args:
Args:
network (Cell): The training network.
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
optimizer (Cell): Optimizer for updating the weights.
scale_update_cell(Cell): The loss scaling update logic cell. Default: None.
scale_update_cell(Cell): The loss scaling update logic cell.
micro_batches (int): The number of small batches split from an original batch. Default: None.
Default: None.
norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0.
micro_batches (int): The number of small batches split from an original
mech (Mechanisms): The object can generate the different type of noise. Default: None.
batch. Default: None.
norm_clip (Tensor): Use to clip the bound, if set 1, will return the
original data. Default: 1.0.
noise_mech (Mechanisms): The object can generate the different type of
noise. Default: None.
Inputs:
Inputs:
- **inputs** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **inputs** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
...
@@ -320,7 +359,9 @@ class _TrainOneStepWithLossScaleCell(Cell):
...
@@ -320,7 +359,9 @@ class _TrainOneStepWithLossScaleCell(Cell):
- **loss_scale** (Tensor) - Tensor with shape :math:`()`.
- **loss_scale** (Tensor) - Tensor with shape :math:`()`.
"""
"""
def
__init__
(
self
,
network
,
optimizer
,
scale_update_cell
=
None
,
micro_batches
=
None
,
norm_clip
=
1.0
,
mech
=
None
):
def
__init__
(
self
,
network
,
optimizer
,
scale_update_cell
=
None
,
micro_batches
=
None
,
norm_clip
=
1.0
,
noise_mech
=
None
,
clip_mech
=
None
):
super
(
_TrainOneStepWithLossScaleCell
,
self
).
__init__
(
auto_prefix
=
False
)
super
(
_TrainOneStepWithLossScaleCell
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
network
=
network
self
.
network
.
set_grad
()
self
.
network
.
set_grad
()
...
@@ -346,39 +387,54 @@ class _TrainOneStepWithLossScaleCell(Cell):
...
@@ -346,39 +387,54 @@ class _TrainOneStepWithLossScaleCell(Cell):
self
.
allreduce
=
P
.
AllReduce
()
self
.
allreduce
=
P
.
AllReduce
()
self
.
parallel_mode
=
_get_parallel_mode
()
self
.
parallel_mode
=
_get_parallel_mode
()
self
.
grad_reducer
=
F
.
identity
self
.
grad_reducer
=
F
.
identity
self
.
reducer_flag
=
self
.
parallel_mode
in
[
ParallelMode
.
DATA_PARALLEL
,
ParallelMode
.
HYBRID_PARALLEL
]
self
.
reducer_flag
=
self
.
parallel_mode
in
[
ParallelMode
.
DATA_PARALLEL
,
ParallelMode
.
HYBRID_PARALLEL
]
if
self
.
reducer_flag
:
if
self
.
reducer_flag
:
mean
=
_get_mirror_mean
()
mean
=
_get_mirror_mean
()
degree
=
_get_device_num
()
degree
=
_get_device_num
()
self
.
grad_reducer
=
DistributedGradReducer
(
optimizer
.
parameters
,
mean
,
degree
)
self
.
grad_reducer
=
DistributedGradReducer
(
optimizer
.
parameters
,
mean
,
degree
)
self
.
is_distributed
=
self
.
parallel_mode
!=
ParallelMode
.
STAND_ALONE
self
.
is_distributed
=
self
.
parallel_mode
!=
ParallelMode
.
STAND_ALONE
self
.
loss_scale
=
None
self
.
loss_scale
=
None
self
.
loss_scaling_manager
=
scale_update_cell
self
.
loss_scaling_manager
=
scale_update_cell
if
scale_update_cell
:
if
scale_update_cell
:
self
.
loss_scale
=
Parameter
(
Tensor
(
scale_update_cell
.
get_loss_scale
(),
dtype
=
mstype
.
float32
),
self
.
loss_scale
=
Parameter
(
name
=
"loss_scale"
)
Tensor
(
scale_update_cell
.
get_loss_scale
(),
dtype
=
mstype
.
float32
),
name
=
"loss_scale"
)
self
.
add_flags
(
has_effect
=
True
)
self
.
add_flags
(
has_effect
=
True
)
# dp params
# dp params
self
.
_micro_batches
=
micro_batches
self
.
_micro_batches
=
micro_batches
norm_clip
=
check_param_type
(
'norm_clip'
,
norm_clip
,
float
)
self
.
_norm_clip
=
norm_clip
self
.
_l2_norm
=
check_value_positive
(
'norm_clip'
,
norm_clip
)
self
.
_split
=
P
.
Split
(
0
,
self
.
_micro_batches
)
self
.
_split
=
P
.
Split
(
0
,
self
.
_micro_batches
)
self
.
_clip_by_global_norm
=
_ClipGradients
()
self
.
_clip_by_global_norm
=
_ClipGradients
()
self
.
_mech
=
mech
self
.
_noise_mech
=
noise_mech
self
.
_clip_mech
=
clip_mech
self
.
_add
=
P
.
TensorAdd
()
self
.
_norm
=
nn
.
Norm
()
self
.
_tuple_add
=
_TupleAdd
()
self
.
_tuple_add
=
_TupleAdd
()
self
.
_hyper_map
=
C
.
HyperMap
()
self
.
_hyper_map
=
C
.
HyperMap
()
self
.
_micro_float
=
Tensor
(
micro_batches
,
mstype
.
float32
)
self
.
_micro_float
=
Tensor
(
micro_batches
,
mstype
.
float32
)
self
.
_zero
=
Tensor
(
0
,
mstype
.
float32
)
self
.
_mech_param_updater
=
None
self
.
_assign
=
P
.
Assign
()
if
self
.
_mech
is
not
None
and
self
.
_mech
.
_decay_policy
is
not
None
:
self
.
_div
=
P
.
Div
()
self
.
_mech_param_updater
=
_MechanismsParamsUpdater
(
policy
=
self
.
_mech
.
_decay_policy
,
self
.
_sqrt
=
P
.
Sqrt
()
decay_rate
=
self
.
_mech
.
_noise_decay_rate
,
self
.
_reduce_sum
=
P
.
ReduceSum
()
cur_noise_multiplier
=
self
.
_square_all
=
P
.
Square
()
self
.
_mech
.
_noise_multiplier
,
self
.
_less
=
P
.
Less
()
init_noise_multiplier
=
self
.
_cast
=
P
.
Cast
()
self
.
_mech
.
_initial_noise_multiplier
)
self
.
_noise_mech_param_updater
=
None
if
self
.
_noise_mech
is
not
None
and
self
.
_noise_mech
.
_decay_policy
is
not
None
:
self
.
_noise_mech_param_updater
=
_MechanismsParamsUpdater
(
policy
=
self
.
_noise_mech
.
_decay_policy
,
decay_rate
=
self
.
_noise_mech
.
_noise_decay_rate
,
cur_noise_multiplier
=
self
.
_noise_mech
.
_noise_multiplier
,
init_noise_multiplier
=
self
.
_noise_mech
.
_initial_noise_multiplier
)
def
construct
(
self
,
data
,
label
,
sens
=
None
):
def
construct
(
self
,
data
,
label
,
sens
=
None
):
"""
"""
...
@@ -402,30 +458,62 @@ class _TrainOneStepWithLossScaleCell(Cell):
...
@@ -402,30 +458,62 @@ class _TrainOneStepWithLossScaleCell(Cell):
record_labels
=
self
.
_split
(
label
)
record_labels
=
self
.
_split
(
label
)
# first index
# first index
loss
=
self
.
network
(
record_datas
[
0
],
record_labels
[
0
])
loss
=
self
.
network
(
record_datas
[
0
],
record_labels
[
0
])
scaling_sens_filled
=
C
.
ones_like
(
loss
)
*
F
.
cast
(
scaling_sens
,
F
.
dtype
(
loss
))
scaling_sens_filled
=
C
.
ones_like
(
loss
)
*
F
.
cast
(
scaling_sens
,
record_grad
=
self
.
grad
(
self
.
network
,
weights
)(
record_datas
[
0
],
record_labels
[
0
],
scaling_sens_filled
)
F
.
dtype
(
loss
))
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
self
.
_l2_norm
)
record_grad
=
self
.
grad
(
self
.
network
,
weights
)(
record_datas
[
0
],
record_labels
[
0
],
scaling_sens_filled
)
beta
=
self
.
_zero
square_sum
=
self
.
_zero
for
grad
in
record_grad
:
square_sum
=
self
.
_add
(
square_sum
,
self
.
_reduce_sum
(
self
.
_square_all
(
grad
)))
norm_grad
=
self
.
_sqrt
(
square_sum
)
beta
=
self
.
_add
(
beta
,
self
.
_cast
(
self
.
_less
(
norm_grad
,
self
.
_norm_clip
),
mstype
.
float32
))
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
self
.
_norm_clip
)
grads
=
record_grad
grads
=
record_grad
total_loss
=
loss
total_loss
=
loss
for
i
in
range
(
1
,
self
.
_micro_batches
):
for
i
in
range
(
1
,
self
.
_micro_batches
):
loss
=
self
.
network
(
record_datas
[
i
],
record_labels
[
i
])
loss
=
self
.
network
(
record_datas
[
i
],
record_labels
[
i
])
scaling_sens_filled
=
C
.
ones_like
(
loss
)
*
F
.
cast
(
scaling_sens
,
F
.
dtype
(
loss
))
scaling_sens_filled
=
C
.
ones_like
(
loss
)
*
F
.
cast
(
scaling_sens
,
record_grad
=
self
.
grad
(
self
.
network
,
weights
)(
record_datas
[
i
],
record_labels
[
i
],
scaling_sens_filled
)
F
.
dtype
(
loss
))
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
self
.
_l2_norm
)
record_grad
=
self
.
grad
(
self
.
network
,
weights
)(
record_datas
[
i
],
record_labels
[
i
],
scaling_sens_filled
)
square_sum
=
self
.
_zero
for
grad
in
record_grad
:
square_sum
=
self
.
_add
(
square_sum
,
self
.
_reduce_sum
(
self
.
_square_all
(
grad
)))
norm_grad
=
self
.
_sqrt
(
square_sum
)
beta
=
self
.
_add
(
beta
,
self
.
_cast
(
self
.
_less
(
norm_grad
,
self
.
_norm_clip
),
mstype
.
float32
))
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
self
.
_norm_clip
)
grads
=
self
.
_tuple_add
(
grads
,
record_grad
)
grads
=
self
.
_tuple_add
(
grads
,
record_grad
)
total_loss
=
P
.
TensorAdd
()(
total_loss
,
loss
)
total_loss
=
P
.
TensorAdd
()(
total_loss
,
loss
)
loss
=
P
.
Div
()(
total_loss
,
self
.
_micro_float
)
loss
=
P
.
Div
()(
total_loss
,
self
.
_micro_float
)
beta
=
self
.
_div
(
beta
,
self
.
_micro_batches
)
if
self
.
_mech
is
not
None
:
if
self
.
_
noise_
mech
is
not
None
:
grad_noise_tuple
=
()
grad_noise_tuple
=
()
for
grad_item
in
grads
:
for
grad_item
in
grads
:
grad_noise
=
self
.
_mech
(
grad_item
)
grad_noise
=
self
.
_mech
(
grad_item
)
grad_noise_tuple
=
grad_noise_tuple
+
(
grad_noise
,)
grad_noise_tuple
=
grad_noise_tuple
+
(
grad_noise
,)
grads
=
self
.
_tuple_add
(
grads
,
grad_noise_tuple
)
grads
=
self
.
_tuple_add
(
grads
,
grad_noise_tuple
)
grads
=
self
.
_hyper_map
(
F
.
partial
(
_grad_scale
,
self
.
_micro_float
),
grads
)
grads
=
self
.
_hyper_map
(
F
.
partial
(
_grad_scale
,
self
.
_micro_float
),
grads
)
# update mech parameters
# update mech parameters
if
self
.
_mech_param_updater
is
not
None
:
multiplier
=
self
.
_mech_param_updater
()
if
self
.
_noise_mech_param_updater
is
not
None
:
multiplier
=
self
.
_noise_mech_param_updater
()
loss
=
F
.
depend
(
loss
,
multiplier
)
loss
=
F
.
depend
(
loss
,
multiplier
)
grads
=
self
.
hyper_map
(
F
.
partial
(
_grad_scale
,
scaling_sens
),
grads
)
grads
=
self
.
hyper_map
(
F
.
partial
(
_grad_scale
,
scaling_sens
),
grads
)
...
@@ -456,6 +544,10 @@ class _TrainOneStepWithLossScaleCell(Cell):
...
@@ -456,6 +544,10 @@ class _TrainOneStepWithLossScaleCell(Cell):
else
:
else
:
opt
=
self
.
optimizer
(
grads
)
opt
=
self
.
optimizer
(
grads
)
ret
=
(
loss
,
cond
,
scaling_sens
)
ret
=
(
loss
,
cond
,
scaling_sens
)
if
self
.
_clip_mech
is
not
None
:
next_norm_clip
=
self
.
_clip_mech
(
beta
,
self
.
_norm_clip
)
P
.
assign
(
self
.
_norm_clip
,
next_norm_clip
)
return
F
.
depend
(
ret
,
opt
)
return
F
.
depend
(
ret
,
opt
)
...
@@ -463,17 +555,22 @@ class _TrainOneStepCell(Cell):
...
@@ -463,17 +555,22 @@ class _TrainOneStepCell(Cell):
r
"""
r
"""
Network training package class.
Network training package class.
Wraps the network with an optimizer. The resulting Cell be trained with input data and label.
Wraps the network with an optimizer. The resulting Cell be trained with
Backward graph will be created in the construct function to do parameter updating. Different
input data and label. Backward graph will be created in the construct
parallel modes are available to run the training.
function to do parameter updating. Different parallel modes are available
to run the training.
Args:
Args:
network (Cell): The training network.
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The scaling number to be filled as the input of back propagation. Default value is 1.0.
sens (Number): The scaling number to be filled as the input of back
micro_batches (int): The number of small batches split from an original batch. Default: None.
propagation. Default value is 1.0.
norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0.
micro_batches (int): The number of small batches split from an original
mech (Mechanisms): The object can generate the different type of noise. Default: None.
batch. Default: None.
norm_clip (Tensor): Use to clip the bound, if set 1, will return the
original data. Default: 1.0.
noise_mech (Mechanisms): The object can generate the different type
of noise. Default: None.
Inputs:
Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
...
@@ -483,7 +580,9 @@ class _TrainOneStepCell(Cell):
...
@@ -483,7 +580,9 @@ class _TrainOneStepCell(Cell):
Tensor, a scalar Tensor with shape :math:`()`.
Tensor, a scalar Tensor with shape :math:`()`.
"""
"""
def
__init__
(
self
,
network
,
optimizer
,
sens
=
1.0
,
micro_batches
=
None
,
norm_clip
=
1.0
,
mech
=
None
):
def
__init__
(
self
,
network
,
optimizer
,
norm_clip
=
1.0
,
sens
=
1.0
,
micro_batches
=
None
,
noise_mech
=
None
,
clip_mech
=
None
):
super
(
_TrainOneStepCell
,
self
).
__init__
(
auto_prefix
=
False
)
super
(
_TrainOneStepCell
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
network
=
network
self
.
network
.
set_grad
()
self
.
network
.
set_grad
()
...
@@ -495,36 +594,51 @@ class _TrainOneStepCell(Cell):
...
@@ -495,36 +594,51 @@ class _TrainOneStepCell(Cell):
self
.
reducer_flag
=
False
self
.
reducer_flag
=
False
self
.
grad_reducer
=
None
self
.
grad_reducer
=
None
parallel_mode
=
_get_parallel_mode
()
parallel_mode
=
_get_parallel_mode
()
if
parallel_mode
in
(
ParallelMode
.
DATA_PARALLEL
,
ParallelMode
.
HYBRID_PARALLEL
):
if
parallel_mode
in
(
ParallelMode
.
DATA_PARALLEL
,
ParallelMode
.
HYBRID_PARALLEL
):
self
.
reducer_flag
=
True
self
.
reducer_flag
=
True
if
self
.
reducer_flag
:
if
self
.
reducer_flag
:
mean
=
_get_mirror_mean
()
mean
=
_get_mirror_mean
()
degree
=
_get_device_num
()
degree
=
_get_device_num
()
self
.
grad_reducer
=
DistributedGradReducer
(
optimizer
.
parameters
,
mean
,
degree
)
self
.
grad_reducer
=
DistributedGradReducer
(
optimizer
.
parameters
,
mean
,
degree
)
# dp params
# dp params
if
micro_batches
is
None
:
if
micro_batches
is
None
:
msg
=
'micro_batches must give in differential privacy, but got value: {}'
.
format
(
micro_batches
)
msg
=
'micro_batches must give in differential privacy, but got value: {}'
.
format
(
micro_batches
)
LOGGER
.
error
(
TAG
,
msg
)
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
self
.
_micro_batches
=
micro_batches
self
.
_micro_batches
=
micro_batches
norm_clip
=
check_param_type
(
'norm_clip'
,
norm_clip
,
float
)
self
.
_norm_clip
=
norm_clip
self
.
_l2_norm
=
check_value_positive
(
'norm_clip'
,
norm_clip
)
self
.
_split
=
P
.
Split
(
0
,
self
.
_micro_batches
)
self
.
_split
=
P
.
Split
(
0
,
self
.
_micro_batches
)
self
.
_clip_by_global_norm
=
_ClipGradients
()
self
.
_clip_by_global_norm
=
_ClipGradients
()
self
.
_mech
=
mech
self
.
_noise_mech
=
noise_mech
self
.
_clip_mech
=
clip_mech
self
.
_tuple_add
=
_TupleAdd
()
self
.
_tuple_add
=
_TupleAdd
()
self
.
_add
=
P
.
TensorAdd
()
self
.
_norm
=
nn
.
Norm
()
self
.
_hyper_map
=
C
.
HyperMap
()
self
.
_hyper_map
=
C
.
HyperMap
()
self
.
_zero
=
Tensor
(
0
,
mstype
.
float32
)
self
.
_assign
=
P
.
Assign
()
self
.
_div
=
P
.
Div
()
self
.
_sqrt
=
P
.
Sqrt
()
self
.
_reduce_sum
=
P
.
ReduceSum
()
self
.
_square_all
=
P
.
Square
()
self
.
_less
=
P
.
Less
()
self
.
_cast
=
P
.
Cast
()
self
.
_micro_float
=
Tensor
(
micro_batches
,
mstype
.
float32
)
self
.
_micro_float
=
Tensor
(
micro_batches
,
mstype
.
float32
)
self
.
_mech_param_updater
=
None
self
.
_noise_mech_param_updater
=
None
if
self
.
_mech
is
not
None
and
self
.
_mech
.
_decay_policy
is
not
None
:
if
self
.
_noise_mech
is
not
None
and
self
.
_noise_mech
.
_decay_policy
is
not
None
:
self
.
_mech_param_updater
=
_MechanismsParamsUpdater
(
policy
=
self
.
_mech
.
_decay_policy
,
self
.
_noise_mech_param_updater
=
_MechanismsParamsUpdater
(
decay_rate
=
self
.
_mech
.
_noise_decay_rate
,
policy
=
self
.
_noise_mech
.
_decay_policy
,
cur_noise_multiplier
=
decay_rate
=
self
.
_noise_mech
.
_noise_decay_rate
,
self
.
_mech
.
_noise_multiplier
,
cur_noise_multiplier
=
init_noise_multiplier
=
self
.
_noise_mech
.
_noise_multiplier
,
self
.
_mech
.
_initial_noise_multiplier
)
init_noise_multiplier
=
self
.
_noise_mech
.
_initial_noise_multiplier
)
def
construct
(
self
,
data
,
label
):
def
construct
(
self
,
data
,
label
):
"""
"""
...
@@ -535,32 +649,65 @@ class _TrainOneStepCell(Cell):
...
@@ -535,32 +649,65 @@ class _TrainOneStepCell(Cell):
record_labels
=
self
.
_split
(
label
)
record_labels
=
self
.
_split
(
label
)
loss
=
self
.
network
(
record_datas
[
0
],
record_labels
[
0
])
loss
=
self
.
network
(
record_datas
[
0
],
record_labels
[
0
])
sens
=
P
.
Fill
()(
P
.
DType
()(
loss
),
P
.
Shape
()(
loss
),
self
.
sens
)
sens
=
P
.
Fill
()(
P
.
DType
()(
loss
),
P
.
Shape
()(
loss
),
self
.
sens
)
record_grad
=
self
.
grad
(
self
.
network
,
weights
)(
record_datas
[
0
],
record_labels
[
0
],
sens
)
record_grad
=
self
.
grad
(
self
.
network
,
weights
)(
record_datas
[
0
],
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
self
.
_l2_norm
)
record_labels
[
0
],
sens
)
beta
=
self
.
_zero
square_sum
=
self
.
_zero
for
grad
in
record_grad
:
square_sum
=
self
.
_add
(
square_sum
,
self
.
_reduce_sum
(
self
.
_square_all
(
grad
)))
norm_grad
=
self
.
_sqrt
(
square_sum
)
beta
=
self
.
_add
(
beta
,
self
.
_cast
(
self
.
_less
(
norm_grad
,
self
.
_norm_clip
),
mstype
.
float32
))
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
self
.
_norm_clip
)
grads
=
record_grad
grads
=
record_grad
total_loss
=
loss
total_loss
=
loss
for
i
in
range
(
1
,
self
.
_micro_batches
):
for
i
in
range
(
1
,
self
.
_micro_batches
):
loss
=
self
.
network
(
record_datas
[
i
],
record_labels
[
i
])
loss
=
self
.
network
(
record_datas
[
i
],
record_labels
[
i
])
sens
=
P
.
Fill
()(
P
.
DType
()(
loss
),
P
.
Shape
()(
loss
),
self
.
sens
)
sens
=
P
.
Fill
()(
P
.
DType
()(
loss
),
P
.
Shape
()(
loss
),
self
.
sens
)
record_grad
=
self
.
grad
(
self
.
network
,
weights
)(
record_datas
[
i
],
record_labels
[
i
],
sens
)
record_grad
=
self
.
grad
(
self
.
network
,
weights
)(
record_datas
[
i
],
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
self
.
_l2_norm
)
record_labels
[
i
],
sens
)
square_sum
=
self
.
_zero
for
grad
in
record_grad
:
square_sum
=
self
.
_add
(
square_sum
,
self
.
_reduce_sum
(
self
.
_square_all
(
grad
)))
norm_grad
=
self
.
_sqrt
(
square_sum
)
beta
=
self
.
_add
(
beta
,
self
.
_cast
(
self
.
_less
(
norm_grad
,
self
.
_norm_clip
),
mstype
.
float32
))
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
self
.
_norm_clip
)
grads
=
self
.
_tuple_add
(
grads
,
record_grad
)
grads
=
self
.
_tuple_add
(
grads
,
record_grad
)
total_loss
=
P
.
TensorAdd
()(
total_loss
,
loss
)
total_loss
=
P
.
TensorAdd
()(
total_loss
,
loss
)
loss
=
P
.
Div
()(
total_loss
,
self
.
_micro_float
)
loss
=
self
.
_div
(
total_loss
,
self
.
_micro_float
)
beta
=
self
.
_div
(
beta
,
self
.
_micro_batches
)
if
self
.
_mech
is
not
None
:
if
self
.
_
noise_
mech
is
not
None
:
grad_noise_tuple
=
()
grad_noise_tuple
=
()
for
grad_item
in
grads
:
for
grad_item
in
grads
:
grad_noise
=
self
.
_mech
(
grad_item
)
grad_noise
=
self
.
_
noise_
mech
(
grad_item
)
grad_noise_tuple
=
grad_noise_tuple
+
(
grad_noise
,)
grad_noise_tuple
=
grad_noise_tuple
+
(
grad_noise
,)
grads
=
self
.
_tuple_add
(
grads
,
grad_noise_tuple
)
grads
=
self
.
_tuple_add
(
grads
,
grad_noise_tuple
)
grads
=
self
.
_hyper_map
(
F
.
partial
(
_grad_scale
,
self
.
_micro_float
),
grads
)
grads
=
self
.
_hyper_map
(
F
.
partial
(
_grad_scale
,
self
.
_micro_float
),
grads
)
# update mech parameters
# update mech parameters
if
self
.
_mech_param_updater
is
not
None
:
if
self
.
_
noise_
mech_param_updater
is
not
None
:
multiplier
=
self
.
_mech_param_updater
()
multiplier
=
self
.
_
noise_
mech_param_updater
()
loss
=
F
.
depend
(
loss
,
multiplier
)
loss
=
F
.
depend
(
loss
,
multiplier
)
if
self
.
reducer_flag
:
if
self
.
reducer_flag
:
# apply grad reducer on grads
# apply grad reducer on grads
grads
=
self
.
grad_reducer
(
grads
)
grads
=
self
.
grad_reducer
(
grads
)
if
self
.
_clip_mech
is
not
None
:
next_norm_clip
=
self
.
_clip_mech
(
beta
,
self
.
_norm_clip
)
self
.
_norm_clip
=
self
.
_assign
(
self
.
_norm_clip
,
next_norm_clip
)
loss
=
F
.
depend
(
loss
,
next_norm_clip
)
return
F
.
depend
(
loss
,
self
.
optimizer
(
grads
))
return
F
.
depend
(
loss
,
self
.
optimizer
(
grads
))
tests/ut/python/diff_privacy/test_mechanisms.py
浏览文件 @
79c6403d
...
@@ -19,9 +19,11 @@ import pytest
...
@@ -19,9 +19,11 @@ import pytest
from
mindspore
import
context
from
mindspore
import
context
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
dtype
as
mstype
from
mindarmour.diff_privacy
import
GaussianRandom
from
mindarmour.diff_privacy
import
Noise
GaussianRandom
from
mindarmour.diff_privacy
import
AdaGaussianRandom
from
mindarmour.diff_privacy
import
AdaGaussianRandom
from
mindarmour.diff_privacy
import
MechanismsFactory
from
mindarmour.diff_privacy
import
AdaClippingWithGaussianRandom
from
mindarmour.diff_privacy
import
NoiseMechanismsFactory
from
mindarmour.diff_privacy
import
ClipMechanismsFactory
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
...
@@ -33,7 +35,7 @@ def test_graph_gaussian():
...
@@ -33,7 +35,7 @@ def test_graph_gaussian():
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
norm_bound
=
1.0
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
initial_noise_multiplier
=
0.1
net
=
GaussianRandom
(
norm_bound
,
initial_noise_multiplier
)
net
=
Noise
GaussianRandom
(
norm_bound
,
initial_noise_multiplier
)
res
=
net
(
grad
)
res
=
net
(
grad
)
print
(
res
)
print
(
res
)
...
@@ -47,7 +49,7 @@ def test_pynative_gaussian():
...
@@ -47,7 +49,7 @@ def test_pynative_gaussian():
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
norm_bound
=
1.0
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
initial_noise_multiplier
=
0.1
net
=
GaussianRandom
(
norm_bound
,
initial_noise_multiplier
)
net
=
Noise
GaussianRandom
(
norm_bound
,
initial_noise_multiplier
)
res
=
net
(
grad
)
res
=
net
(
grad
)
print
(
res
)
print
(
res
)
...
@@ -80,13 +82,13 @@ def test_graph_factory():
...
@@ -80,13 +82,13 @@ def test_graph_factory():
initial_noise_multiplier
=
0.1
initial_noise_multiplier
=
0.1
alpha
=
0.5
alpha
=
0.5
decay_policy
=
'Step'
decay_policy
=
'Step'
noise_mechanism
=
MechanismsFactory
()
noise_mechanism
=
Noise
MechanismsFactory
()
noise_construct
=
noise_mechanism
.
create
(
'Gaussian'
,
noise_construct
=
noise_mechanism
.
create
(
'Gaussian'
,
norm_bound
,
norm_bound
,
initial_noise_multiplier
)
initial_noise_multiplier
)
noise
=
noise_construct
(
grad
)
noise
=
noise_construct
(
grad
)
print
(
'Gaussian noise: '
,
noise
)
print
(
'Gaussian noise: '
,
noise
)
ada_mechanism
=
MechanismsFactory
()
ada_mechanism
=
Noise
MechanismsFactory
()
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
norm_bound
,
norm_bound
,
initial_noise_multiplier
,
initial_noise_multiplier
,
...
@@ -124,13 +126,13 @@ def test_pynative_factory():
...
@@ -124,13 +126,13 @@ def test_pynative_factory():
initial_noise_multiplier
=
0.1
initial_noise_multiplier
=
0.1
alpha
=
0.5
alpha
=
0.5
decay_policy
=
'Step'
decay_policy
=
'Step'
noise_mechanism
=
MechanismsFactory
()
noise_mechanism
=
Noise
MechanismsFactory
()
noise_construct
=
noise_mechanism
.
create
(
'Gaussian'
,
noise_construct
=
noise_mechanism
.
create
(
'Gaussian'
,
norm_bound
,
norm_bound
,
initial_noise_multiplier
)
initial_noise_multiplier
)
noise
=
noise_construct
(
grad
)
noise
=
noise_construct
(
grad
)
print
(
'Gaussian noise: '
,
noise
)
print
(
'Gaussian noise: '
,
noise
)
ada_mechanism
=
MechanismsFactory
()
ada_mechanism
=
Noise
MechanismsFactory
()
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
norm_bound
,
norm_bound
,
initial_noise_multiplier
,
initial_noise_multiplier
,
...
@@ -151,7 +153,7 @@ def test_pynative_exponential():
...
@@ -151,7 +153,7 @@ def test_pynative_exponential():
initial_noise_multiplier
=
0.1
initial_noise_multiplier
=
0.1
alpha
=
0.5
alpha
=
0.5
decay_policy
=
'Exp'
decay_policy
=
'Exp'
ada_mechanism
=
MechanismsFactory
()
ada_mechanism
=
Noise
MechanismsFactory
()
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
norm_bound
,
norm_bound
,
initial_noise_multiplier
,
initial_noise_multiplier
,
...
@@ -172,7 +174,7 @@ def test_graph_exponential():
...
@@ -172,7 +174,7 @@ def test_graph_exponential():
initial_noise_multiplier
=
0.1
initial_noise_multiplier
=
0.1
alpha
=
0.5
alpha
=
0.5
decay_policy
=
'Exp'
decay_policy
=
'Exp'
ada_mechanism
=
MechanismsFactory
()
ada_mechanism
=
Noise
MechanismsFactory
()
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
norm_bound
,
norm_bound
,
initial_noise_multiplier
,
initial_noise_multiplier
,
...
@@ -180,3 +182,107 @@ def test_graph_exponential():
...
@@ -180,3 +182,107 @@ def test_graph_exponential():
decay_policy
=
decay_policy
)
decay_policy
=
decay_policy
)
ada_noise
=
ada_noise_construct
(
grad
)
ada_noise
=
ada_noise_construct
(
grad
)
print
(
'ada noise: '
,
ada_noise
)
print
(
'ada noise: '
,
ada_noise
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
def
test_ada_clip_gaussian_random_pynative
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
decay_policy
=
'Linear'
beta
=
Tensor
(
0.5
,
mstype
.
float32
)
norm_clip
=
Tensor
(
1.0
,
mstype
.
float32
)
beta_stddev
=
0.1
learning_rate
=
0.1
target_unclipped_quantile
=
0.3
ada_clip
=
AdaClippingWithGaussianRandom
(
decay_policy
=
decay_policy
,
learning_rate
=
learning_rate
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
,
seed
=
1
)
next_norm_clip
=
ada_clip
(
beta
,
norm_clip
)
print
(
'Liner next norm clip:'
,
next_norm_clip
)
decay_policy
=
'Geometric'
ada_clip
=
AdaClippingWithGaussianRandom
(
decay_policy
=
decay_policy
,
learning_rate
=
learning_rate
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
,
seed
=
1
)
next_norm_clip
=
ada_clip
(
beta
,
norm_clip
)
print
(
'Geometric next norm clip:'
,
next_norm_clip
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
def
test_ada_clip_gaussian_random_graph
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
decay_policy
=
'Linear'
beta
=
Tensor
(
0.5
,
mstype
.
float32
)
norm_clip
=
Tensor
(
1.0
,
mstype
.
float32
)
beta_stddev
=
0.1
learning_rate
=
0.1
target_unclipped_quantile
=
0.3
ada_clip
=
AdaClippingWithGaussianRandom
(
decay_policy
=
decay_policy
,
learning_rate
=
learning_rate
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
,
seed
=
1
)
next_norm_clip
=
ada_clip
(
beta
,
norm_clip
)
print
(
'Liner next norm clip:'
,
next_norm_clip
)
decay_policy
=
'Geometric'
ada_clip
=
AdaClippingWithGaussianRandom
(
decay_policy
=
decay_policy
,
learning_rate
=
learning_rate
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
,
seed
=
1
)
next_norm_clip
=
ada_clip
(
beta
,
norm_clip
)
print
(
'Geometric next norm clip:'
,
next_norm_clip
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
def
test_pynative_clip_mech_factory
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
decay_policy
=
'Linear'
beta
=
Tensor
(
0.5
,
mstype
.
float32
)
norm_clip
=
Tensor
(
1.0
,
mstype
.
float32
)
beta_stddev
=
0.1
learning_rate
=
0.1
target_unclipped_quantile
=
0.3
clip_mechanism
=
ClipMechanismsFactory
()
ada_clip
=
clip_mechanism
.
create
(
'Gaussian'
,
decay_policy
=
decay_policy
,
learning_rate
=
learning_rate
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
)
next_norm_clip
=
ada_clip
(
beta
,
norm_clip
)
print
(
'next_norm_clip: '
,
next_norm_clip
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
def
test_graph_clip_mech_factory
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
decay_policy
=
'Linear'
beta
=
Tensor
(
0.5
,
mstype
.
float32
)
norm_clip
=
Tensor
(
1.0
,
mstype
.
float32
)
beta_stddev
=
0.1
learning_rate
=
0.1
target_unclipped_quantile
=
0.3
clip_mechanism
=
ClipMechanismsFactory
()
ada_clip
=
clip_mechanism
.
create
(
'Gaussian'
,
decay_policy
=
decay_policy
,
learning_rate
=
learning_rate
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
)
next_norm_clip
=
ada_clip
(
beta
,
norm_clip
)
print
(
'next_norm_clip: '
,
next_norm_clip
)
tests/ut/python/diff_privacy/test_model_train.py
浏览文件 @
79c6403d
...
@@ -22,7 +22,8 @@ from mindspore import context
...
@@ -22,7 +22,8 @@ from mindspore import context
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
from
mindarmour.diff_privacy
import
DPModel
from
mindarmour.diff_privacy
import
DPModel
from
mindarmour.diff_privacy
import
MechanismsFactory
from
mindarmour.diff_privacy
import
NoiseMechanismsFactory
from
mindarmour.diff_privacy
import
ClipMechanismsFactory
from
mindarmour.diff_privacy
import
DPOptimizerClassFactory
from
mindarmour.diff_privacy
import
DPOptimizerClassFactory
from
test_network
import
LeNet5
from
test_network
import
LeNet5
...
@@ -30,10 +31,12 @@ from test_network import LeNet5
...
@@ -30,10 +31,12 @@ from test_network import LeNet5
def
dataset_generator
(
batch_size
,
batches
):
def
dataset_generator
(
batch_size
,
batches
):
"""mock training data."""
"""mock training data."""
data
=
np
.
random
.
random
((
batches
*
batch_size
,
1
,
32
,
32
)).
astype
(
np
.
float32
)
data
=
np
.
random
.
random
((
batches
*
batch_size
,
1
,
32
,
32
)).
astype
(
label
=
np
.
random
.
randint
(
0
,
10
,
batches
*
batch_size
).
astype
(
np
.
int32
)
np
.
float32
)
label
=
np
.
random
.
randint
(
0
,
10
,
batches
*
batch_size
).
astype
(
np
.
int32
)
for
i
in
range
(
batches
):
for
i
in
range
(
batches
):
yield
data
[
i
*
batch_size
:(
i
+
1
)
*
batch_size
],
label
[
i
*
batch_size
:(
i
+
1
)
*
batch_size
]
yield
data
[
i
*
batch_size
:(
i
+
1
)
*
batch_size
],
\
label
[
i
*
batch_size
:(
i
+
1
)
*
batch_size
]
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
...
@@ -55,16 +58,24 @@ def test_dp_model_with_pynative_mode():
...
@@ -55,16 +58,24 @@ def test_dp_model_with_pynative_mode():
factory_opt
.
set_mechanisms
(
'Gaussian'
,
factory_opt
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
norm_clip
,
norm_bound
=
norm_clip
,
initial_noise_multiplier
=
initial_noise_multiplier
)
initial_noise_multiplier
=
initial_noise_multiplier
)
net_opt
=
factory_opt
.
create
(
'Momentum'
)(
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
net_opt
=
factory_opt
.
create
(
'Momentum'
)(
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
clip_mech
=
ClipMechanismsFactory
().
create
(
'Gaussian'
,
decay_policy
=
'Linear'
,
learning_rate
=
0.01
,
target_unclipped_quantile
=
0.9
,
fraction_stddev
=
0.01
)
model
=
DPModel
(
micro_batches
=
micro_batches
,
model
=
DPModel
(
micro_batches
=
micro_batches
,
norm_clip
=
norm_clip
,
norm_clip
=
norm_clip
,
mech
=
None
,
clip_mech
=
clip_mech
,
noise_mech
=
None
,
network
=
network
,
network
=
network
,
loss_fn
=
loss
,
loss_fn
=
loss
,
optimizer
=
net_opt
,
optimizer
=
net_opt
,
metrics
=
None
)
metrics
=
None
)
ms_ds
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
'data'
,
'label'
])
ms_ds
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
ms_ds
.
set_dataset_size
(
batch_size
*
batches
)
[
'data'
,
'label'
])
ms_ds
.
set_dataset_size
(
batch_size
*
batches
)
model
.
train
(
epochs
,
ms_ds
,
dataset_sink_mode
=
False
)
model
.
train
(
epochs
,
ms_ds
,
dataset_sink_mode
=
False
)
...
@@ -82,19 +93,27 @@ def test_dp_model_with_graph_mode():
...
@@ -82,19 +93,27 @@ def test_dp_model_with_graph_mode():
batches
=
128
batches
=
128
epochs
=
1
epochs
=
1
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
mech
=
MechanismsFactory
().
create
(
'Gaussian'
,
noise_mech
=
NoiseMechanismsFactory
().
create
(
'Gaussian'
,
norm_bound
=
norm_clip
,
norm_bound
=
norm_clip
,
initial_noise_multiplier
=
initial_noise_multiplier
)
initial_noise_multiplier
=
initial_noise_multiplier
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
clip_mech
=
ClipMechanismsFactory
().
create
(
'Gaussian'
,
decay_policy
=
'Linear'
,
learning_rate
=
0.01
,
target_unclipped_quantile
=
0.9
,
fraction_stddev
=
0.01
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
model
=
DPModel
(
micro_batches
=
2
,
model
=
DPModel
(
micro_batches
=
2
,
clip_mech
=
clip_mech
,
norm_clip
=
norm_clip
,
norm_clip
=
norm_clip
,
mech
=
mech
,
noise_mech
=
noise_
mech
,
network
=
network
,
network
=
network
,
loss_fn
=
loss
,
loss_fn
=
loss
,
optimizer
=
net_opt
,
optimizer
=
net_opt
,
metrics
=
None
)
metrics
=
None
)
ms_ds
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
'data'
,
'label'
])
ms_ds
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
ms_ds
.
set_dataset_size
(
batch_size
*
batches
)
[
'data'
,
'label'
])
ms_ds
.
set_dataset_size
(
batch_size
*
batches
)
model
.
train
(
epochs
,
ms_ds
,
dataset_sink_mode
=
False
)
model
.
train
(
epochs
,
ms_ds
,
dataset_sink_mode
=
False
)
...
@@ -112,17 +131,25 @@ def test_dp_model_with_graph_mode_ada_gaussian():
...
@@ -112,17 +131,25 @@ def test_dp_model_with_graph_mode_ada_gaussian():
batches
=
128
batches
=
128
epochs
=
1
epochs
=
1
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
mech
=
MechanismsFactory
().
create
(
'AdaGaussian'
,
noise_mech
=
NoiseMechanismsFactory
().
create
(
'AdaGaussian'
,
norm_bound
=
norm_clip
,
norm_bound
=
norm_clip
,
initial_noise_multiplier
=
initial_noise_multiplier
)
initial_noise_multiplier
=
initial_noise_multiplier
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
clip_mech
=
ClipMechanismsFactory
().
create
(
'Gaussian'
,
decay_policy
=
'Linear'
,
learning_rate
=
0.01
,
target_unclipped_quantile
=
0.9
,
fraction_stddev
=
0.01
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
model
=
DPModel
(
micro_batches
=
2
,
model
=
DPModel
(
micro_batches
=
2
,
clip_mech
=
clip_mech
,
norm_clip
=
norm_clip
,
norm_clip
=
norm_clip
,
mech
=
mech
,
noise_mech
=
noise_
mech
,
network
=
network
,
network
=
network
,
loss_fn
=
loss
,
loss_fn
=
loss
,
optimizer
=
net_opt
,
optimizer
=
net_opt
,
metrics
=
None
)
metrics
=
None
)
ms_ds
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
'data'
,
'label'
])
ms_ds
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
ms_ds
.
set_dataset_size
(
batch_size
*
batches
)
[
'data'
,
'label'
])
ms_ds
.
set_dataset_size
(
batch_size
*
batches
)
model
.
train
(
epochs
,
ms_ds
,
dataset_sink_mode
=
False
)
model
.
train
(
epochs
,
ms_ds
,
dataset_sink_mode
=
False
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录