未验证 提交 489a64ef 编写于 作者: 努力努力在努力丶's avatar 努力努力在努力丶 提交者: GitHub

fix cross_entropy when run static graph mode of mlu and npu (#40621)

上级 cb8afc24
......@@ -1818,12 +1818,16 @@ def cross_entropy(input,
helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=input.dtype)
out = helper.create_variable_for_type_inference(dtype=input.dtype)
outputs = {'Softmax': softmax, 'Loss': out}
if core.is_compiled_with_npu() or core.is_compiled_with_mlu():
backprop = helper.create_variable_for_type_inference(dtype=input.dtype)
outputs['Backprop'] = backprop
helper.append_op(
type='softmax_with_cross_entropy',
inputs={'Logits': input,
'Label': label},
outputs={'Softmax': softmax,
'Loss': out},
outputs=outputs,
attrs=attrs)
if weight is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册