From b78ab87bd31929770ccddb57160781f7e05e73ec Mon Sep 17 00:00:00 2001 From: xuezhong Date: Wed, 30 Jan 2019 16:37:14 +0000 Subject: [PATCH] refine code --- python/paddle/fluid/layers/nn.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 0a6c18669..e1387cec1 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5768,7 +5768,7 @@ def softmax_with_cross_entropy(logits, def sampled_softmax_with_cross_entropy(logits, label, num_samples, - num_true=num_true, + num_true=1, remove_accidental_hits=True, use_custom_samples=False, custom_samples=None, @@ -5865,15 +5865,19 @@ def sampled_softmax_with_cross_entropy(logits, 'num_samples': num_samples, 'seed': seed }) + loss = helper.create_variable_for_type_inference(dtype=logits.dtype) + softmax = helper.create_variable_for_type_inference(dtype=logits.dtype) helper.append_op( type='softmax_with_cross_entropy', - inputs={ - 'Logits': sampled_logits, - 'Label': sampled_label, + inputs={'Logits': sampled_logits, + 'Label': sampled_label}, + outputs={'Softmax': softmax, + 'Loss': loss}, + attrs={ 'soft_label': False, - }, - outputs={'loss': samples, }) - + 'ignore_index': False, + 'numeric_stable_mode': False + }) return outputs / num_true -- GitLab