ir_schedule.cc 97.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/ir/ir_schedule.h"

#include <math.h>

#include <algorithm>
#include <iostream>
#include <memory>
#include <random>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/collect_ir_nodes.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_operators.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_schedule_util.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/optim/ir_copy.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/utils/string.h"

namespace cinn {
namespace ir {

/**
 * A struct helps to implement Schedule primitives.
 */
class ScheduleImpl {
 public:
  ScheduleImpl() = default;
  explicit ScheduleImpl(const ModuleExpr& module_expr, bool debug_flag = false)
      : module_expr_(module_expr), debug_flag_(debug_flag) {}
55 56
  explicit ScheduleImpl(ModuleExpr&& module_expr)
      : module_expr_(std::move(module_expr)) {}
57 58 59 60 61 62 63 64 65

  //! Set the debug flag.
  void SetDebugFlag(bool debug_flag) { debug_flag_ = debug_flag; }

  //! Get the ModuleExpr stored in ScheduleImpl.
  const ModuleExpr& GetModule() const { return module_expr_; }

  void MergeExprs();

66 67 68
  void SetExprs(const std::vector<Expr>& exprs) {
    module_expr_.SetExprs(exprs);
  }
69 70 71 72 73 74 75 76 77

  bool HasBlock(const std::string& block_name) const;

  std::vector<Expr> GetLoops(const Expr& block) const;
  std::vector<Expr> GetLoops(const std::string& block_name) const;
  std::vector<Expr> GetAllBlocks() const;
  std::vector<Expr> GetChildBlocks(const Expr& expr) const;
  Expr GetBlock(const std::string& block_name) const;
  std::vector<Expr> Split(const Expr& loop, const std::vector<int>& factors);
78 79 80 81 82
  std::vector<Expr> SamplePerfectTile(
      utils::LinearRandomEngine::StateType* rand_seed,
      const Expr& loop,
      int n,
      int max_innermost_factor);
83 84 85 86 87
  Expr Fuse(const std::vector<Expr>& loops);
  Expr Fuse(const std::string& block_name, const std::vector<int>& loops_index);
  Expr Fuse(const Expr& block, const std::vector<int>& loops_index);
  void ComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops);
  void SimpleComputeAt(const Expr& block, const Expr& loop);
88 89 90
  void ReverseComputeAt(const Expr& block,
                        const Expr& loop,
                        bool keep_unit_loops);
91
  Expr GetRootBlock(const Expr& expr) const;
92 93 94 95 96 97
  Expr CacheRead(const Expr& block,
                 int read_buffer_index,
                 const std::string& memory_type);
  Expr CacheWrite(const Expr& block,
                  int write_buffer_index,
                  const std::string& memory_type);
98
  void SyncThreads(const Expr& ir_node, bool after_node = true);
99 100 101
  void SetBuffer(Expr& block,
                 const std::string& memory_type,
                 bool fixed = false);
102
  Expr Reorder(const std::vector<Expr>& loops);
103 104
  Expr Reorder(const std::string& block_name,
               const std::vector<int>& loops_index);
105 106 107 108 109 110 111 112 113 114 115 116 117
  Expr Reorder(const Expr& block, const std::vector<int>& loops_index);
  DeviceAPI GetDeviceAPI() const;
  void MutateForType(const Expr& loop, ForType for_type, int factor = -1);
  void Parallel(const Expr& loop);
  void Vectorize(const Expr& loop, int factor);
  void Unroll(const Expr& loop);
  void ComputeInline(const Expr& schedule_block);
  void ReverseComputeInline(const Expr& schedule_block);
  void Bind(const Expr& loop, const std::string& thread_axis);
  Expr Rfactor(const Expr& rf_loop, int rf_axis);
  Expr AddUnitLoop(const Expr& block) const;
  void Annotate(const Expr& block, const std::string& key, const attr_t& value);
  void Unannotate(Expr& block, const std::string& key);
118 119
  void FlattenLoops(const std::vector<Expr>& loops,
                    const bool force_flat = false);
120
  void CopyTransformAndLoopInfo(const Expr& block, const Expr& block_target);
121 122
  void CopyTransformAndLoopInfo(const std::string& block_name,
                                const std::string& block_target_name);
123 124 125 126 127 128 129 130 131 132 133
  Expr SampleCategorical(utils::LinearRandomEngine::StateType* rand_seed,
                         const std::vector<int>& candidates,
                         const std::vector<float>& probs);

 private:
  void Replace(const Expr& src_sref, const Expr& tgt_stmt);

  ModuleExpr module_expr_;
  bool debug_flag_{false};
};

134 135 136 137
std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
                                      const std::vector<int>& factors) {
  CHECK(loop.As<ir::For>())
      << "Expr param of Split must be For node! Please check.";
138
  auto* for_node = loop.As<ir::For>();
139 140 141 142
  CHECK(common::is_zero(for_node->min))
      << "The For node must start with 0! Please check.";
  CHECK(for_node->extent.is_constant())
      << "The For node's extent must be constant! Please check.";
143 144
  int tot_extent = for_node->extent.get_constant();

145 146 147
  VLOG(3) << "Try Split loop from (" << for_node->loop_var->name << ", 0, "
          << tot_extent << ") to (" << cinn::utils::Join(factors, ", ")
          << ") at loop:\n"
148 149 150
          << loop;

  auto processed_factors = ValidateFactors(factors, tot_extent);
151 152 153 154
  int prod_size = std::accumulate(processed_factors.begin(),
                                  processed_factors.end(),
                                  1,
                                  std::multiplies<int>());
155 156 157 158
  std::vector<Var> new_loop_vars;
  Expr substitute_value(0);
  for (int i = 0; i < processed_factors.size(); ++i) {
    Var temp_var(common::UniqName(for_node->loop_var->name));
159 160
    substitute_value =
        Expr(temp_var) + substitute_value * Expr(processed_factors[i]);
161 162 163
    new_loop_vars.push_back(temp_var);
  }
  substitute_value = common::AutoSimplify(substitute_value);
164
  Expr new_node = optim::IRCopy(for_node->body);
165 166 167 168
  ReplaceExpr(&new_node, {for_node->loop_var}, {substitute_value});
  std::vector<Expr> splited_loops;
  splited_loops.resize(processed_factors.size());
  if (tot_extent < prod_size) {
169 170
    new_node = IfThenElse::Make(LT::Make(substitute_value, for_node->extent),
                                new_node);
171 172 173
  }
  for (int i = processed_factors.size() - 1; i >= 0; i--) {
    if (!new_node.As<ir::Block>()) new_node = Block::Make({new_node});
174 175 176 177 178 179
    new_node = For::Make(new_loop_vars[i],
                         Expr(0),
                         Expr(processed_factors[i]),
                         for_node->for_type(),
                         for_node->device_api,
                         new_node);
180 181 182 183 184 185 186 187 188 189 190 191
    splited_loops[i] = new_node;
  }

  this->Replace(loop, new_node);
  VLOG(3) << "After Split, ir is:\n" << splited_loops.at(0);
  return splited_loops;
}

Expr ScheduleImpl::Fuse(const std::vector<Expr>& loops) {
  VLOG(3) << "Tring to fuse:\n" << cinn::utils::Join(loops, "\n");
  std::vector<const ir::For*> for_nodes;
  std::vector<Var> loop_vars;
192 193
  CHECK(!loops.empty())
      << "The loops param of Fuse should not be empty! Please check.";
194 195

  for (const Expr& it_loop : loops) {
196 197
    CHECK(it_loop.As<ir::For>())
        << "Expr param of Fuse must be For node! Please check.";
198
    if (!for_nodes.empty()) {
199 200 201 202
      CHECK(for_nodes.back()->body.As<ir::Block>())
          << "The body of for node is not Block!";
      CHECK_EQ(for_nodes.back()->body.As<ir::Block>()->stmts.size(), 1U)
          << "The Block'size of for node is not 1!";
203
      CHECK_EQ(for_nodes.back()->body.As<ir::Block>()->stmts[0], it_loop)
204 205
          << "The For nodes in loops param of Fuse must be adjacent! Please "
             "check.";
206 207 208 209 210
    }
    for_nodes.push_back(it_loop.As<ir::For>());
    loop_vars.push_back(it_loop.As<ir::For>()->loop_var);
  }
  std::string suffix;
211
  suffix = for_nodes[0]->loop_var->name;
212 213 214 215 216 217 218 219 220 221 222
  int loops_number = for_nodes.size();
  for (int i = 1; i < loops_number; ++i) {
    suffix += "_" + for_nodes[i]->loop_var->name;
  }
  suffix += "_fused";
  Var fused_var(suffix);
  std::vector<Expr> substitute_value;
  substitute_value.resize(loops_number);
  Expr fused_expr(fused_var);
  for (int i = loops_number - 1; i > 0; i--) {
    substitute_value[i] = Mod::Make(fused_expr, for_nodes[i]->extent);
223
    fused_expr = Div::Make(fused_expr, for_nodes[i]->extent);
224 225 226 227 228 229 230 231 232 233 234 235 236
  }
  substitute_value[0] = fused_expr;

  Expr fused_body = optim::IRCopy(for_nodes.back()->body);
  ReplaceExpr(&fused_body, loop_vars, substitute_value);
  optim::Simplify(&fused_body);
  Expr fused_extent(1);
  for (int i = 0; i < loops_number; ++i) {
    fused_extent = fused_extent * for_nodes[i]->extent;
  }
  fused_extent = common::AutoSimplify(fused_extent);

  if (!fused_body.As<ir::Block>()) fused_body = Block::Make({fused_body});
237 238 239 240 241 242
  Expr new_stmt = For::Make(fused_var,
                            Expr(0),
                            fused_extent,
                            for_nodes[0]->for_type(),
                            for_nodes[0]->device_api,
                            fused_body);
243 244 245 246 247 248
  this->Replace(loops[0], new_stmt);

  VLOG(3) << "After fuse, ir is:\n" << new_stmt;
  return new_stmt;
}

249 250
Expr ScheduleImpl::Fuse(const std::string& block_name,
                        const std::vector<int>& loops_index) {
251 252 253 254
  std::vector<Expr> all_loops = this->GetLoops(block_name);
  std::vector<Expr> loops_expr;
  loops_expr.reserve(loops_index.size());
  for (int i = 0; i < loops_index.size(); ++i) {
255 256 257
    if (i > 0)
      CHECK_EQ(loops_index[i - 1] + 1, loops_index[i])
          << "Loops index in Fuse shoule be continuous!";
258 259
  }
  for (int i : loops_index) {
260 261
    CHECK_LT(i, (int)all_loops.size())
        << "The loop index in Fuse should be less than total loop's number.";
262 263 264 265 266 267
    CHECK_GE(i, 0) << "The loop index in Fuse should be >= 0.";
    loops_expr.emplace_back(all_loops[i]);
  }
  return this->Fuse(loops_expr);
}

268 269
Expr ScheduleImpl::Fuse(const Expr& block,
                        const std::vector<int>& loops_index) {
270 271 272 273
  std::vector<Expr> all_loops = this->GetLoops(block);
  std::vector<Expr> loops_expr;
  loops_expr.reserve(loops_index.size());
  for (int i = 0; i < loops_index.size(); ++i) {
274 275 276
    if (i > 0)
      CHECK_EQ(loops_index[i - 1] + 1, loops_index[i])
          << "Loops index in Fuse shoule be continuous!";
277 278
  }
  for (int i : loops_index) {
279 280
    CHECK_LT(i, (int)all_loops.size())
        << "The loop index in Fuse should be less than total loop's number.";
281 282 283 284 285 286
    CHECK_GE(i, 0) << "The loop index in Fuse should be >= 0.";
    loops_expr.emplace_back(all_loops[i]);
  }
  return this->Fuse(loops_expr);
}

287 288 289
void ScheduleImpl::MutateForType(const Expr& loop,
                                 ForType for_type,
                                 int factor) {
290 291
  auto* for_node = loop.As<ir::For>();
  CHECK(for_node) << "loop param must be For node! Please check.";
292 293 294 295 296
  CHECK(for_node->is_serial())
      << "loop is not serial, current forloop type is "
      << static_cast<int>(for_node->for_type()) << ", and it cannot become "
      << static_cast<int>(for_type);
  auto loop_copy = optim::IRCopy(loop);
297 298 299 300 301 302 303 304 305 306 307 308 309
  auto* new_for_node = loop_copy.As<ir::For>();
  CHECK(new_for_node);
  new_for_node->set_for_type(for_type);
  if (new_for_node->is_vectorized()) {
    VectorizeInfo vec_info(0, factor);
    new_for_node->set_vectorize_info(vec_info);
  } else if (new_for_node->is_binded()) {
    BindInfo bind_info(for_type, factor, DeviceAPI::GPU);
    new_for_node->set_bind_info(bind_info);
  }
  this->Replace(loop, loop_copy);
}

310 311 312
void ScheduleImpl::Parallel(const Expr& loop) {
  MutateForType(loop, ForType::Parallel);
}
313 314 315 316 317 318

void ScheduleImpl::Vectorize(const Expr& loop, int factor) {
  CHECK_GT(factor, 0) << "vectorize factor should be more than 0";
  MutateForType(loop, ForType::Vectorized, factor);
}

319 320 321
void ScheduleImpl::Unroll(const Expr& loop) {
  MutateForType(loop, ForType::Unrolled);
}
322 323

void ScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) {
324 325 326 327 328 329 330 331
  static std::set<std::string> thread_axes = {"blockIdx.x",
                                              "blockIdx.y",
                                              "blockIdx.z",
                                              "threadIdx.x",
                                              "threadIdx.y",
                                              "threadIdx.z"};
  CHECK(thread_axes.count(thread_axis))
      << "thread_axis " << thread_axis << " is not supported";
332 333 334 335 336 337 338 339 340 341 342
  int offset = thread_axis.back() - 'x';
  if (thread_axis[0] == 'b') {
    MutateForType(loop, ForType::GPUBlock, offset);
  } else {
    MutateForType(loop, ForType::GPUThread, offset);
  }
}

// The struct used to mutate new rfactor forloop and its' schedule block.
struct RfMutator : public ir::IRMutator<> {
 public:
343 344
  RfMutator(const Expr& rf_loop, const int& rf_axis)
      : rf_loop_(rf_loop), rf_axis_(rf_axis) {}
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
  void operator()(Expr* expr) {
    auto* rf_for = rf_loop_.As<For>();
    CHECK(rf_for);
    old_rf_loop_var_ = rf_for->loop_var;
    new_rf_loop_var_ = Var("rf_" + old_rf_loop_var_->name);
    IRMutator::Visit(expr, expr);
  }

  Tensor GetNewRfTensor() { return new_rf_tensor_; }

