timvx_executor.hpp 5.1 KB
Newer Older
B
BUG1989 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * License); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*
 * Copyright (c) 2021, Open AI Lab
 * Author: lswang@openailab.com
 */

K
kalcohol 已提交
25
#pragma once
B
BUG1989 已提交
26

N
nihui 已提交
27
extern "C" {
K
kalcohol 已提交
28 29 30 31 32 33 34
#include "device/device.h"
#include "graph/tensor.h"
#include "graph/node.h"
#include "graph/graph.h"
#include "graph/subgraph.h"
#include "operator/op.h"
#include "utility/log.h"
35 36

#include "timvx_dump.h"
B
BUG1989 已提交
37 38 39 40 41 42 43 44
}

#include <map>
#include <algorithm>
#include <iomanip>
#include <iostream>
#include <tuple>
#include <vector>
B
BowShotDS 已提交
45 46
#include <cmath>

K
kalcohol 已提交
47
#include "convolution_param.h"
48
#include "deconv_param.h"
K
kalcohol 已提交
49 50

#include "tim/vx/tensor.h"
B
BUG1989 已提交
51 52 53 54 55 56

#include "tim/vx/context.h"
#include "tim/vx/graph.h"
#include "tim/vx/operation.h"

#include "tim/vx/ops/activations.h"
B
BowShotDS 已提交
57
#include "tim/vx/ops/batchnorm.h"
H
huoshuai-dot 已提交
58
#include "tim/vx/ops/clip.h"
B
BUG1989 已提交
59 60
#include "tim/vx/ops/concat.h"
#include "tim/vx/ops/conv2d.h"
61
#include "tim/vx/ops/deconv.h"
B
BowShotDS 已提交
62
#include "tim/vx/ops/depth2space.h"
B
BUG1989 已提交
63 64
#include "tim/vx/ops/elementwise.h"
#include "tim/vx/ops/fullyconnected.h"
B
BUG1989 已提交
65
#include "tim/vx/ops/gather.h"
wangq3802's avatar
wangq3802 已提交
66
#include "tim/vx/ops/groupedconv2d.h"
67
#include "tim/vx/ops/instancenormalization.h"
wangq3802's avatar
wangq3802 已提交
68
#include "tim/vx/ops/pad.h"
B
BUG1989 已提交
69
#include "tim/vx/ops/pool2d.h"
wangq3802's avatar
wangq3802 已提交
70
#include "tim/vx/ops/reduce.h"
B
BUG1989 已提交
71
#include "tim/vx/ops/reshape.h"
B
BUG1989 已提交
72
#include "tim/vx/ops/resize.h"
73
#include "tim/vx/ops/simple_operations.h"
B
BUG1989 已提交
74
#include "tim/vx/ops/slice.h"
B
BUG1989 已提交
75
#include "tim/vx/ops/softmax.h"
B
BUG1989 已提交
76
#include "tim/vx/ops/space2depth.h"
B
BowShotDS 已提交
77
#include "tim/vx/ops/split.h"
B
BowShotDS 已提交
78
#include "tim/vx/ops/stridedslice.h"
B
BowShotDS 已提交
79
#include "tim/vx/ops/transpose.h"
80 81
#include "tim/vx/ops/spatial_transformer.h"
#include "tim/vx/ops/l2normalization.h"
82
#include "tim/vx/ops/layernormalization.h"
B
BUG1989 已提交
83

N
nihui 已提交
84 85 86 87 88 89 90 91 92
#define SPEC_TYPE_CONV      1
#define SPEC_TYPE_CONV_BIAS 2
#define SPEC_TYPE_DWCONV    3
#define SPEC_TYPE_INTERP    4
#define SPEC_TYPE_OUTPUT    5
#define SPEC_TYPE_PRELU     6
#define SPEC_TYPE_SLICE     7
#define SPEC_TYPE_RESHAPE   8
#define SPEC_TYPE_INPUT     9
B
BUG1989 已提交
93
#define SPEC_TYPE_DW_DECONV 10
B
BUG1989 已提交
94

