diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index de6ff29c6f8edbcf930546ff157a1c226e1311db..b4eca5bd9ccf7712aecca90add924bd4e3ed3187 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -245,6 +245,7 @@ op_library(max_sequence_len_op DEPS lod_rank_table) op_library(sequence_conv_op DEPS context_project) op_library(sequence_pool_op DEPS sequence_pooling) op_library(lstm_op DEPS sequence2batch lstm_compute) +op_library(hierarchical_sigmoid_op DEPS matrix_bit_code) op_library(lstmp_op DEPS sequence2batch lstm_compute) op_library(gru_op DEPS sequence2batch gru_compute) op_library(recurrent_op DEPS executor) diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..499e641ff02a21fabe2a06024e316d3825f473b0 --- /dev/null +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -0,0 +1,163 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/hierarchical_sigmoid_op.h" +#include + +namespace paddle { +namespace operators { + +/** + * Organize the classes into a binary tree. At each node, a sigmoid function + * is used to calculate the probability of belonging to the right branch. + * This idea is from "F. Morin, Y. Bengio (AISTATS 05): + * Hierarchical Probabilistic Neural Network Language Model." + * + * Here we uses a simple way of making the binary tree. + * Assuming the number of classes C = 6, + * The classes are organized as a binary tree in the following way: + * + * @code{.py} + * *-*-*- 2 + * | | |- 3 + * | | + * | |-*- 4 + * | |- 5 + * | + * |-*- 0 + * |- 1 + * @endcode + * + * where * indicates an internal node, and each leaf node represents a class. + * - Node 0 ... C-2 are internal nodes. + * - Node C-1 ... 2C-2 are leaf nodes. + * - Class c is represented by leaf node \f$c+C-1\f$. + * + * We assign an id for each node: + * - the id of root be 0. + * - the left child of a node i is 2*i+1. + * - the right child of a node i is 2*i+2. + * + * It's easy to see that: + * - the parent of node i is \f$\left\lfloor(i-1)/2\right\rfloor\f$. + * - the j-th level ancestor of node i is + * \f$\left\lfloor(i+1)/2^{j+1}\right\rfloor - 1\f$. + * - A node i is a left child of its parent if \f$(i-1)\%2==0\f$. + * + */ + +class HierarchicalSigmoidOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("PreOut"), + "Output(PreOut) should not be null."); + const int64_t batch_size = ctx->GetInputDim("X")[0]; + std::vector output_shape({batch_size, 1}); + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.GetPlace()); + } +}; + +template +class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, required) The input Tensor, which the shape is" + "[N * D], which N is the size of mini-batch," + "D is the embded size"); + AddInput("W", + "(Tensor, required), The parameters of hierarchical " + "sigmoid operator, each of them is s a 3-D tensor, the shape is" + "[num_classes - 1, D]"); + AddInput("Ids", + "(Tensor, required), The labels of training data. It's a" + "1-D tensor, which the shape is [1, N]"); + AddInput("Bias", + "(Tensor, optional), The bias is a 1-D tensor, " + "which is applied to the output, the shape is" + "[1, num_classes -1]"); + AddOutput("Out", + "(Tensor, required) The output of hierarchical sigmoid operator." + "the shape is [N, 1]"); + AddOutput("PreOut", + "(Tensor, required) A intermedia 2-D Tensor, which the shape is " + "[batch_size, code_length]") + .AsIntermediate(); + AddAttr("num_classes", "(int, required), The number of classes") + .SetDefault(2); + AddComment(R"DOC( +The hierarchical sigmoid operator organize the classes into a binary tree. +At each node, a sigmoid function is used to caculate the probability of +belonging to the right branch. This idea is from +"F. Morin, Y. Bengio (AISTATS 05): +Hierarchical Probabilistic Neural Network Language Model." + )DOC"); + } +}; + +class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("PreOut"), + "Input(Preout) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")), + "Output(W@Grad should not be null.)"); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X"))); + if (ctx->HasOutput(framework::GradVarName("Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Bias"), + ctx->GetInputDim("Bias")); + } + ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp, + ops::HierarchicalSigmoidOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp); +REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid, + ops::HierarchicalSigmoidOpKernel< + paddle::platform::CPUDeviceContext, float>); +REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid_grad, + ops::HierarchicalSigmoidGradOpKernel< + paddle::platform::CPUDeviceContext, float>); diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h new file mode 100644 index 0000000000000000000000000000000000000000..5efac8804e5bb4bbae8ac67397451e459fca835e --- /dev/null +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -0,0 +1,123 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/clip_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/matrix_bit_code.h" +#include "paddle/fluid/platform/transform.h" +namespace paddle { +namespace operators { + +template +using EigenMatrix = framework::EigenMatrix; +using platform::Transform; + +template +class HierarchicalSigmoidOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* w = ctx.Input("W"); + auto* ids = ctx.Input("Ids"); + auto* bias = ctx.Input("Bias"); + auto* out = ctx.Output("Out"); + auto* pre_out = ctx.Output("PreOut"); + size_t num_classes = static_cast(ctx.Attr("num_classes")); + int64_t code_length = math::FindLastSet(num_classes - 1); + int64_t batch_size = in->dims()[0]; + framework::Tensor sum; + math::SetConstant zero; + auto& dev_ctx = ctx.template device_context(); + auto pre_out_data = pre_out->mutable_data( + framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); + auto pre_out_mat = EigenMatrix::From(*pre_out); + zero(dev_ctx, pre_out, static_cast(0.0)); + auto& place = *ctx.template device_context().eigen_device(); + math::RowwiseSum row_sum; + math::MatrixBitCodeFunctor bit_code(num_classes, ids->data()); + + std::vector sum_dims({batch_size, 1UL}); + sum.mutable_data(framework::make_ddim(sum_dims), ctx.GetPlace()); + auto sum_mat = EigenMatrix::From(sum); + out->mutable_data(ctx.GetPlace()); + auto out_mat = framework::EigenVector::Flatten(*out); + if (bias) { + bit_code.Add(pre_out, *bias); + } + bit_code.Mul(pre_out, *w, *in); + // clip the matrix with (-40, 40) + Transform trans; + trans(ctx.template device_context(), pre_out_data, + pre_out_data + pre_out->numel(), pre_out_data, + ClipFunctor(static_cast(-40.0), static_cast(40.0))); + bit_code.Sum(*pre_out, out, static_cast(-1)); + // softrelu with threshold is 40.0 + trans(ctx.template device_context(), pre_out_data, + pre_out_data + pre_out->numel(), pre_out_data, + ClipFunctor(static_cast(-40.0), static_cast(40.0))); + pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); + row_sum(dev_ctx, *pre_out, &sum); + out_mat.device(place) = sum_mat + out_mat; + } +}; + +template +class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* w = ctx.Input("W"); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + auto* w_grad = ctx.Output(framework::GradVarName("W")); + auto* bias_grad = + ctx.Output(framework::GradVarName("Bias")); + auto* ids = ctx.Input("Ids"); + auto* pre_out = ctx.Input("PreOut"); + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + + size_t num_classes = static_cast(ctx.Attr("num_classes")); + int64_t code_length = math::FindLastSet(num_classes - 1); + int64_t batch_size = in->dims()[0]; + framework::Tensor pre_out_grad; + pre_out_grad.mutable_data( + framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); + auto& place = *ctx.template device_context().eigen_device(); + auto pre_out_mat = EigenMatrix::From(*pre_out); + auto pre_out_grad_mat = EigenMatrix::From(pre_out_grad); + math::MatrixBitCodeFunctor bit_code(num_classes, ids->data()); + // softrelu derivative + bit_code.OutGrad(&pre_out_grad, *out_grad); + pre_out_grad_mat.device(place) = + pre_out_grad_mat * + (static_cast(1.0) - static_cast(1.0) / pre_out_mat.exp()); + bit_code.Sub(&pre_out_grad); + if (bias_grad) { + bias_grad->mutable_data(ctx.GetPlace()); + bit_code.AddGrad(pre_out_grad, bias_grad); + } + in_grad->mutable_data(ctx.GetPlace()); + w_grad->mutable_data(ctx.GetPlace()); + bit_code.MulGradWeight(pre_out_grad, w_grad, *in); + bit_code.MulGradError(pre_out_grad, *w, in_grad); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 53a478c1ac0bdf8c0a3f3721161779ef10cb14f8..bc788ef5e9f84c00bb7abb65997ad68182efec62 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -51,6 +51,7 @@ math_library(sequence_padding) math_library(sequence_pooling DEPS math_function) math_library(sequence_scale) math_library(softmax DEPS math_function) +math_library(matrix_bit_code) math_library(unpooling) math_library(vol2col) diff --git a/paddle/fluid/operators/math/math_function_impl.h b/paddle/fluid/operators/math/math_function_impl.h index b9bd49d77d935e985705f78402ffe1ea90f24cb3..895a7019aa10e5d9bb8f0c17e433a4344eac3bf4 100644 --- a/paddle/fluid/operators/math/math_function_impl.h +++ b/paddle/fluid/operators/math/math_function_impl.h @@ -155,7 +155,7 @@ class RowwiseSum { PADDLE_ENFORCE_EQ(in_dims.size(), 2U); auto height = in_dims[0]; auto size = in_dims[1]; - PADDLE_ENFORCE_EQ(out->numel(), size); + PADDLE_ENFORCE_EQ(out->numel(), height); T* out_buf = out->mutable_data(out->place()); const T* in_buf = input.data(); diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc new file mode 100644 index 0000000000000000000000000000000000000000..ea708eb9717737d7b3327e11b59dd8e9b1d87bc8 --- /dev/null +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -0,0 +1,211 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/matrix_bit_code.h" +#include +namespace paddle { +namespace operators { +namespace math { + +/** + * CodeTable class should support 3 functions: + * + * size_t size() + * return the number of ids + * + * int getMaxCodeLength() + * return the maximal code length + * + * Code operator()(size_t i) + * return the i-th code. Code class is descriebed below. + * + * Code class should support 3 functions: + * + * int getLength() + * return the length of the code + * + * bool calcIndex(int bit) + * bit ranges from 0 to getLength() - 1 + * return the index for the (1+bit) level parent + * + * bool calcBit(int bit) + * return true if the bit level parent is the right child of (1+bit) level + * parent + * + */ +template +void MatrixBitCodeFunctor::Add(framework::Tensor* tmat, + const framework::Tensor& vec) { + SimpleCodeTable code_table(num_classes_); + size_t batch_size = tmat->dims()[0]; + size_t width = tmat->dims()[1]; + for (size_t i = 0; i < batch_size; ++i) { + auto code = code_table(static_cast(ids_[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + tmat->data()[i * width + j] += vec.data()[index]; + } + } +} + +template +void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, + framework::Tensor* vec) { + SimpleCodeTable code_table(num_classes_); + size_t batch_size = tmat.dims()[0]; + size_t width = tmat.dims()[1]; + for (size_t i = 0; i < batch_size; ++i) { + auto code = code_table(static_cast(ids_[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + vec->data()[index] += tmat.data()[i * width + j]; + } + } +} + +template +void MatrixBitCodeFunctor::Sum(const framework::Tensor& tmat, + framework::Tensor* sum, T scale_sum) { + SimpleCodeTable code_table(num_classes_); + size_t num_samples = tmat.dims()[0]; + size_t o_width = tmat.dims()[1]; + for (size_t i = 0; i < num_samples; ++i) { + T sm = static_cast(0.0); + auto code = code_table(static_cast(ids_[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + if (code.calc_bit(j)) { + sm += tmat.data()[i * o_width + j]; + } + } + sum->data()[i] = scale_sum * sm; + } +} + +template +void MatrixBitCodeFunctor::Mul(framework::Tensor* tmat, + const framework::Tensor& weight, + const framework::Tensor& input) { + SimpleCodeTable code_table(num_classes_); + size_t num_samples = tmat->dims()[0]; + size_t tmat_width = tmat->dims()[1]; + size_t input_width = input.dims()[1]; + size_t weight_width = weight.dims()[1]; + auto tmat_value = tmat->data(); + auto weight_value = weight.data(); + auto input_value = input.data(); + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(static_cast(ids_[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + T sum = static_cast(0.0); + for (size_t k = 0; k < input_width; ++k) { + sum += weight_value[weight_width * index + k] * + input_value[input_width * i + k]; + } + tmat_value[i * tmat_width + j] += sum; + } + } +} + +template +void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, + framework::Tensor* weight, + const framework::Tensor& input) { + SimpleCodeTable code_table(num_classes_); + size_t num_samples = tmat.dims()[0]; + size_t input_width = input.dims()[1]; + size_t tmat_width = tmat.dims()[1]; + size_t weight_width = weight->dims()[1]; + auto tmat_value = tmat.data(); + auto weight_value = weight->data(); + auto input_value = input.data(); + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(static_cast(ids_[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + + for (size_t k = 0; k < input_width; ++k) { + weight_value[weight_width * index + k] += + tmat_value[i * tmat_width + j] * input_value[input_width * i + k]; + } + } + } +} + +template +void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, + const framework::Tensor& weight, + framework::Tensor* input) { + SimpleCodeTable code_table(num_classes_); + size_t num_samples = tmat.dims()[0]; + size_t tmat_width = tmat.dims()[1]; + size_t input_width = input->dims()[1]; + size_t weight_width = weight.dims()[1]; + auto tmat_value = tmat.data(); + auto weight_value = weight.data(); + auto input_value = input->data(); + + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(static_cast(ids_[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + + for (size_t k = 0; k < input_width; ++k) { + input_value[input_width * i + k] += + tmat_value[i * tmat_width + j] * + weight_value[weight_width * index + k]; + } + } + } +} + +template +void MatrixBitCodeFunctor::Sub(framework::Tensor* tmat) { + SimpleCodeTable code_table(num_classes_); + size_t num_samples = tmat->dims()[0]; + size_t o_width = tmat->dims()[1]; + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(static_cast(ids_[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + if (code.calc_bit(j)) { + tmat->data()[i * o_width + j] -= 1; + } + } + } +} + +template +void MatrixBitCodeFunctor::OutGrad(framework::Tensor* tmat, + const framework::Tensor& input) { + size_t num_samples = tmat->dims()[0]; + size_t code_length = tmat->dims()[1]; + for (size_t i = 0; i < num_samples; ++i) + for (size_t j = 0; j < code_length; ++j) { + tmat->data()[i * code_length + j] = input.data()[i]; + } +} + +template class MatrixBitCodeFunctor; +template class MatrixBitCodeFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h new file mode 100644 index 0000000000000000000000000000000000000000..43820810e1a80cb4764d802c7c355a8fcc4ab053 --- /dev/null +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -0,0 +1,113 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +/** + * return the 1-based index of the highest bit set + * + * for x > 0: + * \f[ + * findLastSet(x) = 1 + \floor*{\log_{2}x} + * \f] + */ +inline constexpr size_t FindLastSet(size_t x) { + return std::is_same::value + ? (x ? 8 * sizeof(x) - __builtin_clz(x) : 0) + : (std::is_same::value // NOLINT + ? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0) + : (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0)); +} + +struct SimpleCode { + SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {} + inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; } + inline bool calc_bit(int bit) const { return c_ & (1 << bit); } + inline int get_length() const { return FindLastSet(c_) - 1; } + + private: + size_t c_; +}; + +struct SimpleCodeTable { + explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {} + SimpleCode operator()(size_t code) const { + return SimpleCode(code, num_classes_); + } + size_t size() const { return num_classes_; } + int get_max_code_length() const { return FindLastSet(num_classes_ - 1); } + + private: + size_t num_classes_; + int max_code_length_; +}; + +template +class MatrixBitCodeFunctor { + public: + explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids) + : num_classes_(num_classes), ids_(ids) {} + /* For j < code_length + tmat(i, j) += vec(0, index(i, j)) + */ + void Add(framework::Tensor* tmat, const framework::Tensor& vec); + + /* For j < code_length + vec(0, index(i, j)) += tmat(i, j) + */ + void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec); + + /* For j < code_length + sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) + */ + void Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum); + + /* For j < code_length + tmat(i, j) -= bit(i, j) + */ + void Sub(framework::Tensor* tmat); + /* For j < code_length + input.row(i) += tmat(i, j) * weight.row(index(i, j)) + */ + void Mul(framework::Tensor* tmat, const framework::Tensor& weight, + const framework::Tensor& input); + + /* For index(i, j) >= 0: + weight.row(index(i, j)) += tmat(i, j) * input.row(i) + */ + void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight, + const framework::Tensor& input); + /* For j < code_length + input.row(i) += tmat(i, j) * weight.row(index(i, j)) + */ + void MulGradError(const framework::Tensor& tmat, + const framework::Tensor& weight, framework::Tensor* input); + /* For j < code_length + tmat(i, j) == input(i) + */ + void OutGrad(framework::Tensor* tmat, const framework::Tensor& input); + + size_t num_classes_; + const int64_t* ids_; +}; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 3bb9fd04038b2a0ee98cbfb17b5c0ea0b8b06789..70858d477ffdfc35b9c9d9f3daeb7fef1c5d9492 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -66,6 +66,7 @@ __all__ = [ 'transpose', 'im2sequence', 'nce', + 'hsigmoid', 'beam_search', 'row_conv', 'multiplex', @@ -2986,6 +2987,78 @@ def nce(input, return cost / (num_neg_samples + 1) +def hsigmoid(input, label, num_classes=2, param_attr=None, bias_attr=None): + """ + The hierarchical sigmoid operator is used to accelerate the training + process of language model. This operator organizes the classes into a + complete binary tree, each leaf node represents a class(a word) and each internal + node acts likea binary classifier. For each word there's a unique path from root + to it's leaf node, hsigmoid calculate the cost for each internal node on the path + (include root), and sum them to get a total cost. hsigmoid can achive a acceleration + from N to logN, for which N represents the size of word dict. This idea is from "F. + Morin, Y. Bengio(AISTATS 05): Hierarchical Probabilistic Neural Network Language Model. + + Args: + input (Variable): (Tensor) The input Tensor, which the shape is + [N * D], which N is the size of mini-batch,D is the embded size + label (Variable): (Tensor), The labels of training data. It's a + 1-D tensor, which the shape is [1, N] + num_classes: (int, default 2), The number of classes, must be lager or + equal than 2. + param_attr (ParamAttr|list of ParamAttr, default None): The parameter + attribute for learnable parameters/weights of this layer. + bias_attr (ParamAttr|list of ParamAttr, default None): The parameter + attribute for the bias of this layer. If it is set to None, no bias + will be added to the output units. + + Returns: + Out: (Tensor) The cost of hierarchical sigmoid operator. the shape is [N, 1] + + Examples: + + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[3, 2], + dtype='float32') + y = fluid.layers.data(name='y', shape=[1, 3], + dtype='int64') + out = fluid.layers.hsigmoid(input=x, label=y, num_classes=2) + """ + + helper = LayerHelper('hierarchical_sigmoid', **locals()) + dtype = helper.input_dtype() + out = helper.create_tmp_variable(dtype) + pre_out = helper.create_tmp_variable(dtype) + dim = input.shape[1] + if num_classes < 2: + raise valueError("num_classes must be lager or equal than 2.") + if x.shape[0] != y.shape[1]: + raise valueError( + "input's 1-st dimension and label's 2-nd dimension must be equal they both equal to batch size." + ) + weights = helper.create_parameter( + attr=helper.param_attr, + shape=[num_classes - 1, dim], + is_bias=False, + dtype=input.dtype) + bias = helper.create_parameter( + attr=helper.bias_attr, + shape=[1, num_classes - 1], + is_bias=True, + dtype=input.dtype) + + helper.append_op( + type="hierarchical_sigmoid", + inputs={"X": input, + "W": weights, + "Ids": label, + "Bias": bias}, + outputs={"Out": out, + "PreOut": pre_out}, + attrs={"num_classes": num_classes}) + return out + + def transpose(x, perm, name=None): """ **transpose Layer** @@ -4009,8 +4082,7 @@ def random_crop(input, shape, seed=1): attrs={ "dtype": seed.dtype, "shape": [1], - "value": float(seed_value), - "force_cpu": True + "value": float(seed_value) }) elif not isinstance(seed, Variable): raise ValueError("'seed' must be a Variable or an int.") diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py new file mode 100644 index 0000000000000000000000000000000000000000..178f56aeb81ebe575a940d2efc3d2d6ef762ed6b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -0,0 +1,109 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest +import math + + +def find_latest_set(num): + return 1 + int(math.floor(math.log(num, 2))) + + +class CodeTable(object): + def __init__(self, num_classes, code): + self.c = num_classes + code + + def cal_index(self, bit): + return (self.c >> (bit + 1)) - 1 + + def get_length(self): + return find_latest_set(self.c) - 1 + + def cal_bit(self, bit): + return self.c & (1 << bit) + + +def hsigmoid(x, w, ids, bias, num_classes): + # code length = + # initialize pre out with dims={batch_size, code_length} + batch_size = x.shape[0] + code_length = find_latest_set(num_classes - 1) + code_table = [0 for _ in range(code_length)] + pre_output = np.zeros((batch_size, code_length)) + pre_sum = np.zeros((batch_size, 1)) + out = np.zeros((batch_size, 1)).astype("float32") + # pre_out += code(bias) + for i in xrange(batch_size): + code_table = CodeTable(num_classes, ids[i]) + length = code_table.get_length() + for j in xrange(length): + idx = code_table.cal_index(j) + pre_output[i][j] += bias[0][idx] + # pre_out += code(w) * x + for i in xrange(batch_size): + for j in xrange(batch_size): + code_table = CodeTable(num_classes, ids[j]) + length = code_table.get_length() + for k in xrange(length): + idx = code_table.cal_index(k) + sum = 0.0 + for l in xrange(x.shape[1]): + sum += w[i][idx][l] * x[j][l] + pre_output[j][k] += sum + # clip[-40.0, 40.0] + np.clip(pre_output, -40.0, 40.0) + # out(i, 0) = \sum_j bit(i, j) * preout(i, j) + for i in xrange(batch_size): + code_table = CodeTable(num_classes, ids[i]) + length = code_table.get_length() + sum = 0.0 + for j in xrange(length): + if code_table.cal_bit(j): + sum += pre_output[i][j] + out[i] = -1.0 * sum + # soft relu + np.clip(pre_output, -40.0, 40.0) + pre_output = np.log(1 + np.exp(pre_output)) + pre_sum = pre_output.sum(1).reshape((batch_size, 1)) + out += pre_sum + return out + + +class TestHSigmoidOp(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid" + num_classes = 6 + embded_size = 10 + batch_size = 5 + x = np.random.random((batch_size, embded_size)).astype("float32") + w = np.random.random( + (batch_size, num_classes - 1, embded_size)).astype("float32") + ids = np.random.randint(0, num_classes, batch_size) + bias = np.random.random((1, num_classes - 1)).astype("float32") + self.inputs = {'X': x, 'W': w, 'Ids': ids, 'Bias': bias} + self.attrs = {'num_classes': num_classes} + out = hsigmoid(x, w, ids, bias, num_classes) + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X', 'W', 'Bias'], 'Out', no_grad_set=set('Ids')) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 60dc1f83fc32e2551eb2a04ef35f1c8a0ffec769..5ac518478804350f46ad8bccb40e8a8d043d07f8 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -173,6 +173,16 @@ class TestBook(unittest.TestCase): x=dat, label=lbl)) print(str(program)) + def test_hsigmoid(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[2, 2], dtype='float32') + y = layers.data(name='y', shape=[1, 3], dtype='int64') + self.assertIsNotNone( + layers.hsigmoid( + input=x, label=y, num_classes=2)) + print(str(program)) + def test_sequence_expand(self): program = Program() with program_guard(program):