提交 29262ab2 编写于 作者: W wanghaoshuang

Fix unitest.

上级 76a65a83
......@@ -41,11 +41,11 @@ class NCEOp : public framework::OperatorWithKernel {
}
auto num_neg_samples = ctx->Attrs().Get<int>("num_neg_samples");
auto num_total_classes = ctx->Attrs().Get<int>("num_total_classes");
std::vector<int> sampled_labels =
ctx->Attrs().Get<std::vector<int>>("sampled_labels");
std::vector<int> custom_neg_classes =
ctx->Attrs().Get<std::vector<int>>("custom_neg_classes");
PADDLE_ENFORCE_EQ(num_total_classes, ctx->GetInputDim("Weight")[0]);
if (sampled_labels.size() > 0) {
PADDLE_ENFORCE_EQ(sampled_labels.size(),
if (custom_neg_classes.size() > 0) {
PADDLE_ENFORCE_EQ(custom_neg_classes.size(),
static_cast<size_t>(num_neg_samples));
}
// set dims of output(Out)
......
......@@ -33,14 +33,14 @@ void PrepareSamples(const framework::ExecutionContext& context) {
auto label = context.Input<Tensor>("Label");
const int64_t* label_data = label->data<int64_t>();
auto label_dims = label->dims();
int num_classes = context.Attr<int>("num_classes");
int num_total_classes = context.Attr<int>("num_total_classes");
// for unitest
std::vector<int> custom_neg_classes =
context.Attr<std::vector<int>>("custom_neg_classes");
// random machine
std::random_device rd;
std::mt19937 rng(rd());
std::uniform_int_distribution<int> rand(0, num_classes - 1);
std::uniform_int_distribution<int> rand(0, num_total_classes - 1);
auto sample_labels = context.Output<Tensor>("SampleLabels");
auto sample_labels_dims = sample_labels->dims();
......@@ -84,13 +84,13 @@ class NCEKernel : public framework::OpKernel<T> {
}
auto out = context.Output<Tensor>("Cost");
T* out_data = out->mutable_data<T>(context.GetPlace());
int num_smalped_classes = context.Attr<int>("num_sampled_classes");
int num_classes = context.Attr<int>("num_classes");
int num_neg_samples = context.Attr<int>("num_neg_samples");
int num_total_classes = context.Attr<int>("num_total_classes");
int num_true_class = 1;
if (label != nullptr) {
num_true_class = label->dims()[1];
}
T b = 1. / num_classes * num_smalped_classes;
T b = 1. / num_total_classes * num_neg_samples;
// forward bias
auto bias = context.Input<Tensor>("Bias");
if (bias != nullptr) {
......@@ -151,13 +151,13 @@ class NCEGradKernel : public framework::OpKernel<T> {
if (sample_weight != nullptr) {
sample_weight_data = sample_weight->data<T>();
}
int num_smalped_classes = context.Attr<int>("num_sampled_classes");
int num_classes = context.Attr<int>("num_classes");
int num_neg_samples = context.Attr<int>("num_neg_samples");
int num_total_classes = context.Attr<int>("num_total_classes");
int num_true_class = 1;
if (label != nullptr) {
num_true_class = label->dims()[1];
}
T b = 1. / num_classes * num_smalped_classes;
T b = 1. / num_total_classes * num_neg_samples;
Tensor sample_grad; // tmp tensor
T* sample_grad_data =
sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace());
......
......@@ -35,7 +35,7 @@ def nce(input, weight, bias, sample_weight, labels, num_classes,
o = sample_out[i]
cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b))
out[samples[i][0]] += cost * samples[i][3]
return (out, np.array(sample_out).reshape(
return (out[:, np.newaxis], np.array(sample_out).reshape(
batch_size, num_sample_class + num_true_class),
np.array(sample_labels).reshape(batch_size,
num_sample_class + num_true_class))
......@@ -43,16 +43,16 @@ def nce(input, weight, bias, sample_weight, labels, num_classes,
class TestNCE(OpTest):
def generate_data(self, dim, batch_size, num_classes, num_true_class,
num_sampled_classes):
num_neg_samples):
input = np.random.randn(batch_size, dim).astype(np.float32)
weight = np.random.randn(num_classes, dim).astype(np.float32)
bias = np.random.randn(num_classes).astype(np.float32)
sample_weight = np.random.randn(batch_size).astype(np.float32)
labels = np.random.randint(0, num_classes, (batch_size, num_true_class))
self.attrs = {
'num_classes': num_classes,
'num_sampled_classes': num_sampled_classes,
'sampled_labels': range(num_sampled_classes)
'num_total_classes': num_classes,
'num_neg_samples': num_neg_samples,
'custom_neg_classes': range(num_neg_samples)
}
self.inputs = {
'Input': input,
......@@ -68,8 +68,8 @@ class TestNCE(OpTest):
def compute(self):
out = nce(self.inputs['Input'], self.inputs['Weight'],
self.inputs['Bias'], self.inputs['SampleWeight'],
self.inputs['Label'], self.attrs['num_classes'],
self.attrs['num_sampled_classes'])
self.inputs['Label'], self.attrs['num_total_classes'],
self.attrs['num_neg_samples'])
self.outputs = {
'Cost': out[0],
'SampleLogits': out[1],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册