未验证 提交 f857e079 编写于 作者: X xuezhong 提交者: GitHub

Merge pull request #15893 from xuezhong/add_sample_logits_op

fix bug for sampled softmax
......@@ -5921,6 +5921,8 @@ def sampled_softmax_with_cross_entropy(logits,
sampled_logits \
= helper.create_variable_for_type_inference(dtype=logits.dtype)
sampled_label = helper.create_variable_for_type_inference(dtype='int64')
sampled_softlabel = helper.create_variable_for_type_inference(
dtype=logits.dtype)
helper.append_op(
type='sample_logits',
......@@ -5945,14 +5947,20 @@ def sampled_softmax_with_cross_entropy(logits,
})
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
helper.append_op(
type='one_hot',
inputs={'X': sampled_label},
attrs={'depth': num_samples + 1},
outputs={'Out': sampled_softlabel})
helper.append_op(
type='softmax_with_cross_entropy',
inputs={'Logits': sampled_logits,
'Label': sampled_label},
'Label': sampled_softlabel},
outputs={'Softmax': softmax,
'Loss': loss},
attrs={
'soft_label': False,
'soft_label': True,
'ignore_index': False,
'numeric_stable_mode': False
})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册