scop_info.cc 46.3 KB
Newer Older
C
ckey_Dou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/**
 * Copyright 2019 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

17
#include "scop_info.h"
C
ckey_Dou 已提交
18 19 20 21 22 23 24 25 26
#include <regex>
#include "poly/dma_inject.h"

namespace akg {
namespace ir {
namespace poly {
constexpr int kInvalidIntAttr = -1;
Expr kInvalidExprAttr;

27 28 29 30 31 32 33
CubeInfo::~CubeInfo() {
  if (model_ != nullptr) {
    delete model_;
    model_ = nullptr;
  }
}
bool CubeInfo::IsConvBackpropInput() const {
C
ckey_Dou 已提交
34 35 36 37
  int n = ExtractIntFromAttrs(ATTR_CONV_BACKPROP_INPUT);
  return (IsConv() && (n != kInvalidIntAttr));
}

38
bool CubeInfo::IsConvBackpropFilter() const {
C
ckey_Dou 已提交
39 40 41 42
  int n = ExtractIntFromAttrs(ATTR_CONV_BACKPROP_FILTER);
  return (IsConv() && (n != kInvalidIntAttr));
}

43 44
Expr CubeInfo::ExtractExprFromAttrs(const std::string &name) const {
  for (auto i : analysis_result_.GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
45 46 47 48
    if (!i.second.isCube) {
      continue;
    }

49
    const Node *stmt_node = analysis_result_.GetStatementMap().at(i.first);
C
ckey_Dou 已提交
50 51 52 53 54
    CHECK(stmt_node != nullptr);
    if (stmt_node->IsInstance<Provide>()) {
      auto provide = static_cast<const Provide *>(stmt_node);
      if (const auto cop = provide->func.as<ComputeOpNode>()) {
        if (cop->attrs.count(name) != 0) {
55
          return air::Downcast<Expr>(cop->attrs.at(name));
C
ckey_Dou 已提交
56 57 58 59 60 61 62
        }
      }
    }
  }
  return kInvalidExprAttr;
}

63
int CubeInfo::ExtractIntFromAttrs(const std::string &name) const {
C
ckey_Dou 已提交
64 65 66 67 68 69 70 71 72 73
  Expr expr_attr = ExtractExprFromAttrs(name);
  if (expr_attr.defined()) {
    if (const auto int_op = expr_attr.as<IntImm>())
      return int_op->value;
    else
      LOG(FATAL) << "attr " << name << " is not an integer";
  }
  return kInvalidIntAttr;
}

74
std::unordered_set<std::string> AnalysisResult::ExtractWithStmtId() const {
C
ckey_Dou 已提交
75
  std::unordered_set<std::string> res;
76
  for (auto i : GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
77 78 79 80 81 82 83 84
    if (!i.second.isWith) {
      continue;
    }
    res.insert(i.first.get_name());
  }
  return res;
}

85 86
std::string CubeInfo::ExtractStringFromAttrs(const std::string &name) const {
  for (auto i : analysis_result_.GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
87 88 89 90
    if (!i.second.isCube) {
      continue;
    }

91
    const Node *stmt_node = analysis_result_.GetStatementMap().at(i.first);
C
ckey_Dou 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    if (stmt_node->IsInstance<Provide>()) {
      auto provide = static_cast<const Provide *>(stmt_node);
      if (const auto cop = provide->func.as<ComputeOpNode>()) {
        if (cop->attrs.count(name) != 0) {
          if (const auto str_op = cop->attrs.at(name).as<StringImm>()) {
            return str_op->value;
          } else {
            LOG(FATAL) << "attr " << name << " is not a string";
          }
        }
      }
    }
  }
  return "";
}

108 109
std::string CubeInfo::ExtractStringFromAttrsAndInfo(const std::string &name) const {
  for (auto i : analysis_result_.GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
110 111 112 113
    if (!i.second.isCube) {
      continue;
    }

114
    const Node *stmt_node = analysis_result_.GetStatementMap().at(i.first);
C
ckey_Dou 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128
    if (stmt_node->IsInstance<Provide>()) {
      auto provide = static_cast<const Provide *>(stmt_node);
      if (const auto cop = provide->func.as<ComputeOpNode>()) {
        if (cop->attrs.count(name) != 0) {
          if (const auto str_op = cop->attrs.at(name).as<StringImm>()) {
            return str_op->value;
          } else {
            LOG(FATAL) << "attr " << name << " is not a string";
          }
        }
      }
    }
  }

129 130
  if (GetConvAttrInfo().count(name) >= 1) {
    if (const auto str_op = GetConvAttrInfo().at(name).as<StringImm>()) {
C
ckey_Dou 已提交
131 132 133 134 135 136 137 138 139
      return str_op->value;
    } else {
      LOG(FATAL) << "attr " << name << " is not a string";
    }
  }

  return "";
}

140 141
bool ScopInfo::IsElewiseVMStmt(const isl::id &id) const {
  auto stmt = analysis_result_.GetStatementMap().at(id);
C
ckey_Dou 已提交
142 143 144 145 146 147 148 149 150
  if (stmt != nullptr && stmt->IsInstance<Provide>()) {
    auto provide = static_cast<const Provide *>(stmt);
    if (auto call = provide->value.as<Call>()) {
      if (call->call_type != Call::CallType::Halide && (call->name == "vmadd" || call->name == "vmla")) return true;
    }
  }
  return false;
}

151
bool ScopInfo::MayWriteAfterRead(const std::string &name) const {
C
ckey_Dou 已提交
152 153
  std::map<int, isl::id> def;
  std::map<int, isl::id> use;
154
  for (auto a : analysis_result_.GetWrites().get_map_list()) {
C
ckey_Dou 已提交
155 156 157 158 159 160 161 162
    isl::id id = a.domain().unwrap().domain().get_tuple_id();
    std::string idstr = id.get_name();
    if (a.get_tuple_id(isl_dim_out).get_name() != name) continue;
    CHECK_GE(idstr.size(), 2);
    idstr = idstr.substr(2, idstr.size());
    int ref = static_cast<int>(WrappedStrtol(idstr));
    def[ref] = id;
  }
163
  for (auto a : analysis_result_.GetReads().get_map_list()) {
C
ckey_Dou 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
    isl::id id = a.domain().unwrap().domain().get_tuple_id();
    std::string idstr = id.get_name();
    if (a.get_tuple_id(isl_dim_out).get_name() != name) continue;
    CHECK_GE(idstr.size(), 2);
    idstr = idstr.substr(2, idstr.size());
    int ref = static_cast<int>(WrappedStrtol(idstr));
    use[ref] = id;
  }

  if (def.empty() || use.empty()) return false;
  if (def.begin()->first >= use.begin()->first) return true;
  // if A = f(A) exists, we think has WAR
  for (auto i : def) {
    if (use.count(i.first)) {
      // vmadd/vmla insn is in the form A = f(A), but there is no WAR dependence
      if (!IsElewiseVMStmt(i.second)) return true;
    }
  }
  return false;
}

185 186
bool CubeInfo::IsA(const std::string &name) const {
  for (auto &info : analysis_result_.GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
187 188 189 190 191 192 193 194 195
    if (info.second.isCube) {
      if (info.second.A_ == name) {
        return true;
      }
    }
  }
  return false;
}

196 197
bool CubeInfo::IsB(const std::string &name) const {
  for (auto &info : analysis_result_.GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
198 199 200 201 202 203 204 205 206
    if (info.second.isCube) {
      if (info.second.B_ == name) {
        return true;
      }
    }
  }
  return false;
}

207 208
bool CubeInfo::IsC(const std::string &name) const {
  for (auto &info : analysis_result_.GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
209 210 211 212 213 214 215 216 217
    if (info.second.isCube) {
      if (info.second.C_ == name) {
        return true;
      }
    }
  }
  return false;
}

218 219
bool CubeInfo::IsCUB(const std::string &name) const {
  for (auto &info : analysis_result_.GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
220 221 222 223 224 225 226 227 228
    if (info.second.isCube) {
      if (info.second.C_ + "_local_UB" == name) {
        return true;
      }
    }
  }
  return false;
}

229 230
std::string CubeInfo::GetAName() const {
  for (auto &info : analysis_result_.GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
231 232 233 234 235 236 237
    if (info.second.isCube) {
      return info.second.A_;
    }
  }
  return "";
}

238 239
std::string CubeInfo::GetBName() const {
  for (auto &info : analysis_result_.GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
240 241 242 243 244 245 246
    if (info.second.isCube) {
      return info.second.B_;
    }
  }
  return "";
}

247 248
std::string CubeInfo::GetCName() const {
  for (auto &info : analysis_result_.GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
249 250 251 252 253 254 255
    if (info.second.isCube) {
      return info.second.C_;
    }
  }
  return "";
}

256 257
bool CubeInfo::IsIm2col() const {
  for (auto &info : analysis_result_.GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
258 259 260 261 262
    if (info.second.isIm2col) return true;
  }
  return false;
}

263 264
bool CubeInfo::IsLoad3dL1Ub() const {
  for (auto &info : analysis_result_.GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
265 266 267 268 269
    if (info.second.isLoad3d) return true;
  }
  return false;
}

270 271
bool CubeInfo::IsLoad3dL1UBStmt(const std::string &stmt_name) const {
  for (auto &info : analysis_result_.GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
272 273 274 275 276 277 278
    if (info.second.isLoad3d && info.first.name() == stmt_name) {
      return true;
    }
  }
  return false;
}

279 280
bool CubeInfo::HasCube() const {
  for (auto &info : analysis_result_.GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
281 282 283 284 285
    if (info.second.isCube) return true;
  }
  return false;
}

286
bool CubeInfo::IsGemmDataTransposeBlock() const {
C
ckey_Dou 已提交
287
  std::string trans_data_block = ExtractStringFromAttrsAndInfo(ATTR_GEMM_DATA_TRANSPOSE_BLOCK);
288
  return IsGemm() && !IsSpecGemm() && (trans_data_block == "Y");
C
ckey_Dou 已提交
289 290
}

291
bool CubeInfo::IsGemmWeightTransposeBlock() const {
C
ckey_Dou 已提交
292
  std::string trans_weight_block = ExtractStringFromAttrsAndInfo(ATTR_GEMM_WEIGHT_TRANSPOSE_BLOCK);
293
  return IsGemm() && !IsSpecGemm() && (trans_weight_block == "Y");
C
ckey_Dou 已提交
294 295
}

296
bool CubeInfo::IsGemmDataTransposeInnerBlock() const {
C
ckey_Dou 已提交
297
  std::string trans_data_inner_block = ExtractStringFromAttrsAndInfo(ATTR_GEMM_DATA_TRANSPOSE_BLOCK_INNER);
298
  return IsGemm() && !IsSpecGemm() && (trans_data_inner_block == "Y");
C
ckey_Dou 已提交
299
}
300
bool CubeInfo::IsGemmWeightTransposeInnerBlock() const {
C
ckey_Dou 已提交
301
  std::string trans_weight_inner_block = ExtractStringFromAttrsAndInfo(ATTR_GEMM_WEIGHT_TRANSPOSE_BLOCK_INNER);
302
  return IsGemm() && !IsSpecGemm() && (trans_weight_inner_block == "Y");
C
ckey_Dou 已提交
303
}
304
bool CubeInfo::IsGemmDataTranspose() const {
C
ckey_Dou 已提交
305
  std::string trans_data = ExtractStringFromAttrsAndInfo(ATTR_GEMM_DATA_TRANSPOSE);
306
  return IsGemm() && !IsSpecGemm() &&
C
ckey_Dou 已提交
307 308 309
         ((trans_data == "Y") || IsGemmDataTransposeBlock() || IsGemmDataTransposeInnerBlock());
}

310
bool CubeInfo::IsGemmWeightTranspose() const {
C
ckey_Dou 已提交
311
  std::string trans_weight = ExtractStringFromAttrsAndInfo(ATTR_GEMM_WEIGHT_TRANSPOSE);
312
  return IsGemm() && !IsSpecGemm() &&
C
ckey_Dou 已提交
313 314 315
         ((trans_weight == "Y") || IsGemmWeightTransposeBlock() || IsGemmWeightTransposeInnerBlock());
}

316
bool CubeInfo::IsGemm() const { return HasCube() && !IsConv(); }
C
ckey_Dou 已提交
317

318
bool CubeInfo::IsConv() const {
C
ckey_Dou 已提交
319 320 321 322
  std::string n = ExtractStringFromAttrs(ATTR_CONV_FEATURE_NAME);
  return (!n.empty());
}

323
void CubeInfo::UpdateComputeAttrInfo() {
C
ckey_Dou 已提交
324 325 326 327 328 329
  if (IsConv()) {
    FindComputeAttr(ConvATTRList);
  } else if (IsLoad3dL1Ub()) {
    FindComputeAttr(FastPoolingATTRList);
  }
}
330 331 332

void CubeInfo::FindComputeAttr(const std::vector<std::string> &op_keys) {
  for (auto i : analysis_result_.GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
333
    if (i.second.isCube || i.second.isLoad3d) {
334
      const Node *stmt_node = analysis_result_.GetStatementMap().at(i.first);
C
ckey_Dou 已提交
335 336 337 338 339 340 341 342
      if (stmt_node->IsInstance<Provide>()) {
        auto provide = static_cast<const Provide *>(stmt_node);
        const auto cop = provide->func.as<ComputeOpNode>();
        if (cop != nullptr) {
          for (auto j : op_keys) {
            std::string err = "Error: You need to set attr feature " + j + " at akg.tvm.compute()!";
            CHECK(cop->attrs.count(j) != 0) << err;
          }
343
          SetConvAttrInfo(cop->attrs);
C
ckey_Dou 已提交
344 345 346 347 348 349 350
        }
      }
      break;
    }
  }
}

351 352
std::string CubeInfo::ConvOutName() {
  for (auto stmt : analysis_result_.GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
353 354 355 356 357 358 359
    if (stmt.second.isCube) {
      return stmt.second.C_;
    }
  }
  return "";
}

360
bool CubeInfo::IsFilterCanByPass() {
C
ckey_Dou 已提交
361 362
  bool can_bypass = true;
  auto filter_name = ExtractStringFromAttrs(ATTR_CONV_FILTER_NAME);
363 364 365 366
  auto tensor_mem_flows = analysis_result_.GetTensorMemFlows();
  if (tensor_mem_flows.count(filter_name)) {
    auto filter_memflow = tensor_mem_flows[filter_name];
    auto it = std::find(filter_memflow.begin(), filter_memflow.end(), UBL1_);
C
ckey_Dou 已提交
367 368 369 370 371
    if (it != filter_memflow.end()) can_bypass = false;
  }
  return can_bypass;
}

372 373 374 375 376 377
Tensor ScopInfo::FindTensorInOrig(const isl::id &var) {
  auto binds_orig = user_config_.GetOriginBind();
  for (auto i : binds_orig) {
    if (i.first->op->name == var.get_name()) {
      return i.first;
    }
C
ckey_Dou 已提交
378
  }
379
  return Tensor();
C
ckey_Dou 已提交
380 381
}

382 383 384 385 386
Tensor ScopInfo::FindTensorInOrig(const std::string &str) {
  auto binds_orig = user_config_.GetOriginBind();
  for (auto i : binds_orig) {
    if (i.first->op->name == str) {
      return i.first;
C
ckey_Dou 已提交
387 388
    }
  }
389
  return Tensor();
C
ckey_Dou 已提交
390 391
}

392 393 394 395 396 397 398
// find the dtype of global buffer by the tensor name
Type ScopInfo::GetDtypeOf(const std::string &tensor_name) const {
  auto binds = user_config_.GetBind();
  for (auto i : binds) {
    if (i.first->op->name == tensor_name) {
      return i.second->dtype;
    }
C
ckey_Dou 已提交
399
  }
400 401
  LOG(INFO) << " no such tensor in binds: " << tensor_name;
  return Int(32);
C
ckey_Dou 已提交
402 403
}

404 405 406 407
Type ScopInfo::GetDtypeOf(const isl::ast_expr &e) const {
  if (auto op = e.as<isl::ast_expr_op>()) {
    isl::id var = op.get_arg(0).as<isl::ast_expr_id>().get_id();
    return GetDtypeOf(var);
C
ckey_Dou 已提交
408
  }
409
  return Int(32);
C
ckey_Dou 已提交
410 411
}

412 413 414 415 416
bool ScopInfo::IsInBinds(const std::string &name) const {
  auto binds_orig = user_config_.GetOriginBind();
  for (auto i : binds_orig) {
    if (name == i.first->op->name) {
      return true;
C
ckey_Dou 已提交
417 418
    }
  }
419
  return false;
C
ckey_Dou 已提交
420 421
}

422 423 424 425
air::DataType CubeInfo::MadCastType() {
  for (auto stmt : analysis_result_.GetStmtOpInfoMap()) {
    if (stmt.second.isCube) {
      return stmt.second.MadType_;
C
ckey_Dou 已提交
426 427
    }
  }
428
  return Float(16);
C
ckey_Dou 已提交
429 430
}

431 432 433 434 435 436
int CubeInfo::GetAttrValue(const std::string &key) {
  Map<std::string, NodeRef> attr_info = GetConvAttrInfo();
  CHECK(attr_info.find(key) != attr_info.end());
  if (attr_info[key].as<IntImm>() != nullptr) return attr_info[key].as<IntImm>()->value;
  if (attr_info[key].as<FloatImm>() != nullptr) {
    float res = attr_info[key].as<FloatImm>()->value;
C
ckey_Dou 已提交
437 438 439 440 441 442
    LOG(WARNING) << "attr: " << key << " : should be an integer, but found float. Force convert to int.";
    return static_cast<int>(res);
  }
  return -1;
}

443
Tensor ScopInfo::FindTensorWithLargestShape(const std::string &name) {
C
ckey_Dou 已提交
444 445
  size_t largest_size = 0;
  Tensor largest_tensor;
446
  for (auto i : analysis_result_.buffer_def_infos_) {
C
ckey_Dou 已提交
447 448 449 450 451 452 453 454 455 456 457 458 459 460
    if (!i.tensor.defined()) continue;
    if (i.dst_tensor_id.get_name() == name) {
      size_t tensor_size = 1;
      for (auto dim : i.tensor->shape) {
        if (dim.as<IntImm>()) {
          tensor_size *= dim.as<IntImm>()->value;
        }
      }
      if (tensor_size > largest_size) {
        largest_size = tensor_size;
        largest_tensor = i.tensor;
      }
    }
  }
461 462
  auto binds = user_config_.GetBind();
  for (auto i : binds) {
C
ckey_Dou 已提交
463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481
    if (!i.first.defined()) continue;
    if (i.first->op->name == name) {
      size_t tensor_size = 1;
      for (auto dim : i.first->shape) {
        if (dim.as<IntImm>()) {
          tensor_size *= dim.as<IntImm>()->value;
        }
      }
      if (tensor_size > largest_size) {
        largest_size = tensor_size;
        largest_tensor = i.first;
      }
    }
  }
  if (largest_size > 0) return largest_tensor;
  CHECK(false) << name << " is not declared in binds and promoted arrays";
  return Tensor();
}

482
Tensor ScopInfo::FindTensorWithLargestShape(const isl::id &var) { return FindTensorWithLargestShape(var.get_name()); }
C
ckey_Dou 已提交
483

484 485
Tensor ScopInfo::FindTensor(const std::string &str) {
  for (auto i : analysis_result_.buffer_def_infos_) {
C
ckey_Dou 已提交
486 487 488 489
    if (str == i.dst_tensor_id.get_name() && i.is_bind_tensor && i.tensor.defined()) {
      return i.tensor;
    }
  }
490 491
  auto binds = user_config_.GetBind();
  for (auto i : binds) {
C
ckey_Dou 已提交
492 493 494 495 496 497 498 499
    if (i.first->op->name == str) {
      return i.first;
    }
  }
  CHECK(false) << str << " is not declared in binds and promoted arrays";
  return Tensor();
}

500 501
Tensor ScopInfo::FindTensor(const isl::id &var) {
  for (const auto &i : analysis_result_.buffer_def_infos_) {
C
ckey_Dou 已提交
502 503 504 505
    if (i.dst_tensor_id.get_name() == var.get_name() && i.is_bind_tensor && i.tensor.defined()) {
      return i.tensor;
    }
  }
506 507
  auto binds = user_config_.GetBind();
  for (const auto &i : binds) {
C
ckey_Dou 已提交
508 509 510 511 512 513 514 515
    if (i.first->op->name == var.get_name()) {
      return i.first;
    }
  }
  CHECK(false) << var.to_str() << " is not declared in binds and promoted arrays";
  return Tensor();
}

516
isl::id ScopInfo::GetOriginTensorId(const std::string &name) const {
C
ckey_Dou 已提交
517 518 519 520 521
  std::string tensor_name = name;
  size_t pos = name.find("_local_");
  if (std::string::npos != pos) {
    tensor_name = name.substr(0, pos);
  }
522
  return isl::id(GetCtx(), tensor_name);
C
ckey_Dou 已提交
523 524
}

525
isl::id ScopInfo::GetOriginTensorId(const isl::id &id) const { return GetOriginTensorId(id.get_name()); }
C
ckey_Dou 已提交
526

527 528
bool CubeInfo::InitRangeStrideVec() {
  if (!GetRangeStride().empty()) return false;
C
ckey_Dou 已提交
529

530
  if (GetRangeInfo().empty()) {
C
ckey_Dou 已提交
531 532 533 534
    LOG(WARNING) << "range_info is not specified, please check";
    return false;
  }

535 536 537
  RecordRangeStrideBack(1);
  for (uint64_t i = GetRangeInfo().size(); i >= 1; --i) {
    RecordRangeStrideFront(GetRangeInfo()[i - 1].size() * (unsigned int)GetRangeStride()[0]);
C
ckey_Dou 已提交
538 539 540 541
  }
  return true;
}

542
std::vector<int> CubeInfo::GetIsolateVec(int range_idx) {
C
ckey_Dou 已提交
543 544
  static_cast<void>(InitRangeStrideVec());
  std::vector<int> idx;
545 546 547 548
  for (unsigned int i = 0; i < GetRangeStride().size() - 1; i++) {
    CHECK_NE(GetRangeStride()[i], 0);
    CHECK_NE(GetRangeStride()[i + 1], 0);
    idx.push_back(range_idx % GetRangeStride()[i] / GetRangeStride()[i + 1]);
C
ckey_Dou 已提交
549 550 551 552
  }
  return idx;
}

553
std::vector<Range> CubeInfo::GetRange(int range_idx) {
C
ckey_Dou 已提交
554 555
  std::vector<int> idx = GetIsolateVec(range_idx);
  std::vector<Range> res;
556
  CHECK(idx.size() == GetRangeInfo().size());
C
ckey_Dou 已提交
557
  for (unsigned int i = 0; i < idx.size(); i++) {
558
    res.push_back(GetRangeInfo()[i][(unsigned int)idx[i]]);
C
ckey_Dou 已提交
559 560 561 562
  }
  return res;
}

563
std::unordered_map<std::string, Expr> CubeInfo::GetConvInfoForTiling() {
C
ckey_Dou 已提交
564 565 566 567 568 569 570 571 572 573 574 575 576 577
  std::unordered_map<std::string, Expr> conv_info;
  conv_info[ATTR_CONV_FEATURE_H] = this->ExtractExprFromAttrs(ATTR_CONV_FEATURE_H);
  conv_info[ATTR_CONV_FEATURE_W] = this->ExtractExprFromAttrs(ATTR_CONV_FEATURE_W);
  conv_info[ATTR_CONV_KERNEL_H] = this->ExtractExprFromAttrs(ATTR_CONV_KERNEL_H);
  conv_info[ATTR_CONV_KERNEL_W] = this->ExtractExprFromAttrs(ATTR_CONV_KERNEL_W);
  conv_info[ATTR_CONV_PAD_TOP] = this->ExtractExprFromAttrs(ATTR_CONV_PAD_TOP);
  conv_info[ATTR_CONV_PAD_LEFT] = this->ExtractExprFromAttrs(ATTR_CONV_PAD_LEFT);
  conv_info[ATTR_CONV_STRIDE_H] = this->ExtractExprFromAttrs(ATTR_CONV_STRIDE_H);
  conv_info[ATTR_CONV_STRIDE_W] = this->ExtractExprFromAttrs(ATTR_CONV_STRIDE_W);
  conv_info[ATTR_CONV_DILATION_H] = this->ExtractExprFromAttrs(ATTR_CONV_DILATION_H);
  conv_info[ATTR_CONV_DILATION_W] = this->ExtractExprFromAttrs(ATTR_CONV_DILATION_W);
  return conv_info;
}

578 579 580 581
void CubeInfo::SetConvMNKInfo() {
  TileSizes &dimInfos_conv = analysis_result_.GetTileSizes();
  TileSizes L1_factors;
  TileSizes L0_factors;
C
ckey_Dou 已提交
582 583 584 585 586 587 588 589 590 591 592 593 594

  std::unordered_set<std::string> conv_pragmas = {
    ATTR_CONV_TILE_W, ATTR_CONV_TILE_H,  ATTR_CONV_TILE_CO, ATTR_CONV_TILE_M,  ATTR_CONV_TILE_N,
    ATTR_CONV_TILE_K, ATTR_CONV_M_INNER, ATTR_CONV_N_INNER, ATTR_CONV_K_INNER, ATTR_CONV_TILE_CIN,
    ATTR_CONV_TILE_B, ATTR_CONV_TILE_KH, ATTR_CONV_TILE_KW};

  for (auto dim : dimInfos_conv) {
    if (conv_pragmas.find(dim.axis) != conv_pragmas.end()) {
      L0_factors.emplace_back(dim);
    } else {
      L1_factors.emplace_back(dim);
    }
  }
595 596 597 598 599
  analysis_result_.SetTileSizes(L1_factors);
  SetConvMNKDims(L0_factors);
  auto conv_mnk_dims = GetConvMNKDims();
  if (user_config_.GetIsDynamic()) {
    for (const auto &dim : conv_mnk_dims) {
C
ckey_Dou 已提交
600 601 602 603 604 605
      fractal_int_info_[dim.axis] = IntImm::make(Int(32), dim.l1_tiling_size);
      attr_info_.Set(dim.axis, IntImm::make(Int(32), dim.l1_tiling_size));
    }
  } else {
    const int c0_size = 16;
    const int int_imm_num_bits = 32;
606
    for (const auto &dim : conv_mnk_dims) {
C
ckey_Dou 已提交
607 608 609 610 611 612 613 614 615 616 617
      int l0tile = static_cast<int>(dim.l0_tiling_size);
      if (dim.axis == ATTR_CONV_TILE_M || dim.axis == ATTR_CONV_TILE_N || dim.axis == ATTR_CONV_TILE_K) {
        // multiply outer tile size with inner size
        l0tile *= c0_size;
      }
      fractal_int_info_[dim.axis] = l0tile;
      attr_info_.Set(dim.axis, IntImm::make(Int(int_imm_num_bits), l0tile));
    }
  }
}

618
void UserConfig::CollectParams() {
C
ckey_Dou 已提交
619
  auto FloorDivToDiv = [](Expr expr) -> Expr {
620
    if (const auto add = expr.as<air::ir::Add>()) {
C
ckey_Dou 已提交
621 622 623
      // case 1: floordiv(a, b) + 1 ==> (a + b) / b
      if (const auto imm = add->b.as<IntImm>()) {
        if (imm->value == 1) {
624
          if (const auto fd = add->a.as<air::ir::FloorDiv>()) {
C
ckey_Dou 已提交
625 626
            if (const auto denominator = fd->b.as<IntImm>()) {
              if (denominator->value == 2) {
627
                return CanonicalSimplify(air::ir::Div::make((fd->a + fd->b), fd->b));
C
ckey_Dou 已提交
628 629
              }
            }
630
            return air::ir::Div::make(CanonicalSimplify(fd->a), fd->b) + 1;
C
ckey_Dou 已提交
631 632 633 634 635 636
          }
        }
      }
    }
    return expr;
  };
637 638
  auto binds_orig = GetOriginBind();
  for (auto x : binds_orig) {
C
ckey_Dou 已提交
639 640 641 642 643 644
    for (const auto &expr : x.second->shape) {
      if (!is_const(expr)) {
        RegisterParam(FloorDivToDiv(expr));
      }
    }
  }
645 646
  auto outer_let_stmts = GetOuterLetStmts();
  for (auto it : outer_let_stmts) {
C
ckey_Dou 已提交
647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678
    if (auto let_op = it.as<LetStmt>()) {
      if (let_op->var.type().is_int() || let_op->var.type().is_uint()) {
        CHECK(params_.count(let_op->var->name_hint) == 0) << "duplicate name in params: " << let_op->var;
        params_.emplace(let_op->var->name_hint, let_op->var);
        params_rev_map_.emplace(let_op->var->name_hint, let_op->var);
      }
    }
  }
}

std::pair<std::string, std::string> ExprToString(const Expr &expr) {
  std::ostringstream os;
  if (auto var = expr.as<Variable>()) {
    os << var->name_hint;
  } else {
    os << expr;
  }
  std::string expr_str = os.str();

  std::string name = expr_str;
  // replace special chars with '_'
  std::replace_if(
    name.begin(), name.end(), [](const char c) -> bool { return !std::isalnum(c); }, '_');
  // remove leading '_'
  auto it = std::find_if(name.begin(), name.end(), [](const char c) { return c != '_'; });
  name.erase(name.begin(), it);
  // remove redundant '_'
  std::regex rx("_+");
  name = std::regex_replace(name, rx, "_");
  return std::pair<std::string, std::string>(expr_str, name);
}

679
void UserConfig::RegisterParam(const Expr &expr) {
C
ckey_Dou 已提交
680
  if (is_const(expr)) return;
681
  if (auto op = expr.as<air::ir::Mul>()) {
C
ckey_Dou 已提交
682 683 684 685 686 687 688 689
    if (is_const(op->a)) {
      RegisterParam(op->b);
      return;
    }
    if (is_const(op->b)) {
      RegisterParam(op->a);
      return;
    }
690 691 692
  } else if (auto add = expr.as<air::ir::Add>()) {
    RegisterParam(add->a);
    RegisterParam(add->b);
C
ckey_Dou 已提交
693
    return;
694 695 696
  } else if (auto sub = expr.as<air::ir::Sub>()) {
    RegisterParam(sub->a);
    RegisterParam(sub->b);
C
ckey_Dou 已提交
697
    return;
698 699 700
  } else if (auto floodiv = expr.as<air::ir::FloorDiv>()) {
    RegisterParam(floodiv->a);
    RegisterParam(floodiv->b);
C
ckey_Dou 已提交
701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717
    return;
  }

  // register the expression itself
  auto pair = ExprToString(expr);
  auto expr_str = pair.first;
  auto name = pair.second;
  if (params_.count(expr_str) > 0) return;
  if (params_rev_map_.count(name) > 0) {
    int suffix = 1;
    while (params_rev_map_.count(name + std::to_string(suffix)) > 0) ++suffix;
    name = name + std::to_string(suffix);
  }
  params_.emplace(expr_str, Variable::make(expr.type(), name));
  params_rev_map_.emplace(name, expr);
}

718 719
void CubeInfo::CreateConvModel() {
  if (model_) return;
C
ckey_Dou 已提交
720 721 722
  if (!attr_info_.empty()) {
    if (attr_info_.count(ATTR_CONV_BACKPROP_INPUT) > 0) {
      try {
723
        model_ = new ConvolutionBackpropInputModel(attr_info_, user_config_.GetIsDynamic());
C
ckey_Dou 已提交
724 725 726 727 728
      } catch (const std::bad_alloc &) {
        LOG(FATAL) << "bad_alloc exception occurred when constructing ConvolutionBackpropInputModel";
      }
    } else if (attr_info_.count(ATTR_CONV_BACKPROP_FILTER) > 0) {
      try {
729
        model_ = new ConvolutionBackpropFilterModel(attr_info_, user_config_.GetIsDynamic());
C
ckey_Dou 已提交
730 731 732 733 734
      } catch (const std::bad_alloc &) {
        LOG(FATAL) << "bad_alloc exception occurred when constructing ConvolutionBackpropFilterModel";
      }
    } else {
      try {
735
        model_ = new ConvolutionForwardModel(attr_info_, user_config_.GetIsDynamic());
C
ckey_Dou 已提交
736 737 738 739 740 741 742 743 744 745
      } catch (const std::bad_alloc &) {
        LOG(FATAL) << "bad_alloc exception occurred when constructing ConvolutionForwardModel";
      }
    }
    if (model_) {
      static_cast<void>(model_->infer_L1_tile());
    }
  }
}

746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824
void CubeInfo::UpdateFractalIntFirstInfo(bool is_conv_backprop_filter,
                                         const std::vector<size_t> &im2col_fp_cluster_size,
                                         const std::vector<size_t> &fractal_fp_cluster_size) {
  if (is_conv_backprop_filter) {
    UpdateFractalIntFirstInfoConvBackpropFilter(im2col_fp_cluster_size, fractal_fp_cluster_size);
  } else {
    UpdateFractalIntFirstInfoConvForward(im2col_fp_cluster_size, fractal_fp_cluster_size);
  }
}

void CubeInfo::UpdateFractalIntLastInfo(std::vector<size_t> filter_fp_cluster_size) {
  if (IsConvBackpropInput()) {
    CHECK_EQ(filter_fp_cluster_size.size(), 4);
    // conv_backprop_input filter: [ko, no, ni, ki]
    int64_t kh = ExtractIntFromAttrs(ATTR_CONV_KERNEL_H);
    int64_t kw = ExtractIntFromAttrs(ATTR_CONV_KERNEL_W);
    fractal_int_info_[ATTR_CONV_TILE_CO] = (int64_t)filter_fp_cluster_size[0] / (kh * kw);
    fractal_int_info_[ATTR_CONV_TILE_N] = (int64_t)filter_fp_cluster_size[0] / (kh * kw);

    fractal_int_info_[ATTR_CONV_N_INNER] = (int64_t)filter_fp_cluster_size[2];
  } else if (IsConvBackpropFilter()) {
    CHECK_EQ(filter_fp_cluster_size.size(), 5);
    // conv_backprop_filter filter: [batch, no, mo, ni, mi]
    fractal_int_info_[ATTR_CONV_TILE_M] = (int64_t)filter_fp_cluster_size[1];
    fractal_int_info_[ATTR_CONV_M_INNER] = (int64_t)filter_fp_cluster_size[3];
    fractal_int_info_[ATTR_CONV_GMM_M] = (int64_t)filter_fp_cluster_size[1] * filter_fp_cluster_size[3];
  } else {
    CHECK_EQ(filter_fp_cluster_size.size(), 4);
    // conv_forward filter: [ko, no, ni, ki]
    fractal_int_info_[ATTR_CONV_TILE_CO] = (int64_t)filter_fp_cluster_size[1];
    fractal_int_info_[ATTR_CONV_TILE_N] = (int64_t)filter_fp_cluster_size[1];
    fractal_int_info_[ATTR_CONV_N_INNER] = (int64_t)filter_fp_cluster_size[2];
  }
}

void CubeInfo::UpdateSpecGemmFractalInfo(const BufferDefInfo &tensor_info) {
  if (IsConv() && IsB(tensor_info.tensor_id.get_name())) {
    CHECK(tensor_info.footprints_cluster != nullptr);
    UpdateFractalIntLastInfo(tensor_info.footprints_cluster->GetFixedBoxSizes());
    fractal_str_info_[ATTR_CONV_GMM_WEIGHT] = tensor_info.dst_tensor_id.get_name();
    CHECK_NE(tensor_info.dst_tensor_id.get_name(), "");
  } else if (IsConv() && IsA(tensor_info.tensor_id.get_name())) {
    fractal_str_info_[ATTR_CONV_GMM_FEATURE] = tensor_info.data_stream[2].first.get_name();
    CHECK_NE(tensor_info.dst_tensor_id.get_name(), "");
  } else if (IsConv() && IsC(tensor_info.tensor_id.get_name())) {
    fractal_str_info_[ATTR_CONV_GMM_RES] = tensor_info.dst_tensor_id.get_name();
    CHECK_NE(tensor_info.dst_tensor_id.get_name(), "");
  }
}

void CubeInfo::UpdateFractalIntFirstInfoConvBackpropFilter(std::vector<size_t> im2col_fp_cluster_size,
                                                           std::vector<size_t> fractal_fp_cluster_size) {
  CHECK_EQ(fractal_fp_cluster_size.size(), 5);
  fractal_int_info_[ATTR_CONV_BATCH] = (int64_t)fractal_fp_cluster_size[0];
  fractal_int_info_[ATTR_CONV_TILE_K] = (int64_t)fractal_fp_cluster_size[1];
  fractal_int_info_[ATTR_CONV_TILE_N] = (int64_t)fractal_fp_cluster_size[2];
  fractal_int_info_[ATTR_CONV_N_INNER] = (int64_t)fractal_fp_cluster_size[3];
  fractal_int_info_[ATTR_CONV_K_INNER] = (int64_t)fractal_fp_cluster_size[4];

  fractal_int_info_[ATTR_CONV_TILE_CO] = (int64_t)fractal_fp_cluster_size[2];

  CHECK_EQ(im2col_fp_cluster_size.size(), 6);
  fractal_int_info_[ATTR_CONV_GMM_K] = (int64_t)im2col_fp_cluster_size[1];
}

void CubeInfo::UpdateFractalIntFirstInfoConvForward(std::vector<size_t> im2col_fp_cluster_size,
                                                    std::vector<size_t> fractal_fp_cluster_size) {
  CHECK_EQ(fractal_fp_cluster_size.size(), 5);
  fractal_int_info_[ATTR_CONV_BATCH] = (int64_t)fractal_fp_cluster_size[0];
  fractal_int_info_[ATTR_CONV_TILE_M] = (int64_t)fractal_fp_cluster_size[1];
  fractal_int_info_[ATTR_CONV_TILE_K] = (int64_t)fractal_fp_cluster_size[2];
  fractal_int_info_[ATTR_CONV_M_INNER] = (int64_t)fractal_fp_cluster_size[3];
  fractal_int_info_[ATTR_CONV_K_INNER] = (int64_t)fractal_fp_cluster_size[4];

  CHECK_EQ(im2col_fp_cluster_size.size(), 6);
  fractal_int_info_[ATTR_CONV_GMM_M] = (int64_t)im2col_fp_cluster_size[1];
}

void CubeInfo::UpdateFractalIntInfoConvForward(int isolate_idx) {
C
ckey_Dou 已提交
825 826 827 828 829 830 831 832
  auto C0_SIZE = IntImm::make(Int(32), 16);
  fractal_int_info_[ATTR_CONV_TILE_N] = floordiv(model_->get_co_isolate_info(isolate_idx).inner, C0_SIZE);

  Expr m = model_->get_h_win_isolate_info(isolate_idx).inner * model_->get_w_win_isolate_info(isolate_idx).inner;
  fractal_int_info_[ATTR_CONV_GMM_M] = m;
  fractal_int_info_[ATTR_CONV_TILE_M] = floordiv(m + C0_SIZE - 1, C0_SIZE);
  fractal_int_info_[ATTR_CONV_M_INNER] = C0_SIZE;
  fractal_int_info_[ATTR_CONV_M_CUT_SIZE] = model_->get_w_win_isolate_info(isolate_idx).inner;
833
  if (!user_config_.GetIsDynamic()) {
C
ckey_Dou 已提交
834 835 836
    if (IsConvBackpropInput()) {
      CHECK(model_->conv_.filter.kh.as<IntImm>());
      CHECK(model_->conv_.filter.kw.as<IntImm>());
837 838
      user_config_.SetMatBDimH(model_->conv_.filter.kh.as<IntImm>()->value);
      user_config_.SetMatBDimW(model_->conv_.filter.kw.as<IntImm>()->value);
C
ckey_Dou 已提交
839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854
    }
  } else {
    auto tile_h = ExtractExprFromAttrs(ATTR_CONV_TILE_H);
    tile_h = tile_h.get() ? tile_h : IntImm::make(Int(32), ExtractIntFromAttrs(ATTR_CONV_TILE_H));
    if (!Equal(tile_h, -1)) fractal_int_info_[ATTR_CONV_TILE_H] = tile_h;
    auto tile_w = ExtractExprFromAttrs(ATTR_CONV_TILE_W);
    tile_w = tile_w.get() ? tile_w : IntImm::make(Int(32), ExtractIntFromAttrs(ATTR_CONV_TILE_W));
    if (!Equal(tile_w, -1)) fractal_int_info_[ATTR_CONV_TILE_W] = tile_w;

    fractal_int_info_[ATTR_CONV_KERNEL_H] = IntImm::make(Int(32), ExtractIntFromAttrs(ATTR_CONV_KERNEL_H));
    fractal_int_info_[ATTR_CONV_STRIDE_H] = IntImm::make(Int(32), ExtractIntFromAttrs(ATTR_CONV_STRIDE_H));
    fractal_int_info_[ATTR_CONV_KERNEL_W] = IntImm::make(Int(32), ExtractIntFromAttrs(ATTR_CONV_KERNEL_W));
    fractal_int_info_[ATTR_CONV_STRIDE_W] = IntImm::make(Int(32), ExtractIntFromAttrs(ATTR_CONV_STRIDE_W));
  }
}

855
void CubeInfo::UpdateFractalIntInfoConvBackpropFilter(int isolate_idx) {
C
ckey_Dou 已提交
856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938
  // gemm_idx order as follow:
  // for (Ci Cut) {
  //   for (KH Cut) {
  //     for (KW Cut) {
  //       for (Co Cut) {
  //         for (Batch Cut) {
  //           for (H Cut) {
  //             for (W Cut) {
  //             }
  //           }
  //         }
  //       }
  //     }
  //   }
  // }

  const int block_size = 16;

  fractal_int_info_[ATTR_SPEC_GEMM_BATCH] = model_->get_b_isolate_info(isolate_idx).inner;
  fractal_int_info_[ATTR_SPEC_GEMM_M] = model_->get_co_isolate_info(isolate_idx).inner;
  CHECK_EQ(fractal_int_info_[ATTR_SPEC_GEMM_M].as<IntImm>()->value % block_size, 0);
  fractal_int_info_[ATTR_SPEC_GEMM_M_ALIGN] = fractal_int_info_[ATTR_SPEC_GEMM_M];
  CHECK(fractal_int_info_[ATTR_SPEC_GEMM_M_ALIGN].as<IntImm>());
  CHECK(model_->tile_.cut_m.as<IntImm>());
  if (fractal_int_info_[ATTR_SPEC_GEMM_M_ALIGN].as<IntImm>()->value < model_->tile_.cut_m.as<IntImm>()->value) {
    fractal_int_info_[ATTR_SPEC_GEMM_TILE_M] = fractal_int_info_[ATTR_SPEC_GEMM_M_ALIGN];
  } else {
    fractal_int_info_[ATTR_SPEC_GEMM_TILE_M] = model_->tile_.cut_m;
  }
  fractal_int_info_[ATTR_SPEC_GEMM_M_ALIGN] =
    fractal_int_info_[ATTR_SPEC_GEMM_M_ALIGN].as<IntImm>()->value / block_size;
  fractal_int_info_[ATTR_SPEC_GEMM_M_INNER] = block_size;
  fractal_int_info_[ATTR_CONV_TILE_M] = fractal_int_info_[ATTR_SPEC_GEMM_M_ALIGN];
  fractal_int_info_[ATTR_CONV_M_INNER] = block_size;

  CHECK(model_->get_h_win_isolate_info(isolate_idx).inner.as<IntImm>());
  CHECK(model_->get_w_win_isolate_info(isolate_idx).inner.as<IntImm>());
  int h_tile = model_->get_h_win_isolate_info(isolate_idx).inner.as<IntImm>()->value;
  int w_tile = model_->get_w_win_isolate_info(isolate_idx).inner.as<IntImm>()->value;
  fractal_int_info_[ATTR_SPEC_GEMM_K] = h_tile * w_tile;
  fractal_int_info_[ATTR_SPEC_GEMM_K_ALIGN] = (h_tile * w_tile + block_size - 1) / block_size * block_size;
  CHECK(fractal_int_info_[ATTR_SPEC_GEMM_K_ALIGN].as<IntImm>());
  CHECK(model_->tile_.cut_k.as<IntImm>());
  if (fractal_int_info_[ATTR_SPEC_GEMM_K_ALIGN].as<IntImm>()->value < model_->tile_.cut_k.as<IntImm>()->value) {
    fractal_int_info_[ATTR_SPEC_GEMM_TILE_K] = fractal_int_info_[ATTR_SPEC_GEMM_K_ALIGN];
  } else {
    fractal_int_info_[ATTR_SPEC_GEMM_TILE_K] = model_->tile_.cut_k;
  }
  fractal_int_info_[ATTR_SPEC_GEMM_K_ALIGN] =
    fractal_int_info_[ATTR_SPEC_GEMM_K_ALIGN].as<IntImm>()->value / block_size;
  fractal_int_info_[ATTR_SPEC_GEMM_K_INNER] = block_size;
  fractal_int_info_[ATTR_CONV_TILE_K] = fractal_int_info_[ATTR_SPEC_GEMM_K_ALIGN];
  fractal_int_info_[ATTR_CONV_K_INNER] = block_size;

  CHECK(model_->get_ci_isolate_info(isolate_idx).inner.as<IntImm>());
  CHECK(model_->get_kh_isolate_info(isolate_idx).inner.as<IntImm>());
  CHECK(model_->get_kw_isolate_info(isolate_idx).inner.as<IntImm>());
  int ci_tile = model_->get_ci_isolate_info(isolate_idx).inner.as<IntImm>()->value;
  int kh_tile = model_->get_kh_isolate_info(isolate_idx).inner.as<IntImm>()->value;
  int kw_tile = model_->get_kw_isolate_info(isolate_idx).inner.as<IntImm>()->value;
  fractal_int_info_[ATTR_SPEC_GEMM_N] = ci_tile * kh_tile * kw_tile;
  CHECK_EQ(fractal_int_info_[ATTR_SPEC_GEMM_N].as<IntImm>()->value % block_size, 0);
  fractal_int_info_[ATTR_SPEC_GEMM_N_ALIGN] = fractal_int_info_[ATTR_SPEC_GEMM_N];
  CHECK(fractal_int_info_[ATTR_SPEC_GEMM_N_ALIGN].as<IntImm>());
  CHECK(model_->tile_.cut_n.as<IntImm>());
  if (fractal_int_info_[ATTR_SPEC_GEMM_N_ALIGN].as<IntImm>()->value < model_->tile_.cut_n.as<IntImm>()->value) {
    fractal_int_info_[ATTR_SPEC_GEMM_TILE_N] = fractal_int_info_[ATTR_SPEC_GEMM_N_ALIGN];
  } else {
    fractal_int_info_[ATTR_SPEC_GEMM_TILE_N] = model_->tile_.cut_n;
  }
  fractal_int_info_[ATTR_SPEC_GEMM_N_ALIGN] =
    fractal_int_info_[ATTR_SPEC_GEMM_N_ALIGN].as<IntImm>()->value / block_size;
  fractal_int_info_[ATTR_SPEC_GEMM_N_INNER] = block_size;
  fractal_int_info_[ATTR_CONV_TILE_N] = fractal_int_info_[ATTR_SPEC_GEMM_N_ALIGN];
  fractal_int_info_[ATTR_CONV_N_INNER] = block_size;

  out_reduce_init_ = 0;
  int l1_reduce_base = model_->b_base * model_->h_base * model_->w_base;
  if ((l1_reduce_base > 1 && isolate_idx % l1_reduce_base == 0) || (l1_reduce_base == 1)) {
    out_reduce_init_ = 1;
  }
}

939
void CubeInfo::UpdateFractalIntInfo(int gemm_idx) {
C
ckey_Dou 已提交
940
  if (IsConvBackpropFilter()) {
941 942 943
    if (!user_config_.GetIsDynamic()) {
      UpdateFractalIntInfoConvBackpropFilter(gemm_idx);
    }
C
ckey_Dou 已提交
944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959
  } else {
    UpdateFractalIntInfoConvForward(gemm_idx);
  }
}

static bool CompareFootprintOfMaps(const isl::map &local_access, const isl::map &global_access) {
  isl::multi_val local_write_footprint = local_access.range_simple_fixed_box_hull().size();
  isl::multi_val global_write_footprint = global_access.range_simple_fixed_box_hull().size();
  if (local_write_footprint.size() != global_write_footprint.size()) return false;
  unsigned int dim = local_write_footprint.size();
  for (unsigned i = 0; i < dim; ++i) {
    if (local_write_footprint.get_val(i) < global_write_footprint.get_val(i)) return false;
  }
  return true;
}

960 961
bool ScopInfo::IsWriteWholeBufferFootPrint(const isl::id &poly_ref_id) const {
  for (const auto &buffer : analysis_result_.active_buffer_footprints_) {
C
ckey_Dou 已提交
962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977
    auto group = buffer.second.cluster;
    for (const auto &reference : group->tensor_foot_prints) {
      if (reference->id == poly_ref_id) {
        CHECK(reference->type == ReferenceType::Write);
        return CompareFootprintOfMaps(reference->scoped_access, group->RichWriteRelations());
      }
    }
  }
  LOG(WARNING) << "buffer for " << poly_ref_id << " is not found";
  return false;
}

/*
 * Checks if a promoted tensor is written conditionally, and there is no other unconditional statement
 * in the same buffer that writes the whole promoted tensor.
 */
