diff --git a/mindarmour/diff_privacy/optimizer/optimizer.py b/mindarmour/diff_privacy/optimizer/optimizer.py index a844e79d7e51d689f51eb2bdf72ff7d4050566a0..0162a203fb9614a1d285bb2a2beff23c2275ce6b 100644 --- a/mindarmour/diff_privacy/optimizer/optimizer.py +++ b/mindarmour/diff_privacy/optimizer/optimizer.py @@ -27,7 +27,7 @@ class DPOptimizerClassFactory: Factory class of Optimizer. Args: - micro_batches (int): The number of small batches split from an origianl batch. Default: None. + micro_batches (int): The number of small batches split from an origianl batch. Default: 2. Returns: Optimizer, Optimizer class @@ -39,7 +39,7 @@ class DPOptimizerClassFactory: >>> learning_rate=cfg.lr, >>> momentum=cfg.momentum) """ - def __init__(self, micro_batches=None): + def __init__(self, micro_batches=2): self._mech_factory = MechanismsFactory() self.mech = None self._micro_batches = check_int_positive('micro_batches', micro_batches) @@ -72,17 +72,7 @@ class DPOptimizerClassFactory: if policy == 'Adam': cls = self._get_dp_optimizer_class(nn.Adam, self.mech, self._micro_batches, *args, **kwargs) return cls - if policy == 'AdamWeightDecay': - cls = self._get_dp_optimizer_class(nn.AdamWeightDecay, self.mech, self._micro_batches, *args, **kwargs) - return cls - if policy == 'AdamWeightDecayDynamicLR': - cls = self._get_dp_optimizer_class(nn.AdamWeightDecayDynamicLR, - self.mech, - self._micro_batches, - *args, **kwargs) - return cls - raise NameError("The {} is not implement, please choose ['SGD', 'Momentum', 'AdamWeightDecay', " - "'Adam', 'AdamWeightDecayDynamicLR']".format(policy)) + raise NameError("The {} is not implement, please choose ['SGD', 'Momentum', 'Adam']".format(policy)) def _get_dp_optimizer_class(self, cls, mech, micro_batches): """ diff --git a/mindarmour/diff_privacy/train/model.py b/mindarmour/diff_privacy/train/model.py index 434bbe0152ccbb866e6c251c617cc42224b99e84..b4233273d07022dd29d3a91bb7edfdc749445963 100644 --- a/mindarmour/diff_privacy/train/model.py +++ b/mindarmour/diff_privacy/train/model.py @@ -48,8 +48,11 @@ from mindspore.nn.wrap.loss_scale import _grad_overflow from mindspore.nn import Cell from mindspore import ParameterTuple +from mindarmour.diff_privacy.mechanisms import mechanisms from mindarmour.utils._check_param import check_param_type from mindarmour.utils._check_param import check_value_positive +from mindarmour.utils._check_param import check_int_positive + GRADIENT_CLIP_TYPE = 1 grad_scale = C.MultitypeFuncGraph("grad_scale") @@ -67,7 +70,7 @@ class DPModel(Model): This class is overload mindspore.train.model.Model. Args: - micro_batches (int): The number of small batches split from an origianl batch. Default: None. + micro_batches (int): The number of small batches split from an origianl batch. Default: 2. norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0. dp_mech (Mechanisms): The object can generate the different type of noise. Default: None. @@ -106,14 +109,17 @@ class DPModel(Model): >>> dataset = get_dataset() >>> model.train(2, dataset) """ - def __init__(self, micro_batches=None, norm_clip=1.0, dp_mech=None, **kwargs): + def __init__(self, micro_batches=2, norm_clip=1.0, dp_mech=None, **kwargs): if micro_batches: - self._micro_batches = int(micro_batches) + self._micro_batches = check_int_positive('micro_batches', micro_batches) else: self._micro_batches = None float_norm_clip = check_param_type('l2_norm_clip', norm_clip, float) self._norm_clip = check_value_positive('l2_norm_clip', float_norm_clip) - self._dp_mech = dp_mech + if isinstance(dp_mech, mechanisms.Mechanisms): + self._dp_mech = dp_mech + else: + raise TypeError('dp mechanisms should be instance of class Mechansms, but got {}'.format(type(dp_mech))) super(DPModel, self).__init__(**kwargs) def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs):