diff --git a/paddle/operators/hierarchical_sigmoid_op.cc b/paddle/operators/hierarchical_sigmoid_op.cc index 1f77ff1268a6a489d2c5377762fb837492737288..9b7af92662e403c94964b49be206a73e2b26faad 100644 --- a/paddle/operators/hierarchical_sigmoid_op.cc +++ b/paddle/operators/hierarchical_sigmoid_op.cc @@ -83,19 +83,18 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "(TensorArray, required) The input array. Each Tensor has the " - "same shape with [N * D]." - .AsDuplicable(); + "same shape with [N * D].") + .AsDuplicable(); AddInput("Label", "(Tensor, required), The labels of training data. It's a" "1-D tensor."); AddInput("Bias", "(Tensor, optional), The bias is a 1-D tensor, " "which is applied to the output"); - AddOutput("Out", - "(Tensor, required) The output of hierarchical sigmoid operator."); - AddAttr("num_classes", - "(int, required)", - "The number of classes"); + AddOutput( + "Out", + "(Tensor, required) The output of hierarchical sigmoid operator."); + AddAttr("num_classes", "(int, required)", "The number of classes"); AddComment(R"DOC( The hierarchical sigmoid operator organize the classes into a binary tree. At each node, a sigmoid function is used to caculate the probability of diff --git a/paddle/operators/hierarchical_sigmoid_op.h b/paddle/operators/hierarchical_sigmoid_op.h index 8a753605d6ec6672046dad397979777f180ed666..11a553a40392c783c3b74efe72c820ca5314852c 100644 --- a/paddle/operators/hierarchical_sigmoid_op.h +++ b/paddle/operators/hierarchical_sigmoid_op.h @@ -22,7 +22,21 @@ template class HierarchicalSigmoidOpKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override {} + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + auto* label = ctx.Input("Label"); + auto* bias = ctx.Input("Bias"); + size_t num_classes = static_cast(ctx.Attr("num_classes")); + int64_t batch_size = ins[0]->dims()[0]; + int64_t size = ins.size(); + framework::Tensor pre_out; + std::vector pre_out_dims({batch_size, size}); + pre_out.mutable_data(framework::make_ddim(pre_out_dims), ctx.GetPlace()); + + if (bias != NULL) { + math::AddByBitCode(num_classes, *label, pre_out, *bias); + } + } }; template diff --git a/paddle/operators/math/matrix_bit_code.cc b/paddle/operators/math/matrix_bit_code.cc index 3f1dbbf39943794a784033ba022972a5f67abeec..30c2ffc2cfdacc89a01721265137eb3cfb496af8 100644 --- a/paddle/operators/math/matrix_bit_code.cc +++ b/paddle/operators/math/matrix_bit_code.cc @@ -50,7 +50,7 @@ namespace math { for j < codeLength: op(a(i, j), b(0, index(i, j))) */ -template +template static void AddByBitCodeT(Op op, CodeTable code_table, const framework::Tensor& codes, framework::Tensor& a, framework::Tensor& b) { @@ -72,11 +72,11 @@ static void AddByBitCodeT(Op op, CodeTable code_table, /* For j < codeLength: a(i, j) += b(0, index(i, j)) */ -template +template void AddByBitCode(size_t num_classes, const framework::Tensor& codes, framework::Tensor& a, const framework::Tensor& b) { auto op = [](T& t, T& v) { t += v; }; - AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, a, b); + AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, a, b); } } // namespace math diff --git a/paddle/operators/math/matrix_bit_code.h b/paddle/operators/math/matrix_bit_code.h index a0dd89ebe018e7d11afa6bbe743d11f76cb3d031..bb0599aa17716eb593b05ef7526c663bef6a2c98 100644 --- a/paddle/operators/math/matrix_bit_code.h +++ b/paddle/operators/math/matrix_bit_code.h @@ -59,6 +59,10 @@ struct SimpleCodeTable { int max_code_length_; }; +template +void AddByBitCode(size_t num_classes, const framework::Tensor& codes, + framework::Tensor& a, const framework::Tensor& b); + } // namespace math } // namespace operators } // namespace paddle