  void Visit(const ScheduleBlockRealize* op, Expr* expr) override {
    // modify iter_vars and iter_values
    auto* node = expr->As<ScheduleBlockRealize>();
    CHECK(node);
    auto* schedule_block = node->schedule_block.As<ScheduleBlock>();
    CHECK(schedule_block);
361 362 363
    old_output_name_ = schedule_block->name;
    find_tensor_ = false;
    auto& block_vars = schedule_block->iter_vars;
364 365 366 367 368 369 370 371
    auto& iter_values = node->iter_values;
    CHECK(old_rf_loop_var_.defined());
    CHECK(new_rf_loop_var_.defined());
    CHECK_EQ(iter_values.size(), block_vars.size());
    int rf_index = -1;
    for (int i = 0; i < iter_values.size(); ++i) {
      // substitute the old rfactor loop var to new rfactor loop var
      if (ContainVar({iter_values[i]}, old_rf_loop_var_->name)) {
372 373 374 375
        CHECK_EQ(rf_index, -1)
            << "only one block var can bind the rfactor loop var";
        CHECK(iter_values[i].As<_Var_>())
            << "rfactor loop var not support composite bindings";
376
        rf_index = i;
377 378
        optim::ReplaceVarWithExpr(
            &iter_values[i], old_rf_loop_var_, new_rf_loop_var_);
379 380 381 382 383
        new_rf_itervar_ = block_vars[i];
      }
    }
    // create new rfactor block var if not exist
    if (rf_index == -1) {
384 385
      new_rf_itervar_ =
          Var(cinn::UniqName("i" + std::to_string(block_vars.size())));
386 387 388 389
      iter_values.push_back(new_rf_loop_var_);
      block_vars.push_back(new_rf_itervar_);
    }
    IRMutator::Visit(&node->schedule_block, &node->schedule_block);
390 391 392
    CHECK(find_tensor_)
        << "not find the store tensor with the schedule block name "
        << old_output_name_;
393 394 395 396 397 398 399 400 401 402 403
    schedule_block->name = "rf_" + old_output_name_;
  }

  void Visit(const Load* op, Expr* expr) override {
    // insert the new rfactor indice if not exist
    auto* node = expr->As<Load>();
    CHECK(node);
    auto* tensor = node->tensor.As<_Tensor_>();
    CHECK(tensor);
    if (tensor->name == "rf_" + old_output_name_) {
      int size = node->indices.size();
404 405
      CHECK_LE(rf_axis_, size)
          << "rf_axis should not be greater than indice size " << size;
406 407
      CHECK(new_rf_itervar_.defined());
      CHECK(!ContainVar(node->indices, new_rf_itervar_->name))
408 409
          << "original output tensor " << old_output_name_
          << " should not have the new rfactor index " << new_rf_itervar_;
410 411 412 413 414 415 416 417 418 419 420 421 422
      node->indices.insert(node->indices.begin() + rf_axis_, new_rf_itervar_);
    }
  }

  void Visit(const Store* op, Expr* expr) override {
    // insert the new rfactor indice if not exist
    auto* node = expr->As<Store>();
    CHECK(node);
    auto* tensor = node->tensor.As<_Tensor_>();
    CHECK(tensor);
    if (tensor->name == old_output_name_) {
      find_tensor_ = true;
      tensor->name = "rf_" + tensor->name;
423 424 425
      int size = node->indices.size();
      CHECK_LE(rf_axis_, size)
          << "rf_axis should not be greater than indice size " << size;
426
      CHECK(!ContainVar(node->indices, new_rf_itervar_->name))
427 428
          << "original output tensor " << old_output_name_
          << " should not have the new rfactor index " << new_rf_itervar_;
429 430 431 432
      node->indices.insert(node->indices.begin() + rf_axis_, new_rf_itervar_);
      auto* rf_for = rf_loop_.As<For>();
      CHECK(rf_for);
      CHECK(is_zero(rf_for->min)) << "rfactor loop's min should be zero";
433 434
      auto extent = common::AutoSimplify(rf_for->extent);
      auto& shape = tensor->shape;
435
      auto& domain = tensor->domain;
436 437 438 439 440 441
      CHECK_LE(rf_axis_, shape.size())
          << "rf_axis should not be greater than tensor shape size "
          << shape.size();
      CHECK_LE(rf_axis_, domain.size())
          << "rf_axis should not be greater than tensor domain size "
          << domain.size();
442 443 444 445
      shape.insert(shape.begin() + rf_axis_, extent);
      domain.insert(domain.begin() + rf_axis_, extent);
      if (tensor->buffer.defined()) {
        if (tensor->buffer->name.find_first_of("rf") == std::string::npos) {
446
          tensor->buffer->name = "rf_" + tensor->buffer->name;
447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
          tensor->buffer->shape = shape;
        }
      }
      new_rf_tensor_ = Tensor(tensor);
    }
    IRMutator::Visit(&node->value, &node->value);
  }

  void Visit(const For* op, Expr* expr) override {
    auto* node = expr->As<For>();
    CHECK(node);
    depth++;
    auto* rf_for = rf_loop_.As<For>();
    CHECK(rf_for);
    // erase the original rfactor forloop
    if (node->loop_var->name == old_rf_loop_var_->name) {
      auto body = node->body.As<Block>();
      if (body && body->stmts.size() == 1) {
        *expr = body->stmts[0];
      } else {
        *expr = node->body;
      }
      IRMutator::Visit(expr, expr);
    } else {
      IRMutator::Visit(&node->body, &node->body);
    }
    if (rf_axis_ == 0 && depth == rf_axis_) {
      // insert new rfactor forloop in the rf_axis as serial loop
475 476 477 478 479 480
      *expr = For::Make(new_rf_loop_var_,
                        rf_for->min,
                        rf_for->extent,
                        ForType::Serial,
                        rf_for->device_api,
                        Block::Make({*expr}));
481 482
    } else if (depth == rf_axis_ - 1) {
      // insert new rfactor forloop in the rf_axis as serial loop
483 484 485 486 487 488
      node->body = Block::Make({For::Make(new_rf_loop_var_,
                                          rf_for->min,
                                          rf_for->extent,
                                          ForType::Serial,
                                          rf_for->device_api,
                                          node->body)});
489 490 491 492 493 494 495 496 497
    }
    depth--;
  }

 private:
  Expr rf_loop_;
  Var old_rf_loop_var_;
  Var new_rf_loop_var_;
  int rf_axis_;
498
  int depth = -1;
499 500 501 502 503 504 505 506 507
  bool find_tensor_ = false;
  std::string old_output_name_;
  Var new_rf_itervar_;
  Tensor new_rf_tensor_;
};

// The struct used to mutate final write-back forloop and schedule block.
struct FinalMutator : public ir::IRMutator<> {
 public:
508 509 510
  FinalMutator(const Expr& rf_loop,
               const int& rf_axis,
               const Tensor& new_rf_tensor)
511 512 513 514 515 516 517 518 519 520 521 522 523
      : rf_loop_(rf_loop), rf_axis_(rf_axis), new_rf_tensor_(new_rf_tensor) {}
  void operator()(Expr* expr) {
    auto* rf_for = rf_loop_.As<For>();
    CHECK(rf_for);
    old_rf_loop_var_ = rf_for->loop_var;
    IRMutator::Visit(expr, expr);
  }

  void Visit(const ScheduleBlockRealize* op, Expr* expr) override {
    auto* node = expr->As<ScheduleBlockRealize>();
    CHECK(node);
    auto* schedule_block = node->schedule_block.As<ScheduleBlock>();
    CHECK(schedule_block);
524
    auto& iter_vars = schedule_block->iter_vars;
525
    auto& iter_values = node->iter_values;
526
    output_name_ = schedule_block->name;
527 528 529 530 531
    visit_init_block_ = output_name_.rfind("_init") != std::string::npos;
    if (!visit_init_block_) {
      for (int i = 0; i < iter_values.size(); ++i) {
        if (ContainVar({iter_values[i]}, old_rf_loop_var_->name)) {
          // record the rfactor loop var's block var
532 533
          CHECK(iter_values[i].As<_Var_>())
              << "not support complex reduce bindings: " << iter_values[i];
534 535 536 537 538 539
          old_rf_iter_var_ = iter_vars[i];
          break;
        }
      }
    }
    IRMutator::Visit(&node->schedule_block, &node->schedule_block);
540 541
    // modify iter_vars and iter_values, erase other reduce block vars and
    // values
542 543 544
    for (auto it = iter_values.begin(); it != iter_values.end(); ++it) {
      for (auto erase_var : erase_reduce_loopvars_) {
        if (ContainVar({*it}, erase_var)) {
545 546
          CHECK((*it).As<_Var_>())
              << "not support complex reduce bindings: " << *it;
547 548 549 550 551 552 553 554 555 556 557 558 559 560
          iter_vars.erase(it - iter_values.begin() + iter_vars.begin());
          iter_values.erase(it);
          --it;
          break;
        }
      }
    }
  }

  // currently only support reduce_sum, reduce_mul, reduce_min and reduce_max
  void Visit(const Add* op, Expr* expr) override {
    auto* node = expr->As<Add>();
    CHECK(node);
    auto& oper_b = node->b();
561
    oper_b = Load::Make(new_rf_tensor_, new_rf_indice_);
562 563 564 565 566 567
  }

  void Visit(const Mul* op, Expr* expr) override {
    auto* node = expr->As<Mul>();
    CHECK(node);
    auto& oper_b = node->b();
568
    oper_b = Load::Make(new_rf_tensor_, new_rf_indice_);
569 570 571 572 573 574
  }

  void Visit(const Min* op, Expr* expr) override {
    auto* node = expr->As<Min>();
    CHECK(node);
    auto& oper_b = node->b();
575
    oper_b = Load::Make(new_rf_tensor_, new_rf_indice_);
576 577 578 579 580 581
  }

  void Visit(const Max* op, Expr* expr) override {
    auto* node = expr->As<Max>();
    CHECK(node);
    auto& oper_b = node->b();
582
    oper_b = Load::Make(new_rf_tensor_, new_rf_indice_);
583 584 585 586 587 588 589 590
  }

  void Visit(const Store* op, Expr* expr) override {
    // insert the new rfactor indice if not exist
    auto* node = expr->As<Store>();
    CHECK(node);
    auto* tensor = node->tensor.As<_Tensor_>();
    CHECK(tensor);
591 592
    CHECK_EQ(tensor->name, output_name_)
        << "store name should be same with the schedule block name";
593 594 595
    if (!visit_init_block_) {
      new_rf_indice_ = node->indices;
      CHECK_LE(rf_axis_, new_rf_indice_.size())
596 597
          << "rf_axis_ should not be greater than tensor indice size "
          << new_rf_indice_.size();
598
      CHECK(old_rf_iter_var_.defined());
599 600
      new_rf_indice_.insert(new_rf_indice_.begin() + rf_axis_,
                            old_rf_iter_var_);
601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650
      IRMutator::Visit(&node->value, &node->value);
    }
  }

  void Visit(const For* op, Expr* expr) override {
    auto* node = expr->As<For>();
    CHECK(node);
    auto* rf_for = rf_loop_.As<For>();
    // erase the reduce forloops after the init block except the rfactor loop
    if (visit_init_block_ && node->loop_var->name != old_rf_loop_var_->name) {
      erase_reduce_loopvars_.insert(node->loop_var->name);
      auto body = node->body.As<Block>();
      if (body && body->stmts.size() == 1) {
        *expr = body->stmts[0];
      } else {
        *expr = node->body;
      }
      IRMutator::Visit(expr, expr);
    } else {
      IRMutator::Visit(&node->body, &node->body);
    }
  }

 private:
  Expr rf_loop_;
  int rf_axis_;
  Var old_rf_loop_var_;
  Var old_rf_iter_var_;
  std::string output_name_;
  // collect reduce loop vars except rfactor loop var
  std::set<std::string> erase_reduce_loopvars_;
  bool visit_init_block_ = false;
  Tensor new_rf_tensor_;
  std::vector<Expr> new_rf_indice_;
};

// The struct used to create all stmts after rfactor transformation.
struct RfCreater : public ir::IRMutator<> {
 public:
  RfCreater(const Expr& root, const Expr& rf_loop, const int& rf_axis)
      : root_(root), rf_loop_(rf_loop), rf_axis_(rf_axis) {}
  void operator()(Expr* expr) { IRMutator::Visit(expr, expr); }

  Expr CreateRfAllStmts() {
    auto root_realize = root_.As<ScheduleBlockRealize>();
    CHECK(root_realize);
    auto root_block = root_realize->schedule_block.As<ScheduleBlock>();
    CHECK(root_block);
    Expr root_loop = optim::IRCopy(root_block->body);
    if (auto block = root_loop.As<Block>()) {
651 652
      CHECK_EQ(block->stmts.size(), 1U)
          << "rfactor root should only have one block stmt";
653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668
      root_loop = block->stmts[0];
    }
    auto* root_for = root_loop.As<For>();
    CHECK(root_for);
    auto rf_for = rf_loop_.As<For>();
    CHECK(rf_for);
    // create new rfactor forloops
    Expr new_rf_forloop = optim::IRCopy(root_loop);
    RfMutator rf_mutator(rf_loop_, rf_axis_);
    rf_mutator(&new_rf_forloop);
    VLOG(3) << "After RfMutator, new rf_forloop is\n" << new_rf_forloop;
    auto new_rf_tensor = rf_mutator.GetNewRfTensor();
    // create final write-back forloops
    Expr final_forloop = optim::IRCopy(root_loop);
    FinalMutator final_mutator(rf_loop_, rf_axis_, new_rf_tensor);
    final_mutator(&final_forloop);
669 670 671 672
    VLOG(3) << "After FinalMuator, final write-back forloop is\n"
            << final_forloop;
    // combine the new created rfactor forloops with the final write-back
    // forloops and replace
673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703
    root_block->body = Block::Make({new_rf_forloop, final_forloop});
    return new_rf_tensor;
  }

  Expr root_;
  Expr rf_loop_;
  int rf_axis_;
};

Expr ScheduleImpl::Rfactor(const Expr& rf_loop, int rf_axis) {
  CHECKRfactorValidation(rf_loop, rf_axis);
  // get root ScheduleBlockRealize
  Expr root = GetRootBlock(rf_loop);
  // create all stmts after rfactor transformation
  RfCreater rf_create(root, rf_loop, rf_axis);
  // return new created rfactor tensor
  return rf_create.CreateRfAllStmts();
}

struct CacheReadRewriter : public ir::IRMutator<> {
 public:
  static Expr Rewrite(const Expr& root, CacheBlockInfo* info) {
    CacheReadRewriter rewriter(root, info);
    Expr new_root = optim::IRCopy(root);
    rewriter(&new_root);
    return new_root;
  }

  void operator()(Expr* expr) { IRMutator::Visit(expr, expr); }

 private:
704 705
  explicit CacheReadRewriter(const Expr& root, CacheBlockInfo* info)
      : root_(root), info_(info) {}
706 707 708 709

