未验证 提交 12e1719f 编写于 作者: J Jiabin Yang 提交者: GitHub

Merge pull request #14352 from JiabinYang/enhance_hierachical_sigmod_op

Enhance hierarchical sigmoid op
...@@ -98,7 +98,7 @@ paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs= ...@@ -98,7 +98,7 @@ paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=
paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False)) paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False))
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False))
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
......
...@@ -120,8 +120,22 @@ class SelectedRows { ...@@ -120,8 +120,22 @@ class SelectedRows {
*/ */
int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false); int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false);
void SyncIndex(); /*
* @brief Get the index of the key from id_to_index_ map.
*/
inline int64_t GetIndexFromId(int64_t key) {
auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) {
return -1;
} else {
return iter->second;
}
}
void SyncIndex();
/*
* @brief Get complete Dims before
*/
DDim GetCompleteDims() const { DDim GetCompleteDims() const {
std::vector<int64_t> dims = vectorize(value_->dims()); std::vector<int64_t> dims = vectorize(value_->dims());
dims[0] = height_; dims[0] = height_;
...@@ -133,9 +147,10 @@ class SelectedRows { ...@@ -133,9 +147,10 @@ class SelectedRows {
// SelectedRows are simply concated when adding together. Until a // SelectedRows are simply concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled. // SelectedRows add a Tensor, will the duplicate rows be handled.
Vector<int64_t> rows_; Vector<int64_t> rows_;
std::unordered_map<int64_t, int64_t> id_to_index_; std::unordered_map<int64_t, int64_t>
id_to_index_; // should not be used when rows_ has duplicate member
std::unique_ptr<Tensor> value_{nullptr}; std::unique_ptr<Tensor> value_{nullptr};
int64_t height_; int64_t height_; // height indicates the underline tensor's height
std::unique_ptr<RWLock> rwlock_{nullptr}; std::unique_ptr<RWLock> rwlock_{nullptr};
}; };
......
...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/hierarchical_sigmoid_op.h" #include "paddle/fluid/operators/hierarchical_sigmoid_op.h"
#include <string>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -70,13 +70,14 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { ...@@ -70,13 +70,14 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
const int64_t batch_size = ctx->GetInputDim("X")[0]; const int64_t batch_size = ctx->GetInputDim("X")[0];
std::vector<int64_t> output_shape({batch_size, 1}); std::vector<int64_t> output_shape({batch_size, 1});
ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->ShareLoD("X", /*->*/ "Out");
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
...@@ -86,27 +87,40 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -86,27 +87,40 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", AddInput("X",
"(Tensor, required) The input tensor with shape [N, D], " "(LoDTensor, required) The input tensor with shape [N, D], "
"where N is the size of mini-batch, and D is the feature size."); "where N is the size of mini-batch, and D is the feature size.");
AddInput("W", AddInput("W",
"(Tensor, required), The parameters of hierarchical " "(LoDTensor, required), The parameters of hierarchical "
"sigmoid operator, each of them is a 2-D tensor, the shape is" "sigmoid operator, each of them is a 2-D tensor, the shape is"
"[num_classes - 1, D]."); "[K, D]. Which K is the num of non-leaf node in Path Tree");
AddInput("Label", AddInput("Label",
"(Tensor, required), The labels of training data. It's a" "(LoDTensor, required), The labels of training data. It's a"
"tensor with shape [N, 1]."); "tensor with shape [N, 1].");
AddInput("PTable",
"(LoDTensor, optional), The Path Table from root to current word"
"it should have shape like [N, L], L is the length of the Path")
.AsDispensable();
AddInput(
"PathCode",
"(LoDTensor, optional), The Code on each Node of the Path from root "
"to current word"
"it should have shape like [N, L], L is the length of the Path")
.AsDispensable();
AddInput("Bias", AddInput("Bias",
"(Tensor, optional), The bias is a tensor with shape" "(LoDTensor, optional), The bias is a tensor with shape or "
"[1, num_classes - 1]."); "[num_classes, 1]"
AddOutput("Out", "[num_classes - 1, 1].")
"(Tensor, required) The output of hierarchical sigmoid operator." .AsDispensable();
"The shape is [N, 1]."); AddOutput(
"Out",
"(LoDTensor, required) The output of hierarchical sigmoid operator."
"The shape is [N, 1].");
AddOutput("PreOut", AddOutput("PreOut",
"(Tensor, required) A intermedia 2-D tensor with shape " "(LoDTensor, required) A intermedia 2-D tensor with shape "
"[batch_size, code_length], where code_length represents the " "[batch_size, code_length], where code_length represents the "
"maximum path length from root to leaf nodes.") "maximum path length from root to leaf nodes.")
.AsIntermediate(); .AsIntermediate();
AddAttr<AttrType>("num_classes", "(int, required), The number of classes") AddAttr<AttrType>("num_classes", "(int, optional), The number of classes")
.SetDefault(2); .SetDefault(2);
AddComment(R"DOC( AddComment(R"DOC(
The hierarchical sigmoid operator organize the classes into a binary tree. The hierarchical sigmoid operator organize the classes into a binary tree.
...@@ -115,6 +129,10 @@ belonging to the right branch. This idea is from ...@@ -115,6 +129,10 @@ belonging to the right branch. This idea is from
"F. Morin, Y. Bengio (AISTATS 05): "F. Morin, Y. Bengio (AISTATS 05):
Hierarchical Probabilistic Neural Network Language Model." Hierarchical Probabilistic Neural Network Language Model."
)DOC"); )DOC");
AddAttr<bool>("is_sparse",
"(boolean, default false) "
"Sparse update.")
.SetDefault(false);
} }
}; };
...@@ -124,16 +142,21 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { ...@@ -124,16 +142,21 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null."); PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@Grad) should not be null");
PADDLE_ENFORCE(ctx->HasInput("PreOut"), PADDLE_ENFORCE(ctx->HasInput("PreOut"),
"Input(Preout) should not be null."); "Input(Preout) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
"Output(W@Grad should not be null.)"); "Output(W@Grad should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X"))); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
if (ctx->HasOutput(framework::GradVarName("Bias"))) { "Output(X@Grad should not be null.");
ctx->SetOutputDim(framework::GradVarName("Bias"), if (!ctx->Attrs().Get<bool>("is_sparse")) {
ctx->GetInputDim("Bias")); if (ctx->HasOutput(framework::GradVarName("Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Bias"));
}
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
} }
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
...@@ -141,11 +164,55 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { ...@@ -141,11 +164,55 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
class HierarchicalSigmoidGradOpGradVarTypeInference
: public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto w_grad_var_name = op_desc.Output(framework::GradVarName("W")).front();
auto bias_grad_var_name_vec =
op_desc.Output(framework::GradVarName("Bias"));
std::string bias_grad_var_name;
bool hasBias = false;
if (bias_grad_var_name_vec.size()) {
hasBias = true;
bias_grad_var_name =
op_desc.Output(framework::GradVarName("Bias")).front();
}
auto attr = op_desc.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows";
block->Var(w_grad_var_name)
->SetType(framework::proto::VarType::SELECTED_ROWS);
if (hasBias) {
VLOG(30) << "hierarchical_sigmoid_grad op "
<< framework::GradVarName("Bias") << " is set to SelectedRows";
block->Var(bias_grad_var_name)
->SetType(framework::proto::VarType::SELECTED_ROWS);
}
} else {
VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
block->Var(w_grad_var_name)
->SetType(framework::proto::VarType::LOD_TENSOR);
if (hasBias) {
VLOG(30) << "hierarchical_sigmoid_grad op "
<< framework::GradVarName("Bias") << " is set to LoDTensor";
block->Var(bias_grad_var_name)
->SetType(framework::proto::VarType::LOD_TENSOR);
}
}
block->Var(w_grad_var_name)->SetDataType(block->Var("W")->GetDataType());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -153,7 +220,8 @@ namespace ops = paddle::operators; ...@@ -153,7 +220,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp, REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
ops::HierarchicalSigmoidOpMaker<int>, ops::HierarchicalSigmoidOpMaker<int>,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp); REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp,
ops::HierarchicalSigmoidGradOpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid, hierarchical_sigmoid,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>, ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -14,12 +14,16 @@ limitations under the License. */ ...@@ -14,12 +14,16 @@ limitations under the License. */
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h" #include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -28,20 +32,38 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,20 +32,38 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
using platform::Transform; using platform::Transform;
static std::vector<int64_t> PathToRows(const framework::LoDTensor& path) {
std::set<int64_t> rows;
for (int64_t i = 0; i < path.numel(); ++i) {
int64_t row = path.data<int64_t>()[i];
if (row < 0) {
continue;
}
rows.emplace(row);
}
return std::vector<int64_t>(rows.begin(), rows.end());
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X"); auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
auto* w = ctx.Input<framework::Tensor>("W"); auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W"));
auto* label = ctx.Input<framework::Tensor>("Label"); auto* path = ctx.Input<framework::LoDTensor>("PTable");
auto* bias = ctx.Input<framework::Tensor>("Bias"); auto* code = ctx.Input<framework::LoDTensor>("PathCode");
auto* out = ctx.Output<framework::Tensor>("Out"); auto& label = detail::Ref(ctx.Input<framework::LoDTensor>("Label"));
auto* pre_out = ctx.Output<framework::Tensor>("PreOut"); auto* bias = ctx.Input<framework::LoDTensor>("Bias");
auto* out = ctx.Output<framework::LoDTensor>("Out");
auto* pre_out = ctx.Output<framework::LoDTensor>("PreOut");
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes")); size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
int64_t code_length = math::FindLastSet(num_classes - 1); bool is_custom = false;
int64_t batch_size = in->dims()[0]; if (path) {
framework::Tensor sum; is_custom = true;
}
int64_t code_length =
path ? path->dims()[1] : math::FindLastSet(num_classes - 1);
int64_t batch_size = in.dims()[0];
framework::LoDTensor sum;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* pre_out_data = pre_out->mutable_data<T>( auto* pre_out_data = pre_out->mutable_data<T>(
framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
...@@ -52,7 +74,15 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -52,7 +74,15 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
zero(dev_ctx, pre_out, static_cast<T>(0.0)); zero(dev_ctx, pre_out, static_cast<T>(0.0));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
math::RowwiseSum<DeviceContext, T> row_sum; math::RowwiseSum<DeviceContext, T> row_sum;
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
if (!is_custom) {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(num_classes,
label.data<int64_t>()));
} else {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(*path, *code,
label.data<int64_t>()));
}
std::vector<int64_t> sum_dims({batch_size, 1UL}); std::vector<int64_t> sum_dims({batch_size, 1UL});
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace()); sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
...@@ -60,15 +90,15 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -60,15 +90,15 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto out_mat = framework::EigenVector<T>::Flatten(*out); auto out_mat = framework::EigenVector<T>::Flatten(*out);
if (bias) { if (bias) {
bit_code.Add(pre_out, *bias); bit_code->Add(*bias, pre_out);
} }
bit_code.Mul(pre_out, *w, *in); bit_code->Mul(pre_out, w, in);
// clip to [-40, 40] // clip to [-40, 40]
Transform<DeviceContext> trans; Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), pre_out_data, trans(ctx.template device_context<DeviceContext>(), pre_out_data,
pre_out_data + pre_out->numel(), pre_out_data, pre_out_data + pre_out->numel(), pre_out_data,
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0))); ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
bit_code.Sum(*pre_out, out, static_cast<T>(-1)); bit_code->Sum(*pre_out, out, static_cast<T>(-1));
// use softrelu to calculate cross entropy // use softrelu to calculate cross entropy
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log(); pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
row_sum(dev_ctx, *pre_out, &sum); row_sum(dev_ctx, *pre_out, &sum);
...@@ -84,50 +114,103 @@ template <typename DeviceContext, typename T> ...@@ -84,50 +114,103 @@ template <typename DeviceContext, typename T>
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X"); auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
auto* w = ctx.Input<framework::Tensor>("W"); auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W"));
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* path = ctx.Input<framework::LoDTensor>("PTable");
auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W")); auto* code = ctx.Input<framework::LoDTensor>("PathCode");
auto* bias_grad = auto* bias = ctx.Input<framework::LoDTensor>("Bias");
ctx.Output<framework::Tensor>(framework::GradVarName("Bias")); auto* in_grad =
auto* label = ctx.Input<framework::Tensor>("Label"); ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* pre_out = ctx.Input<framework::Tensor>("PreOut"); bool is_sparse = ctx.Attr<bool>("is_sparse");
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor pre_out_grad;
pre_out_grad.mutable_data<T>(pre_out->dims(), ctx.GetPlace());
in_grad->mutable_data<T>(ctx.GetPlace());
w_grad->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero; math::SetConstant<DeviceContext, T> zero;
auto& label = detail::Ref(ctx.Input<framework::LoDTensor>("Label"));
auto& pre_out = detail::Ref(ctx.Input<framework::LoDTensor>("PreOut"));
auto& out_grad = detail::Ref(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out")));
framework::LoDTensor pre_out_grad;
pre_out_grad.mutable_data<T>(pre_out.dims(), ctx.GetPlace());
in_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, in_grad, static_cast<T>(0.0)); zero(dev_ctx, in_grad, static_cast<T>(0.0));
zero(dev_ctx, w_grad, static_cast<T>(0.0));
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes")); size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
bool is_custom = false;
if (path) {
is_custom = true;
}
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
if (!is_custom) {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(num_classes,
label.data<int64_t>()));
} else {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(*path, *code,
label.data<int64_t>()));
}
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto pre_out_mat = EigenMatrix<T>::From(*pre_out); auto pre_out_mat = EigenMatrix<T>::From(pre_out);
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad); auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
auto out_grad_mat = EigenMatrix<T>::From(*out_grad); auto out_grad_mat = EigenMatrix<T>::From(out_grad);
Eigen::array<int, 2> bcast{1, static_cast<int>(pre_out_grad.dims()[1])}; Eigen::array<int, 2> bcast{1, static_cast<int>(pre_out_grad.dims()[1])};
// softrelu derivative // softrelu derivative
pre_out_grad_mat.device(place) = pre_out_grad_mat.device(place) =
static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp(); static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp();
bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b) bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b)
pre_out_grad_mat.device(place) = pre_out_grad_mat.device(place) =
pre_out_grad_mat * out_grad_mat.broadcast(bcast); pre_out_grad_mat * out_grad_mat.broadcast(bcast);
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
// be consistent with the clipping in forward. // be consistent with the clipping in forward.
if (bias_grad) {
bias_grad->mutable_data<T>(ctx.GetPlace()); if (!is_sparse) {
zero(dev_ctx, bias_grad, static_cast<T>(0.0)); auto* bias_grad =
bit_code.AddGrad(pre_out_grad, bias_grad); ctx.Output<framework::LoDTensor>(framework::GradVarName("Bias"));
if (bias_grad) {
bias_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
bit_code->AddGrad(pre_out_grad, bias_grad);
}
auto* w_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("W"));
w_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, w_grad, static_cast<T>(0.0));
bit_code->MulGradWeight(pre_out_grad, w_grad, in);
} else {
framework::Vector<int64_t> real_rows = PathToRows(*path);
auto* w_grad =
ctx.Output<framework::SelectedRows>(framework::GradVarName("W"));
w_grad->set_rows(real_rows);
// Build a map of id -> row_index to speed up finding the index of one id
w_grad->SyncIndex();
w_grad->set_height(w.dims()[0]);
auto* w_grad_value = w_grad->mutable_value();
framework::DDim temp_dim(w.dims());
set(temp_dim, 0, real_rows.size());
w_grad_value->mutable_data<T>(temp_dim, ctx.GetPlace());
zero(dev_ctx, w_grad_value, static_cast<T>(0.0));
auto* bias_grad =
ctx.Output<framework::SelectedRows>(framework::GradVarName("Bias"));
if (bias_grad) {
bias_grad->set_rows(real_rows);
// build ids -> rows index map
bias_grad->SyncIndex();
bias_grad->set_height(bias->dims()[0]);
auto* bias_grad_value = bias_grad->mutable_value();
std::vector<int64_t> dims = {static_cast<int64_t>(real_rows.size()),
bias->dims()[1]};
bias_grad_value->mutable_data<T>(framework::make_ddim(dims),
ctx.GetPlace());
zero(dev_ctx, bias_grad_value, static_cast<T>(0.0));
bit_code->AddGrad(pre_out_grad, bias_grad);
}
bit_code->MulGradWeight(pre_out_grad, w_grad, in);
} }
bit_code.MulGradWeight(pre_out_grad, w_grad, *in); bit_code->MulGradError(pre_out_grad, w, in_grad);
bit_code.MulGradError(pre_out_grad, *w, in_grad);
} }
}; };
......
...@@ -19,16 +19,15 @@ namespace operators { ...@@ -19,16 +19,15 @@ namespace operators {
namespace math { namespace math {
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat, void MatrixBitCodeFunctor<T>::Add(const framework::Tensor& vec,
const framework::Tensor& vec) { framework::Tensor* tmat) {
SimpleCodeTable code_table(num_classes_);
size_t batch_size = tmat->dims()[0]; size_t batch_size = tmat->dims()[0];
size_t width = tmat->dims()[1]; size_t width = tmat->dims()[1];
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i])); auto code = code_table_->get_code(i);
int code_length = code.get_length(); int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j); size_t index = code->calc_index(j);
tmat->data<T>()[i * width + j] += vec.data<T>()[index]; tmat->data<T>()[i * width + j] += vec.data<T>()[index];
} }
} }
...@@ -37,31 +36,46 @@ void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat, ...@@ -37,31 +36,46 @@ void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat, void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
framework::Tensor* vec) { framework::Tensor* vec) {
SimpleCodeTable code_table(num_classes_);
size_t batch_size = tmat.dims()[0]; size_t batch_size = tmat.dims()[0];
size_t width = tmat.dims()[1]; size_t width = tmat.dims()[1];
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i])); auto code = code_table_->get_code(i);
int code_length = code.get_length(); int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j); size_t index = code->calc_index(j);
vec->data<T>()[index] += tmat.data<T>()[i * width + j]; vec->data<T>()[index] += tmat.data<T>()[i * width + j];
} }
} }
} }
template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
framework::SelectedRows* vec) {
size_t batch_size = tmat.dims()[0];
size_t width = tmat.dims()[1];
for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table_->get_code(i);
int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code->calc_index(j);
int64_t row_index = vec->GetIndexFromId(static_cast<int64_t>(index));
vec->mutable_value()->data<T>()[row_index] +=
tmat.data<T>()[i * width + j];
}
}
}
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat, void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
framework::Tensor* sum, T scale_sum) { framework::Tensor* sum, T scale_sum) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0]; size_t num_samples = tmat.dims()[0];
size_t o_width = tmat.dims()[1]; size_t o_width = tmat.dims()[1];
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
T sm = static_cast<T>(0.0); T sm = static_cast<T>(0.0);
auto code = code_table(static_cast<size_t>(ids_[i])); auto code = code_table_->get_code(i);
int code_length = code.get_length(); int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) { if (code->calc_bit(j)) {
// calc_bit starts from right most bit, while data in tmat[i] is in the // calc_bit starts from right most bit, while data in tmat[i] is in the
// reverse order. // reverse order.
sm += tmat.data<T>()[i * o_width + j]; sm += tmat.data<T>()[i * o_width + j];
...@@ -75,7 +89,6 @@ template <typename T> ...@@ -75,7 +89,6 @@ template <typename T>
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat, void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
const framework::Tensor& weight, const framework::Tensor& weight,
const framework::Tensor& input) { const framework::Tensor& input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat->dims()[0]; size_t num_samples = tmat->dims()[0];
size_t tmat_width = tmat->dims()[1]; size_t tmat_width = tmat->dims()[1];
size_t input_width = input.dims()[1]; size_t input_width = input.dims()[1];
...@@ -84,10 +97,10 @@ void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat, ...@@ -84,10 +97,10 @@ void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
auto weight_value = weight.data<T>(); auto weight_value = weight.data<T>();
auto input_value = input.data<T>(); auto input_value = input.data<T>();
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i])); auto code = code_table_->get_code(i);
int code_length = code.get_length(); int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j); size_t index = code->calc_index(j);
T sum = static_cast<T>(0.0); T sum = static_cast<T>(0.0);
for (size_t k = 0; k < input_width; ++k) { for (size_t k = 0; k < input_width; ++k) {
sum += weight_value[weight_width * index + k] * sum += weight_value[weight_width * index + k] *
...@@ -102,7 +115,6 @@ template <typename T> ...@@ -102,7 +115,6 @@ template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat, void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
framework::Tensor* weight, framework::Tensor* weight,
const framework::Tensor& input) { const framework::Tensor& input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0]; size_t num_samples = tmat.dims()[0];
size_t input_width = input.dims()[1]; size_t input_width = input.dims()[1];
size_t tmat_width = tmat.dims()[1]; size_t tmat_width = tmat.dims()[1];
...@@ -111,10 +123,10 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat, ...@@ -111,10 +123,10 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
auto weight_value = weight->data<T>(); auto weight_value = weight->data<T>();
auto input_value = input.data<T>(); auto input_value = input.data<T>();
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i])); auto code = code_table_->get_code(i);
int code_length = code.get_length(); int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j); size_t index = code->calc_index(j);
for (size_t k = 0; k < input_width; ++k) { for (size_t k = 0; k < input_width; ++k) {
weight_value[weight_width * index + k] += weight_value[weight_width * index + k] +=
...@@ -124,11 +136,35 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat, ...@@ -124,11 +136,35 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
} }
} }
template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
framework::SelectedRows* weight,
const framework::Tensor& input) {
size_t num_samples = tmat.dims()[0];
size_t input_width = input.dims()[1];
size_t tmat_width = tmat.dims()[1];
size_t weight_width = weight->value().dims()[1];
auto tmat_value = tmat.data<T>();
auto weight_value = weight->mutable_value()->data<T>();
auto input_value = input.data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table_->get_code(i);
int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code->calc_index(j);
for (size_t k = 0; k < input_width; ++k) {
int64_t row_index = weight->GetIndexFromId(static_cast<int64_t>(index));
weight_value[row_index * weight_width + k] +=
tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
}
}
}
}
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat, void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
const framework::Tensor& weight, const framework::Tensor& weight,
framework::Tensor* input) { framework::Tensor* input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0]; size_t num_samples = tmat.dims()[0];
size_t tmat_width = tmat.dims()[1]; size_t tmat_width = tmat.dims()[1];
size_t input_width = input->dims()[1]; size_t input_width = input->dims()[1];
...@@ -138,10 +174,10 @@ void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat, ...@@ -138,10 +174,10 @@ void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
auto input_value = input->data<T>(); auto input_value = input->data<T>();
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i])); auto code = code_table_->get_code(i);
int code_length = code.get_length(); int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j); size_t index = code->calc_index(j);
for (size_t k = 0; k < input_width; ++k) { for (size_t k = 0; k < input_width; ++k) {
input_value[input_width * i + k] += input_value[input_width * i + k] +=
...@@ -154,14 +190,13 @@ void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat, ...@@ -154,14 +190,13 @@ void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) { void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat->dims()[0]; size_t num_samples = tmat->dims()[0];
size_t o_width = tmat->dims()[1]; size_t o_width = tmat->dims()[1];
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i])); auto code = code_table_->get_code(i);
int code_length = code.get_length(); int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) { if (code->calc_bit(j)) {
tmat->data<T>()[i * o_width + j] -= 1; tmat->data<T>()[i * o_width + j] -= 1;
} }
} }
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -92,9 +94,27 @@ inline int clz(const T& value) { ...@@ -92,9 +94,27 @@ inline int clz(const T& value) {
inline size_t FindLastSet(size_t x) { return sizeof(size_t) * 8 - clz(x); } inline size_t FindLastSet(size_t x) { return sizeof(size_t) * 8 - clz(x); }
#endif // !_WIN32 #endif // !_WIN32
// set a code interface to create multiple code
class Code {
public:
virtual ~Code() {}
virtual size_t calc_index(int bit) const = 0;
virtual bool calc_bit(int bit) const = 0;
virtual int get_length() const = 0;
};
// set a CodeTable interface to create multiple code table
class CodeTable {
public:
virtual std::unique_ptr<Code> get_code(int64_t code) const = 0;
virtual size_t size() const = 0;
virtual int get_max_code_length() const = 0;
virtual ~CodeTable() {}
};
struct SimpleCode { class SimpleCode : public Code {
SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {} public:
SimpleCode(size_t code, size_t num_classes, const int64_t* ids)
: c_(static_cast<size_t>(ids[code]) + num_classes) {}
/** /**
* Here the id of root shoud be 1 rather than 0, thus the encoding of class c * Here the id of root shoud be 1 rather than 0, thus the encoding of class c
* is `c + num_classes` and all siblings can get the same weight indice using * is `c + num_classes` and all siblings can get the same weight indice using
...@@ -104,41 +124,121 @@ struct SimpleCode { ...@@ -104,41 +124,121 @@ struct SimpleCode {
* Binary classification path is the suffixes of encoding, thus leave out the * Binary classification path is the suffixes of encoding, thus leave out the
* left most bit in calc_bit. * left most bit in calc_bit.
*/ */
inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; } size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; }
inline bool calc_bit(int bit) const { return c_ & (1 << bit); } bool calc_bit(int bit) const { return c_ & (1 << bit); }
inline int get_length() const { return FindLastSet(c_) - 1; } int get_length() const { return FindLastSet(c_) - 1; }
private: private:
size_t c_; size_t c_;
}; };
struct SimpleCodeTable { template <typename T>
explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {} class CustomCode : public Code {
SimpleCode operator()(size_t code) const { public:
return SimpleCode(code, num_classes_); CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode,
const int64_t* ids, int index)
: ids_(ids), index_(index) {
ptable_ = ptable.Slice(index, index + 1);
pcode_ = pcode.Slice(index, index + 1);
}
/**
* Here the id of root shoud be 1 rather than 0, thus the encoding of class c
* is `c + num_classes` and all siblings can get the same weight indice using
* prefixes.
* Weight index is the prefixes of encoding, thus leave out the right most
* bit in calc_index.
* Binary classification path is the suffixes of encoding, thus leave out the
* left most bit in calc_bit.
*/
size_t calc_index(int bit) const { return ptable_.data<T>()[bit]; }
bool calc_bit(int bit) const { return pcode_.data<T>()[bit]; }
int get_length() const {
int length = 0;
for (int i = 0; i < static_cast<int>(ptable_.dims()[1]); i++) {
if (ptable_.data<T>()[i] >= 0) {
length++;
} else {
return length;
}
}
return length;
}
private:
framework::Tensor ptable_;
framework::Tensor pcode_;
const int64_t* ids_;
const int index_;
};
class SimpleCodeTable : public CodeTable {
public:
SimpleCodeTable(size_t num_classes, const int64_t* ids)
: num_classes_(num_classes), ids_(ids) {}
std::unique_ptr<Code> get_code(int64_t code) const {
std::unique_ptr<Code> coder(new SimpleCode(code, num_classes_, ids_));
return coder;
} }
size_t size() const { return num_classes_; } size_t size() const { return num_classes_; }
int get_max_code_length() const { return FindLastSet(num_classes_ - 1); } int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }
private: private:
size_t num_classes_; size_t num_classes_;
const int64_t* ids_;
};
template <typename T>
class CustomCodeTable : public CodeTable {
public:
CustomCodeTable(const framework::Tensor& ptable,
const framework::Tensor& pcode, const int64_t* ids)
: ptable_(ptable), pcode_(pcode), ids_(ids) {}
std::unique_ptr<Code> get_code(int64_t code) const {
std::unique_ptr<Code> coder(new CustomCode<T>(ptable_, pcode_, ids_, code));
return coder;
}
size_t size() const { return static_cast<size_t>(ptable_.dims()[1]); }
int get_max_code_length() const {
return static_cast<size_t>(ptable_.dims()[1]);
}
private:
const framework::Tensor& ptable_;
const framework::Tensor& pcode_;
const int64_t* ids_;
}; };
template <typename T> template <typename T>
class MatrixBitCodeFunctor { class MatrixBitCodeFunctor {
public: public:
explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids) MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids)
: num_classes_(num_classes), ids_(ids) {} : num_classes_(num_classes),
ids_(ids),
code_table_(new SimpleCodeTable(num_classes, ids)) {}
MatrixBitCodeFunctor(const framework::Tensor& ptable,
const framework::Tensor& pcode, const int64_t* ids)
: num_classes_(static_cast<size_t>(ptable.dims()[1])),
ids_(ids),
code_table_(new CustomCodeTable<int64_t>(ptable, pcode, ids)) {}
/* For j < code_length /* For j < code_length
tmat(i, j) += vec(0, index(i, j)) tmat(i, j) += vec(0, index(i, j))
*/ */
void Add(framework::Tensor* tmat, const framework::Tensor& vec); void Add(const framework::Tensor& vec, framework::Tensor* tmat);
/* For j < code_length /* For j < code_length
vec(0, index(i, j)) += tmat(i, j) vec(0, index(i, j)) += tmat(i, j)
*/ */
void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec); void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec);
/* For selected rows For j < code_length
vec(0, index(i, j)) += tmat(i, j)
*/
void AddGrad(const framework::Tensor& tmat, framework::SelectedRows* vec);
/* For j < code_length /* For j < code_length
sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
*/ */
...@@ -159,6 +259,12 @@ class MatrixBitCodeFunctor { ...@@ -159,6 +259,12 @@ class MatrixBitCodeFunctor {
*/ */
void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight, void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight,
const framework::Tensor& input); const framework::Tensor& input);
/* For SelectedRows Weight, For index(i, j) >= 0:
weight.row(index(i, j)) += tmat(i, j) * input.row(i)
*/
void MulGradWeight(const framework::Tensor& tmat,
framework::SelectedRows* weight,
const framework::Tensor& input);
/* For j < code_length /* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j)) input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/ */
...@@ -167,6 +273,7 @@ class MatrixBitCodeFunctor { ...@@ -167,6 +273,7 @@ class MatrixBitCodeFunctor {
size_t num_classes_; size_t num_classes_;
const int64_t* ids_; const int64_t* ids_;
std::unique_ptr<CodeTable> code_table_;
}; };
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -4587,27 +4587,43 @@ def hsigmoid(input, ...@@ -4587,27 +4587,43 @@ def hsigmoid(input,
num_classes, num_classes,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
name=None): name=None,
path_table=None,
path_code=None,
is_custom=False,
is_sparse=False):
""" """
The hierarchical sigmoid operator is used to accelerate the training The hierarchical sigmoid operator is used to accelerate the training
process of language model. This operator organizes the classes into a process of language model. This operator organizes the classes into a
complete binary tree, each leaf node represents a class(a word) and each complete binary tree, or you can use is_custom to pass your own tree to
implement hierarchical. Each leaf node represents a class(a word) and each
internal node acts as a binary classifier. For each word there's a unique internal node acts as a binary classifier. For each word there's a unique
path from root to it's leaf node, hsigmoid calculate the cost for each path from root to it's leaf node, hsigmoid calculate the cost for each
internal node on the path, and sum them to get a total cost. hsigmoid can internal node on the path, and sum them to get a total cost. hsigmoid can
achive a acceleration from :math:`O(N)` to :math:`O(logN)`, where :math:`N` achive a acceleration from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
represents the size of word dict. represents the size of word dict.
Refer to `Hierarchical Probabilistic Neural Network Language Model Using default tree you can Refer to `Hierarchical Probabilistic Neural Network Language Model
<http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>`_ <http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>`_
And if you want to use the costumed tree by set 'is_custom' as true you may need to do following things first:
1. using your word dict to build a binary tree, each leaf node should be an word of your word dict
2. build a dict to store word_id -> word's leaf to root path, we call it path_table.
3. build a dict to store word_id -> code of word's leaf to root path, we call it path_code. Code
means label of each binary classification, using 1 indicate true, 0 indicate false.
4. now, each word should has its path and code along the path, you can pass a batch of path and code
related to the same batch of inputs.
Args: Args:
input (Variable): The input tensor variable with shape input (Variable): The input tensor variable with shape
:math:`[N \\times D]`, where :math:`N` is the size of mini-batch, :math:`[N \\times D]`, where :math:`N` is the size of mini-batch,
and :math:`D` is the feature size. and :math:`D` is the feature size.
label (Variable): The tensor variable contains labels of training data. label (Variable): The tensor variable contains labels of training data.
It's a tensor with shape is :math:`[N \\times 1]`. It's a tensor with shape is :math:`[N \\times 1]`.
num_classes: (int), The number of classes, must not be less than 2. num_classes: (int), The number of classes, must not be less than 2. with default tree this has to be set,
it should never be None under is_custom=False, but while is_custom is true, it should be non leaf num
which indicates the num of classes using by binary classify.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid
will create ParamAttr as param_attr. If the Initializer of the param_attr will create ParamAttr as param_attr. If the Initializer of the param_attr
...@@ -4619,9 +4635,19 @@ def hsigmoid(input, ...@@ -4619,9 +4635,19 @@ def hsigmoid(input,
is not set, the bias is initialized zero. Default: None. is not set, the bias is initialized zero. Default: None.
name (str|None): A name for this layer(optional). If set None, the layer name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None. will be named automatically. Default: None.
path_table: (Variable|None) this variable can store each batch of samples' path to root,
it should be in leaf -> root order
path_table should have the same shape with path_code, and for each sample i path_table[i] indicates a np.array like
structure and each element in this array is indexes in parent nodes' Weight Matrix.
path_code: (Variable|None) this variable can store each batch of samples' code,
each code consist with every code of parent nodes. it should be in leaf -> root order
is_custom: (bool|False)using user defined binary tree instead of default complete binary tree, if costum is
set you need to set path_table/path_code/num_classes, otherwise num_classes should be set
is_sparse: (bool|False)using sparse update instead of dense update, if set, the gradient
of W and input will be sparse.
Returns: Returns:
Out: (Tensor) The cost of hierarchical sigmoid operator. the shape is [N, 1] Out: (LodTensor) The cost of hierarchical sigmoid operator. the shape is [N, 1]
Examples: Examples:
...@@ -4637,27 +4663,62 @@ def hsigmoid(input, ...@@ -4637,27 +4663,62 @@ def hsigmoid(input,
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
pre_out = helper.create_variable_for_type_inference(dtype) pre_out = helper.create_variable_for_type_inference(dtype)
dim = input.shape[1] dim = input.shape[1]
if num_classes < 2: if ((num_classes is None) or (num_classes < 2)) and (not is_custom):
raise ValueError("num_classes must not be less than 2.") raise ValueError(
weights = helper.create_parameter( "num_classes must not be less than 2 with default tree")
attr=helper.param_attr,
shape=[num_classes - 1, dim], if (is_custom) and (path_code is None):
is_bias=False, raise ValueError("path_code should not be None with costum tree")
dtype=input.dtype) elif (is_custom) and (path_table is None):
inputs = {"X": input, "W": weights, "Label": label} raise ValueError("path_table should not be None with costum tree")
if helper.bias_attr: elif (is_custom) and (num_classes is None):
bias = helper.create_parameter( raise ValueError("num_classes should not be None with costum tree")
attr=helper.bias_attr, else:
shape=[1, num_classes - 1], pass
is_bias=True,
weights = None
if not is_custom:
weights = helper.create_parameter(
attr=helper.param_attr,
shape=[num_classes - 1, dim],
is_bias=False,
dtype=input.dtype) dtype=input.dtype)
inputs['Bias'] = bias else:
weights = helper.create_parameter(
attr=helper.param_attr,
shape=[num_classes, dim],
is_bias=False,
dtype=input.dtype)
inputs = {
"X": input,
"W": weights,
"PTable": path_table,
"PathCode": path_code,
"Label": label
}
if helper.bias_attr:
if not is_custom:
bias = helper.create_parameter(
attr=helper.bias_attr,
shape=[num_classes - 1, 1],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = bias
else:
bias = helper.create_parameter(
attr=helper.bias_attr,
shape=[num_classes, 1],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = bias
helper.append_op( helper.append_op(
type="hierarchical_sigmoid", type="hierarchical_sigmoid",
inputs=inputs, inputs=inputs,
outputs={"Out": out, outputs={"Out": out,
"PreOut": pre_out}, "PreOut": pre_out},
attrs={"num_classes": num_classes}) attrs={"num_classes": num_classes,
"is_sparse": is_sparse})
return out return out
......
...@@ -16,6 +16,8 @@ from __future__ import print_function ...@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core
import paddle.fluid as fluid
import math import math
from op_test import OpTest from op_test import OpTest
...@@ -40,6 +42,29 @@ class CodeTable(object): ...@@ -40,6 +42,29 @@ class CodeTable(object):
return self.c & (1 << bit) return self.c & (1 << bit)
class CodeTableWithCustomTree(object):
def __init__(self, path_table, path_code, index):
self.ptable_ = path_table
self.pcode_ = path_code
self.index_ = index
def cal_index(self, bit):
return self.ptable_[self.index_][bit]
def get_length(self):
length = 0
for ele in self.ptable_[self.index_]: # find the first -1 to stop trace
if ele >= 0:
length = length + 1
else:
return length
return length
def cal_bit(self, bit):
return self.pcode_[self.index_][bit]
def hsigmoid(x, w, label, bias, num_classes): def hsigmoid(x, w, label, bias, num_classes):
batch_size = x.shape[0] batch_size = x.shape[0]
code_length = find_latest_set(num_classes - 1) code_length = find_latest_set(num_classes - 1)
...@@ -52,7 +77,7 @@ def hsigmoid(x, w, label, bias, num_classes): ...@@ -52,7 +77,7 @@ def hsigmoid(x, w, label, bias, num_classes):
length = code_table.get_length() length = code_table.get_length()
for j in range(length): for j in range(length):
idx = code_table.cal_index(j) idx = code_table.cal_index(j)
pre_output[i][j] += bias[0][idx] pre_output[i][j] += bias[idx][0]
for i in range(batch_size): for i in range(batch_size):
code_table = CodeTable(num_classes, label[i]) code_table = CodeTable(num_classes, label[i])
length = code_table.get_length() length = code_table.get_length()
...@@ -77,17 +102,58 @@ def hsigmoid(x, w, label, bias, num_classes): ...@@ -77,17 +102,58 @@ def hsigmoid(x, w, label, bias, num_classes):
return pre_output, out return pre_output, out
def hsigmoidWithCustomTree(x, w, path_table, path_code, label, bias,
num_classes):
batch_size = x.shape[0]
code_length = len(path_table[0])
code_table = [0 for _ in range(code_length)]
# init pre_out with shape [N, code_length]
pre_output = np.zeros((batch_size, code_length))
pre_sum = np.zeros((batch_size, 1))
out = np.zeros((batch_size, 1)).astype("float32")
if isinstance(bias, np.ndarray):
for i in range(batch_size):
code_table = CodeTableWithCustomTree(path_table, path_code, i)
length = code_table.get_length()
for j in range(length):
idx = code_table.cal_index(j)
pre_output[i][j] += bias[idx][0]
for i in range(batch_size):
code_table = CodeTableWithCustomTree(path_table, path_code, i)
length = code_table.get_length()
for j in range(length):
idx = code_table.cal_index(j)
pre_output[i][j] += np.dot(w[idx], x[i])
# clip[-40.0, 40.0]
pre_output = np.clip(pre_output, -40.0, 40.0)
# out(i, 0) = \sum_j bit(i, j) * preout(i, j)
for i in range(batch_size):
code_table = CodeTableWithCustomTree(path_table, path_code, i)
length = code_table.get_length()
sum = 0.0
for j in range(length):
if code_table.cal_bit(j):
sum += pre_output[i][j]
out[i] = -1.0 * sum
# soft relu
pre_output = np.log(1 + np.exp(pre_output))
pre_sum = pre_output.sum(1).reshape((batch_size, 1))
out += pre_sum
return pre_output, out
class TestHSigmoidOp(OpTest): class TestHSigmoidOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "hierarchical_sigmoid" self.op_type = "hierarchical_sigmoid"
num_classes = 6 num_classes = 6
feature_size = 8 feature_size = 8
batch_size = 4 batch_size = 4
x = np.random.random((batch_size, feature_size)).astype("float32") x = np.random.random((batch_size, feature_size)).astype("float32") * 2
w = np.random.random((num_classes - 1, feature_size)).astype("float32") w = np.random.random(
(num_classes - 1, feature_size)).astype("float32") * 2
label = np.random.randint(0, num_classes, (batch_size, 1)) label = np.random.randint(0, num_classes, (batch_size, 1))
bias = np.random.random((1, num_classes - 1)).astype("float32") bias = np.random.random((num_classes - 1, 1)).astype("float32")
self.attrs = {'num_classes': num_classes} self.attrs = {'num_classes': num_classes, 'is_sparse': False}
self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias}
pre_output, out = hsigmoid(x, w, label, bias, num_classes) pre_output, out = hsigmoid(x, w, label, bias, num_classes)
self.outputs = {'PreOut': pre_output, 'Out': out} self.outputs = {'PreOut': pre_output, 'Out': out}
...@@ -99,5 +165,185 @@ class TestHSigmoidOp(OpTest): ...@@ -99,5 +165,185 @@ class TestHSigmoidOp(OpTest):
self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
class TestHSigmoidOpSparse(OpTest):
def setUp(self):
self.op_type = "hierarchical_sigmoid"
num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample
feature_size = 8
batch_size = 4
x = np.random.random((batch_size, feature_size)).astype("float32")
w = np.random.random((num_classes - 1, feature_size)).astype("float32")
label = np.array([0, 1, 4, 5])
path_table = np.array(
[(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
(0, 2, -1, -1,
-1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
bias = np.random.random((num_classes - 1, 1)).astype("float32")
self.attrs = {'num_classes': num_classes, 'is_sparse': True}
self.inputs = {
'X': x,
'W': w,
'PTable': path_table,
'PathCode': path_code,
'Label': label,
'Bias': bias
}
pre_output, out = hsigmoidWithCustomTree(x, w, path_table, path_code,
label, bias, num_classes)
self.outputs = {'PreOut': pre_output, 'Out': out}
def test_check_output(self):
self.check_output()
class TestHSigmoidOpWithSparseGrad(unittest.TestCase):
def hs_net_conf(self, is_sparse):
input_word = fluid.layers.data(name="x", shape=[1], dtype='int64')
path_table = fluid.layers.data(
name='path_table', shape=[3], dtype='int64')
path_code = fluid.layers.data(
name='path_code', shape=[3], dtype='int64')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
data_list = [input_word, path_table, path_code, label]
emb = fluid.layers.embedding(
input=input_word,
is_sparse=is_sparse,
size=[3, 3],
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(3))))
cost = fluid.layers.hsigmoid(
input=emb,
label=label,
bias_attr=True,
num_classes=3,
path_table=path_table,
path_code=path_code,
is_custom=True,
is_sparse=is_sparse)
avg_cost = fluid.layers.reduce_mean(cost)
return avg_cost, data_list
def training_test(self, is_sparse):
with fluid.program_guard(fluid.Program(), fluid.Program()):
start_up = fluid.default_startup_program()
start_up.random_seed = 1 # Fix random seed
x = np.arange(6).reshape(6)
path_table = np.array([(1, 2, -1), (1, 2, -1)])
path_code = np.array([(1, 0, -1), (0, 0, -1)])
label = np.array([1, 4])
loss, data_list = self.hs_net_conf(is_sparse)
optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
optimizer.minimize(loss)
main_program = fluid.default_main_program()
place = fluid.CPUPlace()
feeder = fluid.DataFeeder(feed_list=data_list, place=place)
exe = fluid.Executor(place)
exe.run(start_up)
result = list()
for i in range(10):
data = [([[x[i % 2]]], [list(path_table[i % 2])],
[list(path_code[i % 2])], [label[i % 2]])]
loss_val = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[loss])
result.append(loss_val)
return result
def test_hs_grad_with_sparse(self):
dense_result = self.training_test(is_sparse=False)
sparse_result = self.training_test(is_sparse=True)
assert (dense_result == sparse_result)
class TestHSigmoidOpWithCostumTree(OpTest):
def setUp(self):
self.op_type = "hierarchical_sigmoid"
num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample
feature_size = 8
batch_size = 4
x = np.random.random((batch_size, feature_size)).astype("float32") * 2
w = np.random.random(
(num_classes - 1, feature_size)).astype("float32") * 2
label = np.array([0, 1, 4, 5])
path_table = np.array(
[(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
(0, 2, -1, -1,
-1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
bias = np.random.random((num_classes - 1, 1)).astype("float32")
self.attrs = {'num_classes': num_classes, 'is_sparse': False}
self.inputs = {
'X': x,
'W': w,
'PTable': path_table,
'PathCode': path_code,
'Label': label,
'Bias': bias
}
pre_output, out = hsigmoidWithCustomTree(x, w, path_table, path_code,
label, bias, num_classes)
self.outputs = {'PreOut': pre_output, 'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
class TestHSigmoidOpWithCostumTreeWithoutBias(OpTest):
def setUp(self):
self.op_type = "hierarchical_sigmoid"
num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample
feature_size = 8
batch_size = 4
x = np.random.random((batch_size, feature_size)).astype("float32") * 2
w = np.random.random(
(num_classes - 1, feature_size)).astype("float32") * 2
label = np.array([0, 1, 4, 5])
path_table = np.array(
[(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
(0, 2, -1, -1,
-1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
# bias = np.random.random((num_classes - 1, 1)).astype("float32")
self.attrs = {'num_classes': num_classes, 'is_sparse': False}
self.inputs = {
'X': x,
'W': w,
'PTable': path_table,
'PathCode': path_code,
'Label': label,
}
pre_output, out = hsigmoidWithCustomTree(
x=x,
w=w,
path_table=path_table,
path_code=path_code,
label=label,
bias=None,
num_classes=num_classes)
self.outputs = {'PreOut': pre_output, 'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X', 'W'], ['Out'], no_grad_set=set('Label'))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -185,6 +185,25 @@ class TestBook(unittest.TestCase): ...@@ -185,6 +185,25 @@ class TestBook(unittest.TestCase):
input=x, label=y, num_classes=2)) input=x, label=y, num_classes=2))
print(str(program)) print(str(program))
# test hsigmod with custom tree structure
program2 = Program()
with program_guard(program2):
x2 = layers.data(name='x2', shape=[4, 8], dtype='float32')
y2 = layers.data(name='y2', shape=[4], dtype='int64')
path_table = layers.data(
name='path_table', shape=[4, 6], dtype='int64')
path_code = layers.data(
name='path_code', shape=[4, 6], dtype='int64')
self.assertIsNotNone(
layers.hsigmoid(
input=x2,
label=y2,
num_classes=6,
path_table=path_table,
path_code=path_code,
is_custom=True))
print(str(program2))
def test_sequence_expand(self): def test_sequence_expand(self):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册