From 8c171902798a9325e0efe01e81c7d6c44ad7119f Mon Sep 17 00:00:00 2001 From: nhzlx Date: Thu, 14 Feb 2019 09:10:41 +0000 Subject: [PATCH] 2. TRTEngine using stream only when execute. --- .../inference/tensorrt/convert/ut_helper.h | 6 ++-- paddle/fluid/inference/tensorrt/engine.cc | 33 +++--------------- paddle/fluid/inference/tensorrt/engine.h | 21 +++++------- .../fluid/inference/tensorrt/test_engine.cc | 10 +++--- .../operators/tensorrt/tensorrt_engine_op.h | 34 +++++++------------ 5 files changed, 31 insertions(+), 73 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/ut_helper.h b/paddle/fluid/inference/tensorrt/convert/ut_helper.h index 3298a103a2..c02a6d8da3 100644 --- a/paddle/fluid/inference/tensorrt/convert/ut_helper.h +++ b/paddle/fluid/inference/tensorrt/convert/ut_helper.h @@ -79,7 +79,7 @@ class TRTConvertValidation { if_add_batch_(if_add_batch), max_batch_size_(max_batch_size) { PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0); - engine_.reset(new TensorRTEngine(max_batch_size, workspace_size, stream_)); + engine_.reset(new TensorRTEngine(max_batch_size, workspace_size)); engine_->InitNetwork(); } @@ -192,9 +192,7 @@ class TRTConvertValidation { } // Execute TRT. - engine_->Execute(batch_size, buffers); - - cudaStreamSynchronize(engine_->stream()); + engine_->Execute(batch_size, &buffers, stream_); ASSERT_FALSE(op_desc_->OutputArgumentNames().empty()); int index = 0; diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 1d07b373da..805f047c96 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -32,39 +32,14 @@ void TensorRTEngine::Build(const DescType &paddle_model) { PADDLE_ENFORCE(false, "not implemented"); } -void TensorRTEngine::Execute(int batch_size, std::vector &buffers) { +void TensorRTEngine::Execute(int batch_size, std::vector *buffers, + cudaStream_t stream) { batch_size_ = batch_size; - infer_context_->enqueue(batch_size, buffers.data(), stream_, nullptr); - cudaStreamSynchronize(stream_); + infer_context_->enqueue(batch_size, buffers->data(), stream, nullptr); + cudaStreamSynchronize(stream); SetRuntimeBatch(batch_size); } -void TensorRTEngine::Execute(int batch_size) { - batch_size_ = batch_size; - std::vector buffers; - for (auto &buf : buffers_) { - PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated"); - PADDLE_ENFORCE_GT(buf.max_size, 0); - PADDLE_ENFORCE(buf.device == DeviceType::GPU); - buffers.push_back(buf.buffer); - } - infer_context_->enqueue(batch_size, buffers.data(), stream_, nullptr); - cudaStreamSynchronize(stream_); - SetRuntimeBatch(batch_size); -} - -TensorRTEngine::~TensorRTEngine() { - cudaStreamSynchronize(stream_); - // clean buffer - for (auto &buf : buffers_) { - if (buf.device == DeviceType::GPU && buf.buffer != nullptr) { - PADDLE_ENFORCE_EQ(0, cudaFree(buf.buffer)); - buf.buffer = nullptr; - buf.max_size = 0; - } - } -} - void TensorRTEngine::FreezeNetwork() { VLOG(3) << "TRT to freeze network"; PADDLE_ENFORCE(infer_builder_ != nullptr, diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 3955983658..e1005e9b03 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -37,7 +37,9 @@ class TRTInt8Calibrator; * There are two alternative ways to use it, one is to build from a paddle * protobuf model, another way is to manully construct the network. */ -class TensorRTEngine : public EngineBase { +class TensorRTEngine { + using DescType = ::paddle::framework::proto::BlockDesc; + public: // Weight is model parameter. class Weight { @@ -56,24 +58,22 @@ class TensorRTEngine : public EngineBase { nvinfer1::Weights w_; }; - TensorRTEngine(int max_batch, int max_workspace, cudaStream_t stream, - bool enable_int8 = false, + TensorRTEngine(int max_batch, int max_workspace, bool enable_int8 = false, TRTInt8Calibrator* calibrator = nullptr, nvinfer1::ILogger& logger = NaiveLogger::Global()) : max_batch_(max_batch), max_workspace_(max_workspace), - stream_(stream), enable_int8_(enable_int8), calibrator_(calibrator), logger_(logger) {} - virtual ~TensorRTEngine(); + ~TensorRTEngine() {} // TODO(Superjomn) implement it later when graph segmentation is supported. - void Build(const DescType& paddle_model) override; + void Build(const DescType& paddle_model); - void Execute(int batch_size) override; - void Execute(int batch_size, std::vector& buffers); + void Execute(int batch_size, std::vector* buffers, + cudaStream_t stream); // Initialize the inference network, so that TensorRT layers can add to this // network. @@ -98,8 +98,6 @@ class TensorRTEngine : public EngineBase { // Check if the ITensor has been declared bool HasDeclared(const std::string& name); - cudaStream_t stream() { return stream_; } - void SetITensor(const std::string& name, nvinfer1::ITensor* tensor); // Get an ITensor called name. nvinfer1::ITensor* GetITensor(const std::string& name); @@ -127,8 +125,6 @@ class TensorRTEngine : public EngineBase { // the max memory size the engine uses int max_workspace_; - cudaStream_t stream_; - bool enable_int8_; TRTInt8Calibrator* calibrator_; // batch size of the current data, will be updated each Executation. @@ -136,7 +132,6 @@ class TensorRTEngine : public EngineBase { nvinfer1::ILogger& logger_; - std::vector buffers_; // max data size for the buffers. std::unordered_map buffer_sizes_; std::unordered_map diff --git a/paddle/fluid/inference/tensorrt/test_engine.cc b/paddle/fluid/inference/tensorrt/test_engine.cc index 961b24960b..784290fa44 100644 --- a/paddle/fluid/inference/tensorrt/test_engine.cc +++ b/paddle/fluid/inference/tensorrt/test_engine.cc @@ -31,7 +31,7 @@ class TensorRTEngineTest : public ::testing::Test { void SetUp() override { ctx_ = new platform::CUDADeviceContext(platform::CUDAPlace(0)); - engine_ = new TensorRTEngine(10, 1 << 10, ctx_->stream()); + engine_ = new TensorRTEngine(10, 1 << 10); engine_->InitNetwork(); } @@ -88,7 +88,7 @@ TEST_F(TensorRTEngineTest, add_layer) { buffers[1] = reinterpret_cast(y_gpu_data); LOG(INFO) << "to execute"; - engine_->Execute(1, buffers); + engine_->Execute(1, &buffers, ctx_->stream()); LOG(INFO) << "to get output"; GetOutput(&y_cpu); @@ -128,7 +128,7 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) { buffers[0] = reinterpret_cast(x_v_gpu_data); buffers[1] = reinterpret_cast(y_gpu_data); - engine_->Execute(1, buffers); + engine_->Execute(1, &buffers, ctx_->stream()); LOG(INFO) << "to get output"; GetOutput(&y_cpu); @@ -175,7 +175,7 @@ TEST_F(TensorRTEngineTest, test_conv2d) { buffers[0] = reinterpret_cast(x_v_gpu_data); buffers[1] = reinterpret_cast(y_gpu_data); - engine_->Execute(2, buffers); + engine_->Execute(2, &buffers, ctx_->stream()); LOG(INFO) << "to get output"; GetOutput(&y_cpu); @@ -214,7 +214,7 @@ TEST_F(TensorRTEngineTest, test_pool2d) { buffers[0] = reinterpret_cast(x_v_gpu_data); buffers[1] = reinterpret_cast(y_gpu_data); - engine_->Execute(2, buffers); + engine_->Execute(2, &buffers, ctx_->stream()); LOG(INFO) << "to get output"; GetOutput(&y_cpu); diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index d3efea2812..33bbb6f165 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -142,10 +142,6 @@ class TensorRTEngineOp : public framework::OperatorBase { LOG_FIRST_N(INFO, 1) << "The TRT engine: " << engine_key_ << " is running calibration trt int8... "; int runtime_batch = 1; - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(dev_place); - auto stream = - reinterpret_cast(dev_ctx).stream(); if (!Singleton::Global().Has(engine_key_)) { TRTCalibratorEngine *calib_res = Singleton::Global().Create(engine_key_); @@ -162,10 +158,10 @@ class TensorRTEngineOp : public framework::OperatorBase { calib_buffers, runtime_batch, engine_key_, dev_place)); calib_res->thr_.reset(new std::thread([&]() { calib_res->engine_.reset( - new TensorRTEngine(max_batch_size_, workspace_size_, stream, - enable_int8_, calib_res->calib_.get())); + new TensorRTEngine(max_batch_size_, workspace_size_, enable_int8_, + calib_res->calib_.get())); VLOG(3) << "start the calib trt engine thread"; - Prepare(scope, dev_place, calib_res->engine_.get()); + Prepare(scope, calib_res->engine_.get()); })); } @@ -253,22 +249,17 @@ class TensorRTEngineOp : public framework::OperatorBase { PADDLE_ENFORCE_LE(runtime_batch, max_batch_size_); // Execute the engine. - engine->Execute(runtime_batch, buffers); + engine->Execute(runtime_batch, &buffers, stream); cudaStreamSynchronize(stream); } TensorRTEngine *GetEngine(const framework::Scope &scope, const platform::Place &dev_place) const { - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(dev_place); - auto stream = - reinterpret_cast(dev_ctx).stream(); if (trt_engine_.get() == nullptr) { trt_engine_.reset(new TensorRTEngine(max_batch_size_, workspace_size_, - stream, enable_int8_, - calibrator_.get())); + enable_int8_, calibrator_.get())); if (true) { - Prepare(scope, dev_place, trt_engine_.get()); + Prepare(scope, trt_engine_.get()); } else { // create static engine } @@ -276,20 +267,19 @@ class TensorRTEngineOp : public framework::OperatorBase { return trt_engine_.get(); } - void Prepare(const framework::Scope &scope, const platform::Place &dev_place, - TensorRTEngine *engine) const { + void Prepare(const framework::Scope &scope, TensorRTEngine *engine) const { LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP " "kernel etc). This process may cost a lot of time."; framework::proto::BlockDesc block_desc; block_desc.ParseFromString(Attr("subgraph")); - - std::vector output_maps = - Attr>("output_name_mapping"); + framework::BlockDesc block(nullptr /*programdesc*/, &block_desc); engine->InitNetwork(); - framework::BlockDesc block(nullptr /*programdesc*/, &block_desc); VLOG(4) << "parsed var size " << block.AllVars().size(); + std::vector output_maps = + Attr>("output_name_mapping"); + // Add inputs VLOG(4) << "declare inputs"; for (auto &input : Inputs("Xs")) { @@ -306,12 +296,12 @@ class TensorRTEngineOp : public framework::OperatorBase { PADDLE_ENFORCE(var, "no variable called %s", input); PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR, "TensorRT engine only takes LoDTensor as input"); - engine->DeclareInput( input, FluidDataType2TRT( var->Proto()->type().lod_tensor().tensor().data_type()), Vec2TRT_Dims(t_shape)); } + inference::Singleton::Global() .ConvertBlock(block_desc, param_names_, scope, engine); -- GitLab