  void Visit(const ir::Block* expr, Expr* op) override {
    if (*op == info_->loc_block) {
      IRMutator::Visit(expr, op);
710 711
      op->As<Block>()->stmts.insert(
          op->As<Block>()->stmts.begin() + info_->loc_pos, info_->cache_block);
712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736
    } else {
      IRMutator::Visit(expr, op);
    }
  }

  void Visit(const ir::Load* expr, Expr* op) override {
    if (expr->tensor == Expr(info_->read_tensor)) {
      IRMutator::Visit(expr, op);
      op->As<Load>()->tensor = Expr(info_->write_tensor);
    } else {
      IRMutator::Visit(expr, op);
    }
  }

 private:
  /*! \brief The parent scope of the insertion */
  const Expr& root_;
  /*! \brief The info for inserting cache stage */
  CacheBlockInfo* info_;
};

struct CacheWriteRewriter : public ir::IRMutator<> {
 public:
  static Expr Rewrite(const Expr& root, CacheBlockInfo* info) {
    CacheWriteRewriter rewriter(root, info);
737
    Expr new_root = optim::IRCopy(root);
738 739 740 741 742 743
    rewriter.mutate_cache_block = true;
    rewriter(&info->cache_block);
    rewriter.mutate_cache_block = false;
    rewriter(&new_root);
    auto find_tensor = ir::CollectIRNodesWithoutTensor(
        new_root,
744 745 746 747
        [&](const Expr* x) {
          return x->As<Store>() &&
                 (x->As<Store>()->tensor == Expr(info->read_tensor));
        },
748 749
        true);
    if (!find_tensor.empty()) {
750 751 752 753 754
      auto find_store = ir::CollectIRNodesWithoutTensor(
          (*find_tensor.begin()), [&](const Expr* x) {
            return x->As<Load>() &&
                   (x->As<Load>()->tensor == Expr(info->write_tensor));
          });
755 756 757 758 759 760 761 762 763 764
      for (auto load_ir : find_store) {
        load_ir.As<Load>()->tensor = Expr(info->read_tensor);
      }
    }
    return new_root;
  }

  void operator()(Expr* expr) { IRMutator::Visit(expr, expr); }

 private:
765 766
  explicit CacheWriteRewriter(const Expr& root, CacheBlockInfo* info)
      : root_(root), info_(info) {}
767 768 769 770

  void Visit(const ir::Block* expr, Expr* op) override {
    if (*op == info_->loc_block) {
      IRMutator::Visit(expr, op);
771 772
      op->As<Block>()->stmts.insert(
          op->As<Block>()->stmts.begin() + info_->loc_pos, info_->cache_block);
773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788
    } else {
      IRMutator::Visit(expr, op);
    }
  }

  void Visit(const ir::ScheduleBlock* expr, Expr* op) override {
    if (op->As<ScheduleBlock>()->name == info_->write_tensor->name) {
      op->As<ScheduleBlock>()->name = info_->read_tensor->name;
    } else if (op->As<ScheduleBlock>()->name == info_->read_tensor->name) {
      op->As<ScheduleBlock>()->name = info_->write_tensor->name;
    }
    IRMutator::Visit(expr, op);
  }

  void Visit(const ir::Load* expr, Expr* op) override {
    IRMutator::Visit(expr, op);
789 790
    if (op->As<Load>()->tensor == Expr(info_->write_tensor) &&
        mutate_cache_block) {
791
      op->As<Load>()->tensor = Expr(info_->read_tensor);
792 793
    } else if (op->As<Load>()->tensor == Expr(info_->read_tensor) &&
               mutate_cache_block) {
794 795 796 797 798 799 800 801
      op->As<Load>()->tensor = Expr(info_->write_tensor);
    }
  }

  void Visit(const ir::Store* expr, Expr* op) override {
    IRMutator::Visit(expr, op);
    if (op->As<Store>()->tensor == Expr(info_->write_tensor)) {
      op->As<Store>()->tensor = Expr(info_->read_tensor);
802 803
    } else if (op->As<Store>()->tensor == Expr(info_->read_tensor) &&
               mutate_cache_block) {
804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829
      op->As<Store>()->tensor = Expr(info_->write_tensor);
    }
  }

 private:
  /*! \brief The parent scope of the insertion */
  const Expr& root_;
  /*! \brief The info for inserting cache stage */
  CacheBlockInfo* info_;
  /*! \brief Are we mutating the cache tensor's block */
  bool mutate_cache_block{true};
};

//! Visit all ScheduleBlock and change its body to ir::Block if it is not.
struct ChangeBodyToBlock : public ir::IRMutator<> {
 public:
  static void Change(Expr* expr) {
    ChangeBodyToBlock mutator;
    mutator(expr);
  }

  void operator()(Expr* expr) { IRMutator::Visit(expr, expr); }

 private:
  void Visit(const ir::ScheduleBlock* expr, Expr* op) override {
    if (!op->As<ScheduleBlock>()->body.As<Block>()) {
830 831
      op->As<ScheduleBlock>()->body =
          Block::Make({op->As<ScheduleBlock>()->body});
832 833 834 835 836 837
    }
    IRMutator::Visit(expr, op);
  }
};

DeviceAPI ScheduleImpl::GetDeviceAPI() const {
838
  auto exprs = this->GetModule().GetExprs();
839 840 841 842 843 844
  auto find_for_nodes = ir::CollectIRNodesWithoutTensor(
      exprs.front(), [&](const Expr* x) { return x->As<ir::For>(); }, true);
  CHECK(!find_for_nodes.empty());
  return (*find_for_nodes.begin()).As<ir::For>()->device_api;
}

845 846 847
Expr ScheduleImpl::CacheRead(const Expr& block,
                             int read_tensor_index,
                             const std::string& memory_type) {
848 849 850 851 852 853 854
  CHECK(block.As<ScheduleBlockRealize>());
  auto root = GetRootBlock(block);
  ChangeBodyToBlock::Change(&root);
  Expr read_expr = GetNthAccessExpr(block, read_tensor_index, false);
  CHECK(read_expr.As<ir::Load>());
  auto tensor_indices = read_expr.As<ir::Load>()->indices;
  CacheBlockInfo info;
855
  info.read_tensor = read_expr.As<ir::Load>()->tensor.as_tensor_ref();
856
  info.write_tensor = MakeCacheTensor(info.read_tensor, memory_type);
857
  info.alloc = info.write_tensor;
858

859 860 861 862
  auto read_ranges =
      CalculateTensorRegions(block, tensor_indices, info.read_tensor, root);
  auto new_block =
      MakeCacheBlock(read_ranges, &info, memory_type, this->GetDeviceAPI());
863 864
  FindInsertionPoint(root, &info, false);
  auto new_root = CacheReadRewriter::Rewrite(root, &info);
865 866 867 868 869
  this->Replace(
      root.As<ScheduleBlockRealize>()->schedule_block.As<ScheduleBlock>()->body,
      new_root.As<ScheduleBlockRealize>()
          ->schedule_block.As<ScheduleBlock>()
          ->body);
870 871 872
  return new_block;
}

873 874 875
Expr ScheduleImpl::CacheWrite(const Expr& block,
                              int write_buffer_index,
                              const std::string& memory_type) {
876 877 878 879 880 881 882 883
  CHECK(block.As<ScheduleBlockRealize>());
  auto root = GetRootBlock(block);
  ChangeBodyToBlock::Change(&root);
  Expr write_expr = GetNthAccessExpr(block, write_buffer_index, true);
  CHECK(write_expr.As<ir::Store>());
  Tensor write_tensor = write_expr.As<ir::Store>()->tensor.as_tensor_ref();
  auto tensor_indices = write_expr.As<ir::Store>()->indices;
  CacheBlockInfo info;
884
  info.read_tensor = MakeCacheTensor(write_tensor, memory_type);
885
  info.write_tensor = write_tensor;
886 887 888 889 890
  info.alloc = info.read_tensor;
  auto write_ranges =
      CalculateTensorRegions(block, tensor_indices, info.write_tensor, root);
  auto new_block =
      MakeCacheBlock(write_ranges, &info, memory_type, this->GetDeviceAPI());
891 892 893
  FindInsertionPoint(root, &info, true);

  auto new_root = CacheWriteRewriter::Rewrite(root, &info);
894 895 896 897 898
  this->Replace(
      root.As<ScheduleBlockRealize>()->schedule_block.As<ScheduleBlock>()->body,
      new_root.As<ScheduleBlockRealize>()
          ->schedule_block.As<ScheduleBlock>()
          ->body);
899 900 901 902

  auto find_cache_block = ir::CollectIRNodesWithoutTensor(
      root,
      [&](const Expr* x) {
903 904
        return x->As<ir::ScheduleBlockRealize>() &&
               !x->As<ir::ScheduleBlockRealize>()->iter_values.empty() &&
905 906 907 908 909 910 911
               GetTensor(*x)->name == info.read_tensor->name;
      },
      true);

  CHECK(info.write_tensor->buffer.defined());

  // Replace buffer
912 913 914
  auto all_tensors = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
    return x->as_tensor() && x->as_tensor()->buffer.defined();
  });
915 916

  for (auto i : all_tensors) {
917 918
    if (i.as_tensor()->name != info.write_tensor->name &&
        i.as_tensor()->buffer.defined() &&
919 920 921 922 923 924 925 926 927 928 929 930
        i.as_tensor()->buffer->name == info.write_tensor->buffer->name) {
      i.as_tensor()->Bind(info.read_tensor->buffer);
    }
  }

  CHECK_EQ(find_cache_block.size(), 1U);

  return *find_cache_block.begin();
}

struct InsertExpr : public ir::IRMutator<> {
 public:
931 932 933 934
  static void Insert(const Expr& ir_node,
                     const Expr& insert_node,
                     bool after_node,
                     Expr* expr) {
935 936 937 938 939 940 941
    InsertExpr mutator(ir_node, insert_node, after_node);
    mutator(expr);
  }

  void operator()(Expr* expr) { IRMutator::Visit(expr, expr); }

 private:
942 943 944
  explicit InsertExpr(const Expr& ir_node,
                      const Expr& insert_node,
                      bool after_node)
945 946 947 948 949 950
      : ir_node_(ir_node), insert_node_(insert_node), after_node_(after_node) {}

  void Visit(const ir::Block* expr, Expr* op) override {
    for (int i = 0; i < expr->stmts.size(); i++) {
      if (expr->stmts[i] == ir_node_) {
        if (after_node_) {
951 952
          op->As<ir::Block>()->stmts.insert(
              op->As<ir::Block>()->stmts.begin() + i + 1, insert_node_);
953
        } else {
954 955
          op->As<ir::Block>()->stmts.insert(
              op->As<ir::Block>()->stmts.begin() + i, insert_node_);
956 957 958 959 960 961 962 963 964 965
        }
        return;
      }
    }
    IRMutator::Visit(expr, op);
  }

  void Visit(const ir::For* expr, Expr* op) override {
    if (expr->body == ir_node_) {
      if (after_node_)
966 967
        op->As<ir::For>()->body =
            ir::Block::Make({op->As<ir::For>()->body, insert_node_});
968
      else
969 970
        op->As<ir::For>()->body =
            ir::Block::Make({insert_node_, op->As<ir::For>()->body});
971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996
      return;
    }
    IRMutator::Visit(expr, op);
  }

 private:
  const Expr& ir_node_;
  const Expr& insert_node_;
  bool after_node_;
};

void ScheduleImpl::SyncThreads(const Expr& ir_node, bool after_node) {
  CHECK(ir_node.As<ScheduleBlockRealize>() || ir_node.As<ir::For>());
  auto root = GetRootBlock(ir_node);
  ChangeBodyToBlock::Change(&root);
  Expr sync_threads = runtime::IntrinsicCall(Void(), "__syncthreads", {});
  InsertExpr::Insert(ir_node, sync_threads, after_node, &root);
  return;
}

/**
 * Replace a For node to another For node.
 * @param src_sref The For node to be changed.
 * @param tgt_stmt The For node we want.
 */
void ScheduleImpl::Replace(const Expr& src_sref, const Expr& tgt_stmt) {
997 998 999 1000
  CHECK(src_sref.As<ir::For>() || src_sref.As<ir::Block>() ||
        src_sref.As<ir::ScheduleBlockRealize>());
  CHECK(tgt_stmt.As<ir::For>() || tgt_stmt.As<ir::Block>() ||
        tgt_stmt.As<ir::ScheduleBlockRealize>());
1001 1002 1003 1004
  if (src_sref == tgt_stmt) {
    return;
  }
  struct ForLoopMutator : public ir::IRMutator<> {
1005 1006
    ForLoopMutator(const Expr& source, const Expr& target)
        : source_(source), target_(target) {}
1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050

    void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }

    void Visit(const ir::For* op, Expr* expr) override {
      if (*expr == source_) {
        *expr = target_;
        return;
      }
      ir::IRMutator<>::Visit(op, expr);
    }

    void Visit(const ir::ScheduleBlockRealize* op, Expr* expr) override {
      if (*expr == source_) {
        *expr = target_;
        return;
      }
      ir::IRMutator<>::Visit(op, expr);
    }

    void Visit(const ir::Block* op, Expr* expr) override {
      if (*expr == source_) {
        *expr = target_;
        return;
      }
      ir::IRMutator<>::Visit(op, expr);
    }

    const Expr& source_;
    const Expr& target_;
  };
  auto exprs = module_expr_.GetExprs();
  ForLoopMutator mutator(src_sref, tgt_stmt);
  for (auto& i : exprs) {
    mutator(&i);
  }
}

Expr ScheduleImpl::Reorder(const std::vector<Expr>& loops) {
  if (loops.size() <= 1) {
    return Expr{nullptr};
  }
  VLOG(4) << "Before Reorder, ir is:\n" << loops[0];

  std::set<Expr, CompExpr> loop_set = CollectLoopsToSet(loops);
1051 1052 1053 1054 1055 1056
  auto boundary = GetBoundaryOfReorderRange(loop_set);
  Expr top = boundary.first;
  Expr bottom = boundary.second;
  std::vector<Expr> chain = GetLoopsInRange(top, bottom);
  std::vector<Expr> if_nodes = GetIfThenElseInRange(top, bottom);
  Expr new_loop = ConstructNewLoopChain(chain, loops, loop_set, if_nodes);
1057 1058 1059 1060 1061 1062
  this->Replace(top, new_loop);

  VLOG(4) << "After Reorder, ir is:\n" << new_loop;
  return new_loop;
}

1063 1064
Expr ScheduleImpl::Reorder(const std::string& block_name,
                           const std::vector<int>& loops_index) {
1065 1066 1067 1068
  std::vector<Expr> all_loops = this->GetLoops(block_name);
  std::vector<Expr> loops_expr;
  loops_expr.reserve(loops_index.size());
  for (int i : loops_index) {
1069 1070
    CHECK_LT(i, (int)all_loops.size())
        << "The loop index in Reorder should be less than total loop's number.";
1071 1072 1073 1074 1075 1076
    CHECK_GE(i, 0) << "The loop index in Reorder should be >= 0.";
    loops_expr.emplace_back(all_loops[i]);
  }
  return this->Reorder(loops_expr);
}

