From 4ee069fdba7f67d98229848931f059b620505fdd Mon Sep 17 00:00:00 2001 From: guosheng Date: Thu, 12 Jul 2018 12:57:48 +0800 Subject: [PATCH] Fix the HierarchicalSigmoidGradOpKernel and refine the codes. Now hsigmoid_op is same with V2 implementation and can pass gradient check. --- .../fluid/operators/hierarchical_sigmoid_op.h | 39 ++++++++++++------- .../fluid/operators/math/matrix_bit_code.cc | 2 + paddle/fluid/operators/math/matrix_bit_code.h | 19 ++++----- .../fluid/tests/unittests/test_hsigmoid_op.py | 19 +++++---- 4 files changed, 43 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index ec8eac9d01d..64096a717b1 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -42,13 +42,13 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { int64_t code_length = math::FindLastSet(num_classes - 1); int64_t batch_size = in->dims()[0]; framework::Tensor sum; - math::SetConstant zero; auto& dev_ctx = ctx.template device_context(); auto* pre_out_data = pre_out->mutable_data( framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); auto pre_out_mat = EigenMatrix::From(*pre_out); // Not all class(leaf) nodes' path lengths equal code_length, thus init as // 0s can avoid out of path's loss. + math::SetConstant zero; zero(dev_ctx, pre_out, static_cast(0.0)); auto& place = *ctx.template device_context().eigen_device(); math::RowwiseSum row_sum; @@ -72,6 +72,10 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { // use softrelu to calculate cross entropy pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); row_sum(dev_ctx, *pre_out, &sum); + // TODO(guosheng): Subtract the out of path's loss, since not all + // class(leaf) nodes' path lengths equal code_length. But it won't break the + // gradient check since both have the out of path's loss and will cancel out + // each other. out_mat.device(place) = sum_mat + out_mat; } }; @@ -90,33 +94,38 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { auto* pre_out = ctx.Input("PreOut"); auto* out_grad = ctx.Input(framework::GradVarName("Out")); + framework::Tensor pre_out_grad; + + pre_out_grad.mutable_data(pre_out->dims(), ctx.GetPlace()); + in_grad->mutable_data(ctx.GetPlace()); + w_grad->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + math::SetConstant zero; + zero(dev_ctx, in_grad, static_cast(0.0)); + zero(dev_ctx, w_grad, static_cast(0.0)); size_t num_classes = static_cast(ctx.Attr("num_classes")); - int64_t code_length = math::FindLastSet(num_classes - 1); - int64_t batch_size = in->dims()[0]; - framework::Tensor pre_out_grad; - pre_out_grad.mutable_data( - framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); + math::MatrixBitCodeFunctor bit_code(num_classes, label->data()); + auto& place = *ctx.template device_context().eigen_device(); auto pre_out_mat = EigenMatrix::From(*pre_out); auto pre_out_grad_mat = EigenMatrix::From(pre_out_grad); - math::MatrixBitCodeFunctor bit_code(num_classes, label->data()); - Eigen::array bcast({{1, static_cast(pre_out_grad.dims()[1])}}); auto out_grad_mat = EigenMatrix::From(*out_grad); - pre_out_grad_mat = out_grad_mat.broadcast(bcast); + Eigen::array bcast({{1, static_cast(pre_out_grad.dims()[1])}}); + + // softrelu derivative + pre_out_grad_mat.device(place) = + static_cast(1.0) - static_cast(1.0) / pre_out_mat.exp(); + bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b) pre_out_grad_mat.device(place) = - pre_out_grad_mat * - (static_cast(1.0) - - static_cast(1.0) / pre_out_mat.exp()); // softrelu derivative - bit_code.Sub(&pre_out_grad); + pre_out_grad_mat * out_grad_mat.broadcast(bcast); // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to // be consistent with the clipping in forward. if (bias_grad) { bias_grad->mutable_data(ctx.GetPlace()); + zero(dev_ctx, bias_grad, static_cast(0.0)); bit_code.AddGrad(pre_out_grad, bias_grad); } - in_grad->mutable_data(ctx.GetPlace()); - w_grad->mutable_data(ctx.GetPlace()); bit_code.MulGradWeight(pre_out_grad, w_grad, *in); bit_code.MulGradError(pre_out_grad, *w, in_grad); } diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 7d4955c6a09..1e56e297396 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -62,6 +62,8 @@ void MatrixBitCodeFunctor::Sum(const framework::Tensor& tmat, int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { if (code.calc_bit(j)) { + // calc_bit starts from right most bit, while data in tmat[i] is in the + // reverse order. sm += tmat.data()[i * o_width + j]; } } diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index b911ce2397c..5454d58f371 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -66,23 +66,20 @@ inline constexpr size_t FindLastSet(size_t x) { struct SimpleCode { SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {} /** - * calc_index should make sure that all siblings have the same weight indice. - * As for which weight index it maps to, it doesn't matter. To satisfy this, - * the id of root should be 1, and the left child of a node i is 2*i, the - * right child of a node i is 2*i+1. + * Here the id of root shoud be 1 rather than 0, thus the encoding of class c + * is `c + num_classes` and all siblings can get the same weight indice using + * prefixes. + * Weight index is the prefixes of encoding, thus leave out the right most + * bit in calc_index. + * Binary classification path is the suffixes of encoding, thus leave out the + * left most bit in calc_bit. */ inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; } - /** - * calc_bit uses the right most bits, while calc_index uses the left most - * bits. They are not the same, and that's why we say it doesn't matter which - * weight index calc_index maps to. - */ inline bool calc_bit(int bit) const { return c_ & (1 << bit); } inline int get_length() const { return FindLastSet(c_) - 1; } private: - size_t c_; // Here the id of root is 1 rather than 0, thus the id of class c - // is `c + num_classes`. + size_t c_; }; struct SimpleCodeTable { diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index 000c7263d60..d090960c84e 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -37,7 +37,6 @@ class CodeTable(object): def hsigmoid(x, w, label, bias, num_classes): - global pre_output batch_size = x.shape[0] code_length = find_latest_set(num_classes - 1) code_table = [0 for _ in range(code_length)] @@ -50,12 +49,12 @@ def hsigmoid(x, w, label, bias, num_classes): for j in range(length): idx = code_table.cal_index(j) pre_output[i][j] += bias[0][idx] - for j in range(batch_size): - code_table = CodeTable(num_classes, label[j]) + for i in range(batch_size): + code_table = CodeTable(num_classes, label[i]) length = code_table.get_length() - for k in range(length): - idx = code_table.cal_index(k) - pre_output[j][k] = np.dot(w[idx], x[j]) + for j in range(length): + idx = code_table.cal_index(j) + pre_output[i][j] += np.dot(w[idx], x[i]) # clip[-40.0, 40.0] pre_output = np.clip(pre_output, -40.0, 40.0) # out(i, 0) = \sum_j bit(i, j) * preout(i, j) @@ -71,22 +70,22 @@ def hsigmoid(x, w, label, bias, num_classes): pre_output = np.log(1 + np.exp(pre_output)) pre_sum = pre_output.sum(1).reshape((batch_size, 1)) out += pre_sum - return out + return pre_output, out class TestHSigmoidOp(OpTest): def setUp(self): self.op_type = "hierarchical_sigmoid" num_classes = 6 - feature_size = 5 + feature_size = 8 batch_size = 4 x = np.random.random((batch_size, feature_size)).astype("float32") w = np.random.random((num_classes - 1, feature_size)).astype("float32") - label = np.random.randint(0, num_classes, batch_size) + label = np.random.randint(0, num_classes, (batch_size, 1)) bias = np.random.random((1, num_classes - 1)).astype("float32") self.attrs = {'num_classes': num_classes} self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} - out = hsigmoid(x, w, label, bias, num_classes) + pre_output, out = hsigmoid(x, w, label, bias, num_classes) self.outputs = {'PreOut': pre_output, 'Out': out} def test_check_output(self): -- GitLab