dce.cc 2.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/ir/transforms/dce.h"
#include <memory>
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/pass/pass.h"

namespace {

// TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be
// removed by dce pass.
// Now just a naive implementation.
W
Wilber 已提交
25
class DcePass : public ir::Pass {
26
 public:
W
Wilber 已提交
27
  DcePass() : ir::Pass("DcePass", 0) {}
28 29 30

  void Run(ir::Operation *op) override {
    auto module_op = op->dyn_cast<ir::ModuleOp>();
W
Wilber 已提交
31
    IR_ENFORCE(module_op, "DcePass should run on module op.");
32 33 34 35 36 37 38 39 40 41
    auto *block = module_op.block();
    std::vector<ir::Operation> erased_op;
    for (auto it = block->begin(); it != block->end(); ++it) {
      // TODO(wilber): Support NoSideEffect trait.
      // if (!(*it)->HasTrait<NoSideEffect>()) continue;

      bool use_empty = true;
      for (uint32_t i = 0; i < (*it)->num_results(); ++i) {
        use_empty &= (*it)->result(i).use_empty();
      }
W
Wilber 已提交
42
      // TODO(wilber): Support Terminator trait.
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
      if (use_empty && (*it)->name() != "pd.fetch") {
        erased_op.push_back(**it);
      }
    }

    for (auto ep : erased_op) block->erase(ep);
  }

  bool CanApplyOn(ir::Operation *op) const override {
    return op->name() == "builtin.module" && op->num_regions() > 0;
  }
};

}  // namespace

namespace ir {

W
Wilber 已提交
60
std::unique_ptr<Pass> CreateDcePass() { return std::make_unique<DcePass>(); }
61 62

}  // namespace ir