提交 27633216 编写于 作者: N nhzlx

fix comments

上级 0514882b
...@@ -37,14 +37,10 @@ TEST(Analyzer, analysis_without_tensorrt) { ...@@ -37,14 +37,10 @@ TEST(Analyzer, analysis_without_tensorrt) {
TEST(Analyzer, analysis_with_tensorrt) { TEST(Analyzer, analysis_with_tensorrt) {
FLAGS_IA_enable_tensorrt_subgraph_engine = true; FLAGS_IA_enable_tensorrt_subgraph_engine = true;
Argument argument; Argument argument;
int* minimum_subgraph_size = new int(0); argument.Set<int>("minimum_subgraph_size", new int(0));
int* max_batch_size = new int(3); argument.Set<int>("max_batch_size", new int(3));
int* workspace_size = new int(1 << 20); argument.Set<int>("workspace_size", new int(1 << 20));
std::string* precision_mode = new std::string("FP32"); argument.Set<std::string>("precision_mode", new std::string("FP32"));
argument.Set<int>("minimum_subgraph_size", minimum_subgraph_size);
argument.Set<int>("max_batch_size", max_batch_size);
argument.Set<int>("workspace_size", workspace_size);
argument.Set<std::string>("precision_mode", precision_mode);
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir)); argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
Analyzer analyser; Analyzer analyser;
analyser.Run(&argument); analyser.Run(&argument);
......
...@@ -99,6 +99,7 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) { ...@@ -99,6 +99,7 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) {
void CreateTrtEngineOp(Node *node, Argument *argument, void CreateTrtEngineOp(Node *node, Argument *argument,
framework::proto::BlockDesc *block) { framework::proto::BlockDesc *block) {
PADDLE_ENFORCE(argument->main_dfg.get());
const DataFlowGraph &graph = *(argument->main_dfg); const DataFlowGraph &graph = *(argument->main_dfg);
static int counter{0}; static int counter{0};
PADDLE_ENFORCE(node->IsFunctionBlock()); PADDLE_ENFORCE(node->IsFunctionBlock());
......
...@@ -67,8 +67,7 @@ TEST(SubGraphSplitter, Fuse) { ...@@ -67,8 +67,7 @@ TEST(SubGraphSplitter, Fuse) {
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__"); auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
auto dfg = ProgramDescToDFG(desc); auto dfg = ProgramDescToDFG(desc);
Argument argument; Argument argument;
int* minmum_subgraph_size = new int(3); argument.Set<int>("minimum_subgraph_size", new int(3));
argument.Set<int>("minimum_subgraph_size", minmum_subgraph_size);
size_t count0 = dfg.nodes.size(); size_t count0 = dfg.nodes.size();
......
...@@ -36,14 +36,10 @@ TEST(TensorRTSubGraphPass, main) { ...@@ -36,14 +36,10 @@ TEST(TensorRTSubGraphPass, main) {
}; };
Argument argument(FLAGS_inference_model_dir); Argument argument(FLAGS_inference_model_dir);
int* minimum_subgraph_size = new int(0); argument.Set<int>("minimum_subgraph_size", new int(0));
int* max_batch_size = new int(3); argument.Set<int>("max_batch_size", new int(3));
int* workspace_size = new int(1 << 20); argument.Set<int>("workspace_size", new int(1 << 20));
std::string* precision_mode = new std::string("FP32"); argument.Set<std::string>("precision_mode", new std::string("FP32"));
argument.Set<int>("minimum_subgraph_size", minimum_subgraph_size);
argument.Set<int>("max_batch_size", max_batch_size);
argument.Set<int>("workspace_size", workspace_size);
argument.Set<std::string>("precision_mode", precision_mode);
DFG_GraphvizDrawPass::Config config{FLAGS_dot_dir, "origin"}; DFG_GraphvizDrawPass::Config config{FLAGS_dot_dir, "origin"};
DFG_GraphvizDrawPass::Config config1{FLAGS_dot_dir, "fusion"}; DFG_GraphvizDrawPass::Config config1{FLAGS_dot_dir, "fusion"};
......
...@@ -90,14 +90,12 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { ...@@ -90,14 +90,12 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
// Analyze inference_program // Analyze inference_program
Argument argument; Argument argument;
int* minimum_subgraph_size = new int(config_.minimum_subgraph_size); argument.Set<int>("minimum_subgraph_size",
int* max_batch_size = new int(config_.max_batch_size); new int(config_.minimum_subgraph_size));
int* workspace_size = new int(config_.workspace_size); argument.Set<int>("max_batch_size", new int(config_.max_batch_size));
std::string* precision_mode = new std::string(config_.precision_mode); argument.Set<int>("workspace_size", new int(config_.workspace_size));
argument.Set<int>("minimum_subgraph_size", minimum_subgraph_size); argument.Set<std::string>("precision_mode",
argument.Set<int>("max_batch_size", max_batch_size); new std::string(config_.precision_mode));
argument.Set<int>("workspace_size", workspace_size);
argument.Set<std::string>("precision_mode", precision_mode);
if (!config_.model_dir.empty()) { if (!config_.model_dir.empty()) {
argument.fluid_model_dir.reset(new std::string(config_.model_dir)); argument.fluid_model_dir.reset(new std::string(config_.model_dir));
......
...@@ -153,8 +153,7 @@ struct TensorRTConfig : public NativeConfig { ...@@ -153,8 +153,7 @@ struct TensorRTConfig : public NativeConfig {
// We transform the Ops that can be converted into TRT layer in the model, // We transform the Ops that can be converted into TRT layer in the model,
// and aggregate these Ops into subgraphs for TRT execution. // and aggregate these Ops into subgraphs for TRT execution.
// We set this variable to control the minimum number of nodes in the // We set this variable to control the minimum number of nodes in the
// subgraph, 3 as // subgraph, 3 as default value.
// default value.
int minimum_subgraph_size = 3; int minimum_subgraph_size = 3;
// Reserved configuration // Reserved configuration
// We just support "FP32" now, "FP16" and "INT8" will be supported. // We just support "FP32" now, "FP16" and "INT8" will be supported.
......
...@@ -33,7 +33,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -33,7 +33,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::string>("subgraph", "the subgraph."); AddAttr<std::string>("subgraph", "the subgraph.");
AddAttr<std::string>("engine_uniq_key", "unique key for the TRT engine."); AddAttr<std::string>("engine_uniq_key", "unique key for the TRT engine.");
AddAttr<int>("max_batch_size", "the maximum batch size."); AddAttr<int>("max_batch_size", "the maximum batch size.");
AddAttr<int>("workspace_size", "the maximum batch size."); AddAttr<int>("workspace_size", "the workspace size.");
AddComment("TensorRT engine operator."); AddComment("TensorRT engine operator.");
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册