diff --git a/paddle/fluid/operators/fusion_seq_concat_fc_op.cc b/paddle/fluid/operators/fusion_seq_concat_fc_op.cc index 203ebaf3e28e7972a166d45b296e2e043213b41b..f61c822abf642c7975b3b3b53eb91c6bfec03fa2 100644 --- a/paddle/fluid/operators/fusion_seq_concat_fc_op.cc +++ b/paddle/fluid/operators/fusion_seq_concat_fc_op.cc @@ -23,91 +23,36 @@ namespace paddle { namespace operators { void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of FusionSeqConcatFC should not be null."); + PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 1UL, + "Inputs(X) of FusionSeqConcatFCOp should larger than 1."); PADDLE_ENFORCE(ctx->HasInput("FCWeight"), "Input(FCWeight) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of FusionSeqConcatFC should not be null."); PADDLE_ENFORCE(ctx->HasOutput("FCOut"), "Output(FCOut) of FusionSeqConcatFC should not be null."); - // need check fc height = all inputs width sum - - auto x_dims = ctx->GetInputDim("X"); - const int M = x_dims[1]; - PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); - - auto w_dims = ctx->GetInputDim("LSTMWeight"); - const int D = w_dims[1] / 4; - PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2."); - PADDLE_ENFORCE_EQ(w_dims[0], D + M, - "LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D); - - auto b_dims = ctx->GetInputDim("LSTMBias"); - PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); - PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D); - PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D); - - auto c_dims = ctx->GetInputDim("C0"); - PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); - PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); - if (ctx->HasInput("H0")) { - auto h_dims = ctx->GetInputDim("H0"); - PADDLE_ENFORCE(h_dims == c_dims, - "The dimension of Input(H0) and Input(C0) " - "should be the same."); - } - - auto atten_w_dims = ctx->GetInputDim("AttentionWeight"); - PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2, - "Input(AttentionWeight)'s rank must be 2."); - PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D, - "AttentionWeight shapes must be (%d + %d) * 1.", M, D); - PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, - "AttentionWeight shapes must be (%d + %d) * 1.", M, D); - if (ctx->HasInput("AttentionBias")) { - auto atten_b_dims = ctx->GetInputDim("AttentionBias"); - PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, - "Input(AttentionBias)'s rank must be 2."); - PADDLE_ENFORCE_EQ(atten_b_dims[0], 1, - "AttentionBias shapes must be 1 * 1."); - PADDLE_ENFORCE_EQ(atten_b_dims[1], 1, - "AttentionBias shapes must be 1 * 1."); + auto ins_dims = ctx->GetInputsDim("X"); + auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D + PADDLE_ENFORCE_EQ(w_dims.size(), 2UL, "Input(FCWeight)'s rank must be 2."); + const int D = w_dims[1]; + int sum = ins_dims[0][1]; + for (size_t i = 1; i < ins_dims.size(); ++i) { + sum += ins_dims[i][1]; } - - if (ctx->HasInput("AttentionScalar")) { - auto dims = ctx->GetInputDim("AttentionScalar"); - PADDLE_ENFORCE_EQ(dims.size(), 2, - "Input(AttentionScalar)'s rank must be 2."); - PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1."); - PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(sum, w_dims[0], + "FC height should be sum of all inputs width."); + if (ctx->HasInput("FCBias")) { + auto b_dims = ctx->GetInputDim("FCBias"); + PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(FCBias)'s rank must be 2."); + PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1 * %d.", D); + PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1 * %d.", D); } - if (ctx->HasInput("AttentionScalarBias")) { - auto dims = ctx->GetInputDim("AttentionScalarBias"); - PADDLE_ENFORCE( - ctx->HasInput("AttentionScalar"), - "AttentionScalar should not be null when have AttentionScalarBias."); - PADDLE_ENFORCE_EQ(dims.size(), 2, - "Input(AttentionScalarBias)'s rank must be 2."); - PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1."); - PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1."); - } - - framework::DDim out_dims({x_dims[0], D}); - ctx->SetOutputDim("Hidden", out_dims); - ctx->SetOutputDim("Cell", out_dims); - ctx->SetOutputDim("AttentionedX", {x_dims[0], 1}); - ctx->SetOutputDim("LSTMX", {1, M}); - ctx->SetOutputDim("LSTMOUT", {1, 4 * D}); - // AttentionFCOut should be reshape as (maxseqlen,1) in runtime - ctx->ShareLoD("X", "Hidden"); - ctx->ShareLoD("X", "Cell"); - - ctx->SetOutputDim("Out", out_dims); - ctx->ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Out", {ins_dims[0][0], D}); + // fcout should be reshape when run since can not get lod in infershape + // explicit share the ref lod + ctx->ShareLoD("X", "Out", 0); } framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType( @@ -154,46 +99,46 @@ The concat axis should be 1. )DOC"); } -// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0; -template -inline void bias_relu(const int n, const T* x, const T* bias, T* y) { - if (bias) { - math::vec_add_bias(n, *bias, x, y); - math::vec_relu(n, y, y); - } else { - math::vec_relu(n, x, y); - } -} - -template -inline void vec_softmax(const int n, const T* x, T* y) { - T scalar = x[0]; - // max - for (int i = 1; i < n; ++i) { - scalar = scalar < x[i] ? x[i] : scalar; - } - math::vec_add_bias(n, -scalar, x, y); // sub - math::vec_exp(n, y, y); // exp - // sum - scalar = T(0); - for (int i = 0; i < n; ++i) { - scalar += y[i]; - } - math::vec_scal(n, static_cast(1) / scalar, y); // scale -} - template class FusionSeqConcatFCKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { using DeviceContext = paddle::platform::CPUDeviceContext; - auto* ins = ctx.Input("X"); + auto ins = ctx.MultiInput("X"); auto* w = ctx.Input("FCWeight"); auto* b = ctx.Input("FCBias"); - auto* out = ctx.Output("Out"); auto* fc_out = ctx.Output("FCOUT"); + auto* ref_in = ins[0]; + auto ref_lod = ref_in->lod(); + auto in1_lod = ins[1]->lod(); + auto ref_dims = ref_in->dims(); // T x M0 + auto in1_dims = ins[1]->dims(); // N x M1 + auto w_dims = w->dims(); + const int N = ref_lod[0].size() - 1; + const int total_T = ref_dims[0]; + const int M0 = ref_dims[1]; + const int M1 = in1_dims[1]; + const int D = w_dims[1]; + + // some check and fcout should be reshape here + // since infershape can not get lod info + PADDLE_ENFORCE_EQ(ref_lod.size(), 1UL, "Only support input lod size is 1."); + PADDLE_ENFORCE_EQ(in1_lod.size(), 1UL, "Only support input lod size is 1."); + PADDLE_ENFORCE_EQ(in1_lod[0].size() - 1, N, + "Batch size of all inputs should be equal."); + PADDLE_ENFORCE_EQ(in1_lod[0][N], N, + "Seq_length of other inputs should be 1."); + PADDLE_ENFORCE_EQ(in1_dims[0], N, "input height should be batch size."); + for (size_t i = 2; i < ins.size(); ++i) { + PADDLE_ENFORCE_EQ(ins[i]->dims()[0], N, + "All other inputs height should be equal"); + PADDLE_ENFORCE_EQ(ins[i]->lod(), in1_lod, + "All other inputs should have same lod"); + } + fc_out->Resize({N, D}); + std::function fc_act; auto& fc_act_str = ctx.Attr("fc_activation"); if (platform::jit::MayIUse(platform::jit::avx)) { @@ -204,19 +149,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel { fc_act = act_functor(fc_act_str); } - PADDLE_ENFORCE_GT(ins.size(), 1, "Input(X)'s size must larger than 1."); - auto* ref_in = ins[0]; - auto ref_in_lod = ref_in->lod(); - const int N = ref_in_lod[0].size() - 1; - auto ref_in_dims = ref_in->dims(); // T x M0 - auto w_dims = w->dims(); // (M0+M1+M2+..) x D - const int total_T = ref_in_dims[0]; - const int M0 = ref_in_dims[1]; - const int M1 = ins[1]->dims()[1]; - const int D = w_dims[1]; - - const T* ref_in_data = - ref_in->data(); // size should be check at infershape + const T* ref_in_data = ref_in->data(); const T* in1_data = ins[1]->data(); const T* w_data = w->data(); T* out_data = out->mutable_data(ctx.GetPlace()); @@ -226,11 +159,10 @@ class FusionSeqConcatFCKernel : public framework::OpKernel { math::FCCompute(blas, total_T, D, M0, ref_in_data, w_data, out_data, b ? b->data() : NULL); w_data = w_data + M0 * D; - // first one use write on blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data); w_data = w_data + M1 * D; - for (int i = 2; i < ins.size(); ++i) { + for (size_t i = 2; i < ins.size(); ++i) { // add on const T* in_data = ins[i]->data(); const int K = ins[i]->dims()[1]; @@ -240,7 +172,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel { } for (int i = 0; i < N; ++i) { - int seq_len = ref_in_lod[0][i + 1] - ref_in_lod[0][i]; + int seq_len = ref_lod[0][i + 1] - ref_lod[0][i]; T* src = fc_out_data + i * D; for (int step = 0; step < seq_len; ++step) { blas.VADD(D, out_data, src, out_data); @@ -248,7 +180,7 @@ class FusionSeqConcatFCKernel : public framework::OpKernel { } } - fc_act(out_dims[0] * out_dims[1], out_data, out_data); + fc_act(total_T * D, out_data, out_data); } };