schedule_desc_test.cc 40.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/ir/schedule_desc.h"

#include <glog/logging.h>
#include <gtest/gtest.h>

#include "paddle/cinn/cinn.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_schedule.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/optim/ir_copy.h"
#include "paddle/cinn/utils/string.h"
#include "paddle/cinn/utils/type_defs.h"

namespace cinn {
namespace ir {

// Return lowerd ir AST for example functions used in this test
33 34 35 36 37
std::vector<ir::LoweredFunc> LowerCompute(
    const std::vector<int>& shape,
    const Target& target,
    bool need_c = false,
    const std::string& operation = "elementwise-copy") {
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
  CHECK(shape.size() == 2 || shape.size() == 3) << "shape should be 2 or 3";
  std::vector<Expr> domain;
  for (auto i = 0; i < shape.size(); ++i) {
    domain.emplace_back(shape[i]);
  }

  Placeholder<float> A("A", domain);
  ir::Tensor B, C;

  if (operation == "elementwise-copy") {
    if (domain.size() == 2) {
      B = Compute(
          domain, [&A](Var i, Var j) { return A(i, j); }, "B");
      C = Compute(
          domain, [&B](Var i, Var j) { return B(i, j); }, "C");
    } else {
      B = Compute(
          domain, [&A](Var i, Var j, Var k) { return A(i, j, k); }, "B");
      C = Compute(
          domain, [&B](Var i, Var j, Var k) { return B(i, j, k); }, "C");
    }
  }

  if (operation == "elementwise-add_const") {
    if (domain.size() == 2) {
      B = Compute(
          domain, [&A](Var i, Var j) { return A(i, j) * Expr(2.f); }, "B");
      C = Compute(
          domain, [&B](Var i, Var j) { return B(i, j) + Expr(1.f); }, "C");
    } else {
      B = Compute(
69 70 71
          domain,
          [&A](Var i, Var j, Var k) { return A(i, j, k) * Expr(2.f); },
          "B");
72
      C = Compute(
73 74 75
          domain,
          [&B](Var i, Var j, Var k) { return B(i, j, k) + Expr(1.f); },
          "C");
76 77 78 79
    }
  }

  if (need_c) {
80 81 82 83 84 85 86 87
    return cinn::lang::LowerVec("test_func",
                                CreateStages({A, B, C}),
                                {A, C},
                                {},
                                {},
                                nullptr,
                                target,
                                true);
88 89
  }

90 91
  return cinn::lang::LowerVec(
      "test_func", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true);
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
}

// Create a new IRSchedule with copied ir::LoweredFunc AST
IRSchedule MakeIRSchedule(const std::vector<ir::LoweredFunc>& lowered_funcs) {
  std::vector<Expr> exprs;
  for (auto&& func : lowered_funcs) {
    exprs.emplace_back(optim::IRCopy(func->body));
  }
  return ir::IRSchedule(ir::ModuleExpr(exprs));
}

// Generate source code with transformed ModuleExpr
std::string SourceCodeGen(const ModuleExpr& module_expr,
                          const std::vector<ir::LoweredFunc>& lowered_funcs,
                          const Target& target) {
  auto exprs = module_expr.GetExprs();
  CHECK_EQ(exprs.size(), lowered_funcs.size()) << "size of func is not euqal";
  std::vector<ir::LoweredFunc> updated_funcs = optim::IRCopy(lowered_funcs);
  Module::Builder builder("test_module", target);
  for (auto i = 0; i < lowered_funcs.size(); ++i) {
    updated_funcs[i]->body = optim::IRCopy(exprs.at(i));
    builder.AddFunction(updated_funcs[i]);
  }
  auto module = builder.Build();
  CodeGenC codegen(target);
  codegen.SetInlineBuiltinCodes(false);
  return codegen.Compile(module, CodeGenC::OutputKind::CImpl);
}

class TestScheduleDesc : public ::testing::Test {
 public:
  Target target = common::DefaultHostTarget();
  std::vector<ir::LoweredFunc> lowered_funcs;
  ScheduleDesc trace;
  void SetUp() override { Context::Global().ResetNameId(); }

128 129
  void CheckTracingOutputs(const std::vector<Expr>& base,
                           const ScheduleDesc& trace_desc) {
130 131
    Context::Global().ResetNameId();
    ir::IRSchedule replay_sch = MakeIRSchedule(lowered_funcs);
132 133
    auto traced_outputs =
        ScheduleDesc::ReplayWithProto(trace_desc.ToProto(), &replay_sch);
134 135
    ASSERT_EQ(base.size(), traced_outputs.size());
    for (auto i = 0; i < base.size(); ++i) {
136 137
      ASSERT_EQ(utils::GetStreamCnt(base.at(i)),
                utils::GetStreamCnt(traced_outputs.at(i)));
138 139 140
    }
  }

141 142
  void CheckReplayResult(const ir::IRSchedule& ir_sch,
                         const ScheduleDesc& trace_desc) {
143 144 145 146 147 148 149 150 151 152
    Context::Global().ResetNameId();
    ir::IRSchedule replay_sch = MakeIRSchedule(lowered_funcs);
    trace_desc.Replay(&replay_sch);

    // check the equality of module expr between original schedule
    // and the schedule generated by replaying with tracing ScheduleDesc
    auto lhs_exprs = ir_sch.GetModule().GetExprs();
    auto rhs_exprs = replay_sch.GetModule().GetExprs();
    ASSERT_EQ(lhs_exprs.size(), rhs_exprs.size());
    for (auto i = 0; i < lhs_exprs.size(); ++i) {
153 154
      ASSERT_EQ(utils::GetStreamCnt(lhs_exprs.at(i)),
                utils::GetStreamCnt(rhs_exprs.at(i)));
155 156 157
    }

    // check the equality of source code between them
158 159 160 161
    ASSERT_EQ(
        utils::Trim(SourceCodeGen(ir_sch.GetModule(), lowered_funcs, target)),
        utils::Trim(
            SourceCodeGen(replay_sch.GetModule(), lowered_funcs, target)));
162 163 164 165
  }
};

TEST_F(TestScheduleDesc, Append_Replay) {
166
  lowered_funcs = LowerCompute({32, 32}, target);
167 168 169
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto fused = ir_sch.Fuse("B", {0, 1});
170 171 172 173 174
  trace.Append(ScheduleDesc::Step("FuseWithName",
                                  {},
                                  {{"block_name", std::string("B")},
                                   {"loops_index", std::vector<int>({0, 1})}},
                                  {fused}));
175 176 177
  auto sample = ir_sch.SamplePerfectTile(fused, 2, 1, {4, -1});
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
                                  {{"loop", std::vector<Expr>({fused})}},
178 179 180
                                  {{"n", 2},
                                   {"max_innermost_factor", 1},
                                   {"decision", std::vector<int>{4, -1}}},
181 182
                                  sample));
  auto splited = ir_sch.Split(fused, sample);
183 184 185 186 187
  trace.Append(ScheduleDesc::Step(
      "Split",
      {{"loop", std::vector<Expr>({fused})}, {"factors", sample}},
      {},
      splited));
188 189

