提交 d99219c2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!105 fixed softmax_cross_entropy return NaN value

Merge pull request !105 from liuluobin/master
...@@ -82,7 +82,12 @@ def _softmax_cross_entropy(logits, labels): ...@@ -82,7 +82,12 @@ def _softmax_cross_entropy(logits, labels):
""" """
labels = np.eye(logits.shape[1])[labels].astype(np.int32) labels = np.eye(logits.shape[1])[labels].astype(np.int32)
logits = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) logits = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
return -1*np.sum(labels*np.log(logits), axis=1) loss = -1*np.sum(labels*np.log(logits), axis=1)
nan_index = np.isnan(loss)
if np.any(nan_index):
loss[nan_index] = 0
return loss
class MembershipInference: class MembershipInference:
...@@ -243,6 +248,7 @@ class MembershipInference: ...@@ -243,6 +248,7 @@ class MembershipInference:
np.random.shuffle(shuffle_index) np.random.shuffle(shuffle_index)
features = features[shuffle_index] features = features[shuffle_index]
labels = labels[shuffle_index] labels = labels[shuffle_index]
return features, labels return features, labels
def _generate(self, dataset_x, label): def _generate(self, dataset_x, label):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册