program_pass.h 3.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/utils/registry.h"

namespace cinn {
namespace frontend {

class ProgramPass {
 public:
31
  explicit ProgramPass(const std::string& name) : name_(name) {}
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86

  /**
   * \brief Apply a sequence of passes on a program.
   * @param prog The input program to apply passes on.
   * @param passes The sequence of pass.
   * @return The program after being modified by the passes.
   */
  static void Apply(Program* prog,
                    const std::unordered_set<std::string>& fetch_ids,
                    const common::Target& target,
                    const std::vector<std::string>& passes);

  const std::string& name() const { return name_; }

 protected:
  virtual void ApplyImpl(Program* prog,
                         const std::unordered_set<std::string>& fetch_ids,
                         const common::Target& target) {}
  virtual void ApplyImpl(Program* prog,
                         const std::unordered_set<std::string>& fetch_ids,
                         const common::Target& target) const {
    return const_cast<ProgramPass*>(this)->ApplyImpl(prog, fetch_ids, target);
  }

  virtual void Clear() = 0;

 private:
  std::string name_;
};

class ProgramPassRegistry : public Registry<ProgramPass> {
 public:
  static ProgramPassRegistry* Global() {
    static ProgramPassRegistry x;
    return &x;
  }

  inline const ProgramPass* Get(const std::string& name) {
    const ProgramPass* pass = Registry<ProgramPass>::Find(name);
    CHECK(pass) << "Pass [" << name << "] is not registered";
    return pass;
  }

  inline ProgramPass* __REGISTER__(const std::string& name, ProgramPass* pass) {
    std::lock_guard<std::mutex> guard(registering_mutex);
    if (fmap_.count(name)) {
      return fmap_[name];
    }

    fmap_[name] = pass;
    const_list_.push_back(pass);
    entry_list_.push_back(pass);
    return pass;
  }

87 88
  inline ProgramPass* __REGISTER_OR_GET__(const std::string& name,
                                          ProgramPass* pass) {
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
    if (!fmap_.count(name)) {
      return __REGISTER__(name, pass);
    } else {
      return fmap_.at(name);
    }
  }

 private:
  ProgramPassRegistry() = default;
  CINN_DISALLOW_COPY_AND_ASSIGN(ProgramPassRegistry);
};

/**
 * @def CINN_REGISTER_PROGRAM_PASS
 * \brief Register a new program pass
 *
 * @param PassType The type of pass
 * @param PassClass The pass inherited from ProgramPass
 *
 * \code
 *  CINN_REGISTER_PROGRAM_PASS(decompose, DecomposerPass());
 * \endcode
 */
112 113 114 115
#define CINN_REGISTER_PROGRAM_PASS(PassType, PassClass)                     \
  static ::cinn::frontend::ProgramPass* __make_##PassType##__ =             \
      ::cinn::frontend::ProgramPassRegistry::Global()->__REGISTER_OR_GET__( \
          #PassType, new PassClass{#PassType})
116 117 118

}  // namespace frontend
}  // namespace cinn