未验证 提交 fef94eac 编写于 作者: C chajchaj 提交者: GitHub

fix cross_entropy bug of the axis parameter in log_softmax (#27311)

上级 d28162b9
...@@ -26,7 +26,7 @@ def stable_softmax(x): ...@@ -26,7 +26,7 @@ def stable_softmax(x):
return exps / np.sum(exps) return exps / np.sum(exps)
def log_softmax(x, axis=-1): def log_softmax(x, axis=1):
softmax_out = np.apply_along_axis(stable_softmax, axis, x) softmax_out = np.apply_along_axis(stable_softmax, axis, x)
return np.log(softmax_out) return np.log(softmax_out)
......
...@@ -1093,7 +1093,7 @@ def cross_entropy(input, ...@@ -1093,7 +1093,7 @@ def cross_entropy(input,
" 'none', but received %s, which is not allowed." % reduction) " 'none', but received %s, which is not allowed." % reduction)
#step 1. log_softmax #step 1. log_softmax
log_softmax_out = paddle.nn.functional.log_softmax(input) log_softmax_out = paddle.nn.functional.log_softmax(input, axis=1)
if weight is not None and not isinstance(weight, Variable): if weight is not None and not isinstance(weight, Variable):
raise ValueError( raise ValueError(
"The weight' is not a Variable, please convert to Variable.") "The weight' is not a Variable, please convert to Variable.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册