ir_simplify.cc 10.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/optim/ir_simplify.h"

#include <absl/container/flat_hash_map.h>
#include <ginac/ginac.h>
#include <glog/logging.h>

#include <map>
#include <string>

#include "paddle/cinn/common/arithmatic.h"
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/ir_util.h"
27
#include "paddle/cinn/ir/op/ir_operators.h"
28
#include "paddle/cinn/ir/tensor.h"
29 30 31
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
32 33 34 35 36 37 38 39 40 41 42 43
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/utils/string.h"

namespace cinn {
namespace optim {
using namespace ir;  // NOLINT
using common::ExprToGinacConverter;
using utils::GetStreamCnt;
using utils::Replace;

namespace {

44 45 46 47 48 49 50 51
//! Simplify some sub-expression in the `expr`. Due to the simplify strategy
//! just fit several kinds of IR noedes, we partition the original expression to
//! several sub-expression those supported by simplify, and process each of
//! them.
void PartialSimplify(
    Expr* expr,
    const absl::flat_hash_map<std::string, common::CasInterval>& var_intervals =
        {}) {
52 53 54 55
  *expr = common::AutoSimplify(*expr, var_intervals);
}

//! Simplify the expression but Load.
56
struct SimplifyNoPureMathMutator : public ir::IRMutator<ir::Expr*> {
57
  common::cas_intervals_t& var_intervals;
58
  explicit SimplifyNoPureMathMutator(
59
      common::cas_intervals_t& var_intervals)  // NOLINT
60
      : var_intervals(var_intervals) {}
61 62 63 64 65

  void operator()(Expr* x) { ir::IRMutator<ir::Expr*>::Visit(x, x); }

  using ir::IRMutator<>::Visit;

66 67 68 69
#define __(op__)                                    \
  void Visit(const op__* op, Expr* expr) override { \
    PartialSimplify(expr, var_intervals);           \
  }
70 71 72 73 74 75 76 77 78 79

  __(Add)
  __(Mul)
  __(Sub)
  __(Div)
  __(Min)
  __(Max)
#undef __

  void Visit(const PolyFor* op, Expr* expr) override {
80
    auto* node = expr->As<ir::PolyFor>();
81 82 83 84 85 86 87 88 89
    node->condition = common::SolveInequality(op->condition, op->iterator);

    Visit(&node->body, &node->body);
  }

  void Visit(const For* op, Expr* expr) override {
    auto* node = expr->As<ir::For>();
    Visit(&node->min, &node->min);
    Visit(&node->extent, &node->extent);
90
    auto* min_i = op->min.As<IntImm>();
91 92
    auto* extent_i = op->extent.As<IntImm>();
    if (min_i && extent_i && extent_i->value > min_i->value) {
93 94 95
      var_intervals.emplace(
          op->loop_var->name,
          common::CasInterval{min_i->value, extent_i->value - 1});
96
    } else {
97 98
      var_intervals.emplace(op->loop_var->name,
                            common::CasInterval{op->min, op->extent - 1});
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
    }

    Visit(&node->body, &node->body);
    if (min_i && extent_i) {
      var_intervals.erase(op->loop_var->name);
    }
  }

  void Visit(const _Tensor_* op, Expr* expr) override {
    auto* node = expr->As<ir::_Tensor_>();

    for (auto& e : node->shape) {
      PartialSimplify(&e, var_intervals);
    }
    for (auto& e : node->domain) {
      PartialSimplify(&e, var_intervals);
    }
  }
};

struct SimplifyLoadMutator : public ir::IRMutator<ir::Expr*> {
  void operator()(Expr* x) { ir::IRMutator<ir::Expr*>::Visit(x, x); }

  void Visit(const Load* expr, Expr* op) override {
    auto* node = op->As<Load>();
    for (auto& idx : node->indices) {
      if (common::IsPureMath(idx)) {
        PartialSimplify(&idx, var_intervals_);
      } else {
128
        SimplifyNoPureMathMutator mutator(var_intervals_);
129 130 131 132 133 134
        mutator(&idx);
      }
    }
  }

