From 2d57158e2ba24e53e98f3df5da44d48aa5d82878 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Wed, 25 Apr 2018 20:29:25 +0800 Subject: [PATCH] fea/init tensorrt engine (#10003) --- paddle/fluid/inference/engine.h | 53 +++++++ .../fluid/inference/tensorrt/CMakeLists.txt | 5 +- paddle/fluid/inference/tensorrt/engine.cc | 134 ++++++++++++++++ paddle/fluid/inference/tensorrt/engine.h | 144 ++++++++++++++++++ paddle/fluid/inference/tensorrt/helper.h | 88 +++++++++++ .../fluid/inference/tensorrt/test_engine.cc | 83 ++++++++++ .../fluid/inference/tensorrt/test_tensorrt.cc | 18 +-- 7 files changed, 515 insertions(+), 10 deletions(-) create mode 100644 paddle/fluid/inference/engine.h create mode 100644 paddle/fluid/inference/tensorrt/engine.cc create mode 100644 paddle/fluid/inference/tensorrt/engine.h create mode 100644 paddle/fluid/inference/tensorrt/helper.h create mode 100644 paddle/fluid/inference/tensorrt/test_engine.cc diff --git a/paddle/fluid/inference/engine.h b/paddle/fluid/inference/engine.h new file mode 100644 index 00000000000..0633c052e4d --- /dev/null +++ b/paddle/fluid/inference/engine.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/framework.pb.h" + +namespace paddle { +namespace inference { + +/* + * EngineBase is the base class of all inference engines. An inference engine + * takes a paddle program as input, and outputs the result in fluid Tensor + * format. It can be used to optimize performance of computation sub-blocks, for + * example, break down the original block into sub-blocks and execute each + * sub-blocks in different engines. + * + * For example: + * When inference, the resnet50 model can put most of the model into subgraph + * and run it on a TensorRT engine. + * + * There are several engines such as TensorRT and other frameworks, so an + * EngineBase is put forward to give an unified interface for all the + * different engine implemention. + */ +class EngineBase { + public: + using DescType = ::paddle::framework::proto::BlockDesc; + + // Build the model and do some preparation, for example, in TensorRT, run + // createInferBuilder, buildCudaEngine. + virtual void Build(const DescType& paddle_model) = 0; + + // Execute the engine, that will run the inference network. + virtual void Execute(int batch_size) = 0; + + virtual ~EngineBase() {} + +}; // class EngineBase + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt index e39c0daac76..4b5866ad5dd 100644 --- a/paddle/fluid/inference/tensorrt/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt @@ -1 +1,4 @@ -nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) +if(WITH_TESTING) + nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) + nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda) +endif() diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc new file mode 100644 index 00000000000..276502e4999 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -0,0 +1,134 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/engine.h" + +#include +#include +#include +#include "paddle/fluid/inference/tensorrt/helper.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +void TensorRTEngine::Build(const DescType& paddle_model) { + PADDLE_ENFORCE(false, "not implemented"); +} + +void TensorRTEngine::Execute(int batch_size) { + infer_context_->enqueue(batch_size, buffers_.data(), *stream_, nullptr); + cudaStreamSynchronize(*stream_); +} + +TensorRTEngine::~TensorRTEngine() { + // clean buffer + for (auto& buffer : buffers_) { + if (buffer != nullptr) { + PADDLE_ENFORCE_EQ(0, cudaFree(buffer)); + buffer = nullptr; + } + } +} + +void TensorRTEngine::FreezeNetwork() { + PADDLE_ENFORCE(infer_builder_ != nullptr, + "Call InitNetwork first to initialize network."); + PADDLE_ENFORCE(infer_network_ != nullptr, + "Call InitNetwork first to initialize network."); + // build engine. + infer_builder_->setMaxBatchSize(max_batch_); + infer_builder_->setMaxWorkspaceSize(max_workspace_); + + infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_)); + PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!"); + + infer_context_.reset(infer_engine_->createExecutionContext()); + + // allocate GPU buffers. + buffers_.resize(buffer_sizes_.size(), nullptr); + for (auto& item : buffer_sizes_) { + if (item.second == 0) { + auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str()); + item.second = kDataTypeSize[static_cast( + infer_engine_->getBindingDataType(slot_offset))] * + AccumDims(infer_engine_->getBindingDimensions(slot_offset)); + } + PADDLE_ENFORCE_EQ(0, cudaMalloc(&buffer(item.first), item.second)); + } +} + +nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name, + nvinfer1::DataType dtype, + const nvinfer1::Dims& dim) { + PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s", + name); + + PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first"); + auto* input = infer_network_->addInput(name.c_str(), dtype, dim); + PADDLE_ENFORCE(input, "infer network add input %s failed", name); + + buffer_sizes_[name] = kDataTypeSize[static_cast(dtype)] * AccumDims(dim); + return input; +} + +void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset, + const std::string& name) { + PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s", + name); + + auto* output = layer->getOutput(offset); + PADDLE_ENFORCE(output != nullptr); + output->setName(name.c_str()); + infer_network_->markOutput(*output); + // output buffers' size can only be decided latter, set zero here to mark this + // and will reset latter. + buffer_sizes_[name] = 0; +} + +void* TensorRTEngine::GetOutputInGPU(const std::string& name) { + return buffer(name); +} + +void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst, + size_t max_size) { + // determine data size + auto it = buffer_sizes_.find(name); + PADDLE_ENFORCE(it != buffer_sizes_.end()); + PADDLE_ENFORCE_GT(it->second, 0); + PADDLE_ENFORCE_GE(max_size, it->second); + + PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buffer(name), it->second, + cudaMemcpyDeviceToHost, *stream_)); +} + +void*& TensorRTEngine::buffer(const std::string& name) { + PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first."); + auto it = buffer_sizes_.find(name); + PADDLE_ENFORCE(it != buffer_sizes_.end()); + auto slot_offset = infer_engine_->getBindingIndex(name.c_str()); + return buffers_[slot_offset]; +} + +void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data, + size_t size) { + void* buf = buffer(name); + PADDLE_ENFORCE_EQ( + 0, cudaMemcpyAsync(buf, data, size, cudaMemcpyHostToDevice, *stream_)); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h new file mode 100644 index 00000000000..ff853455b8b --- /dev/null +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -0,0 +1,144 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "paddle/fluid/inference/engine.h" +#include "paddle/fluid/inference/tensorrt/helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * TensorRT Engine. + * + * There are two alternative ways to use it, one is to build from a paddle + * protobuf model, another way is to manully construct the network. + */ +class TensorRTEngine : public EngineBase { + public: + // Weight is model parameter. + class Weight { + public: + Weight(nvinfer1::DataType dtype, void* value, int num_elem) { + w_.type = dtype; + w_.values = value; + w_.count = num_elem; + } + const nvinfer1::Weights& get() { return w_; } + + private: + nvinfer1::Weights w_; + }; + + TensorRTEngine(int max_batch, int max_workspace, cudaStream_t* stream, + nvinfer1::ILogger& logger = NaiveLogger::Global()) + : max_batch_(max_batch), + max_workspace_(max_workspace), + stream_(stream), + logger_(logger) {} + + virtual ~TensorRTEngine(); + + // TODO(Superjomn) implement it later when graph segmentation is supported. + virtual void Build(const DescType& paddle_model) override; + + virtual void Execute(int batch_size) override; + + // Initialize the inference network, so that TensorRT layers can add to this + // network. + void InitNetwork() { + infer_builder_.reset(createInferBuilder(logger_)); + infer_network_.reset(infer_builder_->createNetwork()); + } + // After finishing adding ops, freeze this network and creates the executation + // environment. + void FreezeNetwork(); + + // Add an input and set its name, data type and dimention. + nvinfer1::ITensor* DeclareInput(const std::string& name, + nvinfer1::DataType dtype, + const nvinfer1::Dims& dim); + // Set the offset-th output from a layer as the network's output, and set its + // name. + void DeclareOutput(const nvinfer1::ILayer* layer, int offset, + const std::string& name); + + // GPU memory address for an ITensor with specific name. One can operate on + // these memory directly for acceleration, for example, output the converted + // data directly to the buffer to save data copy overhead. + // NOTE this should be used after calling `FreezeNetwork`. + void*& buffer(const std::string& name); + + // Fill an input from CPU memory with name and size. + void SetInputFromCPU(const std::string& name, void* data, size_t size); + // TODO(Superjomn) is this method necessary given that buffer(xxx) can be + // accessed directly. Fill an input from GPU memory with name and size. + void SetInputFromGPU(const std::string& name, void* data, size_t size); + // Get an output called name, the output of tensorrt is in GPU, so this method + // will just return the output's GPU memory address. + void* GetOutputInGPU(const std::string& name); + // LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU + // to CPU. + void GetOutputInCPU(const std::string& name, void* dst, size_t max_size); + + nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); } + nvinfer1::INetworkDefinition* network() { return infer_network_.get(); } + + private: + // the max batch size + int max_batch_; + // the max memory size the engine uses + int max_workspace_; + cudaStream_t* stream_; + nvinfer1::ILogger& logger_; + + std::vector buffers_; + // max data size for the buffers. + std::unordered_map buffer_sizes_; + + // TensorRT related internal members + template + struct Destroyer { + void operator()(T* x) { x->destroy(); } + }; + template + using infer_ptr = std::unique_ptr>; + infer_ptr infer_builder_; + infer_ptr infer_network_; + infer_ptr infer_engine_; + infer_ptr infer_context_; +}; // class TensorRTEngine + +// Add an layer__ into engine__ with args ARGS. +// For example: +// TRT_ENGINE_ADD_LAYER(xxx, FullyConnected, input, dim, weights, bias) +// +// Reference +// https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#charRNN_define_network +// +// will add a fully connected layer into the engine. +// TensorRT has too many layers, so that is not wise to add member functions for +// them, and an macro like this is more extensible when underlying TensorRT +// library add new layer supports. +#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \ + engine__->network()->add##layer__(ARGS); + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/helper.h b/paddle/fluid/inference/tensorrt/helper.h new file mode 100644 index 00000000000..796283d325c --- /dev/null +++ b/paddle/fluid/inference/tensorrt/helper.h @@ -0,0 +1,88 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "paddle/fluid/platform/dynload/tensorrt.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +namespace dy = paddle::platform::dynload; + +static size_t AccumDims(nvinfer1::Dims dims) { + size_t num = dims.nbDims == 0 ? 0 : 1; + for (int i = 0; i < dims.nbDims; i++) { + PADDLE_ENFORCE_GT(dims.d[i], 0); + num *= dims.d[i]; + } + return num; +} + +// TensorRT data type to size +const int kDataTypeSize[] = { + 4, // kFLOAT + 2, // kHALF + 1, // kINT8 + 4 // kINT32 +}; + +// The following two API are implemented in TensorRT's header file, cannot load +// from the dynamic library. So create our own implementation and directly +// trigger the method from the dynamic library. +static nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) { + return static_cast( + dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION)); +} +static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) { + return static_cast( + dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION)); +} + +// A logger for create TensorRT infer builder. +class NaiveLogger : public nvinfer1::ILogger { + public: + void log(nvinfer1::ILogger::Severity severity, const char* msg) override { + switch (severity) { + case Severity::kINFO: + LOG(INFO) << msg; + break; + case Severity::kWARNING: + LOG(WARNING) << msg; + break; + case Severity::kINTERNAL_ERROR: + case Severity::kERROR: + LOG(ERROR) << msg; + break; + default: + break; + } + } + + static nvinfer1::ILogger& Global() { + static nvinfer1::ILogger* x = new NaiveLogger; + return *x; + } + + virtual ~NaiveLogger() override {} +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/test_engine.cc b/paddle/fluid/inference/tensorrt/test_engine.cc new file mode 100644 index 00000000000..f3dbdf11f21 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/test_engine.cc @@ -0,0 +1,83 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/engine.h" + +#include +#include +#include +#include + +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class TensorRTEngineTest : public ::testing::Test { + protected: + void SetUp() override { + ASSERT_EQ(0, cudaStreamCreate(&stream_)); + engine_ = new TensorRTEngine(1, 1 << 10, &stream_); + engine_->InitNetwork(); + } + + void TearDown() override { + delete engine_; + cudaStreamDestroy(stream_); + } + + protected: + TensorRTEngine* engine_; + cudaStream_t stream_; +}; + +TEST_F(TensorRTEngineTest, add_layer) { + const int size = 1; + + float raw_weight[size] = {2.}; // Weight in CPU memory. + float raw_bias[size] = {3.}; + + LOG(INFO) << "create weights"; + TensorRTEngine::Weight weight(nvinfer1::DataType::kFLOAT, raw_weight, size); + TensorRTEngine::Weight bias(nvinfer1::DataType::kFLOAT, raw_bias, size); + auto* x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT, + nvinfer1::DimsCHW{1, 1, 1}); + auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *x, size, + weight.get(), bias.get()); + PADDLE_ENFORCE(fc_layer != nullptr); + + engine_->DeclareOutput(fc_layer, 0, "y"); + LOG(INFO) << "freeze network"; + engine_->FreezeNetwork(); + ASSERT_EQ(engine_->engine()->getNbBindings(), 2); + + // fill in real data + float x_v = 1234; + engine_->SetInputFromCPU("x", (void*)&x_v, 1 * sizeof(float)); + LOG(INFO) << "to execute"; + engine_->Execute(1); + + LOG(INFO) << "to get output"; + // void* y_v = + float y_cpu; + engine_->GetOutputInCPU("y", &y_cpu, sizeof(float)); + + LOG(INFO) << "to checkout output"; + ASSERT_EQ(y_cpu, x_v * 2 + 3); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/test_tensorrt.cc b/paddle/fluid/inference/tensorrt/test_tensorrt.cc index a81a708e7a7..aed5b5e1a22 100644 --- a/paddle/fluid/inference/tensorrt/test_tensorrt.cc +++ b/paddle/fluid/inference/tensorrt/test_tensorrt.cc @@ -1,16 +1,16 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 +http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include #include -- GitLab