From 9c37a110365b9358c8bd41e164a645f649b445d3 Mon Sep 17 00:00:00 2001 From: zhenghuanhuan Date: Thu, 28 May 2020 20:17:12 +0800 Subject: [PATCH] =?UTF-8?q?[CT][MA][DP]param=20of=20DPModel:norm=5Fclip,us?= =?UTF-8?q?e=20default,=20runtime=20error][DP]param=20of=20DPModel:norm=5F?= =?UTF-8?q?clip=3D0,=20runtime=20error][diff=5Fprivacy][Pref]=20use=20DPMo?= =?UTF-8?q?del=20=EF=BC=8Cthe=20performance=20of=20lenet=20deteriorates=20?= =?UTF-8?q?too=20much=20https://gitee.com/mindspore/dashboard=3Fissue=5Fid?= =?UTF-8?q?=3DI1IMD7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://gitee.com/mindspore/dashboard?issue_id=I1IQ11 https://gitee.com/mindspore/dashboard?issue_id=I1IPKH --- example/mnist_demo/lenet5_dp_model_train.py | 3 +- .../diff_privacy/optimizer/optimizer.py | 3 +- mindarmour/diff_privacy/train/model.py | 75 ++++++++----------- .../python/diff_privacy/test_model_train.py | 13 ++-- 4 files changed, 41 insertions(+), 53 deletions(-) diff --git a/example/mnist_demo/lenet5_dp_model_train.py b/example/mnist_demo/lenet5_dp_model_train.py index 089c23f..f01bfaf 100644 --- a/example/mnist_demo/lenet5_dp_model_train.py +++ b/example/mnist_demo/lenet5_dp_model_train.py @@ -123,10 +123,9 @@ if __name__ == "__main__": net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(), learning_rate=cfg.lr, momentum=cfg.momentum) - micro_size = int(cfg.batch_size // args.micro_batches) rdp_monitor = PrivacyMonitorFactory.create('rdp', num_samples=60000, - batch_size=micro_size, + batch_size=cfg.batch_size, initial_noise_multiplier=args.initial_noise_multiplier, per_print_times=10) model = DPModel(micro_batches=args.micro_batches, diff --git a/mindarmour/diff_privacy/optimizer/optimizer.py b/mindarmour/diff_privacy/optimizer/optimizer.py index e16799f..a844e79 100644 --- a/mindarmour/diff_privacy/optimizer/optimizer.py +++ b/mindarmour/diff_privacy/optimizer/optimizer.py @@ -19,6 +19,7 @@ from mindspore import nn from mindspore import Tensor from mindarmour.diff_privacy.mechanisms.mechanisms import MechanismsFactory +from mindarmour.utils._check_param import check_int_positive class DPOptimizerClassFactory: @@ -41,7 +42,7 @@ class DPOptimizerClassFactory: def __init__(self, micro_batches=None): self._mech_factory = MechanismsFactory() self.mech = None - self._micro_batches = micro_batches + self._micro_batches = check_int_positive('micro_batches', micro_batches) def set_mechanisms(self, policy, *args, **kwargs): """ diff --git a/mindarmour/diff_privacy/train/model.py b/mindarmour/diff_privacy/train/model.py index 20e14a7..434bbe0 100644 --- a/mindarmour/diff_privacy/train/model.py +++ b/mindarmour/diff_privacy/train/model.py @@ -48,6 +48,8 @@ from mindspore.nn.wrap.loss_scale import _grad_overflow from mindspore.nn import Cell from mindspore import ParameterTuple +from mindarmour.utils._check_param import check_param_type +from mindarmour.utils._check_param import check_value_positive GRADIENT_CLIP_TYPE = 1 grad_scale = C.MultitypeFuncGraph("grad_scale") @@ -56,6 +58,7 @@ reciprocal = P.Reciprocal() @grad_scale.register("Tensor", "Tensor") def tensor_grad_scale(scale, grad): + """ grad scaling """ return grad*reciprocal(scale) @@ -65,7 +68,7 @@ class DPModel(Model): Args: micro_batches (int): The number of small batches split from an origianl batch. Default: None. - norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: None. + norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0. dp_mech (Mechanisms): The object can generate the different type of noise. Default: None. Examples: @@ -94,7 +97,7 @@ class DPModel(Model): >>> norm_bound=args.l2_norm_bound, >>> initial_noise_multiplier=args.initial_noise_multiplier) >>> model = DPModel(micro_batches=2, - >>> norm_clip=1, + >>> norm_clip=1.0, >>> dp_mech=gaussian_mech.mech, >>> network=net, >>> loss_fn=loss, @@ -103,16 +106,17 @@ class DPModel(Model): >>> dataset = get_dataset() >>> model.train(2, dataset) """ - def __init__(self, micro_batches=None, norm_clip=None, dp_mech=None, **kwargs): + def __init__(self, micro_batches=None, norm_clip=1.0, dp_mech=None, **kwargs): if micro_batches: self._micro_batches = int(micro_batches) else: self._micro_batches = None - self._norm_clip = norm_clip + float_norm_clip = check_param_type('l2_norm_clip', norm_clip, float) + self._norm_clip = check_value_positive('l2_norm_clip', float_norm_clip) self._dp_mech = dp_mech super(DPModel, self).__init__(**kwargs) - def amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs): + def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs): """ Build the mixed precision training cell automatically. @@ -185,18 +189,18 @@ class DPModel(Model): if self._micro_batches: if self._optimizer: if self._loss_scale_manager_set: - network = self.amp_build_train_network(network, - self._optimizer, - self._loss_fn, - level=self._amp_level, - loss_scale_manager=self._loss_scale_manager, - keep_batchnorm_fp32=self._keep_bn_fp32) + network = self._amp_build_train_network(network, + self._optimizer, + self._loss_fn, + level=self._amp_level, + loss_scale_manager=self._loss_scale_manager, + keep_batchnorm_fp32=self._keep_bn_fp32) else: - network = self.amp_build_train_network(network, - self._optimizer, - self._loss_fn, - level=self._amp_level, - keep_batchnorm_fp32=self._keep_bn_fp32) + network = self._amp_build_train_network(network, + self._optimizer, + self._loss_fn, + level=self._amp_level, + keep_batchnorm_fp32=self._keep_bn_fp32) elif self._loss_fn: network = nn.WithLossCell(network, self._loss_fn) else: @@ -273,8 +277,8 @@ class _TrainOneStepWithLossScaleCell(Cell): network (Cell): The training network. optimizer (Cell): Optimizer for updating the weights. scale_update_cell(Cell): The loss scaling update logic cell. Default: None. - micro_batches (int): The number of small batches split from an origianl batch. Default: None. - l2_norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: None. + micro_batches (int): The number of small batches split from an original batch. Default: None. + l2_norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0. mech (Mechanisms): The object can generate the different type of noise. Default: None. Inputs: @@ -288,21 +292,9 @@ class _TrainOneStepWithLossScaleCell(Cell): - **loss** (Tensor) - Tensor with shape :math:`()`. - **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool. - **loss_scale** (Tensor) - Tensor with shape :math:`()`. - - Examples: - >>> net_with_loss = Net() - >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9) - >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000) - >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=manager) - >>> train_network.set_train() - >>> - >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32)) - >>> label = Tensor(np.zeros([16, 16]).astype(np.float32)) - >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32) - >>> output = train_network(inputs, label, scaling_sens) """ - def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, l2_norm_clip=None, mech=None): + def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, l2_norm_clip=1.0, mech=None): super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) self.network = network self.network.add_flags(defer_inline=True) @@ -343,7 +335,8 @@ class _TrainOneStepWithLossScaleCell(Cell): # dp params self._micro_batches = micro_batches - self._l2_norm = l2_norm_clip + float_norm_clip = check_param_type('l2_norm_clip', l2_norm_clip, float) + self._l2_norm = check_value_positive('l2_norm_clip', float_norm_clip) self._split = P.Split(0, self._micro_batches) self._clip_by_global_norm = _ClipGradients() self._mech = mech @@ -435,9 +428,9 @@ class _TrainOneStepCell(Cell): Args: network (Cell): The training network. optimizer (Cell): Optimizer for updating the weights. - sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. - micro_batches (int): The number of small batches split from an origianl batch. Default: None. - l2_norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: None. + sens (Number): The scaling number to be filled as the input of back propagation. Default value is 1.0. + micro_batches (int): The number of small batches split from an original batch. Default: None. + l2_norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0. mech (Mechanisms): The object can generate the different type of noise. Default: None. Inputs: @@ -446,16 +439,9 @@ class _TrainOneStepCell(Cell): Outputs: Tensor, a scalar Tensor with shape :math:`()`. - - Examples: - >>> net = Net() - >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() - >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) - >>> loss_net = nn.WithLossCell(net, loss_fn) - >>> train_net = nn.TrainOneStepCell(loss_net, optim) """ - def __init__(self, network, optimizer, sens=1.0, micro_batches=None, l2_norm_clip=None, mech=None): + def __init__(self, network, optimizer, sens=1.0, micro_batches=None, l2_norm_clip=1.0, mech=None): super(_TrainOneStepCell, self).__init__(auto_prefix=False) self.network = network self.network.add_flags(defer_inline=True) @@ -475,7 +461,8 @@ class _TrainOneStepCell(Cell): # dp params self._micro_batches = micro_batches - self._l2_norm = l2_norm_clip + float_norm_clip = check_param_type('l2_norm_clip', l2_norm_clip, float) + self._l2_norm = check_value_positive('l2_norm_clip', float_norm_clip) self._split = P.Split(0, self._micro_batches) self._clip_by_global_norm = _ClipGradients() self._mech = mech diff --git a/tests/ut/python/diff_privacy/test_model_train.py b/tests/ut/python/diff_privacy/test_model_train.py index c20bb84..7373967 100644 --- a/tests/ut/python/diff_privacy/test_model_train.py +++ b/tests/ut/python/diff_privacy/test_model_train.py @@ -18,7 +18,6 @@ import pytest import numpy as np from mindspore import nn -from mindspore.nn import SGD from mindspore.model_zoo.lenet import LeNet5 from mindspore import context import mindspore.dataset as ds @@ -43,22 +42,24 @@ def test_dp_model(): context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") l2_norm_bound = 1.0 initial_noise_multiplier = 0.01 - net = LeNet5() + network = LeNet5() batch_size = 32 batches = 128 epochs = 1 loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) - optim = SGD(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) - gaussian_mech = DPOptimizerClassFactory() + gaussian_mech = DPOptimizerClassFactory(micro_batches=2) gaussian_mech.set_mechanisms('Gaussian', norm_bound=l2_norm_bound, initial_noise_multiplier=initial_noise_multiplier) + net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(), + learning_rate=0.1, + momentum=0.9) model = DPModel(micro_batches=2, norm_clip=l2_norm_bound, dp_mech=gaussian_mech.mech, - network=net, + network=network, loss_fn=loss, - optimizer=optim, + optimizer=net_opt, metrics=None) ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label']) ms_ds.set_dataset_size(batch_size * batches) -- GitLab