From 916f42bcbf7bc308f2135be5f341b8628cc883dc Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 11 Sep 2018 18:00:20 +0800 Subject: [PATCH] refine fusion gru infershape --- paddle/fluid/operators/fusion_gru_op.cc | 65 +++++++++++++++++++------ 1 file changed, 49 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index 916f84cb4a7..bcdcb2ac4da 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fusion_gru_op.h" #include // for memcpy #include +#include "paddle/fluid/framework/shape_runtime_infer.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" @@ -25,14 +26,46 @@ namespace paddle { namespace operators { void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasInput("WeightX"), - "Input(WeightX) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasInput("WeightH"), - "Input(WeightH) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Hidden"), - "Output(Hidden) of GRU should not be null."); + auto* runtime_ctx = dynamic_cast(ctx); + if (runtime_ctx == nullptr) { + LOG(FATAL) << "Should have runtime infer context"; + } + const auto& ins = runtime_ctx->OpBase().Inputs(); + const auto& outs = runtime_ctx->OpBase().Outputs(); + const auto& scope = runtime_ctx->InferScope(); + const auto ins_end = ins.end(); + const auto outs_end = outs.end(); + auto fair_input = [&](const std::string& name) -> bool { + auto it = ins.find(name); + if (it == ins_end) { + return false; + } + const auto& in = it->second; + if (in.size() != 1 || in[0] == framework::kEmptyVarName) { + return false; + } + return scope.FindVar(in[0]) != nullptr; + }; + auto fair_output = [&](const std::string& name) -> bool { + auto it = outs.find(name); + if (it == outs_end) { + return false; + } + const auto& out = it->second; + if (out.size() != 1 || out[0] == framework::kEmptyVarName) { + return false; + } + return scope.FindVar(out[0]) != nullptr; + }; + + PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of GRU."); + PADDLE_ENFORCE(fair_input("WeightX"), + "Assert only one Input(WeightX) of GRU."); + PADDLE_ENFORCE(fair_input("WeightH"), + "Assert only one Input(WeightH) of GRU."); + PADDLE_ENFORCE(fair_output("XX"), "Assert only one Output(XX) of GRU."); + PADDLE_ENFORCE(fair_output("Hidden"), + "Assert only one Output(Hidden) of GRU."); auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); @@ -58,12 +91,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { "should be 3 * %d.", frame_size); - if (ctx->HasInput("H0")) { + if (fair_input("H0")) { auto h0_dims = ctx->GetInputDim("H0"); PADDLE_ENFORCE_EQ(h0_dims[1], frame_size, "The width of H0 must be equal to frame_size."); } - if (ctx->HasInput("Bias")) { + if (fair_input("Bias")) { auto b_dims = ctx->GetInputDim("Bias"); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(b_dims[0], 1, @@ -79,12 +112,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { xx_width = wx_dims[1]; } else { xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; - PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), - "Output(ReorderedH0) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), - "Output(BatchedInput) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"), - "Output(BatchedOut) of GRU should not be null."); + PADDLE_ENFORCE(fair_output("ReorderedH0"), + "Assert only one Output(ReorderedH0) of GRU."); + PADDLE_ENFORCE(fair_output("BatchedInput"), + "Assert only one Output(BatchedInput) of GRU."); + PADDLE_ENFORCE(fair_output("BatchedOut"), + "Assert only one Output(BatchedOut) of GRU."); ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedOut", out_dims); } -- GitLab