提交 27633216 编写于 作者: N nhzlx

fix comments

上级 0514882b
......@@ -37,14 +37,10 @@ TEST(Analyzer, analysis_without_tensorrt) {
TEST(Analyzer, analysis_with_tensorrt) {
FLAGS_IA_enable_tensorrt_subgraph_engine = true;
Argument argument;
int* minimum_subgraph_size = new int(0);
int* max_batch_size = new int(3);
int* workspace_size = new int(1 << 20);
std::string* precision_mode = new std::string("FP32");
argument.Set<int>("minimum_subgraph_size", minimum_subgraph_size);
argument.Set<int>("max_batch_size", max_batch_size);
argument.Set<int>("workspace_size", workspace_size);
argument.Set<std::string>("precision_mode", precision_mode);
argument.Set<int>("minimum_subgraph_size", new int(0));
argument.Set<int>("max_batch_size", new int(3));
argument.Set<int>("workspace_size", new int(1 << 20));
argument.Set<std::string>("precision_mode", new std::string("FP32"));
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
Analyzer analyser;
analyser.Run(&argument);
......
......@@ -99,6 +99,7 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) {
void CreateTrtEngineOp(Node *node, Argument *argument,
framework::proto::BlockDesc *block) {
PADDLE_ENFORCE(argument->main_dfg.get());
const DataFlowGraph &graph = *(argument->main_dfg);
static int counter{0};
PADDLE_ENFORCE(node->IsFunctionBlock());
......
......@@ -67,8 +67,7 @@ TEST(SubGraphSplitter, Fuse) {
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
auto dfg = ProgramDescToDFG(desc);
Argument argument;
int* minmum_subgraph_size = new int(3);
argument.Set<int>("minimum_subgraph_size", minmum_subgraph_size);
argument.Set<int>("minimum_subgraph_size", new int(3));
size_t count0 = dfg.nodes.size();
......
......@@ -36,14 +36,10 @@ TEST(TensorRTSubGraphPass, main) {
};
Argument argument(FLAGS_inference_model_dir);
int* minimum_subgraph_size = new int(0);
int* max_batch_size = new int(3);
int* workspace_size = new int(1 << 20);
std::string* precision_mode = new std::string("FP32");
argument.Set<int>("minimum_subgraph_size", minimum_subgraph_size);
argument.Set<int>("max_batch_size", max_batch_size);
argument.Set<int>("workspace_size", workspace_size);
argument.Set<std::string>("precision_mode", precision_mode);
argument.Set<int>("minimum_subgraph_size", new int(0));
argument.Set<int>("max_batch_size", new int(3));
argument.Set<int>("workspace_size", new int(1 << 20));
argument.Set<std::string>("precision_mode", new std::string("FP32"));
DFG_GraphvizDrawPass::Config config{FLAGS_dot_dir, "origin"};
DFG_GraphvizDrawPass::Config config1{FLAGS_dot_dir, "fusion"};
......
......@@ -90,14 +90,12 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
// Analyze inference_program
Argument argument;
int* minimum_subgraph_size = new int(config_.minimum_subgraph_size);
int* max_batch_size = new int(config_.max_batch_size);
int* workspace_size = new int(config_.workspace_size);
std::string* precision_mode = new std::string(config_.precision_mode);
argument.Set<int>("minimum_subgraph_size", minimum_subgraph_size);
argument.Set<int>("max_batch_size", max_batch_size);
argument.Set<int>("workspace_size", workspace_size);
argument.Set<std::string>("precision_mode", precision_mode);
argument.Set<int>("minimum_subgraph_size",
new int(config_.minimum_subgraph_size));
argument.Set<int>("max_batch_size", new int(config_.max_batch_size));
argument.Set<int>("workspace_size", new int(config_.workspace_size));
argument.Set<std::string>("precision_mode",
new std::string(config_.precision_mode));
if (!config_.model_dir.empty()) {
argument.fluid_model_dir.reset(new std::string(config_.model_dir));
......
......@@ -153,8 +153,7 @@ struct TensorRTConfig : public NativeConfig {
// We transform the Ops that can be converted into TRT layer in the model,
// and aggregate these Ops into subgraphs for TRT execution.
// We set this variable to control the minimum number of nodes in the
// subgraph, 3 as
// default value.
// subgraph, 3 as default value.
int minimum_subgraph_size = 3;
// Reserved configuration
// We just support "FP32" now, "FP16" and "INT8" will be supported.
......
......@@ -33,7 +33,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::string>("subgraph", "the subgraph.");
AddAttr<std::string>("engine_uniq_key", "unique key for the TRT engine.");
AddAttr<int>("max_batch_size", "the maximum batch size.");
AddAttr<int>("workspace_size", "the maximum batch size.");
AddAttr<int>("workspace_size", "the workspace size.");
AddComment("TensorRT engine operator.");
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册