提交 792d7888 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!81 Refactor the Autopoly pass to support multi-backend targets

Merge pull request !81 from anyrenwei/master
......@@ -28,6 +28,9 @@ set(AKG_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
include(cmake/RT.cmake)
include(cmake/utils.cmake)
include(cmake/external_libs/isl.cmake)
set(ISL_DIR "${CMAKE_BINARY_DIR}/isl")
if(ENABLE_AKG)
message("-- Build akg in Mindspore")
execute_process(COMMAND bash ${AKG_SOURCE_DIR}/third_party/apply_patches.sh ${CMAKE_CURRENT_BINARY_DIR} "1")
......@@ -43,8 +46,6 @@ else()
set(UNITTEST_DIR "${AKG_SOURCE_DIR}/tests/unittest_cpp")
endif()
set(ISL_DIR "${CMAKE_BINARY_DIR}/isl")
file(COPY ${AKG_SOURCE_DIR}/python/akg DESTINATION
${CMAKE_CURRENT_BINARY_DIR})
......@@ -175,6 +176,8 @@ file(
${TVM_DIR}/src/runtime/vm/profiler/*.cc
${TVM_DIR}/src/codegen/stackvm/*.cc
${AKG_SOURCE_DIR}/src/poly/*.cc
${AKG_SOURCE_DIR}/src/poly/schedule_pass/*.cc
${AKG_SOURCE_DIR}/src/poly/tiling/*.cc
${AKG_SOURCE_DIR}/src/api/*.cc
${AKG_SOURCE_DIR}/src/pass/*.cc
${AKG_SOURCE_DIR}/src/rpc/*.cc
......
......@@ -29,7 +29,7 @@
#include "ir_pass.h"
#include "pass/utils.h"
#include "pass/expr_alg_simplify.h"
#include "poly/tiling_algorithm.h"
#include "poly/tiling/tiling_algorithm.h"
namespace akg {
namespace ir {
......
此差异已折叠。
......@@ -16,13 +16,7 @@
#ifndef POLY_CCE_ISL_EMITTER_H_
#define POLY_CCE_ISL_EMITTER_H_
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "ir_pass.h"
#include "isl.h"
#include "scop.h"
#include "isl_emitter.h"
namespace akg {
......@@ -39,12 +33,15 @@ class Liveness {
std::vector<IslIdSet> read_;
std::vector<IslIdSet> write_;
};
enum AtomicType { Equ = 0, Add };
/*!
* IslEmitter for CCE
*/
class CCEIslEmitter : public IslEmitter {
public:
CCEIslEmitter(Scop &s, const NodeInfoRepo &n, const isl::id_list &i) : IslEmitter(s, n, i) { ProcBypassL1(s); }
CCEIslEmitter(ScopInfo &info, const NodeInfoRepo &n, const isl::id_list &i) : IslEmitter(info, n, i) {
ProcBypassL1(info);
}
~CCEIslEmitter() override = default;
Stmt Emit(const isl::ast_node &node) final;
......@@ -52,7 +49,6 @@ class CCEIslEmitter : public IslEmitter {
private:
// override emitters for CCE
Stmt EmitFor(const isl::ast_node_for &node) final;
Stmt EmitIf(const isl::ast_node_if &node) final;
Stmt EmitMark(const isl::ast_node_mark &node_id) override;
Stmt EmitBlock(const isl::ast_node_block &node) final;
Stmt EmitStmt(const isl::ast_node_user &node) final;
......@@ -60,60 +56,68 @@ class CCEIslEmitter : public IslEmitter {
// DMA emitters for CCE
Expr EmitLoad(const isl::ast_expr &lhs, Type type);
Stmt EmitL1Read(const isl::ast_node_user &node);
Stmt EmitL1Write(const isl::ast_node_user &node, Scop::AtomicType atomic);
Stmt EmitRead(const isl::ast_node_user &node);
Stmt EmitWrite(const isl::ast_node_user &node, AtomicType atomic);
Stmt EmitSpecGemL1write(const isl::ast_node_mark &node, const Stmt &stmt);
// RangeInfo emitters for CCE
Stmt EmitGemmRangeInfoBackPropFilter(const Stmt &stmt);
Stmt EmitGemmRangeInfo(Stmt stmt);
// multicore emitters for CCE
// emit mark node
Stmt EmitMarkMulticore(const isl::ast_node_mark &node);
bool InjectMulticore(const std::string &iter);
Stmt EmitMarkFuseVector(const isl::ast_node_mark &node);
Stmt EmitMarkAllocRealizeOut(const isl::ast_node_mark &node);
Stmt EmitMarkAllocC(const isl::ast_node_mark &node);
Stmt EmitMarkSpecGemm(const isl::ast_node_mark &node);
// emit attrs
void EmitAttrStmt(const isl::ast_node_block &block_node, const Liveness &liveness, bool is_L1, bool is_L0,
std::vector<Stmt> &stmts);
void EmitAttrStmtL0(Tensor &t, bool &is_im2col, bool &is_filter_l0, bool &is_gemm_data_trans,
bool &is_gemm_weight_trans);
void EmitAttrStmtL1(Tensor &t, bool &is_fractal, bool &is_filter_l1);
void EmitAttrStmtLiveness(const Liveness &liveness, std::vector<Stmt> &stmts, int i, bool is_L1);
void EmitReadAttrAtL0(std::vector<Stmt> &stmts, int i, Tensor &t);
void EmitReadAttrAtL1(std::vector<Stmt> &stmts, int i, Tensor &t);
void EmitReadAttr(const std::vector<IslIdSet> &read, std::vector<Stmt> &stmts, int i, bool is_L1, bool is_L0);
void EmitWriteAttr(const std::vector<IslIdSet> &write, std::vector<Stmt> &stmts, int i, bool is_L1);
void EmitAttrStmtAfterRealize(bool is_L1, bool is_L0, std::vector<Stmt> &stmts);
Stmt EmitGemmRangeInfoBackPropFilter(const Stmt &stmt);
Stmt EmitGemmRangeInfo(Stmt stmt);
// emit realize
void EmitRealize(const isl::ast_node_block &block_node, const Liveness &liveness_info, bool is_L1, bool is_L0,
std::vector<Stmt> &stmts);
void EmitRealizeLivenessInfo(std::vector<IslIdSet> &real, const Liveness &liveness_info,
std::unordered_map<isl::id, std::set<int>, isl::IslIdIslHash> &liveness,
std::function<bool(const std::string &id)> const &CheckGoOut);
void EmitGemmRangeInfoNewAxis(std::vector<Range> &range, std::vector<std::string> &prefix,
std::unordered_map<std::string, bool> &outerAxis, Range &axisMRange,
Map<std::string, Range> &range_map, Map<std::string, VarExpr> &axis_map);
// emit access
Stmt EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, BufferedFootPrintInfo &buffer_fp_info) override;
void EmitGemmRangeInfoDynamic(Range &axisMRange, Map<std::string, Range> &range_map);
void EmitGemmRangeInfoStatic(Map<std::string, Range> &range_map);
// realize info for CCE
// tool func
bool InjectMulticore(const std::string &iter);
void CollectLiveness(const Liveness &liveness_info, bool is_L1, std::vector<IslIdSet> &real,
std::unordered_map<isl::id, std::set<int>, isl::IslIdIslHash> &liveness,
std::function<bool(const std::string &id)> const &CheckGoOut);
void CollectGemmRangeInfoNewAxis(std::vector<Range> &range, std::vector<std::string> &prefix,
std::unordered_map<std::string, bool> &outerAxis, Range &axisMRange,
Map<std::string, Range> &range_map, Map<std::string, VarExpr> &axis_map);
void CollectGemmMWSize(Range &axis_m_range, Map<std::string, Range> &range_map);
void CollectGemmMWSizeDynamic(Map<std::string, Range> &range_map);
Expr FindRealizeScope(const isl::id &var);
std::string FindRealizeScopeToString(const isl::id &var);
Stmt InsertRealize(Stmt stmt, const isl::id &var, bool is_L0);
void RealizeOut();
Stmt RemoveCond(const Stmt &stmt);
void ProcBypassL1(const Scop &scop);
void ProcBypassL1(const ScopInfo &info);
void SetCube(const isl::id &stmt_id);
std::string ReplaceAxis(const std::string &oldAxis);
static std::vector<std::string> ConstructPrefix();
void GemmTranspose(std::vector<Stmt> &stmts);
void ConvBackPropFilterFixMadInit(const isl::ast_node_mark &node, Expr &mad_init_cond);
isl::multi_aff TensorAccessMultAff(isl::id &tensor_id, const Array<Expr> &subscripts,
const isl::id &stmt_id) override;
bool IsTransferStmt() override;
bool IsCopyinFromAnotherBand(isl::multi_aff &access) override;
std::map<const Variable *, std::string> iters_old_name_;
std::map<const Variable *, std::string> iters_new_name_;
std::unordered_map<isl::id, VarMap, isl::IslIdIslHash> stmt_var_map_;
std::set<Tensor> realized_;
IslIdSet hoisted_read_;
IslIdSet hoisted_write_;
......
此差异已折叠。
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_CONSTRUCT_POLY_ACCESSES_H_
#define POLY_CONSTRUCT_POLY_ACCESSES_H_
#include <tvm/ir_visitor.h>
#include <tvm/operation.h>
#include <tuple>
#include "poly/scop_info.h"
namespace akg {
namespace ir {
namespace poly {
std::pair<isl::map, isl::map> ConstructPolyAccess(const OperatorDomainSpace &domain, const Node *op,
const std::string &tensor, const Array<Expr> &dimensions,
AccessMap &accesses);
std::tuple<isl::union_map, isl::union_map, isl::union_map> ConstructPolyAccesses(const OperatorDomainSpace &domain,
const Stmt &s, AccessMap &accesses);
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_CONSTRCUT_POLY_ACCESSES_H_
\ No newline at end of file
......@@ -607,7 +607,7 @@ class DynamicPaddingFix : public IRMutator {
std::string fm_l1_{""};
};
Stmt OptimizeCce(const Stmt &s, bool dynamicShape = false) {
Stmt DavinciHalideOptimizer(const Stmt &s, bool dynamicShape = false) {
Stmt stmt = s;
if (dynamicShape) {
stmt = InductionVarElinate().Run(s);
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "poly/davinci_mgr_strategy.h"
#include "poly/schedule_pass/group.h"
#include "poly/schedule_pass/tile_outer_band.h"
#include "poly/schedule_pass/memory_manager.h"
#include "poly/schedule_pass/sink_c0.h"
#include "poly/schedule_pass/sink_last_axis.h"
#include "poly/schedule_pass/reorder_invariant_set_schedule.h"
#include "poly/schedule_pass/keep_outer_band_order.h"
#include "poly/schedule_pass/split_outer_band.h"
#include "poly/schedule_pass/transfer_stmt.h"
#include "poly/schedule_pass/reset_coincidence_of_reduce.h"
#include "poly/schedule_pass/set_all_coincidence.h"
#include "poly/schedule_pass/reschedule.h"
#include "poly/schedule_pass/reorder_inner_band.h"
#include "poly/schedule_pass/change_marknode_position.h"
#include "poly/schedule_pass/insert_node_for_allocc.h"
#include "poly/schedule_pass/label_realize_out_position.h"
#include "poly/schedule_pass/mark_fuse_op.h"
#include "poly/schedule_pass/reorder_mark_nodes.h"
#include "poly/schedule_pass/compute_transfer_copyin.h"
#include "poly/schedule_pass/compute_inner_band_dependency.h"
#include "poly/schedule_pass/mark_outer_most.h"
namespace akg {
namespace ir {
namespace poly {
void DavinciMgrStrategy::RegisterTilingPasses() {
RegisterPass(std::make_shared<TileOuterBand>(pass_info_, scop_info_));
}
void DavinciMgrStrategy::RegisterMemPromPasses() { RegisterPass(std::make_shared<MemoryManager>(scop_info_)); }
void DavinciMgrStrategy::RegisterPasses() {
passes_.clear();
RegisterNormalizationPasses();
if (!scop_info_.user_config_.GetDisableGroup()) {
RegisterPass(std::make_shared<GroupStatements>(pass_info_));
}
RegisterSchedulingPasses();
RegisterPass(std::make_shared<ReorderInvariantSetSchedule>(pass_info_));
if (scop_info_.user_config_.GetReorderSchedule()) {
RegisterPass(std::make_shared<SinkC0>());
}
if (scop_info_.user_config_.GetSinkLastAxis()) {
RegisterPass(std::make_shared<SinkLastAxis>(pass_info_));
}
if (scop_info_.user_config_.GetKeepOuterBandOrder()) {
RegisterPass(std::make_shared<KeepOuterBandOrder>(scop_info_));
}
RegisterPass(std::make_shared<UnGroupStatements>(pass_info_));
if (scop_info_.user_config_.GetOuterBandNeedSplit() && !scop_info_.cube_info_.IsSpecGemm()) {
RegisterPass(std::make_shared<SplitOuterBand>());
}
RegisterPass(std::make_shared<ComputeInnerBandDependency>(scop_info_));
if (!scop_info_.cube_info_.IsSpecGemm() && (scop_info_.cube_info_.IsConv() || scop_info_.cube_info_.IsGemm())) {
RegisterPass(std::make_shared<ComputeTransferCopyin>(scop_info_, pass_info_));
}
if (scop_info_.user_config_.GetIsTuning()) {
return;
}
RegisterTilingPasses();
RegisterPass(std::make_shared<ReorderInvariantSetSchedule>(pass_info_));
RegisterPass(std::make_shared<ResetCoincidenceOfReduce>(scop_info_, pass_info_));
if (scop_info_.user_config_.GetPragmaSetAllCoincident()) {
RegisterPass(std::make_shared<SetAllCoincidence>());
}
if (!scop_info_.user_config_.GetIsDynamic() || !scop_info_.cube_info_.IsConv()) {
RegisterPass(std::make_shared<Reschedule>(scop_info_, pass_info_));
}
RegisterPass(std::make_shared<ReorderInnerBand>(scop_info_.analysis_result_.GetCondVarsMap()));
RegisterPass(std::make_shared<ChangeMarkNodePosition>(scop_info_.analysis_result_.ExtractWithStmtId()));
RegisterPass(std::make_shared<LabelRealizeOutPosition>());
if (scop_info_.cube_info_.IsSpecGemm() || scop_info_.cube_info_.IsGemm() ||
scop_info_.cube_info_.IsConvBackpropFilter()) {
RegisterPass(std::make_shared<InsertNodeForAllocC>());
}
RegisterMemPromPasses();
if (!scop_info_.cube_info_.IsSpecGemm()) {
RegisterPass(std::make_shared<TransferStmt>(scop_info_, pass_info_));
}
RegisterPass(std::make_shared<ReorderMarkNodes>());
RegisterPass(std::make_shared<MarkFuseOp>());
// if coincidence constraints are disabled (due to reschedule), we cannot determine multicore axis reliably
bool can_use_multiCore = !scop_info_.cube_info_.IsSpecGemm() && scop_info_.user_config_.GetConsiderCoincidence();
if (can_use_multiCore || scop_info_.user_config_.GetEnableMarkMultiCore()) {
RegisterPass(std::make_shared<MarkOuterMost>(scop_info_));
}
}
} // namespace poly
} // namespace ir
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_DAVINCI_MGR_STRATEGY_H_
#define POLY_DAVINCI_MGR_STRATEGY_H_
#include "poly/pass_mgr_strategy.h"
namespace akg {
namespace ir {
namespace poly {
class DavinciMgrStrategy : public PassMgrStrategy {
public:
explicit DavinciMgrStrategy(ScopInfo &scop_info) : PassMgrStrategy(scop_info) {
pass_info_.coincident_ = scop_info_.user_config_.GetConsiderCoincidence();
}
~DavinciMgrStrategy() override = default;
void RegisterTilingPasses() override;
void RegisterMemPromPasses() override;
void RegisterPasses() override;
};
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_DAVINCI_MGR_STRATEGY_H_
......@@ -13,9 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "poly/dma_dataflow.h"
#include "poly/scop.h"
#include "poly/dma_dataflow.h"
#include "poly/poly_util.h"
namespace akg {
namespace ir {
......@@ -193,7 +193,6 @@ void StmtDataFlowInfo::AddWriteTensor(const std::string &name, TENSOR_DATAFLOW_T
void StmtDataFlowInfo::CreateTensorDataFlow(TENSOR_DATAFLOW_TYPE type, const std::string &name,
TensorDataFlow &dataflow) {
CHECK_NE(name, "");
dataflow.tensor_name_ = name;
switch (type) {
case TENSOR_DATAFLOW_TYPE::CUBE_CONV_A:
CubeConvA(name, dataflow);
......
......@@ -33,7 +33,7 @@ namespace akg {
namespace ir {
namespace poly {
class TensorFootprintCluster;
class TensorDataFlow;
struct TensorDataFlow;
class StmtDataFlowInfo;
enum MemType { DDR = 1, L1_, UB_, L0A_, L0B_, L0C_, UBL0_, UBL1_ };
......@@ -142,7 +142,6 @@ enum TENSOR_DATAFLOW_TYPE {
};
struct TensorDataFlow {
std::string tensor_name_;
std::vector<std::string> name_flow_;
MemFlow mem_type_flow_;
......
此差异已折叠。
......@@ -13,20 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_DMA_INJECT_H_
#define POLY_DMA_INJECT_H_
#pragma once
#include <isl/constraint.h>
#include <iostream>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <string>
#include <memory>
#include "poly/isl.h"
#include "poly/scop.h"
#include "poly/scop_info.h"
namespace akg {
namespace ir {
......@@ -177,30 +171,36 @@ std::vector<int> ExpandInvalidDims(const std::vector<int> &invalid_dims, const i
int &first_invalid_domain_dim);
isl::multi_aff ComputeBufferFootprint(const isl::map &access, const ScopedFootprint &foot_print);
isl::schedule_node PlaceDataCopyBelowImpl(Scop &scop, isl::schedule_node tree, const TensorFootprintCluster &cluster,
const isl::map &buffer_footprint, const isl::id &tensor_id,
const isl::set &original_elements, const isl::map &exact_reads,
const isl::map &exact_writes);
isl::schedule_node PlaceDataCopyBelowImpl(ScopInfo &scop_info, isl::schedule_node tree,
const TensorFootprintCluster &cluster, const isl::map &buffer_footprint,
const isl::id &tensor_id, const isl::set &original_elements,
const isl::map &exact_reads, const isl::map &exact_writes,
const isl::union_map &sch);
void PlaceDataCopyBelowImplReadWrite(Scop &scop, isl::schedule_node &tree, const TensorFootprintCluster &cluster,
const isl::map &footprint, const isl::id &tensor_id,
const isl::set &original_elements, const isl::map &exact_writes,
isl::map &read_extension, isl::set &buffered_footprint, const isl::id &cluster_id,
isl::map &extension_map, isl::id &read_id);
void PlaceDataCopyBelowImplReadWrite(ScopInfo &scop_info, isl::schedule_node &tree,
const TensorFootprintCluster &cluster, const isl::map &footprint,
const isl::id &tensor_id, const isl::set &original_elements,
const isl::map &exact_writes, isl::map &read_extension,
isl::set &buffered_footprint, const isl::id &cluster_id, isl::map &extension_map,
isl::id &read_id);
void PlaceDataCopyBelowImplFakeReads(Scop &scop, isl::schedule_node &tree, const TensorFootprintCluster &cluster,
isl::map &read_extension, const isl::id &cluster_id);
void PlaceDataCopyBelowImplFakeReads(ScopInfo &scop_info, isl::schedule_node &tree,
const TensorFootprintCluster &cluster, isl::map &read_extension,
const isl::id &cluster_id, const isl::union_map &sch);
isl::schedule_node PlaceInnerDataCopyBelow(Scop &scop, const isl::schedule_node &tree,
isl::schedule_node PlaceInnerDataCopyBelow(ScopInfo &scop_info, const isl::schedule_node &tree,
const TensorFootprintCluster &cluster,
const TensorFootprintCluster &outer_scope_cluster, const isl::id &tensor_id,
const isl::id &cluster_id, const isl::id &outer_scope_cluster_id);
const isl::id &cluster_id, const isl::id &outer_scope_cluster_id,
const isl::union_map &sch);
isl::schedule_node PlaceOuterDataCopyBelow(Scop &scop, const isl::schedule_node &tree,
isl::schedule_node PlaceOuterDataCopyBelow(ScopInfo &scop_info, const isl::schedule_node &tree,
const TensorFootprintCluster &cluster, const isl::id &tensor_id,
const isl::id &cluster_id);
const isl::id &cluster_id, const isl::union_map &sch,
const isl::space &sch_space);
isl::schedule_node PlaceIm2colBelow(Scop &scop, const isl::schedule_node &tree, const TensorFootprintCluster &cluster,
isl::schedule_node PlaceIm2colBelow(ScopInfo &scop_info, const isl::schedule_node &tree,
const TensorFootprintCluster &cluster,
const TensorFootprintCluster &outer_scope_cluster, const isl::id &cluster_id,
const isl::id &outer_scope_cluster_id);
......@@ -210,7 +210,7 @@ class AffineBase {
public:
virtual ~AffineBase() = default;
virtual isl::map ConstructAffine(isl::map original) = 0;
virtual bool NotNeedConstruct(std::string name, Scop &scop) = 0;
virtual bool NotNeedConstruct(std::string name, ScopInfo &scop_info) = 0;
};
class GemmInnerTransposeAffine : public AffineBase {
......@@ -221,13 +221,13 @@ class GemmInnerTransposeAffine : public AffineBase {
isl::map ConstructAffine(isl::map original_map) final;
void SetRightMatrix(AffineTensor v) { is_right_matrix_ = v; }
bool NotNeedConstruct(std::string name, Scop &scop) override {
bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override {
// right matrix filter !B tensor
if (is_right_matrix_ == AffineTensor::RIGHT_TENSOR && !scop.IsB(name)) {
if (is_right_matrix_ == AffineTensor::RIGHT_TENSOR && !scop_info.cube_info_.IsB(name)) {
return true;
}
// left matrix filter !A tensor
if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop.IsA(name)) {
if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop_info.cube_info_.IsA(name)) {
return true;
}
return false;
......@@ -246,13 +246,13 @@ class GemmTransposeAffine : public AffineBase {
void SetRightMatrix(AffineTensor v) { is_right_matrix_ = v; }
bool NotNeedConstruct(std::string name, Scop &scop) override {
bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override {
// right matrix filter !B tensor
if (is_right_matrix_ == AffineTensor::RIGHT_TENSOR && !scop.IsB(name)) {
if (is_right_matrix_ == AffineTensor::RIGHT_TENSOR && !scop_info.cube_info_.IsB(name)) {
return true;
}
// left matrix filter !A tensor
if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop.IsA(name)) {
if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop_info.cube_info_.IsA(name)) {
return true;
}
return false;
......@@ -271,17 +271,17 @@ class GemmTransposeBlockAffine : public AffineBase {
void SetRightMatrix(AffineTensor v) { is_right_matrix_ = v; }
bool NotNeedConstruct(std::string name, Scop &scop) override {
bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override {
// right matrix filter !B tensor
if (AffineTensor::RIGHT_TENSOR == is_right_matrix_ && !scop.IsB(name)) {
if (AffineTensor::RIGHT_TENSOR == is_right_matrix_ && !scop_info.cube_info_.IsB(name)) {
return true;
}
// left matrix filter !A tensor
if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop.IsA(name)) {
if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop_info.cube_info_.IsA(name)) {
return true;
}
if (AffineTensor::OUT_TENSOR == is_right_matrix_ && !scop.IsC(name)) {
if (AffineTensor::OUT_TENSOR == is_right_matrix_ && !scop_info.cube_info_.IsC(name)) {
return true;
}
......@@ -302,8 +302,8 @@ class Im2colAffine : public AffineBase {
void ConstructAffineMap(isl::map &footprint, std::vector<isl::aff> &v_aff_x, std::vector<isl::aff> &v_aff_y,
const isl::map &original_map, isl::local_space &ls);
bool NotNeedConstruct(std::string name, Scop &scop) override {
if (!scop.IsA(name)) {
bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override {
if (!scop_info.cube_info_.IsA(name)) {
return true;
}
return false;
......@@ -319,8 +319,8 @@ class WeightAffine : public AffineBase {
isl::map ConstructAffine(isl::map original_map) final;
bool NotNeedConstruct(std::string name, Scop &scop) override {
if (!scop.IsB(name)) {
bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override {
if (!scop_info.cube_info_.IsB(name)) {
return true;
}
return false;
......@@ -339,8 +339,8 @@ class FractalAffine : public AffineBase {
void ConstructAffineMap(isl::map &footprint, std::vector<isl::aff> &v_aff_x, std::vector<isl::aff> &v_aff_y,
const isl::map &original_map, isl::local_space &ls);
bool NotNeedConstruct(std::string name, Scop &scop) override {
if (!scop.IsA(name)) {
bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override {
if (!scop_info.cube_info_.IsA(name)) {
return true;
}
return false;
......@@ -371,7 +371,7 @@ class AffineRefGroupConstructor {
void create();
std::unique_ptr<TensorFootprintCluster> ConstructRefGroup(Scop &scop, const isl::union_map &accesses,
std::unique_ptr<TensorFootprintCluster> ConstructRefGroup(ScopInfo &scop_info, const isl::union_map &accesses,
const isl::union_set &domain,
const isl::union_map &schedule, ReferenceType type);
......@@ -391,7 +391,7 @@ class AffineRefGroupConstructor {
AffineType type_ = AffineType::AFFINE_GEMM;
};
std::unique_ptr<TensorFootprintCluster> ConstructAffineFpCluster(Scop &scop, const isl::union_map &accesses,
std::unique_ptr<TensorFootprintCluster> ConstructAffineFpCluster(ScopInfo &info, const isl::union_map &accesses,
const isl::union_set &domain,
const isl::union_map &schedule, ReferenceType type,
AffineType affine_type,
......
......@@ -20,9 +20,10 @@
#include <fcntl.h>
#include <sys/stat.h>
#include <fstream>
#include <iostream>
#include <iomanip>
#include "poly/poly_util.h"
#include "poly/scop.h"
#include "poly/dma_inject.h"
namespace akg {
......@@ -152,6 +153,11 @@ void PrettyPrintSchTree(std::FILE *fp, const isl::schedule &sch) {
}
}
std::string PrettyPrintSchTree(const isl::schedule &sch) {
std::string sch_tree_str = DumpSchTreeToString(sch);
return FormatSchTreeStr(sch_tree_str);
}
/*
* Check that file name is a simple relative path (does not start with "/", and does not include "." or "..").
* FileName should not include extension, and the extension will be appended to FileName.
......@@ -218,6 +224,7 @@ bool CompareSchTreeWithString(const std::string &compare_sch_, const isl::schedu
void PrintHeader(std::ofstream &of, const std::string &str) {
of << std::endl << ">>>>>>>>>> " << str << " <<<<<<<<<<" << std::endl;
}
void PrintHeader(const std::string &str) { std::cout << ">>>>>>>>>> " << str << " <<<<<<<<<<" << std::endl; }
void DumpNode(std::ofstream &of, const air::Node *node) {
if (node->IsInstance<Provide>()) {
......@@ -274,28 +281,28 @@ void CreateDirIfNotExist(const std::string &file_name) {
free(file_name_);
}
void Scop::DumpScopDataBasics(std::ofstream &of) {
void AnalysisResult::DumpScopDataBasics(std::ofstream &of) {
PrintHeader(of, "statements");
for (const auto &stmt : data_.statements) {
for (const auto &stmt : GetStatementMap()) {
of << stmt.first << " : ";
DumpNode(of, stmt.second);
of << std::endl;
}
PrintHeader(of, "accesses");
for (const auto &stmt : data_.accesses) {
for (const auto &stmt : GetAccessMap()) {
of << stmt.second << " : ";
DumpNode(of, stmt.first);
of << std::endl;
}
PrintHeader(of, "domains");
for (const auto &stmt : data_.domains) {
for (const auto &stmt : GetOperatorDomainMap()) {
of << stmt.first << " : param_space " << stmt.second.param_space << std::endl;
}
PrintHeader(of, "stmt_op_Info");
for (const auto &stmt : data_.stmt_op_Info) {
for (const auto &stmt : GetStmtOpInfoMap()) {
of << stmt.first << " : ops [ ";
for (auto op : stmt.second.ops) {
of << int(op) << ", ";
......@@ -307,92 +314,79 @@ void Scop::DumpScopDataBasics(std::ofstream &of) {
of << "]" << std::endl;
}
PrintHeader(of, "iterators");
for (const auto &it : data_.iterators) {
of << it.first << " : [ ";
for (const auto &str : it.second) {
of << str << ", ";
}
of << "]" << std::endl;
}
PrintHeader(of, "reads");
of << FormatMupaStr(data_.reads) << std::endl;
of << FormatMupaStr(GetReads()) << std::endl;
PrintHeader(of, "writes");
of << FormatMupaStr(data_.writes) << std::endl;
of << FormatMupaStr(GetWrites()) << std::endl;
PrintHeader(of, "copyin");
of << FormatMupaStr(data_.copyin) << std::endl;
of << FormatMupaStr(GetCopyin()) << std::endl;
PrintHeader(of, "fake_copyin");
of << FormatMupaStr(data_.fake_copyin) << std::endl;
of << FormatMupaStr(GetFakeCopyin()) << std::endl;
PrintHeader(of, "inter_band_dependency");
of << FormatMupaStr(data_.inter_band_dependency) << std::endl;
of << FormatMupaStr(GetInnerBandDependency()) << std::endl;
PrintHeader(of, "transfer_stmt");
of << FormatMupaStr(data_.transfer_stmt) << std::endl;
of << FormatMupaStr(GetTransferStmt()) << std::endl;
PrintHeader(of, "reduce_stmts");
for (const auto &stmt : data_.reduce_stmts) {
for (const auto &stmt : GetReduceStmtMap()) {
of << stmt.first << ": reduce axis [ ";
for (const auto &axis : stmt.second) {
of << axis << " ";
}
of << "]" << std::endl;
}
PrintHeader(of, "group_filter_map");
for (const auto &group : group_filter_map_) {
of << group.first << " : [ ";
for (auto filter : group.second) {
of << filter << ", ";
}
of << "]" << std::endl;
}
}
void Scop::DumpScopDataAdvanced(std::ofstream &of) {
void ScopInfo::DumpScopDataAdvanced(std::ofstream &of) {
PrintHeader(of, "binds");
for (auto bind : binds_) {
auto binds = user_config_.GetBind();
for (auto bind : binds) {
of << bind.first << " : " << bind.second << std::endl;
}
PrintHeader(of, "binds_orig");
for (auto bind : binds_orig_) {
auto binds_orig = user_config_.GetOriginBind();
for (auto bind : binds_orig) {
of << bind.first << " : " << bind.second << std::endl;
}
PrintHeader(of, "realize_from_input");
for (const auto &id : realize_from_input_) {
auto realize_from_input = user_config_.GetRealizeFromInput();
for (const auto &id : realize_from_input) {
of << id << ", ";
}
of << std::endl;
PrintHeader(of, "dim_infos");
for (const auto &dim_info : dim_infos_) {
for (const auto &dim_info : analysis_result_.GetTileSizes()) {
of << "index=" << dim_info.index << " axis=" << dim_info.axis << " l1_tiling_size=" << dim_info.l1_tiling_size
<< " l0_tiling_size=" << dim_info.l0_tiling_size << " dim_seq=" << dim_info.dim_seq << std::endl;
}
PrintHeader(of, "fractal_int_info");
for (const auto &info : fractal_int_info_) {
for (const auto &info : cube_info_.fractal_int_info_) {
of << info.first << " : " << info.second << std::endl;
}
PrintHeader(of, "fractal_str_info");
for (const auto &info : fractal_str_info_) {
for (const auto &info : cube_info_.fractal_str_info_) {
of << info.first << " : " << info.second << std::endl;
}
PrintHeader(of, "conditional_write_buffer_footprints");
for (const auto &tensor : conditional_write_buffer_footprints_) {
auto conditional_write_buffer_footprints = analysis_result_.GetConditionalWriteBufferFootprints();
for (const auto &tensor : conditional_write_buffer_footprints) {
of << tensor << std::endl;
}
PrintHeader(of, "tensor_name_flows");
for (const auto &name_flow : tensor_name_flows_) {
auto tensor_name_flows = analysis_result_.GetTensorNameFlows();
for (const auto &name_flow : tensor_name_flows) {
of << name_flow.first << " : [ ";
for (const auto &name : name_flow.second) {
of << name << ", ";
......@@ -401,7 +395,8 @@ void Scop::DumpScopDataAdvanced(std::ofstream &of) {
}
PrintHeader(of, "tensor_memflows");
for (const auto &mem_flow : tensor_mem_flows_) {
auto tensor_mem_flows = analysis_result_.GetTensorMemFlows();
for (const auto &mem_flow : tensor_mem_flows) {
of << mem_flow.first << " : [ ";
for (auto mem : mem_flow.second) {
of << static_cast<int>(mem) << ", ";
......@@ -409,25 +404,8 @@ void Scop::DumpScopDataAdvanced(std::ofstream &of) {
of << "]" << std::endl;
}
PrintHeader(of, "n_clusters");
for (const auto &cluster : n_clusters_) {
of << cluster.first << " : " << cluster.second << std::endl;
}
PrintHeader(of, "bufferedDecls");
for (const auto &buffered_decl : buffered_decls_) {
of << buffered_decl.first << " : "
<< "tensor_id=" << buffered_decl.second.tensor_id << "type=" << buffered_decl.second.type
<< "kind=" << static_cast<int>(buffered_decl.second.kind) << "tensor=" << buffered_decl.second.tensor
<< "size=[";
for (auto size : buffered_decl.second.sizes) {
of << size << ",";
}
of << "]" << std::endl;
}
PrintHeader(of, "active_buffer_footprints");
for (const auto &active_buffer_footprint : active_buffer_footprints_) {
for (const auto &active_buffer_footprint : analysis_result_.active_buffer_footprints_) {
of << "cluster_id : " << active_buffer_footprint.second.cluster_id << std::endl
<< "domain : " << FormatMupaStr(active_buffer_footprint.first) << std::endl
<< "cluster : " << *(active_buffer_footprint.second.cluster) << std::endl
......@@ -436,81 +414,82 @@ void Scop::DumpScopDataAdvanced(std::ofstream &of) {
}
PrintHeader(of, "buffered_decl_infos");
DumpBufferDefInfos(of);
of << std::endl;
of << "custom_tiling : ";
if (custom_tiling_.empty()) of << "empty" << std::endl;
for (const auto &tiling : custom_tiling_) {
of << tiling << " ";
}
analysis_result_.DumpBufferDefInfos(of);
of << std::endl;
PrintHeader(of, "attr_info");
for (const auto &info : attr_info_) {
for (const auto &info : cube_info_.GetConvAttrInfo()) {
of << info.first << " : " << info.second << std::endl;
}
}
void Scop::DumpScopDataScheduleAttrs(std::ofstream &of) {
void UserConfig::DumpScopDataScheduleAttrs(std::ofstream &of) {
PrintHeader(of, "schedule attrs");
of << "dim : " << b_dim_ << std::endl;
of << "kernel_h : " << matB_dim_h_ << std::endl;
of << "kernel_w : " << matB_dim_w_ << std::endl;
of << "conv_backprop_filter : " << conv_back_prop_filter_ << std::endl;
of << "bypassL1 : " << bypassL1_ << std::endl;
of << "dump_tuning_level : " << dump_tuning_level_ << std::endl;
of << "pragma_rmselfdep : " << remove_self_dependence_ << std::endl;
of << "pragma_force_rmselfdep : " << force_remove_self_dependence_ << std::endl;
of << "pragma_reschedule : " << compute_reschedule_ << std::endl;
of << "pragma_disable_schedule_shift : " << disable_schedule_shift_ << std::endl;
of << "pragma_enable_schedule_max_constant : " << enable_schedule_max_constant_ << std::endl;
of << "pragma_disable_loop_reversal : " << disable_loop_reversal_ << std::endl;
of << "pragma_disable_loop_fusion : " << disable_loop_fusion_ << std::endl;
of << "pragma_modshift : " << mod_schedule_shift_ << std::endl;
of << "pragma_conv_special_dma : " << conv_special_dma_ << std::endl;
of << "pragma_reorder_schedule : " << reorder_schedule_ << std::endl;
of << "pragma_checkcoincident : " << tile_check_coincident_ << std::endl;
of << "pragma_opt_for_davinci : " << optimize_for_davinci_ << std::endl;
of << "pragma_sink_last_axis : " << sink_last_axis_ << std::endl;
of << "pragma_keep_outer_band_order : " << keep_outer_band_order_ << std::endl;
of << "pragma_disable_group : " << disable_group_ << std::endl;
of << "pragma_tile_inner_band : " << tile_inner_band_ << std::endl;
of << "kernel_name : " << kernel_name_ << std::endl;
of << "dump_poly_dir : " << dump_poly_dir_ << std::endl;
of << "isolated_idx : " << isolated_idx_ << std::endl;
of << "dynamic_shape_bound : " << dynamic_shape_bound_ << std::endl;
of << "pragma_tilesize_is_var : " << tile_size_is_var_ << std::endl;
of << "pragma_outerband_need_split : " << outer_band_need_split_ << std::endl;
of << "pragma_is_conv : " << pragma_is_conv_ << std::endl;
of << "dump_poly_dir : " << GetDumpPolyDir() << std::endl;
of << "dump_tuning_level : " << GetDumpTuningLevel() << std::endl;
of << "dim : " << GetBDim() << std::endl;
of << "pragma_rmselfdep : " << GetRemoveSelfDependence() << std::endl;
of << "pragma_force_rmselfdep : " << GetForceRemoveSelfDependence() << std::endl;
of << "pragma_reschedule : " << GetComputeReschedule() << std::endl;
of << "pragma_disable_schedule_shift : " << GetDisableScheduleShift() << std::endl;
of << "pragma_enable_schedule_max_constant : " << GetEnableScheduleMaxConstant() << std::endl;
of << "pragma_disable_loop_reversal : " << GetDisableLoopReversal() << std::endl;
of << "pragma_disable_loop_fusion : " << GetDisableLoopFusion() << std::endl;
of << "pragma_modshift : " << GetModScheduleShift() << std::endl;
of << "pragma_reorder_schedule : " << GetReorderSchedule() << std::endl;
of << "pragma_checkcoincident : " << GetTileCheckCoincident() << std::endl;
of << "pragma_opt_for_davinci : " << GetOptimizeForDavinci() << std::endl;
of << "pragma_sink_last_axis : " << GetSinkLastAxis() << std::endl;
of << "pragma_keep_outer_band_order : " << GetKeepOuterBandOrder() << std::endl;
of << "pragma_disable_group : " << GetDisableGroup() << std::endl;
of << "pragma_tile_inner_band : " << GetTileInnerBand() << std::endl;
of << "isolated_idx : " << GetIsolatedIdx() << std::endl;
of << "pragma_outerband_need_split : " << GetOuterBandNeedSplit() << std::endl;
of << "dynamic_shape_bound : " << GetDynamicShapeBound() << std::endl;
of << "pragma_tilesize_is_var : " << GetTileSizeIsVar() << std::endl;
of << "kernel_name : " << GetKernelName() << std::endl;
of << "kernel_h : " << GetMatBDimH() << std::endl;
of << "kernel_w : " << GetMatBDimW() << std::endl;
of << "conv_backprop_filter : " << GetConvBackPropFilter() << std::endl;
of << "bypassL1 : " << GetByPassL1() << std::endl;
of << "pragma_is_conv : " << GetPragmaIsConv() << std::endl;
of << "pragma_conv_special_dma : " << GetConvSpecialDma() << std::endl;
}
bool Scop::DumpScopData(const std::string &file_name) {
bool ScopInfo::DumpScopData(const std::string &file_name) {
std::string canonical_log_name = FilePathCanonicalize(file_name, true);
if (!CreateFileIfNotExist(canonical_log_name)) return false;
std::ofstream of;
of.open(canonical_log_name, std::ios::out);
if (!of.is_open()) return false;
DumpScopDataBasics(of);
analysis_result_.DumpScopDataBasics(of);
DumpScopDataAdvanced(of);
DumpScopDataScheduleAttrs(of);
user_config_.DumpScopDataScheduleAttrs(of);
of.close();
return true;
}
void Scop::DumpSchTree(const std::string &file_name, const isl::schedule &sch_dump) {
if (dump_pass_ir_) {
void ScopInfo::DumpSchTree(const std::string &file_name, const isl::schedule &sch_dump) {
std::stringstream final_file_name;
final_file_name << std::setw(2) << std::setfill('0') << dump_schtree_count << "_" << file_name
<< std::string(cube_info_.IsSpecGemm() ? "_specgemm" : "");
if (user_config_.GetDumpPassIr()) {
#if DUMP_IR
DumpSchTreeImpl(CreateDumpDir(file_name), sch_dump);
DumpSchTreeImpl(CreateDumpDir(final_file_name.str()), sch_dump);
dump_schtree_count++;
#endif
#if DUMP_SCOP_DATA
#if DUMP_SCOP_DATA_PER_PASS
static_cast<void>(DumpScopData(CreateDumpDir(file_name)));
static_cast<void>(DumpScopData(CreateDumpDir(final_file_name.str())));
#else
static_cast<void>(DumpScopData(CreateDumpDir("scop")));
#endif
......@@ -518,29 +497,29 @@ void Scop::DumpSchTree(const std::string &file_name, const isl::schedule &sch_du
}
}
std::string Scop::AddDumpDir(const std::string &file_name) {
std::string ScopInfo::AddDumpDir(const std::string &file_name) {
std::string real_file_name = file_name;
bool is_specgemm = (isolated_idx_ > 0);
bool is_specgemm = (user_config_.GetIsolatedIdx() > 0);
if (is_specgemm) {
std::string dump_isolate_dir = "specgemm_" + std::to_string(isolated_idx_);
std::string dump_isolate_dir = "specgemm_" + std::to_string(user_config_.GetIsolatedIdx());
real_file_name = dump_isolate_dir + '/' + real_file_name;
}
#if (!DUMP_IN_CURRENT_DIR)
if (!dump_poly_dir_.empty()) {
real_file_name = dump_poly_dir_ + '/' + real_file_name;
if (!user_config_.GetDumpPolyDir().empty()) {
real_file_name = user_config_.GetDumpPolyDir() + '/' + real_file_name;
}
#endif
return real_file_name;
}
std::string Scop::CreateDumpDir(const std::string &file_name) {
std::string ScopInfo::CreateDumpDir(const std::string &file_name) {
std::string real_file_name = AddDumpDir(file_name);
CreateDirIfNotExist(real_file_name);
return real_file_name;
}
void Scop::DumpBufferDefInfos(std::ostream &out) {
void AnalysisResult::DumpBufferDefInfos(std::ostream &out) {
for (size_t index = 0; index < buffer_def_infos_.size(); index++) {
out << "\r\nbufferedDefInfos_[" << index << "]: " << std::endl;
out << " tensor_id : " << buffer_def_infos_[index].tensor_id << std::endl;
......@@ -552,6 +531,48 @@ void Scop::DumpBufferDefInfos(std::ostream &out) {
out << " is_bind_tensor : " << buffer_def_infos_[index].is_bind_tensor << std::endl;
}
}
void ScopInfo::DumpTransform(const std::string &file_name, PassInfo &pass_info) {
auto real_path = CreateDumpDir(file_name);
std::ofstream of;
of.open(real_path, std::ios::out);
if (!of.is_open()) {
return;
}
PrintHeader(of, "group_filter_map");
for (const auto &group : pass_info.group_filter_map_) {
of << group.first << " : [ ";
for (auto filter : group.second) {
of << filter << ", ";
}
of << "]" << std::endl;
}
PrintHeader(of, "dependences");
of << FormatMupaStr(pass_info.dependences_.to_str()) << std::endl;
PrintHeader(of, "constraints");
isl_printer *p;
char *s = nullptr;
p = isl_printer_to_str(GetCtx().get());
CHECK(p != nullptr);
p = isl_printer_set_yaml_style(p, ISL_YAML_STYLE_BLOCK);
p = isl_printer_print_schedule_constraints(p, pass_info.constraints_.get());
s = isl_printer_get_str(p);
if (s) {
of << FormatMupaStr(s);
free(s);
}
static_cast<void>(isl_printer_free(p));
PrintHeader(of, "time_records");
for (auto time_log : time_records_) {
of << time_log << std::endl;
}
of.close();
}
} // namespace poly
} // namespace ir
} // namespace akg
......@@ -19,6 +19,7 @@
#include <isl/cpp.h>
#include <tvm/node/node.h>
#include <string>
#include "poly/poly_util.h"
namespace akg {
namespace ir {
namespace poly {
......@@ -35,11 +36,11 @@ bool CreateFileIfNotExist(const std::string &file_name);
void CreateDirIfNotExist(const std::string &file_name);
std::string DumpSchTreeToString(const isl::schedule &sch);
void DumpSchTreeImpl(const std::string &file_name, const isl::schedule &sch);
std::string PrettyPrintSchTree(const isl::schedule &sch);
void PrintHeader(std::ofstream &of, const std::string &str);
void PrintHeader(const std::string &str);
void DumpNode(std::ofstream &of, const air::Node *node);
bool CompareSchTreeWithString(const std::string &compare_sch, const isl::schedule &sch);
} // namespace poly
} // namespace ir
} // namespace akg
......
......@@ -203,11 +203,13 @@ Stmt IslEmitter::EmitFor(const isl::ast_node_for &node) {
Stmt IslEmitter::EmitIf(const isl::ast_node_if &node) {
Expr cond_expr = Interpret(node.get_cond());
cur_if_list_.push_back(cond_expr.get());
Stmt then_case = EmitAst(node.get_then_node());
Stmt else_case;
if (node.has_else_node()) {
else_case = EmitAst(node.get_else_node());
}
cur_if_list_.pop_back();
return IfThenElse::make(cond_expr, then_case, else_case);
}
......@@ -230,25 +232,8 @@ Stmt IslEmitter::EmitBlock(const isl::ast_node_block &node) {
}
}
class ReplaceLoopVar : public air::ir::IRMutator {
public:
explicit ReplaceLoopVar(VarMap v_) : var_map(std::move(v_)) {}
~ReplaceLoopVar() override = default;
Expr Mutate_(const Variable *op, const Expr &e) final {
for (auto &i : var_map) {
if (op->name_hint == i.first.get_name()) {
return i.second;
}
}
return e;
}
private:
VarMap var_map;
};
isl::space IslEmitter::GetDomainSpace(const isl::id &node_id) {
auto dom = isl::union_set(scop_.Domain());
auto dom = isl::union_set(info_.analysis_result_.Domain());
auto space = isl::space();
dom.foreach_set([&node_id, &space](const isl::set &s) -> void {
if (s.get_tuple_id() == node_id) {
......@@ -265,12 +250,12 @@ isl::space IslEmitter::GetSpace(const isl::id &tensor_id, const Array<Expr> &ten
return space;
}
isl::multi_aff IslEmitter::TensorAccessMultAff(const isl::id &tensor_id, const Array<Expr> &tensor_index,
isl::multi_aff IslEmitter::TensorAccessMultAff(isl::id &tensor_id, const Array<Expr> &tensor_index,
const isl::id &node_id) {
CHECK_NE(tensor_index.size(), 0u);
isl::pw_multi_aff iter_map = node_info_map_.at(node_id).iterator_map;
isl::id stmt_id = iter_map.get_tuple_id(isl_dim_out);
OperatorDomainSpace domain_space = scop_.data_.domains.at(stmt_id);
OperatorDomainSpace domain_space = info_.analysis_result_.GetOperatorDomainMap().at(stmt_id);
isl::multi_aff ma = isl::multi_aff::zero(GetSpace(tensor_id, tensor_index, stmt_id));
for (size_t i = 0; i < tensor_index.size(); ++i) {
auto aff = Expr2Aff(domain_space.param_space, tensor_index[i]).unbind_params_insert_domain(domain_space.tuple);
......@@ -335,8 +320,9 @@ class EmitExpr : public air::ir::IRMutator {
Map<Expr, Expr> cache_;
};
void FindBufferFootprintById(Scop::BufferedFootPrintInfo &buffer_footprint_info,
std::vector<Scop::BufferedFootPrintInfo> active_buf_footprints, isl::id fp_id) {
BufferedFootPrintInfo FindBufferFootprintById(const std::vector<BufferedFootPrintInfo> &active_buf_footprints,
const isl::id &fp_id) {
BufferedFootPrintInfo buffer_footprint_info;
for (const auto &act_buf_fp : active_buf_footprints) {
if (act_buf_fp.cluster != nullptr) {
for (const auto &fp : act_buf_fp.cluster->tensor_foot_prints) {
......@@ -347,14 +333,16 @@ void FindBufferFootprintById(Scop::BufferedFootPrintInfo &buffer_footprint_info,
}
}
}
return buffer_footprint_info;
}
bool IsTransferStmt(Scop &scop, isl::id &stmt_id) {
if (!scop.is_spec_gemm_ && scop.is_tiled_) {
isl::union_set transfer_stmt = scop.data_.transfer_stmt;
bool IslEmitter::IsTransferStmt() {
if (info_.analysis_result_.GetIsTiled()) {
isl::union_set transfer_stmt = info_.analysis_result_.GetTransferStmt();
if (!transfer_stmt.is_empty()) {
bool name_match = false;
transfer_stmt.foreach_set([&name_match, stmt_id](const isl::set &s) -> void {
auto stmt_id = stmt_id_;
transfer_stmt.foreach_set([&name_match, &stmt_id](const isl::set &s) -> void {
if (s.get_tuple_name() == stmt_id.get_name()) {
name_match = true;
}
......@@ -365,8 +353,8 @@ bool IsTransferStmt(Scop &scop, isl::id &stmt_id) {
return false;
}
Stmt EmitAccessNodeProvide(const Node *node, const VarMap &var_map_tmp,
Scop::BufferedFootPrintInfo &buffer_footprint_info) {
Stmt IslEmitter::EmitAccessNodeProvide(const Node *node, const VarMap &var_map_tmp,
BufferedFootPrintInfo &buffer_footprint_info) {
const auto provide = static_cast<const Provide *>(node);
Expr value = ReplaceLoopVar(var_map_tmp).Mutate(provide->value);
Array<Expr> args;
......@@ -380,8 +368,8 @@ Stmt EmitAccessNodeProvide(const Node *node, const VarMap &var_map_tmp,
return Stmt();
}
Stmt EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, Scop::BufferedFootPrintInfo &buffer_footprint_info,
bool &is_transfer_stmt, Scop &scop) {
Stmt IslEmitter::EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp,
BufferedFootPrintInfo &buffer_footprint_info) {
const Call *call = static_cast<const Call *>(node);
Array<Expr> args;
for (auto iv : call->args) {
......@@ -389,46 +377,35 @@ Stmt EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, Scop::Buffe
}
// Not hoisted, emitting just the mapped subscript.
if (!buffer_footprint_info.cluster_id) {
std::string call_name = call->name;
if (is_transfer_stmt && (std::string::npos == call_name.find("_local_UB"))) {
call_name = call_name + "_local_UB";
Tensor t = scop.FindTensor(call_name);
if (t.defined()) {
return Evaluate::make(Call::make(call->type, call_name, args, call->call_type, t->op, call->value_index));
} else {
LOG(WARNING) << "Call can not found tensor!!! tensor name: " << call_name;
}
}
return Evaluate::make(Call::make(call->type, call->name, args, call->call_type, call->func, call->value_index));
}
return Stmt();
}
bool IsCopyinFromAnotherBand(Scop &scop, isl::multi_aff &access) {
if (!scop.is_spec_gemm_) {
for (isl::map inter_band_dependency : scop.data_.inter_band_dependency.get_map_list()) {
if (inter_band_dependency.get_tuple_id(isl_dim_out) == access.get_tuple_id(isl_dim_out)) {
return true;
}
bool IslEmitter::IsCopyinFromAnotherBand(isl::multi_aff &access) {
for (isl::map inter_band_dependency : info_.analysis_result_.GetInnerBandDependency().get_map_list()) {
if (inter_band_dependency.get_tuple_id(isl_dim_out) == access.get_tuple_id(isl_dim_out)) {
return true;
}
}
return false;
}
void AffSubForAstToSchedule(isl::pw_multi_aff &ast_to_schedule, bool &is_transfer_stmt,
bool &is_copyin_from_another_band) {
isl::pw_multi_aff &AffSubForAstToSchedule(isl::pw_multi_aff &ast_to_schedule, bool is_transfer_stmt,
bool is_copyin_from_another_band) {
if (is_transfer_stmt || is_copyin_from_another_band) {
isl_pw_multi_aff *pma1 = ast_to_schedule.copy();
isl_pw_multi_aff *pma2 = ast_to_schedule.copy();
isl_pw_multi_aff *pma = isl_pw_multi_aff_sub(pma1, pma2);
ast_to_schedule = isl::manage(pma);
}
return ast_to_schedule;
}
Stmt IslEmitter::EmitAccessNodeFromPromoteAcsProvide(Scop &scop, isl::id var, const Node *node, Array<Expr> &args) {
Stmt IslEmitter::EmitAccessNodeFromPromoteAcsProvide(isl::id var, const Node *node, Array<Expr> &args) {
const auto provide = static_cast<const Provide *>(node);
Tensor t = scop.FindTensor(var);
if (scop.CountBufferDefInfo(var)) {
Tensor t = info_.FindTensor(var);
if (info_.analysis_result_.CountBufferDefInfo(var)) {
realize_may_def_.insert(var);
if_map_[var] = cur_if_list_;
if (cur_if_list_.empty()) {
......@@ -439,10 +416,10 @@ Stmt IslEmitter::EmitAccessNodeFromPromoteAcsProvide(Scop &scop, isl::id var, co
return s;
}
Stmt IslEmitter::EmitAccessNodeFromPromoteAcsCall(Scop &scop, isl::id var, const Node *node, Array<Expr> &args) {
Stmt IslEmitter::EmitAccessNodeFromPromoteAcsCall(isl::id var, const Node *node, Array<Expr> &args) {
const Call *call = static_cast<const Call *>(node);
Tensor t = scop.FindTensor(var);
if (scop.CountBufferDefInfo(var)) {
Tensor t = info_.FindTensor(var);
if (info_.analysis_result_.CountBufferDefInfo(var)) {
realize_use_.insert(var);
if (!if_map_.count(var) || !AOutThanB(if_map_.at(var), cur_if_list_)) {
realize_use_with_may_def_.insert(var);
......@@ -451,25 +428,6 @@ Stmt IslEmitter::EmitAccessNodeFromPromoteAcsCall(Scop &scop, isl::id var, const
return Evaluate::make(Call::make(call->type, var.get_name(), args, call->call_type, t->op, t->value_index));
}
void GetNameWithoutLocal(isl::id &tensor_id, Scop &scop) {
if (!scop.is_spec_gemm_) {
size_t pos = tensor_id.get_name().find("_local_");
std::string substr = tensor_id.get_name().substr(0, pos);
if (pos != 0) tensor_id = isl::id(tensor_id.ctx(), substr);
}
}
Stmt EmitAccessNodeImpl(const Node *node, const VarMap &var_map_tmp, Scop::BufferedFootPrintInfo &buffer_footprint_info,
bool &is_transfer_stmt, Scop &scop, bool is_Provide) {
Stmt s;
if (is_Provide) {
s = EmitAccessNodeProvide(node, var_map_tmp, buffer_footprint_info);
} else {
s = EmitAccessNodeCall(node, var_map_tmp, buffer_footprint_info, is_transfer_stmt, scop);
}
return s;
}
Stmt IslEmitter::EmitAccessNode(const std::string &name, const Node *node, const Array<Expr> &tensor_index,
const VarMap &var_map_tmp) {
// Scalars are not hoisted or remapped.
......@@ -481,40 +439,34 @@ Stmt IslEmitter::EmitAccessNode(const std::string &name, const Node *node, const
auto build = node_info_map_.at(node_id_).build;
auto iterator_map = node_info_map_.at(node_id_).iterator_map;
CHECK_EQ(scop_.data_.accesses.count(node), 1u)
CHECK_EQ(info_.analysis_result_.GetAccessMap().count(node), 1u)
<< "generating tensor " << name << " not in Scop" << node << " not allowed ";
auto fp_id = scop_.data_.accesses.at(node);
auto fp_id = info_.analysis_result_.GetAccessMap().at(node);
Scop::BufferedFootPrintInfo buffer_footprint_info;
std::vector<Scop::BufferedFootPrintInfo> active_buf_footprint;
for (const auto &kv : scop_.ActiveBufferFootprints()) {
std::vector<BufferedFootPrintInfo> active_buf_footprint;
for (const auto &kv : info_.analysis_result_.ActiveBufferFootprints()) {
if (kv.first.intersect(isl::union_set(Domain())).is_empty()) {
continue;
}
active_buf_footprint.emplace_back(kv.second);
}
FindBufferFootprintById(buffer_footprint_info, active_buf_footprint, fp_id);
bool is_transfer_stmt = false;
is_transfer_stmt = IsTransferStmt(scop_, stmt_id_);
BufferedFootPrintInfo buffer_footprint_info = FindBufferFootprintById(active_buf_footprint, fp_id);
if (node->IsInstance<Provide>()) {
if (EmitAccessNodeImpl(node, var_map_tmp, buffer_footprint_info, is_transfer_stmt, scop_, true).defined())
return EmitAccessNodeImpl(node, var_map_tmp, buffer_footprint_info, is_transfer_stmt, scop_, true);
auto stmt = EmitAccessNodeProvide(node, var_map_tmp, buffer_footprint_info);
if (stmt.defined()) return stmt;
}
if (node->IsInstance<Call>()) {
if (EmitAccessNodeImpl(node, var_map_tmp, buffer_footprint_info, is_transfer_stmt, scop_, false).defined())
return EmitAccessNodeImpl(node, var_map_tmp, buffer_footprint_info, is_transfer_stmt, scop_, false);
auto stmt = EmitAccessNodeCall(node, var_map_tmp, buffer_footprint_info);
if (stmt.defined()) return stmt;
}
auto buf_def = scop_.GetBufferDefInfo(buffer_footprint_info.cluster_id);
GetNameWithoutLocal(buf_def.tensor_id, scop_);
auto buf_def = info_.analysis_result_.GetBufferDefInfo(buffer_footprint_info.cluster_id);
auto access = TensorAccessMultAff(buf_def.tensor_id, tensor_index, node_id_);
bool is_copyin_from_another_band = false;
is_copyin_from_another_band = IsCopyinFromAnotherBand(scop_, access);
bool is_copyin_from_another_band = IsCopyinFromAnotherBand(access);
auto memory_hoist = buffer_footprint_info.cluster->ComputeBufferedFootprints();
if (is_copyin_from_another_band) {
......@@ -523,33 +475,31 @@ Stmt IslEmitter::EmitAccessNode(const std::string &name, const Node *node, const
// split read-only or write-only input tensor memory_hoists
// we need to find tensor by name because tensor_id is a fake isl::id
bool is_input_tensor = scop_.FindTensorInOrig(buf_def.tensor_id.name()).defined();
bool is_input_tensor = info_.FindTensorInOrig(buf_def.tensor_id.name()).defined();
if (is_input_tensor && buffer_footprint_info.cluster->foot_print_.should_split) {
memory_hoist = buffer_footprint_info.cluster->UnshiftedBufferFootprint(memory_hoist, fp_id);
}
memory_hoist = memory_hoist.set_tuple_id(isl_dim_out, buffer_footprint_info.cluster_id);
auto schedule = isl::map::from(buffer_footprint_info.outer_schedule.intersect_domain(this->Domain()));
auto schedule = isl::map::from(buffer_footprint_info.outer_schedule.intersect_domain(Domain()));
CHECK(schedule.is_single_valued()) << schedule << " is not single-valued schedule";
auto ast_to_schedule = isl::pw_multi_aff(schedule).pullback(iterator_map);
AffSubForAstToSchedule(ast_to_schedule, is_transfer_stmt, is_copyin_from_another_band);
ast_to_schedule = AffSubForAstToSchedule(ast_to_schedule, IsTransferStmt(), is_copyin_from_another_band);
auto ast_to_original = isl::pw_multi_aff(access).pullback(iterator_map);
auto ast_to_scheduled_original = ast_to_schedule.range_product(ast_to_original);
auto ast_to_hoisted = isl::pw_multi_aff(memory_hoist).pullback(ast_to_scheduled_original);
auto hoist_acs = build.access_from(ast_to_hoisted);
if (auto op = hoist_acs.as<isl::ast_expr_op>()) {
if (auto access_ = op.as<isl::ast_expr_op_access>()) {
if (op.as<isl::ast_expr_op_access>()) {
Array<Expr> args;
for (int i = 1; i < static_cast<int>(op.get_n_arg()); ++i) {
args.push_back(Interpret(op.get_arg(i)));
}
if (node->IsInstance<Provide>())
return IslEmitter::EmitAccessNodeFromPromoteAcsProvide(scop_, op.get_arg(0).as<isl::ast_expr_id>().get_id(),
node, args);
return EmitAccessNodeFromPromoteAcsProvide(op.get_arg(0).as<isl::ast_expr_id>().get_id(), node, args);
if (node->IsInstance<Call>())
return IslEmitter::EmitAccessNodeFromPromoteAcsCall(scop_, op.get_arg(0).as<isl::ast_expr_id>().get_id(), node,
args);
return EmitAccessNodeFromPromoteAcsCall(op.get_arg(0).as<isl::ast_expr_id>().get_id(), node, args);
}
}
return Evaluate::make(Expr("todo EmitAst"));
......@@ -569,7 +519,7 @@ Stmt IslEmitter::EmitUserStmtContent(const Evaluate *eva_node) {
auto im2col = Call::make(call->type, call->name, args, call->call_type);
Stmt res = Evaluate::make(im2col);
// add AttrStmt to im2col
for (const auto &item : scop_.data_.vecs) {
for (const auto &item : info_.analysis_result_.GetBufferBindVec()) {
Expr replaced = ReplaceLoopVar(var_map_).Mutate(item.second);
res = AttrStmt::make(item.first, air::ir::attr::buffer_bind_scope, replaced, res);
}
......@@ -600,8 +550,9 @@ class SubstituteByNameMutator : public IRMutator {
* So, we need to sink the copy out statement into the innermost "if",
* i.e., copy out immediately after each computation.
*/
static Stmt GenerateCopyOut(const Scop &scop, const Provide *original, const Provide *hoisted, const VarMap &var_map) {
auto call_type = scop.GetDtypeOf(hoisted->func->func_name());
static Stmt GenerateCopyOut(const ScopInfo &info, const Provide *original, const Provide *hoisted,
const VarMap &var_map) {
auto call_type = info.GetDtypeOf(hoisted->func->func_name());
Expr call_expr = Call::make(call_type, hoisted->func->func_name(), hoisted->args, Call::CallType::Halide,
hoisted->func, hoisted->value_index);
Array<Expr> new_args;
......@@ -621,8 +572,8 @@ Stmt IslEmitter::EmitUserStmtContent(const Provide *provide_node) {
Expr value = EmitExpr(f, var_map_).Mutate(provide_node->value);
Stmt provide_stmt = Provide::make(provide_new->func, provide_new->value_index, value, provide_new->args);
if (scop_.conditional_write_buffer_footprints_.count(write_tensor)) {
return Block::make(provide_stmt, GenerateCopyOut(scop_, provide_node, provide_new, var_map_));
if (info_.analysis_result_.GetConditionalWriteBufferFootprints().count(write_tensor)) {
return Block::make(provide_stmt, GenerateCopyOut(info_, provide_node, provide_new, var_map_));
}
return provide_stmt;
}
......@@ -688,11 +639,11 @@ Stmt IslEmitter::EmitUserStmt(const isl::ast_node_user &node) {
isl::ast_expr_op usr_expr = node.get_expr().as<isl::ast_expr_op>();
stmt_id_ = usr_expr.get_arg(0).as<isl::ast_expr_id>().get_id();
node_id_ = node.get_annotation();
const Node *stmt_node = scop_.data_.statements.at(stmt_id_);
const Node *stmt_node = info_.analysis_result_.GetStatementMap().at(stmt_id_);
CHECK(stmt_node);
// compute VarMap to replace old iterators
auto build = node_info_map_.at(node_id_).build;
auto tuple = scop_.data_.domains.at(stmt_id_).tuple;
auto tuple = info_.analysis_result_.GetOperatorDomainMap().at(stmt_id_).tuple;
auto iterator_map = node_info_map_.at(node_id_).iterator_map;
var_map_.clear();
......@@ -701,41 +652,51 @@ Stmt IslEmitter::EmitUserStmt(const isl::ast_node_user &node) {
auto isl_expr = build.expr_from(iterator_map.get_pw_aff(i));
Expr halide_new_iter = Interpret(isl_expr);
var_map_.emplace(isl_old_iter, halide_new_iter);
std::string replace_id = isl_old_iter.get_name() + "_";
std::vector<const Variable *> vec = ExtractIterfromExpr().Run(halide_new_iter);
for (auto item : vec) {
std::string new_name = item->name_hint;
size_t pos = new_name.find(scop_.iter_prefix_);
if (pos != std::string::npos) {
new_name = new_name.replace(pos, scop_.iter_prefix_.size(), replace_id);
iters_old_name_.emplace(item, item->name_hint);
iters_new_name_.emplace(item, new_name);
}
}
}
VarMap vmap = var_map_;
stmt_var_map_.emplace(stmt_id_, vmap);
return EmitUserStmtContent(stmt_node);
}
Stmt IslEmitter::EmitStmt(const isl::ast_node_user &node) { return EmitUserStmt(node); }
Stmt IslEmitter::EmitStmt(const isl::ast_node_user &node) {
CHECK(node.get_expr().isa<isl::ast_expr_op>());
isl::ast_expr_op usr_expr = node.get_expr().as<isl::ast_expr_op>();
CHECK(usr_expr);
auto stmt_id = usr_expr.get_arg(0).as<isl::ast_expr_id>().get_id();
if (info_.IsRead(stmt_id)) {
return Evaluate::make(Expr("todo EmitRead"));
}
if (info_.IsWrite(stmt_id)) {
return Evaluate::make(Expr("todo EmitWrite"));
}
return EmitUserStmt(node);
}
Stmt IslEmitter::EmitAst(const isl::ast_node &node) {
Stmt s;
std::string info;
if (auto for_node = node.as<isl::ast_node_for>()) {
return EmitFor(for_node);
info = "[FOR_NODE]";
s = EmitFor(for_node);
} else if (auto if_node = node.as<isl::ast_node_if>()) {
return EmitIf(if_node);
info = "[IF_NODE]";
s = EmitIf(if_node);
} else if (auto block_node = node.as<isl::ast_node_block>()) {
return EmitBlock(block_node);
info = "[BLOCK_NODE]";
s = EmitBlock(block_node);
} else if (auto mark_node = node.as<isl::ast_node_mark>()) {
return EmitMark(mark_node);
info = "[MARK_NODE]";
s = EmitMark(mark_node);
} else if (auto user_node = node.as<isl::ast_node_user>()) {
return EmitStmt(user_node);
info = "[USER_NODE]";
s = EmitStmt(user_node);
} else {
LOG(FATAL) << "NYI " << node << "\n";
s = Evaluate::make(Expr("todo EmitAst"));
}
return Evaluate::make(Expr("todo EmitAst"));
if (PRINT_EMMITER) {
LOG(INFO) << ">>>>>>>>>>>>INPUT AST_NODE" << info << "<<<<<<<<<<<<<<\n" << node;
LOG(INFO) << ">>>>>>>>>>>>OUTPUT STMT<<<<<<<<<<<<\n" << s;
}
return s;
}
Stmt IslEmitter::Emit(const isl::ast_node &node) { return EmitAst(node); }
......
......@@ -19,11 +19,9 @@
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "ir_pass.h"
#include "poly/isl.h"
#include "poly/scop.h"
#include "poly/scop_info.h"
namespace akg {
namespace ir {
......@@ -47,29 +45,31 @@ class IslEmitter {
Expr InterpretBinaryOp(const isl::ast_expr_op &e);
public:
explicit IslEmitter(Scop &s_, const NodeInfoRepo &n_, const isl::id_list &i_)
: scop_(s_), node_info_map_(n_), iter_names_(i_) {}
explicit IslEmitter(ScopInfo &info, const NodeInfoRepo &n, const isl::id_list &i)
: info_(info), node_info_map_(n), iter_names_(i) {}
virtual ~IslEmitter() = default;
/// Interpret isl::ast_expr to Halide Expr
//@{
// Interpret isl::ast_expr to Halide Expr
Expr Interpret(const isl::ast_expr &e);
//@}
// helper functions, which may can be moved into a separated class
isl::space GetDomainSpace(const isl::id &stmt_id);
isl::space GetSpace(const isl::id &tensor_id, const Array<Expr> &tensor_index, const isl::id &stmt_id);
isl::multi_aff TensorAccessMultAff(const isl::id &tensor_id, const Array<Expr> &subscripts, const isl::id &stmt_id);
isl::set Domain() const {
auto iterator_map = node_info_map_.at(node_id_).iterator_map;
return isl::map::from(iterator_map).range();
}
Stmt EmitAccessNode(const std::string &name, const Node *node, const Array<Expr> &tensor_index,
const VarMap &var_map_tmp);
Stmt EmitAccessNodeFromPromoteAcsProvide(Scop &scop, isl::id var, const Node *node, Array<Expr> &args);
Stmt EmitAccessNodeFromPromoteAcsCall(Scop &scop, isl::id var, const Node *node, Array<Expr> &args);
/// Virtual emitters for different type node
//@{
Stmt EmitAccessNodeFromPromoteAcsProvide(isl::id var, const Node *node, Array<Expr> &args);
Stmt EmitAccessNodeFromPromoteAcsCall(isl::id var, const Node *node, Array<Expr> &args);
Stmt EmitAccessNodeProvide(const Node *node, const VarMap &var_map_tmp, BufferedFootPrintInfo &buffer_fp_info);
virtual Stmt EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, BufferedFootPrintInfo &buffer_fp_info);
virtual isl::multi_aff TensorAccessMultAff(isl::id &tensor_id, const Array<Expr> &subscripts, const isl::id &stmt_id);
virtual bool IsTransferStmt();
virtual bool IsCopyinFromAnotherBand(isl::multi_aff &access);
// Virtual emitters for different type node
virtual Stmt Emit(const isl::ast_node &node);
virtual Stmt EmitFor(const isl::ast_node_for &node);
virtual Stmt EmitIf(const isl::ast_node_if &node);
......@@ -84,7 +84,12 @@ class IslEmitter {
virtual Stmt EmitUserStmtContent(const IfThenElse *if_node);
virtual Stmt EmitUserStmtContent(const For *for_node);
virtual Stmt EmitUserStmtContent(const Block *block_node);
//@}
// Loop isl iters info
virtual void PushIter(const Variable *iter);
virtual void PopIter(const Variable *iter);
bool FindIter(const Variable *iter) const;
const Variable *GetIterByName(const std::string &id) const;
std::unordered_set<isl::id, isl::IslIdIslHash> realize_use_;
std::unordered_set<isl::id, isl::IslIdIslHash> realize_use_with_may_def_;
......@@ -93,28 +98,16 @@ class IslEmitter {
std::unordered_set<isl::id, isl::IslIdIslHash> realize_out_;
std::unordered_set<isl::id, isl::IslIdIslHash> global_realize_out_;
/// Scop
Scop &scop_;
ScopInfo &info_;
/// Node information map including
const NodeInfoRepo &node_info_map_;
/// Loop isl iters info
//@{
/// Loop isl iters list
isl::id_list iter_names_;
/// Loop declared halide iters
std::vector<const Variable *> iters_;
virtual void PushIter(const Variable *iter);
virtual void PopIter(const Variable *iter);
bool FindIter(const Variable *iter) const;
const Variable *GetIterByName(const std::string &id) const;
//@}
std::map<const Variable *, std::string> iters_old_name_;
std::map<const Variable *, std::string> iters_new_name_;
// current ast node id
isl::id node_id_;
// current stmt id
......@@ -125,7 +118,6 @@ class IslEmitter {
// emit in if
std::vector<const Node *> cur_if_list_;
std::unordered_map<isl::id, std::vector<const Node *>, isl::IslIdIslHash> if_map_;
std::unordered_map<isl::id, VarMap, isl::IslIdIslHash> stmt_var_map_;
};
class ExtractIterfromExpr : public air::ir::IRVisitor {
......@@ -146,16 +138,23 @@ class ExtractIterfromExpr : public air::ir::IRVisitor {
std::vector<const Variable *> vec_;
};
void FindBufferFootprintById(Scop::BufferedFootPrintInfo &buffer_footprint_info,
std::vector<Scop::BufferedFootPrintInfo> active_buffer_fp, isl::id id);
void GetNameWithoutLocal(isl::id &tensor_id, Scop &scop);
bool IsTransferStmt(Scop &scop, isl::id &stmt_id);
bool IsCopyinFromAnotherBand(Scop &scop, isl::multi_aff &access);
void AffSubForAstToSchedule(isl::pw_multi_aff &ast_to_schedule, bool &is_transfer_stmt,
bool &is_copyin_from_another_band);
Stmt EmitAccessNodeProvide(const Node *node, const VarMap &var_map_tmp, Scop::BufferedFootPrintInfo &buffer_fp_info);
Stmt EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, Scop::BufferedFootPrintInfo &buffer_fp_info,
bool &is_transfer_stmt, Scop &scop);
class ReplaceLoopVar : public air::ir::IRMutator {
public:
explicit ReplaceLoopVar(VarMap v_) : var_map(std::move(v_)) {}
~ReplaceLoopVar() override = default;
Expr Mutate_(const Variable *op, const Expr &e) final {
for (auto &i : var_map) {
if (op->name_hint == i.first.get_name()) {
return i.second;
}
}
return e;
}
private:
VarMap var_map;
};
} // namespace poly
} // namespace ir
} // namespace akg
......
......@@ -13,21 +13,19 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_RESCHEDULE_H_
#define POLY_RESCHEDULE_H_
#include "poly/pass_info.h"
#pragma once
#include "poly/transform.h"
#include <tvm/ir_visitor.h>
#include <tvm/operation.h>
#include <isl/constraint.h>
#include <climits>
#include <fstream>
#include <queue>
#include <cmath>
namespace akg {
namespace ir {
namespace poly {
isl::schedule_node ReorderFilters(const isl::schedule_node &node,
const std::unordered_map<size_t, size_t> &old_to_new_map);
} // namespace poly
namespace poly {} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_RESCHEDULE_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_PASS_INFO_H_
#define POLY_PASS_INFO_H_
#include <vector>
#include <map>
#include <unordered_map>
#include "isl.h"
namespace akg {
namespace ir {
namespace poly {
using ReduceStmtMap = std::unordered_map<isl::id, std::vector<std::string>, isl::IslIdIslHash>;
class Dependency {
private:
isl::id start_node_id_;
isl::id end_node_id_;
int64_t edge_weight_;
public:
Dependency(const isl::id start_node_id, const isl::id end_node_id, const int64_t edge_weight)
: start_node_id_(start_node_id), end_node_id_(end_node_id), edge_weight_(edge_weight) {}
~Dependency() {}
isl::id GetStartNode() { return start_node_id_; }
isl::id GetEndNode() { return end_node_id_; }
int64_t GetEdgeWeight() const { return edge_weight_; }
};
// pass info on schedule transform
class PassInfo {
public:
PassInfo() {}
~PassInfo() {}
bool has_grouped_{false};
bool tile_check_coincident_{false};
std::unordered_map<isl::id, isl::union_set_list, isl::IslIdIslHash> group_filter_map_;
std::vector<Dependency> dependency_list_;
isl::union_pw_multi_aff group_upma_;
isl::schedule_constraints constraints_;
bool coincident_{true};
isl::union_map dependences_;
isl::union_map orig_dependences_;
isl::union_set transfer_stmt_;
ReduceStmtMap reduce_stmts_;
std::map<std::string, int> invariant_state_;
bool has_invariant_dependence_{false};
bool restart_{false};
};
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_PASS_INFO_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_PASS_MGR_STRATEGY_H_
#define POLY_PASS_MGR_STRATEGY_H_
#include "poly/schedule_pass.h"
#include "poly/dump_log.h"
#include "poly/schedule_pass/init_schedule.h"
#include "poly/schedule_pass/compute_schedule.h"
namespace akg {
namespace ir {
namespace poly {
class PassMgrStrategy {
public:
explicit PassMgrStrategy(ScopInfo &scop_info) : scop_info_(scop_info) {}
void RegisterPass(std::shared_ptr<SchedulePass> pass) {
CHECK(pass);
passes_.emplace_back(std::move(pass));
}
void RegisterNormalizationPasses() { RegisterPass(std::make_shared<InitSchedule>(pass_info_, scop_info_)); }
void RegisterSchedulingPasses() { RegisterPass(std::make_shared<ComputeSchedule>(pass_info_, scop_info_)); }
virtual void RegisterTilingPasses() = 0; // each backend has different achievement
virtual void RegisterMemPromPasses() = 0; // each backend has different achievement
virtual void RegisterPasses() = 0;
const std::vector<std::shared_ptr<SchedulePass>> &GetPasses() const { return passes_; };
virtual ~PassMgrStrategy() = default;
ScopInfo &scop_info_;
PassInfo pass_info_;
protected:
std::vector<std::shared_ptr<SchedulePass>> passes_;
};
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_PASS_MGR_STRATEGY_H_
\ No newline at end of file
......@@ -13,15 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <memory>
#include "ir_pass.h"
#include "poly/scop.h"
#include "pass/utils.h"
namespace akg {
namespace ir {
/*!
......@@ -31,63 +24,72 @@ class Poly {
public:
Poly() : isl_ctx_(isl::ctx(isl_ctx_alloc())) {}
~Poly() noexcept {
scop_.reset();
// scop must be deconstructed before isl_ctx is deconstructed
isl_ctx_free(isl_ctx_.get());
}
void Run(const Stmt &stmt, const Map<Tensor, Buffer> &extern_buffer, const Map<std::string, NodeRef> &attrs,
const bool is_spec_gemm, bool is_tuning, bool is_dynamic) {
stmt_ = stmt;
scop_.reset(new poly::Scop(Simplify_cce(stmt_), extern_buffer, isl_ctx_, is_spec_gemm));
scop_.reset(new poly::Scop(Simplify_cce(stmt_), isl_ctx_));
CHECK(scop_ != nullptr);
scop_->ParseUserConfig(attrs, extern_buffer, is_spec_gemm, is_tuning, is_dynamic);
scop_->SetAttrs(attrs);
scop_->is_dynamic_ = is_dynamic;
// generate isl schedule from Halide
std::chrono::high_resolution_clock::time_point timer_start;
// generate isl schedule from Halide
TIMER_START;
isl::schedule sch = scop_->GenIsl();
TIMER_SHOW("GenIsl", std::string(is_spec_gemm ? "_specgemm" : ""));
// transform isl schedule with coincidence constraints
isl::schedule scht = scop_->Transform(sch, true, is_tuning);
if (is_tuning) return;
if (scht.get() == sch.get()) {
// transform failed, redo transform without coincidence constraints
scht = scop_->Transform(sch, false);
}
// isl schedule transform
TIMER_START;
isl::schedule sched = scop_->Transform(sch);
TIMER_SHOW("Transform", std::string(is_spec_gemm ? "_specgemm" : ""));
// generate Halide from isl schedule
stmt_ = scop_->GenHalide(scht);
TIMER_START;
stmt_ = scop_->GenHalide(sched);
TIMER_SHOW("GenHalide", std::string(is_spec_gemm ? "_specgemm" : ""));
if (is_dynamic) stmt_ = RestoreCombinedParams(stmt_, scop_->info_);
if (is_tuning) {
spaces_ = GenerateTilingSpace(sched, scop_->info_, stmt_, scop_->info_.user_config_.GetDumpTuningLevel());
return;
}
// optimize post poly Halide IR for Davinci
if (scop_->enable_feature_library_ || scop_->optimize_for_davinci_) {
stmt_ = poly::OptimizeHalide(stmt_, !scop_->params_.empty());
if (scop_->info_.user_config_.GetEnableFeatureLib() || scop_->info_.user_config_.GetOptimizeForDavinci()) {
stmt_ = poly::DavinciHalideOptimizer(stmt_, !scop_->info_.user_config_.GetParams().empty());
}
gen_empty_tiling = scop_->is_tiled_;
gen_empty_tiling = scop_->info_.analysis_result_.GetIsTiled();
}
~Poly() noexcept {
scop_.reset();
// scop must be deconstructed before isl_ctx is deconstructed
isl_ctx_free(isl_ctx_.get());
}
Stmt GetStmt() { return stmt_; }
Stmt getstmt() { return stmt_; }
bool gen_empty_tiling{false};
Array<Var> getTilingParams() {
NodeRef GetSpaces() { return spaces_; }
Array<Var> GetTilingParams() {
CHECK(scop_ != nullptr);
Array<Var> tiling_params_array;
if (gen_empty_tiling) return tiling_params_array;
std::unordered_set<Var, NodeHash, NodeEqual> tiling_params;
for (const auto &kv : scop_->param_tiling_map_) {
auto param_tiling_map = scop_->info_.user_config_.GetParamTilingMap();
for (const auto &kv : param_tiling_map) {
GatherVars(kv.second, &tiling_params);
}
for (const auto &param : tiling_params) tiling_params_array.push_back(param);
return tiling_params_array;
}
NodeRef getspaces() {
CHECK(scop_ != nullptr);
return scop_->spaces_;
void GatherVars(const Expr expr, std::unordered_set<Var, air::NodeHash, air::NodeEqual> *vset) {
PostOrderVisit(expr, [&vset](const NodeRef &node) {
if (node.as<Variable>()) {
vset->insert(Downcast<Var>(node));
}
});
}
private:
......@@ -96,6 +98,8 @@ class Poly {
// and we need to ensure that they are deconstructed before the isl_ctx is freed.
isl::ctx isl_ctx_;
Stmt stmt_;
NodeRef spaces_;
bool gen_empty_tiling{false};
};
/// Interface for lower pass
......@@ -103,14 +107,14 @@ Array<NodeRef> AutoPoly(const Stmt &stmt, const Map<Tensor, Buffer> &extern_buff
const Map<std::string, NodeRef> &attrs, const bool is_specgemm, const bool is_dynamic) {
Poly poly;
poly.Run(stmt, extern_buffer, attrs, is_specgemm, false, is_dynamic);
return Array<NodeRef>({poly.getstmt(), poly.getTilingParams()});
return Array<NodeRef>({poly.GetStmt(), poly.GetTilingParams()});
}
NodeRef GenTuningSpace(const Stmt &stmt, const Map<Tensor, Buffer> &extern_buffer,
const Map<std::string, NodeRef> &attrs, const bool is_specgemm) {
Poly poly;
poly.Run(stmt, extern_buffer, attrs, is_specgemm, true, false);
return poly.getspaces();
return poly.GetSpaces();
}
} // namespace ir
} // namespace akg
......@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "poly/poly_util.h"
namespace akg {
......@@ -120,6 +121,65 @@ Stmt PeelOuterLetStmt(const Stmt &s, std::vector<Stmt> &outer_stmts) {
return body;
}
void GetAffOffsetAndNumVars(const isl::aff &aff, int &offset, int &num_vars) {
offset = aff.get_constant_val().get_num_si();
num_vars = 0;
int dim = isl_aff_dim(aff.get(), isl_dim_in);
CHECK_GE(dim, 0);
for (int j = 0; j < dim; ++j) {
isl_val *coef = isl_aff_get_coefficient_val(aff.get(), isl_dim_in, j);
int coef_val = isl_val_get_num_si(coef);
static_cast<void>(isl_val_free(coef));
if (coef_val != 0) ++num_vars;
}
}
/*
* Check the isl::aff is in the form of { [i0, i1, i2, i3, i4] -> [(-64 + i2)] }
* i.e. the mapping is one variable plus a non-zero constant offset.
*/
bool IsAffVarPlusOffset(const isl::aff &aff) {
int offset = 0, num_vars = 0;
GetAffOffsetAndNumVars(aff, offset, num_vars);
return offset != 0 && num_vars == 1;
}
/*
* Check the isl::aff is in the form of { [i0, i1, i2, i3, i4] -> [(64)] }
* i.e. the mapping is a non-zero constant.
*/
bool IsAffNonZeroConst(const isl::aff &aff) {
int offset = 0, num_vars = 0;
GetAffOffsetAndNumVars(aff, offset, num_vars);
return offset != 0 && num_vars == 0;
}
isl::union_map LocalScheduleImpl(const isl::schedule_node &node, bool use_node) {
int tree_depth = node.get_tree_depth();
int new_tree_depth = tree_depth;
if (use_node) ++new_tree_depth;
isl::schedule_node tmp_node;
isl::union_map schedule = isl::union_map::from_domain(node.get_domain());
for (int i = 0; i < new_tree_depth; ++i) {
tmp_node = node.ancestor(tree_depth - i);
if (auto band_node = tmp_node.as<isl::schedule_node_band>()) {
if (band_node.n_member() > 0) {
schedule = schedule.flat_range_product(band_node.get_partial_schedule_union_map());
}
} else if (auto filter_node = tmp_node.as<isl::schedule_node_filter>()) {
schedule = schedule.intersect_domain(filter_node.get_filter());
} else if (auto extension_node = tmp_node.as<isl::schedule_node_extension>()) {
schedule = schedule.unite(extension_node.get_extension().reverse().intersect_range(schedule.range()));
}
}
return schedule;
}
isl::union_map ShortSchedule(const isl::schedule_node &node) { return LocalScheduleImpl(node, false); }
isl::union_map LocalSchedule(const isl::schedule_node &node) { return LocalScheduleImpl(node, true); }
} // namespace poly
} // namespace ir
} // namespace akg
......@@ -15,12 +15,9 @@
*/
#ifndef POLY_UTIL_H_
#define POLY_UTIL_H_
#pragma once
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <chrono>
#include "isl.h"
namespace akg {
......@@ -31,28 +28,26 @@ namespace poly {
#define PRETTY_PRINT_IR true
#define DUMP_SCOP_DATA true
#define DUMP_SCOP_DATA_PER_PASS false
#define DUMP_TRANSFORM true
#define DUMP_TRANSFORM_PER_PASS false
#define DUMP_IN_CURRENT_DIR false
#define PRINT_C false
#define PRINT_SCHEDULE_INFO false
#define PRINT_ISL_EMMITER false
#define PRINT_CCE_ISL_EMMITER false
#define PRINT_EMMITER (PRINT_ISL_EMMITER || PRINT_CCE_ISL_EMMITER)
#define SPEC_GEMM true
#define DELETE_FRACTAL true
/// conv_backward options
#define SELECT_DOMAIN_OPT true
/// transform options
#define USE_CACHED_SCHEDULE false
#define ENABLE_REPLACE_SCHEDULE_HOOK true
/// constants
constexpr auto kReadSuffix = "read";
constexpr auto kWriteSuffix = "write";
constexpr auto kIterNamePrefix = "cc";
constexpr auto kGemmIterNamePrefix = "ee";
constexpr auto TENSORLISTTAILNAME = "TensorListTail";
// timer records
#define TIMER_START timer_start = std::chrono::high_resolution_clock::now()
#define TIMER_DURATION \
(std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::high_resolution_clock::now() - timer_start) \
.count()) * \
1000
#define TIMER_SHOW(NAME, SPEC_GEMM) \
{ LOG(INFO) << "[ Polyhedral exec time" << SPEC_GEMM << " ], " << NAME << " spent " << TIMER_DURATION << " ms"; }
unsigned int WrappedStrtol(const std::string &str);
......@@ -68,6 +63,12 @@ Expr RemoveCast(Expr e);
Stmt PeelOuterLetStmt(const Stmt &s, std::vector<Stmt> &outer_stmts);
isl::union_map ShortSchedule(const isl::schedule_node &node);
isl::union_map LocalSchedule(const isl::schedule_node &node);
void GetAffOffsetAndNumVars(const isl::aff &aff, int &offset, int &num_vars);
bool IsAffVarPlusOffset(const isl::aff &aff);
bool IsAffNonZeroConst(const isl::aff &aff);
class ConsolidateExprMutator : public IRMutator {
public:
explicit ConsolidateExprMutator(const std::unordered_map<std::string, Var> &params_) : params(params_) {}
......@@ -86,15 +87,15 @@ class ConsolidateExprMutator : public IRMutator {
}
// list operators that may appear in dynamic shape params
Expr Mutate_(const Add *op, const Expr &e) { return GenericMutate(op, e); }
Expr Mutate_(const Sub *op, const Expr &e) { return GenericMutate(op, e); }
Expr Mutate_(const Mul *op, const Expr &e) { return GenericMutate(op, e); }
Expr Mutate_(const FloorDiv *op, const Expr &e) { return GenericMutate(op, e); }
Expr Mutate_(const FloorMod *op, const Expr &e) { return GenericMutate(op, e); }
Expr Mutate_(const Div *op, const Expr &e) { return GenericMutate(op, e); }
Expr Mutate_(const Mod *op, const Expr &e) { return GenericMutate(op, e); }
Expr Mutate_(const Min *op, const Expr &e) { return GenericMutate(op, e); }
Expr Mutate_(const Max *op, const Expr &e) { return GenericMutate(op, e); }
Expr Mutate_(const Add *op, const Expr &e) override { return GenericMutate(op, e); }
Expr Mutate_(const Sub *op, const Expr &e) override { return GenericMutate(op, e); }
Expr Mutate_(const Mul *op, const Expr &e) override { return GenericMutate(op, e); }
Expr Mutate_(const FloorDiv *op, const Expr &e) override { return GenericMutate(op, e); }
Expr Mutate_(const FloorMod *op, const Expr &e) override { return GenericMutate(op, e); }
Expr Mutate_(const Div *op, const Expr &e) override { return GenericMutate(op, e); }
Expr Mutate_(const Mod *op, const Expr &e) override { return GenericMutate(op, e); }
Expr Mutate_(const Min *op, const Expr &e) override { return GenericMutate(op, e); }
Expr Mutate_(const Max *op, const Expr &e) override { return GenericMutate(op, e); }
const std::unordered_map<std::string, Var> &params;
};
......@@ -168,6 +169,9 @@ constexpr auto ATTR_GEMM_WEIGHT_TRANSPOSE_BLOCK_INNER = "pragma_weight_transpose
constexpr auto ATTR_ATOMIC_ADD = "atomic_add";
constexpr auto ATOMIC_COND_CLEAN = "atomic_cond_clean";
constexpr auto UBL0 = "UBL0";
constexpr auto REALIZE_ = "realize_";
/******************************************************
* Following const is the mark tags for schedule tree
******************************************************/
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "schedule_pass.h"
#include <climits>
#include <fstream>
namespace akg {
namespace ir {
namespace poly {
/* Reorder filters of a sequence/set node.
* node: must be a sequence or set node.
* old_to_new_map: map from original child position to new child position.
* The caller should make sure that there are no duplicate values.
*/
isl::schedule_node ReorderFilters(const isl::schedule_node &node,
const std::unordered_map<size_t, size_t> &old_to_new_map) {
auto n_children = node.n_children();
isl_schedule_tree *old_tree = isl_schedule_node_get_tree(node.get());
CHECK(old_tree != nullptr);
isl_schedule_tree *new_tree = isl_schedule_node_get_tree(node.get());
CHECK(new_tree != nullptr);
for (auto &it : old_to_new_map) {
auto old_pos = it.first;
auto new_pos = it.second;
CHECK(old_pos < n_children);
CHECK(new_pos < n_children);
isl_schedule_tree *old_child = isl_schedule_tree_get_child(old_tree, old_pos);
CHECK(old_child != nullptr);
new_tree = isl_schedule_tree_replace_child(new_tree, new_pos, old_child);
CHECK(new_tree != nullptr);
}
static_cast<void>(isl_schedule_tree_free(old_tree));
isl_schedule_node *new_node = isl_schedule_node_graft_tree(node.copy(), new_tree);
CHECK(new_node != nullptr);
return isl::manage(new_node);
}
isl::union_map DependenceAnalysis(const isl::union_map &sources, const isl::union_map &targets,
const isl::union_map &kills, const isl::union_map &sch) {
auto access_info = isl::union_access_info(targets);
access_info = access_info.set_kill(kills);
access_info = access_info.set_may_source(sources);
access_info = access_info.set_schedule_map(sch);
auto union_flow = access_info.compute_flow();
return union_flow.get_may_dependence();
}
isl::union_map ComputeAllDependences(const isl::schedule &schedule, const isl::union_map &reads_um,
const isl::union_map &writes_um) {
auto reads = reads_um.domain_factor_domain();
auto writes = writes_um.domain_factor_domain();
auto sch = schedule.get_map();
// RAW
auto flowDeps = DependenceAnalysis(writes, reads, writes, sch);
// WAR and WAW
auto falseDeps = DependenceAnalysis(writes.unite(reads), writes, writes, sch);
return flowDeps.unite(falseDeps).coalesce();
}
isl::schedule_node GetOuterBand(const isl::schedule_node &root) {
auto outer_band = root;
while (!outer_band.isa<isl::schedule_node_band>()) {
auto n = outer_band.n_children();
if (n == 1) {
outer_band = outer_band.child(0);
continue;
} else {
/*
* return the node when encountered branching or a leaf
* an empty band would be inserted elsewhere
*/
return outer_band;
}
}
return outer_band;
}
bool IsSequenceOrSet(const isl::schedule_node &node) {
if (node.isa<isl::schedule_node_sequence>()) return true;
return node.isa<isl::schedule_node_set>();
}
isl::union_map ComputeFilterCopyin(const isl::schedule_node &node, const isl::union_map &ori_reads,
const isl::union_map &ori_writes, const isl::schedule ori_schedule) {
CHECK(node.isa<isl::schedule_node_filter>()) << "The input should be a filter node!" << std::endl;
auto filter = node.as<isl::schedule_node_filter>().get_filter();
auto reads = ori_reads.domain_factor_domain().intersect_domain(filter);
auto writes = ori_writes.domain_factor_domain().intersect_domain(filter);
auto uai = isl::union_access_info(reads);
uai = uai.set_kill(writes);
uai = uai.set_may_source(writes);
uai = uai.set_schedule(ori_schedule);
auto flow = uai.compute_flow();
auto mayNoSource = flow.get_may_no_source();
auto copyin = ori_reads.intersect_range(mayNoSource.range());
return copyin;
}
isl::union_map ComputeFakeCopyin(const isl::schedule &schedule, const isl::union_map &fake_copyin,
const isl::union_map &ori_reads, const isl::union_map &ori_writes) {
auto root = schedule.get_root();
auto node = GetOuterBand(root);
auto result = fake_copyin;
if (!IsSequenceOrSet(node)) return result;
auto n = node.n_children();
for (auto i = 0u; i < n; ++i) {
auto child = node.child(i);
auto copyin = ComputeFilterCopyin(child, ori_reads, ori_writes, schedule);
result = result.unite(copyin);
}
return result;
}
isl::schedule_constraints MakeScheduleConstraints(const isl::schedule &schedule, PassInfo &pass_info) {
isl::schedule_constraints constraints;
if (pass_info.coincident_) {
constraints = isl::schedule_constraints::on_domain(schedule.get_domain())
.set_coincidence(pass_info.dependences_) // keep it, check for more cases
.set_validity(pass_info.dependences_)
.set_proximity(pass_info.dependences_);
} else {
constraints = isl::schedule_constraints::on_domain(schedule.get_domain())
.set_validity(pass_info.dependences_)
.set_proximity(pass_info.dependences_);
}
return constraints;
}
/*
* Merge multiple lines of strings into a single-line string
*/
static std::string UndoPrettyPrintSchTree(const std::string &schedule) {
const char *src = schedule.c_str();
std::stringstream dst;
bool in_string = false;
while (*src != '\0') {
if (*src == '"') {
in_string = !in_string;
if (!in_string) {
// end of string, find next non-empty char
const char *next = src + 1;
while (*next != '\0') {
char c = *next;
if (c != ' ' && c != '\t' && c != '\n' && c != '\r') {
break;
}
++next;
}
if (*next == '"') {
// multiple consecutive strings, merge them and insert a white space
dst << " ";
src = next + 1;
in_string = true;
continue;
}
}
}
dst << *src++;
}
return dst.str();
}
bool LoadScheduleTreeFromFile(const std::string &filename, isl::schedule &schedule) {
std::ifstream new_schedule_file_stream(filename);
std::string schedule_to_replace_str((std::istreambuf_iterator<char>(new_schedule_file_stream)),
std::istreambuf_iterator<char>());
schedule_to_replace_str = UndoPrettyPrintSchTree(schedule_to_replace_str);
isl_schedule *ss = isl_schedule_read_from_str(schedule.ctx().get(), schedule_to_replace_str.c_str());
if (ss != nullptr) {
schedule = isl::manage(ss);
return true;
} else {
LOG(WARNING) << "Failed to load file " << filename << " to schedule tree, please check syntax of the new schedule.";
return false;
}
}
/*
* Compare and replace schedule hook:
* Enable users to replace a specific schedule for debugging purpose.
* If the current schedule is identical to the schedule in OLD_SCHEDULE_FILE,
* the schedule will be replaced with NEW_SCHEDULE_FILE.
*/
bool ReplaceScheduleTree(isl::schedule &schedule, ScopInfo &info) {
const std::string OLD_SCHEDULE_FILE = info.AddDumpDir("old_schedule.txt");
const std::string NEW_SCHEDULE_FILE = info.AddDumpDir("new_schedule.txt");
// check if two files exist
char pathBuffOld[PATH_MAX + 1] = {0};
char pathBuffNew[PATH_MAX + 1] = {0};
bool should_compare_and_replace = false;
if (realpath(OLD_SCHEDULE_FILE.c_str(), pathBuffOld) && realpath(NEW_SCHEDULE_FILE.c_str(), pathBuffNew)) {
FILE *schedule_to_compare = fopen(pathBuffOld, "r");
FILE *schedule_to_replace = fopen(pathBuffNew, "r");
should_compare_and_replace = (schedule_to_compare != nullptr && schedule_to_replace != nullptr);
if (schedule_to_compare != nullptr) {
int status = fclose(schedule_to_compare);
if (status != 0) LOG(WARNING) << "Failed to close old_schedule.txt";
}
if (schedule_to_replace != nullptr) {
int status = fclose(schedule_to_replace);
if (status != 0) LOG(WARNING) << "Failed to close new_schedule.txt";
}
}
if (should_compare_and_replace) {
std::ifstream old_schedule_file_stream(OLD_SCHEDULE_FILE);
std::string schedule_to_compare_str((std::istreambuf_iterator<char>(old_schedule_file_stream)),
std::istreambuf_iterator<char>());
if (CompareSchTreeWithString(schedule_to_compare_str, schedule)) {
LOG(INFO) << "Current schedule is same as " << OLD_SCHEDULE_FILE << ", replace it with new schedule "
<< NEW_SCHEDULE_FILE;
CHECK(LoadScheduleTreeFromFile(NEW_SCHEDULE_FILE, schedule));
return true;
} else {
LOG(INFO) << "Current schedule is different from " << OLD_SCHEDULE_FILE << ", not replacing.";
}
}
return false;
}
} // namespace poly
} // namespace ir
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_PASS_H_
#define POLY_PASS_H_
#include "poly/isl.h"
#include "poly/scop_info.h"
#include "poly/pass_info.h"
namespace akg {
namespace ir {
namespace poly {
class SchedulePass {
public:
virtual ~SchedulePass() {}
virtual isl::schedule Run(isl::schedule sch) = 0;
std::string GetPassName() { return pass_name_; }
std::string pass_name_;
bool restart_{false}; // triggers restart during runtime
};
isl::schedule_node ReorderFilters(const isl::schedule_node &node,
const std::unordered_map<size_t, size_t> &old_to_new_map);
isl::union_map DependenceAnalysis(const isl::union_map &sources, const isl::union_map &targets,
const isl::union_map &kills, const isl::union_map &sch);
isl::union_map ComputeAllDependences(const isl::schedule &schedule, const isl::union_map &reads_um,
const isl::union_map &writes_um);
isl::schedule_node GetOuterBand(const isl::schedule_node &root);
bool IsSequenceOrSet(const isl::schedule_node &node);
/*
* Compute copyin for each filter node, by intersecting the domains of
* reads and writes of the entire scop.
*/
isl::union_map ComputeFilterCopyin(const isl::schedule_node &node, const isl::union_map &ori_reads,
const isl::union_map &ori_writes, const isl::schedule ori_schedule);
bool CompareSchTreeWithString(const std::string &compare_sch, const isl::schedule &sch);
isl::schedule_constraints MakeScheduleConstraints(const isl::schedule &schedule, PassInfo &pass_info);
isl::union_map RemoveReduceOpSelfDependence(ScopInfo &scop_info, PassInfo &pass_info);
isl::union_map RemoveSelfDependence(PassInfo &pass_info);
isl::union_map RemoveInvariantDependence(const isl::schedule &schedule, PassInfo &pass_info);
/*
* Compute copyin for each filter and return the union of such copyins.
* In particular, return an empty result when the outermost band node
* is not a sequence/set node.
*
* "result" is the union of "copyin" from each filter node, which in
* turn is computed by ComputeFilterCopyin.
*/
isl::union_map ComputeFakeCopyin(const isl::schedule &schedule, const isl::union_map &fake_copyin,
const isl::union_map &ori_reads, const isl::union_map &ori_writes);
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_PASS_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "change_marknode_position.h"
namespace akg {
namespace ir {
namespace poly {
isl::schedule ChangeMarkNodePosition::Run(isl::schedule curr_schedule) {
std::unordered_set<std::string> ids = with_stmts_ids_;
if (ids.empty()) {
return curr_schedule;
}
auto fn = [&ids](isl::schedule_node node) -> isl::schedule_node {
if (node.isa<isl::schedule_node_mark>()) {
std::string mark_id = node.as<isl::schedule_node_mark>().get_id().get_name();
if (mark_id == "realize_UB" && node.child(0).isa<isl::schedule_node_band>()) {
if (node.child(0).child(0).isa<isl::schedule_node_sequence>()) {
node = node.get_child(0).get_child(0); // sequence
bool delete_outer_mark = true;
int n = node.n_children();
for (int i = 0; i < n; i++) {
isl::schedule_node_filter filter_node = node.child(i).as<isl::schedule_node_filter>();
bool is_not_with_stmt = filter_node.get_filter().every_set(
[&ids](const isl::set &s) -> bool { return (ids.count(s.get_tuple_name()) == 0); });
if (is_not_with_stmt) {
delete_outer_mark = false;
} else {
node = node.child(i).child(0);
node = node.insert_mark(isl::id(node.ctx(), mark_id));
node = node.parent().parent();
}
}
node = node.parent().parent();
if (delete_outer_mark) {
node = node.del();
}
}
}
}
return node;
};
return curr_schedule.get_root().map_descendant_bottom_up(fn).get_schedule();
}
} // namespace poly
} // namespace ir
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_CHANGE_MARKNODE_POSITION_H_
#define POLY_CHANGE_MARKNODE_POSITION_H_
#include "poly/schedule_pass.h"
#include <unordered_set>
namespace akg {
namespace ir {
namespace poly {
/*
* "with" stmt aims to work around the irregular problem.
* By default, the "realize_UB" mark is on the outer band. However, for tensor-of-tensor,
* the intermediate tensor may be too large if realized in the outermost scope.
* To narrow down the scope, we move "realize_UB" mark to the filter node.
* If all filter nodes of the band are "with" stmts, we remove the outer "realize_UB" mark.
*/
class ChangeMarkNodePosition : public SchedulePass {
public:
ChangeMarkNodePosition(const std::unordered_set<std::string> &with_stmts_ids) : with_stmts_ids_(with_stmts_ids) {
pass_name_ = __FUNCTION__;
};
~ChangeMarkNodePosition(){};
virtual isl::schedule Run(isl::schedule sch);
private:
std::unordered_set<std::string> with_stmts_ids_;
};
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_CHANGE_MARKNODE_POSITION_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "compute_inner_band_dependency.h"
namespace akg {
namespace ir {
namespace poly {
isl::schedule ComputeInnerBandDependency::Run(isl::schedule sch) {
auto ori_reads = scop_info_.analysis_result_.GetReads();
auto ori_writes = scop_info_.analysis_result_.GetWrites();
auto ori_fake_copyin = scop_info_.analysis_result_.GetFakeCopyin();
auto inner_band_dependency =
ComputeFakeCopyin(sch, ori_fake_copyin, ori_reads, ori_writes).subtract(scop_info_.analysis_result_.GetCopyin());
scop_info_.analysis_result_.RecordInnerBandDependency(inner_band_dependency);
return sch;
}
} // namespace poly
} // namespace ir
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_COMPUTE_INNER_BAND_DEPENDENCY_H_
#define POLY_COMPUTE_INNER_BAND_DEPENDENCY_H_
#include "poly/schedule_pass.h"
#include "poly/scop_info.h"
namespace akg {
namespace ir {
namespace poly {
/*
* This class initialises the inner band dependency information used in InjectMulticoreToSchedule pass
* and record it in the scop info. No actual schedule tree transfrom is performed in this pass.
*/
class ComputeInnerBandDependency : public SchedulePass {
public:
ComputeInnerBandDependency(ScopInfo &scop_info) : scop_info_(scop_info) { pass_name_ = __FUNCTION__; }
~ComputeInnerBandDependency() {}
virtual isl::schedule Run(isl::schedule sch);
private:
ScopInfo &scop_info_;
};
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_COMPUTE_INNER_BAND_DEPENDENCY_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "compute_schedule.h"
namespace akg {
namespace ir {
namespace poly {
isl::union_map ComputeSchedule::ModDependences(const isl::union_map &dependences) {
isl::union_map umap = isl::union_map::empty(dependences.ctx());
dependences.foreach_map([&](const isl::map &m) -> void {
isl::map mm = m;
if (mm.get_tuple_id(isl_dim_in) != mm.get_tuple_id(isl_dim_out)) {
isl_map *pmap = mm.copy();
int n_in = isl_map_dim(pmap, isl_dim_in);
for (int i = 0; i < n_in; ++i) {
pmap = isl_map_plain_update_val_if_fixed(pmap, isl_dim_in, i);
}
mm = isl::manage(pmap);
}
umap = umap.unite(isl::union_map(mm));
});
return umap;
}
void ComputeSchedule::SetIslOptions() {
auto ctx = pass_info_.constraints_.ctx().get();
int status = isl_options_set_schedule_unit_max_var_coefficient_sum(ctx, 1);
CHECK(status == isl_stat_ok);
if (scop_info_.user_config_.GetComputeReschedule()) {
status = isl_options_set_schedule_whole_component(ctx, 0);
CHECK(status == isl_stat_ok);
} else {
status = isl_options_set_schedule_maximize_coincidence(ctx, 0);
CHECK(status == isl_stat_ok);
status = isl_options_set_schedule_whole_component(ctx, 1);
CHECK(status == isl_stat_ok);
}
if (scop_info_.user_config_.GetDisableScheduleShift()) {
status = isl_options_set_schedule_max_constant_term(ctx, 0);
CHECK(status == isl_stat_ok);
status = isl_options_set_schedule_nonneg_var_coefficient(ctx, 1);
CHECK(status == isl_stat_ok);
}
if (scop_info_.user_config_.GetEnableScheduleMaxConstant()) {
status = isl_options_set_schedule_max_constant_term(ctx, 0);
CHECK(status == isl_stat_ok);
}
if (scop_info_.user_config_.GetDisableLoopReversal()) {
status = isl_options_set_schedule_nonneg_var_coefficient(ctx, 1);
CHECK(status == isl_stat_ok);
}
if (scop_info_.user_config_.GetDisableLoopFusion()) {
status = isl_options_set_schedule_serialize_sccs(ctx, 1);
CHECK(status == isl_stat_ok);
}
}
isl::schedule ComputeSchedule::Run(isl::schedule sch) {
if (scop_info_.user_config_.GetModScheduleShift()) {
pass_info_.dependences_ = ModDependences(pass_info_.dependences_);
}
pass_info_.constraints_ = MakeScheduleConstraints(sch, pass_info_);
SetIslOptions();
return pass_info_.constraints_.compute_schedule();
}
} // namespace poly
} // namespace ir
} // namespace akg
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_COMPUTE_SCHEDULE_H_
#define POLY_COMPUTE_SCHEDULE_H_
#include "poly/schedule_pass.h"
namespace akg {
namespace ir {
namespace poly {
/*
* compute schedule pass, the main tasks ars as follow
* 1. modify the dependences depends on the configuration switch
* 2. Generating constraints
* 3. According to the constraints, the ISL interface is called to generate a new schedule
*/
class ComputeSchedule : public SchedulePass {
public:
ComputeSchedule(PassInfo &pass_info, ScopInfo &scop_info) : pass_info_(pass_info), scop_info_(scop_info) {
pass_name_ = __FUNCTION__;
}
~ComputeSchedule() {}
virtual isl::schedule Run(isl::schedule sch);
void SetIslOptions();
isl::union_map ModDependences(const isl::union_map &dependences);
private:
PassInfo &pass_info_;
ScopInfo &scop_info_;
};
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_COMPUTE_SCHEDULE_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "compute_transfer_copyin.h"
namespace akg {
namespace ir {
namespace poly {
isl::schedule ComputeTransferCopyin::Run(isl::schedule sch) {
// compute fake copyin
auto ori_reads = scop_info_.analysis_result_.GetReads();
auto ori_writes = scop_info_.analysis_result_.GetWrites();
auto ori_fake_copyin = scop_info_.analysis_result_.GetFakeCopyin();
isl::union_map fake_copyin = ComputeFakeCopyin(sch, ori_fake_copyin, ori_reads, ori_writes);
fake_copyin = fake_copyin.subtract(scop_info_.analysis_result_.GetCopyin());
scop_info_.analysis_result_.RecordFakeCopyin(fake_copyin);
isl::union_map raw_writes = ori_writes.domain_factor_domain();
isl::union_map raw_reads = ori_reads.domain_factor_domain();
isl::union_map raw_copyin = scop_info_.analysis_result_.GetCopyin().domain_factor_domain();
isl::union_map reads = fake_copyin.domain_factor_domain();
isl::union_map transfer_copyin = fake_copyin;
while (!reads.is_empty()) {
isl::union_map writes = raw_writes.intersect_range(reads.range());
isl::union_map dependence = DependenceAnalysis(writes, reads, writes, sch.get_map());
isl::union_set stmt = dependence.domain().universe();
scop_info_.analysis_result_.RecordTransferStmt(scop_info_.analysis_result_.GetTransferStmt().unite(stmt));
reads = raw_reads.intersect_domain(stmt);
// compute transfer copyin
isl::union_map target_acc = raw_writes.intersect_domain(stmt);
isl::union_map relation = target_acc.reverse().apply_range(reads);
transfer_copyin = transfer_copyin.apply_range(relation);
isl::union_map copyin = transfer_copyin.intersect_range(raw_copyin.range().universe());
scop_info_.analysis_result_.RecordReads(scop_info_.analysis_result_.GetReads().unite(copyin));
scop_info_.analysis_result_.RecordCopyin(scop_info_.analysis_result_.GetCopyin().unite(copyin));
transfer_copyin = transfer_copyin.subtract(copyin);
reads = reads.subtract(raw_copyin);
reads = reads.subtract(fake_copyin.domain_factor_domain());
}
return sch;
}
} // namespace poly
} // namespace ir
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_COMPUTE_TRANSFER_COPYIN_H_
#define POLY_COMPUTE_TRANSFER_COPYIN_H_
#include "poly/schedule_pass.h"
namespace akg {
namespace ir {
namespace poly {
/*
* This class initialises the transfer copyin information used in TransferStmt and MemoryPromotion pass
* and record it in the scop info. No actual schedule tree transfrom is performed in this pass.
*/
class ComputeTransferCopyin : public SchedulePass {
public:
ComputeTransferCopyin(ScopInfo &scop_info, PassInfo &pass_info) : scop_info_(scop_info), pass_info_(pass_info) {
pass_name_ = __FUNCTION__;
}
~ComputeTransferCopyin() {}
virtual isl::schedule Run(isl::schedule sch);
private:
ScopInfo &scop_info_;
PassInfo &pass_info_;
};
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_COMPUTE_TRANSFER_COPYIN_H_
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -13,24 +13,25 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_SINK_AXIS_H_
#define POLY_SINK_AXIS_H_
#ifndef POLY_INSERT_NODE_FOR_ALLOCC_H_
#define POLY_INSERT_NODE_FOR_ALLOCC_H_
#pragma once
#include "poly/transform.h"
#define MAX_STRIDE 65535
#include "poly/schedule_pass.h"
namespace akg {
namespace ir {
namespace poly {
bool FindC0Schedule(const isl::pw_aff_list &paList);
void ExchangeCoincident(std::vector<int> &coincident, const isl::schedule_node &node,
const std::unordered_map<int, bool> lastIdxSchedule, const int &n);
class InsertNodeForAllocC : public SchedulePass {
public:
InsertNodeForAllocC() { pass_name_ = __FUNCTION__; };
~InsertNodeForAllocC(){};
virtual isl::schedule Run(isl::schedule sched);
};
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_SINK_AXIS_H_
#endif // POLY_INSERT_NODE_FOR_ALLOCC_H_
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册