diff --git a/paddle/operators/nce_op.cc b/paddle/operators/nce_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..952da10434df01a10fc713a017084d315a2a59d3 --- /dev/null +++ b/paddle/operators/nce_op.cc @@ -0,0 +1,186 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/nce_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class NCEOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input")); + PADDLE_ENFORCE(ctx->HasInput("Label")); + PADDLE_ENFORCE(ctx->HasInput("Weight")); + PADDLE_ENFORCE(ctx->HasOutput("Cost")); + PADDLE_ENFORCE(ctx->HasOutput("SampleLogits")); + PADDLE_ENFORCE(ctx->HasOutput("SampleLabels")); + + auto x_dims = ctx->GetInputDim("Input"); + auto label_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]); + int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1; + if (ctx->HasInput("Bias")) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Weight")[0], + ctx->GetInputDim("Bias")[0]); + } + auto num_neg_samples = ctx->Attrs().Get("num_neg_samples"); + auto num_total_classes = ctx->Attrs().Get("num_total_classes"); + std::vector custom_neg_classes = + ctx->Attrs().Get>("custom_neg_classes"); + PADDLE_ENFORCE_EQ(num_total_classes, ctx->GetInputDim("Weight")[0]); + if (custom_neg_classes.size() > 0) { + PADDLE_ENFORCE_EQ(custom_neg_classes.size(), + static_cast(num_neg_samples)); + } + // set dims of output(Out) + std::vector out_dims; + out_dims.push_back(x_dims[0]); + out_dims.push_back(1); + ctx->SetOutputDim("Cost", framework::make_ddim(out_dims)); + + // set dims of output(SampleOut) + std::vector sample_out_dims; + sample_out_dims.push_back(x_dims[0]); + sample_out_dims.push_back(num_neg_samples + num_true_classes); + ctx->SetOutputDim("SampleLogits", framework::make_ddim(sample_out_dims)); + ctx->SetOutputDim("SampleLabels", framework::make_ddim(sample_out_dims)); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), + ctx.device_context()); + } +}; + +class NCEOpMaker : public framework::OpProtoAndCheckerMaker { + public: + NCEOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim]."); + AddInput( + "Label", + "(Tensor) A tensor of shape [batch_size, num_true_class]. " + "'num_true_class' is the number of target classes in each sample." + "The number of target classes per sample should be same. " + "If you have a variable number of target classes, " + "you can pad them out to a constant number by either repeating them" + " or by padding with an otherwise unused class.)"); + AddInput("Weight", + "(Tensor) A tensor of shape [num_class, dim]. 'num_class' is the " + "total number of class."); + AddInput( + "Bias", + "(Tensor) A tensor of shape [num_class, 1]. 'num_class' is the total " + "number of class. It is a dispensable input.") + .AsDispensable(); + AddInput("SampleWeight", + "(Tensor) A tensor of shape [batch_size, 1] storing a weight for " + "each sample. And it is a dispensable input. The default value of " + "sample is 1.") + .AsDispensable(); + AddOutput("Cost", + "(Tensor) A tensor of shape [batch_size, 1]. Cost of samples."); + AddOutput("SampleLogits", + "An intermediate tensor of shape[batch_size, num_neg_samples + " + "num_pos_samples]." + "This tensor is output of forward kernel and used in backward " + "kernel to compute grads." + "Given X is the dot product of input tensor and sampled labels' " + "weights." + "Then 'SampleLogits' is sigmoid(X).") + .AsIntermediate(); + AddOutput("SampleLabels", + "An intermediate tensor of shape[batch_size, num_neg_samples + " + "num_pos_samples]." + "This tensor is output of forward kernel and used in backward " + "kernel to compute grads." + "") + .AsIntermediate(); + AddAttr("num_total_classes", + "Total number of classes in all samples."); + AddAttr("num_neg_samples", + "The number of negative classes. The default value is 10.") + .SetDefault(10); + AddAttr>("custom_neg_classes", + "This attribute only be used in unitest. Classes " + "in this list wiil be used as negative classes " + "for every samples. Under normal conditions, " + "user should avoid setting this attribute."); + AddComment(R"DOC( +Compute and return the noise-contrastive estimation training loss. +See [Noise-contrastive estimation: A new estimation principle for unnormalized statistical models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf). +By default this operator uses a uniform distribution for sampling. +)DOC"); + } +}; + +class NCEOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input")); + PADDLE_ENFORCE(ctx->HasInput("Weight")); + PADDLE_ENFORCE(ctx->HasInput("Cost")); + PADDLE_ENFORCE(ctx->HasInput("SampleLogits")); + PADDLE_ENFORCE(ctx->HasInput("SampleLabels")); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cost")), + "The input(Out@GRAD) should not be null."); + + auto x_dims = ctx->GetInputDim("Input"); + auto x_grad_name = framework::GradVarName("Input"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + + auto w_dims = ctx->GetInputDim("Weight"); + auto w_grad_name = framework::GradVarName("Weight"); + if (ctx->HasOutput(w_grad_name)) { + ctx->SetOutputDim(w_grad_name, w_dims); + } + + auto bias_grad_name = framework::GradVarName("Bias"); + if (ctx->HasOutput(bias_grad_name)) { + auto bias_dims = ctx->GetInputDim("Bias"); + ctx->SetOutputDim(bias_grad_name, bias_dims); + } + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), + ctx.device_context()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(nce, ops::NCEOp, ops::NCEOpMaker, nce_grad, ops::NCEOpGrad); +REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel, + ops::NCEKernel); +REGISTER_OP_CPU_KERNEL(nce_grad, + ops::NCEGradKernel, + ops::NCEGradKernel); diff --git a/paddle/operators/nce_op.h b/paddle/operators/nce_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ea92a797fe18e218be602e019f3fda6bc0b05f33 --- /dev/null +++ b/paddle/operators/nce_op.h @@ -0,0 +1,211 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "unsupported/Eigen/CXX11/Tensor" +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenMatrix = framework::EigenMatrix; + +template +void PrepareSamples(const framework::ExecutionContext& context) { + auto label = context.Input("Label"); + const int64_t* label_data = label->data(); + auto label_dims = label->dims(); + int num_total_classes = context.Attr("num_total_classes"); + // for unitest + std::vector custom_neg_classes = + context.Attr>("custom_neg_classes"); + // random machine + std::random_device rd; + std::mt19937 rng(rd()); + std::uniform_int_distribution rand(0, num_total_classes - 1); + + auto sample_labels = context.Output("SampleLabels"); + auto sample_labels_dims = sample_labels->dims(); + int64_t* sample_labels_data = + sample_labels->mutable_data(context.GetPlace()); + + int num_label = label_dims.size() == 2 ? label_dims[1] : 1; + int index = 0; + for (size_t i = 0; i < label_dims[0]; ++i) { + int j = 0; + for (; j < num_label; ++j) { + sample_labels_data[index++] = label_data[i * num_label + j]; + } + if (custom_neg_classes.size() > 0) { + for (auto label : custom_neg_classes) { + sample_labels_data[index++] = label; + } + } else { + for (; j < sample_labels_dims[1]; ++j) { + // TODO(wanghaoshuang): support more distribution sampling + sample_labels_data[index++] = rand(rng); + } + } + } +} + +template +class NCEKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PrepareSamples(context); + auto sample_labels = context.Output("SampleLabels"); + const int64_t* sample_labels_data = sample_labels->data(); + auto sample_out = context.Output("SampleLogits"); + T* sample_out_data = sample_out->mutable_data(context.GetPlace()); + auto label = context.Input("Label"); + auto sample_weight = context.Input("SampleWeight"); + const T* sample_weight_data = nullptr; + if (sample_weight != nullptr) { + sample_weight_data = sample_weight->data(); + } + auto out = context.Output("Cost"); + T* out_data = out->mutable_data(context.GetPlace()); + int num_neg_samples = context.Attr("num_neg_samples"); + int num_total_classes = context.Attr("num_total_classes"); + int num_true_class = 1; + if (label != nullptr) { + num_true_class = label->dims()[1]; + } + T b = 1. / num_total_classes * num_neg_samples; + // forward bias + auto bias = context.Input("Bias"); + if (bias != nullptr) { + const T* bias_data = bias->data(); + for (size_t i = 0; i < sample_labels->numel(); ++i) { + sample_out_data[i] = bias_data[sample_labels_data[i]]; + } + } else { + for (size_t i = 0; i < sample_labels->numel(); ++i) { + sample_out_data[i] = 0; + } + } + // forward mul + auto input_mat = EigenMatrix::From(*(context.Input("Input"))); + auto weight_mat = EigenMatrix::From(*(context.Input("Weight"))); + for (size_t i = 0; i < sample_labels->numel(); ++i) { + Eigen::Tensor result = + (input_mat.chip((int)(i / sample_labels->dims()[1]), 0) * + weight_mat.chip(sample_labels_data[i], 0)) + .sum(); + sample_out_data[i] += result(0); + sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i]))); + } + // forward cost + for (size_t i = 0; i < sample_labels->dims()[0]; ++i) { + size_t j = 0; + out_data[i] = 0; + T w = sample_weight == nullptr ? 1. : sample_weight_data[i]; + // for true classes + for (; j < num_true_class; ++j) { + T o = sample_out_data[i * sample_out->dims()[1] + j]; + T cost = -log(o / (o + b)); + out_data[i] += w * cost; + } + // for sampled neg classes + for (; j < sample_labels->dims()[1]; ++j) { + T o = sample_out_data[i * sample_out->dims()[1] + j]; + T cost = -log(b / (o + b)); + out_data[i] += w * cost; + } + } + } +}; + +template +class NCEGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto d_out = context.Input(framework::GradVarName("Cost")); + const T* d_out_data = d_out->data(); + auto label = context.Input("Label"); + auto sample_out = context.Input("SampleLogits"); + const T* sample_out_data = sample_out->data(); + auto sample_labels = context.Input("SampleLabels"); + const int64_t* sample_labels_data = sample_labels->data(); + auto sample_weight = context.Input("SampleWeight"); + const T* sample_weight_data = nullptr; + if (sample_weight != nullptr) { + sample_weight_data = sample_weight->data(); + } + int num_neg_samples = context.Attr("num_neg_samples"); + int num_total_classes = context.Attr("num_total_classes"); + int num_true_class = 1; + if (label != nullptr) { + num_true_class = label->dims()[1]; + } + T b = 1. / num_total_classes * num_neg_samples; + Tensor sample_grad; // tmp tensor + T* sample_grad_data = + sample_grad.mutable_data(sample_labels->dims(), context.GetPlace()); + // backward cost + for (size_t i = 0; i < sample_labels->numel(); ++i) { + T o = sample_out_data[i]; + T w = sample_weight == nullptr + ? 1 + : sample_weight_data[i / sample_labels->dims()[1]]; + sample_grad_data[i] = (i % sample_labels->dims()[1]) < num_true_class + ? w * (b / (o + b)) * (o - 1) + : w * (o * (1 - o) / (o + b)); + sample_grad_data[i] *= d_out_data[i / sample_labels->dims()[1]]; + } + // get d_bias + auto d_bias = context.Output(framework::GradVarName("Bias")); + if (d_bias != nullptr) { + T* d_bias_data = d_bias->mutable_data(context.GetPlace()); + std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0); + for (size_t i = 0; i < sample_labels->numel(); ++i) { + d_bias_data[sample_labels_data[i]] += sample_grad_data[i]; + } + } + // get d_w + auto d_w = context.Output(framework::GradVarName("Weight")); + if (d_w != nullptr) { + auto d_w_data = d_w->mutable_data(context.GetPlace()); + std::fill(d_w_data, d_w_data + d_w->numel(), 0.0); + auto d_w_matrix = EigenMatrix::From(*d_w); + auto x_matrix = EigenMatrix::From(*(context.Input("Input"))); + for (size_t i = 0; i < sample_labels->numel(); ++i) { + d_w_matrix.chip(sample_labels_data[i], 0) += + x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) * + sample_grad_data[i]; + } + } + // get d_x + auto d_x = context.Output(framework::GradVarName("Input")); + if (d_x != nullptr) { + d_x->mutable_data(context.GetPlace()); + auto d_x_matrix = EigenMatrix::From(*d_x); + auto w_matrix = EigenMatrix::From(*(context.Input("Weight"))); + for (size_t i = 0; i < sample_labels->numel(); ++i) { + d_x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) += + w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i]; + } + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_nce.py b/python/paddle/v2/fluid/tests/test_nce.py new file mode 100644 index 0000000000000000000000000000000000000000..8aeba69769525935c26576ec50035ed50d2ce44f --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_nce.py @@ -0,0 +1,98 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def nce(input, weight, bias, sample_weight, labels, num_classes, + num_sample_class): + samples = [] + sample_labels = [] + batch_size = input.shape[0] + num_true_class = labels.shape[1] + for i in range(batch_size): + w = 1 if sample_weight is None else sample_weight[i] + for label in labels[i]: + samples.append((i, label, True, w)) + sample_labels.append(label) + for num in range(num_sample_class): + samples.append((i, num, False, w)) + sample_labels.append(num) + # forward bias + sample_out = np.zeros(len(samples)).astype(np.float32) + if bias is not None: + for i in range(len(samples)): + sample_out[i] = bias[samples[i][1]] + # forward weight + for i in range(len(samples)): + sample_out[i] += np.dot(input[samples[i][0]], weight[samples[i][1]]) + + # forward activation + sample_out = 1.0 / (1.0 + np.exp(-sample_out)) + # forward cost + out = np.zeros(batch_size).astype(np.float32) + b = 1.0 / num_classes * num_sample_class + for i in range(len(samples)): + o = sample_out[i] + cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b)) + out[samples[i][0]] += cost * samples[i][3] + return (out[:, np.newaxis], np.array(sample_out).reshape( + batch_size, num_sample_class + num_true_class), + np.array(sample_labels).reshape(batch_size, + num_sample_class + num_true_class)) + + +class TestNCE(OpTest): + def generate_data(self, dim, batch_size, num_classes, num_true_class, + num_neg_samples): + input = np.random.randn(batch_size, dim).astype(np.float32) + weight = np.random.randn(num_classes, dim).astype(np.float32) + bias = np.random.randn(num_classes).astype(np.float32) + sample_weight = np.random.randn(batch_size).astype(np.float32) + labels = np.random.randint(0, num_classes, (batch_size, num_true_class)) + self.attrs = { + 'num_total_classes': num_classes, + 'num_neg_samples': num_neg_samples, + 'custom_neg_classes': range(num_neg_samples) + } + self.inputs = { + 'Input': input, + 'Label': labels, + 'Weight': weight, + 'Bias': bias, + 'SampleWeight': sample_weight + } + + def set_data(self): + self.generate_data(5, 5, 4, 1, 2) + + def compute(self): + out = nce(self.inputs['Input'], self.inputs['Weight'], + self.inputs['Bias'], self.inputs['SampleWeight'], + self.inputs['Label'], self.attrs['num_total_classes'], + self.attrs['num_neg_samples']) + self.outputs = { + 'Cost': out[0], + 'SampleLogits': out[1], + 'SampleLabels': out[2] + } + + def setUp(self): + self.op_type = 'nce' + self.set_data() + self.compute() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ["Input", "Weight", "Bias"], "Cost", max_relative_error=0.02) + + +class TestNCECase1(TestNCE): + def set_data(self): + self.generate_data(10, 20, 10, 2, 5) + + +if __name__ == '__main__': + unittest.main()