提交 ad8a19bf 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!51 Add new feature of adaptive clipping

Merge pull request !51 from ZhidanLiu/master
......@@ -20,7 +20,7 @@ from easydict import EasyDict as edict
mnist_cfg = edict({
'num_classes': 10, # the number of classes of model's output
'lr': 0.1, # the learning rate of model's optimizer
'lr': 0.01, # the learning rate of model's optimizer
'momentum': 0.9, # the momentum value of model's optimizer
'epoch_size': 10, # training epochs
'batch_size': 256, # batch size for training
......@@ -33,8 +33,13 @@ mnist_cfg = edict({
'dataset_sink_mode': False, # whether deliver all training data to device one time
'micro_batches': 16, # the number of small batches split from an original batch
'norm_clip': 1.0, # the clip bound of the gradients of model's training parameters
'initial_noise_multiplier': 1.5, # the initial multiplication coefficient of the noise added to training
'initial_noise_multiplier': 0.5, # the initial multiplication coefficient of the noise added to training
# parameters' gradients
'mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training
'noise_mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training
'clip_mechanisms': 'Gaussian', # the method of adaptive clipping gradients while training
'clip_decay_policy': 'Linear', # Decay policy of adaptive clipping, decay_policy must be in ['Linear', 'Geometric'].
'clip_learning_rate': 0.001, # Learning rate of update norm clip.
'target_unclipped_quantile': 0.9, # Target quantile of norm clip.
'fraction_stddev': 0.01, # The stddev of Gaussian normal which used in empirical_fraction.
'optimizer': 'Momentum' # the base optimizer used for Differential privacy training
})
......@@ -31,7 +31,8 @@ import mindspore.common.dtype as mstype
from mindarmour.diff_privacy import DPModel
from mindarmour.diff_privacy import PrivacyMonitorFactory
from mindarmour.diff_privacy import MechanismsFactory
from mindarmour.diff_privacy import NoiseMechanismsFactory
from mindarmour.diff_privacy import ClipMechanismsFactory
from mindarmour.utils.logger import LogUtil
from lenet5_net import LeNet5
from lenet5_config import mnist_cfg as cfg
......@@ -87,11 +88,14 @@ def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1,
if __name__ == "__main__":
# This configure can run both in pynative mode and graph mode
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
context.set_context(mode=context.GRAPH_MODE,
device_target=cfg.device_target)
network = LeNet5()
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True,
reduction="mean")
config_ck = CheckpointConfig(
save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory='./trained_ckpt_file/',
config=config_ck)
......@@ -102,17 +106,33 @@ if __name__ == "__main__":
cfg.epoch_size)
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0:
raise ValueError("Number of micro_batches should divide evenly batch_size")
# Create a factory class of DP mechanisms, this method is adding noise in gradients while training.
# Initial_noise_multiplier is suggested to be greater than 1.0, otherwise the privacy budget would be huge, which
# means that the privacy protection effect is weak. Mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise
# would be decayed with 'AdaGaussian' mechanism while be constant with 'Gaussian' mechanism.
mech = MechanismsFactory().create(cfg.mechanisms,
norm_bound=cfg.norm_clip,
initial_noise_multiplier=cfg.initial_noise_multiplier)
net_opt = nn.Momentum(params=network.trainable_params(), learning_rate=cfg.lr, momentum=cfg.momentum)
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps
# and delta) while training.
raise ValueError(
"Number of micro_batches should divide evenly batch_size")
# Create a factory class of DP noise mechanisms, this method is adding noise
# in gradients while training. Initial_noise_multiplier is suggested to be
# greater than 1.0, otherwise the privacy budget would be huge, which means
# that the privacy protection effect is weak. Mechanisms can be 'Gaussian'
# or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian'
# mechanism while be constant with 'Gaussian' mechanism.
noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms,
norm_bound=cfg.norm_clip,
initial_noise_multiplier=cfg.initial_noise_multiplier)
# Create a factory class of clip mechanisms, this method is to adaptive clip
# gradients while training, decay_policy support 'Linear' and 'Geometric',
# learning_rate is the learning rate to update clip_norm,
# target_unclipped_quantile is the target quantile of norm clip,
# fraction_stddev is the stddev of Gaussian normal which used in
# empirical_fraction, the formula is
# $empirical_fraction + N(0, fraction_stddev)$.
clip_mech = ClipMechanismsFactory().create(cfg.clip_mechanisms,
decay_policy=cfg.clip_decay_policy,
learning_rate=cfg.clip_learning_rate,
target_unclipped_quantile=cfg.target_unclipped_quantile,
fraction_stddev=cfg.fraction_stddev)
net_opt = nn.Momentum(params=network.trainable_params(),
learning_rate=cfg.lr, momentum=cfg.momentum)
# Create a monitor for DP training. The function of the monitor is to
# compute and print the privacy budget(eps and delta) while training.
rdp_monitor = PrivacyMonitorFactory.create('rdp',
num_samples=60000,
batch_size=cfg.batch_size,
......@@ -121,20 +141,23 @@ if __name__ == "__main__":
# Create the DP model for training.
model = DPModel(micro_batches=cfg.micro_batches,
norm_clip=cfg.norm_clip,
mech=mech,
noise_mech=noise_mech,
clip_mech=clip_mech,
network=network,
loss_fn=net_loss,
optimizer=net_opt,
metrics={"Accuracy": Accuracy()})
LOGGER.info(TAG, "============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor],
model.train(cfg['epoch_size'], ds_train,
callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor],
dataset_sink_mode=cfg.dataset_sink_mode)
LOGGER.info(TAG, "============== Starting Testing ==============")
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt'
param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(network, param_dict)
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), batch_size=cfg.batch_size)
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'),
batch_size=cfg.batch_size)
acc = model.eval(ds_eval, dataset_sink_mode=False)
LOGGER.info(TAG, "============== Accuracy: %s ==============", acc)
"""
This module provide Differential Privacy feature to protect user privacy.
"""
from .mechanisms.mechanisms import GaussianRandom
from .mechanisms.mechanisms import NoiseGaussianRandom
from .mechanisms.mechanisms import AdaGaussianRandom
from .mechanisms.mechanisms import MechanismsFactory
from .mechanisms.mechanisms import AdaClippingWithGaussianRandom
from .mechanisms.mechanisms import NoiseMechanismsFactory
from .mechanisms.mechanisms import ClipMechanismsFactory
from .monitor.monitor import PrivacyMonitorFactory
from .optimizer.optimizer import DPOptimizerClassFactory
from .train.model import DPModel
__all__ = ['GaussianRandom',
__all__ = ['NoiseGaussianRandom',
'AdaGaussianRandom',
'MechanismsFactory',
'AdaClippingWithGaussianRandom',
'NoiseMechanismsFactory',
'ClipMechanismsFactory',
'PrivacyMonitorFactory',
'DPOptimizerClassFactory',
'DPModel']
......@@ -28,11 +28,54 @@ from mindarmour.utils._check_param import check_param_in_range
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
TAG = 'Defense'
TAG = 'NoiseMechanism'
class MechanismsFactory:
""" Factory class of mechanisms"""
class ClipMechanismsFactory:
""" Factory class of clip mechanisms"""
def __init__(self):
pass
@staticmethod
def create(mech_name, *args, **kwargs):
"""
Args:
mech_name(str): Clip noise generated strategy, support 'Gaussian' now.
args(Union[float, str]): Parameters used for creating clip mechanisms.
kwargs(Union[float, str]): Parameters used for creating clip
mechanisms.
Raises:
NameError: `mech_name` must be in ['Gaussian'].
Returns:
Mechanisms, class of noise generated Mechanism.
Examples:
>>> decay_policy = 'Linear'
>>> beta = Tensor(0.5, mstype.float32)
>>> norm_clip = Tensor(1.0, mstype.float32)
>>> beta_stddev = 0.1
>>> learning_rate = 0.1
>>> target_unclipped_quantile = 0.3
>>> clip_mechanism = ClipMechanismsFactory()
>>> ada_clip = clip_mechanism.create('Gaussian',
>>> decay_policy=decay_policy,
>>> learning_rate=learning_rate,
>>> target_unclipped_quantile=target_unclipped_quantile,
>>> fraction_stddev=beta_stddev)
>>> next_norm_clip = ada_clip(beta, norm_clip)
"""
if mech_name == 'Gaussian':
return AdaClippingWithGaussianRandom(*args, **kwargs)
raise NameError("The {} is not implement, please choose "
"['Gaussian']".format(mech_name))
class NoiseMechanismsFactory:
""" Factory class of noise mechanisms"""
def __init__(self):
pass
......@@ -56,42 +99,38 @@ class MechanismsFactory:
Mechanisms, class of noise generated Mechanism.
Examples:
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
>>> self.bn = nn.BatchNorm2d(64)
>>> self.relu = nn.ReLU()
>>> self.flatten = nn.Flatten()
>>> self.fc = nn.Dense(64*224*224, 12) # padding=0
>>>
>>> def construct(self, x):
>>> x = self.conv(x)
>>> x = self.bn(x)
>>> x = self.relu(x)
>>> x = self.flatten(x)
>>> out = self.fc(x)
>>> return out
>>> norm_clip = 1.0
>>> initial_noise_multiplier = 1.5
>>> net = Net()
>>> initial_noise_multiplier = 0.01
>>> network = LeNet5()
>>> batch_size = 32
>>> batches = 128
>>> epochs = 1
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> net_opt = Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9)
>>> mech = MechanismsFactory().create('Gaussian',
>>> norm_bound=norm_clip,
>>> initial_noise_multiplier=initial_noise_multiplier)
>>> noise_mech = NoiseMechanismsFactory().create('Gaussian',
>>> norm_bound=norm_clip,
>>> initial_noise_multiplier=initial_noise_multiplier)
>>> clip_mech = ClipMechanismsFactory().create('Gaussian',
>>> decay_policy='Linear',
>>> learning_rate=0.01,
>>> target_unclipped_quantile=0.9,
>>> fraction_stddev=0.01)
>>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1,
>>> momentum=0.9)
>>> model = DPModel(micro_batches=2,
>>> norm_clip=1.0,
>>> mech=mech,
>>> network=net,
>>> clip_mech=clip_mech,
>>> norm_clip=norm_clip,
>>> noise_mech=noise_mech,
>>> network=network,
>>> loss_fn=loss,
>>> optimizer=net_opt,
>>> metrics=None)
>>> dataset = get_dataset()
>>> model.train(2, dataset)
>>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches),
>>> ['data', 'label'])
>>> ms_ds.set_dataset_size(batch_size * batches)
>>> model.train(epochs, ms_ds, dataset_sink_mode=False)
"""
if policy == 'Gaussian':
return GaussianRandom(*args, **kwargs)
return NoiseGaussianRandom(*args, **kwargs)
if policy == 'AdaGaussian':
return AdaGaussianRandom(*args, **kwargs)
raise NameError("The {} is not implement, please choose "
......@@ -110,7 +149,7 @@ class Mechanisms(Cell):
"""
class GaussianRandom(Mechanisms):
class NoiseGaussianRandom(Mechanisms):
"""
Gaussian noise generated mechanism.
......@@ -133,18 +172,21 @@ class GaussianRandom(Mechanisms):
>>> gradients = Tensor([0.2, 0.9], mstype.float32)
>>> norm_bound = 0.5
>>> initial_noise_multiplier = 1.5
>>> net = GaussianRandom(norm_bound, initial_noise_multiplier)
>>> net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier)
>>> res = net(gradients)
>>> print(res)
"""
def __init__(self, norm_bound=0.5, initial_noise_multiplier=1.5, seed=0, policy=None):
super(GaussianRandom, self).__init__()
def __init__(self, norm_bound=0.5, initial_noise_multiplier=1.5, seed=0,
policy=None):
super(NoiseGaussianRandom, self).__init__()
self._norm_bound = check_value_positive('norm_bound', norm_bound)
self._norm_bound = Tensor(norm_bound, mstype.float32)
self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier',
initial_noise_multiplier)
self._initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32)
self._initial_noise_multiplier = check_value_positive(
'initial_noise_multiplier',
initial_noise_multiplier)
self._initial_noise_multiplier = Tensor(initial_noise_multiplier,
mstype.float32)
self._mean = Tensor(0, mstype.float32)
self._normal = P.Normal(seed=seed)
self._decay_policy = policy
......@@ -201,17 +243,20 @@ class AdaGaussianRandom(Mechanisms):
noise_decay_rate=6e-4, decay_policy='Time', seed=0):
super(AdaGaussianRandom, self).__init__()
norm_bound = check_value_positive('norm_bound', norm_bound)
initial_noise_multiplier = check_value_positive('initial_noise_multiplier',
initial_noise_multiplier)
initial_noise_multiplier = check_value_positive(
'initial_noise_multiplier',
initial_noise_multiplier)
self._norm_bound = Tensor(norm_bound, mstype.float32)
initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32)
initial_noise_multiplier = Tensor(initial_noise_multiplier,
mstype.float32)
self._initial_noise_multiplier = Parameter(initial_noise_multiplier,
name='initial_noise_multiplier')
self._noise_multiplier = Parameter(initial_noise_multiplier,
name='noise_multiplier')
self._mean = Tensor(0, mstype.float32)
noise_decay_rate = check_param_type('noise_decay_rate', noise_decay_rate, float)
noise_decay_rate = check_param_type('noise_decay_rate',
noise_decay_rate, float)
check_param_in_range('noise_decay_rate', noise_decay_rate, 0.0, 1.0)
self._noise_decay_rate = Tensor(noise_decay_rate, mstype.float32)
if decay_policy not in ['Time', 'Step', 'Exp']:
......@@ -232,7 +277,9 @@ class AdaGaussianRandom(Mechanisms):
Tensor, generated noise with shape like given gradients.
"""
shape = P.Shape()(gradients)
noise = self._normal(shape, self._mean, self._mul(self._noise_multiplier, self._norm_bound))
noise = self._normal(shape, self._mean,
self._mul(self._noise_multiplier,
self._norm_bound))
return noise
......@@ -241,10 +288,14 @@ class _MechanismsParamsUpdater(Cell):
Update mechanisms parameters, the parameters will refresh in train period.
Args:
policy(str): Pass in by the mechanisms class, mechanisms parameters update policy.
decay_rate(Tensor): Pass in by the mechanisms class, hyper parameter for controlling the decay size.
cur_noise_multiplier(Parameter): Pass in by the mechanisms class, current params value in this time.
init_noise_multiplier(Parameter):Pass in by the mechanisms class, initial params value to be updated.
policy(str): Pass in by the mechanisms class, mechanisms parameters
update policy.
decay_rate(Tensor): Pass in by the mechanisms class, hyper parameter for
controlling the decay size.
cur_noise_multiplier(Parameter): Pass in by the mechanisms class,
current params value in this time.
init_noise_multiplier(Parameter):Pass in by the mechanisms class,
initial params value to be updated.
Returns:
Tuple, next params value.
......@@ -281,5 +332,100 @@ class _MechanismsParamsUpdater(Cell):
next_noise_multiplier = self._assign(self._cur_noise_multiplier,
self._mul(temp, self._cur_noise_multiplier))
else:
next_noise_multiplier = self._assign(self._cur_noise_multiplier, self._div(self._one, self._exp(self._one)))
next_noise_multiplier = self._assign(self._cur_noise_multiplier,
self._div(self._one, self._exp(self._one)))
return next_noise_multiplier
class AdaClippingWithGaussianRandom(Cell):
"""
Adaptive clipping. If `decay_policy` is 'Linear', the update formula is
$ norm_clip = norm_clip - learning_rate*(beta-target_unclipped_quantile)$.
`decay_policy` is 'Geometric', the update formula is
$ norm_clip = norm_clip*exp(-learning_rate*(empirical_fraction-target_unclipped_quantile))$.
where beta is the empirical fraction of samples with the value at most
`target_unclipped_quantile`.
Args:
decay_policy(str): Decay policy of adaptive clipping, decay_policy must
be in ['Linear', 'Geometric']. Default: Linear.
learning_rate(float): Learning rate of update norm clip. Default: 0.01.
target_unclipped_quantile(float): Target quantile of norm clip. Default: 0.9.
fraction_stddev(float): The stddev of Gaussian normal which used in
empirical_fraction, the formula is $empirical_fraction + N(0, fraction_stddev)$.
seed(int): Original random seed, if seed=0 random normal will use secure
random number. IF seed!=0 random normal will generate values using
given seed. Default: 0.
Returns:
Tensor, undated norm clip .
Examples:
>>> decay_policy = 'Linear'
>>> beta = Tensor(0.5, mstype.float32)
>>> norm_clip = Tensor(1.0, mstype.float32)
>>> beta_stddev = 0.01
>>> learning_rate = 0.001
>>> target_unclipped_quantile = 0.9
>>> ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy,
>>> learning_rate=learning_rate,
>>> target_unclipped_quantile=target_unclipped_quantile,
>>> fraction_stddev=beta_stddev)
>>> next_norm_clip = ada_clip(beta, norm_clip)
"""
def __init__(self, decay_policy='Linear', learning_rate=0.001,
target_unclipped_quantile=0.9, fraction_stddev=0.01, seed=0):
super(AdaClippingWithGaussianRandom, self).__init__()
if decay_policy not in ['Linear', 'Geometric']:
msg = "decay policy of adaptive clip must be in ['Linear', 'Geometric'], \
but got: {}".format(decay_policy)
LOGGER.error(TAG, msg)
raise ValueError(msg)
self._decay_policy = decay_policy
learning_rate = check_param_type('learning_rate', learning_rate, float)
learning_rate = check_value_positive('learning_rate', learning_rate)
self._learning_rate = Tensor(learning_rate, mstype.float32)
fraction_stddev = check_param_type('fraction_stddev', fraction_stddev, float)
self._fraction_stddev = Tensor(fraction_stddev, mstype.float32)
target_unclipped_quantile = check_param_type('target_unclipped_quantile',
target_unclipped_quantile,
float)
self._target_unclipped_quantile = Tensor(target_unclipped_quantile,
mstype.float32)
self._zero = Tensor(0, mstype.float32)
self._add = P.TensorAdd()
self._sub = P.Sub()
self._mul = P.Mul()
self._exp = P.Exp()
self._normal = P.Normal(seed=seed)
def construct(self, empirical_fraction, norm_clip):
"""
Update value of norm_clip.
Args:
empirical_fraction(Tensor): empirical fraction of samples with the
value at most `target_unclipped_quantile`.
norm_clip(Tensor): Clipping bound for the l2 norm of the gradients.
Returns:
Tensor, generated noise with shape like given gradients.
"""
fraction_noise = self._normal((1,), self._zero, self._fraction_stddev)
empirical_fraction = self._add(empirical_fraction, fraction_noise)
if self._decay_policy == 'Linear':
grad_clip = self._sub(empirical_fraction,
self._target_unclipped_quantile)
next_norm_clip = self._sub(norm_clip,
self._mul(self._learning_rate, grad_clip))
# decay_policy == 'Geometric'
else:
grad_clip = self._sub(empirical_fraction,
self._target_unclipped_quantile)
grad_clip = self._exp(self._mul(-self._learning_rate, grad_clip))
next_norm_clip = self._mul(norm_clip, grad_clip)
return next_norm_clip
......@@ -22,7 +22,7 @@ from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindarmour.utils.logger import LogUtil
from mindarmour.diff_privacy import MechanismsFactory
from mindarmour.diff_privacy import NoiseMechanismsFactory
from mindarmour.diff_privacy.mechanisms.mechanisms import _MechanismsParamsUpdater
from mindarmour.utils._check_param import check_int_positive
......@@ -70,7 +70,7 @@ class DPOptimizerClassFactory:
"""
def __init__(self, micro_batches=2):
self._mech_factory = MechanismsFactory()
self._mech_factory = NoiseMechanismsFactory()
self.mech = None
self._micro_batches = check_int_positive('micro_batches', micro_batches)
......
......@@ -48,7 +48,8 @@ from mindspore.nn import Cell
from mindspore import ParameterTuple
from mindarmour.utils.logger import LogUtil
from mindarmour.diff_privacy.mechanisms.mechanisms import _MechanismsParamsUpdater
from mindarmour.diff_privacy.mechanisms.mechanisms import \
_MechanismsParamsUpdater
from mindarmour.utils._check_param import check_param_type
from mindarmour.utils._check_param import check_value_positive
from mindarmour.utils._check_param import check_int_positive
......@@ -64,7 +65,7 @@ _reciprocal = P.Reciprocal()
@_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
""" grad scaling """
return grad * F.cast(_reciprocal(scale), F.dtype(grad))
return grad*F.cast(_reciprocal(scale), F.dtype(grad))
class DPModel(Model):
......@@ -72,9 +73,14 @@ class DPModel(Model):
This class is overload mindspore.train.model.Model.
Args:
micro_batches (int): The number of small batches split from an original batch. Default: 2.
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0.
mech (Mechanisms): The object can generate the different type of noise. Default: None.
micro_batches (int): The number of small batches split from an original
batch. Default: 2.
norm_clip (float): Use to clip the bound, if set 1, will retun the
original data. Default: 1.0.
noise_mech (Mechanisms): The object can generate the different type of
noise. Default: None.
clip_mech (Mechanisms): The object is used to update the adaptive clip .
Default: None.
Examples:
>>> norm_clip = 1.0
......@@ -89,63 +95,82 @@ class DPModel(Model):
>>> factory_opt.set_mechanisms('Gaussian',
>>> norm_bound=norm_clip,
>>> initial_noise_multiplier=initial_noise_multiplier)
>>> net_opt = factory_opt.create('Momentum')(network.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_opt = factory_opt.create('Momentum')(network.trainable_params(),
>>> learning_rate=0.1, momentum=0.9)
>>> clip_mech = ClipMechanismsFactory().create('Gaussian',
>>> decay_policy='Linear',
>>> learning_rate=0.01,
>>> target_unclipped_quantile=0.9,
>>> fraction_stddev=0.01)
>>> model = DPModel(micro_batches=micro_batches,
>>> norm_clip=norm_clip,
>>> mech=None,
>>> clip_mech=clip_mech,
>>> noise_mech=None,
>>> network=network,
>>> loss_fn=loss,
>>> optimizer=net_opt,
>>> metrics=None)
>>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label'])
>>> ms_ds.set_dataset_size(batch_size * batches)
>>> ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches),
>>> ['data', 'label'])
>>> ms_ds.set_dataset_size(batch_size*batches)
>>> model.train(epochs, ms_ds, dataset_sink_mode=False)
"""
def __init__(self, micro_batches=2, norm_clip=1.0, mech=None, **kwargs):
def __init__(self, micro_batches=2, norm_clip=1.0, noise_mech=None,
clip_mech=None, **kwargs):
if micro_batches:
self._micro_batches = check_int_positive('micro_batches', micro_batches)
self._micro_batches = check_int_positive('micro_batches',
micro_batches)
else:
self._micro_batches = None
norm_clip = check_param_type('norm_clip', norm_clip, float)
self._norm_clip = check_value_positive('norm_clip', norm_clip)
if mech is not None and "DPOptimizer" in kwargs['optimizer'].__class__.__name__:
msg = 'DPOptimizer is not supported while mech is not None'
norm_clip = check_value_positive('norm_clip', norm_clip)
norm_clip = Tensor(norm_clip, mstype.float32)
self._norm_clip = Parameter(norm_clip, 'norm_clip')
if noise_mech is not None and "DPOptimizer" in kwargs['optimizer'].__class__.__name__:
msg = 'DPOptimizer is not supported while noise_mech is not None'
LOGGER.error(TAG, msg)
raise ValueError(msg)
if mech is None:
if noise_mech is None:
if "DPOptimizer" in kwargs['optimizer'].__class__.__name__:
if context.get_context('mode') != context.PYNATIVE_MODE:
msg = 'DPOptimizer just support pynative mode currently.'
LOGGER.error(TAG, msg)
raise ValueError(msg)
else:
msg = 'DPModel should set mech or DPOptimizer configure, please refer to example.'
msg = 'DPModel should set noise_mech or DPOptimizer configure, ' \
'please refer to example.'
LOGGER.error(TAG, msg)
raise ValueError(msg)
self._mech = mech
self._noise_mech = noise_mech
if clip_mech is not None:
self._clip_mech = clip_mech
super(DPModel, self).__init__(**kwargs)
def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs):
def _amp_build_train_network(self, network, optimizer, loss_fn=None,
level='O0', **kwargs):
"""
Build the mixed precision training cell automatically.
Args:
network (Cell): Definition of the network.
loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, the `network` should have the loss inside.
Default: None.
loss_fn (Union[None, Cell]): Definition of the loss_fn. If None,
the `network` should have the loss inside. Default: None.
optimizer (Optimizer): Optimizer to update the Parameter.
level (str): Supports [O0, O2]. Default: "O0".
- O0: Do not change.
- O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
using dynamic loss scale.
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`.
If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting.
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting.
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
scale the loss by LossScaleManager. If set, overwrite the level setting.
- O2: Cast network to float16, keep batchnorm and `loss_fn`
(if set) run in float32, using dynamic loss scale.
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16`
or `mstype.float32`. If set to `mstype.float16`, use `float16`
mode to train. If set, overwrite the level setting.
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set,
overwrite the level setting.
loss_scale_manager (Union[None, LossScaleManager]): If None, not
scale the loss, or else scale the loss by LossScaleManager.
If set, overwrite the level setting.
"""
validator.check_value_type('network', network, nn.Cell, None)
validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
......@@ -161,9 +186,11 @@ class DPModel(Model):
_do_keep_batchnorm_fp32(network)
if loss_fn:
network = _add_loss_network(network, loss_fn, config.cast_model_type)
network = _add_loss_network(network, loss_fn,
config.cast_model_type)
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
if _get_parallel_mode() in (
ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network = _VirtualDatasetCell(network)
loss_scale = 1.0
......@@ -173,9 +200,12 @@ class DPModel(Model):
update_cell = loss_scale_manager.get_update_cell()
if update_cell is not None:
# only cpu not support `TrainOneStepWithLossScaleCell` for control flow.
if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU":
msg = "Only `loss_scale_manager=None` and `loss_scale_manager=FixedLossScaleManager(drop_overflow" \
"_update=False)` are supported in current version. If you use `O2` option, please use " \
if not context.get_context("enable_ge") and context.get_context(
"device_target") == "CPU":
msg = "Only `loss_scale_manager=None` and " \
"`loss_scale_manager=FixedLossScaleManager(drop_overflow" \
"_update=False)` are supported in current version. " \
"If you use `O2` option, please use " \
"`loss_scale_manager=None` or `FixedLossScaleManager`"
LOGGER.error(TAG, msg)
raise ValueError(msg)
......@@ -184,15 +214,17 @@ class DPModel(Model):
scale_update_cell=update_cell,
micro_batches=self._micro_batches,
norm_clip=self._norm_clip,
mech=self._mech).set_train()
clip_mech=self._clip_mech,
noise_mech=self._noise_mech).set_train()
return network
network = _TrainOneStepCell(network,
optimizer,
self._norm_clip,
loss_scale,
micro_batches=self._micro_batches,
norm_clip=self._norm_clip,
mech=self._mech).set_train()
clip_mech=self._clip_mech,
noise_mech=self._noise_mech).set_train()
return network
def _build_train_network(self):
......@@ -233,7 +265,8 @@ class DPModel(Model):
elif self._loss_fn:
network = nn.WithLossCell(network, self._loss_fn)
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL,
ParallelMode.AUTO_PARALLEL):
network.set_auto_parallel()
return network
......@@ -267,11 +300,10 @@ class _ClipGradients(nn.Cell):
new_grads = ()
for grad in grads:
if clip_type == 0:
t = C.clip_by_value(grad, F.tuple_to_array((-clip_value,)),
F.tuple_to_array((clip_value,)))
norm = C.clip_by_value(grad, -clip_value, clip_value)
else:
t = self.clip_by_norm(grad, F.tuple_to_array((clip_value,)))
new_grads = new_grads + (t,)
norm = self.clip_by_norm(grad, clip_value)
new_grads = new_grads + (norm,)
return new_grads
......@@ -292,20 +324,27 @@ class _TrainOneStepWithLossScaleCell(Cell):
r"""
Network training with loss scaling.
This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update
Cell as args. The loss scale value can be updated in both host side or device side. The
TrainOneStepWithLossScaleCell will be compiled to be graph which takes `data`, `label`, `sens` as input
data. The `sens` is acting as loss scaling value. If you want to update it on host side, the value should
be provided. If `sens` is not given, the loss scale update logic should be provied by `scale_update_cell`.
If `scale_update_cell` is not None and `sens` is provided, the `scale_update_cell` will be ignored.
This is a training step with loss scaling. It takes a network, an optimizer
and possibly a scale update Cell as args. The loss scale value can be
updated in both host side or device side. The TrainOneStepWithLossScaleCell
will be compiled to be graph which takes `data`, `label`, `sens` as input
data. The `sens` is acting as loss scaling value. If you want to update it
on host side, the value should be provided. If `sens` is not given, the loss
scale update logic should be provied by `scale_update_cell`. If
`scale_update_cell` is not None and `sens` is provided, the
`scale_update_cell` will be ignored.
Args:
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
scale_update_cell(Cell): The loss scaling update logic cell. Default: None.
micro_batches (int): The number of small batches split from an original batch. Default: None.
norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0.
mech (Mechanisms): The object can generate the different type of noise. Default: None.
scale_update_cell(Cell): The loss scaling update logic cell.
Default: None.
micro_batches (int): The number of small batches split from an original
batch. Default: None.
norm_clip (Tensor): Use to clip the bound, if set 1, will return the
original data. Default: 1.0.
noise_mech (Mechanisms): The object can generate the different type of
noise. Default: None.
Inputs:
- **inputs** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
......@@ -320,7 +359,9 @@ class _TrainOneStepWithLossScaleCell(Cell):
- **loss_scale** (Tensor) - Tensor with shape :math:`()`.
"""
def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, norm_clip=1.0, mech=None):
def __init__(self, network, optimizer, scale_update_cell=None,
micro_batches=None, norm_clip=1.0, noise_mech=None,
clip_mech=None):
super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
......@@ -346,39 +387,54 @@ class _TrainOneStepWithLossScaleCell(Cell):
self.allreduce = P.AllReduce()
self.parallel_mode = _get_parallel_mode()
self.grad_reducer = F.identity
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL,
ParallelMode.HYBRID_PARALLEL]
if self.reducer_flag:
mean = _get_mirror_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.grad_reducer = DistributedGradReducer(optimizer.parameters,
mean, degree)
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
self.loss_scale = Parameter(
Tensor(scale_update_cell.get_loss_scale(),
dtype=mstype.float32),
name="loss_scale")
self.add_flags(has_effect=True)
# dp params
self._micro_batches = micro_batches
norm_clip = check_param_type('norm_clip', norm_clip, float)
self._l2_norm = check_value_positive('norm_clip', norm_clip)
self._norm_clip = norm_clip
self._split = P.Split(0, self._micro_batches)
self._clip_by_global_norm = _ClipGradients()
self._mech = mech
self._noise_mech = noise_mech
self._clip_mech = clip_mech
self._add = P.TensorAdd()
self._norm = nn.Norm()
self._tuple_add = _TupleAdd()
self._hyper_map = C.HyperMap()
self._micro_float = Tensor(micro_batches, mstype.float32)
self._mech_param_updater = None
if self._mech is not None and self._mech._decay_policy is not None:
self._mech_param_updater = _MechanismsParamsUpdater(policy=self._mech._decay_policy,
decay_rate=self._mech._noise_decay_rate,
cur_noise_multiplier=
self._mech._noise_multiplier,
init_noise_multiplier=
self._mech._initial_noise_multiplier)
self._zero = Tensor(0, mstype.float32)
self._assign = P.Assign()
self._div = P.Div()
self._sqrt = P.Sqrt()
self._reduce_sum = P.ReduceSum()
self._square_all = P.Square()
self._less = P.Less()
self._cast = P.Cast()
self._noise_mech_param_updater = None
if self._noise_mech is not None and self._noise_mech._decay_policy is not None:
self._noise_mech_param_updater = _MechanismsParamsUpdater(
policy=self._noise_mech._decay_policy,
decay_rate=self._noise_mech._noise_decay_rate,
cur_noise_multiplier=
self._noise_mech._noise_multiplier,
init_noise_multiplier=
self._noise_mech._initial_noise_multiplier)
def construct(self, data, label, sens=None):
"""
......@@ -402,30 +458,62 @@ class _TrainOneStepWithLossScaleCell(Cell):
record_labels = self._split(label)
# first index
loss = self.network(record_datas[0], record_labels[0])
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], scaling_sens_filled)
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens,
F.dtype(loss))
record_grad = self.grad(self.network, weights)(record_datas[0],
record_labels[0],
scaling_sens_filled)
beta = self._zero
square_sum = self._zero
for grad in record_grad:
square_sum = self._add(square_sum,
self._reduce_sum(self._square_all(grad)))
norm_grad = self._sqrt(square_sum)
beta = self._add(beta,
self._cast(self._less(norm_grad, self._norm_clip),
mstype.float32))
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE,
self._norm_clip)
grads = record_grad
total_loss = loss
for i in range(1, self._micro_batches):
loss = self.network(record_datas[i], record_labels[i])
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], scaling_sens_filled)
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens,
F.dtype(loss))
record_grad = self.grad(self.network, weights)(record_datas[i],
record_labels[i],
scaling_sens_filled)
square_sum = self._zero
for grad in record_grad:
square_sum = self._add(square_sum,
self._reduce_sum(self._square_all(grad)))
norm_grad = self._sqrt(square_sum)
beta = self._add(beta,
self._cast(self._less(norm_grad, self._norm_clip),
mstype.float32))
record_grad = self._clip_by_global_norm(record_grad,
GRADIENT_CLIP_TYPE,
self._norm_clip)
grads = self._tuple_add(grads, record_grad)
total_loss = P.TensorAdd()(total_loss, loss)
loss = P.Div()(total_loss, self._micro_float)
beta = self._div(beta, self._micro_batches)
if self._mech is not None:
if self._noise_mech is not None:
grad_noise_tuple = ()
for grad_item in grads:
grad_noise = self._mech(grad_item)
grad_noise_tuple = grad_noise_tuple + (grad_noise,)
grads = self._tuple_add(grads, grad_noise_tuple)
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads)
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float),
grads)
# update mech parameters
if self._mech_param_updater is not None:
multiplier = self._mech_param_updater()
if self._noise_mech_param_updater is not None:
multiplier = self._noise_mech_param_updater()
loss = F.depend(loss, multiplier)
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
......@@ -456,6 +544,10 @@ class _TrainOneStepWithLossScaleCell(Cell):
else:
opt = self.optimizer(grads)
ret = (loss, cond, scaling_sens)
if self._clip_mech is not None:
next_norm_clip = self._clip_mech(beta, self._norm_clip)
P.assign(self._norm_clip, next_norm_clip)
return F.depend(ret, opt)
......@@ -463,17 +555,22 @@ class _TrainOneStepCell(Cell):
r"""
Network training package class.
Wraps the network with an optimizer. The resulting Cell be trained with input data and label.
Backward graph will be created in the construct function to do parameter updating. Different
parallel modes are available to run the training.
Wraps the network with an optimizer. The resulting Cell be trained with
input data and label. Backward graph will be created in the construct
function to do parameter updating. Different parallel modes are available
to run the training.
Args:
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The scaling number to be filled as the input of back propagation. Default value is 1.0.
micro_batches (int): The number of small batches split from an original batch. Default: None.
norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0.
mech (Mechanisms): The object can generate the different type of noise. Default: None.
sens (Number): The scaling number to be filled as the input of back
propagation. Default value is 1.0.
micro_batches (int): The number of small batches split from an original
batch. Default: None.
norm_clip (Tensor): Use to clip the bound, if set 1, will return the
original data. Default: 1.0.
noise_mech (Mechanisms): The object can generate the different type
of noise. Default: None.
Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
......@@ -483,7 +580,9 @@ class _TrainOneStepCell(Cell):
Tensor, a scalar Tensor with shape :math:`()`.
"""
def __init__(self, network, optimizer, sens=1.0, micro_batches=None, norm_clip=1.0, mech=None):
def __init__(self, network, optimizer, norm_clip=1.0, sens=1.0,
micro_batches=None,
noise_mech=None, clip_mech=None):
super(_TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
......@@ -495,36 +594,51 @@ class _TrainOneStepCell(Cell):
self.reducer_flag = False
self.grad_reducer = None
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
if parallel_mode in (
ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_mirror_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.grad_reducer = DistributedGradReducer(optimizer.parameters,
mean, degree)
# dp params
if micro_batches is None:
msg = 'micro_batches must give in differential privacy, but got value: {}'.format(micro_batches)
msg = 'micro_batches must give in differential privacy, but got value: {}'.format(
micro_batches)
LOGGER.error(TAG, msg)
raise ValueError(msg)
self._micro_batches = micro_batches
norm_clip = check_param_type('norm_clip', norm_clip, float)
self._l2_norm = check_value_positive('norm_clip', norm_clip)
self._norm_clip = norm_clip
self._split = P.Split(0, self._micro_batches)
self._clip_by_global_norm = _ClipGradients()
self._mech = mech
self._noise_mech = noise_mech
self._clip_mech = clip_mech
self._tuple_add = _TupleAdd()
self._add = P.TensorAdd()
self._norm = nn.Norm()
self._hyper_map = C.HyperMap()
self._zero = Tensor(0, mstype.float32)
self._assign = P.Assign()
self._div = P.Div()
self._sqrt = P.Sqrt()
self._reduce_sum = P.ReduceSum()
self._square_all = P.Square()
self._less = P.Less()
self._cast = P.Cast()
self._micro_float = Tensor(micro_batches, mstype.float32)
self._mech_param_updater = None
if self._mech is not None and self._mech._decay_policy is not None:
self._mech_param_updater = _MechanismsParamsUpdater(policy=self._mech._decay_policy,
decay_rate=self._mech._noise_decay_rate,
cur_noise_multiplier=
self._mech._noise_multiplier,
init_noise_multiplier=
self._mech._initial_noise_multiplier)
self._noise_mech_param_updater = None
if self._noise_mech is not None and self._noise_mech._decay_policy is not None:
self._noise_mech_param_updater = _MechanismsParamsUpdater(
policy=self._noise_mech._decay_policy,
decay_rate=self._noise_mech._noise_decay_rate,
cur_noise_multiplier=
self._noise_mech._noise_multiplier,
init_noise_multiplier=
self._noise_mech._initial_noise_multiplier)
def construct(self, data, label):
"""
......@@ -535,32 +649,65 @@ class _TrainOneStepCell(Cell):
record_labels = self._split(label)
loss = self.network(record_datas[0], record_labels[0])
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], sens)
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
record_grad = self.grad(self.network, weights)(record_datas[0],
record_labels[0], sens)
beta = self._zero
square_sum = self._zero
for grad in record_grad:
square_sum = self._add(square_sum,
self._reduce_sum(self._square_all(grad)))
norm_grad = self._sqrt(square_sum)
beta = self._add(beta,
self._cast(self._less(norm_grad, self._norm_clip),
mstype.float32))
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE,
self._norm_clip)
grads = record_grad
total_loss = loss
for i in range(1, self._micro_batches):
loss = self.network(record_datas[i], record_labels[i])
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], sens)
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
record_grad = self.grad(self.network, weights)(record_datas[i],
record_labels[i],
sens)
square_sum = self._zero
for grad in record_grad:
square_sum = self._add(square_sum,
self._reduce_sum(self._square_all(grad)))
norm_grad = self._sqrt(square_sum)
beta = self._add(beta,
self._cast(self._less(norm_grad, self._norm_clip),
mstype.float32))
record_grad = self._clip_by_global_norm(record_grad,
GRADIENT_CLIP_TYPE,
self._norm_clip)
grads = self._tuple_add(grads, record_grad)
total_loss = P.TensorAdd()(total_loss, loss)
loss = P.Div()(total_loss, self._micro_float)
loss = self._div(total_loss, self._micro_float)
beta = self._div(beta, self._micro_batches)
if self._mech is not None:
if self._noise_mech is not None:
grad_noise_tuple = ()
for grad_item in grads:
grad_noise = self._mech(grad_item)
grad_noise = self._noise_mech(grad_item)
grad_noise_tuple = grad_noise_tuple + (grad_noise,)
grads = self._tuple_add(grads, grad_noise_tuple)
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads)
grads = self._hyper_map(F.partial(_grad_scale, self._micro_float),
grads)
# update mech parameters
if self._mech_param_updater is not None:
multiplier = self._mech_param_updater()
if self._noise_mech_param_updater is not None:
multiplier = self._noise_mech_param_updater()
loss = F.depend(loss, multiplier)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
if self._clip_mech is not None:
next_norm_clip = self._clip_mech(beta, self._norm_clip)
self._norm_clip = self._assign(self._norm_clip, next_norm_clip)
loss = F.depend(loss, next_norm_clip)
return F.depend(loss, self.optimizer(grads))
......@@ -19,9 +19,11 @@ import pytest
from mindspore import context
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindarmour.diff_privacy import GaussianRandom
from mindarmour.diff_privacy import NoiseGaussianRandom
from mindarmour.diff_privacy import AdaGaussianRandom
from mindarmour.diff_privacy import MechanismsFactory
from mindarmour.diff_privacy import AdaClippingWithGaussianRandom
from mindarmour.diff_privacy import NoiseMechanismsFactory
from mindarmour.diff_privacy import ClipMechanismsFactory
@pytest.mark.level0
......@@ -33,7 +35,7 @@ def test_graph_gaussian():
grad = Tensor([0.3, 0.2, 0.4], mstype.float32)
norm_bound = 1.0
initial_noise_multiplier = 0.1
net = GaussianRandom(norm_bound, initial_noise_multiplier)
net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier)
res = net(grad)
print(res)
......@@ -47,7 +49,7 @@ def test_pynative_gaussian():
grad = Tensor([0.3, 0.2, 0.4], mstype.float32)
norm_bound = 1.0
initial_noise_multiplier = 0.1
net = GaussianRandom(norm_bound, initial_noise_multiplier)
net = NoiseGaussianRandom(norm_bound, initial_noise_multiplier)
res = net(grad)
print(res)
......@@ -80,13 +82,13 @@ def test_graph_factory():
initial_noise_multiplier = 0.1
alpha = 0.5
decay_policy = 'Step'
noise_mechanism = MechanismsFactory()
noise_mechanism = NoiseMechanismsFactory()
noise_construct = noise_mechanism.create('Gaussian',
norm_bound,
initial_noise_multiplier)
noise = noise_construct(grad)
print('Gaussian noise: ', noise)
ada_mechanism = MechanismsFactory()
ada_mechanism = NoiseMechanismsFactory()
ada_noise_construct = ada_mechanism.create('AdaGaussian',
norm_bound,
initial_noise_multiplier,
......@@ -124,13 +126,13 @@ def test_pynative_factory():
initial_noise_multiplier = 0.1
alpha = 0.5
decay_policy = 'Step'
noise_mechanism = MechanismsFactory()
noise_mechanism = NoiseMechanismsFactory()
noise_construct = noise_mechanism.create('Gaussian',
norm_bound,
initial_noise_multiplier)
noise = noise_construct(grad)
print('Gaussian noise: ', noise)
ada_mechanism = MechanismsFactory()
ada_mechanism = NoiseMechanismsFactory()
ada_noise_construct = ada_mechanism.create('AdaGaussian',
norm_bound,
initial_noise_multiplier,
......@@ -151,7 +153,7 @@ def test_pynative_exponential():
initial_noise_multiplier = 0.1
alpha = 0.5
decay_policy = 'Exp'
ada_mechanism = MechanismsFactory()
ada_mechanism = NoiseMechanismsFactory()
ada_noise_construct = ada_mechanism.create('AdaGaussian',
norm_bound,
initial_noise_multiplier,
......@@ -172,7 +174,7 @@ def test_graph_exponential():
initial_noise_multiplier = 0.1
alpha = 0.5
decay_policy = 'Exp'
ada_mechanism = MechanismsFactory()
ada_mechanism = NoiseMechanismsFactory()
ada_noise_construct = ada_mechanism.create('AdaGaussian',
norm_bound,
initial_noise_multiplier,
......@@ -180,3 +182,107 @@ def test_graph_exponential():
decay_policy=decay_policy)
ada_noise = ada_noise_construct(grad)
print('ada noise: ', ada_noise)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_ada_clip_gaussian_random_pynative():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
decay_policy = 'Linear'
beta = Tensor(0.5, mstype.float32)
norm_clip = Tensor(1.0, mstype.float32)
beta_stddev = 0.1
learning_rate = 0.1
target_unclipped_quantile = 0.3
ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy,
learning_rate=learning_rate,
target_unclipped_quantile=target_unclipped_quantile,
fraction_stddev=beta_stddev,
seed=1)
next_norm_clip = ada_clip(beta, norm_clip)
print('Liner next norm clip:', next_norm_clip)
decay_policy = 'Geometric'
ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy,
learning_rate=learning_rate,
target_unclipped_quantile=target_unclipped_quantile,
fraction_stddev=beta_stddev,
seed=1)
next_norm_clip = ada_clip(beta, norm_clip)
print('Geometric next norm clip:', next_norm_clip)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_ada_clip_gaussian_random_graph():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
decay_policy = 'Linear'
beta = Tensor(0.5, mstype.float32)
norm_clip = Tensor(1.0, mstype.float32)
beta_stddev = 0.1
learning_rate = 0.1
target_unclipped_quantile = 0.3
ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy,
learning_rate=learning_rate,
target_unclipped_quantile=target_unclipped_quantile,
fraction_stddev=beta_stddev,
seed=1)
next_norm_clip = ada_clip(beta, norm_clip)
print('Liner next norm clip:', next_norm_clip)
decay_policy = 'Geometric'
ada_clip = AdaClippingWithGaussianRandom(decay_policy=decay_policy,
learning_rate=learning_rate,
target_unclipped_quantile=target_unclipped_quantile,
fraction_stddev=beta_stddev,
seed=1)
next_norm_clip = ada_clip(beta, norm_clip)
print('Geometric next norm clip:', next_norm_clip)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_pynative_clip_mech_factory():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
decay_policy = 'Linear'
beta = Tensor(0.5, mstype.float32)
norm_clip = Tensor(1.0, mstype.float32)
beta_stddev = 0.1
learning_rate = 0.1
target_unclipped_quantile = 0.3
clip_mechanism = ClipMechanismsFactory()
ada_clip = clip_mechanism.create('Gaussian',
decay_policy=decay_policy,
learning_rate=learning_rate,
target_unclipped_quantile=target_unclipped_quantile,
fraction_stddev=beta_stddev)
next_norm_clip = ada_clip(beta, norm_clip)
print('next_norm_clip: ', next_norm_clip)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_graph_clip_mech_factory():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
decay_policy = 'Linear'
beta = Tensor(0.5, mstype.float32)
norm_clip = Tensor(1.0, mstype.float32)
beta_stddev = 0.1
learning_rate = 0.1
target_unclipped_quantile = 0.3
clip_mechanism = ClipMechanismsFactory()
ada_clip = clip_mechanism.create('Gaussian',
decay_policy=decay_policy,
learning_rate=learning_rate,
target_unclipped_quantile=target_unclipped_quantile,
fraction_stddev=beta_stddev)
next_norm_clip = ada_clip(beta, norm_clip)
print('next_norm_clip: ', next_norm_clip)
......@@ -22,7 +22,8 @@ from mindspore import context
import mindspore.dataset as ds
from mindarmour.diff_privacy import DPModel
from mindarmour.diff_privacy import MechanismsFactory
from mindarmour.diff_privacy import NoiseMechanismsFactory
from mindarmour.diff_privacy import ClipMechanismsFactory
from mindarmour.diff_privacy import DPOptimizerClassFactory
from test_network import LeNet5
......@@ -30,10 +31,12 @@ from test_network import LeNet5
def dataset_generator(batch_size, batches):
"""mock training data."""
data = np.random.random((batches * batch_size, 1, 32, 32)).astype(np.float32)
label = np.random.randint(0, 10, batches * batch_size).astype(np.int32)
data = np.random.random((batches*batch_size, 1, 32, 32)).astype(
np.float32)
label = np.random.randint(0, 10, batches*batch_size).astype(np.int32)
for i in range(batches):
yield data[i * batch_size:(i + 1) * batch_size], label[i * batch_size:(i + 1) * batch_size]
yield data[i*batch_size:(i + 1)*batch_size],\
label[i*batch_size:(i + 1)*batch_size]
@pytest.mark.level0
......@@ -55,16 +58,24 @@ def test_dp_model_with_pynative_mode():
factory_opt.set_mechanisms('Gaussian',
norm_bound=norm_clip,
initial_noise_multiplier=initial_noise_multiplier)
net_opt = factory_opt.create('Momentum')(network.trainable_params(), learning_rate=0.1, momentum=0.9)
net_opt = factory_opt.create('Momentum')(network.trainable_params(),
learning_rate=0.1, momentum=0.9)
clip_mech = ClipMechanismsFactory().create('Gaussian',
decay_policy='Linear',
learning_rate=0.01,
target_unclipped_quantile=0.9,
fraction_stddev=0.01)
model = DPModel(micro_batches=micro_batches,
norm_clip=norm_clip,
mech=None,
clip_mech=clip_mech,
noise_mech=None,
network=network,
loss_fn=loss,
optimizer=net_opt,
metrics=None)
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label'])
ms_ds.set_dataset_size(batch_size * batches)
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches),
['data', 'label'])
ms_ds.set_dataset_size(batch_size*batches)
model.train(epochs, ms_ds, dataset_sink_mode=False)
......@@ -82,19 +93,27 @@ def test_dp_model_with_graph_mode():
batches = 128
epochs = 1
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
mech = MechanismsFactory().create('Gaussian',
norm_bound=norm_clip,
initial_noise_multiplier=initial_noise_multiplier)
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
noise_mech = NoiseMechanismsFactory().create('Gaussian',
norm_bound=norm_clip,
initial_noise_multiplier=initial_noise_multiplier)
clip_mech = ClipMechanismsFactory().create('Gaussian',
decay_policy='Linear',
learning_rate=0.01,
target_unclipped_quantile=0.9,
fraction_stddev=0.01)
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1,
momentum=0.9)
model = DPModel(micro_batches=2,
clip_mech=clip_mech,
norm_clip=norm_clip,
mech=mech,
noise_mech=noise_mech,
network=network,
loss_fn=loss,
optimizer=net_opt,
metrics=None)
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label'])
ms_ds.set_dataset_size(batch_size * batches)
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches),
['data', 'label'])
ms_ds.set_dataset_size(batch_size*batches)
model.train(epochs, ms_ds, dataset_sink_mode=False)
......@@ -112,17 +131,25 @@ def test_dp_model_with_graph_mode_ada_gaussian():
batches = 128
epochs = 1
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
mech = MechanismsFactory().create('AdaGaussian',
norm_bound=norm_clip,
initial_noise_multiplier=initial_noise_multiplier)
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
noise_mech = NoiseMechanismsFactory().create('AdaGaussian',
norm_bound=norm_clip,
initial_noise_multiplier=initial_noise_multiplier)
clip_mech = ClipMechanismsFactory().create('Gaussian',
decay_policy='Linear',
learning_rate=0.01,
target_unclipped_quantile=0.9,
fraction_stddev=0.01)
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.1,
momentum=0.9)
model = DPModel(micro_batches=2,
clip_mech=clip_mech,
norm_clip=norm_clip,
mech=mech,
noise_mech=noise_mech,
network=network,
loss_fn=loss,
optimizer=net_opt,
metrics=None)
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label'])
ms_ds.set_dataset_size(batch_size * batches)
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches),
['data', 'label'])
ms_ds.set_dataset_size(batch_size*batches)
model.train(epochs, ms_ds, dataset_sink_mode=False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册