main.py 4.4 KB
Newer Older
L
liuluobin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Examples of membership inference
"""
import argparse
import sys

from vgg.vgg import vgg16
from vgg.config import cifar_cfg as cfg
from vgg.utils.util import get_param_groups
from vgg.dataset import vgg_create_dataset100

import numpy as np

from mindspore.train import Model
from mindspore.train.serialization import load_param_into_net, load_checkpoint
import mindspore.nn as nn
from mindarmour.diff_privacy.evaluation.membership_inference import MembershipInference
from mindarmour.utils import LogUtil
logging = LogUtil.get_instance()
logging.set_level(20)

sys.path.append("../../")

TAG = "membership inference example"


if __name__ == "__main__":
    parser = argparse.ArgumentParser("main case arg parser.")
    parser.add_argument("--device_target", type=str, default="Ascend",
                        choices=["Ascend"])
    parser.add_argument("--data_path", type=str, required=True,
                        help="Data home path for Cifar100.")
    parser.add_argument("--pre_trained", type=str, required=True,
                        help="Checkpoint path.")
    args = parser.parse_args()
    args.num_classes = cfg.num_classes
    args.batch_norm = cfg.batch_norm
    args.has_dropout = cfg.has_dropout
    args.has_bias = cfg.has_bias
    args.initialize_mode = cfg.initialize_mode
    args.padding = cfg.padding
    args.pad_mode = cfg.pad_mode
    args.weight_decay = cfg.weight_decay
    args.loss_scale = cfg.loss_scale

    # load the pretrained model
    net = vgg16(args.num_classes, args)
    loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
    opt = nn.Momentum(params=get_param_groups(net), learning_rate=0.1, momentum=0.9,
                      weight_decay=args.weight_decay, loss_scale=args.loss_scale)
    load_param_into_net(net, load_checkpoint(args.pre_trained))
    model = Model(network=net, loss_fn=loss, optimizer=opt)
    logging.info(TAG, "The model is loaded.")
    attacker = MembershipInference(model)
    config = [
        {
            "method": "knn",
            "params": {
                "n_neighbors": [3, 5, 7]
            }
        },
        {
            "method": "lr",
            "params": {
                "C": np.logspace(-4, 2, 10)
            }
        },
        {
            "method": "mlp",
            "params": {
                "hidden_layer_sizes": [(64,), (32, 32)],
                "solver": ["adam"],
                "alpha": [0.0001, 0.001, 0.01]
            }
        },
        {
            "method": "rf",
            "params": {
                "n_estimators": [100],
                "max_features": ["auto", "sqrt"],
                "max_depth": [5, 10, 20, None],
                "min_samples_split": [2, 5, 10],
                "min_samples_leaf": [1, 2, 4]
            }
        }
    ]

    # load and split dataset
    train_dataset = vgg_create_dataset100(data_home=args.data_path, image_size=(224, 224),
                                          batch_size=64, num_samples=10000, shuffle=False)
    test_dataset = vgg_create_dataset100(data_home=args.data_path, image_size=(224, 224),
                                         batch_size=64, num_samples=10000, shuffle=False, training=False)
    train_train, eval_train = train_dataset.split([0.8, 0.2])
    train_test, eval_test = test_dataset.split([0.8, 0.2])
    logging.info(TAG, "Data loading is complete.")

    logging.info(TAG, "Start training the inference model.")
    attacker.train(train_train, train_test, config)
    logging.info(TAG, "The inference model is training complete.")

    logging.info(TAG, "Start the evaluation phase")
    metrics = ["precision", "accuracy", "recall"]
    result = attacker.eval(eval_train, eval_test, metrics)

    # Show the metrics for each attack method.
    count = len(config)
    for i in range(count):
        print("Method: {}, {}".format(config[i]["method"], result[i]))