pass_instrumentation.h 2.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>

#include "paddle/ir/core/type_id.h"

namespace ir {

class Operation;
class Pass;

namespace detail {
struct PassInstrumentorImpl;
}  // namespace detail

class PassInstrumentation {
 public:
  PassInstrumentation() = default;
  virtual ~PassInstrumentation() = default;

35
  // A callback to run before a pass pipeline is executed.
36
  virtual void RunBeforePipeline(Operation* op) {}
37

38
  // A callback to run after a pass pipeline is executed.
39
  virtual void RunAfterPipeline(Operation* op) {}
40

41
  // A callback to run before a pass is executed.
42
  virtual void RunBeforePass(Pass* pass, Operation* op) {}
43

44
  // A callback to run after a pass is executed.
45
  virtual void RunAfterPass(Pass* pass, Operation* op) {}
46

47
  // A callback to run before a analysis is executed.
48
  virtual void RunBeforeAnalysis(const std::string& name,
49 50
                                 TypeId id,
                                 Operation* op) {}
51

52
  // A callback to run after a analysis is executed.
53
  virtual void RunAfterAnalysis(const std::string& name,
54 55
                                TypeId id,
                                Operation* op) {}
56 57 58 59 60 61 62 63 64 65 66 67 68
};

/// This class holds a collection of PassInstrumentation obejcts, and invokes
/// their respective callbacks.
class PassInstrumentor {
 public:
  PassInstrumentor();
  ~PassInstrumentor();
  PassInstrumentor(PassInstrumentor&&) = delete;
  PassInstrumentor(const PassInstrumentor&) = delete;

  void AddInstrumentation(std::unique_ptr<PassInstrumentation> pi);

69
  void RunBeforePipeline(Operation* op);
70

71
  void RunAfterPipeline(Operation* op);
72

73
  void RunBeforePass(Pass* pass, Operation* op);
74

75
  void RunAfterPass(Pass* pass, Operation* op);
76

77
  void RunBeforeAnalysis(const std::string& name, TypeId id, Operation* op);
78

79
  void RunAfterAnalysis(const std::string& name, TypeId id, Operation* op);
80 81 82 83 84 85 86 87

  // TODO(wilber): Add other hooks.

 private:
  std::unique_ptr<detail::PassInstrumentorImpl> impl_;
};

}  // namespace ir