提交 55c31a33 编写于 作者: L liuluobin

Membership inference feature

上级 6cda5e2e
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Eval"""
import os
import argparse
import datetime
import mindspore.nn as nn
from mindspore import context
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindarmour.utils import LogUtil
from vgg.vgg import vgg16
from vgg.dataset import vgg_create_dataset100
from vgg.config import cifar_cfg as cfg
class ParameterReduce(nn.Cell):
"""ParameterReduce"""
def __init__(self):
super(ParameterReduce, self).__init__()
self.cast = P.Cast()
self.reduce = P.AllReduce()
def construct(self, x):
one = self.cast(F.scalar_to_array(1.0), mstype.float32)
out = x*one
ret = self.reduce(out)
return ret
def parse_args(cloud_args=None):
"""parse_args"""
parser = argparse.ArgumentParser('mindspore classification test')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
# dataset related
parser.add_argument('--data_path', type=str, default='', help='eval data dir')
parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per npu')
# network related
parser.add_argument('--graph_ckpt', type=int, default=1, help='graph ckpt or feed ckpt')
parser.add_argument('--pre_trained', default='', type=str, help='fully path of pretrained model to load. '
'If it is a direction, it will test all ckpt')
# logging related
parser.add_argument('--log_path', type=str, default='outputs/', help='path to save log')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
args_opt = parser.parse_args()
args_opt = merge_args(args_opt, cloud_args)
args_opt.image_size = cfg.image_size
args_opt.num_classes = cfg.num_classes
args_opt.per_batch_size = cfg.batch_size
args_opt.momentum = cfg.momentum
args_opt.weight_decay = cfg.weight_decay
args_opt.buffer_size = cfg.buffer_size
args_opt.pad_mode = cfg.pad_mode
args_opt.padding = cfg.padding
args_opt.has_bias = cfg.has_bias
args_opt.batch_norm = cfg.batch_norm
args_opt.initialize_mode = cfg.initialize_mode
args_opt.has_dropout = cfg.has_dropout
args_opt.image_size = list(map(int, args_opt.image_size.split(',')))
return args_opt
def merge_args(args, cloud_args):
"""merge_args"""
args_dict = vars(args)
if isinstance(cloud_args, dict):
for key in cloud_args.keys():
val = cloud_args[key]
if key in args_dict and val:
arg_type = type(args_dict[key])
if arg_type is not type(None):
val = arg_type(val)
args_dict[key] = val
return args
def test(cloud_args=None):
"""test"""
args = parse_args(cloud_args)
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target=args.device_target, save_graphs=False)
if os.getenv('DEVICE_ID', "not_set").isdigit():
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
args.outputs_dir = os.path.join(args.log_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = LogUtil.get_instance()
args.logger.set_level(20)
net = vgg16(num_classes=args.num_classes, args=args)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, args.momentum,
weight_decay=args.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
param_dict = load_checkpoint(args.pre_trained)
load_param_into_net(net, param_dict)
net.set_train(False)
dataset_test = vgg_create_dataset100(args.data_path, args.image_size, args.per_batch_size, training=False)
res = model.eval(dataset_test)
print("result: ", res)
if __name__ == "__main__":
test()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Examples of membership inference
"""
import argparse
import sys
from vgg.vgg import vgg16
from vgg.config import cifar_cfg as cfg
from vgg.utils.util import get_param_groups
from vgg.dataset import vgg_create_dataset100
import numpy as np
from mindspore.train import Model
from mindspore.train.serialization import load_param_into_net, load_checkpoint
import mindspore.nn as nn
from mindarmour.diff_privacy.evaluation.membership_inference import MembershipInference
from mindarmour.utils import LogUtil
logging = LogUtil.get_instance()
logging.set_level(20)
sys.path.append("../../")
TAG = "membership inference example"
if __name__ == "__main__":
parser = argparse.ArgumentParser("main case arg parser.")
parser.add_argument("--device_target", type=str, default="Ascend",
choices=["Ascend"])
parser.add_argument("--data_path", type=str, required=True,
help="Data home path for Cifar100.")
parser.add_argument("--pre_trained", type=str, required=True,
help="Checkpoint path.")
args = parser.parse_args()
args.num_classes = cfg.num_classes
args.batch_norm = cfg.batch_norm
args.has_dropout = cfg.has_dropout
args.has_bias = cfg.has_bias
args.initialize_mode = cfg.initialize_mode
args.padding = cfg.padding
args.pad_mode = cfg.pad_mode
args.weight_decay = cfg.weight_decay
args.loss_scale = cfg.loss_scale
# load the pretrained model
net = vgg16(args.num_classes, args)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
opt = nn.Momentum(params=get_param_groups(net), learning_rate=0.1, momentum=0.9,
weight_decay=args.weight_decay, loss_scale=args.loss_scale)
load_param_into_net(net, load_checkpoint(args.pre_trained))
model = Model(network=net, loss_fn=loss, optimizer=opt)
logging.info(TAG, "The model is loaded.")
attacker = MembershipInference(model)
config = [
{
"method": "knn",
"params": {
"n_neighbors": [3, 5, 7]
}
},
{
"method": "lr",
"params": {
"C": np.logspace(-4, 2, 10)
}
},
{
"method": "mlp",
"params": {
"hidden_layer_sizes": [(64,), (32, 32)],
"solver": ["adam"],
"alpha": [0.0001, 0.001, 0.01]
}
},
{
"method": "rf",
"params": {
"n_estimators": [100],
"max_features": ["auto", "sqrt"],
"max_depth": [5, 10, 20, None],
"min_samples_split": [2, 5, 10],
"min_samples_leaf": [1, 2, 4]
}
}
]
# load and split dataset
train_dataset = vgg_create_dataset100(data_home=args.data_path, image_size=(224, 224),
batch_size=64, num_samples=10000, shuffle=False)
test_dataset = vgg_create_dataset100(data_home=args.data_path, image_size=(224, 224),
batch_size=64, num_samples=10000, shuffle=False, training=False)
train_train, eval_train = train_dataset.split([0.8, 0.2])
train_test, eval_test = test_dataset.split([0.8, 0.2])
logging.info(TAG, "Data loading is complete.")
logging.info(TAG, "Start training the inference model.")
attacker.train(train_train, train_test, config)
logging.info(TAG, "The inference model is training complete.")
logging.info(TAG, "Start the evaluation phase")
metrics = ["precision", "accuracy", "recall"]
result = attacker.eval(eval_train, eval_test, metrics)
# Show the metrics for each attack method.
count = len(config)
for i in range(count):
print("Method: {}, {}".format(config[i]["method"], result[i]))
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################train vgg16 example on cifar10########################
python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID
"""
import argparse
import datetime
import os
import random
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.model import Model
from mindspore.train.serialization import load_param_into_net, load_checkpoint
from mindarmour.utils import LogUtil
from vgg.dataset import vgg_create_dataset100
from vgg.warmup_step_lr import warmup_step_lr
from vgg.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
from vgg.warmup_step_lr import lr_steps
from vgg.utils.util import get_param_groups
from vgg.vgg import vgg16
from vgg.config import cifar_cfg as cfg
TAG = "train"
random.seed(1)
np.random.seed(1)
def parse_args(cloud_args=None):
"""parameters"""
parser = argparse.ArgumentParser('mindspore classification training')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)')
# dataset related
parser.add_argument('--data_path', type=str, default='', help='train data dir')
# network related
parser.add_argument('--pre_trained', default='', type=str, help='model_path, local pretrained model to load')
parser.add_argument('--lr_gamma', type=float, default=0.1,
help='decrease lr by a factor of exponential lr_scheduler')
parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler')
parser.add_argument('--T_max', type=int, default=150, help='T-max in cosine_annealing scheduler')
# logging and checkpoint related
parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
parser.add_argument('--ckpt_interval', type=int, default=2, help='ckpt_interval')
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
args_opt = parser.parse_args()
args_opt = merge_args(args_opt, cloud_args)
args_opt.rank = 0
args_opt.group_size = 1
args_opt.label_smooth = cfg.label_smooth
args_opt.label_smooth_factor = cfg.label_smooth_factor
args_opt.lr_scheduler = cfg.lr_scheduler
args_opt.loss_scale = cfg.loss_scale
args_opt.max_epoch = cfg.max_epoch
args_opt.warmup_epochs = cfg.warmup_epochs
args_opt.lr = cfg.lr
args_opt.lr_init = cfg.lr_init
args_opt.lr_max = cfg.lr_max
args_opt.momentum = cfg.momentum
args_opt.weight_decay = cfg.weight_decay
args_opt.per_batch_size = cfg.batch_size
args_opt.num_classes = cfg.num_classes
args_opt.buffer_size = cfg.buffer_size
args_opt.ckpt_save_max = cfg.keep_checkpoint_max
args_opt.pad_mode = cfg.pad_mode
args_opt.padding = cfg.padding
args_opt.has_bias = cfg.has_bias
args_opt.batch_norm = cfg.batch_norm
args_opt.initialize_mode = cfg.initialize_mode
args_opt.has_dropout = cfg.has_dropout
args_opt.lr_epochs = list(map(int, cfg.lr_epochs.split(',')))
args_opt.image_size = list(map(int, cfg.image_size.split(',')))
return args_opt
def merge_args(args_opt, cloud_args):
"""dictionary"""
args_dict = vars(args_opt)
if isinstance(cloud_args, dict):
for key_arg in cloud_args.keys():
val = cloud_args[key_arg]
if key_arg in args_dict and val:
arg_type = type(args_dict[key_arg])
if arg_type is not None:
val = arg_type(val)
args_dict[key_arg] = val
return args_opt
if __name__ == '__main__':
args = parse_args()
device_num = int(os.environ.get("DEVICE_NUM", 1))
context.set_context(device_id=args.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
# select for master rank save ckpt or all rank save, compatiable for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
args.rank_save_ckpt_flag = 1
else:
args.rank_save_ckpt_flag = 1
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = LogUtil.get_instance()
args.logger.set_level(20)
# load train data set
dataset = vgg_create_dataset100(args.data_path, args.image_size, args.per_batch_size, args.rank, args.group_size)
batch_num = dataset.get_dataset_size()
args.steps_per_epoch = dataset.get_dataset_size()
# network
args.logger.info(TAG, 'start create network')
# get network and init
network = vgg16(args.num_classes, args)
# pre_trained
if args.pre_trained:
load_param_into_net(network, load_checkpoint(args.pre_trained))
# lr scheduler
if args.lr_scheduler == 'exponential':
lr = warmup_step_lr(args.lr,
args.lr_epochs,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
gamma=args.lr_gamma,
)
elif args.lr_scheduler == 'cosine_annealing':
lr = warmup_cosine_annealing_lr(args.lr,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
args.T_max,
args.eta_min)
elif args.lr_scheduler == 'step':
lr = lr_steps(0, lr_init=args.lr_init, lr_max=args.lr_max, warmup_epochs=args.warmup_epochs,
total_epochs=args.max_epoch, steps_per_epoch=batch_num)
else:
raise NotImplementedError(args.lr_scheduler)
# optimizer
opt = Momentum(params=get_param_groups(network),
learning_rate=Tensor(lr),
momentum=args.momentum,
weight_decay=args.weight_decay,
loss_scale=args.loss_scale)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
model = Model(network, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
# checkpoint save
if args.rank_save_ckpt_flag:
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval*args.steps_per_epoch,
keep_checkpoint_max=args.ckpt_save_max)
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=args.outputs_dir,
prefix='{}'.format(args.rank))
callbacks = ckpt_cb
model.train(args.max_epoch, dataset, callbacks=callbacks)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as edict
# config for vgg16, cifar100
cifar_cfg = edict({
"num_classes": 100,
"lr": 0.01,
"lr_init": 0.01,
"lr_max": 0.1,
"lr_epochs": '30,60,90,120',
"lr_scheduler": "step",
"warmup_epochs": 5,
"batch_size": 64,
"max_epoch": 100,
"momentum": 0.9,
"weight_decay": 5e-4,
"loss_scale": 1.0,
"label_smooth": 0,
"label_smooth_factor": 0,
"buffer_size": 10,
"image_size": '224,224',
"pad_mode": 'same',
"padding": 0,
"has_bias": False,
"batch_norm": True,
"keep_checkpoint_max": 10,
"initialize_mode": "XavierUniform",
"has_dropout": False
})
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy(_Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0., num_classes=1001):
super(CrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logit, label):
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0)
return loss
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
dataset processing.
"""
import os
from mindspore.common import dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as vision
def vgg_create_dataset100(data_home, image_size, batch_size, rank_id=0, rank_size=1, repeat_num=1,
training=True, num_samples=None, shuffle=True):
"""Data operations."""
de.config.set_seed(1)
data_dir = os.path.join(data_home, "train")
if not training:
data_dir = os.path.join(data_home, "test")
if num_samples is not None:
data_set = de.Cifar100Dataset(data_dir, num_shards=rank_size, shard_id=rank_id,
num_samples=num_samples, shuffle=shuffle)
else:
data_set = de.Cifar100Dataset(data_dir, num_shards=rank_size, shard_id=rank_id)
input_columns = ["fine_label"]
output_columns = ["label"]
data_set = data_set.rename(input_columns=input_columns, output_columns=output_columns)
data_set = data_set.project(["image", "label"])
rescale = 1.0 / 255.0
shift = 0.0
# define map operations
random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
random_horizontal_op = vision.RandomHorizontalFlip()
resize_op = vision.Resize(image_size) # interpolation default BILINEAR
rescale_op = vision.Rescale(rescale, shift)
normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023))
changeswap_op = vision.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
c_trans = []
if training:
c_trans = [random_crop_op, random_horizontal_op]
c_trans += [resize_op, rescale_op, normalize_op,
changeswap_op]
# apply map operations on images
data_set = data_set.map(input_columns="label", operations=type_cast_op)
data_set = data_set.map(input_columns="image", operations=c_trans)
# apply repeat operations
data_set = data_set.repeat(repeat_num)
# apply shuffle operations
# data_set = data_set.shuffle(buffer_size=1000)
# apply batch operations
data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
return data_set
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
linear warm up learning rate.
"""
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc*current_step
return lr
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Util class or function."""
def get_param_groups(network):
"""Param groups for optimizer."""
decay_params = []
no_decay_params = []
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
no_decay_params.append(x)
elif parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
elif parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
else:
decay_params.append(x)
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Initialize.
"""
import math
from functools import reduce
import numpy as np
import mindspore.nn as nn
from mindspore.common import initializer as init
def _calculate_gain(nonlinearity, param=None):
r"""
Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
================= ====================================================
Args:
nonlinearity: the non-linear function
param: optional parameter for the non-linear function
Examples:
>>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
if nonlinearity == 'tanh':
return 5.0 / 3
if nonlinearity == 'relu':
return math.sqrt(2.0)
if nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope**2))
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
def _assignment(arr, num):
"""Assign the value of `num` to `arr`."""
if arr.shape == ():
arr = arr.reshape((1))
arr[:] = num
arr = arr.reshape(())
else:
if isinstance(num, np.ndarray):
arr[:] = num[:]
else:
arr[:] = num
return arr
def _calculate_in_and_out(arr):
"""
Calculate n_in and n_out.
Args:
arr (Array): Input array.
Returns:
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
"""
dim = len(arr.shape)
if dim < 2:
raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
n_in = arr.shape[1]
n_out = arr.shape[0]
if dim > 2:
counter = reduce(lambda x, y: x*y, arr.shape[2:])
n_in *= counter
n_out *= counter
return n_in, n_out
def _select_fan(array, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_in_and_out(array)
return fan_in if mode == 'fan_in' else fan_out
class KaimingInit(init.Initializer):
r"""
Base Class. Initialize the array with He kaiming algorithm.
Args:
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function, recommended to use only with
``'relu'`` or ``'leaky_relu'`` (default).
"""
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
super(KaimingInit, self).__init__()
self.mode = mode
self.gain = _calculate_gain(nonlinearity, a)
def _initialize(self, arr):
pass
class KaimingUniform(KaimingInit):
r"""
Initialize the array with He kaiming uniform algorithm. The resulting tensor will
have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Input:
arr (Array): The array to be assigned.
Returns:
Array, assigned array.
Examples:
>>> w = np.empty(3, 5)
>>> KaimingUniform(w, mode='fan_in', nonlinearity='relu')
"""
def _initialize(self, arr):
fan = _select_fan(arr, self.mode)
bound = math.sqrt(3.0)*self.gain / math.sqrt(fan)
np.random.seed(0)
data = np.random.uniform(-bound, bound, arr.shape)
_assignment(arr, data)
class KaimingNormal(KaimingInit):
r"""
Initialize the array with He kaiming normal algorithm. The resulting tensor will
have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
Input:
arr (Array): The array to be assigned.
Returns:
Array, assigned array.
Examples:
>>> w = np.empty(3, 5)
>>> KaimingNormal(w, mode='fan_out', nonlinearity='relu')
"""
def _initialize(self, arr):
fan = _select_fan(arr, self.mode)
std = self.gain / math.sqrt(fan)
np.random.seed(0)
data = np.random.normal(0, std, arr.shape)
_assignment(arr, data)
def default_recurisive_init(custom_cell):
"""default_recurisive_init"""
for _, cell in custom_cell.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.shape,
cell.weight.dtype)
if cell.bias is not None:
fan_in, _ = _calculate_in_and_out(cell.weight)
bound = 1 / math.sqrt(fan_in)
np.random.seed(0)
cell.bias.default_input = init.initializer(init.Uniform(bound),
cell.bias.shape,
cell.bias.dtype)
elif isinstance(cell, nn.Dense):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.shape,
cell.weight.dtype)
if cell.bias is not None:
fan_in, _ = _calculate_in_and_out(cell.weight)
bound = 1 / math.sqrt(fan_in)
np.random.seed(0)
cell.bias.default_input = init.initializer(init.Uniform(bound),
cell.bias.shape,
cell.bias.dtype)
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
pass
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Image classifiation.
"""
import math
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.common import initializer as init
from mindspore.common.initializer import initializer
from .utils.var_init import default_recurisive_init, KaimingNormal
def _make_layer(base, args, batch_norm):
"""Make stage network of VGG."""
layers = []
in_channels = 3
for v in base:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
weight_shape = (v, in_channels, 3, 3)
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor()
if args.initialize_mode == "KaimingNormal":
weight = 'normal'
conv2d = nn.Conv2d(in_channels=in_channels,
out_channels=v,
kernel_size=3,
padding=args.padding,
pad_mode=args.pad_mode,
has_bias=args.has_bias,
weight_init=weight)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()]
else:
layers += [conv2d, nn.ReLU()]
in_channels = v
return nn.SequentialCell(layers)
class Vgg(nn.Cell):
"""
VGG network definition.
Args:
base (list): Configuration for different layers, mainly the channel number of Conv layer.
num_classes (int): Class numbers. Default: 1000.
batch_norm (bool): Whether to do the batchnorm. Default: False.
batch_size (int): Batch size. Default: 1.
Returns:
Tensor, infer output tensor.
Examples:
>>> Vgg([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
>>> num_classes=1000, batch_norm=False, batch_size=1)
"""
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train"):
super(Vgg, self).__init__()
_ = batch_size
self.layers = _make_layer(base, args, batch_norm=batch_norm)
self.flatten = nn.Flatten()
dropout_ratio = 0.5
if not args.has_dropout or phase == "test":
dropout_ratio = 1.0
self.classifier = nn.SequentialCell([
nn.Dense(512*7*7, 4096),
nn.ReLU(),
nn.Dropout(dropout_ratio),
nn.Dense(4096, 4096),
nn.ReLU(),
nn.Dropout(dropout_ratio),
nn.Dense(4096, num_classes)])
if args.initialize_mode == "KaimingNormal":
default_recurisive_init(self)
self.custom_init_weight()
def construct(self, x):
x = self.layers(x)
x = self.flatten(x)
x = self.classifier(x)
return x
def custom_init_weight(self):
"""
Init the weight of Conv2d and Dense in the net.
"""
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.default_input = init.initializer(
KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'),
cell.weight.shape, cell.weight.dtype)
if cell.bias is not None:
cell.bias.default_input = init.initializer(
'zeros', cell.bias.shape, cell.bias.dtype)
elif isinstance(cell, nn.Dense):
cell.weight.default_input = init.initializer(
init.Normal(0.01), cell.weight.shape, cell.weight.dtype)
if cell.bias is not None:
cell.bias.default_input = init.initializer(
'zeros', cell.bias.shape, cell.bias.dtype)
cfg = {
'11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
def vgg16(num_classes=1000, args=None, phase="train"):
"""
Get Vgg16 neural network with batch normalization.
Args:
num_classes (int): Class numbers. Default: 1000.
args(namespace): param for net init.
phase(str): train or test mode.
Returns:
Cell, cell instance of Vgg16 neural network with batch normalization.
Examples:
>>> vgg16(num_classes=1000, args=args)
"""
net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase)
return net
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
warm up cosine annealing learning rate.
"""
import math
import numpy as np
from .linear_warmup import linear_warmup_lr
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, t_max, eta_min=0):
"""warm up cosine annealing learning rate."""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch*steps_per_epoch)
warmup_steps = int(warmup_epochs*steps_per_epoch)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min)*(1. + math.cos(math.pi*last_epoch / t_max)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
warm up step learning rate.
"""
from collections import Counter
import numpy as np
from .linear_warmup import linear_warmup_lr
def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""Set learning rate."""
lr_each_step = []
total_steps = steps_per_epoch*total_epochs
warmup_steps = steps_per_epoch*warmup_epochs
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr_value = float(lr_init) + inc_each_step*float(i)
else:
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr_value = float(lr_max)*base*base
if lr_value < 0.0:
lr_value = 0.0
lr_each_step.append(lr_value)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
"""warmup_step_lr"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch*steps_per_epoch)
warmup_steps = int(warmup_epochs*steps_per_epoch)
milestones = lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone*steps_per_epoch
milestones_steps.append(milestones_step)
lr_each_step = []
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr*gamma**milestones_steps_counter[i]
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
lr_epochs = []
for i in range(1, max_epoch):
if i % epoch_size == 0:
lr_epochs.append(i)
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
......@@ -32,7 +32,7 @@ def _attack_knn(features, labels, param_grid):
param_grid (dict): Setting of GridSearchCV.
Returns:
sklearn.neighbors.KNeighborsClassifier, trained model.
sklearn.model_selection.GridSearchCV, trained model.
"""
knn_model = KNeighborsClassifier()
knn_model = GridSearchCV(
......@@ -53,9 +53,9 @@ def _attack_lr(features, labels, param_grid):
param_grid (dict): Setting of GridSearchCV.
Returns:
sklearn.linear_model.LogisticRegression, trained model.
sklearn.model_selection.GridSearchCV, trained model.
"""
lr_model = LogisticRegression(C=1.0, penalty="l2")
lr_model = LogisticRegression(C=1.0, penalty="l2", max_iter=1000)
lr_model = GridSearchCV(
lr_model, param_grid=param_grid, cv=3, n_jobs=1, iid=False,
verbose=0,
......@@ -74,7 +74,7 @@ def _attack_mlpc(features, labels, param_grid):
param_grid (dict): Setting of GridSearchCV.
Returns:
sklearn.neural_network.MLPClassifier, trained model.
sklearn.model_selection.GridSearchCV, trained model.
"""
mlpc_model = MLPClassifier(random_state=1, max_iter=300)
mlpc_model = GridSearchCV(
......@@ -95,7 +95,7 @@ def _attack_rf(features, labels, random_grid):
random_grid (dict): Setting of RandomizedSearchCV.
Returns:
sklearn.ensemble.RandomForestClassifier, trained model.
sklearn.model_selection.RandomizedSearchCV, trained model.
"""
rf_model = RandomForestClassifier(max_depth=2, random_state=0)
rf_model = RandomizedSearchCV(
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Membership Inference
"""
import numpy as np
import mindspore as ms
from mindspore.train import Model
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model
def _eval_info(pred, truth, option):
"""
Calculate the performance according to pred and truth.
Args:
pred (numpy.ndarray): Predictions for each sample.
truth (numpy.ndarray): Ground truth for each sample.
option(str): Type of evaluation indicators; Possible
values are 'precision', 'accuracy' and 'recall'.
Returns:
float32, Calculated evaluation results.
Raises:
ValueError, size of parameter pred or truth is 0.
ValueError, value of parameter option must be in ["precision", "accuracy", "recall"].
"""
if pred.size == 0 || truth.size == 0:
raise ValueError("Size of pred or truth is 0.")
if option == "accuracy":
count = np.sum(pred == truth)
return count / len(pred)
if option == "precision":
count = np.sum(pred & truth)
if np.sum(pred) == 0:
return -1
return count / np.sum(pred)
if option == "recall":
count = np.sum(pred & truth)
if np.sum(truth) == 0:
return -1
return count / np.sum(truth)
raise ValueError("The metric value {} is undefined.".format(option))
class MembershipInference:
"""
Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack.
The attack requires obtain loss or logits results of training samples.
References: Reza Shokri, Marco Stronati, Congzheng Song, Vitaly Shmatikov.
Membership Inference Attacks against Machine Learning Models. 2017.
arXiv:1610.05820v2 <https://arxiv.org/abs/1610.05820v2>`_
Args:
model (Model): Target model.
Examples:
>>> # ds_train, eval_train are non-overlapping datasets from training dataset.
>>> # eval_train, eval_test are non-overlapping datasets from test dataset.
>>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'})
>>> inference_model = MembershipInference(model)
>>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}]
>>> inference_model.train(ds_train, ds_test, config)
>>> metrics = ["precision", "recall", "accuracy"]
>>> result = inference_model.eval(eval_train, eval_test, metrics)
Raises:
TypeError: If type of model is not mindspore.train.Model.
"""
def __init__(self, model):
if not isinstance(model, Model):
raise TypeError("Type of model must be {}, but got {}.".format(type(Model), type(model)))
self.model = model
self.attack_list = []
def train(self, dataset_train, dataset_test, attack_config):
"""
Depending on the configuration, use the incoming data set to train the attack model.
Save the attack model to self.attack_list.
Args:
dataset_train (mindspore.dataset): The training dataset for the target model.
dataset_test (mindspore.dataset): The test set for the target model.
attack_config (list): Parameter setting for the attack model.
Raises:
ValueError: If the method in attack_config is not in ["LR", "KNN", "RF", "MLPC"].
"""
features, labels = self._transform(dataset_train, dataset_test)
for config in attack_config:
self.attack_list.append(get_attack_model(features, labels, config))
def eval(self, dataset_train, dataset_test, metrics):
"""
Evaluate the different privacy of the target model.
Evaluation indicators shall be specified by metrics.
Args:
dataset_train (mindspore.dataset): The training dataset for the target model.
dataset_test (mindspore.dataset): The test dataset for the target model.
metrics (Union[list, tuple]): Evaluation indicators. The value of metrics
must be in ["precision", "accuracy", "recall"]. Default: ["precision"].
Returns:
list, Each element contains an evaluation indicator for the attack model.
"""
result = []
features, labels = self._transform(dataset_train, dataset_test)
for attacker in self.attack_list:
pred = attacker.predict(features)
item = {}
for option in metrics:
item[option] = _eval_info(pred, labels, option)
result.append(item)
return result
def _transform(self, dataset_train, dataset_test):
"""
Generate corresponding loss_logits feature and new label, and return after shuffle.
Args:
dataset_train: The training set for the target model.
dataset_test: The test set for the target model.
Returns:
- numpy.ndarray, Loss_logits features for each sample. Shape is (N, C).
N is the number of sample. C = 1 + dim(logits).
- numpy.ndarray, Labels for each sample, Shape is (N,).
"""
features_train, labels_train = self._generate(dataset_train, 1)
features_test, labels_test = self._generate(dataset_test, 0)
features = np.vstack((features_train, features_test))
labels = np.hstack((labels_train, labels_test))
shuffle_index = np.array(range(len(labels)))
np.random.shuffle(shuffle_index)
features = features[shuffle_index]
labels = labels[shuffle_index]
return features, labels
def _generate(self, dataset_x, label):
"""
Return a loss_logits features and labels for training attack model.
Args:
dataset_x (mindspore.dataset): The dataset to be generate.
label (int32): Whether dataset_x belongs to the target model.
Returns:
- numpy.ndarray, Loss_logits features for each sample. Shape is (N, C).
N is the number of sample. C = 1 + dim(logits).
- numpy.ndarray, Labels for each sample, Shape is (N,).
"""
if context.get_context("device_target") != "Ascend":
raise RuntimeError("The target device must be Ascend, "
"but current is {}.".format(context.get_context("device_target")))
loss_logits = np.array([])
for batch in dataset_x.create_dict_iterator():
batch_data = Tensor(batch['image'], ms.float32)
batch_labels = Tensor(batch['label'], ms.int32)
batch_logits = self.model.predict(batch_data)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction=None)
batch_loss = loss(batch_logits, batch_labels).asnumpy()
batch_logits = batch_logits.asnumpy()
batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits))
if loss_logits.size == 0:
loss_logits = batch_feature
else:
loss_logits = np.vstack((loss_logits, batch_feature))
if label == 1:
labels = np.ones(len(loss_logits), np.int32)
elif label == 0:
labels = np.zeros(len(loss_logits), np.int32)
else:
raise ValueError("The value of label must be 0 or 1, but got {}.".format(label))
return loss_logits, labels
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
membership inference test
"""
import os
import sys
import pytest
import numpy as np
import mindspore.dataset as ds
from mindspore import nn
from mindspore.train import Model
from mindarmour.diff_privacy.evaluation.membership_inference import MembershipInference
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../"))
from defenses.mock_net import Net
def dataset_generator(batch_size, batches):
"""mock training data."""
data = np.random.randn(batches*batch_size, 1, 32, 32).astype(
np.float32)
label = np.random.randint(0, 10, batches*batch_size).astype(np.int32)
for i in range(batches):
yield data[i*batch_size:(i + 1)*batch_size],\
label[i*batch_size:(i + 1)*batch_size]
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_get_membership_inference_object():
net = Net()
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(network=net, loss_fn=loss, optimizer=opt)
inference_model = MembershipInference(model)
assert isinstance(inference_model, MembershipInference)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_membership_inference_object_train():
net = Net()
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(network=net, loss_fn=loss, optimizer=opt)
inference_model = MembershipInference(model)
assert isinstance(inference_model, MembershipInference)
config = [{
"method": "KNN",
"params": {
"n_neighbors": [3, 5, 7],
}
}]
batch_size = 16
batches = 1
ds_train = ds.GeneratorDataset(dataset_generator(batch_size, batches),
["image", "label"])
ds_test = ds.GeneratorDataset(dataset_generator(batch_size, batches),
["image", "label"])
ds_train.set_dataset_size(batch_size*batches)
ds_test.set_dataset_size((batch_size*batches))
inference_model.train(ds_train, ds_test, config)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_membership_inference_eval():
net = Net()
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(network=net, loss_fn=loss, optimizer=opt)
inference_model = MembershipInference(model)
assert isinstance(inference_model, MembershipInference)
batch_size = 16
batches = 1
eval_train = ds.GeneratorDataset(dataset_generator(batch_size, batches),
["image", "label"])
eval_test = ds.GeneratorDataset(dataset_generator(batch_size, batches),
["image", "label"])
eval_train.set_dataset_size(batch_size * batches)
eval_test.set_dataset_size((batch_size * batches))
metrics = ["precision", "accuracy", "recall"]
inference_model.eval(eval_train, eval_test, metrics)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册