From d7345959789a22437c7065078cc3a7d457c7e70e Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Mon, 18 Jun 2018 17:16:06 +0800 Subject: [PATCH] Feature/pass manager (#11440) --- .../fluid/inference/analysis/CMakeLists.txt | 39 +++--- paddle/fluid/inference/analysis/argument.cc | 15 +++ paddle/fluid/inference/analysis/argument.h | 53 ++++++++ .../inference/analysis/data_flow_graph.cc | 15 +-- .../analysis/data_flow_graph_to_fluid_pass.cc | 77 ++++++++++++ .../analysis/data_flow_graph_to_fluid_pass.h | 59 +++++++++ .../data_flow_graph_to_fluid_pass_tester.cc | 5 +- .../analysis/dfg_graphviz_draw_pass.cc | 54 ++++++++ .../analysis/dfg_graphviz_draw_pass.h | 41 ++++--- .../analysis/dfg_graphviz_draw_pass_tester.cc | 10 +- .../analysis/fluid_to_data_flow_graph_pass.cc | 22 ++-- .../analysis/fluid_to_data_flow_graph_pass.h | 11 +- .../fluid_to_data_flow_graph_pass_tester.cc | 6 +- paddle/fluid/inference/analysis/helper.h | 1 + paddle/fluid/inference/analysis/node.cc | 3 + paddle/fluid/inference/analysis/node.h | 23 ++-- paddle/fluid/inference/analysis/pass.h | 27 ++-- .../fluid/inference/analysis/pass_manager.cc | 44 +++++++ .../fluid/inference/analysis/pass_manager.h | 116 ++++++++++++++++++ .../inference/analysis/pass_manager_tester.cc | 85 +++++++++++++ .../analysis/subgraph_splitter_tester.cc | 45 +++++-- .../analysis/tensorrt_subgraph_pass.cc | 33 +++++ .../analysis/tensorrt_subgraph_pass.h | 47 +++++++ .../analysis/tensorrt_subgraph_pass_tester.cc | 71 +++++++++++ paddle/fluid/inference/analysis/ut_helper.h | 34 +++-- .../operators/tensorrt_engine_op_test.cc | 2 +- 26 files changed, 830 insertions(+), 108 deletions(-) create mode 100644 paddle/fluid/inference/analysis/argument.cc create mode 100644 paddle/fluid/inference/analysis/argument.h create mode 100644 paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc create mode 100644 paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h create mode 100644 paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc create mode 100644 paddle/fluid/inference/analysis/pass_manager.cc create mode 100644 paddle/fluid/inference/analysis/pass_manager.h create mode 100644 paddle/fluid/inference/analysis/pass_manager_tester.cc create mode 100644 paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc create mode 100644 paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h create mode 100644 paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index 5083578444..2bb2c8135d 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -1,23 +1,32 @@ set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor init) -cc_library(analysis SRCS dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc fluid_to_data_flow_graph_pass.cc - DEPS paddle_fluid) +cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc + fluid_to_data_flow_graph_pass.cc + data_flow_graph_to_fluid_pass.cc + tensorrt_subgraph_pass.cc + dfg_graphviz_draw_pass.cc + DEPS framework_proto) cc_test(test_node SRCS node_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis) set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) -cc_test(test_data_flow_graph SRCS data_flow_graph_tester.cc DEPS analysis ${FLUID_CORE_MODULES} paddle_fluid - ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) -set_tests_properties(test_data_flow_graph PROPERTIES DEPENDS test_word2vec) +function (inference_analysis_test TARGET) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS) + cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) -cc_test(test_subgraph_splitter - SRCS subgraph_splitter_tester.cc - DEPS analysis paddle_fluid tensor - ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) -set_tests_properties(test_subgraph_splitter PROPERTIES DEPENDS test_word2vec) + cc_test(${TARGET} + SRCS "${analysis_test_SRCS}" + DEPS analysis + ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model --fraction_of_gpu_memory_to_use=0.5) + set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec) +endfunction(inference_analysis_test) -cc_test(test_dfg_graphviz_draw_pass - SRCS dfg_graphviz_draw_pass_tester.cc - DEPS analysis - ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) -set_tests_properties(test_dfg_graphviz_draw_pass PROPERTIES DEPENDS test_word2vec) +inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc) +inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc) +inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc) +inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc) +inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc) +#inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc) +inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc) diff --git a/paddle/fluid/inference/analysis/argument.cc b/paddle/fluid/inference/analysis/argument.cc new file mode 100644 index 0000000000..cb0263d5d9 --- /dev/null +++ b/paddle/fluid/inference/analysis/argument.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/analysis/argument.h" diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h new file mode 100644 index 0000000000..7d7131ed7a --- /dev/null +++ b/paddle/fluid/inference/analysis/argument.h @@ -0,0 +1,53 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * This file defines the class Argument, which is the input and output of the + * analysis module. All the fields that needed either by Passes or PassManagers + * are contained in Argument. + * + * TODO(Superjomn) Find some way better to contain the fields when it grow too + * big. + */ + +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/inference/analysis/data_flow_graph.h" + +namespace paddle { +namespace inference { +namespace analysis { + +/* + * The argument definition of both Pass and PassManagers. + * + * All the fields should be registered here for clearness. + */ +struct Argument { + // The graph that process by the Passes or PassManagers. + std::unique_ptr main_dfg; + + // The original program desc. + std::unique_ptr origin_program_desc; +}; + +#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) +#define ANALYSIS_ARGUMENT_CHECK_FIELD(field__) \ + if (!UNLIKELY(field__)) { \ + LOG(ERROR) << "field " << #field__ << " should be set."; \ + return false; \ + } + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/data_flow_graph.cc b/paddle/fluid/inference/analysis/data_flow_graph.cc index 4220451e3c..c30a7c26ce 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/dot.h" +#include "paddle/fluid/inference/analysis/node.h" namespace paddle { namespace inference { @@ -57,19 +58,7 @@ std::string DataFlowGraph::DotString() const { // Add nodes for (size_t i = 0; i < nodes.size(); i++) { const Node &node = nodes.Get(i); - switch (node.type()) { - case Node::Type::kValue: - dot.AddNode(node.repr(), node.dot_attrs()); - break; - case Node::Type::kFunction: - dot.AddNode(node.repr(), node.dot_attrs()); - break; - case Node::Type::kFunctionBlock: - dot.AddNode(node.repr(), node.dot_attrs()); - break; - default: - PADDLE_THROW("unsupported Node type %d", static_cast(node.type())); - } + dot.AddNode(node.repr(), node.dot_attrs()); } // Add edges diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc new file mode 100644 index 0000000000..f7d4cca213 --- /dev/null +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" +#include "paddle/fluid/framework/proto_desc.h" + +namespace paddle { +namespace inference { +namespace analysis { + +bool DataFlowGraphToFluidPass::Initialize(Argument* argument) { + ANALYSIS_ARGUMENT_CHECK_FIELD(argument) + ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc) + desc_ = argument->origin_program_desc.get(); + // Here some logic from program_desc.cc and will not add new interfaces into + // framework::ProgramDesc class, use some UT to assure the correctness. + auto* block = desc_->mutable_blocks()->Add(); + block->set_idx(framework::kRootBlockIndex); + block->set_parent_idx(framework::kNoneBlockIndex); + return true; +} + +bool DataFlowGraphToFluidPass::Finalize() { return true; } + +void DataFlowGraphToFluidPass::Run(DataFlowGraph* graph) { + auto traits = GraphTraits(graph); + for (auto it = traits.nodes().begin(); it != traits.nodes().end(); ++it) { + if (it->deleted()) continue; + switch (it->type()) { + case Node::Type::kFunction: + LOG(INFO) << "add function " << it->name(); + AddFluidOp(&(*it)); + break; + case Node::Type::kFunctionBlock: + AddEngineOp(&(*it)); + break; + default: + continue; + } + } +} + +void DataFlowGraphToFluidPass::AddFluidOp(Node* node) { + LOG(INFO) << "processing func " << node->name(); + auto* ori_op = static_cast(node->pb_desc()); + // currently only the main block is analyzed. + auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex); + auto* op = main_block->add_ops(); + LOG(INFO) << "to copy the op"; + *op = *ori_op; // copy the attributes, by default, these will not be changed + // by analysis phrase. + // The inputs and outputs of the existing ops are not changed by tensorrt + // subgraph pass. + // NOTE It might be changed by other passes in the long run. +} + +void DataFlowGraphToFluidPass::AddEngineOp(Node* node) { + // auto* ori_op = static_cast(node->extra_info()); + // auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex); + // auto* op = main_block->add_ops(); + // TODO(Superjomn) Here need to expose some arguments for default setting. +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h new file mode 100644 index 0000000000..cbb05f622c --- /dev/null +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h @@ -0,0 +1,59 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +/* + * This file implements the transformation from fluid ProgramDesc to data flow + * graph. + */ + +#pragma once + +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/inference/analysis/data_flow_graph.h" +#include "paddle/fluid/inference/analysis/pass.h" + +namespace paddle { +namespace inference { +namespace analysis { +class DataFlowGraphToFluidPass final : public DataFlowGraphPass { + public: + DataFlowGraphToFluidPass() = default; + + bool Initialize(Argument *argument) override; + bool Finalize() override; + + void Run(DataFlowGraph *graph) override; + + std::string repr() const override { return "DFG to fluid"; } + std::string description() const override { + return "Transform a DFG to a Fluid ProgramDesc"; + } + + Pass *CreatePrinterPass(std::ostream &os, + const std::string &banner) const override { + return nullptr; + } + + protected: + // Add a Fluid Op into the ProgramDesc. + void AddFluidOp(Node *node); + // Add a EngineOp into the ProgramDesc. + void AddEngineOp(Node *node); + + private: + framework::proto::ProgramDesc *desc_; +}; +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc index dcee75cee5..d8fc5e580a 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc @@ -27,13 +27,12 @@ namespace inference { namespace analysis { TEST_F(DFG_Tester, Test) { - framework::proto::ProgramDesc new_desc; DataFlowGraph graph; FluidToDataFlowGraphPass pass0; DataFlowGraphToFluidPass pass1; - pass0.Initialize(desc); - pass1.Initialize(&new_desc); + ASSERT_TRUE(pass0.Initialize(&argument)); + ASSERT_TRUE(pass1.Initialize(&argument)); pass0.Run(&graph); pass1.Run(&graph); diff --git a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc new file mode 100644 index 0000000000..afffb3feb0 --- /dev/null +++ b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" + +namespace paddle { +namespace inference { +namespace analysis { + +void DFG_GraphvizDrawPass::Run(DataFlowGraph *graph) { + auto content = Draw(graph); + std::ofstream file(GenDotPath()); + file.write(content.c_str(), content.size()); + file.close(); + LOG(INFO) << "draw dot to " << GenDotPath(); +} + +std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) { + Dot dot; + // Add nodes + for (size_t i = 0; i < graph->nodes.size(); i++) { + const Node &node = graph->nodes.Get(i); + if (config_.display_deleted_node || !node.deleted()) { + dot.AddNode(node.repr(), node.dot_attrs()); + } + } + // Add edges + for (size_t i = 0; i < graph->nodes.size(); i++) { + const Node &node = graph->nodes.Get(i); + if (!config_.display_deleted_node && node.deleted()) continue; + for (auto &in : node.inlinks) { + if (!config_.display_deleted_node && in->deleted()) continue; + for (auto &in : node.inlinks) { + dot.AddEdge(in->repr(), node.repr(), {}); + } + } + } + return dot.Build(); +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h index 41d4475382..93ebff59ae 100644 --- a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h +++ b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h @@ -21,6 +21,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/inference/analysis/pass.h" namespace paddle { @@ -32,35 +33,39 @@ namespace analysis { */ class DFG_GraphvizDrawPass : public DataFlowGraphPass { public: - DFG_GraphvizDrawPass(const std::string& dir, const std::string& id) - : dir_(dir), id_(id) {} - - bool Initialize() override { return Pass::Initialize(); } - void Run(DataFlowGraph* graph) override { - auto content = Draw(graph); - std::ofstream file(GenDotPath()); - file.write(content.c_str(), content.size()); - file.close(); - LOG(INFO) << "draw dot to " << GenDotPath(); - } + struct Config { + Config(const std::string &dir, const std::string &id, + bool display_deleted_node = false) + : dir(dir), id(id), display_deleted_node(display_deleted_node) {} + + // The directory to store the .dot or .png files. + const std::string dir; + // The identifier for this dot file. + const std::string id; + // Whether to display deleted nodes, default false. + const bool display_deleted_node; + }; + + DFG_GraphvizDrawPass(const Config &config) : config_(config) {} + bool Initialize(Argument *argument) override { return true; } + void Run(DataFlowGraph *graph) override; bool Finalize() override { return Pass::Finalize(); } - Pass* CreatePrinterPass(std::ostream& os, - const std::string& banner) const override { - return nullptr; + std::string repr() const override { return "DFG graphviz drawer"; } + std::string description() const override { + return "Debug a DFG by draw with graphviz"; } private: // Path of the dot file to output. std::string GenDotPath() const { - return dir_ + "/" + "graph_" + id_ + ".dot"; + return config_.dir + "/" + "graph_" + config_.id + ".dot"; } - std::string Draw(DataFlowGraph* graph) { return graph->DotString(); } + std::string Draw(DataFlowGraph *graph); - std::string dir_; - std::string id_; + Config config_; }; } // namespace analysis diff --git a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc index 3fc1cc18b8..f4b5c5fd22 100644 --- a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc +++ b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc @@ -24,9 +24,10 @@ namespace inference { namespace analysis { TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) { - auto dfg = ProgramDescToDFG(desc); - DFG_GraphvizDrawPass pass("./", "test"); - pass.Initialize(); + auto dfg = ProgramDescToDFG(*argument.origin_program_desc); + DFG_GraphvizDrawPass::Config config("./", "test"); + DFG_GraphvizDrawPass pass(config); + pass.Initialize(&argument); pass.Run(&dfg); // test content @@ -38,7 +39,8 @@ TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) { while (std::getline(file, line)) { no++; } - ASSERT_EQ(no, 82); + // DFG is sensitive to ProgramDesc, be careful to change the existing models. + ASSERT_EQ(no, 112); } } // namespace analysis diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc index 9f67c989cc..5f62eef528 100644 --- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc +++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc @@ -21,19 +21,23 @@ namespace paddle { namespace inference { namespace analysis { -FluidToDataFlowGraphPass::FluidToDataFlowGraphPass() {} - -bool FluidToDataFlowGraphPass::Initialize() { return Pass::Initialize(); } - -bool FluidToDataFlowGraphPass::Initialize( - const framework::proto::ProgramDesc &desc) { - desc_ = &desc; +bool FluidToDataFlowGraphPass::Initialize(Argument *argument) { + ANALYSIS_ARGUMENT_CHECK_FIELD(argument); + ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc); + PADDLE_ENFORCE(argument); + if (!argument->main_dfg) { + LOG(INFO) << "Init DFG"; + argument->main_dfg.reset(new DataFlowGraph); + } + desc_ = argument->origin_program_desc.get(); return true; } bool FluidToDataFlowGraphPass::Finalize() { return Pass::Finalize(); } void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { + PADDLE_ENFORCE(graph); + PADDLE_ENFORCE(desc_); // insert vars std::unordered_map var2id; auto &main_block = desc_->blocks(framework::kRootBlockIndex); @@ -41,7 +45,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { const auto &var = main_block.vars(i); auto *v = graph->nodes.Create(Node::Type::kValue); v->SetName(var.name()); - v->SetExtraInfo(const_cast(static_cast(&var))); + v->SetPbDesc(const_cast(static_cast(&var))); var2id[var.name()] = v->id(); } for (int i = 0; i < main_block.ops_size(); i++) { @@ -51,7 +55,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { static_cast(o)->SetFuncType(op.type()); // Link to the original protobuf message's memory, make it easier to // generate from a data flow graph to fluid ProgramDesc. - o->SetExtraInfo(const_cast(static_cast(&op))); + o->SetPbDesc(const_cast(static_cast(&op))); // set inputs and outputs // TODO(Superjomn) make sure the InputNames is the real variable name. for (int j = 0; j < op.inputs_size(); j++) { diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h index 33517e57be..176faf0220 100644 --- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h +++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h @@ -34,13 +34,18 @@ namespace analysis { */ class FluidToDataFlowGraphPass final : public DataFlowGraphPass { public: - FluidToDataFlowGraphPass(); - bool Initialize() override; - bool Initialize(const framework::proto::ProgramDesc &desc) override; + FluidToDataFlowGraphPass() = default; + + bool Initialize(Argument *argument) override; bool Finalize() override; void Run(DataFlowGraph *graph) override; + std::string repr() const override { return "fluid-to-data-flow-graph"; } + std::string description() const override { + return "transform a fluid ProgramDesc to a data flow graph."; + } + Pass *CreatePrinterPass(std::ostream &os, const std::string &banner) const override; diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc index 817d32c92c..cfbbc284e4 100644 --- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc +++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc @@ -23,11 +23,11 @@ namespace analysis { TEST_F(DFG_Tester, Init) { FluidToDataFlowGraphPass pass; - pass.Initialize(); - pass.Initialize(desc); + pass.Initialize(&argument); DataFlowGraph graph; pass.Run(&graph); - ASSERT_GT(graph.nodes.size(), 0); + // Analysis is sensitive to ProgramDesc, careful to change the original model. + ASSERT_EQ(graph.nodes.size(), 37); pass.Finalize(); LOG(INFO) << '\n' << graph.DotString(); } diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h index 58eb0e715c..f0039e1131 100644 --- a/paddle/fluid/inference/analysis/helper.h +++ b/paddle/fluid/inference/analysis/helper.h @@ -62,6 +62,7 @@ struct DataTypeNamer { SET_TYPE(int); SET_TYPE(bool); SET_TYPE(float); + SET_TYPE(void *); } std::unordered_map inlinks; // Output links. std::vector outlinks; // A helper class to maintain the status from Pass. - // TODO(superjomn) add a checker here to ensure the T is primary. struct Attr { // NOTE T should be a primary type or a struct combined by several primary // types. // NOTE the STL containers should not use here. // Some usages - // Attr attr; - // T data; - // attr.data.assign((char*)data, sizeof(data)); + // Attr attr; + // attr.Bool() = true; bool &Bool() { return As(); } float &Float() { return As(); } int32_t &Int32() { return As(); } int64_t &Int64() { return As(); } + void *&Pointer() { return As(); } private: template @@ -130,6 +131,7 @@ class Node { size_t type_hash_{std::numeric_limits::max()}; }; + // Type checks. bool IsFunction() const { return type_ == Node::Type::kFunction; } bool IsValue() const { return type_ == Node::Type::kValue; } bool IsFunctionBlock() const { return type_ == Node::Type::kFunctionBlock; } @@ -148,9 +150,6 @@ class Node { Type type_{Type::kNone}; // Mark this node is deleted by some pass. bool deleted_{false}; - - void *extra_info_; - mutable std::unordered_map attrs_; }; diff --git a/paddle/fluid/inference/analysis/pass.h b/paddle/fluid/inference/analysis/pass.h index aa0e8667b5..65632b7491 100644 --- a/paddle/fluid/inference/analysis/pass.h +++ b/paddle/fluid/inference/analysis/pass.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/inference/analysis/argument.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/node.h" @@ -30,19 +31,24 @@ namespace analysis { class Pass { public: Pass() = default; - virtual ~Pass() {} + virtual ~Pass() = default; // Virtual method overridden by subclasses to do only necessary initialization // before any pass is run. - virtual bool Initialize() { return false; } + // virtual bool Initialize() { return false; } // There is some passes such as FlowToDataFlowGraphPass that needs a // ProgramDesc. Here use the native ProgramDesc ProtoBuf message, so that it // only couple with the proto file. - virtual bool Initialize(const framework::proto::ProgramDesc &desc) { - return false; - } + // virtual bool Initialize(const framework::proto::ProgramDesc &desc) { return + // false; } // There are some Passes such as DataFlowGraphToFluidPass that will output a // ProgramDesc. - virtual bool Initialize(framework::proto::ProgramDesc *desc) { return false; } + // virtual bool Initialize(framework::proto::ProgramDesc *desc) { return + // false; } + + // Mutable Pass. + virtual bool Initialize(Argument *argument) { return false; } + // Readonly Pass. + virtual bool Initialize(const Argument &argument) { return false; } // Virtual method overriden by subclasses to do any necessary clean up after // all passes have run. @@ -50,7 +56,9 @@ class Pass { // Get a Pass appropriate to print the Node this pass operates on. virtual Pass *CreatePrinterPass(std::ostream &os, - const std::string &banner) const = 0; + const std::string &banner) const { + return nullptr; + } // Run on a single Node. virtual void Run(Node *x) { LOG(FATAL) << "not valid"; } @@ -60,6 +68,11 @@ class Pass { virtual void Run(FunctionBlock *x) { LOG(FATAL) << "not valid"; } // Run on a single DataFlowGraph. virtual void Run(DataFlowGraph *x) { LOG(FATAL) << "not valid"; } + + // Human-readable short representation. + virtual std::string repr() const = 0; + // Human-readable long description. + virtual std::string description() const = 0; }; // NodePass process on any Node types. diff --git a/paddle/fluid/inference/analysis/pass_manager.cc b/paddle/fluid/inference/analysis/pass_manager.cc new file mode 100644 index 0000000000..b17c0e0d72 --- /dev/null +++ b/paddle/fluid/inference/analysis/pass_manager.cc @@ -0,0 +1,44 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/analysis/pass_manager.h" +#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" + +namespace paddle { +namespace inference { +namespace analysis { + +void DfgPassManager::RunAll() { + PADDLE_ENFORCE(argument_); + for (auto& pass : data_) { + VLOG(4) << "Running pass [" << pass->repr() << "]"; + pass->Run(argument_->main_dfg.get()); + } +} + +void NodePassManager::RunAll() { + PADDLE_ENFORCE(argument_); + PADDLE_ENFORCE(argument_->main_dfg.get()); + auto trait = + GraphTraits(argument_->main_dfg.get()).nodes_in_DFS(); + for (auto& node : trait) { + for (auto& pass : data_) { + pass->Run(&node); + } + } +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/pass_manager.h b/paddle/fluid/inference/analysis/pass_manager.h new file mode 100644 index 0000000000..7841c4b9d0 --- /dev/null +++ b/paddle/fluid/inference/analysis/pass_manager.h @@ -0,0 +1,116 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/* + * This file defines the logic of pass management. The analysis for inference is + * a pipeline of Passes, a PassManager is a agency that helps to manage the + * executation of the Passes. + * + * There are two modes of Passes, the first one is called NodePass and takes + * an Node as input and output; the second one is called DFGPass and takes a + * DFG(Data Flow Graph) as input and output. It is hard to put all the passes in + * the same pipeline, there are two kinds of PassManagers, both takes a DFG as + * input and output a DFG, but the Passes inside are different: + * + * 1. NodePassManager: the passes inside are all NodePasses, it can have + * different graph trivial algorithm, for example, DFS_NodePassManager will + * trigger the passes in depth first order; + * 2. DfgPassManager: the passes inside are all DfgPasses. + */ + +#pragma once + +#include +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/inference/analysis/pass.h" + +namespace paddle { +namespace inference { +namespace analysis { + +/* + * PassManager is the base class for all pass managers, a pass manager has + * several Pass-es registered, and execute them in the linear order. + */ +class PassManager : public OrderedRegistry { + public: + PassManager() = default; + // Call all the passes' Initialize methods. The desc and data_flow_graph are + // globally shared, so pass them as the arguemnts for all the pass managers. + virtual bool Initialize(const Argument& argument) { return false; } + + virtual bool Initialize(Argument* argument) { + argument_ = argument; + for (auto& pass : data_) { + LOG(INFO) << "Initializing pass " << pass->repr(); + if (!pass->Initialize(argument)) { + LOG(ERROR) << "Failed to initialize pass [" << pass->repr() << "]"; + return false; + } + } + return true; + } + + // Call all the passes' Finalize methods. + virtual bool Finalize() { + for (auto& pass : data_) { + if (!pass->Finalize()) { + LOG(ERROR) << "Failed to finalize pass [" << pass->repr() << "]"; + return false; + } + } + return true; + } + + // Run all the passes. + virtual void RunAll() = 0; + + // Short identifier. + virtual std::string repr() const = 0; + // Long description. + virtual std::string description() const = 0; + + virtual ~PassManager() = default; + + protected: + Argument* argument_{nullptr}; +}; + +/* + * A pass manager that process a DFG. + */ +class DfgPassManager : public PassManager { + public: + DfgPassManager() = default; + + void RunAll() override; + + virtual ~DfgPassManager() = default; +}; + +/* + * A pass manager that process a Node each time. + */ +class NodePassManager : public PassManager { + public: + NodePassManager() = default; + + void RunAll() override; + + virtual ~NodePassManager() = default; +}; + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/pass_manager_tester.cc b/paddle/fluid/inference/analysis/pass_manager_tester.cc new file mode 100644 index 0000000000..7af6a19951 --- /dev/null +++ b/paddle/fluid/inference/analysis/pass_manager_tester.cc @@ -0,0 +1,85 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/analysis/pass_manager.h" +#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" +#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" +#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" +#include "paddle/fluid/inference/analysis/ut_helper.h" + +#include + +namespace paddle { +namespace inference { +namespace analysis { + +class TestDfgPassManager final : public DfgPassManager { + public: + TestDfgPassManager() = default; + virtual ~TestDfgPassManager() = default; + // Short identifier. + std::string repr() const override { return "test-pass-manager"; } + // Long description. + std::string description() const override { return "test doc"; } +}; + +class TestNodePassManager final : public NodePassManager { + public: + virtual ~TestNodePassManager() = default; + + std::string repr() const override { return "test-node-pass-manager"; } + std::string description() const override { return "test doc"; } +}; + +class TestNodePass final : public NodePass { + public: + virtual ~TestNodePass() = default; + + bool Initialize(Argument* argument) override { return true; } + + void Run(Node* node) override { + LOG(INFO) << "- Processing node " << node->repr(); + } + + std::string repr() const override { return "test-node"; } + std::string description() const override { return "some doc"; } +}; + +TEST_F(DFG_Tester, DFG_pass_manager) { + TestDfgPassManager manager; + DFG_GraphvizDrawPass::Config config("./", "dfg.dot"); + + manager.Register("fluid-to-flow-graph", new FluidToDataFlowGraphPass); + manager.Register("graphviz", new DFG_GraphvizDrawPass(config)); + manager.Register("dfg-to-fluid", new DataFlowGraphToFluidPass); + + ASSERT_TRUE(manager.Initialize(&argument)); + manager.RunAll(); +} + +TEST_F(DFG_Tester, Node_pass_manager) { + // Pre-process: initialize the DFG with the ProgramDesc first. + FluidToDataFlowGraphPass pass0; + pass0.Initialize(&argument); + pass0.Run(argument.main_dfg.get()); + + TestNodePassManager manager; + manager.Register("test-node-pass", new TestNodePass); + ASSERT_TRUE(manager.Initialize(&argument)); + manager.RunAll(); +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc b/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc index 0644c0db12..8134494f8b 100644 --- a/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc +++ b/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc @@ -19,22 +19,23 @@ namespace paddle { namespace inference { namespace analysis { +SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) { + if (node->type() != Node::Type::kFunction) return false; + const auto* func = static_cast(node); + if (func->func_type() == "elementwise_add" || func->func_type() == "relu" || + func->func_type() == "conv2d" || func->func_type() == "mul" || + func->func_type() == "sigmoid" || func->func_type() == "softmax") { + LOG(INFO) << "sub-graph marked " << node->repr(); + return true; + } + return false; +}; + TEST_F(DFG_Tester, Split) { auto desc = LoadProgramDesc(); auto dfg = ProgramDescToDFG(desc); LOG(INFO) << "spliter\n" << dfg.DotString(); - SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) { - if (node->type() != Node::Type::kFunction) return false; - const auto* func = static_cast(node); - if (func->func_type() == "elementwise_add" || func->func_type() == "relu" || - func->func_type() == "conv2d" || func->func_type() == "mul" || - func->func_type() == "sigmoid" || func->func_type() == "softmax") { - LOG(INFO) << "sub-graph marked " << node->repr(); - return true; - } - return false; - }; ASSERT_GT(dfg.nodes.size(), 5UL); auto subgraphs = SubGraphSplitter(&dfg, teller)(); @@ -62,6 +63,28 @@ TEST_F(DFG_Tester, Split) { ASSERT_EQ(subgraphs.back().size(), 6UL); } +TEST_F(DFG_Tester, Fuse) { + auto desc = LoadProgramDesc(); + auto dfg = ProgramDescToDFG(desc); + + size_t count0 = dfg.nodes.size(); + + SubGraphFuse fuse(&dfg, teller); + fuse(); + + int count1 = 0; + for (auto& node : dfg.nodes.nodes()) { + if (node->deleted()) { + LOG(INFO) << "deleted " << node->repr(); + } + count1 += node->deleted(); + } + + // At least one nodes should be deleted. + ASSERT_EQ(dfg.nodes.size(), count0 + 1); // added a new FunctionBlock + ASSERT_EQ(6UL, count1); +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc new file mode 100644 index 0000000000..b75df33b71 --- /dev/null +++ b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h" +#include "paddle/fluid/inference/analysis/subgraph_splitter.h" + +namespace paddle { +namespace inference { +namespace analysis { + +TensorRTSubGraphPass::TensorRTSubGraphPass( + const TensorRTSubGraphPass::NodeInsideSubgraphTeller &teller) + : node_inside_subgraph_teller_(teller) {} + +void TensorRTSubGraphPass::Run(DataFlowGraph *graph) { + SubGraphFuse(graph, node_inside_subgraph_teller_); +} + +} // analysis +} // inference + +} // paddle diff --git a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h new file mode 100644 index 0000000000..79e9e2bcc9 --- /dev/null +++ b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/inference/analysis/node.h" +#include "paddle/fluid/inference/analysis/pass.h" +#include "paddle/fluid/inference/analysis/subgraph_splitter.h" + +namespace paddle { +namespace inference { +namespace analysis { + +/* + * Parse the graph and replace TensorRT supported nodes with SubGraphNode + */ +class TensorRTSubGraphPass : public DataFlowGraphPass { + public: + // Tell whether to transform a sub-graph into TensorRT. + using NodeInsideSubgraphTeller = SubGraphFuse::NodeInsideSubgraphTeller; + + TensorRTSubGraphPass(const NodeInsideSubgraphTeller& teller); + + bool Initialize(Argument* argument) override { return true; } + + // This class get a sub-graph as input and determine whether to transform this + // sub-graph into TensorRT. + void Run(DataFlowGraph* graph) override; + + private: + NodeInsideSubgraphTeller node_inside_subgraph_teller_; +}; + +} // namespace analysis +} // namespace inference +} // paddle diff --git a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc new file mode 100644 index 0000000000..d12dcf0d0f --- /dev/null +++ b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc @@ -0,0 +1,71 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h" + +#include +#include +#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" +#include "paddle/fluid/inference/analysis/ut_helper.h" + +namespace paddle { +namespace inference { +namespace analysis { + +DEFINE_string(model_dir, "", "inference test model dir"); + +TEST(TensorRTSubGraph, single_pass) { + auto desc = LoadProgramDesc(); + auto dfg = ProgramDescToDFG(desc); + + SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) { + if (node->type() != Node::Type::kFunction) return false; + const auto* func = static_cast(node); + if (func->func_type() == "elementwise_add" || func->func_type() == "relu" || + func->func_type() == "conv2d" || func->func_type() == "mul" || + func->func_type() == "sigmoid" || func->func_type() == "softmax") { + LOG(INFO) << "sub-graph marked " << node->repr(); + return true; + } + return false; + }; + + DFG_GraphvizDrawPass::Config config{"./", "test"}; + DFG_GraphvizDrawPass dfg_pass(config); + dfg_pass.Initialize(); + + DFG_GraphvizDrawPass dfg_pass1(config); + dfg_pass1.Initialize(); + + dfg_pass.Run(&dfg); + + TensorRTSubGraphPass trt_pass(std::move(teller)); + trt_pass.Initialize(); + + trt_pass.Run(&dfg); + + dfg_pass1.Run(&dfg); + + // Check the TRT op's block desc + for (auto node : dfg.nodes.nodes()) { + if (node->IsFunctionBlock()) { + } + } +} + +TEST(TensorRTSubGraph, pass_manager) {} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/ut_helper.h b/paddle/fluid/inference/analysis/ut_helper.h index 722fa99a48..ce1191a567 100644 --- a/paddle/fluid/inference/analysis/ut_helper.h +++ b/paddle/fluid/inference/analysis/ut_helper.h @@ -15,33 +15,46 @@ limitations under the License. */ #pragma once #include #include +#include #include #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" #include "paddle/fluid/inference/analysis/ut_helper.h" -#include "paddle/fluid/inference/io.h" namespace paddle { namespace inference { + +// Read ProgramDesc from a __model__ file, defined in io.cc +extern void ReadBinaryFile(const std::string& filename, std::string* contents); + namespace analysis { DEFINE_string(inference_model_dir, "", "inference test model dir"); static framework::proto::ProgramDesc LoadProgramDesc( const std::string& model_dir = FLAGS_inference_model_dir) { - paddle::platform::CPUPlace place; - paddle::framework::Executor executor(place); - paddle::framework::Scope scope; - auto program = Load(&executor, &scope, model_dir); - return *program->Proto(); + std::string msg; + std::string net_file = FLAGS_inference_model_dir + "/__model__"; + std::ifstream fin(net_file, std::ios::in | std::ios::binary); + PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s", net_file); + fin.seekg(0, std::ios::end); + msg.resize(fin.tellg()); + fin.seekg(0, std::ios::beg); + fin.read(&(msg.at(0)), msg.size()); + fin.close(); + framework::proto::ProgramDesc program_desc; + program_desc.ParseFromString(msg); + return program_desc; } static DataFlowGraph ProgramDescToDFG( const framework::proto::ProgramDesc& desc) { DataFlowGraph graph; FluidToDataFlowGraphPass pass; - pass.Initialize(desc); + Argument argument; + argument.origin_program_desc.reset(new framework::proto::ProgramDesc(desc)); + pass.Initialize(&argument); pass.Run(&graph); pass.Finalize(); return graph; @@ -49,9 +62,12 @@ static DataFlowGraph ProgramDescToDFG( class DFG_Tester : public ::testing::Test { protected: - void SetUp() override { desc = LoadProgramDesc(FLAGS_inference_model_dir); } + void SetUp() override { + auto desc = LoadProgramDesc(FLAGS_inference_model_dir); + argument.origin_program_desc.reset(new framework::proto::ProgramDesc(desc)); + } - framework::proto::ProgramDesc desc; + Argument argument; }; } // namespace analysis diff --git a/paddle/fluid/operators/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt_engine_op_test.cc index 85330958cd..3a2fef4805 100644 --- a/paddle/fluid/operators/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt_engine_op_test.cc @@ -240,7 +240,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { } // Test with a larger FC layer. -TEST(TensorRTEngineOp, fc) { Execute(40, 256, 256); } +TEST(TensorRTEngineOp, fc) { Execute(40, 28, 28); } } // namespace operators } // namespace paddle -- GitLab