提交 3a3ff173 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!98 Fixed exception detection and added log printing.

Merge pull request !98 from liuluobin/master
......@@ -21,6 +21,11 @@ from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RandomizedSearchCV
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
TAG = "Attacker"
def _attack_knn(features, labels, param_grid):
"""
......@@ -138,4 +143,7 @@ def get_attack_model(features, labels, config):
return _attack_mlpc(features, labels, config["params"])
if method == "rf":
return _attack_rf(features, labels, config["params"])
return None
msg = "Method {} is not supported.".format(config["method"])
LOGGER.error(TAG, msg)
raise ValueError(msg)
......@@ -24,6 +24,11 @@ import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
TAG = "MembershipInference"
def _eval_info(pred, truth, option):
"""
......@@ -43,7 +48,9 @@ def _eval_info(pred, truth, option):
ValueError, value of parameter option must be in ["precision", "accuracy", "recall"].
"""
if pred.size == 0 or truth.size == 0:
raise ValueError("Size of pred or truth is 0.")
msg = "Size of pred or truth is 0."
LOGGER.error(TAG, msg)
raise ValueError(msg)
if option == "accuracy":
count = np.sum(pred == truth)
......@@ -59,7 +66,9 @@ def _eval_info(pred, truth, option):
return -1
return count / np.sum(truth)
raise ValueError("The metric value {} is undefined.".format(option))
msg = "The metric value {} is undefined.".format(option)
LOGGER.error(TAG, msg)
raise ValueError(msg)
class MembershipInference:
......@@ -91,7 +100,10 @@ class MembershipInference:
def __init__(self, model):
if not isinstance(model, Model):
raise TypeError("Type of parameter 'model' must be Model, but got {}.".format(type(model)))
msg = "Type of parameter 'model' must be Model, but got {}.".format(type(model))
LOGGER.error(TAG, msg)
raise TypeError(msg)
self.model = model
self.method_list = ["knn", "lr", "mlp", "rf"]
self.attack_list = []
......@@ -117,26 +129,34 @@ class MembershipInference:
ValueError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"].
"""
if not isinstance(dataset_train, Dataset):
raise TypeError("Type of parameter 'dataset_train' must be Dataset, "
"but got {}".format(type(dataset_train)))
msg = "Type of parameter 'dataset_train' must be Dataset, but got {}".format(type(dataset_train))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if not isinstance(dataset_test, Dataset):
raise TypeError("Type of parameter 'test_train' must be Dataset, "
"but got {}".format(type(dataset_train)))
msg = "Type of parameter 'test_train' must be Dataset, but got {}".format(type(dataset_train))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if not isinstance(attack_config, list):
raise TypeError("Type of parameter 'attack_config' must be list, "
"but got {}.".format(type(attack_config)))
msg = "Type of parameter 'attack_config' must be list, but got {}.".format(type(attack_config))
LOGGER.error(TAG, msg)
raise TypeError(msg)
for config in attack_config:
if not isinstance(config, dict):
raise TypeError("Type of each config in 'attack_config' must be dict, "
"but got {}.".format(type(config)))
msg = "Type of each config in 'attack_config' must be dict, but got {}.".format(type(config))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if {"params", "method"} != set(config.keys()):
raise KeyError("Each config in attack_config must have keys 'method' and 'params', "
"but your key value is {}.".format(set(config.keys())))
msg = "Each config in attack_config must have keys 'method' and 'params'," \
"but your key value is {}.".format(set(config.keys()))
LOGGER.error(TAG, msg)
raise KeyError(msg)
if str.lower(config["method"]) not in self.method_list:
raise ValueError("Method {} is not support.".format(config["method"]))
msg = "Method {} is not support.".format(config["method"])
LOGGER.error(TAG, msg)
raise ValueError(msg)
features, labels = self._transform(dataset_train, dataset_test)
for config in attack_config:
......@@ -157,22 +177,26 @@ class MembershipInference:
list, Each element contains an evaluation indicator for the attack model.
"""
if not isinstance(dataset_train, Dataset):
raise TypeError("Type of parameter 'dataset_train' must be Dataset, "
"but got {}".format(type(dataset_train)))
msg = "Type of parameter 'dataset_train' must be Dataset, but got {}".format(type(dataset_train))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if not isinstance(dataset_test, Dataset):
raise TypeError("Type of parameter 'test_train' must be Dataset, "
"but got {}".format(type(dataset_train)))
msg = "Type of parameter 'test_train' must be Dataset, but got {}".format(type(dataset_train))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if not isinstance(metrics, (list, tuple)):
raise TypeError("Type of parameter 'config' must be Union[list, tuple], but got "
"{}.".format(type(metrics)))
msg = "Type of parameter 'config' must be Union[list, tuple], but got {}.".format(type(metrics))
LOGGER.error(TAG, msg)
raise TypeError(msg)
metrics = set(metrics)
metrics_list = {"precision", "accruacy", "recall"}
if metrics > metrics_list:
raise ValueError("Element in 'metrics' must be in {}, but got "
"{}.".format(metrics_list, metrics))
msg = "Element in 'metrics' must be in {}, but got {}.".format(metrics_list, metrics)
LOGGER.error(TAG, msg)
raise ValueError(msg)
result = []
features, labels = self._transform(dataset_train, dataset_test)
......@@ -221,8 +245,10 @@ class MembershipInference:
- numpy.ndarray, Labels for each sample, Shape is (N,).
"""
if context.get_context("device_target") != "Ascend":
raise RuntimeError("The target device must be Ascend, "
"but current is {}.".format(context.get_context("device_target")))
msg = "The target device must be Ascend, " \
"but current is {}.".format(context.get_context("device_target"))
LOGGER.error(TAG, msg)
raise RuntimeError(msg)
loss_logits = np.array([])
for batch in dataset_x.create_dict_iterator():
batch_data = Tensor(batch['image'], ms.float32)
......@@ -243,5 +269,7 @@ class MembershipInference:
elif label == 0:
labels = np.zeros(len(loss_logits), np.int32)
else:
raise ValueError("The value of label must be 0 or 1, but got {}.".format(label))
msg = "The value of label must be 0 or 1, but got {}.".format(label)
LOGGER.error(TAG, msg)
raise ValueError(msg)
return loss_logits, labels
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册