From 8bb824bb93629fbf69d7e93ffc0dca85e726300c Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 12 Sep 2018 00:06:58 +0800 Subject: [PATCH] refine infershape hasinput and hasoutput --- paddle/fluid/framework/operator.cc | 274 ++++++++++-------- paddle/fluid/framework/shape_runtime_infer.h | 86 ------ paddle/fluid/operators/attention_lstm_op.cc | 35 ++- paddle/fluid/operators/fusion_gru_op.cc | 22 +- .../operators/fusion_infershape_define.h | 60 ---- paddle/fluid/operators/fusion_lstm_op.cc | 31 +- 6 files changed, 197 insertions(+), 311 deletions(-) delete mode 100644 paddle/fluid/framework/shape_runtime_infer.h delete mode 100644 paddle/fluid/operators/fusion_infershape_define.h diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 36025db7b..bbd141cb3 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -21,7 +21,6 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/shape_inference.h" -#include "paddle/fluid/framework/shape_runtime_infer.h" #include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/platform/profiler.h" @@ -459,147 +458,184 @@ bool OpSupportGPU(const std::string& op_type) { return false; } -bool RuntimeInferShapeContext::HasInput(const std::string& name) const { - if (!op_.HasInputs(name)) { - return false; - } - auto& ins = Inputs(name); - size_t length = ins.size(); - if (length == 0) { - return false; - } - PADDLE_ENFORCE_EQ(length, 1UL, - "Input %s should not have more than one inputs", name); - auto ipt = ins[0]; - auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); - return var != nullptr; -} +class RuntimeInferShapeContext : public InferShapeContext { + public: + RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) + : op_(op), scope_(scope) {} -bool RuntimeInferShapeContext::HasOutput(const std::string& name) const { - if (!op_.HasOutputs(name)) { - return false; - } - auto& outs = Outputs(name); - size_t length = outs.size(); - if (length == 0) { - return false; - } - PADDLE_ENFORCE_EQ(length, 1UL, - "Output %s should not have more than one inputs", name); - auto ipt = outs[0]; - auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); - return var != nullptr; -} + bool HasInput(const std::string& name) const override { + // has only one input + const auto& ins = op_.Inputs(); + auto it = ins.find(name); + if (it == ins.end()) { + return false; + } + const auto& in = it->second; -bool RuntimeInferShapeContext::HasInputs(const std::string& name) const { - if (!op_.HasInputs(name)) { - return false; - } - auto inputs = op_.Inputs(name); - if (inputs.empty()) { - return false; - } - for (auto& input : inputs) { - if (scope_.FindVar(input) == nullptr) { + if (in.size() != 1 || in[0] == kEmptyVarName) { return false; } + return scope_.FindVar(in[0]) != nullptr; } - return true; -} -bool RuntimeInferShapeContext::HasOutputs(const std::string& name) const { - if (!op_.HasOutputs(name)) { - return false; + bool HasOutput(const std::string& name) const override { + // has only one output + const auto& outs = op_.Outputs(); + auto it = outs.find(name); + if (it == outs.end()) { + return false; + } + const auto& out = it->second; + if (out.size() != 1 || out[0] == kEmptyVarName) { + return false; + } + return scope_.FindVar(out[0]) != nullptr; } - auto outputs = op_.Outputs(name); - if (outputs.empty()) { - return false; + + bool HasInputs(const std::string& name) const override { + if (!op_.HasInputs(name)) { + return false; + } + auto inputs = op_.Inputs(name); + if (inputs.empty()) { + return false; + } + for (auto& input : inputs) { + if (scope_.FindVar(input) == nullptr) { + return false; + } + } + return true; } - for (auto& output : outputs) { - if (scope_.FindVar(output) == nullptr) { + + bool HasOutputs(const std::string& name) const override { + if (!op_.HasOutputs(name)) { + return false; + } + auto outputs = op_.Outputs(name); + if (outputs.empty()) { return false; } + for (auto& output : outputs) { + if (scope_.FindVar(output) == nullptr) { + return false; + } + } + return true; } - return true; -} -void RuntimeInferShapeContext::ShareLoD(const std::string& in, - const std::string& out, size_t i, - size_t j) const { - PADDLE_ENFORCE_LT(i, Inputs(in).size()); - PADDLE_ENFORCE_LT(j, Outputs(out).size()); - Variable* in_var = scope_.FindVar(Inputs(in)[i]); - Variable* out_var = scope_.FindVar(Outputs(out)[j]); - if (!in_var->IsType()) return; - PADDLE_ENFORCE(out_var->IsType(), - "The %d-th output of Output(%s) must be LoDTensor.", j, out); - auto in_tensor = in_var->Get(); - auto* out_tensor = out_var->GetMutable(); - out_tensor->set_lod(in_tensor.lod()); + AttrReader Attrs() const override { return AttrReader(op_.Attrs()); } + + const std::vector& Inputs( + const std::string& name) const override { + return op_.Inputs(name); + } + + const std::vector& Outputs( + const std::string& name) const override { + return op_.Outputs(name); + } + + void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, + size_t j = 0) const override { + PADDLE_ENFORCE_LT(i, Inputs(in).size()); + PADDLE_ENFORCE_LT(j, Outputs(out).size()); + Variable* in_var = scope_.FindVar(Inputs(in)[i]); + Variable* out_var = scope_.FindVar(Outputs(out)[j]); + if (!in_var->IsType()) return; + PADDLE_ENFORCE(out_var->IsType(), + "The %d-th output of Output(%s) must be LoDTensor.", j, out); + auto in_tensor = in_var->Get(); + auto* out_tensor = out_var->GetMutable(); + out_tensor->set_lod(in_tensor.lod()); // TODO(dzhwinter) : reuse ShareLoD in most operators. // Need to call ShareLayout explicitly in sequence related ops. // Shall we have a better method to shared info between in/out Tensor? #ifdef PADDLE_WITH_MKLDNN - // Fix me: ugly workaround below - // Correct solution: - // set_layout() should NOT be called here (i.e. ShareLoD). Instead, - // layout of output tensor should be set "manually" in Compute() - // of each OPKernel. The reason layout should NOT be shared between - // input and output "automatically" (now by InferShape()->ShareLoD()) - // is that layout transform may occur after InferShape(). - // Workaround: - // Skip set_layout() when input layout is kMKLDNN - // This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN - // OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called - // in Compute() - if (in_tensor.layout() != DataLayout::kMKLDNN) + // Fix me: ugly workaround below + // Correct solution: + // set_layout() should NOT be called here (i.e. ShareLoD). Instead, + // layout of output tensor should be set "manually" in Compute() + // of each OPKernel. The reason layout should NOT be shared between + // input and output "automatically" (now by InferShape()->ShareLoD()) + // is that layout transform may occur after InferShape(). + // Workaround: + // Skip set_layout() when input layout is kMKLDNN + // This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN + // OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called + // in Compute() + if (in_tensor.layout() != DataLayout::kMKLDNN) #endif + out_tensor->set_layout(in_tensor.layout()); + } + + void ShareLayout(const std::string& in, const std::string& out, size_t i = 0, + size_t j = 0) const { + PADDLE_ENFORCE_LT(i, Inputs(in).size()); + PADDLE_ENFORCE_LT(j, Outputs(out).size()); + Variable* in_var = scope_.FindVar(Inputs(in)[i]); + Variable* out_var = scope_.FindVar(Outputs(out)[j]); + if (!in_var->IsType()) return; + PADDLE_ENFORCE(out_var->IsType(), + "The %d-th output of Output(%s) must be LoDTensor.", j, out); + auto in_tensor = in_var->Get(); + auto* out_tensor = out_var->GetMutable(); out_tensor->set_layout(in_tensor.layout()); -} + } -void RuntimeInferShapeContext::ShareLayout(const std::string& in, - const std::string& out, size_t i, - size_t j) const { - PADDLE_ENFORCE_LT(i, Inputs(in).size()); - PADDLE_ENFORCE_LT(j, Outputs(out).size()); - Variable* in_var = scope_.FindVar(Inputs(in)[i]); - Variable* out_var = scope_.FindVar(Outputs(out)[j]); - if (!in_var->IsType()) return; - PADDLE_ENFORCE(out_var->IsType(), - "The %d-th output of Output(%s) must be LoDTensor.", j, out); - auto in_tensor = in_var->Get(); - auto* out_tensor = out_var->GetMutable(); - out_tensor->set_layout(in_tensor.layout()); -} - -DDim RuntimeInferShapeContext::GetDim(const std::string& name) const { - Variable* var = scope_.FindVar(name); - PADDLE_ENFORCE_NOT_NULL(var); - if (var->IsType()) { - return var->Get().dims(); - } else if (var->IsType()) { - return var->Get().GetCompleteDims(); - } else { - PADDLE_THROW( - "Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's " - "type_id is %s.", - name, var->Type().name()); + bool IsRuntime() const override { return true; } + + protected: + DDim GetDim(const std::string& name) const override { + Variable* var = scope_.FindVar(name); + PADDLE_ENFORCE_NOT_NULL(var); + if (var->IsType()) { + return var->Get().dims(); + } else if (var->IsType()) { + return var->Get().GetCompleteDims(); + } else { + PADDLE_THROW( + "Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's " + "type_id is %s.", + name, var->Type().name()); + } } -} -void RuntimeInferShapeContext::SetDim(const std::string& name, - const DDim& dim) { - Variable* var = scope_.FindVar(name); - if (var->IsType()) { - var->GetMutable()->Resize(dim); - } else if (var->IsType()) { - var->GetMutable()->set_height(dim[0]); - } else { - PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", name, - var->Type().name()); + std::vector GetRepeatedDims(const std::string& name) const override { + PADDLE_THROW("Only compile time support this method"); } -} + + void SetDim(const std::string& name, const DDim& dim) override { + Variable* var = scope_.FindVar(name); + if (var->IsType()) { + var->GetMutable()->Resize(dim); + } else if (var->IsType()) { + var->GetMutable()->set_height(dim[0]); + } else { + PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", + name, var->Type().name()); + } + } + + void SetRepeatedDims(const std::string& name, + const std::vector& dims) override { + PADDLE_THROW("Only compile time support this method"); + } + + proto::VarType::Type GetVarType(const std::string& name) const override { + auto* var = scope_.FindVar(name); + return ToVarType(var->Type()); + } + + InferShapeVarPtr GetVarPtr(const std::string& name) override { + return scope_.FindVar(name); + } + + private: + const OperatorBase& op_; + const Scope& scope_; +}; static void CheckTensorNANOrInf(const std::string& name, const framework::Tensor& tensor) { diff --git a/paddle/fluid/framework/shape_runtime_infer.h b/paddle/fluid/framework/shape_runtime_infer.h deleted file mode 100644 index 04d4e33f7..000000000 --- a/paddle/fluid/framework/shape_runtime_infer.h +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/shape_inference.h" -#include "paddle/fluid/framework/var_type.h" - -namespace paddle { -namespace framework { - -class RuntimeInferShapeContext : public InferShapeContext { - public: - RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) - : op_(op), scope_(scope) {} - - bool HasInput(const std::string& name) const override; - bool HasOutput(const std::string& name) const override; - bool HasInputs(const std::string& name) const override; - bool HasOutputs(const std::string& name) const override; - - const OperatorBase& OpBase() const { return op_; } - - const Scope& InferScope() const { return scope_; } - AttrReader Attrs() const override { return AttrReader(op_.Attrs()); } - - const std::vector& Inputs( - const std::string& name) const override { - return op_.Inputs(name); - } - - const std::vector& Outputs( - const std::string& name) const override { - return op_.Outputs(name); - } - - void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, - size_t j = 0) const override; - - void ShareLayout(const std::string& in, const std::string& out, size_t i = 0, - size_t j = 0) const; - - bool IsRuntime() const override { return true; } - - protected: - DDim GetDim(const std::string& name) const override; - void SetDim(const std::string& name, const DDim& dim) override; - - std::vector GetRepeatedDims(const std::string& name) const override { - PADDLE_THROW("Only compile time support this method"); - } - void SetRepeatedDims(const std::string& name, - const std::vector& dims) override { - PADDLE_THROW("Only compile time support this method"); - } - - proto::VarType::Type GetVarType(const std::string& name) const override { - auto* var = scope_.FindVar(name); - return ToVarType(var->Type()); - } - - InferShapeVarPtr GetVarPtr(const std::string& name) override { - return scope_.FindVar(name); - } - - private: - const OperatorBase& op_; - const Scope& scope_; -}; - -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 7531aa9a4..9b943440a 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/fluid/operators/attention_lstm_op.h" #include -#include "paddle/fluid/operators/fusion_infershape_define.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" @@ -24,28 +23,28 @@ namespace paddle { namespace operators { void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { - FUSION_INFERSHAPE_INIT; - PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of AttentionLSTM."); - PADDLE_ENFORCE(fair_input("C0"), + PADDLE_ENFORCE(ctx->HasInput("X"), + "Assert only one Input(X) of AttentionLSTM."); + PADDLE_ENFORCE(ctx->HasInput("C0"), "Assert only one Input(C0) of AttentionLSTM."); - PADDLE_ENFORCE(fair_input("LSTMWeight"), + PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"), "Assert only one Input(LSTMWeight) of AttentionLSTM."); - PADDLE_ENFORCE(fair_input("LSTMBias"), + PADDLE_ENFORCE(ctx->HasInput("LSTMBias"), "Assert only one Input(LSTMBias) of AttentionLSTM."); - PADDLE_ENFORCE(fair_input("AttentionWeight"), + PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"), "Assert only one Input(AttentionWeight) of AttentionLSTM."); - PADDLE_ENFORCE(fair_output("Hidden"), + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), "Assert only one Output(Hidden) of AttentionLSTM."); - PADDLE_ENFORCE(fair_output("Cell"), + PADDLE_ENFORCE(ctx->HasOutput("Cell"), "Assert only one Output(Cell) of AttentionLSTM."); - PADDLE_ENFORCE(fair_output("AttentionedX"), + PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"), "Assert only one Output(AttentionedX) of AttentionLSTM."); - PADDLE_ENFORCE(fair_output("AttentionFCOut"), + PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"), "Assert only one Output(AttentionFCOut) of AttentionLSTM."); - PADDLE_ENFORCE(fair_output("LSTMX"), + PADDLE_ENFORCE(ctx->HasOutput("LSTMX"), "Assert only one Output(LSTMX) of AttentionLSTM."); - PADDLE_ENFORCE(fair_output("LSTMOUT"), + PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"), "Assert only one Output(LSTMOUT) of AttentionLSTM."); auto x_dims = ctx->GetInputDim("X"); @@ -66,7 +65,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { auto c_dims = ctx->GetInputDim("C0"); PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); - if (fair_input("H0")) { + if (ctx->HasInput("H0")) { auto h_dims = ctx->GetInputDim("H0"); PADDLE_ENFORCE(h_dims == c_dims, "The dimension of Input(H0) and Input(C0) " @@ -80,7 +79,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { "AttentionWeight shapes must be (%d + %d) * 1.", M, D); PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, "AttentionWeight shapes must be (%d + %d) * 1.", M, D); - if (fair_input("AttentionBias")) { + if (ctx->HasInput("AttentionBias")) { auto atten_b_dims = ctx->GetInputDim("AttentionBias"); PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, "Input(AttentionBias)'s rank must be 2."); @@ -90,7 +89,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { "AttentionBias shapes must be 1 * 1."); } - if (fair_input("AttentionScalar")) { + if (ctx->HasInput("AttentionScalar")) { auto dims = ctx->GetInputDim("AttentionScalar"); PADDLE_ENFORCE_EQ(dims.size(), 2, "Input(AttentionScalar)'s rank must be 2."); @@ -98,10 +97,10 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); } - if (fair_input("AttentionScalarBias")) { + if (ctx->HasInput("AttentionScalarBias")) { auto dims = ctx->GetInputDim("AttentionScalarBias"); PADDLE_ENFORCE( - fair_input("AttentionScalar"), + ctx->HasInput("AttentionScalar"), "AttentionScalar should not be null when have AttentionScalarBias."); PADDLE_ENFORCE_EQ(dims.size(), 2, "Input(AttentionScalarBias)'s rank must be 2."); diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index b10d311f0..31e87d911 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -15,7 +15,6 @@ limitations under the License. */ #include "paddle/fluid/operators/fusion_gru_op.h" #include // for memcpy #include -#include "paddle/fluid/operators/fusion_infershape_define.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" @@ -26,14 +25,13 @@ namespace paddle { namespace operators { void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { - FUSION_INFERSHAPE_INIT; - PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of GRU."); - PADDLE_ENFORCE(fair_input("WeightX"), + PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of GRU."); + PADDLE_ENFORCE(ctx->HasInput("WeightX"), "Assert only one Input(WeightX) of GRU."); - PADDLE_ENFORCE(fair_input("WeightH"), + PADDLE_ENFORCE(ctx->HasInput("WeightH"), "Assert only one Input(WeightH) of GRU."); - PADDLE_ENFORCE(fair_output("XX"), "Assert only one Output(XX) of GRU."); - PADDLE_ENFORCE(fair_output("Hidden"), + PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of GRU."); + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), "Assert only one Output(Hidden) of GRU."); auto x_dims = ctx->GetInputDim("X"); @@ -60,12 +58,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { "should be 3 * %d.", frame_size); - if (fair_input("H0")) { + if (ctx->HasInput("H0")) { auto h0_dims = ctx->GetInputDim("H0"); PADDLE_ENFORCE_EQ(h0_dims[1], frame_size, "The width of H0 must be equal to frame_size."); } - if (fair_input("Bias")) { + if (ctx->HasInput("Bias")) { auto b_dims = ctx->GetInputDim("Bias"); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(b_dims[0], 1, @@ -81,11 +79,11 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { xx_width = wx_dims[1]; } else { xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; - PADDLE_ENFORCE(fair_output("ReorderedH0"), + PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), "Assert only one Output(ReorderedH0) of GRU."); - PADDLE_ENFORCE(fair_output("BatchedInput"), + PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), "Assert only one Output(BatchedInput) of GRU."); - PADDLE_ENFORCE(fair_output("BatchedOut"), + PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"), "Assert only one Output(BatchedOut) of GRU."); ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedOut", out_dims); diff --git a/paddle/fluid/operators/fusion_infershape_define.h b/paddle/fluid/operators/fusion_infershape_define.h deleted file mode 100644 index 89521672b..000000000 --- a/paddle/fluid/operators/fusion_infershape_define.h +++ /dev/null @@ -1,60 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifndef PADDLE_FLUID_OPERATORS_FUSION_INFERSHAPE_DEFINE_H_ -#define PADDLE_FLUID_OPERATORS_FUSION_INFERSHAPE_DEFINE_H_ - -#include -#include "paddle/fluid/framework/shape_runtime_infer.h" - -namespace paddle { -namespace operators { - -#define FUSION_INFERSHAPE_INIT \ - auto* runtime_ctx = dynamic_cast(ctx); \ - if (runtime_ctx == nullptr) { \ - LOG(FATAL) << "Should have runtime infer context"; \ - } \ - const auto& ins = runtime_ctx->OpBase().Inputs(); \ - const auto& outs = runtime_ctx->OpBase().Outputs(); \ - const auto& scope = runtime_ctx->InferScope(); \ - const auto ins_end = ins.end(); \ - const auto outs_end = outs.end(); \ - auto fair_input = [&](const std::string& name) -> bool { \ - auto it = ins.find(name); \ - if (it == ins_end) { \ - return false; \ - } \ - const auto& in = it->second; \ - if (in.size() != 1 || in[0] == framework::kEmptyVarName) { \ - return false; \ - } \ - return scope.FindVar(in[0]) != nullptr; \ - }; \ - auto fair_output = [&](const std::string& name) -> bool { \ - auto it = outs.find(name); \ - if (it == outs_end) { \ - return false; \ - } \ - const auto& out = it->second; \ - if (out.size() != 1 || out[0] == framework::kEmptyVarName) { \ - return false; \ - } \ - return scope.FindVar(out[0]) != nullptr; \ - } - -} // namespace operators -} // namespace paddle - -#endif // PADDLE_FLUID_OPERATORS_FUSION_INFERSHAPE_DEFINE_H_ diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 08af98f85..55e465e3a 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/fluid/operators/fusion_lstm_op.h" #include -#include "paddle/fluid/operators/fusion_infershape_define.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" @@ -25,23 +24,23 @@ namespace paddle { namespace operators { void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { - FUSION_INFERSHAPE_INIT; - PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of LSTM."); - PADDLE_ENFORCE(fair_input("WeightX"), + PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of LSTM."); + PADDLE_ENFORCE(ctx->HasInput("WeightX"), "Assert only one Input(WeightX) of LSTM."); - PADDLE_ENFORCE(fair_input("WeightH"), + PADDLE_ENFORCE(ctx->HasInput("WeightH"), "Assert only one Input(WeightH) of LSTM."); - PADDLE_ENFORCE(fair_input("Bias"), "Assert only one Input(Bias) of LSTM."); - PADDLE_ENFORCE(fair_output("XX"), "Assert only one Output(XX) of LSTM."); - PADDLE_ENFORCE(fair_output("Hidden"), + PADDLE_ENFORCE(ctx->HasInput("Bias"), "Assert only one Input(Bias) of LSTM."); + PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of LSTM."); + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), "Assert only one Output(Hidden) of LSTM."); - PADDLE_ENFORCE(fair_output("Cell"), "Assert only one Output(Cell) of LSTM."); + PADDLE_ENFORCE(ctx->HasOutput("Cell"), + "Assert only one Output(Cell) of LSTM."); auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); - if (fair_input("H0")) { - PADDLE_ENFORCE(fair_input("C0"), + if (ctx->HasInput("H0")) { + PADDLE_ENFORCE(ctx->HasInput("C0"), "Input(Cell) and Input(Hidden) of LSTM should not " "be null at the same time."); auto h_dims = ctx->GetInputDim("H0"); @@ -93,15 +92,15 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { xx_width = wx_dims[1]; } else { xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; - PADDLE_ENFORCE(fair_output("BatchedInput"), + PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), "Assert only one Output(BatchedInput) of LSTM."); - PADDLE_ENFORCE(fair_output("BatchedHidden"), + PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), "Assert only one Output(BatchedHidden) of LSTM."); - PADDLE_ENFORCE(fair_output("BatchedCell"), + PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"), "Assert only one Output(BatchedCell) of LSTM."); - PADDLE_ENFORCE(fair_output("ReorderedH0"), + PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), "Assert only one Output(ReorderedH0) of LSTM"); - PADDLE_ENFORCE(fair_output("ReorderedC0"), + PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"), "Assert only one Output(ReorderedC0) of LSTM."); ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedHidden", out_dims); -- GitLab