scop.cc 8.5 KB
Newer Older
C
ckey_Dou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/**
 * Copyright 2019 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "poly/scop.h"

18 19
#include <fstream>

C
ckey_Dou 已提交
20
#include "poly/scop_builder.h"
21
#include "poly/poly_util.h"
C
ckey_Dou 已提交
22
#include "poly/cce_isl_emitter.h"
23 24
#include "poly/davinci_mgr_strategy.h"
#include "poly/schedule_pass_mgr.h"
C
ckey_Dou 已提交
25 26 27 28

namespace akg {
namespace ir {
namespace poly {
29 30 31 32 33 34 35 36 37 38 39 40
void Scop::ParseUserConfig(const Map<std::string, NodeRef> &attrs, const Map<Tensor, Buffer> &extern_buffer,
                           bool is_spec_gemm, bool is_tuning, bool is_dynamic) {
  info_.user_config_.SetAttrs(attrs);
  info_.user_config_.SetBind(extern_buffer);
  info_.user_config_.SetOriginBind(extern_buffer);
  info_.user_config_.SetIsTuning(is_tuning);
  info_.user_config_.SetDynamic(is_dynamic);

  info_.cube_info_.SetAttrs(attrs);
  info_.cube_info_.SetSpecGemm(is_spec_gemm);
  if (info_.cube_info_.IsSpecGemm()) {
    info_.cube_info_.SetConvAttrInfo(attrs);
C
ckey_Dou 已提交
41 42 43
  }
}

44 45
isl::set CreateParamsSet(ScopInfo &info) {
  auto space = CreateParamsSpace(info.GetCtx(), info.user_config_.GetParams());
C
ckey_Dou 已提交
46
  auto context = isl::set::universe(space);
47 48 49 50
  auto dynamic_shape = info.user_config_.GetDynamicShape();
  auto params = info.user_config_.GetParams();
  for (const auto &param : params) {
    isl::aff aff(isl::aff::param_on_domain(space, isl::id(info.GetCtx(), param.second->name_hint)));
C
ckey_Dou 已提交
51
    context = context & (aff > 0);
52 53 54 55 56 57 58
    if (dynamic_shape.empty()) {
      continue;
    }
    for (const auto &ds : dynamic_shape) {
      if (auto dsn = ds.as<air::DynamicShapeNode>()) {
        if (dsn->tensor_name == param.second->name_hint) {
          context = context & (aff < dsn->poly_upper_bound);
C
ckey_Dou 已提交
59 60 61 62 63 64 65 66
        }
      }
    }
  }
  return context;
}

isl::schedule Scop::GenIsl() {
67 68 69 70 71 72 73
  auto outer_let_stmts = info_.user_config_.GetOuterLetStmts();
  body_ = PeelOuterLetStmt(body_, outer_let_stmts);
  info_.user_config_.SetOuterLetStmts(outer_let_stmts);
  info_.user_config_.CollectParams();
  auto params = info_.user_config_.GetParams();
  if (!params.empty()) {
    auto mutator = ConsolidateExprMutator(params);
C
ckey_Dou 已提交
74 75 76
    body_ = mutator.Mutate(body_);

    Binds new_binds;
77 78
    auto binds = info_.user_config_.GetBind();
    for (auto &it : binds) {
C
ckey_Dou 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
      Array<Expr> shape = it.first->shape;
      for (size_t i = 0; i < shape.size(); ++i) {
        if (!is_const(shape[i])) {
          shape.Set(i, mutator.Mutate(shape[i]));
        }
      }
      Tensor t = TensorNode::make(shape, it.first->dtype, it.first->op, it.first->value_index);

      shape = it.second->shape;
      for (size_t i = 0; i < shape.size(); ++i) {
        if (!is_const(shape[i])) {
          shape.Set(i, mutator.Mutate(shape[i]));
        }
      }
      Buffer b = BufferNode::make(it.second->data, it.second->dtype, shape, it.second->strides, it.second->elem_offset,
                                  it.second->name, it.second->scope, it.second->data_alignment,
                                  it.second->offset_factor, it.second->buffer_type);

      new_binds.Set(t, b);
    }
99
    info_.user_config_.SetBind(new_binds);
C
ckey_Dou 已提交
100 101
  }

102 103
  isl::space param_space = CreateParamsSpace(ctx_, params);
  isl::set param_set = CreateParamsSet(info_);
C
ckey_Dou 已提交
104

105
  info_.user_config_.SetBody(body_);
C
ckey_Dou 已提交
106
  Stmt stmt = body_;
107 108 109 110 111 112
  // Make schedule
  isl::schedule schedule_tmp = MakeScheduleTree(param_space, param_set, stmt, info_);

  info_.CreateDataFlowInfo();
  info_.cube_info_.UpdateComputeAttrInfo();
  info_.cube_info_.ComputeByPassL1();
C
ckey_Dou 已提交
113 114 115
  return schedule_tmp;
}

116 117 118 119 120 121
isl::schedule Scop::Transform(const isl::schedule &input_schedule) {
  info_.user_config_.SetConsiderCoincidence(true);
  DavinciMgrStrategy davinci_strategy(info_);
  SchedulePassMgr mgr(info_);
  auto final_schedule = mgr.Run(input_schedule, davinci_strategy);
  info_.DumpTransform("davinci_transfrom.log", davinci_strategy.pass_info_);
C
ckey_Dou 已提交
122

123 124 125 126 127 128 129
  // We offer a restart mechanism for scalar stmt that cannot tile: do not consider coincidence
  // and re-compute/re-tile to generate final schedule.
  if (mgr.need_restart_) {
    info_.user_config_.SetConsiderCoincidence(false);
    DavinciMgrStrategy scalar_strategy(info_);
    final_schedule = mgr.Run(input_schedule, scalar_strategy);
    info_.DumpTransform("scalar_transform.log", scalar_strategy.pass_info_);
C
ckey_Dou 已提交
130 131
  }

132 133
  if (final_schedule.get()) info_.analysis_result_.SetTranstormedSchedule(final_schedule);
  return final_schedule;
C
ckey_Dou 已提交
134 135
}

136 137
isl::id_list CreateIteratorList(const isl::schedule &schedule_iter, const std::string &prefix) {
  int depth = 0;
C
ckey_Dou 已提交
138
  auto root = schedule_iter.root();
139
  auto fn = [&depth](const isl::schedule_node &node) -> isl::schedule_node {
C
ckey_Dou 已提交
140
    if (node.as<isl::schedule_node_band>()) {
141 142 143
      auto schedule_depth = static_cast<int>(node.schedule_depth());
      schedule_depth = schedule_depth + static_cast<int>(node.as<isl::schedule_node_band>().n_member());
      depth = schedule_depth > depth ? schedule_depth : depth;
C
ckey_Dou 已提交
144 145 146 147
    }
    return node;
  };
  root = root.map_descendant_bottom_up(fn);
148
  isl::id_list res(root.ctx(), depth);
C
ckey_Dou 已提交
149

150
  for (int i = 0; i < depth; ++i) {
C
ckey_Dou 已提交
151 152 153 154 155 156 157 158 159 160 161 162
    std::stringstream ss;
    ss << prefix << i;
    res = res.add(isl::id(root.ctx(), ss.str()));
  }
  return res;
}

size_t &AstNodeNum() {
  static thread_local size_t n = 0;
  return n;
}
constexpr auto AST_NODE_ID_PREFIX = "__node_";
163 164 165 166 167 168
Stmt GenHalide(ScopInfo &info, const isl::schedule &sch, bool used_for_tile_out_band) {
  if (!used_for_tile_out_band) {
    // we should check the return value to be isl_stat_ok, but it returns isl_stat_error, so we skip this check.
    static_cast<void>(isl_options_set_ast_build_group_coscheduled(sch.ctx().get(), isl_bool_true));
    if (info.cube_info_.IsConv()) info.cube_info_.CreateConvModel();
  }
C
ckey_Dou 已提交
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189

  NodeInfoRepo node_info_repo;
  auto gather = [&node_info_repo](const isl::ast_node &node, const isl::ast_build &build) -> isl::ast_node {
    auto fillUpRepo = [](const isl::ast_node &node, const isl::ast_build &build,
                         NodeInfoRepo *node_info_repo) -> isl::ast_node {
      CHECK(node_info_repo != nullptr);
      auto schedule_map = isl::map::from(build.get_schedule());

      auto node_id = isl::id(node.ctx(), std::string(AST_NODE_ID_PREFIX) + std::to_string(AstNodeNum()++));
      CHECK_EQ(0u, node_info_repo->count(node_id)) << "node already exists: " << node_id;

      auto &node_info = (*node_info_repo)[node_id];
      node_info.iterator_map = isl::pw_multi_aff(schedule_map.reverse());
      node_info.build = build;
      return node.set_annotation(node_id);
    };

    return fillUpRepo(node, build, &node_info_repo);
  };

  // set up ast builder
190
  auto builder = isl::ast_build(sch.ctx());
C
ckey_Dou 已提交
191 192
  builder = builder.set_at_each_domain(gather);

193 194
  auto iter_prefix = info.user_config_.GetIterPrefix(info.cube_info_.IsSpecGemm());
  isl::id_list iters = CreateIteratorList(sch, iter_prefix);
C
ckey_Dou 已提交
195 196 197 198 199
  builder = builder.set_iterators(iters);

  // build processing
  std::chrono::high_resolution_clock::time_point timer_start;
  TIMER_START;
200 201
  auto ast_node = builder.node_from(sch);
  TIMER_SHOW("NodeFrom", std::string(info.cube_info_.IsSpecGemm() ? "_specgemm" : ""));
C
ckey_Dou 已提交
202 203 204 205

  ast_node = CanonicalizeBlockInAst(ast_node);

  TIMER_START;
206 207 208 209 210 211 212 213 214 215 216 217 218 219
  Stmt stmt;
  if (PRINT_ISL_EMMITER) {
    if (used_for_tile_out_band) {
      PrintHeader("CCEIslEmitter");
      stmt = CCEIslEmitter(info, node_info_repo, iters).Emit(ast_node);
    } else {
      PrintHeader("IslEmitter");
      stmt = IslEmitter(info, node_info_repo, iters).Emit(ast_node);
    }
  } else {
    stmt = CCEIslEmitter(info, node_info_repo, iters).Emit(ast_node);
  }

  TIMER_SHOW("CCEIslEmitter", std::string(info.cube_info_.IsSpecGemm() ? "_specgemm" : ""));
C
ckey_Dou 已提交
220

221 222 223 224 225 226 227 228 229
  if (PRINT_EMMITER) {
    PrintHeader("FINAL SCHEDULE");
    std::cout << PrettyPrintSchTree(sch) << std::endl;
    PrintHeader("FINAL ASTNODE");
    std::cout << FormatMupaStr(ast_node.to_str(), false) << std::endl << std::endl;
    PrintHeader("FINAL ASTNODE TO C");
    std::cout << ast_node.to_C_str() << std::endl;
    PrintHeader("FINAL STMT");
    std::cout << stmt;
C
ckey_Dou 已提交
230 231 232 233
  }
  return stmt;
}

234 235
Stmt Scop::GenHalide(const isl::schedule &sch) { return poly::GenHalide(info_, sch, false); }

C
ckey_Dou 已提交
236 237 238
}  // namespace poly
}  // namespace ir
}  // namespace akg