未验证 提交 f857e079 编写于 作者: X xuezhong 提交者: GitHub

Merge pull request #15893 from xuezhong/add_sample_logits_op

fix bug for sampled softmax
...@@ -5921,6 +5921,8 @@ def sampled_softmax_with_cross_entropy(logits, ...@@ -5921,6 +5921,8 @@ def sampled_softmax_with_cross_entropy(logits,
sampled_logits \ sampled_logits \
= helper.create_variable_for_type_inference(dtype=logits.dtype) = helper.create_variable_for_type_inference(dtype=logits.dtype)
sampled_label = helper.create_variable_for_type_inference(dtype='int64') sampled_label = helper.create_variable_for_type_inference(dtype='int64')
sampled_softlabel = helper.create_variable_for_type_inference(
dtype=logits.dtype)
helper.append_op( helper.append_op(
type='sample_logits', type='sample_logits',
...@@ -5945,14 +5947,20 @@ def sampled_softmax_with_cross_entropy(logits, ...@@ -5945,14 +5947,20 @@ def sampled_softmax_with_cross_entropy(logits,
}) })
loss = helper.create_variable_for_type_inference(dtype=logits.dtype) loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype) softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
helper.append_op(
type='one_hot',
inputs={'X': sampled_label},
attrs={'depth': num_samples + 1},
outputs={'Out': sampled_softlabel})
helper.append_op( helper.append_op(
type='softmax_with_cross_entropy', type='softmax_with_cross_entropy',
inputs={'Logits': sampled_logits, inputs={'Logits': sampled_logits,
'Label': sampled_label}, 'Label': sampled_softlabel},
outputs={'Softmax': softmax, outputs={'Softmax': softmax,
'Loss': loss}, 'Loss': loss},
attrs={ attrs={
'soft_label': False, 'soft_label': True,
'ignore_index': False, 'ignore_index': False,
'numeric_stable_mode': False 'numeric_stable_mode': False
}) })
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册