diff --git a/paddle/operators/nce_op.cc b/paddle/operators/nce_op.cc index bb9346b134c88ddffb54da5c9fa42ed4cfabd51f..952da10434df01a10fc713a017084d315a2a59d3 100644 --- a/paddle/operators/nce_op.cc +++ b/paddle/operators/nce_op.cc @@ -41,11 +41,11 @@ class NCEOp : public framework::OperatorWithKernel { } auto num_neg_samples = ctx->Attrs().Get("num_neg_samples"); auto num_total_classes = ctx->Attrs().Get("num_total_classes"); - std::vector sampled_labels = - ctx->Attrs().Get>("sampled_labels"); + std::vector custom_neg_classes = + ctx->Attrs().Get>("custom_neg_classes"); PADDLE_ENFORCE_EQ(num_total_classes, ctx->GetInputDim("Weight")[0]); - if (sampled_labels.size() > 0) { - PADDLE_ENFORCE_EQ(sampled_labels.size(), + if (custom_neg_classes.size() > 0) { + PADDLE_ENFORCE_EQ(custom_neg_classes.size(), static_cast(num_neg_samples)); } // set dims of output(Out) diff --git a/paddle/operators/nce_op.h b/paddle/operators/nce_op.h index 8df20f432dadebf9815edd9f9c35bc60983ed07c..ea92a797fe18e218be602e019f3fda6bc0b05f33 100644 --- a/paddle/operators/nce_op.h +++ b/paddle/operators/nce_op.h @@ -33,14 +33,14 @@ void PrepareSamples(const framework::ExecutionContext& context) { auto label = context.Input("Label"); const int64_t* label_data = label->data(); auto label_dims = label->dims(); - int num_classes = context.Attr("num_classes"); + int num_total_classes = context.Attr("num_total_classes"); // for unitest std::vector custom_neg_classes = context.Attr>("custom_neg_classes"); // random machine std::random_device rd; std::mt19937 rng(rd()); - std::uniform_int_distribution rand(0, num_classes - 1); + std::uniform_int_distribution rand(0, num_total_classes - 1); auto sample_labels = context.Output("SampleLabels"); auto sample_labels_dims = sample_labels->dims(); @@ -84,13 +84,13 @@ class NCEKernel : public framework::OpKernel { } auto out = context.Output("Cost"); T* out_data = out->mutable_data(context.GetPlace()); - int num_smalped_classes = context.Attr("num_sampled_classes"); - int num_classes = context.Attr("num_classes"); + int num_neg_samples = context.Attr("num_neg_samples"); + int num_total_classes = context.Attr("num_total_classes"); int num_true_class = 1; if (label != nullptr) { num_true_class = label->dims()[1]; } - T b = 1. / num_classes * num_smalped_classes; + T b = 1. / num_total_classes * num_neg_samples; // forward bias auto bias = context.Input("Bias"); if (bias != nullptr) { @@ -151,13 +151,13 @@ class NCEGradKernel : public framework::OpKernel { if (sample_weight != nullptr) { sample_weight_data = sample_weight->data(); } - int num_smalped_classes = context.Attr("num_sampled_classes"); - int num_classes = context.Attr("num_classes"); + int num_neg_samples = context.Attr("num_neg_samples"); + int num_total_classes = context.Attr("num_total_classes"); int num_true_class = 1; if (label != nullptr) { num_true_class = label->dims()[1]; } - T b = 1. / num_classes * num_smalped_classes; + T b = 1. / num_total_classes * num_neg_samples; Tensor sample_grad; // tmp tensor T* sample_grad_data = sample_grad.mutable_data(sample_labels->dims(), context.GetPlace()); diff --git a/python/paddle/v2/fluid/tests/test_nce.py b/python/paddle/v2/fluid/tests/test_nce.py index 6cbf468e0a983c4b1440dd9e268551c1d1241855..8aeba69769525935c26576ec50035ed50d2ce44f 100644 --- a/python/paddle/v2/fluid/tests/test_nce.py +++ b/python/paddle/v2/fluid/tests/test_nce.py @@ -35,7 +35,7 @@ def nce(input, weight, bias, sample_weight, labels, num_classes, o = sample_out[i] cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b)) out[samples[i][0]] += cost * samples[i][3] - return (out, np.array(sample_out).reshape( + return (out[:, np.newaxis], np.array(sample_out).reshape( batch_size, num_sample_class + num_true_class), np.array(sample_labels).reshape(batch_size, num_sample_class + num_true_class)) @@ -43,16 +43,16 @@ def nce(input, weight, bias, sample_weight, labels, num_classes, class TestNCE(OpTest): def generate_data(self, dim, batch_size, num_classes, num_true_class, - num_sampled_classes): + num_neg_samples): input = np.random.randn(batch_size, dim).astype(np.float32) weight = np.random.randn(num_classes, dim).astype(np.float32) bias = np.random.randn(num_classes).astype(np.float32) sample_weight = np.random.randn(batch_size).astype(np.float32) labels = np.random.randint(0, num_classes, (batch_size, num_true_class)) self.attrs = { - 'num_classes': num_classes, - 'num_sampled_classes': num_sampled_classes, - 'sampled_labels': range(num_sampled_classes) + 'num_total_classes': num_classes, + 'num_neg_samples': num_neg_samples, + 'custom_neg_classes': range(num_neg_samples) } self.inputs = { 'Input': input, @@ -68,8 +68,8 @@ class TestNCE(OpTest): def compute(self): out = nce(self.inputs['Input'], self.inputs['Weight'], self.inputs['Bias'], self.inputs['SampleWeight'], - self.inputs['Label'], self.attrs['num_classes'], - self.attrs['num_sampled_classes']) + self.inputs['Label'], self.attrs['num_total_classes'], + self.attrs['num_neg_samples']) self.outputs = { 'Cost': out[0], 'SampleLogits': out[1],