diff --git a/example/membership_inference_demo/eval.py b/example/membership_inference_demo/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..0b735ac87bd28ce53e8d18d2542f153ba381fde5 --- /dev/null +++ b/example/membership_inference_demo/eval.py @@ -0,0 +1,132 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Eval""" +import os +import argparse +import datetime +import mindspore.nn as nn + +from mindspore import context +from mindspore.nn.optim.momentum import Momentum +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype +from mindarmour.utils import LogUtil + +from vgg.vgg import vgg16 +from vgg.dataset import vgg_create_dataset100 +from vgg.config import cifar_cfg as cfg + + +class ParameterReduce(nn.Cell): + """ParameterReduce""" + def __init__(self): + super(ParameterReduce, self).__init__() + self.cast = P.Cast() + self.reduce = P.AllReduce() + + def construct(self, x): + one = self.cast(F.scalar_to_array(1.0), mstype.float32) + out = x*one + ret = self.reduce(out) + return ret + + +def parse_args(cloud_args=None): + """parse_args""" + parser = argparse.ArgumentParser('mindspore classification test') + parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], + help='device where the code will be implemented. (Default: Ascend)') + # dataset related + parser.add_argument('--data_path', type=str, default='', help='eval data dir') + parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per npu') + # network related + parser.add_argument('--graph_ckpt', type=int, default=1, help='graph ckpt or feed ckpt') + parser.add_argument('--pre_trained', default='', type=str, help='fully path of pretrained model to load. ' + 'If it is a direction, it will test all ckpt') + + # logging related + parser.add_argument('--log_path', type=str, default='outputs/', help='path to save log') + parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') + parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') + + args_opt = parser.parse_args() + args_opt = merge_args(args_opt, cloud_args) + + args_opt.image_size = cfg.image_size + args_opt.num_classes = cfg.num_classes + args_opt.per_batch_size = cfg.batch_size + args_opt.momentum = cfg.momentum + args_opt.weight_decay = cfg.weight_decay + args_opt.buffer_size = cfg.buffer_size + args_opt.pad_mode = cfg.pad_mode + args_opt.padding = cfg.padding + args_opt.has_bias = cfg.has_bias + args_opt.batch_norm = cfg.batch_norm + args_opt.initialize_mode = cfg.initialize_mode + args_opt.has_dropout = cfg.has_dropout + + args_opt.image_size = list(map(int, args_opt.image_size.split(','))) + + return args_opt + + +def merge_args(args, cloud_args): + """merge_args""" + args_dict = vars(args) + if isinstance(cloud_args, dict): + for key in cloud_args.keys(): + val = cloud_args[key] + if key in args_dict and val: + arg_type = type(args_dict[key]) + if arg_type is not type(None): + val = arg_type(val) + args_dict[key] = val + return args + + +def test(cloud_args=None): + """test""" + args = parse_args(cloud_args) + context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, + device_target=args.device_target, save_graphs=False) + if os.getenv('DEVICE_ID', "not_set").isdigit(): + context.set_context(device_id=int(os.getenv('DEVICE_ID'))) + + args.outputs_dir = os.path.join(args.log_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + + args.logger = LogUtil.get_instance() + args.logger.set_level(20) + + net = vgg16(num_classes=args.num_classes, args=args) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, args.momentum, + weight_decay=args.weight_decay) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + + param_dict = load_checkpoint(args.pre_trained) + load_param_into_net(net, param_dict) + net.set_train(False) + + dataset_test = vgg_create_dataset100(args.data_path, args.image_size, args.per_batch_size, training=False) + res = model.eval(dataset_test) + print("result: ", res) + + +if __name__ == "__main__": + test() diff --git a/example/membership_inference_demo/main.py b/example/membership_inference_demo/main.py new file mode 100644 index 0000000000000000000000000000000000000000..d87a3c6130f999fe999325ebdad7939b70b335e5 --- /dev/null +++ b/example/membership_inference_demo/main.py @@ -0,0 +1,122 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Examples of membership inference +""" +import argparse +import sys + +from vgg.vgg import vgg16 +from vgg.config import cifar_cfg as cfg +from vgg.utils.util import get_param_groups +from vgg.dataset import vgg_create_dataset100 + +import numpy as np + +from mindspore.train import Model +from mindspore.train.serialization import load_param_into_net, load_checkpoint +import mindspore.nn as nn +from mindarmour.diff_privacy.evaluation.membership_inference import MembershipInference +from mindarmour.utils import LogUtil +logging = LogUtil.get_instance() +logging.set_level(20) + +sys.path.append("../../") + +TAG = "membership inference example" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("main case arg parser.") + parser.add_argument("--device_target", type=str, default="Ascend", + choices=["Ascend"]) + parser.add_argument("--data_path", type=str, required=True, + help="Data home path for Cifar100.") + parser.add_argument("--pre_trained", type=str, required=True, + help="Checkpoint path.") + args = parser.parse_args() + args.num_classes = cfg.num_classes + args.batch_norm = cfg.batch_norm + args.has_dropout = cfg.has_dropout + args.has_bias = cfg.has_bias + args.initialize_mode = cfg.initialize_mode + args.padding = cfg.padding + args.pad_mode = cfg.pad_mode + args.weight_decay = cfg.weight_decay + args.loss_scale = cfg.loss_scale + + # load the pretrained model + net = vgg16(args.num_classes, args) + loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + opt = nn.Momentum(params=get_param_groups(net), learning_rate=0.1, momentum=0.9, + weight_decay=args.weight_decay, loss_scale=args.loss_scale) + load_param_into_net(net, load_checkpoint(args.pre_trained)) + model = Model(network=net, loss_fn=loss, optimizer=opt) + logging.info(TAG, "The model is loaded.") + attacker = MembershipInference(model) + config = [ + { + "method": "knn", + "params": { + "n_neighbors": [3, 5, 7] + } + }, + { + "method": "lr", + "params": { + "C": np.logspace(-4, 2, 10) + } + }, + { + "method": "mlp", + "params": { + "hidden_layer_sizes": [(64,), (32, 32)], + "solver": ["adam"], + "alpha": [0.0001, 0.001, 0.01] + } + }, + { + "method": "rf", + "params": { + "n_estimators": [100], + "max_features": ["auto", "sqrt"], + "max_depth": [5, 10, 20, None], + "min_samples_split": [2, 5, 10], + "min_samples_leaf": [1, 2, 4] + } + } + ] + + # load and split dataset + train_dataset = vgg_create_dataset100(data_home=args.data_path, image_size=(224, 224), + batch_size=64, num_samples=10000, shuffle=False) + test_dataset = vgg_create_dataset100(data_home=args.data_path, image_size=(224, 224), + batch_size=64, num_samples=10000, shuffle=False, training=False) + train_train, eval_train = train_dataset.split([0.8, 0.2]) + train_test, eval_test = test_dataset.split([0.8, 0.2]) + logging.info(TAG, "Data loading is complete.") + + logging.info(TAG, "Start training the inference model.") + attacker.train(train_train, train_test, config) + logging.info(TAG, "The inference model is training complete.") + + logging.info(TAG, "Start the evaluation phase") + metrics = ["precision", "accuracy", "recall"] + result = attacker.eval(eval_train, eval_test, metrics) + + # Show the metrics for each attack method. + count = len(config) + for i in range(count): + print("Method: {}, {}".format(config[i]["method"], result[i])) diff --git a/example/membership_inference_demo/train.py b/example/membership_inference_demo/train.py new file mode 100644 index 0000000000000000000000000000000000000000..944da0b3bead0fc8a472a8018c9f956b51a5d06d --- /dev/null +++ b/example/membership_inference_demo/train.py @@ -0,0 +1,198 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +#################train vgg16 example on cifar10######################## +python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID +""" +import argparse +import datetime +import os +import random + +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.nn.optim.momentum import Momentum +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig +from mindspore.train.model import Model +from mindspore.train.serialization import load_param_into_net, load_checkpoint +from mindarmour.utils import LogUtil + +from vgg.dataset import vgg_create_dataset100 +from vgg.warmup_step_lr import warmup_step_lr +from vgg.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr +from vgg.warmup_step_lr import lr_steps +from vgg.utils.util import get_param_groups +from vgg.vgg import vgg16 +from vgg.config import cifar_cfg as cfg + +TAG = "train" + +random.seed(1) +np.random.seed(1) + + +def parse_args(cloud_args=None): + """parameters""" + parser = argparse.ArgumentParser('mindspore classification training') + parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], + help='device where the code will be implemented. (Default: Ascend)') + parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)') + + # dataset related + parser.add_argument('--data_path', type=str, default='', help='train data dir') + + # network related + parser.add_argument('--pre_trained', default='', type=str, help='model_path, local pretrained model to load') + parser.add_argument('--lr_gamma', type=float, default=0.1, + help='decrease lr by a factor of exponential lr_scheduler') + parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler') + parser.add_argument('--T_max', type=int, default=150, help='T-max in cosine_annealing scheduler') + + # logging and checkpoint related + parser.add_argument('--log_interval', type=int, default=100, help='logging interval') + parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location') + parser.add_argument('--ckpt_interval', type=int, default=2, help='ckpt_interval') + parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank') + + args_opt = parser.parse_args() + args_opt = merge_args(args_opt, cloud_args) + + args_opt.rank = 0 + args_opt.group_size = 1 + args_opt.label_smooth = cfg.label_smooth + args_opt.label_smooth_factor = cfg.label_smooth_factor + args_opt.lr_scheduler = cfg.lr_scheduler + args_opt.loss_scale = cfg.loss_scale + args_opt.max_epoch = cfg.max_epoch + args_opt.warmup_epochs = cfg.warmup_epochs + args_opt.lr = cfg.lr + args_opt.lr_init = cfg.lr_init + args_opt.lr_max = cfg.lr_max + args_opt.momentum = cfg.momentum + args_opt.weight_decay = cfg.weight_decay + args_opt.per_batch_size = cfg.batch_size + args_opt.num_classes = cfg.num_classes + args_opt.buffer_size = cfg.buffer_size + args_opt.ckpt_save_max = cfg.keep_checkpoint_max + args_opt.pad_mode = cfg.pad_mode + args_opt.padding = cfg.padding + args_opt.has_bias = cfg.has_bias + args_opt.batch_norm = cfg.batch_norm + args_opt.initialize_mode = cfg.initialize_mode + args_opt.has_dropout = cfg.has_dropout + + args_opt.lr_epochs = list(map(int, cfg.lr_epochs.split(','))) + args_opt.image_size = list(map(int, cfg.image_size.split(','))) + + return args_opt + + +def merge_args(args_opt, cloud_args): + """dictionary""" + args_dict = vars(args_opt) + if isinstance(cloud_args, dict): + for key_arg in cloud_args.keys(): + val = cloud_args[key_arg] + if key_arg in args_dict and val: + arg_type = type(args_dict[key_arg]) + if arg_type is not None: + val = arg_type(val) + args_dict[key_arg] = val + return args_opt + + +if __name__ == '__main__': + args = parse_args() + + device_num = int(os.environ.get("DEVICE_NUM", 1)) + + context.set_context(device_id=args.device_id) + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + + # select for master rank save ckpt or all rank save, compatiable for model parallel + args.rank_save_ckpt_flag = 0 + if args.is_save_on_master: + if args.rank == 0: + args.rank_save_ckpt_flag = 1 + else: + args.rank_save_ckpt_flag = 1 + + # logger + args.outputs_dir = os.path.join(args.ckpt_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + args.logger = LogUtil.get_instance() + args.logger.set_level(20) + + # load train data set + dataset = vgg_create_dataset100(args.data_path, args.image_size, args.per_batch_size, args.rank, args.group_size) + batch_num = dataset.get_dataset_size() + args.steps_per_epoch = dataset.get_dataset_size() + + # network + args.logger.info(TAG, 'start create network') + + # get network and init + network = vgg16(args.num_classes, args) + + # pre_trained + if args.pre_trained: + load_param_into_net(network, load_checkpoint(args.pre_trained)) + + # lr scheduler + if args.lr_scheduler == 'exponential': + lr = warmup_step_lr(args.lr, + args.lr_epochs, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + gamma=args.lr_gamma, + ) + elif args.lr_scheduler == 'cosine_annealing': + lr = warmup_cosine_annealing_lr(args.lr, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + args.T_max, + args.eta_min) + elif args.lr_scheduler == 'step': + lr = lr_steps(0, lr_init=args.lr_init, lr_max=args.lr_max, warmup_epochs=args.warmup_epochs, + total_epochs=args.max_epoch, steps_per_epoch=batch_num) + else: + raise NotImplementedError(args.lr_scheduler) + + # optimizer + opt = Momentum(params=get_param_groups(network), + learning_rate=Tensor(lr), + momentum=args.momentum, + weight_decay=args.weight_decay, + loss_scale=args.loss_scale) + + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) + model = Model(network, loss_fn=loss, optimizer=opt, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) + + # checkpoint save + if args.rank_save_ckpt_flag: + ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval*args.steps_per_epoch, + keep_checkpoint_max=args.ckpt_save_max) + ckpt_cb = ModelCheckpoint(config=ckpt_config, + directory=args.outputs_dir, + prefix='{}'.format(args.rank)) + callbacks = ckpt_cb + + model.train(args.max_epoch, dataset, callbacks=callbacks) diff --git a/example/membership_inference_demo/vgg/__init__.py b/example/membership_inference_demo/vgg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..301ef9dcb71d51fb0b849da4c221c67947ab09df --- /dev/null +++ b/example/membership_inference_demo/vgg/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# httpwww.apache.orglicensesLICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/example/membership_inference_demo/vgg/config.py b/example/membership_inference_demo/vgg/config.py new file mode 100755 index 0000000000000000000000000000000000000000..86b9192332aece178e3c1e2d15a35d4fb14a600c --- /dev/null +++ b/example/membership_inference_demo/vgg/config.py @@ -0,0 +1,45 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py and eval.py +""" +from easydict import EasyDict as edict + +# config for vgg16, cifar100 +cifar_cfg = edict({ + "num_classes": 100, + "lr": 0.01, + "lr_init": 0.01, + "lr_max": 0.1, + "lr_epochs": '30,60,90,120', + "lr_scheduler": "step", + "warmup_epochs": 5, + "batch_size": 64, + "max_epoch": 100, + "momentum": 0.9, + "weight_decay": 5e-4, + "loss_scale": 1.0, + "label_smooth": 0, + "label_smooth_factor": 0, + "buffer_size": 10, + "image_size": '224,224', + "pad_mode": 'same', + "padding": 0, + "has_bias": False, + "batch_norm": True, + "keep_checkpoint_max": 10, + "initialize_mode": "XavierUniform", + "has_dropout": False +}) diff --git a/example/membership_inference_demo/vgg/crossentropy.py b/example/membership_inference_demo/vgg/crossentropy.py new file mode 100755 index 0000000000000000000000000000000000000000..5118cb5161218035cc88ac18dfce6ea086322566 --- /dev/null +++ b/example/membership_inference_demo/vgg/crossentropy.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""define loss function for network""" +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore import Tensor +from mindspore.common import dtype as mstype +import mindspore.nn as nn + + +class CrossEntropy(_Loss): + """the redefined loss function with SoftmaxCrossEntropyWithLogits""" + + def __init__(self, smooth_factor=0., num_classes=1001): + super(CrossEntropy, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + + def construct(self, logit, label): + one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) + loss = self.ce(logit, one_hot_label) + loss = self.mean(loss, 0) + return loss diff --git a/example/membership_inference_demo/vgg/dataset.py b/example/membership_inference_demo/vgg/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f687a15281258554374633e16749df23bd59b53a --- /dev/null +++ b/example/membership_inference_demo/vgg/dataset.py @@ -0,0 +1,75 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +dataset processing. +""" +import os +from mindspore.common import dtype as mstype +import mindspore.dataset as de +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as vision + + +def vgg_create_dataset100(data_home, image_size, batch_size, rank_id=0, rank_size=1, repeat_num=1, + training=True, num_samples=None, shuffle=True): + """Data operations.""" + de.config.set_seed(1) + data_dir = os.path.join(data_home, "train") + if not training: + data_dir = os.path.join(data_home, "test") + + if num_samples is not None: + data_set = de.Cifar100Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, + num_samples=num_samples, shuffle=shuffle) + else: + data_set = de.Cifar100Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) + + input_columns = ["fine_label"] + output_columns = ["label"] + data_set = data_set.rename(input_columns=input_columns, output_columns=output_columns) + data_set = data_set.project(["image", "label"]) + + rescale = 1.0 / 255.0 + shift = 0.0 + + # define map operations + random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT + random_horizontal_op = vision.RandomHorizontalFlip() + resize_op = vision.Resize(image_size) # interpolation default BILINEAR + rescale_op = vision.Rescale(rescale, shift) + normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)) + changeswap_op = vision.HWC2CHW() + type_cast_op = C.TypeCast(mstype.int32) + + c_trans = [] + if training: + c_trans = [random_crop_op, random_horizontal_op] + c_trans += [resize_op, rescale_op, normalize_op, + changeswap_op] + + # apply map operations on images + data_set = data_set.map(input_columns="label", operations=type_cast_op) + data_set = data_set.map(input_columns="image", operations=c_trans) + + # apply repeat operations + data_set = data_set.repeat(repeat_num) + + # apply shuffle operations + # data_set = data_set.shuffle(buffer_size=1000) + + # apply batch operations + data_set = data_set.batch(batch_size=batch_size, drop_remainder=True) + + return data_set diff --git a/example/membership_inference_demo/vgg/linear_warmup.py b/example/membership_inference_demo/vgg/linear_warmup.py new file mode 100644 index 0000000000000000000000000000000000000000..716526046fc6ca82e808f8df54cf0f5478e09967 --- /dev/null +++ b/example/membership_inference_demo/vgg/linear_warmup.py @@ -0,0 +1,23 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +linear warm up learning rate. +""" + + +def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + lr = float(init_lr) + lr_inc*current_step + return lr diff --git a/example/membership_inference_demo/vgg/utils/util.py b/example/membership_inference_demo/vgg/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..6f84045a8981a293eceab9cb77c20c66a339cf59 --- /dev/null +++ b/example/membership_inference_demo/vgg/utils/util.py @@ -0,0 +1,36 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Util class or function.""" + + +def get_param_groups(network): + """Param groups for optimizer.""" + decay_params = [] + no_decay_params = [] + for x in network.trainable_params(): + parameter_name = x.name + if parameter_name.endswith('.bias'): + # all bias not using weight decay + no_decay_params.append(x) + elif parameter_name.endswith('.gamma'): + # bn weight bias not using weight decay, be carefully for now x not include BN + no_decay_params.append(x) + elif parameter_name.endswith('.beta'): + # bn weight bias not using weight decay, be carefully for now x not include BN + no_decay_params.append(x) + else: + decay_params.append(x) + + return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] diff --git a/example/membership_inference_demo/vgg/utils/var_init.py b/example/membership_inference_demo/vgg/utils/var_init.py new file mode 100644 index 0000000000000000000000000000000000000000..053e6d03f0635da2db81a7730bf7922d42f898ac --- /dev/null +++ b/example/membership_inference_demo/vgg/utils/var_init.py @@ -0,0 +1,214 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Initialize. +""" +import math +from functools import reduce +import numpy as np +import mindspore.nn as nn +from mindspore.common import initializer as init + +def _calculate_gain(nonlinearity, param=None): + r""" + Return the recommended gain value for the given nonlinearity function. + + The values are as follows: + ================= ==================================================== + nonlinearity gain + ================= ==================================================== + Linear / Identity :math:`1` + Conv{1,2,3}D :math:`1` + Sigmoid :math:`1` + Tanh :math:`\frac{5}{3}` + ReLU :math:`\sqrt{2}` + Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` + ================= ==================================================== + + Args: + nonlinearity: the non-linear function + param: optional parameter for the non-linear function + + Examples: + >>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 + """ + linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + return 1 + if nonlinearity == 'tanh': + return 5.0 / 3 + if nonlinearity == 'relu': + return math.sqrt(2.0) + if nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + return math.sqrt(2.0 / (1 + negative_slope**2)) + + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + +def _assignment(arr, num): + """Assign the value of `num` to `arr`.""" + if arr.shape == (): + arr = arr.reshape((1)) + arr[:] = num + arr = arr.reshape(()) + else: + if isinstance(num, np.ndarray): + arr[:] = num[:] + else: + arr[:] = num + return arr + +def _calculate_in_and_out(arr): + """ + Calculate n_in and n_out. + + Args: + arr (Array): Input array. + + Returns: + Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`. + """ + dim = len(arr.shape) + if dim < 2: + raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.") + + n_in = arr.shape[1] + n_out = arr.shape[0] + + if dim > 2: + counter = reduce(lambda x, y: x*y, arr.shape[2:]) + n_in *= counter + n_out *= counter + return n_in, n_out + +def _select_fan(array, mode): + mode = mode.lower() + valid_modes = ['fan_in', 'fan_out'] + if mode not in valid_modes: + raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) + + fan_in, fan_out = _calculate_in_and_out(array) + return fan_in if mode == 'fan_in' else fan_out + +class KaimingInit(init.Initializer): + r""" + Base Class. Initialize the array with He kaiming algorithm. + + Args: + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function, recommended to use only with + ``'relu'`` or ``'leaky_relu'`` (default). + """ + def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): + super(KaimingInit, self).__init__() + self.mode = mode + self.gain = _calculate_gain(nonlinearity, a) + def _initialize(self, arr): + pass + + +class KaimingUniform(KaimingInit): + r""" + Initialize the array with He kaiming uniform algorithm. The resulting tensor will + have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + + Input: + arr (Array): The array to be assigned. + + Returns: + Array, assigned array. + + Examples: + >>> w = np.empty(3, 5) + >>> KaimingUniform(w, mode='fan_in', nonlinearity='relu') + """ + + def _initialize(self, arr): + fan = _select_fan(arr, self.mode) + bound = math.sqrt(3.0)*self.gain / math.sqrt(fan) + np.random.seed(0) + data = np.random.uniform(-bound, bound, arr.shape) + + _assignment(arr, data) + + +class KaimingNormal(KaimingInit): + r""" + Initialize the array with He kaiming normal algorithm. The resulting tensor will + have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} + + Input: + arr (Array): The array to be assigned. + + Returns: + Array, assigned array. + + Examples: + >>> w = np.empty(3, 5) + >>> KaimingNormal(w, mode='fan_out', nonlinearity='relu') + """ + + def _initialize(self, arr): + fan = _select_fan(arr, self.mode) + std = self.gain / math.sqrt(fan) + np.random.seed(0) + data = np.random.normal(0, std, arr.shape) + + _assignment(arr, data) + + +def default_recurisive_init(custom_cell): + """default_recurisive_init""" + for _, cell in custom_cell.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), + cell.weight.shape, + cell.weight.dtype) + if cell.bias is not None: + fan_in, _ = _calculate_in_and_out(cell.weight) + bound = 1 / math.sqrt(fan_in) + np.random.seed(0) + cell.bias.default_input = init.initializer(init.Uniform(bound), + cell.bias.shape, + cell.bias.dtype) + elif isinstance(cell, nn.Dense): + cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), + cell.weight.shape, + cell.weight.dtype) + if cell.bias is not None: + fan_in, _ = _calculate_in_and_out(cell.weight) + bound = 1 / math.sqrt(fan_in) + np.random.seed(0) + cell.bias.default_input = init.initializer(init.Uniform(bound), + cell.bias.shape, + cell.bias.dtype) + elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): + pass diff --git a/example/membership_inference_demo/vgg/vgg.py b/example/membership_inference_demo/vgg/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..c429ea16c02474290a74df360039373ae565dd34 --- /dev/null +++ b/example/membership_inference_demo/vgg/vgg.py @@ -0,0 +1,142 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Image classifiation. +""" +import math +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore.common import initializer as init +from mindspore.common.initializer import initializer +from .utils.var_init import default_recurisive_init, KaimingNormal + + +def _make_layer(base, args, batch_norm): + """Make stage network of VGG.""" + layers = [] + in_channels = 3 + for v in base: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + weight_shape = (v, in_channels, 3, 3) + weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor() + if args.initialize_mode == "KaimingNormal": + weight = 'normal' + conv2d = nn.Conv2d(in_channels=in_channels, + out_channels=v, + kernel_size=3, + padding=args.padding, + pad_mode=args.pad_mode, + has_bias=args.has_bias, + weight_init=weight) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()] + else: + layers += [conv2d, nn.ReLU()] + in_channels = v + return nn.SequentialCell(layers) + + +class Vgg(nn.Cell): + """ + VGG network definition. + + Args: + base (list): Configuration for different layers, mainly the channel number of Conv layer. + num_classes (int): Class numbers. Default: 1000. + batch_norm (bool): Whether to do the batchnorm. Default: False. + batch_size (int): Batch size. Default: 1. + + Returns: + Tensor, infer output tensor. + + Examples: + >>> Vgg([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + >>> num_classes=1000, batch_norm=False, batch_size=1) + """ + + def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train"): + super(Vgg, self).__init__() + _ = batch_size + self.layers = _make_layer(base, args, batch_norm=batch_norm) + self.flatten = nn.Flatten() + dropout_ratio = 0.5 + if not args.has_dropout or phase == "test": + dropout_ratio = 1.0 + self.classifier = nn.SequentialCell([ + nn.Dense(512*7*7, 4096), + nn.ReLU(), + nn.Dropout(dropout_ratio), + nn.Dense(4096, 4096), + nn.ReLU(), + nn.Dropout(dropout_ratio), + nn.Dense(4096, num_classes)]) + if args.initialize_mode == "KaimingNormal": + default_recurisive_init(self) + self.custom_init_weight() + + def construct(self, x): + x = self.layers(x) + x = self.flatten(x) + x = self.classifier(x) + return x + + def custom_init_weight(self): + """ + Init the weight of Conv2d and Dense in the net. + """ + for _, cell in self.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.default_input = init.initializer( + KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'), + cell.weight.shape, cell.weight.dtype) + if cell.bias is not None: + cell.bias.default_input = init.initializer( + 'zeros', cell.bias.shape, cell.bias.dtype) + elif isinstance(cell, nn.Dense): + cell.weight.default_input = init.initializer( + init.Normal(0.01), cell.weight.shape, cell.weight.dtype) + if cell.bias is not None: + cell.bias.default_input = init.initializer( + 'zeros', cell.bias.shape, cell.bias.dtype) + + +cfg = { + '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +def vgg16(num_classes=1000, args=None, phase="train"): + """ + Get Vgg16 neural network with batch normalization. + + Args: + num_classes (int): Class numbers. Default: 1000. + args(namespace): param for net init. + phase(str): train or test mode. + + Returns: + Cell, cell instance of Vgg16 neural network with batch normalization. + + Examples: + >>> vgg16(num_classes=1000, args=args) + """ + + net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase) + return net diff --git a/example/membership_inference_demo/vgg/warmup_cosine_annealing_lr.py b/example/membership_inference_demo/vgg/warmup_cosine_annealing_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..fe1ded41ccbd3ab893f8022ba020844dbc4c7be5 --- /dev/null +++ b/example/membership_inference_demo/vgg/warmup_cosine_annealing_lr.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +warm up cosine annealing learning rate. +""" +import math +import numpy as np + +from .linear_warmup import linear_warmup_lr + + +def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, t_max, eta_min=0): + """warm up cosine annealing learning rate.""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch*steps_per_epoch) + warmup_steps = int(warmup_epochs*steps_per_epoch) + + lr_each_step = [] + for i in range(total_steps): + last_epoch = i // steps_per_epoch + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = eta_min + (base_lr - eta_min)*(1. + math.cos(math.pi*last_epoch / t_max)) / 2 + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) diff --git a/example/membership_inference_demo/vgg/warmup_step_lr.py b/example/membership_inference_demo/vgg/warmup_step_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..1672d8c28b4e1a0eb5418c1323abce62b63445dd --- /dev/null +++ b/example/membership_inference_demo/vgg/warmup_step_lr.py @@ -0,0 +1,84 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +warm up step learning rate. +""" +from collections import Counter +import numpy as np + +from .linear_warmup import linear_warmup_lr + + +def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch): + """Set learning rate.""" + lr_each_step = [] + total_steps = steps_per_epoch*total_epochs + warmup_steps = steps_per_epoch*warmup_epochs + if warmup_steps != 0: + inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) + else: + inc_each_step = 0 + for i in range(total_steps): + if i < warmup_steps: + lr_value = float(lr_init) + inc_each_step*float(i) + else: + base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) + lr_value = float(lr_max)*base*base + if lr_value < 0.0: + lr_value = 0.0 + lr_each_step.append(lr_value) + + current_step = global_step + lr_each_step = np.array(lr_each_step).astype(np.float32) + learning_rate = lr_each_step[current_step:] + + return learning_rate + + +def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1): + """warmup_step_lr""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch*steps_per_epoch) + warmup_steps = int(warmup_epochs*steps_per_epoch) + milestones = lr_epochs + milestones_steps = [] + for milestone in milestones: + milestones_step = milestone*steps_per_epoch + milestones_steps.append(milestones_step) + + lr_each_step = [] + lr = base_lr + milestones_steps_counter = Counter(milestones_steps) + for i in range(total_steps): + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = lr*gamma**milestones_steps_counter[i] + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1): + return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma) + + +def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1): + lr_epochs = [] + for i in range(1, max_epoch): + if i % epoch_size == 0: + lr_epochs.append(i) + return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma) diff --git a/mindarmour/diff_privacy/evaluation/attacker.py b/mindarmour/diff_privacy/evaluation/attacker.py index 1e70eef4a5989ab53c00a32ec8320760ea2cd744..f4c1c2476269c2f23d793e593924b942a85b7e19 100644 --- a/mindarmour/diff_privacy/evaluation/attacker.py +++ b/mindarmour/diff_privacy/evaluation/attacker.py @@ -32,7 +32,7 @@ def _attack_knn(features, labels, param_grid): param_grid (dict): Setting of GridSearchCV. Returns: - sklearn.neighbors.KNeighborsClassifier, trained model. + sklearn.model_selection.GridSearchCV, trained model. """ knn_model = KNeighborsClassifier() knn_model = GridSearchCV( @@ -53,9 +53,9 @@ def _attack_lr(features, labels, param_grid): param_grid (dict): Setting of GridSearchCV. Returns: - sklearn.linear_model.LogisticRegression, trained model. + sklearn.model_selection.GridSearchCV, trained model. """ - lr_model = LogisticRegression(C=1.0, penalty="l2") + lr_model = LogisticRegression(C=1.0, penalty="l2", max_iter=1000) lr_model = GridSearchCV( lr_model, param_grid=param_grid, cv=3, n_jobs=1, iid=False, verbose=0, @@ -74,7 +74,7 @@ def _attack_mlpc(features, labels, param_grid): param_grid (dict): Setting of GridSearchCV. Returns: - sklearn.neural_network.MLPClassifier, trained model. + sklearn.model_selection.GridSearchCV, trained model. """ mlpc_model = MLPClassifier(random_state=1, max_iter=300) mlpc_model = GridSearchCV( @@ -95,7 +95,7 @@ def _attack_rf(features, labels, random_grid): random_grid (dict): Setting of RandomizedSearchCV. Returns: - sklearn.ensemble.RandomForestClassifier, trained model. + sklearn.model_selection.RandomizedSearchCV, trained model. """ rf_model = RandomForestClassifier(max_depth=2, random_state=0) rf_model = RandomizedSearchCV( diff --git a/mindarmour/diff_privacy/evaluation/membership_inference.py b/mindarmour/diff_privacy/evaluation/membership_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..3e53c8f1cba18d6e4f7f6460e306fea0eaad96f6 --- /dev/null +++ b/mindarmour/diff_privacy/evaluation/membership_inference.py @@ -0,0 +1,197 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Membership Inference +""" + +import numpy as np + +import mindspore as ms +from mindspore.train import Model +import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor +from mindarmour.diff_privacy.evaluation.attacker import get_attack_model + +def _eval_info(pred, truth, option): + """ + Calculate the performance according to pred and truth. + + Args: + pred (numpy.ndarray): Predictions for each sample. + truth (numpy.ndarray): Ground truth for each sample. + option(str): Type of evaluation indicators; Possible + values are 'precision', 'accuracy' and 'recall'. + + Returns: + float32, Calculated evaluation results. + + Raises: + ValueError, size of parameter pred or truth is 0. + ValueError, value of parameter option must be in ["precision", "accuracy", "recall"]. + """ + if pred.size == 0 || truth.size == 0: + raise ValueError("Size of pred or truth is 0.") + + if option == "accuracy": + count = np.sum(pred == truth) + return count / len(pred) + if option == "precision": + count = np.sum(pred & truth) + if np.sum(pred) == 0: + return -1 + return count / np.sum(pred) + if option == "recall": + count = np.sum(pred & truth) + if np.sum(truth) == 0: + return -1 + return count / np.sum(truth) + + raise ValueError("The metric value {} is undefined.".format(option)) + + +class MembershipInference: + """ + Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack. + The attack requires obtain loss or logits results of training samples. + + References: Reza Shokri, Marco Stronati, Congzheng Song, Vitaly Shmatikov. + Membership Inference Attacks against Machine Learning Models. 2017. + arXiv:1610.05820v2 `_ + + Args: + model (Model): Target model. + + Examples: + >>> # ds_train, eval_train are non-overlapping datasets from training dataset. + >>> # eval_train, eval_test are non-overlapping datasets from test dataset. + >>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) + >>> inference_model = MembershipInference(model) + >>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] + >>> inference_model.train(ds_train, ds_test, config) + >>> metrics = ["precision", "recall", "accuracy"] + >>> result = inference_model.eval(eval_train, eval_test, metrics) + + Raises: + TypeError: If type of model is not mindspore.train.Model. + """ + + def __init__(self, model): + if not isinstance(model, Model): + raise TypeError("Type of model must be {}, but got {}.".format(type(Model), type(model))) + self.model = model + self.attack_list = [] + + def train(self, dataset_train, dataset_test, attack_config): + """ + Depending on the configuration, use the incoming data set to train the attack model. + Save the attack model to self.attack_list. + + Args: + dataset_train (mindspore.dataset): The training dataset for the target model. + dataset_test (mindspore.dataset): The test set for the target model. + attack_config (list): Parameter setting for the attack model. + + Raises: + ValueError: If the method in attack_config is not in ["LR", "KNN", "RF", "MLPC"]. + """ + features, labels = self._transform(dataset_train, dataset_test) + for config in attack_config: + self.attack_list.append(get_attack_model(features, labels, config)) + + def eval(self, dataset_train, dataset_test, metrics): + """ + Evaluate the different privacy of the target model. + Evaluation indicators shall be specified by metrics. + + Args: + dataset_train (mindspore.dataset): The training dataset for the target model. + dataset_test (mindspore.dataset): The test dataset for the target model. + metrics (Union[list, tuple]): Evaluation indicators. The value of metrics + must be in ["precision", "accuracy", "recall"]. Default: ["precision"]. + + Returns: + list, Each element contains an evaluation indicator for the attack model. + """ + result = [] + features, labels = self._transform(dataset_train, dataset_test) + for attacker in self.attack_list: + pred = attacker.predict(features) + item = {} + for option in metrics: + item[option] = _eval_info(pred, labels, option) + result.append(item) + return result + + def _transform(self, dataset_train, dataset_test): + """ + Generate corresponding loss_logits feature and new label, and return after shuffle. + + Args: + dataset_train: The training set for the target model. + dataset_test: The test set for the target model. + + Returns: + - numpy.ndarray, Loss_logits features for each sample. Shape is (N, C). + N is the number of sample. C = 1 + dim(logits). + - numpy.ndarray, Labels for each sample, Shape is (N,). + """ + features_train, labels_train = self._generate(dataset_train, 1) + features_test, labels_test = self._generate(dataset_test, 0) + features = np.vstack((features_train, features_test)) + labels = np.hstack((labels_train, labels_test)) + shuffle_index = np.array(range(len(labels))) + np.random.shuffle(shuffle_index) + features = features[shuffle_index] + labels = labels[shuffle_index] + return features, labels + + def _generate(self, dataset_x, label): + """ + Return a loss_logits features and labels for training attack model. + + Args: + dataset_x (mindspore.dataset): The dataset to be generate. + label (int32): Whether dataset_x belongs to the target model. + + Returns: + - numpy.ndarray, Loss_logits features for each sample. Shape is (N, C). + N is the number of sample. C = 1 + dim(logits). + - numpy.ndarray, Labels for each sample, Shape is (N,). + """ + if context.get_context("device_target") != "Ascend": + raise RuntimeError("The target device must be Ascend, " + "but current is {}.".format(context.get_context("device_target"))) + loss_logits = np.array([]) + for batch in dataset_x.create_dict_iterator(): + batch_data = Tensor(batch['image'], ms.float32) + batch_labels = Tensor(batch['label'], ms.int32) + batch_logits = self.model.predict(batch_data) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction=None) + batch_loss = loss(batch_logits, batch_labels).asnumpy() + batch_logits = batch_logits.asnumpy() + + batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits)) + if loss_logits.size == 0: + loss_logits = batch_feature + else: + loss_logits = np.vstack((loss_logits, batch_feature)) + + if label == 1: + labels = np.ones(len(loss_logits), np.int32) + elif label == 0: + labels = np.zeros(len(loss_logits), np.int32) + else: + raise ValueError("The value of label must be 0 or 1, but got {}.".format(label)) + return loss_logits, labels diff --git a/tests/ut/python/diff_privacy/test_membership_inference.py b/tests/ut/python/diff_privacy/test_membership_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..c11d45a4c46b2d74eac4f15a02e8f6f9958bb079 --- /dev/null +++ b/tests/ut/python/diff_privacy/test_membership_inference.py @@ -0,0 +1,111 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +membership inference test +""" +import os +import sys + +import pytest + +import numpy as np + +import mindspore.dataset as ds +from mindspore import nn +from mindspore.train import Model + +from mindarmour.diff_privacy.evaluation.membership_inference import MembershipInference + +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) +from defenses.mock_net import Net + + +def dataset_generator(batch_size, batches): + """mock training data.""" + data = np.random.randn(batches*batch_size, 1, 32, 32).astype( + np.float32) + label = np.random.randint(0, 10, batches*batch_size).astype(np.int32) + for i in range(batches): + yield data[i*batch_size:(i + 1)*batch_size],\ + label[i*batch_size:(i + 1)*batch_size] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_get_membership_inference_object(): + net = Net() + loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + model = Model(network=net, loss_fn=loss, optimizer=opt) + inference_model = MembershipInference(model) + assert isinstance(inference_model, MembershipInference) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_membership_inference_object_train(): + net = Net() + loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + model = Model(network=net, loss_fn=loss, optimizer=opt) + inference_model = MembershipInference(model) + assert isinstance(inference_model, MembershipInference) + + config = [{ + "method": "KNN", + "params": { + "n_neighbors": [3, 5, 7], + } + }] + batch_size = 16 + batches = 1 + ds_train = ds.GeneratorDataset(dataset_generator(batch_size, batches), + ["image", "label"]) + ds_test = ds.GeneratorDataset(dataset_generator(batch_size, batches), + ["image", "label"]) + ds_train.set_dataset_size(batch_size*batches) + ds_test.set_dataset_size((batch_size*batches)) + inference_model.train(ds_train, ds_test, config) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_membership_inference_eval(): + net = Net() + loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + model = Model(network=net, loss_fn=loss, optimizer=opt) + inference_model = MembershipInference(model) + assert isinstance(inference_model, MembershipInference) + + batch_size = 16 + batches = 1 + eval_train = ds.GeneratorDataset(dataset_generator(batch_size, batches), + ["image", "label"]) + eval_test = ds.GeneratorDataset(dataset_generator(batch_size, batches), + ["image", "label"]) + eval_train.set_dataset_size(batch_size * batches) + eval_test.set_dataset_size((batch_size * batches)) + + metrics = ["precision", "accuracy", "recall"] + inference_model.eval(eval_train, eval_test, metrics)