  auto loops = ir_sch.GetLoops("B");
190 191
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
192 193 194 195 196
  fused = ir_sch.Fuse(loops);
  trace.Append(ScheduleDesc::Step("Fuse", {{"loops", loops}}, {}, {fused}));
  sample = ir_sch.SamplePerfectTile(fused, 2, 1, {256, -1});
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
                                  {{"loop", std::vector<Expr>({fused})}},
197 198 199
                                  {{"n", 2},
                                   {"max_innermost_factor", 1},
                                   {"decision", std::vector<int>{256, -1}}},
200 201
                                  sample));
  splited = ir_sch.Split(fused, sample);
202 203 204 205 206
  trace.Append(ScheduleDesc::Step(
      "Split",
      {{"loop", std::vector<Expr>({fused})}, {"factors", sample}},
      {},
      splited));
207 208 209 210 211 212 213 214 215

  // check the equality of results between the ir_sch and replaying of trace
  CheckTracingOutputs(splited, trace);
  CheckReplayResult(ir_sch, trace);
  // check the equality of results between the ir_sch and replaying of its trace
  CheckTracingOutputs(splited, ir_sch.GetTraceDesc());
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

216 217
// Test cases with `StepKind` prefix are to check the correctness of their
// StepKindInfo register
218
TEST_F(TestScheduleDesc, StepKind_GetAllBlocks) {
219
  lowered_funcs = LowerCompute({32, 32}, target);
220 221 222 223 224 225 226 227 228
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto all_blocks = ir_sch.GetAllBlocks();
  trace.Append(ScheduleDesc::Step("GetAllBlocks", {}, {}, {all_blocks}));
  CheckTracingOutputs(all_blocks, trace);
  CheckTracingOutputs(all_blocks, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_GetChildBlocks) {
229
  lowered_funcs = LowerCompute({32, 32, 64}, target, true);
230 231 232
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
233 234
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
235
  auto loops = ir_sch.GetLoops("C");
236 237
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops));
238 239
  ir_sch.ComputeAt(block_b, loops[1]);
  trace.Append(ScheduleDesc::Step("ComputeAt",
240 241
                                  {{"block", std::vector<Expr>({block_b})},
                                   {"loop", std::vector<Expr>({loops[1]})}},
242 243 244
                                  {{"keep_unit_loops", false}},
                                  {}));
  loops = ir_sch.GetLoops("B");