1077 1078
Expr ScheduleImpl::Reorder(const Expr& block,
                           const std::vector<int>& loops_index) {
1079 1080 1081 1082
  std::vector<Expr> all_loops = this->GetLoops(block);
  std::vector<Expr> loops_expr;
  loops_expr.reserve(loops_index.size());
  for (int i : loops_index) {
1083 1084
    CHECK_LT(i, (int)all_loops.size())
        << "The loop index in Reorder should be less than total loop's number.";
1085 1086 1087 1088 1089 1090 1091 1092 1093 1094
    CHECK_GE(i, 0) << "The loop index in Reorder should be >= 0.";
    loops_expr.emplace_back(all_loops[i]);
  }
  return this->Reorder(loops_expr);
}

Expr ScheduleImpl::GetRootBlock(const Expr& expr) const {
  auto exprs = this->GetModule().GetExprs();
  for (auto& it_expr : exprs) {
    auto find_expr = ir::CollectIRNodesWithoutTensor(
1095 1096 1097 1098 1099
        it_expr,
        [&](const Expr* x) {
          return x->node_type() == expr.node_type() && *x == expr;
        },
        true);
1100 1101 1102 1103 1104 1105 1106
    if (!find_expr.empty()) {
      CHECK(it_expr.As<ir::Block>());
      CHECK_EQ(it_expr.As<ir::Block>()->stmts.size(), 1U);
      CHECK(it_expr.As<ir::Block>()->stmts[0].As<ir::ScheduleBlockRealize>());
      return it_expr.As<ir::Block>()->stmts[0];
    }
  }
1107 1108 1109
  LOG(FATAL) << "Didn't find expr \n"
             << expr << "in ScheduleImpl:\n"
             << exprs[0];
1110 1111 1112 1113 1114
}

// The struct used to reconstruct the new For node to replace the old For node.
struct LoopReconstructor : public ir::IRMutator<> {
 public:
1115 1116 1117
  explicit LoopReconstructor(const Expr& root,
                             const Expr& block,
                             const Expr& loop)
1118 1119 1120 1121
      : root_(root), block_(block), loop_(loop) {}

  void operator()(Expr* expr) { IRMutator::Visit(expr, expr); }

1122 1123
  /* \param inserted_pos The position index of the new_loop_ body `stmts` to be
   * inserted:
1124 1125 1126
   *        - `index = -1` means inserted into the tail
   *        - otherwise, it should be a index between [0, stmts size)
   */
1127 1128 1129
  std::string MakeNewLoop(const std::vector<IterRange>& iter_ranges,
                          bool keep_unit_loops,
                          int inserted_pos = -1) {
1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140
    int n_iters = iter_ranges.size();
    std::vector<Var> loop_vars;
    std::vector<Expr> loop_extents;
    std::vector<Expr> iter_values;
    loop_vars.reserve(n_iters);
    loop_extents.reserve(n_iters);
    iter_values.reserve(n_iters);
    std::vector<std::string> new_var_names;
    for (int i = 0; i < n_iters; ++i) {
      const auto& range = iter_ranges[i];
      if (keep_unit_loops || range.extent != Expr(1)) {
1141 1142
        std::string var_name =
            common::UniqName("ax" + std::to_string(loop_vars.size()));
1143 1144 1145 1146 1147 1148 1149 1150 1151
        new_var_names.push_back(var_name);
        Var var(var_name, Int(32));
        loop_vars.push_back(var);
        loop_extents.push_back(range.extent);
        iter_values.push_back(common::AutoSimplify(range.min) + var);
      } else {
        iter_values.push_back(common::AutoSimplify(range.min));
      }
    }
1152 1153 1154 1155 1156
    auto schedule_block_node =
        block_.As<ir::ScheduleBlockRealize>()->schedule_block;
    new_block_ = ScheduleBlockRealize::Make(std::move(iter_values),
                                            std::move(schedule_block_node));
    Expr loop_body = new_block_;
1157
    for (int i = static_cast<int>(loop_vars.size()) - 1; i >= 0; --i) {
1158
      auto loop_var = loop_vars[i];
1159 1160
      auto loop_extent = loop_extents[i];
      if (!loop_body.As<ir::Block>()) loop_body = Block::Make({loop_body});
1161 1162 1163 1164 1165 1166
      loop_body = For::Make(loop_var,
                            Expr(0),
                            loop_extent,
                            ForType::Serial,
                            loop_.As<ir::For>()->device_api,
                            std::move(loop_body));
1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179
    }
    new_loop_ = optim::IRCopy(loop_);

    // Replace the copied Tensor object with the original Tensor object,
    // to ensure that the same Tensor in a AST is the same object.
    std::unordered_map<std::string, ir::Expr> tensors_map;
    ir::CollectIRNodesWithoutTensor(loop_, [&tensors_map](const Expr* x) {
      if (x->as_tensor()) {
        tensors_map.insert({x->as_tensor()->name, *x});
        return true;
      }
      return false;
    });
1180 1181
    auto find_store = ir::CollectIRNodesWithoutTensor(
        new_loop_, [](const Expr* x) { return x->As<ir::Store>(); });
1182
    for (auto store : find_store) {
1183 1184
      store.As<ir::Store>()->tensor =
          tensors_map.at(store.As<ir::Store>()->tensor.as_tensor()->name);
1185
    }
1186 1187
    auto find_load = ir::CollectIRNodesWithoutTensor(
        new_loop_, [](const Expr* x) { return x->As<ir::Load>(); });
1188
    for (auto load : find_load) {
1189 1190
      load.As<ir::Load>()->tensor =
          tensors_map.at(load.As<ir::Load>()->tensor.as_tensor()->name);
1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208
    }

    InsertBlock(new_loop_, loop_body, inserted_pos);
    return utils::Join(new_var_names, ",");
  }

 private:
 public:
  /*! \brief The root block */
  Expr root_;
  /*! \brief The given block to be moved */
  Expr block_;
  /*! \brief The given loop the block and its loop nest to be put under */
  Expr loop_;
  /*! \brief The new loop to replace the original loop */
  Expr new_loop_{nullptr};
  /*! \brief The new block realize to the moved block */
  Expr new_block_{nullptr};
1209 1210
  /*! \brief The plan to remove the given block by replacing this loop/block in
   * the AST */
1211
  Expr source_expr{nullptr};
1212 1213
  /*! \brief The plan to remove the given block by replacing to this loop/block
   * in the AST */
1214 1215 1216 1217 1218
  Expr target_expr{nullptr};
};

struct FixLocalBufferSize : public ir::IRMutator<> {
 public:
1219 1220
  FixLocalBufferSize(const std::string& tensor_name)
      : tensor_name_(tensor_name) {}
1221 1222 1223 1224 1225 1226

  void operator()(Expr* expr) { IRMutator::Visit(expr, expr); }

 private:
  void Visit(const ir::Store* expr, Expr* op) override {
    if (op->As<Store>()->tensor.As<_Tensor_>()->name == tensor_name_) {
1227 1228
      op->As<Store>()->tensor.As<_Tensor_>()->shape = {Expr(1)};
      op->As<Store>()->tensor.As<_Tensor_>()->domain = {Expr(1)};
1229
      op->As<Store>()->tensor.As<_Tensor_>()->buffer->shape = {Expr(1)};
1230
      op->As<Store>()->indices = {Expr(0)};
1231 1232 1233 1234 1235 1236
    }
    IRMutator::Visit(expr, op);
  }

  void Visit(const ir::Load* expr, Expr* op) override {
    if (op->As<Load>()->tensor.As<_Tensor_>()->name == tensor_name_) {
1237 1238
      op->As<Load>()->tensor.As<_Tensor_>()->shape = {Expr(1)};
      op->As<Load>()->tensor.As<_Tensor_>()->domain = {Expr(1)};
1239
      op->As<Load>()->tensor.As<_Tensor_>()->buffer->shape = {Expr(1)};
1240
      op->As<Load>()->indices = {Expr(0)};
1241 1242 1243 1244 1245 1246
    }
    IRMutator::Visit(expr, op);
  }
  std::string tensor_name_;
};

1247 1248 1249
void ScheduleImpl::SetBuffer(Expr& block,
                             const std::string& memory_type,
                             bool fixed) {
1250 1251 1252
  CHECK(block.As<ir::ScheduleBlockRealize>());
  auto find_tensor = ir::CollectIRNodesWithoutTensor(
      block, [&](const Expr* x) { return x->As<ir::Store>(); }, true);
1253 1254
  CHECK_EQ(find_tensor.size(), 1U)
      << "One block should only have one Store node!(except for root block)";
1255
  auto& tensor = (*find_tensor.begin()).As<ir::Store>()->tensor;
1256 1257
  tensor.as_tensor_ref()->WithBuffer(
      memory_type, "_" + tensor.as_tensor_ref()->name + "_temp_buffer");
1258 1259 1260

  auto exprs = this->GetModule().GetExprs();
  for (auto& it_expr : exprs) {
1261 1262 1263 1264 1265 1266 1267
    auto find_tensor =
        ir::CollectIRNodesWithoutTensor(it_expr, [&](const Expr* x) {
          return x->as_tensor() &&
                 (x->as_tensor()->name == tensor.as_tensor_ref()->name ||
                  x->as_tensor()->name ==
                      tensor.as_tensor_ref()->name + "__reduce_init");
        });
1268 1269 1270 1271 1272 1273 1274 1275
    for (auto& t : find_tensor) {
      CHECK(t.as_tensor());
      t.as_tensor_ref()->Bind(tensor.as_tensor_ref()->buffer);
    }
  }

  // if buffer type == "local"
  if (memory_type == "local" && fixed) {
1276 1277 1278
    FixLocalBufferSize mutator(block.As<ir::ScheduleBlockRealize>()
                                   ->schedule_block.As<ir::ScheduleBlock>()
                                   ->name);
1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289
    auto root = GetRootBlock(block);
    mutator(&root);
  }
}

void ScheduleImpl::MergeExprs() {
  auto exprs = this->GetModule().GetExprs();
  if (exprs.size() == 1U) return;
  CHECK(exprs[0].As<ir::Block>());
  CHECK_EQ(exprs[0].As<ir::Block>()->stmts.size(), 1U);
  CHECK(exprs[0].As<ir::Block>()->stmts[0].As<ir::ScheduleBlockRealize>());
1290 1291 1292 1293 1294
  CHECK(exprs[0]
            .As<ir::Block>()
            ->stmts[0]
            .As<ir::ScheduleBlockRealize>()
            ->schedule_block.As<ir::ScheduleBlock>());
1295
  std::vector<Expr> merged_block;
1296 1297 1298 1299 1300 1301
  merged_block.push_back(exprs[0]
                             .As<ir::Block>()
                             ->stmts[0]
                             .As<ir::ScheduleBlockRealize>()
                             ->schedule_block.As<ir::ScheduleBlock>()
                             ->body);
1302 1303 1304 1305 1306
  VLOG(3) << "Before merging, exprs[0] is : " << exprs[0];
  for (int i = 1; i < exprs.size(); ++i) {
    auto root_block = ir::CollectIRNodesWithoutTensor(
        exprs[i],
        [&](const Expr* x) {
1307 1308
          return x->As<ir::ScheduleBlockRealize>() &&
                 x->As<ir::ScheduleBlockRealize>()->iter_values.empty();
1309 1310 1311 1312
        },
        true);
    CHECK_EQ(root_block.size(), 1U);
    for (auto& it_block : root_block) {
1313 1314 1315
      auto& block_body = it_block.As<ir::ScheduleBlockRealize>()
                             ->schedule_block.As<ir::ScheduleBlock>()
                             ->body;
1316 1317 1318 1319 1320 1321 1322
      merged_block.push_back(block_body);
    }
  }
  for (auto& block : merged_block) {
    VLOG(3) << "in merged_block, it has " << block;
  }
  auto merged_expr = ir::Block::Make(merged_block);
1323 1324 1325 1326 1327 1328
  exprs[0]
      .As<ir::Block>()
      ->stmts[0]
      .As<ir::ScheduleBlockRealize>()
      ->schedule_block.As<ir::ScheduleBlock>()
      ->body = merged_expr;
1329 1330 1331 1332 1333
  VLOG(3) << "After merging, exprs[0] is : " << exprs[0];
  exprs.erase(exprs.begin() + 1, exprs.end());
  this->SetExprs(exprs);
}

1334 1335 1336
void ScheduleImpl::ComputeAt(const Expr& block,
                             const Expr& loop,
                             bool keep_unit_loops) {
1337 1338 1339 1340 1341 1342 1343 1344 1345 1346
  CHECK(block.As<ir::ScheduleBlockRealize>());
  CHECK(loop.As<ir::For>());
  Expr root = this->GetRootBlock(block);

  VLOG(3) << "Begin ComputeAt of loop:\n" << loop << "\nat block:\n" << root;

  auto producers = GetProducers(block, root);
  auto consumers = GetConsumers(block, root);
  CheckComputeAtValidation(block, loop, root);
  LoopReconstructor reconstructor(root, block, loop);
1347 1348
  LeafBlockRemovalPlan remove_plan(
      block, &reconstructor.source_expr, &reconstructor.target_expr);
1349
  remove_plan(&root);
1350 1351 1352 1353 1354 1355
  auto iter_ranges = CalculateRequiredRegions(block, loop, root, consumers);
  std::string new_var_names =
      reconstructor.MakeNewLoop(iter_ranges, keep_unit_loops, 0);
  auto sch_block_expr = block.As<ir::ScheduleBlockRealize>()->schedule_block;
  sch_block_expr.As<ir::ScheduleBlock>()->attrs.emplace(
      ir::attr::compute_at_extra_var, new_var_names);
1356 1357 1358 1359 1360 1361 1362 1363 1364 1365
  this->Replace(reconstructor.source_expr, reconstructor.target_expr);
  this->Replace(reconstructor.loop_, reconstructor.new_loop_);

  VLOG(3) << "After SimpleComputeAt, ir is:\n" << reconstructor.new_loop_;
}