N
nihui 已提交
95 96
typedef std::map<uint32_t, std::shared_ptr<tim::vx::Tensor> > dict_irt2vxt;
typedef std::map<uint32_t, std::shared_ptr<tim::vx::Operation> > dict_irt2vxo;
B
BUG1989 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109

class VXEngine
{
public:
    VXEngine();
    ~VXEngine() = default;

    int VXEnginePreRun(struct subgraph* subgraph);
    int VXEngineRun(struct subgraph* subgraph);
    void VXEnginePostRun();

private:
    int Build(struct subgraph* subgraph);
110
    int VXTensorMap(struct graph* ir_graph, int ir_tensor_idx, int spec_type);
K
kalcohol 已提交
111

B
BowShotDS 已提交
112
    bool AddBatchNormNode(struct node* ir_node);
K
kalcohol 已提交
113 114 115
    bool AddClipNode(struct node* ir_node);
    bool AddConcatNode(struct node* ir_node);
    bool AddConvolutionNode(struct node* ir_node);
B
BowShotDS 已提交
116
    bool AddCropNode(struct node* ir_node);
117
    bool AddDeconvNode(struct node* ir_node);
K
kalcohol 已提交
118 119 120 121 122 123 124 125
    bool AddDepthToSpaceNode(struct node* ir_node);
    bool AddDropoutNode(struct node* ir_node);
    bool AddEltwiseNode(struct node* ir_node);
    bool AddEluNode(struct node* ir_node);
    bool AddFlattenNode(struct node* ir_node);
    bool AddFullyConnectionNode(struct node* node);
    bool AddGatherNode(struct node* node);
    bool AddHardSwishNode(struct node* node);
126
    bool AddInstanceNormNode(struct node* node);
K
kalcohol 已提交
127
    bool AddInterpNode(struct node* ir_node);
B
BUG1989 已提交
128
    bool AddMishNode(struct node* ir_node);
wangq3802's avatar
wangq3802 已提交
129
    bool AddPadNode(struct node* ir_node);
K
kalcohol 已提交
130 131 132
    bool AddPermuteNode(struct node* ir_node);
    bool AddPoolingNode(struct node* ir_node);
    bool AddPReluNode(struct node* ir_node);
wangq3802's avatar
wangq3802 已提交
133
    bool AddReduceNode(struct node* ir_node);
K
kalcohol 已提交
134 135 136
    bool AddReluNode(struct node* ir_node);
    bool AddRelu1Node(struct node* ir_node);
    bool AddReshapeNode(struct node* ir_node);
137 138
    bool AddResizeNode(struct node* ir_node);
    bool AddScaleNode(struct node* ir_node);
K
kalcohol 已提交
139 140 141 142
    bool AddSigmoidNode(struct node* ir_node);
    bool AddSliceNode(struct node* ir_node);
    bool AddSoftmaxNode(struct node* ir_node);
    bool AddSpaceToDepthNode(struct node* ir_node);
B
BowShotDS 已提交
143
    bool AddSplitNode(struct node* ir_node);
K
kalcohol 已提交
144 145 146
    bool AddTanhNode(struct node* ir_node);
    bool AddTransposeNode(struct node* ir_node);
    bool AddUpsampleNode(struct node* ir_node);
147 148
    bool AddSpatialtransformerNode(struct node* ir_node);
    bool AddL2normalizationNode(struct node* ir_node);
149 150
    bool AddGeluNode(struct node* ir_node);
    bool AddLayerNormNode(struct node* ir_node);
B
BUG1989 已提交
151

B
BUG1989 已提交
152 153 154 155
public:
    std::shared_ptr<tim::vx::Context> context;
    std::shared_ptr<tim::vx::Graph> graph;
    std::shared_ptr<tim::vx::Operation> ops;
156
    std::vector<char> nbg_buffer;
B
BUG1989 已提交
157 158

private:
N
nihui 已提交
159 160
    dict_irt2vxt vx_tensor_map;
    dict_irt2vxo vx_node_map;
B
BUG1989 已提交
161
};