提交 f3baf9db 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!28 fix issue

Merge pull request !28 from zheng-huanhuan/dp_1
......@@ -123,10 +123,9 @@ if __name__ == "__main__":
net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(),
learning_rate=cfg.lr,
momentum=cfg.momentum)
micro_size = int(cfg.batch_size // args.micro_batches)
rdp_monitor = PrivacyMonitorFactory.create('rdp',
num_samples=60000,
batch_size=micro_size,
batch_size=cfg.batch_size,
initial_noise_multiplier=args.initial_noise_multiplier,
per_print_times=10)
model = DPModel(micro_batches=args.micro_batches,
......
......@@ -19,6 +19,7 @@ from mindspore import nn
from mindspore import Tensor
from mindarmour.diff_privacy.mechanisms.mechanisms import MechanismsFactory
from mindarmour.utils._check_param import check_int_positive
class DPOptimizerClassFactory:
......@@ -41,7 +42,7 @@ class DPOptimizerClassFactory:
def __init__(self, micro_batches=None):
self._mech_factory = MechanismsFactory()
self.mech = None
self._micro_batches = micro_batches
self._micro_batches = check_int_positive('micro_batches', micro_batches)
def set_mechanisms(self, policy, *args, **kwargs):
"""
......
......@@ -48,6 +48,8 @@ from mindspore.nn.wrap.loss_scale import _grad_overflow
from mindspore.nn import Cell
from mindspore import ParameterTuple
from mindarmour.utils._check_param import check_param_type
from mindarmour.utils._check_param import check_value_positive
GRADIENT_CLIP_TYPE = 1
grad_scale = C.MultitypeFuncGraph("grad_scale")
......@@ -56,6 +58,7 @@ reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
""" grad scaling """
return grad*reciprocal(scale)
......@@ -65,7 +68,7 @@ class DPModel(Model):
Args:
micro_batches (int): The number of small batches split from an origianl batch. Default: None.
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: None.
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: 1.0.
dp_mech (Mechanisms): The object can generate the different type of noise. Default: None.
Examples:
......@@ -94,7 +97,7 @@ class DPModel(Model):
>>> norm_bound=args.l2_norm_bound,
>>> initial_noise_multiplier=args.initial_noise_multiplier)
>>> model = DPModel(micro_batches=2,
>>> norm_clip=1,
>>> norm_clip=1.0,
>>> dp_mech=gaussian_mech.mech,
>>> network=net,
>>> loss_fn=loss,
......@@ -103,16 +106,17 @@ class DPModel(Model):
>>> dataset = get_dataset()
>>> model.train(2, dataset)
"""
def __init__(self, micro_batches=None, norm_clip=None, dp_mech=None, **kwargs):
def __init__(self, micro_batches=None, norm_clip=1.0, dp_mech=None, **kwargs):
if micro_batches:
self._micro_batches = int(micro_batches)
else:
self._micro_batches = None
self._norm_clip = norm_clip
float_norm_clip = check_param_type('l2_norm_clip', norm_clip, float)
self._norm_clip = check_value_positive('l2_norm_clip', float_norm_clip)
self._dp_mech = dp_mech
super(DPModel, self).__init__(**kwargs)
def amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs):
def _amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs):
"""
Build the mixed precision training cell automatically.
......@@ -185,18 +189,18 @@ class DPModel(Model):
if self._micro_batches:
if self._optimizer:
if self._loss_scale_manager_set:
network = self.amp_build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
loss_scale_manager=self._loss_scale_manager,
keep_batchnorm_fp32=self._keep_bn_fp32)
network = self._amp_build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
loss_scale_manager=self._loss_scale_manager,
keep_batchnorm_fp32=self._keep_bn_fp32)
else:
network = self.amp_build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
keep_batchnorm_fp32=self._keep_bn_fp32)
network = self._amp_build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
keep_batchnorm_fp32=self._keep_bn_fp32)
elif self._loss_fn:
network = nn.WithLossCell(network, self._loss_fn)
else:
......@@ -273,8 +277,8 @@ class _TrainOneStepWithLossScaleCell(Cell):
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
scale_update_cell(Cell): The loss scaling update logic cell. Default: None.
micro_batches (int): The number of small batches split from an origianl batch. Default: None.
l2_norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: None.
micro_batches (int): The number of small batches split from an original batch. Default: None.
l2_norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0.
mech (Mechanisms): The object can generate the different type of noise. Default: None.
Inputs:
......@@ -288,21 +292,9 @@ class _TrainOneStepWithLossScaleCell(Cell):
- **loss** (Tensor) - Tensor with shape :math:`()`.
- **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool.
- **loss_scale** (Tensor) - Tensor with shape :math:`()`.
Examples:
>>> net_with_loss = Net()
>>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=manager)
>>> train_network.set_train()
>>>
>>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
>>> output = train_network(inputs, label, scaling_sens)
"""
def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, l2_norm_clip=None, mech=None):
def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, l2_norm_clip=1.0, mech=None):
super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.add_flags(defer_inline=True)
......@@ -343,7 +335,8 @@ class _TrainOneStepWithLossScaleCell(Cell):
# dp params
self._micro_batches = micro_batches
self._l2_norm = l2_norm_clip
float_norm_clip = check_param_type('l2_norm_clip', l2_norm_clip, float)
self._l2_norm = check_value_positive('l2_norm_clip', float_norm_clip)
self._split = P.Split(0, self._micro_batches)
self._clip_by_global_norm = _ClipGradients()
self._mech = mech
......@@ -435,9 +428,9 @@ class _TrainOneStepCell(Cell):
Args:
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
micro_batches (int): The number of small batches split from an origianl batch. Default: None.
l2_norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: None.
sens (Number): The scaling number to be filled as the input of back propagation. Default value is 1.0.
micro_batches (int): The number of small batches split from an original batch. Default: None.
l2_norm_clip (float): Use to clip the bound, if set 1, will return the original data. Default: 1.0.
mech (Mechanisms): The object can generate the different type of noise. Default: None.
Inputs:
......@@ -446,16 +439,9 @@ class _TrainOneStepCell(Cell):
Outputs:
Tensor, a scalar Tensor with shape :math:`()`.
Examples:
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> loss_net = nn.WithLossCell(net, loss_fn)
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
"""
def __init__(self, network, optimizer, sens=1.0, micro_batches=None, l2_norm_clip=None, mech=None):
def __init__(self, network, optimizer, sens=1.0, micro_batches=None, l2_norm_clip=1.0, mech=None):
super(_TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.add_flags(defer_inline=True)
......@@ -475,7 +461,8 @@ class _TrainOneStepCell(Cell):
# dp params
self._micro_batches = micro_batches
self._l2_norm = l2_norm_clip
float_norm_clip = check_param_type('l2_norm_clip', l2_norm_clip, float)
self._l2_norm = check_value_positive('l2_norm_clip', float_norm_clip)
self._split = P.Split(0, self._micro_batches)
self._clip_by_global_norm = _ClipGradients()
self._mech = mech
......
......@@ -18,7 +18,6 @@ import pytest
import numpy as np
from mindspore import nn
from mindspore.nn import SGD
from mindspore.model_zoo.lenet import LeNet5
from mindspore import context
import mindspore.dataset as ds
......@@ -43,22 +42,24 @@ def test_dp_model():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
l2_norm_bound = 1.0
initial_noise_multiplier = 0.01
net = LeNet5()
network = LeNet5()
batch_size = 32
batches = 128
epochs = 1
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
optim = SGD(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
gaussian_mech = DPOptimizerClassFactory()
gaussian_mech = DPOptimizerClassFactory(micro_batches=2)
gaussian_mech.set_mechanisms('Gaussian',
norm_bound=l2_norm_bound,
initial_noise_multiplier=initial_noise_multiplier)
net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(),
learning_rate=0.1,
momentum=0.9)
model = DPModel(micro_batches=2,
norm_clip=l2_norm_bound,
dp_mech=gaussian_mech.mech,
network=net,
network=network,
loss_fn=loss,
optimizer=optim,
optimizer=net_opt,
metrics=None)
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label'])
ms_ds.set_dataset_size(batch_size * batches)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册