From a25c3aeba6d1a370e4b361f26f0e112fd00e7c4e Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Fri, 17 Nov 2017 10:31:44 +0800 Subject: [PATCH] add forward --- paddle/operators/hierarchical_sigmoid_op.cc | 13 ++++++------- paddle/operators/hierarchical_sigmoid_op.h | 16 +++++++++++++++- paddle/operators/math/matrix_bit_code.cc | 6 +++--- paddle/operators/math/matrix_bit_code.h | 4 ++++ 4 files changed, 28 insertions(+), 11 deletions(-) diff --git a/paddle/operators/hierarchical_sigmoid_op.cc b/paddle/operators/hierarchical_sigmoid_op.cc index 1f77ff126..9b7af9266 100644 --- a/paddle/operators/hierarchical_sigmoid_op.cc +++ b/paddle/operators/hierarchical_sigmoid_op.cc @@ -83,19 +83,18 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "(TensorArray, required) The input array. Each Tensor has the " - "same shape with [N * D]." - .AsDuplicable(); + "same shape with [N * D].") + .AsDuplicable(); AddInput("Label", "(Tensor, required), The labels of training data. It's a" "1-D tensor."); AddInput("Bias", "(Tensor, optional), The bias is a 1-D tensor, " "which is applied to the output"); - AddOutput("Out", - "(Tensor, required) The output of hierarchical sigmoid operator."); - AddAttr("num_classes", - "(int, required)", - "The number of classes"); + AddOutput( + "Out", + "(Tensor, required) The output of hierarchical sigmoid operator."); + AddAttr("num_classes", "(int, required)", "The number of classes"); AddComment(R"DOC( The hierarchical sigmoid operator organize the classes into a binary tree. At each node, a sigmoid function is used to caculate the probability of diff --git a/paddle/operators/hierarchical_sigmoid_op.h b/paddle/operators/hierarchical_sigmoid_op.h index 8a753605d..11a553a40 100644 --- a/paddle/operators/hierarchical_sigmoid_op.h +++ b/paddle/operators/hierarchical_sigmoid_op.h @@ -22,7 +22,21 @@ template class HierarchicalSigmoidOpKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override {} + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + auto* label = ctx.Input("Label"); + auto* bias = ctx.Input("Bias"); + size_t num_classes = static_cast(ctx.Attr("num_classes")); + int64_t batch_size = ins[0]->dims()[0]; + int64_t size = ins.size(); + framework::Tensor pre_out; + std::vector pre_out_dims({batch_size, size}); + pre_out.mutable_data(framework::make_ddim(pre_out_dims), ctx.GetPlace()); + + if (bias != NULL) { + math::AddByBitCode(num_classes, *label, pre_out, *bias); + } + } }; template diff --git a/paddle/operators/math/matrix_bit_code.cc b/paddle/operators/math/matrix_bit_code.cc index 3f1dbbf39..30c2ffc2c 100644 --- a/paddle/operators/math/matrix_bit_code.cc +++ b/paddle/operators/math/matrix_bit_code.cc @@ -50,7 +50,7 @@ namespace math { for j < codeLength: op(a(i, j), b(0, index(i, j))) */ -template +template static void AddByBitCodeT(Op op, CodeTable code_table, const framework::Tensor& codes, framework::Tensor& a, framework::Tensor& b) { @@ -72,11 +72,11 @@ static void AddByBitCodeT(Op op, CodeTable code_table, /* For j < codeLength: a(i, j) += b(0, index(i, j)) */ -template +template void AddByBitCode(size_t num_classes, const framework::Tensor& codes, framework::Tensor& a, const framework::Tensor& b) { auto op = [](T& t, T& v) { t += v; }; - AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, a, b); + AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, a, b); } } // namespace math diff --git a/paddle/operators/math/matrix_bit_code.h b/paddle/operators/math/matrix_bit_code.h index a0dd89ebe..bb0599aa1 100644 --- a/paddle/operators/math/matrix_bit_code.h +++ b/paddle/operators/math/matrix_bit_code.h @@ -59,6 +59,10 @@ struct SimpleCodeTable { int max_code_length_; }; +template +void AddByBitCode(size_t num_classes, const framework::Tensor& codes, + framework::Tensor& a, const framework::Tensor& b); + } // namespace math } // namespace operators } // namespace paddle -- GitLab