From 8a1abe54d797de7c4f17ab92d2268c3cebf83b66 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 11 Sep 2018 18:30:49 +0800 Subject: [PATCH] clean fusion infershape code --- paddle/fluid/operators/attention_lstm_op.cc | 35 +---------- paddle/fluid/operators/fusion_gru_op.cc | 35 +---------- .../operators/fusion_infershape_define.h | 60 +++++++++++++++++++ paddle/fluid/operators/fusion_lstm_op.cc | 35 +---------- 4 files changed, 66 insertions(+), 99 deletions(-) create mode 100644 paddle/fluid/operators/fusion_infershape_define.h diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index ac4ddb55025..7531aa9a464 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/attention_lstm_op.h" #include -#include "paddle/fluid/framework/shape_runtime_infer.h" +#include "paddle/fluid/operators/fusion_infershape_define.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" @@ -24,38 +24,7 @@ namespace paddle { namespace operators { void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { - auto* runtime_ctx = dynamic_cast(ctx); - if (runtime_ctx == nullptr) { - LOG(FATAL) << "Should have runtime infer context"; - } - const auto& ins = runtime_ctx->OpBase().Inputs(); - const auto& outs = runtime_ctx->OpBase().Outputs(); - const auto& scope = runtime_ctx->InferScope(); - const auto ins_end = ins.end(); - const auto outs_end = outs.end(); - auto fair_input = [&](const std::string& name) -> bool { - auto it = ins.find(name); - if (it == ins_end) { - return false; - } - const auto& in = it->second; - if (in.size() != 1 || in[0] == framework::kEmptyVarName) { - return false; - } - return scope.FindVar(in[0]) != nullptr; - }; - auto fair_output = [&](const std::string& name) -> bool { - auto it = outs.find(name); - if (it == outs_end) { - return false; - } - const auto& out = it->second; - if (out.size() != 1 || out[0] == framework::kEmptyVarName) { - return false; - } - return scope.FindVar(out[0]) != nullptr; - }; - + FUSION_INFERSHAPE_INIT; PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of AttentionLSTM."); PADDLE_ENFORCE(fair_input("C0"), "Assert only one Input(C0) of AttentionLSTM."); diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index bcdcb2ac4da..b10d311f050 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fusion_gru_op.h" #include // for memcpy #include -#include "paddle/fluid/framework/shape_runtime_infer.h" +#include "paddle/fluid/operators/fusion_infershape_define.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" @@ -26,38 +26,7 @@ namespace paddle { namespace operators { void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { - auto* runtime_ctx = dynamic_cast(ctx); - if (runtime_ctx == nullptr) { - LOG(FATAL) << "Should have runtime infer context"; - } - const auto& ins = runtime_ctx->OpBase().Inputs(); - const auto& outs = runtime_ctx->OpBase().Outputs(); - const auto& scope = runtime_ctx->InferScope(); - const auto ins_end = ins.end(); - const auto outs_end = outs.end(); - auto fair_input = [&](const std::string& name) -> bool { - auto it = ins.find(name); - if (it == ins_end) { - return false; - } - const auto& in = it->second; - if (in.size() != 1 || in[0] == framework::kEmptyVarName) { - return false; - } - return scope.FindVar(in[0]) != nullptr; - }; - auto fair_output = [&](const std::string& name) -> bool { - auto it = outs.find(name); - if (it == outs_end) { - return false; - } - const auto& out = it->second; - if (out.size() != 1 || out[0] == framework::kEmptyVarName) { - return false; - } - return scope.FindVar(out[0]) != nullptr; - }; - + FUSION_INFERSHAPE_INIT; PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of GRU."); PADDLE_ENFORCE(fair_input("WeightX"), "Assert only one Input(WeightX) of GRU."); diff --git a/paddle/fluid/operators/fusion_infershape_define.h b/paddle/fluid/operators/fusion_infershape_define.h new file mode 100644 index 00000000000..89521672b0a --- /dev/null +++ b/paddle/fluid/operators/fusion_infershape_define.h @@ -0,0 +1,60 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef PADDLE_FLUID_OPERATORS_FUSION_INFERSHAPE_DEFINE_H_ +#define PADDLE_FLUID_OPERATORS_FUSION_INFERSHAPE_DEFINE_H_ + +#include +#include "paddle/fluid/framework/shape_runtime_infer.h" + +namespace paddle { +namespace operators { + +#define FUSION_INFERSHAPE_INIT \ + auto* runtime_ctx = dynamic_cast(ctx); \ + if (runtime_ctx == nullptr) { \ + LOG(FATAL) << "Should have runtime infer context"; \ + } \ + const auto& ins = runtime_ctx->OpBase().Inputs(); \ + const auto& outs = runtime_ctx->OpBase().Outputs(); \ + const auto& scope = runtime_ctx->InferScope(); \ + const auto ins_end = ins.end(); \ + const auto outs_end = outs.end(); \ + auto fair_input = [&](const std::string& name) -> bool { \ + auto it = ins.find(name); \ + if (it == ins_end) { \ + return false; \ + } \ + const auto& in = it->second; \ + if (in.size() != 1 || in[0] == framework::kEmptyVarName) { \ + return false; \ + } \ + return scope.FindVar(in[0]) != nullptr; \ + }; \ + auto fair_output = [&](const std::string& name) -> bool { \ + auto it = outs.find(name); \ + if (it == outs_end) { \ + return false; \ + } \ + const auto& out = it->second; \ + if (out.size() != 1 || out[0] == framework::kEmptyVarName) { \ + return false; \ + } \ + return scope.FindVar(out[0]) != nullptr; \ + } + +} // namespace operators +} // namespace paddle + +#endif // PADDLE_FLUID_OPERATORS_FUSION_INFERSHAPE_DEFINE_H_ diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index ae9d5d78ae6..08af98f8506 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fusion_lstm_op.h" #include -#include "paddle/fluid/framework/shape_runtime_infer.h" +#include "paddle/fluid/operators/fusion_infershape_define.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" @@ -25,38 +25,7 @@ namespace paddle { namespace operators { void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { - auto* runtime_ctx = dynamic_cast(ctx); - if (runtime_ctx == nullptr) { - LOG(FATAL) << "Should have runtime infer context"; - } - const auto& ins = runtime_ctx->OpBase().Inputs(); - const auto& outs = runtime_ctx->OpBase().Outputs(); - const auto& scope = runtime_ctx->InferScope(); - const auto ins_end = ins.end(); - const auto outs_end = outs.end(); - auto fair_input = [&](const std::string& name) -> bool { - auto it = ins.find(name); - if (it == ins_end) { - return false; - } - const auto& in = it->second; - if (in.size() != 1 || in[0] == framework::kEmptyVarName) { - return false; - } - return scope.FindVar(in[0]) != nullptr; - }; - auto fair_output = [&](const std::string& name) -> bool { - auto it = outs.find(name); - if (it == outs_end) { - return false; - } - const auto& out = it->second; - if (out.size() != 1 || out[0] == framework::kEmptyVarName) { - return false; - } - return scope.FindVar(out[0]) != nullptr; - }; - + FUSION_INFERSHAPE_INIT; PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of LSTM."); PADDLE_ENFORCE(fair_input("WeightX"), "Assert only one Input(WeightX) of LSTM."); -- GitLab