diff --git a/example/mnist_demo/lenet5_config.py b/example/mnist_demo/lenet5_config.py index e9ac0480a560d31d14ff6821fab95d28733ec767..8206b9795c14c5725a0bd818ec3c4692c29fec5c 100644 --- a/example/mnist_demo/lenet5_config.py +++ b/example/mnist_demo/lenet5_config.py @@ -20,7 +20,7 @@ from easydict import EasyDict as edict mnist_cfg = edict({ 'num_classes': 10, # the number of classes of model's output - 'lr': 0.1, # the learning rate of model's optimizer + 'lr': 0.01, # the learning rate of model's optimizer 'momentum': 0.9, # the momentum value of model's optimizer 'epoch_size': 10, # training epochs 'batch_size': 256, # batch size for training @@ -33,8 +33,13 @@ mnist_cfg = edict({ 'dataset_sink_mode': False, # whether deliver all training data to device one time 'micro_batches': 16, # the number of small batches split from an original batch 'norm_clip': 1.0, # the clip bound of the gradients of model's training parameters - 'initial_noise_multiplier': 1.5, # the initial multiplication coefficient of the noise added to training + 'initial_noise_multiplier': 0.5, # the initial multiplication coefficient of the noise added to training # parameters' gradients - 'mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training + 'noise_mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training + 'clip_mechanisms': 'Gaussian', # the method of adaptive clipping gradients while training + 'clip_decay_policy': 'Linear', # Decay policy of adaptive clipping, decay_policy must be in ['Linear', 'Geometric']. + 'clip_learning_rate': 0.001, # Learning rate of update norm clip. + 'target_unclipped_quantile': 0.9, # Target quantile of norm clip. + 'fraction_stddev': 0.01, # The stddev of Gaussian normal which used in empirical_fraction. 'optimizer': 'Momentum' # the base optimizer used for Differential privacy training }) diff --git a/example/mnist_demo/lenet5_dp.py b/example/mnist_demo/lenet5_dp.py index bf62fd3155c71a06a14a7d109bfbd7c4d09b0078..14eb299bd77d30026f795cb25e902f2ccb9582b5 100644 --- a/example/mnist_demo/lenet5_dp.py +++ b/example/mnist_demo/lenet5_dp.py @@ -31,7 +31,8 @@ import mindspore.common.dtype as mstype from mindarmour.diff_privacy import DPModel from mindarmour.diff_privacy import PrivacyMonitorFactory -from mindarmour.diff_privacy import MechanismsFactory +from mindarmour.diff_privacy import NoiseMechanismsFactory +from mindarmour.diff_privacy import ClipMechanismsFactory from mindarmour.utils.logger import LogUtil from lenet5_net import LeNet5 from lenet5_config import mnist_cfg as cfg @@ -87,11 +88,14 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, if __name__ == "__main__": # This configure can run both in pynative mode and graph mode - context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) + context.set_context(mode=context.GRAPH_MODE, + device_target=cfg.device_target) network = LeNet5() - net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") - config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, - keep_checkpoint_max=cfg.keep_checkpoint_max) + net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, + reduction="mean") + config_ck = CheckpointConfig( + save_checkpoint_steps=cfg.save_checkpoint_steps, + keep_checkpoint_max=cfg.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory='./trained_ckpt_file/', config=config_ck) @@ -102,17 +106,33 @@ if __name__ == "__main__": cfg.epoch_size) if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: - raise ValueError("Number of micro_batches should divide evenly batch_size") - # Create a factory class of DP mechanisms, this method is adding noise in gradients while training. - # Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which - # means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise - # would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism. - mech = MechanismsFactory().create(cfg.mechanisms, - norm_bound=cfg.norm_clip, - initial_noise_multiplier=cfg.initial_noise_multiplier) - net_opt = nn.Momentum(params=network.trainable_params(), learning_rate=cfg.lr, momentum=cfg.momentum) - # Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps - # and delta) while training. + raise ValueError( + "Number of micro_batches should divide evenly batch_size") + # Create a factory class of DP noise mechanisms, this method is adding noise + # in gradients while training. Initial_noise_multiplier is suggested to be + # greater than 1.0, otherwise the privacy budget would be huge, which means + # that the privacy protection effect is weak. Mechanisms can be 'Gaussian' + # or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' + # mechanism while be constant with 'Gaussian' mechanism. + noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms, + norm_bound=cfg.norm_clip, + initial_noise_multiplier=cfg.initial_noise_multiplier) + # Create a factory class of clip mechanisms, this method is to adaptive clip + # gradients while training, decay_policy support 'Linear' and 'Geometric', + # learning_rate is the learning rate to update clip_norm, + # target_unclipped_quantile is the target quantile of norm clip, + # fraction_stddev is the stddev of Gaussian normal which used in + # empirical_fraction, the formula is + # $empirical_fraction + N(0, fraction_stddev)$. + clip_mech = ClipMechanismsFactory().create(cfg.clip_mechanisms, + decay_policy=cfg.clip_decay_policy, + learning_rate=cfg.clip_learning_rate, + target_unclipped_quantile=cfg.target_unclipped_quantile, + fraction_stddev=cfg.fraction_stddev) + net_opt = nn.Momentum(params=network.trainable_params(), + learning_rate=cfg.lr, momentum=cfg.momentum) + # Create a monitor for DP training. The function of the monitor is to + # compute and print the privacy budget(eps and delta) while training. rdp_monitor = PrivacyMonitorFactory.create('rdp', num_samples=60000, batch_size=cfg.batch_size, @@ -121,20 +141,23 @@ if __name__ == "__main__": # Create the DP model for training. model = DPModel(micro_batches=cfg.micro_batches, norm_clip=cfg.norm_clip, - mech=mech, + noise_mech=noise_mech, + clip_mech=clip_mech, network=network, loss_fn=net_loss, optimizer=net_opt, metrics={"Accuracy": Accuracy()}) LOGGER.info(TAG, "============== Starting Training ==============") - model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor], + model.train(cfg['epoch_size'], ds_train, + callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor], dataset_sink_mode=cfg.dataset_sink_mode) LOGGER.info(TAG, "============== Starting Testing ==============") ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt' param_dict = load_checkpoint(ckpt_file_name) load_param_into_net(network, param_dict) - ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), batch_size=cfg.batch_size) + ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), + batch_size=cfg.batch_size) acc = model.eval(ds_eval, dataset_sink_mode=False) LOGGER.info(TAG, "============== Accuracy: %s ==============", acc) diff --git a/mindarmour/diff_privacy/__init__.py b/mindarmour/diff_privacy/__init__.py index bf8a1494c4eb2f1e766294b85a97da8a6e97b90c..22057d7f3b876d43145449940712012f9f6b9f26 100644 --- a/mindarmour/diff_privacy/__init__.py +++ b/mindarmour/diff_privacy/__init__.py @@ -1,16 +1,20 @@ """ This module provide Differential Privacy feature to protect user privacy. """ -from .mechanisms.mechanisms import GaussianRandom +from .mechanisms.mechanisms import NoiseGaussianRandom from .mechanisms.mechanisms import AdaGaussianRandom -from .mechanisms.mechanisms import MechanismsFactory +from .mechanisms.mechanisms import AdaClippingWithGaussianRandom +from .mechanisms.mechanisms import NoiseMechanismsFactory +from .mechanisms.mechanisms import ClipMechanismsFactory from .monitor.monitor import PrivacyMonitorFactory from .optimizer.optimizer import DPOptimizerClassFactory from .train.model import DPModel -__all__ = ['GaussianRandom', +__all__ = ['NoiseGaussianRandom', 'AdaGaussianRandom', - 'MechanismsFactory', + 'AdaClippingWithGaussianRandom', + 'NoiseMechanismsFactory', + 'ClipMechanismsFactory', 'PrivacyMonitorFactory', 'DPOptimizerClassFactory', 'DPModel'] diff --git a/mindarmour/diff_privacy/mechanisms/mechanisms.py b/mindarmour/diff_privacy/mechanisms/mechanisms.py index 908f10be8523d84fcddbf293befbc1fe0d7a14ba..aed6043e25c0cb0c72fd155bd9cc8b3df4c0b7b2 100644 --- a/mindarmour/diff_privacy/mechanisms/mechanisms.py +++ b/mindarmour/diff_privacy/mechanisms/mechanisms.py @@ -28,11 +28,54 @@ from mindarmour.utils._check_param import check_param_in_range from mindarmour.utils.logger import LogUtil LOGGER = LogUtil.get_instance() -TAG = 'Defense' +TAG = 'NoiseMechanism' -class MechanismsFactory: - """ Factory class of mechanisms""" +class ClipMechanismsFactory: + """ Factory class of clip mechanisms""" + + def __init__(self): + pass + + @staticmethod + def create(mech_name, *args, **kwargs): + """ + Args: + mech_name(str): Clip noise generated strategy, support 'Gaussian' now. + args(Union[float, str]): Parameters used for creating clip mechanisms. + kwargs(Union[float, str]): Parameters used for creating clip + mechanisms. + + Raises: + NameError: `mech_name` must be in ['Gaussian']. + + Returns: + Mechanisms, class of noise generated Mechanism. + + Examples: + >>> decay_policy = 'Linear' + >>> beta = Tensor(0.5, mstype.float32) + >>> norm_clip = Tensor(1.0, mstype.float32) + >>> beta_stddev = 0.1 + >>> learning_rate = 0.1 + >>> target_unclipped_quantile = 0.3 + >>> clip_mechanism = ClipMechanismsFactory() + >>> ada_clip = clip_mechanism.create('Gaussian', + >>> decay_policy=decay_policy, + >>> learning_rate=learning_rate, + >>> target_unclipped_quantile=target_unclipped_quantile, + >>> fraction_stddev=beta_stddev) + >>> next_norm_clip = ada_clip(beta, norm_clip) + + """ + if mech_name == 'Gaussian': + return AdaClippingWithGaussianRandom(*args, **kwargs) + raise NameError("The {} is not implement, please choose " + "['Gaussian']".format(mech_name)) + + +class NoiseMechanismsFactory: + """ Factory class of noise mechanisms""" def __init__(self): pass @@ -56,42 +99,38 @@ class MechanismsFactory: Mechanisms, class of noise generated Mechanism. Examples: - >>> class Net(nn.Cell): - >>> def __init__(self): - >>> super(Net, self).__init__() - >>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') - >>> self.bn = nn.BatchNorm2d(64) - >>> self.relu = nn.ReLU() - >>> self.flatten = nn.Flatten() - >>> self.fc = nn.Dense(64*224*224, 12) # padding=0 - >>> - >>> def construct(self, x): - >>> x = self.conv(x) - >>> x = self.bn(x) - >>> x = self.relu(x) - >>> x = self.flatten(x) - >>> out = self.fc(x) - >>> return out >>> norm_clip = 1.0 - >>> initial_noise_multiplier = 1.5 - >>> net = Net() + >>> initial_noise_multiplier = 0.01 + >>> network = LeNet5() + >>> batch_size = 32 + >>> batches = 128 + >>> epochs = 1 >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) - >>> net_opt = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9) - >>> mech = MechanismsFactory().create('Gaussian', - >>> norm_bound=norm_clip, - >>> initial_noise_multiplier=initial_noise_multiplier) + >>> noise_mech = NoiseMechanismsFactory().create('Gaussian', + >>> norm_bound=norm_clip, + >>> initial_noise_multiplier=initial_noise_multiplier) + >>> clip_mech = ClipMechanismsFactory().create('Gaussian', + >>> decay_policy='Linear', + >>> learning_rate=0.01, + >>> target_unclipped_quantile=0.9, + >>> fraction_stddev=0.01) + >>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, + >>> momentum=0.9) >>> model = DPModel(micro_batches=2, - >>> norm_clip=1.0, - >>> mech=mech, - >>> network=net, + >>> clip_mech=clip_mech, + >>> norm_clip=norm_clip, + >>> noise_mech=noise_mech, + >>> network=network, >>> loss_fn=loss, >>> optimizer=net_opt, >>> metrics=None) - >>> dataset = get_dataset() - >>> model.train(2, dataset) + >>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), + >>> ['data', 'label']) + >>> ms_ds.set_dataset_size(batch_size * batches) + >>> model.train(epochs, ms_ds, dataset_sink_mode=False) """ if policy == 'Gaussian': - return GaussianRandom(*args, **kwargs) + return NoiseGaussianRandom(*args, **kwargs) if policy == 'AdaGaussian': return AdaGaussianRandom(*args, **kwargs) raise NameError("The {} is not implement, please choose " @@ -110,7 +149,7 @@ class Mechanisms(Cell): """ -class GaussianRandom(Mechanisms): +class NoiseGaussianRandom(Mechanisms): """ Gaussian noise generated mechanism. @@ -133,18 +172,21 @@ class GaussianRandom(Mechanisms): >>> gradients = Tensor([0.2, 0.9], mstype.float32) >>> norm_bound = 0.5 >>> initial_noise_multiplier = 1.5 - >>> net = GaussianRandom(norm_bound, initial_noise_multiplier) + >>> net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier) >>> res = net(gradients) >>> print(res) """ - def __init__(self, norm_bound=0.5, initial_noise_multiplier=1.5, seed=0, policy=None): - super(GaussianRandom, self).__init__() + def __init__(self, norm_bound=0.5, initial_noise_multiplier=1.5, seed=0, + policy=None): + super(NoiseGaussianRandom, self).__init__() self._norm_bound = check_value_positive('norm_bound', norm_bound) self._norm_bound = Tensor(norm_bound, mstype.float32) - self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier', - initial_noise_multiplier) - self._initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32) + self._initial_noise_multiplier = check_value_positive( + 'initial_noise_multiplier', + initial_noise_multiplier) + self._initial_noise_multiplier = Tensor(initial_noise_multiplier, + mstype.float32) self._mean = Tensor(0, mstype.float32) self._normal = P.Normal(seed=seed) self._decay_policy = policy @@ -201,17 +243,20 @@ class AdaGaussianRandom(Mechanisms): noise_decay_rate=6e-4, decay_policy='Time', seed=0): super(AdaGaussianRandom, self).__init__() norm_bound = check_value_positive('norm_bound', norm_bound) - initial_noise_multiplier = check_value_positive('initial_noise_multiplier', - initial_noise_multiplier) + initial_noise_multiplier = check_value_positive( + 'initial_noise_multiplier', + initial_noise_multiplier) self._norm_bound = Tensor(norm_bound, mstype.float32) - initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32) + initial_noise_multiplier = Tensor(initial_noise_multiplier, + mstype.float32) self._initial_noise_multiplier = Parameter(initial_noise_multiplier, name='initial_noise_multiplier') self._noise_multiplier = Parameter(initial_noise_multiplier, name='noise_multiplier') self._mean = Tensor(0, mstype.float32) - noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float) + noise_decay_rate = check_param_type('noise_decay_rate', + noise_decay_rate, float) check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0) self._noise_decay_rate = Tensor(noise_decay_rate, mstype.float32) if decay_policy not in ['Time', 'Step', 'Exp']: @@ -232,7 +277,9 @@ class AdaGaussianRandom(Mechanisms): Tensor, generated noise with shape like given gradients. """ shape = P.Shape()(gradients) - noise = self._normal(shape, self._mean, self._mul(self._noise_multiplier, self._norm_bound)) + noise = self._normal(shape, self._mean, + self._mul(self._noise_multiplier, + self._norm_bound)) return noise @@ -241,10 +288,14 @@ class _MechanismsParamsUpdater(Cell): Update mechanisms parameters, the parameters will refresh in train period. Args: - policy(str): Pass in by the mechanisms class, mechanisms parameters update policy. - decay_rate(Tensor): Pass in by the mechanisms class, hyper parameter for controlling the decay size. - cur_noise_multiplier(Parameter): Pass in by the mechanisms class, current params value in this time. - init_noise_multiplier(Parameter):Pass in by the mechanisms class, initial params value to be updated. + policy(str): Pass in by the mechanisms class, mechanisms parameters + update policy. + decay_rate(Tensor): Pass in by the mechanisms class, hyper parameter for + controlling the decay size. + cur_noise_multiplier(Parameter): Pass in by the mechanisms class, + current params value in this time. + init_noise_multiplier(Parameter):Pass in by the mechanisms class, + initial params value to be updated. Returns: Tuple, next params value. @@ -281,5 +332,100 @@ class _MechanismsParamsUpdater(Cell): next_noise_multiplier = self._assign(self._cur_noise_multiplier, self._mul(temp, self._cur_noise_multiplier)) else: - next_noise_multiplier = self._assign(self._cur_noise_multiplier, self._div(self._one, self._exp(self._one))) + next_noise_multiplier = self._assign(self._cur_noise_multiplier, + self._div(self._one, self._exp(self._one))) return next_noise_multiplier + + +class AdaClippingWithGaussianRandom(Cell): + """ + Adaptive clipping. If `decay_policy` is 'Linear', the update formula is + $ norm_clip = norm_clip - learning_rate*(beta-target_unclipped_quantile)$. + `decay_policy` is 'Geometric', the update formula is + $ norm_clip = norm_clip*exp(-learning_rate*(empirical_fraction-target_unclipped_quantile))$. + where beta is the empirical fraction of samples with the value at most + `target_unclipped_quantile`. + + Args: + decay_policy(str): Decay policy of adaptive clipping, decay_policy must + be in ['Linear', 'Geometric']. Default: Linear. + learning_rate(float): Learning rate of update norm clip. Default: 0.01. + target_unclipped_quantile(float): Target quantile of norm clip. Default: 0.9. + fraction_stddev(float): The stddev of Gaussian normal which used in + empirical_fraction, the formula is $empirical_fraction + N(0, fraction_stddev)$. + seed(int): Original random seed, if seed=0 random normal will use secure + random number. IF seed!=0 random normal will generate values using + given seed. Default: 0. + + Returns: + Tensor, undated norm clip . + + Examples: + >>> decay_policy = 'Linear' + >>> beta = Tensor(0.5, mstype.float32) + >>> norm_clip = Tensor(1.0, mstype.float32) + >>> beta_stddev = 0.01 + >>> learning_rate = 0.001 + >>> target_unclipped_quantile = 0.9 + >>> ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, + >>> learning_rate=learning_rate, + >>> target_unclipped_quantile=target_unclipped_quantile, + >>> fraction_stddev=beta_stddev) + >>> next_norm_clip = ada_clip(beta, norm_clip) + + """ + + def __init__(self, decay_policy='Linear', learning_rate=0.001, + target_unclipped_quantile=0.9, fraction_stddev=0.01, seed=0): + super(AdaClippingWithGaussianRandom, self).__init__() + if decay_policy not in ['Linear', 'Geometric']: + msg = "decay policy of adaptive clip must be in ['Linear', 'Geometric'], \ + but got: {}".format(decay_policy) + LOGGER.error(TAG, msg) + raise ValueError(msg) + self._decay_policy = decay_policy + learning_rate = check_param_type('learning_rate', learning_rate, float) + learning_rate = check_value_positive('learning_rate', learning_rate) + self._learning_rate = Tensor(learning_rate, mstype.float32) + fraction_stddev = check_param_type('fraction_stddev', fraction_stddev, float) + self._fraction_stddev = Tensor(fraction_stddev, mstype.float32) + target_unclipped_quantile = check_param_type('target_unclipped_quantile', + target_unclipped_quantile, + float) + self._target_unclipped_quantile = Tensor(target_unclipped_quantile, + mstype.float32) + + self._zero = Tensor(0, mstype.float32) + self._add = P.TensorAdd() + self._sub = P.Sub() + self._mul = P.Mul() + self._exp = P.Exp() + self._normal = P.Normal(seed=seed) + + def construct(self, empirical_fraction, norm_clip): + """ + Update value of norm_clip. + + Args: + empirical_fraction(Tensor): empirical fraction of samples with the + value at most `target_unclipped_quantile`. + norm_clip(Tensor): Clipping bound for the l2 norm of the gradients. + + Returns: + Tensor, generated noise with shape like given gradients. + """ + fraction_noise = self._normal((1,), self._zero, self._fraction_stddev) + empirical_fraction = self._add(empirical_fraction, fraction_noise) + if self._decay_policy == 'Linear': + grad_clip = self._sub(empirical_fraction, + self._target_unclipped_quantile) + next_norm_clip = self._sub(norm_clip, + self._mul(self._learning_rate, grad_clip)) + + # decay_policy == 'Geometric' + else: + grad_clip = self._sub(empirical_fraction, + self._target_unclipped_quantile) + grad_clip = self._exp(self._mul(-self._learning_rate, grad_clip)) + next_norm_clip = self._mul(norm_clip, grad_clip) + return next_norm_clip diff --git a/mindarmour/diff_privacy/optimizer/optimizer.py b/mindarmour/diff_privacy/optimizer/optimizer.py index 1c28ce1f0d3d2ec5a4cc23589cea2e9efb0cd75b..efd4a6a46b620eb734525e703efab1398b65560f 100644 --- a/mindarmour/diff_privacy/optimizer/optimizer.py +++ b/mindarmour/diff_privacy/optimizer/optimizer.py @@ -22,7 +22,7 @@ from mindspore.ops import functional as F from mindspore.common import dtype as mstype from mindarmour.utils.logger import LogUtil -from mindarmour.diff_privacy import MechanismsFactory +from mindarmour.diff_privacy import NoiseMechanismsFactory from mindarmour.diff_privacy.mechanisms.mechanisms import _MechanismsParamsUpdater from mindarmour.utils._check_param import check_int_positive @@ -70,7 +70,7 @@ class DPOptimizerClassFactory: """ def __init__(self, micro_batches=2): - self._mech_factory = MechanismsFactory() + self._mech_factory = NoiseMechanismsFactory() self.mech = None self._micro_batches = check_int_positive('micro_batches', micro_batches) diff --git a/mindarmour/diff_privacy/train/model.py b/mindarmour/diff_privacy/train/model.py index 7991449d4de1c6e6342f1577db881e633aea556d..dcd02b17165bfa64db09ddbc3aba4c3fbd7ec837 100644 --- a/mindarmour/diff_privacy/train/model.py +++ b/mindarmour/diff_privacy/train/model.py @@ -48,7 +48,8 @@ from mindspore.nn import Cell from mindspore import ParameterTuple from mindarmour.utils.logger import LogUtil -from mindarmour.diff_privacy.mechanisms.mechanisms import _MechanismsParamsUpdater +from mindarmour.diff_privacy.mechanisms.mechanisms import \ + _MechanismsParamsUpdater from mindarmour.utils._check_param import check_param_type from mindarmour.utils._check_param import check_value_positive from mindarmour.utils._check_param import check_int_positive @@ -64,7 +65,7 @@ _reciprocal = P.Reciprocal() @_grad_scale.register("Tensor", "Tensor") def tensor_grad_scale(scale, grad): """ grad scaling """ - return grad * F.cast(_reciprocal(scale), F.dtype(grad)) + return grad*F.cast(_reciprocal(scale), F.dtype(grad)) class DPModel(Model): @@ -72,9 +73,14 @@ class DPModel(Model): This class is overload mindspore.train.model.Model. Args: - micro_batches (int): The number of small batches split from an original batch. Default: 2. - norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0. - mech (Mechanisms): The object can generate the different type of noise. Default: None. + micro_batches (int): The number of small batches split from an original + batch. Default: 2. + norm_clip (float): Use to clip the bound, if set 1, will retun the + original data. Default: 1.0. + noise_mech (Mechanisms): The object can generate the different type of + noise. Default: None. + clip_mech (Mechanisms): The object is used to update the adaptive clip . + Default: None. Examples: >>> norm_clip = 1.0 @@ -89,63 +95,82 @@ class DPModel(Model): >>> factory_opt.set_mechanisms('Gaussian', >>> norm_bound=norm_clip, >>> initial_noise_multiplier=initial_noise_multiplier) - >>> net_opt = factory_opt.create('Momentum')(network.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> net_opt = factory_opt.create('Momentum')(network.trainable_params(), + >>> learning_rate=0.1, momentum=0.9) + >>> clip_mech = ClipMechanismsFactory().create('Gaussian', + >>> decay_policy='Linear', + >>> learning_rate=0.01, + >>> target_unclipped_quantile=0.9, + >>> fraction_stddev=0.01) >>> model = DPModel(micro_batches=micro_batches, >>> norm_clip=norm_clip, - >>> mech=None, + >>> clip_mech=clip_mech, + >>> noise_mech=None, >>> network=network, >>> loss_fn=loss, >>> optimizer=net_opt, >>> metrics=None) - >>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) - >>> ms_ds.set_dataset_size(batch_size * batches) + >>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), + >>> ['data', 'label']) + >>> ms_ds.set_dataset_size(batch_size*batches) >>> model.train(epochs, ms_ds, dataset_sink_mode=False) """ - def __init__(self, micro_batches=2, norm_clip=1.0, mech=None, **kwargs): + def __init__(self, micro_batches=2, norm_clip=1.0, noise_mech=None, + clip_mech=None, **kwargs): if micro_batches: - self._micro_batches = check_int_positive('micro_batches', micro_batches) + self._micro_batches = check_int_positive('micro_batches', + micro_batches) else: self._micro_batches = None norm_clip = check_param_type('norm_clip', norm_clip, float) - self._norm_clip = check_value_positive('norm_clip', norm_clip) - if mech is not None and "DPOptimizer" in kwargs['optimizer'].__class__.__name__: - msg = 'DPOptimizer is not supported while mech is not None' + norm_clip = check_value_positive('norm_clip', norm_clip) + norm_clip = Tensor(norm_clip, mstype.float32) + self._norm_clip = Parameter(norm_clip, 'norm_clip') + if noise_mech is not None and "DPOptimizer" in kwargs['optimizer'].__class__.__name__: + msg = 'DPOptimizer is not supported while noise_mech is not None' LOGGER.error(TAG, msg) raise ValueError(msg) - if mech is None: + if noise_mech is None: if "DPOptimizer" in kwargs['optimizer'].__class__.__name__: if context.get_context('mode') != context.PYNATIVE_MODE: msg = 'DPOptimizer just support pynative mode currently.' LOGGER.error(TAG, msg) raise ValueError(msg) else: - msg = 'DPModel should set mech or DPOptimizer configure, please refer to example.' + msg = 'DPModel should set noise_mech or DPOptimizer configure, ' \ + 'please refer to example.' LOGGER.error(TAG, msg) raise ValueError(msg) - self._mech = mech + self._noise_mech = noise_mech + if clip_mech is not None: + self._clip_mech = clip_mech super(DPModel, self).__init__(**kwargs) - def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs): + def _amp_build_train_network(self, network, optimizer, loss_fn=None, + level='O0', **kwargs): """ Build the mixed precision training cell automatically. Args: network (Cell): Definition of the network. - loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, the `network` should have the loss inside. - Default: None. + loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, + the `network` should have the loss inside. Default: None. optimizer (Optimizer): Optimizer to update the Parameter. level (str): Supports [O0, O2]. Default: "O0". - O0: Do not change. - - O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32, - using dynamic loss scale. - - cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`. - If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting. - keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. - loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else - scale the loss by LossScaleManager. If set, overwrite the level setting. + - O2: Cast network to float16, keep batchnorm and `loss_fn` + (if set) run in float32, using dynamic loss scale. + + cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` + or `mstype.float32`. If set to `mstype.float16`, use `float16` + mode to train. If set, overwrite the level setting. + keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, + overwrite the level setting. + loss_scale_manager (Union[None, LossScaleManager]): If None, not + scale the loss, or else scale the loss by LossScaleManager. + If set, overwrite the level setting. """ validator.check_value_type('network', network, nn.Cell, None) validator.check_value_type('optimizer', optimizer, nn.Optimizer, None) @@ -161,9 +186,11 @@ class DPModel(Model): _do_keep_batchnorm_fp32(network) if loss_fn: - network = _add_loss_network(network, loss_fn, config.cast_model_type) + network = _add_loss_network(network, loss_fn, + config.cast_model_type) - if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + if _get_parallel_mode() in ( + ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): network = _VirtualDatasetCell(network) loss_scale = 1.0 @@ -173,9 +200,12 @@ class DPModel(Model): update_cell = loss_scale_manager.get_update_cell() if update_cell is not None: # only cpu not support `TrainOneStepWithLossScaleCell` for control flow. - if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU": - msg = "Only `loss_scale_manager=None` and `loss_scale_manager=FixedLossScaleManager(drop_overflow" \ - "_update=False)` are supported in current version. If you use `O2` option, please use " \ + if not context.get_context("enable_ge") and context.get_context( + "device_target") == "CPU": + msg = "Only `loss_scale_manager=None` and " \ + "`loss_scale_manager=FixedLossScaleManager(drop_overflow" \ + "_update=False)` are supported in current version. " \ + "If you use `O2` option, please use " \ "`loss_scale_manager=None` or `FixedLossScaleManager`" LOGGER.error(TAG, msg) raise ValueError(msg) @@ -184,15 +214,17 @@ class DPModel(Model): scale_update_cell=update_cell, micro_batches=self._micro_batches, norm_clip=self._norm_clip, - mech=self._mech).set_train() + clip_mech=self._clip_mech, + noise_mech=self._noise_mech).set_train() return network network = _TrainOneStepCell(network, optimizer, + self._norm_clip, loss_scale, micro_batches=self._micro_batches, - norm_clip=self._norm_clip, - mech=self._mech).set_train() + clip_mech=self._clip_mech, + noise_mech=self._noise_mech).set_train() return network def _build_train_network(self): @@ -233,7 +265,8 @@ class DPModel(Model): elif self._loss_fn: network = nn.WithLossCell(network, self._loss_fn) - if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, + ParallelMode.AUTO_PARALLEL): network.set_auto_parallel() return network @@ -267,11 +300,10 @@ class _ClipGradients(nn.Cell): new_grads = () for grad in grads: if clip_type == 0: - t = C.clip_by_value(grad, F.tuple_to_array((-clip_value,)), - F.tuple_to_array((clip_value,))) + norm = C.clip_by_value(grad, -clip_value, clip_value) else: - t = self.clip_by_norm(grad, F.tuple_to_array((clip_value,))) - new_grads = new_grads + (t,) + norm = self.clip_by_norm(grad, clip_value) + new_grads = new_grads + (norm,) return new_grads @@ -292,20 +324,27 @@ class _TrainOneStepWithLossScaleCell(Cell): r""" Network training with loss scaling. - This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update - Cell as args. The loss scale value can be updated in both host side or device side. The - TrainOneStepWithLossScaleCell will be compiled to be graph which takes `data`, `label`, `sens` as input - data. The `sens` is acting as loss scaling value. If you want to update it on host side, the value should - be provided. If `sens` is not given, the loss scale update logic should be provied by `scale_update_cell`. - If `scale_update_cell` is not None and `sens` is provided, the `scale_update_cell` will be ignored. + This is a training step with loss scaling. It takes a network, an optimizer + and possibly a scale update Cell as args. The loss scale value can be + updated in both host side or device side. The TrainOneStepWithLossScaleCell + will be compiled to be graph which takes `data`, `label`, `sens` as input + data. The `sens` is acting as loss scaling value. If you want to update it + on host side, the value should be provided. If `sens` is not given, the loss + scale update logic should be provied by `scale_update_cell`. If + `scale_update_cell` is not None and `sens` is provided, the + `scale_update_cell` will be ignored. Args: network (Cell): The training network. optimizer (Cell): Optimizer for updating the weights. - scale_update_cell(Cell): The loss scaling update logic cell. Default: None. - micro_batches (int): The number of small batches split from an original batch. Default: None. - norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0. - mech (Mechanisms): The object can generate the different type of noise. Default: None. + scale_update_cell(Cell): The loss scaling update logic cell. + Default: None. + micro_batches (int): The number of small batches split from an original + batch. Default: None. + norm_clip (Tensor): Use to clip the bound, if set 1, will return the + original data. Default: 1.0. + noise_mech (Mechanisms): The object can generate the different type of + noise. Default: None. Inputs: - **inputs** (Tensor) - Tensor of shape :math:`(N, \ldots)`. @@ -320,7 +359,9 @@ class _TrainOneStepWithLossScaleCell(Cell): - **loss_scale** (Tensor) - Tensor with shape :math:`()`. """ - def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, norm_clip=1.0, mech=None): + def __init__(self, network, optimizer, scale_update_cell=None, + micro_batches=None, norm_clip=1.0, noise_mech=None, + clip_mech=None): super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) self.network = network self.network.set_grad() @@ -346,39 +387,54 @@ class _TrainOneStepWithLossScaleCell(Cell): self.allreduce = P.AllReduce() self.parallel_mode = _get_parallel_mode() self.grad_reducer = F.identity - self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL] + self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, + ParallelMode.HYBRID_PARALLEL] if self.reducer_flag: mean = _get_mirror_mean() degree = _get_device_num() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + self.grad_reducer = DistributedGradReducer(optimizer.parameters, + mean, degree) self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE self.loss_scale = None self.loss_scaling_manager = scale_update_cell if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), - name="loss_scale") + self.loss_scale = Parameter( + Tensor(scale_update_cell.get_loss_scale(), + dtype=mstype.float32), + name="loss_scale") self.add_flags(has_effect=True) # dp params self._micro_batches = micro_batches - norm_clip = check_param_type('norm_clip', norm_clip, float) - self._l2_norm = check_value_positive('norm_clip', norm_clip) + self._norm_clip = norm_clip self._split = P.Split(0, self._micro_batches) self._clip_by_global_norm = _ClipGradients() - self._mech = mech + self._noise_mech = noise_mech + self._clip_mech = clip_mech + self._add = P.TensorAdd() + self._norm = nn.Norm() self._tuple_add = _TupleAdd() self._hyper_map = C.HyperMap() self._micro_float = Tensor(micro_batches, mstype.float32) - - self._mech_param_updater = None - if self._mech is not None and self._mech._decay_policy is not None: - self._mech_param_updater = _MechanismsParamsUpdater(policy=self._mech._decay_policy, - decay_rate=self._mech._noise_decay_rate, - cur_noise_multiplier= - self._mech._noise_multiplier, - init_noise_multiplier= - self._mech._initial_noise_multiplier) + self._zero = Tensor(0, mstype.float32) + self._assign = P.Assign() + self._div = P.Div() + self._sqrt = P.Sqrt() + self._reduce_sum = P.ReduceSum() + self._square_all = P.Square() + self._less = P.Less() + self._cast = P.Cast() + + self._noise_mech_param_updater = None + if self._noise_mech is not None and self._noise_mech._decay_policy is not None: + self._noise_mech_param_updater = _MechanismsParamsUpdater( + policy=self._noise_mech._decay_policy, + decay_rate=self._noise_mech._noise_decay_rate, + cur_noise_multiplier= + self._noise_mech._noise_multiplier, + init_noise_multiplier= + self._noise_mech._initial_noise_multiplier) def construct(self, data, label, sens=None): """ @@ -402,30 +458,62 @@ class _TrainOneStepWithLossScaleCell(Cell): record_labels = self._split(label) # first index loss = self.network(record_datas[0], record_labels[0]) - scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) - record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], scaling_sens_filled) - record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) + scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, + F.dtype(loss)) + record_grad = self.grad(self.network, weights)(record_datas[0], + record_labels[0], + scaling_sens_filled) + + beta = self._zero + square_sum = self._zero + for grad in record_grad: + square_sum = self._add(square_sum, + self._reduce_sum(self._square_all(grad))) + norm_grad = self._sqrt(square_sum) + beta = self._add(beta, + self._cast(self._less(norm_grad, self._norm_clip), + mstype.float32)) + record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, + self._norm_clip) grads = record_grad total_loss = loss for i in range(1, self._micro_batches): loss = self.network(record_datas[i], record_labels[i]) - scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) - record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], scaling_sens_filled) - record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) + scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, + F.dtype(loss)) + record_grad = self.grad(self.network, weights)(record_datas[i], + record_labels[i], + scaling_sens_filled) + + square_sum = self._zero + for grad in record_grad: + square_sum = self._add(square_sum, + self._reduce_sum(self._square_all(grad))) + norm_grad = self._sqrt(square_sum) + beta = self._add(beta, + self._cast(self._less(norm_grad, self._norm_clip), + mstype.float32)) + + record_grad = self._clip_by_global_norm(record_grad, + GRADIENT_CLIP_TYPE, + self._norm_clip) grads = self._tuple_add(grads, record_grad) total_loss = P.TensorAdd()(total_loss, loss) loss = P.Div()(total_loss, self._micro_float) + beta = self._div(beta, self._micro_batches) - if self._mech is not None: + if self._noise_mech is not None: grad_noise_tuple = () for grad_item in grads: grad_noise = self._mech(grad_item) grad_noise_tuple = grad_noise_tuple + (grad_noise,) grads = self._tuple_add(grads, grad_noise_tuple) - grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) + grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), + grads) # update mech parameters - if self._mech_param_updater is not None: - multiplier = self._mech_param_updater() + + if self._noise_mech_param_updater is not None: + multiplier = self._noise_mech_param_updater() loss = F.depend(loss, multiplier) grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) @@ -456,6 +544,10 @@ class _TrainOneStepWithLossScaleCell(Cell): else: opt = self.optimizer(grads) ret = (loss, cond, scaling_sens) + + if self._clip_mech is not None: + next_norm_clip = self._clip_mech(beta, self._norm_clip) + P.assign(self._norm_clip, next_norm_clip) return F.depend(ret, opt) @@ -463,17 +555,22 @@ class _TrainOneStepCell(Cell): r""" Network training package class. - Wraps the network with an optimizer. The resulting Cell be trained with input data and label. - Backward graph will be created in the construct function to do parameter updating. Different - parallel modes are available to run the training. + Wraps the network with an optimizer. The resulting Cell be trained with + input data and label. Backward graph will be created in the construct + function to do parameter updating. Different parallel modes are available + to run the training. Args: network (Cell): The training network. optimizer (Cell): Optimizer for updating the weights. - sens (Number): The scaling number to be filled as the input of back propagation. Default value is 1.0. - micro_batches (int): The number of small batches split from an original batch. Default: None. - norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0. - mech (Mechanisms): The object can generate the different type of noise. Default: None. + sens (Number): The scaling number to be filled as the input of back + propagation. Default value is 1.0. + micro_batches (int): The number of small batches split from an original + batch. Default: None. + norm_clip (Tensor): Use to clip the bound, if set 1, will return the + original data. Default: 1.0. + noise_mech (Mechanisms): The object can generate the different type + of noise. Default: None. Inputs: - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. @@ -483,7 +580,9 @@ class _TrainOneStepCell(Cell): Tensor, a scalar Tensor with shape :math:`()`. """ - def __init__(self, network, optimizer, sens=1.0, micro_batches=None, norm_clip=1.0, mech=None): + def __init__(self, network, optimizer, norm_clip=1.0, sens=1.0, + micro_batches=None, + noise_mech=None, clip_mech=None): super(_TrainOneStepCell, self).__init__(auto_prefix=False) self.network = network self.network.set_grad() @@ -495,36 +594,51 @@ class _TrainOneStepCell(Cell): self.reducer_flag = False self.grad_reducer = None parallel_mode = _get_parallel_mode() - if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): + if parallel_mode in ( + ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): self.reducer_flag = True if self.reducer_flag: mean = _get_mirror_mean() degree = _get_device_num() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + self.grad_reducer = DistributedGradReducer(optimizer.parameters, + mean, degree) # dp params if micro_batches is None: - msg = 'micro_batches must give in differential privacy, but got value: {}'.format(micro_batches) + msg = 'micro_batches must give in differential privacy, but got value: {}'.format( + micro_batches) LOGGER.error(TAG, msg) raise ValueError(msg) self._micro_batches = micro_batches - norm_clip = check_param_type('norm_clip', norm_clip, float) - self._l2_norm = check_value_positive('norm_clip', norm_clip) + self._norm_clip = norm_clip self._split = P.Split(0, self._micro_batches) self._clip_by_global_norm = _ClipGradients() - self._mech = mech + self._noise_mech = noise_mech + self._clip_mech = clip_mech self._tuple_add = _TupleAdd() + self._add = P.TensorAdd() + self._norm = nn.Norm() self._hyper_map = C.HyperMap() + self._zero = Tensor(0, mstype.float32) + self._assign = P.Assign() + self._div = P.Div() + self._sqrt = P.Sqrt() + self._reduce_sum = P.ReduceSum() + self._square_all = P.Square() + self._less = P.Less() + self._cast = P.Cast() + self._micro_float = Tensor(micro_batches, mstype.float32) - self._mech_param_updater = None - if self._mech is not None and self._mech._decay_policy is not None: - self._mech_param_updater = _MechanismsParamsUpdater(policy=self._mech._decay_policy, - decay_rate=self._mech._noise_decay_rate, - cur_noise_multiplier= - self._mech._noise_multiplier, - init_noise_multiplier= - self._mech._initial_noise_multiplier) + self._noise_mech_param_updater = None + if self._noise_mech is not None and self._noise_mech._decay_policy is not None: + self._noise_mech_param_updater = _MechanismsParamsUpdater( + policy=self._noise_mech._decay_policy, + decay_rate=self._noise_mech._noise_decay_rate, + cur_noise_multiplier= + self._noise_mech._noise_multiplier, + init_noise_multiplier= + self._noise_mech._initial_noise_multiplier) def construct(self, data, label): """ @@ -535,32 +649,65 @@ class _TrainOneStepCell(Cell): record_labels = self._split(label) loss = self.network(record_datas[0], record_labels[0]) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) - record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], sens) - record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) + record_grad = self.grad(self.network, weights)(record_datas[0], + record_labels[0], sens) + beta = self._zero + square_sum = self._zero + for grad in record_grad: + square_sum = self._add(square_sum, + self._reduce_sum(self._square_all(grad))) + norm_grad = self._sqrt(square_sum) + beta = self._add(beta, + self._cast(self._less(norm_grad, self._norm_clip), + mstype.float32)) + + record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, + self._norm_clip) grads = record_grad total_loss = loss for i in range(1, self._micro_batches): loss = self.network(record_datas[i], record_labels[i]) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) - record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], sens) - record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) + record_grad = self.grad(self.network, weights)(record_datas[i], + record_labels[i], + sens) + square_sum = self._zero + for grad in record_grad: + square_sum = self._add(square_sum, + self._reduce_sum(self._square_all(grad))) + norm_grad = self._sqrt(square_sum) + beta = self._add(beta, + self._cast(self._less(norm_grad, self._norm_clip), + mstype.float32)) + + record_grad = self._clip_by_global_norm(record_grad, + GRADIENT_CLIP_TYPE, + self._norm_clip) grads = self._tuple_add(grads, record_grad) total_loss = P.TensorAdd()(total_loss, loss) - loss = P.Div()(total_loss, self._micro_float) + loss = self._div(total_loss, self._micro_float) + beta = self._div(beta, self._micro_batches) - if self._mech is not None: + if self._noise_mech is not None: grad_noise_tuple = () for grad_item in grads: - grad_noise = self._mech(grad_item) + grad_noise = self._noise_mech(grad_item) grad_noise_tuple = grad_noise_tuple + (grad_noise,) grads = self._tuple_add(grads, grad_noise_tuple) - grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) + grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), + grads) # update mech parameters - if self._mech_param_updater is not None: - multiplier = self._mech_param_updater() + if self._noise_mech_param_updater is not None: + multiplier = self._noise_mech_param_updater() loss = F.depend(loss, multiplier) if self.reducer_flag: # apply grad reducer on grads grads = self.grad_reducer(grads) + + if self._clip_mech is not None: + next_norm_clip = self._clip_mech(beta, self._norm_clip) + self._norm_clip = self._assign(self._norm_clip, next_norm_clip) + loss = F.depend(loss, next_norm_clip) + return F.depend(loss, self.optimizer(grads)) diff --git a/tests/ut/python/diff_privacy/test_mechanisms.py b/tests/ut/python/diff_privacy/test_mechanisms.py index 2f7262cfc545596882a1e378b0059d7504b891b6..f201031aea73f9f7af2069860b5a4798faf16ee1 100644 --- a/tests/ut/python/diff_privacy/test_mechanisms.py +++ b/tests/ut/python/diff_privacy/test_mechanisms.py @@ -19,9 +19,11 @@ import pytest from mindspore import context from mindspore import Tensor from mindspore.common import dtype as mstype -from mindarmour.diff_privacy import GaussianRandom +from mindarmour.diff_privacy import NoiseGaussianRandom from mindarmour.diff_privacy import AdaGaussianRandom -from mindarmour.diff_privacy import MechanismsFactory +from mindarmour.diff_privacy import AdaClippingWithGaussianRandom +from mindarmour.diff_privacy import NoiseMechanismsFactory +from mindarmour.diff_privacy import ClipMechanismsFactory @pytest.mark.level0 @@ -33,7 +35,7 @@ def test_graph_gaussian(): grad = Tensor([0.3, 0.2, 0.4], mstype.float32) norm_bound = 1.0 initial_noise_multiplier = 0.1 - net = GaussianRandom(norm_bound, initial_noise_multiplier) + net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier) res = net(grad) print(res) @@ -47,7 +49,7 @@ def test_pynative_gaussian(): grad = Tensor([0.3, 0.2, 0.4], mstype.float32) norm_bound = 1.0 initial_noise_multiplier = 0.1 - net = GaussianRandom(norm_bound, initial_noise_multiplier) + net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier) res = net(grad) print(res) @@ -80,13 +82,13 @@ def test_graph_factory(): initial_noise_multiplier = 0.1 alpha = 0.5 decay_policy = 'Step' - noise_mechanism = MechanismsFactory() + noise_mechanism = NoiseMechanismsFactory() noise_construct = noise_mechanism.create('Gaussian', norm_bound, initial_noise_multiplier) noise = noise_construct(grad) print('Gaussian noise: ', noise) - ada_mechanism = MechanismsFactory() + ada_mechanism = NoiseMechanismsFactory() ada_noise_construct = ada_mechanism.create('AdaGaussian', norm_bound, initial_noise_multiplier, @@ -124,13 +126,13 @@ def test_pynative_factory(): initial_noise_multiplier = 0.1 alpha = 0.5 decay_policy = 'Step' - noise_mechanism = MechanismsFactory() + noise_mechanism = NoiseMechanismsFactory() noise_construct = noise_mechanism.create('Gaussian', norm_bound, initial_noise_multiplier) noise = noise_construct(grad) print('Gaussian noise: ', noise) - ada_mechanism = MechanismsFactory() + ada_mechanism = NoiseMechanismsFactory() ada_noise_construct = ada_mechanism.create('AdaGaussian', norm_bound, initial_noise_multiplier, @@ -151,7 +153,7 @@ def test_pynative_exponential(): initial_noise_multiplier = 0.1 alpha = 0.5 decay_policy = 'Exp' - ada_mechanism = MechanismsFactory() + ada_mechanism = NoiseMechanismsFactory() ada_noise_construct = ada_mechanism.create('AdaGaussian', norm_bound, initial_noise_multiplier, @@ -172,7 +174,7 @@ def test_graph_exponential(): initial_noise_multiplier = 0.1 alpha = 0.5 decay_policy = 'Exp' - ada_mechanism = MechanismsFactory() + ada_mechanism = NoiseMechanismsFactory() ada_noise_construct = ada_mechanism.create('AdaGaussian', norm_bound, initial_noise_multiplier, @@ -180,3 +182,107 @@ def test_graph_exponential(): decay_policy=decay_policy) ada_noise = ada_noise_construct(grad) print('ada noise: ', ada_noise) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_ada_clip_gaussian_random_pynative(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + decay_policy = 'Linear' + beta = Tensor(0.5, mstype.float32) + norm_clip = Tensor(1.0, mstype.float32) + beta_stddev = 0.1 + learning_rate = 0.1 + target_unclipped_quantile = 0.3 + ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, + learning_rate=learning_rate, + target_unclipped_quantile=target_unclipped_quantile, + fraction_stddev=beta_stddev, + seed=1) + next_norm_clip = ada_clip(beta, norm_clip) + print('Liner next norm clip:', next_norm_clip) + + decay_policy = 'Geometric' + ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, + learning_rate=learning_rate, + target_unclipped_quantile=target_unclipped_quantile, + fraction_stddev=beta_stddev, + seed=1) + next_norm_clip = ada_clip(beta, norm_clip) + print('Geometric next norm clip:', next_norm_clip) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_ada_clip_gaussian_random_graph(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + decay_policy = 'Linear' + beta = Tensor(0.5, mstype.float32) + norm_clip = Tensor(1.0, mstype.float32) + beta_stddev = 0.1 + learning_rate = 0.1 + target_unclipped_quantile = 0.3 + ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, + learning_rate=learning_rate, + target_unclipped_quantile=target_unclipped_quantile, + fraction_stddev=beta_stddev, + seed=1) + next_norm_clip = ada_clip(beta, norm_clip) + print('Liner next norm clip:', next_norm_clip) + + decay_policy = 'Geometric' + ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy, + learning_rate=learning_rate, + target_unclipped_quantile=target_unclipped_quantile, + fraction_stddev=beta_stddev, + seed=1) + next_norm_clip = ada_clip(beta, norm_clip) + print('Geometric next norm clip:', next_norm_clip) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_pynative_clip_mech_factory(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + decay_policy = 'Linear' + beta = Tensor(0.5, mstype.float32) + norm_clip = Tensor(1.0, mstype.float32) + beta_stddev = 0.1 + learning_rate = 0.1 + target_unclipped_quantile = 0.3 + clip_mechanism = ClipMechanismsFactory() + ada_clip = clip_mechanism.create('Gaussian', + decay_policy=decay_policy, + learning_rate=learning_rate, + target_unclipped_quantile=target_unclipped_quantile, + fraction_stddev=beta_stddev) + next_norm_clip = ada_clip(beta, norm_clip) + print('next_norm_clip: ', next_norm_clip) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_graph_clip_mech_factory(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + decay_policy = 'Linear' + beta = Tensor(0.5, mstype.float32) + norm_clip = Tensor(1.0, mstype.float32) + beta_stddev = 0.1 + learning_rate = 0.1 + target_unclipped_quantile = 0.3 + clip_mechanism = ClipMechanismsFactory() + ada_clip = clip_mechanism.create('Gaussian', + decay_policy=decay_policy, + learning_rate=learning_rate, + target_unclipped_quantile=target_unclipped_quantile, + fraction_stddev=beta_stddev) + next_norm_clip = ada_clip(beta, norm_clip) + print('next_norm_clip: ', next_norm_clip) diff --git a/tests/ut/python/diff_privacy/test_model_train.py b/tests/ut/python/diff_privacy/test_model_train.py index aac04b5da88cdeea603f92c2226083a6d9523459..892913c45182a128bd59cccd3cf6aa2e71a81df0 100644 --- a/tests/ut/python/diff_privacy/test_model_train.py +++ b/tests/ut/python/diff_privacy/test_model_train.py @@ -22,7 +22,8 @@ from mindspore import context import mindspore.dataset as ds from mindarmour.diff_privacy import DPModel -from mindarmour.diff_privacy import MechanismsFactory +from mindarmour.diff_privacy import NoiseMechanismsFactory +from mindarmour.diff_privacy import ClipMechanismsFactory from mindarmour.diff_privacy import DPOptimizerClassFactory from test_network import LeNet5 @@ -30,10 +31,12 @@ from test_network import LeNet5 def dataset_generator(batch_size, batches): """mock training data.""" - data = np.random.random((batches * batch_size, 1, 32, 32)).astype(np.float32) - label = np.random.randint(0, 10, batches * batch_size).astype(np.int32) + data = np.random.random((batches*batch_size, 1, 32, 32)).astype( + np.float32) + label = np.random.randint(0, 10, batches*batch_size).astype(np.int32) for i in range(batches): - yield data[i * batch_size:(i + 1) * batch_size], label[i * batch_size:(i + 1) * batch_size] + yield data[i*batch_size:(i + 1)*batch_size],\ + label[i*batch_size:(i + 1)*batch_size] @pytest.mark.level0 @@ -55,16 +58,24 @@ def test_dp_model_with_pynative_mode(): factory_opt.set_mechanisms('Gaussian', norm_bound=norm_clip, initial_noise_multiplier=initial_noise_multiplier) - net_opt = factory_opt.create('Momentum')(network.trainable_params(), learning_rate=0.1, momentum=0.9) + net_opt = factory_opt.create('Momentum')(network.trainable_params(), + learning_rate=0.1, momentum=0.9) + clip_mech = ClipMechanismsFactory().create('Gaussian', + decay_policy='Linear', + learning_rate=0.01, + target_unclipped_quantile=0.9, + fraction_stddev=0.01) model = DPModel(micro_batches=micro_batches, norm_clip=norm_clip, - mech=None, + clip_mech=clip_mech, + noise_mech=None, network=network, loss_fn=loss, optimizer=net_opt, metrics=None) - ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) - ms_ds.set_dataset_size(batch_size * batches) + ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), + ['data', 'label']) + ms_ds.set_dataset_size(batch_size*batches) model.train(epochs, ms_ds, dataset_sink_mode=False) @@ -82,19 +93,27 @@ def test_dp_model_with_graph_mode(): batches = 128 epochs = 1 loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) - mech = MechanismsFactory().create('Gaussian', - norm_bound=norm_clip, - initial_noise_multiplier=initial_noise_multiplier) - net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) + noise_mech = NoiseMechanismsFactory().create('Gaussian', + norm_bound=norm_clip, + initial_noise_multiplier=initial_noise_multiplier) + clip_mech = ClipMechanismsFactory().create('Gaussian', + decay_policy='Linear', + learning_rate=0.01, + target_unclipped_quantile=0.9, + fraction_stddev=0.01) + net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, + momentum=0.9) model = DPModel(micro_batches=2, + clip_mech=clip_mech, norm_clip=norm_clip, - mech=mech, + noise_mech=noise_mech, network=network, loss_fn=loss, optimizer=net_opt, metrics=None) - ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) - ms_ds.set_dataset_size(batch_size * batches) + ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), + ['data', 'label']) + ms_ds.set_dataset_size(batch_size*batches) model.train(epochs, ms_ds, dataset_sink_mode=False) @@ -112,17 +131,25 @@ def test_dp_model_with_graph_mode_ada_gaussian(): batches = 128 epochs = 1 loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) - mech = MechanismsFactory().create('AdaGaussian', - norm_bound=norm_clip, - initial_noise_multiplier=initial_noise_multiplier) - net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) + noise_mech = NoiseMechanismsFactory().create('AdaGaussian', + norm_bound=norm_clip, + initial_noise_multiplier=initial_noise_multiplier) + clip_mech = ClipMechanismsFactory().create('Gaussian', + decay_policy='Linear', + learning_rate=0.01, + target_unclipped_quantile=0.9, + fraction_stddev=0.01) + net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, + momentum=0.9) model = DPModel(micro_batches=2, + clip_mech=clip_mech, norm_clip=norm_clip, - mech=mech, + noise_mech=noise_mech, network=network, loss_fn=loss, optimizer=net_opt, metrics=None) - ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) - ms_ds.set_dataset_size(batch_size * batches) + ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), + ['data', 'label']) + ms_ds.set_dataset_size(batch_size*batches) model.train(epochs, ms_ds, dataset_sink_mode=False)