AddAttr<int>("num_classes","Total number of classes.");
AddAttr<int>("num_sampled_classes","The number of negative classes.")
.SetDefault(10);
AddAttr<std::vector<int>>("sampled_labels","");
AddComment(R"DOC(
Expand input(X) according to LOD of input(Y).
Computes and returns the noise-contrastive estimation training loss.
See [Noise-contrastive estimation: A new estimation principle for unnormalized statistical models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
By default this uses a uniform distribution for sampling.
The number of target classes per example should be same. If you have a variable number of target classes, you can pad them out to a constant number by either repeating them or by padding with an otherwise unused class.
)DOC");
}
};
...
...
@@ -82,32 +112,41 @@ class NCEOpGrad : public framework::OperatorWithKernel {