未验证 提交 dcfbc6a6 编写于 作者: Y Yan Chunwei 提交者: GitHub

inference analyzer as bin (#12450)

上级 31a2c876
......@@ -6,9 +6,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph
tensorrt_subgraph_node_mark_pass.cc
analyzer.cc
helper.cc
model_store_pass.cc
DEPS framework_proto proto_desc)
cc_test(test_node SRCS node_tester.cc DEPS analysis)
cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
......@@ -40,3 +42,4 @@ inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_
inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc)
inference_analysis_test(test_tensorrt_subgraph_node_mark_pass SRCS tensorrt_subgraph_node_mark_pass_tester.cc)
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc)
inference_analysis_test(test_model_store_pass SRCS model_store_pass_tester.cc)
......@@ -17,6 +17,7 @@
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/model_store_pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h"
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
......@@ -29,6 +30,9 @@ DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false,
DEFINE_string(inference_analysis_graphviz_log_root, "./",
"Graphviz debuger for data flow graphs.");
DEFINE_string(inference_analysis_output_storage_path, "",
"optimized model output path");
namespace inference {
namespace analysis {
......@@ -47,6 +51,9 @@ class DfgPassManagerImpl final : public DfgPassManager {
AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller));
}
AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass);
if (!FLAGS_inference_analysis_output_storage_path.empty()) {
AddPass("model-store-pass", new ModelStorePass);
}
}
std::string repr() const override { return "dfg-pass-manager"; }
......
......@@ -16,28 +16,23 @@ limitations under the License. */
/*
* This file contains Analyzer, an class that exposed as a library that analyze
* and optimize
* Fluid ProgramDesc for inference. Similar to LLVM, it has multiple flags to
* control whether
* an process is applied on the program.
* and optimize Fluid ProgramDesc for inference. Similar to LLVM, it has
* multiple flags to
* control whether an process is applied on the program.
*
* The processes are called Passes in analysis, the Passes are placed in a
* pipeline, the first
* Pass is the FluidToDataFlowGraphPass which transforms a Fluid ProgramDesc to
* a data flow
* graph, the last Pass is DataFlowGraphToFluidPass which transforms a data flow
* graph to a
* Fluid ProgramDesc. The passes in the middle of the pipeline can be any Passes
* which take a
* node or data flow graph as input.
* pipeline, the first Pass is the FluidToDataFlowGraphPass which transforms a
* Fluid ProgramDesc to
* a data flow graph, the last Pass is DataFlowGraphToFluidPass which transforms
* a data flow graph to a Fluid ProgramDesc. The passes in the middle of the
* pipeline can be any Passes
* which take a node or data flow graph as input.
*
* The Analyzer can be used in two methods, the first is a executable file which
* can be used to
* pre-process the inference model and can be controlled by passing difference
* command flags;
* can be used to pre-process the inference model and can be controlled by
* passing difference command flags;
* the other way is to compose inside the inference API as a runtime pre-process
* phase in the
* inference service.
* phase in the inference service.
*/
#include <gflags/gflags.h>
......@@ -50,6 +45,7 @@ namespace paddle {
// flag if not available.
DECLARE_bool(inference_analysis_enable_tensorrt_subgraph_engine);
DECLARE_string(inference_analysis_graphviz_log_root);
DECLARE_string(inference_analysis_output_storage_path);
namespace inference {
namespace analysis {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file implements analysizer -- an executation help to analyze and
* optimize trained model.
*/
#include "paddle/fluid/inference/analysis/analyzer.h"
#include <gflags/gflags.h>
#include <glog/logging.h>
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
using paddle::inference::analysis::Analyzer;
using paddle::inference::analysis::Argument;
Argument argument;
Analyzer analyzer;
analyzer.Run(&argument);
return 0;
}
......@@ -20,14 +20,18 @@ namespace paddle {
namespace inference {
namespace analysis {
TEST_F(DFG_Tester, analysis_without_tensorrt) {
TEST(Analyzer, analysis_without_tensorrt) {
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = false;
Argument argument;
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
Analyzer analyser;
analyser.Run(&argument);
}
TEST_F(DFG_Tester, analysis_with_tensorrt) {
TEST(Analyzer, analysis_with_tensorrt) {
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = true;
Argument argument;
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
Analyzer analyser;
analyser.Run(&argument);
}
......
......@@ -36,6 +36,16 @@ namespace analysis {
* All the fields should be registered here for clearness.
*/
struct Argument {
Argument() = default;
explicit Argument(const std::string& fluid_model_dir)
: fluid_model_dir(new std::string(fluid_model_dir)) {}
// The directory of the trained model.
std::unique_ptr<std::string> fluid_model_dir;
// The path of `__model__` and `param`, this is used when the file name of
// model and param is changed.
std::unique_ptr<std::string> fluid_model_program_path;
std::unique_ptr<std::string> fluid_model_param_path;
// The graph that process by the Passes or PassManagers.
std::unique_ptr<DataFlowGraph> main_dfg;
......@@ -44,6 +54,9 @@ struct Argument {
// The processed program desc.
std::unique_ptr<framework::proto::ProgramDesc> transformed_program_desc;
// The output storage path of ModelStorePass.
std::unique_ptr<std::string> model_output_store_path;
};
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
......
......@@ -20,7 +20,7 @@ namespace inference {
namespace analysis {
TEST(DataFlowGraph, BFS) {
auto desc = LoadProgramDesc();
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
auto dfg = ProgramDescToDFG(desc);
dfg.Build();
......@@ -44,7 +44,7 @@ TEST(DataFlowGraph, BFS) {
}
TEST(DataFlowGraph, DFS) {
auto desc = LoadProgramDesc();
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
auto dfg = ProgramDescToDFG(desc);
dfg.Build();
GraphTraits<DataFlowGraph> trait(&dfg);
......
......@@ -26,21 +26,21 @@ namespace paddle {
namespace inference {
namespace analysis {
TEST_F(DFG_Tester, Test) {
DataFlowGraph graph;
TEST(DataFlowGraph, Test) {
Argument argument(FLAGS_inference_model_dir);
FluidToDataFlowGraphPass pass0;
DataFlowGraphToFluidPass pass1;
ASSERT_TRUE(pass0.Initialize(&argument));
ASSERT_TRUE(pass1.Initialize(&argument));
pass0.Run(&graph);
pass1.Run(&graph);
pass0.Run(argument.main_dfg.get());
pass1.Run(argument.main_dfg.get());
pass0.Finalize();
pass1.Finalize();
LOG(INFO) << graph.nodes.size();
LOG(INFO) << argument.main_dfg->nodes.size();
}
}; // namespace analysis
......
......@@ -23,12 +23,18 @@ namespace paddle {
namespace inference {
namespace analysis {
TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) {
auto dfg = ProgramDescToDFG(*argument.origin_program_desc);
TEST(DFG_GraphvizDrawPass, dfg_graphviz_draw_pass_tester) {
Argument argument(FLAGS_inference_model_dir);
FluidToDataFlowGraphPass pass0;
ASSERT_TRUE(pass0.Initialize(&argument));
pass0.Run(argument.main_dfg.get());
// auto dfg = ProgramDescToDFG(*argument.origin_program_desc);
DFG_GraphvizDrawPass::Config config("./", "test");
DFG_GraphvizDrawPass pass(config);
pass.Initialize(&argument);
pass.Run(&dfg);
pass.Run(argument.main_dfg.get());
// test content
std::ifstream file("./0-graph_test.dot");
......
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include <string>
#include <vector>
......@@ -25,8 +26,20 @@ namespace analysis {
bool FluidToDataFlowGraphPass::Initialize(Argument *argument) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument);
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc);
PADDLE_ENFORCE(argument);
if (argument->origin_program_desc) {
LOG(WARNING) << "argument's origin_program_desc is already set, might "
"duplicate called";
}
if (!argument->fluid_model_program_path) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_dir);
argument->fluid_model_program_path.reset(
new std::string(*argument->fluid_model_dir + "/__model__"));
}
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path);
auto program = LoadProgramDesc(*argument->fluid_model_program_path);
argument->origin_program_desc.reset(
new framework::proto::ProgramDesc(program));
if (!argument->main_dfg) {
argument->main_dfg.reset(new DataFlowGraph);
}
......
......@@ -21,8 +21,9 @@ namespace paddle {
namespace inference {
namespace analysis {
TEST_F(DFG_Tester, Init) {
TEST(FluidToDataFlowGraphPass, Test) {
FluidToDataFlowGraphPass pass;
Argument argument(FLAGS_inference_model_dir);
pass.Initialize(&argument);
pass.Run(argument.main_dfg.get());
// Analysis is sensitive to ProgramDesc, careful to change the original model.
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <cstdio>
#include <fstream>
#include <string>
#include <typeindex>
#include <unordered_map>
......@@ -136,6 +137,20 @@ static void ExecShellCommand(const std::string &cmd, std::string *message) {
}
}
static framework::proto::ProgramDesc LoadProgramDesc(
const std::string &model_path) {
std::ifstream fin(model_path, std::ios::in | std::ios::binary);
PADDLE_ENFORCE(fin.is_open(), "Cannot open file %s", model_path);
fin.seekg(0, std::ios::end);
std::string buffer(fin.tellg(), ' ');
fin.seekg(0, std::ios::beg);
fin.read(&buffer[0], buffer.size());
fin.close();
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(buffer);
return program_desc;
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/model_store_pass.h"
#include <stdio.h>
#include <stdlib.h>
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/argument.h"
namespace paddle {
namespace inference {
namespace analysis {
void ModelStorePass::Run(DataFlowGraph *x) {
if (!argument_->fluid_model_param_path) {
PADDLE_ENFORCE_NOT_NULL(argument_->fluid_model_dir);
argument_->fluid_model_param_path.reset(
new std::string(*argument_->fluid_model_dir + "param"));
}
PADDLE_ENFORCE_NOT_NULL(argument_->model_output_store_path);
// Directly copy param file to destination.
std::stringstream ss;
// NOTE these commands only works on linux.
ss << "mkdir -p " << *argument_->model_output_store_path;
LOG(INFO) << "run command: " << ss.str();
PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0);
ss.str("");
ss << "cp " << *argument_->fluid_model_dir << "/*"
<< " " << *argument_->model_output_store_path;
LOG(INFO) << "run command: " << ss.str();
PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0);
// Store program
PADDLE_ENFORCE_NOT_NULL(argument_->transformed_program_desc,
"program desc is not transformed, should call "
"DataFlowGraphToFluidPass first.");
const std::string program_output_path =
*argument_->model_output_store_path + "/__model__";
std::ofstream file(program_output_path, std::ios::binary);
PADDLE_ENFORCE(file.is_open(), "failed to open %s to write.",
program_output_path);
const std::string serialized_message =
argument_->transformed_program_desc->SerializeAsString();
file.write(serialized_message.c_str(), serialized_message.size());
}
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file defines ModelStorePass, which store the runtime DFG to a Paddle
* model in the disk, and that model can be reloaded for prediction.
*/
#include "paddle/fluid/inference/analysis/pass.h"
namespace paddle {
namespace inference {
namespace analysis {
class ModelStorePass : public DataFlowGraphPass {
public:
bool Initialize(Argument* argument) override {
if (!argument) {
LOG(ERROR) << "invalid argument";
return false;
}
argument_ = argument;
return true;
}
void Run(DataFlowGraph* x) override;
std::string repr() const override { return "DFG-store-pass"; }
std::string description() const override {
return R"DD(This file defines ModelStorePass, which store the runtime DFG to a Paddle
model in the disk, and that model can be reloaded for prediction again.)DD";
}
private:
Argument* argument_{nullptr};
};
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/model_store_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/analysis/analyzer.h"
namespace paddle {
namespace inference {
namespace analysis {
DEFINE_string(inference_model_dir, "", "Model path");
TEST(DFG_StorePass, test) {
Analyzer analyzer;
Argument argument(FLAGS_inference_model_dir);
argument.model_output_store_path.reset(
new std::string("./_dfg_store_pass_tmp"));
// disable storage in alalyzer
FLAGS_inference_analysis_output_storage_path = "";
analyzer.Run(&argument);
ModelStorePass pass;
pass.Initialize(&argument);
pass.Run(argument.main_dfg.get());
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -50,6 +50,7 @@ class Pass {
// Create a debugger Pass that draw the DFG by graphviz toolkit.
virtual Pass *CreateGraphvizDebugerPass() const { return nullptr; }
virtual void Run() { LOG(FATAL) << "not valid"; }
// Run on a single Node.
virtual void Run(Node *x) { LOG(FATAL) << "not valid"; }
// Run on a single Function.
......
......@@ -56,7 +56,7 @@ class TestNodePass final : public NodePass {
std::string description() const override { return "some doc"; }
};
TEST_F(DFG_Tester, DFG_pass_manager) {
TEST(PassManager, DFG_pass_manager) {
TestDfgPassManager manager;
DFG_GraphvizDrawPass::Config config("./", "dfg.dot");
......@@ -64,12 +64,15 @@ TEST_F(DFG_Tester, DFG_pass_manager) {
manager.Register("graphviz", new DFG_GraphvizDrawPass(config));
manager.Register("dfg-to-fluid", new DataFlowGraphToFluidPass);
Argument argument(FLAGS_inference_model_dir);
ASSERT_TRUE(&argument);
ASSERT_TRUE(manager.Initialize(&argument));
manager.RunAll();
}
TEST_F(DFG_Tester, Node_pass_manager) {
TEST(PassManager, Node_pass_manager) {
Argument argument(FLAGS_inference_model_dir);
// Pre-process: initialize the DFG with the ProgramDesc first.
FluidToDataFlowGraphPass pass0;
pass0.Initialize(&argument);
......
......@@ -31,8 +31,8 @@ SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) {
return false;
};
TEST_F(DFG_Tester, Split) {
auto desc = LoadProgramDesc();
TEST(SubGraphSplitter, Split) {
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
auto dfg = ProgramDescToDFG(desc);
LOG(INFO) << "spliter\n" << dfg.DotString();
......@@ -63,8 +63,8 @@ TEST_F(DFG_Tester, Split) {
ASSERT_EQ(subgraphs.back().size(), 6UL);
}
TEST_F(DFG_Tester, Fuse) {
auto desc = LoadProgramDesc();
TEST(SubGraphSplitter, Fuse) {
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
auto dfg = ProgramDescToDFG(desc);
size_t count0 = dfg.nodes.size();
......
......@@ -22,11 +22,11 @@ namespace paddle {
namespace inference {
namespace analysis {
TEST_F(DFG_Tester, tensorrt_subgraph_node_mark_pass) {
TEST(TensorRTSubgraphNodeMarkPass, test) {
// init
FluidToDataFlowGraphPass pass;
Argument argument(FLAGS_inference_model_dir);
ASSERT_TRUE(pass.Initialize(&argument));
argument.main_dfg.reset(new DataFlowGraph);
pass.Run(argument.main_dfg.get());
TensorRTSubgraphNodeMarkPass::teller_t teller = [](const Node* node) {
......@@ -41,7 +41,7 @@ TEST_F(DFG_Tester, tensorrt_subgraph_node_mark_pass) {
for (auto& node : argument.main_dfg->nodes.nodes()) {
counter += node->attr(ATTR_supported_by_tensorrt).Bool();
}
ASSERT_EQ(counter, 2);
LOG(INFO) << counter << " nodes marked";
}
......
......@@ -25,7 +25,7 @@ namespace analysis {
DEFINE_string(dot_dir, "./", "");
TEST_F(DFG_Tester, tensorrt_single_pass) {
TEST(TensorRTSubGraphPass, main) {
std::unordered_set<std::string> teller_set(
{"elementwise_add", "mul", "sigmoid"});
SubGraphSplitter::NodeInsideSubgraphTeller teller = [&](const Node* node) {
......@@ -35,7 +35,8 @@ TEST_F(DFG_Tester, tensorrt_single_pass) {
return false;
};
LOG(INFO) << "init";
Argument argument(FLAGS_inference_model_dir);
DFG_GraphvizDrawPass::Config config{FLAGS_dot_dir, "origin"};
DFG_GraphvizDrawPass::Config config1{FLAGS_dot_dir, "fusion"};
......@@ -44,13 +45,11 @@ TEST_F(DFG_Tester, tensorrt_single_pass) {
FluidToDataFlowGraphPass pass0;
TensorRTSubGraphPass trt_pass(std::move(teller));
LOG(INFO) << "Initialize";
dfg_pass.Initialize(&argument);
dfg_pass1.Initialize(&argument);
pass0.Initialize(&argument);
trt_pass.Initialize(&argument);
LOG(INFO) << "Run";
argument.main_dfg.reset(new DataFlowGraph);
pass0.Run(argument.main_dfg.get());
dfg_pass.Run(argument.main_dfg.get());
......
......@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/analysis/helper.h"
namespace paddle {
namespace inference {
......@@ -32,27 +32,12 @@ namespace analysis {
DEFINE_string(inference_model_dir, "", "inference test model dir");
static framework::proto::ProgramDesc LoadProgramDesc(
const std::string& model_dir = FLAGS_inference_model_dir) {
std::string msg;
std::string net_file = FLAGS_inference_model_dir + "/__model__";
std::ifstream fin(net_file, std::ios::in | std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s", net_file);
fin.seekg(0, std::ios::end);
msg.resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(msg.at(0)), msg.size());
fin.close();
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(msg);
return program_desc;
}
static DataFlowGraph ProgramDescToDFG(
const framework::proto::ProgramDesc& desc) {
DataFlowGraph graph;
FluidToDataFlowGraphPass pass;
Argument argument;
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
argument.origin_program_desc.reset(new framework::proto::ProgramDesc(desc));
pass.Initialize(&argument);
pass.Run(&graph);
......@@ -63,7 +48,7 @@ static DataFlowGraph ProgramDescToDFG(
class DFG_Tester : public ::testing::Test {
protected:
void SetUp() override {
auto desc = LoadProgramDesc(FLAGS_inference_model_dir);
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
argument.origin_program_desc.reset(new framework::proto::ProgramDesc(desc));
}
......
......@@ -90,6 +90,18 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
void OptimizeInferenceProgram() {
// Analyze inference_program
Argument argument;
if (!config_.model_dir.empty()) {
argument.fluid_model_dir.reset(new std::string(config_.model_dir));
} else {
PADDLE_ENFORCE(
!config_.param_file.empty(),
"Either model_dir or (param_file, prog_file) should be set.");
PADDLE_ENFORCE(!config_.prog_file.empty());
argument.fluid_model_program_path.reset(
new std::string(config_.prog_file));
argument.fluid_model_param_path.reset(
new std::string(config_.param_file));
}
argument.origin_program_desc.reset(
new ProgramDesc(*inference_program_->Proto()));
Singleton<Analyzer>::Global().Run(&argument);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册