From 3d276277df1b1f8b216cae246d5cdc4f6dd02028 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 8 Nov 2017 14:17:38 +0800 Subject: [PATCH] Add nce op 1. Add nce forward and backward kernel for CPU --- paddle/operators/nce_op.cc | 120 +++++++++++++++++++++ paddle/operators/nce_op.h | 210 +++++++++++++++++++++++++++++++++++++ 2 files changed, 330 insertions(+) create mode 100644 paddle/operators/nce_op.cc create mode 100644 paddle/operators/nce_op.h diff --git a/paddle/operators/nce_op.cc b/paddle/operators/nce_op.cc new file mode 100644 index 0000000000..afd61b8851 --- /dev/null +++ b/paddle/operators/nce_op.cc @@ -0,0 +1,120 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/nce_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class NCEOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X")); + PADDLE_ENFORCE(ctx->HasInput("Label")); + PADDLE_ENFORCE(ctx->HasInput("W")); + PADDLE_ENFORCE(ctx->HasOutput("Out")); + PADDLE_ENFORCE(ctx->HasOutput("SampleLogits")); + PADDLE_ENFORCE(ctx->HasOutput("SampleLabels")); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]); + if (ctx->HasInput("B")) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("W")[0], ctx->GetInputDim("B")[0]); + } + int num_sampled_classes = ctx->Attrs().Get("num_sampled_classes"); + int num_classes = ctx->Attrs().Get("num_classes"); + PADDLE_ENFORCE_EQ(num_classes, ctx->GetInputDim("W")[0]); + PADDLE_ENFORCE_LT(num_sampled_classes, num_classes); + + // set dims of output(Out) + std::vector out_dims(1); + out_dims.push_back(x_dims[0]); + ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + + // set dims of output(SampleOut) + std::vector sample_out_dims(2); + sample_out_dims.push_back(x_dims[0]); + sample_out_dims.push_back(num_sampled_classes + 1); + ctx->SetOutputDim("SampleLogits", framework::make_ddim(sample_out_dims)); + ctx->SetOutputDim("SampleLabels", framework::make_ddim(sample_out_dims)); + } +}; + +class NCEOpMaker : public framework::OpProtoAndCheckerMaker { + public: + NCEOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", ""); + AddInput("Label", ""); + AddInput("W", ""); + AddInput("B", ""); + AddInput("SampleWeight", ""); + AddOutput("Out", ""); + AddOutput("SampleLogits", ""); + AddOutput("SampleLabels", ""); + AddAttr("num_classes", ""); + AddAttr("num_sampled_classes", "").SetDefault(10); + AddComment(R"DOC( +Expand input(X) according to LOD of input(Y). + +)DOC"); + } +}; + +class NCEOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X")); + PADDLE_ENFORCE(ctx->HasInput("W")); + PADDLE_ENFORCE(ctx->HasInput("Out")); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "The input(Out@GRAD) should not be null"); + + auto x_dims = ctx->GetInputDim("X"); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + + auto w_dims = ctx->GetInputDim("W"); + auto w_grad_name = framework::GradVarName("W"); + if (ctx->HasOutput(w_grad_name)) { + ctx->SetOutputDim(w_grad_name, w_dims); + } + + auto bias_grad_name = framework::GradVarName("B"); + if (ctx->HasOutput(bias_grad_name)) { + auto bias_dims = ctx->GetInputDim("B"); + ctx->SetOutputDim(bias_grad_name, bias_dims); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(nce, ops::NCEOp, ops::NCEOpMaker, nce_grad, ops::NCEOpGrad); +REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel); +REGISTER_OP_CPU_KERNEL(nce_grad, + ops::NCEGradKernel); diff --git a/paddle/operators/nce_op.h b/paddle/operators/nce_op.h new file mode 100644 index 0000000000..ce1717c9b0 --- /dev/null +++ b/paddle/operators/nce_op.h @@ -0,0 +1,210 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/memory/memcpy.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenMatrix = framework::EigenMatrix; + +template +void PrepareSamples(const framework::ExecutionContext& context) { + auto label = context.Input("Label"); + const T* label_data = label->data(); + auto label_dims = label->dims(); + int num_classes = context.Attr("num_classes"); + // random machine + std::random_device rd; + std::mt19937 rng(rd()); + std::uniform_int_distribution rand(0, num_classes - 1); + + auto sample_labels = context.Output("SampleLabels"); + auto sample_labels_dims = sample_labels->dims(); + int* sample_labels_data = + sample_labels->mutable_data(context.GetPlace()); + + int num_label = label_dims.size() == 2 ? label_dims[1] : 1; + for (size_t i = 0; i < label_dims[0]; ++i) { + int j = 0; + for (; j < num_label; ++j) { + sample_labels_data[sample_labels_dims[1] * i + j] = + label_data[i * num_label + j]; + } + for (; j < sample_labels_dims[1]; ++j) { + int id = rand(rng); + sample_labels_data[sample_labels_dims[1] * i + j] = id; + } + } +} + +template +class NCEKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PrepareSamples(context); + auto sample_labels = context.Output("SampleLabels"); + const int* sample_labels_data = sample_labels->data(); + auto sample_out = context.Output("SampleLogits"); + T* sample_out_data = sample_out->mutable_data(context.GetPlace()); + auto label = context.Input("Label"); + auto sample_weight = context.Input("SampleWeight"); + const T* sample_weight_data = nullptr; + if (sample_weight != nullptr) { + sample_weight_data = sample_weight->data(); + } + auto out = context.Output("Out"); + T* out_data = out->mutable_data(context.GetPlace()); + int num_smalped_classes = context.Attr("num_sampled_classes"); + int num_classes = context.Attr("num_classes"); + int num_true_class = 1; + if (label != nullptr) { + num_true_class = label->dims()[1]; + } + T b = 1. / num_classes * num_smalped_classes; + + // forward bias + auto bias = context.Input("B"); + if (bias != nullptr) { + const T* bias_data = bias->data(); + for (size_t i = 0; i < sample_labels->numel(); ++i) { + sample_out_data[i] = bias_data[sample_labels_data[i]]; + } + } else { + for (size_t i = 0; i < sample_labels->numel(); ++i) { + sample_out_data[i] = 0; + } + } + + // forward mul + auto input_mat = EigenMatrix::From(*(context.Input("X"))); + auto weight_mat = EigenMatrix::From(*(context.Input("W"))); + for (size_t i = 0; i < sample_labels->numel(); ++i) { + // sample_out_data[i] += (input_mat.chip((int)(i / + // sample_labels->dims()[1]), 0) * weight_mat.chip(sample_labels_data[i], + // 0)).sum(); + Eigen::Tensor result = + (input_mat.chip((int)(i / sample_labels->dims()[1]), 0) * + weight_mat.chip(sample_labels_data[i], 0)) + .sum(); + sample_out_data[i] += result(0); + // activation_->forward + sample_out_data[i] = (1 / 1 + (sample_out_data[i])); + } + + // forward cost + for (size_t i = 0; i < sample_labels->dims()[0]; ++i) { + size_t j = 0; + T w = sample_weight == nullptr ? 1 : sample_weight_data[i]; + // for true classes + for (; j < num_true_class; ++j) { + T o = sample_out_data[i * sample_out->dims()[1] + j]; + T cost = -log(o / (o + b)); + out_data[i] += w * cost; + } + // for sampled neg classes + for (; j < sample_labels->dims()[1]; ++j) { + T o = sample_out_data[i * sample_out->dims()[1] + j]; + T cost = -log(b / (o + b)); + out_data[i] += w * cost; + } + } + } +}; + +template +class NCEGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto label = context.Input("Label"); + auto sample_out = context.Input("SampleLogits"); + const T* sample_out_data = sample_out->data(); + auto sample_labels = context.Input("SampleLabels"); + const int* sample_labels_data = sample_labels->data(); + auto sample_weight = context.Input("SampleWeight"); + const T* sample_weight_data = nullptr; + if (sample_weight != nullptr) { + sample_weight_data = sample_weight->data(); + } + int num_smalped_classes = context.Attr("num_sampled_classes"); + int num_classes = context.Attr("num_classes"); + int num_true_class = 1; + if (label != nullptr) { + num_true_class = label->dims()[1]; + } + T b = 1. / num_classes * num_smalped_classes; + + Tensor sample_grad; // tmp tensor + T* sample_grad_data = + sample_grad.mutable_data(sample_labels->dims(), context.GetPlace()); + + // backward cost + for (size_t i = 0; i < sample_labels->numel(); ++i) { + T o = sample_out_data[i]; + T w = sample_weight == nullptr + ? 1 + : sample_weight_data[i / sample_labels->dims()[1]]; + sample_grad_data[i] = (i % sample_labels->dims()[1]) < num_true_class + ? -w * b / (o * (o + b)) + : w / (o + b); + // sigmoid->backward + sample_grad_data[i] = + (o > 0) ? sample_grad_data[i] : ((o < 0) ? -sample_grad_data[i] : 0); + } + + // get d_bias + auto d_bias = context.Output(framework::GradVarName("B")); + if (d_bias != nullptr) { + T* d_bias_data = d_bias->mutable_data(context.GetPlace()); + for (size_t i = 0; i < sample_labels->numel(); ++i) { + d_bias_data[sample_labels_data[i]] += sample_grad_data[i]; + } + } + // get d_w + auto d_w = context.Output(framework::GradVarName("W")); + if (d_w != nullptr) { + auto d_w_matrix = EigenMatrix::From(*d_w); + auto x_matrix = EigenMatrix::From(*(context.Input("X"))); + for (size_t i = 0; i < sample_labels->numel(); ++i) { + d_w_matrix.chip(sample_labels_data[i], 0) = + x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) * + sample_grad_data[i]; + } + } + + // get d_x + auto d_x = context.Output(framework::GradVarName("X")); + if (d_x != nullptr) { + auto d_x_matrix = EigenMatrix::From(*d_x); + auto w_matrix = EigenMatrix::From(*(context.Input("W"))); + for (size_t i = 0; i < sample_labels->numel(); ++i) { + d_x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) += + w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i]; + } + } + } +}; + +} // namespace operators +} // namespace paddle -- GitLab