提交 7dc09aed 编写于 作者: L liuluobin

Append description for get_attacker_model and train

上级 ce15e781
...@@ -22,9 +22,6 @@ from sklearn.model_selection import GridSearchCV ...@@ -22,9 +22,6 @@ from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RandomizedSearchCV from sklearn.model_selection import RandomizedSearchCV
method_list = ["lr", "knn", "rf", "mlp"]
def _attack_knn(features, labels, param_grid): def _attack_knn(features, labels, param_grid):
""" """
Train and return a KNN model. Train and return a KNN model.
...@@ -117,9 +114,19 @@ def get_attack_model(features, labels, config): ...@@ -117,9 +114,19 @@ def get_attack_model(features, labels, config):
features (numpy.ndarray): Loss and logits characteristics of each sample. features (numpy.ndarray): Loss and logits characteristics of each sample.
labels (numpy.ndarray): Labels of each sample whether belongs to training set. labels (numpy.ndarray): Labels of each sample whether belongs to training set.
config (dict): Config of attacker, with key in ["method", "params"]. config (dict): Config of attacker, with key in ["method", "params"].
The format is {"method": "knn", "params": {"n_neighbors": [3, 5, 7]}},
params of each method must within the range of changeable parameters.
Tips of params implement can be found in
"https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html".
Returns: Returns:
sklearn.BaseEstimator, trained model specify by config["method"]. sklearn.BaseEstimator, trained model specify by config["method"].
Examples:
>>> features = np.random.randn(10, 10)
>>> labels = np.random.randint(0, 2, 10)
>>> config = {"method": "knn", "params": {"n_neighbors": [3, 5, 7]}}
>>> attack_model = get_attack_model(features, labels, config)
""" """
method = str.lower(config["method"]) method = str.lower(config["method"])
......
...@@ -23,7 +23,7 @@ from mindspore.dataset.engine import Dataset ...@@ -23,7 +23,7 @@ from mindspore.dataset.engine import Dataset
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model, method_list from mindarmour.diff_privacy.evaluation.attacker import get_attack_model
def _eval_info(pred, truth, option): def _eval_info(pred, truth, option):
""" """
...@@ -67,22 +67,23 @@ class MembershipInference: ...@@ -67,22 +67,23 @@ class MembershipInference:
Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack. Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack.
The attack requires obtain loss or logits results of training samples. The attack requires obtain loss or logits results of training samples.
References: Reza Shokri, Marco Stronati, Congzheng Song, Vitaly Shmatikov. References: `Reza Shokri, Marco Stronati, Congzheng Song, Vitaly Shmatikov.
Membership Inference Attacks against Machine Learning Models. 2017. Membership Inference Attacks against Machine Learning Models. 2017.
arXiv:1610.05820v2 <https://arxiv.org/abs/1610.05820v2>`_ <https://arxiv.org/abs/1610.05820v2>`_
Args: Args:
model (Model): Target model. model (Model): Target model.
Examples: Examples:
>>> # ds_train, eval_train are non-overlapping datasets from training dataset. >>> train_1, train_2 are non-overlapping datasets from training dataset of target model.
>>> # eval_train, eval_test are non-overlapping datasets from test dataset. >>> test_1, test_2 are non-overlapping datasets from test dataset of target model.
>>> We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model.
>>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) >>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'})
>>> inference_model = MembershipInference(model) >>> inference_model = MembershipInference(model)
>>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] >>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}]
>>> inference_model.train(ds_train, ds_test, config) >>> inference_model.train(train_1, test_1, config)
>>> metrics = ["precision", "recall", "accuracy"] >>> metrics = ["precision", "recall", "accuracy"]
>>> result = inference_model.eval(eval_train, eval_test, metrics) >>> result = inference_model.eval(train_2, test_2, metrics)
Raises: Raises:
TypeError: If type of model is not mindspore.train.Model. TypeError: If type of model is not mindspore.train.Model.
...@@ -92,6 +93,7 @@ class MembershipInference: ...@@ -92,6 +93,7 @@ class MembershipInference:
if not isinstance(model, Model): if not isinstance(model, Model):
raise TypeError("Type of parameter 'model' must be Model, but got {}.".format(type(model))) raise TypeError("Type of parameter 'model' must be Model, but got {}.".format(type(model)))
self.model = model self.model = model
self.method_list = ["knn", "lr", "mlp", "rf"]
self.attack_list = [] self.attack_list = []
def train(self, dataset_train, dataset_test, attack_config): def train(self, dataset_train, dataset_test, attack_config):
...@@ -102,7 +104,13 @@ class MembershipInference: ...@@ -102,7 +104,13 @@ class MembershipInference:
Args: Args:
dataset_train (mindspore.dataset): The training dataset for the target model. dataset_train (mindspore.dataset): The training dataset for the target model.
dataset_test (mindspore.dataset): The test set for the target model. dataset_test (mindspore.dataset): The test set for the target model.
attack_config (list): Parameter setting for the attack model. attack_config (list): Parameter setting for the attack model. The format is
[{"method": "knn", "params": {"n_neighbors": [3, 5, 7]}},
{"method": "lr", "params": {"C": np.logspace(-4, 2, 10)}}].
The support methods list is in self.method_list, and the params of each method
must within the range of changeable parameters. Tips of params implement
can be found in
"https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html".
Raises: Raises:
KeyError: If each config in attack_config doesn't have keys {"method", "params"} KeyError: If each config in attack_config doesn't have keys {"method", "params"}
...@@ -120,7 +128,7 @@ class MembershipInference: ...@@ -120,7 +128,7 @@ class MembershipInference:
if {"params", "method"} != set(config.keys()): if {"params", "method"} != set(config.keys()):
raise KeyError("Each config in attack_config must have keys 'method' and 'params', " raise KeyError("Each config in attack_config must have keys 'method' and 'params', "
"but your key value is {}.".format(set(config.keys()))) "but your key value is {}.".format(set(config.keys())))
if str.lower(config["method"]) not in method_list: if str.lower(config["method"]) not in self.method_list:
raise ValueError("Method {} is not support.".format(config["method"])) raise ValueError("Method {} is not support.".format(config["method"]))
features, labels = self._transform(dataset_train, dataset_test) features, labels = self._transform(dataset_train, dataset_test)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册