978 979
bool ScopInfo::IsConditionalWriteTensor(const std::string &name,
                                        const std::vector<std::pair<isl::id, isl::id>> &write_stmts) const {
C
ckey_Dou 已提交
980 981 982 983 984
  bool has_conditional_write = false;
  bool has_unconditional_full_write = false;
  for (const auto &pair : write_stmts) {
    auto stmt_id = pair.first;
    auto poly_ref_id = pair.second;
985 986
    CHECK_GT(analysis_result_.GetStatementMap().count(stmt_id), 0);
    const Node *stmt = analysis_result_.GetStatementMap().at(stmt_id);
C
ckey_Dou 已提交
987 988 989 990 991 992 993 994 995
    if (stmt->IsInstance<IfThenElse>()) {
      has_conditional_write = true;
    } else if (IsWriteWholeBufferFootPrint(poly_ref_id)) {
      has_unconditional_full_write = true;
    }
  }
  return has_conditional_write && !has_unconditional_full_write;
}

996
void ScopInfo::CollectConditionalWritePromotions() {
C
ckey_Dou 已提交
997
  std::unordered_map<std::string, std::vector<std::pair<isl::id, isl::id>>> tensor_write_stmts_map;
998
  analysis_result_.GetWrites().foreach_map([&tensor_write_stmts_map](const isl::map &map) -> void {
C
ckey_Dou 已提交
999 1000 1001 1002 1003 1004
    std::string tensor_name = map.get_tuple_id(isl_dim_out).name();
    isl::id stmt_id = map.domain().unwrap().get_tuple_id(isl_dim_in);
    isl::id poly_ref_id = map.domain().unwrap().get_tuple_id(isl_dim_out);
    tensor_write_stmts_map[tensor_name].push_back(std::make_pair(stmt_id, poly_ref_id));
  });

1005 1006
  auto binds_orig = user_config_.GetOriginBind();
  for (auto bind : binds_orig) {
C
ckey_Dou 已提交
1007 1008 1009 1010 1011
    auto name = bind.first->op->name;
    if (tensor_write_stmts_map.count(name) == 0) continue;
    if (IsConditionalWriteTensor(name, tensor_write_stmts_map[name])) {
      LOG(INFO) << "found conditionally written promoted tensor: " << name
                << ", buffer will be sinked to the computation.";
1012
      analysis_result_.InsertConditionalWriteBufferFootprints(name);
C
ckey_Dou 已提交
1013 1014 1015 1016
    }
  }
}

