Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
69e45a3d
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
2
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
69e45a3d
编写于
7月 22, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
!55 refactor mechanisms, fix exp formula error.
Merge pull request !55 from zheng-huanhuan/master
上级
1589a23b
0de5e062
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
301 addition
and
298 deletion
+301
-298
example/mnist_demo/lenet5_config.py
example/mnist_demo/lenet5_config.py
+1
-1
example/mnist_demo/lenet5_dp.py
example/mnist_demo/lenet5_dp.py
+5
-4
example/mnist_demo/lenet5_dp_pynative_model.py
example/mnist_demo/lenet5_dp_pynative_model.py
+22
-7
mindarmour/diff_privacy/__init__.py
mindarmour/diff_privacy/__init__.py
+2
-2
mindarmour/diff_privacy/mechanisms/mechanisms.py
mindarmour/diff_privacy/mechanisms/mechanisms.py
+88
-104
mindarmour/diff_privacy/optimizer/optimizer.py
mindarmour/diff_privacy/optimizer/optimizer.py
+2
-2
mindarmour/diff_privacy/train/model.py
mindarmour/diff_privacy/train/model.py
+37
-38
mindarmour/fuzzing/model_coverage_metrics.py
mindarmour/fuzzing/model_coverage_metrics.py
+4
-4
requirements.txt
requirements.txt
+1
-1
setup.py
setup.py
+1
-1
tests/ut/python/diff_privacy/test_mechanisms.py
tests/ut/python/diff_privacy/test_mechanisms.py
+113
-112
tests/ut/python/diff_privacy/test_model_train.py
tests/ut/python/diff_privacy/test_model_train.py
+13
-10
tests/ut/python/diff_privacy/test_optimizer.py
tests/ut/python/diff_privacy/test_optimizer.py
+12
-12
未找到文件。
example/mnist_demo/lenet5_config.py
浏览文件 @
69e45a3d
...
@@ -32,7 +32,7 @@ mnist_cfg = edict({
...
@@ -32,7 +32,7 @@ mnist_cfg = edict({
'data_path'
:
'./MNIST_unzip'
,
# the path of training and testing data set
'data_path'
:
'./MNIST_unzip'
,
# the path of training and testing data set
'dataset_sink_mode'
:
False
,
# whether deliver all training data to device one time
'dataset_sink_mode'
:
False
,
# whether deliver all training data to device one time
'micro_batches'
:
16
,
# the number of small batches split from an original batch
'micro_batches'
:
16
,
# the number of small batches split from an original batch
'norm_
clip
'
:
1.0
,
# the clip bound of the gradients of model's training parameters
'norm_
bound
'
:
1.0
,
# the clip bound of the gradients of model's training parameters
'initial_noise_multiplier'
:
0.5
,
# the initial multiplication coefficient of the noise added to training
'initial_noise_multiplier'
:
0.5
,
# the initial multiplication coefficient of the noise added to training
# parameters' gradients
# parameters' gradients
'noise_mechanisms'
:
'AdaGaussian'
,
# the method of adding noise in gradients while training
'noise_mechanisms'
:
'AdaGaussian'
,
# the method of adding noise in gradients while training
...
...
example/mnist_demo/lenet5_dp.py
浏览文件 @
69e45a3d
...
@@ -115,8 +115,9 @@ if __name__ == "__main__":
...
@@ -115,8 +115,9 @@ if __name__ == "__main__":
# or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian'
# or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian'
# mechanism while be constant with 'Gaussian' mechanism.
# mechanism while be constant with 'Gaussian' mechanism.
noise_mech
=
NoiseMechanismsFactory
().
create
(
cfg
.
noise_mechanisms
,
noise_mech
=
NoiseMechanismsFactory
().
create
(
cfg
.
noise_mechanisms
,
norm_bound
=
cfg
.
norm_clip
,
norm_bound
=
cfg
.
norm_bound
,
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
)
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
,
noise_update
=
'Exp'
)
# Create a factory class of clip mechanisms, this method is to adaptive clip
# Create a factory class of clip mechanisms, this method is to adaptive clip
# gradients while training, decay_policy support 'Linear' and 'Geometric',
# gradients while training, decay_policy support 'Linear' and 'Geometric',
# learning_rate is the learning rate to update clip_norm,
# learning_rate is the learning rate to update clip_norm,
...
@@ -136,11 +137,11 @@ if __name__ == "__main__":
...
@@ -136,11 +137,11 @@ if __name__ == "__main__":
rdp_monitor
=
PrivacyMonitorFactory
.
create
(
'rdp'
,
rdp_monitor
=
PrivacyMonitorFactory
.
create
(
'rdp'
,
num_samples
=
60000
,
num_samples
=
60000
,
batch_size
=
cfg
.
batch_size
,
batch_size
=
cfg
.
batch_size
,
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
*
cfg
.
norm_
clip
,
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
*
cfg
.
norm_
bound
,
per_print_times
=
10
)
per_print_times
=
10
)
# Create the DP model for training.
# Create the DP model for training.
model
=
DPModel
(
micro_batches
=
cfg
.
micro_batches
,
model
=
DPModel
(
micro_batches
=
cfg
.
micro_batches
,
norm_
clip
=
cfg
.
norm_clip
,
norm_
bound
=
cfg
.
norm_bound
,
noise_mech
=
noise_mech
,
noise_mech
=
noise_mech
,
clip_mech
=
clip_mech
,
clip_mech
=
clip_mech
,
network
=
network
,
network
=
network
,
...
...
example/mnist_demo/lenet5_dp_pynative_mode.py
→
example/mnist_demo/lenet5_dp_pynative_mode
l
.py
浏览文件 @
69e45a3d
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
"""
python lenet5_dp_pynative_mode.py --data_path /YourDataPath --micro_batches=2
python lenet5_dp_pynative_mode
l
.py --data_path /YourDataPath --micro_batches=2
"""
"""
import
os
import
os
...
@@ -32,6 +32,7 @@ import mindspore.common.dtype as mstype
...
@@ -32,6 +32,7 @@ import mindspore.common.dtype as mstype
from
mindarmour.diff_privacy
import
DPModel
from
mindarmour.diff_privacy
import
DPModel
from
mindarmour.diff_privacy
import
PrivacyMonitorFactory
from
mindarmour.diff_privacy
import
PrivacyMonitorFactory
from
mindarmour.diff_privacy
import
DPOptimizerClassFactory
from
mindarmour.diff_privacy
import
DPOptimizerClassFactory
from
mindarmour.diff_privacy
import
ClipMechanismsFactory
from
mindarmour.utils.logger
import
LogUtil
from
mindarmour.utils.logger
import
LogUtil
from
lenet5_net
import
LeNet5
from
lenet5_net
import
LeNet5
from
lenet5_config
import
mnist_cfg
as
cfg
from
lenet5_config
import
mnist_cfg
as
cfg
...
@@ -108,21 +109,35 @@ if __name__ == "__main__":
...
@@ -108,21 +109,35 @@ if __name__ == "__main__":
# means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise
# means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism.
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism.
dp_opt
=
DPOptimizerClassFactory
(
micro_batches
=
cfg
.
micro_batches
)
dp_opt
=
DPOptimizerClassFactory
(
micro_batches
=
cfg
.
micro_batches
)
dp_opt
.
set_mechanisms
(
cfg
.
mechanisms
,
dp_opt
.
set_mechanisms
(
cfg
.
noise_mechanisms
,
norm_bound
=
cfg
.
norm_clip
,
norm_bound
=
cfg
.
norm_bound
,
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
)
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
,
noise_update
=
'Exp'
)
# Create a factory class of clip mechanisms, this method is to adaptive clip
# gradients while training, decay_policy support 'Linear' and 'Geometric',
# learning_rate is the learning rate to update clip_norm,
# target_unclipped_quantile is the target quantile of norm clip,
# fraction_stddev is the stddev of Gaussian normal which used in
# empirical_fraction, the formula is
# $empirical_fraction + N(0, fraction_stddev)$.
clip_mech
=
ClipMechanismsFactory
().
create
(
cfg
.
clip_mechanisms
,
decay_policy
=
cfg
.
clip_decay_policy
,
learning_rate
=
cfg
.
clip_learning_rate
,
target_unclipped_quantile
=
cfg
.
target_unclipped_quantile
,
fraction_stddev
=
cfg
.
fraction_stddev
)
net_opt
=
dp_opt
.
create
(
'Momentum'
)(
params
=
network
.
trainable_params
(),
learning_rate
=
cfg
.
lr
,
momentum
=
cfg
.
momentum
)
net_opt
=
dp_opt
.
create
(
'Momentum'
)(
params
=
network
.
trainable_params
(),
learning_rate
=
cfg
.
lr
,
momentum
=
cfg
.
momentum
)
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps
# and delta) while training.
# and delta) while training.
rdp_monitor
=
PrivacyMonitorFactory
.
create
(
'rdp'
,
rdp_monitor
=
PrivacyMonitorFactory
.
create
(
'rdp'
,
num_samples
=
60000
,
num_samples
=
60000
,
batch_size
=
cfg
.
batch_size
,
batch_size
=
cfg
.
batch_size
,
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
*
cfg
.
norm_
clip
,
initial_noise_multiplier
=
cfg
.
initial_noise_multiplier
*
cfg
.
norm_
bound
,
per_print_times
=
10
)
per_print_times
=
10
)
# Create the DP model for training.
# Create the DP model for training.
model
=
DPModel
(
micro_batches
=
cfg
.
micro_batches
,
model
=
DPModel
(
micro_batches
=
cfg
.
micro_batches
,
norm_clip
=
cfg
.
norm_clip
,
norm_bound
=
cfg
.
norm_bound
,
mech
=
None
,
noise_mech
=
None
,
clip_mech
=
clip_mech
,
network
=
network
,
network
=
network
,
loss_fn
=
net_loss
,
loss_fn
=
net_loss
,
optimizer
=
net_opt
,
optimizer
=
net_opt
,
...
...
mindarmour/diff_privacy/__init__.py
浏览文件 @
69e45a3d
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
This module provide Differential Privacy feature to protect user privacy.
This module provide Differential Privacy feature to protect user privacy.
"""
"""
from
.mechanisms.mechanisms
import
NoiseGaussianRandom
from
.mechanisms.mechanisms
import
NoiseGaussianRandom
from
.mechanisms.mechanisms
import
AdaGaussianRandom
from
.mechanisms.mechanisms
import
Noise
AdaGaussianRandom
from
.mechanisms.mechanisms
import
AdaClippingWithGaussianRandom
from
.mechanisms.mechanisms
import
AdaClippingWithGaussianRandom
from
.mechanisms.mechanisms
import
NoiseMechanismsFactory
from
.mechanisms.mechanisms
import
NoiseMechanismsFactory
from
.mechanisms.mechanisms
import
ClipMechanismsFactory
from
.mechanisms.mechanisms
import
ClipMechanismsFactory
...
@@ -11,7 +11,7 @@ from .optimizer.optimizer import DPOptimizerClassFactory
...
@@ -11,7 +11,7 @@ from .optimizer.optimizer import DPOptimizerClassFactory
from
.train.model
import
DPModel
from
.train.model
import
DPModel
__all__
=
[
'NoiseGaussianRandom'
,
__all__
=
[
'NoiseGaussianRandom'
,
'AdaGaussianRandom'
,
'
Noise
AdaGaussianRandom'
,
'AdaClippingWithGaussianRandom'
,
'AdaClippingWithGaussianRandom'
,
'NoiseMechanismsFactory'
,
'NoiseMechanismsFactory'
,
'ClipMechanismsFactory'
,
'ClipMechanismsFactory'
,
...
...
mindarmour/diff_privacy/mechanisms/mechanisms.py
浏览文件 @
69e45a3d
...
@@ -19,6 +19,7 @@ from abc import abstractmethod
...
@@ -19,6 +19,7 @@ from abc import abstractmethod
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore.nn
import
Cell
from
mindspore.nn
import
Cell
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
mindspore.ops.composite
import
normal
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.parameter
import
Parameter
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
dtype
as
mstype
...
@@ -55,7 +56,7 @@ class ClipMechanismsFactory:
...
@@ -55,7 +56,7 @@ class ClipMechanismsFactory:
Examples:
Examples:
>>> decay_policy = 'Linear'
>>> decay_policy = 'Linear'
>>> beta = Tensor(0.5, mstype.float32)
>>> beta = Tensor(0.5, mstype.float32)
>>> norm_
clip
= Tensor(1.0, mstype.float32)
>>> norm_
bound
= Tensor(1.0, mstype.float32)
>>> beta_stddev = 0.1
>>> beta_stddev = 0.1
>>> learning_rate = 0.1
>>> learning_rate = 0.1
>>> target_unclipped_quantile = 0.3
>>> target_unclipped_quantile = 0.3
...
@@ -65,7 +66,7 @@ class ClipMechanismsFactory:
...
@@ -65,7 +66,7 @@ class ClipMechanismsFactory:
>>> learning_rate=learning_rate,
>>> learning_rate=learning_rate,
>>> target_unclipped_quantile=target_unclipped_quantile,
>>> target_unclipped_quantile=target_unclipped_quantile,
>>> fraction_stddev=beta_stddev)
>>> fraction_stddev=beta_stddev)
>>> next_norm_
clip = ada_clip(beta, norm_clip
)
>>> next_norm_
bound = ada_clip(beta, norm_bound
)
"""
"""
if
mech_name
==
'Gaussian'
:
if
mech_name
==
'Gaussian'
:
...
@@ -81,25 +82,32 @@ class NoiseMechanismsFactory:
...
@@ -81,25 +82,32 @@ class NoiseMechanismsFactory:
pass
pass
@
staticmethod
@
staticmethod
def
create
(
policy
,
*
args
,
**
kwargs
):
def
create
(
mech_name
=
'Gaussian'
,
norm_bound
=
0.5
,
initial_noise_multiplier
=
1.5
,
seed
=
0
,
noise_decay_rate
=
6e-6
,
noise_update
=
None
):
"""
"""
Args:
Args:
policy
(str): Noise generated strategy, could be 'Gaussian' or
mech_name
(str): Noise generated strategy, could be 'Gaussian' or
'AdaGaussian'. Noise would be decayed with 'AdaGaussian' mechanism
'AdaGaussian'. Noise would be decayed with 'AdaGaussian' mechanism
while be constant with 'Gaussian' mechanism.
while be constant with 'Gaussian' mechanism.
args(Union[float, str]): Parameters used for creating noise
norm_bound(float): Clipping bound for the l2 norm of the gradients.
mechanisms.
initial_noise_multiplier(float): Ratio of the standard deviation of
kwargs(Union[float, str]): Parameters used for creating noise
Gaussian noise divided by the norm_bound, which will be used to
mechanisms.
calculate privacy spent.
seed(int): Original random seed, if seed=0 random normal will use secure
random number. IF seed!=0 random normal will generate values using
given seed.
noise_decay_rate(float): Hyper parameter for controlling the noise decay.
noise_update(str): Mechanisms parameters update policy. Default: None, no
parameters need update.
Raises:
Raises:
NameError: `
policy
` must be in ['Gaussian', 'AdaGaussian'].
NameError: `
mech_name
` must be in ['Gaussian', 'AdaGaussian'].
Returns:
Returns:
Mechanisms, class of noise generated Mechanism.
Mechanisms, class of noise generated Mechanism.
Examples:
Examples:
>>> norm_
clip
= 1.0
>>> norm_
bound
= 1.0
>>> initial_noise_multiplier = 0.01
>>> initial_noise_multiplier = 0.01
>>> network = LeNet5()
>>> network = LeNet5()
>>> batch_size = 32
>>> batch_size = 32
...
@@ -107,7 +115,7 @@ class NoiseMechanismsFactory:
...
@@ -107,7 +115,7 @@ class NoiseMechanismsFactory:
>>> epochs = 1
>>> epochs = 1
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> noise_mech = NoiseMechanismsFactory().create('Gaussian',
>>> noise_mech = NoiseMechanismsFactory().create('Gaussian',
>>> norm_bound=norm_
clip
,
>>> norm_bound=norm_
bound
,
>>> initial_noise_multiplier=initial_noise_multiplier)
>>> initial_noise_multiplier=initial_noise_multiplier)
>>> clip_mech = ClipMechanismsFactory().create('Gaussian',
>>> clip_mech = ClipMechanismsFactory().create('Gaussian',
>>> decay_policy='Linear',
>>> decay_policy='Linear',
...
@@ -118,7 +126,7 @@ class NoiseMechanismsFactory:
...
@@ -118,7 +126,7 @@ class NoiseMechanismsFactory:
>>> momentum=0.9)
>>> momentum=0.9)
>>> model = DPModel(micro_batches=2,
>>> model = DPModel(micro_batches=2,
>>> clip_mech=clip_mech,
>>> clip_mech=clip_mech,
>>> norm_
clip=norm_clip
,
>>> norm_
bound=norm_bound
,
>>> noise_mech=noise_mech,
>>> noise_mech=noise_mech,
>>> network=network,
>>> network=network,
>>> loss_fn=loss,
>>> loss_fn=loss,
...
@@ -129,15 +137,22 @@ class NoiseMechanismsFactory:
...
@@ -129,15 +137,22 @@ class NoiseMechanismsFactory:
>>> ms_ds.set_dataset_size(batch_size * batches)
>>> ms_ds.set_dataset_size(batch_size * batches)
>>> model.train(epochs, ms_ds, dataset_sink_mode=False)
>>> model.train(epochs, ms_ds, dataset_sink_mode=False)
"""
"""
if
policy
==
'Gaussian'
:
if
mech_name
==
'Gaussian'
:
return
NoiseGaussianRandom
(
*
args
,
**
kwargs
)
return
NoiseGaussianRandom
(
norm_bound
=
norm_bound
,
if
policy
==
'AdaGaussian'
:
initial_noise_multiplier
=
initial_noise_multiplier
,
return
AdaGaussianRandom
(
*
args
,
**
kwargs
)
seed
=
seed
,
noise_update
=
noise_update
)
if
mech_name
==
'AdaGaussian'
:
return
NoiseAdaGaussianRandom
(
norm_bound
=
norm_bound
,
initial_noise_multiplier
=
initial_noise_multiplier
,
seed
=
seed
,
noise_decay_rate
=
noise_decay_rate
,
noise_update
=
noise_update
)
raise
NameError
(
"The {} is not implement, please choose "
raise
NameError
(
"The {} is not implement, please choose "
"['Gaussian', 'AdaGaussian']"
.
format
(
policy
))
"['Gaussian', 'AdaGaussian']"
.
format
(
mech_name
))
class
Mechanisms
(
Cell
):
class
_
Mechanisms
(
Cell
):
"""
"""
Basic class of noise generated mechanism.
Basic class of noise generated mechanism.
"""
"""
...
@@ -149,21 +164,19 @@ class Mechanisms(Cell):
...
@@ -149,21 +164,19 @@ class Mechanisms(Cell):
"""
"""
class
NoiseGaussianRandom
(
Mechanisms
):
class
NoiseGaussianRandom
(
_
Mechanisms
):
"""
"""
Gaussian noise generated mechanism.
Gaussian noise generated mechanism.
Args:
Args:
norm_bound(float): Clipping bound for the l2 norm of the gradients.
norm_bound(float): Clipping bound for the l2 norm of the gradients.
Default: 0.5.
initial_noise_multiplier(float): Ratio of the standard deviation of
initial_noise_multiplier(float): Ratio of the standard deviation of
Gaussian noise divided by the norm_bound, which will be used to
Gaussian noise divided by the norm_bound, which will be used to
calculate privacy spent.
Default: 1.5.
calculate privacy spent.
seed(int): Original random seed, if seed=0 random normal will use secure
seed(int): Original random seed, if seed=0 random normal will use secure
random number. IF seed!=0 random normal will generate values using
random number. IF seed!=0 random normal will generate values using
given seed. Default: 0.
given seed.
policy(str): Mechanisms parameters update policy. Default: None, no
noise_update(str): Mechanisms parameters update policy. Default: None.
parameters need update.
Returns:
Returns:
Tensor, generated noise with shape like given gradients.
Tensor, generated noise with shape like given gradients.
...
@@ -172,24 +185,25 @@ class NoiseGaussianRandom(Mechanisms):
...
@@ -172,24 +185,25 @@ class NoiseGaussianRandom(Mechanisms):
>>> gradients = Tensor([0.2, 0.9], mstype.float32)
>>> gradients = Tensor([0.2, 0.9], mstype.float32)
>>> norm_bound = 0.5
>>> norm_bound = 0.5
>>> initial_noise_multiplier = 1.5
>>> initial_noise_multiplier = 1.5
>>> net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier)
>>> seed = 0
>>> noise_update = None
>>> net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier, seed, noise_update)
>>> res = net(gradients)
>>> res = net(gradients)
>>> print(res)
>>> print(res)
"""
"""
def
__init__
(
self
,
norm_bound
=
0.5
,
initial_noise_multiplier
=
1.5
,
seed
=
0
,
def
__init__
(
self
,
norm_bound
,
initial_noise_multiplier
,
seed
,
noise_update
=
None
):
policy
=
None
):
super
(
NoiseGaussianRandom
,
self
).
__init__
()
super
(
NoiseGaussianRandom
,
self
).
__init__
()
self
.
_norm_bound
=
check_value_positive
(
'norm_bound'
,
norm_bound
)
self
.
_norm_bound
=
check_value_positive
(
'norm_bound'
,
norm_bound
)
self
.
_norm_bound
=
Tensor
(
norm_bound
,
mstype
.
float32
)
self
.
_norm_bound
=
Tensor
(
norm_bound
,
mstype
.
float32
)
self
.
_initial_noise_multiplier
=
check_value_positive
(
self
.
_initial_noise_multiplier
=
check_value_positive
(
'initial_noise_multiplier'
,
'initial_noise_multiplier'
,
initial_noise_multiplier
)
initial_noise_multiplier
)
self
.
_initial_noise_multiplier
=
Tensor
(
initial_noise_multiplier
,
mstype
.
float32
)
self
.
_initial_noise_multiplier
=
Tensor
(
initial_noise_multiplier
,
mstype
.
float32
)
self
.
_mean
=
Tensor
(
0
,
mstype
.
float32
)
self
.
_mean
=
Tensor
(
0
,
mstype
.
float32
)
self
.
_normal
=
P
.
Normal
(
seed
=
seed
)
if
noise_update
is
not
None
:
self
.
_decay_policy
=
policy
raise
ValueError
(
'noise_update must be None in GaussianRandom class, but got {}.'
.
format
(
noise_update
))
self
.
_noise_update
=
noise_update
self
.
_seed
=
seed
def
construct
(
self
,
gradients
):
def
construct
(
self
,
gradients
):
"""
"""
...
@@ -203,26 +217,25 @@ class NoiseGaussianRandom(Mechanisms):
...
@@ -203,26 +217,25 @@ class NoiseGaussianRandom(Mechanisms):
"""
"""
shape
=
P
.
Shape
()(
gradients
)
shape
=
P
.
Shape
()(
gradients
)
stddev
=
P
.
Mul
()(
self
.
_norm_bound
,
self
.
_initial_noise_multiplier
)
stddev
=
P
.
Mul
()(
self
.
_norm_bound
,
self
.
_initial_noise_multiplier
)
noise
=
self
.
_normal
(
shape
,
self
.
_mean
,
stddev
)
noise
=
normal
(
shape
,
self
.
_mean
,
stddev
,
self
.
_seed
)
return
noise
return
noise
class
AdaGaussianRandom
(
Mechanisms
):
class
NoiseAdaGaussianRandom
(
NoiseGaussianRandom
):
"""
"""
Adaptive Gaussian noise generated mechanism. Noise would be decayed with
Adaptive Gaussian noise generated mechanism. Noise would be decayed with
training. Decay mode could be 'Time' mode
or 'Ste
p' mode.
training. Decay mode could be 'Time' mode
, 'Step' mode, 'Ex
p' mode.
Args:
Args:
norm_bound(float): Clipping bound for the l2 norm of the gradients.
norm_bound(float): Clipping bound for the l2 norm of the gradients.
Default: 1.0.
initial_noise_multiplier(float): Ratio of the standard deviation of
initial_noise_multiplier(float): Ratio of the standard deviation of
Gaussian noise divided by the norm_bound, which will be used to
Gaussian noise divided by the norm_bound, which will be used to
calculate privacy spent. Default: 1.5.
calculate privacy spent.
seed(int): Original random seed, if seed=0 random normal will use secure
random number. IF seed!=0 random normal will generate values using
given seed.
noise_decay_rate(float): Hyper parameter for controlling the noise decay.
noise_decay_rate(float): Hyper parameter for controlling the noise decay.
Default: 6e-4.
noise_update(str): Noise decay strategy include 'Step', 'Time', 'Exp'.
decay_policy(str): Noise decay strategy include 'Step' and 'Time'.
Default: 'Time'.
seed(int): Original random seed. Default: 0.
Returns:
Returns:
Tensor, generated noise with shape like given gradients.
Tensor, generated noise with shape like given gradients.
...
@@ -231,56 +244,27 @@ class AdaGaussianRandom(Mechanisms):
...
@@ -231,56 +244,27 @@ class AdaGaussianRandom(Mechanisms):
>>> gradients = Tensor([0.2, 0.9], mstype.float32)
>>> gradients = Tensor([0.2, 0.9], mstype.float32)
>>> norm_bound = 1.0
>>> norm_bound = 1.0
>>> initial_noise_multiplier = 1.5
>>> initial_noise_multiplier = 1.5
>>> seed = 0
>>> noise_decay_rate = 6e-4
>>> noise_decay_rate = 6e-4
>>> decay_policy = "Time"
>>> noise_update = "Time"
>>> net = AdaGaussianRandom(norm_bound, initial_noise_multiplier,
>>> net = NoiseAdaGaussianRandom(norm_bound, initial_noise_multiplier, seed, noise_decay_rate, noise_update)
>>> noise_decay_rate, decay_policy)
>>> res = net(gradients)
>>> res = net(gradients)
>>> print(res)
>>> print(res)
"""
"""
def
__init__
(
self
,
norm_bound
=
1.0
,
initial_noise_multiplier
=
1.5
,
def
__init__
(
self
,
norm_bound
,
initial_noise_multiplier
,
seed
,
noise_decay_rate
,
noise_update
):
noise_decay_rate
=
6e-4
,
decay_policy
=
'Time'
,
seed
=
0
):
super
(
NoiseAdaGaussianRandom
,
self
).
__init__
(
norm_bound
=
norm_bound
,
super
(
AdaGaussianRandom
,
self
).
__init__
()
initial_noise_multiplier
=
initial_noise_multiplier
,
norm_bound
=
check_value_positive
(
'norm_bound'
,
norm_bound
)
seed
=
seed
)
initial_noise_multiplier
=
check_value_positive
(
self
.
_noise_multiplier
=
Parameter
(
self
.
_initial_noise_multiplier
,
'initial_noise_multiplier'
,
initial_noise_multiplier
)
self
.
_norm_bound
=
Tensor
(
norm_bound
,
mstype
.
float32
)
initial_noise_multiplier
=
Tensor
(
initial_noise_multiplier
,
mstype
.
float32
)
self
.
_initial_noise_multiplier
=
Parameter
(
initial_noise_multiplier
,
name
=
'initial_noise_multiplier'
)
self
.
_noise_multiplier
=
Parameter
(
initial_noise_multiplier
,
name
=
'noise_multiplier'
)
name
=
'noise_multiplier'
)
self
.
_mean
=
Tensor
(
0
,
mstype
.
float32
)
noise_decay_rate
=
check_param_type
(
'noise_decay_rate'
,
noise_decay_rate
,
float
)
noise_decay_rate
=
check_param_type
(
'noise_decay_rate'
,
noise_decay_rate
,
float
)
check_param_in_range
(
'noise_decay_rate'
,
noise_decay_rate
,
0.0
,
1.0
)
check_param_in_range
(
'noise_decay_rate'
,
noise_decay_rate
,
0.0
,
1.0
)
self
.
_noise_decay_rate
=
Tensor
(
noise_decay_rate
,
mstype
.
float32
)
self
.
_noise_decay_rate
=
Tensor
(
noise_decay_rate
,
mstype
.
float32
)
if
decay_policy
not
in
[
'Time'
,
'Step'
,
'Exp'
]:
if
noise_update
not
in
[
'Time'
,
'Step'
,
'Exp'
]:
raise
NameError
(
"The decay_policy must be in ['Time', 'Step', 'Exp'], but "
raise
NameError
(
"The noise_update must be in ['Time', 'Step', 'Exp'], but "
"get {}"
.
format
(
decay_policy
))
"get {}"
.
format
(
noise_update
))
self
.
_decay_policy
=
decay_policy
self
.
_noise_update
=
noise_update
self
.
_mul
=
P
.
Mul
()
self
.
_normal
=
P
.
Normal
(
seed
=
seed
)
def
construct
(
self
,
gradients
):
"""
Generate adaptive Gaussian noise.
Args:
gradients(Tensor): The gradients.
Returns:
Tensor, generated noise with shape like given gradients.
"""
shape
=
P
.
Shape
()(
gradients
)
noise
=
self
.
_normal
(
shape
,
self
.
_mean
,
self
.
_mul
(
self
.
_noise_multiplier
,
self
.
_norm_bound
))
return
noise
class
_MechanismsParamsUpdater
(
Cell
):
class
_MechanismsParamsUpdater
(
Cell
):
...
@@ -288,7 +272,7 @@ class _MechanismsParamsUpdater(Cell):
...
@@ -288,7 +272,7 @@ class _MechanismsParamsUpdater(Cell):
Update mechanisms parameters, the parameters will refresh in train period.
Update mechanisms parameters, the parameters will refresh in train period.
Args:
Args:
policy
(str): Pass in by the mechanisms class, mechanisms parameters
noise_update
(str): Pass in by the mechanisms class, mechanisms parameters
update policy.
update policy.
decay_rate(Tensor): Pass in by the mechanisms class, hyper parameter for
decay_rate(Tensor): Pass in by the mechanisms class, hyper parameter for
controlling the decay size.
controlling the decay size.
...
@@ -300,9 +284,9 @@ class _MechanismsParamsUpdater(Cell):
...
@@ -300,9 +284,9 @@ class _MechanismsParamsUpdater(Cell):
Returns:
Returns:
Tuple, next params value.
Tuple, next params value.
"""
"""
def
__init__
(
self
,
policy
,
decay_rate
,
cur_noise_multiplier
,
init_noise_multiplier
):
def
__init__
(
self
,
noise_update
,
decay_rate
,
cur_noise_multiplier
,
init_noise_multiplier
):
super
(
_MechanismsParamsUpdater
,
self
).
__init__
()
super
(
_MechanismsParamsUpdater
,
self
).
__init__
()
self
.
_
policy
=
policy
self
.
_
noise_update
=
noise_update
self
.
_decay_rate
=
decay_rate
self
.
_decay_rate
=
decay_rate
self
.
_cur_noise_multiplier
=
cur_noise_multiplier
self
.
_cur_noise_multiplier
=
cur_noise_multiplier
self
.
_init_noise_multiplier
=
init_noise_multiplier
self
.
_init_noise_multiplier
=
init_noise_multiplier
...
@@ -322,27 +306,27 @@ class _MechanismsParamsUpdater(Cell):
...
@@ -322,27 +306,27 @@ class _MechanismsParamsUpdater(Cell):
Returns:
Returns:
Tuple, next step parameters value.
Tuple, next step parameters value.
"""
"""
if
self
.
_
policy
==
'Time'
:
if
self
.
_
noise_update
==
'Time'
:
temp
=
self
.
_div
(
self
.
_init_noise_multiplier
,
self
.
_cur_noise_multiplier
)
temp
=
self
.
_div
(
self
.
_init_noise_multiplier
,
self
.
_cur_noise_multiplier
)
temp
=
self
.
_add
(
temp
,
self
.
_decay_rate
)
temp
=
self
.
_add
(
temp
,
self
.
_decay_rate
)
next_noise_multiplier
=
self
.
_assign
(
self
.
_cur_noise_multiplier
,
next_noise_multiplier
=
self
.
_assign
(
self
.
_cur_noise_multiplier
,
self
.
_div
(
self
.
_init_noise_multiplier
,
temp
))
self
.
_div
(
self
.
_init_noise_multiplier
,
temp
))
elif
self
.
_
policy
==
'Step'
:
elif
self
.
_
noise_update
==
'Step'
:
temp
=
self
.
_sub
(
self
.
_one
,
self
.
_decay_rate
)
temp
=
self
.
_sub
(
self
.
_one
,
self
.
_decay_rate
)
next_noise_multiplier
=
self
.
_assign
(
self
.
_cur_noise_multiplier
,
next_noise_multiplier
=
self
.
_assign
(
self
.
_cur_noise_multiplier
,
self
.
_mul
(
temp
,
self
.
_cur_noise_multiplier
))
self
.
_mul
(
temp
,
self
.
_cur_noise_multiplier
))
else
:
else
:
next_noise_multiplier
=
self
.
_assign
(
self
.
_cur_noise_multiplier
,
next_noise_multiplier
=
self
.
_assign
(
self
.
_cur_noise_multiplier
,
self
.
_div
(
self
.
_
one
,
self
.
_exp
(
self
.
_on
e
)))
self
.
_div
(
self
.
_
cur_noise_multiplier
,
self
.
_exp
(
self
.
_decay_rat
e
)))
return
next_noise_multiplier
return
next_noise_multiplier
class
AdaClippingWithGaussianRandom
(
Cell
):
class
AdaClippingWithGaussianRandom
(
Cell
):
"""
"""
Adaptive clipping. If `decay_policy` is 'Linear', the update formula is
Adaptive clipping. If `decay_policy` is 'Linear', the update formula is
$ norm_
clip = norm_clip
- learning_rate*(beta-target_unclipped_quantile)$.
$ norm_
bound = norm_bound
- learning_rate*(beta-target_unclipped_quantile)$.
`decay_policy` is 'Geometric', the update formula is
`decay_policy` is 'Geometric', the update formula is
$ norm_
clip = norm_clip
*exp(-learning_rate*(empirical_fraction-target_unclipped_quantile))$.
$ norm_
bound = norm_bound
*exp(-learning_rate*(empirical_fraction-target_unclipped_quantile))$.
where beta is the empirical fraction of samples with the value at most
where beta is the empirical fraction of samples with the value at most
`target_unclipped_quantile`.
`target_unclipped_quantile`.
...
@@ -363,7 +347,7 @@ class AdaClippingWithGaussianRandom(Cell):
...
@@ -363,7 +347,7 @@ class AdaClippingWithGaussianRandom(Cell):
Examples:
Examples:
>>> decay_policy = 'Linear'
>>> decay_policy = 'Linear'
>>> beta = Tensor(0.5, mstype.float32)
>>> beta = Tensor(0.5, mstype.float32)
>>> norm_
clip
= Tensor(1.0, mstype.float32)
>>> norm_
bound
= Tensor(1.0, mstype.float32)
>>> beta_stddev = 0.01
>>> beta_stddev = 0.01
>>> learning_rate = 0.001
>>> learning_rate = 0.001
>>> target_unclipped_quantile = 0.9
>>> target_unclipped_quantile = 0.9
...
@@ -371,7 +355,7 @@ class AdaClippingWithGaussianRandom(Cell):
...
@@ -371,7 +355,7 @@ class AdaClippingWithGaussianRandom(Cell):
>>> learning_rate=learning_rate,
>>> learning_rate=learning_rate,
>>> target_unclipped_quantile=target_unclipped_quantile,
>>> target_unclipped_quantile=target_unclipped_quantile,
>>> fraction_stddev=beta_stddev)
>>> fraction_stddev=beta_stddev)
>>> next_norm_
clip = ada_clip(beta, norm_clip
)
>>> next_norm_
bound = ada_clip(beta, norm_bound
)
"""
"""
...
@@ -400,32 +384,32 @@ class AdaClippingWithGaussianRandom(Cell):
...
@@ -400,32 +384,32 @@ class AdaClippingWithGaussianRandom(Cell):
self
.
_sub
=
P
.
Sub
()
self
.
_sub
=
P
.
Sub
()
self
.
_mul
=
P
.
Mul
()
self
.
_mul
=
P
.
Mul
()
self
.
_exp
=
P
.
Exp
()
self
.
_exp
=
P
.
Exp
()
self
.
_
normal
=
P
.
Normal
(
seed
=
seed
)
self
.
_
seed
=
seed
def
construct
(
self
,
empirical_fraction
,
norm_
clip
):
def
construct
(
self
,
empirical_fraction
,
norm_
bound
):
"""
"""
Update value of norm_
clip
.
Update value of norm_
bound
.
Args:
Args:
empirical_fraction(Tensor): empirical fraction of samples with the
empirical_fraction(Tensor): empirical fraction of samples with the
value at most `target_unclipped_quantile`.
value at most `target_unclipped_quantile`.
norm_
clip
(Tensor): Clipping bound for the l2 norm of the gradients.
norm_
bound
(Tensor): Clipping bound for the l2 norm of the gradients.
Returns:
Returns:
Tensor, generated noise with shape like given gradients.
Tensor, generated noise with shape like given gradients.
"""
"""
fraction_noise
=
self
.
_normal
((
1
,),
self
.
_zero
,
self
.
_fraction_stddev
)
fraction_noise
=
normal
((
1
,),
self
.
_zero
,
self
.
_fraction_stddev
,
self
.
_seed
)
empirical_fraction
=
self
.
_add
(
empirical_fraction
,
fraction_noise
)
empirical_fraction
=
self
.
_add
(
empirical_fraction
,
fraction_noise
)
if
self
.
_decay_policy
==
'Linear'
:
if
self
.
_decay_policy
==
'Linear'
:
grad_clip
=
self
.
_sub
(
empirical_fraction
,
grad_clip
=
self
.
_sub
(
empirical_fraction
,
self
.
_target_unclipped_quantile
)
self
.
_target_unclipped_quantile
)
next_norm_
clip
=
self
.
_sub
(
norm_clip
,
next_norm_
bound
=
self
.
_sub
(
norm_bound
,
self
.
_mul
(
self
.
_learning_rate
,
grad_clip
))
self
.
_mul
(
self
.
_learning_rate
,
grad_clip
))
# decay_policy == 'Geometric'
# decay_policy == 'Geometric'
else
:
else
:
grad_clip
=
self
.
_sub
(
empirical_fraction
,
grad_clip
=
self
.
_sub
(
empirical_fraction
,
self
.
_target_unclipped_quantile
)
self
.
_target_unclipped_quantile
)
grad_clip
=
self
.
_exp
(
self
.
_mul
(
-
self
.
_learning_rate
,
grad_clip
))
grad_clip
=
self
.
_exp
(
self
.
_mul
(
-
self
.
_learning_rate
,
grad_clip
))
next_norm_
clip
=
self
.
_mul
(
norm_clip
,
grad_clip
)
next_norm_
bound
=
self
.
_mul
(
norm_bound
,
grad_clip
)
return
next_norm_
clip
return
next_norm_
bound
mindarmour/diff_privacy/optimizer/optimizer.py
浏览文件 @
69e45a3d
...
@@ -127,8 +127,8 @@ class DPOptimizerClassFactory:
...
@@ -127,8 +127,8 @@ class DPOptimizerClassFactory:
self
.
_micro_float
=
Tensor
(
micro_batches
,
mstype
.
float32
)
self
.
_micro_float
=
Tensor
(
micro_batches
,
mstype
.
float32
)
self
.
_mech_param_updater
=
None
self
.
_mech_param_updater
=
None
if
self
.
_mech
is
not
None
and
self
.
_mech
.
_
decay_policy
is
not
None
:
if
self
.
_mech
is
not
None
and
self
.
_mech
.
_
noise_update
is
not
None
:
self
.
_mech_param_updater
=
_MechanismsParamsUpdater
(
policy
=
self
.
_mech
.
_decay_policy
,
self
.
_mech_param_updater
=
_MechanismsParamsUpdater
(
noise_update
=
self
.
_mech
.
_noise_update
,
decay_rate
=
self
.
_mech
.
_noise_decay_rate
,
decay_rate
=
self
.
_mech
.
_noise_decay_rate
,
cur_noise_multiplier
=
cur_noise_multiplier
=
self
.
_mech
.
_noise_multiplier
,
self
.
_mech
.
_noise_multiplier
,
...
...
mindarmour/diff_privacy/train/model.py
浏览文件 @
69e45a3d
...
@@ -75,7 +75,7 @@ class DPModel(Model):
...
@@ -75,7 +75,7 @@ class DPModel(Model):
Args:
Args:
micro_batches (int): The number of small batches split from an original
micro_batches (int): The number of small batches split from an original
batch. Default: 2.
batch. Default: 2.
norm_
clip (float): Use to clip the bound, if set 1, will retu
n the
norm_
bound (float): Use to clip the bound, if set 1, will retur
n the
original data. Default: 1.0.
original data. Default: 1.0.
noise_mech (Mechanisms): The object can generate the different type of
noise_mech (Mechanisms): The object can generate the different type of
noise. Default: None.
noise. Default: None.
...
@@ -83,7 +83,7 @@ class DPModel(Model):
...
@@ -83,7 +83,7 @@ class DPModel(Model):
Default: None.
Default: None.
Examples:
Examples:
>>> norm_
clip
= 1.0
>>> norm_
bound
= 1.0
>>> initial_noise_multiplier = 0.01
>>> initial_noise_multiplier = 0.01
>>> network = LeNet5()
>>> network = LeNet5()
>>> batch_size = 32
>>> batch_size = 32
...
@@ -93,7 +93,7 @@ class DPModel(Model):
...
@@ -93,7 +93,7 @@ class DPModel(Model):
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> factory_opt = DPOptimizerClassFactory(micro_batches=micro_batches)
>>> factory_opt = DPOptimizerClassFactory(micro_batches=micro_batches)
>>> factory_opt.set_mechanisms('Gaussian',
>>> factory_opt.set_mechanisms('Gaussian',
>>> norm_bound=norm_
clip
,
>>> norm_bound=norm_
bound
,
>>> initial_noise_multiplier=initial_noise_multiplier)
>>> initial_noise_multiplier=initial_noise_multiplier)
>>> net_opt = factory_opt.create('Momentum')(network.trainable_params(),
>>> net_opt = factory_opt.create('Momentum')(network.trainable_params(),
>>> learning_rate=0.1, momentum=0.9)
>>> learning_rate=0.1, momentum=0.9)
...
@@ -103,7 +103,7 @@ class DPModel(Model):
...
@@ -103,7 +103,7 @@ class DPModel(Model):
>>> target_unclipped_quantile=0.9,
>>> target_unclipped_quantile=0.9,
>>> fraction_stddev=0.01)
>>> fraction_stddev=0.01)
>>> model = DPModel(micro_batches=micro_batches,
>>> model = DPModel(micro_batches=micro_batches,
>>> norm_
clip=norm_clip
,
>>> norm_
bound=norm_bound
,
>>> clip_mech=clip_mech,
>>> clip_mech=clip_mech,
>>> noise_mech=None,
>>> noise_mech=None,
>>> network=network,
>>> network=network,
...
@@ -116,17 +116,18 @@ class DPModel(Model):
...
@@ -116,17 +116,18 @@ class DPModel(Model):
>>> model.train(epochs, ms_ds, dataset_sink_mode=False)
>>> model.train(epochs, ms_ds, dataset_sink_mode=False)
"""
"""
def
__init__
(
self
,
micro_batches
=
2
,
norm_
clip
=
1.0
,
noise_mech
=
None
,
def
__init__
(
self
,
micro_batches
=
2
,
norm_
bound
=
1.0
,
noise_mech
=
None
,
clip_mech
=
None
,
**
kwargs
):
clip_mech
=
None
,
**
kwargs
):
if
micro_batches
:
if
micro_batches
:
self
.
_micro_batches
=
check_int_positive
(
'micro_batches'
,
self
.
_micro_batches
=
check_int_positive
(
'micro_batches'
,
micro_batches
)
micro_batches
)
else
:
else
:
self
.
_micro_batches
=
None
self
.
_micro_batches
=
None
norm_clip
=
check_param_type
(
'norm_clip'
,
norm_clip
,
float
)
norm_bound
=
check_param_type
(
'norm_bound'
,
norm_bound
,
float
)
norm_clip
=
check_value_positive
(
'norm_clip'
,
norm_clip
)
norm_bound
=
check_value_positive
(
'norm_bound'
,
norm_bound
)
norm_clip
=
Tensor
(
norm_clip
,
mstype
.
float32
)
norm_bound
=
Tensor
(
norm_bound
,
mstype
.
float32
)
self
.
_norm_clip
=
Parameter
(
norm_clip
,
'norm_clip'
)
self
.
_norm_bound
=
Parameter
(
norm_bound
,
'norm_bound'
)
if
noise_mech
is
not
None
and
"DPOptimizer"
in
kwargs
[
'optimizer'
].
__class__
.
__name__
:
if
noise_mech
is
not
None
and
"DPOptimizer"
in
kwargs
[
'optimizer'
].
__class__
.
__name__
:
msg
=
'DPOptimizer is not supported while noise_mech is not None'
msg
=
'DPOptimizer is not supported while noise_mech is not None'
LOGGER
.
error
(
TAG
,
msg
)
LOGGER
.
error
(
TAG
,
msg
)
...
@@ -219,14 +220,14 @@ class DPModel(Model):
...
@@ -219,14 +220,14 @@ class DPModel(Model):
optimizer
,
optimizer
,
scale_update_cell
=
update_cell
,
scale_update_cell
=
update_cell
,
micro_batches
=
self
.
_micro_batches
,
micro_batches
=
self
.
_micro_batches
,
norm_
clip
=
self
.
_norm_clip
,
norm_
bound
=
self
.
_norm_bound
,
clip_mech
=
self
.
_clip_mech
,
clip_mech
=
self
.
_clip_mech
,
noise_mech
=
self
.
_noise_mech
).
set_train
()
noise_mech
=
self
.
_noise_mech
).
set_train
()
return
network
return
network
network
=
_TrainOneStepCell
(
network
,
network
=
_TrainOneStepCell
(
network
,
optimizer
,
optimizer
,
self
.
_norm_
clip
,
self
.
_norm_
bound
,
loss_scale
,
loss_scale
,
micro_batches
=
self
.
_micro_batches
,
micro_batches
=
self
.
_micro_batches
,
clip_mech
=
self
.
_clip_mech
,
clip_mech
=
self
.
_clip_mech
,
...
@@ -347,7 +348,7 @@ class _TrainOneStepWithLossScaleCell(Cell):
...
@@ -347,7 +348,7 @@ class _TrainOneStepWithLossScaleCell(Cell):
Default: None.
Default: None.
micro_batches (int): The number of small batches split from an original
micro_batches (int): The number of small batches split from an original
batch. Default: None.
batch. Default: None.
norm_
clip
(Tensor): Use to clip the bound, if set 1, will return the
norm_
bound
(Tensor): Use to clip the bound, if set 1, will return the
original data. Default: 1.0.
original data. Default: 1.0.
noise_mech (Mechanisms): The object can generate the different type of
noise_mech (Mechanisms): The object can generate the different type of
noise. Default: None.
noise. Default: None.
...
@@ -366,7 +367,7 @@ class _TrainOneStepWithLossScaleCell(Cell):
...
@@ -366,7 +367,7 @@ class _TrainOneStepWithLossScaleCell(Cell):
"""
"""
def
__init__
(
self
,
network
,
optimizer
,
scale_update_cell
=
None
,
def
__init__
(
self
,
network
,
optimizer
,
scale_update_cell
=
None
,
micro_batches
=
None
,
norm_
clip
=
1.0
,
noise_mech
=
None
,
micro_batches
=
None
,
norm_
bound
=
1.0
,
noise_mech
=
None
,
clip_mech
=
None
):
clip_mech
=
None
):
super
(
_TrainOneStepWithLossScaleCell
,
self
).
__init__
(
auto_prefix
=
False
)
super
(
_TrainOneStepWithLossScaleCell
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
network
=
network
...
@@ -405,15 +406,13 @@ class _TrainOneStepWithLossScaleCell(Cell):
...
@@ -405,15 +406,13 @@ class _TrainOneStepWithLossScaleCell(Cell):
self
.
loss_scale
=
None
self
.
loss_scale
=
None
self
.
loss_scaling_manager
=
scale_update_cell
self
.
loss_scaling_manager
=
scale_update_cell
if
scale_update_cell
:
if
scale_update_cell
:
self
.
loss_scale
=
Parameter
(
self
.
loss_scale
=
Parameter
(
Tensor
(
scale_update_cell
.
get_loss_scale
(),
dtype
=
mstype
.
float32
),
Tensor
(
scale_update_cell
.
get_loss_scale
(),
name
=
"loss_scale"
)
dtype
=
mstype
.
float32
),
name
=
"loss_scale"
)
self
.
add_flags
(
has_effect
=
True
)
self
.
add_flags
(
has_effect
=
True
)
# dp params
# dp params
self
.
_micro_batches
=
micro_batches
self
.
_micro_batches
=
micro_batches
self
.
_norm_
clip
=
norm_clip
self
.
_norm_
bound
=
norm_bound
self
.
_split
=
P
.
Split
(
0
,
self
.
_micro_batches
)
self
.
_split
=
P
.
Split
(
0
,
self
.
_micro_batches
)
self
.
_clip_by_global_norm
=
_ClipGradients
()
self
.
_clip_by_global_norm
=
_ClipGradients
()
self
.
_noise_mech
=
noise_mech
self
.
_noise_mech
=
noise_mech
...
@@ -433,9 +432,9 @@ class _TrainOneStepWithLossScaleCell(Cell):
...
@@ -433,9 +432,9 @@ class _TrainOneStepWithLossScaleCell(Cell):
self
.
_cast
=
P
.
Cast
()
self
.
_cast
=
P
.
Cast
()
self
.
_noise_mech_param_updater
=
None
self
.
_noise_mech_param_updater
=
None
if
self
.
_noise_mech
is
not
None
and
self
.
_noise_mech
.
_
decay_policy
is
not
None
:
if
self
.
_noise_mech
is
not
None
and
self
.
_noise_mech
.
_
noise_update
is
not
None
:
self
.
_noise_mech_param_updater
=
_MechanismsParamsUpdater
(
self
.
_noise_mech_param_updater
=
_MechanismsParamsUpdater
(
policy
=
self
.
_noise_mech
.
_decay_policy
,
noise_update
=
self
.
_noise_mech
.
_noise_update
,
decay_rate
=
self
.
_noise_mech
.
_noise_decay_rate
,
decay_rate
=
self
.
_noise_mech
.
_noise_decay_rate
,
cur_noise_multiplier
=
cur_noise_multiplier
=
self
.
_noise_mech
.
_noise_multiplier
,
self
.
_noise_mech
.
_noise_multiplier
,
...
@@ -477,10 +476,10 @@ class _TrainOneStepWithLossScaleCell(Cell):
...
@@ -477,10 +476,10 @@ class _TrainOneStepWithLossScaleCell(Cell):
self
.
_reduce_sum
(
self
.
_square_all
(
grad
)))
self
.
_reduce_sum
(
self
.
_square_all
(
grad
)))
norm_grad
=
self
.
_sqrt
(
square_sum
)
norm_grad
=
self
.
_sqrt
(
square_sum
)
beta
=
self
.
_add
(
beta
,
beta
=
self
.
_add
(
beta
,
self
.
_cast
(
self
.
_less
(
norm_grad
,
self
.
_norm_
clip
),
self
.
_cast
(
self
.
_less
(
norm_grad
,
self
.
_norm_
bound
),
mstype
.
float32
))
mstype
.
float32
))
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
self
.
_norm_
clip
)
self
.
_norm_
bound
)
grads
=
record_grad
grads
=
record_grad
total_loss
=
loss
total_loss
=
loss
for
i
in
range
(
1
,
self
.
_micro_batches
):
for
i
in
range
(
1
,
self
.
_micro_batches
):
...
@@ -497,12 +496,12 @@ class _TrainOneStepWithLossScaleCell(Cell):
...
@@ -497,12 +496,12 @@ class _TrainOneStepWithLossScaleCell(Cell):
self
.
_reduce_sum
(
self
.
_square_all
(
grad
)))
self
.
_reduce_sum
(
self
.
_square_all
(
grad
)))
norm_grad
=
self
.
_sqrt
(
square_sum
)
norm_grad
=
self
.
_sqrt
(
square_sum
)
beta
=
self
.
_add
(
beta
,
beta
=
self
.
_add
(
beta
,
self
.
_cast
(
self
.
_less
(
norm_grad
,
self
.
_norm_
clip
),
self
.
_cast
(
self
.
_less
(
norm_grad
,
self
.
_norm_
bound
),
mstype
.
float32
))
mstype
.
float32
))
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
GRADIENT_CLIP_TYPE
,
self
.
_norm_
clip
)
self
.
_norm_
bound
)
grads
=
self
.
_tuple_add
(
grads
,
record_grad
)
grads
=
self
.
_tuple_add
(
grads
,
record_grad
)
total_loss
=
P
.
TensorAdd
()(
total_loss
,
loss
)
total_loss
=
P
.
TensorAdd
()(
total_loss
,
loss
)
loss
=
P
.
Div
()(
total_loss
,
self
.
_micro_float
)
loss
=
P
.
Div
()(
total_loss
,
self
.
_micro_float
)
...
@@ -552,8 +551,8 @@ class _TrainOneStepWithLossScaleCell(Cell):
...
@@ -552,8 +551,8 @@ class _TrainOneStepWithLossScaleCell(Cell):
ret
=
(
loss
,
cond
,
scaling_sens
)
ret
=
(
loss
,
cond
,
scaling_sens
)
if
self
.
_clip_mech
is
not
None
:
if
self
.
_clip_mech
is
not
None
:
next_norm_
clip
=
self
.
_clip_mech
(
beta
,
self
.
_norm_clip
)
next_norm_
bound
=
self
.
_clip_mech
(
beta
,
self
.
_norm_bound
)
P
.
assign
(
self
.
_norm_
clip
,
next_norm_clip
)
P
.
assign
(
self
.
_norm_
bound
,
next_norm_bound
)
return
F
.
depend
(
ret
,
opt
)
return
F
.
depend
(
ret
,
opt
)
...
@@ -573,7 +572,7 @@ class _TrainOneStepCell(Cell):
...
@@ -573,7 +572,7 @@ class _TrainOneStepCell(Cell):
propagation. Default value is 1.0.
propagation. Default value is 1.0.
micro_batches (int): The number of small batches split from an original
micro_batches (int): The number of small batches split from an original
batch. Default: None.
batch. Default: None.
norm_
clip
(Tensor): Use to clip the bound, if set 1, will return the
norm_
bound
(Tensor): Use to clip the bound, if set 1, will return the
original data. Default: 1.0.
original data. Default: 1.0.
noise_mech (Mechanisms): The object can generate the different type
noise_mech (Mechanisms): The object can generate the different type
of noise. Default: None.
of noise. Default: None.
...
@@ -586,7 +585,7 @@ class _TrainOneStepCell(Cell):
...
@@ -586,7 +585,7 @@ class _TrainOneStepCell(Cell):
Tensor, a scalar Tensor with shape :math:`()`.
Tensor, a scalar Tensor with shape :math:`()`.
"""
"""
def
__init__
(
self
,
network
,
optimizer
,
norm_
clip
=
1.0
,
sens
=
1.0
,
def
__init__
(
self
,
network
,
optimizer
,
norm_
bound
=
1.0
,
sens
=
1.0
,
micro_batches
=
None
,
micro_batches
=
None
,
noise_mech
=
None
,
clip_mech
=
None
):
noise_mech
=
None
,
clip_mech
=
None
):
super
(
_TrainOneStepCell
,
self
).
__init__
(
auto_prefix
=
False
)
super
(
_TrainOneStepCell
,
self
).
__init__
(
auto_prefix
=
False
)
...
@@ -616,7 +615,7 @@ class _TrainOneStepCell(Cell):
...
@@ -616,7 +615,7 @@ class _TrainOneStepCell(Cell):
LOGGER
.
error
(
TAG
,
msg
)
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
self
.
_micro_batches
=
micro_batches
self
.
_micro_batches
=
micro_batches
self
.
_norm_
clip
=
norm_clip
self
.
_norm_
bound
=
norm_bound
self
.
_split
=
P
.
Split
(
0
,
self
.
_micro_batches
)
self
.
_split
=
P
.
Split
(
0
,
self
.
_micro_batches
)
self
.
_clip_by_global_norm
=
_ClipGradients
()
self
.
_clip_by_global_norm
=
_ClipGradients
()
self
.
_noise_mech
=
noise_mech
self
.
_noise_mech
=
noise_mech
...
@@ -637,9 +636,9 @@ class _TrainOneStepCell(Cell):
...
@@ -637,9 +636,9 @@ class _TrainOneStepCell(Cell):
self
.
_micro_float
=
Tensor
(
micro_batches
,
mstype
.
float32
)
self
.
_micro_float
=
Tensor
(
micro_batches
,
mstype
.
float32
)
self
.
_noise_mech_param_updater
=
None
self
.
_noise_mech_param_updater
=
None
if
self
.
_noise_mech
is
not
None
and
self
.
_noise_mech
.
_
decay_policy
is
not
None
:
if
self
.
_noise_mech
is
not
None
and
self
.
_noise_mech
.
_
noise_update
is
not
None
:
self
.
_noise_mech_param_updater
=
_MechanismsParamsUpdater
(
self
.
_noise_mech_param_updater
=
_MechanismsParamsUpdater
(
policy
=
self
.
_noise_mech
.
_decay_policy
,
noise_update
=
self
.
_noise_mech
.
_noise_update
,
decay_rate
=
self
.
_noise_mech
.
_noise_decay_rate
,
decay_rate
=
self
.
_noise_mech
.
_noise_decay_rate
,
cur_noise_multiplier
=
cur_noise_multiplier
=
self
.
_noise_mech
.
_noise_multiplier
,
self
.
_noise_mech
.
_noise_multiplier
,
...
@@ -664,11 +663,11 @@ class _TrainOneStepCell(Cell):
...
@@ -664,11 +663,11 @@ class _TrainOneStepCell(Cell):
self
.
_reduce_sum
(
self
.
_square_all
(
grad
)))
self
.
_reduce_sum
(
self
.
_square_all
(
grad
)))
norm_grad
=
self
.
_sqrt
(
square_sum
)
norm_grad
=
self
.
_sqrt
(
square_sum
)
beta
=
self
.
_add
(
beta
,
beta
=
self
.
_add
(
beta
,
self
.
_cast
(
self
.
_less
(
norm_grad
,
self
.
_norm_
clip
),
self
.
_cast
(
self
.
_less
(
norm_grad
,
self
.
_norm_
bound
),
mstype
.
float32
))
mstype
.
float32
))
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
self
.
_norm_
clip
)
self
.
_norm_
bound
)
grads
=
record_grad
grads
=
record_grad
total_loss
=
loss
total_loss
=
loss
for
i
in
range
(
1
,
self
.
_micro_batches
):
for
i
in
range
(
1
,
self
.
_micro_batches
):
...
@@ -683,12 +682,12 @@ class _TrainOneStepCell(Cell):
...
@@ -683,12 +682,12 @@ class _TrainOneStepCell(Cell):
self
.
_reduce_sum
(
self
.
_square_all
(
grad
)))
self
.
_reduce_sum
(
self
.
_square_all
(
grad
)))
norm_grad
=
self
.
_sqrt
(
square_sum
)
norm_grad
=
self
.
_sqrt
(
square_sum
)
beta
=
self
.
_add
(
beta
,
beta
=
self
.
_add
(
beta
,
self
.
_cast
(
self
.
_less
(
norm_grad
,
self
.
_norm_
clip
),
self
.
_cast
(
self
.
_less
(
norm_grad
,
self
.
_norm_
bound
),
mstype
.
float32
))
mstype
.
float32
))
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
record_grad
=
self
.
_clip_by_global_norm
(
record_grad
,
GRADIENT_CLIP_TYPE
,
GRADIENT_CLIP_TYPE
,
self
.
_norm_
clip
)
self
.
_norm_
bound
)
grads
=
self
.
_tuple_add
(
grads
,
record_grad
)
grads
=
self
.
_tuple_add
(
grads
,
record_grad
)
total_loss
=
P
.
TensorAdd
()(
total_loss
,
loss
)
total_loss
=
P
.
TensorAdd
()(
total_loss
,
loss
)
loss
=
self
.
_div
(
total_loss
,
self
.
_micro_float
)
loss
=
self
.
_div
(
total_loss
,
self
.
_micro_float
)
...
@@ -712,8 +711,8 @@ class _TrainOneStepCell(Cell):
...
@@ -712,8 +711,8 @@ class _TrainOneStepCell(Cell):
grads
=
self
.
grad_reducer
(
grads
)
grads
=
self
.
grad_reducer
(
grads
)
if
self
.
_clip_mech
is
not
None
:
if
self
.
_clip_mech
is
not
None
:
next_norm_
clip
=
self
.
_clip_mech
(
beta
,
self
.
_norm_clip
)
next_norm_
bound
=
self
.
_clip_mech
(
beta
,
self
.
_norm_bound
)
self
.
_norm_
clip
=
self
.
_assign
(
self
.
_norm_clip
,
next_norm_clip
)
self
.
_norm_
bound
=
self
.
_assign
(
self
.
_norm_bound
,
next_norm_bound
)
loss
=
F
.
depend
(
loss
,
next_norm_
clip
)
loss
=
F
.
depend
(
loss
,
next_norm_
bound
)
return
F
.
depend
(
loss
,
self
.
optimizer
(
grads
))
return
F
.
depend
(
loss
,
self
.
optimizer
(
grads
))
mindarmour/fuzzing/model_coverage_metrics.py
浏览文件 @
69e45a3d
...
@@ -63,14 +63,14 @@ class ModelCoverageMetrics:
...
@@ -63,14 +63,14 @@ class ModelCoverageMetrics:
self
.
_model
=
check_model
(
'model'
,
model
,
Model
)
self
.
_model
=
check_model
(
'model'
,
model
,
Model
)
self
.
_segmented_num
=
check_int_positive
(
'segmented_num'
,
segmented_num
)
self
.
_segmented_num
=
check_int_positive
(
'segmented_num'
,
segmented_num
)
self
.
_neuron_num
=
check_int_positive
(
'neuron_num'
,
neuron_num
)
self
.
_neuron_num
=
check_int_positive
(
'neuron_num'
,
neuron_num
)
if
self
.
_neuron_num
>
1e+10
:
if
self
.
_neuron_num
>
=
1e+10
:
msg
=
'neuron_num should be less than 1e+10, otherwise a MemoryError'
\
msg
=
'neuron_num should be less than 1e+10, otherwise a MemoryError'
\
'would occur'
'would occur'
LOGGER
.
error
(
TAG
,
msg
)
LOGGER
.
error
(
TAG
,
msg
)
train_dataset
=
check_numpy_param
(
'train_dataset'
,
train_dataset
)
train_dataset
=
check_numpy_param
(
'train_dataset'
,
train_dataset
)
self
.
_lower_bounds
=
[
np
.
inf
]
*
neuron_num
self
.
_lower_bounds
=
[
np
.
inf
]
*
self
.
_
neuron_num
self
.
_upper_bounds
=
[
-
np
.
inf
]
*
neuron_num
self
.
_upper_bounds
=
[
-
np
.
inf
]
*
self
.
_
neuron_num
self
.
_var
=
[
0
]
*
neuron_num
self
.
_var
=
[
0
]
*
self
.
_
neuron_num
self
.
_main_section_hits
=
[[
0
for
_
in
range
(
self
.
_segmented_num
)]
for
_
in
self
.
_main_section_hits
=
[[
0
for
_
in
range
(
self
.
_segmented_num
)]
for
_
in
range
(
self
.
_neuron_num
)]
range
(
self
.
_neuron_num
)]
self
.
_lower_corner_hits
=
[
0
]
*
self
.
_neuron_num
self
.
_lower_corner_hits
=
[
0
]
*
self
.
_neuron_num
...
...
requirements.txt
浏览文件 @
69e45a3d
numpy
>= 1.17.0
numpy
>= 1.17.0
scipy
>= 1.3.3
scipy
>= 1.3.3
matplotlib
>= 3.
1.3
matplotlib
>= 3.
2.1
Pillow
>= 2.0.0
Pillow
>= 2.0.0
pytest
>= 4.3.1
pytest
>= 4.3.1
wheel
>= 0.32.0
wheel
>= 0.32.0
...
...
setup.py
浏览文件 @
69e45a3d
...
@@ -104,7 +104,7 @@ setup(
...
@@ -104,7 +104,7 @@ setup(
install_requires
=
[
install_requires
=
[
'scipy >= 1.3.3'
,
'scipy >= 1.3.3'
,
'numpy >= 1.17.0'
,
'numpy >= 1.17.0'
,
'matplotlib >= 3.
1.3
'
,
'matplotlib >= 3.
2.1
'
,
'Pillow >= 2.0.0'
'Pillow >= 2.0.0'
],
],
classifiers
=
[
classifiers
=
[
...
...
tests/ut/python/diff_privacy/test_mechanisms.py
浏览文件 @
69e45a3d
...
@@ -19,8 +19,7 @@ import pytest
...
@@ -19,8 +19,7 @@ import pytest
from
mindspore
import
context
from
mindspore
import
context
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
dtype
as
mstype
from
mindarmour.diff_privacy
import
NoiseGaussianRandom
from
mindarmour.diff_privacy
import
NoiseAdaGaussianRandom
from
mindarmour.diff_privacy
import
AdaGaussianRandom
from
mindarmour.diff_privacy
import
AdaClippingWithGaussianRandom
from
mindarmour.diff_privacy
import
AdaClippingWithGaussianRandom
from
mindarmour.diff_privacy
import
NoiseMechanismsFactory
from
mindarmour.diff_privacy
import
NoiseMechanismsFactory
from
mindarmour.diff_privacy
import
ClipMechanismsFactory
from
mindarmour.diff_privacy
import
ClipMechanismsFactory
...
@@ -30,72 +29,98 @@ from mindarmour.diff_privacy import ClipMechanismsFactory
...
@@ -30,72 +29,98 @@ from mindarmour.diff_privacy import ClipMechanismsFactory
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
@
pytest
.
mark
.
component_mindarmour
def
test_graph_
gaussian
():
def
test_graph_
factory
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
norm_bound
=
1.0
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
initial_noise_multiplier
=
0.1
net
=
NoiseGaussianRandom
(
norm_bound
,
initial_noise_multiplier
)
alpha
=
0.5
res
=
net
(
grad
)
noise_update
=
'Step'
print
(
res
)
factory
=
NoiseMechanismsFactory
()
noise_mech
=
factory
.
create
(
'Gaussian'
,
norm_bound
,
initial_noise_multiplier
)
noise
=
noise_mech
(
grad
)
print
(
'Gaussian noise: '
,
noise
)
ada_noise_mech
=
factory
.
create
(
'AdaGaussian'
,
norm_bound
,
initial_noise_multiplier
,
noise_decay_rate
=
alpha
,
noise_update
=
noise_update
)
ada_noise
=
ada_noise_mech
(
grad
)
print
(
'ada noise: '
,
ada_noise
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
@
pytest
.
mark
.
component_mindarmour
def
test_pynative_
gaussian
():
def
test_pynative_
factory
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
norm_bound
=
1.0
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
initial_noise_multiplier
=
0.1
net
=
NoiseGaussianRandom
(
norm_bound
,
initial_noise_multiplier
)
alpha
=
0.5
res
=
net
(
grad
)
noise_update
=
'Step'
print
(
res
)
factory
=
NoiseMechanismsFactory
()
noise_mech
=
factory
.
create
(
'Gaussian'
,
norm_bound
,
initial_noise_multiplier
)
noise
=
noise_mech
(
grad
)
print
(
'Gaussian noise: '
,
noise
)
ada_noise_mech
=
factory
.
create
(
'AdaGaussian'
,
norm_bound
,
initial_noise_multiplier
,
noise_decay_rate
=
alpha
,
noise_update
=
noise_update
)
ada_noise
=
ada_noise_mech
(
grad
)
print
(
'ada noise: '
,
ada_noise
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
@
pytest
.
mark
.
component_mindarmour
def
test_
graph_ada
_gaussian
():
def
test_
pynative
_gaussian
():
context
.
set_context
(
mode
=
context
.
GRAPH
_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
PYNATIVE
_MODE
,
device_target
=
"Ascend"
)
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
norm_bound
=
1.0
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
initial_noise_multiplier
=
0.1
alpha
=
0.5
alpha
=
0.5
decay_policy
=
'Step'
noise_update
=
'Step'
net
=
AdaGaussianRandom
(
norm_bound
,
initial_noise_multiplier
,
factory
=
NoiseMechanismsFactory
()
noise_decay_rate
=
alpha
,
decay_policy
=
decay_policy
)
noise_mech
=
factory
.
create
(
'Gaussian'
,
res
=
net
(
grad
)
norm_bound
,
print
(
res
)
initial_noise_multiplier
)
noise
=
noise_mech
(
grad
)
print
(
'Gaussian noise: '
,
noise
)
ada_noise_mech
=
factory
.
create
(
'AdaGaussian'
,
norm_bound
,
initial_noise_multiplier
,
noise_decay_rate
=
alpha
,
noise_update
=
noise_update
)
ada_noise
=
ada_noise_mech
(
grad
)
print
(
'ada noise: '
,
ada_noise
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
@
pytest
.
mark
.
component_mindarmour
def
test_graph_
factory
():
def
test_graph_
ada_gaussian
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
norm_bound
=
1.0
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
initial_noise_multiplier
=
0.1
alpha
=
0.5
noise_decay_rate
=
0.5
decay_policy
=
'Step'
noise_update
=
'Step'
noise_mechanism
=
NoiseMechanismsFactory
()
ada_noise_mech
=
NoiseAdaGaussianRandom
(
norm_bound
,
noise_construct
=
noise_mechanism
.
create
(
'Gaussian'
,
initial_noise_multiplier
,
norm_bound
,
seed
=
0
,
initial_noise_multiplier
)
noise_decay_rate
=
noise_decay_rate
,
noise
=
noise_construct
(
grad
)
noise_update
=
noise_update
)
print
(
'Gaussian noise: '
,
noise
)
res
=
ada_noise_mech
(
grad
)
ada_mechanism
=
NoiseMechanismsFactory
()
print
(
res
)
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
norm_bound
,
initial_noise_multiplier
,
noise_decay_rate
=
alpha
,
decay_policy
=
decay_policy
)
ada_noise
=
ada_noise_construct
(
grad
)
print
(
'ada noise: '
,
ada_noise
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
...
@@ -107,11 +132,14 @@ def test_pynative_ada_gaussian():
...
@@ -107,11 +132,14 @@ def test_pynative_ada_gaussian():
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
norm_bound
=
1.0
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
initial_noise_multiplier
=
0.1
alpha
=
0.5
noise_decay_rate
=
0.5
decay_policy
=
'Step'
noise_update
=
'Step'
net
=
AdaGaussianRandom
(
norm_bound
,
initial_noise_multiplier
,
ada_noise_mech
=
NoiseAdaGaussianRandom
(
norm_bound
,
noise_decay_rate
=
alpha
,
decay_policy
=
decay_policy
)
initial_noise_multiplier
,
res
=
net
(
grad
)
seed
=
0
,
noise_decay_rate
=
noise_decay_rate
,
noise_update
=
noise_update
)
res
=
ada_noise_mech
(
grad
)
print
(
res
)
print
(
res
)
...
@@ -119,26 +147,20 @@ def test_pynative_ada_gaussian():
...
@@ -119,26 +147,20 @@ def test_pynative_ada_gaussian():
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
@
pytest
.
mark
.
component_mindarmour
def
test_
pynative_factory
():
def
test_
graph_exponential
():
context
.
set_context
(
mode
=
context
.
PYNATIVE
_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
GRAPH
_MODE
,
device_target
=
"Ascend"
)
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
norm_bound
=
1.0
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
initial_noise_multiplier
=
0.1
alpha
=
0.5
alpha
=
0.5
decay_policy
=
'Step'
noise_update
=
'Exp'
noise_mechanism
=
NoiseMechanismsFactory
()
factory
=
NoiseMechanismsFactory
()
noise_construct
=
noise_mechanism
.
create
(
'Gaussian'
,
ada_noise
=
factory
.
create
(
'AdaGaussian'
,
norm_bound
,
norm_bound
,
initial_noise_multiplier
)
initial_noise_multiplier
,
noise
=
noise_construct
(
grad
)
noise_decay_rate
=
alpha
,
print
(
'Gaussian noise: '
,
noise
)
noise_update
=
noise_update
)
ada_mechanism
=
NoiseMechanismsFactory
()
ada_noise
=
ada_noise
(
grad
)
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
norm_bound
,
initial_noise_multiplier
,
noise_decay_rate
=
alpha
,
decay_policy
=
decay_policy
)
ada_noise
=
ada_noise_construct
(
grad
)
print
(
'ada noise: '
,
ada_noise
)
print
(
'ada noise: '
,
ada_noise
)
...
@@ -152,35 +174,14 @@ def test_pynative_exponential():
...
@@ -152,35 +174,14 @@ def test_pynative_exponential():
norm_bound
=
1.0
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
initial_noise_multiplier
=
0.1
alpha
=
0.5
alpha
=
0.5
decay_policy
=
'Exp'
noise_update
=
'Exp'
ada_mechanism
=
NoiseMechanismsFactory
()
factory
=
NoiseMechanismsFactory
()
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
ada_noise
=
factory
.
create
(
'AdaGaussian'
,
norm_bound
,
norm_bound
,
initial_noise_multiplier
,
initial_noise_multiplier
,
noise_decay_rate
=
alpha
,
noise_decay_rate
=
alpha
,
decay_policy
=
decay_policy
)
noise_update
=
noise_update
)
ada_noise
=
ada_noise_construct
(
grad
)
ada_noise
=
ada_noise
(
grad
)
print
(
'ada noise: '
,
ada_noise
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
component_mindarmour
def
test_graph_exponential
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
grad
=
Tensor
([
0.3
,
0.2
,
0.4
],
mstype
.
float32
)
norm_bound
=
1.0
initial_noise_multiplier
=
0.1
alpha
=
0.5
decay_policy
=
'Exp'
ada_mechanism
=
NoiseMechanismsFactory
()
ada_noise_construct
=
ada_mechanism
.
create
(
'AdaGaussian'
,
norm_bound
,
initial_noise_multiplier
,
noise_decay_rate
=
alpha
,
decay_policy
=
decay_policy
)
ada_noise
=
ada_noise_construct
(
grad
)
print
(
'ada noise: '
,
ada_noise
)
print
(
'ada noise: '
,
ada_noise
)
...
@@ -192,7 +193,7 @@ def test_ada_clip_gaussian_random_pynative():
...
@@ -192,7 +193,7 @@ def test_ada_clip_gaussian_random_pynative():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
decay_policy
=
'Linear'
decay_policy
=
'Linear'
beta
=
Tensor
(
0.5
,
mstype
.
float32
)
beta
=
Tensor
(
0.5
,
mstype
.
float32
)
norm_
clip
=
Tensor
(
1.0
,
mstype
.
float32
)
norm_
bound
=
Tensor
(
1.0
,
mstype
.
float32
)
beta_stddev
=
0.1
beta_stddev
=
0.1
learning_rate
=
0.1
learning_rate
=
0.1
target_unclipped_quantile
=
0.3
target_unclipped_quantile
=
0.3
...
@@ -201,8 +202,8 @@ def test_ada_clip_gaussian_random_pynative():
...
@@ -201,8 +202,8 @@ def test_ada_clip_gaussian_random_pynative():
target_unclipped_quantile
=
target_unclipped_quantile
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
,
fraction_stddev
=
beta_stddev
,
seed
=
1
)
seed
=
1
)
next_norm_
clip
=
ada_clip
(
beta
,
norm_clip
)
next_norm_
bound
=
ada_clip
(
beta
,
norm_bound
)
print
(
'Liner next norm clip:'
,
next_norm_
clip
)
print
(
'Liner next norm clip:'
,
next_norm_
bound
)
decay_policy
=
'Geometric'
decay_policy
=
'Geometric'
ada_clip
=
AdaClippingWithGaussianRandom
(
decay_policy
=
decay_policy
,
ada_clip
=
AdaClippingWithGaussianRandom
(
decay_policy
=
decay_policy
,
...
@@ -210,8 +211,8 @@ def test_ada_clip_gaussian_random_pynative():
...
@@ -210,8 +211,8 @@ def test_ada_clip_gaussian_random_pynative():
target_unclipped_quantile
=
target_unclipped_quantile
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
,
fraction_stddev
=
beta_stddev
,
seed
=
1
)
seed
=
1
)
next_norm_
clip
=
ada_clip
(
beta
,
norm_clip
)
next_norm_
bound
=
ada_clip
(
beta
,
norm_bound
)
print
(
'Geometric next norm clip:'
,
next_norm_
clip
)
print
(
'Geometric next norm clip:'
,
next_norm_
bound
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
...
@@ -222,7 +223,7 @@ def test_ada_clip_gaussian_random_graph():
...
@@ -222,7 +223,7 @@ def test_ada_clip_gaussian_random_graph():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
decay_policy
=
'Linear'
decay_policy
=
'Linear'
beta
=
Tensor
(
0.5
,
mstype
.
float32
)
beta
=
Tensor
(
0.5
,
mstype
.
float32
)
norm_
clip
=
Tensor
(
1.0
,
mstype
.
float32
)
norm_
bound
=
Tensor
(
1.0
,
mstype
.
float32
)
beta_stddev
=
0.1
beta_stddev
=
0.1
learning_rate
=
0.1
learning_rate
=
0.1
target_unclipped_quantile
=
0.3
target_unclipped_quantile
=
0.3
...
@@ -231,8 +232,8 @@ def test_ada_clip_gaussian_random_graph():
...
@@ -231,8 +232,8 @@ def test_ada_clip_gaussian_random_graph():
target_unclipped_quantile
=
target_unclipped_quantile
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
,
fraction_stddev
=
beta_stddev
,
seed
=
1
)
seed
=
1
)
next_norm_
clip
=
ada_clip
(
beta
,
norm_clip
)
next_norm_
bound
=
ada_clip
(
beta
,
norm_bound
)
print
(
'Liner next norm clip:'
,
next_norm_
clip
)
print
(
'Liner next norm clip:'
,
next_norm_
bound
)
decay_policy
=
'Geometric'
decay_policy
=
'Geometric'
ada_clip
=
AdaClippingWithGaussianRandom
(
decay_policy
=
decay_policy
,
ada_clip
=
AdaClippingWithGaussianRandom
(
decay_policy
=
decay_policy
,
...
@@ -240,8 +241,8 @@ def test_ada_clip_gaussian_random_graph():
...
@@ -240,8 +241,8 @@ def test_ada_clip_gaussian_random_graph():
target_unclipped_quantile
=
target_unclipped_quantile
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
,
fraction_stddev
=
beta_stddev
,
seed
=
1
)
seed
=
1
)
next_norm_
clip
=
ada_clip
(
beta
,
norm_clip
)
next_norm_
bound
=
ada_clip
(
beta
,
norm_bound
)
print
(
'Geometric next norm clip:'
,
next_norm_
clip
)
print
(
'Geometric next norm clip:'
,
next_norm_
bound
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
...
@@ -252,18 +253,18 @@ def test_pynative_clip_mech_factory():
...
@@ -252,18 +253,18 @@ def test_pynative_clip_mech_factory():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
decay_policy
=
'Linear'
decay_policy
=
'Linear'
beta
=
Tensor
(
0.5
,
mstype
.
float32
)
beta
=
Tensor
(
0.5
,
mstype
.
float32
)
norm_
clip
=
Tensor
(
1.0
,
mstype
.
float32
)
norm_
bound
=
Tensor
(
1.0
,
mstype
.
float32
)
beta_stddev
=
0.1
beta_stddev
=
0.1
learning_rate
=
0.1
learning_rate
=
0.1
target_unclipped_quantile
=
0.3
target_unclipped_quantile
=
0.3
clip_mechanism
=
ClipMechanismsFactory
()
factory
=
ClipMechanismsFactory
()
ada_clip
=
clip_mechanism
.
create
(
'Gaussian'
,
ada_clip
=
factory
.
create
(
'Gaussian'
,
decay_policy
=
decay_policy
,
decay_policy
=
decay_policy
,
learning_rate
=
learning_rate
,
learning_rate
=
learning_rate
,
target_unclipped_quantile
=
target_unclipped_quantile
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
)
fraction_stddev
=
beta_stddev
)
next_norm_
clip
=
ada_clip
(
beta
,
norm_clip
)
next_norm_
bound
=
ada_clip
(
beta
,
norm_bound
)
print
(
'next_norm_
clip: '
,
next_norm_clip
)
print
(
'next_norm_
bound: '
,
next_norm_bound
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
...
@@ -274,15 +275,15 @@ def test_graph_clip_mech_factory():
...
@@ -274,15 +275,15 @@ def test_graph_clip_mech_factory():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
decay_policy
=
'Linear'
decay_policy
=
'Linear'
beta
=
Tensor
(
0.5
,
mstype
.
float32
)
beta
=
Tensor
(
0.5
,
mstype
.
float32
)
norm_
clip
=
Tensor
(
1.0
,
mstype
.
float32
)
norm_
bound
=
Tensor
(
1.0
,
mstype
.
float32
)
beta_stddev
=
0.1
beta_stddev
=
0.1
learning_rate
=
0.1
learning_rate
=
0.1
target_unclipped_quantile
=
0.3
target_unclipped_quantile
=
0.3
clip_mechanism
=
ClipMechanismsFactory
()
factory
=
ClipMechanismsFactory
()
ada_clip
=
clip_mechanism
.
create
(
'Gaussian'
,
ada_clip
=
factory
.
create
(
'Gaussian'
,
decay_policy
=
decay_policy
,
decay_policy
=
decay_policy
,
learning_rate
=
learning_rate
,
learning_rate
=
learning_rate
,
target_unclipped_quantile
=
target_unclipped_quantile
,
target_unclipped_quantile
=
target_unclipped_quantile
,
fraction_stddev
=
beta_stddev
)
fraction_stddev
=
beta_stddev
)
next_norm_
clip
=
ada_clip
(
beta
,
norm_clip
)
next_norm_
bound
=
ada_clip
(
beta
,
norm_bound
)
print
(
'next_norm_
clip: '
,
next_norm_clip
)
print
(
'next_norm_
bound: '
,
next_norm_bound
)
tests/ut/python/diff_privacy/test_model_train.py
浏览文件 @
69e45a3d
...
@@ -46,7 +46,7 @@ def dataset_generator(batch_size, batches):
...
@@ -46,7 +46,7 @@ def dataset_generator(batch_size, batches):
@
pytest
.
mark
.
component_mindarmour
@
pytest
.
mark
.
component_mindarmour
def
test_dp_model_with_pynative_mode
():
def
test_dp_model_with_pynative_mode
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
norm_
clip
=
1.0
norm_
bound
=
1.0
initial_noise_multiplier
=
0.01
initial_noise_multiplier
=
0.01
network
=
LeNet5
()
network
=
LeNet5
()
batch_size
=
32
batch_size
=
32
...
@@ -56,7 +56,7 @@ def test_dp_model_with_pynative_mode():
...
@@ -56,7 +56,7 @@ def test_dp_model_with_pynative_mode():
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
factory_opt
=
DPOptimizerClassFactory
(
micro_batches
=
micro_batches
)
factory_opt
=
DPOptimizerClassFactory
(
micro_batches
=
micro_batches
)
factory_opt
.
set_mechanisms
(
'Gaussian'
,
factory_opt
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
norm_
clip
,
norm_bound
=
norm_
bound
,
initial_noise_multiplier
=
initial_noise_multiplier
)
initial_noise_multiplier
=
initial_noise_multiplier
)
net_opt
=
factory_opt
.
create
(
'Momentum'
)(
network
.
trainable_params
(),
net_opt
=
factory_opt
.
create
(
'Momentum'
)(
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
learning_rate
=
0.1
,
momentum
=
0.9
)
...
@@ -66,7 +66,7 @@ def test_dp_model_with_pynative_mode():
...
@@ -66,7 +66,7 @@ def test_dp_model_with_pynative_mode():
target_unclipped_quantile
=
0.9
,
target_unclipped_quantile
=
0.9
,
fraction_stddev
=
0.01
)
fraction_stddev
=
0.01
)
model
=
DPModel
(
micro_batches
=
micro_batches
,
model
=
DPModel
(
micro_batches
=
micro_batches
,
norm_
clip
=
norm_clip
,
norm_
bound
=
norm_bound
,
clip_mech
=
clip_mech
,
clip_mech
=
clip_mech
,
noise_mech
=
None
,
noise_mech
=
None
,
network
=
network
,
network
=
network
,
...
@@ -86,7 +86,7 @@ def test_dp_model_with_pynative_mode():
...
@@ -86,7 +86,7 @@ def test_dp_model_with_pynative_mode():
@
pytest
.
mark
.
component_mindarmour
@
pytest
.
mark
.
component_mindarmour
def
test_dp_model_with_graph_mode
():
def
test_dp_model_with_graph_mode
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
norm_
clip
=
1.0
norm_
bound
=
1.0
initial_noise_multiplier
=
0.01
initial_noise_multiplier
=
0.01
network
=
LeNet5
()
network
=
LeNet5
()
batch_size
=
32
batch_size
=
32
...
@@ -94,7 +94,7 @@ def test_dp_model_with_graph_mode():
...
@@ -94,7 +94,7 @@ def test_dp_model_with_graph_mode():
epochs
=
1
epochs
=
1
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
noise_mech
=
NoiseMechanismsFactory
().
create
(
'Gaussian'
,
noise_mech
=
NoiseMechanismsFactory
().
create
(
'Gaussian'
,
norm_bound
=
norm_
clip
,
norm_bound
=
norm_
bound
,
initial_noise_multiplier
=
initial_noise_multiplier
)
initial_noise_multiplier
=
initial_noise_multiplier
)
clip_mech
=
ClipMechanismsFactory
().
create
(
'Gaussian'
,
clip_mech
=
ClipMechanismsFactory
().
create
(
'Gaussian'
,
decay_policy
=
'Linear'
,
decay_policy
=
'Linear'
,
...
@@ -105,7 +105,7 @@ def test_dp_model_with_graph_mode():
...
@@ -105,7 +105,7 @@ def test_dp_model_with_graph_mode():
momentum
=
0.9
)
momentum
=
0.9
)
model
=
DPModel
(
micro_batches
=
2
,
model
=
DPModel
(
micro_batches
=
2
,
clip_mech
=
clip_mech
,
clip_mech
=
clip_mech
,
norm_
clip
=
norm_clip
,
norm_
bound
=
norm_bound
,
noise_mech
=
noise_mech
,
noise_mech
=
noise_mech
,
network
=
network
,
network
=
network
,
loss_fn
=
loss
,
loss_fn
=
loss
,
...
@@ -124,22 +124,25 @@ def test_dp_model_with_graph_mode():
...
@@ -124,22 +124,25 @@ def test_dp_model_with_graph_mode():
@
pytest
.
mark
.
component_mindarmour
@
pytest
.
mark
.
component_mindarmour
def
test_dp_model_with_graph_mode_ada_gaussian
():
def
test_dp_model_with_graph_mode_ada_gaussian
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
norm_
clip
=
1.0
norm_
bound
=
1.0
initial_noise_multiplier
=
0.01
initial_noise_multiplier
=
0.01
network
=
LeNet5
()
network
=
LeNet5
()
batch_size
=
32
batch_size
=
32
batches
=
128
batches
=
128
epochs
=
1
epochs
=
1
alpha
=
0.8
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
noise_mech
=
NoiseMechanismsFactory
().
create
(
'AdaGaussian'
,
noise_mech
=
NoiseMechanismsFactory
().
create
(
'AdaGaussian'
,
norm_bound
=
norm_clip
,
norm_bound
=
norm_bound
,
initial_noise_multiplier
=
initial_noise_multiplier
)
initial_noise_multiplier
=
initial_noise_multiplier
,
noise_decay_rate
=
alpha
,
noise_update
=
'Exp'
)
clip_mech
=
None
clip_mech
=
None
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
learning_rate
=
0.1
,
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
momentum
=
0.9
)
model
=
DPModel
(
micro_batches
=
2
,
model
=
DPModel
(
micro_batches
=
2
,
clip_mech
=
clip_mech
,
clip_mech
=
clip_mech
,
norm_
clip
=
norm_clip
,
norm_
bound
=
norm_bound
,
noise_mech
=
noise_mech
,
noise_mech
=
noise_mech
,
network
=
network
,
network
=
network
,
loss_fn
=
loss
,
loss_fn
=
loss
,
...
...
tests/ut/python/diff_privacy/test_optimizer.py
浏览文件 @
69e45a3d
...
@@ -34,10 +34,10 @@ def test_optimizer():
...
@@ -34,10 +34,10 @@ def test_optimizer():
momentum
=
0.9
momentum
=
0.9
micro_batches
=
2
micro_batches
=
2
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
gaussian_mech
=
DPOptimizerClassFactory
(
micro_batches
)
factory
=
DPOptimizerClassFactory
(
micro_batches
)
gaussian_mech
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
1.5
,
initial_noise_multiplier
=
5.0
)
factory
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
1.5
,
initial_noise_multiplier
=
5.0
)
net_opt
=
gaussian_mech
.
create
(
'SGD'
)(
params
=
network
.
trainable_params
(),
learning_rate
=
lr
,
net_opt
=
factory
.
create
(
'SGD'
)(
params
=
network
.
trainable_params
(),
learning_rate
=
lr
,
momentum
=
momentum
)
momentum
=
momentum
)
_
=
Model
(
network
,
loss_fn
=
loss
,
optimizer
=
net_opt
,
metrics
=
None
)
_
=
Model
(
network
,
loss_fn
=
loss
,
optimizer
=
net_opt
,
metrics
=
None
)
...
@@ -52,10 +52,10 @@ def test_optimizer_gpu():
...
@@ -52,10 +52,10 @@ def test_optimizer_gpu():
momentum
=
0.9
momentum
=
0.9
micro_batches
=
2
micro_batches
=
2
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
gaussian_mech
=
DPOptimizerClassFactory
(
micro_batches
)
factory
=
DPOptimizerClassFactory
(
micro_batches
)
gaussian_mech
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
1.5
,
initial_noise_multiplier
=
5.0
)
factory
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
1.5
,
initial_noise_multiplier
=
5.0
)
net_opt
=
gaussian_mech
.
create
(
'SGD'
)(
params
=
network
.
trainable_params
(),
learning_rate
=
lr
,
net_opt
=
factory
.
create
(
'SGD'
)(
params
=
network
.
trainable_params
(),
learning_rate
=
lr
,
momentum
=
momentum
)
momentum
=
momentum
)
_
=
Model
(
network
,
loss_fn
=
loss
,
optimizer
=
net_opt
,
metrics
=
None
)
_
=
Model
(
network
,
loss_fn
=
loss
,
optimizer
=
net_opt
,
metrics
=
None
)
...
@@ -70,8 +70,8 @@ def test_optimizer_cpu():
...
@@ -70,8 +70,8 @@ def test_optimizer_cpu():
momentum
=
0.9
momentum
=
0.9
micro_batches
=
2
micro_batches
=
2
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
gaussian_mech
=
DPOptimizerClassFactory
(
micro_batches
)
factory
=
DPOptimizerClassFactory
(
micro_batches
)
gaussian_mech
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
1.5
,
initial_noise_multiplier
=
5.0
)
factory
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
1.5
,
initial_noise_multiplier
=
5.0
)
net_opt
=
gaussian_mech
.
create
(
'SGD'
)(
params
=
network
.
trainable_params
(),
learning_rate
=
lr
,
net_opt
=
factory
.
create
(
'SGD'
)(
params
=
network
.
trainable_params
(),
learning_rate
=
lr
,
momentum
=
momentum
)
momentum
=
momentum
)
_
=
Model
(
network
,
loss_fn
=
loss
,
optimizer
=
net_opt
,
metrics
=
None
)
_
=
Model
(
network
,
loss_fn
=
loss
,
optimizer
=
net_opt
,
metrics
=
None
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录