提交 4801beb1 编写于 作者: N nhzlx

add arguments for trt config

上级 202e0a1e
......@@ -37,12 +37,20 @@ TEST(Analyzer, analysis_without_tensorrt) {
TEST(Analyzer, analysis_with_tensorrt) {
FLAGS_IA_enable_tensorrt_subgraph_engine = true;
Argument argument;
int* minimum_subgraph_size = new int(0);
int* max_batch_size = new int(3);
int* workspace_size = new int(1 << 20);
std::string* precision_mode = new std::string("FP32");
argument.Set<int>("minimum_subgraph_size", minimum_subgraph_size);
argument.Set<int>("max_batch_size", max_batch_size);
argument.Set<int>("workspace_size", workspace_size);
argument.Set<std::string>("precision_mode", precision_mode);
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
Analyzer analyser;
analyser.Run(&argument);
}
void TestWord2vecPrediction(const std::string &model_path) {
void TestWord2vecPrediction(const std::string& model_path) {
NativeConfig config;
config.model_dir = model_path;
config.use_gpu = false;
......@@ -73,8 +81,8 @@ void TestWord2vecPrediction(const std::string &model_path) {
// The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
LOG(INFO) << "data: "
<< static_cast<float *>(outputs.front().data.data())[i];
PADDLE_ENFORCE(static_cast<float *>(outputs.front().data.data())[i],
<< static_cast<float*>(outputs.front().data.data())[i];
PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i],
result[i]);
}
}
......
......@@ -309,7 +309,7 @@ void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); }
void SubGraphFuse::ReplaceNodesWithSubGraphs() {
auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)();
for (auto &subgraph : subgraphs) {
if (subgraph.size() <= argument_->Get<int>("minimun_subgraph_size"))
if (subgraph.size() <= argument_->Get<int>("minimum_subgraph_size"))
continue;
std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end());
// replace this sub-graph with the first node. Two steps: 1. Create a Block
......
......@@ -68,7 +68,7 @@ TEST(SubGraphSplitter, Fuse) {
auto dfg = ProgramDescToDFG(desc);
Argument argument;
int* minmum_subgraph_size = new int(3);
argument.Set<int>("minmum_subgraph_size", minmum_subgraph_size);
argument.Set<int>("minimum_subgraph_size", minmum_subgraph_size);
size_t count0 = dfg.nodes.size();
......
......@@ -36,6 +36,14 @@ TEST(TensorRTSubGraphPass, main) {
};
Argument argument(FLAGS_inference_model_dir);
int* minimum_subgraph_size = new int(0);
int* max_batch_size = new int(3);
int* workspace_size = new int(1 << 20);
std::string* precision_mode = new std::string("FP32");
argument.Set<int>("minimun_subgraph_size", minimum_subgraph_size);
argument.Set<int>("max_batch_size", max_batch_size);
argument.Set<int>("workspace_size", workspace_size);
argument.Set<std::string>("precision_mode", precision_mode);
DFG_GraphvizDrawPass::Config config{FLAGS_dot_dir, "origin"};
DFG_GraphvizDrawPass::Config config1{FLAGS_dot_dir, "fusion"};
......
......@@ -94,7 +94,7 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
int* max_batch_size = new int(config_.max_batch_size);
int* workspace_size = new int(config_.workspace_size);
std::string* precision_mode = new std::string(config_.precision_mode);
argument.Set<int>("minimun_subgraph_size", minimum_subgraph_size);
argument.Set<int>("minimum_subgraph_size", minimum_subgraph_size);
argument.Set<int>("max_batch_size", max_batch_size);
argument.Set<int>("workspace_size", workspace_size);
argument.Set<std::string>("precision_mode", precision_mode);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册