245 246
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
247
  auto root_block = ir_sch.GetRootBlock(loops[1]);
248 249 250 251
  trace.Append(ScheduleDesc::Step("GetRootBlock",
                                  {{"expr", std::vector<Expr>({loops[1]})}},
                                  {},
                                  {root_block}));
252
  auto childblocks = ir_sch.GetChildBlocks(root_block);
253 254 255 256
  trace.Append(ScheduleDesc::Step("GetChildBlocks",
                                  {{"expr", std::vector<Expr>({root_block})}},
                                  {},
                                  childblocks));
257 258 259 260 261
  CheckTracingOutputs(childblocks, trace);
  CheckTracingOutputs(childblocks, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_GetLoops) {
262
  lowered_funcs = LowerCompute({32, 32}, target);
263 264 265
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
266 267
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
268
  auto loops = ir_sch.GetLoops(block_b);
269 270
  trace.Append(ScheduleDesc::Step(
      "GetLoops", {{"block", std::vector<Expr>({block_b})}}, {}, loops));
271 272 273 274 275
  CheckTracingOutputs(loops, trace);
  CheckTracingOutputs(loops, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_GetLoopsWithName) {
276
  lowered_funcs = LowerCompute({32, 32}, target);
277 278 279
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
280 281
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
282 283 284 285 286
  CheckTracingOutputs(loops, trace);
  CheckTracingOutputs(loops, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_GetBlock) {
287
  lowered_funcs = LowerCompute({32, 32, 32}, target);
288 289 290
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
291 292
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
293 294 295 296 297 298 299 300 301 302 303 304 305
  CheckTracingOutputs({block_b}, trace);
  CheckTracingOutputs({block_b}, ir_sch.GetTraceDesc());
}
// TODO: fix in future, as fix split var name, this case some problem.
/*
TEST_F(TestScheduleDesc, StepKind_Split) {
  lowered_funcs                         = LowerCompute({32, 32, 32}, target);
  ir::IRSchedule ir_sch_split_base      = MakeIRSchedule(lowered_funcs);
  ir::IRSchedule ir_sch_split           = MakeIRSchedule(lowered_funcs);
  ir::IRSchedule ir_sch_split_with_name = MakeIRSchedule(lowered_funcs);

  // test split with inputs of Expr
  auto loops = ir_sch_split_base.GetLoops("B");
306 307 308
  trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name",
std::string("B")}}, loops)); auto sample =
ir_sch_split_base.SamplePerfectTile(loops.front(), 2, 1, {4, -1});
309
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
310 311 312 313 314 315 316
                                  {{"loop",
std::vector<Expr>({loops.front()})}},
                                  {{"n", 2}, {"max_innermost_factor", 1},
{"decision", std::vector<int>{4, -1}}}, sample)); auto splited =
ir_sch_split_base.Split(loops.front(), sample); trace.Append(
      ScheduleDesc::Step("Split", {{"loop", std::vector<Expr>({loops.front()})},
{"factors", sample}}, {}, splited)); CheckTracingOutputs(splited, trace);
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
  CheckTracingOutputs(splited, ir_sch_split_base.GetTraceDesc());

  // test split with inputs of int
  loops   = ir_sch_split.GetLoops("B");
  splited = ir_sch_split.Split(loops.front(), {4, -1});
  CheckTracingOutputs(splited, trace);
  CheckTracingOutputs(splited, ir_sch_split.GetTraceDesc());

  // test split with block name and inputs of int
  splited = ir_sch_split_with_name.Split("B", 0, {4, -1});
  CheckTracingOutputs(splited, trace);
  CheckTracingOutputs(splited, ir_sch_split_with_name.GetTraceDesc());
}
*/
TEST_F(TestScheduleDesc, StepKind_Fuse) {
332
  lowered_funcs = LowerCompute({32, 32, 64}, target);
333 334 335
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
336 337
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
338 339 340 341 342 343 344
  auto fused = ir_sch.Fuse(loops);
  trace.Append(ScheduleDesc::Step("Fuse", {{"loops", loops}}, {}, {fused}));
  CheckTracingOutputs({fused}, trace);
  CheckTracingOutputs({fused}, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_FuseWithName) {
345
  lowered_funcs = LowerCompute({32, 32, 64}, target);
346 347 348
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto fused = ir_sch.Fuse("B", {0, 1, 2});
349 350 351 352 353 354
  trace.Append(
      ScheduleDesc::Step("FuseWithName",
                         {},
                         {{"block_name", std::string("B")},
                          {"loops_index", std::vector<int>({0, 1, 2})}},
                         {fused}));
355 356 357 358 359
  CheckTracingOutputs({fused}, trace);
  CheckTracingOutputs({fused}, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_FuseWithBlock) {
360
  lowered_funcs = LowerCompute({32, 32, 64}, target);
361 362 363
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
364 365
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
366
  auto fused = ir_sch.Fuse(block_b, {0, 1, 2});
367 368 369 370 371
  trace.Append(
      ScheduleDesc::Step("FuseWithBlock",
                         {{"block", std::vector<Expr>({block_b})}},
                         {{"loops_index", std::vector<int>({0, 1, 2})}},
                         {fused}));
372 373 374 375 376
  CheckTracingOutputs({fused}, trace);
  CheckTracingOutputs({fused}, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_ComputeAt) {
377
  lowered_funcs = LowerCompute({32, 32, 64}, target, true);
378 379 380
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
381 382
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
383
  auto loops = ir_sch.GetLoops("C");
384 385
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops));
386 387
  ir_sch.ComputeAt(block_b, loops[1]);
  trace.Append(ScheduleDesc::Step("ComputeAt",
388 389
                                  {{"block", std::vector<Expr>({block_b})},
                                   {"loop", std::vector<Expr>({loops[1]})}},
390 391 392 393 394 395 396
                                  {{"keep_unit_loops", false}},
                                  {}));
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_SimpleComputeAt) {
397
  lowered_funcs = LowerCompute({32, 32, 64}, target, true);
398 399 400
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
401 402
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
403
  auto loops = ir_sch.GetLoops("C");
404 405
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops));
406 407
  ir_sch.SimpleComputeAt(block_b, loops[2]);
  trace.Append(ScheduleDesc::Step("SimpleComputeAt",
408 409
                                  {{"block", std::vector<Expr>({block_b})},
                                   {"loop", std::vector<Expr>({loops[2]})}},
410 411 412 413 414 415 416
                                  {{"keep_unit_loops", false}},
                                  {}));
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_ReverseComputeAt) {
417
  lowered_funcs = LowerCompute({32, 32, 64}, target, true);
418 419 420
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_c = ir_sch.GetBlock("C");
421 422
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("C")}}, {block_c}));
423
  auto loops = ir_sch.GetLoops("B");
