schedule_desc_test.cc 40.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/ir/schedule_desc.h"

#include <glog/logging.h>
#include <gtest/gtest.h>

#include "paddle/cinn/cinn.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_schedule.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/optim/ir_copy.h"
#include "paddle/cinn/utils/string.h"
#include "paddle/cinn/utils/type_defs.h"

namespace cinn {
namespace ir {

// Return lowerd ir AST for example functions used in this test
33 34 35 36 37
std::vector<ir::LoweredFunc> LowerCompute(
    const std::vector<int>& shape,
    const Target& target,
    bool need_c = false,
    const std::string& operation = "elementwise-copy") {
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
  CHECK(shape.size() == 2 || shape.size() == 3) << "shape should be 2 or 3";
  std::vector<Expr> domain;
  for (auto i = 0; i < shape.size(); ++i) {
    domain.emplace_back(shape[i]);
  }

  Placeholder<float> A("A", domain);
  ir::Tensor B, C;

  if (operation == "elementwise-copy") {
    if (domain.size() == 2) {
      B = Compute(
          domain, [&A](Var i, Var j) { return A(i, j); }, "B");
      C = Compute(
          domain, [&B](Var i, Var j) { return B(i, j); }, "C");
    } else {
      B = Compute(
          domain, [&A](Var i, Var j, Var k) { return A(i, j, k); }, "B");
      C = Compute(
          domain, [&B](Var i, Var j, Var k) { return B(i, j, k); }, "C");
    }
  }

  if (operation == "elementwise-add_const") {
    if (domain.size() == 2) {
      B = Compute(
          domain, [&A](Var i, Var j) { return A(i, j) * Expr(2.f); }, "B");
      C = Compute(
          domain, [&B](Var i, Var j) { return B(i, j) + Expr(1.f); }, "C");
    } else {
      B = Compute(
69 70 71
          domain,
          [&A](Var i, Var j, Var k) { return A(i, j, k) * Expr(2.f); },
          "B");
72
      C = Compute(
73 74 75
          domain,
          [&B](Var i, Var j, Var k) { return B(i, j, k) + Expr(1.f); },
          "C");
76 77 78 79
    }
  }

  if (need_c) {
80 81 82 83 84 85 86 87
    return cinn::lang::LowerVec("test_func",
                                CreateStages({A, B, C}),
                                {A, C},
                                {},
                                {},
                                nullptr,
                                target,
                                true);
88 89
  }

90 91
  return cinn::lang::LowerVec(
      "test_func", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true);
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
}

// Create a new IRSchedule with copied ir::LoweredFunc AST
IRSchedule MakeIRSchedule(const std::vector<ir::LoweredFunc>& lowered_funcs) {
  std::vector<Expr> exprs;
  for (auto&& func : lowered_funcs) {
    exprs.emplace_back(optim::IRCopy(func->body));
  }
  return ir::IRSchedule(ir::ModuleExpr(exprs));
}

// Generate source code with transformed ModuleExpr
std::string SourceCodeGen(const ModuleExpr& module_expr,
                          const std::vector<ir::LoweredFunc>& lowered_funcs,
                          const Target& target) {
  auto exprs = module_expr.GetExprs();
  CHECK_EQ(exprs.size(), lowered_funcs.size()) << "size of func is not euqal";
  std::vector<ir::LoweredFunc> updated_funcs = optim::IRCopy(lowered_funcs);
  Module::Builder builder("test_module", target);
  for (auto i = 0; i < lowered_funcs.size(); ++i) {
    updated_funcs[i]->body = optim::IRCopy(exprs.at(i));
    builder.AddFunction(updated_funcs[i]);
  }
  auto module = builder.Build();
  CodeGenC codegen(target);
  codegen.SetInlineBuiltinCodes(false);
  return codegen.Compile(module, CodeGenC::OutputKind::CImpl);
}

class TestScheduleDesc : public ::testing::Test {
 public:
  Target target = common::DefaultHostTarget();
  std::vector<ir::LoweredFunc> lowered_funcs;
  ScheduleDesc trace;
  void SetUp() override { Context::Global().ResetNameId(); }

128 129
  void CheckTracingOutputs(const std::vector<Expr>& base,
                           const ScheduleDesc& trace_desc) {
130 131
    Context::Global().ResetNameId();
    ir::IRSchedule replay_sch = MakeIRSchedule(lowered_funcs);
132 133
    auto traced_outputs =
        ScheduleDesc::ReplayWithProto(trace_desc.ToProto(), &replay_sch);
134 135
    ASSERT_EQ(base.size(), traced_outputs.size());
    for (auto i = 0; i < base.size(); ++i) {
136 137
      ASSERT_EQ(utils::GetStreamCnt(base.at(i)),
                utils::GetStreamCnt(traced_outputs.at(i)));
138 139 140
    }
  }

141 142
  void CheckReplayResult(const ir::IRSchedule& ir_sch,
                         const ScheduleDesc& trace_desc) {
143 144 145 146 147 148 149 150 151 152
    Context::Global().ResetNameId();
    ir::IRSchedule replay_sch = MakeIRSchedule(lowered_funcs);
    trace_desc.Replay(&replay_sch);

    // check the equality of module expr between original schedule
    // and the schedule generated by replaying with tracing ScheduleDesc
    auto lhs_exprs = ir_sch.GetModule().GetExprs();
    auto rhs_exprs = replay_sch.GetModule().GetExprs();
    ASSERT_EQ(lhs_exprs.size(), rhs_exprs.size());
    for (auto i = 0; i < lhs_exprs.size(); ++i) {
153 154
      ASSERT_EQ(utils::GetStreamCnt(lhs_exprs.at(i)),
                utils::GetStreamCnt(rhs_exprs.at(i)));
155 156 157
    }

    // check the equality of source code between them
158 159 160 161
    ASSERT_EQ(
        utils::Trim(SourceCodeGen(ir_sch.GetModule(), lowered_funcs, target)),
        utils::Trim(
            SourceCodeGen(replay_sch.GetModule(), lowered_funcs, target)));
162 163 164 165
  }
};

TEST_F(TestScheduleDesc, Append_Replay) {
166
  lowered_funcs = LowerCompute({32, 32}, target);
167 168 169
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto fused = ir_sch.Fuse("B", {0, 1});
170 171 172 173 174
  trace.Append(ScheduleDesc::Step("FuseWithName",
                                  {},
                                  {{"block_name", std::string("B")},
                                   {"loops_index", std::vector<int>({0, 1})}},
                                  {fused}));
175 176 177
  auto sample = ir_sch.SamplePerfectTile(fused, 2, 1, {4, -1});
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
                                  {{"loop", std::vector<Expr>({fused})}},
178 179 180
                                  {{"n", 2},
                                   {"max_innermost_factor", 1},
                                   {"decision", std::vector<int>{4, -1}}},
181 182
                                  sample));
  auto splited = ir_sch.Split(fused, sample);
183 184 185 186 187
  trace.Append(ScheduleDesc::Step(
      "Split",
      {{"loop", std::vector<Expr>({fused})}, {"factors", sample}},
      {},
      splited));
188 189

