make_trt_net.h 2.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
/**
 * \file src/tensorrt/test/make_trt_net.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/plugin/profiler.h"
#include "megbrain/test/helper.h"
#include "megbrain/utils/debug.h"

#if MGB_ENABLE_TENSOR_RT

#include "megbrain/tensorrt/tensorrt_opr.h"

#include <random>

using namespace mgb;
using namespace opr;
using namespace nvinfer1;

template <typename T>
using TensorRTUniquePtr = intl::TensorRTUniquePtr<T>;

namespace mgb{
namespace opr{
namespace intl{

struct SimpleTensorRTNetwork {
    HostTensorGenerator<> gen;
    std::shared_ptr<HostTensorND> host_x, host_w, host_b;
    std::shared_ptr<ComputingGraph> graph;
    SymbolVar x, y;

    HostTensorND host_z1;

    SimpleTensorRTNetwork();

    std::pair<nvinfer1::IBuilder*, INetworkDefinition*>
    create_trt_network(bool has_batch_dim);
};

struct SimpleQuantizedTensorRTNetwork {
    HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> weight_gen{
            1*1.1f, 127*1.1f};
    HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> range_gen{
            1*1.2f, 127*1.2f};
    std::shared_ptr<HostTensorND> host_x, host_w, host_b;
    std::shared_ptr<ComputingGraph> graph;
    SymbolVar x, y;
    SymbolVar quantized_x, quantized_y;

    SimpleQuantizedTensorRTNetwork();

    std::pair<nvinfer1::IBuilder*, INetworkDefinition*>
    create_trt_network(bool has_batch_dim);
};

struct ConcatConvTensorRTNetwork {
    HostTensorGenerator<> gen;
    std::shared_ptr<HostTensorND> host_x0, host_x1, host_x, host_w, host_b;
    std::shared_ptr<ComputingGraph> graph;
    SymbolVar x0, x1, y;

    HostTensorND host_z1;

    ConcatConvTensorRTNetwork();

    std::pair<nvinfer1::IBuilder*, INetworkDefinition*>
    create_trt_network(bool has_batch_dim);
};

}  // namespace intl
}  // namespace opr
}  // namespace mgb


#endif  // MGB_ENABLE_TENSOR_RT

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}