  void Visit(const For* op, Expr* expr) override {
135
    auto* min_i = op->min.As<IntImm>();
136 137
    auto* extent_i = op->extent.As<IntImm>();
    if (min_i && extent_i && extent_i->value > min_i->value) {
138 139 140
      var_intervals_.emplace(
          op->loop_var->name,
          common::CasInterval{min_i->value, extent_i->value - 1});
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
    }

    auto* node = expr->As<For>();

    operator()(&node->body);
    operator()(&node->extent);

    if (min_i && extent_i) {
      var_intervals_.erase(op->loop_var->name);
    }
  }

  common::cas_intervals_t var_intervals_;
};

struct SimplifyStoreMutator : public ir::IRMutator<ir::Expr*> {
  void operator()(Expr* x) { ir::IRMutator<ir::Expr*>::Visit(x, x); }

  void Visit(const Store* expr, Expr* op) override {
    auto* node = op->As<Store>();

    for (auto& idx : node->indices) {
      if (common::IsPureMath(idx)) {
        PartialSimplify(&idx, var_intervals_);
      } else {
166
        SimplifyNoPureMathMutator mutator(var_intervals_);
167 168 169 170 171 172
        mutator(&idx);
      }
    }
  }

  void Visit(const For* op, Expr* expr) override {
173
    auto* min_i = op->min.As<IntImm>();
174 175
    auto* extent_i = op->extent.As<IntImm>();
    if (min_i && extent_i) {
176 177 178
      var_intervals_.emplace(
          op->loop_var->name,
          common::CasInterval{min_i->value, extent_i->value - 1});
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
    }

    auto* node = expr->As<For>();

    operator()(&node->body);
    operator()(&node->extent);

    if (min_i && extent_i) {
      var_intervals_.erase(op->loop_var->name);
    }
  }

  common::cas_intervals_t var_intervals_;
};

struct SimplifyRampMutator : public ir::IRMutator<Expr*> {
  void operator()(Expr* x) { ir::IRMutator<ir::Expr*>::Visit(x, x); }

  void Visit(const Ramp* op, Expr* expr) override {
    auto* node = expr->As<ir::Ramp>();

200 201 202 203
    CHECK(common::IsPureMath(node->base))
        << node->base << "is not a pure math!";
    CHECK(common::IsPureMath(node->stride))
        << node->stride << "is not a pure math!";
204

205 206
    PartialSimplify(&node->base);
    PartialSimplify(&node->stride);
207 208 209
  }
  // ramp + ramp
  void Visit(const Add* op, Expr* expr) override {
210 211 212
    auto* node = expr->As<ir::Add>();
    Expr a = node->a();
    Expr b = node->b();
213 214 215 216
    auto a_ramp = a.As<ir::Ramp>();
    auto b_ramp = b.As<ir::Ramp>();

    if (a_ramp && b_ramp && a_ramp->lanes == b_ramp->lanes) {
217
      Expr base_add = common::AutoSimplify(a_ramp->base + b_ramp->base);
218
      Expr stride_add = common::AutoSimplify(a_ramp->stride + b_ramp->stride);
219
      *expr = ir::Ramp::Make(base_add, stride_add, a_ramp->lanes);
220 221 222 223 224 225 226 227 228 229
    }
  }
};

struct SimplifyIfThenElseMutator : public ir::IRMutator<> {
  void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); }

  using ir::IRMutator<>::Visit;

  void Visit(const IfThenElse* op, Expr* expr) override {
230
    auto* node = expr->As<ir::IfThenElse>();
231 232
    node->condition = common::AutoSimplify(node->condition);

233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
    auto* condition_int = node->condition.As<ir::IntImm>();
    auto* condition_uint = node->condition.As<ir::UIntImm>();
    int64_t value;
    if (condition_int || condition_uint) {
      if (condition_int) {
        value = condition_int->value;
      } else {
        value = condition_uint->value;
      }
      if (value) {
        *expr = op->true_case;
      } else {
        if (op->false_case.defined()) {
          *expr = op->false_case;
        } else {
          // null condition
          *expr = ir::Block::Make({});
        }
      }
    }
    if (expr->As<ir::IfThenElse>()) {
      if (node->true_case.defined()) Visit(&node->true_case, &node->true_case);
      if (node->false_case.defined())
        Visit(&node->false_case, &node->false_case);
    }
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
  }
};

struct ReplaceFracWithDivMutator : public ir::IRMutator<> {
  void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); }

  void Visit(const FracOp* op, Expr* expr) override {
    auto* node = expr->As<ir::FracOp>();

    ir::IRMutator<>::Visit(&node->operand(0), &node->operand(0));
    ir::IRMutator<>::Visit(&node->operand(1), &node->operand(1));

    *expr = ir::Div::Make(node->operand(0), node->operand(1));
  }
};

