提交 16062255 编写于 作者: L liuluobin

remove parameter 'iid' in GridSearchCV. Increase support for parallel...

remove parameter 'iid' in GridSearchCV. Increase support for parallel training. Modify ut case. Add attack config verfication.
上级 fd3eb11b
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Verify attack config
"""
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
TAG = "check_params"
def _is_positive_int(item):
"""
Verify that the value is a positive integer.
"""
if not isinstance(item, int) or item <= 0:
return False
return True
def _is_non_negative_int(item):
"""
Verify that the value is a non-negative integer.
"""
if not isinstance(item, int) or item < 0:
return False
return True
def _is_positive_float(item):
"""
Verify that value is a positive number.
"""
if not isinstance(item, (int, float)) or item <= 0:
return False
return True
def _is_non_negative_float(item):
"""
Verify that value is a non-negative number.
"""
if not isinstance(item, (int, float)) or item < 0:
return False
return True
def _is_positive_int_tuple(item):
"""
Verify that the input parameter is a positive integer tuple.
"""
if not isinstance(item, tuple):
return False
for i in item:
if not _is_positive_int(i):
return False
return True
def _is_dict(item):
"""
Check whether the type is dict.
"""
return isinstance(item, dict)
VALID_PARAMS_DICT = {
"knn": {
"n_neighbors": [_is_positive_int],
"weights": [{"uniform", "distance"}],
"algorithm": [{"auto", "ball_tree", "kd_tree", "brute"}],
"leaf_size": [_is_positive_int],
"p": [_is_positive_int],
"metric": None,
"metric_params": None,
},
"lr": {
"penalty": [{"l1", "l2", "elasticnet", "none"}],
"dual": [{True, False}],
"tol": [_is_positive_float],
"C": [_is_positive_float],
"fit_intercept": [{True, False}],
"intercept_scaling": [_is_positive_float],
"class_weight": [{"balanced", None}, _is_dict],
"random_state": None,
"solver": [{"newton-cg", "lbfgs", "liblinear", "sag", "saga"}]
},
"mlp": {
"hidden_layer_sizes": [_is_positive_int_tuple],
"activation": [{"identity", "logistic", "tanh", "relu"}],
"solver": {"lbfgs", "sgd", "adam"},
"alpha": [_is_positive_float],
"batch_size": [{"auto"}, _is_positive_int],
"learning_rate": [{"constant", "invscaling", "adaptive"}],
"learning_rate_init": [_is_positive_float],
"power_t": [_is_positive_float],
"max_iter": [_is_positive_int],
"shuffle": [{True, False}],
"random_state": None,
"tol": [_is_positive_float],
"verbose": [{True, False}],
"warm_start": [{True, False}],
"momentum": [_is_positive_float],
"nesterovs_momentum": [{True, False}],
"early_stopping": [{True, False}],
"validation_fraction": [_is_positive_float],
"beta_1": [_is_positive_float],
"beta_2": [_is_positive_float],
"epsilon": [_is_positive_float],
"n_iter_no_change": [_is_positive_int],
"max_fun": [_is_positive_int]
},
"rf": {
"n_estimators": [_is_positive_int],
"criterion": [{"gini", "entropy"}],
"max_depth": [_is_positive_int],
"min_samples_split": [_is_positive_float],
"min_samples_leaf": [_is_positive_float],
"min_weight_fraction_leaf": [_is_non_negative_float],
"max_features": [{"auto", "sqrt", "log2", None}, _is_positive_float],
"max_leaf_nodes": [_is_positive_int, {None}],
"min_impurity_decrease": {_is_non_negative_float},
"min_impurity_split": [{None}, _is_positive_float],
"bootstrap": [{True, False}],
"oob_scroe": [{True, False}],
"n_jobs": [_is_positive_int, {None}],
"random_state": None,
"verbose": [_is_non_negative_int],
"warm_start": [{True, False}],
"class_weight": None,
"ccp_alpha": [_is_non_negative_float],
"max_samples": [_is_positive_float]
}
}
def _check_config(config_list, check_params):
"""
Verify that config_list is valid.
Check_params is the valid value range of the parameter.
"""
if not isinstance(config_list, (list, tuple)):
msg = "Type of parameter 'config_list' must be list, but got {}.".format(type(config_list))
LOGGER.error(TAG, msg)
raise TypeError(msg)
for config in config_list:
if not isinstance(config, dict):
msg = "Type of each config in config_list must be dict, but got {}.".format(type(config))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if set(config.keys()) != {"params", "method"}:
msg = "Keys of each config in config_list must be {}," \
"but got {}.".format({'method', 'params'}, set(config.keys()))
LOGGER.error(TAG, msg)
raise KeyError(msg)
method = str.lower(config["method"])
params = config["params"]
if method not in check_params.keys():
msg = "Method {} is not supported.".format(method)
LOGGER.error(TAG, msg)
raise ValueError(msg)
if not params.keys() <= check_params[method].keys():
msg = "Params in method {} is not accepted, the parameters " \
"that can be set are {}.".format(method, set(check_params[method].keys()))
LOGGER.error(TAG, msg)
raise KeyError(msg)
for param_key in params.keys():
param_value = params[param_key]
candidate_values = check_params[method][param_key]
if not isinstance(param_value, list):
msg = "The parameter '{}' in method '{}' setting must within the range of " \
"changeable parameters.".format(param_key, method)
LOGGER.error(TAG, msg)
raise ValueError(msg)
if candidate_values is None:
continue
for item_value in param_value:
flag = False
for candidate_value in candidate_values:
if isinstance(candidate_value, set) and item_value in candidate_value:
flag = True
break
elif candidate_value(item_value):
flag = True
break
if not flag:
msg = "Setting of parmeter {} in method {} is invalid".format(param_key, method)
raise ValueError(msg)
def check_config_params(config_list):
"""
External interfaces to verify attack config.
"""
_check_config(config_list, VALID_PARAMS_DICT)
...@@ -27,7 +27,7 @@ LOGGER = LogUtil.get_instance() ...@@ -27,7 +27,7 @@ LOGGER = LogUtil.get_instance()
TAG = "Attacker" TAG = "Attacker"
def _attack_knn(features, labels, param_grid): def _attack_knn(features, labels, param_grid, n_jobs):
""" """
Train and return a KNN model. Train and return a KNN model.
...@@ -35,20 +35,21 @@ def _attack_knn(features, labels, param_grid): ...@@ -35,20 +35,21 @@ def _attack_knn(features, labels, param_grid):
features (numpy.ndarray): Loss and logits characteristics of each sample. features (numpy.ndarray): Loss and logits characteristics of each sample.
labels (numpy.ndarray): Labels of each sample whether belongs to training set. labels (numpy.ndarray): Labels of each sample whether belongs to training set.
param_grid (dict): Setting of GridSearchCV. param_grid (dict): Setting of GridSearchCV.
n_jobs (int): Number of jobs run in parallel. -1 means using all processors,
otherwise the value of n_jobs must be a positive integer.
Returns: Returns:
sklearn.model_selection.GridSearchCV, trained model. sklearn.model_selection.GridSearchCV, trained model.
""" """
knn_model = KNeighborsClassifier() knn_model = KNeighborsClassifier()
knn_model = GridSearchCV( knn_model = GridSearchCV(
knn_model, param_grid=param_grid, cv=3, n_jobs=1, iid=False, knn_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0,
verbose=0,
) )
knn_model.fit(X=features, y=labels) knn_model.fit(X=features, y=labels)
return knn_model return knn_model
def _attack_lr(features, labels, param_grid): def _attack_lr(features, labels, param_grid, n_jobs):
""" """
Train and return a LR model. Train and return a LR model.
...@@ -56,20 +57,21 @@ def _attack_lr(features, labels, param_grid): ...@@ -56,20 +57,21 @@ def _attack_lr(features, labels, param_grid):
features (numpy.ndarray): Loss and logits characteristics of each sample. features (numpy.ndarray): Loss and logits characteristics of each sample.
labels (numpy.ndarray): Labels of each sample whether belongs to training set. labels (numpy.ndarray): Labels of each sample whether belongs to training set.
param_grid (dict): Setting of GridSearchCV. param_grid (dict): Setting of GridSearchCV.
n_jobs (int): Number of jobs run in parallel. -1 means using all processors,
otherwise the value of n_jobs must be a positive integer.
Returns: Returns:
sklearn.model_selection.GridSearchCV, trained model. sklearn.model_selection.GridSearchCV, trained model.
""" """
lr_model = LogisticRegression(C=1.0, penalty="l2", max_iter=1000) lr_model = LogisticRegression(C=1.0, penalty="l2", max_iter=300)
lr_model = GridSearchCV( lr_model = GridSearchCV(
lr_model, param_grid=param_grid, cv=3, n_jobs=1, iid=False, lr_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0,
verbose=0,
) )
lr_model.fit(X=features, y=labels) lr_model.fit(X=features, y=labels)
return lr_model return lr_model
def _attack_mlpc(features, labels, param_grid): def _attack_mlpc(features, labels, param_grid, n_jobs):
""" """
Train and return a MLPC model. Train and return a MLPC model.
...@@ -77,20 +79,21 @@ def _attack_mlpc(features, labels, param_grid): ...@@ -77,20 +79,21 @@ def _attack_mlpc(features, labels, param_grid):
features (numpy.ndarray): Loss and logits characteristics of each sample. features (numpy.ndarray): Loss and logits characteristics of each sample.
labels (numpy.ndarray): Labels of each sample whether belongs to training set. labels (numpy.ndarray): Labels of each sample whether belongs to training set.
param_grid (dict): Setting of GridSearchCV. param_grid (dict): Setting of GridSearchCV.
n_jobs (int): Number of jobs run in parallel. -1 means using all processors,
otherwise the value of n_jobs must be a positive integer.
Returns: Returns:
sklearn.model_selection.GridSearchCV, trained model. sklearn.model_selection.GridSearchCV, trained model.
""" """
mlpc_model = MLPClassifier(random_state=1, max_iter=300) mlpc_model = MLPClassifier(random_state=1, max_iter=300)
mlpc_model = GridSearchCV( mlpc_model = GridSearchCV(
mlpc_model, param_grid=param_grid, cv=3, n_jobs=1, iid=False, mlpc_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0,
verbose=0,
) )
mlpc_model.fit(features, labels) mlpc_model.fit(features, labels)
return mlpc_model return mlpc_model
def _attack_rf(features, labels, random_grid): def _attack_rf(features, labels, random_grid, n_jobs):
""" """
Train and return a RF model. Train and return a RF model.
...@@ -98,20 +101,22 @@ def _attack_rf(features, labels, random_grid): ...@@ -98,20 +101,22 @@ def _attack_rf(features, labels, random_grid):
features (numpy.ndarray): Loss and logits characteristics of each sample. features (numpy.ndarray): Loss and logits characteristics of each sample.
labels (numpy.ndarray): Labels of each sample whether belongs to training set. labels (numpy.ndarray): Labels of each sample whether belongs to training set.
random_grid (dict): Setting of RandomizedSearchCV. random_grid (dict): Setting of RandomizedSearchCV.
n_jobs (int): Number of jobs run in parallel. -1 means using all processors,
otherwise the value of n_jobs must be a positive integer.
Returns: Returns:
sklearn.model_selection.RandomizedSearchCV, trained model. sklearn.model_selection.RandomizedSearchCV, trained model.
""" """
rf_model = RandomForestClassifier(max_depth=2, random_state=0) rf_model = RandomForestClassifier(max_depth=2, random_state=0)
rf_model = RandomizedSearchCV( rf_model = RandomizedSearchCV(
rf_model, param_distributions=random_grid, n_iter=7, cv=3, n_jobs=1, rf_model, param_distributions=random_grid, n_iter=7, cv=3, n_jobs=n_jobs,
iid=False, verbose=0, verbose=0,
) )
rf_model.fit(features, labels) rf_model.fit(features, labels)
return rf_model return rf_model
def get_attack_model(features, labels, config): def get_attack_model(features, labels, config, n_jobs=-1):
""" """
Get trained attack model specify by config. Get trained attack model specify by config.
...@@ -123,6 +128,8 @@ def get_attack_model(features, labels, config): ...@@ -123,6 +128,8 @@ def get_attack_model(features, labels, config):
params of each method must within the range of changeable parameters. params of each method must within the range of changeable parameters.
Tips of params implement can be found in Tips of params implement can be found in
"https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html". "https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html".
n_jobs (int): Number of jobs run in parallel. -1 means using all processors,
otherwise the value of n_jobs must be a positive integer.
Returns: Returns:
sklearn.BaseEstimator, trained model specify by config["method"]. sklearn.BaseEstimator, trained model specify by config["method"].
...@@ -136,13 +143,13 @@ def get_attack_model(features, labels, config): ...@@ -136,13 +143,13 @@ def get_attack_model(features, labels, config):
method = str.lower(config["method"]) method = str.lower(config["method"])
if method == "knn": if method == "knn":
return _attack_knn(features, labels, config["params"]) return _attack_knn(features, labels, config["params"], n_jobs)
if method == "lr": if method == "lr":
return _attack_lr(features, labels, config["params"]) return _attack_lr(features, labels, config["params"], n_jobs)
if method == "mlp": if method == "mlp":
return _attack_mlpc(features, labels, config["params"]) return _attack_mlpc(features, labels, config["params"], n_jobs)
if method == "rf": if method == "rf":
return _attack_rf(features, labels, config["params"]) return _attack_rf(features, labels, config["params"], n_jobs)
msg = "Method {} is not supported.".format(config["method"]) msg = "Method {} is not supported.".format(config["method"])
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
......
...@@ -15,14 +15,16 @@ ...@@ -15,14 +15,16 @@
Membership Inference Membership Inference
""" """
from multiprocessing import cpu_count
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
from mindspore.train import Model from mindspore.train import Model
from mindspore.dataset.engine import Dataset from mindspore.dataset.engine import Dataset
from mindspore import Tensor from mindspore import Tensor
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model
from mindarmour.utils.logger import LogUtil from mindarmour.utils.logger import LogUtil
from .attacker import get_attack_model
from ._check_config import check_config_params
LOGGER = LogUtil.get_instance() LOGGER = LogUtil.get_instance()
TAG = "MembershipInference" TAG = "MembershipInference"
...@@ -101,13 +103,15 @@ class MembershipInference: ...@@ -101,13 +103,15 @@ class MembershipInference:
Args: Args:
model (Model): Target model. model (Model): Target model.
n_jobs (int): Number of jobs run in parallel. -1 means using all processors,
otherwise the value of n_jobs must be a positive integer.
Examples: Examples:
>>> train_1, train_2 are non-overlapping datasets from training dataset of target model. >>> train_1, train_2 are non-overlapping datasets from training dataset of target model.
>>> test_1, test_2 are non-overlapping datasets from test dataset of target model. >>> test_1, test_2 are non-overlapping datasets from test dataset of target model.
>>> We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model. >>> We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model.
>>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) >>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'})
>>> inference_model = MembershipInference(model) >>> inference_model = MembershipInference(model, n_jobs=-1)
>>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] >>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}]
>>> inference_model.train(train_1, test_1, config) >>> inference_model.train(train_1, test_1, config)
>>> metrics = ["precision", "recall", "accuracy"] >>> metrics = ["precision", "recall", "accuracy"]
...@@ -115,15 +119,26 @@ class MembershipInference: ...@@ -115,15 +119,26 @@ class MembershipInference:
Raises: Raises:
TypeError: If type of model is not mindspore.train.Model. TypeError: If type of model is not mindspore.train.Model.
TypeError: If type of n_jobs is not int.
ValueError: The value of n_jobs is neither -1 nor a positive integer.
""" """
def __init__(self, model): def __init__(self, model, n_jobs=-1):
if not isinstance(model, Model): if not isinstance(model, Model):
msg = "Type of parameter 'model' must be Model, but got {}.".format(type(model)) msg = "Type of parameter 'model' must be Model, but got {}.".format(type(model))
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise TypeError(msg) raise TypeError(msg)
if not isinstance(n_jobs, int):
msg = "Type of parameter 'n_jobs' must be int, but got {}".format(type(n_jobs))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if not (n_jobs == -1 or n_jobs > 0):
msg = "Value of n_jobs must be either -1 or positive integer, but got {}.".format(n_jobs)
LOGGER.error(TAG, msg)
raise ValueError(msg)
self.model = model self.model = model
self.n_jobs = min(n_jobs, cpu_count())
self.method_list = ["knn", "lr", "mlp", "rf"] self.method_list = ["knn", "lr", "mlp", "rf"]
self.attack_list = [] self.attack_list = []
...@@ -162,24 +177,13 @@ class MembershipInference: ...@@ -162,24 +177,13 @@ class MembershipInference:
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise TypeError(msg) raise TypeError(msg)
for config in attack_config: check_config_params(attack_config) # Verify attack config.
if not isinstance(config, dict):
msg = "Type of each config in 'attack_config' must be dict, but got {}.".format(type(config))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if {"params", "method"} != set(config.keys()):
msg = "Each config in attack_config must have keys 'method' and 'params'," \
"but your key value is {}.".format(set(config.keys()))
LOGGER.error(TAG, msg)
raise KeyError(msg)
if str.lower(config["method"]) not in self.method_list:
msg = "Method {} is not support.".format(config["method"])
LOGGER.error(TAG, msg)
raise ValueError(msg)
features, labels = self._transform(dataset_train, dataset_test) features, labels = self._transform(dataset_train, dataset_test)
for config in attack_config: for config in attack_config:
self.attack_list.append(get_attack_model(features, labels, config)) self.attack_list.append(get_attack_model(features, labels, config, n_jobs=self.n_jobs))
def eval(self, dataset_train, dataset_test, metrics): def eval(self, dataset_train, dataset_test, metrics):
""" """
......
...@@ -35,7 +35,7 @@ def test_get_knn_model(): ...@@ -35,7 +35,7 @@ def test_get_knn_model():
"n_neighbors": [3], "n_neighbors": [3],
} }
} }
knn_attacker = get_attack_model(features, labels, config_knn) knn_attacker = get_attack_model(features, labels, config_knn, -1)
pred = knn_attacker.predict(features) pred = knn_attacker.predict(features)
assert pred is not None assert pred is not None
...@@ -54,7 +54,7 @@ def test_get_lr_model(): ...@@ -54,7 +54,7 @@ def test_get_lr_model():
"C": np.logspace(-4, 2, 10), "C": np.logspace(-4, 2, 10),
} }
} }
lr_attacker = get_attack_model(features, labels, config_lr) lr_attacker = get_attack_model(features, labels, config_lr, -1)
pred = lr_attacker.predict(features) pred = lr_attacker.predict(features)
assert pred is not None assert pred is not None
...@@ -75,7 +75,7 @@ def test_get_mlp_model(): ...@@ -75,7 +75,7 @@ def test_get_mlp_model():
"alpha": [0.0001, 0.001, 0.01], "alpha": [0.0001, 0.001, 0.01],
} }
} }
mlpc_attacker = get_attack_model(features, labels, config_mlpc) mlpc_attacker = get_attack_model(features, labels, config_mlpc, -1)
pred = mlpc_attacker.predict(features) pred = mlpc_attacker.predict(features)
assert pred is not None assert pred is not None
...@@ -98,6 +98,6 @@ def test_get_rf_model(): ...@@ -98,6 +98,6 @@ def test_get_rf_model():
"min_samples_leaf": [1, 2, 4], "min_samples_leaf": [1, 2, 4],
} }
} }
rf_attacker = get_attack_model(features, labels, config_rf) rf_attacker = get_attack_model(features, labels, config_rf, -1)
pred = rf_attacker.predict(features) pred = rf_attacker.predict(features)
assert pred is not None assert pred is not None
...@@ -24,6 +24,7 @@ import numpy as np ...@@ -24,6 +24,7 @@ import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import nn from mindspore import nn
from mindspore.train import Model from mindspore.train import Model
import mindspore.context as context
from mindarmour.diff_privacy.evaluation.membership_inference import MembershipInference from mindarmour.diff_privacy.evaluation.membership_inference import MembershipInference
...@@ -31,6 +32,8 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) ...@@ -31,6 +32,8 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../"))
from defenses.mock_net import Net from defenses.mock_net import Net
context.set_context(mode=context.GRAPH_MODE)
def dataset_generator(batch_size, batches): def dataset_generator(batch_size, batches):
"""mock training data.""" """mock training data."""
data = np.random.randn(batches*batch_size, 1, 32, 32).astype( data = np.random.randn(batches*batch_size, 1, 32, 32).astype(
...@@ -51,7 +54,7 @@ def test_get_membership_inference_object(): ...@@ -51,7 +54,7 @@ def test_get_membership_inference_object():
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(network=net, loss_fn=loss, optimizer=opt) model = Model(network=net, loss_fn=loss, optimizer=opt)
inference_model = MembershipInference(model) inference_model = MembershipInference(model, -1)
assert isinstance(inference_model, MembershipInference) assert isinstance(inference_model, MembershipInference)
...@@ -65,7 +68,7 @@ def test_membership_inference_object_train(): ...@@ -65,7 +68,7 @@ def test_membership_inference_object_train():
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(network=net, loss_fn=loss, optimizer=opt) model = Model(network=net, loss_fn=loss, optimizer=opt)
inference_model = MembershipInference(model) inference_model = MembershipInference(model, -1)
assert isinstance(inference_model, MembershipInference) assert isinstance(inference_model, MembershipInference)
config = [{ config = [{
...@@ -95,7 +98,7 @@ def test_membership_inference_eval(): ...@@ -95,7 +98,7 @@ def test_membership_inference_eval():
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(network=net, loss_fn=loss, optimizer=opt) model = Model(network=net, loss_fn=loss, optimizer=opt)
inference_model = MembershipInference(model) inference_model = MembershipInference(model, -1)
assert isinstance(inference_model, MembershipInference) assert isinstance(inference_model, MembershipInference)
batch_size = 16 batch_size = 16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册