424 425
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
426 427
  ir_sch.ReverseComputeAt(block_c, loops[1]);
  trace.Append(ScheduleDesc::Step("ReverseComputeAt",
428 429
                                  {{"block", std::vector<Expr>({block_c})},
                                   {"loop", std::vector<Expr>({loops[1]})}},
430 431 432 433 434 435 436
                                  {{"keep_unit_loops", false}},
                                  {}));
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_GetRootBlock) {
437
  lowered_funcs = LowerCompute({32, 64}, target);
438 439 440
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
441 442
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
443
  auto root_b = ir_sch.GetRootBlock(loops[1]);
444 445
  trace.Append(ScheduleDesc::Step(
      "GetRootBlock", {{"expr", std::vector<Expr>({loops[1]})}}, {}, {root_b}));
446 447 448 449 450
  CheckTracingOutputs({root_b}, trace);
  CheckTracingOutputs({root_b}, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_CacheRead) {
451 452
  lowered_funcs =
      LowerCompute({32, 64}, target, false, "elementwise-add_const");
453 454 455
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
456 457
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
458
  auto a_cache = ir_sch.CacheRead(block_b, 0, "local");
459 460 461 462 463
  trace.Append(ScheduleDesc::Step(
      "CacheRead",
      {{"block", std::vector<Expr>({block_b})}},
      {{"read_buffer_index", 0}, {"memory_type", std::string("local")}},
      {a_cache}));
464 465 466 467 468 469 470
  CheckTracingOutputs({a_cache}, trace);
  CheckTracingOutputs({a_cache}, ir_sch.GetTraceDesc());
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_CacheWrite) {
471 472
  lowered_funcs =
      LowerCompute({32, 64}, target, false, "elementwise-add_const");
473 474 475
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
476 477
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
478
  auto b_cache = ir_sch.CacheWrite(block_b, 0, "local");
479 480 481 482 483
  trace.Append(ScheduleDesc::Step(
      "CacheWrite",
      {{"block", std::vector<Expr>({block_b})}},
      {{"write_buffer_index", 0}, {"memory_type", std::string("local")}},
      {b_cache}));
484 485 486 487 488 489 490
  CheckTracingOutputs({b_cache}, trace);
  CheckTracingOutputs({b_cache}, ir_sch.GetTraceDesc());
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_SyncThreads) {
491
  lowered_funcs = LowerCompute({64, 32}, target, true, "elementwise-add_const");
492 493 494
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
495 496
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
497
  auto b_cache = ir_sch.CacheWrite(block_b, 0, "local");
498 499 500 501 502
  trace.Append(ScheduleDesc::Step(
      "CacheWrite",
      {{"block", std::vector<Expr>({block_b})}},
      {{"write_buffer_index", 0}, {"memory_type", std::string("local")}},
      {b_cache}));
503
  auto block_c = ir_sch.GetBlock("C");
504 505
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("C")}}, {block_c}));
506
  auto c_cache = ir_sch.CacheWrite(block_c, 0, "local");
