diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index e15fdc82572d543dfc18a71b2b98e4ee59275a44..c40f6033419a2425d9996eb9a4584fc9cd1a70e3 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -98,7 +98,7 @@ paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs= paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False)) -paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) +paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index 55ca02038e083da4f8984f70fecf4ca2d878088e..44384082dbaf7a8d654e8461da87009bde33a3d5 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -120,8 +120,22 @@ class SelectedRows { */ int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false); - void SyncIndex(); + /* + * @brief Get the index of the key from id_to_index_ map. + */ + inline int64_t GetIndexFromId(int64_t key) { + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + return -1; + } else { + return iter->second; + } + } + void SyncIndex(); + /* + * @brief Get complete Dims before + */ DDim GetCompleteDims() const { std::vector dims = vectorize(value_->dims()); dims[0] = height_; @@ -133,9 +147,10 @@ class SelectedRows { // SelectedRows are simply concated when adding together. Until a // SelectedRows add a Tensor, will the duplicate rows be handled. Vector rows_; - std::unordered_map id_to_index_; + std::unordered_map + id_to_index_; // should not be used when rows_ has duplicate member std::unique_ptr value_{nullptr}; - int64_t height_; + int64_t height_; // height indicates the underline tensor's height std::unique_ptr rwlock_{nullptr}; }; diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index dadd054b9a6f8d44f4e5832888052bffde34c827..972dcf5494e9acd47e7ff615db45f056a43724a6 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/hierarchical_sigmoid_op.h" +#include #include - namespace paddle { namespace operators { @@ -70,13 +70,14 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { const int64_t batch_size = ctx->GetInputDim("X")[0]; std::vector output_shape({batch_size, 1}); ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + ctx->ShareLoD("X", /*->*/ "Out"); } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); } }; @@ -86,27 +87,40 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "(Tensor, required) The input tensor with shape [N, D], " + "(LoDTensor, required) The input tensor with shape [N, D], " "where N is the size of mini-batch, and D is the feature size."); AddInput("W", - "(Tensor, required), The parameters of hierarchical " + "(LoDTensor, required), The parameters of hierarchical " "sigmoid operator, each of them is a 2-D tensor, the shape is" - "[num_classes - 1, D]."); + "[K, D]. Which K is the num of non-leaf node in Path Tree"); AddInput("Label", - "(Tensor, required), The labels of training data. It's a" + "(LoDTensor, required), The labels of training data. It's a" "tensor with shape [N, 1]."); + AddInput("PTable", + "(LoDTensor, optional), The Path Table from root to current word" + "it should have shape like [N, L], L is the length of the Path") + .AsDispensable(); + AddInput( + "PathCode", + "(LoDTensor, optional), The Code on each Node of the Path from root " + "to current word" + "it should have shape like [N, L], L is the length of the Path") + .AsDispensable(); AddInput("Bias", - "(Tensor, optional), The bias is a tensor with shape" - "[1, num_classes - 1]."); - AddOutput("Out", - "(Tensor, required) The output of hierarchical sigmoid operator." - "The shape is [N, 1]."); + "(LoDTensor, optional), The bias is a tensor with shape or " + "[num_classes, 1]" + "[num_classes - 1, 1].") + .AsDispensable(); + AddOutput( + "Out", + "(LoDTensor, required) The output of hierarchical sigmoid operator." + "The shape is [N, 1]."); AddOutput("PreOut", - "(Tensor, required) A intermedia 2-D tensor with shape " + "(LoDTensor, required) A intermedia 2-D tensor with shape " "[batch_size, code_length], where code_length represents the " "maximum path length from root to leaf nodes.") .AsIntermediate(); - AddAttr("num_classes", "(int, required), The number of classes") + AddAttr("num_classes", "(int, optional), The number of classes") .SetDefault(2); AddComment(R"DOC( The hierarchical sigmoid operator organize the classes into a binary tree. @@ -115,6 +129,10 @@ belonging to the right branch. This idea is from "F. Morin, Y. Bengio (AISTATS 05): Hierarchical Probabilistic Neural Network Language Model." )DOC"); + AddAttr("is_sparse", + "(boolean, default false) " + "Sparse update.") + .SetDefault(false); } }; @@ -124,16 +142,21 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@Grad) should not be null"); PADDLE_ENFORCE(ctx->HasInput("PreOut"), "Input(Preout) should not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")), - "Output(W@Grad should not be null.)"); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X"))); - if (ctx->HasOutput(framework::GradVarName("Bias"))) { - ctx->SetOutputDim(framework::GradVarName("Bias"), - ctx->GetInputDim("Bias")); + "Output(W@Grad should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@Grad should not be null."); + if (!ctx->Attrs().Get("is_sparse")) { + if (ctx->HasOutput(framework::GradVarName("Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Bias"), + ctx->GetInputDim("Bias")); + } + ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); } - ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } @@ -141,11 +164,55 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); } }; +class HierarchicalSigmoidGradOpGradVarTypeInference + : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto w_grad_var_name = op_desc.Output(framework::GradVarName("W")).front(); + auto bias_grad_var_name_vec = + op_desc.Output(framework::GradVarName("Bias")); + std::string bias_grad_var_name; + bool hasBias = false; + if (bias_grad_var_name_vec.size()) { + hasBias = true; + bias_grad_var_name = + op_desc.Output(framework::GradVarName("Bias")).front(); + } + auto attr = op_desc.GetAttr("is_sparse"); + bool is_sparse = boost::get(attr); + if (is_sparse) { + VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") + << " is set to SelectedRows"; + block->Var(w_grad_var_name) + ->SetType(framework::proto::VarType::SELECTED_ROWS); + if (hasBias) { + VLOG(30) << "hierarchical_sigmoid_grad op " + << framework::GradVarName("Bias") << " is set to SelectedRows"; + block->Var(bias_grad_var_name) + ->SetType(framework::proto::VarType::SELECTED_ROWS); + } + } else { + VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") + << " is set to LoDTensor"; + block->Var(w_grad_var_name) + ->SetType(framework::proto::VarType::LOD_TENSOR); + if (hasBias) { + VLOG(30) << "hierarchical_sigmoid_grad op " + << framework::GradVarName("Bias") << " is set to LoDTensor"; + block->Var(bias_grad_var_name) + ->SetType(framework::proto::VarType::LOD_TENSOR); + } + } + block->Var(w_grad_var_name)->SetDataType(block->Var("W")->GetDataType()); + } +}; + } // namespace operators } // namespace paddle @@ -153,7 +220,8 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp, ops::HierarchicalSigmoidOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp); +REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp, + ops::HierarchicalSigmoidGradOpGradVarTypeInference); REGISTER_OP_CPU_KERNEL( hierarchical_sigmoid, ops::HierarchicalSigmoidOpKernel, diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 79980cda53befc2bce3cbd79a15da58b39c922ad..07ff8f947e59d2954783e2ba537bfce3cb320f22 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -14,12 +14,16 @@ limitations under the License. */ #pragma once #include +#include #include +#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/clip_op.h" +#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/matrix_bit_code.h" #include "paddle/fluid/platform/transform.h" + namespace paddle { namespace operators { @@ -28,20 +32,38 @@ template ; using platform::Transform; +static std::vector PathToRows(const framework::LoDTensor& path) { + std::set rows; + for (int64_t i = 0; i < path.numel(); ++i) { + int64_t row = path.data()[i]; + if (row < 0) { + continue; + } + rows.emplace(row); + } + return std::vector(rows.begin(), rows.end()); +} template class HierarchicalSigmoidOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* w = ctx.Input("W"); - auto* label = ctx.Input("Label"); - auto* bias = ctx.Input("Bias"); - auto* out = ctx.Output("Out"); - auto* pre_out = ctx.Output("PreOut"); + auto& in = detail::Ref(ctx.Input("X")); + auto& w = detail::Ref(ctx.Input("W")); + auto* path = ctx.Input("PTable"); + auto* code = ctx.Input("PathCode"); + auto& label = detail::Ref(ctx.Input("Label")); + auto* bias = ctx.Input("Bias"); + auto* out = ctx.Output("Out"); + auto* pre_out = ctx.Output("PreOut"); size_t num_classes = static_cast(ctx.Attr("num_classes")); - int64_t code_length = math::FindLastSet(num_classes - 1); - int64_t batch_size = in->dims()[0]; - framework::Tensor sum; + bool is_custom = false; + if (path) { + is_custom = true; + } + int64_t code_length = + path ? path->dims()[1] : math::FindLastSet(num_classes - 1); + int64_t batch_size = in.dims()[0]; + framework::LoDTensor sum; auto& dev_ctx = ctx.template device_context(); auto* pre_out_data = pre_out->mutable_data( framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); @@ -52,7 +74,15 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { zero(dev_ctx, pre_out, static_cast(0.0)); auto& place = *ctx.template device_context().eigen_device(); math::RowwiseSum row_sum; - math::MatrixBitCodeFunctor bit_code(num_classes, label->data()); + + std::unique_ptr> bit_code; + if (!is_custom) { + bit_code.reset(new math::MatrixBitCodeFunctor(num_classes, + label.data())); + } else { + bit_code.reset(new math::MatrixBitCodeFunctor(*path, *code, + label.data())); + } std::vector sum_dims({batch_size, 1UL}); sum.mutable_data(framework::make_ddim(sum_dims), ctx.GetPlace()); @@ -60,15 +90,15 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { out->mutable_data(ctx.GetPlace()); auto out_mat = framework::EigenVector::Flatten(*out); if (bias) { - bit_code.Add(pre_out, *bias); + bit_code->Add(*bias, pre_out); } - bit_code.Mul(pre_out, *w, *in); + bit_code->Mul(pre_out, w, in); // clip to [-40, 40] Transform trans; trans(ctx.template device_context(), pre_out_data, pre_out_data + pre_out->numel(), pre_out_data, ClipFunctor(static_cast(-40.0), static_cast(40.0))); - bit_code.Sum(*pre_out, out, static_cast(-1)); + bit_code->Sum(*pre_out, out, static_cast(-1)); // use softrelu to calculate cross entropy pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); row_sum(dev_ctx, *pre_out, &sum); @@ -84,50 +114,103 @@ template class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* w = ctx.Input("W"); - auto* in_grad = ctx.Output(framework::GradVarName("X")); - auto* w_grad = ctx.Output(framework::GradVarName("W")); - auto* bias_grad = - ctx.Output(framework::GradVarName("Bias")); - auto* label = ctx.Input("Label"); - auto* pre_out = ctx.Input("PreOut"); - auto* out_grad = - ctx.Input(framework::GradVarName("Out")); - framework::Tensor pre_out_grad; - - pre_out_grad.mutable_data(pre_out->dims(), ctx.GetPlace()); - in_grad->mutable_data(ctx.GetPlace()); - w_grad->mutable_data(ctx.GetPlace()); + auto& in = detail::Ref(ctx.Input("X")); + auto& w = detail::Ref(ctx.Input("W")); + auto* path = ctx.Input("PTable"); + auto* code = ctx.Input("PathCode"); + auto* bias = ctx.Input("Bias"); + auto* in_grad = + ctx.Output(framework::GradVarName("X")); + bool is_sparse = ctx.Attr("is_sparse"); auto& dev_ctx = ctx.template device_context(); math::SetConstant zero; + auto& label = detail::Ref(ctx.Input("Label")); + auto& pre_out = detail::Ref(ctx.Input("PreOut")); + auto& out_grad = detail::Ref( + ctx.Input(framework::GradVarName("Out"))); + framework::LoDTensor pre_out_grad; + + pre_out_grad.mutable_data(pre_out.dims(), ctx.GetPlace()); + in_grad->mutable_data(ctx.GetPlace()); zero(dev_ctx, in_grad, static_cast(0.0)); - zero(dev_ctx, w_grad, static_cast(0.0)); size_t num_classes = static_cast(ctx.Attr("num_classes")); - math::MatrixBitCodeFunctor bit_code(num_classes, label->data()); + + bool is_custom = false; + if (path) { + is_custom = true; + } + + std::unique_ptr> bit_code; + if (!is_custom) { + bit_code.reset(new math::MatrixBitCodeFunctor(num_classes, + label.data())); + } else { + bit_code.reset(new math::MatrixBitCodeFunctor(*path, *code, + label.data())); + } auto& place = *ctx.template device_context().eigen_device(); - auto pre_out_mat = EigenMatrix::From(*pre_out); + auto pre_out_mat = EigenMatrix::From(pre_out); auto pre_out_grad_mat = EigenMatrix::From(pre_out_grad); - auto out_grad_mat = EigenMatrix::From(*out_grad); + auto out_grad_mat = EigenMatrix::From(out_grad); + Eigen::array bcast{1, static_cast(pre_out_grad.dims()[1])}; // softrelu derivative pre_out_grad_mat.device(place) = static_cast(1.0) - static_cast(1.0) / pre_out_mat.exp(); - bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b) + bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b) pre_out_grad_mat.device(place) = pre_out_grad_mat * out_grad_mat.broadcast(bcast); // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to // be consistent with the clipping in forward. - if (bias_grad) { - bias_grad->mutable_data(ctx.GetPlace()); - zero(dev_ctx, bias_grad, static_cast(0.0)); - bit_code.AddGrad(pre_out_grad, bias_grad); + + if (!is_sparse) { + auto* bias_grad = + ctx.Output(framework::GradVarName("Bias")); + if (bias_grad) { + bias_grad->mutable_data(ctx.GetPlace()); + zero(dev_ctx, bias_grad, static_cast(0.0)); + bit_code->AddGrad(pre_out_grad, bias_grad); + } + auto* w_grad = + ctx.Output(framework::GradVarName("W")); + w_grad->mutable_data(ctx.GetPlace()); + zero(dev_ctx, w_grad, static_cast(0.0)); + bit_code->MulGradWeight(pre_out_grad, w_grad, in); + } else { + framework::Vector real_rows = PathToRows(*path); + auto* w_grad = + ctx.Output(framework::GradVarName("W")); + w_grad->set_rows(real_rows); + // Build a map of id -> row_index to speed up finding the index of one id + w_grad->SyncIndex(); + w_grad->set_height(w.dims()[0]); + auto* w_grad_value = w_grad->mutable_value(); + framework::DDim temp_dim(w.dims()); + set(temp_dim, 0, real_rows.size()); + + w_grad_value->mutable_data(temp_dim, ctx.GetPlace()); + zero(dev_ctx, w_grad_value, static_cast(0.0)); + auto* bias_grad = + ctx.Output(framework::GradVarName("Bias")); + if (bias_grad) { + bias_grad->set_rows(real_rows); + // build ids -> rows index map + bias_grad->SyncIndex(); + bias_grad->set_height(bias->dims()[0]); + auto* bias_grad_value = bias_grad->mutable_value(); + std::vector dims = {static_cast(real_rows.size()), + bias->dims()[1]}; + bias_grad_value->mutable_data(framework::make_ddim(dims), + ctx.GetPlace()); + zero(dev_ctx, bias_grad_value, static_cast(0.0)); + bit_code->AddGrad(pre_out_grad, bias_grad); + } + bit_code->MulGradWeight(pre_out_grad, w_grad, in); } - bit_code.MulGradWeight(pre_out_grad, w_grad, *in); - bit_code.MulGradError(pre_out_grad, *w, in_grad); + bit_code->MulGradError(pre_out_grad, w, in_grad); } }; diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 1e56e297396c6e37867a53f039478191f0caf08e..71b9293eeded77553ca06a8574cca3941fa36b6a 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -19,16 +19,15 @@ namespace operators { namespace math { template -void MatrixBitCodeFunctor::Add(framework::Tensor* tmat, - const framework::Tensor& vec) { - SimpleCodeTable code_table(num_classes_); +void MatrixBitCodeFunctor::Add(const framework::Tensor& vec, + framework::Tensor* tmat) { size_t batch_size = tmat->dims()[0]; size_t width = tmat->dims()[1]; for (size_t i = 0; i < batch_size; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table_->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - size_t index = code.calc_index(j); + size_t index = code->calc_index(j); tmat->data()[i * width + j] += vec.data()[index]; } } @@ -37,31 +36,46 @@ void MatrixBitCodeFunctor::Add(framework::Tensor* tmat, template void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, framework::Tensor* vec) { - SimpleCodeTable code_table(num_classes_); size_t batch_size = tmat.dims()[0]; size_t width = tmat.dims()[1]; for (size_t i = 0; i < batch_size; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table_->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - size_t index = code.calc_index(j); + size_t index = code->calc_index(j); vec->data()[index] += tmat.data()[i * width + j]; } } } +template +void MatrixBitCodeFunctor::AddGrad(const framework::Tensor& tmat, + framework::SelectedRows* vec) { + size_t batch_size = tmat.dims()[0]; + size_t width = tmat.dims()[1]; + for (size_t i = 0; i < batch_size; ++i) { + auto code = code_table_->get_code(i); + int code_length = code->get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code->calc_index(j); + int64_t row_index = vec->GetIndexFromId(static_cast(index)); + vec->mutable_value()->data()[row_index] += + tmat.data()[i * width + j]; + } + } +} + template void MatrixBitCodeFunctor::Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum) { - SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat.dims()[0]; size_t o_width = tmat.dims()[1]; for (size_t i = 0; i < num_samples; ++i) { T sm = static_cast(0.0); - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table_->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - if (code.calc_bit(j)) { + if (code->calc_bit(j)) { // calc_bit starts from right most bit, while data in tmat[i] is in the // reverse order. sm += tmat.data()[i * o_width + j]; @@ -75,7 +89,6 @@ template void MatrixBitCodeFunctor::Mul(framework::Tensor* tmat, const framework::Tensor& weight, const framework::Tensor& input) { - SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat->dims()[0]; size_t tmat_width = tmat->dims()[1]; size_t input_width = input.dims()[1]; @@ -84,10 +97,10 @@ void MatrixBitCodeFunctor::Mul(framework::Tensor* tmat, auto weight_value = weight.data(); auto input_value = input.data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table_->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - size_t index = code.calc_index(j); + size_t index = code->calc_index(j); T sum = static_cast(0.0); for (size_t k = 0; k < input_width; ++k) { sum += weight_value[weight_width * index + k] * @@ -102,7 +115,6 @@ template void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight, const framework::Tensor& input) { - SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat.dims()[0]; size_t input_width = input.dims()[1]; size_t tmat_width = tmat.dims()[1]; @@ -111,10 +123,10 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, auto weight_value = weight->data(); auto input_value = input.data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table_->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - size_t index = code.calc_index(j); + size_t index = code->calc_index(j); for (size_t k = 0; k < input_width; ++k) { weight_value[weight_width * index + k] += @@ -124,11 +136,35 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, } } +template +void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, + framework::SelectedRows* weight, + const framework::Tensor& input) { + size_t num_samples = tmat.dims()[0]; + size_t input_width = input.dims()[1]; + size_t tmat_width = tmat.dims()[1]; + size_t weight_width = weight->value().dims()[1]; + auto tmat_value = tmat.data(); + auto weight_value = weight->mutable_value()->data(); + auto input_value = input.data(); + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table_->get_code(i); + int code_length = code->get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code->calc_index(j); + for (size_t k = 0; k < input_width; ++k) { + int64_t row_index = weight->GetIndexFromId(static_cast(index)); + weight_value[row_index * weight_width + k] += + tmat_value[i * tmat_width + j] * input_value[input_width * i + k]; + } + } + } +} + template void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, const framework::Tensor& weight, framework::Tensor* input) { - SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat.dims()[0]; size_t tmat_width = tmat.dims()[1]; size_t input_width = input->dims()[1]; @@ -138,10 +174,10 @@ void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, auto input_value = input->data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table_->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - size_t index = code.calc_index(j); + size_t index = code->calc_index(j); for (size_t k = 0; k < input_width; ++k) { input_value[input_width * i + k] += @@ -154,14 +190,13 @@ void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, template void MatrixBitCodeFunctor::Sub(framework::Tensor* tmat) { - SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat->dims()[0]; size_t o_width = tmat->dims()[1]; for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(ids_[i])); - int code_length = code.get_length(); + auto code = code_table_->get_code(i); + int code_length = code->get_length(); for (int j = 0; j < code_length; ++j) { - if (code.calc_bit(j)) { + if (code->calc_bit(j)) { tmat->data()[i * o_width + j] -= 1; } } diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index c329b8b6113e847ec1c57e63258a18b6f65d9396..c30bb52641e865efe57659a551bc4b493634c6b9 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" @@ -92,9 +94,27 @@ inline int clz(const T& value) { inline size_t FindLastSet(size_t x) { return sizeof(size_t) * 8 - clz(x); } #endif // !_WIN32 +// set a code interface to create multiple code +class Code { + public: + virtual ~Code() {} + virtual size_t calc_index(int bit) const = 0; + virtual bool calc_bit(int bit) const = 0; + virtual int get_length() const = 0; +}; +// set a CodeTable interface to create multiple code table +class CodeTable { + public: + virtual std::unique_ptr get_code(int64_t code) const = 0; + virtual size_t size() const = 0; + virtual int get_max_code_length() const = 0; + virtual ~CodeTable() {} +}; -struct SimpleCode { - SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {} +class SimpleCode : public Code { + public: + SimpleCode(size_t code, size_t num_classes, const int64_t* ids) + : c_(static_cast(ids[code]) + num_classes) {} /** * Here the id of root shoud be 1 rather than 0, thus the encoding of class c * is `c + num_classes` and all siblings can get the same weight indice using @@ -104,41 +124,121 @@ struct SimpleCode { * Binary classification path is the suffixes of encoding, thus leave out the * left most bit in calc_bit. */ - inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; } - inline bool calc_bit(int bit) const { return c_ & (1 << bit); } - inline int get_length() const { return FindLastSet(c_) - 1; } + size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; } + bool calc_bit(int bit) const { return c_ & (1 << bit); } + int get_length() const { return FindLastSet(c_) - 1; } private: size_t c_; }; -struct SimpleCodeTable { - explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {} - SimpleCode operator()(size_t code) const { - return SimpleCode(code, num_classes_); +template +class CustomCode : public Code { + public: + CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode, + const int64_t* ids, int index) + : ids_(ids), index_(index) { + ptable_ = ptable.Slice(index, index + 1); + pcode_ = pcode.Slice(index, index + 1); + } + /** + * Here the id of root shoud be 1 rather than 0, thus the encoding of class c + * is `c + num_classes` and all siblings can get the same weight indice using + * prefixes. + * Weight index is the prefixes of encoding, thus leave out the right most + * bit in calc_index. + * Binary classification path is the suffixes of encoding, thus leave out the + * left most bit in calc_bit. + */ + size_t calc_index(int bit) const { return ptable_.data()[bit]; } + bool calc_bit(int bit) const { return pcode_.data()[bit]; } + int get_length() const { + int length = 0; + + for (int i = 0; i < static_cast(ptable_.dims()[1]); i++) { + if (ptable_.data()[i] >= 0) { + length++; + } else { + return length; + } + } + return length; + } + + private: + framework::Tensor ptable_; + framework::Tensor pcode_; + const int64_t* ids_; + const int index_; +}; + +class SimpleCodeTable : public CodeTable { + public: + SimpleCodeTable(size_t num_classes, const int64_t* ids) + : num_classes_(num_classes), ids_(ids) {} + std::unique_ptr get_code(int64_t code) const { + std::unique_ptr coder(new SimpleCode(code, num_classes_, ids_)); + return coder; } size_t size() const { return num_classes_; } int get_max_code_length() const { return FindLastSet(num_classes_ - 1); } private: size_t num_classes_; + const int64_t* ids_; +}; + +template +class CustomCodeTable : public CodeTable { + public: + CustomCodeTable(const framework::Tensor& ptable, + const framework::Tensor& pcode, const int64_t* ids) + : ptable_(ptable), pcode_(pcode), ids_(ids) {} + + std::unique_ptr get_code(int64_t code) const { + std::unique_ptr coder(new CustomCode(ptable_, pcode_, ids_, code)); + return coder; + } + + size_t size() const { return static_cast(ptable_.dims()[1]); } + int get_max_code_length() const { + return static_cast(ptable_.dims()[1]); + } + + private: + const framework::Tensor& ptable_; + const framework::Tensor& pcode_; + const int64_t* ids_; }; template class MatrixBitCodeFunctor { public: - explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids) - : num_classes_(num_classes), ids_(ids) {} + MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids) + : num_classes_(num_classes), + ids_(ids), + code_table_(new SimpleCodeTable(num_classes, ids)) {} + + MatrixBitCodeFunctor(const framework::Tensor& ptable, + const framework::Tensor& pcode, const int64_t* ids) + : num_classes_(static_cast(ptable.dims()[1])), + ids_(ids), + code_table_(new CustomCodeTable(ptable, pcode, ids)) {} /* For j < code_length tmat(i, j) += vec(0, index(i, j)) */ - void Add(framework::Tensor* tmat, const framework::Tensor& vec); + void Add(const framework::Tensor& vec, framework::Tensor* tmat); /* For j < code_length vec(0, index(i, j)) += tmat(i, j) */ void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec); + /* For selected rows For j < code_length + vec(0, index(i, j)) += tmat(i, j) + */ + void AddGrad(const framework::Tensor& tmat, framework::SelectedRows* vec); + /* For j < code_length sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) */ @@ -159,6 +259,12 @@ class MatrixBitCodeFunctor { */ void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight, const framework::Tensor& input); + /* For SelectedRows Weight, For index(i, j) >= 0: + weight.row(index(i, j)) += tmat(i, j) * input.row(i) + */ + void MulGradWeight(const framework::Tensor& tmat, + framework::SelectedRows* weight, + const framework::Tensor& input); /* For j < code_length input.row(i) += tmat(i, j) * weight.row(index(i, j)) */ @@ -167,6 +273,7 @@ class MatrixBitCodeFunctor { size_t num_classes_; const int64_t* ids_; + std::unique_ptr code_table_; }; } // namespace math } // namespace operators diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 48f571a7cc1374955460ffe22e1123baf27d87f6..4df74edfcebe4e8da7172c89f3958f3df2fd2c1f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4587,27 +4587,43 @@ def hsigmoid(input, num_classes, param_attr=None, bias_attr=None, - name=None): + name=None, + path_table=None, + path_code=None, + is_custom=False, + is_sparse=False): """ The hierarchical sigmoid operator is used to accelerate the training process of language model. This operator organizes the classes into a - complete binary tree, each leaf node represents a class(a word) and each + complete binary tree, or you can use is_custom to pass your own tree to + implement hierarchical. Each leaf node represents a class(a word) and each internal node acts as a binary classifier. For each word there's a unique path from root to it's leaf node, hsigmoid calculate the cost for each internal node on the path, and sum them to get a total cost. hsigmoid can achive a acceleration from :math:`O(N)` to :math:`O(logN)`, where :math:`N` represents the size of word dict. - Refer to `Hierarchical Probabilistic Neural Network Language Model + Using default tree you can Refer to `Hierarchical Probabilistic Neural Network Language Model `_ + And if you want to use the costumed tree by set 'is_custom' as true you may need to do following things first: + 1. using your word dict to build a binary tree, each leaf node should be an word of your word dict + 2. build a dict to store word_id -> word's leaf to root path, we call it path_table. + 3. build a dict to store word_id -> code of word's leaf to root path, we call it path_code. Code + means label of each binary classification, using 1 indicate true, 0 indicate false. + 4. now, each word should has its path and code along the path, you can pass a batch of path and code + related to the same batch of inputs. + + Args: input (Variable): The input tensor variable with shape :math:`[N \\times D]`, where :math:`N` is the size of mini-batch, and :math:`D` is the feature size. label (Variable): The tensor variable contains labels of training data. It's a tensor with shape is :math:`[N \\times 1]`. - num_classes: (int), The number of classes, must not be less than 2. + num_classes: (int), The number of classes, must not be less than 2. with default tree this has to be set, + it should never be None under is_custom=False, but while is_custom is true, it should be non leaf num + which indicates the num of classes using by binary classify. param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid will create ParamAttr as param_attr. If the Initializer of the param_attr @@ -4619,9 +4635,19 @@ def hsigmoid(input, is not set, the bias is initialized zero. Default: None. name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. Default: None. + path_table: (Variable|None) this variable can store each batch of samples' path to root, + it should be in leaf -> root order + path_table should have the same shape with path_code, and for each sample i path_table[i] indicates a np.array like + structure and each element in this array is indexes in parent nodes' Weight Matrix. + path_code: (Variable|None) this variable can store each batch of samples' code, + each code consist with every code of parent nodes. it should be in leaf -> root order + is_custom: (bool|False)using user defined binary tree instead of default complete binary tree, if costum is + set you need to set path_table/path_code/num_classes, otherwise num_classes should be set + is_sparse: (bool|False)using sparse update instead of dense update, if set, the gradient + of W and input will be sparse. Returns: - Out: (Tensor) The cost of hierarchical sigmoid operator. the shape is [N, 1] + Out: (LodTensor) The cost of hierarchical sigmoid operator. the shape is [N, 1] Examples: @@ -4637,27 +4663,62 @@ def hsigmoid(input, out = helper.create_variable_for_type_inference(dtype) pre_out = helper.create_variable_for_type_inference(dtype) dim = input.shape[1] - if num_classes < 2: - raise ValueError("num_classes must not be less than 2.") - weights = helper.create_parameter( - attr=helper.param_attr, - shape=[num_classes - 1, dim], - is_bias=False, - dtype=input.dtype) - inputs = {"X": input, "W": weights, "Label": label} - if helper.bias_attr: - bias = helper.create_parameter( - attr=helper.bias_attr, - shape=[1, num_classes - 1], - is_bias=True, + if ((num_classes is None) or (num_classes < 2)) and (not is_custom): + raise ValueError( + "num_classes must not be less than 2 with default tree") + + if (is_custom) and (path_code is None): + raise ValueError("path_code should not be None with costum tree") + elif (is_custom) and (path_table is None): + raise ValueError("path_table should not be None with costum tree") + elif (is_custom) and (num_classes is None): + raise ValueError("num_classes should not be None with costum tree") + else: + pass + + weights = None + + if not is_custom: + weights = helper.create_parameter( + attr=helper.param_attr, + shape=[num_classes - 1, dim], + is_bias=False, dtype=input.dtype) - inputs['Bias'] = bias + else: + weights = helper.create_parameter( + attr=helper.param_attr, + shape=[num_classes, dim], + is_bias=False, + dtype=input.dtype) + inputs = { + "X": input, + "W": weights, + "PTable": path_table, + "PathCode": path_code, + "Label": label + } + if helper.bias_attr: + if not is_custom: + bias = helper.create_parameter( + attr=helper.bias_attr, + shape=[num_classes - 1, 1], + is_bias=True, + dtype=input.dtype) + inputs['Bias'] = bias + else: + bias = helper.create_parameter( + attr=helper.bias_attr, + shape=[num_classes, 1], + is_bias=True, + dtype=input.dtype) + inputs['Bias'] = bias helper.append_op( type="hierarchical_sigmoid", inputs=inputs, outputs={"Out": out, "PreOut": pre_out}, - attrs={"num_classes": num_classes}) + attrs={"num_classes": num_classes, + "is_sparse": is_sparse}) return out diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index 6948ae30023a75d4735db1c78466e89e28640c9e..2a6c93f75fad53440a2db64e4f34c9a5c22c654e 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -16,6 +16,8 @@ from __future__ import print_function import unittest import numpy as np +import paddle.fluid.core as core +import paddle.fluid as fluid import math from op_test import OpTest @@ -40,6 +42,29 @@ class CodeTable(object): return self.c & (1 << bit) +class CodeTableWithCustomTree(object): + def __init__(self, path_table, path_code, index): + self.ptable_ = path_table + self.pcode_ = path_code + self.index_ = index + + def cal_index(self, bit): + return self.ptable_[self.index_][bit] + + def get_length(self): + length = 0 + for ele in self.ptable_[self.index_]: # find the first -1 to stop trace + + if ele >= 0: + length = length + 1 + else: + return length + return length + + def cal_bit(self, bit): + return self.pcode_[self.index_][bit] + + def hsigmoid(x, w, label, bias, num_classes): batch_size = x.shape[0] code_length = find_latest_set(num_classes - 1) @@ -52,7 +77,7 @@ def hsigmoid(x, w, label, bias, num_classes): length = code_table.get_length() for j in range(length): idx = code_table.cal_index(j) - pre_output[i][j] += bias[0][idx] + pre_output[i][j] += bias[idx][0] for i in range(batch_size): code_table = CodeTable(num_classes, label[i]) length = code_table.get_length() @@ -77,17 +102,58 @@ def hsigmoid(x, w, label, bias, num_classes): return pre_output, out +def hsigmoidWithCustomTree(x, w, path_table, path_code, label, bias, + num_classes): + batch_size = x.shape[0] + code_length = len(path_table[0]) + code_table = [0 for _ in range(code_length)] + # init pre_out with shape [N, code_length] + pre_output = np.zeros((batch_size, code_length)) + pre_sum = np.zeros((batch_size, 1)) + out = np.zeros((batch_size, 1)).astype("float32") + if isinstance(bias, np.ndarray): + for i in range(batch_size): + code_table = CodeTableWithCustomTree(path_table, path_code, i) + length = code_table.get_length() + for j in range(length): + idx = code_table.cal_index(j) + pre_output[i][j] += bias[idx][0] + for i in range(batch_size): + code_table = CodeTableWithCustomTree(path_table, path_code, i) + length = code_table.get_length() + for j in range(length): + idx = code_table.cal_index(j) + pre_output[i][j] += np.dot(w[idx], x[i]) + # clip[-40.0, 40.0] + pre_output = np.clip(pre_output, -40.0, 40.0) + # out(i, 0) = \sum_j bit(i, j) * preout(i, j) + for i in range(batch_size): + code_table = CodeTableWithCustomTree(path_table, path_code, i) + length = code_table.get_length() + sum = 0.0 + for j in range(length): + if code_table.cal_bit(j): + sum += pre_output[i][j] + out[i] = -1.0 * sum + # soft relu + pre_output = np.log(1 + np.exp(pre_output)) + pre_sum = pre_output.sum(1).reshape((batch_size, 1)) + out += pre_sum + return pre_output, out + + class TestHSigmoidOp(OpTest): def setUp(self): self.op_type = "hierarchical_sigmoid" num_classes = 6 feature_size = 8 batch_size = 4 - x = np.random.random((batch_size, feature_size)).astype("float32") - w = np.random.random((num_classes - 1, feature_size)).astype("float32") + x = np.random.random((batch_size, feature_size)).astype("float32") * 2 + w = np.random.random( + (num_classes - 1, feature_size)).astype("float32") * 2 label = np.random.randint(0, num_classes, (batch_size, 1)) - bias = np.random.random((1, num_classes - 1)).astype("float32") - self.attrs = {'num_classes': num_classes} + bias = np.random.random((num_classes - 1, 1)).astype("float32") + self.attrs = {'num_classes': num_classes, 'is_sparse': False} self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} pre_output, out = hsigmoid(x, w, label, bias, num_classes) self.outputs = {'PreOut': pre_output, 'Out': out} @@ -99,5 +165,185 @@ class TestHSigmoidOp(OpTest): self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) +class TestHSigmoidOpSparse(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid" + num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample + feature_size = 8 + batch_size = 4 + x = np.random.random((batch_size, feature_size)).astype("float32") + w = np.random.random((num_classes - 1, feature_size)).astype("float32") + label = np.array([0, 1, 4, 5]) + path_table = np.array( + [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), + (0, 2, -1, -1, + -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) + path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( + 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store + bias = np.random.random((num_classes - 1, 1)).astype("float32") + self.attrs = {'num_classes': num_classes, 'is_sparse': True} + self.inputs = { + 'X': x, + 'W': w, + 'PTable': path_table, + 'PathCode': path_code, + 'Label': label, + 'Bias': bias + } + pre_output, out = hsigmoidWithCustomTree(x, w, path_table, path_code, + label, bias, num_classes) + self.outputs = {'PreOut': pre_output, 'Out': out} + + def test_check_output(self): + self.check_output() + + +class TestHSigmoidOpWithSparseGrad(unittest.TestCase): + def hs_net_conf(self, is_sparse): + input_word = fluid.layers.data(name="x", shape=[1], dtype='int64') + path_table = fluid.layers.data( + name='path_table', shape=[3], dtype='int64') + path_code = fluid.layers.data( + name='path_code', shape=[3], dtype='int64') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + data_list = [input_word, path_table, path_code, label] + + emb = fluid.layers.embedding( + input=input_word, + is_sparse=is_sparse, + size=[3, 3], + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(3)))) + + cost = fluid.layers.hsigmoid( + input=emb, + label=label, + bias_attr=True, + num_classes=3, + path_table=path_table, + path_code=path_code, + is_custom=True, + is_sparse=is_sparse) + + avg_cost = fluid.layers.reduce_mean(cost) + + return avg_cost, data_list + + def training_test(self, is_sparse): + with fluid.program_guard(fluid.Program(), fluid.Program()): + start_up = fluid.default_startup_program() + start_up.random_seed = 1 # Fix random seed + x = np.arange(6).reshape(6) + path_table = np.array([(1, 2, -1), (1, 2, -1)]) + path_code = np.array([(1, 0, -1), (0, 0, -1)]) + label = np.array([1, 4]) + + loss, data_list = self.hs_net_conf(is_sparse) + optimizer = fluid.optimizer.SGD(learning_rate=1e-3) + optimizer.minimize(loss) + + main_program = fluid.default_main_program() + place = fluid.CPUPlace() + feeder = fluid.DataFeeder(feed_list=data_list, place=place) + exe = fluid.Executor(place) + + exe.run(start_up) + result = list() + for i in range(10): + data = [([[x[i % 2]]], [list(path_table[i % 2])], + [list(path_code[i % 2])], [label[i % 2]])] + + loss_val = exe.run(main_program, + feed=feeder.feed(data), + fetch_list=[loss]) + result.append(loss_val) + return result + + def test_hs_grad_with_sparse(self): + dense_result = self.training_test(is_sparse=False) + sparse_result = self.training_test(is_sparse=True) + assert (dense_result == sparse_result) + + +class TestHSigmoidOpWithCostumTree(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid" + num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample + feature_size = 8 + batch_size = 4 + x = np.random.random((batch_size, feature_size)).astype("float32") * 2 + w = np.random.random( + (num_classes - 1, feature_size)).astype("float32") * 2 + label = np.array([0, 1, 4, 5]) + path_table = np.array( + [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), + (0, 2, -1, -1, + -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) + path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( + 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store + bias = np.random.random((num_classes - 1, 1)).astype("float32") + self.attrs = {'num_classes': num_classes, 'is_sparse': False} + self.inputs = { + 'X': x, + 'W': w, + 'PTable': path_table, + 'PathCode': path_code, + 'Label': label, + 'Bias': bias + } + pre_output, out = hsigmoidWithCustomTree(x, w, path_table, path_code, + label, bias, num_classes) + self.outputs = {'PreOut': pre_output, 'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) + + +class TestHSigmoidOpWithCostumTreeWithoutBias(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid" + num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample + feature_size = 8 + batch_size = 4 + x = np.random.random((batch_size, feature_size)).astype("float32") * 2 + w = np.random.random( + (num_classes - 1, feature_size)).astype("float32") * 2 + label = np.array([0, 1, 4, 5]) + path_table = np.array( + [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), + (0, 2, -1, -1, + -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) + path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( + 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store + # bias = np.random.random((num_classes - 1, 1)).astype("float32") + self.attrs = {'num_classes': num_classes, 'is_sparse': False} + self.inputs = { + 'X': x, + 'W': w, + 'PTable': path_table, + 'PathCode': path_code, + 'Label': label, + } + pre_output, out = hsigmoidWithCustomTree( + x=x, + w=w, + path_table=path_table, + path_code=path_code, + label=label, + bias=None, + num_classes=num_classes) + self.outputs = {'PreOut': pre_output, 'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X', 'W'], ['Out'], no_grad_set=set('Label')) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 559c9cda4812e2c099f25b31dffd823a2fa7620d..541160771152dd2ebc8a782863bb4ad3643892e5 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -185,6 +185,25 @@ class TestBook(unittest.TestCase): input=x, label=y, num_classes=2)) print(str(program)) + # test hsigmod with custom tree structure + program2 = Program() + with program_guard(program2): + x2 = layers.data(name='x2', shape=[4, 8], dtype='float32') + y2 = layers.data(name='y2', shape=[4], dtype='int64') + path_table = layers.data( + name='path_table', shape=[4, 6], dtype='int64') + path_code = layers.data( + name='path_code', shape=[4, 6], dtype='int64') + self.assertIsNotNone( + layers.hsigmoid( + input=x2, + label=y2, + num_classes=6, + path_table=path_table, + path_code=path_code, + is_custom=True)) + print(str(program2)) + def test_sequence_expand(self): program = Program() with program_guard(program):