1017
StmtIdHashMap ScopInfo::StmtWriteMap() {
C
ckey_Dou 已提交
1018
  StmtIdHashMap stmt_write_map;
1019
  isl::union_map write_stmt = analysis_result_.GetWrites().domain_factor_domain();
C
ckey_Dou 已提交
1020 1021 1022 1023 1024 1025 1026 1027
  for (auto stmt : write_stmt.get_map_list()) {
    auto stmtId = stmt.domain().get_tuple_id();
    auto write_tensor = stmt.get_tuple_id(isl_dim_out);
    stmt_write_map[stmtId].push_back(write_tensor);
  }
  return stmt_write_map;
}

1028
StmtIdHashMap ScopInfo::StmtReadMap() {
C
ckey_Dou 已提交
1029
  StmtIdHashMap stmt_read_map;
1030
  isl::union_map read_stmt = analysis_result_.GetReads().domain_factor_domain();
C
ckey_Dou 已提交
1031 1032 1033 1034 1035 1036 1037 1038
  for (auto stmt : read_stmt.get_map_list()) {
    auto stmtId = stmt.domain().get_tuple_id();
    auto read_tensor = stmt.get_tuple_id(isl_dim_out);
    stmt_read_map[stmtId].push_back(read_tensor);
  }
  return stmt_read_map;
}