507 508 509 510 511
  trace.Append(ScheduleDesc::Step(
      "CacheWrite",
      {{"block", std::vector<Expr>({block_c})}},
      {{"write_buffer_index", 0}, {"memory_type", std::string("local")}},
      {c_cache}));
512
  block_c = ir_sch.GetBlock("C");
513 514
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("C")}}, {block_c}));
515
  ir_sch.SyncThreads(block_c, false);
516 517 518 519
  trace.Append(ScheduleDesc::Step("SyncThreads",
                                  {{"ir_node", std::vector<Expr>({block_c})}},
                                  {{"after_node", false}},
                                  {}));
520
  block_b = ir_sch.GetBlock("B");
521 522
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
523
  ir_sch.SyncThreads(block_b);
524 525 526 527
  trace.Append(ScheduleDesc::Step("SyncThreads",
                                  {{"ir_node", std::vector<Expr>({block_b})}},
                                  {{"after_node", true}},
                                  {}));
528 529 530 531 532 533

  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_SetBuffer) {
534 535
  lowered_funcs =
      LowerCompute({32, 64}, target, false, "elementwise-add_const");
536 537 538
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
539 540
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
541
  ir_sch.SetBuffer(block_b, "shared", true);
542 543 544 545 546
  trace.Append(ScheduleDesc::Step(
      "SetBuffer",
      {{"block", std::vector<Expr>({block_b})}},
      {{"memory_type", std::string("shared")}, {"fixed", true}},
      {}));
547 548 549 550 551
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_Reorder) {
552
  lowered_funcs = LowerCompute({32, 64, 12}, target);
553 554 555
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
556 557
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
558 559 560
  auto sample = ir_sch.SamplePerfectTile(loops[0], 2, 1, {-1, 4});
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
                                  {{"loop", std::vector<Expr>({loops[0]})}},
561 562 563
                                  {{"n", 2},
                                   {"max_innermost_factor", 1},
                                   {"decision", std::vector<int>{-1, 4}}},
564 565
                                  sample));
  auto splited = ir_sch.Split(loops[0], sample);
566 567 568 569 570
  trace.Append(ScheduleDesc::Step(
      "Split",
      {{"loop", std::vector<Expr>({loops[0]})}, {"factors", sample}},
      {},
      splited));
571 572

  loops = ir_sch.GetLoops("B");
573 574
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
575 576 577
  sample = ir_sch.SamplePerfectTile(loops[2], 2, 1, {-1, 2});
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
                                  {{"loop", std::vector<Expr>({loops[2]})}},
578 579 580
                                  {{"n", 2},
                                   {"max_innermost_factor", 1},
                                   {"decision", std::vector<int>{-1, 2}}},
581 582
                                  sample));
  splited = ir_sch.Split(loops[2], sample);
583 584 585 586 587
  trace.Append(ScheduleDesc::Step(
      "Split",
      {{"loop", std::vector<Expr>({loops[2]})}, {"factors", sample}},
      {},
      splited));
588 589

  loops = ir_sch.GetLoops("B");
590 591
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
592
  Expr ret = ir_sch.Reorder({loops[4], loops[0]});
593 594 595 596 597
  trace.Append(
      ScheduleDesc::Step("Reorder",
                         {{"loops", std::vector<Expr>({loops[4], loops[0]})}},
                         {},
                         {ret}));
598 599 600 601 602
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_ReorderWithBlock) {
603
  lowered_funcs = LowerCompute({32, 32, 64}, target);
604
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);
605 606 607
  auto loops = ir_sch.GetLoops("B");
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
608 609 610
  auto sample = ir_sch.SamplePerfectTile(loops[0], 2, 1, {-1, 4});
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
                                  {{"loop", std::vector<Expr>({loops[0]})}},
611 612 613
                                  {{"n", 2},
                                   {"max_innermost_factor", 1},
                                   {"decision", std::vector<int>{-1, 4}}},
614 615
                                  sample));
  auto splited = ir_sch.Split(loops[0], sample);
616 617 618 619 620
  trace.Append(ScheduleDesc::Step(
      "Split",
      {{"loop", std::vector<Expr>({loops[0]})}, {"factors", sample}},
      {},
      splited));
621 622

  loops = ir_sch.GetLoops("B");
