未验证 提交 b4555028 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #7723 from reyoung/feature/wrap_nce_loss

Wrap NCE to python
......@@ -124,7 +124,8 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
"This attribute only be used in unitest. Classes "
"in this list wiil be used as negative classes "
"for every samples. Under normal conditions, "
"user should avoid setting this attribute.");
"user should avoid setting this attribute.")
.SetDefault({});
AddComment(R"DOC(
Compute and return the noise-contrastive estimation training loss.
See [Noise-contrastive estimation: A new estimation principle for unnormalized statistical models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
......
......@@ -197,7 +197,8 @@ class NCEGradKernel : public framework::OpKernel<T> {
// get d_x
auto d_x = context.Output<Tensor>(framework::GradVarName("Input"));
if (d_x != nullptr) {
d_x->mutable_data<T>(context.GetPlace());
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
std::fill(d_x_data, d_x_data + d_x->numel(), 0.0);
auto d_x_matrix = EigenMatrix<T>::From(*d_x);
auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
for (int64_t i = 0; i < sample_labels->numel(); ++i) {
......
......@@ -19,6 +19,7 @@ from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant
from ..framework import Variable
from ..param_attr import ParamAttr
from layer_function_generator import autodoc
from tensor import concat
__all__ = [
......@@ -57,6 +58,7 @@ __all__ = [
'warpctc',
'sequence_reshape',
'transpose',
'nce',
]
......@@ -2190,6 +2192,61 @@ def sequence_reshape(input, new_dim):
return out
@autodoc()
def nce(input,
label,
num_total_classes,
sample_weight=None,
param_attr=None,
bias_attr=None,
num_neg_samples=None):
helper = LayerHelper('nce', **locals())
assert isinstance(input, Variable)
dim = input.shape[1]
assert isinstance(label, Variable)
num_true_class = label.shape[1]
w = helper.create_parameter(
attr=helper.param_attr,
shape=[num_total_classes, dim],
is_bias=False,
dtype=input.dtype)
b = helper.create_parameter(
attr=helper.bias_attr,
shape=[num_total_classes, 1],
is_bias=True,
dtype=input.dtype)
cost = helper.create_tmp_variable(dtype=input.dtype)
sample_logits = helper.create_tmp_variable(dtype=input.dtype)
sample_labels = helper.create_tmp_variable(dtype=label.dtype)
if num_neg_samples is None:
num_neg_samples = 10
else:
num_neg_samples = int(num_neg_samples)
attrs = {
'num_total_classes': int(num_total_classes),
'num_neg_samples': num_neg_samples
}
helper.append_op(
type='nce',
inputs={
'Input': input,
'Label': label,
'Weight': w,
'Bias': b,
'SampleWeight': sample_weight if sample_weight is not None else []
},
outputs={
'Cost': cost,
'SampleLogits': sample_logits,
'SampleLabels': sample_labels
},
attrs=attrs)
return cost / (num_neg_samples + 1)
def transpose(x, perm, name=None):
"""
**transpose Layer**
......
......@@ -17,8 +17,9 @@ import unittest
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.nets as nets
from paddle.v2.fluid.framework import Program, program_guard
from paddle.v2.fluid.framework import Program, program_guard, default_main_program
from paddle.v2.fluid.param_attr import ParamAttr
import decorators
class TestBook(unittest.TestCase):
......@@ -225,6 +226,41 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(out)
print(str(program))
@decorators.prog_scope()
def test_nce(self):
window_size = 5
words = []
for i in xrange(window_size):
words.append(
layers.data(
name='word_{0}'.format(i), shape=[1], dtype='int64'))
dict_size = 10000
label_word = int(window_size / 2) + 1
embs = []
for i in xrange(window_size):
if i == label_word:
continue
emb = layers.embedding(
input=words[i],
size=[dict_size, 32],
param_attr='emb.w',
is_sparse=True)
embs.append(emb)
embs = layers.concat(input=embs, axis=1)
loss = layers.nce(input=embs,
label=words[label_word],
num_total_classes=dict_size,
param_attr='nce.w',
bias_attr='nce.b')
avg_loss = layers.mean(x=loss)
self.assertIsNotNone(avg_loss)
print(str(default_main_program()))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册