提交 a98b8471 编写于 作者: Z zhenghuanhuan

[MA][diff_privacy][Func] micro_batches and dp_mech not checked

https://gitee.com/mindspore/dashboard/issues?id=I1IS9G
上级 f3baf9db
......@@ -27,7 +27,7 @@ class DPOptimizerClassFactory:
Factory class of Optimizer.
Args:
micro_batches (int): The number of small batches split from an origianl batch. Default: None.
micro_batches (int): The number of small batches split from an origianl batch. Default: 2.
Returns:
Optimizer, Optimizer class
......@@ -39,7 +39,7 @@ class DPOptimizerClassFactory:
>>> learning_rate=cfg.lr,
>>> momentum=cfg.momentum)
"""
def __init__(self, micro_batches=None):
def __init__(self, micro_batches=2):
self._mech_factory = MechanismsFactory()
self.mech = None
self._micro_batches = check_int_positive('micro_batches', micro_batches)
......@@ -72,17 +72,7 @@ class DPOptimizerClassFactory:
if policy == 'Adam':
cls = self._get_dp_optimizer_class(nn.Adam, self.mech, self._micro_batches, *args, **kwargs)
return cls
if policy == 'AdamWeightDecay':
cls = self._get_dp_optimizer_class(nn.AdamWeightDecay, self.mech, self._micro_batches, *args, **kwargs)
return cls
if policy == 'AdamWeightDecayDynamicLR':
cls = self._get_dp_optimizer_class(nn.AdamWeightDecayDynamicLR,
self.mech,
self._micro_batches,
*args, **kwargs)
return cls
raise NameError("The {} is not implement, please choose ['SGD', 'Momentum', 'AdamWeightDecay', "
"'Adam', 'AdamWeightDecayDynamicLR']".format(policy))
raise NameError("The {} is not implement, please choose ['SGD', 'Momentum', 'Adam']".format(policy))
def _get_dp_optimizer_class(self, cls, mech, micro_batches):
"""
......
......@@ -48,8 +48,11 @@ from mindspore.nn.wrap.loss_scale import _grad_overflow
from mindspore.nn import Cell
from mindspore import ParameterTuple
from mindarmour.diff_privacy.mechanisms import mechanisms
from mindarmour.utils._check_param import check_param_type
from mindarmour.utils._check_param import check_value_positive
from mindarmour.utils._check_param import check_int_positive
GRADIENT_CLIP_TYPE = 1
grad_scale = C.MultitypeFuncGraph("grad_scale")
......@@ -67,7 +70,7 @@ class DPModel(Model):
This class is overload mindspore.train.model.Model.
Args:
micro_batches (int): The number of small batches split from an origianl batch. Default: None.
micro_batches (int): The number of small batches split from an origianl batch. Default: 2.
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0.
dp_mech (Mechanisms): The object can generate the different type of noise. Default: None.
......@@ -106,14 +109,17 @@ class DPModel(Model):
>>> dataset = get_dataset()
>>> model.train(2, dataset)
"""
def __init__(self, micro_batches=None, norm_clip=1.0, dp_mech=None, **kwargs):
def __init__(self, micro_batches=2, norm_clip=1.0, dp_mech=None, **kwargs):
if micro_batches:
self._micro_batches = int(micro_batches)
self._micro_batches = check_int_positive('micro_batches', micro_batches)
else:
self._micro_batches = None
float_norm_clip = check_param_type('l2_norm_clip', norm_clip, float)
self._norm_clip = check_value_positive('l2_norm_clip', float_norm_clip)
self._dp_mech = dp_mech
if isinstance(dp_mech, mechanisms.Mechanisms):
self._dp_mech = dp_mech
else:
raise TypeError('dp mechanisms should be instance of class Mechansms, but got {}'.format(type(dp_mech)))
super(DPModel, self).__init__(**kwargs)
def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册