1039
StmtIdHashMap ScopInfo::StmtCopyinMap() {
C
ckey_Dou 已提交
1040
  StmtIdHashMap stmt_copyin_map;
1041
  isl::union_map copyin_stmt = analysis_result_.GetCopyin().domain_factor_domain();
C
ckey_Dou 已提交
1042 1043 1044 1045 1046 1047 1048 1049
  for (auto stmt : copyin_stmt.get_map_list()) {
    auto stmtId = stmt.domain().get_tuple_id();
    auto read_tensor = stmt.get_tuple_id(isl_dim_out);
    stmt_copyin_map[stmtId].push_back(read_tensor);
  }
  return stmt_copyin_map;
}

1050
bool ScopInfo::IsCopyinTensor(const std::string &tensor_name) {
C
ckey_Dou 已提交
1051 1052 1053 1054 1055 1056 1057 1058 1059 1060
  CHECK_NE(tensor_name, "");
  StmtIdHashMap copyin_map = StmtCopyinMap();
  for (const auto &item : copyin_map) {
    for (const auto &tensor : item.second) {
      if (tensor.get_name() == tensor_name) return true;
    }
  }
  return false;
}

1061 1062
bool CubeInfo::IsConvHeadTail(const std::string &conv_output, const isl::id &stmtId, const StmtOpInfo &op_info,
                              const StmtIdHashMap &op_write_map) {
C
ckey_Dou 已提交
1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080
  if (!IsConv()) return false;

  if (op_info.isCube || op_info.isCubeAssign) return false;

  if (op_info.ops.size() != 1) return false;

  if (op_write_map.find(stmtId) == op_write_map.end()) return false;

  if (op_write_map.at(stmtId).size() != 1) return false;

  if (op_info.ops[0] == PolyOpType::broadcast || op_info.ops[0] == PolyOpType::assignment) {
    isl::id writeId = op_write_map.at(stmtId)[0];
    if (writeId.get_name() == conv_output) return true;
  }

  return false;
}

