提交 ffe483ca 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!498 Update tutorial of differential privacy.

Merge pull request !498 from ZhidanLiu/master
......@@ -43,7 +43,7 @@ MindArmour differential privacy module Differential-Privacy implements the diffe
The LeNet model and MNIST dataset are used as an example to describe how to use the differential privacy optimizer to train a neural network model on MindSpore.
> This example is for the Ascend 910 AI processor and supports PYNATIVE_MODE. You can download the complete sample code from <https://gitee.com/mindspore/mindarmour/blob/master/example/mnist_demo/lenet5_dp_model_train.py>.
> This example is for the Ascend 910 AI processor. You can download the complete sample code from <https://gitee.com/mindspore/mindarmour/blob/master/example/mnist_demo/lenet5_dp.py>.
## Implementation
......@@ -82,12 +82,12 @@ TAG = 'Lenet5_train'
### Configuring Parameters
1. Set the running environment, dataset path, model training parameters, checkpoint storage parameters, and differential privacy parameters.
1. Set the running environment, dataset path, model training parameters, checkpoint storage parameters, and differential privacy parameters. Replace 'data_path' with you data path.
```python
cfg = edict({
'num_classes': 10, # the number of classes of model's output
'lr': 0.1, # the learning rate of model's optimizer
'lr': 0.01, # the learning rate of model's optimizer
'momentum': 0.9, # the momentum value of model's optimizer
'epoch_size': 10, # training epochs
'batch_size': 256, # batch size for training
......@@ -99,10 +99,15 @@ TAG = 'Lenet5_train'
'data_path': './MNIST_unzip', # the path of training and testing data set
'dataset_sink_mode': False, # whether deliver all training data to device one time
'micro_batches': 16, # the number of small batches split from an original batch
'norm_clip': 1.0, # the clip bound of the gradients of model's training parameters
'initial_noise_multiplier': 1.5, # the initial multiplication coefficient of the noise added to training
'norm_bound': 1.0, # the clip bound of the gradients of model's training parameters
'initial_noise_multiplier': 1.0, # the initial multiplication coefficient of the noise added to training
# parameters' gradients
'mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training
'noise_mechanisms': 'Gaussian', # the method of adding noise in gradients while training
'clip_mechanisms': 'Gaussian', # the method of adaptive clipping gradients while training
'clip_decay_policy': 'Linear', # Decay policy of adaptive clipping, decay_policy must be in ['Linear', 'Geometric'].
'clip_learning_rate': 0.001, # Learning rate of update norm clip.
'target_unclipped_quantile': 0.9, # Target quantile of norm clip.
'fraction_stddev': 0.01, # The stddev of Gaussian normal which used in empirical_fraction.
'optimizer': 'Momentum' # the base optimizer used for Differential privacy training
})
```
......@@ -110,7 +115,7 @@ TAG = 'Lenet5_train'
2. Configure necessary information, including the environment information and execution mode.
```python
context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target)
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
```
For details about the API configuration, see the `context.set_context`.
......@@ -241,41 +246,48 @@ ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"),
1. Set parameters of a differential privacy optimizer.
- Determine whether values of the **micro_batches** and **batch_size** parameters meet the requirements. The value of **batch_size** must be an integer multiple of **micro_batches**.
- Determine whether values of the `micro_batches` and `batch_size` parameters meet the requirements. The value of `batch_size` must be an integer multiple of `micro_batches`.
- Instantiate a differential privacy factory class.
- Set a noise mechanism for the differential privacy. Currently, the Gaussian noise mechanism with a fixed standard deviation (`Gaussian`) and the Gaussian noise mechanism with an adaptive standard deviation (`AdaGaussian`) are supported.
- Set an optimizer type. Currently, `SGD`, `Momentum`, and `Adam` are supported.
- Set up a differential privacy budget monitor RDP to observe changes in the differential privacy budget $\epsilon$ in each step.
```python
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0:
raise ValueError("Number of micro_batches should divide evenly batch_size")
# Create a factory class of DP optimizer
gaussian_mech = DPOptimizerClassFactory(cfg.micro_batches)
# Set the method of adding noise in gradients while training. Initial_noise_multiplier is suggested to be greater
# than 1.0, otherwise the privacy budget would be huge, which means that the privacy protection effect is weak.
# mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' mechanism while
# be constant with 'Gaussian' mechanism.
gaussian_mech.set_mechanisms(cfg.mechanisms,
norm_bound=cfg.l2_norm_bound,
initial_noise_multiplier=cfg.initial_noise_multiplier)
# Wrap the base optimizer for DP training. Momentum optimizer is suggested for LenNet5.
net_opt = gaussian_mech.create(cfg.optimizer)(params=network.trainable_params(),
learning_rate=cfg.lr,
momentum=cfg.momentum)
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps
# and delta) while training.
rdp_monitor = PrivacyMonitorFactory.create('rdp',
num_samples=60000,
batch_size=cfg.batch_size,
initial_noise_multiplier=cfg.initial_noise_multiplier*
cfg.l2_norm_bound,
per_print_times=50)
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0:
raise ValueError(
"Number of micro_batches should divide evenly batch_size")
# Create a factory class of DP noise mechanisms, this method is adding noise
# in gradients while training. Initial_noise_multiplier is suggested to be
# greater than 1.0, otherwise the privacy budget would be huge, which means
# that the privacy protection effect is weak. Mechanisms can be 'Gaussian'
# or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian'
# mechanism while be constant with 'Gaussian' mechanism.
noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms,
norm_bound=cfg.norm_bound,
initial_noise_multiplier=cfg.initial_noise_multiplier,
decay_policy=None)
# Create a factory class of clip mechanisms, this method is to adaptive clip
# gradients while training, decay_policy support 'Linear' and 'Geometric',
# learning_rate is the learning rate to update clip_norm,
# target_unclipped_quantile is the target quantile of norm clip,
# fraction_stddev is the stddev of Gaussian normal which used in
# empirical_fraction, the formula is
# $empirical_fraction + N(0, fraction_stddev)$.
clip_mech = ClipMechanismsFactory().create(cfg.clip_mechanisms,
decay_policy=cfg.clip_decay_policy,
learning_rate=cfg.clip_learning_rate,
target_unclipped_quantile=cfg.target_unclipped_quantile,
fraction_stddev=cfg.fraction_stddev)
net_opt = nn.Momentum(params=network.trainable_params(),
learning_rate=cfg.lr, momentum=cfg.momentum)
# Create a monitor for DP training. The function of the monitor is to
# compute and print the privacy budget(eps and delta) while training.
rdp_monitor = PrivacyMonitorFactory.create('rdp',
num_samples=60000,
batch_size=cfg.batch_size,
initial_noise_multiplier=cfg.initial_noise_multiplier,
per_print_times=234,
noise_decay_mode=None)
```
2. Package the LeNet model as a differential privacy model by transferring the network to `DPModel`.
......@@ -283,8 +295,9 @@ ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"),
```python
# Create the DP model for training.
model = DPModel(micro_batches=cfg.micro_batches,
norm_clip=cfg.l2_norm_bound,
dp_mech=gaussian_mech.mech,
norm_bound=cfg.norm_bound,
noise_mech=noise_mech,
clip_mech=clip_mech,
network=network,
loss_fn=net_loss,
optimizer=net_opt,
......@@ -293,18 +306,20 @@ ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"),
3. Train and test the model.
```python
LOGGER.info(TAG, "============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor],
dataset_sink_mode=cfg.dataset_sink_mode)
LOGGER.info(TAG, "============== Starting Testing ==============")
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt'
param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(network, param_dict)
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), batch_size=cfg.batch_size)
acc = model.eval(ds_eval, dataset_sink_mode=False)
LOGGER.info(TAG, "============== Accuracy: %s ==============", acc)
```python
LOGGER.info(TAG, "============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train,
callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor],
dataset_sink_mode=cfg.dataset_sink_mode)
LOGGER.info(TAG, "============== Starting Testing ==============")
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt'
param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(network, param_dict)
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'),
batch_size=cfg.batch_size)
acc = model.eval(ds_eval, dataset_sink_mode=False)
LOGGER.info(TAG, "============== Accuracy: %s ==============", acc)
```
4. Run the following command.
......@@ -312,20 +327,20 @@ ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"),
Execute the script:
```bash
python lenet5_dp_model_train.py
python lenet5_dp.py
```
In the preceding command, replace `lenet5_dp_model_train.py` with the name of your script.
In the preceding command, replace `lenet5_dp.py` with the name of your script.
5. Display the result.
The accuracy of the LeNet model without differential privacy is 99%, and the accuracy of the LeNet model with adaptive differential privacy AdaDP is 98%.
The accuracy of the LeNet model without differential privacy is 99%, and the accuracy of the LeNet model with Gaussian noise and adaptive clip differential privacy is 97%.
```
============== Starting Training ==============
...
============== Starting Testing ==============
...
============== Accuracy: 0.9879 ==============
============== Accuracy: 0.9767 ==============
```
### References
......
......@@ -29,7 +29,7 @@ MindArmour的差分隐私模块Differential-Privacy,实现了差分隐私优
这里以LeNet模型,MNIST 数据集为例,说明如何在MindSpore上使用差分隐私优化器训练神经网络模型。
> 本例面向Ascend 910 AI处理器,支持PYNATIVE_MODE,你可以在这里下载完整的样例代码:<https://gitee.com/mindspore/mindarmour/blob/master/example/mnist_demo/lenet5_dp_model_train.py>
> 本例面向Ascend 910 AI处理器,你可以在这里下载完整的样例代码:<https://gitee.com/mindspore/mindarmour/blob/master/example/mnist_demo/lenet5_dp.py>
## 实现阶段
......@@ -68,12 +68,12 @@ TAG = 'Lenet5_train'
### 参数配置
1. 设置运行环境、数据集路径、模型训练参数、checkpoint存储参数、差分隐私参数。
1. 设置运行环境、数据集路径、模型训练参数、checkpoint存储参数、差分隐私参数`data_path`数据路径替换成你的数据集所在路径
```python
cfg = edict({
'num_classes': 10, # the number of classes of model's output
'lr': 0.1, # the learning rate of model's optimizer
'lr': 0.01, # the learning rate of model's optimizer
'momentum': 0.9, # the momentum value of model's optimizer
'epoch_size': 10, # training epochs
'batch_size': 256, # batch size for training
......@@ -85,10 +85,15 @@ TAG = 'Lenet5_train'
'data_path': './MNIST_unzip', # the path of training and testing data set
'dataset_sink_mode': False, # whether deliver all training data to device one time
'micro_batches': 16, # the number of small batches split from an original batch
'norm_clip': 1.0, # the clip bound of the gradients of model's training parameters
'initial_noise_multiplier': 1.5, # the initial multiplication coefficient of the noise added to training
'norm_bound': 1.0, # the clip bound of the gradients of model's training parameters
'initial_noise_multiplier': 1.0, # the initial multiplication coefficient of the noise added to training
# parameters' gradients
'mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training
'noise_mechanisms': 'Gaussian', # the method of adding noise in gradients while training
'clip_mechanisms': 'Gaussian', # the method of adaptive clipping gradients while training
'clip_decay_policy': 'Linear', # Decay policy of adaptive clipping, decay_policy must be in ['Linear', 'Geometric'].
'clip_learning_rate': 0.001, # Learning rate of update norm clip.
'target_unclipped_quantile': 0.9, # Target quantile of norm clip.
'fraction_stddev': 0.01, # The stddev of Gaussian normal which used in empirical_fraction.
'optimizer': 'Momentum' # the base optimizer used for Differential privacy training
})
```
......@@ -96,7 +101,7 @@ TAG = 'Lenet5_train'
2. 配置必要的信息,包括环境信息、执行的模式。
```python
context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target)
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
```
详细的接口配置信息,请参见`context.set_context`接口说明。
......@@ -227,41 +232,48 @@ ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"),
1. 配置差分隐私优化器的参数。
- 判断micro_batches和batch_size参数是否符合要求,batch_size必须要整除micro_batches
- 判断`micro_batches``batch_size`参数是否符合要求,`batch_size`必须要整除`micro_batches`
- 实例化差分隐私工厂类。
- 设置差分隐私的噪声机制,目前mechanisms支持固定标准差的高斯噪声机制:`Gaussian`和自适应调整标准差的高斯噪声机制:`AdaGaussian`
- 设置优化器类型,目前支持`SGD``Momentum``Adam`
- 设置差分隐私预算监测器RDP,用于观测每个step中的差分隐私预算$\epsilon$的变化。
```python
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0:
raise ValueError("Number of micro_batches should divide evenly batch_size")
# Create a factory class of DP optimizer
gaussian_mech = DPOptimizerClassFactory(cfg.micro_batches)
# Set the method of adding noise in gradients while training. Initial_noise_multiplier is suggested to be greater
# than 1.0, otherwise the privacy budget would be huge, which means that the privacy protection effect is weak.
# mechanisms can be 'Gaussian' or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' mechanism while
# be constant with 'Gaussian' mechanism.
gaussian_mech.set_mechanisms(cfg.mechanisms,
norm_bound=cfg.l2_norm_bound,
initial_noise_multiplier=cfg.initial_noise_multiplier)
# Wrap the base optimizer for DP training. Momentum optimizer is suggested for LenNet5.
net_opt = gaussian_mech.create(cfg.optimizer)(params=network.trainable_params(),
learning_rate=cfg.lr,
momentum=cfg.momentum)
# Create a monitor for DP training. The function of the monitor is to compute and print the privacy budget(eps
# and delta) while training.
rdp_monitor = PrivacyMonitorFactory.create('rdp',
num_samples=60000,
batch_size=cfg.batch_size,
initial_noise_multiplier=cfg.initial_noise_multiplier*
cfg.l2_norm_bound,
per_print_times=50)
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0:
raise ValueError(
"Number of micro_batches should divide evenly batch_size")
# Create a factory class of DP noise mechanisms, this method is adding noise
# in gradients while training. Initial_noise_multiplier is suggested to be
# greater than 1.0, otherwise the privacy budget would be huge, which means
# that the privacy protection effect is weak. Mechanisms can be 'Gaussian'
# or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian'
# mechanism while be constant with 'Gaussian' mechanism.
noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms,
norm_bound=cfg.norm_bound,
initial_noise_multiplier=cfg.initial_noise_multiplier,
decay_policy=None)
# Create a factory class of clip mechanisms, this method is to adaptive clip
# gradients while training, decay_policy support 'Linear' and 'Geometric',
# learning_rate is the learning rate to update clip_norm,
# target_unclipped_quantile is the target quantile of norm clip,
# fraction_stddev is the stddev of Gaussian normal which used in
# empirical_fraction, the formula is
# $empirical_fraction + N(0, fraction_stddev)$.
clip_mech = ClipMechanismsFactory().create(cfg.clip_mechanisms,
decay_policy=cfg.clip_decay_policy,
learning_rate=cfg.clip_learning_rate,
target_unclipped_quantile=cfg.target_unclipped_quantile,
fraction_stddev=cfg.fraction_stddev)
net_opt = nn.Momentum(params=network.trainable_params(),
learning_rate=cfg.lr, momentum=cfg.momentum)
# Create a monitor for DP training. The function of the monitor is to
# compute and print the privacy budget(eps and delta) while training.
rdp_monitor = PrivacyMonitorFactory.create('rdp',
num_samples=60000,
batch_size=cfg.batch_size,
initial_noise_multiplier=cfg.initial_noise_multiplier,
per_print_times=234,
noise_decay_mode=None)
```
2. 将LeNet模型包装成差分隐私模型,只需要将网络传入`DPModel`即可。
......@@ -269,8 +281,9 @@ ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"),
```python
# Create the DP model for training.
model = DPModel(micro_batches=cfg.micro_batches,
norm_clip=cfg.l2_norm_bound,
dp_mech=gaussian_mech.mech,
norm_bound=cfg.norm_bound,
noise_mech=noise_mech,
clip_mech=clip_mech,
network=network,
loss_fn=net_loss,
optimizer=net_opt,
......@@ -279,18 +292,20 @@ ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"),
3. 模型训练与测试。
```python
LOGGER.info(TAG, "============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor],
dataset_sink_mode=cfg.dataset_sink_mode)
LOGGER.info(TAG, "============== Starting Testing ==============")
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt'
param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(network, param_dict)
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), batch_size=cfg.batch_size)
acc = model.eval(ds_eval, dataset_sink_mode=False)
LOGGER.info(TAG, "============== Accuracy: %s ==============", acc)
```python
LOGGER.info(TAG, "============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train,
callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor],
dataset_sink_mode=cfg.dataset_sink_mode)
LOGGER.info(TAG, "============== Starting Testing ==============")
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_234.ckpt'
param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(network, param_dict)
ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'),
batch_size=cfg.batch_size)
acc = model.eval(ds_eval, dataset_sink_mode=False)
LOGGER.info(TAG, "============== Accuracy: %s ==============", acc)
```
4. 运行命令。
......@@ -298,20 +313,20 @@ ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"),
运行脚本,可在命令行输入命令:
```bash
python lenet5_dp_model_train.py
python lenet_dp.py
```
其中`lenet5_dp_model_train.py`替换成你的脚本的名字。
其中`lenet5_dp.py`替换成你的脚本的名字。
5. 结果展示。
不加差分隐私的LeNet模型精度稳定在99%,加了自适应差分隐私AdaDP的LeNet模型收敛,精度稳定在98%。
不加差分隐私的LeNet模型精度稳定在99%,加了Gaussian噪声,自适应Clip的差分隐私LeNet模型收敛,精度稳定在97.6%。
```
============== Starting Training ==============
...
============== Starting Testing ==============
...
============== Accuracy: 0.9879 ==============
============== Accuracy: 0.9767 ==============
```
### 引用
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册