提交 b2e0934b 编写于 作者: Z zheng-huanhuan 提交者: zhenghuanhuan

5-month feature: add differential privacy train and optimizer.

上级 45cede10
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py
"""
from easydict import EasyDict as edict
mnist_cfg = edict({
'num_classes': 10,
'lr': 0.01,
'momentum': 0.9,
'epoch_size': 10,
'batch_size': 32,
'buffer_size': 1000,
'image_height': 32,
'image_width': 32,
'save_checkpoint_steps': 1875,
'keep_checkpoint_max': 10,
})
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python lenet5_dp_model_train.py --data_path /YourDataPath --micro_batches=2
"""
import os
import argparse
import mindspore.nn as nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import CheckpointConfig
from mindspore.train.callback import LossMonitor
from mindspore.nn.metrics import Accuracy
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.transforms.vision import Inter
import mindspore.common.dtype as mstype
from mindarmour.diff_privacy import DPModel
from mindarmour.diff_privacy import DPOptimizerClassFactory
from mindarmour.diff_privacy import PrivacyMonitorFactory
from mindarmour.utils.logger import LogUtil
from lenet5_net import LeNet5
from lenet5_config import mnist_cfg as cfg
LOGGER = LogUtil.get_instance()
TAG = 'Lenet5_train'
def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1, sparse=True):
"""
create dataset for training or testing
"""
# define dataset
ds1 = ds.MnistDataset(data_path)
# define operation parameters
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
# define map operations
resize_op = CV.Resize((resize_height, resize_width),
interpolation=Inter.LINEAR)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# apply map operations on images
if not sparse:
one_hot_enco = C.OneHot(10)
ds1 = ds1.map(input_columns="label", operations=one_hot_enco,
num_parallel_workers=num_parallel_workers)
type_cast_op = C.TypeCast(mstype.float32)
ds1 = ds1.map(input_columns="label", operations=type_cast_op,
num_parallel_workers=num_parallel_workers)
ds1 = ds1.map(input_columns="image", operations=resize_op,
num_parallel_workers=num_parallel_workers)
ds1 = ds1.map(input_columns="image", operations=rescale_op,
num_parallel_workers=num_parallel_workers)
ds1 = ds1.map(input_columns="image", operations=hwc2chw_op,
num_parallel_workers=num_parallel_workers)
# apply DatasetOps
buffer_size = 10000
ds1 = ds1.shuffle(buffer_size=buffer_size)
ds1 = ds1.batch(batch_size, drop_remainder=True)
ds1 = ds1.repeat(repeat_size)
return ds1
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_unzip",
help='path where the dataset is saved')
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
parser.add_argument('--micro_batches', type=float, default=None,
help='optional, if use differential privacy, need to set micro_batches')
parser.add_argument('--l2_norm_bound', type=float, default=1,
help='optional, if use differential privacy, need to set l2_norm_bound')
parser.add_argument('--initial_noise_multiplier', type=float, default=0.001,
help='optional, if use differential privacy, need to set initial_noise_multiplier')
args = parser.parse_args()
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target, enable_mem_reuse=False)
network = LeNet5()
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory='./trained_ckpt_file/',
config=config_ck)
ds_train = generate_mnist_dataset(os.path.join(args.data_path, "train"),
cfg.batch_size,
cfg.epoch_size)
if args.micro_batches and cfg.batch_size % args.micro_batches != 0:
raise ValueError("Number of micro_batches should divide evenly batch_size")
gaussian_mech = DPOptimizerClassFactory(args.micro_batches)
gaussian_mech.set_mechanisms('Gaussian',
norm_bound=args.l2_norm_bound,
initial_noise_multiplier=args.initial_noise_multiplier)
net_opt = gaussian_mech.create('Momentum')(params=network.trainable_params(),
learning_rate=cfg.lr,
momentum=cfg.momentum)
micro_size = int(cfg.batch_size // args.micro_batches)
rdp_monitor = PrivacyMonitorFactory.create('rdp',
num_samples=60000,
batch_size=micro_size,
initial_noise_multiplier=args.initial_noise_multiplier,
per_print_times=10)
model = DPModel(micro_batches=args.micro_batches,
norm_clip=args.l2_norm_bound,
dp_mech=gaussian_mech.mech,
network=network,
loss_fn=net_loss,
optimizer=net_opt,
metrics={"Accuracy": Accuracy()})
LOGGER.info(TAG, "============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor],
dataset_sink_mode=args.dataset_sink_mode)
LOGGER.info(TAG, "============== Starting Testing ==============")
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(network, param_dict)
ds_eval = generate_mnist_dataset(os.path.join(args.data_path, 'test'), batch_size=cfg.batch_size)
acc = model.eval(ds_eval, dataset_sink_mode=False)
LOGGER.info(TAG, "============== Accuracy: %s ==============", acc)
......@@ -4,7 +4,13 @@ This module provide Differential Privacy feature to protect user privacy.
from .mechanisms.mechanisms import GaussianRandom
from .mechanisms.mechanisms import AdaGaussianRandom
from .mechanisms.mechanisms import MechanismsFactory
from .monitor.monitor import PrivacyMonitorFactory
from .optimizer.optimizer import DPOptimizerClassFactory
from .train.model import DPModel
__all__ = ['GaussianRandom',
'AdaGaussianRandom',
'MechanismsFactory']
'MechanismsFactory',
'PrivacyMonitorFactory',
'DPOptimizerClassFactory',
'DPModel']
......@@ -60,10 +60,6 @@ class Mechanisms(Cell):
"""
Basic class of noise generated mechanism.
"""
def __init__(self):
pass
def construct(self, shape):
"""
Construct function.
......
......@@ -47,7 +47,7 @@ class PrivacyMonitorFactory:
parameters used for creating a privacy monitor.
Returns:
PrivacyMonitor, a privacy monitor.
Callback, a privacy monitor.
Examples:
>>> rdp = PrivacyMonitorFactory.create(policy='rdp',
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Differential privacy optimizer.
"""
import mindspore as ms
from mindspore import nn
from mindspore import Tensor
from mindarmour.diff_privacy.mechanisms.mechanisms import MechanismsFactory
class DPOptimizerClassFactory:
"""
Factory class of Optimizer.
Args:
micro_batches (int): The number of small batches split from an origianl batch. Default: None.
Returns:
Optimizer, Optimizer class
Examples:
>>> GaussianSGD = DPOptimizerClassFactory(micro_batches=2)
>>> GaussianSGD.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0)
>>> net_opt = GaussianSGD.create('SGD')(params=network.trainable_params(),
>>> learning_rate=cfg.lr,
>>> momentum=cfg.momentum)
"""
def __init__(self, micro_batches=None):
self._mech_factory = MechanismsFactory()
self.mech = None
self._micro_batches = micro_batches
def set_mechanisms(self, policy, *args, **kwargs):
"""
Get noise mechanism object.
Args:
policy (str): Choose mechanism type.
"""
self.mech = self._mech_factory.create(policy, *args, **kwargs)
def create(self, policy, *args, **kwargs):
"""
Create DP optimizer.
Args:
policy (str): Choose original optimizer type.
Returns:
Optimizer, A optimizer with DP.
"""
if policy == 'SGD':
cls = self._get_dp_optimizer_class(nn.SGD, self.mech, self._micro_batches, *args, **kwargs)
return cls
if policy == 'Momentum':
cls = self._get_dp_optimizer_class(nn.Momentum, self.mech, self._micro_batches, *args, **kwargs)
return cls
if policy == 'Adam':
cls = self._get_dp_optimizer_class(nn.Adam, self.mech, self._micro_batches, *args, **kwargs)
return cls
if policy == 'AdamWeightDecay':
cls = self._get_dp_optimizer_class(nn.AdamWeightDecay, self.mech, self._micro_batches, *args, **kwargs)
return cls
if policy == 'AdamWeightDecayDynamicLR':
cls = self._get_dp_optimizer_class(nn.AdamWeightDecayDynamicLR,
self.mech,
self._micro_batches,
*args, **kwargs)
return cls
raise NameError("The {} is not implement, please choose ['SGD', 'Momentum', 'AdamWeightDecay', "
"'Adam', 'AdamWeightDecayDynamicLR']".format(policy))
def _get_dp_optimizer_class(self, cls, mech, micro_batches):
"""
Wrap original mindspore optimizer with `self._mech`.
"""
class DPOptimizer(cls):
"""
Initialize the DPOptimizerClass.
Returns:
Optimizer, Optimizer class.
"""
def __init__(self, *args, **kwargs):
super(DPOptimizer, self).__init__(*args, **kwargs)
self._mech = mech
def construct(self, gradients):
"""
construct a compute flow.
"""
g_len = len(gradients)
gradient_noise = list(gradients)
for i in range(g_len):
gradient_noise[i] = gradient_noise[i].asnumpy()
gradient_noise[i] = self._mech(gradient_noise[i].shape).asnumpy() + gradient_noise[i]
gradient_noise[i] = gradient_noise[i] / micro_batches
gradient_noise[i] = Tensor(gradient_noise[i], ms.float32)
gradients = tuple(gradient_noise)
gradients = super(DPOptimizer, self).construct(gradients)
return gradients
return DPOptimizer
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Differential privacy model.
"""
from easydict import EasyDict as edict
import mindspore as ms
from mindspore.train.model import Model
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.train import amp
from mindspore.train.amp import _config_level
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel._utils import _get_parallel_mode
from mindspore.train.model import ParallelMode
from mindspore.train.amp import _do_keep_batchnorm_fp32
from mindspore.train.amp import _add_loss_network
from mindspore import context
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.operations import NPUGetFloatStatus
from mindspore.ops.operations import NPUAllocFloatStatus
from mindspore.ops.operations import NPUClearFloatStatus
from mindspore.ops.operations import ReduceSum
from mindspore.ops.operations import LessEqual
from mindspore.ops.operations import ControlDepend
from mindspore.parallel._utils import _get_mirror_mean
from mindspore.parallel._utils import _get_device_num
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.common.parameter import Parameter
from mindspore.nn.wrap.loss_scale import _grad_overflow
from mindspore.nn import Cell
from mindspore import ParameterTuple
GRADIENT_CLIP_TYPE = 1
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad*reciprocal(scale)
class DPModel(Model):
"""
This class is overload mindspore.train.model.Model.
Args:
micro_batches (int): The number of small batches split from an origianl batch. Default: None.
norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: None.
dp_mech (Mechanisms): The object can generate the different type of noise. Default: None.
Examples:
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
>>> self.bn = nn.BatchNorm2d(64)
>>> self.relu = nn.ReLU()
>>> self.flatten = nn.Flatten()
>>> self.fc = nn.Dense(64*224*224, 12) # padding=0
>>>
>>> def construct(self, x):
>>> x = self.conv(x)
>>> x = self.bn(x)
>>> x = self.relu(x)
>>> x = self.flatten(x)
>>> out = self.fc(x)
>>> return out
>>>
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> gaussian_mech = DPOptimizerClassFactory()
>>> gaussian_mech.set_mechanisms('Gaussian',
>>> norm_bound=args.l2_norm_bound,
>>> initial_noise_multiplier=args.initial_noise_multiplier)
>>> model = DPModel(micro_batches=2,
>>> norm_clip=1,
>>> dp_mech=gaussian_mech.mech,
>>> network=net,
>>> loss_fn=loss,
>>> optimizer=optim,
>>> metrics=None)
>>> dataset = get_dataset()
>>> model.train(2, dataset)
"""
def __init__(self, micro_batches=None, norm_clip=None, dp_mech=None, **kwargs):
if micro_batches:
self._micro_batches = int(micro_batches)
else:
self._micro_batches = None
self._norm_clip = norm_clip
self._dp_mech = dp_mech
super(DPModel, self).__init__(**kwargs)
def amp_build_train_network(self, network, optimizer, loss_fn=None, level='O0', **kwargs):
"""
Build the mixed precision training cell automatically.
Args:
network (Cell): Definition of the network.
loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, the `network` should have the loss inside.
Default: None.
optimizer (Optimizer): Optimizer to update the Parameter.
level (str): Supports [O0, O2]. Default: "O0".
- O0: Do not change.
- O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
using dynamic loss scale.
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`.
If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting.
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting.
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
scale the loss by LossScaleManager. If set, overwrite the level setting.
"""
validator.check_value_type('network', network, nn.Cell, None)
validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None)
self._check_kwargs(kwargs)
config = dict(_config_level[level], **kwargs)
config = edict(config)
if config.cast_model_type == mstype.float16:
network.to_float(mstype.float16)
if config.keep_batchnorm_fp32:
_do_keep_batchnorm_fp32(network)
if loss_fn:
network = _add_loss_network(network, loss_fn, config.cast_model_type)
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network = _VirtualDatasetCell(network)
loss_scale = 1.0
if config.loss_scale_manager is not None:
loss_scale_manager = config.loss_scale_manager
loss_scale = loss_scale_manager.get_loss_scale()
update_cell = loss_scale_manager.get_update_cell()
if update_cell is not None:
# only cpu not support `TrainOneStepWithLossScaleCell` for control flow.
if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU":
raise ValueError("Only `loss_scale_manager=None` and "
"`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`"
"are supported in current version. If you use `O2` option, please"
"use `loss_scale_manager=None` or `FixedLossScaleManager`")
network = _TrainOneStepWithLossScaleCell(network,
optimizer,
scale_update_cell=update_cell,
micro_batches=self._micro_batches,
l2_norm_clip=self._norm_clip,
mech=self._dp_mech).set_train()
return network
network = _TrainOneStepCell(network,
optimizer,
loss_scale,
micro_batches=self._micro_batches,
l2_norm_clip=self._norm_clip,
mech=self._dp_mech).set_train()
return network
def _build_train_network(self):
"""Build train network"""
network = self._network
if self._micro_batches:
if self._optimizer:
if self._loss_scale_manager_set:
network = self.amp_build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
loss_scale_manager=self._loss_scale_manager,
keep_batchnorm_fp32=self._keep_bn_fp32)
else:
network = self.amp_build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
keep_batchnorm_fp32=self._keep_bn_fp32)
elif self._loss_fn:
network = nn.WithLossCell(network, self._loss_fn)
else:
if self._optimizer:
if self._loss_scale_manager_set:
network = amp.build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
loss_scale_manager=self._loss_scale_manager,
keep_batchnorm_fp32=self._keep_bn_fp32)
else:
network = amp.build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
keep_batchnorm_fp32=self._keep_bn_fp32)
elif self._loss_fn:
network = nn.WithLossCell(network, self._loss_fn)
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network.set_auto_parallel()
return network
class _ClipGradients(nn.Cell):
"""
Clip gradients.
Inputs:
grads (tuple[Tensor]): Gradients.
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
Outputs:
tuple[Tensor], clipped gradients.
"""
def __init__(self):
super(_ClipGradients, self).__init__()
self.clip_by_norm = nn.ClipByNorm()
self.dtype = P.DType()
def construct(self, grads, clip_type, clip_value):
"""
construct a compute flow.
"""
if clip_type not in (0, 1):
return grads
new_grads = ()
for grad in grads:
if clip_type == 0:
t = C.clip_by_value(grad, F.tuple_to_array((-clip_value,)),
F.tuple_to_array((clip_value,)))
else:
t = self.clip_by_norm(grad, F.tuple_to_array((clip_value,)))
new_grads = new_grads + (t,)
return new_grads
class _TrainOneStepWithLossScaleCell(Cell):
r"""
Network training with loss scaling.
This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update
Cell as args. The loss scale value can be updated in both host side or device side. The
TrainOneStepWithLossScaleCell will be compiled to be graph which takes `data`, `label`, `sens` as input
data. The `sens` is acting as loss scaling value. If you want to update it on host side, the value should
be provided. If `sens` is not given, the loss scale update logic should be provied by `scale_update_cell`.
If `scale_update_cell` is not None and `sens` is provided, the `scale_update_cell` will be ignored.
Args:
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
scale_update_cell(Cell): The loss scaling update logic cell. Default: None.
micro_batches (int): The number of small batches split from an origianl batch. Default: None.
l2_norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: None.
mech (Mechanisms): The object can generate the different type of noise. Default: None.
Inputs:
- **inputs** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **scaling_sens** (Tensor) - Tensor of shape :math:`()`.
Outputs:
Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value.
- **loss** (Tensor) - Tensor with shape :math:`()`.
- **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool.
- **loss_scale** (Tensor) - Tensor with shape :math:`()`.
Examples:
>>> net_with_loss = Net()
>>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=manager)
>>> train_network.set_train()
>>>
>>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
>>> output = train_network(inputs, label, scaling_sens)
"""
def __init__(self, network, optimizer, scale_update_cell=None, micro_batches=None, l2_norm_clip=None, mech=None):
super(_TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.hyper_map = C.HyperMap()
if context.get_context("device_target") == "GPU":
self.gpu_target = True
self.float_status = P.FloatStatus()
self.addn = P.AddN()
self.reshape = P.Reshape()
else:
self.gpu_target = False
self.alloc_status = NPUAllocFloatStatus()
self.get_status = NPUGetFloatStatus()
self.clear_status = NPUClearFloatStatus()
self.reduce_sum = ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = LessEqual()
self.depend_parameter_use = ControlDepend(depend_mode=1)
self.allreduce = P.AllReduce()
self.parallel_mode = _get_parallel_mode()
self.grad_reducer = F.identity
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]
if self.reducer_flag:
mean = _get_mirror_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
self.add_flags(has_effect=True)
# dp params
self._micro_batches = micro_batches
self._l2_norm = l2_norm_clip
self._split = P.Split(0, self._micro_batches)
self._clip_by_global_norm = _ClipGradients()
self._mech = mech
def construct(self, data, label, sens=None):
"""
construct a compute flow.
"""
init = False
if not self.gpu_target:
# init overflow buffer
init = self.alloc_status()
# clear overflow buffer
self.clear_status(init)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
# DP clip
weights = self.weights
record_datas = self._split(data)
record_labels = self._split(label)
grads = ()
# first index
loss = self.network(record_datas[0], record_labels[0])
scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss))
record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], scaling_sens_filled)
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
grad_sum = list(record_grad)
grad_len = len(record_grad)
for i in range(grad_len):
grad_sum[i] = grad_sum[i].asnumpy()
for i in range(1, self._micro_batches):
loss = self.network(record_datas[i], record_labels[i])
scaling_sens_filled = C.ones_like(loss)*F.cast(scaling_sens, F.dtype(loss))
record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], scaling_sens_filled)
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
for j in range(grad_len):
grad_sum[j] = grad_sum[j] + record_grad[j].asnumpy()
for i in range(grad_len):
grad_sum[i] = Tensor(grad_sum[i], ms.float32)
grads = tuple(grad_sum)
loss = self.network(data, label)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# get the overflow buffer
if not self.gpu_target:
self.get_status(init)
# sum overflow buffer elements, 0:not overflow , >0:overflow
flag_sum = self.reduce_sum(init, (0,))
else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
flag_sum = self.addn(flag_sum)
# convert flag_sum to scalar
flag_sum = self.reshape(flag_sum, (()))
if self.is_distributed:
# sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond)
# if there is no overflow, do optimize
if overflow:
opt = False
else:
opt = self.optimizer(grads)
ret = (loss, cond, scaling_sens)
return F.depend(ret, opt)
class _TrainOneStepCell(Cell):
r"""
Network training package class.
Wraps the network with an optimizer. The resulting Cell be trained with input data and label.
Backward graph will be created in the construct function to do parameter updating. Different
parallel modes are available to run the training.
Args:
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
micro_batches (int): The number of small batches split from an origianl batch. Default: None.
l2_norm_clip (float): Use to clip the bound, if set 1, will retun the original data. Default: None.
mech (Mechanisms): The object can generate the different type of noise. Default: None.
Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
Outputs:
Tensor, a scalar Tensor with shape :math:`()`.
Examples:
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> loss_net = nn.WithLossCell(net, loss_fn)
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
"""
def __init__(self, network, optimizer, sens=1.0, micro_batches=None, l2_norm_clip=None, mech=None):
super(_TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_mirror_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
# dp params
self._micro_batches = micro_batches
self._l2_norm = l2_norm_clip
self._split = P.Split(0, self._micro_batches)
self._clip_by_global_norm = _ClipGradients()
self._mech = mech
def construct(self, data, label):
"""
construct a compute flow.
"""
weights = self.weights
record_datas = self._split(data)
record_labels = self._split(label)
loss = self.network(record_datas[0], record_labels[0])
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], sens)
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
grad_sum = list(record_grad)
grad_len = len(record_grad)
for i in range(grad_len):
grad_sum[i] = grad_sum[i].asnumpy()
for i in range(1, self._micro_batches):
loss = self.network(record_datas[i], record_labels[i])
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], sens)
record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm)
for j in range(grad_len):
grad_sum[j] = grad_sum[j] + record_grad[j].asnumpy()
for i in range(grad_len):
grad_sum[i] = Tensor(grad_sum[i], ms.float32)
grads = tuple(grad_sum)
loss = self.network(data, label)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DP-Model test.
"""
import pytest
import numpy as np
from mindspore import nn
from mindspore.nn import SGD
from mindspore.model_zoo.lenet import LeNet5
from mindspore import context
import mindspore.dataset as ds
from mindarmour.diff_privacy import DPOptimizerClassFactory
from mindarmour.diff_privacy import DPModel
def dataset_generator(batch_size, batches):
data = np.random.random((batches * batch_size, 1, 32, 32)).astype(np.float32)
label = np.random.randint(0, 10, batches * batch_size).astype(np.int32)
for i in range(batches):
yield data[i * batch_size:(i + 1) * batch_size], label[i * batch_size:(i + 1) * batch_size]
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_dp_model():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
l2_norm_bound = 1.0
initial_noise_multiplier = 0.01
net = LeNet5()
batch_size = 32
batches = 128
epochs = 1
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
optim = SGD(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
gaussian_mech = DPOptimizerClassFactory()
gaussian_mech.set_mechanisms('Gaussian',
norm_bound=l2_norm_bound,
initial_noise_multiplier=initial_noise_multiplier)
model = DPModel(micro_batches=2,
norm_clip=l2_norm_bound,
dp_mech=gaussian_mech.mech,
network=net,
loss_fn=loss,
optimizer=optim,
metrics=None)
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label'])
ms_ds.set_dataset_size(batch_size * batches)
model.train(epochs, ms_ds)
......@@ -23,7 +23,7 @@ from mindspore.train import Model
import mindspore.context as context
from mindspore.model_zoo.lenet import LeNet5
from mindarmour.diff_privacy.monitor.monitor import PrivacyMonitorFactory
from mindarmour.diff_privacy import PrivacyMonitorFactory
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from mindspore import nn
from mindspore import context
from mindspore.model_zoo.lenet import LeNet5
from mindspore.train.model import Model
from mindarmour.diff_privacy import DPOptimizerClassFactory
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_optimizer():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
network = LeNet5()
lr = 0.01
momentum = 0.9
micro_batches = 2
loss = nn.SoftmaxCrossEntropyWithLogits()
gaussian_mech = DPOptimizerClassFactory(micro_batches)
gaussian_mech.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0)
net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(), learning_rate=lr,
momentum=momentum)
_ = Model(network, loss_fn=loss, optimizer=net_opt, metrics=None)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_inference
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_optimizer_gpu():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
network = LeNet5()
lr = 0.01
momentum = 0.9
micro_batches = 2
loss = nn.SoftmaxCrossEntropyWithLogits()
gaussian_mech = DPOptimizerClassFactory(micro_batches)
gaussian_mech.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0)
net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(), learning_rate=lr,
momentum=momentum)
_ = Model(network, loss_fn=loss, optimizer=net_opt, metrics=None)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_optimizer_cpu():
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
network = LeNet5()
lr = 0.01
momentum = 0.9
micro_batches = 2
loss = nn.SoftmaxCrossEntropyWithLogits()
gaussian_mech = DPOptimizerClassFactory(micro_batches)
gaussian_mech.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0)
net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(), learning_rate=lr,
momentum=momentum)
_ = Model(network, loss_fn=loss, optimizer=net_opt, metrics=None)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册