提交 af8cb820 编写于 作者: Y Yang Yu

Fix bug of nce_op

* also div num_samples when return cost of nce_op
上级 f035f327
......@@ -197,7 +197,8 @@ class NCEGradKernel : public framework::OpKernel<T> {
// get d_x
auto d_x = context.Output<Tensor>(framework::GradVarName("Input"));
if (d_x != nullptr) {
d_x->mutable_data<T>(context.GetPlace());
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
std::fill(d_x_data, d_x_data + d_x->numel(), 0.0);
auto d_x_matrix = EigenMatrix<T>::From(*d_x);
auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
for (int64_t i = 0; i < sample_labels->numel(); ++i) {
......
......@@ -2001,9 +2001,15 @@ def nce(input,
sample_logits = helper.create_tmp_variable(dtype=input.dtype)
sample_labels = helper.create_tmp_variable(dtype=label.dtype)
attrs = {'num_total_classes': int(num_total_classes)}
if num_neg_samples is not None:
attrs['num_neg_samples'] = int(num_neg_samples)
if num_neg_samples is None:
num_neg_samples = 10
else:
num_neg_samples = int(num_neg_samples)
attrs = {
'num_total_classes': int(num_total_classes),
'num_neg_samples': num_neg_samples
}
helper.append_op(
type='nce',
......@@ -2020,4 +2026,4 @@ def nce(input,
'SampleLabels': sample_labels
},
attrs=attrs)
return cost
return cost / (num_neg_samples + 1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册