From 09d32b068cbdf65f93e98f7b357dbc7e90f11734 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Thu, 16 Nov 2017 00:01:55 +0800 Subject: [PATCH] Add unitest and comments. --- paddle/operators/nce_op.cc | 115 +++++++++++++------ paddle/operators/nce_op.h | 79 +++++++------ python/paddle/v2/framework/tests/test_nce.py | 96 ++++++++++++++++ 3 files changed, 212 insertions(+), 78 deletions(-) create mode 100644 python/paddle/v2/framework/tests/test_nce.py diff --git a/paddle/operators/nce_op.cc b/paddle/operators/nce_op.cc index afd61b8851..c365d5d922 100644 --- a/paddle/operators/nce_op.cc +++ b/paddle/operators/nce_op.cc @@ -23,57 +23,87 @@ class NCEOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X")); + PADDLE_ENFORCE(ctx->HasInput("Input")); PADDLE_ENFORCE(ctx->HasInput("Label")); - PADDLE_ENFORCE(ctx->HasInput("W")); - PADDLE_ENFORCE(ctx->HasOutput("Out")); + PADDLE_ENFORCE(ctx->HasInput("Weight")); + PADDLE_ENFORCE(ctx->HasOutput("Cost")); PADDLE_ENFORCE(ctx->HasOutput("SampleLogits")); PADDLE_ENFORCE(ctx->HasOutput("SampleLabels")); - auto x_dims = ctx->GetInputDim("X"); + auto x_dims = ctx->GetInputDim("Input"); auto label_dims = ctx->GetInputDim("Label"); PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]); - if (ctx->HasInput("B")) { - PADDLE_ENFORCE_EQ(ctx->GetInputDim("W")[0], ctx->GetInputDim("B")[0]); + int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1; + if (ctx->HasInput("Bias")) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Weight")[0], + ctx->GetInputDim("Bias")[0]); } - int num_sampled_classes = ctx->Attrs().Get("num_sampled_classes"); - int num_classes = ctx->Attrs().Get("num_classes"); - PADDLE_ENFORCE_EQ(num_classes, ctx->GetInputDim("W")[0]); + auto num_sampled_classes = ctx->Attrs().Get("num_sampled_classes"); + auto num_classes = ctx->Attrs().Get("num_classes"); + std::vector sampled_labels = + ctx->Attrs().Get>("sampled_labels"); + PADDLE_ENFORCE_EQ(num_classes, ctx->GetInputDim("Weight")[0]); PADDLE_ENFORCE_LT(num_sampled_classes, num_classes); - + if (sampled_labels.size() > 0) { + PADDLE_ENFORCE_EQ(sampled_labels.size(), + static_cast(num_sampled_classes)); + } // set dims of output(Out) - std::vector out_dims(1); + std::vector out_dims; out_dims.push_back(x_dims[0]); - ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + ctx->SetOutputDim("Cost", framework::make_ddim(out_dims)); // set dims of output(SampleOut) - std::vector sample_out_dims(2); + std::vector sample_out_dims; sample_out_dims.push_back(x_dims[0]); - sample_out_dims.push_back(num_sampled_classes + 1); + sample_out_dims.push_back(num_sampled_classes + num_true_classes); ctx->SetOutputDim("SampleLogits", framework::make_ddim(sample_out_dims)); ctx->SetOutputDim("SampleLabels", framework::make_ddim(sample_out_dims)); } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), + ctx.device_context()); + } }; class NCEOpMaker : public framework::OpProtoAndCheckerMaker { public: NCEOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", ""); - AddInput("Label", ""); - AddInput("W", ""); - AddInput("B", ""); - AddInput("SampleWeight", ""); - AddOutput("Out", ""); - AddOutput("SampleLogits", ""); - AddOutput("SampleLabels", ""); - AddAttr("num_classes", ""); - AddAttr("num_sampled_classes", "").SetDefault(10); + AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim]."); + AddInput("Label", + "(Tensor) A tensor of shape [batch_size, num_true_class]. " + "'num_true_class' is the number of target class in each sample."); + AddInput("Weight", + "(Tensor) A tensor of shape [num_class, dim]. 'num_class' is the " + "total number of class."); + AddInput("Bias", + "(Tensor) A tensor of shape [num_class]. 'num_class' is the total " + "number of class. It is a dispensable input.") + .AsDispensable(); + AddInput("SampleWeight", + "(Tensor) A tensor of shape [batch_size] storing a weight for " + "each sample. And it is a dispensable input. The default value of " + "sample is 1.") + .AsDispensable(); + AddOutput("Cost", + "(Tensor) A tensor of shape [batch_size]. Cost of samples."); + AddOutput("SampleLogits", "An intermediate tensor.").AsIntermediate(); + AddOutput("SampleLabels", "An intermediate tensor.").AsIntermediate(); + AddAttr("num_classes", "Total number of classes."); + AddAttr("num_sampled_classes", "The number of negative classes.") + .SetDefault(10); + AddAttr>("sampled_labels", ""); AddComment(R"DOC( -Expand input(X) according to LOD of input(Y). - +Computes and returns the noise-contrastive estimation training loss. +See [Noise-contrastive estimation: A new estimation principle for unnormalized statistical models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf). +By default this uses a uniform distribution for sampling. +The number of target classes per example should be same. If you have a variable number of target classes, you can pad them out to a constant number by either repeating them or by padding with an otherwise unused class. )DOC"); } }; @@ -82,32 +112,41 @@ class NCEOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X")); - PADDLE_ENFORCE(ctx->HasInput("W")); - PADDLE_ENFORCE(ctx->HasInput("Out")); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + PADDLE_ENFORCE(ctx->HasInput("Input")); + PADDLE_ENFORCE(ctx->HasInput("Weight")); + PADDLE_ENFORCE(ctx->HasInput("Cost")); + PADDLE_ENFORCE(ctx->HasInput("SampleLogits")); + PADDLE_ENFORCE(ctx->HasInput("SampleLabels")); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cost")), "The input(Out@GRAD) should not be null"); - auto x_dims = ctx->GetInputDim("X"); - auto x_grad_name = framework::GradVarName("X"); + auto x_dims = ctx->GetInputDim("Input"); + auto x_grad_name = framework::GradVarName("Input"); if (ctx->HasOutput(x_grad_name)) { ctx->SetOutputDim(x_grad_name, x_dims); } - auto w_dims = ctx->GetInputDim("W"); - auto w_grad_name = framework::GradVarName("W"); + auto w_dims = ctx->GetInputDim("Weight"); + auto w_grad_name = framework::GradVarName("Weight"); if (ctx->HasOutput(w_grad_name)) { ctx->SetOutputDim(w_grad_name, w_dims); } - auto bias_grad_name = framework::GradVarName("B"); + auto bias_grad_name = framework::GradVarName("Bias"); if (ctx->HasOutput(bias_grad_name)) { - auto bias_dims = ctx->GetInputDim("B"); + auto bias_dims = ctx->GetInputDim("Bias"); ctx->SetOutputDim(bias_grad_name, bias_dims); } } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), + ctx.device_context()); + } }; } // namespace operators diff --git a/paddle/operators/nce_op.h b/paddle/operators/nce_op.h index ce1717c9b0..3017bccdca 100644 --- a/paddle/operators/nce_op.h +++ b/paddle/operators/nce_op.h @@ -14,12 +14,11 @@ #pragma once +#include #include #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/memory/memcpy.h" #include "unsupported/Eigen/CXX11/Tensor" - namespace paddle { namespace operators { @@ -32,9 +31,12 @@ using EigenMatrix = framework::EigenMatrix; template void PrepareSamples(const framework::ExecutionContext& context) { auto label = context.Input("Label"); - const T* label_data = label->data(); + const int64_t* label_data = label->data(); auto label_dims = label->dims(); int num_classes = context.Attr("num_classes"); + // for unitest + std::vector sampled_labels = + context.Attr>("sampled_labels"); // random machine std::random_device rd; std::mt19937 rng(rd()); @@ -42,19 +44,24 @@ void PrepareSamples(const framework::ExecutionContext& context) { auto sample_labels = context.Output("SampleLabels"); auto sample_labels_dims = sample_labels->dims(); - int* sample_labels_data = - sample_labels->mutable_data(context.GetPlace()); + int64_t* sample_labels_data = + sample_labels->mutable_data(context.GetPlace()); int num_label = label_dims.size() == 2 ? label_dims[1] : 1; + int index = 0; for (size_t i = 0; i < label_dims[0]; ++i) { int j = 0; for (; j < num_label; ++j) { - sample_labels_data[sample_labels_dims[1] * i + j] = - label_data[i * num_label + j]; + sample_labels_data[index++] = label_data[i * num_label + j]; } - for (; j < sample_labels_dims[1]; ++j) { - int id = rand(rng); - sample_labels_data[sample_labels_dims[1] * i + j] = id; + if (sampled_labels.size() > 0) { + for (auto label : sampled_labels) { + sample_labels_data[index++] = label; + } + } else { + for (; j < sample_labels_dims[1]; ++j) { + sample_labels_data[index++] = rand(rng); + } } } } @@ -65,7 +72,7 @@ class NCEKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { PrepareSamples(context); auto sample_labels = context.Output("SampleLabels"); - const int* sample_labels_data = sample_labels->data(); + const int64_t* sample_labels_data = sample_labels->data(); auto sample_out = context.Output("SampleLogits"); T* sample_out_data = sample_out->mutable_data(context.GetPlace()); auto label = context.Input("Label"); @@ -74,7 +81,7 @@ class NCEKernel : public framework::OpKernel { if (sample_weight != nullptr) { sample_weight_data = sample_weight->data(); } - auto out = context.Output("Out"); + auto out = context.Output("Cost"); T* out_data = out->mutable_data(context.GetPlace()); int num_smalped_classes = context.Attr("num_sampled_classes"); int num_classes = context.Attr("num_classes"); @@ -83,9 +90,8 @@ class NCEKernel : public framework::OpKernel { num_true_class = label->dims()[1]; } T b = 1. / num_classes * num_smalped_classes; - // forward bias - auto bias = context.Input("B"); + auto bias = context.Input("Bias"); if (bias != nullptr) { const T* bias_data = bias->data(); for (size_t i = 0; i < sample_labels->numel(); ++i) { @@ -96,27 +102,23 @@ class NCEKernel : public framework::OpKernel { sample_out_data[i] = 0; } } - // forward mul - auto input_mat = EigenMatrix::From(*(context.Input("X"))); - auto weight_mat = EigenMatrix::From(*(context.Input("W"))); + auto input_mat = EigenMatrix::From(*(context.Input("Input"))); + auto weight_mat = EigenMatrix::From(*(context.Input("Weight"))); for (size_t i = 0; i < sample_labels->numel(); ++i) { - // sample_out_data[i] += (input_mat.chip((int)(i / - // sample_labels->dims()[1]), 0) * weight_mat.chip(sample_labels_data[i], - // 0)).sum(); Eigen::Tensor result = (input_mat.chip((int)(i / sample_labels->dims()[1]), 0) * weight_mat.chip(sample_labels_data[i], 0)) .sum(); sample_out_data[i] += result(0); // activation_->forward - sample_out_data[i] = (1 / 1 + (sample_out_data[i])); + sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i]))); } - // forward cost for (size_t i = 0; i < sample_labels->dims()[0]; ++i) { size_t j = 0; - T w = sample_weight == nullptr ? 1 : sample_weight_data[i]; + out_data[i] = 0; + T w = sample_weight == nullptr ? 1. : sample_weight_data[i]; // for true classes for (; j < num_true_class; ++j) { T o = sample_out_data[i * sample_out->dims()[1] + j]; @@ -137,11 +139,13 @@ template class NCEGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + auto d_out = context.Input(framework::GradVarName("Cost")); + const T* d_out_data = d_out->data(); auto label = context.Input("Label"); auto sample_out = context.Input("SampleLogits"); const T* sample_out_data = sample_out->data(); auto sample_labels = context.Input("SampleLabels"); - const int* sample_labels_data = sample_labels->data(); + const int64_t* sample_labels_data = sample_labels->data(); auto sample_weight = context.Input("SampleWeight"); const T* sample_weight_data = nullptr; if (sample_weight != nullptr) { @@ -154,11 +158,9 @@ class NCEGradKernel : public framework::OpKernel { num_true_class = label->dims()[1]; } T b = 1. / num_classes * num_smalped_classes; - Tensor sample_grad; // tmp tensor T* sample_grad_data = sample_grad.mutable_data(sample_labels->dims(), context.GetPlace()); - // backward cost for (size_t i = 0; i < sample_labels->numel(); ++i) { T o = sample_out_data[i]; @@ -166,15 +168,12 @@ class NCEGradKernel : public framework::OpKernel { ? 1 : sample_weight_data[i / sample_labels->dims()[1]]; sample_grad_data[i] = (i % sample_labels->dims()[1]) < num_true_class - ? -w * b / (o * (o + b)) - : w / (o + b); - // sigmoid->backward - sample_grad_data[i] = - (o > 0) ? sample_grad_data[i] : ((o < 0) ? -sample_grad_data[i] : 0); + ? w * (b / (o + b)) * (o - 1) + : w * (o * (1 - o) / (o + b)); + sample_grad_data[i] *= d_out_data[i / sample_labels->dims()[1]]; } - // get d_bias - auto d_bias = context.Output(framework::GradVarName("B")); + auto d_bias = context.Output(framework::GradVarName("Bias")); if (d_bias != nullptr) { T* d_bias_data = d_bias->mutable_data(context.GetPlace()); for (size_t i = 0; i < sample_labels->numel(); ++i) { @@ -182,22 +181,23 @@ class NCEGradKernel : public framework::OpKernel { } } // get d_w - auto d_w = context.Output(framework::GradVarName("W")); + auto d_w = context.Output(framework::GradVarName("Weight")); if (d_w != nullptr) { + d_w->mutable_data(context.GetPlace()); auto d_w_matrix = EigenMatrix::From(*d_w); - auto x_matrix = EigenMatrix::From(*(context.Input("X"))); + auto x_matrix = EigenMatrix::From(*(context.Input("Input"))); for (size_t i = 0; i < sample_labels->numel(); ++i) { - d_w_matrix.chip(sample_labels_data[i], 0) = + d_w_matrix.chip(sample_labels_data[i], 0) += x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) * sample_grad_data[i]; } } - // get d_x - auto d_x = context.Output(framework::GradVarName("X")); + auto d_x = context.Output(framework::GradVarName("Input")); if (d_x != nullptr) { + d_x->mutable_data(context.GetPlace()); auto d_x_matrix = EigenMatrix::From(*d_x); - auto w_matrix = EigenMatrix::From(*(context.Input("W"))); + auto w_matrix = EigenMatrix::From(*(context.Input("Weight"))); for (size_t i = 0; i < sample_labels->numel(); ++i) { d_x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) += w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i]; @@ -205,6 +205,5 @@ class NCEGradKernel : public framework::OpKernel { } } }; - } // namespace operators } // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_nce.py b/python/paddle/v2/framework/tests/test_nce.py new file mode 100644 index 0000000000..8b1e7a6bb5 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_nce.py @@ -0,0 +1,96 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def nce(input, weight, bias, sample_weight, labels, num_classes, + num_sample_class): + samples = [] + sample_labels = [] + batch_size = input.shape[0] + num_true_class = labels.shape[1] + for i in range(batch_size): + w = 1 if sample_weight is None else sample_weight[i] + for label in labels[i]: + samples.append((i, label, True, w)) + sample_labels.append(label) + for num in range(num_sample_class): + samples.append((i, num, False, w)) + sample_labels.append(num) + # forward bias + sampleOut = np.zeros(len(samples)).astype(np.float32) + if bias is not None: + for i in range(len(samples)): + sampleOut[i] = bias[samples[i][1]] + # forward weight + for i in range(len(samples)): + sampleOut[i] += np.dot(input[samples[i][0]], weight[samples[i][1]]) + + # forward activation + sampleOut = 1.0 / (1.0 + np.exp(-sampleOut)) + # forward cost + out = np.zeros(batch_size).astype(np.float32) + b = 1.0 / num_classes * num_sample_class + for i in range(len(samples)): + o = sampleOut[i] + cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b)) + out[samples[i][0]] += cost * samples[i][3] + return (out, np.array(sampleOut).reshape(batch_size, + num_sample_class + num_true_class), + np.array(sample_labels).reshape(batch_size, + num_sample_class + num_true_class)) + + +class TestNCE(OpTest): + def generate_data(self, dim, batch_size, num_classes, num_true_class, + num_sampled_classes): + input = np.random.randn(batch_size, dim).astype(np.float32) + weight = np.random.randn(num_classes, dim).astype(np.float32) + bias = np.random.randn(num_classes).astype(np.float32) + sample_weight = np.random.randn(batch_size).astype(np.float32) + labels = np.random.randint(0, num_classes, (batch_size, num_true_class)) + self.attrs = { + 'num_classes': num_classes, + 'num_sampled_classes': num_sampled_classes, + 'sampled_labels': range(num_sampled_classes) + } + self.inputs = { + 'X': input, + 'Label': labels, + 'W': weight, + 'B': bias, + 'SampleWeight': sample_weight + } + + def set_data(self): + self.generate_data(5, 5, 4, 1, 2) + + def compute(self): + out = nce(self.inputs['X'], self.inputs['W'], self.inputs['B'], + self.inputs['SampleWeight'], self.inputs['Label'], + self.attrs['num_classes'], self.attrs['num_sampled_classes']) + self.outputs = { + 'Out': out[0], + 'SampleLogits': out[1], + 'SampleLabels': out[2] + } + + def setUp(self): + self.op_type = 'nce' + self.set_data() + self.compute() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X", "W", "B"], "Out", max_relative_error=0.02) + + +class TestNCECase1(TestNCE): + def set_data(self): + self.generate_data(10, 20, 10, 2, 5) + + +if __name__ == '__main__': + unittest.main() -- GitLab