提交 ce15e781 编写于 作者: L liuluobin

Append the parameter verification of class MembershipInference.

上级 29e303a8
...@@ -27,7 +27,7 @@ import mindspore.nn as nn ...@@ -27,7 +27,7 @@ import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_param_into_net, load_checkpoint from mindspore.train.serialization import load_param_into_net, load_checkpoint
from mindarmour.utils import LogUtil from mindarmour.utils import LogUtil
...@@ -187,12 +187,13 @@ if __name__ == '__main__': ...@@ -187,12 +187,13 @@ if __name__ == '__main__':
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
# checkpoint save # checkpoint save
callbacks = [LossMonitor()]
if args.rank_save_ckpt_flag: if args.rank_save_ckpt_flag:
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval*args.steps_per_epoch, ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval*args.steps_per_epoch,
keep_checkpoint_max=args.ckpt_save_max) keep_checkpoint_max=args.ckpt_save_max)
ckpt_cb = ModelCheckpoint(config=ckpt_config, ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=args.outputs_dir, directory=args.outputs_dir,
prefix='{}'.format(args.rank)) prefix='{}'.format(args.rank))
callbacks = ckpt_cb callbacks.append(ckpt_cb)
model.train(args.max_epoch, dataset, callbacks=callbacks) model.train(args.max_epoch, dataset, callbacks=callbacks)
...@@ -22,6 +22,9 @@ from sklearn.model_selection import GridSearchCV ...@@ -22,6 +22,9 @@ from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RandomizedSearchCV from sklearn.model_selection import RandomizedSearchCV
method_list = ["lr", "knn", "rf", "mlp"]
def _attack_knn(features, labels, param_grid): def _attack_knn(features, labels, param_grid):
""" """
Train and return a KNN model. Train and return a KNN model.
...@@ -119,12 +122,13 @@ def get_attack_model(features, labels, config): ...@@ -119,12 +122,13 @@ def get_attack_model(features, labels, config):
sklearn.BaseEstimator, trained model specify by config["method"]. sklearn.BaseEstimator, trained model specify by config["method"].
""" """
method = str.lower(config["method"]) method = str.lower(config["method"])
if method == "knn": if method == "knn":
return _attack_knn(features, labels, config["params"]) return _attack_knn(features, labels, config["params"])
if method in ["lr", "logitic regression"]: if method == "lr":
return _attack_lr(features, labels, config["params"]) return _attack_lr(features, labels, config["params"])
if method == "mlp": if method == "mlp":
return _attack_mlpc(features, labels, config["params"]) return _attack_mlpc(features, labels, config["params"])
if method in ["rf", "random forest"]: if method == "rf":
return _attack_rf(features, labels, config["params"]) return _attack_rf(features, labels, config["params"])
raise ValueError("Method {} is not support.".format(config["method"])) return None
...@@ -19,10 +19,11 @@ import numpy as np ...@@ -19,10 +19,11 @@ import numpy as np
import mindspore as ms import mindspore as ms
from mindspore.train import Model from mindspore.train import Model
from mindspore.dataset.engine import Dataset
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model from mindarmour.diff_privacy.evaluation.attacker import get_attack_model, method_list
def _eval_info(pred, truth, option): def _eval_info(pred, truth, option):
""" """
...@@ -89,7 +90,7 @@ class MembershipInference: ...@@ -89,7 +90,7 @@ class MembershipInference:
def __init__(self, model): def __init__(self, model):
if not isinstance(model, Model): if not isinstance(model, Model):
raise TypeError("Type of model must be {}, but got {}.".format(type(Model), type(model))) raise TypeError("Type of parameter 'model' must be Model, but got {}.".format(type(model)))
self.model = model self.model = model
self.attack_list = [] self.attack_list = []
...@@ -104,8 +105,24 @@ class MembershipInference: ...@@ -104,8 +105,24 @@ class MembershipInference:
attack_config (list): Parameter setting for the attack model. attack_config (list): Parameter setting for the attack model.
Raises: Raises:
ValueError: If the method in attack_config is not in ["LR", "KNN", "RF", "MLPC"]. KeyError: If each config in attack_config doesn't have keys {"method", "params"}
ValueError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"].
""" """
if not isinstance(dataset_train, Dataset):
raise TypeError("Type of parameter 'dataset_train' must be Dataset, "
"but got {}".format(type(dataset_train)))
if not isinstance(dataset_test, Dataset):
raise TypeError("Type of parameter 'test_train' must be Dataset, "
"but got {}".format(type(dataset_train)))
for config in attack_config:
if {"params", "method"} != set(config.keys()):
raise KeyError("Each config in attack_config must have keys 'method' and 'params', "
"but your key value is {}.".format(set(config.keys())))
if str.lower(config["method"]) not in method_list:
raise ValueError("Method {} is not support.".format(config["method"]))
features, labels = self._transform(dataset_train, dataset_test) features, labels = self._transform(dataset_train, dataset_test)
for config in attack_config: for config in attack_config:
self.attack_list.append(get_attack_model(features, labels, config)) self.attack_list.append(get_attack_model(features, labels, config))
...@@ -124,6 +141,24 @@ class MembershipInference: ...@@ -124,6 +141,24 @@ class MembershipInference:
Returns: Returns:
list, Each element contains an evaluation indicator for the attack model. list, Each element contains an evaluation indicator for the attack model.
""" """
if not isinstance(dataset_train, Dataset):
raise TypeError("Type of parameter 'dataset_train' must be Dataset, "
"but got {}".format(type(dataset_train)))
if not isinstance(dataset_test, Dataset):
raise TypeError("Type of parameter 'test_train' must be Dataset, "
"but got {}".format(type(dataset_train)))
if not isinstance(metrics, (list, tuple)):
raise TypeError("Type of parameter 'config' must be Union[list, tuple], but got "
"{}.".format(type(metrics)))
metrics = set(metrics)
metrics_list = {"precision", "accruacy", "recall"}
if metrics > metrics_list:
raise ValueError("Element in 'metrics' must be in {}, but got "
"{}.".format(metrics_list, metrics))
result = [] result = []
features, labels = self._transform(dataset_train, dataset_test) features, labels = self._transform(dataset_train, dataset_test)
for attacker in self.attack_list: for attacker in self.attack_list:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册