Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
f3baf9db
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
2
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f3baf9db
编写于
5月 29, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 29, 2020
浏览文件
操作
浏览文件
下载
差异文件
!28 fix issue
Merge pull request !28 from zheng-huanhuan/dp_1
上级
07542569
9c37a110
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
41 addition
and
53 deletion
+41
-53
example/mnist_demo/lenet5_dp_model_train.py
example/mnist_demo/lenet5_dp_model_train.py
+1
-2
mindarmour/diff_privacy/optimizer/optimizer.py
mindarmour/diff_privacy/optimizer/optimizer.py
+2
-1
mindarmour/diff_privacy/train/model.py
mindarmour/diff_privacy/train/model.py
+31
-44
tests/ut/python/diff_privacy/test_model_train.py
tests/ut/python/diff_privacy/test_model_train.py
+7
-6
未找到文件。
example/mnist_demo/lenet5_dp_model_train.py
浏览文件 @
f3baf9db
...
@@ -123,10 +123,9 @@ if __name__ == "__main__":
...
@@ -123,10 +123,9 @@ if __name__ == "__main__":
net_opt
=
gaussian_mech
.
create
(
'SGD'
)(
params
=
network
.
trainable_params
(),
net_opt
=
gaussian_mech
.
create
(
'SGD'
)(
params
=
network
.
trainable_params
(),
learning_rate
=
cfg
.
lr
,
learning_rate
=
cfg
.
lr
,
momentum
=
cfg
.
momentum
)
momentum
=
cfg
.
momentum
)
micro_size
=
int
(
cfg
.
batch_size
//
args
.
micro_batches
)
rdp_monitor
=
PrivacyMonitorFactory
.
create
(
'rdp'
,
rdp_monitor
=
PrivacyMonitorFactory
.
create
(
'rdp'
,
num_samples
=
60000
,
num_samples
=
60000
,
batch_size
=
micro
_size
,
batch_size
=
cfg
.
batch
_size
,
initial_noise_multiplier
=
args
.
initial_noise_multiplier
,
initial_noise_multiplier
=
args
.
initial_noise_multiplier
,
per_print_times
=
10
)
per_print_times
=
10
)
model
=
DPModel
(
micro_batches
=
args
.
micro_batches
,
model
=
DPModel
(
micro_batches
=
args
.
micro_batches
,
...
...
mindarmour/diff_privacy/optimizer/optimizer.py
浏览文件 @
f3baf9db
...
@@ -19,6 +19,7 @@ from mindspore import nn
...
@@ -19,6 +19,7 @@ from mindspore import nn
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindarmour.diff_privacy.mechanisms.mechanisms
import
MechanismsFactory
from
mindarmour.diff_privacy.mechanisms.mechanisms
import
MechanismsFactory
from
mindarmour.utils._check_param
import
check_int_positive
class
DPOptimizerClassFactory
:
class
DPOptimizerClassFactory
:
...
@@ -41,7 +42,7 @@ class DPOptimizerClassFactory:
...
@@ -41,7 +42,7 @@ class DPOptimizerClassFactory:
def
__init__
(
self
,
micro_batches
=
None
):
def
__init__
(
self
,
micro_batches
=
None
):
self
.
_mech_factory
=
MechanismsFactory
()
self
.
_mech_factory
=
MechanismsFactory
()
self
.
mech
=
None
self
.
mech
=
None
self
.
_micro_batches
=
micro_batches
self
.
_micro_batches
=
check_int_positive
(
'micro_batches'
,
micro_batches
)
def
set_mechanisms
(
self
,
policy
,
*
args
,
**
kwargs
):
def
set_mechanisms
(
self
,
policy
,
*
args
,
**
kwargs
):
"""
"""
...
...
mindarmour/diff_privacy/train/model.py
浏览文件 @
f3baf9db
...
@@ -48,6 +48,8 @@ from mindspore.nn.wrap.loss_scale import _grad_overflow
...
@@ -48,6 +48,8 @@ from mindspore.nn.wrap.loss_scale import _grad_overflow
from
mindspore.nn
import
Cell
from
mindspore.nn
import
Cell
from
mindspore
import
ParameterTuple
from
mindspore
import
ParameterTuple
from
mindarmour.utils._check_param
import
check_param_type
from
mindarmour.utils._check_param
import
check_value_positive
GRADIENT_CLIP_TYPE
=
1
GRADIENT_CLIP_TYPE
=
1
grad_scale
=
C
.
MultitypeFuncGraph
(
"grad_scale"
)
grad_scale
=
C
.
MultitypeFuncGraph
(
"grad_scale"
)
...
@@ -56,6 +58,7 @@ reciprocal = P.Reciprocal()
...
@@ -56,6 +58,7 @@ reciprocal = P.Reciprocal()
@
grad_scale
.
register
(
"Tensor"
,
"Tensor"
)
@
grad_scale
.
register
(
"Tensor"
,
"Tensor"
)
def
tensor_grad_scale
(
scale
,
grad
):
def
tensor_grad_scale
(
scale
,
grad
):
""" grad scaling """
return
grad
*
reciprocal
(
scale
)
return
grad
*
reciprocal
(
scale
)
...
@@ -65,7 +68,7 @@ class DPModel(Model):
...
@@ -65,7 +68,7 @@ class DPModel(Model):
Args:
Args:
micro_batches (int): The number of small batches split from an origianl batch. Default: None.
micro_batches (int): The number of small batches split from an origianl batch. Default: None.
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default:
None
.
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default:
1.0
.
dp_mech (Mechanisms): The object can generate the different type of noise. Default: None.
dp_mech (Mechanisms): The object can generate the different type of noise. Default: None.
Examples:
Examples:
...
@@ -94,7 +97,7 @@ class DPModel(Model):
...
@@ -94,7 +97,7 @@ class DPModel(Model):
>>> norm_bound=args.l2_norm_bound,
>>> norm_bound=args.l2_norm_bound,
>>> initial_noise_multiplier=args.initial_noise_multiplier)
>>> initial_noise_multiplier=args.initial_noise_multiplier)
>>> model = DPModel(micro_batches=2,
>>> model = DPModel(micro_batches=2,
>>> norm_clip=1,
>>> norm_clip=1
.0
,
>>> dp_mech=gaussian_mech.mech,
>>> dp_mech=gaussian_mech.mech,
>>> network=net,
>>> network=net,
>>> loss_fn=loss,
>>> loss_fn=loss,
...
@@ -103,16 +106,17 @@ class DPModel(Model):
...
@@ -103,16 +106,17 @@ class DPModel(Model):
>>> dataset = get_dataset()
>>> dataset = get_dataset()
>>> model.train(2, dataset)
>>> model.train(2, dataset)
"""
"""
def
__init__
(
self
,
micro_batches
=
None
,
norm_clip
=
None
,
dp_mech
=
None
,
**
kwargs
):
def
__init__
(
self
,
micro_batches
=
None
,
norm_clip
=
1.0
,
dp_mech
=
None
,
**
kwargs
):
if
micro_batches
:
if
micro_batches
:
self
.
_micro_batches
=
int
(
micro_batches
)
self
.
_micro_batches
=
int
(
micro_batches
)
else
:
else
:
self
.
_micro_batches
=
None
self
.
_micro_batches
=
None
self
.
_norm_clip
=
norm_clip
float_norm_clip
=
check_param_type
(
'l2_norm_clip'
,
norm_clip
,
float
)
self
.
_norm_clip
=
check_value_positive
(
'l2_norm_clip'
,
float_norm_clip
)
self
.
_dp_mech
=
dp_mech
self
.
_dp_mech
=
dp_mech
super
(
DPModel
,
self
).
__init__
(
**
kwargs
)
super
(
DPModel
,
self
).
__init__
(
**
kwargs
)
def
amp_build_train_network
(
self
,
network
,
optimizer
,
loss_fn
=
None
,
level
=
'O0'
,
**
kwargs
):
def
_
amp_build_train_network
(
self
,
network
,
optimizer
,
loss_fn
=
None
,
level
=
'O0'
,
**
kwargs
):
"""
"""
Build the mixed precision training cell automatically.
Build the mixed precision training cell automatically.
...
@@ -185,18 +189,18 @@ class DPModel(Model):
...
@@ -185,18 +189,18 @@ class DPModel(Model):
if
self
.
_micro_batches
:
if
self
.
_micro_batches
:
if
self
.
_optimizer
:
if
self
.
_optimizer
:
if
self
.
_loss_scale_manager_set
:
if
self
.
_loss_scale_manager_set
:
network
=
self
.
amp_build_train_network
(
network
,
network
=
self
.
_
amp_build_train_network
(
network
,
self
.
_optimizer
,
self
.
_optimizer
,
self
.
_loss_fn
,
self
.
_loss_fn
,
level
=
self
.
_amp_level
,
level
=
self
.
_amp_level
,
loss_scale_manager
=
self
.
_loss_scale_manager
,
loss_scale_manager
=
self
.
_loss_scale_manager
,
keep_batchnorm_fp32
=
self
.
_keep_bn_fp32
)
keep_batchnorm_fp32
=
self
.
_keep_bn_fp32
)
else
:
else
:
network
=
self
.
amp_build_train_network
(
network
,
network
=
self
.
_
amp_build_train_network
(
network
,
self
.
_optimizer
,
self
.
_optimizer
,
self
.
_loss_fn
,
self
.
_loss_fn
,
level
=
self
.
_amp_level
,
level
=
self
.
_amp_level
,
keep_batchnorm_fp32
=
self
.
_keep_bn_fp32
)
keep_batchnorm_fp32
=
self
.
_keep_bn_fp32
)
elif
self
.
_loss_fn
:
elif
self
.
_loss_fn
:
network
=
nn
.
WithLossCell
(
network
,
self
.
_loss_fn
)
network
=
nn
.
WithLossCell
(
network
,
self
.
_loss_fn
)
else
:
else
:
...
@@ -273,8 +277,8 @@ class _TrainOneStepWithLossScaleCell(Cell):
...
@@ -273,8 +277,8 @@ class _TrainOneStepWithLossScaleCell(Cell):
network (Cell): The training network.
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
optimizer (Cell): Optimizer for updating the weights.
scale_update_cell(Cell): The loss scaling update logic cell. Default: None.
scale_update_cell(Cell): The loss scaling update logic cell. Default: None.
micro_batches (int): The number of small batches split from an origi
an
l batch. Default: None.
micro_batches (int): The number of small batches split from an origi
na
l batch. Default: None.
l2_norm_clip (float): Use to clip the bound, if set 1, will retu
n the original data. Default: None
.
l2_norm_clip (float): Use to clip the bound, if set 1, will retu
rn the original data. Default: 1.0
.
mech (Mechanisms): The object can generate the different type of noise. Default: None.
mech (Mechanisms): The object can generate the different type of noise. Default: None.
Inputs:
Inputs:
...
@@ -288,21 +292,9 @@ class _TrainOneStepWithLossScaleCell(Cell):
...
@@ -288,21 +292,9 @@ class _TrainOneStepWithLossScaleCell(Cell):
- **loss** (Tensor) - Tensor with shape :math:`()`.
- **loss** (Tensor) - Tensor with shape :math:`()`.
- **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool.
- **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool.
- **loss_scale** (Tensor) - Tensor with shape :math:`()`.
- **loss_scale** (Tensor) - Tensor with shape :math:`()`.
Examples:
>>> net_with_loss = Net()
>>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=manager)
>>> train_network.set_train()
>>>
>>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
>>> output = train_network(inputs, label, scaling_sens)
"""
"""
def
__init__
(
self
,
network
,
optimizer
,
scale_update_cell
=
None
,
micro_batches
=
None
,
l2_norm_clip
=
None
,
mech
=
None
):
def
__init__
(
self
,
network
,
optimizer
,
scale_update_cell
=
None
,
micro_batches
=
None
,
l2_norm_clip
=
1.0
,
mech
=
None
):
super
(
_TrainOneStepWithLossScaleCell
,
self
).
__init__
(
auto_prefix
=
False
)
super
(
_TrainOneStepWithLossScaleCell
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
network
=
network
self
.
network
.
add_flags
(
defer_inline
=
True
)
self
.
network
.
add_flags
(
defer_inline
=
True
)
...
@@ -343,7 +335,8 @@ class _TrainOneStepWithLossScaleCell(Cell):
...
@@ -343,7 +335,8 @@ class _TrainOneStepWithLossScaleCell(Cell):
# dp params
# dp params
self
.
_micro_batches
=
micro_batches
self
.
_micro_batches
=
micro_batches
self
.
_l2_norm
=
l2_norm_clip
float_norm_clip
=
check_param_type
(
'l2_norm_clip'
,
l2_norm_clip
,
float
)
self
.
_l2_norm
=
check_value_positive
(
'l2_norm_clip'
,
float_norm_clip
)
self
.
_split
=
P
.
Split
(
0
,
self
.
_micro_batches
)
self
.
_split
=
P
.
Split
(
0
,
self
.
_micro_batches
)
self
.
_clip_by_global_norm
=
_ClipGradients
()
self
.
_clip_by_global_norm
=
_ClipGradients
()
self
.
_mech
=
mech
self
.
_mech
=
mech
...
@@ -435,9 +428,9 @@ class _TrainOneStepCell(Cell):
...
@@ -435,9 +428,9 @@ class _TrainOneStepCell(Cell):
Args:
Args:
network (Cell): The training network.
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
sens (Number): The scaling number to be filled as the input of back
propagation. Default value is 1.0.
micro_batches (int): The number of small batches split from an origi
an
l batch. Default: None.
micro_batches (int): The number of small batches split from an origi
na
l batch. Default: None.
l2_norm_clip (float): Use to clip the bound, if set 1, will retu
n the original data. Default: None
.
l2_norm_clip (float): Use to clip the bound, if set 1, will retu
rn the original data. Default: 1.0
.
mech (Mechanisms): The object can generate the different type of noise. Default: None.
mech (Mechanisms): The object can generate the different type of noise. Default: None.
Inputs:
Inputs:
...
@@ -446,16 +439,9 @@ class _TrainOneStepCell(Cell):
...
@@ -446,16 +439,9 @@ class _TrainOneStepCell(Cell):
Outputs:
Outputs:
Tensor, a scalar Tensor with shape :math:`()`.
Tensor, a scalar Tensor with shape :math:`()`.
Examples:
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> loss_net = nn.WithLossCell(net, loss_fn)
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
"""
"""
def
__init__
(
self
,
network
,
optimizer
,
sens
=
1.0
,
micro_batches
=
None
,
l2_norm_clip
=
None
,
mech
=
None
):
def
__init__
(
self
,
network
,
optimizer
,
sens
=
1.0
,
micro_batches
=
None
,
l2_norm_clip
=
1.0
,
mech
=
None
):
super
(
_TrainOneStepCell
,
self
).
__init__
(
auto_prefix
=
False
)
super
(
_TrainOneStepCell
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
network
=
network
self
.
network
.
add_flags
(
defer_inline
=
True
)
self
.
network
.
add_flags
(
defer_inline
=
True
)
...
@@ -475,7 +461,8 @@ class _TrainOneStepCell(Cell):
...
@@ -475,7 +461,8 @@ class _TrainOneStepCell(Cell):
# dp params
# dp params
self
.
_micro_batches
=
micro_batches
self
.
_micro_batches
=
micro_batches
self
.
_l2_norm
=
l2_norm_clip
float_norm_clip
=
check_param_type
(
'l2_norm_clip'
,
l2_norm_clip
,
float
)
self
.
_l2_norm
=
check_value_positive
(
'l2_norm_clip'
,
float_norm_clip
)
self
.
_split
=
P
.
Split
(
0
,
self
.
_micro_batches
)
self
.
_split
=
P
.
Split
(
0
,
self
.
_micro_batches
)
self
.
_clip_by_global_norm
=
_ClipGradients
()
self
.
_clip_by_global_norm
=
_ClipGradients
()
self
.
_mech
=
mech
self
.
_mech
=
mech
...
...
tests/ut/python/diff_privacy/test_model_train.py
浏览文件 @
f3baf9db
...
@@ -18,7 +18,6 @@ import pytest
...
@@ -18,7 +18,6 @@ import pytest
import
numpy
as
np
import
numpy
as
np
from
mindspore
import
nn
from
mindspore
import
nn
from
mindspore.nn
import
SGD
from
mindspore.model_zoo.lenet
import
LeNet5
from
mindspore.model_zoo.lenet
import
LeNet5
from
mindspore
import
context
from
mindspore
import
context
import
mindspore.dataset
as
ds
import
mindspore.dataset
as
ds
...
@@ -43,22 +42,24 @@ def test_dp_model():
...
@@ -43,22 +42,24 @@ def test_dp_model():
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"Ascend"
)
l2_norm_bound
=
1.0
l2_norm_bound
=
1.0
initial_noise_multiplier
=
0.01
initial_noise_multiplier
=
0.01
net
=
LeNet5
()
net
work
=
LeNet5
()
batch_size
=
32
batch_size
=
32
batches
=
128
batches
=
128
epochs
=
1
epochs
=
1
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
optim
=
SGD
(
params
=
net
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
gaussian_mech
=
DPOptimizerClassFactory
(
micro_batches
=
2
)
gaussian_mech
=
DPOptimizerClassFactory
()
gaussian_mech
.
set_mechanisms
(
'Gaussian'
,
gaussian_mech
.
set_mechanisms
(
'Gaussian'
,
norm_bound
=
l2_norm_bound
,
norm_bound
=
l2_norm_bound
,
initial_noise_multiplier
=
initial_noise_multiplier
)
initial_noise_multiplier
=
initial_noise_multiplier
)
net_opt
=
gaussian_mech
.
create
(
'SGD'
)(
params
=
network
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
model
=
DPModel
(
micro_batches
=
2
,
model
=
DPModel
(
micro_batches
=
2
,
norm_clip
=
l2_norm_bound
,
norm_clip
=
l2_norm_bound
,
dp_mech
=
gaussian_mech
.
mech
,
dp_mech
=
gaussian_mech
.
mech
,
network
=
net
,
network
=
net
work
,
loss_fn
=
loss
,
loss_fn
=
loss
,
optimizer
=
optim
,
optimizer
=
net_opt
,
metrics
=
None
)
metrics
=
None
)
ms_ds
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
'data'
,
'label'
])
ms_ds
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
'data'
,
'label'
])
ms_ds
.
set_dataset_size
(
batch_size
*
batches
)
ms_ds
.
set_dataset_size
(
batch_size
*
batches
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录