From 27536a322ba9d5127374e70e36da4d79166be5da Mon Sep 17 00:00:00 2001 From: Wilber Date: Mon, 28 Feb 2022 10:44:46 +0800 Subject: [PATCH] infrt add trt engine (#39885) --- paddle/fluid/platform/dynload/tensorrt.h | 12 +- paddle/infrt/CMakeLists.txt | 1 + paddle/infrt/backends/CMakeLists.txt | 3 + paddle/infrt/backends/tensorrt/CMakeLists.txt | 3 + .../backends/tensorrt/test_trt_engine.cc | 254 ++++++++++++ paddle/infrt/backends/tensorrt/trt_engine.cc | 365 ++++++++++++++++++ paddle/infrt/backends/tensorrt/trt_engine.h | 114 ++++++ paddle/infrt/backends/tensorrt/trt_options.h | 94 +++++ paddle/infrt/backends/tensorrt/trt_utils.h | 147 +++++++ paddle/infrt/kernel/phi/CMakeLists.txt | 4 + .../infershaped/infershape_launchers_test.cc | 2 +- tools/infrt/get_phi_kernel_info.py | 2 +- 12 files changed, 993 insertions(+), 8 deletions(-) create mode 100644 paddle/infrt/backends/CMakeLists.txt create mode 100644 paddle/infrt/backends/tensorrt/CMakeLists.txt create mode 100644 paddle/infrt/backends/tensorrt/test_trt_engine.cc create mode 100644 paddle/infrt/backends/tensorrt/trt_engine.cc create mode 100644 paddle/infrt/backends/tensorrt/trt_engine.h create mode 100644 paddle/infrt/backends/tensorrt/trt_options.h create mode 100644 paddle/infrt/backends/tensorrt/trt_utils.h diff --git a/paddle/fluid/platform/dynload/tensorrt.h b/paddle/fluid/platform/dynload/tensorrt.h index bc29a04720..c2d7eef582 100644 --- a/paddle/fluid/platform/dynload/tensorrt.h +++ b/paddle/fluid/platform/dynload/tensorrt.h @@ -37,7 +37,7 @@ void* GetTensorRtPluginHandle(); extern std::once_flag tensorrt_plugin_dso_flag; extern void* tensorrt_plugin_dso_handle; -#define DECLARE_DYNAMIC_LOAD_TENSORRT_POINTER_WRAP(__name) \ +#define DECLARE_DYNAMIC_LOAD_TENSORRT_POINTER_WRAP_(__name) \ struct DynLoad__##__name { \ template \ void* operator()(Args... args) { \ @@ -55,7 +55,7 @@ extern void* tensorrt_plugin_dso_handle; }; \ extern DynLoad__##__name __name -#define DECLARE_DYNAMIC_LOAD_TENSORRT_NON_POINTER_WRAP(__name) \ +#define DECLARE_DYNAMIC_LOAD_TENSORRT_NON_POINTER_WRAP_(__name) \ struct DynLoad__##__name { \ template \ auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ @@ -72,7 +72,7 @@ extern void* tensorrt_plugin_dso_handle; }; \ extern DynLoad__##__name __name -#define DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP(__name) \ +#define DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP_(__name) \ struct DynLoad__##__name { \ template \ auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ @@ -109,10 +109,10 @@ extern void* tensorrt_plugin_dso_handle; #define TENSORRT_PLUGIN_RAND_ROUTINE_EACH(__macro) \ __macro(initLibNvInferPlugins); -TENSORRT_RAND_ROUTINE_EACH_POINTER(DECLARE_DYNAMIC_LOAD_TENSORRT_POINTER_WRAP) +TENSORRT_RAND_ROUTINE_EACH_POINTER(DECLARE_DYNAMIC_LOAD_TENSORRT_POINTER_WRAP_) TENSORRT_RAND_ROUTINE_EACH_NON_POINTER( - DECLARE_DYNAMIC_LOAD_TENSORRT_NON_POINTER_WRAP) -TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP) + DECLARE_DYNAMIC_LOAD_TENSORRT_NON_POINTER_WRAP_) +TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP_) #endif // end of NV_TENSORRT_MAJOR diff --git a/paddle/infrt/CMakeLists.txt b/paddle/infrt/CMakeLists.txt index f2a78db558..dc22eecc99 100644 --- a/paddle/infrt/CMakeLists.txt +++ b/paddle/infrt/CMakeLists.txt @@ -74,6 +74,7 @@ endif() add_subdirectory(api) +add_subdirectory(backends) add_subdirectory(common) add_subdirectory(dialect) add_subdirectory(host_context) diff --git a/paddle/infrt/backends/CMakeLists.txt b/paddle/infrt/backends/CMakeLists.txt new file mode 100644 index 0000000000..b639f89292 --- /dev/null +++ b/paddle/infrt/backends/CMakeLists.txt @@ -0,0 +1,3 @@ +if (INFRT_WITH_PHI AND WITH_GPU AND WITH_TENSORRT) + add_subdirectory(tensorrt) +endif() diff --git a/paddle/infrt/backends/tensorrt/CMakeLists.txt b/paddle/infrt/backends/tensorrt/CMakeLists.txt new file mode 100644 index 0000000000..cc20c9a2e1 --- /dev/null +++ b/paddle/infrt/backends/tensorrt/CMakeLists.txt @@ -0,0 +1,3 @@ +cc_library(infrt_trt SRCS trt_engine.cc DEPS glog phi_dynload_cuda phi) + +cc_test_tiny(test_infrt_trt SRCS test_trt_engine.cc DEPS infrt_trt phi_dynload_cuda tensorrt_converter) diff --git a/paddle/infrt/backends/tensorrt/test_trt_engine.cc b/paddle/infrt/backends/tensorrt/test_trt_engine.cc new file mode 100644 index 0000000000..54b7bc3e8a --- /dev/null +++ b/paddle/infrt/backends/tensorrt/test_trt_engine.cc @@ -0,0 +1,254 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/infrt/backends/tensorrt/trt_engine.h" +#include "paddle/infrt/backends/tensorrt/trt_options.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/allocator.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/meta_tensor.h" + +namespace infrt { +namespace backends { +namespace tensorrt { + +const char* model_input = "model_input"; +const char* model_output = "model_output1"; +const char* model_output2 = "model_output2"; + +TrtUniquePtr ConstructNetwork( + nvinfer1::IBuilder* builder, nvinfer1::Dims dims, bool is_static_shape) { + TrtUniquePtr network; + if (is_static_shape) { + network.reset(builder->createNetworkV2(0U)); + } else { + auto networkFlags = + 1U << static_cast( + nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + network.reset(builder->createNetworkV2(networkFlags)); + } + + ITensor* data = + network->addInput(model_input, nvinfer1::DataType::kFLOAT, dims); + CHECK_NOTNULL(data); + IActivationLayer* act = + network->addActivation(*data, ActivationType::kSIGMOID); + CHECK_NOTNULL(act); + auto* act_out = act->getOutput(0); + std::vector output_length{1, 2}; + int axis; + nvinfer1::IPluginV2Layer* split_layer; + if (is_static_shape) { + axis = 0; + paddle::inference::tensorrt::plugin::SplitPlugin plugin( + axis, output_length, false); + split_layer = network->addPluginV2(&act_out, 1, plugin); + } else { + axis = 1; + paddle::inference::tensorrt::plugin::SplitPluginDynamic plugin( + axis, output_length, false); + split_layer = network->addPluginV2(&act_out, 1, plugin); + } + + split_layer->getOutput(0)->setName(model_output); + split_layer->getOutput(1)->setName(model_output2); + network->markOutput(*split_layer->getOutput(0)); + network->markOutput(*split_layer->getOutput(1)); + return network; +} + +// sigmoid(x) = 1 / (1 + exp(-x)) +inline float sigmoid(float x) { return 1.f / (1.f + exp(-1 * x)); } + +TEST(trt, run_static) { + TRTEngine static_trt_engine(0); + auto net = ConstructNetwork( + static_trt_engine.GetTrtBuilder(), nvinfer1::Dims3{3, 28, 28}, true); + BuildOptions static_build_options; + static_build_options.max_batch = 4; + static_trt_engine.Build(std::move(net), static_build_options); + InferenceOptions inference_options; + inference_options.batch = 2; + + phi::GPUPlace place; + phi::GPUContext context; + context.PartialInitWithoutAllocator(); + context.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(place, context.stream()) + .get()); + context.PartialInitWithAllocator(); + + phi::DenseTensorMeta meta( + phi::DataType::FLOAT32, + phi::make_ddim({inference_options.batch, 3, 28, 28})); + phi::DenseTensor input; + input.set_meta(meta); + context.Alloc(&input, input.numel() * sizeof(float)); + std::vector host_data(inference_options.batch * 3 * 28 * 28, 0); + for (size_t i = 0; i < host_data.size(); ++i) { + host_data[i] = i % 100 * 0.016f; + } + paddle::memory::Copy(place, + input.data(), + phi::CPUPlace(), + host_data.data(), + sizeof(float) * host_data.size(), + context.stream()); + + std::unordered_map inputs; + inputs.emplace(std::make_pair(model_input, &input)); + phi::DenseTensor output, output2; + std::unordered_map outputs; + outputs.emplace(std::make_pair(model_output, &output)); + outputs.emplace(std::make_pair(model_output2, &output2)); + + static_trt_engine.SetUpInference(inference_options, inputs, &outputs); + static_trt_engine.GetEngineInfo(); + static_trt_engine.Run(context); + + std::vector output_data1(inference_options.batch * 1 * 28 * 28, 0); + std::vector output_data2(inference_options.batch * 2 * 28 * 28, 0); + paddle::memory::Copy(phi::CPUPlace(), + output_data1.data(), + place, + output.data(), + sizeof(float) * output_data1.size(), + context.stream()); + paddle::memory::Copy(phi::CPUPlace(), + output_data2.data(), + place, + output2.data(), + sizeof(float) * output_data2.size(), + context.stream()); + cudaStreamSynchronize(context.stream()); + + for (size_t i = 0; i < host_data.size(); ++i) { + int w = i % 28; + int h = (i / 28) % 28; + int c = i / (28 * 28) % 3; + int n = i / (28 * 28 * 3); + if (c == 0) { + CHECK_NEAR( + sigmoid(host_data[i]), output_data1[n * 28 * 28 + h * 28 + w], 1e-5); + } else { + CHECK_NEAR(sigmoid(host_data[i]), + output_data2[n * 28 * 28 * 2 + (c - 1) * 28 * 28 + h * 28 + w], + 1e-5); + } + } +} + +TEST(trt, run_dynamic) { + TRTEngine engine(0); + auto net = ConstructNetwork( + engine.GetTrtBuilder(), nvinfer1::Dims4{-1, 3, -1, -1}, false); + BuildOptions build_options; + build_options.max_batch = 4; + build_options.workspace = 32; + // build_options.fp16 = true; + std::vector min_shape{1, 3, 16, 16}; + std::vector opt_shape{2, 3, 28, 28}; + std::vector max_shape{4, 3, 28, 28}; + build_options.shapes[model_input][0] = min_shape; + build_options.shapes[model_input][1] = opt_shape; + build_options.shapes[model_input][2] = max_shape; + engine.Build(std::move(net), build_options); + + InferenceOptions inference_options; + inference_options.batch = 2; + + phi::GPUPlace place; + phi::GPUContext context; + context.PartialInitWithoutAllocator(); + context.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(place, context.stream()) + .get()); + context.PartialInitWithAllocator(); + + phi::DenseTensorMeta meta( + phi::DataType::FLOAT32, + phi::make_ddim({inference_options.batch, 3, 16, 16})); + phi::DenseTensor input, output, output2; + input.set_meta(meta); + context.Alloc(&input, input.numel() * sizeof(float)); + std::vector host_data(inference_options.batch * 3 * 16 * 16, 0); + for (size_t i = 0; i < host_data.size(); ++i) { + host_data[i] = i % 100 * 0.016f; + } + paddle::memory::Copy(place, + input.data(), + phi::CPUPlace(), + host_data.data(), + sizeof(float) * host_data.size(), + context.stream()); + + std::unordered_map inputs; + std::unordered_map outputs; + inputs.emplace(std::make_pair(model_input, &input)); + outputs.emplace(std::make_pair(model_output, &output)); + outputs.emplace(std::make_pair(model_output2, &output2)); + + engine.SetUpInference(inference_options, inputs, &outputs); + engine.GetEngineInfo(); + engine.Run(context); + + std::vector output_data1(inference_options.batch * 1 * 16 * 16, 0); + std::vector output_data2(inference_options.batch * 2 * 16 * 16, 0); + paddle::memory::Copy(phi::CPUPlace(), + output_data1.data(), + place, + output.data(), + sizeof(float) * output_data1.size(), + context.stream()); + paddle::memory::Copy(phi::CPUPlace(), + output_data2.data(), + place, + output2.data(), + sizeof(float) * output_data2.size(), + context.stream()); + cudaStreamSynchronize(context.stream()); + + for (size_t i = 0; i < host_data.size(); ++i) { + int w = i % 16; + int h = (i / 16) % 16; + int c = i / (16 * 16) % 3; + int n = i / (16 * 16 * 3); + if (c == 0) { + CHECK_NEAR( + sigmoid(host_data[i]), output_data1[n * 16 * 16 + h * 16 + w], 1e-5); + } else { + CHECK_NEAR(sigmoid(host_data[i]), + output_data2[n * 16 * 16 * 2 + (c - 1) * 16 * 16 + h * 16 + w], + 1e-5); + } + } +} + +} // namespace tensorrt +} // namespace backends +} // namespace infrt diff --git a/paddle/infrt/backends/tensorrt/trt_engine.cc b/paddle/infrt/backends/tensorrt/trt_engine.cc new file mode 100644 index 0000000000..a204fe42b4 --- /dev/null +++ b/paddle/infrt/backends/tensorrt/trt_engine.cc @@ -0,0 +1,365 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/backends/tensorrt/trt_engine.h" + +#include +#include +#include "glog/logging.h" +#include "paddle/phi/backends/dynload/tensorrt.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/ddim.h" + +namespace infrt { +namespace backends { +namespace tensorrt { + +// The following two API are implemented in TensorRT's header file, cannot load +// from the dynamic library. So create our own implementation and directly +// trigger the method from the dynamic library. +static nvinfer1::IBuilder* createInferBuilder( + nvinfer1::ILogger& logger) { // NOLINT + return static_cast( + phi::dynload::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION)); +} +static nvinfer1::IRuntime* createInferRuntime( + nvinfer1::ILogger& logger) { // NOLINT + return static_cast( + phi::dynload::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION)); +} + +TRTEngine::TRTEngine(int device_id) : device_id_(device_id) { + FreshDeviceId(); + logger_.reset(new TrtLogger()); + builder_.reset(createInferBuilder(logger_->GetTrtLogger())); + phi::dynload::initLibNvInferPlugins(&logger_->GetTrtLogger(), ""); +} + +nvinfer1::IBuilder* TRTEngine::GetTrtBuilder() { + CHECK_NOTNULL(builder_); + return builder_.get(); +} + +void TRTEngine::Build(TrtUniquePtr network, + const BuildOptions& build_options) { + FreshDeviceId(); + ModelToBuildEnv(std::move(network), build_options); + CHECK_NOTNULL(engine_); +} + +bool TRTEngine::ModelToBuildEnv( + TrtUniquePtr network, + const BuildOptions& build) { + CHECK_NOTNULL(builder_); + std::swap(network, network_); + CHECK_NOTNULL(network_); + // ModelToNetwork(network_, logger); + NetworkToEngine(build); + return true; +} + +bool TRTEngine::NetworkToEngine(const BuildOptions& build) { + TrtUniquePtr config{builder_->createBuilderConfig()}; + CHECK_NOTNULL(config); + CHECK(SetupNetworkAndConfig(build, *network_, *config)); + +#if IS_TRT_VERSION_LT(8000) + engine_.reset(builder_->buildEngineWithConfig(*network_, *config)); +#else + serialized_engine_.reset( + builder_->buildSerializedNetwork(*network_, *config)); + CHECK_NOTNULL(serialized_engine_); + + TrtUniquePtr runtime{createInferRuntime(logger_->GetTrtLogger())}; + CHECK_NOTNULL(runtime); + engine_.reset(runtime->deserializeCudaEngine(serialized_engine_->data(), + serialized_engine_->size())); + CHECK_NOTNULL(engine_); +#endif + return true; +} + +bool TRTEngine::SetupNetworkAndConfig(const BuildOptions& build, + INetworkDefinition& network, + IBuilderConfig& config) { + builder_->setMaxBatchSize(build.max_batch); + // TODO(wilber): handle one engine - multi execution context case. + IOptimizationProfile* profile{nullptr}; + if (!build.shapes.empty()) { + profile = builder_->createOptimizationProfile(); + CHECK_NOTNULL(profile); + } + + // Set formats and data types of inputs + for (int32_t i = 0; i < network.getNbInputs(); ++i) { + auto* input = network.getInput(i); + if (!build.input_formats.empty()) { + input->setType(build.input_formats[i].first); + input->setAllowedFormats(build.input_formats[i].second); + } else { + switch (input->getType()) { + case DataType::kINT32: + case DataType::kBOOL: + case DataType::kHALF: + // Leave these as is. + break; + case DataType::kFLOAT: + case DataType::kINT8: + // User did not specify a floating-point format. Default to kFLOAT. + input->setType(DataType::kFLOAT); + break; + } + input->setAllowedFormats(1U << static_cast(TensorFormat::kLINEAR)); + } + + if (profile) { + Dims dims = input->getDimensions(); + // TODO(wilber): shape tensor. + const bool is_dynamic_input = std::any_of( + dims.d, dims.d + dims.nbDims, [](int dim) { return dim == -1; }); + if (is_dynamic_input) { + is_dynamic_shape_ = true; + auto shape = build.shapes.find(input->getName()); + + // If no shape is provided + if (shape == build.shapes.end()) { + // TODO(wilber): add infomation. + CHECK(false); + } + LOG(INFO) << "Run Paddle-TRT Dynamic Shape mode."; + std::vector profile_dims{}; + profile_dims = + shape->second[static_cast(OptProfileSelector::kMIN)]; + CHECK(profile->setDimensions(input->getName(), + OptProfileSelector::kMIN, + VecToDims(profile_dims))); + profile_dims = + shape->second[static_cast(OptProfileSelector::kOPT)]; + CHECK(profile->setDimensions(input->getName(), + OptProfileSelector::kOPT, + VecToDims(profile_dims))); + profile_dims = + shape->second[static_cast(OptProfileSelector::kMAX)]; + CHECK(profile->setDimensions(input->getName(), + OptProfileSelector::kMAX, + VecToDims(profile_dims))); + } + } + } + + if (profile && is_dynamic_shape_) { + CHECK(profile->isValid()); // Required optimization profile is invalid + CHECK_NE(config.addOptimizationProfile(profile), -1); + } + + // Set formats and data types of outputs + for (int32_t i = 0, n = network.getNbOutputs(); i < n; i++) { + auto* output = network.getOutput(i); + if (!build.output_formats.empty()) { + // int outputFormatIndex = broadcastOutputFormats ? 0 : i; + output->setType(build.output_formats[i].first); + output->setAllowedFormats(build.output_formats[i].second); + } else { + output->setAllowedFormats(1U << static_cast(TensorFormat::kLINEAR)); + } + } + + config.setMaxWorkspaceSize(static_cast(build.workspace) << 20); + + if (build.fp16) { + config.setFlag(BuilderFlag::kFP16); + bool support_fp16 = builder_->platformHasFastFp16(); + if (support_fp16) { + LOG(INFO) << "Run INFRT-TRT FP16 mode"; + } else { + LOG(INFO) << "You specify FP16 mode, but the hardware do not support " + "FP16 speed up, use FP32 instead."; + } + } + + if (build.tf32) { + config.setFlag(BuilderFlag::kTF32); + bool support_tf32 = builder_->platformHasTf32(); + if (support_tf32) { + LOG(INFO) << "Run INFRT-TRT TF32 mode"; + } else { + LOG(INFO) << "You specify TF32 mode, but the hardware do not support " + "TF32 speed up, use FP32 instead."; + } + } + + // TODO(wilber): other precision. + + // TODO(wilber): precision config. + switch (build.precision_constraints) { + case PrecisionConstraints::kNONE: + // It's the default for TensorRT. + break; + case PrecisionConstraints::kOBEY: + config.setFlag(BuilderFlag::kOBEY_PRECISION_CONSTRAINTS); + break; + case PrecisionConstraints::kPREFER: + config.setFlag(BuilderFlag::kPREFER_PRECISION_CONSTRAINTS); + break; + } + + // TODO(TRT): DLA config. + + // TODO(TRT): int8 config. + // TODO(TRT): support int8 + if (build.int8) { + assert(false); + config.setFlag(BuilderFlag::kINT8); + bool support_int8 = builder_->platformHasFastInt8(); + if (support_int8) { + LOG(INFO) << "Run INFRT-TRT FP16 mode"; + } + } + + // TODO(TRT): calib config. + + // TODO(TRT): sparse config. + + return true; +} + +bool TRTEngine::SetUpInference( + const InferenceOptions& inference, + const std::unordered_map& inputs, + std::unordered_map* outputs) { + // TODO(wilber): now only create one exec_context + FreshDeviceId(); + CHECK(engine_ != nullptr); + nvinfer1::IExecutionContext* ec = engine_->createExecutionContext(); + CHECK(ec != nullptr); + contexts_.emplace_back(ec); + bindings_.emplace_back(new Bindings()); + + for (const auto& it : inputs) { + const int bind_index = engine_->getBindingIndex(it.first.c_str()); + bindings_.front()->AddBinding( + bind_index, it.first, true, it.second, nvinfer1::DataType::kFLOAT); + } + for (auto& it : *outputs) { + const int bind_index = engine_->getBindingIndex(it.first.c_str()); + bindings_.front()->AddBinding( + bind_index, it.first, false, it.second, nvinfer1::DataType::kFLOAT); + } + + return true; +} + +void TRTEngine::Run(const phi::GPUContext& ctx) { + if (is_dynamic_shape_) { + DynamicRun(ctx); + } else { + StaticRun(ctx); + } +} + +void TRTEngine::StaticRun(const phi::GPUContext& ctx) { + const int num_bindings = engine_->getNbBindings(); + std::vector buffers(num_bindings, nullptr); + + int runtime_batch = -1; + auto input_binds = bindings_.front()->GetInputBindings(); + for (auto bind : input_binds) { + const int bind_index = engine_->getBindingIndex(bind.name.c_str()); + buffers[bind_index] = + const_cast(static_cast(bind.buffer->data())); + if (runtime_batch != -1) { + CHECK_EQ(runtime_batch, phi::vectorize(bind.buffer->dims())[0]); + } + runtime_batch = bind.buffer->dims()[0]; + } + + auto output_binds = bindings_.front()->GetOutputBindings(); + for (auto bind : output_binds) { + const int bind_index = engine_->getBindingIndex(bind.name.c_str()); + std::vector ddim; + auto dims = engine_->getBindingDimensions(bind_index); + ddim.push_back(runtime_batch); + for (int i = 0; i < dims.nbDims; ++i) { + ddim.push_back(dims.d[i]); + } + bind.buffer->Resize(phi::make_ddim(ddim)); + ctx.Alloc(bind.buffer, sizeof(float) * bind.buffer->numel()); + buffers[bind_index] = static_cast(bind.buffer->data()); + } + + contexts_.front()->enqueue( + runtime_batch, buffers.data(), ctx.stream(), nullptr); +} + +void TRTEngine::DynamicRun(const phi::GPUContext& ctx) { + const int num_bindings = engine_->getNbBindings(); + std::vector buffers(num_bindings, nullptr); + + auto input_binds = bindings_.front()->GetInputBindings(); + for (auto bind : input_binds) { + const int bind_index = engine_->getBindingIndex(bind.name.c_str()); + buffers[bind_index] = + const_cast(static_cast(bind.buffer->data())); + nvinfer1::Dims trt_dims; + trt_dims.nbDims = bind.buffer->dims().size(); + + for (int i = 0; i < trt_dims.nbDims; ++i) { + trt_dims.d[i] = bind.buffer->dims()[i]; + } + contexts_.front()->setBindingDimensions(bind_index, trt_dims); + } + + CHECK(contexts_.front()->allInputDimensionsSpecified()); + + auto output_binds = bindings_.front()->GetOutputBindings(); + for (auto bind : output_binds) { + const int bind_index = engine_->getBindingIndex(bind.name.c_str()); + auto dims = contexts_.front()->getBindingDimensions(bind_index); + std::vector ddim(dims.nbDims); + for (int i = 0; i < dims.nbDims; ++i) { + ddim[i] = dims.d[i]; + } + bind.buffer->Resize(phi::make_ddim(ddim)); + ctx.Alloc(bind.buffer, sizeof(float) * bind.buffer->numel()); + buffers[bind_index] = static_cast(bind.buffer->data()); + } + + contexts_.front()->enqueueV2(buffers.data(), ctx.stream(), nullptr); +} + +void TRTEngine::FreshDeviceId() { + int count; + cudaGetDeviceCount(&count); + CHECK_LT(device_id_, count); + phi::backends::gpu::SetDeviceId(device_id_); +} + +void TRTEngine::GetEngineInfo() { +#if IS_TRT_VERSION_GE(8200) + LOG(INFO) << "====== engine info ======"; + std::unique_ptr infer_inspector( + engine_->createEngineInspector()); + infer_inspector->setExecutionContext(contexts_.front().get()); + LOG(INFO) << infer_inspector->getEngineInformation( + nvinfer1::LayerInformationFormat::kONELINE); + LOG(INFO) << "====== engine info end ======"; +#else + LOG(INFO) << "Inspector needs TensorRT version 8.2 and after."; +#endif +} + +} // namespace tensorrt +} // namespace backends +} // namespace infrt diff --git a/paddle/infrt/backends/tensorrt/trt_engine.h b/paddle/infrt/backends/tensorrt/trt_engine.h new file mode 100644 index 0000000000..f72bdaf3ac --- /dev/null +++ b/paddle/infrt/backends/tensorrt/trt_engine.h @@ -0,0 +1,114 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/infrt/backends/tensorrt/trt_options.h" +#include "paddle/infrt/backends/tensorrt/trt_utils.h" +#include "paddle/phi/backends/dynload/tensorrt.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace infrt { +namespace backends { +namespace tensorrt { +using namespace nvinfer1; // NOLINT + +// The trt programing model as follows: +// 1. The build phase: +// IBuilder* builder = createInferBuilder(&logger_); +// 2. Create a network definition: +// INetworkDefinition* network = builder->createNetworkV2(...); +// 3. Build network: +// network->AddLayer(...) +// 4. Configure network: +// IBuilderConfig* config = builder->createBuilderConfig(); +// config->setMaxWorkspaceSize(...) +// 5. Get cuda engine and deserializing a plan: +// IHostMemory* serialized_model = builder->buildSerializedNetwork(...); +// IRuntime* runtime = createInferRuntime(&logger_); +// ICudaEngine* engine = runtime->deserializeCudaEngine(...); +// 6. Get execution context: +// IExecutionContext* exec_context = engine->createExecutionContext(); +// 7. Set input data: +// int32_t input_index = engine->getBindingIndex("input"); +// int32_t output_index = engine->getBindingIndex("output"); +// void* buffers[2]; +// buffers[input_index] = input_buffer; +// buffers[output_index] = output_buffer; +// 8. Performance inference: +// exec_context->enqueueV2(buffers, stream, nullptr); +// +// We have encapsulated this logic, please use the following programming model. +// +// TRTEngine trt_engine; +// trt_engine.Build(...); +// trt_engine.SetUpInference(...); +// trt_engine.Run(...); +class TRTEngine { + public: + explicit TRTEngine(int device_id); + + nvinfer1::IBuilder* GetTrtBuilder(); + + // TODO(wilber): Modify signature after infrt-trt ready. + void Build(TrtUniquePtr network, + const BuildOptions& build_options); + + // TODO(wilber): Modify signature after infrt-trt ready. + void Run(const phi::GPUContext& ctx); + + // TODO(wilber): How to support multiple execution contexts? + bool SetUpInference( + const InferenceOptions& inference, + const std::unordered_map& inputs, + std::unordered_map* outputs); + + void GetEngineInfo(); + + private: + void FreshDeviceId(); + + bool SetupNetworkAndConfig(const BuildOptions& build, + INetworkDefinition& network, // NOLINT + IBuilderConfig& config); // NOLINT + + bool NetworkToEngine(const BuildOptions& build); + + bool ModelToBuildEnv(TrtUniquePtr network, + const BuildOptions& build); + + void StaticRun(const phi::GPUContext& ctx); + + void DynamicRun(const phi::GPUContext& ctx); + + private: + std::unique_ptr logger_{nullptr}; + TrtUniquePtr builder_{nullptr}; + TrtUniquePtr network_{nullptr}; + std::unique_ptr serialized_engine_{nullptr}; + TrtUniquePtr engine_{nullptr}; + std::vector> contexts_; + std::vector> bindings_; + int device_id_{0}; + bool is_dynamic_shape_{false}; +}; + +} // namespace tensorrt +} // namespace backends +} // namespace infrt diff --git a/paddle/infrt/backends/tensorrt/trt_options.h b/paddle/infrt/backends/tensorrt/trt_options.h new file mode 100644 index 0000000000..d5190f5e62 --- /dev/null +++ b/paddle/infrt/backends/tensorrt/trt_options.h @@ -0,0 +1,94 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include + +namespace infrt { +namespace backends { +namespace tensorrt { + +// Build default params +constexpr int32_t max_batch_not_provided{0}; +constexpr int32_t default_workspace{16}; +// Inference default params +constexpr int32_t default_batch{1}; +constexpr int32_t batch_not_provided{0}; + +enum class PrecisionConstraints { kNONE, kOBEY, kPREFER }; + +enum class SparsityFlag { kDISABLE, kENABLE, kFORCE }; + +using ShapeRange = + std::array, + nvinfer1::EnumMax()>; + +using IOFormat = std::pair; + +struct BuildOptions { + // Set max batch size. + int32_t max_batch{max_batch_not_provided}; + + // Set workspace size in megabytes (default = 16) + int32_t workspace{default_workspace}; + + // Enable tf32 precision, in addition to fp32 (default = disabled) + bool tf32{false}; + + // Enable fp16 precision, in addition to fp32 (default = disabled) + bool fp16{false}; + + // Enable int8 precision, in addition to fp32 (default = disabled) + bool int8{false}; + + // Control precision constraints. (default = none) + // Precision Constaints: = none, obey, prefer + // none = no constraints + // prefer = meet precision constraints if possible + // obey = meet precision constraints or fail otherwise + PrecisionConstraints precision_constraints{PrecisionConstraints::kNONE}; + + // Save the serialized engine. + bool save{false}; + + // Load a serialized engine. + bool load{false}; + + // Build with dynamic shapes using a profile with the min, max and opt shapes + // provided + std::unordered_map shapes; + + // Type and format of each of the input tensors (default = all inputs in + // fp32:chw) + std::vector input_formats; + + // Type and format of each of the output tensors (default = all outputs in + // fp32:chw) + std::vector output_formats; +}; + +struct InferenceOptions { + int32_t batch{batch_not_provided}; + std::unordered_map> shapes; +}; + +} // namespace tensorrt +} // namespace backends +} // namespace infrt diff --git a/paddle/infrt/backends/tensorrt/trt_utils.h b/paddle/infrt/backends/tensorrt/trt_utils.h new file mode 100644 index 0000000000..4b129af1d5 --- /dev/null +++ b/paddle/infrt/backends/tensorrt/trt_utils.h @@ -0,0 +1,147 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include "glog/logging.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace infrt { +namespace backends { +namespace tensorrt { + +#define IS_TRT_VERSION_GE(version) \ + ((NV_TENSORRT_MAJOR * 1000 + NV_TENSORRT_MINOR * 100 + \ + NV_TENSORRT_PATCH * 10 + NV_TENSORRT_BUILD) >= version) + +#define IS_TRT_VERSION_LT(version) \ + ((NV_TENSORRT_MAJOR * 1000 + NV_TENSORRT_MINOR * 100 + \ + NV_TENSORRT_PATCH * 10 + NV_TENSORRT_BUILD) < version) + +#define TRT_VERSION \ + NV_TENSORRT_MAJOR * 1000 + NV_TENSORRT_MINOR * 100 + \ + NV_TENSORRT_PATCH * 10 + NV_TENSORRT_BUILD + +inline nvinfer1::Dims VecToDims(const std::vector& vec) { + int limit = static_cast(nvinfer1::Dims::MAX_DIMS); + if (static_cast(vec.size()) > limit) { + assert(false); + } + // Pick first nvinfer1::Dims::MAX_DIMS elements + nvinfer1::Dims dims{std::min(static_cast(vec.size()), limit), {}}; + std::copy_n(vec.begin(), dims.nbDims, std::begin(dims.d)); + return dims; +} + +template +struct TrtDestroyer { + void operator()(T* t) { t->destroy(); } +}; + +template +using TrtUniquePtr = std::unique_ptr>; + +class TrtLogger : public nvinfer1::ILogger { + public: + void log(nvinfer1::ILogger::Severity severity, + const char* msg) noexcept override { + switch (severity) { + case Severity::kVERBOSE: + VLOG(3) << msg; + break; + case Severity::kINFO: + VLOG(2) << msg; + break; + case Severity::kWARNING: + LOG(WARNING) << msg; + break; + case Severity::kINTERNAL_ERROR: + case Severity::kERROR: + LOG(ERROR) << msg; + break; + default: + break; + } + } + nvinfer1::ILogger& GetTrtLogger() noexcept { return *this; } + ~TrtLogger() override = default; +}; + +struct Binding { + bool is_input{false}; + nvinfer1::DataType data_type{nvinfer1::DataType::kFLOAT}; + phi::DenseTensor* buffer{nullptr}; + std::string name; +}; + +class Bindings { + public: + Bindings() = default; + + void AddBinding(int32_t b, + const std::string& name, + bool is_input, + phi::DenseTensor* buffer, + nvinfer1::DataType data_type) { + while (bindings_.size() <= static_cast(b)) { + bindings_.emplace_back(); + } + names_[name] = b; + bindings_[b].buffer = buffer; + bindings_[b].is_input = is_input; + bindings_[b].data_type = data_type; + bindings_[b].name = name; + } + + std::vector GetInputBindings() { + return GetBindings([](const Binding& b) -> bool { return b.is_input; }); + } + + std::vector GetOutputBindings() { + return GetBindings([](const Binding& b) -> bool { return !b.is_input; }); + } + + std::vector GetBindings() { + return GetBindings([](const Binding& b) -> bool { return true; }); + } + + std::vector GetBindings( + std::function predicate) { + std::vector bindings; + for (const auto& b : bindings_) { + if (predicate(b)) { + bindings.push_back(b); + } + } + return bindings; + } + + private: + std::unordered_map names_; + std::vector bindings_; +}; + +} // namespace tensorrt +} // namespace backends +} // namespace infrt diff --git a/paddle/infrt/kernel/phi/CMakeLists.txt b/paddle/infrt/kernel/phi/CMakeLists.txt index 30a2621f4a..7055c0c06d 100644 --- a/paddle/infrt/kernel/phi/CMakeLists.txt +++ b/paddle/infrt/kernel/phi/CMakeLists.txt @@ -18,6 +18,10 @@ set(wrapped_infermeta_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/infermeta/gener add_custom_command( OUTPUT ${infrt_register_phi_kernels_gen_source_file} + COMMAND sh ${infrt_register_phi_kernels_gen_file} + DEPENDS ${wrapped_infermeta_header_file} ${wrapped_infermeta_source_file} + VERBATIM) +add_custom_target(infrt_register_phi_kernel COMMAND sh ${infrt_register_phi_kernels_gen_file} DEPENDS ${wrapped_infermeta_header_file} ${wrapped_infermeta_source_file} COMMENT "infrt generate ${infrt_register_phi_kernels_gen_source_file}" diff --git a/paddle/infrt/kernel/phi/infershaped/infershape_launchers_test.cc b/paddle/infrt/kernel/phi/infershaped/infershape_launchers_test.cc index 331ebcfb4a..2161e98fac 100644 --- a/paddle/infrt/kernel/phi/infershaped/infershape_launchers_test.cc +++ b/paddle/infrt/kernel/phi/infershaped/infershape_launchers_test.cc @@ -54,7 +54,7 @@ TEST(ElementwiseAdd, launcher_registry) { host_context::KernelRegistry registry; RegisterInferShapeLaunchers(®istry); ASSERT_GE(registry.size(), 1UL); - auto creator = registry.GetKernel("add.cpu.any.fp32"); + auto creator = registry.GetKernel("pten.add.cpu.any.fp32"); const phi::DDim dims({1, 2}); const phi::DataType dtype{phi::DataType::FLOAT32}; diff --git a/tools/infrt/get_phi_kernel_info.py b/tools/infrt/get_phi_kernel_info.py index b0c834718b..f3e9f345da 100644 --- a/tools/infrt/get_phi_kernel_info.py +++ b/tools/infrt/get_phi_kernel_info.py @@ -219,7 +219,7 @@ def gen_register_info(resources: List[List[str]]): for ir_dtype, origin_dtype in zip(ir_dtypes, origin_dtypes): kernel_func = gen_kernel_func(update_item[3], ctx_name, origin_dtype) - ir_name = '.'.join( + ir_name = 'pten.' + '.'.join( [it.lower() for it in update_item[:3]]) + "." + ir_dtype res += f""" registry->AddKernel("{ir_name}",""" -- GitLab