void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) {
  CHECK(block.As<ir::ScheduleBlockRealize>());
  CHECK(loop.As<ir::For>());
  std::vector<Expr> block_loops = this->GetLoops(block);
1366 1367
  Expr root = this->GetRootBlock(block);
  auto loops = GetLoopsOfExpr(loop, root);
1368

1369 1370 1371
  VLOG(3) << "Begin SimpleComputeAt of loop:\n"
          << loop << "\nat block:\n"
          << root;
1372

1373
  auto this_loop = loop;
1374 1375 1376 1377 1378
  auto block_name = GetTensor(block)->name;
  auto this_block = block;
  if (GetLoopExtent(loops[0]) == 1 && GetLoopExtent(block_loops[0]) != 1) {
    this->Split(block_loops[0], {1, -1});
    this_block = this->GetBlock(block_name);
1379 1380
  } else if (GetLoopExtent(loops[0]) != 1 &&
             GetLoopExtent(block_loops[0]) == 1) {
1381
    auto splited = this->Split(loops[0], {1, -1});
1382
    this_loop = splited[1];
1383 1384 1385
  }

  block_loops = this->GetLoops(this_block);
1386 1387
  root = this->GetRootBlock(this_block);
  loops = GetLoopsOfExpr(this_loop, root);
1388 1389 1390 1391 1392 1393 1394

  CHECK_LE(loops.size(), block_loops.size());

  std::vector<Var> replaced_var;
  std::vector<Expr> substitute_expr;
  for (int i = 0; i < loops.size(); ++i) {
    CHECK_EQ(GetLoopExtent(loops[i]), GetLoopExtent(block_loops[i]));
1395 1396 1397 1398
    if (block_loops[i].As<ir::For>()->bind_info().valid() &&
        !loops[i].As<ir::For>()->bind_info().valid()) {
      loops[i].As<ir::For>()->set_bind_info(
          block_loops[i].As<ir::For>()->bind_info());
1399 1400 1401 1402 1403
    }
    replaced_var.push_back(block_loops[i].As<ir::For>()->loop_var);
    substitute_expr.push_back(Expr(loops[i].As<ir::For>()->loop_var));
  }

1404 1405 1406
  Expr result = loops.size() < block_loops.size()
                    ? optim::IRCopy(block_loops[loops.size()])
                    : optim::IRCopy(this_block);
1407 1408 1409 1410 1411 1412
  Expr new_loop = optim::IRCopy(this_loop);

  // Get the body of block_loop under the same loops
  auto body = block_loops.at(loops.size() - 1).As<ir::For>()->body;
  // collect if
  auto if_checker = [](const Expr* x) { return x->As<ir::IfThenElse>(); };
1413
  auto if_set = ir::CollectIRNodesWithoutTensor(body, if_checker);
1414 1415 1416
  for (auto if_expr : if_set) {
    auto checker = [block_name](const Expr* x) {
      return x->As<ir::ScheduleBlockRealize>() &&
1417 1418 1419
             x->As<ir::ScheduleBlockRealize>()
                     ->schedule_block.As<ScheduleBlock>()
                     ->name == block_name;
1420 1421
    };
    if (ir::CollectIRNodesWithoutTensor(if_expr, checker, true).size() > 0) {
1422 1423
      result =
          IfThenElse::Make(if_expr.As<ir::IfThenElse>()->condition, result);
1424 1425 1426 1427 1428 1429 1430
      break;
    }
  }

  ReplaceExpr(&result, replaced_var, substitute_expr);
  // When there are two identical IfThenElse
  if (new_loop.As<ir::For>() && new_loop.As<ir::For>()->body.As<ir::Block>() &&
1431 1432 1433 1434
      new_loop.As<ir::For>()
          ->body.As<ir::Block>()
          ->stmts[0]
          .As<ir::IfThenElse>()) {
1435 1436
    auto if_then_else = new_loop.As<ir::For>()->body.As<ir::Block>()->stmts[0];
    if (result.As<ir::IfThenElse>() &&
1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448
        if_then_else.As<ir::IfThenElse>()->condition ==
            result.As<ir::IfThenElse>()->condition) {
      new_loop.As<ir::For>()
          ->body.As<ir::Block>()
          ->stmts[0]
          .As<ir::IfThenElse>()
          ->true_case = ir::Block::Make({result.As<ir::IfThenElse>()->true_case,
                                         new_loop.As<ir::For>()
                                             ->body.As<ir::Block>()
                                             ->stmts[0]
                                             .As<ir::IfThenElse>()
                                             ->true_case});
1449
    } else {
1450 1451
      std::vector<ir::Expr>::iterator pos =
          new_loop.As<ir::For>()->body.As<ir::Block>()->stmts.begin();
1452 1453 1454
      new_loop.As<ir::For>()->body.As<ir::Block>()->stmts.insert(pos, result);
    }
  } else {
1455 1456
    new_loop.As<ir::For>()->body =
        ir::Block::Make({result, new_loop.As<ir::For>()->body});
1457 1458 1459 1460 1461 1462
  }

  Expr source_expr{nullptr};
  Expr target_expr{nullptr};

  LeafBlockRemovalPlan remove_plan(
1463 1464 1465
      result.As<ir::For>() ? block_loops[loops.size()] : this_block,
      &source_expr,
      &target_expr);
1466 1467 1468 1469 1470 1471 1472 1473
  remove_plan(&root);

  this->Replace(source_expr, target_expr);
  this->Replace(this_loop, new_loop);

  VLOG(3) << "After SimpleComputeAt, ir is:\n" << new_loop;
}

1474 1475 1476
void ScheduleImpl::ReverseComputeAt(const Expr& block,
                                    const Expr& loop,
                                    bool keep_unit_loops) {
1477 1478
  CHECK(block.As<ir::ScheduleBlockRealize>());
  CHECK(loop.As<ir::For>());
1479
  Expr root = this->GetRootBlock(block);
1480 1481 1482 1483
  auto producers = GetProducers(block, root);
  auto consumers = GetConsumers(block, root);
  CheckComputeAtValidation(block, loop, root);
  LoopReconstructor reconstructor(root, block, loop);
1484 1485
  LeafBlockRemovalPlan remove_plan(
      block, &reconstructor.source_expr, &reconstructor.target_expr);
1486
  remove_plan(&root);
1487 1488 1489 1490 1491 1492 1493
  auto iter_ranges =
      CalculateRequiredRegions(block, loop, root, producers, false);
  std::string new_var_names =
      reconstructor.MakeNewLoop(iter_ranges, keep_unit_loops, -1);
  auto sch_block_expr = block.As<ir::ScheduleBlockRealize>()->schedule_block;
  sch_block_expr.As<ir::ScheduleBlock>()->attrs.emplace(
      ir::attr::reverse_compute_at_extra_var, new_var_names);
1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511
  this->Replace(reconstructor.source_expr, reconstructor.target_expr);
  this->Replace(reconstructor.loop_, reconstructor.new_loop_);
  return;
}

void BaseInliner::operator()(Expr* expr) {
  IRMutator::Visit(&tgt_stmt, &tgt_stmt);
  IRMutator::Visit(expr, expr);
}

void BaseInliner::Visit(const ir::Block* expr, Expr* op) {
  if (*op == src_stmt) {
    *op = tgt_stmt;
    return;
  }
  IRMutator::Visit(expr, op);
}

1512 1513
bool BaseInliner::UpdateAndCheckIndexVars(const std::vector<Expr>& indices,
                                          int expected_ndim) {
1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557
  int n = indices.size();
  if (n != expected_ndim) {
    return false;
  }
  std::vector<Var> result;
  result.reserve(n);
  for (auto& i : indices) {
    if (i.as_var()) {
      result.push_back(i.as_var_ref());
    } else {
      return false;
    }
  }
  int n_distinct = std::set<Var, CompVar>(result.begin(), result.end()).size();
  if (n != n_distinct) {
    return false;
  }
  if (idx_vars_.empty()) {
    idx_vars_ = std::move(result);
  } else {
    if (idx_vars_.size() != result.size()) return false;
    for (int i = 0; i < result.size(); ++i) {
      if (Expr(idx_vars_[i]) != Expr(result[i])) return false;
    }
  }
  return true;
}

void BaseInliner::SetIndexSubstitution(const std::vector<Expr>& indices) {
  CHECK_EQ(indices.size(), idx_vars_.size());
  int n = idx_vars_.size();
  idx_sub_var_.reserve(n);
  idx_sub_expr_.reserve(n);
  for (int i = 0; i < n; ++i) {
    idx_sub_var_.push_back(idx_vars_[i]);
    idx_sub_expr_.push_back(indices[i]);
  }
}

bool ComputeInliner::BodyPatternAllowInline() {
  if (!inlined_store_.defined()) {
    return false;
  }
  CHECK(inlined_store_.As<Store>());
1558 1559
  auto find_vars = ir::CollectIRNodesWithoutTensor(
      inlined_store_, [&](const Expr* x) { return x->as_var(); });
1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587
  std::set<Var, CompVar> vars_set;
  for (auto& i : find_vars) vars_set.insert(i.as_var_ref());
  int n_vars = vars_set.size();
  if (!UpdateAndCheckIndexVars(inlined_store_.As<Store>()->indices, n_vars)) {
    return false;
  }
  return true;
}

void ComputeInliner::Visit(const ir::Load* expr, Expr* op) {
  if ((expr->tensor).as_tensor_ref()->name == inlined_tensor_->name) {
    *op = ReplaceInlinedTensor(op);
    return;
  }
  IRMutator::Visit(expr, op);
}

//! Replace the 'Load' node on the tensor to 'Load' node of its producers.
Expr ComputeInliner::ReplaceInlinedTensor(Expr* load) {
  CHECK(load->As<ir::Load>());
  SetIndexSubstitution(load->As<ir::Load>()->indices);
  Expr value_copy = optim::IRCopy(inlined_store_.As<Store>()->value);
  ReplaceExpr(&value_copy, idx_sub_var_, idx_sub_expr_);
  return value_copy;
}

void ScheduleImpl::ComputeInline(const Expr& schedule_block) {
  CHECK(schedule_block.As<ir::ScheduleBlockRealize>());
1588
  Expr root = this->GetRootBlock(schedule_block);
1589 1590 1591 1592
  Expr store = CheckComputeInlineValidationAndGetStore(schedule_block, root);
  ComputeInliner inliner(store.As<ir::Store>()->tensor.as_tensor_ref(), store);
  CHECK(inliner.BodyPatternAllowInline());
  // Create a plan that removes the block to be inlined
1593 1594
  LeafBlockRemovalPlan remove_plan(
      schedule_block, &inliner.src_stmt, &inliner.tgt_stmt);
1595 1596 1597 1598 1599 1600 1601
  remove_plan(&root);
  inliner(&root);
  return;
}

bool ComputeInlineChecker::Check() {
  Expr root = ir_schedule_.GetRootBlock(block_);
1602
  store_ = CheckComputeInlineValidationAndGetStore(block_, root);
1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625
  IRMutator::Visit(&root, &root);
  return !should_skip_;
}

void ComputeInlineChecker::BuildDataDependency() {
  ir_schedule_.SetBuffer(block_, "shared", true);
  auto loops = ir_schedule_.GetLoops(block_);
  ir_schedule_.SyncThreads(loops.back(), true);
}

bool ReverseComputeInliner::BodyPatternAllowInline() {
  if (!inlined_store_.defined()) {
    return false;
  }
  if (!inlined_load_.defined()) {
    return false;
  }
  if (!target_store_.defined()) {
    return false;
  }
  CHECK(inlined_store_.As<Store>());
  CHECK(inlined_load_.As<Load>());
  CHECK(target_store_.As<Store>());
1626 1627
  auto find_vars = ir::CollectIRNodesWithoutTensor(
      inlined_store_, [&](const Expr* x) { return x->as_var(); });
1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677
  std::set<Var, CompVar> vars_set;
  for (auto& i : find_vars) vars_set.insert(i.as_var_ref());
  int n_vars = vars_set.size();
  if (!UpdateAndCheckIndexVars(inlined_store_.As<Store>()->indices, n_vars)) {
    return false;
  }
  return true;
}

void ReverseComputeInliner::Visit(const ir::Load* expr, Expr* op) {
  if ((expr->tensor).as_tensor_ref()->name == inlined_tensor_->name) {
    *op = inlined_store_.As<Store>()->value;
    return;
  }
  IRMutator::Visit(expr, op);
}

void ReverseComputeInliner::Visit(const ir::Store* expr, Expr* op) {
  if ((expr->tensor).as_tensor_ref()->name == inlined_tensor_->name) {
    *op = ReplaceTargetTensor(op);
    return;
  }
  IRMutator::Visit(expr, op);
}

//! Replace the 'Load' node on the tensor to 'Load' node of its producers.
Expr ReverseComputeInliner::ReplaceInlinedTensor(Expr* load) {
  CHECK(load->As<ir::Load>());
  SetIndexSubstitution(load->As<ir::Load>()->indices);
  Expr value_copy = optim::IRCopy(inlined_store_.As<Store>()->value);
  return value_copy;
}

Expr ReverseComputeInliner::ReplaceTargetTensor(Expr* store) {
  auto indices = inlined_load_.As<ir::Load>()->indices;
  CHECK_EQ(indices.size(), idx_vars_.size());
  size_t n = idx_vars_.size();
  idx_sub_var_.reserve(n);
  idx_sub_expr_.reserve(n);
  for (int i = 0; i < n; ++i) {
    idx_sub_var_.emplace_back(indices[i].as_var_ref());
    idx_sub_expr_.emplace_back(idx_vars_[i]);
  }

  Expr value_copy = optim::IRCopy(target_store_);
  ReplaceExpr(&value_copy, idx_sub_var_, idx_sub_expr_);
  return value_copy;
}

void ScheduleImpl::ReverseComputeInline(const Expr& schedule_block) {
1678 1679 1680 1681
  Expr root = this->GetRootBlock(schedule_block);
  auto exprs =
      CheckReverseComputeInlineValidationAndGetExprs(schedule_block, root);
  Expr inlined_load = std::get<0>(exprs);
1682
  Expr inlined_store = std::get<1>(exprs);
1683
  Expr target_store = std::get<2>(exprs);
1684
  ReverseComputeInliner inliner(
1685 1686 1687 1688
      inlined_store.As<ir::Store>()->tensor.as_tensor_ref(),
      inlined_store,
      inlined_load,
      target_store);
1689 1690
  CHECK(inliner.BodyPatternAllowInline());
  // Create a plan that removes the block to be inlined
1691 1692
  LeafBlockRemovalPlan remove_plan(
      schedule_block, &inliner.src_stmt, &inliner.tgt_stmt);
1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708
  remove_plan(&root);
  inliner(&root);
  inliner(&root);
}

struct FindBlockParent : public ir::IRMutator<> {
 public:
  FindBlockParent(const std::string& block_name) : block_name_(block_name) {}

  void operator()(Expr* expr) { IRMutator::Visit(expr, expr); }

 private:
  void Visit(const ir::Block* expr, Expr* op) override {
    if (target_) return;
    for (auto& stmt : expr->stmts) {
      if (stmt.As<ir::ScheduleBlockRealize>()) {
1709 1710 1711
        if (stmt.As<ir::ScheduleBlockRealize>()
                ->schedule_block.As<ir::ScheduleBlock>()
                ->name == block_name_) {
1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722
          target_ = op;
          return;
        }
      }
    }
    IRMutator::Visit(expr, op);
  }

