From 0438b60462350c04618a2c45f87f4fc75b401e55 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Wed, 21 Jul 2021 20:36:14 +0800 Subject: [PATCH] [Paddle-TRT] upgrade test_tensorrt to trt8 (#34294) * upgrade test_tensorrt to trt8 * format --- paddle/fluid/inference/tensorrt/test_tensorrt.cc | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/test_tensorrt.cc b/paddle/fluid/inference/tensorrt/test_tensorrt.cc index 36a25e27d7..2f5b75c102 100644 --- a/paddle/fluid/inference/tensorrt/test_tensorrt.cc +++ b/paddle/fluid/inference/tensorrt/test_tensorrt.cc @@ -16,13 +16,15 @@ limitations under the License. */ #include #include #include "NvInfer.h" +#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/platform/dynload/tensorrt.h" namespace dy = paddle::platform::dynload; class Logger : public nvinfer1::ILogger { public: - void log(nvinfer1::ILogger::Severity severity, const char* msg) override { + void log(nvinfer1::ILogger::Severity severity, + const char* msg) TRT_NOEXCEPT override { switch (severity) { case Severity::kINFO: LOG(INFO) << msg; @@ -74,10 +76,11 @@ nvinfer1::IHostMemory* CreateNetwork() { Logger logger; // Create the engine. nvinfer1::IBuilder* builder = createInferBuilder(&logger); + auto config = builder->createBuilderConfig(); ScopedWeights weights(2.); ScopedWeights bias(3.); - nvinfer1::INetworkDefinition* network = builder->createNetwork(); + nvinfer1::INetworkDefinition* network = builder->createNetworkV2(0U); // Add the input auto input = network->addInput(kInputTensor, nvinfer1::DataType::kFLOAT, nvinfer1::Dims3{1, 1, 1}); @@ -91,8 +94,8 @@ nvinfer1::IHostMemory* CreateNetwork() { network->markOutput(*output); // Build the engine. builder->setMaxBatchSize(1); - builder->setMaxWorkspaceSize(1 << 10); - auto engine = builder->buildCudaEngine(*network); + config->setMaxWorkspaceSize(1 << 10); + auto engine = builder->buildEngineWithConfig(*network, *config); EXPECT_NE(engine, nullptr); // Serialize the engine to create a model, then close. nvinfer1::IHostMemory* model = engine->serialize(); -- GitLab