optimizer.h 10.4 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16
#include <map>
Y
Yan Chunwei 已提交
17
#include <memory>
18
#include <set>
Y
Yan Chunwei 已提交
19
#include <string>
20
#include <utility>
Y
Yan Chunwei 已提交
21
#include <vector>
22
#include "lite/core/mir/elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.h"
Y
Yan Chunwei 已提交
23 24
#include "lite/core/mir/generate_program_pass.h"
#include "lite/core/mir/pass_manager.h"
25
#include "lite/core/mir/pass_utils.h"
Y
Yan Chunwei 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39
#include "lite/core/mir/ssa_graph.h"
#include "lite/core/mir/static_kernel_pick_pass.h"
#include "lite/core/mir/type_target_cast_pass.h"
#include "lite/core/program.h"
#include "lite/core/types.h"
#include "lite/model_parser/model_parser.h"

namespace paddle {
namespace lite {

/*
 * lite::Optimizer optimize a program. It utilize the mir passes to analysis the
 * program and export an optimized program.
 */
40 41 42
// TODO(hong1986032) Support the following passes for the subblocks
const std::set<std::string> kSubblockUnsupportedPasses(
    {"memory_optimize_pass"});
Y
Yan Chunwei 已提交
43 44
class Optimizer {
 public:
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
  Optimizer() {}