  auto loops = ir_sch.GetLoops("B");
190 191
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
192 193 194 195 196
  fused = ir_sch.Fuse(loops);
  trace.Append(ScheduleDesc::Step("Fuse", {{"loops", loops}}, {}, {fused}));
  sample = ir_sch.SamplePerfectTile(fused, 2, 1, {256, -1});
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
                                  {{"loop", std::vector<Expr>({fused})}},
197 198 199
                                  {{"n", 2},
                                   {"max_innermost_factor", 1},
                                   {"decision", std::vector<int>{256, -1}}},
200 201
                                  sample));
  splited = ir_sch.Split(fused, sample);
202 203 204 205 206
  trace.Append(ScheduleDesc::Step(
      "Split",
      {{"loop", std::vector<Expr>({fused})}, {"factors", sample}},
      {},
      splited));
207 208 209 210 211 212 213 214 215

  // check the equality of results between the ir_sch and replaying of trace
  CheckTracingOutputs(splited, trace);
  CheckReplayResult(ir_sch, trace);
  // check the equality of results between the ir_sch and replaying of its trace
  CheckTracingOutputs(splited, ir_sch.GetTraceDesc());
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

216 217
// Test cases with `StepKind` prefix are to check the correctness of their
// StepKindInfo register
218
TEST_F(TestScheduleDesc, StepKind_GetAllBlocks) {
219
  lowered_funcs = LowerCompute({32, 32}, target);
220 221 222 223 224 225 226 227 228
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto all_blocks = ir_sch.GetAllBlocks();
  trace.Append(ScheduleDesc::Step("GetAllBlocks", {}, {}, {all_blocks}));
  CheckTracingOutputs(all_blocks, trace);
  CheckTracingOutputs(all_blocks, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_GetChildBlocks) {
229
  lowered_funcs = LowerCompute({32, 32, 64}, target, true);
230 231 232
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
233 234
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
235
  auto loops = ir_sch.GetLoops("C");
236 237
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops));
238 239
  ir_sch.ComputeAt(block_b, loops[1]);
  trace.Append(ScheduleDesc::Step("ComputeAt",
240 241
                                  {{"block", std::vector<Expr>({block_b})},
                                   {"loop", std::vector<Expr>({loops[1]})}},
242 243 244
                                  {{"keep_unit_loops", false}},
                                  {}));
  loops = ir_sch.GetLoops("B");
