From a5556d44175931682bb049451639948c0da7ed6e Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 11 Sep 2018 17:49:54 +0800 Subject: [PATCH] refine attentionlstm infershape --- paddle/fluid/operators/attention_lstm_op.cc | 88 ++++++++++++++------- 1 file changed, 60 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 39b0c8569..ac4ddb550 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/attention_lstm_op.h" #include +#include "paddle/fluid/framework/shape_runtime_infer.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" @@ -23,29 +24,60 @@ namespace paddle { namespace operators { void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of AttentionLSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("C0"), - "Input(C0) of AttentionLSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"), - "Input(LSTMWeight) of AttentionLSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("LSTMBias"), - "Input(LSTMBias) of AttentionLSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"), - "Input(AttentionWeight) of AttentionLSTM should not be null."); - - PADDLE_ENFORCE(ctx->HasOutput("Hidden"), - "Output(Hidden) of AttentionLSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Cell"), - "Output(Cell) of AttentionLSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"), - "Output(AttentionedX) of AttentionLSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"), - "Output(AttentionFCOut) of AttentionLSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("LSTMX"), - "Output(LSTMX) of AttentionLSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"), - "Output(LSTMOUT) of AttentionLSTM should not be null."); + auto* runtime_ctx = dynamic_cast(ctx); + if (runtime_ctx == nullptr) { + LOG(FATAL) << "Should have runtime infer context"; + } + const auto& ins = runtime_ctx->OpBase().Inputs(); + const auto& outs = runtime_ctx->OpBase().Outputs(); + const auto& scope = runtime_ctx->InferScope(); + const auto ins_end = ins.end(); + const auto outs_end = outs.end(); + auto fair_input = [&](const std::string& name) -> bool { + auto it = ins.find(name); + if (it == ins_end) { + return false; + } + const auto& in = it->second; + if (in.size() != 1 || in[0] == framework::kEmptyVarName) { + return false; + } + return scope.FindVar(in[0]) != nullptr; + }; + auto fair_output = [&](const std::string& name) -> bool { + auto it = outs.find(name); + if (it == outs_end) { + return false; + } + const auto& out = it->second; + if (out.size() != 1 || out[0] == framework::kEmptyVarName) { + return false; + } + return scope.FindVar(out[0]) != nullptr; + }; + + PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of AttentionLSTM."); + PADDLE_ENFORCE(fair_input("C0"), + "Assert only one Input(C0) of AttentionLSTM."); + PADDLE_ENFORCE(fair_input("LSTMWeight"), + "Assert only one Input(LSTMWeight) of AttentionLSTM."); + PADDLE_ENFORCE(fair_input("LSTMBias"), + "Assert only one Input(LSTMBias) of AttentionLSTM."); + PADDLE_ENFORCE(fair_input("AttentionWeight"), + "Assert only one Input(AttentionWeight) of AttentionLSTM."); + + PADDLE_ENFORCE(fair_output("Hidden"), + "Assert only one Output(Hidden) of AttentionLSTM."); + PADDLE_ENFORCE(fair_output("Cell"), + "Assert only one Output(Cell) of AttentionLSTM."); + PADDLE_ENFORCE(fair_output("AttentionedX"), + "Assert only one Output(AttentionedX) of AttentionLSTM."); + PADDLE_ENFORCE(fair_output("AttentionFCOut"), + "Assert only one Output(AttentionFCOut) of AttentionLSTM."); + PADDLE_ENFORCE(fair_output("LSTMX"), + "Assert only one Output(LSTMX) of AttentionLSTM."); + PADDLE_ENFORCE(fair_output("LSTMOUT"), + "Assert only one Output(LSTMOUT) of AttentionLSTM."); auto x_dims = ctx->GetInputDim("X"); const int M = x_dims[1]; @@ -65,7 +97,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { auto c_dims = ctx->GetInputDim("C0"); PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); - if (ctx->HasInput("H0")) { + if (fair_input("H0")) { auto h_dims = ctx->GetInputDim("H0"); PADDLE_ENFORCE(h_dims == c_dims, "The dimension of Input(H0) and Input(C0) " @@ -79,7 +111,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { "AttentionWeight shapes must be (%d + %d) * 1.", M, D); PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, "AttentionWeight shapes must be (%d + %d) * 1.", M, D); - if (ctx->HasInput("AttentionBias")) { + if (fair_input("AttentionBias")) { auto atten_b_dims = ctx->GetInputDim("AttentionBias"); PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, "Input(AttentionBias)'s rank must be 2."); @@ -89,7 +121,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { "AttentionBias shapes must be 1 * 1."); } - if (ctx->HasInput("AttentionScalar")) { + if (fair_input("AttentionScalar")) { auto dims = ctx->GetInputDim("AttentionScalar"); PADDLE_ENFORCE_EQ(dims.size(), 2, "Input(AttentionScalar)'s rank must be 2."); @@ -97,10 +129,10 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); } - if (ctx->HasInput("AttentionScalarBias")) { + if (fair_input("AttentionScalarBias")) { auto dims = ctx->GetInputDim("AttentionScalarBias"); PADDLE_ENFORCE( - ctx->HasInput("AttentionScalar"), + fair_input("AttentionScalar"), "AttentionScalar should not be null when have AttentionScalarBias."); PADDLE_ENFORCE_EQ(dims.size(), 2, "Input(AttentionScalarBias)'s rank must be 2."); -- GitLab