pass.cc 7.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/ir/pass/pass.h"
16

17 18 19
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
20
#include "paddle/ir/core/region.h"
21 22 23 24 25 26
#include "paddle/ir/pass/pass_adaptor.h"
#include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/pass_manager.h"

namespace ir {

27 28 29 30 31 32 33
//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//
Pass::~Pass() = default;

bool Pass::CanApplyOn(Operation* op) const { return op->num_regions() > 0; }

34 35 36
//----------------------------------------------------------------------------------------------//
// PassAdaptor
//----------------------------------------------------------------------------------------------//
37
void detail::PassAdaptor::Run(Operation* op, uint8_t opt_level, bool verify) {
38 39 40
  RunImpl(op, opt_level, verify);
}

41
void detail::PassAdaptor::RunImpl(Operation* op,
42 43
                                  uint8_t opt_level,
                                  bool verify) {
44 45 46 47 48 49 50 51 52 53 54 55 56 57
  auto last_am = analysis_manager();

  for (size_t i = 0; i < op->num_regions(); ++i) {
    auto& region = op->GetRegion(i);
    for (auto it = region.begin(); it != region.end(); ++it) {
      auto* block = *it;
      for (auto it = block->begin(); it != block->end(); ++it) {
        auto* op = *it;
        AnalysisManagerHolder am(op, last_am.GetPassInstrumentor());
        if (!RunPipeline(*pm_, op, am, opt_level, verify))
          return SignalPassFailure();
      }
    }
  }
58 59 60 61
  return;
}

bool detail::PassAdaptor::RunPipeline(const PassManager& pm,
62
                                      Operation* op,
63 64 65 66 67 68 69 70 71
                                      AnalysisManager am,
                                      uint8_t opt_level,
                                      bool verify) {
  auto* instrumentor = am.GetPassInstrumentor();
  if (instrumentor) {
    instrumentor->RunBeforePipeline(op);
  }

  for (auto& pass : pm.passes()) {
72
    if (pass->CanApplyOn(op)) {
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
      if (!RunPass(pass.get(), op, am, opt_level, verify)) {
        return false;
      }
    }
  }

  if (instrumentor) {
    instrumentor->RunAfterPipeline(op);
  }

  // Apply pass manager on all nested ir.
  if (!RunPass(pm.pass_adaptor_.get(), op, am, opt_level, verify)) {
    return false;
  }

  return true;
}

bool detail::PassAdaptor::RunPass(Pass* pass,
92
                                  Operation* op,
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
                                  AnalysisManager am,
                                  uint8_t opt_level,
                                  bool verify) {
  if (opt_level < pass->pass_info().opt_level) return true;

  pass->pass_state_ = PassExecutionState(op, am);

  PassInstrumentor* instrumentor = am.GetPassInstrumentor();

  if (auto* adaptor = dynamic_cast<PassAdaptor*>(pass)) {
    adaptor->Run(op, opt_level, verify);
  } else {
    if (instrumentor) instrumentor->RunBeforePass(pass, op);
    pass->Run(op);
    if (instrumentor) instrumentor->RunAfterPass(pass, op);
  }

  bool pass_failed = pass->pass_state().pass_failed;

  // TODO(liuyuanle): Support verification of operation
  if (!pass_failed && verify) {
    // bool verify_recursively = !dynamic_cast<PassAdaptor*>(pass);
115
    // pass_failed = ir::Verify(op, verify_recursively);
116 117 118 119 120 121 122 123
  }

  return !pass_failed;
}

//----------------------------------------------------------------------------------------------//
// PassManager
//----------------------------------------------------------------------------------------------//
124
PassManager::PassManager(IrContext* context, uint8_t opt_level)
125 126 127 128
    : context_(context), opt_level_(opt_level) {
  pass_adaptor_ = std::make_unique<detail::PassAdaptor>(this);
}

129
bool PassManager::Run(Program* program) {
130 131 132 133 134
  if (!Initialize(context_)) {
    return false;
  }
  return Run(program->module_op());
}
135

136
bool PassManager::Run(Operation* op) {
137 138 139 140 141 142
  // Construct a analysis manager for the pipeline.
  AnalysisManagerHolder am(op, instrumentor_.get());

  return detail::PassAdaptor::RunPipeline(*this, op, am, opt_level_, verify_);
}

143
bool PassManager::Initialize(IrContext* context) {
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
  for (auto& pass : passes()) {
    if (!pass->Initialize(context)) return false;
  }

  return true;
}

void PassManager::AddInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
  if (!instrumentor_) instrumentor_ = std::make_unique<PassInstrumentor>();

  instrumentor_->AddInstrumentation(std::move(pi));
}

//----------------------------------------------------------------------------------------------//
// PassInstrumentor
//----------------------------------------------------------------------------------------------//
namespace detail {
struct PassInstrumentorImpl {
  // TODO(wilber): Support multi-thread.
  std::vector<std::unique_ptr<PassInstrumentation>> instrumentations;
};
}  // namespace detail

PassInstrumentor::PassInstrumentor()
    : impl_(new detail::PassInstrumentorImpl{}) {}

PassInstrumentor::~PassInstrumentor() = default;

172 173
void PassInstrumentor::RunBeforePipeline(Operation* op) {
  if (op->num_regions() == 0) return;
174 175 176 177 178
  for (auto& instr : impl_->instrumentations) {
    instr->RunBeforePipeline(op);
  }
}

179 180
void PassInstrumentor::RunAfterPipeline(Operation* op) {
  if (op->num_regions() == 0) return;
181 182 183 184 185 186 187
  for (auto it = impl_->instrumentations.rbegin();
       it != impl_->instrumentations.rend();
       ++it) {
    (*it)->RunAfterPipeline(op);
  }
}

188 189
void PassInstrumentor::RunBeforePass(Pass* pass, Operation* op) {
  if (op->num_regions() == 0) return;
190 191 192 193 194
  for (auto& instr : impl_->instrumentations) {
    instr->RunBeforePass(pass, op);
  }
}

195 196
void PassInstrumentor::RunAfterPass(Pass* pass, Operation* op) {
  if (op->num_regions() == 0) return;
197 198 199 200 201 202 203 204
  for (auto it = impl_->instrumentations.rbegin();
       it != impl_->instrumentations.rend();
       ++it) {
    (*it)->RunAfterPass(pass, op);
  }
}

void PassInstrumentor::RunBeforeAnalysis(const std::string& name,
205 206 207
                                         TypeId id,
                                         Operation* op) {
  if (op->num_regions() == 0) return;
208 209 210 211 212 213
  for (auto& instr : impl_->instrumentations) {
    instr->RunBeforeAnalysis(name, id, op);
  }
}

void PassInstrumentor::RunAfterAnalysis(const std::string& name,
214 215 216
                                        TypeId id,
                                        Operation* op) {
  if (op->num_regions() == 0) return;
217 218 219 220 221 222 223 224 225 226 227 228 229
  for (auto it = impl_->instrumentations.rbegin();
       it != impl_->instrumentations.rend();
       ++it) {
    (*it)->RunBeforeAnalysis(name, id, op);
  }
}

void PassInstrumentor::AddInstrumentation(
    std::unique_ptr<PassInstrumentation> pi) {
  impl_->instrumentations.emplace_back(std::move(pi));
}

}  // namespace ir