  void Visit(const ir::For* expr, Expr* op) override {
    if (target_) return;
    if (expr->body.As<ir::ScheduleBlockRealize>()) {
1723 1724 1725
      if (expr->body.As<ir::ScheduleBlockRealize>()
              ->schedule_block.As<ir::ScheduleBlock>()
              ->name == block_name_) {
1726 1727 1728 1729 1730 1731 1732 1733 1734 1735
        target_ = op;
        return;
      }
    }
    IRMutator::Visit(expr, op);
  }

  void Visit(const ir::ScheduleBlock* expr, Expr* op) override {
    if (target_) return;
    if (expr->body.As<ir::ScheduleBlockRealize>()) {
1736 1737 1738
      if (expr->body.As<ir::ScheduleBlockRealize>()
              ->schedule_block.As<ir::ScheduleBlock>()
              ->name == block_name_) {
1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754
        target_ = op;
        return;
      }
    }
    IRMutator::Visit(expr, op);
  }

  std::string block_name_;

 public:
  ir::Expr* target_{nullptr};
};

Expr ScheduleImpl::AddUnitLoop(const Expr& block) const {
  auto exprs = module_expr_.GetExprs();
  CHECK(block.As<ir::ScheduleBlockRealize>());
1755 1756 1757 1758 1759
  CHECK(block.As<ir::ScheduleBlockRealize>()
            ->schedule_block.As<ir::ScheduleBlock>());
  std::string block_name = block.As<ir::ScheduleBlockRealize>()
                               ->schedule_block.As<ir::ScheduleBlock>()
                               ->name;
1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772

  FindBlockParent visitor(block_name);
  for (auto expr : exprs) {
    visitor(&expr);
    if (visitor.target_) {
      break;
    }
  }

  CHECK(visitor.target_) << ", block name : " << block_name << "\n" << exprs;
  if (visitor.target_->As<ir::Block>()) {
    for (auto& stmt : visitor.target_->As<ir::Block>()->stmts) {
      if (stmt.As<ir::ScheduleBlockRealize>()) {
1773 1774 1775
        if (stmt.As<ir::ScheduleBlockRealize>()
                ->schedule_block.As<ir::ScheduleBlock>()
                ->name == block_name) {
1776
          auto block = ir::Block::Make({GetBlock(block_name)});
1777
          auto loop = ir::For::Make(ir::Var(common::UniqName("ix")),
1778 1779 1780 1781 1782
                                    ir::Expr(0),
                                    ir::Expr(1),
                                    ir::ForType::Serial,
                                    ir::DeviceAPI::UNK,
                                    block);
1783
          stmt = loop;
1784 1785 1786 1787 1788 1789
          return loop;
        }
      }
    }
  } else if (visitor.target_->As<ir::For>()) {
    auto block = ir::Block::Make({visitor.target_->As<ir::For>()->body});
1790 1791 1792 1793 1794 1795
    auto loop = ir::For::Make(ir::Var(common::UniqName("ix")),
                              ir::Expr(0),
                              ir::Expr(1),
                              ir::ForType::Serial,
                              ir::DeviceAPI::UNK,
                              block);
1796 1797 1798
    visitor.target_->As<ir::For>()->body = loop;
    return loop;
  } else if (visitor.target_->As<ir::ScheduleBlock>()) {
1799 1800 1801 1802 1803 1804 1805 1806
    auto block =
        ir::Block::Make({visitor.target_->As<ir::ScheduleBlock>()->body});
    auto loop = ir::For::Make(ir::Var(common::UniqName("ix")),
                              ir::Expr(0),
                              ir::Expr(1),
                              ir::ForType::Serial,
                              ir::DeviceAPI::UNK,
                              block);
1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819
    visitor.target_->As<ir::ScheduleBlock>()->body = loop;
    return loop;
  } else {
    LOG(FATAL) << "Can't find block's parent!";
  }
  LOG(FATAL) << "Shouldn't reach code here in AddUnitLoop";
  return Expr{nullptr};
}

std::vector<Expr> ScheduleImpl::GetLoops(const Expr& block) const {
  std::vector<Expr> result;
  auto exprs = module_expr_.GetExprs();
  CHECK(block.As<ir::ScheduleBlockRealize>());
1820 1821 1822 1823 1824
  CHECK(block.As<ir::ScheduleBlockRealize>()
            ->schedule_block.As<ir::ScheduleBlock>());
  std::string block_name = block.As<ir::ScheduleBlockRealize>()
                               ->schedule_block.As<ir::ScheduleBlock>()
                               ->name;
1825 1826 1827 1828 1829

  for (auto& it_expr : exprs) {
    ir::FindLoopsVisitor visitor(block);
    auto find_loops = visitor(&it_expr);
    if (!find_loops.empty()) {
1830 1831 1832
      if (!result.empty())
        LOG(FATAL) << "Find block with name: \n"
                   << block_name << " appeared in more than one AST!";
1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843
      result = find_loops;
    }
  }

  if (result.empty()) {
    result.push_back(AddUnitLoop(block));
  }
  return result;
}

std::vector<Expr> ScheduleImpl::GetLoops(const std::string& block_name) const {
1844
  Expr block = this->GetBlock(block_name);
1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876
  std::vector<Expr> result = this->GetLoops(block);
  return result;
}

std::vector<Expr> ScheduleImpl::GetAllBlocks() const {
  std::vector<Expr> result;
  auto exprs = module_expr_.GetExprs();
  for (auto& it_expr : exprs) {
    ir::FindBlocksVisitor visitor;
    auto find_blocks = visitor(&it_expr);
    result.insert(result.end(), find_blocks.begin(), find_blocks.end());
  }
  for (auto& it_expr : exprs) {
    VLOG(3) << "it_expr is : " << it_expr;
  }
  CHECK(!result.empty()) << "Didn't find blocks in expr.";
  return result;
}

std::vector<Expr> ScheduleImpl::GetChildBlocks(const Expr& expr) const {
  CHECK(expr.As<ir::ScheduleBlockRealize>() || expr.As<ir::For>());
  ir::FindBlocksVisitor visitor;
  std::vector<Expr> result = visitor(&expr);
  return result;
}

bool ScheduleImpl::HasBlock(const std::string& block_name) const {
  auto exprs = module_expr_.GetExprs();
  for (auto& it_expr : exprs) {
    ir::FindBlocksVisitor visitor(block_name);
    auto find_blocks = visitor(&it_expr);
    if (!find_blocks.empty()) {
1877 1878
      CHECK_EQ(find_blocks.size(), 1U)
          << "There should not be more than 1 block with identical name!";
1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891
      return true;
    }
  }
  return false;
}

Expr ScheduleImpl::GetBlock(const std::string& block_name) const {
  Expr result;
  auto exprs = module_expr_.GetExprs();
  for (auto& it_expr : exprs) {
    ir::FindBlocksVisitor visitor(block_name);
    auto find_blocks = visitor(&it_expr);
    if (!find_blocks.empty()) {
1892 1893
      CHECK_EQ(find_blocks.size(), 1U)
          << "There should not be more than 1 block with identical name!";
1894 1895 1896 1897
      result = find_blocks[0];
      return result;
    }
  }
1898 1899
  LOG(FATAL) << "Didn't find a block with name " << block_name
             << " in this ModuleExpr!";
1900 1901
}

1902 1903 1904
void ScheduleImpl::Annotate(const Expr& block,
                            const std::string& key,
                            const attr_t& value) {
1905
  CHECK(block.As<ir::ScheduleBlockRealize>());
1906 1907 1908 1909 1910
  CHECK(block.As<ir::ScheduleBlockRealize>()
            ->schedule_block.As<ir::ScheduleBlock>());
  auto copied_block = optim::IRCopy(block);
  auto* schedule_block = copied_block.As<ir::ScheduleBlockRealize>()
                             ->schedule_block.As<ir::ScheduleBlock>();
1911 1912 1913 1914 1915 1916
  schedule_block->attrs.emplace(key, value);
  this->Replace(block, copied_block);
}

void ScheduleImpl::Unannotate(Expr& block, const std::string& ann_key) {
  CHECK(block.As<ir::ScheduleBlockRealize>());
1917 1918 1919 1920
  CHECK(block.As<ir::ScheduleBlockRealize>()
            ->schedule_block.As<ir::ScheduleBlock>());
  auto* schedule_block = block.As<ir::ScheduleBlockRealize>()
                             ->schedule_block.As<ir::ScheduleBlock>();
1921 1922 1923 1924 1925 1926 1927 1928
  if (schedule_block->attrs.count(ann_key)) {
    schedule_block->attrs.erase(ann_key);
  } else {
    LOG(WARNING) << "Can't find annotation with key: " << ann_key;
    return;
  }
}

1929 1930
void ScheduleImpl::FlattenLoops(const std::vector<Expr>& loops,
                                const bool flat_tensor) {
1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945
  CHECK_GT(loops.size(), 0) << "Loops can't be empty!";
  VLOG(4) << "Before FlattenLoops, ir is:\n" << loops[0];
  // compute loop
  int extent = 1;
  std::vector<int> strides;
  std::vector<ir::Var> loop_vars(loops.size());
  for (int idx = loops.size() - 1; idx >= 0; --idx) {
    strides.insert(strides.begin(), extent);
    extent *= loops[idx].As<ir::For>()->extent.as_int32();
    loop_vars[idx] = loops[idx].As<ir::For>()->loop_var;
  }
  CHECK_EQ(loops.size(), strides.size());

  // create new loop.
  auto last = loops.back().As<ir::For>();
1946
  auto var = ir::Var("flat_i");
1947
  auto _var = ir::Var("_flat_i");
1948 1949 1950 1951 1952 1953
  auto loop = ir::For::Make(var,
                            ir::Expr(0),
                            ir::Expr(extent),
                            last->for_type(),
                            last->device_api,
                            last->body);
1954 1955 1956 1957 1958 1959 1960 1961 1962 1963

  // map loop var to old loop var.
  auto _iter = ir::Expr(_var);
  std::unordered_map<std::string, ir::Expr> loops_to_flat_var_map;
  for (int idx = 0; idx < strides.size(); ++idx) {
    if (strides[idx] == 1) {
      // flat_i_to_loop_var.push_back(_iter);
      loops_to_flat_var_map[loops[idx].As<ir::For>()->loop_var->name] = _iter;
    } else {
      // flat_i_to_loop_var.push_back(_iter / Expr(strides[idx]));
1964 1965 1966
      loops_to_flat_var_map[loops[idx].As<ir::For>()->loop_var->name] =
          _iter / Expr(strides[idx]);
      _iter = _iter % Expr(strides[idx]);
1967 1968 1969 1970
    }
  }

  ir::FindBlocksVisitor visitor;
1971 1972 1973
  auto blocks = visitor(&last->body);
  auto can_do_flat = [](const std::vector<Expr>& indexs,
                        const std::vector<Var>& loop_vars) {
1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992
    if (indexs.size() != loop_vars.size()) {
      return false;
    }

    for (int idx = 0; idx < indexs.size(); ++idx) {
      if (!indexs[idx].as_var()) {
        return false;
      } else {
        auto var = indexs[idx].as_var_ref();
        if (var->name != loop_vars[idx]->name) {
          return false;
        }
      }
    }
    return true;
  };

  // change blocks iter value/iter var
  for (auto& block : blocks) {
1993
    auto block_realize = block.As<ir::ScheduleBlockRealize>();
1994 1995 1996 1997 1998 1999 2000 2001 2002
    auto schedule_block = block_realize->schedule_block.As<ir::ScheduleBlock>();

    // checkout loops in orders.
    std::vector<std::string> var_names = {};
    CHECK_GE(block_realize->iter_values.size(), loop_vars.size())
        << "the number of iter bind values must be more than loop vars!";
    for (int idx = 0; idx < block_realize->iter_values.size(); ++idx) {
      auto& iter = block_realize->iter_values[idx];
      if (iter.is_var()) {
2003 2004
        CHECK_EQ(iter.as_var_ref()->name, loop_vars[idx]->name)
            << "loops is not the same order with tensor!";
2005 2006 2007 2008 2009 2010
      } else {
        CHECK(iter.As<IntImm>());
        CHECK_EQ(iter.as_int32(), 0);
      }
    }

2011 2012 2013
    auto exprs = ir::CollectIRNodesInOrder(
        schedule_block->body,
        [&](const Expr* x) { return x->As<ir::Store>() || x->As<ir::Load>(); });
2014 2015 2016 2017 2018 2019 2020 2021 2022 2023
    // reverse exprs from last to first.
    std::reverse(std::begin(exprs), std::end(exprs));

    std::vector<ir::Var> var_to_replace;
    std::vector<ir::Expr> flat_i_to_loop_var;
    // if iter var is more than flat i to loop, there exist dim = 1.
    for (int idx = 0; idx < block_realize->iter_values.size(); ++idx) {
      if (block_realize->iter_values[idx].is_var()) {
        var_to_replace.push_back(schedule_block->iter_vars[idx]);
        auto var_name = block_realize->iter_values[idx].as_var_ref()->name;
2024 2025
        CHECK(loops_to_flat_var_map.count(var_name))
            << "Can't find var name : " << var_name;
2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041
        flat_i_to_loop_var.push_back(loops_to_flat_var_map[var_name]);
      } else {
        CHECK_EQ(block_realize->iter_values[idx].as_int32(), 0);
        // insert var -> 0, to replace var to 0.
        var_to_replace.push_back(schedule_block->iter_vars[idx]);
        flat_i_to_loop_var.push_back(Expr(0));
      }
    }
    CHECK_EQ(var_to_replace.size(), flat_i_to_loop_var.size());

    for (auto expr : exprs) {
      if (expr.As<ir::Store>()) {
        auto store = expr.As<ir::Store>();
        if (store->is_addr_tensor()) {
          auto t = store->tensor.as_tensor_ref();
          CHECK(!t->reduce_axis.size());
2042 2043 2044 2045 2046 2047 2048 2049 2050
          auto tsize = std::accumulate(t->shape.begin(),
                                       t->shape.end(),
                                       1,
                                       [](const int sum, const Expr& expr) {
                                         return sum * expr.as_int32();
                                       });
          if ((!flat_tensor &&
               !can_do_flat(store->indices, schedule_block->iter_vars)) ||
              extent != tsize) {
2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069
            // just replace indexs
            for (auto& indice : store->indices) {
              if (!indice.is_var()) {
                continue;
              }
              ReplaceExpr(&indice, var_to_replace, flat_i_to_loop_var);
            }
            // compute index and flat tensor.
            store->indices = {store->index()};
            continue;
          }
          // update var and shape
          store->indices = {Expr(_var)};
        }
      } else {
        auto load = expr.As<ir::Load>();
        if (load->is_addr_tensor()) {
          auto t = load->tensor.as_tensor_ref();
          CHECK(!t->reduce_axis.size());
2070 2071 2072 2073 2074 2075 2076 2077 2078
          auto tsize = std::accumulate(t->shape.begin(),
                                       t->shape.end(),
                                       1,
                                       [](const int sum, const Expr& expr) {
                                         return sum * expr.as_int32();
                                       });
          if ((!flat_tensor &&
               !can_do_flat(load->indices, schedule_block->iter_vars)) ||
              extent != tsize) {
2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097
            // just replace indexs
            for (auto& indice : load->indices) {
              if (!indice.is_var()) {
                continue;
              }
              ReplaceExpr(&indice, var_to_replace, flat_i_to_loop_var);
            }
            // compute index and flat tensor.
            load->indices = {load->index()};
            continue;
          }
          // update var and shape
          load->indices = {Expr(_var)};
        }
      }
    }
    ReplaceExpr(&schedule_block->body, var_to_replace, flat_i_to_loop_var);

    // update iter values
2098
    auto iter = ir::Expr(var);
2099 2100 2101 2102
    block_realize->iter_values = {iter};

    // update iter_vars
    schedule_block->iter_vars = {_var};
2103 2104
    CHECK_EQ(block_realize->iter_values.size(),
             schedule_block->iter_vars.size());
2105 2106 2107 2108 2109 2110
  }

  this->Replace(loops[0], loop);
  VLOG(4) << "After FlattenLoops, ir is:\n" << loop;
}