  Optimizer(Program&& program, const std::vector<Place>& valid_places) {
    program_ = &program;
    valid_places_ = valid_places;
    CHECK(!valid_places.empty()) << "At least one valid_place should be set";

    core::KernelPickFactor factor;
    factor.ConsiderTarget();
    factor.ConsiderPrecision();
    factor.ConsiderDataLayout();

    Run(std::move(program), valid_places, factor, {});
  }

Y
Yan Chunwei 已提交
60 61 62 63 64 65 66
  void Run(Program&& program,
           const std::vector<Place>& valid_places,
           core::KernelPickFactor kernel_pick_factor,
           const std::vector<std::string>& passes = {}) {
    program_ = &program;
    valid_places_ = valid_places;
    CHECK(!valid_places.empty()) << "At least one valid_place should be set";
67 68 69 70 71 72 73 74 75 76
    CHECK(graphs_.empty()) << "duplicate optimize found";

    auto block_size = program.block_size();
    for (size_t block_idx = 0; block_idx < block_size; ++block_idx) {
      std::unique_ptr<mir::SSAGraph> graph;
      graph.reset(new mir::SSAGraph);
      graph->Build(program, valid_places, block_idx);
      graph->SetValidPlaces(valid_places);
      graphs_.emplace_back(std::move(graph));
    }
Y
Yan Chunwei 已提交
77 78 79

    SpecifyKernelPickTactic(kernel_pick_factor);
    InitTargetTypeTransformPass();
80
    InitControlFlowOpUnusedInputsAndOutputsEliminatePass();
Y
Yan Chunwei 已提交
81

82
    if (passes.empty() || passes.size() == 1) {
83
      std::vector<std::string> passes_local{
J
juncaipeng 已提交
84 85 86 87 88
          {"lite_quant_dequant_fuse_pass",         //
           "weight_quantization_preprocess_pass",  //
           "lite_conv_elementwise_fuse_pass",      // conv-elemwise-bn
           "lite_conv_bn_fuse_pass",               //
           "lite_conv_elementwise_fuse_pass",      // conv-bn-elemwise
H
HappyAngel 已提交
89
           "lite_conv_conv_fuse_pass",             //
Y
Yan Chunwei 已提交
90 91
           // TODO(Superjomn) Refine the fusion related design to select fusion
           // kernels for devices automatically.
92
           "lite_conv_activation_fuse_pass",              //
93
           "lite_var_conv_2d_activation_fuse_pass",       //
W
Wilber 已提交
94
           "lite_match_matrix_activation_fuse_pass",      //
95 96 97
           "lite_fc_fuse_pass",                           //
           "lite_shuffle_channel_fuse_pass",              //
           "lite_transpose_softmax_transpose_fuse_pass",  //
Z
zhupengyang 已提交
98
           "lite_interpolate_fuse_pass",                  //
99
           "identity_scale_eliminate_pass",               //
H
HappyAngel 已提交
100
           "elementwise_mul_constant_eliminate_pass",     //
101
           "lite_sequence_pool_concat_fuse_pass",         //
102
           "lite_scale_activation_fuse_pass",             //
H
HappyAngel 已提交
103 104
#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \
    (defined LITE_WITH_ARM)
105
           "lite_elementwise_activation_fuse_pass",  //
Y
Yan Chunwei 已提交
106
#endif
107
           "identity_dropout_eliminate_pass",
108
           "__xpu__resnet_fuse_pass",
109 110
           "__xpu__resnet_cbam_fuse_pass",
           "__xpu__mmdnn_fuse_pass",
111
           "__xpu__multi_encoder_fuse_pass",
C
Cwndmiao 已提交
112 113
           "__xpu__embedding_with_eltwise_add_fuse_pass",
           "__xpu__fc_fuse_pass",
114 115 116 117 118 119
           "quantized_op_attributes_inference_pass",  // Only for fully
                                                      // quantized model, infer
                                                      // the output scale and
                                                      // fix the attribute
                                                      // 'enable_int8' for all
                                                      // of the quantized ops.
120
           "npu_subgraph_pass",
121
           "huawei_ascend_npu_subgraph_pass",
122 123
           "xpu_subgraph_pass",
           "bm_subgraph_pass",
H
hong19860320 已提交
124
           "apu_subgraph_pass",
125
           "rknpu_subgraph_pass",
126
           "mlu_subgraph_pass",
127
           "control_flow_op_unused_inputs_and_outputs_eliminate_pass",
128
           "static_kernel_pick_pass",  // pick original kernel from graph
129

130
           "remove_tf_redundant_ops_pass",
131
           "variable_place_inference_pass",  // inference arg/var's
132 133

           "mlu_postprocess_pass",
134 135 136 137 138
           // info(target/precision/layout/device)
           // using kernel info
           "argument_type_display_pass",  // debug pass: show arg-type-node's
                                          // info
                                          // (target/precision/layout/device)
Y
Yan Chunwei 已提交
139

140 141 142
           "type_target_cast_pass",  // add io_copy/io_copy_once if meet
                                     // different targets when last and next
                                     // node
Y
Yan Chunwei 已提交
143 144 145
           "variable_place_inference_pass",  //
           "argument_type_display_pass",     //

146 147 148
           "io_copy_kernel_pick_pass",    //
           "argument_type_display_pass",  //

Y
Yan Chunwei 已提交
149 150 151 152 153 154 155
           "variable_place_inference_pass",  //
           "argument_type_display_pass",     //

           "type_precision_cast_pass",       //
           "variable_place_inference_pass",  //
           "argument_type_display_pass",     //

156 157 158 159
           "type_layout_cast_pass",  // add layout/layout_once op if meet
                                     // different layout when last and next node
           "argument_type_display_pass",  //

Y
Yan Chunwei 已提交
160
           "variable_place_inference_pass",  //
161
           "argument_type_display_pass",
Y
Yan Chunwei 已提交
162 163

           "runtime_context_assign_pass",
164
           "argument_type_display_pass",
165

166
           "memory_optimize_pass"}};
167

168
      if (passes.size() == 1) {
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
        // multi_stream_analysis_pass must be in the front of
        // runtime_context_assign_pass
        const std::string msa_pass{"multi_stream_analysis_pass"};
        const std::string depend_pass{"runtime_context_assign_pass"};
        if (passes[0] == msa_pass) {
          auto iter =
              std::find(passes_local.begin(), passes_local.end(), depend_pass);
          if (iter != passes_local.end()) {
            passes_local.insert(iter, msa_pass);
          } else {
            CHECK(false) << "Not find " << depend_pass;
          }
        } else {
          passes_local.push_back(passes[0]);
        }
184
      }
185
      RunPasses(passes_local);
Y
Yan Chunwei 已提交
186 187 188 189 190 191
    } else {
      RunPasses(passes);
    }
    exec_scope_ = program.exec_scope();
  }

192
  const Scope* exec_scope() const { return exec_scope_; }
193

Y
Yan Chunwei 已提交
194 195 196 197
  // Generate a new program based on the mir graph.
  std::unique_ptr<RuntimeProgram> GenRuntimeProgram() {
    auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
        "generate_program_pass");
198 199 200
    for (auto& graph : graphs_) {
      pass->Apply(graph);
    }
Y
Yan Chunwei 已提交
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
    auto program = pass->GenProgram();
    CHECK(exec_scope_);
    program->set_exec_scope(exec_scope_);
    return program;
  }