struct SimplifyBlocksMutator : public ir::IRMutator<> {
275
  SimplifyBlocksMutator() {}
276 277 278 279 280 281 282 283 284 285 286 287 288

  void operator()(Expr* x) { ir::IRMutator<ir::Expr*>::Visit(x, x); }

  using ir::IRMutator<>::Visit;

  void Visit(const Block* op, Expr* expr) override {
    auto* node = expr->As<ir::Block>();

    if (node->stmts.size() == 1 && node->stmts[0].As<ir::Block>()) {
      VLOG(6) << "Simplify size-1 ir::Block";
      *expr = node->stmts[0];
      Visit(expr, expr);
    } else {
289 290 291
      for (auto& s : node->stmts) {
        Visit(&s, &s);
      }
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
      std::vector<Expr> stmts;
      for (auto& s : node->stmts) {
        if (s.As<ir::Block>()) {
          VLOG(6) << "Simplify ir::Block inside ir::Block";
          auto inner_block = s.As<ir::Block>();
          for (auto inner_stmt : inner_block->stmts) {
            stmts.push_back(inner_stmt);
          }
        } else {
          stmts.push_back(s);
        }
      }
      expr->As<ir::Block>()->stmts = stmts;
    }
  }
};

struct SimplifyForLoopsMutator : public ir::IRMutator<> {
  absl::flat_hash_map<std::string, common::CasInterval> var_intervals;
311
  SimplifyForLoopsMutator() {}
312 313 314 315 316 317 318 319 320

  void operator()(Expr* x) { ir::IRMutator<ir::Expr*>::Visit(x, x); }

  using ir::IRMutator<>::Visit;

  void Visit(const For* op, Expr* expr) override {
    auto* node = expr->As<ir::For>();
    Visit(&node->min, &node->min);
    Visit(&node->extent, &node->extent);
321
    auto* min_i = node->min.As<IntImm>();
322
    auto* extent_i = node->extent.As<IntImm>();
323 324
    if (min_i && extent_i && extent_i->value > min_i->value &&
        extent_i->value - min_i->value == 1) {
325 326
      VLOG(6) << "Simplify current For Loop";
      std::string var_name = node->loop_var->name;
327 328
      var_intervals.emplace(
          var_name, common::CasInterval{min_i->value, extent_i->value - 1});
傅剑寒 已提交
329 330 331

      *expr = node->body;

332 333 334 335 336 337 338 339 340 341 342 343
      Visit(expr, expr);
      var_intervals.erase(var_name);
    } else {
      Visit(&node->body, &node->body);
    }
  }

  void Visit(const _Var_* op, Expr* expr) override {
    auto* node = expr->As<ir::_Var_>();

    if (var_intervals.count(node->name)) {
      auto loop_range = var_intervals.at(node->name);
344
      *expr = Expr(loop_range.l);
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
    }
  }
};

}  // namespace

void Simplify(Expr* expr) {
  VLOG(3) << "Begin Simplify " << *expr;
  optim::CastSimplify(expr);
  SimplifyRampMutator()(expr);
  SimplifyLoadMutator()(expr);
  SimplifyStoreMutator()(expr);
  SimplifyIfThenElseMutator()(expr);

  common::cas_intervals_t var_intervals;
360
  SimplifyNoPureMathMutator mutator(var_intervals);
361 362 363 364 365 366 367 368 369 370
  mutator(expr);

  ReplaceFracWithDivMutator()(expr);
}

void SimplifyForLoops(Expr* expr) { SimplifyForLoopsMutator()(expr); }
void SimplifyBlocks(Expr* expr) { SimplifyBlocksMutator()(expr); }

}  // namespace optim
}  // namespace cinn