1081
void ScopInfo::CreateDataFlowInfo() {
C
ckey_Dou 已提交
1082 1083 1084
  StmtIdHashMap op_write_map = StmtWriteMap();
  StmtIdHashMap op_read_map = StmtReadMap();
  std::string conv_output;
1085 1086
  if (cube_info_.IsConv()) {
    conv_output = cube_info_.ConvOutName();
C
ckey_Dou 已提交
1087
  }
1088 1089
  uint64_t stmtNum = analysis_result_.GetStmtOpInfoMap().size();
  analysis_result_.stmt_type_.resize(stmtNum);
C
ckey_Dou 已提交
1090
  DMADataFlow dma_dataflow;
1091
  for (auto stmt : analysis_result_.GetStmtOpInfoMap()) {
C
ckey_Dou 已提交
1092 1093 1094 1095 1096 1097 1098 1099 1100
    std::string name = stmt.first.get_name();
    size_t pos = name.find("_");
    CHECK(pos != name.size() - 1);
    std::string subNum = name.substr(pos + 1, name.size() - pos - 1);
    char *endptr = nullptr;
    const int radix = 10;
    size_t num = strtol(subNum.c_str(), &endptr, radix);
    if (endptr == nullptr || *endptr != '\0') LOG(FATAL) << "failed to convert string " << subNum << " to number";

1101 1102
    if (cube_info_.IsConv() && cube_info_.IsConvHeadTail(conv_output, stmt.first, stmt.second, op_write_map)) {
      analysis_result_.stmt_type_[num] = std::make_pair(stmt.first.get_name(), STMT_OP_TYPE::VECTOR);
C
ckey_Dou 已提交
1103 1104 1105
      continue;
    }

1106 1107
    if (stmt.second.isCube && cube_info_.IsConv()) {
      analysis_result_.stmt_type_[num] = std::make_pair(stmt.first.get_name(), STMT_OP_TYPE::CUBE_CONV);
C
ckey_Dou 已提交
1108 1109 1110
      dma_dataflow.CreateStmtDataFlow(STMT_OP_TYPE::CUBE_CONV, stmt.first, stmt.second, op_read_map, op_write_map);
    }

1111 1112
    if (stmt.second.isCube && !cube_info_.IsConv()) {
      analysis_result_.stmt_type_[num] = std::make_pair(stmt.first.get_name(), STMT_OP_TYPE::CUBE_GEMM);
C
ckey_Dou 已提交
1113 1114 1115 1116
      dma_dataflow.CreateStmtDataFlow(STMT_OP_TYPE::CUBE_GEMM, stmt.first, stmt.second, op_read_map, op_write_map);
    }

    if (stmt.second.isIm2col || stmt.second.isLoad3d) {
1117
      analysis_result_.stmt_type_[num] = std::make_pair(stmt.first.get_name(), STMT_OP_TYPE::IM2COL_UB);
C
ckey_Dou 已提交
1118 1119 1120 1121
      dma_dataflow.CreateStmtDataFlow(STMT_OP_TYPE::IM2COL_UB, stmt.first, stmt.second, op_read_map, op_write_map);
    }

    if (!stmt.second.isCube && !stmt.second.isCubeAssign) {
1122
      analysis_result_.stmt_type_[num] = std::make_pair(stmt.first.get_name(), STMT_OP_TYPE::VECTOR);
C
ckey_Dou 已提交
1123 1124 1125 1126
      dma_dataflow.CreateStmtDataFlow(STMT_OP_TYPE::VECTOR, stmt.first, stmt.second, op_read_map, op_write_map);
    }

    if (stmt.second.isCubeAssign) {
1127
      analysis_result_.stmt_type_[num] = std::make_pair(stmt.first.get_name(), STMT_OP_TYPE::VECTOR);
C
ckey_Dou 已提交
1128 1129 1130
    }
  }
  dma_dataflow.FusionAnalysis();
1131 1132 1133 1134 1135 1136
  std::map<std::string, std::vector<std::string>> tensor_name_flows;
  std::map<std::string, MemFlow> tensor_mem_flows;
  dma_dataflow.OpDataflowInfo(tensor_name_flows, tensor_mem_flows);
  analysis_result_.SetTensorNameFlows(tensor_name_flows);
  analysis_result_.SetTensorMemFlows(tensor_mem_flows);
}
C
ckey_Dou 已提交
1137

