提交 4801beb1 编写于 作者: N nhzlx

add arguments for trt config

上级 202e0a1e
...@@ -37,12 +37,20 @@ TEST(Analyzer, analysis_without_tensorrt) { ...@@ -37,12 +37,20 @@ TEST(Analyzer, analysis_without_tensorrt) {
TEST(Analyzer, analysis_with_tensorrt) { TEST(Analyzer, analysis_with_tensorrt) {
FLAGS_IA_enable_tensorrt_subgraph_engine = true; FLAGS_IA_enable_tensorrt_subgraph_engine = true;
Argument argument; Argument argument;
int* minimum_subgraph_size = new int(0);
int* max_batch_size = new int(3);
int* workspace_size = new int(1 << 20);
std::string* precision_mode = new std::string("FP32");
argument.Set<int>("minimum_subgraph_size", minimum_subgraph_size);
argument.Set<int>("max_batch_size", max_batch_size);
argument.Set<int>("workspace_size", workspace_size);
argument.Set<std::string>("precision_mode", precision_mode);
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir)); argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
Analyzer analyser; Analyzer analyser;
analyser.Run(&argument); analyser.Run(&argument);
} }
void TestWord2vecPrediction(const std::string &model_path) { void TestWord2vecPrediction(const std::string& model_path) {
NativeConfig config; NativeConfig config;
config.model_dir = model_path; config.model_dir = model_path;
config.use_gpu = false; config.use_gpu = false;
...@@ -73,8 +81,8 @@ void TestWord2vecPrediction(const std::string &model_path) { ...@@ -73,8 +81,8 @@ void TestWord2vecPrediction(const std::string &model_path) {
// The outputs' buffers are in CPU memory. // The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(5UL, num_elements); i++) { for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
LOG(INFO) << "data: " LOG(INFO) << "data: "
<< static_cast<float *>(outputs.front().data.data())[i]; << static_cast<float*>(outputs.front().data.data())[i];
PADDLE_ENFORCE(static_cast<float *>(outputs.front().data.data())[i], PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i],
result[i]); result[i]);
} }
} }
......
...@@ -309,7 +309,7 @@ void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); } ...@@ -309,7 +309,7 @@ void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); }
void SubGraphFuse::ReplaceNodesWithSubGraphs() { void SubGraphFuse::ReplaceNodesWithSubGraphs() {
auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)(); auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)();
for (auto &subgraph : subgraphs) { for (auto &subgraph : subgraphs) {
if (subgraph.size() <= argument_->Get<int>("minimun_subgraph_size")) if (subgraph.size() <= argument_->Get<int>("minimum_subgraph_size"))
continue; continue;
std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end()); std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end());
// replace this sub-graph with the first node. Two steps: 1. Create a Block // replace this sub-graph with the first node. Two steps: 1. Create a Block
......
...@@ -68,7 +68,7 @@ TEST(SubGraphSplitter, Fuse) { ...@@ -68,7 +68,7 @@ TEST(SubGraphSplitter, Fuse) {
auto dfg = ProgramDescToDFG(desc); auto dfg = ProgramDescToDFG(desc);
Argument argument; Argument argument;
int* minmum_subgraph_size = new int(3); int* minmum_subgraph_size = new int(3);
argument.Set<int>("minmum_subgraph_size", minmum_subgraph_size); argument.Set<int>("minimum_subgraph_size", minmum_subgraph_size);
size_t count0 = dfg.nodes.size(); size_t count0 = dfg.nodes.size();
......
...@@ -36,6 +36,14 @@ TEST(TensorRTSubGraphPass, main) { ...@@ -36,6 +36,14 @@ TEST(TensorRTSubGraphPass, main) {
}; };
Argument argument(FLAGS_inference_model_dir); Argument argument(FLAGS_inference_model_dir);
int* minimum_subgraph_size = new int(0);
int* max_batch_size = new int(3);
int* workspace_size = new int(1 << 20);
std::string* precision_mode = new std::string("FP32");
argument.Set<int>("minimun_subgraph_size", minimum_subgraph_size);
argument.Set<int>("max_batch_size", max_batch_size);
argument.Set<int>("workspace_size", workspace_size);
argument.Set<std::string>("precision_mode", precision_mode);
DFG_GraphvizDrawPass::Config config{FLAGS_dot_dir, "origin"}; DFG_GraphvizDrawPass::Config config{FLAGS_dot_dir, "origin"};
DFG_GraphvizDrawPass::Config config1{FLAGS_dot_dir, "fusion"}; DFG_GraphvizDrawPass::Config config1{FLAGS_dot_dir, "fusion"};
......
...@@ -94,7 +94,7 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { ...@@ -94,7 +94,7 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
int* max_batch_size = new int(config_.max_batch_size); int* max_batch_size = new int(config_.max_batch_size);
int* workspace_size = new int(config_.workspace_size); int* workspace_size = new int(config_.workspace_size);
std::string* precision_mode = new std::string(config_.precision_mode); std::string* precision_mode = new std::string(config_.precision_mode);
argument.Set<int>("minimun_subgraph_size", minimum_subgraph_size); argument.Set<int>("minimum_subgraph_size", minimum_subgraph_size);
argument.Set<int>("max_batch_size", max_batch_size); argument.Set<int>("max_batch_size", max_batch_size);
argument.Set<int>("workspace_size", workspace_size); argument.Set<int>("workspace_size", workspace_size);
argument.Set<std::string>("precision_mode", precision_mode); argument.Set<std::string>("precision_mode", precision_mode);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册