dead_code_elimination_pass.cc 2.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/ir/transforms/dead_code_elimination_pass.h"
16

17
#include "paddle/ir/core/builtin_op.h"
18
#include "paddle/ir/core/program.h"
19 20 21 22 23 24 25
#include "paddle/ir/pass/pass.h"

namespace {

// TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be
// removed by dce pass.
// Now just a naive implementation.
26
class DeadCodeEliminationPass : public ir::Pass {
27
 public:
28
  DeadCodeEliminationPass() : ir::Pass("DeadCodeEliminationPass", 0) {}
29 30 31

  void Run(ir::Operation *op) override {
    auto module_op = op->dyn_cast<ir::ModuleOp>();
W
Wilber 已提交
32
    IR_ENFORCE(module_op, "DcePass should run on module op.");
33
    auto *block = module_op.block();
34
    std::vector<ir::Operation *> erased_op;
35
    for (auto &op : *block) {
36
      // TODO(wilber): Support NoSideEffect trait.
37
      // if (!op->HasTrait<NoSideEffect>()) continue;
38 39

      bool use_empty = true;
40 41
      for (uint32_t i = 0; i < op->num_results(); ++i) {
        use_empty &= op->result(i).use_empty();
42
      }
W
Wilber 已提交
43
      // TODO(wilber): Support Terminator trait.
44 45
      if (use_empty && op->name() != "pd.fetch") {
        erased_op.push_back(op);
46 47 48
      }
    }

49 50 51 52 53 54 55 56 57
    for (auto *op : erased_op) {
      if (op->dyn_cast<ir::GetParameterOp>()) {
        // Delete parameter from program.
        ir::GetParameterOp get_parameter_op =
            op->dyn_cast<ir::GetParameterOp>();
        get_parameter_op->GetParentProgram()->parameters().erase(
            get_parameter_op->attributes()
                .at(get_parameter_op.attributes_name[0])
                .dyn_cast<ir::StrAttribute>()
58
                .AsString());
59 60 61
      }
      block->erase(*op);
    }
62 63 64 65 66 67 68 69 70 71 72
  }

  bool CanApplyOn(ir::Operation *op) const override {
    return op->name() == "builtin.module" && op->num_regions() > 0;
  }
};

}  // namespace

namespace ir {

73 74 75
std::unique_ptr<Pass> CreateDeadCodeEliminationPass() {
  return std::make_unique<DeadCodeEliminationPass>();
}
76 77

}  // namespace ir