提交 af8cb820 编写于 作者: Y Yang Yu

Fix bug of nce_op

* also div num_samples when return cost of nce_op
上级 f035f327
...@@ -197,7 +197,8 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -197,7 +197,8 @@ class NCEGradKernel : public framework::OpKernel<T> {
// get d_x // get d_x
auto d_x = context.Output<Tensor>(framework::GradVarName("Input")); auto d_x = context.Output<Tensor>(framework::GradVarName("Input"));
if (d_x != nullptr) { if (d_x != nullptr) {
d_x->mutable_data<T>(context.GetPlace()); auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
std::fill(d_x_data, d_x_data + d_x->numel(), 0.0);
auto d_x_matrix = EigenMatrix<T>::From(*d_x); auto d_x_matrix = EigenMatrix<T>::From(*d_x);
auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight"))); auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
for (int64_t i = 0; i < sample_labels->numel(); ++i) { for (int64_t i = 0; i < sample_labels->numel(); ++i) {
......
...@@ -2001,9 +2001,15 @@ def nce(input, ...@@ -2001,9 +2001,15 @@ def nce(input,
sample_logits = helper.create_tmp_variable(dtype=input.dtype) sample_logits = helper.create_tmp_variable(dtype=input.dtype)
sample_labels = helper.create_tmp_variable(dtype=label.dtype) sample_labels = helper.create_tmp_variable(dtype=label.dtype)
attrs = {'num_total_classes': int(num_total_classes)} if num_neg_samples is None:
if num_neg_samples is not None: num_neg_samples = 10
attrs['num_neg_samples'] = int(num_neg_samples) else:
num_neg_samples = int(num_neg_samples)
attrs = {
'num_total_classes': int(num_total_classes),
'num_neg_samples': num_neg_samples
}
helper.append_op( helper.append_op(
type='nce', type='nce',
...@@ -2020,4 +2026,4 @@ def nce(input, ...@@ -2020,4 +2026,4 @@ def nce(input,
'SampleLabels': sample_labels 'SampleLabels': sample_labels
}, },
attrs=attrs) attrs=attrs)
return cost return cost / (num_neg_samples + 1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册