623 624
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
625 626 627
  sample = ir_sch.SamplePerfectTile(loops[2], 2, 1, {-1, 2});
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
                                  {{"loop", std::vector<Expr>({loops[2]})}},
628 629 630
                                  {{"n", 2},
                                   {"max_innermost_factor", 1},
                                   {"decision", std::vector<int>{-1, 2}}},
631 632
                                  sample));
  splited = ir_sch.Split(loops[2], sample);
633 634 635 636 637
  trace.Append(ScheduleDesc::Step(
      "Split",
      {{"loop", std::vector<Expr>({loops[2]})}, {"factors", sample}},
      {},
      splited));
638 639

  auto block_b = ir_sch.GetBlock("B");
640 641
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
642
  Expr ret = ir_sch.Reorder("B", {2, 3, 1, 4, 0});
643 644 645 646 647
  trace.Append(
      ScheduleDesc::Step("ReorderWithBlock",
                         {{"block", std::vector<Expr>({block_b})}},
                         {{"loops_index", std::vector<int>({2, 3, 1, 4, 0})}},
                         {ret}));
648 649 650 651 652
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_ReorderWithName) {
653
  lowered_funcs = LowerCompute({32, 32, 64}, target);
654 655 656
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
657 658
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
659 660 661
  auto sample = ir_sch.SamplePerfectTile(loops[0], 2, 1, {-1, 4});
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
                                  {{"loop", std::vector<Expr>({loops[0]})}},
662 663 664
                                  {{"n", 2},
                                   {"max_innermost_factor", 1},
                                   {"decision", std::vector<int>{-1, 4}}},
665 666
                                  sample));
  auto splited = ir_sch.Split(loops[0], sample);
667 668 669 670 671
  trace.Append(ScheduleDesc::Step(
      "Split",
      {{"loop", std::vector<Expr>({loops[0]})}, {"factors", sample}},
      {},
      splited));
672 673

  loops = ir_sch.GetLoops("B");
674 675
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
676 677 678
  sample = ir_sch.SamplePerfectTile(loops[2], 2, 1, {-1, 2});
  trace.Append(ScheduleDesc::Step("SamplePerfectTile",
                                  {{"loop", std::vector<Expr>({loops[2]})}},
679 680 681
                                  {{"n", 2},
                                   {"max_innermost_factor", 1},
                                   {"decision", std::vector<int>{-1, 2}}},
682 683
                                  sample));
  splited = ir_sch.Split(loops[2], sample);
684 685 686 687 688
  trace.Append(ScheduleDesc::Step(
      "Split",
      {{"loop", std::vector<Expr>({loops[2]})}, {"factors", sample}},
      {},
      splited));
689 690 691 692 693

  Expr ret = ir_sch.Reorder("B", {4, 2, 3, 1, 0});
  trace.Append(
      ScheduleDesc::Step("ReorderWithName",
                         {},
694 695
                         {{"block_name", std::string("B")},
                          {"loops_index", std::vector<int>({4, 2, 3, 1, 0})}},
696 697 698 699 700 701
                         {ret}));
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_Parallel) {
702
  lowered_funcs = LowerCompute({32, 64}, target);
703 704 705
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
706 707
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
708
  ir_sch.Parallel(loops[0]);
709 710
  trace.Append(ScheduleDesc::Step(
      "Parallel", {{"loop", std::vector<Expr>({loops[0]})}}, {}, {}));
711 712 713 714 715
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_Vectorize) {
716
  lowered_funcs = LowerCompute({32, 64}, target);
717 718 719
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
720 721
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
722
  ir_sch.Vectorize(loops[1], 16);
723 724 725 726
  trace.Append(ScheduleDesc::Step("Vectorize",
                                  {{"loop", std::vector<Expr>({loops[1]})}},
                                  {{"factor", 16}},
                                  {}));
727 728 729 730 731
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_Unroll) {
732
  lowered_funcs = LowerCompute({32, 2}, target);
733 734 735
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
736 737
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
738
  ir_sch.Unroll(loops[1]);
739 740
  trace.Append(ScheduleDesc::Step(
      "Unroll", {{"loop", std::vector<Expr>({loops[1]})}}, {}, {}));
741 742 743 744 745
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_ComputeInline) {
746 747
  lowered_funcs =
      LowerCompute({32, 32, 32}, target, true, "elementwise-add_const");
748 749 750
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
751 752
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
753
  ir_sch.ComputeInline(block_b);
754 755 756 757 758
  trace.Append(
      ScheduleDesc::Step("ComputeInline",
                         {{"schedule_block", std::vector<Expr>({block_b})}},
                         {},
                         {}));
759 760 761 762 763
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_ReverseComputeInline) {
764 765
  lowered_funcs =
      LowerCompute({32, 32, 32}, target, true, "elementwise-add_const");
766
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);
767 768 769
  auto block_c = ir_sch.GetBlock("C");
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("C")}}, {block_c}));
770
  ir_sch.ReverseComputeInline(block_c);
