未验证 提交 3a255881 编写于 作者: C chajchaj 提交者: GitHub

fix use_softmax=False does not work, test=develop (#32035)

上级 1f8834ad
......@@ -1388,6 +1388,8 @@ def cross_entropy(input,
"should be '-100', but received %s, which is not allowed." %
ignore_index)
softmax_switch = use_softmax
input_dims = len(list(input.shape))
label_dims = len(list(label.shape))
if input_dims - 1 != label_dims and input_dims != label_dims:
......@@ -1400,7 +1402,7 @@ def cross_entropy(input,
_, out = core.ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', True, 'axis', axis,
'use_softmax', use_softmax)
'softmax_switch', softmax_switch)
if weight is not None:
......@@ -1482,7 +1484,7 @@ def cross_entropy(input,
'ignore_index': ignore_index,
'numeric_stable_mode': True,
'axis': axis,
'use_softmax': use_softmax
'softmax_switch': softmax_switch
}
helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=input.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册