245 246
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
247
  auto root_block = ir_sch.GetRootBlock(loops[1]);
248 249 250 251
  trace.Append(ScheduleDesc::Step("GetRootBlock",
                                  {{"expr", std::vector<Expr>({loops[1]})}},
                                  {},
                                  {root_block}));
252
  auto childblocks = ir_sch.GetChildBlocks(root_block);
253 254 255 256
  trace.Append(ScheduleDesc::Step("GetChildBlocks",
                                  {{"expr", std::vector<Expr>({root_block})}},
                                  {},
                                  childblocks));
257 258 259 260 261
  CheckTracingOutputs(childblocks, trace);
  CheckTracingOutputs(childblocks, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_GetLoops) {
262
  lowered_funcs = LowerCompute({32, 32}, target);
263 264 265
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
266 267
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
268
  auto loops = ir_sch.GetLoops(block_b);
269 270
  trace.Append(ScheduleDesc::Step(
      "GetLoops", {{"block", std::vector<Expr>({block_b})}}, {}, loops));
271 272 273 274 275
  CheckTracingOutputs(loops, trace);
  CheckTracingOutputs(loops, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_GetLoopsWithName) {
276
  lowered_funcs = LowerCompute({32, 32}, target);
277 278 279
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
280 281
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
282 283 284 285 286
  CheckTracingOutputs(loops, trace);
  CheckTracingOutputs(loops, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_GetBlock) {
287
  lowered_funcs = LowerCompute({32, 32, 32}, target);
288 289 290
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
291 292
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
293 294 295
  CheckTracingOutputs({block_b}, trace);
  CheckTracingOutputs({block_b}, ir_sch.GetTraceDesc());
}
296 297
// TODO(SunNy820828449): fix in future, as fix split var name, this case some
// problem.
298 299 300 301 302 303 304 305 306
/*
TEST_F(TestScheduleDesc, StepKind_Split) {
  lowered_funcs                         = LowerCompute({32, 32, 32}, target);
  ir::IRSchedule ir_sch_split_base      = MakeIRSchedule(lowered_funcs);
  ir::IRSchedule ir_sch_split           = MakeIRSchedule(lowered_funcs);
  ir::IRSchedule ir_sch_split_with_name = MakeIRSchedule(lowered_funcs);

  // test split with inputs of Expr
  auto loops = ir_sch_split_base.GetLoops("B");
307 308 309
  trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name",
std::string("B")}}, loops)); auto sample =
ir_sch_split_base.SamplePerfectTile(loops.front(), 2, 1, {4, -1});
310
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
311 312 313 314 315 316 317
                                  {{"loop",
std::vector<Expr>({loops.front()})}},
                                  {{"n", 2}, {"max_innermost_factor", 1},
{"decision", std::vector<int>{4, -1}}}, sample)); auto splited =
ir_sch_split_base.Split(loops.front(), sample); trace.Append(
      ScheduleDesc::Step("Split", {{"loop", std::vector<Expr>({loops.front()})},
{"factors", sample}}, {}, splited)); CheckTracingOutputs(splited, trace);
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
  CheckTracingOutputs(splited, ir_sch_split_base.GetTraceDesc());

  // test split with inputs of int
  loops   = ir_sch_split.GetLoops("B");
  splited = ir_sch_split.Split(loops.front(), {4, -1});
  CheckTracingOutputs(splited, trace);
  CheckTracingOutputs(splited, ir_sch_split.GetTraceDesc());

  // test split with block name and inputs of int
  splited = ir_sch_split_with_name.Split("B", 0, {4, -1});
  CheckTracingOutputs(splited, trace);
  CheckTracingOutputs(splited, ir_sch_split_with_name.GetTraceDesc());
}
*/
TEST_F(TestScheduleDesc, StepKind_Fuse) {
333
  lowered_funcs = LowerCompute({32, 32, 64}, target);
334 335 336
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
337 338
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
339 340 341 342 343 344 345
  auto fused = ir_sch.Fuse(loops);
  trace.Append(ScheduleDesc::Step("Fuse", {{"loops", loops}}, {}, {fused}));
  CheckTracingOutputs({fused}, trace);
  CheckTracingOutputs({fused}, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_FuseWithName) {
346
  lowered_funcs = LowerCompute({32, 32, 64}, target);
347 348 349
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto fused = ir_sch.Fuse("B", {0, 1, 2});
350 351 352 353 354 355
  trace.Append(
      ScheduleDesc::Step("FuseWithName",
                         {},
                         {{"block_name", std::string("B")},
                          {"loops_index", std::vector<int>({0, 1, 2})}},
                         {fused}));
356 357 358 359 360
  CheckTracingOutputs({fused}, trace);
  CheckTracingOutputs({fused}, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_FuseWithBlock) {
361
  lowered_funcs = LowerCompute({32, 32, 64}, target);
362 363 364
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
365 366
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
367
  auto fused = ir_sch.Fuse(block_b, {0, 1, 2});
368 369 370 371 372
  trace.Append(
      ScheduleDesc::Step("FuseWithBlock",
                         {{"block", std::vector<Expr>({block_b})}},
                         {{"loops_index", std::vector<int>({0, 1, 2})}},
                         {fused}));
373 374 375 376 377
  CheckTracingOutputs({fused}, trace);
  CheckTracingOutputs({fused}, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_ComputeAt) {
378
  lowered_funcs = LowerCompute({32, 32, 64}, target, true);
379 380 381
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
382 383
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
384
  auto loops = ir_sch.GetLoops("C");
385 386
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops));
387 388
  ir_sch.ComputeAt(block_b, loops[1]);
  trace.Append(ScheduleDesc::Step("ComputeAt",
389 390
                                  {{"block", std::vector<Expr>({block_b})},
                                   {"loop", std::vector<Expr>({loops[1]})}},
391 392 393 394 395 396 397
                                  {{"keep_unit_loops", false}},
                                  {}));
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_SimpleComputeAt) {
398
  lowered_funcs = LowerCompute({32, 32, 64}, target, true);
399 400 401
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
402 403
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
404
  auto loops = ir_sch.GetLoops("C");
405 406
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops));
407 408
  ir_sch.SimpleComputeAt(block_b, loops[2]);
  trace.Append(ScheduleDesc::Step("SimpleComputeAt",
409 410
                                  {{"block", std::vector<Expr>({block_b})},
                                   {"loop", std::vector<Expr>({loops[2]})}},
411 412 413 414 415 416 417
                                  {{"keep_unit_loops", false}},
                                  {}));
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_ReverseComputeAt) {
418
  lowered_funcs = LowerCompute({32, 32, 64}, target, true);
419 420 421
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_c = ir_sch.GetBlock("C");
422 423
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("C")}}, {block_c}));
424
  auto loops = ir_sch.GetLoops("B");
