diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu index c79c4d0c721f9e568c937cb9e524e925fcdc83d0..5b90fbfca7f6bec4f2c862d0ff18dfd7cf39e181 100644 --- a/paddle/framework/lod_tensor_test.cu +++ b/paddle/framework/lod_tensor_test.cu @@ -36,8 +36,8 @@ TEST(LoDTensor, LoDInGPU) { lod_tensor.mutable_data(place); lod_tensor.set_lod(src_lod); - CHECK_EQ(lod_tensor.lod_element(0, 2).first, 4UL); - CHECK_EQ(lod_tensor.lod_element(0, 4).first, 8UL); + EXPECT_EQ(lod_tensor.lod_element(0, 2).first, 4UL); + EXPECT_EQ(lod_tensor.lod_element(0, 4).first, 8UL); auto lod = lod_tensor.lod(); @@ -45,6 +45,6 @@ TEST(LoDTensor, LoDInGPU) { cudaDeviceSynchronize(); for (size_t i = 0; i < src_lod[0].size(); ++i) { - CHECK_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2); + EXPECT_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2); } -} \ No newline at end of file +} diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc index 0a089b7c2dc1e05224525bc4fe5399ec39036d01..94342d940704d850a2a45c281a3d88de5a132753 100644 --- a/paddle/operators/lstm_op.cc +++ b/paddle/operators/lstm_op.cc @@ -21,7 +21,6 @@ class LSTMOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), "Input(Input) of LSTM should not be null."); @@ -29,9 +28,13 @@ class LSTMOp : public framework::OperatorWithKernel { "Output(Hidden) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Cell"), "Output(Cell) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchGate"), + "Output(BatchGate) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"), + "Output(BatchGate) of LSTM should not be null."); - auto x_dims = ctx->GetInputDim("Input"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); + auto in_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2."); if (ctx->HasInput("H0")) { PADDLE_ENFORCE(ctx->HasInput("C0"), @@ -44,7 +47,7 @@ class LSTMOp : public framework::OperatorWithKernel { "should be the same."); } - int frame_size = x_dims[1] / 4; + int frame_size = in_dims[1] / 4; auto w_dims = ctx->GetInputDim("Weight"); PADDLE_ENFORCE_EQ(w_dims.size(), 2, "The rank of Input(Weight) should be 2."); @@ -71,12 +74,21 @@ class LSTMOp : public framework::OperatorWithKernel { "4 * %d if disable peepholes connection", frame_size); } - ctx->SetOutputDim("Hidden", {x_dims[0], frame_size}); - ctx->SetOutputDim("Cell", {x_dims[0], frame_size}); - ctx->SetOutputDim("BatchGate", x_dims); + framework::DDim out_dims({in_dims[0], frame_size}); + ctx->SetOutputDim("Hidden", out_dims); + ctx->SetOutputDim("Cell", out_dims); + ctx->SetOutputDim("BatchGate", in_dims); + ctx->SetOutputDim("BatchCellPreAct", out_dims); ctx->ShareLoD("Input", "Hidden"); ctx->ShareLoD("Input", "Cell"); } + + protected: + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType( + ctx.Input("Input")->type()); + } }; class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { @@ -86,16 +98,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Input", "(LoDTensor) the first input is a LodTensor, which support " "variable-time length input sequence. The underlying tensor in " - "this LoDTensor is a matrix with shape (T X 4D), where, T is the " + "this LoDTensor is a matrix with shape (T X 4D), where T is the " "total time steps in this mini-batch, D is the hidden size."); AddInput("H0", "(Tensor, optional) the initial hidden state is an optional " "input. This is a tensor with shape (N x D), where N is the " - "batch size, D is the hidden size."); + "batch size, D is the hidden size.") + .AsDispensable(); AddInput("C0", "(Tensor, optional) the initial cell state is an optional " "input. This is a tensor with shape (N x D), where N is the " - "batch size. `H0` and `C0` can be NULL but only at the same time"); + "batch size. `H0` and `C0` can be NULL but only at the same time") + .AsDispensable(); AddInput("Weight", "(Tensor) the learnable hidden-hidden weights." " - The shape is (D x 4D), where D is the hidden size. " @@ -109,22 +123,27 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { " - Bias = {b_c, b_i, b_f, b_o}." "2. `usePeepholes = True` " " - The shape is (1 x 7D). " - " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."); + " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.") + .AsDispensable(); + AddOutput("Hidden", + "(LoDTensor) the hidden state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); + AddOutput("Cell", + "(LoDTensor) the cell state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); AddOutput("BatchGate", "(LoDTensor) This LoDTensor contains input gate, forget gate " "and output gate after the nonlinear computation. This " "LoDTensor has the same shape with the reorganized input, which " - "was also be called batch input. The LoD size is 2. The first " + "is also be called batch input. The LoD size is 2. The first " "LoD is the batch offsets and the second LoD contains the " "indexes, which denote the position of reorganized sequence " "in the raw input.") .AsIntermediate(); - AddOutput("Hidden", - "(LoDTensor) the hidden state lod tensor of LSTM operator. " - "The shape and lod is the same with the `Input`."); - AddOutput("Cell", - "(LoDTensor) the cell state lod tensor of LSTM operator. " - "The shape and lod is the same with the `Input`."); + AddOutput("BatchCellPreAct", + "(LoDTensor) This LoDTensor is got in the forward and used " + "in the backward.") + .AsIntermediate(); AddAttr("usePeepholes", "(bool, defalut: True) " "whether to enable diagonal/peephole connections.") @@ -202,15 +221,37 @@ class LSTMGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")), - "Input(Hidden@GRAD) should not be null"); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")), - "Input(Cell@GRAD) should not be null"); - ctx->SetOutputDim(framework::GradVarName("Weight"), - ctx->GetInputDim("Weight")); - ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias")); + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Hidden"), + "Input(Hidden) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Cell"), + "Input(Cell) of LSTM should not be null."); + + PADDLE_ENFORCE(ctx->HasInput("BatchGate"), + "Input(BatchGate) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"), + "Input(BatchGate) of LSTM should not be null."); + + auto in_g_name = framework::GradVarName("Input"); + if (ctx->HasOutput(in_g_name)) + ctx->SetOutputDim(in_g_name, ctx->GetInputDim("Input")); + + auto w_g_name = framework::GradVarName("Weight"); + if (ctx->HasOutput(w_g_name)) + ctx->SetOutputDim(w_g_name, ctx->GetInputDim("Weight")); + + auto b_g_name = framework::GradVarName("Bias"); + if (ctx->HasOutput(b_g_name)) + ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias")); + } + + protected: + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType( + ctx.Input("Input")->type()); } }; diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index 0af5694c48fcb4437e3acd422606de013bb2e145..af088b80b4283cf221a1dff74546d73d977fada3 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -21,8 +21,9 @@ limitations under the License. */ namespace paddle { namespace operators { -using framework::LoDTensor; -using framework::Tensor; +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + template using EigenMatrix = framework::EigenMatrix; @@ -31,15 +32,15 @@ template class LSTMKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("Input"); - auto* weight = ctx.Input("Weight"); - auto* bias = ctx.Input("Bias"); + auto* input = ctx.Input("Input"); + auto* weight = ctx.Input("Weight"); + auto* bias = ctx.Input("Bias"); - auto* batch_gate = ctx.Output("BatchGate"); + auto* batch_gate = ctx.Output("BatchGate"); batch_gate->mutable_data(ctx.GetPlace()); - auto* hidden_out = ctx.Output("Hidden"); + auto* hidden_out = ctx.Output("Hidden"); hidden_out->mutable_data(ctx.GetPlace()); - auto* cell_out = ctx.Output("Cell"); + auto* cell_out = ctx.Output("Cell"); cell_out->mutable_data(ctx.GetPlace()); // Now the function ShareLoD in InferShape is not implemented. @@ -49,7 +50,8 @@ class LSTMKernel : public framework::OpKernel { bool is_reverse = ctx.Attr("isReverse"); math::LoDTensor2BatchFunctor to_batch; - to_batch(ctx.device_context(), *input, *batch_gate, is_reverse); + auto& device_ctx = ctx.device_context(); + to_batch(device_ctx, *input, *batch_gate, true, is_reverse); auto in_dims = input->dims(); int frame_size = static_cast(in_dims[1] / 4); @@ -69,17 +71,26 @@ class LSTMKernel : public framework::OpKernel { } math::LstmMetaValue lstm_value; - T* bias_data = const_cast(bias->data()); - // the code style in LstmMetaValue will be updated later. - lstm_value.checkIg = bias_data + 4 * frame_size; - lstm_value.checkFg = lstm_value.checkIg + frame_size; - lstm_value.checkOg = lstm_value.checkFg + frame_size; + if (bias) { + T* bias_data = const_cast(bias->data()); + // the code style in LstmMetaValue will be updated later. + + lstm_value.checkIg = bias_data + 4 * frame_size; + lstm_value.checkFg = lstm_value.checkIg + frame_size; + lstm_value.checkOg = lstm_value.checkFg + frame_size; + } else { + lstm_value.checkIg = nullptr; + lstm_value.checkFg = nullptr; + lstm_value.checkOg = nullptr; + } lstm_value.prevStateValue = nullptr; - framework::LoDTensor batch_out, batch_cell, batch_cell_pre_act; - batch_out.mutable_data(dims, ctx.GetPlace()); + // Use the local variable as here. + LoDTensor batch_hidden, batch_cell; + auto* batch_cell_pre_act = ctx.Output("BatchCellPreAct"); + batch_hidden.mutable_data(dims, ctx.GetPlace()); batch_cell.mutable_data(dims, ctx.GetPlace()); - batch_cell_pre_act.mutable_data(dims, ctx.GetPlace()); + batch_cell_pre_act->mutable_data(dims, ctx.GetPlace()); auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; @@ -92,18 +103,18 @@ class LSTMKernel : public framework::OpKernel { int bend = static_cast(batch_starts[n + 1]); Tensor gate_t = batch_gate->Slice(bstart, bend); - Tensor out_t = batch_out.Slice(bstart, bend); + Tensor out_t = batch_hidden.Slice(bstart, bend); Tensor cell_t = batch_cell.Slice(bstart, bend); - Tensor cell_pre_act_t = batch_cell_pre_act.Slice(bstart, bend); + Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend); int cur_batch_size = bend - bstart; if (n != 0) { int pre_h_start = static_cast(batch_starts[n - 1]); int pre_h_end = pre_h_start + cur_batch_size; - auto pre_hidden_t = batch_out.Slice(pre_h_start, pre_h_end); - math::matmul(ctx.device_context(), pre_hidden_t, false, - *weight, false, static_cast(1.0), &gate_t, + auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end); + math::matmul(device_ctx, pre_hidden_t, false, *weight, false, + static_cast(1.0), &gate_t, static_cast(1.0)); } // else if : FIXME support the initial hidden and cell @@ -112,27 +123,186 @@ class LSTMKernel : public framework::OpKernel { lstm_value.outputValue = out_t.data(); lstm_value.stateValue = cell_t.data(); lstm_value.stateActiveValue = cell_pre_act_t.data(); - math::LstmUnitFunctor::compute(ctx.device_context(), lstm_value, + math::LstmUnitFunctor::compute(device_ctx, lstm_value, frame_size, cur_batch_size, gate_act, cell_act, cand_act); lstm_value.prevStateValue = lstm_value.stateValue; } math::Batch2LoDTensorFunctor to_seq; - batch_out.set_lod(batch_gate->lod()); + batch_hidden.set_lod(batch_gate->lod()); // restore the output hidden in LoDTensor from the batch hidden - to_seq(ctx.device_context(), batch_out, *hidden_out); + to_seq(device_ctx, batch_hidden, *hidden_out); batch_cell.set_lod(batch_gate->lod()); // restore the output cell state in LoDTensor from the batch cell - to_seq(ctx.device_context(), batch_cell, *cell_out); + to_seq(device_ctx, batch_cell, *cell_out); } }; template class LSTMGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override {} + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* weight = ctx.Input("Weight"); + auto* bias = ctx.Input("Bias"); + + auto* hidden_out = ctx.Input("Hidden"); + auto* cell_out = ctx.Input("Cell"); + + auto* batch_gate = ctx.Input("BatchGate"); + auto* batch_cell_pre_act = ctx.Input("BatchCellPreAct"); + + auto* hidden_g = ctx.Input(framework::GradVarName("Hidden")); + + auto* in_g = ctx.Output(framework::GradVarName("Input")); + auto* weight_g = ctx.Output(framework::GradVarName("Weight")); + auto* bias_g = ctx.Output(framework::GradVarName("Bias")); + + auto& device_ctx = ctx.device_context(); + math::SetConstant zero; + if (weight_g) { + weight_g->mutable_data(ctx.GetPlace()); + zero(device_ctx, weight_g, static_cast(0.0)); + } + + auto in_dims = input->dims(); + auto out_dims = hidden_g->dims(); + int frame_size = static_cast(in_dims[1] / 4); + PADDLE_ENFORCE_EQ(frame_size, out_dims[1]); + + math::LstmMetaValue lstm_value; + if (bias) { + T* bias_data = const_cast(bias->data()); + lstm_value.checkIg = bias_data + 4 * frame_size; + lstm_value.checkFg = lstm_value.checkIg + frame_size; + lstm_value.checkOg = lstm_value.checkFg + frame_size; + } else { + lstm_value.checkIg = nullptr; + lstm_value.checkFg = nullptr; + lstm_value.checkOg = nullptr; + } + + math::LstmMetaGrad lstm_grad; + if (bias && bias_g) { + T* bias_g_data = const_cast(bias_g->mutable_data(ctx.GetPlace())); + zero(device_ctx, bias_g, static_cast(0.0)); + lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size; + lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size; + lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size; + } else { + lstm_grad.checkIgGrad = nullptr; + lstm_grad.checkFgGrad = nullptr; + lstm_grad.checkOgGrad = nullptr; + } + + math::LoDTensor2BatchFunctor to_batch; + + // use the local variable as here. + LoDTensor batch_hidden; + batch_hidden.mutable_data(out_dims, ctx.GetPlace()); + batch_hidden.set_lod(batch_gate->lod()); + to_batch(device_ctx, *hidden_out, batch_hidden, false); + + LoDTensor batch_hidden_g; + batch_hidden_g.mutable_data(out_dims, ctx.GetPlace()); + batch_hidden_g.set_lod(batch_gate->lod()); + to_batch(device_ctx, *hidden_g, batch_hidden_g, false); + + LoDTensor batch_cell; + batch_cell.mutable_data(out_dims, ctx.GetPlace()); + batch_cell.set_lod(batch_gate->lod()); + to_batch(device_ctx, *cell_out, batch_cell, false); + + LoDTensor batch_cell_g; + batch_cell_g.mutable_data(out_dims, ctx.GetPlace()); + batch_cell_g.set_lod(batch_gate->lod()); + // TODO(qingqing) support the case output cell has gradient. + // to_batch(device_ctx, *cell_g, batch_cell_g, false); + zero(device_ctx, &batch_cell_g, static_cast(0.0)); + + LoDTensor batch_gate_g; + batch_gate_g.mutable_data(batch_gate->dims(), ctx.GetPlace()); + batch_gate_g.set_lod(batch_gate->lod()); + + auto gate_act = ctx.Attr("gateActivation"); + auto cell_act = ctx.Attr("cellActivation"); + auto cand_act = ctx.Attr("candidateActivation"); + + auto batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; + for (int n = static_cast(num_batch) - 1; n >= 0; n--) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + + Tensor gate = batch_gate->Slice(bstart, bend); + Tensor cell = batch_cell.Slice(bstart, bend); + Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend); + lstm_value.gateValue = gate.data(); + lstm_value.stateValue = cell.data(); + lstm_value.stateActiveValue = cell_pre_act.data(); + + Tensor out_g = batch_hidden_g.Slice(bstart, bend); + Tensor gate_g = batch_gate_g.Slice(bstart, bend); + Tensor cell_g = batch_cell_g.Slice(bstart, bend); + lstm_grad.stateGrad = cell_g.data(); + lstm_grad.gateGrad = gate_g.data(); + lstm_grad.outputGrad = out_g.data(); + + if (n) { + int bstart_pre = static_cast(batch_starts[n - 1]); + Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart); + Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart); + lstm_value.prevStateValue = cell_pre.data(); + lstm_grad.prevStateGrad = cell_pre_g.data(); + } else { + lstm_value.prevStateValue = nullptr; + lstm_grad.prevStateGrad = nullptr; + } + + int cur_batch_size = bend - bstart; + math::LstmUnitGradFunctor::compute( + device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size, + gate_act, cell_act, cand_act); + + if (n != 0) { + int pre_h_start = static_cast(batch_starts[n - 1]); + int pre_h_end = pre_h_start + cur_batch_size; + auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end); + math::matmul(device_ctx, gate_g, false, *weight, true, + static_cast(1.0), &pre_hidden_g, + static_cast(1.0)); + if (weight_g) { + /* backward weight */ + auto pre_hidden = batch_hidden.Slice(pre_h_start, pre_h_end); + math::matmul(device_ctx, pre_hidden, true, gate_g, false, + static_cast(1.0), weight_g, + static_cast(1.0)); + } + } + } + + math::Batch2LoDTensorFunctor to_seq; + if (in_g) { + /* backward data */ + in_g->mutable_data(ctx.GetPlace()); + to_seq(device_ctx, batch_gate_g, *in_g); + } + if (bias && bias_g) { + /* backward bias */ + int m = static_cast(batch_gate_g.dims()[0]); + int n = static_cast(batch_gate_g.dims()[1]); + + Tensor ones; + ones.mutable_data({m}, ctx.GetPlace()); + math::SetConstant set; + set(device_ctx, &ones, static_cast(1.0)); + + math::gemv(device_ctx, true, m, n, 1., batch_gate_g.data(), + ones.data(), 0., bias_g->data()); + } + } }; } // namespace operators diff --git a/paddle/operators/math/detail/lstm_cpu_kernel.h b/paddle/operators/math/detail/lstm_cpu_kernel.h index 74d51d7bc9b91f4c8088384d77183131f57aafab..d0ed55ea168bc3e701c421c51d662c646e475351 100644 --- a/paddle/operators/math/detail/lstm_cpu_kernel.h +++ b/paddle/operators/math/detail/lstm_cpu_kernel.h @@ -26,10 +26,7 @@ namespace detail { template void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, - int frameSize, - activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + int frameSize) { T rValueIn; T rValueIg; T rValueFg; @@ -60,10 +57,8 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, rPrevState = value.prevStateValue[i]; } - hppl::cpu::ForwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate), - act(active_state)); + rOut, rCheckI, rCheckF, rCheckO); valueIn[i] = rValueIn; valueIg[i] = rValueIg; @@ -77,10 +72,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, template void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, - LstmMetaGrad grad, int frameSize, - activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + LstmMetaGrad grad, int frameSize) { T rValueIn; T rValueIg; T rValueFg; @@ -127,11 +119,10 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, rPrevState = value.prevStateValue[i]; } - hppl::cpu::BackwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, - rCheckOGrad, act(active_node), act(active_gate), act(active_state)); + rCheckOGrad); gradIn[i] = rGradIn; gradIg[i] = rGradIg; @@ -283,8 +274,7 @@ void cpu_lstm_forward(Op op, LstmMetaValue value, int frameSize, avx_lstm_forward_one_sequence(op, value, frameSize, active_node, active_gate, active_state); } else { - naive_lstm_forward_one_sequence(op, value, frameSize, active_node, - active_gate, active_state); + naive_lstm_forward_one_sequence(op, value, frameSize); } } @@ -297,8 +287,7 @@ void cpu_lstm_backward(Op op, LstmMetaValue value, LstmMetaGrad grad, avx_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, active_gate, active_state); } else { - naive_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, - active_gate, active_state); + naive_lstm_backward_one_sequence(op, value, grad, frameSize); } } diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index 9573eaefb6a9d678ef70f2e2bffdc6a3011b21ea..c06f164f84a92d31f89901e2656bdb8e69c533b7 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -32,9 +32,7 @@ namespace detail { */ template __global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, - int batchSize, activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + int batchSize) { const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; if (frameIdx >= frameSize) return; @@ -70,10 +68,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, rPrevState = value.prevStateValue[frameIdx]; } - hppl::gpu::ForwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate), - act(active_state)); + rOut, rCheckI, rCheckF, rCheckO); value.gateValue[frameIdx] = rValueIn; value.gateValue[frameIdx + frameSize] = rValueIg; @@ -92,9 +88,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, template __global__ void KeLstmBackward(Op op, LstmMetaValue value, LstmMetaGrad grad, int frameSize, - int batchSize, activation_mode_t active_node, - activation_mode_t active_gate, - activation_mode_t active_state) { + int batchSize) { const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; if (frameIdx >= frameSize) return; @@ -145,11 +139,9 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue value, rPrevState = value.prevStateValue[frameIdx]; } - hppl::gpu::BackwardAct act; op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, - rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad, - act(active_node), act(active_gate), act(active_state)); + rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad); grad.gateGrad[frameIdx] = rGradIn; grad.gateGrad[frameIdx + frameSize] = rGradIg; @@ -205,13 +197,11 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, if (batchSize == 1) { KeLstmForward<<>>( - op, value, frameSize, batchSize, active_node, active_gate, - active_state); + op, value, frameSize, batchSize); } else { KeLstmForward<<>>( - op, value, frameSize, batchSize, active_node, active_gate, - active_state); + op, value, frameSize, batchSize); } } @@ -240,13 +230,11 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, if (batchSize == 1) { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize, active_node, active_gate, - active_state); + op, value, grad, frameSize, batchSize); } else { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize, active_node, active_gate, - active_state); + op, value, grad, frameSize, batchSize); } } diff --git a/paddle/operators/math/detail/lstm_kernel.h b/paddle/operators/math/detail/lstm_kernel.h index 6f3ead2397d5131b4468d0ad288513cedb289594..461039a4d51a2b9b8a55d3101bdf4c511907597e 100644 --- a/paddle/operators/math/detail/lstm_kernel.h +++ b/paddle/operators/math/detail/lstm_kernel.h @@ -24,15 +24,29 @@ namespace detail { namespace forward { +template +DEVICE inline T sigmoid(const T a) { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + T tmp = (a < min) ? min : ((a > max) ? max : a); + return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); +} + +template +DEVICE inline T tanh(const T a) { + T tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + template class lstm { public: HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, T &prevState, T &state, T &stateAtv, T &output, - T &checkI, T &checkF, T &checkO, - typename hppl::ForwardActType::type actInput, - typename hppl::ForwardActType::type actGate, - typename hppl::ForwardActType::type actState) { + T &checkI, T &checkF, T &checkO) { +#if 0 + // TODO(qingqing) support to activation speficed by users valueIn = actInput(valueIn); valueIg = actGate(valueIg + prevState * checkI); valueFg = actGate(valueFg + prevState * checkF); @@ -40,6 +54,15 @@ class lstm { valueOg = actGate(valueOg + state * checkO); stateAtv = actState(state); output = valueOg * stateAtv; +#else + valueIn = tanh(valueIn); + valueIg = sigmoid(valueIg + prevState * checkI); + valueFg = sigmoid(valueFg + prevState * checkF); + state = valueIn * valueIg + prevState * valueFg; + valueOg = sigmoid(valueOg + state * checkO); + stateAtv = tanh(state); + output = valueOg * stateAtv; +#endif } #ifndef __NVCC__ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default @@ -72,6 +95,16 @@ class lstm { namespace backward { +template +DEVICE inline T sigmoid(const T a, const T b) { + return a * b * (1.0 - b); +} + +template +DEVICE inline T tanh(const T a, const T b) { + return a * (1.0 - b * b); +} + template class lstm { public: @@ -80,10 +113,9 @@ class lstm { T &prevState, T &prevStateGrad, T &state, T &stateGrad, T &stateAtv, T &outputGrad, T &checkI, T &checkF, T &checkO, T &checkIGrad, - T &checkFGrad, T &checkOGrad, - typename hppl::BackwardActType::type actInput, - typename hppl::BackwardActType::type actGate, - typename hppl::BackwardActType::type actState) { + T &checkFGrad, T &checkOGrad) { +#if 0 + // TODO(qingqing) support to activation speficed by users gradOg = actGate(outputGrad * stateAtv, valueOg); stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO; gradIn = actInput(stateGrad * valueIg, valueIn); @@ -93,6 +125,17 @@ class lstm { checkIGrad = gradIg * prevState; checkFGrad = gradFg * prevState; checkOGrad = gradOg * state; +#else + gradOg = sigmoid(outputGrad * stateAtv, valueOg); + stateGrad += tanh(outputGrad * valueOg, stateAtv) + gradOg * checkO; + gradIn = tanh(stateGrad * valueIg, valueIn); + gradIg = sigmoid(stateGrad * valueIn, valueIg); + gradFg = sigmoid(stateGrad * prevState, valueFg); + prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg; + checkIGrad = gradIg * prevState; + checkFGrad = gradFg * prevState; + checkOGrad = gradOg * state; +#endif } #ifndef __NVCC__ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index aad1357598c629a4edfe0ad9b23f0241093a2522..2a9c09a0f16b71473e21765ab9253eb7b8bcf28c 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -211,6 +211,26 @@ void batched_gemm( } #endif +template <> +void gemv(const platform::DeviceContext& context, + const bool trans_a, const int M, + const int N, const float alpha, + const float* A, const float* B, + const float beta, float* C) { + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + cblas_sgemv(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); +} + +template <> +void gemv(const platform::DeviceContext& context, + const bool trans_a, const int M, + const int N, const double alpha, + const double* A, const double* B, + const double beta, double* C) { + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + cblas_dgemv(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); +} + template struct SetConstant; } // namespace math diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 5583683c6e12b88ba81015aef9161913de261ef2..e6fd8bf235b8539702ca2c5b39e305cb1becf5cb 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -203,6 +203,33 @@ void batched_gemm( &beta, C, ldc, strideC, batchCount)); } +template <> +void gemv(const platform::DeviceContext& context, + const bool trans_a, const int M, + const int N, const float alpha, + const float* A, const float* B, + const float beta, float* C) { + cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N; + + PADDLE_ENFORCE(platform::dynload::cublasSgemv( + reinterpret_cast(context) + .cublas_handle(), + cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1)); +} + +template <> +void gemv(const platform::DeviceContext& context, + const bool trans_a, const int M, + const int N, const double alpha, + const double* A, const double* B, + const double beta, double* C) { + cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N; + PADDLE_ENFORCE(platform::dynload::cublasDgemv( + reinterpret_cast(context) + .cublas_handle(), + cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1)); +} + template struct SetConstant; } // namespace math diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 9777ebfd156709a370be2cb4ba0077ac7c6735fb..3bb5aa0332c7e2a63d20b91893c03ccd468dd863 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -93,6 +93,11 @@ void batched_gemm(const platform::DeviceContext& context, const T* A, const T* B, const T beta, T* C, const int batchCount, const int strideA, const int strideB); +template +void gemv(const platform::DeviceContext& context, const bool trans_a, + const int M, const int N, const T alpha, const T* A, const T* B, + const T beta, T* C); + template struct SetConstant { void operator()(const platform::DeviceContext& context, diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc index 3b9f92e7ae5f34dd0fb1ba8fb0c67ff5ae1628c4..7d84ad9aadb2892db0d0ee9cab428dc5036614e9 100644 --- a/paddle/operators/math/math_function_test.cc +++ b/paddle/operators/math/math_function_test.cc @@ -89,3 +89,53 @@ TEST(math_function, zero) { EXPECT_EQ(t[2], 1); EXPECT_EQ(t[3], 1); } + +template +void GemvTest(int m, int n, bool trans) { + paddle::framework::Tensor mat_a; + paddle::framework::Tensor vec_b; + paddle::framework::Tensor vec_c; + auto* cpu_place = new paddle::platform::CPUPlace(); + int b_num = trans ? m : n; + int c_num = trans ? n : m; + + T* data_a = mat_a.mutable_data({m, n}, *cpu_place); + T* data_b = vec_b.mutable_data({b_num}, *cpu_place); + T* data_c = vec_c.mutable_data({c_num}, *cpu_place); + for (int i = 0; i < mat_a.numel(); ++i) { + data_a[i] = static_cast(i); + } + for (int i = 0; i < vec_b.numel(); ++i) { + data_b[i] = static_cast(i); + } + + paddle::platform::CPUDeviceContext context(*cpu_place); + paddle::operators::math::gemv( + context, trans, static_cast(m), static_cast(n), 1., data_a, + data_b, 0., data_c); + + if (!trans) { + for (int i = 0; i < m; ++i) { + T sum = 0.0; + for (int j = 0; j < n; ++j) { + sum += data_a[i * n + j] * data_b[j]; + } + ASSERT_FLOAT_EQ(data_c[i], sum); + } + } else { + for (int i = 0; i < n; ++i) { + T sum = 0.0; + for (int j = 0; j < m; ++j) { + sum += data_a[j * n + i] * data_b[j]; + } + ASSERT_FLOAT_EQ(data_c[i], sum); + } + } +} + +TEST(math_function, gemv) { + GemvTest(3, 13, false); + GemvTest(4, 5, false); + GemvTest(12, 7, true); + GemvTest(7, 9, true); +} diff --git a/paddle/operators/math/math_function_test.cu b/paddle/operators/math/math_function_test.cu index 8b22c71552a65044cbd02441fb35c1eafe0173dc..780d17ffc6539c5f4d67ebab5476d6f646840b41 100644 --- a/paddle/operators/math/math_function_test.cu +++ b/paddle/operators/math/math_function_test.cu @@ -177,3 +177,65 @@ TEST(math_function, gemm_trans_cublas) { EXPECT_EQ(input3_ptr[7], 99); delete gpu_place; } + +template +void GemvTest(int m, int n, bool trans) { + paddle::framework::Tensor mat_a; + paddle::framework::Tensor vec_b; + paddle::framework::Tensor vec_c; + auto* cpu_place = new paddle::platform::CPUPlace(); + + T* data_a = mat_a.mutable_data({m, n}, *cpu_place); + T* data_b = vec_b.mutable_data({trans ? m : n}, *cpu_place); + T* data_c = vec_c.mutable_data({trans ? n : m}, *cpu_place); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::framework::Tensor g_mat_a; + paddle::framework::Tensor g_vec_b; + paddle::framework::Tensor g_vec_c; + T* g_data_a = g_mat_a.mutable_data(mat_a.dims(), *gpu_place); + T* g_data_b = g_vec_b.mutable_data(vec_b.dims(), *gpu_place); + T* g_data_c = g_vec_c.mutable_data(vec_c.dims(), *gpu_place); + + for (int i = 0; i < mat_a.numel(); ++i) { + data_a[i] = static_cast(i); + } + for (int i = 0; i < vec_b.numel(); ++i) { + data_b[i] = static_cast(i); + } + + paddle::platform::CUDADeviceContext context(*gpu_place); + g_mat_a.CopyFrom(mat_a, *gpu_place, context); + g_vec_b.CopyFrom(vec_b, *gpu_place, context); + + paddle::operators::math::gemv( + context, trans, static_cast(m), static_cast(n), 1., g_data_a, + g_data_b, 0., g_data_c); + + vec_c.CopyFrom(g_vec_c, paddle::platform::CPUPlace(), context); + + if (!trans) { + for (int i = 0; i < m; ++i) { + T sum = 0.0; + for (int j = 0; j < n; ++j) { + sum += data_a[i * n + j] * data_b[j]; + } + ASSERT_FLOAT_EQ(data_c[i], sum); + } + } else { + for (int i = 0; i < n; ++i) { + T sum = 0.0; + for (int j = 0; j < m; ++j) { + sum += data_a[j * n + i] * data_b[j]; + } + ASSERT_FLOAT_EQ(data_c[i], sum); + } + } +} + +TEST(math_function, gemv) { + GemvTest(3, 13, false); + GemvTest(3, 13, false); + GemvTest(3, 13, true); + GemvTest(3, 13, true); +} diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h index 03cd018e46e90c9bbe689c9686377e0e998ee513..b1ba35a6d4a891e9152ac2088bc76e3969be6405 100644 --- a/paddle/operators/math/sequence2batch.h +++ b/paddle/operators/math/sequence2batch.h @@ -53,7 +53,18 @@ class LoDTensor2BatchFunctor { public: void operator()(const platform::DeviceContext& context, const framework::LoDTensor& lod_tensor, - framework::LoDTensor& batch, bool is_reverse) const { + framework::LoDTensor& batch, bool is_cal_batch_lod, + bool is_reverse = false) const { + if (!is_cal_batch_lod) { + auto lods = batch.lod(); + PADDLE_ENFORCE_EQ(lods.size(), 2UL); + PADDLE_ENFORCE_EQ(lods[1].size(), + static_cast(lod_tensor.dims()[0])); + CopyMatrixRowsFunctor to_batch; + to_batch(context, lod_tensor, lods[1].data(), batch, true); + return; + } + auto lods = lod_tensor.lod(); PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now."); auto lod = lods[0]; @@ -101,10 +112,10 @@ class LoDTensor2BatchFunctor { size_t* batch_starts = batch_lods[0].data(); size_t* seq2batch_idx = batch_lods[1].data(); batch_starts[0] = 0; - for (size_t n = 0; n < num_batch; n++) { + for (int n = 0; n < num_batch; n++) { auto batch_id = static_cast(batch_starts[n]); for (size_t i = 0; i < seq_info.size(); ++i) { - size_t seq_len = seq_info[i].length; + int seq_len = seq_info[i].length; int start = seq_info[i].start; if (n < seq_len) { seq2batch_idx[batch_id] = @@ -132,11 +143,8 @@ class Batch2LoDTensorFunctor { auto in_lod = batch.lod(); PADDLE_ENFORCE_EQ(in_lod.size(), 2UL, "The LoD size of input `batch` should be 2."); - auto out_lod = lod_tensor.lod()[0]; - auto num = out_lod[out_lod.size() - 1]; - PADDLE_ENFORCE_EQ(num, lod_tensor.dims()[0]); - PADDLE_ENFORCE_EQ(num, in_lod[1].size()); - PADDLE_ENFORCE_EQ(num, batch.dims()[0]); + PADDLE_ENFORCE_EQ(in_lod[1].size(), + static_cast(lod_tensor.dims()[0])); CopyMatrixRowsFunctor to_seq; size_t* index = in_lod[1].data(); to_seq(context, batch, index, lod_tensor, false); diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py index 93a4e450e916716e27573d192bace73f271733de..ff75160083f2936dd653a8396254bf16d1752ffa 100644 --- a/python/paddle/v2/framework/tests/test_lstm_op.py +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -52,7 +52,7 @@ def lstm( g = np.dot(h_pre, w_h) # 1 x 4D g = g + x g = np.reshape(g, (1, g.size)) - c_tmp, g_i, g_f, g_o = np.split(g, 4, axis=1) + c, g_i, g_f, g_o = np.split(g, 4, axis=1) if w_c is None: g_i = act_gate(g_i) # 1 x D g_f = act_gate(g_f) # 1 x D @@ -60,7 +60,7 @@ def lstm( w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1) g_i = act_gate(g_i + w_ic * c_pre) # 1 x D g_f = act_gate(g_f + w_fc * c_pre) # 1 x D - c = g_f * c_pre + g_i * act_cand(c_tmp) # 1 x D + c = g_f * c_pre + g_i * act_cand(c) # 1 x D if w_c is None: g_o = act_gate(g_o) # 1 x D @@ -68,8 +68,7 @@ def lstm( _, _, w_oc = np.split(w_c, 3, axis=1) g_o = act_gate(g_o + w_oc * c) # 1 x D h = g_o * act_cell(c) - bg = np.concatenate((act_cand(c_tmp), g_i, g_f, g_o), axis=1) - return h, c, bg + return h, c def _reverse(x, lod): y = np.zeros_like(x) @@ -82,7 +81,6 @@ def lstm( batch_size = len(offset) - 1 hidden = [] cell = [] - gate = [] input = _reverse(input, offset) if is_reverse else input if w_b is not None: input = input + np.tile(w_b, (offset[-1], 1)) @@ -94,96 +92,109 @@ def lstm( c_pre = c0[i] # 1 x D for j in range(seq_len): # compute one step - h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate, - act_cell, act_cand) + h_pre, c_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate, + act_cell, act_cand) hidden.append(h_pre.flatten()) cell.append(c_pre.flatten()) - gate.append(g_pre.flatten()) - hidden = np.array(hidden).astype("float64") - cell = np.array(cell).astype("float64") - gate = np.array(gate).astype("float64") + hidden = np.array(hidden).astype('float64') + cell = np.array(cell).astype('float64') hidden = _reverse(hidden, offset) if is_reverse else hidden cell = _reverse(cell, offset) if is_reverse else cell - assert gate.shape == input.shape assert hidden.shape == (input.shape[0], input.shape[1] / 4) assert cell.shape == (input.shape[0], input.shape[1] / 4) - return hidden, cell, gate + return hidden, cell class TestLstmOp(OpTest): - def set_data(self): - self.lod = [[0, 2, 6, 9]] - self.D = 64 - self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5] + def set_argument(self): + self.lod = [[0, 2, 5, 7]] + self.D = 16 - self.act_gate = "sigmoid" - self.act_cell = "tanh" - self.act_cand = "tanh" + self.act_gate = 'sigmoid' + self.act_cell = 'tanh' + self.act_cand = 'tanh' + self.has_initial_state = True self.is_reverse = False def setUp(self): - self.set_data() - self.op_type = "lstm" + self.set_argument() + self.op_type = 'lstm' T = self.lod[0][-1] N = len(self.lod[0]) - 1 - x = np.random.normal(size=(T, 4 * self.D)).astype("float64") - h0 = np.zeros((N, self.D)).astype("float64") - c0 = np.zeros((N, self.D)).astype("float64") - w = np.random.normal(size=(self.D, 4 * self.D)).astype("float64") - b = np.random.normal(size=(1, 7 * self.D)).astype("float64") + x = np.random.normal(size=(T, 4 * self.D)).astype('float64') + h0 = np.zeros((N, self.D)).astype('float64') + c0 = np.zeros((N, self.D)).astype('float64') + w = np.random.normal(size=(self.D, 4 * self.D)).astype('float64') + b = np.random.normal(size=(1, 7 * self.D)).astype('float64') w_b = b[:, 0:4 * self.D] w_c = b[:, 4 * self.D:] - h, c, g = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse, - ACTVATION[self.act_gate], ACTVATION[self.act_cell], - ACTVATION[self.act_cand]) - - g_sort = np.zeros_like(x) - for i, j in enumerate(self.sort_idx): - g_sort[i, :] = g[j, :] - - self.inputs = { - 'Input': (x, self.lod), - 'H0': h0, - 'C0': c0, - 'Weight': w, - 'Bias': b - } + h, c = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse, + ACTVATION[self.act_gate], ACTVATION[self.act_cell], + ACTVATION[self.act_cand]) + + self.inputs = {'Input': (x, self.lod), 'Weight': w, 'Bias': b} + if self.has_initial_state: + self.inputs['H0'] = h0 + self.inputs['C0'] = c0 + self.outputs = { 'Hidden': (h, self.lod), 'Cell': (c, self.lod), - 'BatchGate': g_sort } self.attrs = { 'usePeepholes': True, 'isReverse': self.is_reverse, - 'gateActivation': 'sigmoid', - 'cellActivation': 'tanh', - 'candidateActivation': 'tanh' + 'gateActivation': self.act_gate, + 'cellActivation': self.act_cell, + 'candidateActivation': self.act_cand } def test_check_output(self): - self.check_output() + self.check_output(atol=1e-8) + + #TODO(qingqing) add more unit testing case + def test_check_grad(self): + # TODO(qingqing) remove folowing lines after the check_grad is refined. + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4) + + +class TestLstmOpHasNoInitial(TestLstmOp): + def set_argument(self): + self.lod = [[0, 2, 5, 7]] + self.D = 16 + + self.act_gate = 'sigmoid' + self.act_cell = 'tanh' + self.act_cand = 'tanh' + + self.has_initial_state = False + self.is_reverse = True class TestLstmOpRerverse(TestLstmOp): - def set_data(self): - self.lod = [[0, 2, 6, 9]] - self.D = 64 - self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5] + def set_argument(self): + self.lod = [[0, 2, 5, 7]] + self.D = 16 - self.act_gate = "sigmoid" - self.act_cell = "tanh" - self.act_cand = "tanh" + self.act_gate = 'sigmoid' + self.act_cell = 'tanh' + self.act_cand = 'tanh' + self.has_initial_state = True self.is_reverse = True -if __name__ == "__main__": +if __name__ == '__main__': unittest.main()