2111 2112 2113
void ScheduleImpl::CopyTransformAndLoopInfo(
    const std::string& block_name, const std::string& block_target_name) {
  auto block = this->GetBlock(block_name);
2114 2115 2116 2117
  auto block_target = this->GetBlock(block_target_name);
  this->CopyTransformAndLoopInfo(block, block_target);
}

2118 2119
void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
                                            const Expr& block_target) {
2120 2121 2122 2123
  CHECK(block.As<ir::ScheduleBlockRealize>());
  CHECK(block_target.As<ir::ScheduleBlockRealize>());
  auto exprs = this->GetModule().GetExprs();
  CHECK_EQ(exprs.size(), 1U);
2124 2125 2126 2127 2128 2129 2130
  auto expr = exprs[0];
  auto vars = block.As<ir::ScheduleBlockRealize>()
                  ->schedule_block.As<ir::ScheduleBlock>()
                  ->iter_vars;
  auto vars_target = block_target.As<ir::ScheduleBlockRealize>()
                         ->schedule_block.As<ir::ScheduleBlock>()
                         ->iter_vars;
2131
  auto old_iter_values = block.As<ir::ScheduleBlockRealize>()->iter_values;
2132 2133
  auto iter_values_target =
      block_target.As<ir::ScheduleBlockRealize>()->iter_values;
2134 2135
  std::vector<Expr> new_iter_values;
  for (int i = 0; i < vars.size() && i < vars_target.size(); ++i) {
2136 2137 2138 2139 2140 2141 2142
    CHECK(vars[i]->upper_bound.defined() &&
          vars_target[i]->upper_bound.defined());
    if (vars[i]->upper_bound.is_constant() &&
        vars_target[i]->upper_bound.is_constant() &&
        vars[i]->upper_bound.get_constant() ==
            vars_target[i]->upper_bound.get_constant() &&
        !vars[i]->is_reduce_axis && !vars_target[i]->is_reduce_axis) {
2143 2144 2145 2146 2147 2148 2149
      new_iter_values.push_back(iter_values_target[i]);
      VLOG(3) << "new_iter_values.push_back " << iter_values_target[i];
    } else
      break;
  }

  if (new_iter_values.empty())
2150 2151 2152 2153
    LOG(FATAL) << "Cannot CopyTransformAndLoopInfo since shape[0] of source "
                  "and target is not equal! "
               << vars[0]->upper_bound << " v.s "
               << vars_target[0]->upper_bound;
2154 2155 2156 2157

  int changed_loop_num = new_iter_values.size();
  std::set<std::string> used_target_loop_vars;
  for (auto& iter_val : new_iter_values) {
2158 2159 2160 2161 2162
    auto find_partial_loop =
        ir::CollectIRNodesWithoutTensor(iter_val, [&](const Expr* x) {
          if (x->as_var()) used_target_loop_vars.insert(x->as_var_ref()->name);
          return x->as_var();
        });
2163 2164 2165 2166 2167 2168 2169 2170
  }
  CHECK(!used_target_loop_vars.empty());
  std::vector<Expr> used_target_loops;
  auto expr_copy = optim::IRCopy(expr);
  for (auto& var : used_target_loop_vars) {
    auto find_loop_var = ir::CollectIRNodesWithoutTensor(
        expr_copy,
        [&](const Expr* x) {
2171 2172
          return x->As<ir::For>() && x->As<ir::For>()->loop_var->name == var &&
                 Contains(*x, block_target);
2173 2174 2175 2176 2177 2178
        },
        true);
    CHECK_EQ(find_loop_var.size(), 1U);
    used_target_loops.push_back(*find_loop_var.begin());
    VLOG(3) << "used_target_loops push_back " << used_target_loops.back();
  }
2179 2180 2181 2182
  std::sort(
      used_target_loops.begin(), used_target_loops.end(), [&](Expr i, Expr j) {
        return (utils::GetStreamCnt(i).size() > utils::GetStreamCnt(j).size());
      });
2183 2184 2185 2186 2187 2188 2189 2190
  for (int i = new_iter_values.size(); i < old_iter_values.size(); ++i) {
    CHECK(old_iter_values[i].as_var());
    new_iter_values.push_back(old_iter_values[i]);
  }
  Expr new_loop;
  VLOG(3) << "changed_loop_num is : " << changed_loop_num;
  VLOG(3) << "old_iter_values.size() is : " << old_iter_values.size();
  if (changed_loop_num >= (int)old_iter_values.size()) {
2191
    new_loop = optim::IRCopy(block);
2192 2193 2194
    new_loop.As<ir::ScheduleBlockRealize>()->iter_values = new_iter_values;
  } else {
    CHECK(old_iter_values[changed_loop_num].as_var());
2195
    auto old_var = old_iter_values[changed_loop_num].as_var_ref();
2196 2197 2198
    auto find_partial_loop = ir::CollectIRNodesWithoutTensor(
        expr,
        [&](const Expr* x) {
2199 2200 2201
          return x->As<ir::For>() &&
                 x->As<ir::For>()->loop_var->name == old_var->name &&
                 Contains(*x, block);
2202 2203 2204
        },
        true);
    CHECK_EQ(find_partial_loop.size(), 1U);
2205
    new_loop = optim::IRCopy(*find_partial_loop.begin());
2206
    auto find_schedule_block = ir::CollectIRNodesWithoutTensor(
2207 2208 2209
        new_loop,
        [&](const Expr* x) { return x->As<ir::ScheduleBlockRealize>(); },
        true);
2210
    CHECK_EQ(find_schedule_block.size(), 1U);
2211
    Expr sch_block = (*find_schedule_block.begin());
2212 2213 2214 2215 2216 2217 2218
    sch_block.As<ir::ScheduleBlockRealize>()->iter_values = new_iter_values;
  }
  VLOG(3) << "new_loop is : " << new_loop;
  CHECK(!used_target_loops.empty());
  Expr res;
  if (used_target_loops.size() == 1) {
    auto for_loop = used_target_loops[0].As<ir::For>();
2219
    res = For::Make(for_loop->loop_var,
2220 2221 2222 2223 2224 2225 2226 2227
                    for_loop->min,
                    for_loop->extent,
                    for_loop->for_type(),
                    for_loop->device_api,
                    new_loop,
                    for_loop->vectorize_info(),
                    for_loop->bind_info());
  } else {
2228 2229
    Expr outer_loop = used_target_loops.front();
    Expr inner_loop = used_target_loops.back();
2230
    inner_loop.As<ir::For>()->body = Block::Make({new_loop});
2231
    res = outer_loop;
2232 2233 2234 2235 2236 2237 2238
  }
  VLOG(3) << "res is : " << res;
  std::vector<Expr> all_loops = this->GetLoops(block);
  CHECK(!all_loops.empty());
  this->Replace(all_loops[0], res);
}

2239 2240 2241 2242 2243 2244 2245
std::vector<Expr> ScheduleImpl::SamplePerfectTile(
    utils::LinearRandomEngine::StateType* rand_seed,
    const Expr& loop,
    int n,
    int max_innermost_factor) {
  CHECK(loop.As<ir::For>())
      << "Expr param of SamplePerfectTile should be a For loop";
2246
  CHECK_GE(n, 2) << "The number of tile factors should be at least 2";
2247 2248 2249 2250
  CHECK_GE(max_innermost_factor, 1)
      << "The max innermost factor should be at least 1";
  CHECK(common::is_zero(loop.As<ir::For>()->min))
      << "The For loop should start from 0";
2251 2252 2253 2254 2255 2256 2257 2258
  int loop_extent = GetLoopExtent(loop);
  std::vector<int> innermost_factors;
  for (int i = max_innermost_factor; i >= 1; --i) {
    if (loop_extent % i == 0) {
      innermost_factors.push_back(i);
    }
  }
  CHECK(!innermost_factors.empty()) << "No innermost factor found";
2259 2260 2261
  int innermost_factor = innermost_factors[utils::SampleUniformInt(
      0, innermost_factors.size(), rand_seed)];
  auto result = SampleTile(rand_seed, n - 1, loop_extent / innermost_factor);
2262 2263 2264 2265 2266 2267 2268 2269
  std::vector<Expr> result_expr;
  for (auto& factor : result) {
    result_expr.push_back(Expr(factor));
  }
  result_expr.push_back(Expr(innermost_factor));
  return result_expr;
}

2270 2271 2272 2273
Expr ScheduleImpl::SampleCategorical(
    utils::LinearRandomEngine::StateType* rand_seed,
    const std::vector<int>& candidates,
    const std::vector<float>& probs) {
2274
  // check two sizes
2275 2276
  CHECK_EQ(candidates.size(), probs.size())
      << "candidates and probs must have same size.";
2277
  int seed_idx = utils::SampleDiscreteFromDistribution(probs, rand_seed);
2278
  auto result = candidates[seed_idx];
2279 2280 2281 2282 2283 2284
  Expr result_expr(result);
  return result_expr;
}

IRSchedule::IRSchedule() {}

2285 2286 2287
IRSchedule::IRSchedule(const ModuleExpr& module_expr,
                       utils::LinearRandomEngine::StateType rand_seed,
                       bool debug_flag) {
2288 2289 2290 2291
  impl_ = std::make_unique<ScheduleImpl>(module_expr, debug_flag);
  this->InitSeed(rand_seed);
}

2292 2293 2294 2295 2296
IRSchedule::IRSchedule(ir::ModuleExpr&& mod_expr,
                       ScheduleDesc&& trace,
                       utils::LinearRandomEngine::StateType rand_seed)
    : impl_(std::make_unique<ScheduleImpl>(std::move(mod_expr))),
      trace_(std::move(trace)) {
2297 2298 2299 2300
  this->InitSeed(rand_seed);
}

IRSchedule::IRSchedule(const IRSchedule& other)
2301 2302
    : impl_(std::make_unique<ScheduleImpl>(optim::IRCopy(other.GetModule()))),
      trace_(other.trace_) {
2303 2304 2305 2306
  this->InitSeed(other.ForkSeed());
}

IRSchedule& IRSchedule::operator=(const IRSchedule& src) {
2307
  impl_ = std::make_unique<ScheduleImpl>(optim::IRCopy(src.GetModule()));
2308 2309 2310 2311 2312
  trace_ = src.trace_;
  this->InitSeed(src.ForkSeed());
  return *this;
}

2313 2314
IRSchedule::IRSchedule(IRSchedule&& other)
    : impl_(std::move(other.impl_)), trace_(std::move(other.trace_)) {
2315 2316 2317 2318
  this->InitSeed(other.ForkSeed());
}

IRSchedule& IRSchedule::operator=(IRSchedule&& src) {
2319
  impl_ = std::move(src.impl_);
2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330
  trace_ = std::move(src.trace_);
  this->InitSeed(src.ForkSeed());
  return *this;
}

IRSchedule::~IRSchedule() {}

void IRSchedule::InitSeed(utils::LinearRandomEngine::StateType rand_seed) {
  this->rand_seed_ = utils::LinearRandomEngine::NormalizeState(rand_seed);
}

2331 2332 2333
utils::LinearRandomEngine::StateType IRSchedule::ForkSeed() const {
  return utils::ForkRandomState(&rand_seed_);
}
2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356

void IRSchedule::SetExprs(const std::vector<Expr>& exprs) {
  return impl_->SetExprs(exprs);
  // no need to trace
}

const ModuleExpr& IRSchedule::GetModule() const {
  return impl_->GetModule();
  // no need to trace
}

bool IRSchedule::HasBlock(const std::string& block_name) const {
  return impl_->HasBlock(block_name);
  // no need to trace
}

void IRSchedule::MergeExprs() {
  impl_->MergeExprs();
  trace_.Append(ScheduleDesc::Step("MergeExprs", {}, {}, {}));
}

std::vector<Expr> IRSchedule::GetLoops(const Expr& block) const {
  auto results = impl_->GetLoops(block);
2357 2358
  trace_.Append(ScheduleDesc::Step(
      "GetLoops", {{"block", std::vector<Expr>({block})}}, {}, results));
2359 2360 2361 2362 2363
  return results;
}

std::vector<Expr> IRSchedule::GetLoops(const std::string& block_name) const {
  auto results = impl_->GetLoops(block_name);
2364 2365
  trace_.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", block_name}}, results));
2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376
  return results;
}

std::vector<Expr> IRSchedule::GetAllBlocks() const {
  auto results = impl_->GetAllBlocks();
  trace_.Append(ScheduleDesc::Step("GetAllBlocks", {}, {}, results));
  return results;
}

std::vector<Expr> IRSchedule::GetChildBlocks(const Expr& expr) const {
  auto results = impl_->GetChildBlocks(expr);
2377 2378
  trace_.Append(ScheduleDesc::Step(
      "GetChildBlocks", {{"expr", std::vector<Expr>({expr})}}, {}, results));
2379 2380 2381 2382 2383
  return results;
}

Expr IRSchedule::GetBlock(const std::string& block_name) const {
  auto result = impl_->GetBlock(block_name);
2384 2385
  trace_.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", block_name}}, {result}));
2386 2387 2388
  return result;
}

2389 2390 2391 2392 2393
std::vector<Expr> IRSchedule::Split(const Expr& loop,
                                    const std::vector<int>& factors) {
  std::vector<Expr> decision = SamplePerfectTile(
      loop, factors.size(), loop.As<ir::For>()->extent.as_int32(), factors);
  auto results = Split(loop, decision);
2394 2395 2396
  return results;
}

