diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index 67d355d10d3c9e11b59c9ce9d208826523095459..27fe575cb6167a726ff92a8f3d2e47b6f536ba39 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -6,9 +6,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph tensorrt_subgraph_node_mark_pass.cc analyzer.cc helper.cc + model_store_pass.cc DEPS framework_proto proto_desc) cc_test(test_node SRCS node_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis) +cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis) set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) @@ -40,3 +42,4 @@ inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_ inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc) inference_analysis_test(test_tensorrt_subgraph_node_mark_pass SRCS tensorrt_subgraph_node_mark_pass_tester.cc) inference_analysis_test(test_analyzer SRCS analyzer_tester.cc) +inference_analysis_test(test_model_store_pass SRCS model_store_pass_tester.cc) diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index b3a1075e5adf4a24bf32017574c061f36c46ba8c..98bdfcc00b9f0e8f40dfc92e4021b2bd6fb19313 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -17,6 +17,7 @@ #include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" #include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" #include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" +#include "paddle/fluid/inference/analysis/model_store_pass.h" #include "paddle/fluid/inference/analysis/pass_manager.h" #include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h" #include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h" @@ -29,6 +30,9 @@ DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false, DEFINE_string(inference_analysis_graphviz_log_root, "./", "Graphviz debuger for data flow graphs."); +DEFINE_string(inference_analysis_output_storage_path, "", + "optimized model output path"); + namespace inference { namespace analysis { @@ -47,6 +51,9 @@ class DfgPassManagerImpl final : public DfgPassManager { AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller)); } AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass); + if (!FLAGS_inference_analysis_output_storage_path.empty()) { + AddPass("model-store-pass", new ModelStorePass); + } } std::string repr() const override { return "dfg-pass-manager"; } diff --git a/paddle/fluid/inference/analysis/analyzer.h b/paddle/fluid/inference/analysis/analyzer.h index 0132bf5b9c6552391aaa19542669487f42b685a7..c82fdfff86c91b4e07e3c1b80987d3d8d796ad23 100644 --- a/paddle/fluid/inference/analysis/analyzer.h +++ b/paddle/fluid/inference/analysis/analyzer.h @@ -16,28 +16,23 @@ limitations under the License. */ /* * This file contains Analyzer, an class that exposed as a library that analyze - * and optimize - * Fluid ProgramDesc for inference. Similar to LLVM, it has multiple flags to - * control whether - * an process is applied on the program. + * and optimize Fluid ProgramDesc for inference. Similar to LLVM, it has + * multiple flags to + * control whether an process is applied on the program. * * The processes are called Passes in analysis, the Passes are placed in a - * pipeline, the first - * Pass is the FluidToDataFlowGraphPass which transforms a Fluid ProgramDesc to - * a data flow - * graph, the last Pass is DataFlowGraphToFluidPass which transforms a data flow - * graph to a - * Fluid ProgramDesc. The passes in the middle of the pipeline can be any Passes - * which take a - * node or data flow graph as input. + * pipeline, the first Pass is the FluidToDataFlowGraphPass which transforms a + * Fluid ProgramDesc to + * a data flow graph, the last Pass is DataFlowGraphToFluidPass which transforms + * a data flow graph to a Fluid ProgramDesc. The passes in the middle of the + * pipeline can be any Passes + * which take a node or data flow graph as input. * * The Analyzer can be used in two methods, the first is a executable file which - * can be used to - * pre-process the inference model and can be controlled by passing difference - * command flags; + * can be used to pre-process the inference model and can be controlled by + * passing difference command flags; * the other way is to compose inside the inference API as a runtime pre-process - * phase in the - * inference service. + * phase in the inference service. */ #include <gflags/gflags.h> @@ -50,6 +45,7 @@ namespace paddle { // flag if not available. DECLARE_bool(inference_analysis_enable_tensorrt_subgraph_engine); DECLARE_string(inference_analysis_graphviz_log_root); +DECLARE_string(inference_analysis_output_storage_path); namespace inference { namespace analysis { diff --git a/paddle/fluid/inference/analysis/analyzer_main.cc b/paddle/fluid/inference/analysis/analyzer_main.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e1fe3eb797cdced56a61aa2db0c3d18601824f8 --- /dev/null +++ b/paddle/fluid/inference/analysis/analyzer_main.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * This file implements analysizer -- an executation help to analyze and + * optimize trained model. + */ +#include "paddle/fluid/inference/analysis/analyzer.h" +#include <gflags/gflags.h> +#include <glog/logging.h> + +int main(int argc, char** argv) { + google::ParseCommandLineFlags(&argc, &argv, true); + using paddle::inference::analysis::Analyzer; + using paddle::inference::analysis::Argument; + + Argument argument; + Analyzer analyzer; + analyzer.Run(&argument); + + return 0; +} diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index 25a440e7e71fddb38cc515f99d15231675a8172e..24bfb3993cf569561980006b6627b56327dd0f67 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -20,14 +20,18 @@ namespace paddle { namespace inference { namespace analysis { -TEST_F(DFG_Tester, analysis_without_tensorrt) { +TEST(Analyzer, analysis_without_tensorrt) { FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = false; + Argument argument; + argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir)); Analyzer analyser; analyser.Run(&argument); } -TEST_F(DFG_Tester, analysis_with_tensorrt) { +TEST(Analyzer, analysis_with_tensorrt) { FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = true; + Argument argument; + argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir)); Analyzer analyser; analyser.Run(&argument); } diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 6d316f20bff7a68754b0afec6463bd5d7579227f..9e1c2e45865a56efb60d4ec632ff3c52e23fedde 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -36,6 +36,16 @@ namespace analysis { * All the fields should be registered here for clearness. */ struct Argument { + Argument() = default; + explicit Argument(const std::string& fluid_model_dir) + : fluid_model_dir(new std::string(fluid_model_dir)) {} + // The directory of the trained model. + std::unique_ptr<std::string> fluid_model_dir; + // The path of `__model__` and `param`, this is used when the file name of + // model and param is changed. + std::unique_ptr<std::string> fluid_model_program_path; + std::unique_ptr<std::string> fluid_model_param_path; + // The graph that process by the Passes or PassManagers. std::unique_ptr<DataFlowGraph> main_dfg; @@ -44,6 +54,9 @@ struct Argument { // The processed program desc. std::unique_ptr<framework::proto::ProgramDesc> transformed_program_desc; + + // The output storage path of ModelStorePass. + std::unique_ptr<std::string> model_output_store_path; }; #define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0) diff --git a/paddle/fluid/inference/analysis/data_flow_graph_tester.cc b/paddle/fluid/inference/analysis/data_flow_graph_tester.cc index 7912f8d7f17ae3c79e8f73f36b7095fd52c9ac86..a881262665f156812da9e1576aa29b05fc398499 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_tester.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_tester.cc @@ -20,7 +20,7 @@ namespace inference { namespace analysis { TEST(DataFlowGraph, BFS) { - auto desc = LoadProgramDesc(); + auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__"); auto dfg = ProgramDescToDFG(desc); dfg.Build(); @@ -44,7 +44,7 @@ TEST(DataFlowGraph, BFS) { } TEST(DataFlowGraph, DFS) { - auto desc = LoadProgramDesc(); + auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__"); auto dfg = ProgramDescToDFG(desc); dfg.Build(); GraphTraits<DataFlowGraph> trait(&dfg); diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc index d8fc5e580a98f76233f01fdc4d7987311f78ee45..4ef381db295b986b91173a728b6d98640f6f4f51 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc @@ -26,21 +26,21 @@ namespace paddle { namespace inference { namespace analysis { -TEST_F(DFG_Tester, Test) { - DataFlowGraph graph; +TEST(DataFlowGraph, Test) { + Argument argument(FLAGS_inference_model_dir); FluidToDataFlowGraphPass pass0; DataFlowGraphToFluidPass pass1; ASSERT_TRUE(pass0.Initialize(&argument)); ASSERT_TRUE(pass1.Initialize(&argument)); - pass0.Run(&graph); - pass1.Run(&graph); + pass0.Run(argument.main_dfg.get()); + pass1.Run(argument.main_dfg.get()); pass0.Finalize(); pass1.Finalize(); - LOG(INFO) << graph.nodes.size(); + LOG(INFO) << argument.main_dfg->nodes.size(); } }; // namespace analysis diff --git a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc index 65842b1e850953e77e3d4d28416609be271af9f1..928be7917047382d9b86294f6039b26b0ebf6f49 100644 --- a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc +++ b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc @@ -23,12 +23,18 @@ namespace paddle { namespace inference { namespace analysis { -TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) { - auto dfg = ProgramDescToDFG(*argument.origin_program_desc); +TEST(DFG_GraphvizDrawPass, dfg_graphviz_draw_pass_tester) { + Argument argument(FLAGS_inference_model_dir); + FluidToDataFlowGraphPass pass0; + ASSERT_TRUE(pass0.Initialize(&argument)); + pass0.Run(argument.main_dfg.get()); + + // auto dfg = ProgramDescToDFG(*argument.origin_program_desc); + DFG_GraphvizDrawPass::Config config("./", "test"); DFG_GraphvizDrawPass pass(config); pass.Initialize(&argument); - pass.Run(&dfg); + pass.Run(argument.main_dfg.get()); // test content std::ifstream file("./0-graph_test.dot"); diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc index 88fdf8c9cb4ce5369d70d416bbcfe6a4c7f23a98..511631d3e067f14bc1230d9e4b4d92dbe604e1d4 100644 --- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc +++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include <glog/logging.h> #include <string> #include <vector> @@ -25,8 +26,20 @@ namespace analysis { bool FluidToDataFlowGraphPass::Initialize(Argument *argument) { ANALYSIS_ARGUMENT_CHECK_FIELD(argument); - ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc); - PADDLE_ENFORCE(argument); + if (argument->origin_program_desc) { + LOG(WARNING) << "argument's origin_program_desc is already set, might " + "duplicate called"; + } + if (!argument->fluid_model_program_path) { + ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_dir); + argument->fluid_model_program_path.reset( + new std::string(*argument->fluid_model_dir + "/__model__")); + } + ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path); + auto program = LoadProgramDesc(*argument->fluid_model_program_path); + argument->origin_program_desc.reset( + new framework::proto::ProgramDesc(program)); + if (!argument->main_dfg) { argument->main_dfg.reset(new DataFlowGraph); } diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc index dadb84059d21adab44159a6145b345460663cb96..d218dcd05015aa4636c16569de4addf4936c8cd5 100644 --- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc +++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc @@ -21,8 +21,9 @@ namespace paddle { namespace inference { namespace analysis { -TEST_F(DFG_Tester, Init) { +TEST(FluidToDataFlowGraphPass, Test) { FluidToDataFlowGraphPass pass; + Argument argument(FLAGS_inference_model_dir); pass.Initialize(&argument); pass.Run(argument.main_dfg.get()); // Analysis is sensitive to ProgramDesc, careful to change the original model. diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h index f1064cd20f28092d80d3fd23a862da080b6cc2f3..a0f912b251d5ea29594a7f601d5b2bce91201790 100644 --- a/paddle/fluid/inference/analysis/helper.h +++ b/paddle/fluid/inference/analysis/helper.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include <cstdio> +#include <fstream> #include <string> #include <typeindex> #include <unordered_map> @@ -136,6 +137,20 @@ static void ExecShellCommand(const std::string &cmd, std::string *message) { } } +static framework::proto::ProgramDesc LoadProgramDesc( + const std::string &model_path) { + std::ifstream fin(model_path, std::ios::in | std::ios::binary); + PADDLE_ENFORCE(fin.is_open(), "Cannot open file %s", model_path); + fin.seekg(0, std::ios::end); + std::string buffer(fin.tellg(), ' '); + fin.seekg(0, std::ios::beg); + fin.read(&buffer[0], buffer.size()); + fin.close(); + framework::proto::ProgramDesc program_desc; + program_desc.ParseFromString(buffer); + return program_desc; +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/model_store_pass.cc b/paddle/fluid/inference/analysis/model_store_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..db7be3c0cde12c90ca698c13d4f3564d8b66ee40 --- /dev/null +++ b/paddle/fluid/inference/analysis/model_store_pass.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/analysis/model_store_pass.h" +#include <stdio.h> +#include <stdlib.h> +#include "paddle/fluid/inference/analysis/analyzer.h" +#include "paddle/fluid/inference/analysis/argument.h" + +namespace paddle { +namespace inference { +namespace analysis { + +void ModelStorePass::Run(DataFlowGraph *x) { + if (!argument_->fluid_model_param_path) { + PADDLE_ENFORCE_NOT_NULL(argument_->fluid_model_dir); + argument_->fluid_model_param_path.reset( + new std::string(*argument_->fluid_model_dir + "param")); + } + PADDLE_ENFORCE_NOT_NULL(argument_->model_output_store_path); + // Directly copy param file to destination. + std::stringstream ss; + // NOTE these commands only works on linux. + ss << "mkdir -p " << *argument_->model_output_store_path; + LOG(INFO) << "run command: " << ss.str(); + PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0); + ss.str(""); + + ss << "cp " << *argument_->fluid_model_dir << "/*" + << " " << *argument_->model_output_store_path; + LOG(INFO) << "run command: " << ss.str(); + PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0); + + // Store program + PADDLE_ENFORCE_NOT_NULL(argument_->transformed_program_desc, + "program desc is not transformed, should call " + "DataFlowGraphToFluidPass first."); + const std::string program_output_path = + *argument_->model_output_store_path + "/__model__"; + std::ofstream file(program_output_path, std::ios::binary); + PADDLE_ENFORCE(file.is_open(), "failed to open %s to write.", + program_output_path); + const std::string serialized_message = + argument_->transformed_program_desc->SerializeAsString(); + file.write(serialized_message.c_str(), serialized_message.size()); +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/model_store_pass.h b/paddle/fluid/inference/analysis/model_store_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..713e8783eac3e9294dd22622e42deb50fd432082 --- /dev/null +++ b/paddle/fluid/inference/analysis/model_store_pass.h @@ -0,0 +1,51 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * This file defines ModelStorePass, which store the runtime DFG to a Paddle + * model in the disk, and that model can be reloaded for prediction. + */ + +#include "paddle/fluid/inference/analysis/pass.h" + +namespace paddle { +namespace inference { +namespace analysis { + +class ModelStorePass : public DataFlowGraphPass { + public: + bool Initialize(Argument* argument) override { + if (!argument) { + LOG(ERROR) << "invalid argument"; + return false; + } + argument_ = argument; + return true; + } + + void Run(DataFlowGraph* x) override; + + std::string repr() const override { return "DFG-store-pass"; } + std::string description() const override { + return R"DD(This file defines ModelStorePass, which store the runtime DFG to a Paddle + model in the disk, and that model can be reloaded for prediction again.)DD"; + } + + private: + Argument* argument_{nullptr}; +}; + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/model_store_pass_tester.cc b/paddle/fluid/inference/analysis/model_store_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f3526dd504e77e58d79b4f675db86a22fd0f26b --- /dev/null +++ b/paddle/fluid/inference/analysis/model_store_pass_tester.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/analysis/model_store_pass.h" + +#include <gflags/gflags.h> +#include <gtest/gtest.h> +#include "paddle/fluid/inference/analysis/analyzer.h" + +namespace paddle { +namespace inference { +namespace analysis { + +DEFINE_string(inference_model_dir, "", "Model path"); + +TEST(DFG_StorePass, test) { + Analyzer analyzer; + Argument argument(FLAGS_inference_model_dir); + argument.model_output_store_path.reset( + new std::string("./_dfg_store_pass_tmp")); + // disable storage in alalyzer + FLAGS_inference_analysis_output_storage_path = ""; + analyzer.Run(&argument); + + ModelStorePass pass; + pass.Initialize(&argument); + pass.Run(argument.main_dfg.get()); +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/pass.h b/paddle/fluid/inference/analysis/pass.h index 6b4dbb3bb5ddd9f15f26758beef1d1b5bbf49142..6806f9ff7dada2c1e2328e1ffbfd225afefcf474 100644 --- a/paddle/fluid/inference/analysis/pass.h +++ b/paddle/fluid/inference/analysis/pass.h @@ -50,6 +50,7 @@ class Pass { // Create a debugger Pass that draw the DFG by graphviz toolkit. virtual Pass *CreateGraphvizDebugerPass() const { return nullptr; } + virtual void Run() { LOG(FATAL) << "not valid"; } // Run on a single Node. virtual void Run(Node *x) { LOG(FATAL) << "not valid"; } // Run on a single Function. diff --git a/paddle/fluid/inference/analysis/pass_manager_tester.cc b/paddle/fluid/inference/analysis/pass_manager_tester.cc index dac1c509d728114bd24a2ea1150c407646026fd4..13423e4837e12a96e7a5dfc9ca3f59bf8b14746a 100644 --- a/paddle/fluid/inference/analysis/pass_manager_tester.cc +++ b/paddle/fluid/inference/analysis/pass_manager_tester.cc @@ -56,7 +56,7 @@ class TestNodePass final : public NodePass { std::string description() const override { return "some doc"; } }; -TEST_F(DFG_Tester, DFG_pass_manager) { +TEST(PassManager, DFG_pass_manager) { TestDfgPassManager manager; DFG_GraphvizDrawPass::Config config("./", "dfg.dot"); @@ -64,12 +64,15 @@ TEST_F(DFG_Tester, DFG_pass_manager) { manager.Register("graphviz", new DFG_GraphvizDrawPass(config)); manager.Register("dfg-to-fluid", new DataFlowGraphToFluidPass); + Argument argument(FLAGS_inference_model_dir); + ASSERT_TRUE(&argument); ASSERT_TRUE(manager.Initialize(&argument)); manager.RunAll(); } -TEST_F(DFG_Tester, Node_pass_manager) { +TEST(PassManager, Node_pass_manager) { + Argument argument(FLAGS_inference_model_dir); // Pre-process: initialize the DFG with the ProgramDesc first. FluidToDataFlowGraphPass pass0; pass0.Initialize(&argument); diff --git a/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc b/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc index 67dd4da54b95add703428e1fded61065f60353e8..39cc433b40fad17f4f12359d4e907a250a88bd63 100644 --- a/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc +++ b/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc @@ -31,8 +31,8 @@ SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) { return false; }; -TEST_F(DFG_Tester, Split) { - auto desc = LoadProgramDesc(); +TEST(SubGraphSplitter, Split) { + auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__"); auto dfg = ProgramDescToDFG(desc); LOG(INFO) << "spliter\n" << dfg.DotString(); @@ -63,8 +63,8 @@ TEST_F(DFG_Tester, Split) { ASSERT_EQ(subgraphs.back().size(), 6UL); } -TEST_F(DFG_Tester, Fuse) { - auto desc = LoadProgramDesc(); +TEST(SubGraphSplitter, Fuse) { + auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__"); auto dfg = ProgramDescToDFG(desc); size_t count0 = dfg.nodes.size(); diff --git a/paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass_tester.cc b/paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass_tester.cc index a6c15e848b99ca318f4583e3d4b88345fe8e5ebc..c1d932878e559180af987594535959afdf475587 100644 --- a/paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass_tester.cc +++ b/paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass_tester.cc @@ -22,11 +22,11 @@ namespace paddle { namespace inference { namespace analysis { -TEST_F(DFG_Tester, tensorrt_subgraph_node_mark_pass) { +TEST(TensorRTSubgraphNodeMarkPass, test) { // init FluidToDataFlowGraphPass pass; + Argument argument(FLAGS_inference_model_dir); ASSERT_TRUE(pass.Initialize(&argument)); - argument.main_dfg.reset(new DataFlowGraph); pass.Run(argument.main_dfg.get()); TensorRTSubgraphNodeMarkPass::teller_t teller = [](const Node* node) { @@ -41,7 +41,7 @@ TEST_F(DFG_Tester, tensorrt_subgraph_node_mark_pass) { for (auto& node : argument.main_dfg->nodes.nodes()) { counter += node->attr(ATTR_supported_by_tensorrt).Bool(); } - + ASSERT_EQ(counter, 2); LOG(INFO) << counter << " nodes marked"; } diff --git a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc index 1d749d3fa3f39b351ccee6ebeb82467f7220a0b6..67a5af83d89b771536ea11be51b35244ff5c09d6 100644 --- a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc +++ b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc @@ -25,7 +25,7 @@ namespace analysis { DEFINE_string(dot_dir, "./", ""); -TEST_F(DFG_Tester, tensorrt_single_pass) { +TEST(TensorRTSubGraphPass, main) { std::unordered_set<std::string> teller_set( {"elementwise_add", "mul", "sigmoid"}); SubGraphSplitter::NodeInsideSubgraphTeller teller = [&](const Node* node) { @@ -35,7 +35,8 @@ TEST_F(DFG_Tester, tensorrt_single_pass) { return false; }; - LOG(INFO) << "init"; + Argument argument(FLAGS_inference_model_dir); + DFG_GraphvizDrawPass::Config config{FLAGS_dot_dir, "origin"}; DFG_GraphvizDrawPass::Config config1{FLAGS_dot_dir, "fusion"}; @@ -44,13 +45,11 @@ TEST_F(DFG_Tester, tensorrt_single_pass) { FluidToDataFlowGraphPass pass0; TensorRTSubGraphPass trt_pass(std::move(teller)); - LOG(INFO) << "Initialize"; dfg_pass.Initialize(&argument); dfg_pass1.Initialize(&argument); pass0.Initialize(&argument); trt_pass.Initialize(&argument); - LOG(INFO) << "Run"; argument.main_dfg.reset(new DataFlowGraph); pass0.Run(argument.main_dfg.get()); dfg_pass.Run(argument.main_dfg.get()); diff --git a/paddle/fluid/inference/analysis/ut_helper.h b/paddle/fluid/inference/analysis/ut_helper.h index ce1191a567a4198f003520c40bf02487c48c56eb..1073a6f686eaeeaaae2d93ab044149b7df518085 100644 --- a/paddle/fluid/inference/analysis/ut_helper.h +++ b/paddle/fluid/inference/analysis/ut_helper.h @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" -#include "paddle/fluid/inference/analysis/ut_helper.h" +#include "paddle/fluid/inference/analysis/helper.h" namespace paddle { namespace inference { @@ -32,27 +32,12 @@ namespace analysis { DEFINE_string(inference_model_dir, "", "inference test model dir"); -static framework::proto::ProgramDesc LoadProgramDesc( - const std::string& model_dir = FLAGS_inference_model_dir) { - std::string msg; - std::string net_file = FLAGS_inference_model_dir + "/__model__"; - std::ifstream fin(net_file, std::ios::in | std::ios::binary); - PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s", net_file); - fin.seekg(0, std::ios::end); - msg.resize(fin.tellg()); - fin.seekg(0, std::ios::beg); - fin.read(&(msg.at(0)), msg.size()); - fin.close(); - framework::proto::ProgramDesc program_desc; - program_desc.ParseFromString(msg); - return program_desc; -} - static DataFlowGraph ProgramDescToDFG( const framework::proto::ProgramDesc& desc) { DataFlowGraph graph; FluidToDataFlowGraphPass pass; Argument argument; + argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir)); argument.origin_program_desc.reset(new framework::proto::ProgramDesc(desc)); pass.Initialize(&argument); pass.Run(&graph); @@ -63,7 +48,7 @@ static DataFlowGraph ProgramDescToDFG( class DFG_Tester : public ::testing::Test { protected: void SetUp() override { - auto desc = LoadProgramDesc(FLAGS_inference_model_dir); + auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__"); argument.origin_program_desc.reset(new framework::proto::ProgramDesc(desc)); } diff --git a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc index c0891e9c281961fa03d278a0f5c676f92672c419..45b5a7638b7dc6a54bbd905766fd5c284cb6aea1 100644 --- a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc +++ b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc @@ -90,6 +90,18 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { void OptimizeInferenceProgram() { // Analyze inference_program Argument argument; + if (!config_.model_dir.empty()) { + argument.fluid_model_dir.reset(new std::string(config_.model_dir)); + } else { + PADDLE_ENFORCE( + !config_.param_file.empty(), + "Either model_dir or (param_file, prog_file) should be set."); + PADDLE_ENFORCE(!config_.prog_file.empty()); + argument.fluid_model_program_path.reset( + new std::string(config_.prog_file)); + argument.fluid_model_param_path.reset( + new std::string(config_.param_file)); + } argument.origin_program_desc.reset( new ProgramDesc(*inference_program_->Proto())); Singleton<Analyzer>::Global().Run(&argument);