425 426
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
427 428
  ir_sch.ReverseComputeAt(block_c, loops[1]);
  trace.Append(ScheduleDesc::Step("ReverseComputeAt",
429 430
                                  {{"block", std::vector<Expr>({block_c})},
                                   {"loop", std::vector<Expr>({loops[1]})}},
431 432 433 434 435 436 437
                                  {{"keep_unit_loops", false}},
                                  {}));
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_GetRootBlock) {
438
  lowered_funcs = LowerCompute({32, 64}, target);
439 440 441
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
442 443
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
444
  auto root_b = ir_sch.GetRootBlock(loops[1]);
445 446
  trace.Append(ScheduleDesc::Step(
      "GetRootBlock", {{"expr", std::vector<Expr>({loops[1]})}}, {}, {root_b}));
447 448 449 450 451
  CheckTracingOutputs({root_b}, trace);
  CheckTracingOutputs({root_b}, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_CacheRead) {
452 453
  lowered_funcs =
      LowerCompute({32, 64}, target, false, "elementwise-add_const");
454 455 456
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
457 458
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
459
  auto a_cache = ir_sch.CacheRead(block_b, 0, "local");
460 461 462 463 464
  trace.Append(ScheduleDesc::Step(
      "CacheRead",
      {{"block", std::vector<Expr>({block_b})}},
      {{"read_buffer_index", 0}, {"memory_type", std::string("local")}},
      {a_cache}));
465 466 467 468 469 470 471
  CheckTracingOutputs({a_cache}, trace);
  CheckTracingOutputs({a_cache}, ir_sch.GetTraceDesc());
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_CacheWrite) {
472 473
  lowered_funcs =
      LowerCompute({32, 64}, target, false, "elementwise-add_const");
474 475 476
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
477 478
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
479
  auto b_cache = ir_sch.CacheWrite(block_b, 0, "local");
480 481 482 483 484
  trace.Append(ScheduleDesc::Step(
      "CacheWrite",
      {{"block", std::vector<Expr>({block_b})}},
      {{"write_buffer_index", 0}, {"memory_type", std::string("local")}},
      {b_cache}));
485 486 487 488 489 490 491
  CheckTracingOutputs({b_cache}, trace);
  CheckTracingOutputs({b_cache}, ir_sch.GetTraceDesc());
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_SyncThreads) {
492
  lowered_funcs = LowerCompute({64, 32}, target, true, "elementwise-add_const");
493 494 495
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
496 497
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
498
  auto b_cache = ir_sch.CacheWrite(block_b, 0, "local");
499 500 501 502 503
  trace.Append(ScheduleDesc::Step(
      "CacheWrite",
      {{"block", std::vector<Expr>({block_b})}},
      {{"write_buffer_index", 0}, {"memory_type", std::string("local")}},
      {b_cache}));
504
  auto block_c = ir_sch.GetBlock("C");
505 506
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("C")}}, {block_c}));
507
  auto c_cache = ir_sch.CacheWrite(block_c, 0, "local");