1138 1139 1140 1141 1142 1143 1144 1145
void ScopInfo::AddPartitionInfoToData(const std::vector<std::vector<int>> &partition_info) {
  for (unsigned int i = 0; i < partition_info.size(); i++) {
    std::vector<Range> tmp;
    for (unsigned int j = 1; j < partition_info[i].size(); j++) {
      cube_info_.RecordRangeAt(i, Range(Expr(partition_info[i][j - 1]), Expr(partition_info[i][j])));
    }
    if (partition_info[i].size() == 1) {
      cube_info_.RecordRangeAt(i, Range(Expr(0), Expr(0)));
C
ckey_Dou 已提交
1146 1147
    }
  }
1148
}
C
ckey_Dou 已提交
1149

1150 1151 1152 1153 1154
void CubeInfo::ComputeByPassL1() {
  if (user_config_.GetByPassL1() == 0) {
    int value = ExtractIntFromAttrs(ATTR_CONV_BYPASS_L1);
    if (value >= 0) {
      user_config_.SetByPassL1(value);
C
ckey_Dou 已提交
1155 1156
    }
  }
1157 1158 1159
  if (!IsFilterCanByPass()) {
    user_config_.SetByPassL1(0);
  }
C
ckey_Dou 已提交
1160 1161
}

1162 1163 1164 1165 1166 1167
void GatherVars(const Expr &expr, std::unordered_set<Var, air::NodeHash, air::NodeEqual> *vset) {
  PostOrderVisit(expr, [&vset](const NodeRef &node) {
    if (node.as<Variable>()) {
      vset->insert(Downcast<Var>(node));
    }
  });
C
ckey_Dou 已提交
1168 1169 1170 1171 1172 1173 1174 1175 1176 1177
}

