diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 596e0fe9da3d272ecb1c0f8dbef09a75d08a4b1a..5198da84a461a25c21987f38fde3df50f567959f 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -26,6 +26,8 @@ namespace paddle { namespace inference { namespace tensorrt { +int TensorRTEngine::runtime_batch_ = 1; + void TensorRTEngine::Build(const DescType& paddle_model) { PADDLE_ENFORCE(false, "not implemented"); } @@ -40,6 +42,7 @@ void TensorRTEngine::Execute(int batch_size) { } infer_context_->enqueue(batch_size, buffers.data(), *stream_, nullptr); cudaStreamSynchronize(*stream_); + SetRuntimeBatch(batch_size); } TensorRTEngine::~TensorRTEngine() { @@ -76,14 +79,15 @@ void TensorRTEngine::FreezeNetwork() { auto dims = infer_engine_->getBindingDimensions(slot_offset); item.second = kDataTypeSize[static_cast( infer_engine_->getBindingDataType(slot_offset))] * - analysis::AccuDims(dims.d, dims.nbDims); + analysis::AccuDims(dims.d, dims.nbDims) * max_batch_; } auto& buf = buffer(item.first); CHECK(buf.buffer == nullptr); // buffer should be allocated only once. - PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second)); + PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second * max_batch_)); VLOG(4) << "buffer malloc " << item.first << " " << item.second << " " << buf.buffer; - buf.size = buf.max_size = item.second; + buf.size = item.second; + buf.max_size = item.second * max_batch_; buf.device = DeviceType::GPU; } } @@ -98,7 +102,7 @@ nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name, auto* input = infer_network_->addInput(name.c_str(), dtype, dims); PADDLE_ENFORCE(input, "infer network add input %s failed", name); buffer_sizes_[name] = kDataTypeSize[static_cast(dtype)] * - analysis::AccuDims(dims.d, dims.nbDims); + analysis::AccuDims(dims.d, dims.nbDims) * max_batch_; PADDLE_ENFORCE(input->isNetworkInput()); TensorRTEngine::SetITensor(name, input); return input; @@ -139,30 +143,40 @@ void* TensorRTEngine::GetOutputInGPU(const std::string& name) { return buffer(name).buffer; } -void TensorRTEngine::GetOutputInGPU(const std::string& name, void* dst, - size_t max_size) { +void TensorRTEngine::GetOutputInGPU(const std::string& name, void* dst) { // determine data size + auto* output = TensorRTEngine::GetITensor(name); + nvinfer1::Dims dims = output->getDimensions(); + auto dim_size = analysis::AccuDims(dims.d, dims.nbDims); + size_t dst_size = dim_size * runtime_batch_ * + kDataTypeSize[static_cast(output->getType())]; + auto it = buffer_sizes_.find(name); PADDLE_ENFORCE(it != buffer_sizes_.end()); PADDLE_ENFORCE_GT(it->second, 0); - PADDLE_ENFORCE_GE(max_size, it->second); + PADDLE_ENFORCE_LE(dst_size, it->second); auto& buf = buffer(name); PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before"); - PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, it->second, + PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, dst_size, cudaMemcpyDeviceToDevice, *stream_), 0); } -void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst, - size_t max_size) { +void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst) { // determine data size + + auto* output = TensorRTEngine::GetITensor(name); + nvinfer1::Dims dims = output->getDimensions(); + auto dim_size = analysis::AccuDims(dims.d, dims.nbDims); + size_t dst_size = dim_size * runtime_batch_ * + kDataTypeSize[static_cast(output->getType())]; auto it = buffer_sizes_.find(name); PADDLE_ENFORCE(it != buffer_sizes_.end()); PADDLE_ENFORCE_GT(it->second, 0); - PADDLE_ENFORCE_GE(max_size, it->second); + PADDLE_ENFORCE_LE(dst_size, it->second); auto& buf = buffer(name); PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before"); - PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, it->second, + PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, dst_size, cudaMemcpyDeviceToHost, *stream_)); } @@ -207,6 +221,12 @@ nvinfer1::ITensor* TensorRTEngine::GetITensor(const std::string& name) { return itensor_map_[name]; } +void TensorRTEngine::SetRuntimeBatch(size_t batch_size) { + runtime_batch_ = batch_size; +} + +int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; } + } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index b06a9bbc6758ae9410b2fce99ef2b1a9e7ab98c0..ed6d4fd11c462e9c039c7662d2f8a5fd6371f486 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -104,10 +104,10 @@ class TensorRTEngine : public EngineBase { // Return the output's GPU memory address without copy. void* GetOutputInGPU(const std::string& name); // Copy data into dst inside the GPU device. - void GetOutputInGPU(const std::string& name, void* dst, size_t max_size); + void GetOutputInGPU(const std::string& name, void* dst); // LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU // to CPU. - void GetOutputInCPU(const std::string& name, void* dst, size_t max_size); + void GetOutputInCPU(const std::string& name, void* dst); // Fill an ITensor into map itensor_map_. void SetITensor(const std::string& name, nvinfer1::ITensor* tensor); // Get an ITensor called name. @@ -115,10 +115,14 @@ class TensorRTEngine : public EngineBase { nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); } nvinfer1::INetworkDefinition* network() { return infer_network_.get(); } + void SetRuntimeBatch(size_t batch_size); + int GetRuntimeBatch(); private: // the max batch size int max_batch_; + // the runtime batch size + static int runtime_batch_; // the max memory size the engine uses int max_workspace_; cudaStream_t* stream_; diff --git a/paddle/fluid/inference/tensorrt/test_engine.cc b/paddle/fluid/inference/tensorrt/test_engine.cc index e635f0f87d577a1f1ac74687ee60f762be525418..d3387939b2f10fbc06036a5d33311b6015ff1562 100644 --- a/paddle/fluid/inference/tensorrt/test_engine.cc +++ b/paddle/fluid/inference/tensorrt/test_engine.cc @@ -28,7 +28,7 @@ class TensorRTEngineTest : public ::testing::Test { protected: void SetUp() override { ASSERT_EQ(0, cudaStreamCreate(&stream_)); - engine_ = new TensorRTEngine(1, 1 << 10, &stream_); + engine_ = new TensorRTEngine(10, 1 << 10, &stream_); engine_->InitNetwork(); } @@ -71,7 +71,7 @@ TEST_F(TensorRTEngineTest, add_layer) { LOG(INFO) << "to get output"; float y_cpu; - engine_->GetOutputInCPU("y", &y_cpu, sizeof(float)); + engine_->GetOutputInCPU("y", &y_cpu); LOG(INFO) << "to checkout output"; ASSERT_EQ(y_cpu, x_v * 2 + 3); @@ -103,11 +103,44 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) { LOG(INFO) << "to get output"; float y_cpu[2] = {-1., -1.}; - engine_->GetOutputInCPU("y", &y_cpu[0], sizeof(float) * 2); + engine_->GetOutputInCPU("y", &y_cpu[0]); ASSERT_EQ(y_cpu[0], 4.5); ASSERT_EQ(y_cpu[1], 14.5); } +TEST_F(TensorRTEngineTest, test_conv2d_temp) { + // Weight in CPU memory. + float raw_weight[9] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + float raw_bias[1] = {0}; + + TensorRTEngine::Weight weight(nvinfer1::DataType::kFLOAT, raw_weight, 9); + TensorRTEngine::Weight bias(nvinfer1::DataType::kFLOAT, raw_bias, 1); + auto* x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT, + nvinfer1::Dims3{1, 3, 3}); + auto* conv_layer = + TRT_ENGINE_ADD_LAYER(engine_, Convolution, *x, 1, nvinfer1::DimsHW{3, 3}, + weight.get(), bias.get()); + PADDLE_ENFORCE(conv_layer != nullptr); + conv_layer->setStride(nvinfer1::DimsHW{1, 1}); + conv_layer->setPadding(nvinfer1::DimsHW{1, 1}); + + engine_->DeclareOutput(conv_layer, 0, "y"); + engine_->FreezeNetwork(); + ASSERT_EQ(engine_->engine()->getNbBindings(), 2); + + float x_v[18] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + engine_->SetInputFromCPU("x", reinterpret_cast(&x_v), + 18 * sizeof(float)); + engine_->Execute(2); + + LOG(INFO) << "to get output"; + float* y_cpu = new float[18]; + engine_->GetOutputInCPU("y", &y_cpu[0]); + ASSERT_EQ(y_cpu[0], 4.0); + ASSERT_EQ(y_cpu[1], 6.0); +} + } // namespace tensorrt } // namespace inference } // namespace paddle