未验证 提交 0438b604 编写于 作者: Z zlsh80826 提交者: GitHub

[Paddle-TRT] upgrade test_tensorrt to trt8 (#34294)

* upgrade test_tensorrt to trt8

* format
上级 6fc33a0c
...@@ -16,13 +16,15 @@ limitations under the License. */ ...@@ -16,13 +16,15 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "NvInfer.h" #include "NvInfer.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/dynload/tensorrt.h" #include "paddle/fluid/platform/dynload/tensorrt.h"
namespace dy = paddle::platform::dynload; namespace dy = paddle::platform::dynload;
class Logger : public nvinfer1::ILogger { class Logger : public nvinfer1::ILogger {
public: public:
void log(nvinfer1::ILogger::Severity severity, const char* msg) override { void log(nvinfer1::ILogger::Severity severity,
const char* msg) TRT_NOEXCEPT override {
switch (severity) { switch (severity) {
case Severity::kINFO: case Severity::kINFO:
LOG(INFO) << msg; LOG(INFO) << msg;
...@@ -74,10 +76,11 @@ nvinfer1::IHostMemory* CreateNetwork() { ...@@ -74,10 +76,11 @@ nvinfer1::IHostMemory* CreateNetwork() {
Logger logger; Logger logger;
// Create the engine. // Create the engine.
nvinfer1::IBuilder* builder = createInferBuilder(&logger); nvinfer1::IBuilder* builder = createInferBuilder(&logger);
auto config = builder->createBuilderConfig();
ScopedWeights weights(2.); ScopedWeights weights(2.);
ScopedWeights bias(3.); ScopedWeights bias(3.);
nvinfer1::INetworkDefinition* network = builder->createNetwork(); nvinfer1::INetworkDefinition* network = builder->createNetworkV2(0U);
// Add the input // Add the input
auto input = network->addInput(kInputTensor, nvinfer1::DataType::kFLOAT, auto input = network->addInput(kInputTensor, nvinfer1::DataType::kFLOAT,
nvinfer1::Dims3{1, 1, 1}); nvinfer1::Dims3{1, 1, 1});
...@@ -91,8 +94,8 @@ nvinfer1::IHostMemory* CreateNetwork() { ...@@ -91,8 +94,8 @@ nvinfer1::IHostMemory* CreateNetwork() {
network->markOutput(*output); network->markOutput(*output);
// Build the engine. // Build the engine.
builder->setMaxBatchSize(1); builder->setMaxBatchSize(1);
builder->setMaxWorkspaceSize(1 << 10); config->setMaxWorkspaceSize(1 << 10);
auto engine = builder->buildCudaEngine(*network); auto engine = builder->buildEngineWithConfig(*network, *config);
EXPECT_NE(engine, nullptr); EXPECT_NE(engine, nullptr);
// Serialize the engine to create a model, then close. // Serialize the engine to create a model, then close.
nvinfer1::IHostMemory* model = engine->serialize(); nvinfer1::IHostMemory* model = engine->serialize();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册