508 509 510 511 512
  trace.Append(ScheduleDesc::Step(
      "CacheWrite",
      {{"block", std::vector<Expr>({block_c})}},
      {{"write_buffer_index", 0}, {"memory_type", std::string("local")}},
      {c_cache}));
513
  block_c = ir_sch.GetBlock("C");
514 515
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("C")}}, {block_c}));
516
  ir_sch.SyncThreads(block_c, false);
517 518 519 520
  trace.Append(ScheduleDesc::Step("SyncThreads",
                                  {{"ir_node", std::vector<Expr>({block_c})}},
                                  {{"after_node", false}},
                                  {}));
521
  block_b = ir_sch.GetBlock("B");
522 523
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
524
  ir_sch.SyncThreads(block_b);
525 526 527 528
  trace.Append(ScheduleDesc::Step("SyncThreads",
                                  {{"ir_node", std::vector<Expr>({block_b})}},
                                  {{"after_node", true}},
                                  {}));
529 530 531 532 533 534

  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_SetBuffer) {
535 536
  lowered_funcs =
      LowerCompute({32, 64}, target, false, "elementwise-add_const");
537 538 539
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
540 541
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
542
  ir_sch.SetBuffer(block_b, "shared", true);
543 544 545 546 547
  trace.Append(ScheduleDesc::Step(
      "SetBuffer",
      {{"block", std::vector<Expr>({block_b})}},
      {{"memory_type", std::string("shared")}, {"fixed", true}},
      {}));
548 549 550 551 552
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_Reorder) {
553
  lowered_funcs = LowerCompute({32, 64, 12}, target);
554 555 556
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
557 558
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
559 560 561
  auto sample = ir_sch.SamplePerfectTile(loops[0], 2, 1, {-1, 4});
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
                                  {{"loop", std::vector<Expr>({loops[0]})}},
562 563 564
                                  {{"n", 2},
                                   {"max_innermost_factor", 1},
                                   {"decision", std::vector<int>{-1, 4}}},
565 566
                                  sample));
  auto splited = ir_sch.Split(loops[0], sample);
567 568 569 570 571
  trace.Append(ScheduleDesc::Step(
      "Split",
      {{"loop", std::vector<Expr>({loops[0]})}, {"factors", sample}},
      {},
      splited));
572 573

  loops = ir_sch.GetLoops("B");
574 575
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
576 577 578
  sample = ir_sch.SamplePerfectTile(loops[2], 2, 1, {-1, 2});
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
                                  {{"loop", std::vector<Expr>({loops[2]})}},
579 580 581
                                  {{"n", 2},
                                   {"max_innermost_factor", 1},
                                   {"decision", std::vector<int>{-1, 2}}},
582 583
                                  sample));
  splited = ir_sch.Split(loops[2], sample);
584 585 586 587 588
  trace.Append(ScheduleDesc::Step(
      "Split",
      {{"loop", std::vector<Expr>({loops[2]})}, {"factors", sample}},
      {},
      splited));
589 590

  loops = ir_sch.GetLoops("B");
591 592
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
593
  Expr ret = ir_sch.Reorder({loops[4], loops[0]});
594 595 596 597 598
  trace.Append(
      ScheduleDesc::Step("Reorder",
                         {{"loops", std::vector<Expr>({loops[4], loops[0]})}},
                         {},
                         {ret}));
599 600 601 602 603
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_ReorderWithBlock) {
604
  lowered_funcs = LowerCompute({32, 32, 64}, target);
605
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);
606 607 608
  auto loops = ir_sch.GetLoops("B");
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
609 610 611
  auto sample = ir_sch.SamplePerfectTile(loops[0], 2, 1, {-1, 4});
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
                                  {{"loop", std::vector<Expr>({loops[0]})}},
612 613 614
                                  {{"n", 2},
                                   {"max_innermost_factor", 1},
                                   {"decision", std::vector<int>{-1, 4}}},
615 616
                                  sample));
  auto splited = ir_sch.Split(loops[0], sample);
