未验证 提交 0438b604 编写于 作者: Z zlsh80826 提交者: GitHub

[Paddle-TRT] upgrade test_tensorrt to trt8 (#34294)

* upgrade test_tensorrt to trt8

* format
上级 6fc33a0c
......@@ -16,13 +16,15 @@ limitations under the License. */
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "NvInfer.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/dynload/tensorrt.h"
namespace dy = paddle::platform::dynload;
class Logger : public nvinfer1::ILogger {
public:
void log(nvinfer1::ILogger::Severity severity, const char* msg) override {
void log(nvinfer1::ILogger::Severity severity,
const char* msg) TRT_NOEXCEPT override {
switch (severity) {
case Severity::kINFO:
LOG(INFO) << msg;
......@@ -74,10 +76,11 @@ nvinfer1::IHostMemory* CreateNetwork() {
Logger logger;
// Create the engine.
nvinfer1::IBuilder* builder = createInferBuilder(&logger);
auto config = builder->createBuilderConfig();
ScopedWeights weights(2.);
ScopedWeights bias(3.);
nvinfer1::INetworkDefinition* network = builder->createNetwork();
nvinfer1::INetworkDefinition* network = builder->createNetworkV2(0U);
// Add the input
auto input = network->addInput(kInputTensor, nvinfer1::DataType::kFLOAT,
nvinfer1::Dims3{1, 1, 1});
......@@ -91,8 +94,8 @@ nvinfer1::IHostMemory* CreateNetwork() {
network->markOutput(*output);
// Build the engine.
builder->setMaxBatchSize(1);
builder->setMaxWorkspaceSize(1 << 10);
auto engine = builder->buildCudaEngine(*network);
config->setMaxWorkspaceSize(1 << 10);
auto engine = builder->buildEngineWithConfig(*network, *config);
EXPECT_NE(engine, nullptr);
// Serialize the engine to create a model, then close.
nvinfer1::IHostMemory* model = engine->serialize();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册