diff --git a/mindarmour/diff_privacy/evaluation/membership_inference.py b/mindarmour/diff_privacy/evaluation/membership_inference.py index c27882d9858776516402bed6d2feed6b0cc52eb0..a91c5fb89faafa19c85f4ac75070992bde299d8b 100755 --- a/mindarmour/diff_privacy/evaluation/membership_inference.py +++ b/mindarmour/diff_privacy/evaluation/membership_inference.py @@ -82,7 +82,12 @@ def _softmax_cross_entropy(logits, labels): """ labels = np.eye(logits.shape[1])[labels].astype(np.int32) logits = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) - return -1*np.sum(labels*np.log(logits), axis=1) + loss = -1*np.sum(labels*np.log(logits), axis=1) + + nan_index = np.isnan(loss) + if np.any(nan_index): + loss[nan_index] = 0 + return loss class MembershipInference: @@ -243,6 +248,7 @@ class MembershipInference: np.random.shuffle(shuffle_index) features = features[shuffle_index] labels = labels[shuffle_index] + return features, labels def _generate(self, dataset_x, label):