pass_manager.h 3.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <cstdint>
18
#include <iostream>
19 20 21
#include <memory>
#include <vector>

22 23
#include "paddle/ir/core/program.h"

24 25 26 27
namespace ir {

class IrContext;
class Operation;
28
class Program;
29
class Pass;
30 31
class PassInstrumentation;
class PassInstrumentor;
32 33 34 35 36

namespace detail {
class PassAdaptor;
}

37
class IR_API PassManager {
38
 public:
39
  explicit PassManager(IrContext *context, uint8_t opt_level = 2);
40 41 42

  ~PassManager() = default;

43
  const std::vector<std::unique_ptr<Pass>> &passes() const { return passes_; }
44

45
  bool Empty() const { return passes_.empty(); }
46

47
  IrContext *context() const { return context_; }
48

49
  bool Run(Program *program);
50 51 52 53 54

  void AddPass(std::unique_ptr<Pass> pass) {
    passes_.emplace_back(std::move(pass));
  }

55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
  class IRPrinterOption {
   public:
    using PrintCallBack = std::function<void(std::ostream &)>;

    explicit IRPrinterOption(
        const std::function<bool(Pass *, Operation *)> &enable_print_before =
            [](Pass *, Operation *) { return true; },
        const std::function<bool(Pass *, Operation *)> &enable_print_after =
            [](Pass *, Operation *) { return true; },
        bool print_module = true,
        bool print_on_change = true,
        std::ostream &os = std::cout)
        : enable_print_before_(enable_print_before),
          enable_print_after_(enable_print_after),
          print_module_(print_module),
          print_on_change_(print_on_change),
          os(os) {
      assert((enable_print_before_ || enable_print_after_) &&
             "expected at least one valid filter function");
    }

    ~IRPrinterOption() = default;

    void PrintBeforeIfEnabled(Pass *pass,
                              Operation *op,
                              const PrintCallBack &print_callback) {
      if (enable_print_before_ && enable_print_before_(pass, op)) {
        print_callback(os);
      }
    }

    void PrintAfterIfEnabled(Pass *pass,
                             Operation *op,
                             const PrintCallBack &print_callback) {
      if (enable_print_after_ && enable_print_after_(pass, op)) {
        print_callback(os);
      }
    }

94
    bool print_module() const { return print_module_; }
95

96
    bool print_on_change() const { return print_on_change_; }
97 98 99 100 101 102 103 104

   private:
    // The enable_print_before_ and enable_print_after_ can be used to specify
    // the pass to be printed. The default is to print all passes.
    std::function<bool(Pass *, Operation *)> enable_print_before_;
    std::function<bool(Pass *, Operation *)> enable_print_after_;

    bool print_module_;
105

106 107 108 109 110 111 112
    bool print_on_change_;

    std::ostream &os;

    // TODO(liuyuanle): Add flags to control printing behavior.
  };

113 114
  void EnableIRPrinting(std::unique_ptr<IRPrinterOption> option =
                            std::make_unique<IRPrinterOption>());
115

116 117
  void EnablePassTiming(bool print_module = true);

118
  void AddInstrumentation(std::unique_ptr<PassInstrumentation> pi);
119

120
 private:
121 122 123
  bool Initialize(IrContext *context);

  bool Run(Operation *op);
124 125

 private:
126
  IrContext *context_;
127 128 129

  uint8_t opt_level_;

130 131
  bool verify_{true};

132 133 134 135
  std::vector<std::unique_ptr<Pass>> passes_;

  std::unique_ptr<Pass> pass_adaptor_;

136 137
  std::unique_ptr<PassInstrumentor> instrumentor_;

138
  // For access member of pass_adaptor_.
139 140 141 142
  friend class detail::PassAdaptor;
};

}  // namespace ir