2397 2398 2399
std::vector<Expr> IRSchedule::Split(const std::string& block_name,
                                    int loop_index,
                                    const std::vector<int>& factors) {
2400 2401
  std::vector<Expr> all_loops = this->GetLoops(block_name);
  Expr loop_expr;
2402 2403
  CHECK_LT(loop_index, (int)all_loops.size())
      << "The loop index in Split should be less than total loop's number.";
2404 2405 2406 2407 2408 2409
  CHECK_GE(loop_index, 0) << "The loop index in Split should be >= 0.";
  loop_expr = all_loops[loop_index];

  return this->Split(loop_expr, factors);
}

2410 2411
std::vector<Expr> IRSchedule::Split(const Expr& loop,
                                    const std::vector<Expr>& factors) {
2412
  std::vector<int> int_factors;
2413 2414 2415 2416
  std::transform(factors.begin(),
                 factors.end(),
                 std::back_inserter(int_factors),
                 [](Expr x) { return x.as_int32(); });
2417
  auto results = impl_->Split(loop, int_factors);
2418 2419 2420 2421 2422
  trace_.Append(ScheduleDesc::Step(
      "Split",
      {{"loop", std::vector<Expr>({loop})}, {"factors", factors}},
      {},
      results));
2423 2424 2425 2426 2427 2428 2429 2430 2431
  return results;
}

Expr IRSchedule::Fuse(const std::vector<Expr>& loops) {
  auto result = impl_->Fuse(loops);
  trace_.Append(ScheduleDesc::Step("Fuse", {{"loops", loops}}, {}, {result}));
  return result;
}

2432 2433
Expr IRSchedule::Fuse(const std::string& block_name,
                      const std::vector<int>& loops_index) {
2434
  auto result = impl_->Fuse(block_name, loops_index);
2435 2436 2437 2438 2439
  trace_.Append(ScheduleDesc::Step(
      "FuseWithName",
      {},
      {{"block_name", block_name}, {"loops_index", loops_index}},
      {result}));
2440 2441 2442 2443 2444
  return result;
}

Expr IRSchedule::Fuse(const Expr& block, const std::vector<int>& loops_index) {
  auto result = impl_->Fuse(block, loops_index);
2445 2446 2447 2448
  trace_.Append(ScheduleDesc::Step("FuseWithBlock",
                                   {{"block", std::vector<Expr>({block})}},
                                   {{"loops_index", loops_index}},
                                   {result}));
2449 2450 2451
  return result;
}

2452 2453 2454
void IRSchedule::ComputeAt(const Expr& block,
                           const Expr& loop,
                           bool keep_unit_loops) {
2455 2456
  impl_->ComputeAt(block, loop, keep_unit_loops);
  trace_.Append(ScheduleDesc::Step("ComputeAt",
2457 2458
                                   {{"block", std::vector<Expr>({block})},
                                    {"loop", std::vector<Expr>({loop})}},
2459 2460 2461 2462 2463 2464
                                   {{"keep_unit_loops", keep_unit_loops}},
                                   {}));
}

void IRSchedule::SimpleComputeAt(const Expr& block, const Expr& loop) {
  impl_->SimpleComputeAt(block, loop);
2465 2466 2467 2468 2469
  trace_.Append(ScheduleDesc::Step("SimpleComputeAt",
                                   {{"block", std::vector<Expr>({block})},
                                    {"loop", std::vector<Expr>({loop})}},
                                   {},
                                   {}));
2470 2471
}

2472 2473 2474
void IRSchedule::ReverseComputeAt(const Expr& block,
                                  const Expr& loop,
                                  bool keep_unit_loops) {
2475 2476
  impl_->ReverseComputeAt(block, loop, keep_unit_loops);
  trace_.Append(ScheduleDesc::Step("ReverseComputeAt",
2477 2478
                                   {{"block", std::vector<Expr>({block})},
                                    {"loop", std::vector<Expr>({loop})}},
2479 2480 2481 2482 2483 2484
                                   {{"keep_unit_loops", keep_unit_loops}},
                                   {}));
}

Expr IRSchedule::GetRootBlock(const Expr& expr) const {
  auto result = impl_->GetRootBlock(expr);
2485 2486
  trace_.Append(ScheduleDesc::Step(
      "GetRootBlock", {{"expr", std::vector<Expr>({expr})}}, {}, {result}));
2487 2488 2489
  return result;
}

2490 2491 2492
Expr IRSchedule::CacheRead(const Expr& block,
                           int read_buffer_index,
                           const std::string& memory_type) {
2493
  auto result = impl_->CacheRead(block, read_buffer_index, memory_type);
2494 2495 2496 2497 2498
  trace_.Append(ScheduleDesc::Step(
      "CacheRead",
      {{"block", std::vector<Expr>({block})}},
      {{"read_buffer_index", read_buffer_index}, {"memory_type", memory_type}},
      {result}));
2499 2500 2501
  return result;
}

2502 2503 2504
Expr IRSchedule::CacheWrite(const Expr& block,
                            int write_buffer_index,
                            const std::string& memory_type) {
2505 2506 2507
  auto result = impl_->CacheWrite(block, write_buffer_index, memory_type);
  trace_.Append(ScheduleDesc::Step("CacheWrite",
                                   {{"block", std::vector<Expr>({block})}},
2508 2509
                                   {{"write_buffer_index", write_buffer_index},
                                    {"memory_type", memory_type}},
2510 2511 2512 2513 2514 2515
                                   {result}));
  return result;
}

void IRSchedule::SyncThreads(const Expr& ir_node, bool after_node) {
  impl_->SyncThreads(ir_node, after_node);
2516 2517 2518 2519
  trace_.Append(ScheduleDesc::Step("SyncThreads",
                                   {{"ir_node", std::vector<Expr>({ir_node})}},
                                   {{"after_node", after_node}},
                                   {}));
2520 2521
}

2522 2523 2524
void IRSchedule::SetBuffer(Expr& block,
                           const std::string& memory_type,
                           bool fixed) {
2525
  impl_->SetBuffer(block, memory_type, fixed);
2526 2527 2528 2529 2530
  trace_.Append(
      ScheduleDesc::Step("SetBuffer",
                         {{"block", std::vector<Expr>({block})}},
                         {{"memory_type", memory_type}, {"fixed", fixed}},
                         {}));
2531 2532 2533 2534 2535 2536 2537 2538
}

Expr IRSchedule::Reorder(const std::vector<Expr>& loops) {
  Expr ret = impl_->Reorder(loops);
  trace_.Append(ScheduleDesc::Step("Reorder", {{"loops", loops}}, {}, {ret}));
  return ret;
}

2539 2540
Expr IRSchedule::Reorder(const std::string& block_name,
                         const std::vector<int>& loops_index) {
2541
  Expr ret = impl_->Reorder(block_name, loops_index);
2542 2543 2544 2545 2546
  trace_.Append(ScheduleDesc::Step(
      "ReorderWithName",
      {},
      {{"block_name", block_name}, {"loops_index", loops_index}},
      {ret}));
2547 2548 2549
  return ret;
}

2550 2551
Expr IRSchedule::Reorder(const Expr& block,
                         const std::vector<int>& loops_index) {
2552
  Expr ret = impl_->Reorder(block, loops_index);
2553 2554 2555 2556
  trace_.Append(ScheduleDesc::Step("ReorderWithBlock",
                                   {{"block", std::vector<Expr>({block})}},
                                   {{"loops_index", loops_index}},
                                   {ret}));
2557 2558 2559 2560 2561
  return ret;
}

void IRSchedule::Parallel(const Expr& loop) {
  impl_->Parallel(loop);
2562 2563
  trace_.Append(ScheduleDesc::Step(
      "Parallel", {{"loop", std::vector<Expr>({loop})}}, {}, {}));
2564 2565 2566 2567
}

void IRSchedule::Vectorize(const Expr& loop, int factor) {
  impl_->Vectorize(loop, factor);
2568 2569 2570 2571
  trace_.Append(ScheduleDesc::Step("Vectorize",
                                   {{"loop", std::vector<Expr>({loop})}},
                                   {{"factor", factor}},
                                   {}));
2572 2573 2574 2575
}

void IRSchedule::Unroll(const Expr& loop) {
  impl_->Unroll(loop);
2576 2577
  trace_.Append(ScheduleDesc::Step(
      "Unroll", {{"loop", std::vector<Expr>({loop})}}, {}, {}));
2578 2579 2580 2581
}

void IRSchedule::ComputeInline(const Expr& schedule_block) {
  impl_->ComputeInline(schedule_block);
2582 2583 2584 2585 2586
  trace_.Append(ScheduleDesc::Step(
      "ComputeInline",
      {{"schedule_block", std::vector<Expr>({schedule_block})}},
      {},
      {}));
2587 2588 2589 2590
}

void IRSchedule::ReverseComputeInline(const Expr& schedule_block) {
  impl_->ReverseComputeInline(schedule_block);
2591 2592 2593 2594 2595
  trace_.Append(ScheduleDesc::Step(
      "ReverseComputeInline",
      {{"schedule_block", std::vector<Expr>({schedule_block})}},
      {},
      {}));
2596 2597 2598 2599
}

void IRSchedule::Bind(const Expr& loop, const std::string& thread_axis) {
  impl_->Bind(loop, thread_axis);
2600 2601 2602 2603
  trace_.Append(ScheduleDesc::Step("Bind",
                                   {{"loop", std::vector<Expr>({loop})}},
                                   {{"thread_axis", thread_axis}},
                                   {}));
2604 2605 2606 2607
}

Expr IRSchedule::Rfactor(const Expr& rf_loop, int rf_axis) {
  auto result = impl_->Rfactor(rf_loop, rf_axis);
2608 2609 2610 2611
  trace_.Append(ScheduleDesc::Step("Rfactor",
                                   {{"rf_loop", std::vector<Expr>({rf_loop})}},
                                   {{"rf_axis", rf_axis}},
                                   {result}));
2612 2613 2614
  return result;
}

2615 2616 2617
void IRSchedule::Annotate(const Expr& block,
                          const std::string& key,
                          const attr_t& value) {
2618 2619
  impl_->Annotate(block, key, value);

2620 2621 2622 2623 2624 2625 2626 2627
#define TRACE_ANNOTATE_ITEM(data_type, step_name)               \
  if (absl::holds_alternative<data_type>(value)) {              \
    trace_.Append(ScheduleDesc::Step(                           \
        #step_name,                                             \
        {{"block", std::vector<Expr>({block})}},                \
        {{"key", key}, {"value", absl::get<data_type>(value)}}, \
        {}));                                                   \
    return;                                                     \
2628 2629 2630 2631 2632 2633 2634 2635 2636 2637 2638 2639
  }
  TRACE_ANNOTATE_ITEM(int, AnnotateIntAttr)
  TRACE_ANNOTATE_ITEM(bool, AnnotateBoolAttr)
  TRACE_ANNOTATE_ITEM(float, AnnotateFloatAttr)
  TRACE_ANNOTATE_ITEM(std::string, AnnotateStringAttr)
#undef TRACE_ANNOTATE_ITEM

  LOG(FATAL) << "Value of attribute:" << key << " input unsupported data type";
}

void IRSchedule::Unannotate(Expr& block, const std::string& key) {
  impl_->Unannotate(block, key);
2640 2641 2642 2643
  trace_.Append(ScheduleDesc::Step("Unannotate",
                                   {{"block", std::vector<Expr>({block})}},
                                   {{"key", key}},
                                   {}));
2644 2645
}

2646 2647
void IRSchedule::FlattenLoops(const std::vector<Expr>& loops,
                              const bool force_flat) {
2648
  impl_->FlattenLoops(loops, force_flat);
2649 2650 2651 2652
  trace_.Append(ScheduleDesc::Step("FlattenLoops",
                                   {{"loop", std::vector<Expr>({loops})}},
                                   {{"force_flat", force_flat}},
                                   {}));
2653 2654
}

2655 2656
void IRSchedule::CopyTransformAndLoopInfo(const Expr& block,
                                          const Expr& block_target) {
2657
  impl_->CopyTransformAndLoopInfo(block, block_target);
2658 2659
  // don't support to trace, because we can't ensure both blocks are from the
  // same ModuleExpr
2660 2661
}

2662 2663
void IRSchedule::CopyTransformAndLoopInfo(
    const std::string& block_name, const std::string& block_target_name) {
2664
  impl_->CopyTransformAndLoopInfo(block_name, block_target_name);
2665 2666
  // don't support to trace, because we can't ensure both blocks are from the
  // same ModuleExpr
2667 2668
}

2669 2670 2671 2672 2673
std::vector<Expr> IRSchedule::SamplePerfectTile(
    const Expr& loop,
    int n,
    int max_innermost_factor,
    const std::vector<int>& decision) {
2674 2675 2676
  std::vector<Expr> factors;
  std::vector<int> new_decision;
  if (decision.empty()) {
2677 2678 2679 2680 2681 2682
    factors =
        impl_->SamplePerfectTile(&rand_seed_, loop, n, max_innermost_factor);
    std::transform(factors.begin(),
                   factors.end(),
                   std::back_inserter(new_decision),
                   [](Expr x) { return x.as_int32(); });
2683 2684
  } else {
    new_decision = decision;
2685 2686 2687 2688
    std::transform(decision.begin(),
                   decision.end(),
                   std::back_inserter(factors),
                   [](int x) { return Expr(x); });
2689 2690 2691 2692
  }
  trace_.Append(
      ScheduleDesc::Step("SamplePerfectTile",
                         {{"loop", std::vector<Expr>({loop})}},
2693 2694 2695
                         {{"n", n},
                          {"max_innermost_factor", max_innermost_factor},
                          {"decision", new_decision}},
2696 2697 2698 2699
                         factors));
  return factors;
}

2700 2701 2702
void IRSchedule::TagPostSchedule() {
  trace_.Append(ScheduleDesc::Step("TagPostSchedule", {}, {}, {}));
}
2703 2704 2705 2706 2707 2708 2709 2710 2711 2712 2713 2714 2715 2716 2717

Expr IRSchedule::SampleCategorical(const std::vector<int>& candidates,
                                   const std::vector<float>& probs,
                                   const std::vector<int>& decision) {
  Expr result;
  std::vector<int> new_decision;
  if (decision.empty()) {
    result = impl_->SampleCategorical(&rand_seed_, candidates, probs);
    new_decision.push_back(result.as_int32());
  } else {
    new_decision = decision;
    for (auto ndco : new_decision) {
      result = Expr(ndco);
    }
  }
2718 2719 2720 2721 2722 2723
  trace_.Append(ScheduleDesc::Step("SampleCategorical",
                                   {},
                                   {{"candidates", candidates},
                                    {"probs", probs},
                                    {"decision", new_decision}},
                                   {result}));
2724 2725 2726 2727 2728
  return result;
}

}  // namespace ir
}  // namespace cinn