diff --git a/mindarmour/diff_privacy/evaluation/membership_inference.py b/mindarmour/diff_privacy/evaluation/membership_inference.py index 4240f29bd22f11f53daa913e4ae249636569ddf6..2841af3cb9ddf51ad69ae6c8136e55ff86b4c80e 100755 --- a/mindarmour/diff_privacy/evaluation/membership_inference.py +++ b/mindarmour/diff_privacy/evaluation/membership_inference.py @@ -124,7 +124,14 @@ class MembershipInference: raise TypeError("Type of parameter 'test_train' must be Dataset, " "but got {}".format(type(dataset_train))) + if not isinstance(attack_config, list): + raise TypeError("Type of parameter 'attack_config' must be list, " + "but got {}.".format(type(attack_config))) + for config in attack_config: + if not isinstance(config, dict): + raise TypeError("Type of each config in 'attack_config' must be dict, " + "but got {}.".format(type(config))) if {"params", "method"} != set(config.keys()): raise KeyError("Each config in attack_config must have keys 'method' and 'params', " "but your key value is {}.".format(set(config.keys())))