提交 b78ab87b 编写于 作者: X xuezhong

refine code

上级 3c8aa787
......@@ -5768,7 +5768,7 @@ def softmax_with_cross_entropy(logits,
def sampled_softmax_with_cross_entropy(logits,
label,
num_samples,
num_true=num_true,
num_true=1,
remove_accidental_hits=True,
use_custom_samples=False,
custom_samples=None,
......@@ -5865,15 +5865,19 @@ def sampled_softmax_with_cross_entropy(logits,
'num_samples': num_samples,
'seed': seed
})
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
helper.append_op(
type='softmax_with_cross_entropy',
inputs={
'Logits': sampled_logits,
'Label': sampled_label,
inputs={'Logits': sampled_logits,
'Label': sampled_label},
outputs={'Softmax': softmax,
'Loss': loss},
attrs={
'soft_label': False,
},
outputs={'loss': samples, })
'ignore_index': False,
'numeric_stable_mode': False
})
return outputs / num_true
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册