提交 64fe9bcc 编写于 作者: D dangqingqing

Update lstm comments and fix bug.

上级 34aac18c
......@@ -20,7 +20,8 @@ proto_library(framework_proto SRCS framework.proto)
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
device_context)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
......
......@@ -127,7 +127,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op)
op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute math_function)
op_library(lstm_op DEPS sequence2batch lstm_compute)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
......
......@@ -98,18 +98,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"batch size. `H0` and `C0` can be NULL but only at the same time");
AddInput("Weight",
"(Tensor) the learnable hidden-hidden weights."
" - The shape is (D x 4*D), where D is the hidden size. "
" - Weight = {W_ih, W_fh, W_ch, W_oh}");
" - The shape is (D x 4D), where D is the hidden size. "
" - Weight = {W_ch, W_ih, W_fh, W_oh}");
AddInput("Bias",
"(Tensor) the learnable weights, which contains two parts: "
"input-hidden bias weight and peephole connections weight if "
"seting `usePeepholes` True. "
"setting `usePeepholes` True. "
"1. `usePeepholes = False` "
" - The shape is (1 x 4*D). "
" - Bias = {b_i, b_f, b_c, b_o}."
" - The shape is (1 x 4D). "
" - Bias = {b_c, b_i, b_f, b_o}."
"2. `usePeepholes = True` "
" - The shape is (1 x 7*D). "
" - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}.");
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
AddOutput("BatchGate",
"(LoDTensor) This LoDTensor contains input gate, forget gate "
"and output gate after the nonlinear computation. This "
......@@ -184,8 +184,8 @@ Set `usePeepholes` False to disable peephole connection [2]. The formula
is omitted here.
@note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$
operations on the input x_{t} were NOT included in this operator. The
users can choose to use fully-connect operator before LSTM operator.
operations on the input x_{t} were NOT included in this operator.
Users can choose to use fully-connect operator before LSTM operator.
[1] Hasim Sak, Andrew Senior, and Francoise Beaufays. Long short-term memory
recurrent neural network architectures for large scale acoustic modeling.
......
......@@ -76,14 +76,12 @@ class LSTMKernel : public framework::OpKernel<T> {
lstm_value.checkOg = lstm_value.checkFg + frame_size;
lstm_value.prevStateValue = nullptr;
framework::LoDTensor batch_out;
framework::LoDTensor batch_out, batch_cell, batch_cell_pre_act;
batch_out.mutable_data<T>(dims, ctx.GetPlace());
framework::LoDTensor batch_cell;
batch_cell.mutable_data<T>(dims, ctx.GetPlace());
framework::LoDTensor batch_cell_pre_act;
batch_cell_pre_act.mutable_data<T>(dims, ctx.GetPlace());
auto& batch_starts = batch_gate->lod()[0];
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
auto gate_act = ctx.Attr<std::string>("gateActivation");
auto cell_act = ctx.Attr<std::string>("cellActivation");
......
......@@ -51,6 +51,8 @@ class CopyMatrixRowsFunctor<platform::CPUPlace, T> {
template class CopyMatrixRowsFunctor<platform::CPUPlace, float>;
template class CopyMatrixRowsFunctor<platform::CPUPlace, double>;
template class LoDTensor2BatchFunctor<platform::CPUPlace, float>;
template class LoDTensor2BatchFunctor<platform::CPUPlace, double>;
template class Batch2LoDTensorFunctor<platform::CPUPlace, float>;
template class Batch2LoDTensorFunctor<platform::CPUPlace, double>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册