未验证 提交 dcfbc6a6 编写于 作者: Y Yan Chunwei 提交者: GitHub

inference analyzer as bin (#12450)

上级 31a2c876
...@@ -6,9 +6,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph ...@@ -6,9 +6,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph
tensorrt_subgraph_node_mark_pass.cc tensorrt_subgraph_node_mark_pass.cc
analyzer.cc analyzer.cc
helper.cc helper.cc
model_store_pass.cc
DEPS framework_proto proto_desc) DEPS framework_proto proto_desc)
cc_test(test_node SRCS node_tester.cc DEPS analysis) cc_test(test_node SRCS node_tester.cc DEPS analysis)
cc_test(test_dot SRCS dot_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
...@@ -40,3 +42,4 @@ inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_ ...@@ -40,3 +42,4 @@ inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_
inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc) inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc)
inference_analysis_test(test_tensorrt_subgraph_node_mark_pass SRCS tensorrt_subgraph_node_mark_pass_tester.cc) inference_analysis_test(test_tensorrt_subgraph_node_mark_pass SRCS tensorrt_subgraph_node_mark_pass_tester.cc)
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc) inference_analysis_test(test_analyzer SRCS analyzer_tester.cc)
inference_analysis_test(test_model_store_pass SRCS model_store_pass_tester.cc)
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" #include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" #include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" #include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/model_store_pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h" #include "paddle/fluid/inference/analysis/pass_manager.h"
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h" #include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h" #include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
...@@ -29,6 +30,9 @@ DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false, ...@@ -29,6 +30,9 @@ DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false,
DEFINE_string(inference_analysis_graphviz_log_root, "./", DEFINE_string(inference_analysis_graphviz_log_root, "./",
"Graphviz debuger for data flow graphs."); "Graphviz debuger for data flow graphs.");
DEFINE_string(inference_analysis_output_storage_path, "",
"optimized model output path");
namespace inference { namespace inference {
namespace analysis { namespace analysis {
...@@ -47,6 +51,9 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -47,6 +51,9 @@ class DfgPassManagerImpl final : public DfgPassManager {
AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller)); AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller));
} }
AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass); AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass);
if (!FLAGS_inference_analysis_output_storage_path.empty()) {
AddPass("model-store-pass", new ModelStorePass);
}
} }
std::string repr() const override { return "dfg-pass-manager"; } std::string repr() const override { return "dfg-pass-manager"; }
......
...@@ -16,28 +16,23 @@ limitations under the License. */ ...@@ -16,28 +16,23 @@ limitations under the License. */
/* /*
* This file contains Analyzer, an class that exposed as a library that analyze * This file contains Analyzer, an class that exposed as a library that analyze
* and optimize * and optimize Fluid ProgramDesc for inference. Similar to LLVM, it has
* Fluid ProgramDesc for inference. Similar to LLVM, it has multiple flags to * multiple flags to
* control whether * control whether an process is applied on the program.
* an process is applied on the program.
* *
* The processes are called Passes in analysis, the Passes are placed in a * The processes are called Passes in analysis, the Passes are placed in a
* pipeline, the first * pipeline, the first Pass is the FluidToDataFlowGraphPass which transforms a
* Pass is the FluidToDataFlowGraphPass which transforms a Fluid ProgramDesc to * Fluid ProgramDesc to
* a data flow * a data flow graph, the last Pass is DataFlowGraphToFluidPass which transforms
* graph, the last Pass is DataFlowGraphToFluidPass which transforms a data flow * a data flow graph to a Fluid ProgramDesc. The passes in the middle of the
* graph to a * pipeline can be any Passes
* Fluid ProgramDesc. The passes in the middle of the pipeline can be any Passes * which take a node or data flow graph as input.
* which take a
* node or data flow graph as input.
* *
* The Analyzer can be used in two methods, the first is a executable file which * The Analyzer can be used in two methods, the first is a executable file which
* can be used to * can be used to pre-process the inference model and can be controlled by
* pre-process the inference model and can be controlled by passing difference * passing difference command flags;
* command flags;
* the other way is to compose inside the inference API as a runtime pre-process * the other way is to compose inside the inference API as a runtime pre-process
* phase in the * phase in the inference service.
* inference service.
*/ */
#include <gflags/gflags.h> #include <gflags/gflags.h>
...@@ -50,6 +45,7 @@ namespace paddle { ...@@ -50,6 +45,7 @@ namespace paddle {
// flag if not available. // flag if not available.
DECLARE_bool(inference_analysis_enable_tensorrt_subgraph_engine); DECLARE_bool(inference_analysis_enable_tensorrt_subgraph_engine);
DECLARE_string(inference_analysis_graphviz_log_root); DECLARE_string(inference_analysis_graphviz_log_root);
DECLARE_string(inference_analysis_output_storage_path);
namespace inference { namespace inference {
namespace analysis { namespace analysis {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file implements analysizer -- an executation help to analyze and
* optimize trained model.
*/
#include "paddle/fluid/inference/analysis/analyzer.h"
#include <gflags/gflags.h>
#include <glog/logging.h>
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
using paddle::inference::analysis::Analyzer;
using paddle::inference::analysis::Argument;
Argument argument;
Analyzer analyzer;
analyzer.Run(&argument);
return 0;
}
...@@ -20,14 +20,18 @@ namespace paddle { ...@@ -20,14 +20,18 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
TEST_F(DFG_Tester, analysis_without_tensorrt) { TEST(Analyzer, analysis_without_tensorrt) {
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = false; FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = false;
Argument argument;
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
Analyzer analyser; Analyzer analyser;
analyser.Run(&argument); analyser.Run(&argument);
} }
TEST_F(DFG_Tester, analysis_with_tensorrt) { TEST(Analyzer, analysis_with_tensorrt) {
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = true; FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = true;
Argument argument;
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
Analyzer analyser; Analyzer analyser;
analyser.Run(&argument); analyser.Run(&argument);
} }
......
...@@ -36,6 +36,16 @@ namespace analysis { ...@@ -36,6 +36,16 @@ namespace analysis {
* All the fields should be registered here for clearness. * All the fields should be registered here for clearness.
*/ */
struct Argument { struct Argument {
Argument() = default;
explicit Argument(const std::string& fluid_model_dir)
: fluid_model_dir(new std::string(fluid_model_dir)) {}
// The directory of the trained model.
std::unique_ptr<std::string> fluid_model_dir;
// The path of `__model__` and `param`, this is used when the file name of
// model and param is changed.
std::unique_ptr<std::string> fluid_model_program_path;
std::unique_ptr<std::string> fluid_model_param_path;
// The graph that process by the Passes or PassManagers. // The graph that process by the Passes or PassManagers.
std::unique_ptr<DataFlowGraph> main_dfg; std::unique_ptr<DataFlowGraph> main_dfg;
...@@ -44,6 +54,9 @@ struct Argument { ...@@ -44,6 +54,9 @@ struct Argument {
// The processed program desc. // The processed program desc.
std::unique_ptr<framework::proto::ProgramDesc> transformed_program_desc; std::unique_ptr<framework::proto::ProgramDesc> transformed_program_desc;
// The output storage path of ModelStorePass.
std::unique_ptr<std::string> model_output_store_path;
}; };
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0) #define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
......
...@@ -20,7 +20,7 @@ namespace inference { ...@@ -20,7 +20,7 @@ namespace inference {
namespace analysis { namespace analysis {
TEST(DataFlowGraph, BFS) { TEST(DataFlowGraph, BFS) {
auto desc = LoadProgramDesc(); auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
auto dfg = ProgramDescToDFG(desc); auto dfg = ProgramDescToDFG(desc);
dfg.Build(); dfg.Build();
...@@ -44,7 +44,7 @@ TEST(DataFlowGraph, BFS) { ...@@ -44,7 +44,7 @@ TEST(DataFlowGraph, BFS) {
} }
TEST(DataFlowGraph, DFS) { TEST(DataFlowGraph, DFS) {
auto desc = LoadProgramDesc(); auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
auto dfg = ProgramDescToDFG(desc); auto dfg = ProgramDescToDFG(desc);
dfg.Build(); dfg.Build();
GraphTraits<DataFlowGraph> trait(&dfg); GraphTraits<DataFlowGraph> trait(&dfg);
......
...@@ -26,21 +26,21 @@ namespace paddle { ...@@ -26,21 +26,21 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
TEST_F(DFG_Tester, Test) { TEST(DataFlowGraph, Test) {
DataFlowGraph graph; Argument argument(FLAGS_inference_model_dir);
FluidToDataFlowGraphPass pass0; FluidToDataFlowGraphPass pass0;
DataFlowGraphToFluidPass pass1; DataFlowGraphToFluidPass pass1;
ASSERT_TRUE(pass0.Initialize(&argument)); ASSERT_TRUE(pass0.Initialize(&argument));
ASSERT_TRUE(pass1.Initialize(&argument)); ASSERT_TRUE(pass1.Initialize(&argument));
pass0.Run(&graph); pass0.Run(argument.main_dfg.get());
pass1.Run(&graph); pass1.Run(argument.main_dfg.get());
pass0.Finalize(); pass0.Finalize();
pass1.Finalize(); pass1.Finalize();
LOG(INFO) << graph.nodes.size(); LOG(INFO) << argument.main_dfg->nodes.size();
} }
}; // namespace analysis }; // namespace analysis
......
...@@ -23,12 +23,18 @@ namespace paddle { ...@@ -23,12 +23,18 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) { TEST(DFG_GraphvizDrawPass, dfg_graphviz_draw_pass_tester) {
auto dfg = ProgramDescToDFG(*argument.origin_program_desc); Argument argument(FLAGS_inference_model_dir);
FluidToDataFlowGraphPass pass0;
ASSERT_TRUE(pass0.Initialize(&argument));
pass0.Run(argument.main_dfg.get());
// auto dfg = ProgramDescToDFG(*argument.origin_program_desc);
DFG_GraphvizDrawPass::Config config("./", "test"); DFG_GraphvizDrawPass::Config config("./", "test");
DFG_GraphvizDrawPass pass(config); DFG_GraphvizDrawPass pass(config);
pass.Initialize(&argument); pass.Initialize(&argument);
pass.Run(&dfg); pass.Run(argument.main_dfg.get());
// test content // test content
std::ifstream file("./0-graph_test.dot"); std::ifstream file("./0-graph_test.dot");
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <glog/logging.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -25,8 +26,20 @@ namespace analysis { ...@@ -25,8 +26,20 @@ namespace analysis {
bool FluidToDataFlowGraphPass::Initialize(Argument *argument) { bool FluidToDataFlowGraphPass::Initialize(Argument *argument) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument); ANALYSIS_ARGUMENT_CHECK_FIELD(argument);
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc); if (argument->origin_program_desc) {
PADDLE_ENFORCE(argument); LOG(WARNING) << "argument's origin_program_desc is already set, might "
"duplicate called";
}
if (!argument->fluid_model_program_path) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_dir);
argument->fluid_model_program_path.reset(
new std::string(*argument->fluid_model_dir + "/__model__"));
}
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path);
auto program = LoadProgramDesc(*argument->fluid_model_program_path);
argument->origin_program_desc.reset(
new framework::proto::ProgramDesc(program));
if (!argument->main_dfg) { if (!argument->main_dfg) {
argument->main_dfg.reset(new DataFlowGraph); argument->main_dfg.reset(new DataFlowGraph);
} }
......
...@@ -21,8 +21,9 @@ namespace paddle { ...@@ -21,8 +21,9 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
TEST_F(DFG_Tester, Init) { TEST(FluidToDataFlowGraphPass, Test) {
FluidToDataFlowGraphPass pass; FluidToDataFlowGraphPass pass;
Argument argument(FLAGS_inference_model_dir);
pass.Initialize(&argument); pass.Initialize(&argument);
pass.Run(argument.main_dfg.get()); pass.Run(argument.main_dfg.get());
// Analysis is sensitive to ProgramDesc, careful to change the original model. // Analysis is sensitive to ProgramDesc, careful to change the original model.
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <cstdio> #include <cstdio>
#include <fstream>
#include <string> #include <string>
#include <typeindex> #include <typeindex>
#include <unordered_map> #include <unordered_map>
...@@ -136,6 +137,20 @@ static void ExecShellCommand(const std::string &cmd, std::string *message) { ...@@ -136,6 +137,20 @@ static void ExecShellCommand(const std::string &cmd, std::string *message) {
} }
} }
static framework::proto::ProgramDesc LoadProgramDesc(
const std::string &model_path) {
std::ifstream fin(model_path, std::ios::in | std::ios::binary);
PADDLE_ENFORCE(fin.is_open(), "Cannot open file %s", model_path);
fin.seekg(0, std::ios::end);
std::string buffer(fin.tellg(), ' ');
fin.seekg(0, std::ios::beg);
fin.read(&buffer[0], buffer.size());
fin.close();
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(buffer);
return program_desc;
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/model_store_pass.h"
#include <stdio.h>
#include <stdlib.h>
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/argument.h"
namespace paddle {
namespace inference {
namespace analysis {
void ModelStorePass::Run(DataFlowGraph *x) {
if (!argument_->fluid_model_param_path) {
PADDLE_ENFORCE_NOT_NULL(argument_->fluid_model_dir);
argument_->fluid_model_param_path.reset(
new std::string(*argument_->fluid_model_dir + "param"));
}
PADDLE_ENFORCE_NOT_NULL(argument_->model_output_store_path);
// Directly copy param file to destination.
std::stringstream ss;
// NOTE these commands only works on linux.
ss << "mkdir -p " << *argument_->model_output_store_path;
LOG(INFO) << "run command: " << ss.str();
PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0);
ss.str("");
ss << "cp " << *argument_->fluid_model_dir << "/*"
<< " " << *argument_->model_output_store_path;
LOG(INFO) << "run command: " << ss.str();
PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0);
// Store program
PADDLE_ENFORCE_NOT_NULL(argument_->transformed_program_desc,
"program desc is not transformed, should call "
"DataFlowGraphToFluidPass first.");
const std::string program_output_path =
*argument_->model_output_store_path + "/__model__";
std::ofstream file(program_output_path, std::ios::binary);
PADDLE_ENFORCE(file.is_open(), "failed to open %s to write.",
program_output_path);
const std::string serialized_message =
argument_->transformed_program_desc->SerializeAsString();
file.write(serialized_message.c_str(), serialized_message.size());
}
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file defines ModelStorePass, which store the runtime DFG to a Paddle
* model in the disk, and that model can be reloaded for prediction.
*/
#include "paddle/fluid/inference/analysis/pass.h"
namespace paddle {
namespace inference {
namespace analysis {
class ModelStorePass : public DataFlowGraphPass {
public:
bool Initialize(Argument* argument) override {
if (!argument) {
LOG(ERROR) << "invalid argument";
return false;
}
argument_ = argument;
return true;
}
void Run(DataFlowGraph* x) override;
std::string repr() const override { return "DFG-store-pass"; }
std::string description() const override {
return R"DD(This file defines ModelStorePass, which store the runtime DFG to a Paddle
model in the disk, and that model can be reloaded for prediction again.)DD";
}
private:
Argument* argument_{nullptr};
};
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/model_store_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/analysis/analyzer.h"
namespace paddle {
namespace inference {
namespace analysis {
DEFINE_string(inference_model_dir, "", "Model path");
TEST(DFG_StorePass, test) {
Analyzer analyzer;
Argument argument(FLAGS_inference_model_dir);
argument.model_output_store_path.reset(
new std::string("./_dfg_store_pass_tmp"));
// disable storage in alalyzer
FLAGS_inference_analysis_output_storage_path = "";
analyzer.Run(&argument);
ModelStorePass pass;
pass.Initialize(&argument);
pass.Run(argument.main_dfg.get());
}
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -50,6 +50,7 @@ class Pass { ...@@ -50,6 +50,7 @@ class Pass {
// Create a debugger Pass that draw the DFG by graphviz toolkit. // Create a debugger Pass that draw the DFG by graphviz toolkit.
virtual Pass *CreateGraphvizDebugerPass() const { return nullptr; } virtual Pass *CreateGraphvizDebugerPass() const { return nullptr; }
virtual void Run() { LOG(FATAL) << "not valid"; }
// Run on a single Node. // Run on a single Node.
virtual void Run(Node *x) { LOG(FATAL) << "not valid"; } virtual void Run(Node *x) { LOG(FATAL) << "not valid"; }
// Run on a single Function. // Run on a single Function.
......
...@@ -56,7 +56,7 @@ class TestNodePass final : public NodePass { ...@@ -56,7 +56,7 @@ class TestNodePass final : public NodePass {
std::string description() const override { return "some doc"; } std::string description() const override { return "some doc"; }
}; };
TEST_F(DFG_Tester, DFG_pass_manager) { TEST(PassManager, DFG_pass_manager) {
TestDfgPassManager manager; TestDfgPassManager manager;
DFG_GraphvizDrawPass::Config config("./", "dfg.dot"); DFG_GraphvizDrawPass::Config config("./", "dfg.dot");
...@@ -64,12 +64,15 @@ TEST_F(DFG_Tester, DFG_pass_manager) { ...@@ -64,12 +64,15 @@ TEST_F(DFG_Tester, DFG_pass_manager) {
manager.Register("graphviz", new DFG_GraphvizDrawPass(config)); manager.Register("graphviz", new DFG_GraphvizDrawPass(config));
manager.Register("dfg-to-fluid", new DataFlowGraphToFluidPass); manager.Register("dfg-to-fluid", new DataFlowGraphToFluidPass);
Argument argument(FLAGS_inference_model_dir);
ASSERT_TRUE(&argument); ASSERT_TRUE(&argument);
ASSERT_TRUE(manager.Initialize(&argument)); ASSERT_TRUE(manager.Initialize(&argument));
manager.RunAll(); manager.RunAll();
} }
TEST_F(DFG_Tester, Node_pass_manager) { TEST(PassManager, Node_pass_manager) {
Argument argument(FLAGS_inference_model_dir);
// Pre-process: initialize the DFG with the ProgramDesc first. // Pre-process: initialize the DFG with the ProgramDesc first.
FluidToDataFlowGraphPass pass0; FluidToDataFlowGraphPass pass0;
pass0.Initialize(&argument); pass0.Initialize(&argument);
......
...@@ -31,8 +31,8 @@ SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) { ...@@ -31,8 +31,8 @@ SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) {
return false; return false;
}; };
TEST_F(DFG_Tester, Split) { TEST(SubGraphSplitter, Split) {
auto desc = LoadProgramDesc(); auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
auto dfg = ProgramDescToDFG(desc); auto dfg = ProgramDescToDFG(desc);
LOG(INFO) << "spliter\n" << dfg.DotString(); LOG(INFO) << "spliter\n" << dfg.DotString();
...@@ -63,8 +63,8 @@ TEST_F(DFG_Tester, Split) { ...@@ -63,8 +63,8 @@ TEST_F(DFG_Tester, Split) {
ASSERT_EQ(subgraphs.back().size(), 6UL); ASSERT_EQ(subgraphs.back().size(), 6UL);
} }
TEST_F(DFG_Tester, Fuse) { TEST(SubGraphSplitter, Fuse) {
auto desc = LoadProgramDesc(); auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
auto dfg = ProgramDescToDFG(desc); auto dfg = ProgramDescToDFG(desc);
size_t count0 = dfg.nodes.size(); size_t count0 = dfg.nodes.size();
......
...@@ -22,11 +22,11 @@ namespace paddle { ...@@ -22,11 +22,11 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
TEST_F(DFG_Tester, tensorrt_subgraph_node_mark_pass) { TEST(TensorRTSubgraphNodeMarkPass, test) {
// init // init
FluidToDataFlowGraphPass pass; FluidToDataFlowGraphPass pass;
Argument argument(FLAGS_inference_model_dir);
ASSERT_TRUE(pass.Initialize(&argument)); ASSERT_TRUE(pass.Initialize(&argument));
argument.main_dfg.reset(new DataFlowGraph);
pass.Run(argument.main_dfg.get()); pass.Run(argument.main_dfg.get());
TensorRTSubgraphNodeMarkPass::teller_t teller = [](const Node* node) { TensorRTSubgraphNodeMarkPass::teller_t teller = [](const Node* node) {
...@@ -41,7 +41,7 @@ TEST_F(DFG_Tester, tensorrt_subgraph_node_mark_pass) { ...@@ -41,7 +41,7 @@ TEST_F(DFG_Tester, tensorrt_subgraph_node_mark_pass) {
for (auto& node : argument.main_dfg->nodes.nodes()) { for (auto& node : argument.main_dfg->nodes.nodes()) {
counter += node->attr(ATTR_supported_by_tensorrt).Bool(); counter += node->attr(ATTR_supported_by_tensorrt).Bool();
} }
ASSERT_EQ(counter, 2);
LOG(INFO) << counter << " nodes marked"; LOG(INFO) << counter << " nodes marked";
} }
......
...@@ -25,7 +25,7 @@ namespace analysis { ...@@ -25,7 +25,7 @@ namespace analysis {
DEFINE_string(dot_dir, "./", ""); DEFINE_string(dot_dir, "./", "");
TEST_F(DFG_Tester, tensorrt_single_pass) { TEST(TensorRTSubGraphPass, main) {
std::unordered_set<std::string> teller_set( std::unordered_set<std::string> teller_set(
{"elementwise_add", "mul", "sigmoid"}); {"elementwise_add", "mul", "sigmoid"});
SubGraphSplitter::NodeInsideSubgraphTeller teller = [&](const Node* node) { SubGraphSplitter::NodeInsideSubgraphTeller teller = [&](const Node* node) {
...@@ -35,7 +35,8 @@ TEST_F(DFG_Tester, tensorrt_single_pass) { ...@@ -35,7 +35,8 @@ TEST_F(DFG_Tester, tensorrt_single_pass) {
return false; return false;
}; };
LOG(INFO) << "init"; Argument argument(FLAGS_inference_model_dir);
DFG_GraphvizDrawPass::Config config{FLAGS_dot_dir, "origin"}; DFG_GraphvizDrawPass::Config config{FLAGS_dot_dir, "origin"};
DFG_GraphvizDrawPass::Config config1{FLAGS_dot_dir, "fusion"}; DFG_GraphvizDrawPass::Config config1{FLAGS_dot_dir, "fusion"};
...@@ -44,13 +45,11 @@ TEST_F(DFG_Tester, tensorrt_single_pass) { ...@@ -44,13 +45,11 @@ TEST_F(DFG_Tester, tensorrt_single_pass) {
FluidToDataFlowGraphPass pass0; FluidToDataFlowGraphPass pass0;
TensorRTSubGraphPass trt_pass(std::move(teller)); TensorRTSubGraphPass trt_pass(std::move(teller));
LOG(INFO) << "Initialize";
dfg_pass.Initialize(&argument); dfg_pass.Initialize(&argument);
dfg_pass1.Initialize(&argument); dfg_pass1.Initialize(&argument);
pass0.Initialize(&argument); pass0.Initialize(&argument);
trt_pass.Initialize(&argument); trt_pass.Initialize(&argument);
LOG(INFO) << "Run";
argument.main_dfg.reset(new DataFlowGraph); argument.main_dfg.reset(new DataFlowGraph);
pass0.Run(argument.main_dfg.get()); pass0.Run(argument.main_dfg.get());
dfg_pass.Run(argument.main_dfg.get()); dfg_pass.Run(argument.main_dfg.get());
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" #include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/analysis/helper.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -32,27 +32,12 @@ namespace analysis { ...@@ -32,27 +32,12 @@ namespace analysis {
DEFINE_string(inference_model_dir, "", "inference test model dir"); DEFINE_string(inference_model_dir, "", "inference test model dir");
static framework::proto::ProgramDesc LoadProgramDesc(
const std::string& model_dir = FLAGS_inference_model_dir) {
std::string msg;
std::string net_file = FLAGS_inference_model_dir + "/__model__";
std::ifstream fin(net_file, std::ios::in | std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s", net_file);
fin.seekg(0, std::ios::end);
msg.resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(msg.at(0)), msg.size());
fin.close();
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(msg);
return program_desc;
}
static DataFlowGraph ProgramDescToDFG( static DataFlowGraph ProgramDescToDFG(
const framework::proto::ProgramDesc& desc) { const framework::proto::ProgramDesc& desc) {
DataFlowGraph graph; DataFlowGraph graph;
FluidToDataFlowGraphPass pass; FluidToDataFlowGraphPass pass;
Argument argument; Argument argument;
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
argument.origin_program_desc.reset(new framework::proto::ProgramDesc(desc)); argument.origin_program_desc.reset(new framework::proto::ProgramDesc(desc));
pass.Initialize(&argument); pass.Initialize(&argument);
pass.Run(&graph); pass.Run(&graph);
...@@ -63,7 +48,7 @@ static DataFlowGraph ProgramDescToDFG( ...@@ -63,7 +48,7 @@ static DataFlowGraph ProgramDescToDFG(
class DFG_Tester : public ::testing::Test { class DFG_Tester : public ::testing::Test {
protected: protected:
void SetUp() override { void SetUp() override {
auto desc = LoadProgramDesc(FLAGS_inference_model_dir); auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
argument.origin_program_desc.reset(new framework::proto::ProgramDesc(desc)); argument.origin_program_desc.reset(new framework::proto::ProgramDesc(desc));
} }
......
...@@ -90,6 +90,18 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { ...@@ -90,6 +90,18 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
void OptimizeInferenceProgram() { void OptimizeInferenceProgram() {
// Analyze inference_program // Analyze inference_program
Argument argument; Argument argument;
if (!config_.model_dir.empty()) {
argument.fluid_model_dir.reset(new std::string(config_.model_dir));
} else {
PADDLE_ENFORCE(
!config_.param_file.empty(),
"Either model_dir or (param_file, prog_file) should be set.");
PADDLE_ENFORCE(!config_.prog_file.empty());
argument.fluid_model_program_path.reset(
new std::string(config_.prog_file));
argument.fluid_model_param_path.reset(
new std::string(config_.param_file));
}
argument.origin_program_desc.reset( argument.origin_program_desc.reset(
new ProgramDesc(*inference_program_->Proto())); new ProgramDesc(*inference_program_->Proto()));
Singleton<Analyzer>::Global().Run(&argument); Singleton<Analyzer>::Global().Run(&argument);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册