771 772 773 774 775
  trace.Append(
      ScheduleDesc::Step("ReverseComputeInline",
                         {{"schedule_block", std::vector<Expr>({block_c})}},
                         {},
                         {}));
776 777 778 779 780
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_Bind) {
781
  lowered_funcs = LowerCompute({32, 128}, target);
782 783 784 785
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto loops = ir_sch.GetLoops("B");
  trace.Append(ScheduleDesc::Step(
786 787 788 789 790 791
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
  ir_sch.Bind(loops[0], "blockIdx.x");
  trace.Append(ScheduleDesc::Step("Bind",
                                  {{"loop", std::vector<Expr>({loops[0]})}},
                                  {{"thread_axis", std::string("blockIdx.x")}},
                                  {}));
792 793 794 795 796 797 798 799 800 801 802 803 804
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_Rfactor) {
  Expr M(32);
  Expr N(2);
  Expr K(16);

  Placeholder<float> A("A", {M, K});
  Placeholder<float> B("B", {K, N});
  Var k(16, "k0");
  auto C = Compute(
805 806 807 808 809 810 811 812 813 814 815 816
      {M, N},
      [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); },
      "C");

  lowered_funcs = cinn::lang::LowerVec("test_rfactor",
                                       CreateStages({A, B, C}),
                                       {A, B, C},
                                       {},
                                       {},
                                       nullptr,
                                       target,
                                       true);
817 818 819 820 821 822

  cinn::common::Context::Global().ResetNameId();
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);
  cinn::common::Context::Global().ResetNameId();

  auto loops = ir_sch.GetLoops("C");
823 824
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops));
825
  auto new_rf_tensor = ir_sch.Rfactor(loops[2], 0);
826 827 828 829
  trace.Append(ScheduleDesc::Step("Rfactor",
                                  {{"rf_loop", std::vector<Expr>({loops[2]})}},
                                  {{"rf_axis", 0}},
                                  {new_rf_tensor}));
830 831 832 833 834 835 836 837
  CheckTracingOutputs({new_rf_tensor}, trace);
  CheckTracingOutputs({new_rf_tensor}, ir_sch.GetTraceDesc());
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_MergeExprs) {
  auto funcs_0 = LowerCompute({32, 128}, target);
838 839
  auto funcs_1 =
      LowerCompute({32, 32, 32}, target, true, "elementwise-add_const");
840

841 842
  ir::IRSchedule ir_sch = ir::IRSchedule(ir::ModuleExpr(
      {optim::IRCopy(funcs_0[0]->body), optim::IRCopy(funcs_0[0]->body)}));
843 844
  ir_sch.MergeExprs();
  trace.Append(ScheduleDesc::Step("MergeExprs", {}, {}, {}));
845 846
  ir::IRSchedule replay_sch = ir::IRSchedule(ir::ModuleExpr(
      {optim::IRCopy(funcs_0[0]->body), optim::IRCopy(funcs_0[0]->body)}));
847 848 849 850 851 852
  trace.Replay(&replay_sch);

  auto lhs_exprs = ir_sch.GetModule().GetExprs();
  auto rhs_exprs = replay_sch.GetModule().GetExprs();
  ASSERT_EQ(lhs_exprs.size(), rhs_exprs.size());
  for (auto i = 0; i < lhs_exprs.size(); ++i) {
853 854
    ASSERT_EQ(utils::GetStreamCnt(lhs_exprs.at(i)),
              utils::GetStreamCnt(rhs_exprs.at(i)));
855 856 857 858
  }
}

