提交 41752a9c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!24 Add train module and optimizer module for differential privacy.

Merge pull request !24 from zheng-huanhuan/dp_pynative
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py
"""
from easydict import EasyDict as edict
mnist_cfg = edict({
'num_classes': 10,
'lr': 0.01,
'momentum': 0.9,
'epoch_size': 10,
'batch_size': 32,
'buffer_size': 1000,
'image_height': 32,
'image_width': 32,
'save_checkpoint_steps': 1875,
'keep_checkpoint_max': 10,
})
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python lenet5_dp_model_train.py --data_path /YourDataPath --micro_batches=2
"""
import os
import argparse
import mindspore.nn as nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import CheckpointConfig
from mindspore.train.callback import LossMonitor
from mindspore.nn.metrics import Accuracy
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.transforms.vision import Inter
import mindspore.common.dtype as mstype
from mindarmour.diff_privacy import DPModel
from mindarmour.diff_privacy import DPOptimizerClassFactory
from mindarmour.diff_privacy import PrivacyMonitorFactory
from mindarmour.utils.logger import LogUtil
from lenet5_net import LeNet5
from lenet5_config import mnist_cfg as cfg
LOGGER = LogUtil.get_instance()
TAG = 'Lenet5_train'
def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1, sparse=True):
"""
create dataset for training or testing
"""
# define dataset
ds1 = ds.MnistDataset(data_path)
# define operation parameters
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
# define map operations
resize_op = CV.Resize((resize_height, resize_width),
interpolation=Inter.LINEAR)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# apply map operations on images
if not sparse:
one_hot_enco = C.OneHot(10)
ds1 = ds1.map(input_columns="label", operations=one_hot_enco,
num_parallel_workers=num_parallel_workers)
type_cast_op = C.TypeCast(mstype.float32)
ds1 = ds1.map(input_columns="label", operations=type_cast_op,
num_parallel_workers=num_parallel_workers)
ds1 = ds1.map(input_columns="image", operations=resize_op,
num_parallel_workers=num_parallel_workers)
ds1 = ds1.map(input_columns="image", operations=rescale_op,
num_parallel_workers=num_parallel_workers)
ds1 = ds1.map(input_columns="image", operations=hwc2chw_op,
num_parallel_workers=num_parallel_workers)
# apply DatasetOps
buffer_size = 10000
ds1 = ds1.shuffle(buffer_size=buffer_size)
ds1 = ds1.batch(batch_size, drop_remainder=True)
ds1 = ds1.repeat(repeat_size)
return ds1
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_unzip",
help='path where the dataset is saved')
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
parser.add_argument('--micro_batches', type=float, default=None,
help='optional, if use differential privacy, need to set micro_batches')
parser.add_argument('--l2_norm_bound', type=float, default=1,
help='optional, if use differential privacy, need to set l2_norm_bound')
parser.add_argument('--initial_noise_multiplier', type=float, default=0.001,
help='optional, if use differential privacy, need to set initial_noise_multiplier')
args = parser.parse_args()
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target, enable_mem_reuse=False)
network = LeNet5()
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory='./trained_ckpt_file/',
config=config_ck)
ds_train = generate_mnist_dataset(os.path.join(args.data_path, "train"),
cfg.batch_size,
cfg.epoch_size)
if args.micro_batches and cfg.batch_size % args.micro_batches != 0:
raise ValueError("Number of micro_batches should divide evenly batch_size")
gaussian_mech = DPOptimizerClassFactory(args.micro_batches)
gaussian_mech.set_mechanisms('Gaussian',
norm_bound=args.l2_norm_bound,
initial_noise_multiplier=args.initial_noise_multiplier)
net_opt = gaussian_mech.create('Momentum')(params=network.trainable_params(),
learning_rate=cfg.lr,
momentum=cfg.momentum)
micro_size = int(cfg.batch_size // args.micro_batches)
rdp_monitor = PrivacyMonitorFactory.create('rdp',
num_samples=60000,
batch_size=micro_size,
initial_noise_multiplier=args.initial_noise_multiplier,
per_print_times=10)
model = DPModel(micro_batches=args.micro_batches,
norm_clip=args.l2_norm_bound,
dp_mech=gaussian_mech.mech,
network=network,
loss_fn=net_loss,
optimizer=net_opt,
metrics={"Accuracy": Accuracy()})
LOGGER.info(TAG, "============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor],
dataset_sink_mode=args.dataset_sink_mode)
LOGGER.info(TAG, "============== Starting Testing ==============")
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(network, param_dict)
ds_eval = generate_mnist_dataset(os.path.join(args.data_path, 'test'), batch_size=cfg.batch_size)
acc = model.eval(ds_eval, dataset_sink_mode=False)
LOGGER.info(TAG, "============== Accuracy: %s ==============", acc)
......@@ -4,7 +4,13 @@ This module provide Differential Privacy feature to protect user privacy.
from .mechanisms.mechanisms import GaussianRandom
from .mechanisms.mechanisms import AdaGaussianRandom
from .mechanisms.mechanisms import MechanismsFactory
from .monitor.monitor import PrivacyMonitorFactory
from .optimizer.optimizer import DPOptimizerClassFactory
from .train.model import DPModel
__all__ = ['GaussianRandom',
'AdaGaussianRandom',
'MechanismsFactory']
'MechanismsFactory',
'PrivacyMonitorFactory',
'DPOptimizerClassFactory',
'DPModel']
......@@ -60,10 +60,6 @@ class Mechanisms(Cell):
"""
Basic class of noise generated mechanism.
"""
def __init__(self):
pass
def construct(self, shape):
"""
Construct function.
......
......@@ -47,7 +47,7 @@ class PrivacyMonitorFactory:
parameters used for creating a privacy monitor.
Returns:
PrivacyMonitor, a privacy monitor.
Callback, a privacy monitor.
Examples:
>>> rdp = PrivacyMonitorFactory.create(policy='rdp',
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Differential privacy optimizer.
"""
import mindspore as ms
from mindspore import nn
from mindspore import Tensor
from mindarmour.diff_privacy.mechanisms.mechanisms import MechanismsFactory
class DPOptimizerClassFactory:
"""
Factory class of Optimizer.
Args:
micro_batches (int): The number of small batches split from an origianl batch. Default: None.
Returns:
Optimizer, Optimizer class
Examples:
>>> GaussianSGD = DPOptimizerClassFactory(micro_batches=2)
>>> GaussianSGD.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0)
>>> net_opt = GaussianSGD.create('SGD')(params=network.trainable_params(),
>>> learning_rate=cfg.lr,
>>> momentum=cfg.momentum)
"""
def __init__(self, micro_batches=None):
self._mech_factory = MechanismsFactory()
self.mech = None
self._micro_batches = micro_batches
def set_mechanisms(self, policy, *args, **kwargs):
"""
Get noise mechanism object.
Args:
policy (str): Choose mechanism type.
"""
self.mech = self._mech_factory.create(policy, *args, **kwargs)
def create(self, policy, *args, **kwargs):
"""
Create DP optimizer.
Args:
policy (str): Choose original optimizer type.
Returns:
Optimizer, A optimizer with DP.
"""
if policy == 'SGD':
cls = self._get_dp_optimizer_class(nn.SGD, self.mech, self._micro_batches, *args, **kwargs)
return cls
if policy == 'Momentum':
cls = self._get_dp_optimizer_class(nn.Momentum, self.mech, self._micro_batches, *args, **kwargs)
return cls
if policy == 'Adam':
cls = self._get_dp_optimizer_class(nn.Adam, self.mech, self._micro_batches, *args, **kwargs)
return cls
if policy == 'AdamWeightDecay':
cls = self._get_dp_optimizer_class(nn.AdamWeightDecay, self.mech, self._micro_batches, *args, **kwargs)
return cls
if policy == 'AdamWeightDecayDynamicLR':
cls = self._get_dp_optimizer_class(nn.AdamWeightDecayDynamicLR,
self.mech,
self._micro_batches,
*args, **kwargs)
return cls
raise NameError("The {} is not implement, please choose ['SGD', 'Momentum', 'AdamWeightDecay', "
"'Adam', 'AdamWeightDecayDynamicLR']".format(policy))
def _get_dp_optimizer_class(self, cls, mech, micro_batches):
"""
Wrap original mindspore optimizer with `self._mech`.
"""
class DPOptimizer(cls):
"""
Initialize the DPOptimizerClass.
Returns:
Optimizer, Optimizer class.
"""
def __init__(self, *args, **kwargs):
super(DPOptimizer, self).__init__(*args, **kwargs)
self._mech = mech
def construct(self, gradients):
"""
construct a compute flow.
"""
g_len = len(gradients)
gradient_noise = list(gradients)
for i in range(g_len):
gradient_noise[i] = gradient_noise[i].asnumpy()
gradient_noise[i] = self._mech(gradient_noise[i].shape).asnumpy() + gradient_noise[i]
gradient_noise[i] = gradient_noise[i] / micro_batches
gradient_noise[i] = Tensor(gradient_noise[i], ms.float32)
gradients = tuple(gradient_noise)
gradients = super(DPOptimizer, self).construct(gradients)
return gradients
return DPOptimizer
此差异已折叠。
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DP-Model test.
"""
import pytest
import numpy as np
from mindspore import nn
from mindspore.nn import SGD
from mindspore.model_zoo.lenet import LeNet5
from mindspore import context
import mindspore.dataset as ds
from mindarmour.diff_privacy import DPOptimizerClassFactory
from mindarmour.diff_privacy import DPModel
def dataset_generator(batch_size, batches):
data = np.random.random((batches * batch_size, 1, 32, 32)).astype(np.float32)
label = np.random.randint(0, 10, batches * batch_size).astype(np.int32)
for i in range(batches):
yield data[i * batch_size:(i + 1) * batch_size], label[i * batch_size:(i + 1) * batch_size]
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_dp_model():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
l2_norm_bound = 1.0
initial_noise_multiplier = 0.01
net = LeNet5()
batch_size = 32
batches = 128
epochs = 1
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
optim = SGD(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
gaussian_mech = DPOptimizerClassFactory()
gaussian_mech.set_mechanisms('Gaussian',
norm_bound=l2_norm_bound,
initial_noise_multiplier=initial_noise_multiplier)
model = DPModel(micro_batches=2,
norm_clip=l2_norm_bound,
dp_mech=gaussian_mech.mech,
network=net,
loss_fn=loss,
optimizer=optim,
metrics=None)
ms_ds = ds.GeneratorDataset(dataset_generator(batch_size, batches), ['data', 'label'])
ms_ds.set_dataset_size(batch_size * batches)
model.train(epochs, ms_ds)
......@@ -23,7 +23,7 @@ from mindspore.train import Model
import mindspore.context as context
from mindspore.model_zoo.lenet import LeNet5
from mindarmour.diff_privacy.monitor.monitor import PrivacyMonitorFactory
from mindarmour.diff_privacy import PrivacyMonitorFactory
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from mindspore import nn
from mindspore import context
from mindspore.model_zoo.lenet import LeNet5
from mindspore.train.model import Model
from mindarmour.diff_privacy import DPOptimizerClassFactory
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_optimizer():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
network = LeNet5()
lr = 0.01
momentum = 0.9
micro_batches = 2
loss = nn.SoftmaxCrossEntropyWithLogits()
gaussian_mech = DPOptimizerClassFactory(micro_batches)
gaussian_mech.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0)
net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(), learning_rate=lr,
momentum=momentum)
_ = Model(network, loss_fn=loss, optimizer=net_opt, metrics=None)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_inference
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_optimizer_gpu():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
network = LeNet5()
lr = 0.01
momentum = 0.9
micro_batches = 2
loss = nn.SoftmaxCrossEntropyWithLogits()
gaussian_mech = DPOptimizerClassFactory(micro_batches)
gaussian_mech.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0)
net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(), learning_rate=lr,
momentum=momentum)
_ = Model(network, loss_fn=loss, optimizer=net_opt, metrics=None)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_optimizer_cpu():
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
network = LeNet5()
lr = 0.01
momentum = 0.9
micro_batches = 2
loss = nn.SoftmaxCrossEntropyWithLogits()
gaussian_mech = DPOptimizerClassFactory(micro_batches)
gaussian_mech.set_mechanisms('Gaussian', norm_bound=1.5, initial_noise_multiplier=5.0)
net_opt = gaussian_mech.create('SGD')(params=network.trainable_params(), learning_rate=lr,
momentum=momentum)
_ = Model(network, loss_fn=loss, optimizer=net_opt, metrics=None)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册