void GatherVarNames(const Expr &expr, CondVarsMap &cond_vars, const isl::id &id) {
  std::unordered_set<Var, NodeHash, NodeEqual> vars_in_cond;
  GatherVars(expr, &vars_in_cond);
  for (const auto &var : vars_in_cond) {
    cond_vars[id].insert(var->name_hint);
  }
}

1178
CondVarsMap AnalysisResult::GetCondVarsMap() {
C
ckey_Dou 已提交
1179
  CondVarsMap cond_vars;
1180
  for (const auto &pair : statements_) {
C
ckey_Dou 已提交
1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198
    const isl::id &id = pair.first;
    const Node *stmt = pair.second;
    CHECK(stmt);
    if (stmt->IsInstance<IfThenElse>()) {
      const auto op = static_cast<const IfThenElse *>(stmt);
      GatherVarNames(op->condition, cond_vars, id);
    } else if (stmt->IsInstance<Provide>()) {
      const auto op = static_cast<const Provide *>(stmt);
      PostOrderVisit(op->value, [&id, &cond_vars](const NodeRef &node) -> void {
        if (auto op = node.as<Select>()) {
          GatherVarNames(op->condition, cond_vars, id);
        }
      });
    }
  }
  return cond_vars;
}

1199 1200 1201 1202
const BufferDefInfo &AnalysisResult::GetBufferDefInfo(const isl::id &tensor_id) const {
  for (const auto &idx : BufferDefInfos()) {
    if (idx.dst_tensor_id.get_name() == tensor_id.get_name()) {
      return idx;
C
ckey_Dou 已提交
1203
    }
1204 1205 1206 1207 1208 1209 1210 1211 1212 1213
  }
  LOG(FATAL) << "Hoist footprint of tensor " << tensor_id << " has no buffer definition";
  return default_buffer_def_info_;
}