617 618 619 620 621
  trace.Append(ScheduleDesc::Step(
      "Split",
      {{"loop", std::vector<Expr>({loops[0]})}, {"factors", sample}},
      {},
      splited));
622 623

  loops = ir_sch.GetLoops("B");
624 625
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
626 627 628
  sample = ir_sch.SamplePerfectTile(loops[2], 2, 1, {-1, 2});
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
                                  {{"loop", std::vector<Expr>({loops[2]})}},
629 630 631
                                  {{"n", 2},
                                   {"max_innermost_factor", 1},
                                   {"decision", std::vector<int>{-1, 2}}},
632 633
                                  sample));
  splited = ir_sch.Split(loops[2], sample);
634 635 636 637 638
  trace.Append(ScheduleDesc::Step(
      "Split",
      {{"loop", std::vector<Expr>({loops[2]})}, {"factors", sample}},
      {},
      splited));
639 640

  auto block_b = ir_sch.GetBlock("B");
641 642
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
643
  Expr ret = ir_sch.Reorder("B", {2, 3, 1, 4, 0});
644 645 646 647 648
  trace.Append(
      ScheduleDesc::Step("ReorderWithBlock",
                         {{"block", std::vector<Expr>({block_b})}},
                         {{"loops_index", std::vector<int>({2, 3, 1, 4, 0})}},
                         {ret}));
649 650 651 652 653
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_ReorderWithName) {
654
  lowered_funcs = LowerCompute({32, 32, 64}, target);
655 656 657
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
658 659
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
660 661 662
  auto sample = ir_sch.SamplePerfectTile(loops[0], 2, 1, {-1, 4});
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
                                  {{"loop", std::vector<Expr>({loops[0]})}},
663 664 665
                                  {{"n", 2},
                                   {"max_innermost_factor", 1},
                                   {"decision", std::vector<int>{-1, 4}}},
666 667
                                  sample));
  auto splited = ir_sch.Split(loops[0], sample);
668 669 670 671 672
  trace.Append(ScheduleDesc::Step(
      "Split",
      {{"loop", std::vector<Expr>({loops[0]})}, {"factors", sample}},
      {},
      splited));
673 674

  loops = ir_sch.GetLoops("B");
675 676
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
677 678 679
  sample = ir_sch.SamplePerfectTile(loops[2], 2, 1, {-1, 2});
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
                                  {{"loop", std::vector<Expr>({loops[2]})}},
680 681 682
                                  {{"n", 2},
                                   {"max_innermost_factor", 1},
                                   {"decision", std::vector<int>{-1, 2}}},
683 684
                                  sample));
  splited = ir_sch.Split(loops[2], sample);
685 686 687 688 689
  trace.Append(ScheduleDesc::Step(
      "Split",
      {{"loop", std::vector<Expr>({loops[2]})}, {"factors", sample}},
      {},
      splited));
690 691 692 693 694

  Expr ret = ir_sch.Reorder("B", {4, 2, 3, 1, 0});
  trace.Append(
      ScheduleDesc::Step("ReorderWithName",
                         {},
695 696
                         {{"block_name", std::string("B")},
                          {"loops_index", std::vector<int>({4, 2, 3, 1, 0})}},
697 698 699 700 701 702
                         {ret}));
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_Parallel) {
703
  lowered_funcs = LowerCompute({32, 64}, target);
704 705 706
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
707 708
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
709
  ir_sch.Parallel(loops[0]);
710 711
  trace.Append(ScheduleDesc::Step(
      "Parallel", {{"loop", std::vector<Expr>({loops[0]})}}, {}, {}));
712 713 714 715 716
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_Vectorize) {
717
  lowered_funcs = LowerCompute({32, 64}, target);
718 719 720
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
721 722
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
723
  ir_sch.Vectorize(loops[1], 16);
724 725 726 727
  trace.Append(ScheduleDesc::Step("Vectorize",
                                  {{"loop", std::vector<Expr>({loops[1]})}},
                                  {{"factor", 16}},
                                  {}));
728 729 730 731 732
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_Unroll) {
733
  lowered_funcs = LowerCompute({32, 2}, target);
734 735 736
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
737 738
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
739
  ir_sch.Unroll(loops[1]);
740 741
  trace.Append(ScheduleDesc::Step(
      "Unroll", {{"loop", std::vector<Expr>({loops[1]})}}, {}, {}));
742 743 744 745 746
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_ComputeInline) {
747 748
  lowered_funcs =
      LowerCompute({32, 32, 32}, target, true, "elementwise-add_const");
749 750 751
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
752 753
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
754
  ir_sch.ComputeInline(block_b);
755 756 757 758 759
  trace.Append(
      ScheduleDesc::Step("ComputeInline",
                         {{"schedule_block", std::vector<Expr>({block_b})}},
                         {},
                         {}));
760 761 762 763 764
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_ReverseComputeInline) {
765 766
  lowered_funcs =
      LowerCompute({32, 32, 32}, target, true, "elementwise-add_const");
767
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);
768 769 770
  auto block_c = ir_sch.GetBlock("C");
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("C")}}, {block_c}));
771
  ir_sch.ReverseComputeInline(block_c);
