提交 b78ab87b 编写于 作者: X xuezhong

refine code

上级 3c8aa787
...@@ -5768,7 +5768,7 @@ def softmax_with_cross_entropy(logits, ...@@ -5768,7 +5768,7 @@ def softmax_with_cross_entropy(logits,
def sampled_softmax_with_cross_entropy(logits, def sampled_softmax_with_cross_entropy(logits,
label, label,
num_samples, num_samples,
num_true=num_true, num_true=1,
remove_accidental_hits=True, remove_accidental_hits=True,
use_custom_samples=False, use_custom_samples=False,
custom_samples=None, custom_samples=None,
...@@ -5865,15 +5865,19 @@ def sampled_softmax_with_cross_entropy(logits, ...@@ -5865,15 +5865,19 @@ def sampled_softmax_with_cross_entropy(logits,
'num_samples': num_samples, 'num_samples': num_samples,
'seed': seed 'seed': seed
}) })
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
helper.append_op( helper.append_op(
type='softmax_with_cross_entropy', type='softmax_with_cross_entropy',
inputs={ inputs={'Logits': sampled_logits,
'Logits': sampled_logits, 'Label': sampled_label},
'Label': sampled_label, outputs={'Softmax': softmax,
'Loss': loss},
attrs={
'soft_label': False, 'soft_label': False,
}, 'ignore_index': False,
outputs={'loss': samples, }) 'numeric_stable_mode': False
})
return outputs / num_true return outputs / num_true
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册