make_trt_net.h 3.1 KB
Newer Older
1 2 3 4
/**
 * \file src/tensorrt/test/make_trt_net.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/plugin/profiler.h"
#include "megbrain/test/helper.h"
#include "megbrain/utils/debug.h"

#if MGB_ENABLE_TENSOR_RT

#include "megbrain/tensorrt/tensorrt_opr.h"

#include <random>

using namespace mgb;
using namespace opr;
using namespace nvinfer1;

template <typename T>
using TensorRTUniquePtr = intl::TensorRTUniquePtr<T>;

M
Megvii Engine Team 已提交
33 34 35
namespace mgb {
namespace opr {
namespace intl {
36 37 38 39 40 41 42 43 44 45 46

struct SimpleTensorRTNetwork {
    HostTensorGenerator<> gen;
    std::shared_ptr<HostTensorND> host_x, host_w, host_b;
    std::shared_ptr<ComputingGraph> graph;
    SymbolVar x, y;

    HostTensorND host_z1;

    SimpleTensorRTNetwork();

M
Megvii Engine Team 已提交
47 48
    std::pair<nvinfer1::IBuilder*, INetworkDefinition*> create_trt_network(
            bool has_batch_dim);
49 50
};

51 52 53 54 55 56 57 58 59 60
struct BatchedTensorRTNetwork {
    HostTensorGenerator<> gen;
    std::shared_ptr<HostTensorND> host_x, host_w, host_b;
    std::shared_ptr<ComputingGraph> graph;
    SymbolVar x, y;

    HostTensorND host_z1;

    BatchedTensorRTNetwork();

M
Megvii Engine Team 已提交
61 62
    std::pair<nvinfer1::IBuilder*, INetworkDefinition*> create_trt_network(
            bool has_batch_dim);
63 64
};

65 66
struct SimpleQuantizedTensorRTNetwork {
    HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> weight_gen{
M
Megvii Engine Team 已提交
67
            1 * 1.1f, 127 * 1.1f};
68
    HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> range_gen{
M
Megvii Engine Team 已提交
69
            1 * 1.2f, 127 * 1.2f};
70 71 72 73 74 75 76
    std::shared_ptr<HostTensorND> host_x, host_w, host_b;
    std::shared_ptr<ComputingGraph> graph;
    SymbolVar x, y;
    SymbolVar quantized_x, quantized_y;

    SimpleQuantizedTensorRTNetwork();

M
Megvii Engine Team 已提交
77 78
    std::pair<nvinfer1::IBuilder*, INetworkDefinition*> create_trt_network(
            bool has_batch_dim);
79 80 81 82 83 84 85 86 87 88 89 90
};

struct ConcatConvTensorRTNetwork {
    HostTensorGenerator<> gen;
    std::shared_ptr<HostTensorND> host_x0, host_x1, host_x, host_w, host_b;
    std::shared_ptr<ComputingGraph> graph;
    SymbolVar x0, x1, y;

    HostTensorND host_z1;

    ConcatConvTensorRTNetwork();

M
Megvii Engine Team 已提交
91 92
    std::pair<nvinfer1::IBuilder*, INetworkDefinition*> create_trt_network(
            bool has_batch_dim);
93 94
};

95 96 97 98 99 100 101 102 103 104 105 106
struct ReshapeConcatTensorRTNetwork {
    HostTensorGenerator<> gen;
    std::shared_ptr<HostTensorND> host_x0, host_y0;
    std::shared_ptr<ComputingGraph> graph;
    SymbolVar x0, y0, z;

    ReshapeConcatTensorRTNetwork();

    std::pair<nvinfer1::IBuilder*, INetworkDefinition*> create_trt_network(
            bool has_batch_dim);
};

107 108 109 110 111 112 113
}  // namespace intl
}  // namespace opr
}  // namespace mgb

#endif  // MGB_ENABLE_TENSOR_RT

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}