提交 a98b8471 编写于 作者: Z zhenghuanhuan

[MA][diff_privacy][Func] micro_batches and dp_mech not checked

https://gitee.com/mindspore/dashboard/issues?id=I1IS9G
上级 f3baf9db
...@@ -27,7 +27,7 @@ class DPOptimizerClassFactory: ...@@ -27,7 +27,7 @@ class DPOptimizerClassFactory:
Factory class of Optimizer. Factory class of Optimizer.
Args: Args:
micro_batches (int): The number of small batches split from an origianl batch. Default: None. micro_batches (int): The number of small batches split from an origianl batch. Default: 2.
Returns: Returns:
Optimizer, Optimizer class Optimizer, Optimizer class
...@@ -39,7 +39,7 @@ class DPOptimizerClassFactory: ...@@ -39,7 +39,7 @@ class DPOptimizerClassFactory:
>>> learning_rate=cfg.lr, >>> learning_rate=cfg.lr,
>>> momentum=cfg.momentum) >>> momentum=cfg.momentum)
""" """
def __init__(self, micro_batches=None): def __init__(self, micro_batches=2):
self._mech_factory = MechanismsFactory() self._mech_factory = MechanismsFactory()
self.mech = None self.mech = None
self._micro_batches = check_int_positive('micro_batches', micro_batches) self._micro_batches = check_int_positive('micro_batches', micro_batches)
...@@ -72,17 +72,7 @@ class DPOptimizerClassFactory: ...@@ -72,17 +72,7 @@ class DPOptimizerClassFactory:
if policy == 'Adam': if policy == 'Adam':
cls = self._get_dp_optimizer_class(nn.Adam, self.mech, self._micro_batches, *args, **kwargs) cls = self._get_dp_optimizer_class(nn.Adam, self.mech, self._micro_batches, *args, **kwargs)
return cls return cls
if policy == 'AdamWeightDecay': raise NameError("The {} is not implement, please choose ['SGD', 'Momentum', 'Adam']".format(policy))
cls = self._get_dp_optimizer_class(nn.AdamWeightDecay, self.mech, self._micro_batches, *args, **kwargs)
return cls
if policy == 'AdamWeightDecayDynamicLR':
cls = self._get_dp_optimizer_class(nn.AdamWeightDecayDynamicLR,
self.mech,
self._micro_batches,
*args, **kwargs)
return cls
raise NameError("The {} is not implement, please choose ['SGD', 'Momentum', 'AdamWeightDecay', "
"'Adam', 'AdamWeightDecayDynamicLR']".format(policy))
def _get_dp_optimizer_class(self, cls, mech, micro_batches): def _get_dp_optimizer_class(self, cls, mech, micro_batches):
""" """
......
...@@ -48,8 +48,11 @@ from mindspore.nn.wrap.loss_scale import _grad_overflow ...@@ -48,8 +48,11 @@ from mindspore.nn.wrap.loss_scale import _grad_overflow
from mindspore.nn import Cell from mindspore.nn import Cell
from mindspore import ParameterTuple from mindspore import ParameterTuple
from mindarmour.diff_privacy.mechanisms import mechanisms
from mindarmour.utils._check_param import check_param_type from mindarmour.utils._check_param import check_param_type
from mindarmour.utils._check_param import check_value_positive from mindarmour.utils._check_param import check_value_positive
from mindarmour.utils._check_param import check_int_positive
GRADIENT_CLIP_TYPE = 1 GRADIENT_CLIP_TYPE = 1
grad_scale = C.MultitypeFuncGraph("grad_scale") grad_scale = C.MultitypeFuncGraph("grad_scale")
...@@ -67,7 +70,7 @@ class DPModel(Model): ...@@ -67,7 +70,7 @@ class DPModel(Model):
This class is overload mindspore.train.model.Model. This class is overload mindspore.train.model.Model.
Args: Args:
micro_batches (int): The number of small batches split from an origianl batch. Default: None. micro_batches (int): The number of small batches split from an origianl batch. Default: 2.
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0. norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0.
dp_mech (Mechanisms): The object can generate the different type of noise. Default: None. dp_mech (Mechanisms): The object can generate the different type of noise. Default: None.
...@@ -106,14 +109,17 @@ class DPModel(Model): ...@@ -106,14 +109,17 @@ class DPModel(Model):
>>> dataset = get_dataset() >>> dataset = get_dataset()
>>> model.train(2, dataset) >>> model.train(2, dataset)
""" """
def __init__(self, micro_batches=None, norm_clip=1.0, dp_mech=None, **kwargs): def __init__(self, micro_batches=2, norm_clip=1.0, dp_mech=None, **kwargs):
if micro_batches: if micro_batches:
self._micro_batches = int(micro_batches) self._micro_batches = check_int_positive('micro_batches', micro_batches)
else: else:
self._micro_batches = None self._micro_batches = None
float_norm_clip = check_param_type('l2_norm_clip', norm_clip, float) float_norm_clip = check_param_type('l2_norm_clip', norm_clip, float)
self._norm_clip = check_value_positive('l2_norm_clip', float_norm_clip) self._norm_clip = check_value_positive('l2_norm_clip', float_norm_clip)
if isinstance(dp_mech, mechanisms.Mechanisms):
self._dp_mech = dp_mech self._dp_mech = dp_mech
else:
raise TypeError('dp mechanisms should be instance of class Mechansms, but got {}'.format(type(dp_mech)))
super(DPModel, self).__init__(**kwargs) super(DPModel, self).__init__(**kwargs)
def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs): def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册