提交 7dc09aed 编写于 作者: L liuluobin

Append description for get_attacker_model and train

上级 ce15e781
......@@ -22,9 +22,6 @@ from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RandomizedSearchCV
method_list = ["lr", "knn", "rf", "mlp"]
def _attack_knn(features, labels, param_grid):
"""
Train and return a KNN model.
......@@ -117,9 +114,19 @@ def get_attack_model(features, labels, config):
features (numpy.ndarray): Loss and logits characteristics of each sample.
labels (numpy.ndarray): Labels of each sample whether belongs to training set.
config (dict): Config of attacker, with key in ["method", "params"].
The format is {"method": "knn", "params": {"n_neighbors": [3, 5, 7]}},
params of each method must within the range of changeable parameters.
Tips of params implement can be found in
"https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html".
Returns:
sklearn.BaseEstimator, trained model specify by config["method"].
Examples:
>>> features = np.random.randn(10, 10)
>>> labels = np.random.randint(0, 2, 10)
>>> config = {"method": "knn", "params": {"n_neighbors": [3, 5, 7]}}
>>> attack_model = get_attack_model(features, labels, config)
"""
method = str.lower(config["method"])
......
......@@ -23,7 +23,7 @@ from mindspore.dataset.engine import Dataset
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model, method_list
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model
def _eval_info(pred, truth, option):
"""
......@@ -67,22 +67,23 @@ class MembershipInference:
Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack.
The attack requires obtain loss or logits results of training samples.
References: Reza Shokri, Marco Stronati, Congzheng Song, Vitaly Shmatikov.
References: `Reza Shokri, Marco Stronati, Congzheng Song, Vitaly Shmatikov.
Membership Inference Attacks against Machine Learning Models. 2017.
arXiv:1610.05820v2 <https://arxiv.org/abs/1610.05820v2>`_
<https://arxiv.org/abs/1610.05820v2>`_
Args:
model (Model): Target model.
Examples:
>>> # ds_train, eval_train are non-overlapping datasets from training dataset.
>>> # eval_train, eval_test are non-overlapping datasets from test dataset.
>>> train_1, train_2 are non-overlapping datasets from training dataset of target model.
>>> test_1, test_2 are non-overlapping datasets from test dataset of target model.
>>> We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model.
>>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'})
>>> inference_model = MembershipInference(model)
>>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}]
>>> inference_model.train(ds_train, ds_test, config)
>>> inference_model.train(train_1, test_1, config)
>>> metrics = ["precision", "recall", "accuracy"]
>>> result = inference_model.eval(eval_train, eval_test, metrics)
>>> result = inference_model.eval(train_2, test_2, metrics)
Raises:
TypeError: If type of model is not mindspore.train.Model.
......@@ -92,6 +93,7 @@ class MembershipInference:
if not isinstance(model, Model):
raise TypeError("Type of parameter 'model' must be Model, but got {}.".format(type(model)))
self.model = model
self.method_list = ["knn", "lr", "mlp", "rf"]
self.attack_list = []
def train(self, dataset_train, dataset_test, attack_config):
......@@ -102,7 +104,13 @@ class MembershipInference:
Args:
dataset_train (mindspore.dataset): The training dataset for the target model.
dataset_test (mindspore.dataset): The test set for the target model.
attack_config (list): Parameter setting for the attack model.
attack_config (list): Parameter setting for the attack model. The format is
[{"method": "knn", "params": {"n_neighbors": [3, 5, 7]}},
{"method": "lr", "params": {"C": np.logspace(-4, 2, 10)}}].
The support methods list is in self.method_list, and the params of each method
must within the range of changeable parameters. Tips of params implement
can be found in
"https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html".
Raises:
KeyError: If each config in attack_config doesn't have keys {"method", "params"}
......@@ -120,7 +128,7 @@ class MembershipInference:
if {"params", "method"} != set(config.keys()):
raise KeyError("Each config in attack_config must have keys 'method' and 'params', "
"but your key value is {}.".format(set(config.keys())))
if str.lower(config["method"]) not in method_list:
if str.lower(config["method"]) not in self.method_list:
raise ValueError("Method {} is not support.".format(config["method"]))
features, labels = self._transform(dataset_train, dataset_test)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册