  void InitTargetTypeTransformPass() {
    auto* pass =
        mir::PassManager::Global().LookUp<mir::TypeTargetTransformPass>(
            "type_target_cast_pass");
    CHECK(pass);
    CHECK(!valid_places_.empty());
    pass->SetValidPlaces(valid_places_);
  }

216 217 218 219 220 221 222 223 224 225
  void InitControlFlowOpUnusedInputsAndOutputsEliminatePass() {
    auto* pass =
        mir::PassManager::Global()
            .LookUp<mir::ControlFlowOpUnusedInputsAndOutputsEliminatePass>(
                "control_flow_op_unused_inputs_and_outputs_eliminate_pass");
    CHECK(pass);
    CHECK(!graphs_.empty());
    pass->SetAllGraphs(&graphs_);
  }

Y
Yan Chunwei 已提交
226 227 228
  // Generate C++ code which combines the inference program, model and weights.
  void GenCode(const std::string& code_dir);

229 230 231 232
  const mir::SSAGraph& ssa_graph(int block_idx = kRootBlockIdx) const {
    CHECK(!graphs_.empty());
    CHECK(graphs_[block_idx]);
    return *graphs_[block_idx];
Y
Yan Chunwei 已提交
233 234
  }

235 236 237 238
  mir::SSAGraph* mutable_ssa_graph(int block_idx = kRootBlockIdx) {
    CHECK(!graphs_.empty());
    CHECK(graphs_[block_idx]);
    return graphs_[block_idx].get();
Y
Yan Chunwei 已提交
239 240
  }

241
  Scope* exec_scope() { return exec_scope_; }
Y
Yan Chunwei 已提交
242 243 244 245 246 247 248

 protected:
  void SpecifyKernelPickTactic(core::KernelPickFactor factor);

  // Specify the passes and run them.
  void RunPasses(const std::vector<std::string>& passes) {
    for (auto& x : passes) {
249 250
      LOG(INFO) << "== Running pass: " << x;
      mir::Pass* pass = mir::PassManager::Global().LookUp(x);
251 252 253 254 255
      if (!pass) {
        LOG(INFO) << "   - Skip " << x << " because the pass isn't found.";
        continue;
      }
      std::set<TargetType> targets;
256
      for (const auto& place : valid_places_) {
257
        targets.insert(place.target);
258
      }
259 260
      bool matched =
          PassMatchesTarget(*pass, targets) && PassMatchesKernels(*pass);
261
      if (!matched) {
262 263
        LOG(INFO) << "   - Skip " << x
                  << " because the target or kernel does not match.";
264
      } else {
265 266 267 268 269 270 271 272
        // Check the pass whether it is supported for processing subblocks
        if (kSubblockUnsupportedPasses.count(x)) {
          pass->Apply(graphs_[kRootBlockIdx]);
        } else {
          for (auto& graph : graphs_) {
            pass->Apply(graph);
          }
        }
273 274
        LOG(INFO) << "== Finished running: " << x;
      }
Y
Yan Chunwei 已提交
275 276 277 278
    }
  }

 private:
279
  std::vector<std::unique_ptr<mir::SSAGraph>> graphs_;
Y
Yan Chunwei 已提交
280
  std::vector<Place> valid_places_;
281
  Scope* exec_scope_{};
Y
Yan Chunwei 已提交
282 283 284 285 286
  Program* program_{};
};

}  // namespace lite
}  // namespace paddle