network.h 3.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#pragma once

#include "megbrain/test/helper.h"

#include "megbrain/gopt/framework.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"

namespace mgb {
class Network {
private:
19
    HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen{-0.01, 0.01};
20 21 22 23 24 25 26 27 28 29 30 31 32
    CompNode cn;

public:
    std::shared_ptr<ComputingGraph> graph = ComputingGraph::make();
    Network(CompNode cn_) : cn{cn_} {}
    ~Network() noexcept = default;
    using KernSize = SmallVector<size_t, 2>;
    using Stride = SmallVector<size_t, 2>;
    using Padding = SmallVector<size_t, 2>;
    SymbolVar add_var(const char* name, const TensorShape& shp = {1}) {
        return opr::Host2DeviceCopy::make(*graph, gen(shp), cn).rename(name);
    }
    SymbolVar add_cvar(const char* name, const TensorShape& shp = {1}) {
M
Megvii Engine Team 已提交
33
        return opr::SharedDeviceTensor::make(*graph, *gen(shp), cn).rename(name);
34 35
    }

M
Megvii Engine Team 已提交
36 37 38 39
    SymbolVar add_conv(
            SymbolVar f, size_t output_channels, KernSize kern_size,
            DType out_dtype = dtype::Float32(), bool has_relu = true,
            Stride stride = {1, 1}, Padding padding = {0, 0});
40 41 42 43
    SymbolVar add_group_conv(
            SymbolVar f, size_t output_channels, size_t groups, KernSize kern_size,
            DType out_dtype = dtype::Float32(), bool has_relu = true,
            Stride stride = {1, 1}, Padding padding = {0, 0});
M
Megvii Engine Team 已提交
44 45
    SymbolVar add_deconv(
            SymbolVar f, size_t ratio, size_t output_channels, DType out_dtype);
46 47 48 49 50 51 52 53 54
    SymbolVar add_elemwise(
            const SymbolVarArray inps, DType out_dtype = dtype::Float32(),
            opr::Elemwise::Param::Mode mode = opr::Elemwise::Param::Mode::ADD);
    using Window = SmallVector<size_t, 2>;
    SymbolVar add_pooling(
            SymbolVar f, Window window, Stride stride = {1, 1},
            Padding padding = {0, 0},
            opr::Pooling::Param::Mode mode = opr::Pooling::Param::Mode::MAX);
    SymbolVar add_type_cvt(SymbolVar f, DType out_dtype = dtype::Float32());
55
    SymbolVar add_concat(SymbolVar f, SymbolVar g, int axis = 0);
56 57 58 59 60 61
    SymbolVar add_dimshuffle(SymbolVar f, std::vector<int> pattern);
    SymbolVar add_axisaddremove(SymbolVar f);
    SymbolVar add_subtensor(SymbolVar f);
    SymbolVar add_reshape(SymbolVar f);
    SymbolVar add_broadcast(SymbolVar f);
    SymbolVar add_copy(SymbolVar f);
62 63
};

M
Megvii Engine Team 已提交
64 65 66
SymbolVar create_block(
        Network& network, SymbolVar f, size_t stride, size_t num_outputs1,
        bool has_proj = false, DType out_dtype = dtype::Float32());
67

M
Megvii Engine Team 已提交
68 69
SymbolVar make_resnet18(
        Network& network, size_t batch = 16, DType out_dtype = dtype::Float32());
70

M
Megvii Engine Team 已提交
71 72
SymbolVarArray make_det(
        Network& network, size_t batch = 16, DType out_dtype = dtype::Float32());
73

74 75
SymbolVar bottleneck(
        Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t,
76
        size_t stride, DType out_dtype = dtype::Float32());
77 78 79

SymbolVar bottleneck_group(
        Network& network, SymbolVar f, size_t input_channels, size_t channels,
80
        size_t stages, size_t s, size_t t, DType out_dtype = dtype::Float32());
81

82 83
SymbolVar make_mobilenet_v2(
        Network& network, size_t batch = 1, DType out_dtype = dtype::Float32());
84

85 86 87
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}