772 773 774 775 776
  trace.Append(
      ScheduleDesc::Step("ReverseComputeInline",
                         {{"schedule_block", std::vector<Expr>({block_c})}},
                         {},
                         {}));
777 778 779 780 781
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_Bind) {
782
  lowered_funcs = LowerCompute({32, 128}, target);
783 784 785 786
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
  trace.Append(ScheduleDesc::Step(
787 788 789 790 791 792
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
  ir_sch.Bind(loops[0], "blockIdx.x");
  trace.Append(ScheduleDesc::Step("Bind",
                                  {{"loop", std::vector<Expr>({loops[0]})}},
                                  {{"thread_axis", std::string("blockIdx.x")}},
                                  {}));
793 794 795 796 797 798 799 800 801 802 803 804 805
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_Rfactor) {
  Expr M(32);
  Expr N(2);
  Expr K(16);

  Placeholder<float> A("A", {M, K});
  Placeholder<float> B("B", {K, N});
  Var k(16, "k0");
  auto C = Compute(
806 807 808 809 810 811 812 813 814 815 816 817
      {M, N},
      [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); },
      "C");

  lowered_funcs = cinn::lang::LowerVec("test_rfactor",
                                       CreateStages({A, B, C}),
                                       {A, B, C},
                                       {},
                                       {},
                                       nullptr,
                                       target,
                                       true);
818 819 820 821 822 823

  cinn::common::Context::Global().ResetNameId();
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);
  cinn::common::Context::Global().ResetNameId();

  auto loops = ir_sch.GetLoops("C");
824 825
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops));
826
  auto new_rf_tensor = ir_sch.Rfactor(loops[2], 0);
827 828 829 830
  trace.Append(ScheduleDesc::Step("Rfactor",
                                  {{"rf_loop", std::vector<Expr>({loops[2]})}},
                                  {{"rf_axis", 0}},
                                  {new_rf_tensor}));
831 832 833 834 835 836 837 838
  CheckTracingOutputs({new_rf_tensor}, trace);
  CheckTracingOutputs({new_rf_tensor}, ir_sch.GetTraceDesc());
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_MergeExprs) {
  auto funcs_0 = LowerCompute({32, 128}, target);
839 840
  auto funcs_1 =
      LowerCompute({32, 32, 32}, target, true, "elementwise-add_const");
841

842 843
  ir::IRSchedule ir_sch = ir::IRSchedule(ir::ModuleExpr(
      {optim::IRCopy(funcs_0[0]->body), optim::IRCopy(funcs_0[0]->body)}));
844 845
  ir_sch.MergeExprs();
  trace.Append(ScheduleDesc::Step("MergeExprs", {}, {}, {}));
846 847
  ir::IRSchedule replay_sch = ir::IRSchedule(ir::ModuleExpr(
      {optim::IRCopy(funcs_0[0]->body), optim::IRCopy(funcs_0[0]->body)}));
848 849 850 851 852 853
  trace.Replay(&replay_sch);

  auto lhs_exprs = ir_sch.GetModule().GetExprs();
  auto rhs_exprs = replay_sch.GetModule().GetExprs();
  ASSERT_EQ(lhs_exprs.size(), rhs_exprs.size());
  for (auto i = 0; i < lhs_exprs.size(); ++i) {
854 855
    ASSERT_EQ(utils::GetStreamCnt(lhs_exprs.at(i)),
              utils::GetStreamCnt(rhs_exprs.at(i)));
856 857 858 859
  }
}