TEST_F(TestScheduleDesc, StepKind_Annotate) {
859
  lowered_funcs = LowerCompute({32, 128}, target);
860 861 862
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
863 864
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
865
  ir_sch.Annotate(block_b, "k1", int(64));
866 867 868 869 870
  trace.Append(
      ScheduleDesc::Step("AnnotateIntAttr",
                         {{"block", std::vector<Expr>({block_b})}},
                         {{"key", std::string("k1")}, {"value", int(64)}},
                         {}));
871 872

  block_b = ir_sch.GetBlock("B");
873 874
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
875
  ir_sch.Annotate(block_b, "k2", bool(true));
876 877 878 879 880
  trace.Append(
      ScheduleDesc::Step("AnnotateBoolAttr",
                         {{"block", std::vector<Expr>({block_b})}},
                         {{"key", std::string("k2")}, {"value", bool(true)}},
                         {}));
881 882

  block_b = ir_sch.GetBlock("B");
883 884
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
885
  ir_sch.Annotate(block_b, "k3", float(2.0));
886 887 888 889 890
  trace.Append(
      ScheduleDesc::Step("AnnotateFloatAttr",
                         {{"block", std::vector<Expr>({block_b})}},
                         {{"key", std::string("k3")}, {"value", float(2.0)}},
                         {}));
891 892

  block_b = ir_sch.GetBlock("B");
893 894
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
895
  ir_sch.Annotate(block_b, "k4", std::string("v4"));
896 897 898 899 900
  trace.Append(ScheduleDesc::Step(
      "AnnotateStringAttr",
      {{"block", std::vector<Expr>({block_b})}},
      {{"key", std::string("k4")}, {"value", std::string("v4")}},
      {}));
901 902 903 904 905 906

  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_Unannotate) {
907
  lowered_funcs = LowerCompute({32, 128}, target);
908 909 910
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

  auto block_b = ir_sch.GetBlock("B");
911 912
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
913
  ir_sch.Annotate(block_b, "k1", int(64));
914 915 916 917 918
  trace.Append(
      ScheduleDesc::Step("AnnotateIntAttr",
                         {{"block", std::vector<Expr>({block_b})}},
                         {{"key", std::string("k1")}, {"value", int(64)}},
                         {}));
919 920

  block_b = ir_sch.GetBlock("B");
921 922
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
923
  ir_sch.Annotate(block_b, "k2", bool(true));
924 925 926 927 928
  trace.Append(
      ScheduleDesc::Step("AnnotateBoolAttr",
                         {{"block", std::vector<Expr>({block_b})}},
                         {{"key", std::string("k2")}, {"value", bool(true)}},
                         {}));
929 930

  block_b = ir_sch.GetBlock("B");
931 932
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
933
  ir_sch.Unannotate(block_b, "k1");
934 935 936 937
  trace.Append(ScheduleDesc::Step("Unannotate",
                                  {{"block", std::vector<Expr>({block_b})}},
                                  {{"key", std::string("k1")}},
                                  {}));
938 939

  block_b = ir_sch.GetBlock("B");
940 941
  trace.Append(ScheduleDesc::Step(
      "GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
942
  ir_sch.Unannotate(block_b, "k2");
943 944 945 946
  trace.Append(ScheduleDesc::Step("Unannotate",
                                  {{"block", std::vector<Expr>({block_b})}},
                                  {{"key", std::string("k2")}},
                                  {}));
947 948 949 950 951 952 953 954 955 956 957 958

  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_SamplePerfectTile) {
  Expr M(1024);
  Var n(1, "n");

  Placeholder<int> A("A", {M});
  auto B = Compute(
      {M}, [&](Expr i) { return A(i) + n; }, "B");
959 960 961 962 963 964 965 966
  lowered_funcs = cinn::lang::LowerVec("test_sample_perfect_tile",
                                       CreateStages({A, B}),
                                       {A, B},
                                       {},
                                       {},
                                       nullptr,
                                       target,
                                       true);
967 968

  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);
969 970 971
  auto loops = ir_sch.GetLoops("B");
  trace.Append(ScheduleDesc::Step(
      "GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
972 973
  auto result = ir_sch.SamplePerfectTile(loops[0], 2, 64);
  std::vector<int> decision;
974 975 976 977 978 979 980 981 982
  std::transform(
      result.begin(), result.end(), std::back_inserter(decision), [](Expr x) {
        return x.as_int32();
      });
  trace.Append(ScheduleDesc::Step(
      "SamplePerfectTile",
      {{"loop", std::vector<Expr>({loops[0]})}},
      {{"n", 2}, {"max_innermost_factor", 64}, {"decision", decision}},
      result));
983 984 985 986 987 988 989
  CheckTracingOutputs(result, trace);
  CheckTracingOutputs(result, ir_sch.GetTraceDesc());
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_SampleCategorical) {
990 991 992
  lowered_funcs = LowerCompute({32, 32, 64}, target, true);
  ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);
  Expr ret = ir_sch.SampleCategorical({1, 2, 3}, {1.0, 2.0, 3.0});
993
  std::vector<int> decision = {ret.as_int32()};
994 995 996 997 998 999 1000
  trace.Append(
      ScheduleDesc::Step("SampleCategorical",
                         {},
                         {{"candidates", std::vector<int>({1, 2, 3})},
                          {"probs", std::vector<float>({1.0, 2.0, 3.0})},
                          {"decision", decision}},
                         {ret}));
1001 1002 1003 1004 1005 1006 1007 1008
  CheckTracingOutputs({ret}, trace);
  CheckTracingOutputs({ret}, ir_sch.GetTraceDesc());
  CheckReplayResult(ir_sch, trace);
  CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

}  // namespace ir
}  // namespace cinn