提交 3d887767 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!90 Append the parameter verification of class MembershipInference.

Merge pull request !90 from liuluobin/master
......@@ -27,7 +27,7 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.model import Model
from mindspore.train.serialization import load_param_into_net, load_checkpoint
from mindarmour.utils import LogUtil
......@@ -187,12 +187,13 @@ if __name__ == '__main__':
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
# checkpoint save
callbacks = [LossMonitor()]
if args.rank_save_ckpt_flag:
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval*args.steps_per_epoch,
keep_checkpoint_max=args.ckpt_save_max)
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=args.outputs_dir,
prefix='{}'.format(args.rank))
callbacks = ckpt_cb
callbacks.append(ckpt_cb)
model.train(args.max_epoch, dataset, callbacks=callbacks)
......@@ -22,6 +22,9 @@ from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RandomizedSearchCV
method_list = ["lr", "knn", "rf", "mlp"]
def _attack_knn(features, labels, param_grid):
"""
Train and return a KNN model.
......@@ -119,12 +122,13 @@ def get_attack_model(features, labels, config):
sklearn.BaseEstimator, trained model specify by config["method"].
"""
method = str.lower(config["method"])
if method == "knn":
return _attack_knn(features, labels, config["params"])
if method in ["lr", "logitic regression"]:
if method == "lr":
return _attack_lr(features, labels, config["params"])
if method == "mlp":
return _attack_mlpc(features, labels, config["params"])
if method in ["rf", "random forest"]:
if method == "rf":
return _attack_rf(features, labels, config["params"])
raise ValueError("Method {} is not support.".format(config["method"]))
return None
......@@ -19,10 +19,11 @@ import numpy as np
import mindspore as ms
from mindspore.train import Model
from mindspore.dataset.engine import Dataset
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model, method_list
def _eval_info(pred, truth, option):
"""
......@@ -89,7 +90,7 @@ class MembershipInference:
def __init__(self, model):
if not isinstance(model, Model):
raise TypeError("Type of model must be {}, but got {}.".format(type(Model), type(model)))
raise TypeError("Type of parameter 'model' must be Model, but got {}.".format(type(model)))
self.model = model
self.attack_list = []
......@@ -104,8 +105,24 @@ class MembershipInference:
attack_config (list): Parameter setting for the attack model.
Raises:
ValueError: If the method in attack_config is not in ["LR", "KNN", "RF", "MLPC"].
KeyError: If each config in attack_config doesn't have keys {"method", "params"}
ValueError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"].
"""
if not isinstance(dataset_train, Dataset):
raise TypeError("Type of parameter 'dataset_train' must be Dataset, "
"but got {}".format(type(dataset_train)))
if not isinstance(dataset_test, Dataset):
raise TypeError("Type of parameter 'test_train' must be Dataset, "
"but got {}".format(type(dataset_train)))
for config in attack_config:
if {"params", "method"} != set(config.keys()):
raise KeyError("Each config in attack_config must have keys 'method' and 'params', "
"but your key value is {}.".format(set(config.keys())))
if str.lower(config["method"]) not in method_list:
raise ValueError("Method {} is not support.".format(config["method"]))
features, labels = self._transform(dataset_train, dataset_test)
for config in attack_config:
self.attack_list.append(get_attack_model(features, labels, config))
......@@ -124,6 +141,24 @@ class MembershipInference:
Returns:
list, Each element contains an evaluation indicator for the attack model.
"""
if not isinstance(dataset_train, Dataset):
raise TypeError("Type of parameter 'dataset_train' must be Dataset, "
"but got {}".format(type(dataset_train)))
if not isinstance(dataset_test, Dataset):
raise TypeError("Type of parameter 'test_train' must be Dataset, "
"but got {}".format(type(dataset_train)))
if not isinstance(metrics, (list, tuple)):
raise TypeError("Type of parameter 'config' must be Union[list, tuple], but got "
"{}.".format(type(metrics)))
metrics = set(metrics)
metrics_list = {"precision", "accruacy", "recall"}
if metrics > metrics_list:
raise ValueError("Element in 'metrics' must be in {}, but got "
"{}.".format(metrics_list, metrics))
result = []
features, labels = self._transform(dataset_train, dataset_test)
for attacker in self.attack_list:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册