diff --git a/paddle/operators/hierarchical_sigmoid_op.cc b/paddle/operators/hierarchical_sigmoid_op.cc index f81f3d34d1931ce2e3231a3fd60b6dda434e86dd..063f8576e660a62953f26bffb73a597d0f856ebe 100644 --- a/paddle/operators/hierarchical_sigmoid_op.cc +++ b/paddle/operators/hierarchical_sigmoid_op.cc @@ -60,12 +60,11 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null."); + PADDLE_ENFORCE(ctx->hasInput("X"), "Input(X) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); - const int64_t batch_size = ctx->GetInputsDim("X")[0][0]; - const int64_t size = ctx->GetInputsDim("X").size(); - std::vector output_shape({batch_size, size}); + const int64_t batch_size = ctx->GetInputDim("X")[0]; + std::vector output_shape({batch_size, num_classes_ - 1}); ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); } }; @@ -82,22 +81,23 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", - "(TensorArray, required) The input array. Each Tensor has the " - "same shape with [N * D].") - .AsDuplicable(); + "(Tensor, required) The input Tensor, which the shape is" + "[N * D], which N is the size of mini-batch," + "D is the embded size"); AddInput("Parameters", "(Tensor, required), The parameters of hierarchical " - "sigmoid operator, each of them is s a 2-D tensor.") - .AsDuplicable(); + "sigmoid operator, each of them is s a 3-D tensor, the shape is" + "[N, num_classes - 1, D]"); AddInput("Label", "(Tensor, required), The labels of training data. It's a" - "1-D tensor."); + "1-D tensor, which the shape is [1, N]"); AddInput("Bias", "(Tensor, optional), The bias is a 1-D tensor, " - "which is applied to the output."); - AddOutput( - "Out", - "(Tensor, required) The output of hierarchical sigmoid operator."); + "which is applied to the output, the shape is" + "[1, num_classes -1]"); + AddOutput("Out", + "(Tensor, required) The output of hierarchical sigmoid operator." + "the shape is [N, 1]"); AddAttr("num_classes", "(int, required)", "The number of classes"); AddComment(R"DOC( The hierarchical sigmoid operator organize the classes into a binary tree. diff --git a/paddle/operators/hierarchical_sigmoid_op.h b/paddle/operators/hierarchical_sigmoid_op.h index 186c76793233ecb21db17f7040a202ab4e3f480c..e3f0bcacd8b556a73ecc57e80a99d981acb27a21 100644 --- a/paddle/operators/hierarchical_sigmoid_op.h +++ b/paddle/operators/hierarchical_sigmoid_op.h @@ -28,8 +28,8 @@ template class HierarchicalSigmoidOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto ins = ctx.MultiInput("X"); - auto params = ctx.MultiInput("Parameters"); + auto* in = ctx.Input("X"); + auto* param = ctx.Input("Parameter"); auto* label = ctx.Input("Label"); auto* bias = ctx.Input("Bias"); auto* out = ctx.Output("Out"); @@ -56,8 +56,9 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { math::AddByBitCode(num_classes, *label, pre_out, *bias); } - for (size_t i = 0; i < ins.size(); ++i) { - math::MulByBitCode(num_classes, *label, pre_out, *params[i], *ins[i]); + for (size_t i = 0; i < in.dims()[0]; ++i) { + math::MulByBitCode(num_classes, *label, pre_out, + *params->Slice(i, i + 1), *in->Slice(i, i + 1)); } // clip the matrix with (-40, 40) pre_out_mat.device(place) = @@ -79,11 +80,10 @@ template class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto ins = ctx.MultiInput("X"); - auto ins_grad = - ctx.MultiOutput(framework::GradVarName("X")); - auto params = ctx.MultiOutput( - framework::GradVarName("Parameters")); + auto* in = ctx.Input("X"); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + auto* params = + ctx.Output(framework::GradVarName("Parameters")); auto* bias = ctx.Output(framework::GradVarName("Bias")); auto* label = ctx.Output(framework::GradVarName("Label")); @@ -92,7 +92,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { framework::Tensor pre_out; auto place = ctx.GetEigenDevice(); auto& dev_ctx = ctx.device_context(); - int64_t batch_size = ins_grad.size(); + int64_t batch_size = in_grad.dims()[0]; int64_t code_length = math::FindLastSet(num_classes - 1); auto pre_out_mat = EigenMatrix::From(pre_out); @@ -111,11 +111,11 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { math::AddByBitCodeGrad(num_classes, *label, pre_out, *bias); } - for (size_t i = 0; i < ins_grad.size(); ++i) { + for (size_t i = 0; i < in_grad.dims()[0]; ++i) { math::MulByBitCodeGradWeight(num_classes, *label, pre_out, *params[i], - *ins[i]); + *in[i]->Slice(i, i + 1)); math::MulByBitCodeGradError(num_classes, *label, pre_out, *params[i], - *ins_grad[i]); + *ins_grad[i]->Slice(i, i + 1)); } } }; diff --git a/paddle/operators/math/matrix_bit_code.cc b/paddle/operators/math/matrix_bit_code.cc index 996e0b819f683a2a1e4157d37309b0effd4d2066..df98851054779fe680b97ff5c28a6635a807080c 100644 --- a/paddle/operators/math/matrix_bit_code.cc +++ b/paddle/operators/math/matrix_bit_code.cc @@ -52,19 +52,20 @@ namespace math { */ template static void AddByBitCodeT(Op op, CodeTable code_table, - const framework::Tensor& codes, framework::Tensor& a, - const framework::Tensor& b) { + const framework::Tensor& codes, + framework::Tensor& tmat, + const framework::Tensor& vec) { size_t num_classes = code_table.size(); size_t max_code_length = code_table.get_max_code_length(); - size_t num_sample = a.dims()[0]; - size_t width = a.dims()[1]; + size_t num_sample = tmat.dims()[0]; + size_t width = vec.dims()[1]; for (size_t i = 0; i < num_sample; ++i) { auto code = code_table(codes.data()[i]); int code_length = code.get_length(); for (int j = 0; j < code_length; + j) { size_t index = code.calc_index(j); - op(a.data()[i * width + j], b.data()[index]); + op(tmat.data()[i * width + j], vec.data()[index]); } } } diff --git a/python/paddle/v2/fluid/tests/test_hsigmoid_op.py b/python/paddle/v2/fluid/tests/test_hsigmoid_op.py new file mode 100644 index 0000000000000000000000000000000000000000..25c13aabe9771fe18d0b95fbc65541aac5aad183 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_hsigmoid_op.py @@ -0,0 +1,34 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestHSigmoidOp(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid_op" + num_classes = 6 + embded_size = 10 + batch_size = 5 + x = np.random.random((batch_size, embded_size)).astype("float32") + parameter = np.random.random( + (batch_size, num_classes - 1, embded_size)).astype("float32") + label = np.random.randint(0, num_classes, batch_size).astype("int64") + bias = np.random.random((1, num_classes - 1)) + self.inputs = { + 'X': x, + 'Parameters': parameter, + 'Label': label, + 'Bias': bias + } + self.attrs = {'num_classes': num_classes} + self.outputs = {'Out': label} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['x0'], 'Out') + + +if __name__ == '__main__': + unittest.main()