TEST_F(TestScheduleDesc, StepKind_Annotate) {
860
  lowered_funcs = LowerCompute({32, 128}, target);
861 862 863
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
864 865
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
866
  ir_sch.Annotate(block_b, "k1", int(64));
867 868 869 870 871
  trace.Append(
      ScheduleDesc::Step("AnnotateIntAttr",
                         {{"block", std::vector<Expr>({block_b})}},
                         {{"key", std::string("k1")}, {"value", int(64)}},
                         {}));
872 873

  block_b = ir_sch.GetBlock("B");
874 875
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
876
  ir_sch.Annotate(block_b, "k2", bool(true));
877 878 879 880 881
  trace.Append(
      ScheduleDesc::Step("AnnotateBoolAttr",
                         {{"block", std::vector<Expr>({block_b})}},
                         {{"key", std::string("k2")}, {"value", bool(true)}},
                         {}));
882 883

  block_b = ir_sch.GetBlock("B");
884 885
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
886
  ir_sch.Annotate(block_b, "k3", float(2.0));
887 888 889 890 891
  trace.Append(
      ScheduleDesc::Step("AnnotateFloatAttr",
                         {{"block", std::vector<Expr>({block_b})}},
                         {{"key", std::string("k3")}, {"value", float(2.0)}},
                         {}));
892 893

  block_b = ir_sch.GetBlock("B");
894 895
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
896
  ir_sch.Annotate(block_b, "k4", std::string("v4"));
897 898 899 900 901
  trace.Append(ScheduleDesc::Step(
      "AnnotateStringAttr",
      {{"block", std::vector<Expr>({block_b})}},
      {{"key", std::string("k4")}, {"value", std::string("v4")}},
      {}));
902 903 904 905 906 907

  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_Unannotate) {
908
  lowered_funcs = LowerCompute({32, 128}, target);
909 910 911
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
912 913
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
914
  ir_sch.Annotate(block_b, "k1", int(64));
915 916 917 918 919
  trace.Append(
      ScheduleDesc::Step("AnnotateIntAttr",
                         {{"block", std::vector<Expr>({block_b})}},
                         {{"key", std::string("k1")}, {"value", int(64)}},
                         {}));
920 921

  block_b = ir_sch.GetBlock("B");
922 923
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
924
  ir_sch.Annotate(block_b, "k2", bool(true));
925 926 927 928 929
  trace.Append(
      ScheduleDesc::Step("AnnotateBoolAttr",
                         {{"block", std::vector<Expr>({block_b})}},
                         {{"key", std::string("k2")}, {"value", bool(true)}},
                         {}));
930 931

  block_b = ir_sch.GetBlock("B");
932 933
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
934
  ir_sch.Unannotate(block_b, "k1");
935 936 937 938
  trace.Append(ScheduleDesc::Step("Unannotate",
                                  {{"block", std::vector<Expr>({block_b})}},
                                  {{"key", std::string("k1")}},
                                  {}));
939 940

  block_b = ir_sch.GetBlock("B");
941 942
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
943
  ir_sch.Unannotate(block_b, "k2");
944 945 946 947
  trace.Append(ScheduleDesc::Step("Unannotate",
                                  {{"block", std::vector<Expr>({block_b})}},
                                  {{"key", std::string("k2")}},
                                  {}));
948 949 950 951 952 953 954 955 956 957 958 959

  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_SamplePerfectTile) {
  Expr M(1024);
  Var n(1, "n");

  Placeholder<int> A("A", {M});
  auto B = Compute(
      {M}, [&](Expr i) { return A(i) + n; }, "B");
960 961 962 963 964 965 966 967
  lowered_funcs = cinn::lang::LowerVec("test_sample_perfect_tile",
                                       CreateStages({A, B}),
                                       {A, B},
                                       {},
                                       {},
                                       nullptr,
                                       target,
                                       true);
968 969

  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);
970 971 972
  auto loops = ir_sch.GetLoops("B");
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
973 974
  auto result = ir_sch.SamplePerfectTile(loops[0], 2, 64);
  std::vector<int> decision;
975 976 977 978 979 980 981 982 983
  std::transform(
      result.begin(), result.end(), std::back_inserter(decision), [](Expr x) {
        return x.as_int32();
      });
  trace.Append(ScheduleDesc::Step(
      "SamplePerfectTile",
      {{"loop", std::vector<Expr>({loops[0]})}},
      {{"n", 2}, {"max_innermost_factor", 64}, {"decision", decision}},
      result));
984 985 986 987 988 989 990
  CheckTracingOutputs(result, trace);
  CheckTracingOutputs(result, ir_sch.GetTraceDesc());
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_SampleCategorical) {
991 992 993
  lowered_funcs = LowerCompute({32, 32, 64}, target, true);
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);
  Expr ret = ir_sch.SampleCategorical({1, 2, 3}, {1.0, 2.0, 3.0});
994
  std::vector<int> decision = {ret.as_int32()};
995 996 997 998 999 1000 1001
  trace.Append(
      ScheduleDesc::Step("SampleCategorical",
                         {},
                         {{"candidates", std::vector<int>({1, 2, 3})},
                          {"probs", std::vector<float>({1.0, 2.0, 3.0})},
                          {"decision", decision}},
                         {ret}));
1002 1003 1004 1005 1006 1007 1008 1009
  CheckTracingOutputs({ret}, trace);
  CheckTracingOutputs({ret}, ir_sch.GetTraceDesc());
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

}  // namespace ir
}  // namespace cinn