From c7c25067338dc147c5b6b282ce34205f4bfee373 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 27 Aug 2018 13:12:33 +0800 Subject: [PATCH] add forward implementation --- .../operators/fusion_seq_concat_fc_op.cc | 318 +++++------------- 1 file changed, 83 insertions(+), 235 deletions(-) diff --git a/paddle/fluid/operators/fusion_seq_concat_fc_op.cc b/paddle/fluid/operators/fusion_seq_concat_fc_op.cc index 810df3c3e..203ebaf3e 100644 --- a/paddle/fluid/operators/fusion_seq_concat_fc_op.cc +++ b/paddle/fluid/operators/fusion_seq_concat_fc_op.cc @@ -25,30 +25,15 @@ namespace operators { void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasInput("C0"), - "Input(C0) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"), - "Input(LSTMWeight) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasInput("LSTMBias"), - "Input(LSTMBias) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE( - ctx->HasInput("AttentionWeight"), - "Input(AttentionWeight) of FusionSeqConcatFC should not be null."); - - PADDLE_ENFORCE(ctx->HasOutput("Hidden"), - "Output(Hidden) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Cell"), - "Output(Cell) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("AttentionedX"), - "Output(AttentionedX) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("AttentionFCOut"), - "Output(AttentionFCOut) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("LSTMX"), - "Output(LSTMX) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"), - "Output(LSTMOUT) of FusionSeqConcatFC should not be null."); + PADDLE_ENFORCE(ctx->HasInput("FCWeight"), + "Input(FCWeight) of FusionSeqConcatFC should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of FusionSeqConcatFC should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("FCOut"), + "Output(FCOut) of FusionSeqConcatFC should not be null."); + + // need check fc height = all inputs width sum auto x_dims = ctx->GetInputDim("X"); const int M = x_dims[1]; @@ -120,6 +105,9 @@ void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const { // AttentionFCOut should be reshape as (maxseqlen,1) in runtime ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Cell"); + + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType( @@ -131,95 +119,37 @@ framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType( void FusionSeqConcatFCOpMaker::Make() { AddInput("X", - "(LoDTensor) the input is a LodTensor, which support " - "variable-time length input sequence. The underlying tensor in " - "this LoDTensor is a matrix with shape (T X M), where T is the " - "total time steps in this mini-batch, M is the dim size of x."); - AddInput("C0", - "(Tensor) LSTM C0" - "This is a tensor with shape (N x D), where N is the batch size, D " - "is the gate size." - "C0 is necessary because of attention."); - AddInput("H0", - "(Tensor, optional) LSTM H0" - "This is a tensor with shape (N x D), where N is the " - "batch size and D is the gate size.") - .AsDispensable(); - AddInput("AttentionWeight", - "(Tensor) the weights of attention fc. Always relu the fc result." - "The shape is ((M+D) x 1), where M is the dim size of x, D is the " - "gate size of LSTM."); - AddInput("AttentionBias", - "(Tensor, optional) the bias of attention fc." - "The shape is (1 x 1)") - .AsDispensable(); - AddInput("AttentionScalar", - "(Tensor, optional) the scalar on the result of attentioned fc. " - "Always relu the Scalar." - "The shape is (1 x 1)") - .AsDispensable(); - AddInput("AttentionScalarBias", - "(Tensor, optional) the scalar bias of attention fc." - "The shape is (1 x 1)") - .AsDispensable(); - AddInput("LSTMWeight", - "(Tensor) the combined weight of LSTM" - " - The shape is ((D+M) x 4D), where D is the hidden gate size, M " - "is the dim size of x" - " - Weight = {W_forget, W_input, W_output, W_cell}"); - AddInput("LSTMBias", - "(Tensor) the combined bias of LSTM, shape (1x4D)." - "Note: we should add the bias of hidden and context accorindg to " - "the same gate: " - "{B_forget, B_input, B_output, B_cell}"); - AddOutput("Hidden", - "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. " - "The shape is (T x D), and lod is the same with the `Input`."); - AddOutput("Cell", - "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. " - "The shape is (T x D), and lod is the same with the `Input`."); - AddOutput("AttentionedX", - "(Tensor) shape is (T x 1), the result after X * AttentionWeight," - " where T is the total time steps in this mini-batch," - " D is the hidden size.") - .AsIntermediate(); - AddOutput("AttentionFCOut", - "(Tensor) (max_seq_len, 1), compute at each step.") - .AsIntermediate(); - AddOutput("LSTMX", - "(Tensor) the input X of LSTM for each step." - "Shape is (1 x M), where M is the x frame size") - .AsIntermediate(); + "(LoDTensor) input LodDTensors, the first one must be have ref lod " + "for sequence expand, and the rest input should have same lod.") + .AsDuplicable(); + AddInput("FCWeight", "(Tensor) the weights of fc."); + AddInput("FCBias", "(Tensor, optional) the bias of fc.").AsDispensable(); + AddOutput("Out", "(LoDTensor) Output LodTensor."); AddOutput( - "LSTMOUT", - "(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step." - "Shape is (1 x 4D), where M is the x frame size") + "FCOut", + "(Tensor) the intermediate tensor to keep the result of fc." + "Shape is (N x D), where N is the batch size, D is the output dim of fc") .AsIntermediate(); - AddAttr("gate_activation", - "(string, default: sigmoid)" - "The activation for input gate, forget gate and output " - "gate, `sigmoid` by default.") - .SetDefault("sigmoid") - .InEnum({"sigmoid", "tanh", "relu", "identity"}); - AddAttr("cell_activation", - "(string, default: tanh)" - "The activation for cell output, `tanh` by defalut.") - .SetDefault("tanh") - .InEnum({"sigmoid", "tanh", "relu", "identity"}); - AddAttr("candidate_activation", - "(string, default: tanh)" - "The activation for candidate hidden state, " - "`tanh` by default.") - .SetDefault("tanh") + AddAttr("fc_activation", + "(string, default: identity)" + "The activation for the result of fc." + "`identity` by default.") + .SetDefault("identity") .InEnum({"sigmoid", "tanh", "relu", "identity"}); AddComment(R"DOC( Fusion Sequence expand + concat + fc Operator. -Only support seq_expand ref_level=0, +All below conditions should be meet: -and the ref lod of seq_expand level is the first input of concat, +The ref_level of seq_expand should be 0. -and the other inputs should have same lod and same batch size of ref lod. +The ref lod of seq_expand level is the first input of concat. + +The other inputs should have same lod and same batch size of ref lod. + +The seq len of other inputs should be 1. + +The concat axis should be 1. )DOC"); } @@ -257,150 +187,68 @@ class FusionSeqConcatFCKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { using DeviceContext = paddle::platform::CPUDeviceContext; + auto* ins = ctx.Input("X"); + auto* w = ctx.Input("FCWeight"); + auto* b = ctx.Input("FCBias"); - auto* x = ctx.Input("X"); - auto* h0 = ctx.Input("H0"); - auto* c0 = ctx.Input("C0"); - auto* atten_w = ctx.Input("AttentionWeight"); - auto* atten_b = ctx.Input("AttentionBias"); - auto* atten_scalar = ctx.Input("AttentionScalar"); - auto* atten_scalar_bias = ctx.Input("AttentionScalarBias"); - auto* lstm_w = ctx.Input("LSTMWeight"); - auto* lstm_b = ctx.Input("LSTMBias"); - - auto* hidden_out = ctx.Output("Hidden"); - auto* cell_out = ctx.Output("Cell"); - auto* atted_x = ctx.Output("AttentionedX"); - auto* fc_out = ctx.Output("AttentionFCOut"); - auto* lstm_x = ctx.Output("LSTMX"); - auto* lstm_out = ctx.Output("LSTMOUT"); - - // some shape should be reshape here since infershape can not get lod info - auto x_lod = x->lod(); - const int N = x_lod[0].size() - 1; // batch size - auto x_dims = x->dims(); // T x M - auto w_dims = lstm_w->dims(); // (D+M) x 4D - const int total_T = x_dims[0]; - const int M = x_dims[1]; // x frame size - const int D = w_dims[1] / 4; // gate frame size - const int D2 = D * 2; - const int D3 = D * 3; - const int D4 = w_dims[1]; - int max_seq_len = x_lod[0][1]; - for (int i = 1; i < N; ++i) { - int len = x_lod[0][i + 1] - x_lod[0][i]; - max_seq_len = max_seq_len < len ? len : max_seq_len; - } - PADDLE_ENFORCE_EQ(x_lod.size(), 1, "Input(X)'s lod size must be 1."); - PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); - fc_out->Resize({max_seq_len, 1}); - - std::function act_gate, act_cell, act_cand; - auto& act_gate_str = ctx.Attr("gate_activation"); - auto& act_cell_str = ctx.Attr("cell_activation"); - auto& act_cand_str = ctx.Attr("candidate_activation"); + auto* out = ctx.Output("Out"); + auto* fc_out = ctx.Output("FCOUT"); + + std::function fc_act; + auto& fc_act_str = ctx.Attr("fc_activation"); if (platform::jit::MayIUse(platform::jit::avx)) { math::VecActivations act_functor; - act_gate = act_functor(act_gate_str); - act_cell = act_functor(act_cell_str); - act_cand = act_functor(act_cand_str); + fc_act = act_functor(fc_act_str); } else { math::VecActivations act_functor; - act_gate = act_functor(act_gate_str); - act_cell = act_functor(act_cell_str); - act_cand = act_functor(act_cand_str); + fc_act = act_functor(fc_act_str); } - const T* x_data = x->data(); - const T* h0_data = h0 ? h0->data() : NULL; - const T* c0_data = c0->data(); - const T* lstm_w_data = lstm_w->data(); - const T* lstm_b_data = lstm_b->data(); - const T* atten_w_data = atten_w->data(); - const T* atten_b_data = atten_b ? atten_b->data() : NULL; - const T* atten_scalar_data = atten_scalar ? atten_scalar->data() : NULL; - const T* atten_scalar_bias_data = - atten_scalar_bias ? atten_scalar_bias->data() : NULL; - - T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); - T* cell_out_data = cell_out->mutable_data(ctx.GetPlace()); - T* atted_x_data = atted_x->mutable_data(ctx.GetPlace()); + PADDLE_ENFORCE_GT(ins.size(), 1, "Input(X)'s size must larger than 1."); + auto* ref_in = ins[0]; + auto ref_in_lod = ref_in->lod(); + const int N = ref_in_lod[0].size() - 1; + auto ref_in_dims = ref_in->dims(); // T x M0 + auto w_dims = w->dims(); // (M0+M1+M2+..) x D + const int total_T = ref_in_dims[0]; + const int M0 = ref_in_dims[1]; + const int M1 = ins[1]->dims()[1]; + const int D = w_dims[1]; + + const T* ref_in_data = + ref_in->data(); // size should be check at infershape + const T* in1_data = ins[1]->data(); + const T* w_data = w->data(); + T* out_data = out->mutable_data(ctx.GetPlace()); T* fc_out_data = fc_out->mutable_data(ctx.GetPlace()); - T* lstm_x_data = lstm_x->mutable_data(ctx.GetPlace()); - T* lstm_out_data = lstm_out->mutable_data(ctx.GetPlace()); - // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 auto blas = math::GetBlas(ctx); - math::FCCompute(blas, total_T, 1, M, x_data, atten_w_data, - atted_x_data, atten_b_data); - - const T* cur_atten_x_data = atted_x_data; - const T* cur_x_data = x_data; - const T* prev_cell_data = NULL; - const T* prev_hidden_data = NULL; - T* cur_cell_out_data = cell_out_data; - T* cur_hidden_out_data = hidden_out_data; + math::FCCompute(blas, total_T, D, M0, ref_in_data, w_data, + out_data, b ? b->data() : NULL); + w_data = w_data + M0 * D; + + // first one use write on + blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data); + w_data = w_data + M1 * D; + for (int i = 2; i < ins.size(); ++i) { + // add on + const T* in_data = ins[i]->data(); + const int K = ins[i]->dims()[1]; + blas.GEMM(CblasNoTrans, CblasNoTrans, N, D, K, static_cast(1), in_data, + K, w_data, D, static_cast(1), fc_out_data, D); + w_data = w_data + K * D; + } + for (int i = 0; i < N; ++i) { - int seq_len = x_lod[0][i + 1] - x_lod[0][i]; - prev_cell_data = c0_data + i * D; - prev_hidden_data = h0_data ? h0_data + i * D : NULL; + int seq_len = ref_in_lod[0][i + 1] - ref_in_lod[0][i]; + T* src = fc_out_data + i * D; for (int step = 0; step < seq_len; ++step) { - /// 1. compute attention vector - // 1a. prev_cell(1xD) * fc(D) rest part of atten_wgt - T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M); - // 1b. add cell bias and relu - bias_relu(seq_len, cur_atten_x_data, &prev_cell_bias, fc_out_data); - // 1c. fc scalar - if (atten_scalar_data) { - blas.SCAL(seq_len, *atten_scalar_data, fc_out_data); - bias_relu(seq_len, fc_out_data, atten_scalar_bias_data, - fc_out_data); - } - // 1d. softmax - vec_softmax(seq_len, fc_out_data, fc_out_data); - // mul x(seq_len*M) and sum pool - math::FCCompute(blas, 1, M, seq_len, fc_out_data, - cur_x_data, lstm_x_data); - - /// 2. compute LSTM step - // lstm weight : concat[forget , input , output , tilde] - // shape : (D + M) x (4 * D) - // fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D - blas.MatMul(1, D4, M, lstm_x_data, lstm_w_data + D * D4, lstm_out_data); - if (prev_hidden_data) { - blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), - prev_hidden_data, D, lstm_w_data, D4, static_cast(1), - lstm_out_data, D4); - } - // since input is 1xM, so can use add bias - blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data); - - // gate act: sigmoid - act_gate(D3, lstm_out_data, lstm_out_data); - // candicate act: tanh - act_cand(D, lstm_out_data + D3, lstm_out_data + D3); - - // a = forget * prev_cell - blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data); - - // b = input * tilde - blas.VMUL(D, lstm_out_data + D, lstm_out_data + D3, lstm_out_data + D); - - // cell_out = a + b - blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data); - - // state act tanh(cell_out) * output_gate - act_cell(D, cur_cell_out_data, lstm_out_data); - blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data); - - prev_hidden_data = cur_hidden_out_data; - prev_cell_data = cur_cell_out_data; - cur_cell_out_data = cur_cell_out_data + D; - cur_hidden_out_data = cur_hidden_out_data + D; + blas.VADD(D, out_data, src, out_data); + out_data = out_data + D; } - cur_x_data = cur_x_data + seq_len * M; - cur_atten_x_data = cur_atten_x_data + seq_len; } + + fc_act(out_dims[0] * out_dims[1], out_data, out_data); } }; -- GitLab