提交 dfb30f04 编写于 作者: L liuluobin

Fixed exception detection and added log printing

上级 dcc64ddd
...@@ -21,6 +21,11 @@ from sklearn.ensemble import RandomForestClassifier ...@@ -21,6 +21,11 @@ from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RandomizedSearchCV from sklearn.model_selection import RandomizedSearchCV
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
TAG = "Attacker"
def _attack_knn(features, labels, param_grid): def _attack_knn(features, labels, param_grid):
""" """
...@@ -138,4 +143,7 @@ def get_attack_model(features, labels, config): ...@@ -138,4 +143,7 @@ def get_attack_model(features, labels, config):
return _attack_mlpc(features, labels, config["params"]) return _attack_mlpc(features, labels, config["params"])
if method == "rf": if method == "rf":
return _attack_rf(features, labels, config["params"]) return _attack_rf(features, labels, config["params"])
return None
msg = "Method {} is not supported.".format(config["method"])
LOGGER.error(TAG, msg)
raise ValueError(msg)
...@@ -24,6 +24,11 @@ import mindspore.nn as nn ...@@ -24,6 +24,11 @@ import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model from mindarmour.diff_privacy.evaluation.attacker import get_attack_model
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
TAG = "MembershipInference"
def _eval_info(pred, truth, option): def _eval_info(pred, truth, option):
""" """
...@@ -43,7 +48,9 @@ def _eval_info(pred, truth, option): ...@@ -43,7 +48,9 @@ def _eval_info(pred, truth, option):
ValueError, value of parameter option must be in ["precision", "accuracy", "recall"]. ValueError, value of parameter option must be in ["precision", "accuracy", "recall"].
""" """
if pred.size == 0 or truth.size == 0: if pred.size == 0 or truth.size == 0:
raise ValueError("Size of pred or truth is 0.") msg = "Size of pred or truth is 0."
LOGGER.error(TAG, msg)
raise ValueError(msg)
if option == "accuracy": if option == "accuracy":
count = np.sum(pred == truth) count = np.sum(pred == truth)
...@@ -59,7 +66,9 @@ def _eval_info(pred, truth, option): ...@@ -59,7 +66,9 @@ def _eval_info(pred, truth, option):
return -1 return -1
return count / np.sum(truth) return count / np.sum(truth)
raise ValueError("The metric value {} is undefined.".format(option)) msg = "The metric value {} is undefined.".format(option)
LOGGER.error(TAG, msg)
raise ValueError(msg)
class MembershipInference: class MembershipInference:
...@@ -91,7 +100,10 @@ class MembershipInference: ...@@ -91,7 +100,10 @@ class MembershipInference:
def __init__(self, model): def __init__(self, model):
if not isinstance(model, Model): if not isinstance(model, Model):
raise TypeError("Type of parameter 'model' must be Model, but got {}.".format(type(model))) msg = "Type of parameter 'model' must be Model, but got {}.".format(type(model))
LOGGER.error(TAG, msg)
raise TypeError(msg)
self.model = model self.model = model
self.method_list = ["knn", "lr", "mlp", "rf"] self.method_list = ["knn", "lr", "mlp", "rf"]
self.attack_list = [] self.attack_list = []
...@@ -117,26 +129,34 @@ class MembershipInference: ...@@ -117,26 +129,34 @@ class MembershipInference:
ValueError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"]. ValueError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"].
""" """
if not isinstance(dataset_train, Dataset): if not isinstance(dataset_train, Dataset):
raise TypeError("Type of parameter 'dataset_train' must be Dataset, " msg = "Type of parameter 'dataset_train' must be Dataset, but got {}".format(type(dataset_train))
"but got {}".format(type(dataset_train))) LOGGER.error(TAG, msg)
raise TypeError(msg)
if not isinstance(dataset_test, Dataset): if not isinstance(dataset_test, Dataset):
raise TypeError("Type of parameter 'test_train' must be Dataset, " msg = "Type of parameter 'test_train' must be Dataset, but got {}".format(type(dataset_train))
"but got {}".format(type(dataset_train))) LOGGER.error(TAG, msg)
raise TypeError(msg)
if not isinstance(attack_config, list): if not isinstance(attack_config, list):
raise TypeError("Type of parameter 'attack_config' must be list, " msg = "Type of parameter 'attack_config' must be list, but got {}.".format(type(attack_config))
"but got {}.".format(type(attack_config))) LOGGER.error(TAG, msg)
raise TypeError(msg)
for config in attack_config: for config in attack_config:
if not isinstance(config, dict): if not isinstance(config, dict):
raise TypeError("Type of each config in 'attack_config' must be dict, " msg = "Type of each config in 'attack_config' must be dict, but got {}.".format(type(config))
"but got {}.".format(type(config))) LOGGER.error(TAG, msg)
raise TypeError(msg)
if {"params", "method"} != set(config.keys()): if {"params", "method"} != set(config.keys()):
raise KeyError("Each config in attack_config must have keys 'method' and 'params', " msg = "Each config in attack_config must have keys 'method' and 'params'," \
"but your key value is {}.".format(set(config.keys()))) "but your key value is {}.".format(set(config.keys()))
LOGGER.error(TAG, msg)
raise KeyError(msg)
if str.lower(config["method"]) not in self.method_list: if str.lower(config["method"]) not in self.method_list:
raise ValueError("Method {} is not support.".format(config["method"])) msg = "Method {} is not support.".format(config["method"])
LOGGER.error(TAG, msg)
raise ValueError(msg)
features, labels = self._transform(dataset_train, dataset_test) features, labels = self._transform(dataset_train, dataset_test)
for config in attack_config: for config in attack_config:
...@@ -157,22 +177,26 @@ class MembershipInference: ...@@ -157,22 +177,26 @@ class MembershipInference:
list, Each element contains an evaluation indicator for the attack model. list, Each element contains an evaluation indicator for the attack model.
""" """
if not isinstance(dataset_train, Dataset): if not isinstance(dataset_train, Dataset):
raise TypeError("Type of parameter 'dataset_train' must be Dataset, " msg = "Type of parameter 'dataset_train' must be Dataset, but got {}".format(type(dataset_train))
"but got {}".format(type(dataset_train))) LOGGER.error(TAG, msg)
raise TypeError(msg)
if not isinstance(dataset_test, Dataset): if not isinstance(dataset_test, Dataset):
raise TypeError("Type of parameter 'test_train' must be Dataset, " msg = "Type of parameter 'test_train' must be Dataset, but got {}".format(type(dataset_train))
"but got {}".format(type(dataset_train))) LOGGER.error(TAG, msg)
raise TypeError(msg)
if not isinstance(metrics, (list, tuple)): if not isinstance(metrics, (list, tuple)):
raise TypeError("Type of parameter 'config' must be Union[list, tuple], but got " msg = "Type of parameter 'config' must be Union[list, tuple], but got {}.".format(type(metrics))
"{}.".format(type(metrics))) LOGGER.error(TAG, msg)
raise TypeError(msg)
metrics = set(metrics) metrics = set(metrics)
metrics_list = {"precision", "accruacy", "recall"} metrics_list = {"precision", "accruacy", "recall"}
if metrics > metrics_list: if metrics > metrics_list:
raise ValueError("Element in 'metrics' must be in {}, but got " msg = "Element in 'metrics' must be in {}, but got {}.".format(metrics_list, metrics)
"{}.".format(metrics_list, metrics)) LOGGER.error(TAG, msg)
raise ValueError(msg)
result = [] result = []
features, labels = self._transform(dataset_train, dataset_test) features, labels = self._transform(dataset_train, dataset_test)
...@@ -221,8 +245,10 @@ class MembershipInference: ...@@ -221,8 +245,10 @@ class MembershipInference:
- numpy.ndarray, Labels for each sample, Shape is (N,). - numpy.ndarray, Labels for each sample, Shape is (N,).
""" """
if context.get_context("device_target") != "Ascend": if context.get_context("device_target") != "Ascend":
raise RuntimeError("The target device must be Ascend, " msg = "The target device must be Ascend, " \
"but current is {}.".format(context.get_context("device_target"))) "but current is {}.".format(context.get_context("device_target"))
LOGGER.error(TAG, msg)
raise RuntimeError(msg)
loss_logits = np.array([]) loss_logits = np.array([])
for batch in dataset_x.create_dict_iterator(): for batch in dataset_x.create_dict_iterator():
batch_data = Tensor(batch['image'], ms.float32) batch_data = Tensor(batch['image'], ms.float32)
...@@ -243,5 +269,7 @@ class MembershipInference: ...@@ -243,5 +269,7 @@ class MembershipInference:
elif label == 0: elif label == 0:
labels = np.zeros(len(loss_logits), np.int32) labels = np.zeros(len(loss_logits), np.int32)
else: else:
raise ValueError("The value of label must be 0 or 1, but got {}.".format(label)) msg = "The value of label must be 0 or 1, but got {}.".format(label)
LOGGER.error(TAG, msg)
raise ValueError(msg)
return loss_logits, labels return loss_logits, labels
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册