From 64fe9bcc5c1dcbf90f54cb649f40c4e2a1f19ff0 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Mon, 23 Oct 2017 17:51:17 +0800 Subject: [PATCH] Update lstm comments and fix bug. --- paddle/framework/CMakeLists.txt | 3 ++- paddle/operators/CMakeLists.txt | 2 +- paddle/operators/lstm_op.cc | 18 +++++++++--------- paddle/operators/lstm_op.h | 6 ++---- paddle/operators/math/sequence2batch.cc | 2 ++ 5 files changed, 16 insertions(+), 15 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 6e32a1c99b..85752f5d6b 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -20,7 +20,8 @@ proto_library(framework_proto SRCS framework.proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info) -cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc) +cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc +device_context) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 0c53ed3cdc..f97bc837dc 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -127,7 +127,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(sum_op DEPS net_op) op_library(pool_op DEPS pooling) op_library(pool_with_index_op DEPS pooling) -op_library(lstm_op DEPS sequence2batch lstm_compute math_function) +op_library(lstm_op DEPS sequence2batch lstm_compute) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc index 222aeeace5..0a089b7c2d 100644 --- a/paddle/operators/lstm_op.cc +++ b/paddle/operators/lstm_op.cc @@ -98,18 +98,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { "batch size. `H0` and `C0` can be NULL but only at the same time"); AddInput("Weight", "(Tensor) the learnable hidden-hidden weights." - " - The shape is (D x 4*D), where D is the hidden size. " - " - Weight = {W_ih, W_fh, W_ch, W_oh}"); + " - The shape is (D x 4D), where D is the hidden size. " + " - Weight = {W_ch, W_ih, W_fh, W_oh}"); AddInput("Bias", "(Tensor) the learnable weights, which contains two parts: " "input-hidden bias weight and peephole connections weight if " - "seting `usePeepholes` True. " + "setting `usePeepholes` True. " "1. `usePeepholes = False` " - " - The shape is (1 x 4*D). " - " - Bias = {b_i, b_f, b_c, b_o}." + " - The shape is (1 x 4D). " + " - Bias = {b_c, b_i, b_f, b_o}." "2. `usePeepholes = True` " - " - The shape is (1 x 7*D). " - " - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}."); + " - The shape is (1 x 7D). " + " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."); AddOutput("BatchGate", "(LoDTensor) This LoDTensor contains input gate, forget gate " "and output gate after the nonlinear computation. This " @@ -184,8 +184,8 @@ Set `usePeepholes` False to disable peephole connection [2]. The formula is omitted here. @note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$ -operations on the input x_{t} were NOT included in this operator. The -users can choose to use fully-connect operator before LSTM operator. +operations on the input x_{t} were NOT included in this operator. +Users can choose to use fully-connect operator before LSTM operator. [1] Hasim Sak, Andrew Senior, and Francoise Beaufays. Long short-term memory recurrent neural network architectures for large scale acoustic modeling. diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index 5e10036707..b3e3db9726 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -76,14 +76,12 @@ class LSTMKernel : public framework::OpKernel { lstm_value.checkOg = lstm_value.checkFg + frame_size; lstm_value.prevStateValue = nullptr; - framework::LoDTensor batch_out; + framework::LoDTensor batch_out, batch_cell, batch_cell_pre_act; batch_out.mutable_data(dims, ctx.GetPlace()); - framework::LoDTensor batch_cell; batch_cell.mutable_data(dims, ctx.GetPlace()); - framework::LoDTensor batch_cell_pre_act; batch_cell_pre_act.mutable_data(dims, ctx.GetPlace()); - auto& batch_starts = batch_gate->lod()[0]; + auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; auto gate_act = ctx.Attr("gateActivation"); auto cell_act = ctx.Attr("cellActivation"); diff --git a/paddle/operators/math/sequence2batch.cc b/paddle/operators/math/sequence2batch.cc index 00de56f7cd..10c6e105b9 100644 --- a/paddle/operators/math/sequence2batch.cc +++ b/paddle/operators/math/sequence2batch.cc @@ -51,6 +51,8 @@ class CopyMatrixRowsFunctor { template class CopyMatrixRowsFunctor; template class CopyMatrixRowsFunctor; +template class LoDTensor2BatchFunctor; +template class LoDTensor2BatchFunctor; template class Batch2LoDTensorFunctor; template class Batch2LoDTensorFunctor; -- GitLab