未验证 提交 da3f7668 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #12088 from guoshengCS/complete-hsigmoid

Complete hsigmoid_op
......@@ -259,6 +259,7 @@ op_library(max_sequence_len_op DEPS lod_rank_table)
op_library(sequence_conv_op DEPS context_project)
op_library(sequence_pool_op DEPS sequence_pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(hierarchical_sigmoid_op DEPS matrix_bit_code)
op_library(lstmp_op DEPS sequence2batch lstm_compute)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op DEPS executor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/hierarchical_sigmoid_op.h"
#include <vector>
namespace paddle {
namespace operators {
/**
* Organize the classes into a binary tree. At each node, a sigmoid function
* is used to calculate the probability of belonging to the right branch.
* This idea is from "F. Morin, Y. Bengio (AISTATS 05):
* Hierarchical Probabilistic Neural Network Language Model."
*
* Here we uses a simple way of making the binary tree.
* Assuming the number of classes C = 6,
* The classes are organized as a binary tree in the following way:
*
* @code{.py}
* *-*-*- 2
* | | |- 3
* | |
* | |-*- 4
* | |- 5
* |
* |-*- 0
* |- 1
* @endcode
*
* where * indicates an internal node, and each leaf node represents a class.
* - Node 0 ... C-2 are internal nodes.
* - Node C-1 ... 2C-2 are leaf nodes.
* - Class c is represented by leaf node \f$c+C-1\f$.
*
* We assign an id for each node:
* - the id of root be 0.
* - the left child of a node i is 2*i+1.
* - the right child of a node i is 2*i+2.
*
* It's easy to see that:
* - the parent of node i is \f$\left\lfloor(i-1)/2\right\rfloor\f$.
* - the j-th level ancestor of node i is
* \f$\left\lfloor(i+1)/2^{j+1}\right\rfloor - 1\f$.
* - A node i is a left child of its parent if \f$(i-1)\%2==0\f$.
*
*/
class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("PreOut"),
"Output(PreOut) should not be null.");
const int64_t batch_size = ctx->GetInputDim("X")[0];
std::vector<int64_t> output_shape({batch_size, 1});
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace());
}
};
template <typename AttrType>
class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, required) The input tensor with shape [N, D], "
"where N is the size of mini-batch, and D is the feature size.");
AddInput("W",
"(Tensor, required), The parameters of hierarchical "
"sigmoid operator, each of them is a 2-D tensor, the shape is"
"[num_classes - 1, D].");
AddInput("Label",
"(Tensor, required), The labels of training data. It's a"
"tensor with shape [N, 1].");
AddInput("Bias",
"(Tensor, optional), The bias is a tensor with shape"
"[1, num_classes - 1].");
AddOutput("Out",
"(Tensor, required) The output of hierarchical sigmoid operator."
"The shape is [N, 1].");
AddOutput("PreOut",
"(Tensor, required) A intermedia 2-D tensor with shape "
"[batch_size, code_length], where code_length represents the "
"maximum path length from root to leaf nodes.")
.AsIntermediate();
AddAttr<AttrType>("num_classes", "(int, required), The number of classes")
.SetDefault(2);
AddComment(R"DOC(
The hierarchical sigmoid operator organize the classes into a binary tree.
At each node, a sigmoid function is used to calculate the probability of
belonging to the right branch. This idea is from
"F. Morin, Y. Bengio (AISTATS 05):
Hierarchical Probabilistic Neural Network Language Model."
)DOC");
}
};
class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("PreOut"),
"Input(Preout) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
"Output(W@Grad should not be null.)");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")));
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Bias"));
}
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
ops::HierarchicalSigmoidOpMaker<int>,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid_grad,
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
float>,
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
using platform::Transform;
template <typename DeviceContext, typename T>
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* w = ctx.Input<framework::Tensor>("W");
auto* label = ctx.Input<framework::Tensor>("Label");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* out = ctx.Output<framework::Tensor>("Out");
auto* pre_out = ctx.Output<framework::Tensor>("PreOut");
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
int64_t code_length = math::FindLastSet(num_classes - 1);
int64_t batch_size = in->dims()[0];
framework::Tensor sum;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* pre_out_data = pre_out->mutable_data<T>(
framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
// Not all class(leaf) nodes' path lengths equal code_length, thus init as
// 0s can avoid out of path's loss.
math::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, pre_out, static_cast<T>(0.0));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
math::RowwiseSum<DeviceContext, T> row_sum;
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
std::vector<int64_t> sum_dims({batch_size, 1UL});
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
auto sum_mat = EigenMatrix<T>::From(sum);
out->mutable_data<T>(ctx.GetPlace());
auto out_mat = framework::EigenVector<T>::Flatten(*out);
if (bias) {
bit_code.Add(pre_out, *bias);
}
bit_code.Mul(pre_out, *w, *in);
// clip to [-40, 40]
Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), pre_out_data,
pre_out_data + pre_out->numel(), pre_out_data,
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
bit_code.Sum(*pre_out, out, static_cast<T>(-1));
// use softrelu to calculate cross entropy
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
row_sum(dev_ctx, *pre_out, &sum);
// TODO(guosheng): Subtract the out of path's loss, since not all
// class(leaf) nodes' path lengths equal code_length. But it won't break the
// gradient check since both have the out of path's loss and will cancel out
// each other.
out_mat.device(place) = sum_mat + out_mat;
}
};
template <typename DeviceContext, typename T>
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* w = ctx.Input<framework::Tensor>("W");
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W"));
auto* bias_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
auto* label = ctx.Input<framework::Tensor>("Label");
auto* pre_out = ctx.Input<framework::Tensor>("PreOut");
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor pre_out_grad;
pre_out_grad.mutable_data<T>(pre_out->dims(), ctx.GetPlace());
in_grad->mutable_data<T>(ctx.GetPlace());
w_grad->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, in_grad, static_cast<T>(0.0));
zero(dev_ctx, w_grad, static_cast<T>(0.0));
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}});
// softrelu derivative
pre_out_grad_mat.device(place) =
static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp();
bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b)
pre_out_grad_mat.device(place) =
pre_out_grad_mat * out_grad_mat.broadcast(bcast);
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
// be consistent with the clipping in forward.
if (bias_grad) {
bias_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
bit_code.AddGrad(pre_out_grad, bias_grad);
}
bit_code.MulGradWeight(pre_out_grad, w_grad, *in);
bit_code.MulGradError(pre_out_grad, *w, in_grad);
}
};
} // namespace operators
} // namespace paddle
......@@ -51,6 +51,7 @@ math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function)
math_library(sequence_scale)
math_library(softmax DEPS math_function)
math_library(matrix_bit_code)
math_library(unpooling)
math_library(vol2col)
......
......@@ -155,7 +155,7 @@ class RowwiseSum<platform::CPUDeviceContext, T> {
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
auto height = in_dims[0];
auto size = in_dims[1];
PADDLE_ENFORCE_EQ(out->numel(), size);
PADDLE_ENFORCE_EQ(out->numel(), height);
T* out_buf = out->mutable_data<T>(out->place());
const T* in_buf = input.data<T>();
......
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include <iostream>
namespace paddle {
namespace operators {
namespace math {
template <typename T>
void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
const framework::Tensor& vec) {
SimpleCodeTable code_table(num_classes_);
size_t batch_size = tmat->dims()[0];
size_t width = tmat->dims()[1];
for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
tmat->data<T>()[i * width + j] += vec.data<T>()[index];
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
framework::Tensor* vec) {
SimpleCodeTable code_table(num_classes_);
size_t batch_size = tmat.dims()[0];
size_t width = tmat.dims()[1];
for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
vec->data<T>()[index] += tmat.data<T>()[i * width + j];
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
framework::Tensor* sum, T scale_sum) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t o_width = tmat.dims()[1];
for (size_t i = 0; i < num_samples; ++i) {
T sm = static_cast<T>(0.0);
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) {
// calc_bit starts from right most bit, while data in tmat[i] is in the
// reverse order.
sm += tmat.data<T>()[i * o_width + j];
}
}
sum->data<T>()[i] = scale_sum * sm;
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
const framework::Tensor& weight,
const framework::Tensor& input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat->dims()[0];
size_t tmat_width = tmat->dims()[1];
size_t input_width = input.dims()[1];
size_t weight_width = weight.dims()[1];
auto tmat_value = tmat->data<T>();
auto weight_value = weight.data<T>();
auto input_value = input.data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
T sum = static_cast<T>(0.0);
for (size_t k = 0; k < input_width; ++k) {
sum += weight_value[weight_width * index + k] *
input_value[input_width * i + k];
}
tmat_value[i * tmat_width + j] += sum;
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
framework::Tensor* weight,
const framework::Tensor& input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t input_width = input.dims()[1];
size_t tmat_width = tmat.dims()[1];
size_t weight_width = weight->dims()[1];
auto tmat_value = tmat.data<T>();
auto weight_value = weight->data<T>();
auto input_value = input.data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
for (size_t k = 0; k < input_width; ++k) {
weight_value[weight_width * index + k] +=
tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
}
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
const framework::Tensor& weight,
framework::Tensor* input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t tmat_width = tmat.dims()[1];
size_t input_width = input->dims()[1];
size_t weight_width = weight.dims()[1];
auto tmat_value = tmat.data<T>();
auto weight_value = weight.data<T>();
auto input_value = input->data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
for (size_t k = 0; k < input_width; ++k) {
input_value[input_width * i + k] +=
tmat_value[i * tmat_width + j] *
weight_value[weight_width * index + k];
}
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat->dims()[0];
size_t o_width = tmat->dims()[1];
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) {
tmat->data<T>()[i * o_width + j] -= 1;
}
}
}
}
template class MatrixBitCodeFunctor<float>;
template class MatrixBitCodeFunctor<double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
/**
* SimpleCodeTable class should support 3 functions:
*
* size_t size()
* return the number of ids
*
* int get_max_code_length()
* return the maximal code length
*
* SimpleCode operator()(size_t i)
* return the i-th code. Code class is descriebed below.
*
* SimpleCode class should support 3 functions:
*
* int get_length()
* return the length of the code
*
* size_t cal_index(int bit)
* bit ranges from 0 to get_length() - 1
* return the index for the (1+bit) level parent
*
* bool calc_bit(int bit)
* return true if the bit level parent is the right child of (1+bit) level
* parent
*
*/
/**
* return the 1-based index of the highest bit set
*
* for x > 0:
* \f[
* FindLastSet(x) = 1 + \floor*{\log_{2}x}
* \f]
*/
inline constexpr size_t FindLastSet(size_t x) {
return std::is_same<size_t, unsigned int>::value
? (x ? 8 * sizeof(x) - __builtin_clz(x) : 0)
: (std::is_same<size_t, unsigned long>::value // NOLINT
? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0)
: (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0));
}
struct SimpleCode {
SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {}
/**
* Here the id of root shoud be 1 rather than 0, thus the encoding of class c
* is `c + num_classes` and all siblings can get the same weight indice using
* prefixes.
* Weight index is the prefixes of encoding, thus leave out the right most
* bit in calc_index.
* Binary classification path is the suffixes of encoding, thus leave out the
* left most bit in calc_bit.
*/
inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; }
inline bool calc_bit(int bit) const { return c_ & (1 << bit); }
inline int get_length() const { return FindLastSet(c_) - 1; }
private:
size_t c_;
};
struct SimpleCodeTable {
explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {}
SimpleCode operator()(size_t code) const {
return SimpleCode(code, num_classes_);
}
size_t size() const { return num_classes_; }
int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }
private:
size_t num_classes_;
};
template <typename T>
class MatrixBitCodeFunctor {
public:
explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids)
: num_classes_(num_classes), ids_(ids) {}
/* For j < code_length
tmat(i, j) += vec(0, index(i, j))
*/
void Add(framework::Tensor* tmat, const framework::Tensor& vec);
/* For j < code_length
vec(0, index(i, j)) += tmat(i, j)
*/
void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec);
/* For j < code_length
sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
*/
void Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum);
/* For j < code_length
tmat(i, j) -= bit(i, j)
*/
void Sub(framework::Tensor* tmat);
/* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/
void Mul(framework::Tensor* tmat, const framework::Tensor& weight,
const framework::Tensor& input);
/* For index(i, j) >= 0:
weight.row(index(i, j)) += tmat(i, j) * input.row(i)
*/
void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight,
const framework::Tensor& input);
/* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/
void MulGradError(const framework::Tensor& tmat,
const framework::Tensor& weight, framework::Tensor* input);
size_t num_classes_;
const int64_t* ids_;
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -85,6 +85,7 @@ __all__ = [
'transpose',
'im2sequence',
'nce',
'hsigmoid',
'beam_search',
'row_conv',
'multiplex',
......@@ -3871,6 +3872,74 @@ def nce(input,
return cost / (num_neg_samples + 1)
def hsigmoid(input, label, num_classes, param_attr=None, bias_attr=None):
"""
The hierarchical sigmoid operator is used to accelerate the training
process of language model. This operator organizes the classes into a
complete binary tree, each leaf node represents a class(a word) and each
internal node acts as a binary classifier. For each word there's a unique
path from root to it's leaf node, hsigmoid calculate the cost for each
internal node on the path, and sum them to get a total cost. hsigmoid can
achive a acceleration from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
represents the size of word dict.
Refer to `Hierarchical Probabilistic Neural Network Language Model
<http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>`_
Args:
input (Variable): The input tensor variable with shape
:math:`[N \\times D]`, where :math:`N` is the size of mini-batch,
and :math:`D` is the feature size.
label (Variable): The tensor variable contains labels of training data.
It's a tensor with shape is :math:`[N \\times 1]`.
num_classes: (int), The number of classes, must not be less than 2.
param_attr (ParamAttr|list of ParamAttr, default None): The parameter
attribute for learnable parameters/weights of this layer.
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter
attribute for the bias of this layer. If it is set to False, no
bias will be applied.
Returns:
Out: (Tensor) The cost of hierarchical sigmoid operator. the shape is [N, 1]
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[2], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='int64')
out = fluid.layers.hsigmoid(input=x, label=y, num_classes=6)
"""
helper = LayerHelper('hierarchical_sigmoid', **locals())
dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype)
pre_out = helper.create_tmp_variable(dtype)
dim = input.shape[1]
if num_classes < 2:
raise ValueError("num_classes must not be less than 2.")
weights = helper.create_parameter(
attr=helper.param_attr,
shape=[num_classes - 1, dim],
is_bias=False,
dtype=input.dtype)
inputs = {"X": input, "W": weights, "Label": label}
if helper.bias_attr:
bias = helper.create_parameter(
attr=helper.bias_attr,
shape=[1, num_classes - 1],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = bias
helper.append_op(
type="hierarchical_sigmoid",
inputs=inputs,
outputs={"Out": out,
"PreOut": pre_out},
attrs={"num_classes": num_classes})
return out
def transpose(x, perm, name=None):
"""
Permute the dimensions of `input` according to `perm`.
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import math
from op_test import OpTest
def find_latest_set(num):
return 1 + int(math.floor(math.log(num, 2)))
class CodeTable(object):
def __init__(self, num_classes, code):
self.c = num_classes + code
def cal_index(self, bit):
return (self.c >> (bit + 1)) - 1
def get_length(self):
return find_latest_set(self.c) - 1
def cal_bit(self, bit):
return self.c & (1 << bit)
def hsigmoid(x, w, label, bias, num_classes):
batch_size = x.shape[0]
code_length = find_latest_set(num_classes - 1)
code_table = [0 for _ in range(code_length)]
pre_output = np.zeros((batch_size, code_length))
pre_sum = np.zeros((batch_size, 1))
out = np.zeros((batch_size, 1)).astype("float32")
for i in range(batch_size):
code_table = CodeTable(num_classes, label[i])
length = code_table.get_length()
for j in range(length):
idx = code_table.cal_index(j)
pre_output[i][j] += bias[0][idx]
for i in range(batch_size):
code_table = CodeTable(num_classes, label[i])
length = code_table.get_length()
for j in range(length):
idx = code_table.cal_index(j)
pre_output[i][j] += np.dot(w[idx], x[i])
# clip[-40.0, 40.0]
pre_output = np.clip(pre_output, -40.0, 40.0)
# out(i, 0) = \sum_j bit(i, j) * preout(i, j)
for i in range(batch_size):
code_table = CodeTable(num_classes, label[i])
length = code_table.get_length()
sum = 0.0
for j in range(length):
if code_table.cal_bit(j):
sum += pre_output[i][j]
out[i] = -1.0 * sum
# soft relu
pre_output = np.log(1 + np.exp(pre_output))
pre_sum = pre_output.sum(1).reshape((batch_size, 1))
out += pre_sum
return pre_output, out
class TestHSigmoidOp(OpTest):
def setUp(self):
self.op_type = "hierarchical_sigmoid"
num_classes = 6
feature_size = 8
batch_size = 4
x = np.random.random((batch_size, feature_size)).astype("float32")
w = np.random.random((num_classes - 1, feature_size)).astype("float32")
label = np.random.randint(0, num_classes, (batch_size, 1))
bias = np.random.random((1, num_classes - 1)).astype("float32")
self.attrs = {'num_classes': num_classes}
self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias}
pre_output, out = hsigmoid(x, w, label, bias, num_classes)
self.outputs = {'PreOut': pre_output, 'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
if __name__ == '__main__':
unittest.main()
......@@ -174,6 +174,16 @@ class TestBook(unittest.TestCase):
x=dat, label=lbl))
print(str(program))
def test_hsigmoid(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[2], dtype='float32')
y = layers.data(name='y', shape=[2], dtype='int64')
self.assertIsNotNone(
layers.hsigmoid(
input=x, label=y, num_classes=2))
print(str(program))
def test_sequence_expand(self):
program = Program()
with program_guard(program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册