int AnalysisResult::CountBufferDefInfo(const isl::id &tensor_id) const {
  int num = 0;
  for (const auto &tensorIter : BufferDefInfos()) {
    if (tensorIter.dst_tensor_id.get_name() == tensor_id.get_name()) {
      num++;
C
ckey_Dou 已提交
1214 1215
    }
  }
1216
  return num;
C
ckey_Dou 已提交
1217 1218
}

1219 1220 1221 1222
bool AnalysisResult::HasBufferDefInfo(const isl::id &tensor_id) const {
  for (const auto &idx : BufferDefInfos()) {
    if (idx.dst_tensor_id.get_name() == tensor_id.get_name()) {
      return true;
C
ckey_Dou 已提交
1223 1224
    }
  }
1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301
  return false;
}

static std::string MemTypeToString(const MemType &memType) {
  switch (memType) {
    case MemType::UB_:
      return "UB";
    case MemType::L1_:
      return "L1";
    case MemType::UBL0_:
      return "UBL0";
    case MemType::UBL1_:
      return "UBL1";
    case MemType::L0A_:
      return "L0A";
    case MemType::L0B_:
      return "L0B";
    case MemType::L0C_:
      return "L0C";
    case MemType::DDR:
      return "GM";
    default:
      return "";
  }
}

std::string ScopInfo::GetIslReadName(const isl::id &cluster_id) {
  auto tensor_info = analysis_result_.GetBufferDefInfo(cluster_id);
  MemType memType = tensor_info.SrcMemType();
  return MemTypeToString(memType) + "read";
}

std::string ScopInfo::GetIslWriteName(const isl::id &cluster_id) {
  if (analysis_result_.HasBufferDefInfo(cluster_id)) {
    auto tensor_info = analysis_result_.GetBufferDefInfo(cluster_id);
    MemType memType = tensor_info.DstMemType();
    return MemTypeToString(memType) + "write";
  }
  return MemTypeToString(MemType::DDR) + "write";
}

std::string TensorMarkTag(MemType mem_type, MemFlow mem_flow) {
  /******************************
   *  This interface is used to convert tensor MemType to isl schedule tree mark_tag,
   *  used to record the extension position for each tensor in isl schedule tree.
   *  Now REALIZE_L1/REALIZE_L0/REALIZE_UB mark_tag is equal to its MemType.
   *  For mem_type is DDR, mark_tag is empty string "".
   * */
  switch (mem_type) {
    case MemType::L1_:
      if (mem_flow.size() == 3 && mem_flow[0] == MemType::DDR && mem_flow[1] == MemType::L1_ &&
          mem_flow[2] == MemType::UBL1_)
        return REALIZE_L1UBL1;
      return REALIZE_L1;
    case MemType::UB_:
      // ordinary conv condition no fusion
      if (mem_flow.size() == 3 && mem_flow[0] == MemType::DDR && mem_flow[1] == mem_type &&
          mem_flow[2] == MemType::L0C_)
        return REALIZE_L0;
      return REALIZE_UB;
    case MemType::L0A_:
      return REALIZE_L0;
    case MemType::L0B_:
      return REALIZE_L0;
    case MemType::L0C_:
      return REALIZE_L0;
    case MemType::UBL0_:
      return REALIZE_UBL0;
    case MemType::UBL1_:
      if (mem_flow.size() == 2 && mem_flow[0] == MemType::DDR && mem_flow[1] == MemType::UBL1_) return REALIZE_L1;
      return REALIZE_UBL1;
    case MemType::DDR:
      return "";
    default:
      LOG(FATAL) << "undefined mem_type " << mem_type;
      return "";
  }
C
ckey_Dou 已提交
1302 1303 1304 1305 1306
}

}  // namespace poly
}  // namespace ir
}  // namespace akg