Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
a98b8471
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
2
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a98b8471
编写于
5月 29, 2020
作者:
Z
zhenghuanhuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[MA][diff_privacy][Func] micro_batches and dp_mech not checked
https://gitee.com/mindspore/dashboard/issues?id=I1IS9G
上级
f3baf9db
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
13 addition
and
17 deletion
+13
-17
mindarmour/diff_privacy/optimizer/optimizer.py
mindarmour/diff_privacy/optimizer/optimizer.py
+3
-13
mindarmour/diff_privacy/train/model.py
mindarmour/diff_privacy/train/model.py
+10
-4
未找到文件。
mindarmour/diff_privacy/optimizer/optimizer.py
浏览文件 @
a98b8471
...
...
@@ -27,7 +27,7 @@ class DPOptimizerClassFactory:
Factory class of Optimizer.
Args:
micro_batches (int): The number of small batches split from an origianl batch. Default:
None
.
micro_batches (int): The number of small batches split from an origianl batch. Default:
2
.
Returns:
Optimizer, Optimizer class
...
...
@@ -39,7 +39,7 @@ class DPOptimizerClassFactory:
>>> learning_rate=cfg.lr,
>>> momentum=cfg.momentum)
"""
def
__init__
(
self
,
micro_batches
=
None
):
def
__init__
(
self
,
micro_batches
=
2
):
self
.
_mech_factory
=
MechanismsFactory
()
self
.
mech
=
None
self
.
_micro_batches
=
check_int_positive
(
'micro_batches'
,
micro_batches
)
...
...
@@ -72,17 +72,7 @@ class DPOptimizerClassFactory:
if
policy
==
'Adam'
:
cls
=
self
.
_get_dp_optimizer_class
(
nn
.
Adam
,
self
.
mech
,
self
.
_micro_batches
,
*
args
,
**
kwargs
)
return
cls
if
policy
==
'AdamWeightDecay'
:
cls
=
self
.
_get_dp_optimizer_class
(
nn
.
AdamWeightDecay
,
self
.
mech
,
self
.
_micro_batches
,
*
args
,
**
kwargs
)
return
cls
if
policy
==
'AdamWeightDecayDynamicLR'
:
cls
=
self
.
_get_dp_optimizer_class
(
nn
.
AdamWeightDecayDynamicLR
,
self
.
mech
,
self
.
_micro_batches
,
*
args
,
**
kwargs
)
return
cls
raise
NameError
(
"The {} is not implement, please choose ['SGD', 'Momentum', 'AdamWeightDecay', "
"'Adam', 'AdamWeightDecayDynamicLR']"
.
format
(
policy
))
raise
NameError
(
"The {} is not implement, please choose ['SGD', 'Momentum', 'Adam']"
.
format
(
policy
))
def
_get_dp_optimizer_class
(
self
,
cls
,
mech
,
micro_batches
):
"""
...
...
mindarmour/diff_privacy/train/model.py
浏览文件 @
a98b8471
...
...
@@ -48,8 +48,11 @@ from mindspore.nn.wrap.loss_scale import _grad_overflow
from
mindspore.nn
import
Cell
from
mindspore
import
ParameterTuple
from
mindarmour.diff_privacy.mechanisms
import
mechanisms
from
mindarmour.utils._check_param
import
check_param_type
from
mindarmour.utils._check_param
import
check_value_positive
from
mindarmour.utils._check_param
import
check_int_positive
GRADIENT_CLIP_TYPE
=
1
grad_scale
=
C
.
MultitypeFuncGraph
(
"grad_scale"
)
...
...
@@ -67,7 +70,7 @@ class DPModel(Model):
This class is overload mindspore.train.model.Model.
Args:
micro_batches (int): The number of small batches split from an origianl batch. Default:
None
.
micro_batches (int): The number of small batches split from an origianl batch. Default:
2
.
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0.
dp_mech (Mechanisms): The object can generate the different type of noise. Default: None.
...
...
@@ -106,14 +109,17 @@ class DPModel(Model):
>>> dataset = get_dataset()
>>> model.train(2, dataset)
"""
def
__init__
(
self
,
micro_batches
=
None
,
norm_clip
=
1.0
,
dp_mech
=
None
,
**
kwargs
):
def
__init__
(
self
,
micro_batches
=
2
,
norm_clip
=
1.0
,
dp_mech
=
None
,
**
kwargs
):
if
micro_batches
:
self
.
_micro_batches
=
int
(
micro_batches
)
self
.
_micro_batches
=
check_int_positive
(
'micro_batches'
,
micro_batches
)
else
:
self
.
_micro_batches
=
None
float_norm_clip
=
check_param_type
(
'l2_norm_clip'
,
norm_clip
,
float
)
self
.
_norm_clip
=
check_value_positive
(
'l2_norm_clip'
,
float_norm_clip
)
if
isinstance
(
dp_mech
,
mechanisms
.
Mechanisms
):
self
.
_dp_mech
=
dp_mech
else
:
raise
TypeError
(
'dp mechanisms should be instance of class Mechansms, but got {}'
.
format
(
type
(
dp_mech
)))
super
(
DPModel
,
self
).
__init__
(
**
kwargs
)
def
_amp_build_train_network
(
self
,
network
,
optimizer
,
loss_fn
=
None
,
level
=
'O0'
,
**
kwargs
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录