提交 fd3eb11b 编写于 作者: L liuluobin

fixed softmax_cross_entropy return NaN value

上级 3868fd41
......@@ -82,7 +82,12 @@ def _softmax_cross_entropy(logits, labels):
"""
labels = np.eye(logits.shape[1])[labels].astype(np.int32)
logits = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
return -1*np.sum(labels*np.log(logits), axis=1)
loss = -1*np.sum(labels*np.log(logits), axis=1)
nan_index = np.isnan(loss)
if np.any(nan_index):
loss[nan_index] = 0
return loss
class MembershipInference:
......@@ -243,6 +248,7 @@ class MembershipInference:
np.random.shuffle(shuffle_index)
features = features[shuffle_index]
labels = labels[shuffle_index]
return features, labels
def _generate(self, dataset_x, label):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册