提交 59db5eff 编写于 作者: N nhzlx

add tensorrt ut and refine interface.

test=release/1.0.0
上级 644bad1d
...@@ -37,12 +37,16 @@ TEST(Analyzer, analysis_without_tensorrt) { ...@@ -37,12 +37,16 @@ TEST(Analyzer, analysis_without_tensorrt) {
TEST(Analyzer, analysis_with_tensorrt) { TEST(Analyzer, analysis_with_tensorrt) {
FLAGS_IA_enable_tensorrt_subgraph_engine = true; FLAGS_IA_enable_tensorrt_subgraph_engine = true;
Argument argument; Argument argument;
argument.Set<int>("minimum_subgraph_size", new int(0));
argument.Set<int>("max_batch_size", new int(3));
argument.Set<int>("workspace_size", new int(1 << 20));
argument.Set<std::string>("precision_mode", new std::string("FP32"));
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir)); argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
Analyzer analyser; Analyzer analyser;
analyser.Run(&argument); analyser.Run(&argument);
} }
void TestWord2vecPrediction(const std::string &model_path) { void TestWord2vecPrediction(const std::string& model_path) {
NativeConfig config; NativeConfig config;
config.model_dir = model_path; config.model_dir = model_path;
config.use_gpu = false; config.use_gpu = false;
...@@ -73,8 +77,8 @@ void TestWord2vecPrediction(const std::string &model_path) { ...@@ -73,8 +77,8 @@ void TestWord2vecPrediction(const std::string &model_path) {
// The outputs' buffers are in CPU memory. // The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(5UL, num_elements); i++) { for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
LOG(INFO) << "data: " LOG(INFO) << "data: "
<< static_cast<float *>(outputs.front().data.data())[i]; << static_cast<float*>(outputs.front().data.data())[i];
PADDLE_ENFORCE(static_cast<float *>(outputs.front().data.data())[i], PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i],
result[i]); result[i]);
} }
} }
......
...@@ -97,8 +97,10 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) { ...@@ -97,8 +97,10 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) {
} }
} }
void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, void CreateTrtEngineOp(Node *node, Argument *argument,
framework::proto::BlockDesc *block) { framework::proto::BlockDesc *block) {
PADDLE_ENFORCE(argument->main_dfg.get());
const DataFlowGraph &graph = *(argument->main_dfg);
static int counter{0}; static int counter{0};
PADDLE_ENFORCE(node->IsFunctionBlock()); PADDLE_ENFORCE(node->IsFunctionBlock());
framework::OpDesc desc; framework::OpDesc desc;
...@@ -204,7 +206,10 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, ...@@ -204,7 +206,10 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
PADDLE_ENFORCE(!block->vars().empty(), "the block has no var-desc"); PADDLE_ENFORCE(!block->vars().empty(), "the block has no var-desc");
// Set attrs // Set attrs
SetAttr(desc.Proto(), "subgraph", block->SerializeAsString()); SetAttr(desc.Proto(), "subgraph", block->SerializeAsString());
SetAttr(desc.Proto(), "max_batch_size", argument->Get<int>("max_batch_size"));
SetAttr(desc.Proto(), "workspace_size", argument->Get<int>("workspace_size"));
SetAttr(desc.Proto(), "engine_uniq_key", "trt-" + std::to_string(counter++)); SetAttr(desc.Proto(), "engine_uniq_key", "trt-" + std::to_string(counter++));
SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes())); SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes()));
SetAttr(desc.Proto(), "output_name_mapping", output_mapping); SetAttr(desc.Proto(), "output_name_mapping", output_mapping);
...@@ -248,7 +253,7 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) { ...@@ -248,7 +253,7 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) {
*block_desc.Proto()->mutable_vars() = *block_desc.Proto()->mutable_vars() =
argument_->origin_program_desc->blocks(0).vars(); argument_->origin_program_desc->blocks(0).vars();
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty()); PADDLE_ENFORCE(!block_desc.Proto()->vars().empty());
CreateTrtEngineOp(node, *argument_->main_dfg, block_desc.Proto()); CreateTrtEngineOp(node, argument_, block_desc.Proto());
auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex); auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
auto *op = main_block->add_ops(); auto *op = main_block->add_ops();
PADDLE_ENFORCE(!node->pb_msg().empty(), "failed to set desc for block"); PADDLE_ENFORCE(!node->pb_msg().empty(), "failed to set desc for block");
......
...@@ -309,6 +309,8 @@ void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); } ...@@ -309,6 +309,8 @@ void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); }
void SubGraphFuse::ReplaceNodesWithSubGraphs() { void SubGraphFuse::ReplaceNodesWithSubGraphs() {
auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)(); auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)();
for (auto &subgraph : subgraphs) { for (auto &subgraph : subgraphs) {
if (subgraph.size() <= argument_->Get<int>("minimum_subgraph_size"))
continue;
std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end()); std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end());
// replace this sub-graph with the first node. Two steps: 1. Create a Block // replace this sub-graph with the first node. Two steps: 1. Create a Block
// Node that contains this subgraph 2. Mark the nodes inside the sub-graph // Node that contains this subgraph 2. Mark the nodes inside the sub-graph
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/node.h" #include "paddle/fluid/inference/analysis/node.h"
...@@ -63,8 +64,11 @@ class SubGraphFuse { ...@@ -63,8 +64,11 @@ class SubGraphFuse {
public: public:
using NodeInsideSubgraphTeller = SubGraphSplitter::NodeInsideSubgraphTeller; using NodeInsideSubgraphTeller = SubGraphSplitter::NodeInsideSubgraphTeller;
SubGraphFuse(DataFlowGraph *graph, const NodeInsideSubgraphTeller &teller) SubGraphFuse(DataFlowGraph *graph, const NodeInsideSubgraphTeller &teller,
: graph_(graph), node_inside_subgraph_teller_(teller) {} Argument *argument)
: graph_(graph),
node_inside_subgraph_teller_(teller),
argument_(argument) {}
// The main method which run all the logic. // The main method which run all the logic.
void operator()(); void operator()();
...@@ -76,6 +80,7 @@ class SubGraphFuse { ...@@ -76,6 +80,7 @@ class SubGraphFuse {
private: private:
DataFlowGraph *graph_; DataFlowGraph *graph_;
NodeInsideSubgraphTeller node_inside_subgraph_teller_; NodeInsideSubgraphTeller node_inside_subgraph_teller_;
Argument *argument_;
}; };
} // namespace analysis } // namespace analysis
......
...@@ -66,10 +66,12 @@ TEST(SubGraphSplitter, Split) { ...@@ -66,10 +66,12 @@ TEST(SubGraphSplitter, Split) {
TEST(SubGraphSplitter, Fuse) { TEST(SubGraphSplitter, Fuse) {
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__"); auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
auto dfg = ProgramDescToDFG(desc); auto dfg = ProgramDescToDFG(desc);
Argument argument;
argument.Set<int>("minimum_subgraph_size", new int(3));
size_t count0 = dfg.nodes.size(); size_t count0 = dfg.nodes.size();
SubGraphFuse fuse(&dfg, teller); SubGraphFuse fuse(&dfg, teller, &argument);
fuse(); fuse();
int count1 = 0; int count1 = 0;
......
...@@ -24,7 +24,7 @@ TensorRTSubGraphPass::TensorRTSubGraphPass( ...@@ -24,7 +24,7 @@ TensorRTSubGraphPass::TensorRTSubGraphPass(
: node_inside_subgraph_teller_(teller) {} : node_inside_subgraph_teller_(teller) {}
void TensorRTSubGraphPass::Run(DataFlowGraph *graph) { void TensorRTSubGraphPass::Run(DataFlowGraph *graph) {
SubGraphFuse(graph, node_inside_subgraph_teller_)(); SubGraphFuse(graph, node_inside_subgraph_teller_, argument_)();
VLOG(4) << "debug info " VLOG(4) << "debug info "
<< graph->HumanReadableInfo(false /*show_values*/, << graph->HumanReadableInfo(false /*show_values*/,
true /*show_functions*/); true /*show_functions*/);
......
...@@ -33,7 +33,10 @@ class TensorRTSubGraphPass : public DataFlowGraphPass { ...@@ -33,7 +33,10 @@ class TensorRTSubGraphPass : public DataFlowGraphPass {
explicit TensorRTSubGraphPass(const NodeInsideSubgraphTeller& teller); explicit TensorRTSubGraphPass(const NodeInsideSubgraphTeller& teller);
bool Initialize(Argument* argument) override { return true; } bool Initialize(Argument* argument) override {
argument_ = argument;
return true;
}
// This class get a sub-graph as input and determine whether to transform this // This class get a sub-graph as input and determine whether to transform this
// sub-graph into TensorRT. // sub-graph into TensorRT.
...@@ -46,6 +49,7 @@ class TensorRTSubGraphPass : public DataFlowGraphPass { ...@@ -46,6 +49,7 @@ class TensorRTSubGraphPass : public DataFlowGraphPass {
private: private:
NodeInsideSubgraphTeller node_inside_subgraph_teller_; NodeInsideSubgraphTeller node_inside_subgraph_teller_;
Argument* argument_;
}; };
} // namespace analysis } // namespace analysis
......
...@@ -36,6 +36,10 @@ TEST(TensorRTSubGraphPass, main) { ...@@ -36,6 +36,10 @@ TEST(TensorRTSubGraphPass, main) {
}; };
Argument argument(FLAGS_inference_model_dir); Argument argument(FLAGS_inference_model_dir);
argument.Set<int>("minimum_subgraph_size", new int(0));
argument.Set<int>("max_batch_size", new int(3));
argument.Set<int>("workspace_size", new int(1 << 20));
argument.Set<std::string>("precision_mode", new std::string("FP32"));
DFG_GraphvizDrawPass::Config config{FLAGS_dot_dir, "origin"}; DFG_GraphvizDrawPass::Config config{FLAGS_dot_dir, "origin"};
DFG_GraphvizDrawPass::Config config1{FLAGS_dot_dir, "fusion"}; DFG_GraphvizDrawPass::Config config1{FLAGS_dot_dir, "fusion"};
......
...@@ -35,8 +35,6 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { ...@@ -35,8 +35,6 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
bool Init(const std::shared_ptr<framework::Scope>& parent_scope) { bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
FLAGS_IA_enable_tensorrt_subgraph_engine = true; FLAGS_IA_enable_tensorrt_subgraph_engine = true;
VLOG(3) << "Predictor::init()"; VLOG(3) << "Predictor::init()";
FLAGS_tensorrt_max_batch_size = config_.max_batch_size;
FLAGS_tensorrt_workspace_size = config_.workspace_size;
if (config_.use_gpu) { if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device); place_ = paddle::platform::CUDAPlace(config_.device);
} else { } else {
...@@ -92,6 +90,14 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { ...@@ -92,6 +90,14 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
void OptimizeInferenceProgram() { void OptimizeInferenceProgram() {
// Analyze inference_program // Analyze inference_program
Argument argument; Argument argument;
argument.Set<int>("minimum_subgraph_size",
new int(config_.minimum_subgraph_size));
argument.Set<int>("max_batch_size", new int(config_.max_batch_size));
argument.Set<int>("workspace_size", new int(config_.workspace_size));
argument.Set<std::string>("precision_mode",
new std::string(config_.precision_mode));
if (!config_.model_dir.empty()) { if (!config_.model_dir.empty()) {
argument.fluid_model_dir.reset(new std::string(config_.model_dir)); argument.fluid_model_dir.reset(new std::string(config_.model_dir));
} else { } else {
......
...@@ -194,6 +194,14 @@ struct MixedRTConfig : public NativeConfig { ...@@ -194,6 +194,14 @@ struct MixedRTConfig : public NativeConfig {
// For workspace_size, refer it from here: // For workspace_size, refer it from here:
// https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#troubleshooting // https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#troubleshooting
int workspace_size{1 << 30}; int workspace_size{1 << 30};
// We transform the Ops that can be converted into TRT layer in the model,
// and aggregate these Ops into subgraphs for TRT execution.
// We set this variable to control the minimum number of nodes in the
// subgraph, 3 as default value.
int minimum_subgraph_size = 3;
// Reserved configuration
// We just support "FP32" now, "FP16" and "INT8" will be supported.
std::string precision_mode = "FP32";
}; };
// NOTE WIP, not stable yet. // NOTE WIP, not stable yet.
......
...@@ -85,3 +85,13 @@ if (WITH_ANAKIN AND WITH_MKL) # only needed in CI ...@@ -85,3 +85,13 @@ if (WITH_ANAKIN AND WITH_MKL) # only needed in CI
DEPS inference_anakin_api_shared dynload_cuda SERIAL) DEPS inference_anakin_api_shared dynload_cuda SERIAL)
endif() endif()
endif() endif()
if(WITH_GPU AND TENSORRT_FOUND)
set(TRT_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/trt")
if (NOT EXISTS ${TRT_MODEL_INSTALL_DIR})
inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_test_models.tar.gz")
endif()
cc_test(test_trt_models SRCS trt_models_tester.cc
ARGS --dirname=${TRT_MODEL_INSTALL_DIR}/trt_test_models
DEPS paddle_inference_tensorrt_subgraph_engine)
endif()
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
namespace paddle {
using paddle::contrib::MixedRTConfig;
DEFINE_string(dirname, "", "Directory of the inference model.");
NativeConfig GetConfigNative() {
NativeConfig config;
config.model_dir = FLAGS_dirname;
// LOG(INFO) << "dirname " << config.model_dir;
config.fraction_of_gpu_memory = 0.45;
config.use_gpu = true;
config.device = 0;
return config;
}
MixedRTConfig GetConfigTRT() {
MixedRTConfig config;
config.model_dir = FLAGS_dirname;
config.use_gpu = true;
config.fraction_of_gpu_memory = 0.2;
config.device = 0;
config.max_batch_size = 3;
return config;
}
void CompareTensorRTWithFluid(int batch_size, std::string model_dirname) {
NativeConfig config0 = GetConfigNative();
config0.model_dir = model_dirname;
MixedRTConfig config1 = GetConfigTRT();
config1.model_dir = model_dirname;
config1.max_batch_size = batch_size;
auto predictor0 =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config0);
auto predictor1 =
CreatePaddlePredictor<MixedRTConfig,
PaddleEngineKind::kAutoMixedTensorRT>(config1);
// Prepare inputs
int height = 224;
int width = 224;
float *data = new float[batch_size * 3 * height * width];
memset(data, 0, sizeof(float) * (batch_size * 3 * height * width));
data[0] = 1.0f;
// Prepare inputs
PaddleTensor tensor;
tensor.name = "input_0";
tensor.shape = std::vector<int>({batch_size, 3, height, width});
tensor.data = PaddleBuf(static_cast<void *>(data),
sizeof(float) * (batch_size * 3 * height * width));
tensor.dtype = PaddleDType::FLOAT32;
std::vector<PaddleTensor> paddle_tensor_feeds(1, tensor);
// Prepare outputs
std::vector<PaddleTensor> outputs0;
std::vector<PaddleTensor> outputs1;
CHECK(predictor0->Run(paddle_tensor_feeds, &outputs0));
CHECK(predictor1->Run(paddle_tensor_feeds, &outputs1, batch_size));
// Get output.
ASSERT_EQ(outputs0.size(), 1UL);
ASSERT_EQ(outputs1.size(), 1UL);
const size_t num_elements = outputs0.front().data.length() / sizeof(float);
const size_t num_elements1 = outputs1.front().data.length() / sizeof(float);
EXPECT_EQ(num_elements, num_elements1);
auto *data0 = static_cast<float *>(outputs0.front().data.data());
auto *data1 = static_cast<float *>(outputs1.front().data.data());
ASSERT_GT(num_elements, 0UL);
for (size_t i = 0; i < std::min(num_elements, num_elements1); i++) {
EXPECT_NEAR(data0[i], data1[i], 1e-3);
}
}
TEST(trt_models_test, main) {
std::vector<std::string> infer_models = {"mobilenet", "resnet50",
"resnext50"};
for (auto &model_dir : infer_models) {
CompareTensorRTWithFluid(1, FLAGS_dirname + "/" + model_dir);
}
}
} // namespace paddle
...@@ -22,8 +22,6 @@ ...@@ -22,8 +22,6 @@
namespace paddle { namespace paddle {
DEFINE_int32(tensorrt_engine_batch_size, 1, "the batch_size of TensorRT"); DEFINE_int32(tensorrt_engine_batch_size, 1, "the batch_size of TensorRT");
DEFINE_int32(tensorrt_max_batch_size, 1, "TensorRT maximum batch size");
DEFINE_int32(tensorrt_workspace_size, 16 << 20, "TensorRT workspace size");
namespace operators { namespace operators {
...@@ -34,6 +32,8 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -34,6 +32,8 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Ys", "A list of outputs").AsDuplicable(); AddOutput("Ys", "A list of outputs").AsDuplicable();
AddAttr<std::string>("subgraph", "the subgraph."); AddAttr<std::string>("subgraph", "the subgraph.");
AddAttr<std::string>("engine_uniq_key", "unique key for the TRT engine."); AddAttr<std::string>("engine_uniq_key", "unique key for the TRT engine.");
AddAttr<int>("max_batch_size", "the maximum batch size.");
AddAttr<int>("workspace_size", "the workspace size.");
AddComment("TensorRT engine operator."); AddComment("TensorRT engine operator.");
} }
}; };
......
...@@ -28,8 +28,6 @@ ...@@ -28,8 +28,6 @@
namespace paddle { namespace paddle {
DECLARE_int32(tensorrt_engine_batch_size); DECLARE_int32(tensorrt_engine_batch_size);
DECLARE_int32(tensorrt_max_batch_size);
DECLARE_int32(tensorrt_workspace_size);
namespace operators { namespace operators {
...@@ -92,14 +90,14 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -92,14 +90,14 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto engine_name = context.Attr<std::string>("engine_uniq_key"); auto engine_name = context.Attr<std::string>("engine_uniq_key");
int max_batch_size = context.Attr<int>("max_batch_size");
if (!Singleton<TRT_EngineManager>::Global().HasEngine(engine_name)) { if (!Singleton<TRT_EngineManager>::Global().HasEngine(engine_name)) {
Prepare(context); Prepare(context);
} }
auto* engine = Singleton<TRT_EngineManager>::Global().Get(engine_name); auto* engine = Singleton<TRT_EngineManager>::Global().Get(engine_name);
auto input_names = context.op().Inputs("Xs"); auto input_names = context.op().Inputs("Xs");
PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs"); PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs");
PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size, PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size, max_batch_size);
FLAGS_tensorrt_max_batch_size);
std::vector<std::string> output_maps = std::vector<std::string> output_maps =
context.Attr<std::vector<std::string>>("output_name_mapping"); context.Attr<std::vector<std::string>>("output_name_mapping");
...@@ -173,8 +171,9 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -173,8 +171,9 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
// Get the ProgramDesc and pass to convert. // Get the ProgramDesc and pass to convert.
framework::proto::BlockDesc block_desc; framework::proto::BlockDesc block_desc;
block_desc.ParseFromString(context.Attr<std::string>("subgraph")); block_desc.ParseFromString(context.Attr<std::string>("subgraph"));
int max_batch = FLAGS_tensorrt_max_batch_size; int max_batch_size = context.Attr<int>("max_batch_size");
auto max_workspace = FLAGS_tensorrt_workspace_size; int workspace_size = context.Attr<int>("workspace_size");
auto params = context.Attr<std::vector<std::string>>("parameters"); auto params = context.Attr<std::vector<std::string>>("parameters");
std::unordered_set<std::string> parameters; std::unordered_set<std::string> parameters;
for (const auto& param : params) { for (const auto& param : params) {
...@@ -186,7 +185,7 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -186,7 +185,7 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
// TODO(Superjomn) replace this with a different stream // TODO(Superjomn) replace this with a different stream
auto* engine = Singleton<TRT_EngineManager>::Global().Create( auto* engine = Singleton<TRT_EngineManager>::Global().Create(
max_batch, max_workspace, nullptr /*engine hold its own stream*/, max_batch_size, workspace_size, nullptr /*engine hold its own stream*/,
context.Attr<std::string>("engine_uniq_key"), context.Attr<std::string>("engine_uniq_key"),
boost::get<platform::CUDAPlace>(context.GetPlace()).device); boost::get<platform::CUDAPlace>(context.GetPlace()).device);
......
...@@ -58,8 +58,6 @@ void AddTensorToBlockDesc(framework::proto::BlockDesc* block, ...@@ -58,8 +58,6 @@ void AddTensorToBlockDesc(framework::proto::BlockDesc* block,
using inference::analysis::SetAttr; using inference::analysis::SetAttr;
TEST(TensorRTEngineOp, manual) { TEST(TensorRTEngineOp, manual) {
FLAGS_tensorrt_engine_batch_size = 2;
FLAGS_tensorrt_max_batch_size = 2;
framework::ProgramDesc program; framework::ProgramDesc program;
auto* block_ = program.Proto()->add_blocks(); auto* block_ = program.Proto()->add_blocks();
block_->set_idx(0); block_->set_idx(0);
...@@ -101,6 +99,8 @@ TEST(TensorRTEngineOp, manual) { ...@@ -101,6 +99,8 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"})); engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"}));
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph", SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString()); block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", 2);
SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 2 << 10);
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "a_engine"); SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "a_engine");
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(), "parameters", SetAttr<std::vector<std::string>>(engine_op_desc.Proto(), "parameters",
std::vector<std::string>({})); std::vector<std::string>({}));
...@@ -129,8 +129,6 @@ TEST(TensorRTEngineOp, manual) { ...@@ -129,8 +129,6 @@ TEST(TensorRTEngineOp, manual) {
} }
void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
FLAGS_tensorrt_engine_batch_size = batch_size;
FLAGS_tensorrt_max_batch_size = batch_size;
framework::ProgramDesc program; framework::ProgramDesc program;
framework::Scope scope; framework::Scope scope;
platform::CUDAPlace place; platform::CUDAPlace place;
...@@ -195,8 +193,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { ...@@ -195,8 +193,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph", SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString()); block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch", batch_size); SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", batch_size);
SetAttr<int>(engine_op_desc.Proto(), "max_workspace", 2 << 10); SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 2 << 10);
SetAttr<std::vector<std::string>>( SetAttr<std::vector<std::string>>(
engine_op_desc.Proto(), "parameters", engine_op_desc.Proto(), "parameters",
std::vector<std::string>({"y0", "y1", "y2", "y3"})); std::vector<std::string>({"y0", "y1", "y2", "y3"}));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册