Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
41752a9c
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
2
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
41752a9c
编写于
5月 26, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 26, 2020
浏览文件
操作
浏览文件
下载
差异文件
!24 Add train module and optimizer module for differential privacy.
Merge pull request !24 from zheng-huanhuan/dp_pynative
上级
45cede10
b2e0934b
变更
13
展开全部
隐藏空白更改
内联
并排
Showing
13 changed file
with
964 addition
and
7 deletion
+964
-7
example/mnist_demo/lenet5_config.py
example/mnist_demo/lenet5_config.py
+32
-0
example/mnist_demo/lenet5_dp_model_train.py
example/mnist_demo/lenet5_dp_model_train.py
+151
-0
mindarmour/diff_privacy/__init__.py
mindarmour/diff_privacy/__init__.py
+7
-1
mindarmour/diff_privacy/mechanisms/mechanisms.py
mindarmour/diff_privacy/mechanisms/mechanisms.py
+0
-4
mindarmour/diff_privacy/monitor/monitor.py
mindarmour/diff_privacy/monitor/monitor.py
+1
-1
mindarmour/diff_privacy/optimizer/__init__.py
mindarmour/diff_privacy/optimizer/__init__.py
+0
-0
mindarmour/diff_privacy/optimizer/optimizer.py
mindarmour/diff_privacy/optimizer/optimizer.py
+116
-0
mindarmour/diff_privacy/train/__init__.py
mindarmour/diff_privacy/train/__init__.py
+0
-0
mindarmour/diff_privacy/train/model.py
mindarmour/diff_privacy/train/model.py
+515
-0
tests/ut/python/diff_privacy/test_mechanisms.py
tests/ut/python/diff_privacy/test_mechanisms.py
+0
-0
tests/ut/python/diff_privacy/test_model_train.py
tests/ut/python/diff_privacy/test_model_train.py
+65
-0
tests/ut/python/diff_privacy/test_monitor.py
tests/ut/python/diff_privacy/test_monitor.py
+1
-1
tests/ut/python/diff_privacy/test_optimizer.py
tests/ut/python/diff_privacy/test_optimizer.py
+76
-0
未找到文件。
example/mnist_demo/lenet5_config.py
0 → 100644
浏览文件 @
41752a9c
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py
"""
from
easydict
import
EasyDict
as
edict
mnist_cfg
=
edict
({
'num_classes'
:
10
,
'lr'
:
0.01
,
'momentum'
:
0.9
,
'epoch_size'
:
10
,
'batch_size'
:
32
,
'buffer_size'
:
1000
,
'image_height'
:
32
,
'image_width'
:
32
,
'save_checkpoint_steps'
:
1875
,
'keep_checkpoint_max'
:
10
,
})
example/mnist_demo/lenet5_dp_model_train.py
0 → 100644
浏览文件 @
41752a9c
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python lenet5_dp_model_train.py --data_path /YourDataPath --micro_batches=2
"""
import
os
import
argparse
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore.train.callback
import
ModelCheckpoint
from
mindspore.train.callback
import
CheckpointConfig
from
mindspore.train.callback
import
LossMonitor
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
CV
import
mindspore.dataset.transforms.c_transforms
as
C
from
mindspore.dataset.transforms.vision
import
Inter
import
mindspore.common.dtype
as
mstype
from
mindarmour.diff_privacy
import
DPModel
from
mindarmour.diff_privacy
import
DPOptimizerClassFactory
from
mindarmour.diff_privacy
import
PrivacyMonitorFactory
from
mindarmour.utils.logger
import
LogUtil
from
lenet5_net
import
LeNet5
from
lenet5_config
import
mnist_cfg
as
cfg
LOGGER
=
LogUtil
.
get_instance
()
TAG
=
'Lenet5_train'
def
generate_mnist_dataset
(
data_path
,
batch_size
=
32
,
repeat_size
=
1
,
num_parallel_workers
=
1
,
sparse
=
True
):
"""
create dataset for training or testing
"""
# define dataset
ds1
=
ds
.
MnistDataset
(
data_path
)
# define operation parameters
resize_height
,
resize_width
=
32
,
32
rescale
=
1.0
/
255.0
shift
=
0.0
# define map operations
resize_op
=
CV
.
Resize
((
resize_height
,
resize_width
),
interpolation
=
Inter
.
LINEAR
)
rescale_op
=
CV
.
Rescale
(
rescale
,
shift
)
hwc2chw_op
=
CV
.
HWC2CHW
()
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
# apply map operations on images
if
not
sparse
:
one_hot_enco
=
C
.
OneHot
(
10
)
ds1
=
ds1
.
map
(
input_columns
=
"label"
,
operations
=
one_hot_enco
,
num_parallel_workers
=
num_parallel_workers
)
type_cast_op
=
C
.
TypeCast
(
mstype
.
float32
)
ds1
=
ds1
.
map
(
input_columns
=
"label"
,
operations
=
type_cast_op
,
num_parallel_workers
=
num_parallel_workers
)
ds1
=
ds1
.
map
(
input_columns
=
"image"
,
operations
=
resize_op
,
num_parallel_workers
=
num_parallel_workers
)
ds1
=
ds1
.
map
(
input_columns
=
"image"
,
operations
=
rescale_op
,
num_parallel_workers
=
num_parallel_workers
)
ds1
=
ds1
.
map
(
input_columns
=
"image"
,
operations
=
hwc2chw_op
,
num_parallel_workers
=
num_parallel_workers
)
# apply DatasetOps
buffer_size
=
10000
ds1
=
ds1
.
shuffle
(
buffer_size
=
buffer_size
)
ds1
=
ds1
.
batch
(
batch_size
,
drop_remainder
=
True
)
ds1
=
ds1
.
repeat
(
repeat_size
)
return
ds1
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
'MindSpore MNIST Example'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
"Ascend"
,
choices
=
[
'Ascend'
,
'GPU'
,
'CPU'
],
help
=
'device where the code will be implemented (default: Ascend)'
)
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
"./MNIST_unzip"
,
help
=
'path where the dataset is saved'
)
parser
.
add_argument
(
'--dataset_sink_mode'
,
type
=
bool
,
default
=
False
,
help
=
'dataset_sink_mode is False or True'
)
parser
.
add_argument
(
'--micro_batches'
,
type
=
float
,
default
=
None
,
help
=
'optional, if use differential privacy, need to set micro_batches'
)
parser
.
add_argument
(
'--l2_norm_bound'
,
type
=
float
,
default
=
1
,
help
=
'optional, if use differential privacy, need to set l2_norm_bound'
)
parser
.
add_argument
(
'--initial_noise_multiplier'
,
type
=
float
,
default
=
0.001
,
help
=
'optional, if use differential privacy, need to set initial_noise_multiplier'
)
args
=
parser
.
parse_args
()
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
args
.
device_target
,
enable_mem_reuse
=
False
)
network
=
LeNet5
()
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
save_checkpoint_steps
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
directory
=
'./trained_ckpt_file/'
,
config
=
config_ck
)
ds_train
=
generate_mnist_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"train"
),
cfg
.
batch_size
,
cfg
.
epoch_size
)
if
args
.
micro_batches
and
cfg
.
batch_size
%
args
.
micro_batches
!=
0
:
raise
ValueError
(
"Number of micro_batches should divide evenly batch_size"
)
gaussian_mech
=
DPOptimizerClassFactory
(
args
.
micro_batches
)
gaussian_mech
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
args
.
l2_norm_bound
,
initial_noise_multiplier
=
args
.
initial_noise_multiplier
)
net_opt
=
gaussian_mech
.
create
(
'Momentum'
)(
params
=
network
.
trainable_params
(),
learning_rate
=
cfg
.
lr
,
momentum
=
cfg
.
momentum
)
micro_size
=
int
(
cfg
.
batch_size
//
args
.
micro_batches
)
rdp_monitor
=
PrivacyMonitorFactory
.
create
(
'rdp'
,
num_samples
=
60000
,
batch_size
=
micro_size
,
initial_noise_multiplier
=
args
.
initial_noise_multiplier
,
per_print_times
=
10
)
model
=
DPModel
(
micro_batches
=
args
.
micro_batches
,
norm_clip
=
args
.
l2_norm_bound
,
dp_mech
=
gaussian_mech
.
mech
,
network
=
network
,
loss_fn
=
net_loss
,
optimizer
=
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
LOGGER
.
info
(
TAG
,
"============== Starting Training =============="
)
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
ckpoint_cb
,
LossMonitor
(),
rdp_monitor
],
dataset_sink_mode
=
args
.
dataset_sink_mode
)
LOGGER
.
info
(
TAG
,
"============== Starting Testing =============="
)
ckpt_file_name
=
'trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
param_dict
=
load_checkpoint
(
ckpt_file_name
)
load_param_into_net
(
network
,
param_dict
)
ds_eval
=
generate_mnist_dataset
(
os
.
path
.
join
(
args
.
data_path
,
'test'
),
batch_size
=
cfg
.
batch_size
)
acc
=
model
.
eval
(
ds_eval
,
dataset_sink_mode
=
False
)
LOGGER
.
info
(
TAG
,
"============== Accuracy: %s =============="
,
acc
)
mindarmour/diff_privacy/__init__.py
浏览文件 @
41752a9c
...
@@ -4,7 +4,13 @@ This module provide Differential Privacy feature to protect user privacy.
...
@@ -4,7 +4,13 @@ This module provide Differential Privacy feature to protect user privacy.
from
.mechanisms.mechanisms
import
GaussianRandom
from
.mechanisms.mechanisms
import
GaussianRandom
from
.mechanisms.mechanisms
import
AdaGaussianRandom
from
.mechanisms.mechanisms
import
AdaGaussianRandom
from
.mechanisms.mechanisms
import
MechanismsFactory
from
.mechanisms.mechanisms
import
MechanismsFactory
from
.monitor.monitor
import
PrivacyMonitorFactory
from
.optimizer.optimizer
import
DPOptimizerClassFactory
from
.train.model
import
DPModel
__all__
=
[
'GaussianRandom'
,
__all__
=
[
'GaussianRandom'
,
'AdaGaussianRandom'
,
'AdaGaussianRandom'
,
'MechanismsFactory'
]
'MechanismsFactory'
,
'PrivacyMonitorFactory'
,
'DPOptimizerClassFactory'
,
'DPModel'
]
mindarmour/diff_privacy/mechanisms/mechanisms.py
浏览文件 @
41752a9c
...
@@ -60,10 +60,6 @@ class Mechanisms(Cell):
...
@@ -60,10 +60,6 @@ class Mechanisms(Cell):
"""
"""
Basic class of noise generated mechanism.
Basic class of noise generated mechanism.
"""
"""
def
__init__
(
self
):
pass
def
construct
(
self
,
shape
):
def
construct
(
self
,
shape
):
"""
"""
Construct function.
Construct function.
...
...
mindarmour/diff_privacy/monitor/monitor.py
浏览文件 @
41752a9c
...
@@ -47,7 +47,7 @@ class PrivacyMonitorFactory:
...
@@ -47,7 +47,7 @@ class PrivacyMonitorFactory:
parameters used for creating a privacy monitor.
parameters used for creating a privacy monitor.
Returns:
Returns:
PrivacyMonitor
, a privacy monitor.
Callback
, a privacy monitor.
Examples:
Examples:
>>> rdp = PrivacyMonitorFactory.create(policy='rdp',
>>> rdp = PrivacyMonitorFactory.create(policy='rdp',
...
...
mindarmour/diff_privacy/optimizer/__init__.py
0 → 100644
浏览文件 @
41752a9c
mindarmour/diff_privacy/optimizer/optimizer.py
0 → 100644
浏览文件 @
41752a9c
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Differential privacy optimizer.
"""
import
mindspore
as
ms
from
mindspore
import
nn
from
mindspore
import
Tensor
from
mindarmour.diff_privacy.mechanisms.mechanisms
import
MechanismsFactory
class
DPOptimizerClassFactory
:
"""
Factory class of Optimizer.
Args:
micro_batches (int): The number of small batches split from an origianl batch. Default: None.
Returns:
Optimizer, Optimizer class
Examples:
>>> GaussianSGD = DPOptimizerClassFactory(micro_batches=2)
>>> GaussianSGD.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0)
>>> net_opt = GaussianSGD.create('SGD')(params=network.trainable_params(),
>>> learning_rate=cfg.lr,
>>> momentum=cfg.momentum)
"""
def
__init__
(
self
,
micro_batches
=
None
):
self
.
_mech_factory
=
MechanismsFactory
()
self
.
mech
=
None
self
.
_micro_batches
=
micro_batches
def
set_mechanisms
(
self
,
policy
,
*
args
,
**
kwargs
):
"""
Get noise mechanism object.
Args:
policy (str): Choose mechanism type.
"""
self
.
mech
=
self
.
_mech_factory
.
create
(
policy
,
*
args
,
**
kwargs
)
def
create
(
self
,
policy
,
*
args
,
**
kwargs
):
"""
Create DP optimizer.
Args:
policy (str): Choose original optimizer type.
Returns:
Optimizer, A optimizer with DP.
"""
if
policy
==
'SGD'
:
cls
=
self
.
_get_dp_optimizer_class
(
nn
.
SGD
,
self
.
mech
,
self
.
_micro_batches
,
*
args
,
**
kwargs
)
return
cls
if
policy
==
'Momentum'
:
cls
=
self
.
_get_dp_optimizer_class
(
nn
.
Momentum
,
self
.
mech
,
self
.
_micro_batches
,
*
args
,
**
kwargs
)
return
cls
if
policy
==
'Adam'
:
cls
=
self
.
_get_dp_optimizer_class
(
nn
.
Adam
,
self
.
mech
,
self
.
_micro_batches
,
*
args
,
**
kwargs
)
return
cls
if
policy
==
'AdamWeightDecay'
:
cls
=
self
.
_get_dp_optimizer_class
(
nn
.
AdamWeightDecay
,
self
.
mech
,
self
.
_micro_batches
,
*
args
,
**
kwargs
)
return
cls
if
policy
==
'AdamWeightDecayDynamicLR'
:
cls
=
self
.
_get_dp_optimizer_class
(
nn
.
AdamWeightDecayDynamicLR
,
self
.
mech
,
self
.
_micro_batches
,
*
args
,
**
kwargs
)
return
cls
raise
NameError
(
"The {} is not implement, please choose ['SGD', 'Momentum', 'AdamWeightDecay', "
"'Adam', 'AdamWeightDecayDynamicLR']"
.
format
(
policy
))
def
_get_dp_optimizer_class
(
self
,
cls
,
mech
,
micro_batches
):
"""
Wrap original mindspore optimizer with `self._mech`.
"""
class
DPOptimizer
(
cls
):
"""
Initialize the DPOptimizerClass.
Returns:
Optimizer, Optimizer class.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
DPOptimizer
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
_mech
=
mech
def
construct
(
self
,
gradients
):
"""
construct a compute flow.
"""
g_len
=
len
(
gradients
)
gradient_noise
=
list
(
gradients
)
for
i
in
range
(
g_len
):
gradient_noise
[
i
]
=
gradient_noise
[
i
].
asnumpy
()
gradient_noise
[
i
]
=
self
.
_mech
(
gradient_noise
[
i
].
shape
).
asnumpy
()
+
gradient_noise
[
i
]
gradient_noise
[
i
]
=
gradient_noise
[
i
]
/
micro_batches
gradient_noise
[
i
]
=
Tensor
(
gradient_noise
[
i
],
ms
.
float32
)
gradients
=
tuple
(
gradient_noise
)
gradients
=
super
(
DPOptimizer
,
self
).
construct
(
gradients
)
return
gradients
return
DPOptimizer
mindarmour/diff_privacy/train/__init__.py
0 → 100644
浏览文件 @
41752a9c
mindarmour/diff_privacy/train/model.py
0 → 100644
浏览文件 @
41752a9c
此差异已折叠。
点击以展开。
tests/ut/python/diff_privacy/
mechanisms/
test_mechanisms.py
→
tests/ut/python/diff_privacy/test_mechanisms.py
浏览文件 @
41752a9c
文件已移动
tests/ut/python/diff_privacy/test_model_train.py
0 → 100644
浏览文件 @
41752a9c
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DP-Model test.
"""
import
pytest
import
numpy
as
np
from
mindspore
import
nn
from
mindspore.nn
import
SGD
from
mindspore.model_zoo.lenet
import
LeNet5
from
mindspore
import
context
import
mindspore.dataset
as
ds
from
mindarmour.diff_privacy
import
DPOptimizerClassFactory
from
mindarmour.diff_privacy
import
DPModel
def
dataset_generator
(
batch_size
,
batches
):
data
=
np
.
random
.
random
((
batches
*
batch_size
,
1
,
32
,
32
)).
astype
(
np
.
float32
)
label
=
np
.
random
.
randint
(
0
,
10
,
batches
*
batch_size
).
astype
(
np
.
int32
)
for
i
in
range
(
batches
):
yield
data
[
i
*
batch_size
:(
i
+
1
)
*
batch_size
],
label
[
i
*
batch_size
:(
i
+
1
)
*
batch_size
]
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_card
@
pytest
.
mark
.
component_mindarmour
def
test_dp_model
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
l2_norm_bound
=
1.0
initial_noise_multiplier
=
0.01
net
=
LeNet5
()
batch_size
=
32
batches
=
128
epochs
=
1
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
optim
=
SGD
(
params
=
net
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
gaussian_mech
=
DPOptimizerClassFactory
()
gaussian_mech
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
l2_norm_bound
,
initial_noise_multiplier
=
initial_noise_multiplier
)
model
=
DPModel
(
micro_batches
=
2
,
norm_clip
=
l2_norm_bound
,
dp_mech
=
gaussian_mech
.
mech
,
network
=
net
,
loss_fn
=
loss
,
optimizer
=
optim
,
metrics
=
None
)
ms_ds
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
'data'
,
'label'
])
ms_ds
.
set_dataset_size
(
batch_size
*
batches
)
model
.
train
(
epochs
,
ms_ds
)
tests/ut/python/diff_privacy/test_monitor.py
浏览文件 @
41752a9c
...
@@ -23,7 +23,7 @@ from mindspore.train import Model
...
@@ -23,7 +23,7 @@ from mindspore.train import Model
import
mindspore.context
as
context
import
mindspore.context
as
context
from
mindspore.model_zoo.lenet
import
LeNet5
from
mindspore.model_zoo.lenet
import
LeNet5
from
mindarmour.diff_privacy
.monitor.monitor
import
PrivacyMonitorFactory
from
mindarmour.diff_privacy
import
PrivacyMonitorFactory
from
mindarmour.utils.logger
import
LogUtil
from
mindarmour.utils.logger
import
LogUtil
LOGGER
=
LogUtil
.
get_instance
()
LOGGER
=
LogUtil
.
get_instance
()
...
...
tests/ut/python/diff_privacy/test_optimizer.py
0 → 100644
浏览文件 @
41752a9c
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pytest
from
mindspore
import
nn
from
mindspore
import
context
from
mindspore.model_zoo.lenet
import
LeNet5
from
mindspore.train.model
import
Model
from
mindarmour.diff_privacy
import
DPOptimizerClassFactory
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_card
@
pytest
.
mark
.
component_mindarmour
def
test_optimizer
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
network
=
LeNet5
()
lr
=
0.01
momentum
=
0.9
micro_batches
=
2
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
gaussian_mech
=
DPOptimizerClassFactory
(
micro_batches
)
gaussian_mech
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
1.5
,
initial_noise_multiplier
=
5.0
)
net_opt
=
gaussian_mech
.
create
(
'SGD'
)(
params
=
network
.
trainable_params
(),
learning_rate
=
lr
,
momentum
=
momentum
)
_
=
Model
(
network
,
loss_fn
=
loss
,
optimizer
=
net_opt
,
metrics
=
None
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_inference
@
pytest
.
mark
.
env_card
@
pytest
.
mark
.
component_mindarmour
def
test_optimizer_gpu
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"GPU"
)
network
=
LeNet5
()
lr
=
0.01
momentum
=
0.9
micro_batches
=
2
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
gaussian_mech
=
DPOptimizerClassFactory
(
micro_batches
)
gaussian_mech
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
1.5
,
initial_noise_multiplier
=
5.0
)
net_opt
=
gaussian_mech
.
create
(
'SGD'
)(
params
=
network
.
trainable_params
(),
learning_rate
=
lr
,
momentum
=
momentum
)
_
=
Model
(
network
,
loss_fn
=
loss
,
optimizer
=
net_opt
,
metrics
=
None
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_card
@
pytest
.
mark
.
component_mindarmour
def
test_optimizer_cpu
():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"CPU"
)
network
=
LeNet5
()
lr
=
0.01
momentum
=
0.9
micro_batches
=
2
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
gaussian_mech
=
DPOptimizerClassFactory
(
micro_batches
)
gaussian_mech
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
1.5
,
initial_noise_multiplier
=
5.0
)
net_opt
=
gaussian_mech
.
create
(
'SGD'
)(
params
=
network
.
trainable_params
(),
learning_rate
=
lr
,
momentum
=
momentum
)
_
=
Model
(
network
,
loss_fn
=
loss
,
optimizer
=
net_opt
,
metrics
=
None
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录