提交 3e46ec41 编写于 作者: W weixing02

add hsigmoid

......@@ -245,6 +245,7 @@ op_library(max_sequence_len_op DEPS lod_rank_table)
op_library(sequence_conv_op DEPS context_project)
op_library(sequence_pool_op DEPS sequence_pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(hierarchical_sigmoid_op DEPS matrix_bit_code)
op_library(lstmp_op DEPS sequence2batch lstm_compute)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op DEPS executor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/hierarchical_sigmoid_op.h"
#include <vector>
namespace paddle {
namespace operators {
/**
* Organize the classes into a binary tree. At each node, a sigmoid function
* is used to calculate the probability of belonging to the right branch.
* This idea is from "F. Morin, Y. Bengio (AISTATS 05):
* Hierarchical Probabilistic Neural Network Language Model."
*
* Here we uses a simple way of making the binary tree.
* Assuming the number of classes C = 6,
* The classes are organized as a binary tree in the following way:
*
* @code{.py}
* *-*-*- 2
* | | |- 3
* | |
* | |-*- 4
* | |- 5
* |
* |-*- 0
* |- 1
* @endcode
*
* where * indicates an internal node, and each leaf node represents a class.
* - Node 0 ... C-2 are internal nodes.
* - Node C-1 ... 2C-2 are leaf nodes.
* - Class c is represented by leaf node \f$c+C-1\f$.
*
* We assign an id for each node:
* - the id of root be 0.
* - the left child of a node i is 2*i+1.
* - the right child of a node i is 2*i+2.
*
* It's easy to see that:
* - the parent of node i is \f$\left\lfloor(i-1)/2\right\rfloor\f$.
* - the j-th level ancestor of node i is
* \f$\left\lfloor(i+1)/2^{j+1}\right\rfloor - 1\f$.
* - A node i is a left child of its parent if \f$(i-1)\%2==0\f$.
*
*/
class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("PreOut"),
"Output(PreOut) should not be null.");
const int64_t batch_size = ctx->GetInputDim("X")[0];
std::vector<int64_t> output_shape({batch_size, 1});
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace());
}
};
template <typename AttrType>
class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, required) The input Tensor, which the shape is"
"[N * D], which N is the size of mini-batch,"
"D is the embded size");
AddInput("W",
"(Tensor, required), The parameters of hierarchical "
"sigmoid operator, each of them is s a 3-D tensor, the shape is"
"[num_classes - 1, D]");
AddInput("Ids",
"(Tensor, required), The labels of training data. It's a"
"1-D tensor, which the shape is [1, N]");
AddInput("Bias",
"(Tensor, optional), The bias is a 1-D tensor, "
"which is applied to the output, the shape is"
"[1, num_classes -1]");
AddOutput("Out",
"(Tensor, required) The output of hierarchical sigmoid operator."
"the shape is [N, 1]");
AddOutput("PreOut",
"(Tensor, required) A intermedia 2-D Tensor, which the shape is "
"[batch_size, code_length]")
.AsIntermediate();
AddAttr<AttrType>("num_classes", "(int, required), The number of classes")
.SetDefault(2);
AddComment(R"DOC(
The hierarchical sigmoid operator organize the classes into a binary tree.
At each node, a sigmoid function is used to caculate the probability of
belonging to the right branch. This idea is from
"F. Morin, Y. Bengio (AISTATS 05):
Hierarchical Probabilistic Neural Network Language Model."
)DOC");
}
};
class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("PreOut"),
"Input(Preout) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
"Output(W@Grad should not be null.)");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")));
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Bias"));
}
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
ops::HierarchicalSigmoidOpMaker<int>,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp);
REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid,
ops::HierarchicalSigmoidOpKernel<
paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid_grad,
ops::HierarchicalSigmoidGradOpKernel<
paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
using platform::Transform;
template <typename DeviceContext, typename T>
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* w = ctx.Input<framework::Tensor>("W");
auto* ids = ctx.Input<framework::Tensor>("Ids");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* out = ctx.Output<framework::Tensor>("Out");
auto* pre_out = ctx.Output<framework::Tensor>("PreOut");
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
int64_t code_length = math::FindLastSet(num_classes - 1);
int64_t batch_size = in->dims()[0];
framework::Tensor sum;
math::SetConstant<DeviceContext, T> zero;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto pre_out_data = pre_out->mutable_data<T>(
framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
zero(dev_ctx, pre_out, static_cast<T>(0.0));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
math::RowwiseSum<DeviceContext, T> row_sum;
math::MatrixBitCodeFunctor<T> bit_code(num_classes, ids->data<int64_t>());
std::vector<int64_t> sum_dims({batch_size, 1UL});
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
auto sum_mat = EigenMatrix<T>::From(sum);
out->mutable_data<T>(ctx.GetPlace());
auto out_mat = framework::EigenVector<T>::Flatten(*out);
if (bias) {
bit_code.Add(pre_out, *bias);
}
bit_code.Mul(pre_out, *w, *in);
// clip the matrix with (-40, 40)
Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), pre_out_data,
pre_out_data + pre_out->numel(), pre_out_data,
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
bit_code.Sum(*pre_out, out, static_cast<T>(-1));
// softrelu with threshold is 40.0
trans(ctx.template device_context<DeviceContext>(), pre_out_data,
pre_out_data + pre_out->numel(), pre_out_data,
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
row_sum(dev_ctx, *pre_out, &sum);
out_mat.device(place) = sum_mat + out_mat;
}
};
template <typename DeviceContext, typename T>
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* w = ctx.Input<framework::Tensor>("W");
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W"));
auto* bias_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
auto* ids = ctx.Input<framework::Tensor>("Ids");
auto* pre_out = ctx.Input<framework::Tensor>("PreOut");
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
int64_t code_length = math::FindLastSet(num_classes - 1);
int64_t batch_size = in->dims()[0];
framework::Tensor pre_out_grad;
pre_out_grad.mutable_data<T>(
framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
math::MatrixBitCodeFunctor<T> bit_code(num_classes, ids->data<int64_t>());
// softrelu derivative
bit_code.OutGrad(&pre_out_grad, *out_grad);
pre_out_grad_mat.device(place) =
pre_out_grad_mat *
(static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp());
bit_code.Sub(&pre_out_grad);
if (bias_grad) {
bias_grad->mutable_data<T>(ctx.GetPlace());
bit_code.AddGrad(pre_out_grad, bias_grad);
}
in_grad->mutable_data<T>(ctx.GetPlace());
w_grad->mutable_data<T>(ctx.GetPlace());
bit_code.MulGradWeight(pre_out_grad, w_grad, *in);
bit_code.MulGradError(pre_out_grad, *w, in_grad);
}
};
} // namespace operators
} // namespace paddle
......@@ -51,6 +51,7 @@ math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function)
math_library(sequence_scale)
math_library(softmax DEPS math_function)
math_library(matrix_bit_code)
math_library(unpooling)
math_library(vol2col)
......
......@@ -155,7 +155,7 @@ class RowwiseSum<platform::CPUDeviceContext, T> {
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
auto height = in_dims[0];
auto size = in_dims[1];
PADDLE_ENFORCE_EQ(out->numel(), size);
PADDLE_ENFORCE_EQ(out->numel(), height);
T* out_buf = out->mutable_data<T>(out->place());
const T* in_buf = input.data<T>();
......
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include <iostream>
namespace paddle {
namespace operators {
namespace math {
/**
* CodeTable class should support 3 functions:
*
* size_t size()
* return the number of ids
*
* int getMaxCodeLength()
* return the maximal code length
*
* Code operator()(size_t i)
* return the i-th code. Code class is descriebed below.
*
* Code class should support 3 functions:
*
* int getLength()
* return the length of the code
*
* bool calcIndex(int bit)
* bit ranges from 0 to getLength() - 1
* return the index for the (1+bit) level parent
*
* bool calcBit(int bit)
* return true if the bit level parent is the right child of (1+bit) level
* parent
*
*/
template <typename T>
void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
const framework::Tensor& vec) {
SimpleCodeTable code_table(num_classes_);
size_t batch_size = tmat->dims()[0];
size_t width = tmat->dims()[1];
for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
tmat->data<T>()[i * width + j] += vec.data<T>()[index];
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
framework::Tensor* vec) {
SimpleCodeTable code_table(num_classes_);
size_t batch_size = tmat.dims()[0];
size_t width = tmat.dims()[1];
for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
vec->data<T>()[index] += tmat.data<T>()[i * width + j];
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
framework::Tensor* sum, T scale_sum) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t o_width = tmat.dims()[1];
for (size_t i = 0; i < num_samples; ++i) {
T sm = static_cast<T>(0.0);
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) {
sm += tmat.data<T>()[i * o_width + j];
}
}
sum->data<T>()[i] = scale_sum * sm;
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
const framework::Tensor& weight,
const framework::Tensor& input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat->dims()[0];
size_t tmat_width = tmat->dims()[1];
size_t input_width = input.dims()[1];
size_t weight_width = weight.dims()[1];
auto tmat_value = tmat->data<T>();
auto weight_value = weight.data<T>();
auto input_value = input.data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
T sum = static_cast<T>(0.0);
for (size_t k = 0; k < input_width; ++k) {
sum += weight_value[weight_width * index + k] *
input_value[input_width * i + k];
}
tmat_value[i * tmat_width + j] += sum;
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
framework::Tensor* weight,
const framework::Tensor& input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t input_width = input.dims()[1];
size_t tmat_width = tmat.dims()[1];
size_t weight_width = weight->dims()[1];
auto tmat_value = tmat.data<T>();
auto weight_value = weight->data<T>();
auto input_value = input.data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
for (size_t k = 0; k < input_width; ++k) {
weight_value[weight_width * index + k] +=
tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
}
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
const framework::Tensor& weight,
framework::Tensor* input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t tmat_width = tmat.dims()[1];
size_t input_width = input->dims()[1];
size_t weight_width = weight.dims()[1];
auto tmat_value = tmat.data<T>();
auto weight_value = weight.data<T>();
auto input_value = input->data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
for (size_t k = 0; k < input_width; ++k) {
input_value[input_width * i + k] +=
tmat_value[i * tmat_width + j] *
weight_value[weight_width * index + k];
}
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat->dims()[0];
size_t o_width = tmat->dims()[1];
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) {
tmat->data<T>()[i * o_width + j] -= 1;
}
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::OutGrad(framework::Tensor* tmat,
const framework::Tensor& input) {
size_t num_samples = tmat->dims()[0];
size_t code_length = tmat->dims()[1];
for (size_t i = 0; i < num_samples; ++i)
for (size_t j = 0; j < code_length; ++j) {
tmat->data<T>()[i * code_length + j] = input.data<T>()[i];
}
}
template class MatrixBitCodeFunctor<float>;
template class MatrixBitCodeFunctor<double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
/**
* return the 1-based index of the highest bit set
*
* for x > 0:
* \f[
* findLastSet(x) = 1 + \floor*{\log_{2}x}
* \f]
*/
inline constexpr size_t FindLastSet(size_t x) {
return std::is_same<size_t, unsigned int>::value
? (x ? 8 * sizeof(x) - __builtin_clz(x) : 0)
: (std::is_same<size_t, unsigned long>::value // NOLINT
? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0)
: (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0));
}
struct SimpleCode {
SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {}
inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; }
inline bool calc_bit(int bit) const { return c_ & (1 << bit); }
inline int get_length() const { return FindLastSet(c_) - 1; }
private:
size_t c_;
};
struct SimpleCodeTable {
explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {}
SimpleCode operator()(size_t code) const {
return SimpleCode(code, num_classes_);
}
size_t size() const { return num_classes_; }
int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }
private:
size_t num_classes_;
int max_code_length_;
};
template <typename T>
class MatrixBitCodeFunctor {
public:
explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids)
: num_classes_(num_classes), ids_(ids) {}
/* For j < code_length
tmat(i, j) += vec(0, index(i, j))
*/
void Add(framework::Tensor* tmat, const framework::Tensor& vec);
/* For j < code_length
vec(0, index(i, j)) += tmat(i, j)
*/
void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec);
/* For j < code_length
sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
*/
void Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum);
/* For j < code_length
tmat(i, j) -= bit(i, j)
*/
void Sub(framework::Tensor* tmat);
/* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/
void Mul(framework::Tensor* tmat, const framework::Tensor& weight,
const framework::Tensor& input);
/* For index(i, j) >= 0:
weight.row(index(i, j)) += tmat(i, j) * input.row(i)
*/
void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight,
const framework::Tensor& input);
/* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/
void MulGradError(const framework::Tensor& tmat,
const framework::Tensor& weight, framework::Tensor* input);
/* For j < code_length
tmat(i, j) == input(i)
*/
void OutGrad(framework::Tensor* tmat, const framework::Tensor& input);
size_t num_classes_;
const int64_t* ids_;
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -66,6 +66,7 @@ __all__ = [
'transpose',
'im2sequence',
'nce',
'hsigmoid',
'beam_search',
'row_conv',
'multiplex',
......@@ -2986,6 +2987,78 @@ def nce(input,
return cost / (num_neg_samples + 1)
def hsigmoid(input, label, num_classes=2, param_attr=None, bias_attr=None):
"""
The hierarchical sigmoid operator is used to accelerate the training
process of language model. This operator organizes the classes into a
complete binary tree, each leaf node represents a class(a word) and each internal
node acts likea binary classifier. For each word there's a unique path from root
to it's leaf node, hsigmoid calculate the cost for each internal node on the path
(include root), and sum them to get a total cost. hsigmoid can achive a acceleration
from N to logN, for which N represents the size of word dict. This idea is from "F.
Morin, Y. Bengio(AISTATS 05): Hierarchical Probabilistic Neural Network Language Model.
Args:
input (Variable): (Tensor) The input Tensor, which the shape is
[N * D], which N is the size of mini-batch,D is the embded size
label (Variable): (Tensor), The labels of training data. It's a
1-D tensor, which the shape is [1, N]
num_classes: (int, default 2), The number of classes, must be lager or
equal than 2.
param_attr (ParamAttr|list of ParamAttr, default None): The parameter
attribute for learnable parameters/weights of this layer.
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter
attribute for the bias of this layer. If it is set to None, no bias
will be added to the output units.
Returns:
Out: (Tensor) The cost of hierarchical sigmoid operator. the shape is [N, 1]
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[3, 2],
dtype='float32')
y = fluid.layers.data(name='y', shape=[1, 3],
dtype='int64')
out = fluid.layers.hsigmoid(input=x, label=y, num_classes=2)
"""
helper = LayerHelper('hierarchical_sigmoid', **locals())
dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype)
pre_out = helper.create_tmp_variable(dtype)
dim = input.shape[1]
if num_classes < 2:
raise valueError("num_classes must be lager or equal than 2.")
if x.shape[0] != y.shape[1]:
raise valueError(
"input's 1-st dimension and label's 2-nd dimension must be equal they both equal to batch size."
)
weights = helper.create_parameter(
attr=helper.param_attr,
shape=[num_classes - 1, dim],
is_bias=False,
dtype=input.dtype)
bias = helper.create_parameter(
attr=helper.bias_attr,
shape=[1, num_classes - 1],
is_bias=True,
dtype=input.dtype)
helper.append_op(
type="hierarchical_sigmoid",
inputs={"X": input,
"W": weights,
"Ids": label,
"Bias": bias},
outputs={"Out": out,
"PreOut": pre_out},
attrs={"num_classes": num_classes})
return out
def transpose(x, perm, name=None):
"""
**transpose Layer**
......@@ -4009,8 +4082,7 @@ def random_crop(input, shape, seed=1):
attrs={
"dtype": seed.dtype,
"shape": [1],
"value": float(seed_value),
"force_cpu": True
"value": float(seed_value)
})
elif not isinstance(seed, Variable):
raise ValueError("'seed' must be a Variable or an int.")
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
import math
def find_latest_set(num):
return 1 + int(math.floor(math.log(num, 2)))
class CodeTable(object):
def __init__(self, num_classes, code):
self.c = num_classes + code
def cal_index(self, bit):
return (self.c >> (bit + 1)) - 1
def get_length(self):
return find_latest_set(self.c) - 1
def cal_bit(self, bit):
return self.c & (1 << bit)
def hsigmoid(x, w, ids, bias, num_classes):
# code length =
# initialize pre out with dims={batch_size, code_length}
batch_size = x.shape[0]
code_length = find_latest_set(num_classes - 1)
code_table = [0 for _ in range(code_length)]
pre_output = np.zeros((batch_size, code_length))
pre_sum = np.zeros((batch_size, 1))
out = np.zeros((batch_size, 1)).astype("float32")
# pre_out += code(bias)
for i in xrange(batch_size):
code_table = CodeTable(num_classes, ids[i])
length = code_table.get_length()
for j in xrange(length):
idx = code_table.cal_index(j)
pre_output[i][j] += bias[0][idx]
# pre_out += code(w) * x
for i in xrange(batch_size):
for j in xrange(batch_size):
code_table = CodeTable(num_classes, ids[j])
length = code_table.get_length()
for k in xrange(length):
idx = code_table.cal_index(k)
sum = 0.0
for l in xrange(x.shape[1]):
sum += w[i][idx][l] * x[j][l]
pre_output[j][k] += sum
# clip[-40.0, 40.0]
np.clip(pre_output, -40.0, 40.0)
# out(i, 0) = \sum_j bit(i, j) * preout(i, j)
for i in xrange(batch_size):
code_table = CodeTable(num_classes, ids[i])
length = code_table.get_length()
sum = 0.0
for j in xrange(length):
if code_table.cal_bit(j):
sum += pre_output[i][j]
out[i] = -1.0 * sum
# soft relu
np.clip(pre_output, -40.0, 40.0)
pre_output = np.log(1 + np.exp(pre_output))
pre_sum = pre_output.sum(1).reshape((batch_size, 1))
out += pre_sum
return out
class TestHSigmoidOp(OpTest):
def setUp(self):
self.op_type = "hierarchical_sigmoid"
num_classes = 6
embded_size = 10
batch_size = 5
x = np.random.random((batch_size, embded_size)).astype("float32")
w = np.random.random(
(batch_size, num_classes - 1, embded_size)).astype("float32")
ids = np.random.randint(0, num_classes, batch_size)
bias = np.random.random((1, num_classes - 1)).astype("float32")
self.inputs = {'X': x, 'W': w, 'Ids': ids, 'Bias': bias}
self.attrs = {'num_classes': num_classes}
out = hsigmoid(x, w, ids, bias, num_classes)
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X', 'W', 'Bias'], 'Out', no_grad_set=set('Ids'))
if __name__ == '__main__':
unittest.main()
......@@ -173,6 +173,16 @@ class TestBook(unittest.TestCase):
x=dat, label=lbl))
print(str(program))
def test_hsigmoid(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[2, 2], dtype='float32')
y = layers.data(name='y', shape=[1, 3], dtype='int64')
self.assertIsNotNone(
layers.hsigmoid(
input=x, label=y, num_classes=2))
print(str(program))
def test_sequence_expand(self):
program = Program()
with program_guard(program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册