make_trt_net.cpp 18.6 KB
Newer Older
1 2 3 4
/**
 * \file src/tensorrt/test/make_trt_net.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"

#include "megbrain/opr/basic_arith.h"
#include "megbrain/plugin/profiler.h"
#include "megbrain/test/helper.h"
#include "megbrain/utils/debug.h"

#if MGB_ENABLE_TENSOR_RT
23 24
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
25
#include "make_trt_net.h"
M
Megvii Engine Team 已提交
26
#include "megbrain/tensorrt/tensorrt_opr.h"
27

28
#include <NvInferPlugin.h>
29 30 31 32 33 34 35 36 37 38 39 40 41 42
#include <random>

using namespace mgb;
using namespace opr;
using namespace nvinfer1;

intl::SimpleTensorRTNetwork::SimpleTensorRTNetwork() {
    host_x = gen({5, 23, 28, 28});
    host_w = gen({32, 23, 3, 3});
    host_b = gen({1, 32, 1, 1});

    graph = ComputingGraph::make();
    x = Host2DeviceCopy::make(*graph, host_x);
    auto w = Host2DeviceCopy::make(*graph, host_w),
M
Megvii Engine Team 已提交
43
         b = Host2DeviceCopy::make(*graph, host_b), y0 = opr::Convolution::make(x, w);
44 45 46
    y = y0 + b;
}

M
Megvii Engine Team 已提交
47 48
std::pair<nvinfer1::IBuilder*, INetworkDefinition*> intl::SimpleTensorRTNetwork::
        create_trt_network(bool has_batch_dim) {
49
    CompNode::load("xpu0").activate();
50 51 52 53 54 55 56 57 58 59 60 61 62
    Weights wt_filter{DataType::kFLOAT, nullptr, 0},
            wt_bias{DataType::kFLOAT, nullptr, 0};
    wt_filter.type = DataType::kFLOAT;
    wt_bias.type = DataType::kFLOAT;
    wt_filter.values = host_w->raw_ptr();
    wt_bias.values = host_b->raw_ptr();
    wt_filter.count = host_w->shape().total_nr_elems();
    wt_bias.count = host_b->shape().total_nr_elems();
    auto builder = createInferBuilder(TensorRTOpr::Logger::instance());
#if NV_TENSOR_RT_VERSION >= 6001
    nvinfer1::NetworkDefinitionCreationFlags flags;
    ::memset(&flags, 0, sizeof(nvinfer1::NetworkDefinitionCreationFlags));
    if (has_batch_dim)
M
Megvii Engine Team 已提交
63 64
        flags = 1 << static_cast<int>(
                        nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
65 66 67 68 69 70 71
    auto network = builder->createNetworkV2(flags);
#else
    auto network = builder->createNetwork();
#endif
    nvinfer1::ITensor* data;
#if NV_TENSOR_RT_VERSION >= 6001
    if (has_batch_dim) {
M
Megvii Engine Team 已提交
72
        data = network->addInput("data", DataType::kFLOAT, Dims4{5, 23, 28, 28});
73 74 75 76 77 78 79 80 81 82
    } else {
        data = network->addInput("data", DataType::kFLOAT, Dims3{23, 28, 28});
    }
    {
        nvinfer1::TensorFormats formats =
                1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
        data->setAllowedFormats(formats);
    }
#else
    if (has_batch_dim) {
M
Megvii Engine Team 已提交
83
        data = network->addInput("data", DataType::kFLOAT, DimsNCHW{5, 23, 28, 28});
84 85 86 87 88
    } else {
        data = network->addInput("data", DataType::kFLOAT, DimsCHW{23, 28, 28});
    }
#endif
    mgb_assert(data != nullptr, "data is invalid");
M
Megvii Engine Team 已提交
89
    auto conv1 = network->addConvolution(*data, 32, DimsHW{3, 3}, wt_filter, wt_bias);
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
    mgb_assert(conv1 != nullptr, "conv1 is invalid");
    conv1->setStride(DimsHW{1, 1});
    conv1->getOutput(0)->setName("prob");
    network->markOutput(*conv1->getOutput(0));
#if NV_TENSOR_RT_VERSION >= 6001
    {
        nvinfer1::TensorFormats formats =
                1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
        conv1->getOutput(0)->setAllowedFormats(formats);
    }
#endif

    return std::make_pair(builder, network);
}

105 106 107 108 109 110 111 112 113 114 115 116 117
intl::BatchedTensorRTNetwork::BatchedTensorRTNetwork() {
    host_x = gen({23, 28, 28});

    graph = ComputingGraph::make();
    x = Host2DeviceCopy::make(*graph, host_x);
    opr::Reduce::Param param1{Reduce::Mode::SUM, 0, Reduce::Param::DataType::DEFAULT};
    opr::Reduce::Param param2{Reduce::Mode::SUM, 1, Reduce::Param::DataType::DEFAULT};
    auto y0 = opr::Reduce::make(x, param1);
    auto y1 = opr::Reduce::make(y0, param2);
    TensorShape tshp{1, 28};
    y = opr::Reshape::make(y1, tshp);
}

M
Megvii Engine Team 已提交
118 119
std::pair<nvinfer1::IBuilder*, INetworkDefinition*> intl::BatchedTensorRTNetwork::
        create_trt_network(bool has_batch_dim) {
120 121 122 123 124 125
    CompNode::load("xpu0").activate();
    auto builder = createInferBuilder(TensorRTOpr::Logger::instance());
#if NV_TENSOR_RT_VERSION >= 6001
    nvinfer1::NetworkDefinitionCreationFlags flags;
    ::memset(&flags, 0, sizeof(nvinfer1::NetworkDefinitionCreationFlags));
    if (has_batch_dim)
M
Megvii Engine Team 已提交
126 127
        flags = 1 << static_cast<int>(
                        nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
128 129 130 131 132 133 134
    auto network = builder->createNetworkV2(flags);
#else
    auto network = builder->createNetwork();
#endif
    nvinfer1::ITensor* data;
#if NV_TENSOR_RT_VERSION >= 6001
    if (has_batch_dim) {
M
Megvii Engine Team 已提交
135
        data = network->addInput("data", DataType::kFLOAT, Dims4{1, 23, 28, 28});
136 137 138 139 140 141 142 143 144 145
    } else {
        data = network->addInput("data", DataType::kFLOAT, Dims3{23, 28, 28});
    }
    {
        nvinfer1::TensorFormats formats =
                1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
        data->setAllowedFormats(formats);
    }
#else
    if (has_batch_dim) {
M
Megvii Engine Team 已提交
146
        data = network->addInput("data", DataType::kFLOAT, DimsNCHW{1, 23, 28, 28});
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
    } else {
        data = network->addInput("data", DataType::kFLOAT, DimsCHW{23, 28, 28});
    }
#endif
    mgb_assert(data != nullptr, "data is invalid");
    auto reduce1 = network->addReduce(*data, nvinfer1::ReduceOperation::kSUM, 3, false);
    mgb_assert(reduce1 != nullptr, "reduce1 is invalid");
    reduce1->getOutput(0)->setName("prob");
    network->markOutput(*reduce1->getOutput(0));
#if NV_TENSOR_RT_VERSION >= 6001
    {
        nvinfer1::TensorFormats formats =
                1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
        reduce1->getOutput(0)->setAllowedFormats(formats);
    }
#endif

    return std::make_pair(builder, network);
}

167 168 169 170 171 172
intl::SimpleQuantizedTensorRTNetwork::SimpleQuantizedTensorRTNetwork() {
    host_x = range_gen({32, 8, 28, 28});
    host_w = weight_gen({8, 8, 3, 3});
    host_b = range_gen({1, 8, 1, 1});

    {
173 174
        void* w_ptr = host_w->raw_ptr();
        float* ptr = reinterpret_cast<float*>(w_ptr);
M
Megvii Engine Team 已提交
175 176
        ptr[0] = -127 * 1.1f;
        ptr[1] = 127 * 1.1f;
177 178 179
    }

    graph = ComputingGraph::make();
M
Megvii Engine Team 已提交
180
    auto mkvar = [this](const char* name, const std::shared_ptr<HostTensorND>& host_ts,
181 182
                        const DType& dtype) {
        return opr::TypeCvt::make(
M
Megvii Engine Team 已提交
183
                opr::Host2DeviceCopy::make(*graph, host_ts).rename(name), dtype);
184
    };
M
Megvii Engine Team 已提交
185
    auto mkcvar = [this](const char* name, const std::shared_ptr<HostTensorND>& host_ts,
186 187
                         const DType& dtype) {
        return opr::TypeCvt::make(
M
Megvii Engine Team 已提交
188
                opr::SharedDeviceTensor::make(*graph, *host_ts).rename(name), dtype);
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
    };

    x = mkvar("x", host_x, dtype::Float32());
    quantized_x = mkvar("quantized_x", host_x, dtype::QuantizedS8(1.2f));
    auto float_w = mkcvar("float_w", host_w, dtype::Float32()),
         float_b = mkcvar("float_b", host_b, dtype::Float32()),
         w = opr::TypeCvt::make(float_w, dtype::QuantizedS8(1.1f)),
         b = opr::TypeCvt::make(float_b, dtype::QuantizedS32(1.2f * 1.1f));

    {
        auto xshp = opr::GetVarShape::make(quantized_x);

        auto cv = [this](int v) { return quantized_x.make_scalar(v); };
        auto sub = [&xshp, &cv](int idx) {
            return opr::IndexAt::make(xshp, {{0, cv(idx)}});
        };
M
Megvii Engine Team 已提交
205
        auto tshp = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
206 207 208 209 210 211 212 213 214 215 216
        quantized_x = opr::Reshape::make(quantized_x, tshp);
        quantized_x = opr::Dimshuffle::make(quantized_x, {0, 1, 3, 4, 2});
    }

    {
        auto wshp = opr::GetVarShape::make(w);

        auto cv = [&w](int v) { return w.make_scalar(v); };
        auto sub = [&wshp, &cv](int idx) {
            return opr::IndexAt::make(wshp, {{0, cv(idx)}});
        };
M
Megvii Engine Team 已提交
217
        auto tshp = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
218 219 220 221 222 223 224 225 226 227 228
        w = opr::Reshape::make(w, tshp);
        w = opr::Dimshuffle::make(w, {0, 1, 3, 4, 2});
    }

    {
        auto bshp = opr::GetVarShape::make(b);

        auto cv = [&b](int v) { return b.make_scalar(v); };
        auto sub = [&bshp, &cv](int idx) {
            return opr::IndexAt::make(bshp, {{0, cv(idx)}});
        };
M
Megvii Engine Team 已提交
229
        auto tshp = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
230 231 232 233 234 235 236 237 238 239
        b = opr::Reshape::make(b, tshp);
        b = opr::Dimshuffle::make(b, {0, 1, 3, 4, 2});
    }

    opr::ConvBias::Param param;
    param.format = opr::ConvBias::Param::Format::NCHW4;
    param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY;
    param.stride_h = param.stride_w = 1;
    param.pad_h = param.pad_w = 1;

M
Megvii Engine Team 已提交
240 241
    quantized_y = opr::ConvBias::make(
            quantized_x, w, b, param, {}, OperatorNodeConfig{dtype::QuantizedS8(1.1f)});
242
    param.format = opr::ConvBias::Param::Format::NCHW;
M
Megvii Engine Team 已提交
243 244
    y = opr::ConvBias::make(
            x, float_w, float_b, param, {}, OperatorNodeConfig{dtype::Float32()});
245 246 247 248 249 250 251 252 253 254 255 256 257

    auto yshp = opr::GetVarShape::make(quantized_y);

    auto cv = [this](int v) { return quantized_y.make_scalar(v); };
    auto sub = [&yshp, &cv](int idx) {
        return opr::IndexAt::make(yshp, {{0, cv(idx)}});
    };
    auto tshp = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
    quantized_y = opr::Dimshuffle::make(quantized_y, {0, 1, 4, 2, 3});
    quantized_y = opr::Reshape::make(quantized_y, tshp);
    quantized_y = TypeCvt::make(quantized_y, dtype::Float32());
}

M
Megvii Engine Team 已提交
258 259
std::pair<nvinfer1::IBuilder*, INetworkDefinition*> intl::
        SimpleQuantizedTensorRTNetwork::create_trt_network(bool has_batch_dim) {
260
    CompNode::load("xpu0").activate();
261 262 263 264 265 266 267 268 269 270 271 272 273
    Weights wt_filter{DataType::kFLOAT, nullptr, 0},
            wt_bias{DataType::kFLOAT, nullptr, 0};
    wt_filter.type = DataType::kFLOAT;
    wt_bias.type = DataType::kFLOAT;
    wt_filter.values = host_w->raw_ptr();
    wt_bias.values = host_b->raw_ptr();
    wt_filter.count = host_w->shape().total_nr_elems();
    wt_bias.count = host_b->shape().total_nr_elems();
    auto builder = createInferBuilder(TensorRTOpr::Logger::instance());
#if NV_TENSOR_RT_VERSION >= 6001
    nvinfer1::NetworkDefinitionCreationFlags flags;
    ::memset(&flags, 0, sizeof(nvinfer1::NetworkDefinitionCreationFlags));
    if (has_batch_dim)
M
Megvii Engine Team 已提交
274 275
        flags = 1 << static_cast<int>(
                        nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
276 277 278 279 280 281 282
    auto network = builder->createNetworkV2(flags);
#else
    auto network = builder->createNetwork();
#endif
    nvinfer1::ITensor* data;
#if NV_TENSOR_RT_VERSION >= 6001
    if (has_batch_dim) {
M
Megvii Engine Team 已提交
283
        data = network->addInput("data", DataType::kFLOAT, Dims4{32, 8, 28, 28});
284 285 286 287 288 289 290 291 292 293
    } else {
        data = network->addInput("data", DataType::kFLOAT, Dims3{8, 28, 28});
    }
    {
        nvinfer1::TensorFormats formats =
                1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
        data->setAllowedFormats(formats);
    }
#else
    if (has_batch_dim) {
M
Megvii Engine Team 已提交
294
        data = network->addInput("data", DataType::kFLOAT, DimsNCHW{32, 8, 28, 28});
295 296 297 298 299 300 301
    } else {
        data = network->addInput("data", DataType::kFLOAT, DimsCHW{8, 28, 28});
    }
#endif
    data->setDynamicRange(-127.f * 1.2f, 127.f * 1.2f);
    mgb_assert(data != nullptr, "data is invalid");
    auto add_conv = [&](const char* name, nvinfer1::ITensor* inp) {
M
Megvii Engine Team 已提交
302
        auto conv = network->addConvolution(*inp, 8, DimsHW{3, 3}, wt_filter, wt_bias);
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
        mgb_assert(conv != nullptr, "conv1 is invalid");
        conv->setName(name);
        conv->setStride(DimsHW{1, 1});
        conv->setPadding(DimsHW{1, 1});
        conv->getOutput(0)->setDynamicRange(-127.f * 1.1f, 127.f * 1.1f);
        // conv->setPrecision(nvinfer1::DataType::kINT8);
        return conv->getOutput(0);
    };
    auto out = add_conv("conv1", data);
    out->setName("prob");
#if NV_TENSOR_RT_VERSION >= 6001
    {
        nvinfer1::TensorFormats formats =
                1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
        out->setAllowedFormats(formats);
    }
#endif
    network->markOutput(*out);

    return std::make_pair(builder, network);
}

intl::ConcatConvTensorRTNetwork::ConcatConvTensorRTNetwork() {
    host_x0 = gen({5, 23, 14, 28});
    host_x1 = gen({5, 23, 14, 28});
    host_w = gen({32, 46, 3, 3});
    host_b = gen({1, 32, 1, 1});

    graph = ComputingGraph::make();
    x0 = Host2DeviceCopy::make(*graph, host_x0);
    x1 = Host2DeviceCopy::make(*graph, host_x1);
M
Megvii Engine Team 已提交
334 335
    auto y0 = opr::Concat::make({x0, x1}, 1), w = Host2DeviceCopy::make(*graph, host_w),
         b = Host2DeviceCopy::make(*graph, host_b), y1 = opr::Convolution::make(y0, w);
336 337 338
    y = y1 + b;
}

M
Megvii Engine Team 已提交
339 340
std::pair<nvinfer1::IBuilder*, INetworkDefinition*> intl::ConcatConvTensorRTNetwork::
        create_trt_network(bool has_batch_dim) {
341
    CompNode::load("xpu0").activate();
342 343 344 345
    auto builder = createInferBuilder(TensorRTOpr::Logger::instance());
#if NV_TENSOR_RT_VERSION >= 6001
    nvinfer1::NetworkDefinitionCreationFlags flags;
    ::memset(&flags, 0, sizeof(nvinfer1::NetworkDefinitionCreationFlags));
M
Megvii Engine Team 已提交
346 347 348
    if (has_batch_dim)
        flags = 1 << static_cast<int>(
                        nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
349 350 351 352 353 354 355
    auto network = builder->createNetworkV2(flags);
#else
    auto network = builder->createNetwork();
#endif
    ITensor *data0, *data1;
#if NV_TENSOR_RT_VERSION >= 6001
    if (has_batch_dim) {
M
Megvii Engine Team 已提交
356 357
        data0 = network->addInput("x0", DataType::kFLOAT, Dims4{5, 23, 14, 28});
        data1 = network->addInput("x1", DataType::kFLOAT, Dims4{5, 23, 14, 28});
358 359 360 361 362 363 364 365 366 367 368 369
    } else {
        data0 = network->addInput("x0", DataType::kFLOAT, Dims3{23, 14, 28});
        data1 = network->addInput("x1", DataType::kFLOAT, Dims3{23, 14, 28});
    }
    {
        nvinfer1::TensorFormats formats =
                1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
        data0->setAllowedFormats(formats);
        data1->setAllowedFormats(formats);
    }
#else
    if (has_batch_dim) {
M
Megvii Engine Team 已提交
370 371
        data0 = network->addInput("x0", DataType::kFLOAT, DimsNCHW{5, 23, 14, 28});
        data1 = network->addInput("x1", DataType::kFLOAT, DimsNCHW{5, 23, 14, 28});
372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
    } else {
        data0 = network->addInput("x0", DataType::kFLOAT, DimsCHW{23, 14, 28});
        data1 = network->addInput("x1", DataType::kFLOAT, DimsCHW{23, 14, 28});
    }
#endif
    ITensor* inputTensors[] = {data0, data1};
    auto concat = network->addConcatenation(inputTensors, 2);
    mgb_assert(concat != nullptr, "concat is null!");
    concat->setName("concat0");
    if (has_batch_dim) {
        concat->setAxis(1);
    } else {
        concat->setAxis(0);
    }

    Weights wt_filter{DataType::kFLOAT, host_w->raw_ptr(), 0},
            wt_bias{DataType::kFLOAT, host_b->raw_ptr(), 0};
    wt_filter.count = host_w->shape().total_nr_elems();
    wt_bias.count = host_b->shape().total_nr_elems();
M
Megvii Engine Team 已提交
391 392
    auto conv1 = network->addConvolution(
            *concat->getOutput(0), 32, DimsHW{3, 3}, wt_filter, wt_bias);
393 394 395 396 397 398 399 400 401 402 403 404 405 406 407
    mgb_assert(conv1 != nullptr, "conv1 is invalid");
    conv1->setName("conv1");
    conv1->setStride(DimsHW{1, 1});
    conv1->getOutput(0)->setName("convOut");
    network->markOutput(*conv1->getOutput(0));
#if NV_TENSOR_RT_VERSION >= 6001
    {
        nvinfer1::TensorFormats formats =
                1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
        conv1->getOutput(0)->setAllowedFormats(formats);
    }
#endif
    return std::make_pair(builder, network);
}

408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
intl::ReshapeConcatTensorRTNetwork::ReshapeConcatTensorRTNetwork() {
    host_x0 = gen({2, 2, 2, 2});
    host_y0 = gen({2, 3, 2, 2});

    graph = ComputingGraph::make();
    x0 = Host2DeviceCopy::make(*graph, host_x0);
    y0 = Host2DeviceCopy::make(*graph, host_y0);
    auto x1 = opr::Reshape::make(x0, {2, 8, 1, 1}),
         y1 = opr::Reshape::make(y0, {2, 12, 1, 1});
    z = opr::Concat::make({x1, y1}, 1);
}

std::pair<nvinfer1::IBuilder*, INetworkDefinition*> intl::ReshapeConcatTensorRTNetwork::
        create_trt_network(bool has_batch_dim) {
    initLibNvInferPlugins(&TensorRTOpr::Logger::instance(), "");

    CompNode::load("xpu0").activate();
    auto builder = createInferBuilder(TensorRTOpr::Logger::instance());
#if NV_TENSOR_RT_VERSION >= 6001
    nvinfer1::NetworkDefinitionCreationFlags flags;
    ::memset(&flags, 0, sizeof(nvinfer1::NetworkDefinitionCreationFlags));
    if (has_batch_dim)
        flags = 1 << static_cast<int>(
                        nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
    auto network = builder->createNetworkV2(flags);
#else
    auto network = builder->createNetwork();
#endif
    nvinfer1::ITensor *data0, *data1;
#if NV_TENSOR_RT_VERSION >= 6001
    if (has_batch_dim) {
        data0 = network->addInput("x0", DataType::kFLOAT, Dims4{2, 2, 2, 2});
        data1 = network->addInput("y0", DataType::kFLOAT, Dims4{2, 3, 2, 2});
    } else {
        data0 = network->addInput("x0", DataType::kFLOAT, Dims3{2, 2, 2});
        data1 = network->addInput("y0", DataType::kFLOAT, Dims3{3, 2, 2});
    }
    {
        nvinfer1::TensorFormats formats =
                1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
        data0->setAllowedFormats(formats);
        data1->setAllowedFormats(formats);
    }
#else
    if (has_batch_dim) {
        data0 = network->addInput("x0", DataType::kFLOAT, DimsNCHW{2, 2, 2, 2});
        data1 = network->addInput("y0", DataType::kFLOAT, DimsNCHW{2, 3, 2, 2});
    } else {
        data0 = network->addInput("x0", DataType::kFLOAT, DimsCHW{2, 2, 2});
        data1 = network->addInput("y0", DataType::kFLOAT, DimsCHW{3, 2, 2});
    }
#endif
    int axis = 1;
    bool ignoreBatch = false;
    nvinfer1::PluginField fields[2] = {
            nvinfer1::PluginField{"axis", &axis, nvinfer1::PluginFieldType::kINT32, 1},
            nvinfer1::PluginField{
                    "ignoreBatch", &ignoreBatch, nvinfer1::PluginFieldType::kINT32, 1},
    };
    nvinfer1::PluginFieldCollection fc{2, fields};

    auto creator = getPluginRegistry()->getPluginCreator("FlattenConcat_TRT", "1", "");
    TensorRTUniquePtr<nvinfer1::IPluginV2> plugin(
            creator->createPlugin("FlattenConcat_TRT", &fc));
    ITensor* inputTensors[] = {data0, data1};
    auto flt_cct = network->addPluginV2(inputTensors, 2, *plugin);
    mgb_assert(flt_cct != nullptr, "FlattenConcat_TRT is invalid");
    network->markOutput(*flt_cct->getOutput(0));
#if NV_TENSOR_RT_VERSION >= 6001
    {
        nvinfer1::TensorFormats formats =
                1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
        flt_cct->getOutput(0)->setAllowedFormats(formats);
    }
#endif
    return std::make_pair(builder, network);
}

486
#pragma GCC diagnostic pop
487 488 489
#endif  // MGB_ENABLE_TENSOR_RT

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}