diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 147374bc5465dd11871afc93b08ec8a7c6d6b1a8..dadd054b9a6f8d44f4e5832888052bffde34c827 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -86,25 +86,25 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "(Tensor, required) The input Tensor, which the shape is" - "[N, D], which N is the size of mini-batch," - "D is the embded size"); + "(Tensor, required) The input tensor with shape [N, D], " + "where N is the size of mini-batch, and D is the feature size."); AddInput("W", "(Tensor, required), The parameters of hierarchical " - "sigmoid operator, each of them is s a 2-D tensor, the shape is" - "[num_classes - 1, D]"); + "sigmoid operator, each of them is a 2-D tensor, the shape is" + "[num_classes - 1, D]."); AddInput("Label", "(Tensor, required), The labels of training data. It's a" - "1-D tensor, which the shape is [N, 1]"); + "tensor with shape [N, 1]."); AddInput("Bias", "(Tensor, optional), The bias is a tensor with shape" - "[1, num_classes - 1]"); + "[1, num_classes - 1]."); AddOutput("Out", "(Tensor, required) The output of hierarchical sigmoid operator." - "the shape is [N, 1]"); + "The shape is [N, 1]."); AddOutput("PreOut", - "(Tensor, required) A intermedia 2-D Tensor, which the shape is " - "[batch_size, code_length]") + "(Tensor, required) A intermedia 2-D tensor with shape " + "[batch_size, code_length], where code_length represents the " + "maximum path length from root to leaf nodes.") .AsIntermediate(); AddAttr("num_classes", "(int, required), The number of classes") .SetDefault(2); diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index e189abf0b5ff9bf1a6256cef9e7acc21fc86df65..ec8eac9d01d39b034d6dcef35e032789950e4e6c 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -44,9 +44,11 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { framework::Tensor sum; math::SetConstant zero; auto& dev_ctx = ctx.template device_context(); - auto pre_out_data = pre_out->mutable_data( + auto* pre_out_data = pre_out->mutable_data( framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); auto pre_out_mat = EigenMatrix::From(*pre_out); + // Not all class(leaf) nodes' path lengths equal code_length, thus init as + // 0s can avoid out of path's loss. zero(dev_ctx, pre_out, static_cast(0.0)); auto& place = *ctx.template device_context().eigen_device(); math::RowwiseSum row_sum; @@ -61,16 +63,13 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { bit_code.Add(pre_out, *bias); } bit_code.Mul(pre_out, *w, *in); - // clip the matrix with (-40, 40) + // clip to [-40, 40] Transform trans; trans(ctx.template device_context(), pre_out_data, pre_out_data + pre_out->numel(), pre_out_data, ClipFunctor(static_cast(-40.0), static_cast(40.0))); bit_code.Sum(*pre_out, out, static_cast(-1)); - // softrelu with threshold is 40.0 - trans(ctx.template device_context(), pre_out_data, - pre_out_data + pre_out->numel(), pre_out_data, - ClipFunctor(static_cast(-40.0), static_cast(40.0))); + // use softrelu to calculate cross entropy pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); row_sum(dev_ctx, *pre_out, &sum); out_mat.device(place) = sum_mat + out_mat; @@ -102,14 +101,16 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { auto pre_out_mat = EigenMatrix::From(*pre_out); auto pre_out_grad_mat = EigenMatrix::From(pre_out_grad); math::MatrixBitCodeFunctor bit_code(num_classes, label->data()); - // softrelu derivative - Eigen::array bcast({1, static_cast(pre_out_grad.dims()[1])}); + Eigen::array bcast({{1, static_cast(pre_out_grad.dims()[1])}}); auto out_grad_mat = EigenMatrix::From(*out_grad); pre_out_grad_mat = out_grad_mat.broadcast(bcast); pre_out_grad_mat.device(place) = pre_out_grad_mat * - (static_cast(1.0) - static_cast(1.0) / pre_out_mat.exp()); + (static_cast(1.0) - + static_cast(1.0) / pre_out_mat.exp()); // softrelu derivative bit_code.Sub(&pre_out_grad); + // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to + // be consistent with the clipping in forward. if (bias_grad) { bias_grad->mutable_data(ctx.GetPlace()); bit_code.AddGrad(pre_out_grad, bias_grad); diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index e5027de168fd99180b8163bd4ed99d999b59598e..b911ce2397c9d44337a2dbf77864d1e10429a8e7 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -65,12 +65,24 @@ inline constexpr size_t FindLastSet(size_t x) { struct SimpleCode { SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {} + /** + * calc_index should make sure that all siblings have the same weight indice. + * As for which weight index it maps to, it doesn't matter. To satisfy this, + * the id of root should be 1, and the left child of a node i is 2*i, the + * right child of a node i is 2*i+1. + */ inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; } + /** + * calc_bit uses the right most bits, while calc_index uses the left most + * bits. They are not the same, and that's why we say it doesn't matter which + * weight index calc_index maps to. + */ inline bool calc_bit(int bit) const { return c_ & (1 << bit); } inline int get_length() const { return FindLastSet(c_) - 1; } private: - size_t c_; + size_t c_; // Here the id of root is 1 rather than 0, thus the id of class c + // is `c + num_classes`. }; struct SimpleCodeTable { @@ -83,7 +95,6 @@ struct SimpleCodeTable { private: size_t num_classes_; - int max_code_length_; }; template diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 925700d7368ae31e7b697ca3b82115e3b900d21c..28ff31d6f09515638c760176f1a3566dd4bb92d4 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3858,29 +3858,32 @@ def nce(input, return cost / (num_neg_samples + 1) -def hsigmoid(input, label, num_classes=2, param_attr=None, bias_attr=None): +def hsigmoid(input, label, num_classes, param_attr=None, bias_attr=None): """ The hierarchical sigmoid operator is used to accelerate the training process of language model. This operator organizes the classes into a - complete binary tree, each leaf node represents a class(a word) and each internal - node acts likea binary classifier. For each word there's a unique path from root - to it's leaf node, hsigmoid calculate the cost for each internal node on the path - (include root), and sum them to get a total cost. hsigmoid can achive a acceleration - from N to logN, for which N represents the size of word dict. This idea is from "F. - Morin, Y. Bengio(AISTATS 05): Hierarchical Probabilistic Neural Network Language Model. - + complete binary tree, each leaf node represents a class(a word) and each + internal node acts as a binary classifier. For each word there's a unique + path from root to it's leaf node, hsigmoid calculate the cost for each + internal node on the path, and sum them to get a total cost. hsigmoid can + achive a acceleration from :math:`O(N)` to :math:`O(logN)`, where :math:`N` + represents the size of word dict. + + Refer to `Hierarchical Probabilistic Neural Network Language Model + `_ + Args: - input (Variable): (Tensor) The input Tensor, which the shape is - [N * D], which N is the size of mini-batch,D is the embded size - label (Variable): (Tensor), The labels of training data. It's a - 1-D tensor, which the shape is [1, N] - num_classes: (int, default 2), The number of classes, must be lager or - equal than 2. + input (Variable): The input tensor variable with shape + :math:`[N \\times D]`, where :math:`N` is the size of mini-batch, + and :math:`D` is the feature size. + label (Variable): The tensor variable contains labels of training data. + It's a tensor with shape is :math:`[N \\times 1]`. + num_classes: (int), The number of classes, must not be less than 2. param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable parameters/weights of this layer. bias_attr (ParamAttr|list of ParamAttr, default None): The parameter - attribute for the bias of this layer. If it is set to None, no bias - will be added to the output units. + attribute for the bias of this layer. If it is set to False, no + bias will be applied. Returns: Out: (Tensor) The cost of hierarchical sigmoid operator. the shape is [N, 1] @@ -3889,11 +3892,9 @@ def hsigmoid(input, label, num_classes=2, param_attr=None, bias_attr=None): .. code-block:: python - x = fluid.layers.data(name='x', shape=[3, 2], - dtype='float32') - y = fluid.layers.data(name='y', shape=[1, 3], - dtype='int64') - out = fluid.layers.hsigmoid(input=x, label=y, num_classes=2) + x = fluid.layers.data(name='x', shape=[2], dtype='float32') + y = fluid.layers.data(name='y', shape=[1], dtype='int64') + out = fluid.layers.hsigmoid(input=x, label=y, num_classes=6) """ helper = LayerHelper('hierarchical_sigmoid', **locals()) @@ -3902,7 +3903,7 @@ def hsigmoid(input, label, num_classes=2, param_attr=None, bias_attr=None): pre_out = helper.create_tmp_variable(dtype) dim = input.shape[1] if num_classes < 2: - raise ValueError("num_classes must be lager or equal than 2.") + raise ValueError("num_classes must not be less than 2.") weights = helper.create_parameter( attr=helper.param_attr, shape=[num_classes - 1, dim], diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index da58b8e62645aaece69ae77316e770a8cd35d2e0..000c7263d604dbedee84f56f74eb60c32dfe08c7 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -55,10 +55,7 @@ def hsigmoid(x, w, label, bias, num_classes): length = code_table.get_length() for k in range(length): idx = code_table.cal_index(k) - sum = 0.0 - for l in range(x.shape[1]): - sum += w[idx][l] * x[j][l] - pre_output[j][k] += sum + pre_output[j][k] = np.dot(w[idx], x[j]) # clip[-40.0, 40.0] pre_output = np.clip(pre_output, -40.0, 40.0) # out(i, 0) = \sum_j bit(i, j) * preout(i, j) @@ -71,7 +68,6 @@ def hsigmoid(x, w, label, bias, num_classes): sum += pre_output[i][j] out[i] = -1.0 * sum # soft relu - np.clip(pre_output, -40.0, 40.0) pre_output = np.log(1 + np.exp(pre_output)) pre_sum = pre_output.sum(1).reshape((batch_size, 1)) out += pre_sum @@ -81,11 +77,11 @@ def hsigmoid(x, w, label, bias, num_classes): class TestHSigmoidOp(OpTest): def setUp(self): self.op_type = "hierarchical_sigmoid" - num_classes = 4 - embded_size = 1 - batch_size = 1 - x = np.random.random((batch_size, embded_size)).astype("float32") - w = np.random.random((num_classes - 1, embded_size)).astype("float32") + num_classes = 6 + feature_size = 5 + batch_size = 4 + x = np.random.random((batch_size, feature_size)).astype("float32") + w = np.random.random((num_classes - 1, feature_size)).astype("float32") label = np.random.randint(0, num_classes, batch_size) bias = np.random.random((1, num_classes - 1)).astype("float32") self.attrs = {'num_classes': num_classes} @@ -97,7 +93,7 @@ class TestHSigmoidOp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['Bias', 'X', 'W'], 'Out', no_grad_set=set('Label')) + self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) if __name__ == '__main__':