From ee13b396f2d0a3a9e677c221a0344f1fbf2caf0e Mon Sep 17 00:00:00 2001 From: weixing02 Date: Fri, 15 Jun 2018 06:57:30 +0000 Subject: [PATCH] fix some errors --- .../operators/hierarchical_sigmoid_op.cc | 34 +++++++++-------- .../fluid/operators/hierarchical_sigmoid_op.h | 12 +++--- .../fluid/operators/math/matrix_bit_code.cc | 37 ------------------- paddle/fluid/operators/math/matrix_bit_code.h | 32 +++++++++++++--- python/paddle/fluid/layers/nn.py | 19 +++++----- .../fluid/tests/unittests/test_hsigmoid_op.py | 18 ++++----- .../fluid/tests/unittests/test_layers.py | 4 +- 7 files changed, 73 insertions(+), 83 deletions(-) diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 499e641ff0..119c437f90 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -62,7 +62,7 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); PADDLE_ENFORCE(ctx->HasOutput("PreOut"), @@ -87,19 +87,18 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "(Tensor, required) The input Tensor, which the shape is" - "[N * D], which N is the size of mini-batch," + "[N, D], which N is the size of mini-batch," "D is the embded size"); AddInput("W", "(Tensor, required), The parameters of hierarchical " - "sigmoid operator, each of them is s a 3-D tensor, the shape is" + "sigmoid operator, each of them is s a 2-D tensor, the shape is" "[num_classes - 1, D]"); - AddInput("Ids", + AddInput("Label", "(Tensor, required), The labels of training data. It's a" "1-D tensor, which the shape is [1, N]"); AddInput("Bias", - "(Tensor, optional), The bias is a 1-D tensor, " - "which is applied to the output, the shape is" - "[1, num_classes -1]"); + "(Tensor, optional), The bias is a tensor with shape" + "[1, num_classes - 1]"); AddOutput("Out", "(Tensor, required) The output of hierarchical sigmoid operator." "the shape is [N, 1]"); @@ -111,7 +110,7 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(2); AddComment(R"DOC( The hierarchical sigmoid operator organize the classes into a binary tree. -At each node, a sigmoid function is used to caculate the probability of +At each node, a sigmoid function is used to calculate the probability of belonging to the right branch. This idea is from "F. Morin, Y. Bengio (AISTATS 05): Hierarchical Probabilistic Neural Network Language Model." @@ -124,7 +123,7 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); PADDLE_ENFORCE(ctx->HasInput("PreOut"), "Input(Preout) should not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")), @@ -155,9 +154,14 @@ REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp, ops::HierarchicalSigmoidOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp); -REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid, - ops::HierarchicalSigmoidOpKernel< - paddle::platform::CPUDeviceContext, float>); -REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid_grad, - ops::HierarchicalSigmoidGradOpKernel< - paddle::platform::CPUDeviceContext, float>); +REGISTER_OP_CPU_KERNEL( + hierarchical_sigmoid, + ops::HierarchicalSigmoidOpKernel, + ops::HierarchicalSigmoidOpKernel); +REGISTER_OP_CPU_KERNEL( + hierarchical_sigmoid_grad, + ops::HierarchicalSigmoidGradOpKernel, + ops::HierarchicalSigmoidGradOpKernel); diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 5efac8804e..e189abf0b5 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -34,7 +34,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); auto* w = ctx.Input("W"); - auto* ids = ctx.Input("Ids"); + auto* label = ctx.Input("Label"); auto* bias = ctx.Input("Bias"); auto* out = ctx.Output("Out"); auto* pre_out = ctx.Output("PreOut"); @@ -50,7 +50,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { zero(dev_ctx, pre_out, static_cast(0.0)); auto& place = *ctx.template device_context().eigen_device(); math::RowwiseSum row_sum; - math::MatrixBitCodeFunctor bit_code(num_classes, ids->data()); + math::MatrixBitCodeFunctor bit_code(num_classes, label->data()); std::vector sum_dims({batch_size, 1UL}); sum.mutable_data(framework::make_ddim(sum_dims), ctx.GetPlace()); @@ -87,7 +87,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { auto* w_grad = ctx.Output(framework::GradVarName("W")); auto* bias_grad = ctx.Output(framework::GradVarName("Bias")); - auto* ids = ctx.Input("Ids"); + auto* label = ctx.Input("Label"); auto* pre_out = ctx.Input("PreOut"); auto* out_grad = ctx.Input(framework::GradVarName("Out")); @@ -101,9 +101,11 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { auto& place = *ctx.template device_context().eigen_device(); auto pre_out_mat = EigenMatrix::From(*pre_out); auto pre_out_grad_mat = EigenMatrix::From(pre_out_grad); - math::MatrixBitCodeFunctor bit_code(num_classes, ids->data()); + math::MatrixBitCodeFunctor bit_code(num_classes, label->data()); // softrelu derivative - bit_code.OutGrad(&pre_out_grad, *out_grad); + Eigen::array bcast({1, static_cast(pre_out_grad.dims()[1])}); + auto out_grad_mat = EigenMatrix::From(*out_grad); + pre_out_grad_mat = out_grad_mat.broadcast(bcast); pre_out_grad_mat.device(place) = pre_out_grad_mat * (static_cast(1.0) - static_cast(1.0) / pre_out_mat.exp()); diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index ea708eb971..7d4955c6a0 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -18,32 +18,6 @@ namespace paddle { namespace operators { namespace math { -/** - * CodeTable class should support 3 functions: - * - * size_t size() - * return the number of ids - * - * int getMaxCodeLength() - * return the maximal code length - * - * Code operator()(size_t i) - * return the i-th code. Code class is descriebed below. - * - * Code class should support 3 functions: - * - * int getLength() - * return the length of the code - * - * bool calcIndex(int bit) - * bit ranges from 0 to getLength() - 1 - * return the index for the (1+bit) level parent - * - * bool calcBit(int bit) - * return true if the bit level parent is the right child of (1+bit) level - * parent - * - */ template void MatrixBitCodeFunctor::Add(framework::Tensor* tmat, const framework::Tensor& vec) { @@ -192,17 +166,6 @@ void MatrixBitCodeFunctor::Sub(framework::Tensor* tmat) { } } -template -void MatrixBitCodeFunctor::OutGrad(framework::Tensor* tmat, - const framework::Tensor& input) { - size_t num_samples = tmat->dims()[0]; - size_t code_length = tmat->dims()[1]; - for (size_t i = 0; i < num_samples; ++i) - for (size_t j = 0; j < code_length; ++j) { - tmat->data()[i * code_length + j] = input.data()[i]; - } -} - template class MatrixBitCodeFunctor; template class MatrixBitCodeFunctor; diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index 43820810e1..e5027de168 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -20,13 +20,39 @@ limitations under the License. */ namespace paddle { namespace operators { namespace math { +/** + * SimpleCodeTable class should support 3 functions: + * + * size_t size() + * return the number of ids + * + * int get_max_code_length() + * return the maximal code length + * + * SimpleCode operator()(size_t i) + * return the i-th code. Code class is descriebed below. + * + * SimpleCode class should support 3 functions: + * + * int get_length() + * return the length of the code + * + * size_t cal_index(int bit) + * bit ranges from 0 to get_length() - 1 + * return the index for the (1+bit) level parent + * + * bool calc_bit(int bit) + * return true if the bit level parent is the right child of (1+bit) level + * parent + * + */ /** * return the 1-based index of the highest bit set * * for x > 0: * \f[ - * findLastSet(x) = 1 + \floor*{\log_{2}x} + * FindLastSet(x) = 1 + \floor*{\log_{2}x} * \f] */ inline constexpr size_t FindLastSet(size_t x) { @@ -100,10 +126,6 @@ class MatrixBitCodeFunctor { */ void MulGradError(const framework::Tensor& tmat, const framework::Tensor& weight, framework::Tensor* input); - /* For j < code_length - tmat(i, j) == input(i) - */ - void OutGrad(framework::Tensor* tmat, const framework::Tensor& input); size_t num_classes_; const int64_t* ids_; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c3ff9b7725..ac3ba4174f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3571,18 +3571,17 @@ def hsigmoid(input, label, num_classes=2, param_attr=None, bias_attr=None): shape=[num_classes - 1, dim], is_bias=False, dtype=input.dtype) - bias = helper.create_parameter( - attr=helper.bias_attr, - shape=[1, num_classes - 1], - is_bias=True, - dtype=input.dtype) - + inputs = {"X": input, "W": weights, "Label": label} + if helper.bias_attr: + bias = helper.create_parameter( + attr=helper.bias_attr, + shape=[1, num_classes - 1], + is_bias=True, + dtype=input.dtype) + inputs['Bias'] = bias helper.append_op( type="hierarchical_sigmoid", - inputs={"X": input, - "W": weights, - "Ids": label, - "Bias": bias}, + inputs=inputs, outputs={"Out": out, "PreOut": pre_out}, attrs={"num_classes": num_classes}) diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index 226ce8b904..da58b8e626 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -36,7 +36,7 @@ class CodeTable(object): return self.c & (1 << bit) -def hsigmoid(x, w, ids, bias, num_classes): +def hsigmoid(x, w, label, bias, num_classes): global pre_output batch_size = x.shape[0] code_length = find_latest_set(num_classes - 1) @@ -45,13 +45,13 @@ def hsigmoid(x, w, ids, bias, num_classes): pre_sum = np.zeros((batch_size, 1)) out = np.zeros((batch_size, 1)).astype("float32") for i in range(batch_size): - code_table = CodeTable(num_classes, ids[i]) + code_table = CodeTable(num_classes, label[i]) length = code_table.get_length() for j in range(length): idx = code_table.cal_index(j) pre_output[i][j] += bias[0][idx] for j in range(batch_size): - code_table = CodeTable(num_classes, ids[j]) + code_table = CodeTable(num_classes, label[j]) length = code_table.get_length() for k in range(length): idx = code_table.cal_index(k) @@ -60,10 +60,10 @@ def hsigmoid(x, w, ids, bias, num_classes): sum += w[idx][l] * x[j][l] pre_output[j][k] += sum # clip[-40.0, 40.0] - np.clip(pre_output, -40.0, 40.0) + pre_output = np.clip(pre_output, -40.0, 40.0) # out(i, 0) = \sum_j bit(i, j) * preout(i, j) for i in range(batch_size): - code_table = CodeTable(num_classes, ids[i]) + code_table = CodeTable(num_classes, label[i]) length = code_table.get_length() sum = 0.0 for j in range(length): @@ -86,18 +86,18 @@ class TestHSigmoidOp(OpTest): batch_size = 1 x = np.random.random((batch_size, embded_size)).astype("float32") w = np.random.random((num_classes - 1, embded_size)).astype("float32") - ids = np.random.randint(0, num_classes, batch_size) + label = np.random.randint(0, num_classes, batch_size) bias = np.random.random((1, num_classes - 1)).astype("float32") self.attrs = {'num_classes': num_classes} - self.inputs = {'X': x, 'W': w, 'Ids': ids, 'Bias': bias} - out = hsigmoid(x, w, ids, bias, num_classes) + self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} + out = hsigmoid(x, w, label, bias, num_classes) self.outputs = {'PreOut': pre_output, 'Out': out} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['Bias', 'X', 'W'], 'Out', no_grad_set=set('Ids')) + self.check_grad(['Bias', 'X', 'W'], 'Out', no_grad_set=set('Label')) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index f6e516bbe7..f5b305a025 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -176,8 +176,8 @@ class TestBook(unittest.TestCase): def test_hsigmoid(self): program = Program() with program_guard(program): - x = layers.data(name='x', shape=[2, 2], dtype='float32') - y = layers.data(name='y', shape=[1, 2], dtype='int64') + x = layers.data(name='x', shape=[2], dtype='float32') + y = layers.data(name='y', shape=[2], dtype='int64') self.assertIsNotNone( layers.hsigmoid( input=x, label=y, num_classes=2)) -- GitLab