poly.cc 4.2 KB
Newer Older
C
ckey_Dou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
/**
 * Copyright 2019 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "poly/scop.h"
namespace akg {
namespace ir {
/*!
 * \brief Poly entry
 */
class Poly {
 public:
  Poly() : isl_ctx_(isl::ctx(isl_ctx_alloc())) {}

27 28 29 30 31 32
  ~Poly() noexcept {
    scop_.reset();
    // scop must be deconstructed before isl_ctx is deconstructed
    isl_ctx_free(isl_ctx_.get());
  }

C
ckey_Dou 已提交
33 34 35
  void Run(const Stmt &stmt, const Map<Tensor, Buffer> &extern_buffer, const Map<std::string, NodeRef> &attrs,
           const bool is_spec_gemm, bool is_tuning, bool is_dynamic) {
    stmt_ = stmt;
36
    scop_.reset(new poly::Scop(Simplify_cce(stmt_), isl_ctx_));
C
ckey_Dou 已提交
37
    CHECK(scop_ != nullptr);
38
    scop_->ParseUserConfig(attrs, extern_buffer, is_spec_gemm, is_tuning, is_dynamic);
C
ckey_Dou 已提交
39 40

    std::chrono::high_resolution_clock::time_point timer_start;
41
    // generate isl schedule from Halide
C
ckey_Dou 已提交
42 43 44 45
    TIMER_START;
    isl::schedule sch = scop_->GenIsl();
    TIMER_SHOW("GenIsl", std::string(is_spec_gemm ? "_specgemm" : ""));

46 47 48 49
    // isl schedule transform
    TIMER_START;
    isl::schedule sched = scop_->Transform(sch);
    TIMER_SHOW("Transform", std::string(is_spec_gemm ? "_specgemm" : ""));
C
ckey_Dou 已提交
50 51

    // generate Halide from isl schedule
52 53 54 55 56 57 58 59 60 61
    TIMER_START;
    stmt_ = scop_->GenHalide(sched);
    TIMER_SHOW("GenHalide", std::string(is_spec_gemm ? "_specgemm" : ""));

    if (is_dynamic) stmt_ = RestoreCombinedParams(stmt_, scop_->info_);

    if (is_tuning) {
      spaces_ = GenerateTilingSpace(sched, scop_->info_, stmt_, scop_->info_.user_config_.GetDumpTuningLevel());
      return;
    }
C
ckey_Dou 已提交
62 63

    // optimize post poly Halide IR for Davinci
64 65
    if (scop_->info_.user_config_.GetEnableFeatureLib() || scop_->info_.user_config_.GetOptimizeForDavinci()) {
      stmt_ = poly::DavinciHalideOptimizer(stmt_, !scop_->info_.user_config_.GetParams().empty());
C
ckey_Dou 已提交
66
    }
67
    gen_empty_tiling = scop_->info_.analysis_result_.GetIsTiled();
C
ckey_Dou 已提交
68 69
  }

70
  Stmt GetStmt() { return stmt_; }
C
ckey_Dou 已提交
71

72 73 74
  NodeRef GetSpaces() { return spaces_; }

  Array<Var> GetTilingParams() {
C
ckey_Dou 已提交
75 76 77 78
    CHECK(scop_ != nullptr);
    Array<Var> tiling_params_array;
    if (gen_empty_tiling) return tiling_params_array;
    std::unordered_set<Var, NodeHash, NodeEqual> tiling_params;
79 80
    auto param_tiling_map = scop_->info_.user_config_.GetParamTilingMap();
    for (const auto &kv : param_tiling_map) {
C
ckey_Dou 已提交
81 82 83 84 85 86
      GatherVars(kv.second, &tiling_params);
    }
    for (const auto &param : tiling_params) tiling_params_array.push_back(param);
    return tiling_params_array;
  }

87 88 89 90 91 92
  void GatherVars(const Expr expr, std::unordered_set<Var, air::NodeHash, air::NodeEqual> *vset) {
    PostOrderVisit(expr, [&vset](const NodeRef &node) {
      if (node.as<Variable>()) {
        vset->insert(Downcast<Var>(node));
      }
    });
C
ckey_Dou 已提交
93 94 95 96 97 98 99 100
  }

 private:
  std::unique_ptr<poly::Scop> scop_{nullptr};
  // define isl_ctx outside scop because there are a lot of isl objects in the members of scop class,
  // and we need to ensure that they are deconstructed before the isl_ctx is freed.
  isl::ctx isl_ctx_;
  Stmt stmt_;
101 102
  NodeRef spaces_;
  bool gen_empty_tiling{false};
C
ckey_Dou 已提交
103 104 105 106 107 108 109
};

/// Interface for lower pass
Array<NodeRef> AutoPoly(const Stmt &stmt, const Map<Tensor, Buffer> &extern_buffer,
                        const Map<std::string, NodeRef> &attrs, const bool is_specgemm, const bool is_dynamic) {
  Poly poly;
  poly.Run(stmt, extern_buffer, attrs, is_specgemm, false, is_dynamic);
110
  return Array<NodeRef>({poly.GetStmt(), poly.GetTilingParams()});
C
ckey_Dou 已提交
111 112 113 114 115 116
}

NodeRef GenTuningSpace(const Stmt &stmt, const Map<Tensor, Buffer> &extern_buffer,
                       const Map<std::string, NodeRef> &attrs, const bool is_specgemm) {
  Poly poly;
  poly.Run(stmt, extern_buffer, attrs, is_specgemm, true, false);
117
  return poly.GetSpaces();
C
ckey_Dou 已提交
118 119 120
}
}  // namespace ir
}  // namespace akg