pass.cc 7.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/ir/pass/pass.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
19
#include "paddle/ir/core/region.h"
20 21 22 23 24 25
#include "paddle/ir/pass/pass_adaptor.h"
#include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/pass_manager.h"

namespace ir {

26 27 28 29 30 31 32
//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//
Pass::~Pass() = default;

bool Pass::CanApplyOn(Operation* op) const { return op->num_regions() > 0; }

33 34 35
//----------------------------------------------------------------------------------------------//
// PassAdaptor
//----------------------------------------------------------------------------------------------//
36
void detail::PassAdaptor::Run(Operation* op, uint8_t opt_level, bool verify) {
37 38 39
  RunImpl(op, opt_level, verify);
}

40
void detail::PassAdaptor::RunImpl(Operation* op,
41 42
                                  uint8_t opt_level,
                                  bool verify) {
43 44 45 46 47 48 49 50 51 52 53 54 55 56
  auto last_am = analysis_manager();

  for (size_t i = 0; i < op->num_regions(); ++i) {
    auto& region = op->GetRegion(i);
    for (auto it = region.begin(); it != region.end(); ++it) {
      auto* block = *it;
      for (auto it = block->begin(); it != block->end(); ++it) {
        auto* op = *it;
        AnalysisManagerHolder am(op, last_am.GetPassInstrumentor());
        if (!RunPipeline(*pm_, op, am, opt_level, verify))
          return SignalPassFailure();
      }
    }
  }
57 58 59 60
  return;
}

bool detail::PassAdaptor::RunPipeline(const PassManager& pm,
61
                                      Operation* op,
62 63 64 65 66 67 68 69 70
                                      AnalysisManager am,
                                      uint8_t opt_level,
                                      bool verify) {
  auto* instrumentor = am.GetPassInstrumentor();
  if (instrumentor) {
    instrumentor->RunBeforePipeline(op);
  }

  for (auto& pass : pm.passes()) {
71
    if (pass->CanApplyOn(op)) {
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
      if (!RunPass(pass.get(), op, am, opt_level, verify)) {
        return false;
      }
    }
  }

  if (instrumentor) {
    instrumentor->RunAfterPipeline(op);
  }

  // Apply pass manager on all nested ir.
  if (!RunPass(pm.pass_adaptor_.get(), op, am, opt_level, verify)) {
    return false;
  }

  return true;
}

bool detail::PassAdaptor::RunPass(Pass* pass,
91
                                  Operation* op,
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
                                  AnalysisManager am,
                                  uint8_t opt_level,
                                  bool verify) {
  if (opt_level < pass->pass_info().opt_level) return true;

  pass->pass_state_ = PassExecutionState(op, am);

  PassInstrumentor* instrumentor = am.GetPassInstrumentor();

  if (auto* adaptor = dynamic_cast<PassAdaptor*>(pass)) {
    adaptor->Run(op, opt_level, verify);
  } else {
    if (instrumentor) instrumentor->RunBeforePass(pass, op);
    pass->Run(op);
    if (instrumentor) instrumentor->RunAfterPass(pass, op);
  }

  bool pass_failed = pass->pass_state().pass_failed;

  // TODO(liuyuanle): Support verification of operation
  if (!pass_failed && verify) {
    // bool verify_recursively = !dynamic_cast<PassAdaptor*>(pass);
114
    // pass_failed = ir::Verify(op, verify_recursively);
115 116 117 118 119 120 121 122
  }

  return !pass_failed;
}

//----------------------------------------------------------------------------------------------//
// PassManager
//----------------------------------------------------------------------------------------------//
123
PassManager::PassManager(IrContext* context, uint8_t opt_level)
124 125 126 127
    : context_(context), opt_level_(opt_level) {
  pass_adaptor_ = std::make_unique<detail::PassAdaptor>(this);
}

128
bool PassManager::Run(Program* program) {
129 130 131 132 133
  if (!Initialize(context_)) {
    return false;
  }
  return Run(program->module_op());
}
134

135
bool PassManager::Run(Operation* op) {
136 137 138 139 140 141
  // Construct a analysis manager for the pipeline.
  AnalysisManagerHolder am(op, instrumentor_.get());

  return detail::PassAdaptor::RunPipeline(*this, op, am, opt_level_, verify_);
}

142
bool PassManager::Initialize(IrContext* context) {
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
  for (auto& pass : passes()) {
    if (!pass->Initialize(context)) return false;
  }

  return true;
}

void PassManager::AddInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
  if (!instrumentor_) instrumentor_ = std::make_unique<PassInstrumentor>();

  instrumentor_->AddInstrumentation(std::move(pi));
}

//----------------------------------------------------------------------------------------------//
// PassInstrumentor
//----------------------------------------------------------------------------------------------//
namespace detail {
struct PassInstrumentorImpl {
  // TODO(wilber): Support multi-thread.
  std::vector<std::unique_ptr<PassInstrumentation>> instrumentations;
};
}  // namespace detail

PassInstrumentor::PassInstrumentor()
    : impl_(new detail::PassInstrumentorImpl{}) {}

PassInstrumentor::~PassInstrumentor() = default;

171 172
void PassInstrumentor::RunBeforePipeline(Operation* op) {
  if (op->num_regions() == 0) return;
173 174 175 176 177
  for (auto& instr : impl_->instrumentations) {
    instr->RunBeforePipeline(op);
  }
}

178 179
void PassInstrumentor::RunAfterPipeline(Operation* op) {
  if (op->num_regions() == 0) return;
180 181 182 183 184 185 186
  for (auto it = impl_->instrumentations.rbegin();
       it != impl_->instrumentations.rend();
       ++it) {
    (*it)->RunAfterPipeline(op);
  }
}

187 188
void PassInstrumentor::RunBeforePass(Pass* pass, Operation* op) {
  if (op->num_regions() == 0) return;
189 190 191 192 193
  for (auto& instr : impl_->instrumentations) {
    instr->RunBeforePass(pass, op);
  }
}

194 195
void PassInstrumentor::RunAfterPass(Pass* pass, Operation* op) {
  if (op->num_regions() == 0) return;
196 197 198 199 200 201 202 203
  for (auto it = impl_->instrumentations.rbegin();
       it != impl_->instrumentations.rend();
       ++it) {
    (*it)->RunAfterPass(pass, op);
  }
}

void PassInstrumentor::RunBeforeAnalysis(const std::string& name,
204 205 206
                                         TypeId id,
                                         Operation* op) {
  if (op->num_regions() == 0) return;
207 208 209 210 211 212
  for (auto& instr : impl_->instrumentations) {
    instr->RunBeforeAnalysis(name, id, op);
  }
}

void PassInstrumentor::RunAfterAnalysis(const std::string& name,
213 214 215
                                        TypeId id,
                                        Operation* op) {
  if (op->num_regions() == 0) return;
216 217 218 219 220 221 222 223 224 225 226 227 228
  for (auto it = impl_->instrumentations.rbegin();
       it != impl_->instrumentations.rend();
       ++it) {
    (*it)->RunBeforeAnalysis(name, id, op);
  }
}

void PassInstrumentor::AddInstrumentation(
    std::unique_ptr<PassInstrumentation> pi) {
  impl_->instrumentations.emplace_back(std::move(pi));
}

}  // namespace ir