提交 792d7888 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!81 Refactor the Autopoly pass to support multi-backend targets

Merge pull request !81 from anyrenwei/master
...@@ -28,6 +28,9 @@ set(AKG_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") ...@@ -28,6 +28,9 @@ set(AKG_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
include(cmake/RT.cmake) include(cmake/RT.cmake)
include(cmake/utils.cmake) include(cmake/utils.cmake)
include(cmake/external_libs/isl.cmake) include(cmake/external_libs/isl.cmake)
set(ISL_DIR "${CMAKE_BINARY_DIR}/isl")
if(ENABLE_AKG) if(ENABLE_AKG)
message("-- Build akg in Mindspore") message("-- Build akg in Mindspore")
execute_process(COMMAND bash ${AKG_SOURCE_DIR}/third_party/apply_patches.sh ${CMAKE_CURRENT_BINARY_DIR} "1") execute_process(COMMAND bash ${AKG_SOURCE_DIR}/third_party/apply_patches.sh ${CMAKE_CURRENT_BINARY_DIR} "1")
...@@ -43,8 +46,6 @@ else() ...@@ -43,8 +46,6 @@ else()
set(UNITTEST_DIR "${AKG_SOURCE_DIR}/tests/unittest_cpp") set(UNITTEST_DIR "${AKG_SOURCE_DIR}/tests/unittest_cpp")
endif() endif()
set(ISL_DIR "${CMAKE_BINARY_DIR}/isl")
file(COPY ${AKG_SOURCE_DIR}/python/akg DESTINATION file(COPY ${AKG_SOURCE_DIR}/python/akg DESTINATION
${CMAKE_CURRENT_BINARY_DIR}) ${CMAKE_CURRENT_BINARY_DIR})
...@@ -175,6 +176,8 @@ file( ...@@ -175,6 +176,8 @@ file(
${TVM_DIR}/src/runtime/vm/profiler/*.cc ${TVM_DIR}/src/runtime/vm/profiler/*.cc
${TVM_DIR}/src/codegen/stackvm/*.cc ${TVM_DIR}/src/codegen/stackvm/*.cc
${AKG_SOURCE_DIR}/src/poly/*.cc ${AKG_SOURCE_DIR}/src/poly/*.cc
${AKG_SOURCE_DIR}/src/poly/schedule_pass/*.cc
${AKG_SOURCE_DIR}/src/poly/tiling/*.cc
${AKG_SOURCE_DIR}/src/api/*.cc ${AKG_SOURCE_DIR}/src/api/*.cc
${AKG_SOURCE_DIR}/src/pass/*.cc ${AKG_SOURCE_DIR}/src/pass/*.cc
${AKG_SOURCE_DIR}/src/rpc/*.cc ${AKG_SOURCE_DIR}/src/rpc/*.cc
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
#include "ir_pass.h" #include "ir_pass.h"
#include "pass/utils.h" #include "pass/utils.h"
#include "pass/expr_alg_simplify.h" #include "pass/expr_alg_simplify.h"
#include "poly/tiling_algorithm.h" #include "poly/tiling/tiling_algorithm.h"
namespace akg { namespace akg {
namespace ir { namespace ir {
......
此差异已折叠。
...@@ -16,13 +16,7 @@ ...@@ -16,13 +16,7 @@
#ifndef POLY_CCE_ISL_EMITTER_H_ #ifndef POLY_CCE_ISL_EMITTER_H_
#define POLY_CCE_ISL_EMITTER_H_ #define POLY_CCE_ISL_EMITTER_H_
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "ir_pass.h" #include "ir_pass.h"
#include "isl.h"
#include "scop.h"
#include "isl_emitter.h" #include "isl_emitter.h"
namespace akg { namespace akg {
...@@ -39,12 +33,15 @@ class Liveness { ...@@ -39,12 +33,15 @@ class Liveness {
std::vector<IslIdSet> read_; std::vector<IslIdSet> read_;
std::vector<IslIdSet> write_; std::vector<IslIdSet> write_;
}; };
enum AtomicType { Equ = 0, Add };
/*! /*!
* IslEmitter for CCE * IslEmitter for CCE
*/ */
class CCEIslEmitter : public IslEmitter { class CCEIslEmitter : public IslEmitter {
public: public:
CCEIslEmitter(Scop &s, const NodeInfoRepo &n, const isl::id_list &i) : IslEmitter(s, n, i) { ProcBypassL1(s); } CCEIslEmitter(ScopInfo &info, const NodeInfoRepo &n, const isl::id_list &i) : IslEmitter(info, n, i) {
ProcBypassL1(info);
}
~CCEIslEmitter() override = default; ~CCEIslEmitter() override = default;
Stmt Emit(const isl::ast_node &node) final; Stmt Emit(const isl::ast_node &node) final;
...@@ -52,7 +49,6 @@ class CCEIslEmitter : public IslEmitter { ...@@ -52,7 +49,6 @@ class CCEIslEmitter : public IslEmitter {
private: private:
// override emitters for CCE // override emitters for CCE
Stmt EmitFor(const isl::ast_node_for &node) final; Stmt EmitFor(const isl::ast_node_for &node) final;
Stmt EmitIf(const isl::ast_node_if &node) final;
Stmt EmitMark(const isl::ast_node_mark &node_id) override; Stmt EmitMark(const isl::ast_node_mark &node_id) override;
Stmt EmitBlock(const isl::ast_node_block &node) final; Stmt EmitBlock(const isl::ast_node_block &node) final;
Stmt EmitStmt(const isl::ast_node_user &node) final; Stmt EmitStmt(const isl::ast_node_user &node) final;
...@@ -60,60 +56,68 @@ class CCEIslEmitter : public IslEmitter { ...@@ -60,60 +56,68 @@ class CCEIslEmitter : public IslEmitter {
// DMA emitters for CCE // DMA emitters for CCE
Expr EmitLoad(const isl::ast_expr &lhs, Type type); Expr EmitLoad(const isl::ast_expr &lhs, Type type);
Stmt EmitL1Read(const isl::ast_node_user &node); Stmt EmitRead(const isl::ast_node_user &node);
Stmt EmitL1Write(const isl::ast_node_user &node, Scop::AtomicType atomic); Stmt EmitWrite(const isl::ast_node_user &node, AtomicType atomic);
Stmt EmitSpecGemL1write(const isl::ast_node_mark &node, const Stmt &stmt); Stmt EmitSpecGemL1write(const isl::ast_node_mark &node, const Stmt &stmt);
// RangeInfo emitters for CCE // emit mark node
Stmt EmitGemmRangeInfoBackPropFilter(const Stmt &stmt);
Stmt EmitGemmRangeInfo(Stmt stmt);
// multicore emitters for CCE
Stmt EmitMarkMulticore(const isl::ast_node_mark &node); Stmt EmitMarkMulticore(const isl::ast_node_mark &node);
bool InjectMulticore(const std::string &iter);
Stmt EmitMarkFuseVector(const isl::ast_node_mark &node); Stmt EmitMarkFuseVector(const isl::ast_node_mark &node);
Stmt EmitMarkAllocRealizeOut(const isl::ast_node_mark &node); Stmt EmitMarkAllocRealizeOut(const isl::ast_node_mark &node);
Stmt EmitMarkAllocC(const isl::ast_node_mark &node); Stmt EmitMarkAllocC(const isl::ast_node_mark &node);
Stmt EmitMarkSpecGemm(const isl::ast_node_mark &node); Stmt EmitMarkSpecGemm(const isl::ast_node_mark &node);
// emit attrs
void EmitAttrStmt(const isl::ast_node_block &block_node, const Liveness &liveness, bool is_L1, bool is_L0, void EmitAttrStmt(const isl::ast_node_block &block_node, const Liveness &liveness, bool is_L1, bool is_L0,
std::vector<Stmt> &stmts); std::vector<Stmt> &stmts);
void EmitReadAttrAtL0(std::vector<Stmt> &stmts, int i, Tensor &t);
void EmitAttrStmtL0(Tensor &t, bool &is_im2col, bool &is_filter_l0, bool &is_gemm_data_trans, void EmitReadAttrAtL1(std::vector<Stmt> &stmts, int i, Tensor &t);
bool &is_gemm_weight_trans); void EmitReadAttr(const std::vector<IslIdSet> &read, std::vector<Stmt> &stmts, int i, bool is_L1, bool is_L0);
void EmitWriteAttr(const std::vector<IslIdSet> &write, std::vector<Stmt> &stmts, int i, bool is_L1);
void EmitAttrStmtL1(Tensor &t, bool &is_fractal, bool &is_filter_l1);
void EmitAttrStmtLiveness(const Liveness &liveness, std::vector<Stmt> &stmts, int i, bool is_L1);
void EmitAttrStmtAfterRealize(bool is_L1, bool is_L0, std::vector<Stmt> &stmts); void EmitAttrStmtAfterRealize(bool is_L1, bool is_L0, std::vector<Stmt> &stmts);
Stmt EmitGemmRangeInfoBackPropFilter(const Stmt &stmt);
Stmt EmitGemmRangeInfo(Stmt stmt);
// emit realize
void EmitRealize(const isl::ast_node_block &block_node, const Liveness &liveness_info, bool is_L1, bool is_L0, void EmitRealize(const isl::ast_node_block &block_node, const Liveness &liveness_info, bool is_L1, bool is_L0,
std::vector<Stmt> &stmts); std::vector<Stmt> &stmts);
void EmitRealizeLivenessInfo(std::vector<IslIdSet> &real, const Liveness &liveness_info, // emit access
std::unordered_map<isl::id, std::set<int>, isl::IslIdIslHash> &liveness, Stmt EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, BufferedFootPrintInfo &buffer_fp_info) override;
std::function<bool(const std::string &id)> const &CheckGoOut);
void EmitGemmRangeInfoNewAxis(std::vector<Range> &range, std::vector<std::string> &prefix,
std::unordered_map<std::string, bool> &outerAxis, Range &axisMRange,
Map<std::string, Range> &range_map, Map<std::string, VarExpr> &axis_map);
void EmitGemmRangeInfoDynamic(Range &axisMRange, Map<std::string, Range> &range_map); // tool func
void EmitGemmRangeInfoStatic(Map<std::string, Range> &range_map); bool InjectMulticore(const std::string &iter);
// realize info for CCE void CollectLiveness(const Liveness &liveness_info, bool is_L1, std::vector<IslIdSet> &real,
std::unordered_map<isl::id, std::set<int>, isl::IslIdIslHash> &liveness,
std::function<bool(const std::string &id)> const &CheckGoOut);
void CollectGemmRangeInfoNewAxis(std::vector<Range> &range, std::vector<std::string> &prefix,
std::unordered_map<std::string, bool> &outerAxis, Range &axisMRange,
Map<std::string, Range> &range_map, Map<std::string, VarExpr> &axis_map);
void CollectGemmMWSize(Range &axis_m_range, Map<std::string, Range> &range_map);
void CollectGemmMWSizeDynamic(Map<std::string, Range> &range_map);
Expr FindRealizeScope(const isl::id &var); Expr FindRealizeScope(const isl::id &var);
std::string FindRealizeScopeToString(const isl::id &var); std::string FindRealizeScopeToString(const isl::id &var);
Stmt InsertRealize(Stmt stmt, const isl::id &var, bool is_L0); Stmt InsertRealize(Stmt stmt, const isl::id &var, bool is_L0);
void RealizeOut(); void RealizeOut();
Stmt RemoveCond(const Stmt &stmt); Stmt RemoveCond(const Stmt &stmt);
void ProcBypassL1(const Scop &scop); void ProcBypassL1(const ScopInfo &info);
void SetCube(const isl::id &stmt_id); void SetCube(const isl::id &stmt_id);
std::string ReplaceAxis(const std::string &oldAxis); std::string ReplaceAxis(const std::string &oldAxis);
static std::vector<std::string> ConstructPrefix(); static std::vector<std::string> ConstructPrefix();
void GemmTranspose(std::vector<Stmt> &stmts); void GemmTranspose(std::vector<Stmt> &stmts);
void ConvBackPropFilterFixMadInit(const isl::ast_node_mark &node, Expr &mad_init_cond); void ConvBackPropFilterFixMadInit(const isl::ast_node_mark &node, Expr &mad_init_cond);
isl::multi_aff TensorAccessMultAff(isl::id &tensor_id, const Array<Expr> &subscripts,
const isl::id &stmt_id) override;
bool IsTransferStmt() override;
bool IsCopyinFromAnotherBand(isl::multi_aff &access) override;
std::map<const Variable *, std::string> iters_old_name_;
std::map<const Variable *, std::string> iters_new_name_;
std::unordered_map<isl::id, VarMap, isl::IslIdIslHash> stmt_var_map_;
std::set<Tensor> realized_; std::set<Tensor> realized_;
IslIdSet hoisted_read_; IslIdSet hoisted_read_;
IslIdSet hoisted_write_; IslIdSet hoisted_write_;
......
此差异已折叠。
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_CONSTRUCT_POLY_ACCESSES_H_
#define POLY_CONSTRUCT_POLY_ACCESSES_H_
#include <tvm/ir_visitor.h>
#include <tvm/operation.h>
#include <tuple>
#include "poly/scop_info.h"
namespace akg {
namespace ir {
namespace poly {
std::pair<isl::map, isl::map> ConstructPolyAccess(const OperatorDomainSpace &domain, const Node *op,
const std::string &tensor, const Array<Expr> &dimensions,
AccessMap &accesses);
std::tuple<isl::union_map, isl::union_map, isl::union_map> ConstructPolyAccesses(const OperatorDomainSpace &domain,
const Stmt &s, AccessMap &accesses);
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_CONSTRCUT_POLY_ACCESSES_H_
\ No newline at end of file
...@@ -607,7 +607,7 @@ class DynamicPaddingFix : public IRMutator { ...@@ -607,7 +607,7 @@ class DynamicPaddingFix : public IRMutator {
std::string fm_l1_{""}; std::string fm_l1_{""};
}; };
Stmt OptimizeCce(const Stmt &s, bool dynamicShape = false) { Stmt DavinciHalideOptimizer(const Stmt &s, bool dynamicShape = false) {
Stmt stmt = s; Stmt stmt = s;
if (dynamicShape) { if (dynamicShape) {
stmt = InductionVarElinate().Run(s); stmt = InductionVarElinate().Run(s);
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "poly/davinci_mgr_strategy.h"
#include "poly/schedule_pass/group.h"
#include "poly/schedule_pass/tile_outer_band.h"
#include "poly/schedule_pass/memory_manager.h"
#include "poly/schedule_pass/sink_c0.h"
#include "poly/schedule_pass/sink_last_axis.h"
#include "poly/schedule_pass/reorder_invariant_set_schedule.h"
#include "poly/schedule_pass/keep_outer_band_order.h"
#include "poly/schedule_pass/split_outer_band.h"
#include "poly/schedule_pass/transfer_stmt.h"
#include "poly/schedule_pass/reset_coincidence_of_reduce.h"
#include "poly/schedule_pass/set_all_coincidence.h"
#include "poly/schedule_pass/reschedule.h"
#include "poly/schedule_pass/reorder_inner_band.h"
#include "poly/schedule_pass/change_marknode_position.h"
#include "poly/schedule_pass/insert_node_for_allocc.h"
#include "poly/schedule_pass/label_realize_out_position.h"
#include "poly/schedule_pass/mark_fuse_op.h"
#include "poly/schedule_pass/reorder_mark_nodes.h"
#include "poly/schedule_pass/compute_transfer_copyin.h"
#include "poly/schedule_pass/compute_inner_band_dependency.h"
#include "poly/schedule_pass/mark_outer_most.h"
namespace akg {
namespace ir {
namespace poly {
void DavinciMgrStrategy::RegisterTilingPasses() {
RegisterPass(std::make_shared<TileOuterBand>(pass_info_, scop_info_));
}
void DavinciMgrStrategy::RegisterMemPromPasses() { RegisterPass(std::make_shared<MemoryManager>(scop_info_)); }
void DavinciMgrStrategy::RegisterPasses() {
passes_.clear();
RegisterNormalizationPasses();
if (!scop_info_.user_config_.GetDisableGroup()) {
RegisterPass(std::make_shared<GroupStatements>(pass_info_));
}
RegisterSchedulingPasses();
RegisterPass(std::make_shared<ReorderInvariantSetSchedule>(pass_info_));
if (scop_info_.user_config_.GetReorderSchedule()) {
RegisterPass(std::make_shared<SinkC0>());
}
if (scop_info_.user_config_.GetSinkLastAxis()) {
RegisterPass(std::make_shared<SinkLastAxis>(pass_info_));
}
if (scop_info_.user_config_.GetKeepOuterBandOrder()) {
RegisterPass(std::make_shared<KeepOuterBandOrder>(scop_info_));
}
RegisterPass(std::make_shared<UnGroupStatements>(pass_info_));
if (scop_info_.user_config_.GetOuterBandNeedSplit() && !scop_info_.cube_info_.IsSpecGemm()) {
RegisterPass(std::make_shared<SplitOuterBand>());
}
RegisterPass(std::make_shared<ComputeInnerBandDependency>(scop_info_));
if (!scop_info_.cube_info_.IsSpecGemm() && (scop_info_.cube_info_.IsConv() || scop_info_.cube_info_.IsGemm())) {
RegisterPass(std::make_shared<ComputeTransferCopyin>(scop_info_, pass_info_));
}
if (scop_info_.user_config_.GetIsTuning()) {
return;
}
RegisterTilingPasses();
RegisterPass(std::make_shared<ReorderInvariantSetSchedule>(pass_info_));
RegisterPass(std::make_shared<ResetCoincidenceOfReduce>(scop_info_, pass_info_));
if (scop_info_.user_config_.GetPragmaSetAllCoincident()) {
RegisterPass(std::make_shared<SetAllCoincidence>());
}
if (!scop_info_.user_config_.GetIsDynamic() || !scop_info_.cube_info_.IsConv()) {
RegisterPass(std::make_shared<Reschedule>(scop_info_, pass_info_));
}
RegisterPass(std::make_shared<ReorderInnerBand>(scop_info_.analysis_result_.GetCondVarsMap()));
RegisterPass(std::make_shared<ChangeMarkNodePosition>(scop_info_.analysis_result_.ExtractWithStmtId()));
RegisterPass(std::make_shared<LabelRealizeOutPosition>());
if (scop_info_.cube_info_.IsSpecGemm() || scop_info_.cube_info_.IsGemm() ||
scop_info_.cube_info_.IsConvBackpropFilter()) {
RegisterPass(std::make_shared<InsertNodeForAllocC>());
}
RegisterMemPromPasses();
if (!scop_info_.cube_info_.IsSpecGemm()) {
RegisterPass(std::make_shared<TransferStmt>(scop_info_, pass_info_));
}
RegisterPass(std::make_shared<ReorderMarkNodes>());
RegisterPass(std::make_shared<MarkFuseOp>());
// if coincidence constraints are disabled (due to reschedule), we cannot determine multicore axis reliably
bool can_use_multiCore = !scop_info_.cube_info_.IsSpecGemm() && scop_info_.user_config_.GetConsiderCoincidence();
if (can_use_multiCore || scop_info_.user_config_.GetEnableMarkMultiCore()) {
RegisterPass(std::make_shared<MarkOuterMost>(scop_info_));
}
}
} // namespace poly
} // namespace ir
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_DAVINCI_MGR_STRATEGY_H_
#define POLY_DAVINCI_MGR_STRATEGY_H_
#include "poly/pass_mgr_strategy.h"
namespace akg {
namespace ir {
namespace poly {
class DavinciMgrStrategy : public PassMgrStrategy {
public:
explicit DavinciMgrStrategy(ScopInfo &scop_info) : PassMgrStrategy(scop_info) {
pass_info_.coincident_ = scop_info_.user_config_.GetConsiderCoincidence();
}
~DavinciMgrStrategy() override = default;
void RegisterTilingPasses() override;
void RegisterMemPromPasses() override;
void RegisterPasses() override;
};
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_DAVINCI_MGR_STRATEGY_H_
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "poly/dma_dataflow.h"
#include "poly/scop.h" #include "poly/dma_dataflow.h"
#include "poly/poly_util.h"
namespace akg { namespace akg {
namespace ir { namespace ir {
...@@ -193,7 +193,6 @@ void StmtDataFlowInfo::AddWriteTensor(const std::string &name, TENSOR_DATAFLOW_T ...@@ -193,7 +193,6 @@ void StmtDataFlowInfo::AddWriteTensor(const std::string &name, TENSOR_DATAFLOW_T
void StmtDataFlowInfo::CreateTensorDataFlow(TENSOR_DATAFLOW_TYPE type, const std::string &name, void StmtDataFlowInfo::CreateTensorDataFlow(TENSOR_DATAFLOW_TYPE type, const std::string &name,
TensorDataFlow &dataflow) { TensorDataFlow &dataflow) {
CHECK_NE(name, ""); CHECK_NE(name, "");
dataflow.tensor_name_ = name;
switch (type) { switch (type) {
case TENSOR_DATAFLOW_TYPE::CUBE_CONV_A: case TENSOR_DATAFLOW_TYPE::CUBE_CONV_A:
CubeConvA(name, dataflow); CubeConvA(name, dataflow);
......
...@@ -33,7 +33,7 @@ namespace akg { ...@@ -33,7 +33,7 @@ namespace akg {
namespace ir { namespace ir {
namespace poly { namespace poly {
class TensorFootprintCluster; class TensorFootprintCluster;
class TensorDataFlow; struct TensorDataFlow;
class StmtDataFlowInfo; class StmtDataFlowInfo;
enum MemType { DDR = 1, L1_, UB_, L0A_, L0B_, L0C_, UBL0_, UBL1_ }; enum MemType { DDR = 1, L1_, UB_, L0A_, L0B_, L0C_, UBL0_, UBL1_ };
...@@ -142,7 +142,6 @@ enum TENSOR_DATAFLOW_TYPE { ...@@ -142,7 +142,6 @@ enum TENSOR_DATAFLOW_TYPE {
}; };
struct TensorDataFlow { struct TensorDataFlow {
std::string tensor_name_;
std::vector<std::string> name_flow_; std::vector<std::string> name_flow_;
MemFlow mem_type_flow_; MemFlow mem_type_flow_;
......
此差异已折叠。
...@@ -13,20 +13,14 @@ ...@@ -13,20 +13,14 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef POLY_DMA_INJECT_H_ #ifndef POLY_DMA_INJECT_H_
#define POLY_DMA_INJECT_H_ #define POLY_DMA_INJECT_H_
#pragma once
#include <isl/constraint.h> #include <isl/constraint.h>
#include <iostream>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <string>
#include <memory> #include <memory>
#include "poly/isl.h" #include "poly/isl.h"
#include "poly/scop.h" #include "poly/scop_info.h"
namespace akg { namespace akg {
namespace ir { namespace ir {
...@@ -177,30 +171,36 @@ std::vector<int> ExpandInvalidDims(const std::vector<int> &invalid_dims, const i ...@@ -177,30 +171,36 @@ std::vector<int> ExpandInvalidDims(const std::vector<int> &invalid_dims, const i
int &first_invalid_domain_dim); int &first_invalid_domain_dim);
isl::multi_aff ComputeBufferFootprint(const isl::map &access, const ScopedFootprint &foot_print); isl::multi_aff ComputeBufferFootprint(const isl::map &access, const ScopedFootprint &foot_print);
isl::schedule_node PlaceDataCopyBelowImpl(Scop &scop, isl::schedule_node tree, const TensorFootprintCluster &cluster, isl::schedule_node PlaceDataCopyBelowImpl(ScopInfo &scop_info, isl::schedule_node tree,
const isl::map &buffer_footprint, const isl::id &tensor_id, const TensorFootprintCluster &cluster, const isl::map &buffer_footprint,
const isl::set &original_elements, const isl::map &exact_reads, const isl::id &tensor_id, const isl::set &original_elements,
const isl::map &exact_writes); const isl::map &exact_reads, const isl::map &exact_writes,
const isl::union_map &sch);
void PlaceDataCopyBelowImplReadWrite(Scop &scop, isl::schedule_node &tree, const TensorFootprintCluster &cluster, void PlaceDataCopyBelowImplReadWrite(ScopInfo &scop_info, isl::schedule_node &tree,
const isl::map &footprint, const isl::id &tensor_id, const TensorFootprintCluster &cluster, const isl::map &footprint,
const isl::set &original_elements, const isl::map &exact_writes, const isl::id &tensor_id, const isl::set &original_elements,
isl::map &read_extension, isl::set &buffered_footprint, const isl::id &cluster_id, const isl::map &exact_writes, isl::map &read_extension,
isl::map &extension_map, isl::id &read_id); isl::set &buffered_footprint, const isl::id &cluster_id, isl::map &extension_map,
isl::id &read_id);
void PlaceDataCopyBelowImplFakeReads(Scop &scop, isl::schedule_node &tree, const TensorFootprintCluster &cluster, void PlaceDataCopyBelowImplFakeReads(ScopInfo &scop_info, isl::schedule_node &tree,
isl::map &read_extension, const isl::id &cluster_id); const TensorFootprintCluster &cluster, isl::map &read_extension,
const isl::id &cluster_id, const isl::union_map &sch);
isl::schedule_node PlaceInnerDataCopyBelow(Scop &scop, const isl::schedule_node &tree, isl::schedule_node PlaceInnerDataCopyBelow(ScopInfo &scop_info, const isl::schedule_node &tree,
const TensorFootprintCluster &cluster, const TensorFootprintCluster &cluster,
const TensorFootprintCluster &outer_scope_cluster, const isl::id &tensor_id, const TensorFootprintCluster &outer_scope_cluster, const isl::id &tensor_id,
const isl::id &cluster_id, const isl::id &outer_scope_cluster_id); const isl::id &cluster_id, const isl::id &outer_scope_cluster_id,
const isl::union_map &sch);
isl::schedule_node PlaceOuterDataCopyBelow(Scop &scop, const isl::schedule_node &tree, isl::schedule_node PlaceOuterDataCopyBelow(ScopInfo &scop_info, const isl::schedule_node &tree,
const TensorFootprintCluster &cluster, const isl::id &tensor_id, const TensorFootprintCluster &cluster, const isl::id &tensor_id,
const isl::id &cluster_id); const isl::id &cluster_id, const isl::union_map &sch,
const isl::space &sch_space);
isl::schedule_node PlaceIm2colBelow(Scop &scop, const isl::schedule_node &tree, const TensorFootprintCluster &cluster, isl::schedule_node PlaceIm2colBelow(ScopInfo &scop_info, const isl::schedule_node &tree,
const TensorFootprintCluster &cluster,
const TensorFootprintCluster &outer_scope_cluster, const isl::id &cluster_id, const TensorFootprintCluster &outer_scope_cluster, const isl::id &cluster_id,
const isl::id &outer_scope_cluster_id); const isl::id &outer_scope_cluster_id);
...@@ -210,7 +210,7 @@ class AffineBase { ...@@ -210,7 +210,7 @@ class AffineBase {
public: public:
virtual ~AffineBase() = default; virtual ~AffineBase() = default;
virtual isl::map ConstructAffine(isl::map original) = 0; virtual isl::map ConstructAffine(isl::map original) = 0;
virtual bool NotNeedConstruct(std::string name, Scop &scop) = 0; virtual bool NotNeedConstruct(std::string name, ScopInfo &scop_info) = 0;
}; };
class GemmInnerTransposeAffine : public AffineBase { class GemmInnerTransposeAffine : public AffineBase {
...@@ -221,13 +221,13 @@ class GemmInnerTransposeAffine : public AffineBase { ...@@ -221,13 +221,13 @@ class GemmInnerTransposeAffine : public AffineBase {
isl::map ConstructAffine(isl::map original_map) final; isl::map ConstructAffine(isl::map original_map) final;
void SetRightMatrix(AffineTensor v) { is_right_matrix_ = v; } void SetRightMatrix(AffineTensor v) { is_right_matrix_ = v; }
bool NotNeedConstruct(std::string name, Scop &scop) override { bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override {
// right matrix filter !B tensor // right matrix filter !B tensor
if (is_right_matrix_ == AffineTensor::RIGHT_TENSOR && !scop.IsB(name)) { if (is_right_matrix_ == AffineTensor::RIGHT_TENSOR && !scop_info.cube_info_.IsB(name)) {
return true; return true;
} }
// left matrix filter !A tensor // left matrix filter !A tensor
if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop.IsA(name)) { if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop_info.cube_info_.IsA(name)) {
return true; return true;
} }
return false; return false;
...@@ -246,13 +246,13 @@ class GemmTransposeAffine : public AffineBase { ...@@ -246,13 +246,13 @@ class GemmTransposeAffine : public AffineBase {
void SetRightMatrix(AffineTensor v) { is_right_matrix_ = v; } void SetRightMatrix(AffineTensor v) { is_right_matrix_ = v; }
bool NotNeedConstruct(std::string name, Scop &scop) override { bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override {
// right matrix filter !B tensor // right matrix filter !B tensor
if (is_right_matrix_ == AffineTensor::RIGHT_TENSOR && !scop.IsB(name)) { if (is_right_matrix_ == AffineTensor::RIGHT_TENSOR && !scop_info.cube_info_.IsB(name)) {
return true; return true;
} }
// left matrix filter !A tensor // left matrix filter !A tensor
if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop.IsA(name)) { if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop_info.cube_info_.IsA(name)) {
return true; return true;
} }
return false; return false;
...@@ -271,17 +271,17 @@ class GemmTransposeBlockAffine : public AffineBase { ...@@ -271,17 +271,17 @@ class GemmTransposeBlockAffine : public AffineBase {
void SetRightMatrix(AffineTensor v) { is_right_matrix_ = v; } void SetRightMatrix(AffineTensor v) { is_right_matrix_ = v; }
bool NotNeedConstruct(std::string name, Scop &scop) override { bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override {
// right matrix filter !B tensor // right matrix filter !B tensor
if (AffineTensor::RIGHT_TENSOR == is_right_matrix_ && !scop.IsB(name)) { if (AffineTensor::RIGHT_TENSOR == is_right_matrix_ && !scop_info.cube_info_.IsB(name)) {
return true; return true;
} }
// left matrix filter !A tensor // left matrix filter !A tensor
if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop.IsA(name)) { if (is_right_matrix_ == AffineTensor::LEFT_TENSOR && !scop_info.cube_info_.IsA(name)) {
return true; return true;
} }
if (AffineTensor::OUT_TENSOR == is_right_matrix_ && !scop.IsC(name)) { if (AffineTensor::OUT_TENSOR == is_right_matrix_ && !scop_info.cube_info_.IsC(name)) {
return true; return true;
} }
...@@ -302,8 +302,8 @@ class Im2colAffine : public AffineBase { ...@@ -302,8 +302,8 @@ class Im2colAffine : public AffineBase {
void ConstructAffineMap(isl::map &footprint, std::vector<isl::aff> &v_aff_x, std::vector<isl::aff> &v_aff_y, void ConstructAffineMap(isl::map &footprint, std::vector<isl::aff> &v_aff_x, std::vector<isl::aff> &v_aff_y,
const isl::map &original_map, isl::local_space &ls); const isl::map &original_map, isl::local_space &ls);
bool NotNeedConstruct(std::string name, Scop &scop) override { bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override {
if (!scop.IsA(name)) { if (!scop_info.cube_info_.IsA(name)) {
return true; return true;
} }
return false; return false;
...@@ -319,8 +319,8 @@ class WeightAffine : public AffineBase { ...@@ -319,8 +319,8 @@ class WeightAffine : public AffineBase {
isl::map ConstructAffine(isl::map original_map) final; isl::map ConstructAffine(isl::map original_map) final;
bool NotNeedConstruct(std::string name, Scop &scop) override { bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override {
if (!scop.IsB(name)) { if (!scop_info.cube_info_.IsB(name)) {
return true; return true;
} }
return false; return false;
...@@ -339,8 +339,8 @@ class FractalAffine : public AffineBase { ...@@ -339,8 +339,8 @@ class FractalAffine : public AffineBase {
void ConstructAffineMap(isl::map &footprint, std::vector<isl::aff> &v_aff_x, std::vector<isl::aff> &v_aff_y, void ConstructAffineMap(isl::map &footprint, std::vector<isl::aff> &v_aff_x, std::vector<isl::aff> &v_aff_y,
const isl::map &original_map, isl::local_space &ls); const isl::map &original_map, isl::local_space &ls);
bool NotNeedConstruct(std::string name, Scop &scop) override { bool NotNeedConstruct(std::string name, ScopInfo &scop_info) override {
if (!scop.IsA(name)) { if (!scop_info.cube_info_.IsA(name)) {
return true; return true;
} }
return false; return false;
...@@ -371,7 +371,7 @@ class AffineRefGroupConstructor { ...@@ -371,7 +371,7 @@ class AffineRefGroupConstructor {
void create(); void create();
std::unique_ptr<TensorFootprintCluster> ConstructRefGroup(Scop &scop, const isl::union_map &accesses, std::unique_ptr<TensorFootprintCluster> ConstructRefGroup(ScopInfo &scop_info, const isl::union_map &accesses,
const isl::union_set &domain, const isl::union_set &domain,
const isl::union_map &schedule, ReferenceType type); const isl::union_map &schedule, ReferenceType type);
...@@ -391,7 +391,7 @@ class AffineRefGroupConstructor { ...@@ -391,7 +391,7 @@ class AffineRefGroupConstructor {
AffineType type_ = AffineType::AFFINE_GEMM; AffineType type_ = AffineType::AFFINE_GEMM;
}; };
std::unique_ptr<TensorFootprintCluster> ConstructAffineFpCluster(Scop &scop, const isl::union_map &accesses, std::unique_ptr<TensorFootprintCluster> ConstructAffineFpCluster(ScopInfo &info, const isl::union_map &accesses,
const isl::union_set &domain, const isl::union_set &domain,
const isl::union_map &schedule, ReferenceType type, const isl::union_map &schedule, ReferenceType type,
AffineType affine_type, AffineType affine_type,
......
...@@ -20,9 +20,10 @@ ...@@ -20,9 +20,10 @@
#include <fcntl.h> #include <fcntl.h>
#include <sys/stat.h> #include <sys/stat.h>
#include <fstream> #include <fstream>
#include <iostream>
#include <iomanip>
#include "poly/poly_util.h" #include "poly/poly_util.h"
#include "poly/scop.h"
#include "poly/dma_inject.h" #include "poly/dma_inject.h"
namespace akg { namespace akg {
...@@ -152,6 +153,11 @@ void PrettyPrintSchTree(std::FILE *fp, const isl::schedule &sch) { ...@@ -152,6 +153,11 @@ void PrettyPrintSchTree(std::FILE *fp, const isl::schedule &sch) {
} }
} }
std::string PrettyPrintSchTree(const isl::schedule &sch) {
std::string sch_tree_str = DumpSchTreeToString(sch);
return FormatSchTreeStr(sch_tree_str);
}
/* /*
* Check that file name is a simple relative path (does not start with "/", and does not include "." or ".."). * Check that file name is a simple relative path (does not start with "/", and does not include "." or "..").
* FileName should not include extension, and the extension will be appended to FileName. * FileName should not include extension, and the extension will be appended to FileName.
...@@ -218,6 +224,7 @@ bool CompareSchTreeWithString(const std::string &compare_sch_, const isl::schedu ...@@ -218,6 +224,7 @@ bool CompareSchTreeWithString(const std::string &compare_sch_, const isl::schedu
void PrintHeader(std::ofstream &of, const std::string &str) { void PrintHeader(std::ofstream &of, const std::string &str) {
of << std::endl << ">>>>>>>>>> " << str << " <<<<<<<<<<" << std::endl; of << std::endl << ">>>>>>>>>> " << str << " <<<<<<<<<<" << std::endl;
} }
void PrintHeader(const std::string &str) { std::cout << ">>>>>>>>>> " << str << " <<<<<<<<<<" << std::endl; }
void DumpNode(std::ofstream &of, const air::Node *node) { void DumpNode(std::ofstream &of, const air::Node *node) {
if (node->IsInstance<Provide>()) { if (node->IsInstance<Provide>()) {
...@@ -274,28 +281,28 @@ void CreateDirIfNotExist(const std::string &file_name) { ...@@ -274,28 +281,28 @@ void CreateDirIfNotExist(const std::string &file_name) {
free(file_name_); free(file_name_);
} }
void Scop::DumpScopDataBasics(std::ofstream &of) { void AnalysisResult::DumpScopDataBasics(std::ofstream &of) {
PrintHeader(of, "statements"); PrintHeader(of, "statements");
for (const auto &stmt : data_.statements) { for (const auto &stmt : GetStatementMap()) {
of << stmt.first << " : "; of << stmt.first << " : ";
DumpNode(of, stmt.second); DumpNode(of, stmt.second);
of << std::endl; of << std::endl;
} }
PrintHeader(of, "accesses"); PrintHeader(of, "accesses");
for (const auto &stmt : data_.accesses) { for (const auto &stmt : GetAccessMap()) {
of << stmt.second << " : "; of << stmt.second << " : ";
DumpNode(of, stmt.first); DumpNode(of, stmt.first);
of << std::endl; of << std::endl;
} }
PrintHeader(of, "domains"); PrintHeader(of, "domains");
for (const auto &stmt : data_.domains) { for (const auto &stmt : GetOperatorDomainMap()) {
of << stmt.first << " : param_space " << stmt.second.param_space << std::endl; of << stmt.first << " : param_space " << stmt.second.param_space << std::endl;
} }
PrintHeader(of, "stmt_op_Info"); PrintHeader(of, "stmt_op_Info");
for (const auto &stmt : data_.stmt_op_Info) { for (const auto &stmt : GetStmtOpInfoMap()) {
of << stmt.first << " : ops [ "; of << stmt.first << " : ops [ ";
for (auto op : stmt.second.ops) { for (auto op : stmt.second.ops) {
of << int(op) << ", "; of << int(op) << ", ";
...@@ -307,92 +314,79 @@ void Scop::DumpScopDataBasics(std::ofstream &of) { ...@@ -307,92 +314,79 @@ void Scop::DumpScopDataBasics(std::ofstream &of) {
of << "]" << std::endl; of << "]" << std::endl;
} }
PrintHeader(of, "iterators");
for (const auto &it : data_.iterators) {
of << it.first << " : [ ";
for (const auto &str : it.second) {
of << str << ", ";
}
of << "]" << std::endl;
}
PrintHeader(of, "reads"); PrintHeader(of, "reads");
of << FormatMupaStr(data_.reads) << std::endl; of << FormatMupaStr(GetReads()) << std::endl;
PrintHeader(of, "writes"); PrintHeader(of, "writes");
of << FormatMupaStr(data_.writes) << std::endl; of << FormatMupaStr(GetWrites()) << std::endl;
PrintHeader(of, "copyin"); PrintHeader(of, "copyin");
of << FormatMupaStr(data_.copyin) << std::endl; of << FormatMupaStr(GetCopyin()) << std::endl;
PrintHeader(of, "fake_copyin"); PrintHeader(of, "fake_copyin");
of << FormatMupaStr(data_.fake_copyin) << std::endl; of << FormatMupaStr(GetFakeCopyin()) << std::endl;
PrintHeader(of, "inter_band_dependency"); PrintHeader(of, "inter_band_dependency");
of << FormatMupaStr(data_.inter_band_dependency) << std::endl; of << FormatMupaStr(GetInnerBandDependency()) << std::endl;
PrintHeader(of, "transfer_stmt"); PrintHeader(of, "transfer_stmt");
of << FormatMupaStr(data_.transfer_stmt) << std::endl; of << FormatMupaStr(GetTransferStmt()) << std::endl;
PrintHeader(of, "reduce_stmts"); PrintHeader(of, "reduce_stmts");
for (const auto &stmt : data_.reduce_stmts) { for (const auto &stmt : GetReduceStmtMap()) {
of << stmt.first << ": reduce axis [ "; of << stmt.first << ": reduce axis [ ";
for (const auto &axis : stmt.second) { for (const auto &axis : stmt.second) {
of << axis << " "; of << axis << " ";
} }
of << "]" << std::endl; of << "]" << std::endl;
} }
PrintHeader(of, "group_filter_map");
for (const auto &group : group_filter_map_) {
of << group.first << " : [ ";
for (auto filter : group.second) {
of << filter << ", ";
}
of << "]" << std::endl;
}
} }
void Scop::DumpScopDataAdvanced(std::ofstream &of) { void ScopInfo::DumpScopDataAdvanced(std::ofstream &of) {
PrintHeader(of, "binds"); PrintHeader(of, "binds");
for (auto bind : binds_) { auto binds = user_config_.GetBind();
for (auto bind : binds) {
of << bind.first << " : " << bind.second << std::endl; of << bind.first << " : " << bind.second << std::endl;
} }
PrintHeader(of, "binds_orig"); PrintHeader(of, "binds_orig");
for (auto bind : binds_orig_) { auto binds_orig = user_config_.GetOriginBind();
for (auto bind : binds_orig) {
of << bind.first << " : " << bind.second << std::endl; of << bind.first << " : " << bind.second << std::endl;
} }
PrintHeader(of, "realize_from_input"); PrintHeader(of, "realize_from_input");
for (const auto &id : realize_from_input_) { auto realize_from_input = user_config_.GetRealizeFromInput();
for (const auto &id : realize_from_input) {
of << id << ", "; of << id << ", ";
} }
of << std::endl; of << std::endl;
PrintHeader(of, "dim_infos"); PrintHeader(of, "dim_infos");
for (const auto &dim_info : dim_infos_) { for (const auto &dim_info : analysis_result_.GetTileSizes()) {
of << "index=" << dim_info.index << " axis=" << dim_info.axis << " l1_tiling_size=" << dim_info.l1_tiling_size of << "index=" << dim_info.index << " axis=" << dim_info.axis << " l1_tiling_size=" << dim_info.l1_tiling_size
<< " l0_tiling_size=" << dim_info.l0_tiling_size << " dim_seq=" << dim_info.dim_seq << std::endl; << " l0_tiling_size=" << dim_info.l0_tiling_size << " dim_seq=" << dim_info.dim_seq << std::endl;
} }
PrintHeader(of, "fractal_int_info"); PrintHeader(of, "fractal_int_info");
for (const auto &info : fractal_int_info_) { for (const auto &info : cube_info_.fractal_int_info_) {
of << info.first << " : " << info.second << std::endl; of << info.first << " : " << info.second << std::endl;
} }
PrintHeader(of, "fractal_str_info"); PrintHeader(of, "fractal_str_info");
for (const auto &info : fractal_str_info_) { for (const auto &info : cube_info_.fractal_str_info_) {
of << info.first << " : " << info.second << std::endl; of << info.first << " : " << info.second << std::endl;
} }
PrintHeader(of, "conditional_write_buffer_footprints"); PrintHeader(of, "conditional_write_buffer_footprints");
for (const auto &tensor : conditional_write_buffer_footprints_) { auto conditional_write_buffer_footprints = analysis_result_.GetConditionalWriteBufferFootprints();
for (const auto &tensor : conditional_write_buffer_footprints) {
of << tensor << std::endl; of << tensor << std::endl;
} }
PrintHeader(of, "tensor_name_flows"); PrintHeader(of, "tensor_name_flows");
for (const auto &name_flow : tensor_name_flows_) { auto tensor_name_flows = analysis_result_.GetTensorNameFlows();
for (const auto &name_flow : tensor_name_flows) {
of << name_flow.first << " : [ "; of << name_flow.first << " : [ ";
for (const auto &name : name_flow.second) { for (const auto &name : name_flow.second) {
of << name << ", "; of << name << ", ";
...@@ -401,7 +395,8 @@ void Scop::DumpScopDataAdvanced(std::ofstream &of) { ...@@ -401,7 +395,8 @@ void Scop::DumpScopDataAdvanced(std::ofstream &of) {
} }
PrintHeader(of, "tensor_memflows"); PrintHeader(of, "tensor_memflows");
for (const auto &mem_flow : tensor_mem_flows_) { auto tensor_mem_flows = analysis_result_.GetTensorMemFlows();
for (const auto &mem_flow : tensor_mem_flows) {
of << mem_flow.first << " : [ "; of << mem_flow.first << " : [ ";
for (auto mem : mem_flow.second) { for (auto mem : mem_flow.second) {
of << static_cast<int>(mem) << ", "; of << static_cast<int>(mem) << ", ";
...@@ -409,25 +404,8 @@ void Scop::DumpScopDataAdvanced(std::ofstream &of) { ...@@ -409,25 +404,8 @@ void Scop::DumpScopDataAdvanced(std::ofstream &of) {
of << "]" << std::endl; of << "]" << std::endl;
} }
PrintHeader(of, "n_clusters");
for (const auto &cluster : n_clusters_) {
of << cluster.first << " : " << cluster.second << std::endl;
}
PrintHeader(of, "bufferedDecls");
for (const auto &buffered_decl : buffered_decls_) {
of << buffered_decl.first << " : "
<< "tensor_id=" << buffered_decl.second.tensor_id << "type=" << buffered_decl.second.type
<< "kind=" << static_cast<int>(buffered_decl.second.kind) << "tensor=" << buffered_decl.second.tensor
<< "size=[";
for (auto size : buffered_decl.second.sizes) {
of << size << ",";
}
of << "]" << std::endl;
}
PrintHeader(of, "active_buffer_footprints"); PrintHeader(of, "active_buffer_footprints");
for (const auto &active_buffer_footprint : active_buffer_footprints_) { for (const auto &active_buffer_footprint : analysis_result_.active_buffer_footprints_) {
of << "cluster_id : " << active_buffer_footprint.second.cluster_id << std::endl of << "cluster_id : " << active_buffer_footprint.second.cluster_id << std::endl
<< "domain : " << FormatMupaStr(active_buffer_footprint.first) << std::endl << "domain : " << FormatMupaStr(active_buffer_footprint.first) << std::endl
<< "cluster : " << *(active_buffer_footprint.second.cluster) << std::endl << "cluster : " << *(active_buffer_footprint.second.cluster) << std::endl
...@@ -436,81 +414,82 @@ void Scop::DumpScopDataAdvanced(std::ofstream &of) { ...@@ -436,81 +414,82 @@ void Scop::DumpScopDataAdvanced(std::ofstream &of) {
} }
PrintHeader(of, "buffered_decl_infos"); PrintHeader(of, "buffered_decl_infos");
DumpBufferDefInfos(of); analysis_result_.DumpBufferDefInfos(of);
of << std::endl;
of << "custom_tiling : ";
if (custom_tiling_.empty()) of << "empty" << std::endl;
for (const auto &tiling : custom_tiling_) {
of << tiling << " ";
}
of << std::endl; of << std::endl;
PrintHeader(of, "attr_info"); PrintHeader(of, "attr_info");
for (const auto &info : attr_info_) { for (const auto &info : cube_info_.GetConvAttrInfo()) {
of << info.first << " : " << info.second << std::endl; of << info.first << " : " << info.second << std::endl;
} }
} }
void Scop::DumpScopDataScheduleAttrs(std::ofstream &of) { void UserConfig::DumpScopDataScheduleAttrs(std::ofstream &of) {
PrintHeader(of, "schedule attrs"); PrintHeader(of, "schedule attrs");
of << "dim : " << b_dim_ << std::endl; of << "dump_poly_dir : " << GetDumpPolyDir() << std::endl;
of << "kernel_h : " << matB_dim_h_ << std::endl;
of << "kernel_w : " << matB_dim_w_ << std::endl; of << "dump_tuning_level : " << GetDumpTuningLevel() << std::endl;
of << "conv_backprop_filter : " << conv_back_prop_filter_ << std::endl; of << "dim : " << GetBDim() << std::endl;
of << "bypassL1 : " << bypassL1_ << std::endl;
of << "dump_tuning_level : " << dump_tuning_level_ << std::endl; of << "pragma_rmselfdep : " << GetRemoveSelfDependence() << std::endl;
of << "pragma_rmselfdep : " << remove_self_dependence_ << std::endl; of << "pragma_force_rmselfdep : " << GetForceRemoveSelfDependence() << std::endl;
of << "pragma_force_rmselfdep : " << force_remove_self_dependence_ << std::endl; of << "pragma_reschedule : " << GetComputeReschedule() << std::endl;
of << "pragma_reschedule : " << compute_reschedule_ << std::endl; of << "pragma_disable_schedule_shift : " << GetDisableScheduleShift() << std::endl;
of << "pragma_disable_schedule_shift : " << disable_schedule_shift_ << std::endl; of << "pragma_enable_schedule_max_constant : " << GetEnableScheduleMaxConstant() << std::endl;
of << "pragma_enable_schedule_max_constant : " << enable_schedule_max_constant_ << std::endl; of << "pragma_disable_loop_reversal : " << GetDisableLoopReversal() << std::endl;
of << "pragma_disable_loop_reversal : " << disable_loop_reversal_ << std::endl; of << "pragma_disable_loop_fusion : " << GetDisableLoopFusion() << std::endl;
of << "pragma_disable_loop_fusion : " << disable_loop_fusion_ << std::endl; of << "pragma_modshift : " << GetModScheduleShift() << std::endl;
of << "pragma_modshift : " << mod_schedule_shift_ << std::endl; of << "pragma_reorder_schedule : " << GetReorderSchedule() << std::endl;
of << "pragma_conv_special_dma : " << conv_special_dma_ << std::endl; of << "pragma_checkcoincident : " << GetTileCheckCoincident() << std::endl;
of << "pragma_reorder_schedule : " << reorder_schedule_ << std::endl; of << "pragma_opt_for_davinci : " << GetOptimizeForDavinci() << std::endl;
of << "pragma_checkcoincident : " << tile_check_coincident_ << std::endl; of << "pragma_sink_last_axis : " << GetSinkLastAxis() << std::endl;
of << "pragma_opt_for_davinci : " << optimize_for_davinci_ << std::endl; of << "pragma_keep_outer_band_order : " << GetKeepOuterBandOrder() << std::endl;
of << "pragma_sink_last_axis : " << sink_last_axis_ << std::endl; of << "pragma_disable_group : " << GetDisableGroup() << std::endl;
of << "pragma_keep_outer_band_order : " << keep_outer_band_order_ << std::endl; of << "pragma_tile_inner_band : " << GetTileInnerBand() << std::endl;
of << "pragma_disable_group : " << disable_group_ << std::endl; of << "isolated_idx : " << GetIsolatedIdx() << std::endl;
of << "pragma_tile_inner_band : " << tile_inner_band_ << std::endl; of << "pragma_outerband_need_split : " << GetOuterBandNeedSplit() << std::endl;
of << "kernel_name : " << kernel_name_ << std::endl;
of << "dump_poly_dir : " << dump_poly_dir_ << std::endl; of << "dynamic_shape_bound : " << GetDynamicShapeBound() << std::endl;
of << "isolated_idx : " << isolated_idx_ << std::endl; of << "pragma_tilesize_is_var : " << GetTileSizeIsVar() << std::endl;
of << "dynamic_shape_bound : " << dynamic_shape_bound_ << std::endl;
of << "pragma_tilesize_is_var : " << tile_size_is_var_ << std::endl; of << "kernel_name : " << GetKernelName() << std::endl;
of << "pragma_outerband_need_split : " << outer_band_need_split_ << std::endl; of << "kernel_h : " << GetMatBDimH() << std::endl;
of << "pragma_is_conv : " << pragma_is_conv_ << std::endl; of << "kernel_w : " << GetMatBDimW() << std::endl;
of << "conv_backprop_filter : " << GetConvBackPropFilter() << std::endl;
of << "bypassL1 : " << GetByPassL1() << std::endl;
of << "pragma_is_conv : " << GetPragmaIsConv() << std::endl;
of << "pragma_conv_special_dma : " << GetConvSpecialDma() << std::endl;
} }
bool Scop::DumpScopData(const std::string &file_name) { bool ScopInfo::DumpScopData(const std::string &file_name) {
std::string canonical_log_name = FilePathCanonicalize(file_name, true); std::string canonical_log_name = FilePathCanonicalize(file_name, true);
if (!CreateFileIfNotExist(canonical_log_name)) return false; if (!CreateFileIfNotExist(canonical_log_name)) return false;
std::ofstream of; std::ofstream of;
of.open(canonical_log_name, std::ios::out); of.open(canonical_log_name, std::ios::out);
if (!of.is_open()) return false; if (!of.is_open()) return false;
DumpScopDataBasics(of); analysis_result_.DumpScopDataBasics(of);
DumpScopDataAdvanced(of); DumpScopDataAdvanced(of);
DumpScopDataScheduleAttrs(of); user_config_.DumpScopDataScheduleAttrs(of);
of.close(); of.close();
return true; return true;
} }
void Scop::DumpSchTree(const std::string &file_name, const isl::schedule &sch_dump) { void ScopInfo::DumpSchTree(const std::string &file_name, const isl::schedule &sch_dump) {
if (dump_pass_ir_) { std::stringstream final_file_name;
final_file_name << std::setw(2) << std::setfill('0') << dump_schtree_count << "_" << file_name
<< std::string(cube_info_.IsSpecGemm() ? "_specgemm" : "");
if (user_config_.GetDumpPassIr()) {
#if DUMP_IR #if DUMP_IR
DumpSchTreeImpl(CreateDumpDir(file_name), sch_dump); DumpSchTreeImpl(CreateDumpDir(final_file_name.str()), sch_dump);
dump_schtree_count++;
#endif #endif
#if DUMP_SCOP_DATA #if DUMP_SCOP_DATA
#if DUMP_SCOP_DATA_PER_PASS #if DUMP_SCOP_DATA_PER_PASS
static_cast<void>(DumpScopData(CreateDumpDir(file_name))); static_cast<void>(DumpScopData(CreateDumpDir(final_file_name.str())));
#else #else
static_cast<void>(DumpScopData(CreateDumpDir("scop"))); static_cast<void>(DumpScopData(CreateDumpDir("scop")));
#endif #endif
...@@ -518,29 +497,29 @@ void Scop::DumpSchTree(const std::string &file_name, const isl::schedule &sch_du ...@@ -518,29 +497,29 @@ void Scop::DumpSchTree(const std::string &file_name, const isl::schedule &sch_du
} }
} }
std::string Scop::AddDumpDir(const std::string &file_name) { std::string ScopInfo::AddDumpDir(const std::string &file_name) {
std::string real_file_name = file_name; std::string real_file_name = file_name;
bool is_specgemm = (isolated_idx_ > 0); bool is_specgemm = (user_config_.GetIsolatedIdx() > 0);
if (is_specgemm) { if (is_specgemm) {
std::string dump_isolate_dir = "specgemm_" + std::to_string(isolated_idx_); std::string dump_isolate_dir = "specgemm_" + std::to_string(user_config_.GetIsolatedIdx());
real_file_name = dump_isolate_dir + '/' + real_file_name; real_file_name = dump_isolate_dir + '/' + real_file_name;
} }
#if (!DUMP_IN_CURRENT_DIR) #if (!DUMP_IN_CURRENT_DIR)
if (!dump_poly_dir_.empty()) { if (!user_config_.GetDumpPolyDir().empty()) {
real_file_name = dump_poly_dir_ + '/' + real_file_name; real_file_name = user_config_.GetDumpPolyDir() + '/' + real_file_name;
} }
#endif #endif
return real_file_name; return real_file_name;
} }
std::string Scop::CreateDumpDir(const std::string &file_name) { std::string ScopInfo::CreateDumpDir(const std::string &file_name) {
std::string real_file_name = AddDumpDir(file_name); std::string real_file_name = AddDumpDir(file_name);
CreateDirIfNotExist(real_file_name); CreateDirIfNotExist(real_file_name);
return real_file_name; return real_file_name;
} }
void Scop::DumpBufferDefInfos(std::ostream &out) { void AnalysisResult::DumpBufferDefInfos(std::ostream &out) {
for (size_t index = 0; index < buffer_def_infos_.size(); index++) { for (size_t index = 0; index < buffer_def_infos_.size(); index++) {
out << "\r\nbufferedDefInfos_[" << index << "]: " << std::endl; out << "\r\nbufferedDefInfos_[" << index << "]: " << std::endl;
out << " tensor_id : " << buffer_def_infos_[index].tensor_id << std::endl; out << " tensor_id : " << buffer_def_infos_[index].tensor_id << std::endl;
...@@ -552,6 +531,48 @@ void Scop::DumpBufferDefInfos(std::ostream &out) { ...@@ -552,6 +531,48 @@ void Scop::DumpBufferDefInfos(std::ostream &out) {
out << " is_bind_tensor : " << buffer_def_infos_[index].is_bind_tensor << std::endl; out << " is_bind_tensor : " << buffer_def_infos_[index].is_bind_tensor << std::endl;
} }
} }
void ScopInfo::DumpTransform(const std::string &file_name, PassInfo &pass_info) {
auto real_path = CreateDumpDir(file_name);
std::ofstream of;
of.open(real_path, std::ios::out);
if (!of.is_open()) {
return;
}
PrintHeader(of, "group_filter_map");
for (const auto &group : pass_info.group_filter_map_) {
of << group.first << " : [ ";
for (auto filter : group.second) {
of << filter << ", ";
}
of << "]" << std::endl;
}
PrintHeader(of, "dependences");
of << FormatMupaStr(pass_info.dependences_.to_str()) << std::endl;
PrintHeader(of, "constraints");
isl_printer *p;
char *s = nullptr;
p = isl_printer_to_str(GetCtx().get());
CHECK(p != nullptr);
p = isl_printer_set_yaml_style(p, ISL_YAML_STYLE_BLOCK);
p = isl_printer_print_schedule_constraints(p, pass_info.constraints_.get());
s = isl_printer_get_str(p);
if (s) {
of << FormatMupaStr(s);
free(s);
}
static_cast<void>(isl_printer_free(p));
PrintHeader(of, "time_records");
for (auto time_log : time_records_) {
of << time_log << std::endl;
}
of.close();
}
} // namespace poly } // namespace poly
} // namespace ir } // namespace ir
} // namespace akg } // namespace akg
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <isl/cpp.h> #include <isl/cpp.h>
#include <tvm/node/node.h> #include <tvm/node/node.h>
#include <string> #include <string>
#include "poly/poly_util.h"
namespace akg { namespace akg {
namespace ir { namespace ir {
namespace poly { namespace poly {
...@@ -35,11 +36,11 @@ bool CreateFileIfNotExist(const std::string &file_name); ...@@ -35,11 +36,11 @@ bool CreateFileIfNotExist(const std::string &file_name);
void CreateDirIfNotExist(const std::string &file_name); void CreateDirIfNotExist(const std::string &file_name);
std::string DumpSchTreeToString(const isl::schedule &sch); std::string DumpSchTreeToString(const isl::schedule &sch);
void DumpSchTreeImpl(const std::string &file_name, const isl::schedule &sch); void DumpSchTreeImpl(const std::string &file_name, const isl::schedule &sch);
std::string PrettyPrintSchTree(const isl::schedule &sch);
void PrintHeader(std::ofstream &of, const std::string &str); void PrintHeader(std::ofstream &of, const std::string &str);
void PrintHeader(const std::string &str);
void DumpNode(std::ofstream &of, const air::Node *node); void DumpNode(std::ofstream &of, const air::Node *node);
bool CompareSchTreeWithString(const std::string &compare_sch, const isl::schedule &sch);
} // namespace poly } // namespace poly
} // namespace ir } // namespace ir
} // namespace akg } // namespace akg
......
...@@ -203,11 +203,13 @@ Stmt IslEmitter::EmitFor(const isl::ast_node_for &node) { ...@@ -203,11 +203,13 @@ Stmt IslEmitter::EmitFor(const isl::ast_node_for &node) {
Stmt IslEmitter::EmitIf(const isl::ast_node_if &node) { Stmt IslEmitter::EmitIf(const isl::ast_node_if &node) {
Expr cond_expr = Interpret(node.get_cond()); Expr cond_expr = Interpret(node.get_cond());
cur_if_list_.push_back(cond_expr.get());
Stmt then_case = EmitAst(node.get_then_node()); Stmt then_case = EmitAst(node.get_then_node());
Stmt else_case; Stmt else_case;
if (node.has_else_node()) { if (node.has_else_node()) {
else_case = EmitAst(node.get_else_node()); else_case = EmitAst(node.get_else_node());
} }
cur_if_list_.pop_back();
return IfThenElse::make(cond_expr, then_case, else_case); return IfThenElse::make(cond_expr, then_case, else_case);
} }
...@@ -230,25 +232,8 @@ Stmt IslEmitter::EmitBlock(const isl::ast_node_block &node) { ...@@ -230,25 +232,8 @@ Stmt IslEmitter::EmitBlock(const isl::ast_node_block &node) {
} }
} }
class ReplaceLoopVar : public air::ir::IRMutator {
public:
explicit ReplaceLoopVar(VarMap v_) : var_map(std::move(v_)) {}
~ReplaceLoopVar() override = default;
Expr Mutate_(const Variable *op, const Expr &e) final {
for (auto &i : var_map) {
if (op->name_hint == i.first.get_name()) {
return i.second;
}
}
return e;
}
private:
VarMap var_map;
};
isl::space IslEmitter::GetDomainSpace(const isl::id &node_id) { isl::space IslEmitter::GetDomainSpace(const isl::id &node_id) {
auto dom = isl::union_set(scop_.Domain()); auto dom = isl::union_set(info_.analysis_result_.Domain());
auto space = isl::space(); auto space = isl::space();
dom.foreach_set([&node_id, &space](const isl::set &s) -> void { dom.foreach_set([&node_id, &space](const isl::set &s) -> void {
if (s.get_tuple_id() == node_id) { if (s.get_tuple_id() == node_id) {
...@@ -265,12 +250,12 @@ isl::space IslEmitter::GetSpace(const isl::id &tensor_id, const Array<Expr> &ten ...@@ -265,12 +250,12 @@ isl::space IslEmitter::GetSpace(const isl::id &tensor_id, const Array<Expr> &ten
return space; return space;
} }
isl::multi_aff IslEmitter::TensorAccessMultAff(const isl::id &tensor_id, const Array<Expr> &tensor_index, isl::multi_aff IslEmitter::TensorAccessMultAff(isl::id &tensor_id, const Array<Expr> &tensor_index,
const isl::id &node_id) { const isl::id &node_id) {
CHECK_NE(tensor_index.size(), 0u); CHECK_NE(tensor_index.size(), 0u);
isl::pw_multi_aff iter_map = node_info_map_.at(node_id).iterator_map; isl::pw_multi_aff iter_map = node_info_map_.at(node_id).iterator_map;
isl::id stmt_id = iter_map.get_tuple_id(isl_dim_out); isl::id stmt_id = iter_map.get_tuple_id(isl_dim_out);
OperatorDomainSpace domain_space = scop_.data_.domains.at(stmt_id); OperatorDomainSpace domain_space = info_.analysis_result_.GetOperatorDomainMap().at(stmt_id);
isl::multi_aff ma = isl::multi_aff::zero(GetSpace(tensor_id, tensor_index, stmt_id)); isl::multi_aff ma = isl::multi_aff::zero(GetSpace(tensor_id, tensor_index, stmt_id));
for (size_t i = 0; i < tensor_index.size(); ++i) { for (size_t i = 0; i < tensor_index.size(); ++i) {
auto aff = Expr2Aff(domain_space.param_space, tensor_index[i]).unbind_params_insert_domain(domain_space.tuple); auto aff = Expr2Aff(domain_space.param_space, tensor_index[i]).unbind_params_insert_domain(domain_space.tuple);
...@@ -335,8 +320,9 @@ class EmitExpr : public air::ir::IRMutator { ...@@ -335,8 +320,9 @@ class EmitExpr : public air::ir::IRMutator {
Map<Expr, Expr> cache_; Map<Expr, Expr> cache_;
}; };
void FindBufferFootprintById(Scop::BufferedFootPrintInfo &buffer_footprint_info, BufferedFootPrintInfo FindBufferFootprintById(const std::vector<BufferedFootPrintInfo> &active_buf_footprints,
std::vector<Scop::BufferedFootPrintInfo> active_buf_footprints, isl::id fp_id) { const isl::id &fp_id) {
BufferedFootPrintInfo buffer_footprint_info;
for (const auto &act_buf_fp : active_buf_footprints) { for (const auto &act_buf_fp : active_buf_footprints) {
if (act_buf_fp.cluster != nullptr) { if (act_buf_fp.cluster != nullptr) {
for (const auto &fp : act_buf_fp.cluster->tensor_foot_prints) { for (const auto &fp : act_buf_fp.cluster->tensor_foot_prints) {
...@@ -347,14 +333,16 @@ void FindBufferFootprintById(Scop::BufferedFootPrintInfo &buffer_footprint_info, ...@@ -347,14 +333,16 @@ void FindBufferFootprintById(Scop::BufferedFootPrintInfo &buffer_footprint_info,
} }
} }
} }
return buffer_footprint_info;
} }
bool IsTransferStmt(Scop &scop, isl::id &stmt_id) { bool IslEmitter::IsTransferStmt() {
if (!scop.is_spec_gemm_ && scop.is_tiled_) { if (info_.analysis_result_.GetIsTiled()) {
isl::union_set transfer_stmt = scop.data_.transfer_stmt; isl::union_set transfer_stmt = info_.analysis_result_.GetTransferStmt();
if (!transfer_stmt.is_empty()) { if (!transfer_stmt.is_empty()) {
bool name_match = false; bool name_match = false;
transfer_stmt.foreach_set([&name_match, stmt_id](const isl::set &s) -> void { auto stmt_id = stmt_id_;
transfer_stmt.foreach_set([&name_match, &stmt_id](const isl::set &s) -> void {
if (s.get_tuple_name() == stmt_id.get_name()) { if (s.get_tuple_name() == stmt_id.get_name()) {
name_match = true; name_match = true;
} }
...@@ -365,8 +353,8 @@ bool IsTransferStmt(Scop &scop, isl::id &stmt_id) { ...@@ -365,8 +353,8 @@ bool IsTransferStmt(Scop &scop, isl::id &stmt_id) {
return false; return false;
} }
Stmt EmitAccessNodeProvide(const Node *node, const VarMap &var_map_tmp, Stmt IslEmitter::EmitAccessNodeProvide(const Node *node, const VarMap &var_map_tmp,
Scop::BufferedFootPrintInfo &buffer_footprint_info) { BufferedFootPrintInfo &buffer_footprint_info) {
const auto provide = static_cast<const Provide *>(node); const auto provide = static_cast<const Provide *>(node);
Expr value = ReplaceLoopVar(var_map_tmp).Mutate(provide->value); Expr value = ReplaceLoopVar(var_map_tmp).Mutate(provide->value);
Array<Expr> args; Array<Expr> args;
...@@ -380,8 +368,8 @@ Stmt EmitAccessNodeProvide(const Node *node, const VarMap &var_map_tmp, ...@@ -380,8 +368,8 @@ Stmt EmitAccessNodeProvide(const Node *node, const VarMap &var_map_tmp,
return Stmt(); return Stmt();
} }
Stmt EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, Scop::BufferedFootPrintInfo &buffer_footprint_info, Stmt IslEmitter::EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp,
bool &is_transfer_stmt, Scop &scop) { BufferedFootPrintInfo &buffer_footprint_info) {
const Call *call = static_cast<const Call *>(node); const Call *call = static_cast<const Call *>(node);
Array<Expr> args; Array<Expr> args;
for (auto iv : call->args) { for (auto iv : call->args) {
...@@ -389,46 +377,35 @@ Stmt EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, Scop::Buffe ...@@ -389,46 +377,35 @@ Stmt EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, Scop::Buffe
} }
// Not hoisted, emitting just the mapped subscript. // Not hoisted, emitting just the mapped subscript.
if (!buffer_footprint_info.cluster_id) { if (!buffer_footprint_info.cluster_id) {
std::string call_name = call->name;
if (is_transfer_stmt && (std::string::npos == call_name.find("_local_UB"))) {
call_name = call_name + "_local_UB";
Tensor t = scop.FindTensor(call_name);
if (t.defined()) {
return Evaluate::make(Call::make(call->type, call_name, args, call->call_type, t->op, call->value_index));
} else {
LOG(WARNING) << "Call can not found tensor!!! tensor name: " << call_name;
}
}
return Evaluate::make(Call::make(call->type, call->name, args, call->call_type, call->func, call->value_index)); return Evaluate::make(Call::make(call->type, call->name, args, call->call_type, call->func, call->value_index));
} }
return Stmt(); return Stmt();
} }
bool IsCopyinFromAnotherBand(Scop &scop, isl::multi_aff &access) { bool IslEmitter::IsCopyinFromAnotherBand(isl::multi_aff &access) {
if (!scop.is_spec_gemm_) { for (isl::map inter_band_dependency : info_.analysis_result_.GetInnerBandDependency().get_map_list()) {
for (isl::map inter_band_dependency : scop.data_.inter_band_dependency.get_map_list()) { if (inter_band_dependency.get_tuple_id(isl_dim_out) == access.get_tuple_id(isl_dim_out)) {
if (inter_band_dependency.get_tuple_id(isl_dim_out) == access.get_tuple_id(isl_dim_out)) { return true;
return true;
}
} }
} }
return false; return false;
} }
void AffSubForAstToSchedule(isl::pw_multi_aff &ast_to_schedule, bool &is_transfer_stmt, isl::pw_multi_aff &AffSubForAstToSchedule(isl::pw_multi_aff &ast_to_schedule, bool is_transfer_stmt,
bool &is_copyin_from_another_band) { bool is_copyin_from_another_band) {
if (is_transfer_stmt || is_copyin_from_another_band) { if (is_transfer_stmt || is_copyin_from_another_band) {
isl_pw_multi_aff *pma1 = ast_to_schedule.copy(); isl_pw_multi_aff *pma1 = ast_to_schedule.copy();
isl_pw_multi_aff *pma2 = ast_to_schedule.copy(); isl_pw_multi_aff *pma2 = ast_to_schedule.copy();
isl_pw_multi_aff *pma = isl_pw_multi_aff_sub(pma1, pma2); isl_pw_multi_aff *pma = isl_pw_multi_aff_sub(pma1, pma2);
ast_to_schedule = isl::manage(pma); ast_to_schedule = isl::manage(pma);
} }
return ast_to_schedule;
} }
Stmt IslEmitter::EmitAccessNodeFromPromoteAcsProvide(Scop &scop, isl::id var, const Node *node, Array<Expr> &args) { Stmt IslEmitter::EmitAccessNodeFromPromoteAcsProvide(isl::id var, const Node *node, Array<Expr> &args) {
const auto provide = static_cast<const Provide *>(node); const auto provide = static_cast<const Provide *>(node);
Tensor t = scop.FindTensor(var); Tensor t = info_.FindTensor(var);
if (scop.CountBufferDefInfo(var)) { if (info_.analysis_result_.CountBufferDefInfo(var)) {
realize_may_def_.insert(var); realize_may_def_.insert(var);
if_map_[var] = cur_if_list_; if_map_[var] = cur_if_list_;
if (cur_if_list_.empty()) { if (cur_if_list_.empty()) {
...@@ -439,10 +416,10 @@ Stmt IslEmitter::EmitAccessNodeFromPromoteAcsProvide(Scop &scop, isl::id var, co ...@@ -439,10 +416,10 @@ Stmt IslEmitter::EmitAccessNodeFromPromoteAcsProvide(Scop &scop, isl::id var, co
return s; return s;
} }
Stmt IslEmitter::EmitAccessNodeFromPromoteAcsCall(Scop &scop, isl::id var, const Node *node, Array<Expr> &args) { Stmt IslEmitter::EmitAccessNodeFromPromoteAcsCall(isl::id var, const Node *node, Array<Expr> &args) {
const Call *call = static_cast<const Call *>(node); const Call *call = static_cast<const Call *>(node);
Tensor t = scop.FindTensor(var); Tensor t = info_.FindTensor(var);
if (scop.CountBufferDefInfo(var)) { if (info_.analysis_result_.CountBufferDefInfo(var)) {
realize_use_.insert(var); realize_use_.insert(var);
if (!if_map_.count(var) || !AOutThanB(if_map_.at(var), cur_if_list_)) { if (!if_map_.count(var) || !AOutThanB(if_map_.at(var), cur_if_list_)) {
realize_use_with_may_def_.insert(var); realize_use_with_may_def_.insert(var);
...@@ -451,25 +428,6 @@ Stmt IslEmitter::EmitAccessNodeFromPromoteAcsCall(Scop &scop, isl::id var, const ...@@ -451,25 +428,6 @@ Stmt IslEmitter::EmitAccessNodeFromPromoteAcsCall(Scop &scop, isl::id var, const
return Evaluate::make(Call::make(call->type, var.get_name(), args, call->call_type, t->op, t->value_index)); return Evaluate::make(Call::make(call->type, var.get_name(), args, call->call_type, t->op, t->value_index));
} }
void GetNameWithoutLocal(isl::id &tensor_id, Scop &scop) {
if (!scop.is_spec_gemm_) {
size_t pos = tensor_id.get_name().find("_local_");
std::string substr = tensor_id.get_name().substr(0, pos);
if (pos != 0) tensor_id = isl::id(tensor_id.ctx(), substr);
}
}
Stmt EmitAccessNodeImpl(const Node *node, const VarMap &var_map_tmp, Scop::BufferedFootPrintInfo &buffer_footprint_info,
bool &is_transfer_stmt, Scop &scop, bool is_Provide) {
Stmt s;
if (is_Provide) {
s = EmitAccessNodeProvide(node, var_map_tmp, buffer_footprint_info);
} else {
s = EmitAccessNodeCall(node, var_map_tmp, buffer_footprint_info, is_transfer_stmt, scop);
}
return s;
}
Stmt IslEmitter::EmitAccessNode(const std::string &name, const Node *node, const Array<Expr> &tensor_index, Stmt IslEmitter::EmitAccessNode(const std::string &name, const Node *node, const Array<Expr> &tensor_index,
const VarMap &var_map_tmp) { const VarMap &var_map_tmp) {
// Scalars are not hoisted or remapped. // Scalars are not hoisted or remapped.
...@@ -481,40 +439,34 @@ Stmt IslEmitter::EmitAccessNode(const std::string &name, const Node *node, const ...@@ -481,40 +439,34 @@ Stmt IslEmitter::EmitAccessNode(const std::string &name, const Node *node, const
auto build = node_info_map_.at(node_id_).build; auto build = node_info_map_.at(node_id_).build;
auto iterator_map = node_info_map_.at(node_id_).iterator_map; auto iterator_map = node_info_map_.at(node_id_).iterator_map;
CHECK_EQ(scop_.data_.accesses.count(node), 1u) CHECK_EQ(info_.analysis_result_.GetAccessMap().count(node), 1u)
<< "generating tensor " << name << " not in Scop" << node << " not allowed "; << "generating tensor " << name << " not in Scop" << node << " not allowed ";
auto fp_id = scop_.data_.accesses.at(node); auto fp_id = info_.analysis_result_.GetAccessMap().at(node);
Scop::BufferedFootPrintInfo buffer_footprint_info; std::vector<BufferedFootPrintInfo> active_buf_footprint;
std::vector<Scop::BufferedFootPrintInfo> active_buf_footprint; for (const auto &kv : info_.analysis_result_.ActiveBufferFootprints()) {
for (const auto &kv : scop_.ActiveBufferFootprints()) {
if (kv.first.intersect(isl::union_set(Domain())).is_empty()) { if (kv.first.intersect(isl::union_set(Domain())).is_empty()) {
continue; continue;
} }
active_buf_footprint.emplace_back(kv.second); active_buf_footprint.emplace_back(kv.second);
} }
FindBufferFootprintById(buffer_footprint_info, active_buf_footprint, fp_id); BufferedFootPrintInfo buffer_footprint_info = FindBufferFootprintById(active_buf_footprint, fp_id);
bool is_transfer_stmt = false;
is_transfer_stmt = IsTransferStmt(scop_, stmt_id_);
if (node->IsInstance<Provide>()) { if (node->IsInstance<Provide>()) {
if (EmitAccessNodeImpl(node, var_map_tmp, buffer_footprint_info, is_transfer_stmt, scop_, true).defined()) auto stmt = EmitAccessNodeProvide(node, var_map_tmp, buffer_footprint_info);
return EmitAccessNodeImpl(node, var_map_tmp, buffer_footprint_info, is_transfer_stmt, scop_, true); if (stmt.defined()) return stmt;
} }
if (node->IsInstance<Call>()) { if (node->IsInstance<Call>()) {
if (EmitAccessNodeImpl(node, var_map_tmp, buffer_footprint_info, is_transfer_stmt, scop_, false).defined()) auto stmt = EmitAccessNodeCall(node, var_map_tmp, buffer_footprint_info);
return EmitAccessNodeImpl(node, var_map_tmp, buffer_footprint_info, is_transfer_stmt, scop_, false); if (stmt.defined()) return stmt;
} }
auto buf_def = scop_.GetBufferDefInfo(buffer_footprint_info.cluster_id); auto buf_def = info_.analysis_result_.GetBufferDefInfo(buffer_footprint_info.cluster_id);
GetNameWithoutLocal(buf_def.tensor_id, scop_);
auto access = TensorAccessMultAff(buf_def.tensor_id, tensor_index, node_id_); auto access = TensorAccessMultAff(buf_def.tensor_id, tensor_index, node_id_);
bool is_copyin_from_another_band = false; bool is_copyin_from_another_band = IsCopyinFromAnotherBand(access);
is_copyin_from_another_band = IsCopyinFromAnotherBand(scop_, access);
auto memory_hoist = buffer_footprint_info.cluster->ComputeBufferedFootprints(); auto memory_hoist = buffer_footprint_info.cluster->ComputeBufferedFootprints();
if (is_copyin_from_another_band) { if (is_copyin_from_another_band) {
...@@ -523,33 +475,31 @@ Stmt IslEmitter::EmitAccessNode(const std::string &name, const Node *node, const ...@@ -523,33 +475,31 @@ Stmt IslEmitter::EmitAccessNode(const std::string &name, const Node *node, const
// split read-only or write-only input tensor memory_hoists // split read-only or write-only input tensor memory_hoists
// we need to find tensor by name because tensor_id is a fake isl::id // we need to find tensor by name because tensor_id is a fake isl::id
bool is_input_tensor = scop_.FindTensorInOrig(buf_def.tensor_id.name()).defined(); bool is_input_tensor = info_.FindTensorInOrig(buf_def.tensor_id.name()).defined();
if (is_input_tensor && buffer_footprint_info.cluster->foot_print_.should_split) { if (is_input_tensor && buffer_footprint_info.cluster->foot_print_.should_split) {
memory_hoist = buffer_footprint_info.cluster->UnshiftedBufferFootprint(memory_hoist, fp_id); memory_hoist = buffer_footprint_info.cluster->UnshiftedBufferFootprint(memory_hoist, fp_id);
} }
memory_hoist = memory_hoist.set_tuple_id(isl_dim_out, buffer_footprint_info.cluster_id); memory_hoist = memory_hoist.set_tuple_id(isl_dim_out, buffer_footprint_info.cluster_id);
auto schedule = isl::map::from(buffer_footprint_info.outer_schedule.intersect_domain(this->Domain())); auto schedule = isl::map::from(buffer_footprint_info.outer_schedule.intersect_domain(Domain()));
CHECK(schedule.is_single_valued()) << schedule << " is not single-valued schedule"; CHECK(schedule.is_single_valued()) << schedule << " is not single-valued schedule";
auto ast_to_schedule = isl::pw_multi_aff(schedule).pullback(iterator_map); auto ast_to_schedule = isl::pw_multi_aff(schedule).pullback(iterator_map);
AffSubForAstToSchedule(ast_to_schedule, is_transfer_stmt, is_copyin_from_another_band); ast_to_schedule = AffSubForAstToSchedule(ast_to_schedule, IsTransferStmt(), is_copyin_from_another_band);
auto ast_to_original = isl::pw_multi_aff(access).pullback(iterator_map); auto ast_to_original = isl::pw_multi_aff(access).pullback(iterator_map);
auto ast_to_scheduled_original = ast_to_schedule.range_product(ast_to_original); auto ast_to_scheduled_original = ast_to_schedule.range_product(ast_to_original);
auto ast_to_hoisted = isl::pw_multi_aff(memory_hoist).pullback(ast_to_scheduled_original); auto ast_to_hoisted = isl::pw_multi_aff(memory_hoist).pullback(ast_to_scheduled_original);
auto hoist_acs = build.access_from(ast_to_hoisted); auto hoist_acs = build.access_from(ast_to_hoisted);
if (auto op = hoist_acs.as<isl::ast_expr_op>()) { if (auto op = hoist_acs.as<isl::ast_expr_op>()) {
if (auto access_ = op.as<isl::ast_expr_op_access>()) { if (op.as<isl::ast_expr_op_access>()) {
Array<Expr> args; Array<Expr> args;
for (int i = 1; i < static_cast<int>(op.get_n_arg()); ++i) { for (int i = 1; i < static_cast<int>(op.get_n_arg()); ++i) {
args.push_back(Interpret(op.get_arg(i))); args.push_back(Interpret(op.get_arg(i)));
} }
if (node->IsInstance<Provide>()) if (node->IsInstance<Provide>())
return IslEmitter::EmitAccessNodeFromPromoteAcsProvide(scop_, op.get_arg(0).as<isl::ast_expr_id>().get_id(), return EmitAccessNodeFromPromoteAcsProvide(op.get_arg(0).as<isl::ast_expr_id>().get_id(), node, args);
node, args);
if (node->IsInstance<Call>()) if (node->IsInstance<Call>())
return IslEmitter::EmitAccessNodeFromPromoteAcsCall(scop_, op.get_arg(0).as<isl::ast_expr_id>().get_id(), node, return EmitAccessNodeFromPromoteAcsCall(op.get_arg(0).as<isl::ast_expr_id>().get_id(), node, args);
args);
} }
} }
return Evaluate::make(Expr("todo EmitAst")); return Evaluate::make(Expr("todo EmitAst"));
...@@ -569,7 +519,7 @@ Stmt IslEmitter::EmitUserStmtContent(const Evaluate *eva_node) { ...@@ -569,7 +519,7 @@ Stmt IslEmitter::EmitUserStmtContent(const Evaluate *eva_node) {
auto im2col = Call::make(call->type, call->name, args, call->call_type); auto im2col = Call::make(call->type, call->name, args, call->call_type);
Stmt res = Evaluate::make(im2col); Stmt res = Evaluate::make(im2col);
// add AttrStmt to im2col // add AttrStmt to im2col
for (const auto &item : scop_.data_.vecs) { for (const auto &item : info_.analysis_result_.GetBufferBindVec()) {
Expr replaced = ReplaceLoopVar(var_map_).Mutate(item.second); Expr replaced = ReplaceLoopVar(var_map_).Mutate(item.second);
res = AttrStmt::make(item.first, air::ir::attr::buffer_bind_scope, replaced, res); res = AttrStmt::make(item.first, air::ir::attr::buffer_bind_scope, replaced, res);
} }
...@@ -600,8 +550,9 @@ class SubstituteByNameMutator : public IRMutator { ...@@ -600,8 +550,9 @@ class SubstituteByNameMutator : public IRMutator {
* So, we need to sink the copy out statement into the innermost "if", * So, we need to sink the copy out statement into the innermost "if",
* i.e., copy out immediately after each computation. * i.e., copy out immediately after each computation.
*/ */
static Stmt GenerateCopyOut(const Scop &scop, const Provide *original, const Provide *hoisted, const VarMap &var_map) { static Stmt GenerateCopyOut(const ScopInfo &info, const Provide *original, const Provide *hoisted,
auto call_type = scop.GetDtypeOf(hoisted->func->func_name()); const VarMap &var_map) {
auto call_type = info.GetDtypeOf(hoisted->func->func_name());
Expr call_expr = Call::make(call_type, hoisted->func->func_name(), hoisted->args, Call::CallType::Halide, Expr call_expr = Call::make(call_type, hoisted->func->func_name(), hoisted->args, Call::CallType::Halide,
hoisted->func, hoisted->value_index); hoisted->func, hoisted->value_index);
Array<Expr> new_args; Array<Expr> new_args;
...@@ -621,8 +572,8 @@ Stmt IslEmitter::EmitUserStmtContent(const Provide *provide_node) { ...@@ -621,8 +572,8 @@ Stmt IslEmitter::EmitUserStmtContent(const Provide *provide_node) {
Expr value = EmitExpr(f, var_map_).Mutate(provide_node->value); Expr value = EmitExpr(f, var_map_).Mutate(provide_node->value);
Stmt provide_stmt = Provide::make(provide_new->func, provide_new->value_index, value, provide_new->args); Stmt provide_stmt = Provide::make(provide_new->func, provide_new->value_index, value, provide_new->args);
if (scop_.conditional_write_buffer_footprints_.count(write_tensor)) { if (info_.analysis_result_.GetConditionalWriteBufferFootprints().count(write_tensor)) {
return Block::make(provide_stmt, GenerateCopyOut(scop_, provide_node, provide_new, var_map_)); return Block::make(provide_stmt, GenerateCopyOut(info_, provide_node, provide_new, var_map_));
} }
return provide_stmt; return provide_stmt;
} }
...@@ -688,11 +639,11 @@ Stmt IslEmitter::EmitUserStmt(const isl::ast_node_user &node) { ...@@ -688,11 +639,11 @@ Stmt IslEmitter::EmitUserStmt(const isl::ast_node_user &node) {
isl::ast_expr_op usr_expr = node.get_expr().as<isl::ast_expr_op>(); isl::ast_expr_op usr_expr = node.get_expr().as<isl::ast_expr_op>();
stmt_id_ = usr_expr.get_arg(0).as<isl::ast_expr_id>().get_id(); stmt_id_ = usr_expr.get_arg(0).as<isl::ast_expr_id>().get_id();
node_id_ = node.get_annotation(); node_id_ = node.get_annotation();
const Node *stmt_node = scop_.data_.statements.at(stmt_id_); const Node *stmt_node = info_.analysis_result_.GetStatementMap().at(stmt_id_);
CHECK(stmt_node); CHECK(stmt_node);
// compute VarMap to replace old iterators // compute VarMap to replace old iterators
auto build = node_info_map_.at(node_id_).build; auto build = node_info_map_.at(node_id_).build;
auto tuple = scop_.data_.domains.at(stmt_id_).tuple; auto tuple = info_.analysis_result_.GetOperatorDomainMap().at(stmt_id_).tuple;
auto iterator_map = node_info_map_.at(node_id_).iterator_map; auto iterator_map = node_info_map_.at(node_id_).iterator_map;
var_map_.clear(); var_map_.clear();
...@@ -701,41 +652,51 @@ Stmt IslEmitter::EmitUserStmt(const isl::ast_node_user &node) { ...@@ -701,41 +652,51 @@ Stmt IslEmitter::EmitUserStmt(const isl::ast_node_user &node) {
auto isl_expr = build.expr_from(iterator_map.get_pw_aff(i)); auto isl_expr = build.expr_from(iterator_map.get_pw_aff(i));
Expr halide_new_iter = Interpret(isl_expr); Expr halide_new_iter = Interpret(isl_expr);
var_map_.emplace(isl_old_iter, halide_new_iter); var_map_.emplace(isl_old_iter, halide_new_iter);
std::string replace_id = isl_old_iter.get_name() + "_";
std::vector<const Variable *> vec = ExtractIterfromExpr().Run(halide_new_iter);
for (auto item : vec) {
std::string new_name = item->name_hint;
size_t pos = new_name.find(scop_.iter_prefix_);
if (pos != std::string::npos) {
new_name = new_name.replace(pos, scop_.iter_prefix_.size(), replace_id);
iters_old_name_.emplace(item, item->name_hint);
iters_new_name_.emplace(item, new_name);
}
}
} }
VarMap vmap = var_map_;
stmt_var_map_.emplace(stmt_id_, vmap);
return EmitUserStmtContent(stmt_node); return EmitUserStmtContent(stmt_node);
} }
Stmt IslEmitter::EmitStmt(const isl::ast_node_user &node) { return EmitUserStmt(node); } Stmt IslEmitter::EmitStmt(const isl::ast_node_user &node) {
CHECK(node.get_expr().isa<isl::ast_expr_op>());
isl::ast_expr_op usr_expr = node.get_expr().as<isl::ast_expr_op>();
CHECK(usr_expr);
auto stmt_id = usr_expr.get_arg(0).as<isl::ast_expr_id>().get_id();
if (info_.IsRead(stmt_id)) {
return Evaluate::make(Expr("todo EmitRead"));
}
if (info_.IsWrite(stmt_id)) {
return Evaluate::make(Expr("todo EmitWrite"));
}
return EmitUserStmt(node);
}
Stmt IslEmitter::EmitAst(const isl::ast_node &node) { Stmt IslEmitter::EmitAst(const isl::ast_node &node) {
Stmt s;
std::string info;
if (auto for_node = node.as<isl::ast_node_for>()) { if (auto for_node = node.as<isl::ast_node_for>()) {
return EmitFor(for_node); info = "[FOR_NODE]";
s = EmitFor(for_node);
} else if (auto if_node = node.as<isl::ast_node_if>()) { } else if (auto if_node = node.as<isl::ast_node_if>()) {
return EmitIf(if_node); info = "[IF_NODE]";
s = EmitIf(if_node);
} else if (auto block_node = node.as<isl::ast_node_block>()) { } else if (auto block_node = node.as<isl::ast_node_block>()) {
return EmitBlock(block_node); info = "[BLOCK_NODE]";
s = EmitBlock(block_node);
} else if (auto mark_node = node.as<isl::ast_node_mark>()) { } else if (auto mark_node = node.as<isl::ast_node_mark>()) {
return EmitMark(mark_node); info = "[MARK_NODE]";
s = EmitMark(mark_node);
} else if (auto user_node = node.as<isl::ast_node_user>()) { } else if (auto user_node = node.as<isl::ast_node_user>()) {
return EmitStmt(user_node); info = "[USER_NODE]";
s = EmitStmt(user_node);
} else { } else {
LOG(FATAL) << "NYI " << node << "\n"; s = Evaluate::make(Expr("todo EmitAst"));
} }
return Evaluate::make(Expr("todo EmitAst")); if (PRINT_EMMITER) {
LOG(INFO) << ">>>>>>>>>>>>INPUT AST_NODE" << info << "<<<<<<<<<<<<<<\n" << node;
LOG(INFO) << ">>>>>>>>>>>>OUTPUT STMT<<<<<<<<<<<<\n" << s;
}
return s;
} }
Stmt IslEmitter::Emit(const isl::ast_node &node) { return EmitAst(node); } Stmt IslEmitter::Emit(const isl::ast_node &node) { return EmitAst(node); }
......
...@@ -19,11 +19,9 @@ ...@@ -19,11 +19,9 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "ir_pass.h" #include "ir_pass.h"
#include "poly/isl.h" #include "poly/scop_info.h"
#include "poly/scop.h"
namespace akg { namespace akg {
namespace ir { namespace ir {
...@@ -47,29 +45,31 @@ class IslEmitter { ...@@ -47,29 +45,31 @@ class IslEmitter {
Expr InterpretBinaryOp(const isl::ast_expr_op &e); Expr InterpretBinaryOp(const isl::ast_expr_op &e);
public: public:
explicit IslEmitter(Scop &s_, const NodeInfoRepo &n_, const isl::id_list &i_) explicit IslEmitter(ScopInfo &info, const NodeInfoRepo &n, const isl::id_list &i)
: scop_(s_), node_info_map_(n_), iter_names_(i_) {} : info_(info), node_info_map_(n), iter_names_(i) {}
virtual ~IslEmitter() = default; virtual ~IslEmitter() = default;
/// Interpret isl::ast_expr to Halide Expr // Interpret isl::ast_expr to Halide Expr
//@{
Expr Interpret(const isl::ast_expr &e); Expr Interpret(const isl::ast_expr &e);
//@}
// helper functions, which may can be moved into a separated class // helper functions, which may can be moved into a separated class
isl::space GetDomainSpace(const isl::id &stmt_id); isl::space GetDomainSpace(const isl::id &stmt_id);
isl::space GetSpace(const isl::id &tensor_id, const Array<Expr> &tensor_index, const isl::id &stmt_id); isl::space GetSpace(const isl::id &tensor_id, const Array<Expr> &tensor_index, const isl::id &stmt_id);
isl::multi_aff TensorAccessMultAff(const isl::id &tensor_id, const Array<Expr> &subscripts, const isl::id &stmt_id);
isl::set Domain() const { isl::set Domain() const {
auto iterator_map = node_info_map_.at(node_id_).iterator_map; auto iterator_map = node_info_map_.at(node_id_).iterator_map;
return isl::map::from(iterator_map).range(); return isl::map::from(iterator_map).range();
} }
Stmt EmitAccessNode(const std::string &name, const Node *node, const Array<Expr> &tensor_index, Stmt EmitAccessNode(const std::string &name, const Node *node, const Array<Expr> &tensor_index,
const VarMap &var_map_tmp); const VarMap &var_map_tmp);
Stmt EmitAccessNodeFromPromoteAcsProvide(Scop &scop, isl::id var, const Node *node, Array<Expr> &args); Stmt EmitAccessNodeFromPromoteAcsProvide(isl::id var, const Node *node, Array<Expr> &args);
Stmt EmitAccessNodeFromPromoteAcsCall(Scop &scop, isl::id var, const Node *node, Array<Expr> &args); Stmt EmitAccessNodeFromPromoteAcsCall(isl::id var, const Node *node, Array<Expr> &args);
/// Virtual emitters for different type node Stmt EmitAccessNodeProvide(const Node *node, const VarMap &var_map_tmp, BufferedFootPrintInfo &buffer_fp_info);
//@{ virtual Stmt EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, BufferedFootPrintInfo &buffer_fp_info);
virtual isl::multi_aff TensorAccessMultAff(isl::id &tensor_id, const Array<Expr> &subscripts, const isl::id &stmt_id);
virtual bool IsTransferStmt();
virtual bool IsCopyinFromAnotherBand(isl::multi_aff &access);
// Virtual emitters for different type node
virtual Stmt Emit(const isl::ast_node &node); virtual Stmt Emit(const isl::ast_node &node);
virtual Stmt EmitFor(const isl::ast_node_for &node); virtual Stmt EmitFor(const isl::ast_node_for &node);
virtual Stmt EmitIf(const isl::ast_node_if &node); virtual Stmt EmitIf(const isl::ast_node_if &node);
...@@ -84,7 +84,12 @@ class IslEmitter { ...@@ -84,7 +84,12 @@ class IslEmitter {
virtual Stmt EmitUserStmtContent(const IfThenElse *if_node); virtual Stmt EmitUserStmtContent(const IfThenElse *if_node);
virtual Stmt EmitUserStmtContent(const For *for_node); virtual Stmt EmitUserStmtContent(const For *for_node);
virtual Stmt EmitUserStmtContent(const Block *block_node); virtual Stmt EmitUserStmtContent(const Block *block_node);
//@}
// Loop isl iters info
virtual void PushIter(const Variable *iter);
virtual void PopIter(const Variable *iter);
bool FindIter(const Variable *iter) const;
const Variable *GetIterByName(const std::string &id) const;
std::unordered_set<isl::id, isl::IslIdIslHash> realize_use_; std::unordered_set<isl::id, isl::IslIdIslHash> realize_use_;
std::unordered_set<isl::id, isl::IslIdIslHash> realize_use_with_may_def_; std::unordered_set<isl::id, isl::IslIdIslHash> realize_use_with_may_def_;
...@@ -93,28 +98,16 @@ class IslEmitter { ...@@ -93,28 +98,16 @@ class IslEmitter {
std::unordered_set<isl::id, isl::IslIdIslHash> realize_out_; std::unordered_set<isl::id, isl::IslIdIslHash> realize_out_;
std::unordered_set<isl::id, isl::IslIdIslHash> global_realize_out_; std::unordered_set<isl::id, isl::IslIdIslHash> global_realize_out_;
/// Scop ScopInfo &info_;
Scop &scop_;
/// Node information map including /// Node information map including
const NodeInfoRepo &node_info_map_; const NodeInfoRepo &node_info_map_;
/// Loop isl iters info
//@{
/// Loop isl iters list /// Loop isl iters list
isl::id_list iter_names_; isl::id_list iter_names_;
/// Loop declared halide iters /// Loop declared halide iters
std::vector<const Variable *> iters_; std::vector<const Variable *> iters_;
virtual void PushIter(const Variable *iter);
virtual void PopIter(const Variable *iter);
bool FindIter(const Variable *iter) const;
const Variable *GetIterByName(const std::string &id) const;
//@}
std::map<const Variable *, std::string> iters_old_name_;
std::map<const Variable *, std::string> iters_new_name_;
// current ast node id // current ast node id
isl::id node_id_; isl::id node_id_;
// current stmt id // current stmt id
...@@ -125,7 +118,6 @@ class IslEmitter { ...@@ -125,7 +118,6 @@ class IslEmitter {
// emit in if // emit in if
std::vector<const Node *> cur_if_list_; std::vector<const Node *> cur_if_list_;
std::unordered_map<isl::id, std::vector<const Node *>, isl::IslIdIslHash> if_map_; std::unordered_map<isl::id, std::vector<const Node *>, isl::IslIdIslHash> if_map_;
std::unordered_map<isl::id, VarMap, isl::IslIdIslHash> stmt_var_map_;
}; };
class ExtractIterfromExpr : public air::ir::IRVisitor { class ExtractIterfromExpr : public air::ir::IRVisitor {
...@@ -146,16 +138,23 @@ class ExtractIterfromExpr : public air::ir::IRVisitor { ...@@ -146,16 +138,23 @@ class ExtractIterfromExpr : public air::ir::IRVisitor {
std::vector<const Variable *> vec_; std::vector<const Variable *> vec_;
}; };
void FindBufferFootprintById(Scop::BufferedFootPrintInfo &buffer_footprint_info, class ReplaceLoopVar : public air::ir::IRMutator {
std::vector<Scop::BufferedFootPrintInfo> active_buffer_fp, isl::id id); public:
void GetNameWithoutLocal(isl::id &tensor_id, Scop &scop); explicit ReplaceLoopVar(VarMap v_) : var_map(std::move(v_)) {}
bool IsTransferStmt(Scop &scop, isl::id &stmt_id); ~ReplaceLoopVar() override = default;
bool IsCopyinFromAnotherBand(Scop &scop, isl::multi_aff &access); Expr Mutate_(const Variable *op, const Expr &e) final {
void AffSubForAstToSchedule(isl::pw_multi_aff &ast_to_schedule, bool &is_transfer_stmt, for (auto &i : var_map) {
bool &is_copyin_from_another_band); if (op->name_hint == i.first.get_name()) {
Stmt EmitAccessNodeProvide(const Node *node, const VarMap &var_map_tmp, Scop::BufferedFootPrintInfo &buffer_fp_info); return i.second;
Stmt EmitAccessNodeCall(const Node *node, const VarMap &var_map_tmp, Scop::BufferedFootPrintInfo &buffer_fp_info, }
bool &is_transfer_stmt, Scop &scop); }
return e;
}
private:
VarMap var_map;
};
} // namespace poly } // namespace poly
} // namespace ir } // namespace ir
} // namespace akg } // namespace akg
......
...@@ -13,21 +13,19 @@ ...@@ -13,21 +13,19 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef POLY_RESCHEDULE_H_ #include "poly/pass_info.h"
#define POLY_RESCHEDULE_H_
#pragma once #include <tvm/ir_visitor.h>
#include "poly/transform.h" #include <tvm/operation.h>
#include <isl/constraint.h>
#include <climits>
#include <fstream>
#include <queue>
#include <cmath>
namespace akg { namespace akg {
namespace ir { namespace ir {
namespace poly { namespace poly {} // namespace poly
isl::schedule_node ReorderFilters(const isl::schedule_node &node,
const std::unordered_map<size_t, size_t> &old_to_new_map);
} // namespace poly
} // namespace ir } // namespace ir
} // namespace akg } // namespace akg
#endif // POLY_RESCHEDULE_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_PASS_INFO_H_
#define POLY_PASS_INFO_H_
#include <vector>
#include <map>
#include <unordered_map>
#include "isl.h"
namespace akg {
namespace ir {
namespace poly {
using ReduceStmtMap = std::unordered_map<isl::id, std::vector<std::string>, isl::IslIdIslHash>;
class Dependency {
private:
isl::id start_node_id_;
isl::id end_node_id_;
int64_t edge_weight_;
public:
Dependency(const isl::id start_node_id, const isl::id end_node_id, const int64_t edge_weight)
: start_node_id_(start_node_id), end_node_id_(end_node_id), edge_weight_(edge_weight) {}
~Dependency() {}
isl::id GetStartNode() { return start_node_id_; }
isl::id GetEndNode() { return end_node_id_; }
int64_t GetEdgeWeight() const { return edge_weight_; }
};
// pass info on schedule transform
class PassInfo {
public:
PassInfo() {}
~PassInfo() {}
bool has_grouped_{false};
bool tile_check_coincident_{false};
std::unordered_map<isl::id, isl::union_set_list, isl::IslIdIslHash> group_filter_map_;
std::vector<Dependency> dependency_list_;
isl::union_pw_multi_aff group_upma_;
isl::schedule_constraints constraints_;
bool coincident_{true};
isl::union_map dependences_;
isl::union_map orig_dependences_;
isl::union_set transfer_stmt_;
ReduceStmtMap reduce_stmts_;
std::map<std::string, int> invariant_state_;
bool has_invariant_dependence_{false};
bool restart_{false};
};
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_PASS_INFO_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_PASS_MGR_STRATEGY_H_
#define POLY_PASS_MGR_STRATEGY_H_
#include "poly/schedule_pass.h"
#include "poly/dump_log.h"
#include "poly/schedule_pass/init_schedule.h"
#include "poly/schedule_pass/compute_schedule.h"
namespace akg {
namespace ir {
namespace poly {
class PassMgrStrategy {
public:
explicit PassMgrStrategy(ScopInfo &scop_info) : scop_info_(scop_info) {}
void RegisterPass(std::shared_ptr<SchedulePass> pass) {
CHECK(pass);
passes_.emplace_back(std::move(pass));
}
void RegisterNormalizationPasses() { RegisterPass(std::make_shared<InitSchedule>(pass_info_, scop_info_)); }
void RegisterSchedulingPasses() { RegisterPass(std::make_shared<ComputeSchedule>(pass_info_, scop_info_)); }
virtual void RegisterTilingPasses() = 0; // each backend has different achievement
virtual void RegisterMemPromPasses() = 0; // each backend has different achievement
virtual void RegisterPasses() = 0;
const std::vector<std::shared_ptr<SchedulePass>> &GetPasses() const { return passes_; };
virtual ~PassMgrStrategy() = default;
ScopInfo &scop_info_;
PassInfo pass_info_;
protected:
std::vector<std::shared_ptr<SchedulePass>> passes_;
};
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_PASS_MGR_STRATEGY_H_
\ No newline at end of file
...@@ -13,15 +13,8 @@ ...@@ -13,15 +13,8 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <memory>
#include "ir_pass.h"
#include "poly/scop.h" #include "poly/scop.h"
#include "pass/utils.h"
namespace akg { namespace akg {
namespace ir { namespace ir {
/*! /*!
...@@ -31,63 +24,72 @@ class Poly { ...@@ -31,63 +24,72 @@ class Poly {
public: public:
Poly() : isl_ctx_(isl::ctx(isl_ctx_alloc())) {} Poly() : isl_ctx_(isl::ctx(isl_ctx_alloc())) {}
~Poly() noexcept {
scop_.reset();
// scop must be deconstructed before isl_ctx is deconstructed
isl_ctx_free(isl_ctx_.get());
}
void Run(const Stmt &stmt, const Map<Tensor, Buffer> &extern_buffer, const Map<std::string, NodeRef> &attrs, void Run(const Stmt &stmt, const Map<Tensor, Buffer> &extern_buffer, const Map<std::string, NodeRef> &attrs,
const bool is_spec_gemm, bool is_tuning, bool is_dynamic) { const bool is_spec_gemm, bool is_tuning, bool is_dynamic) {
stmt_ = stmt; stmt_ = stmt;
scop_.reset(new poly::Scop(Simplify_cce(stmt_), extern_buffer, isl_ctx_, is_spec_gemm)); scop_.reset(new poly::Scop(Simplify_cce(stmt_), isl_ctx_));
CHECK(scop_ != nullptr); CHECK(scop_ != nullptr);
scop_->ParseUserConfig(attrs, extern_buffer, is_spec_gemm, is_tuning, is_dynamic);
scop_->SetAttrs(attrs);
scop_->is_dynamic_ = is_dynamic;
// generate isl schedule from Halide
std::chrono::high_resolution_clock::time_point timer_start; std::chrono::high_resolution_clock::time_point timer_start;
// generate isl schedule from Halide
TIMER_START; TIMER_START;
isl::schedule sch = scop_->GenIsl(); isl::schedule sch = scop_->GenIsl();
TIMER_SHOW("GenIsl", std::string(is_spec_gemm ? "_specgemm" : "")); TIMER_SHOW("GenIsl", std::string(is_spec_gemm ? "_specgemm" : ""));
// transform isl schedule with coincidence constraints // isl schedule transform
isl::schedule scht = scop_->Transform(sch, true, is_tuning); TIMER_START;
if (is_tuning) return; isl::schedule sched = scop_->Transform(sch);
TIMER_SHOW("Transform", std::string(is_spec_gemm ? "_specgemm" : ""));
if (scht.get() == sch.get()) {
// transform failed, redo transform without coincidence constraints
scht = scop_->Transform(sch, false);
}
// generate Halide from isl schedule // generate Halide from isl schedule
stmt_ = scop_->GenHalide(scht); TIMER_START;
stmt_ = scop_->GenHalide(sched);
TIMER_SHOW("GenHalide", std::string(is_spec_gemm ? "_specgemm" : ""));
if (is_dynamic) stmt_ = RestoreCombinedParams(stmt_, scop_->info_);
if (is_tuning) {
spaces_ = GenerateTilingSpace(sched, scop_->info_, stmt_, scop_->info_.user_config_.GetDumpTuningLevel());
return;
}
// optimize post poly Halide IR for Davinci // optimize post poly Halide IR for Davinci
if (scop_->enable_feature_library_ || scop_->optimize_for_davinci_) { if (scop_->info_.user_config_.GetEnableFeatureLib() || scop_->info_.user_config_.GetOptimizeForDavinci()) {
stmt_ = poly::OptimizeHalide(stmt_, !scop_->params_.empty()); stmt_ = poly::DavinciHalideOptimizer(stmt_, !scop_->info_.user_config_.GetParams().empty());
} }
gen_empty_tiling = scop_->is_tiled_; gen_empty_tiling = scop_->info_.analysis_result_.GetIsTiled();
} }
~Poly() noexcept { Stmt GetStmt() { return stmt_; }
scop_.reset();
// scop must be deconstructed before isl_ctx is deconstructed
isl_ctx_free(isl_ctx_.get());
}
Stmt getstmt() { return stmt_; } NodeRef GetSpaces() { return spaces_; }
bool gen_empty_tiling{false};
Array<Var> getTilingParams() { Array<Var> GetTilingParams() {
CHECK(scop_ != nullptr); CHECK(scop_ != nullptr);
Array<Var> tiling_params_array; Array<Var> tiling_params_array;
if (gen_empty_tiling) return tiling_params_array; if (gen_empty_tiling) return tiling_params_array;
std::unordered_set<Var, NodeHash, NodeEqual> tiling_params; std::unordered_set<Var, NodeHash, NodeEqual> tiling_params;
for (const auto &kv : scop_->param_tiling_map_) { auto param_tiling_map = scop_->info_.user_config_.GetParamTilingMap();
for (const auto &kv : param_tiling_map) {
GatherVars(kv.second, &tiling_params); GatherVars(kv.second, &tiling_params);
} }
for (const auto &param : tiling_params) tiling_params_array.push_back(param); for (const auto &param : tiling_params) tiling_params_array.push_back(param);
return tiling_params_array; return tiling_params_array;
} }
NodeRef getspaces() { void GatherVars(const Expr expr, std::unordered_set<Var, air::NodeHash, air::NodeEqual> *vset) {
CHECK(scop_ != nullptr); PostOrderVisit(expr, [&vset](const NodeRef &node) {
return scop_->spaces_; if (node.as<Variable>()) {
vset->insert(Downcast<Var>(node));
}
});
} }
private: private:
...@@ -96,6 +98,8 @@ class Poly { ...@@ -96,6 +98,8 @@ class Poly {
// and we need to ensure that they are deconstructed before the isl_ctx is freed. // and we need to ensure that they are deconstructed before the isl_ctx is freed.
isl::ctx isl_ctx_; isl::ctx isl_ctx_;
Stmt stmt_; Stmt stmt_;
NodeRef spaces_;
bool gen_empty_tiling{false};
}; };
/// Interface for lower pass /// Interface for lower pass
...@@ -103,14 +107,14 @@ Array<NodeRef> AutoPoly(const Stmt &stmt, const Map<Tensor, Buffer> &extern_buff ...@@ -103,14 +107,14 @@ Array<NodeRef> AutoPoly(const Stmt &stmt, const Map<Tensor, Buffer> &extern_buff
const Map<std::string, NodeRef> &attrs, const bool is_specgemm, const bool is_dynamic) { const Map<std::string, NodeRef> &attrs, const bool is_specgemm, const bool is_dynamic) {
Poly poly; Poly poly;
poly.Run(stmt, extern_buffer, attrs, is_specgemm, false, is_dynamic); poly.Run(stmt, extern_buffer, attrs, is_specgemm, false, is_dynamic);
return Array<NodeRef>({poly.getstmt(), poly.getTilingParams()}); return Array<NodeRef>({poly.GetStmt(), poly.GetTilingParams()});
} }
NodeRef GenTuningSpace(const Stmt &stmt, const Map<Tensor, Buffer> &extern_buffer, NodeRef GenTuningSpace(const Stmt &stmt, const Map<Tensor, Buffer> &extern_buffer,
const Map<std::string, NodeRef> &attrs, const bool is_specgemm) { const Map<std::string, NodeRef> &attrs, const bool is_specgemm) {
Poly poly; Poly poly;
poly.Run(stmt, extern_buffer, attrs, is_specgemm, true, false); poly.Run(stmt, extern_buffer, attrs, is_specgemm, true, false);
return poly.getspaces(); return poly.GetSpaces();
} }
} // namespace ir } // namespace ir
} // namespace akg } // namespace akg
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "poly/poly_util.h" #include "poly/poly_util.h"
namespace akg { namespace akg {
...@@ -120,6 +121,65 @@ Stmt PeelOuterLetStmt(const Stmt &s, std::vector<Stmt> &outer_stmts) { ...@@ -120,6 +121,65 @@ Stmt PeelOuterLetStmt(const Stmt &s, std::vector<Stmt> &outer_stmts) {
return body; return body;
} }
void GetAffOffsetAndNumVars(const isl::aff &aff, int &offset, int &num_vars) {
offset = aff.get_constant_val().get_num_si();
num_vars = 0;
int dim = isl_aff_dim(aff.get(), isl_dim_in);
CHECK_GE(dim, 0);
for (int j = 0; j < dim; ++j) {
isl_val *coef = isl_aff_get_coefficient_val(aff.get(), isl_dim_in, j);
int coef_val = isl_val_get_num_si(coef);
static_cast<void>(isl_val_free(coef));
if (coef_val != 0) ++num_vars;
}
}
/*
* Check the isl::aff is in the form of { [i0, i1, i2, i3, i4] -> [(-64 + i2)] }
* i.e. the mapping is one variable plus a non-zero constant offset.
*/
bool IsAffVarPlusOffset(const isl::aff &aff) {
int offset = 0, num_vars = 0;
GetAffOffsetAndNumVars(aff, offset, num_vars);
return offset != 0 && num_vars == 1;
}
/*
* Check the isl::aff is in the form of { [i0, i1, i2, i3, i4] -> [(64)] }
* i.e. the mapping is a non-zero constant.
*/
bool IsAffNonZeroConst(const isl::aff &aff) {
int offset = 0, num_vars = 0;
GetAffOffsetAndNumVars(aff, offset, num_vars);
return offset != 0 && num_vars == 0;
}
isl::union_map LocalScheduleImpl(const isl::schedule_node &node, bool use_node) {
int tree_depth = node.get_tree_depth();
int new_tree_depth = tree_depth;
if (use_node) ++new_tree_depth;
isl::schedule_node tmp_node;
isl::union_map schedule = isl::union_map::from_domain(node.get_domain());
for (int i = 0; i < new_tree_depth; ++i) {
tmp_node = node.ancestor(tree_depth - i);
if (auto band_node = tmp_node.as<isl::schedule_node_band>()) {
if (band_node.n_member() > 0) {
schedule = schedule.flat_range_product(band_node.get_partial_schedule_union_map());
}
} else if (auto filter_node = tmp_node.as<isl::schedule_node_filter>()) {
schedule = schedule.intersect_domain(filter_node.get_filter());
} else if (auto extension_node = tmp_node.as<isl::schedule_node_extension>()) {
schedule = schedule.unite(extension_node.get_extension().reverse().intersect_range(schedule.range()));
}
}
return schedule;
}
isl::union_map ShortSchedule(const isl::schedule_node &node) { return LocalScheduleImpl(node, false); }
isl::union_map LocalSchedule(const isl::schedule_node &node) { return LocalScheduleImpl(node, true); }
} // namespace poly } // namespace poly
} // namespace ir } // namespace ir
} // namespace akg } // namespace akg
...@@ -15,12 +15,9 @@ ...@@ -15,12 +15,9 @@
*/ */
#ifndef POLY_UTIL_H_ #ifndef POLY_UTIL_H_
#define POLY_UTIL_H_ #define POLY_UTIL_H_
#pragma once
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <ir_pass.h> #include <ir_pass.h>
#include <tvm/ir_visitor.h> #include <chrono>
#include <tvm/ir_mutator.h>
#include "isl.h" #include "isl.h"
namespace akg { namespace akg {
...@@ -31,28 +28,26 @@ namespace poly { ...@@ -31,28 +28,26 @@ namespace poly {
#define PRETTY_PRINT_IR true #define PRETTY_PRINT_IR true
#define DUMP_SCOP_DATA true #define DUMP_SCOP_DATA true
#define DUMP_SCOP_DATA_PER_PASS false #define DUMP_SCOP_DATA_PER_PASS false
#define DUMP_TRANSFORM true
#define DUMP_TRANSFORM_PER_PASS false
#define DUMP_IN_CURRENT_DIR false #define DUMP_IN_CURRENT_DIR false
#define PRINT_C false
#define PRINT_SCHEDULE_INFO false #define PRINT_SCHEDULE_INFO false
#define PRINT_ISL_EMMITER false
#define PRINT_CCE_ISL_EMMITER false
#define PRINT_EMMITER (PRINT_ISL_EMMITER || PRINT_CCE_ISL_EMMITER)
#define SPEC_GEMM true #define SPEC_GEMM true
#define DELETE_FRACTAL true #define DELETE_FRACTAL true
/// conv_backward options /// conv_backward options
#define SELECT_DOMAIN_OPT true #define SELECT_DOMAIN_OPT true
/// transform options // timer records
#define USE_CACHED_SCHEDULE false #define TIMER_START timer_start = std::chrono::high_resolution_clock::now()
#define ENABLE_REPLACE_SCHEDULE_HOOK true #define TIMER_DURATION \
(std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::high_resolution_clock::now() - timer_start) \
/// constants .count()) * \
constexpr auto kReadSuffix = "read"; 1000
constexpr auto kWriteSuffix = "write"; #define TIMER_SHOW(NAME, SPEC_GEMM) \
constexpr auto kIterNamePrefix = "cc"; { LOG(INFO) << "[ Polyhedral exec time" << SPEC_GEMM << " ], " << NAME << " spent " << TIMER_DURATION << " ms"; }
constexpr auto kGemmIterNamePrefix = "ee";
constexpr auto TENSORLISTTAILNAME = "TensorListTail";
unsigned int WrappedStrtol(const std::string &str); unsigned int WrappedStrtol(const std::string &str);
...@@ -68,6 +63,12 @@ Expr RemoveCast(Expr e); ...@@ -68,6 +63,12 @@ Expr RemoveCast(Expr e);
Stmt PeelOuterLetStmt(const Stmt &s, std::vector<Stmt> &outer_stmts); Stmt PeelOuterLetStmt(const Stmt &s, std::vector<Stmt> &outer_stmts);
isl::union_map ShortSchedule(const isl::schedule_node &node);
isl::union_map LocalSchedule(const isl::schedule_node &node);
void GetAffOffsetAndNumVars(const isl::aff &aff, int &offset, int &num_vars);
bool IsAffVarPlusOffset(const isl::aff &aff);
bool IsAffNonZeroConst(const isl::aff &aff);
class ConsolidateExprMutator : public IRMutator { class ConsolidateExprMutator : public IRMutator {
public: public:
explicit ConsolidateExprMutator(const std::unordered_map<std::string, Var> &params_) : params(params_) {} explicit ConsolidateExprMutator(const std::unordered_map<std::string, Var> &params_) : params(params_) {}
...@@ -86,15 +87,15 @@ class ConsolidateExprMutator : public IRMutator { ...@@ -86,15 +87,15 @@ class ConsolidateExprMutator : public IRMutator {
} }
// list operators that may appear in dynamic shape params // list operators that may appear in dynamic shape params
Expr Mutate_(const Add *op, const Expr &e) { return GenericMutate(op, e); } Expr Mutate_(const Add *op, const Expr &e) override { return GenericMutate(op, e); }
Expr Mutate_(const Sub *op, const Expr &e) { return GenericMutate(op, e); } Expr Mutate_(const Sub *op, const Expr &e) override { return GenericMutate(op, e); }
Expr Mutate_(const Mul *op, const Expr &e) { return GenericMutate(op, e); } Expr Mutate_(const Mul *op, const Expr &e) override { return GenericMutate(op, e); }
Expr Mutate_(const FloorDiv *op, const Expr &e) { return GenericMutate(op, e); } Expr Mutate_(const FloorDiv *op, const Expr &e) override { return GenericMutate(op, e); }
Expr Mutate_(const FloorMod *op, const Expr &e) { return GenericMutate(op, e); } Expr Mutate_(const FloorMod *op, const Expr &e) override { return GenericMutate(op, e); }
Expr Mutate_(const Div *op, const Expr &e) { return GenericMutate(op, e); } Expr Mutate_(const Div *op, const Expr &e) override { return GenericMutate(op, e); }
Expr Mutate_(const Mod *op, const Expr &e) { return GenericMutate(op, e); } Expr Mutate_(const Mod *op, const Expr &e) override { return GenericMutate(op, e); }
Expr Mutate_(const Min *op, const Expr &e) { return GenericMutate(op, e); } Expr Mutate_(const Min *op, const Expr &e) override { return GenericMutate(op, e); }
Expr Mutate_(const Max *op, const Expr &e) { return GenericMutate(op, e); } Expr Mutate_(const Max *op, const Expr &e) override { return GenericMutate(op, e); }
const std::unordered_map<std::string, Var> &params; const std::unordered_map<std::string, Var> &params;
}; };
...@@ -168,6 +169,9 @@ constexpr auto ATTR_GEMM_WEIGHT_TRANSPOSE_BLOCK_INNER = "pragma_weight_transpose ...@@ -168,6 +169,9 @@ constexpr auto ATTR_GEMM_WEIGHT_TRANSPOSE_BLOCK_INNER = "pragma_weight_transpose
constexpr auto ATTR_ATOMIC_ADD = "atomic_add"; constexpr auto ATTR_ATOMIC_ADD = "atomic_add";
constexpr auto ATOMIC_COND_CLEAN = "atomic_cond_clean"; constexpr auto ATOMIC_COND_CLEAN = "atomic_cond_clean";
constexpr auto UBL0 = "UBL0";
constexpr auto REALIZE_ = "realize_";
/****************************************************** /******************************************************
* Following const is the mark tags for schedule tree * Following const is the mark tags for schedule tree
******************************************************/ ******************************************************/
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "schedule_pass.h"
#include <climits>
#include <fstream>
namespace akg {
namespace ir {
namespace poly {
/* Reorder filters of a sequence/set node.
* node: must be a sequence or set node.
* old_to_new_map: map from original child position to new child position.
* The caller should make sure that there are no duplicate values.
*/
isl::schedule_node ReorderFilters(const isl::schedule_node &node,
const std::unordered_map<size_t, size_t> &old_to_new_map) {
auto n_children = node.n_children();
isl_schedule_tree *old_tree = isl_schedule_node_get_tree(node.get());
CHECK(old_tree != nullptr);
isl_schedule_tree *new_tree = isl_schedule_node_get_tree(node.get());
CHECK(new_tree != nullptr);
for (auto &it : old_to_new_map) {
auto old_pos = it.first;
auto new_pos = it.second;
CHECK(old_pos < n_children);
CHECK(new_pos < n_children);
isl_schedule_tree *old_child = isl_schedule_tree_get_child(old_tree, old_pos);
CHECK(old_child != nullptr);
new_tree = isl_schedule_tree_replace_child(new_tree, new_pos, old_child);
CHECK(new_tree != nullptr);
}
static_cast<void>(isl_schedule_tree_free(old_tree));
isl_schedule_node *new_node = isl_schedule_node_graft_tree(node.copy(), new_tree);
CHECK(new_node != nullptr);
return isl::manage(new_node);
}
isl::union_map DependenceAnalysis(const isl::union_map &sources, const isl::union_map &targets,
const isl::union_map &kills, const isl::union_map &sch) {
auto access_info = isl::union_access_info(targets);
access_info = access_info.set_kill(kills);
access_info = access_info.set_may_source(sources);
access_info = access_info.set_schedule_map(sch);
auto union_flow = access_info.compute_flow();
return union_flow.get_may_dependence();
}
isl::union_map ComputeAllDependences(const isl::schedule &schedule, const isl::union_map &reads_um,
const isl::union_map &writes_um) {
auto reads = reads_um.domain_factor_domain();
auto writes = writes_um.domain_factor_domain();
auto sch = schedule.get_map();
// RAW
auto flowDeps = DependenceAnalysis(writes, reads, writes, sch);
// WAR and WAW
auto falseDeps = DependenceAnalysis(writes.unite(reads), writes, writes, sch);
return flowDeps.unite(falseDeps).coalesce();
}
isl::schedule_node GetOuterBand(const isl::schedule_node &root) {
auto outer_band = root;
while (!outer_band.isa<isl::schedule_node_band>()) {
auto n = outer_band.n_children();
if (n == 1) {
outer_band = outer_band.child(0);
continue;
} else {
/*
* return the node when encountered branching or a leaf
* an empty band would be inserted elsewhere
*/
return outer_band;
}
}
return outer_band;
}
bool IsSequenceOrSet(const isl::schedule_node &node) {
if (node.isa<isl::schedule_node_sequence>()) return true;
return node.isa<isl::schedule_node_set>();
}
isl::union_map ComputeFilterCopyin(const isl::schedule_node &node, const isl::union_map &ori_reads,
const isl::union_map &ori_writes, const isl::schedule ori_schedule) {
CHECK(node.isa<isl::schedule_node_filter>()) << "The input should be a filter node!" << std::endl;
auto filter = node.as<isl::schedule_node_filter>().get_filter();
auto reads = ori_reads.domain_factor_domain().intersect_domain(filter);
auto writes = ori_writes.domain_factor_domain().intersect_domain(filter);
auto uai = isl::union_access_info(reads);
uai = uai.set_kill(writes);
uai = uai.set_may_source(writes);
uai = uai.set_schedule(ori_schedule);
auto flow = uai.compute_flow();
auto mayNoSource = flow.get_may_no_source();
auto copyin = ori_reads.intersect_range(mayNoSource.range());
return copyin;
}
isl::union_map ComputeFakeCopyin(const isl::schedule &schedule, const isl::union_map &fake_copyin,
const isl::union_map &ori_reads, const isl::union_map &ori_writes) {
auto root = schedule.get_root();
auto node = GetOuterBand(root);
auto result = fake_copyin;
if (!IsSequenceOrSet(node)) return result;
auto n = node.n_children();
for (auto i = 0u; i < n; ++i) {
auto child = node.child(i);
auto copyin = ComputeFilterCopyin(child, ori_reads, ori_writes, schedule);
result = result.unite(copyin);
}
return result;
}
isl::schedule_constraints MakeScheduleConstraints(const isl::schedule &schedule, PassInfo &pass_info) {
isl::schedule_constraints constraints;
if (pass_info.coincident_) {
constraints = isl::schedule_constraints::on_domain(schedule.get_domain())
.set_coincidence(pass_info.dependences_) // keep it, check for more cases
.set_validity(pass_info.dependences_)
.set_proximity(pass_info.dependences_);
} else {
constraints = isl::schedule_constraints::on_domain(schedule.get_domain())
.set_validity(pass_info.dependences_)
.set_proximity(pass_info.dependences_);
}
return constraints;
}
/*
* Merge multiple lines of strings into a single-line string
*/
static std::string UndoPrettyPrintSchTree(const std::string &schedule) {
const char *src = schedule.c_str();
std::stringstream dst;
bool in_string = false;
while (*src != '\0') {
if (*src == '"') {
in_string = !in_string;
if (!in_string) {
// end of string, find next non-empty char
const char *next = src + 1;
while (*next != '\0') {
char c = *next;
if (c != ' ' && c != '\t' && c != '\n' && c != '\r') {
break;
}
++next;
}
if (*next == '"') {
// multiple consecutive strings, merge them and insert a white space
dst << " ";
src = next + 1;
in_string = true;
continue;
}
}
}
dst << *src++;
}
return dst.str();
}
bool LoadScheduleTreeFromFile(const std::string &filename, isl::schedule &schedule) {
std::ifstream new_schedule_file_stream(filename);
std::string schedule_to_replace_str((std::istreambuf_iterator<char>(new_schedule_file_stream)),
std::istreambuf_iterator<char>());
schedule_to_replace_str = UndoPrettyPrintSchTree(schedule_to_replace_str);
isl_schedule *ss = isl_schedule_read_from_str(schedule.ctx().get(), schedule_to_replace_str.c_str());
if (ss != nullptr) {
schedule = isl::manage(ss);
return true;
} else {
LOG(WARNING) << "Failed to load file " << filename << " to schedule tree, please check syntax of the new schedule.";
return false;
}
}
/*
* Compare and replace schedule hook:
* Enable users to replace a specific schedule for debugging purpose.
* If the current schedule is identical to the schedule in OLD_SCHEDULE_FILE,
* the schedule will be replaced with NEW_SCHEDULE_FILE.
*/
bool ReplaceScheduleTree(isl::schedule &schedule, ScopInfo &info) {
const std::string OLD_SCHEDULE_FILE = info.AddDumpDir("old_schedule.txt");
const std::string NEW_SCHEDULE_FILE = info.AddDumpDir("new_schedule.txt");
// check if two files exist
char pathBuffOld[PATH_MAX + 1] = {0};
char pathBuffNew[PATH_MAX + 1] = {0};
bool should_compare_and_replace = false;
if (realpath(OLD_SCHEDULE_FILE.c_str(), pathBuffOld) && realpath(NEW_SCHEDULE_FILE.c_str(), pathBuffNew)) {
FILE *schedule_to_compare = fopen(pathBuffOld, "r");
FILE *schedule_to_replace = fopen(pathBuffNew, "r");
should_compare_and_replace = (schedule_to_compare != nullptr && schedule_to_replace != nullptr);
if (schedule_to_compare != nullptr) {
int status = fclose(schedule_to_compare);
if (status != 0) LOG(WARNING) << "Failed to close old_schedule.txt";
}
if (schedule_to_replace != nullptr) {
int status = fclose(schedule_to_replace);
if (status != 0) LOG(WARNING) << "Failed to close new_schedule.txt";
}
}
if (should_compare_and_replace) {
std::ifstream old_schedule_file_stream(OLD_SCHEDULE_FILE);
std::string schedule_to_compare_str((std::istreambuf_iterator<char>(old_schedule_file_stream)),
std::istreambuf_iterator<char>());
if (CompareSchTreeWithString(schedule_to_compare_str, schedule)) {
LOG(INFO) << "Current schedule is same as " << OLD_SCHEDULE_FILE << ", replace it with new schedule "
<< NEW_SCHEDULE_FILE;
CHECK(LoadScheduleTreeFromFile(NEW_SCHEDULE_FILE, schedule));
return true;
} else {
LOG(INFO) << "Current schedule is different from " << OLD_SCHEDULE_FILE << ", not replacing.";
}
}
return false;
}
} // namespace poly
} // namespace ir
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_PASS_H_
#define POLY_PASS_H_
#include "poly/isl.h"
#include "poly/scop_info.h"
#include "poly/pass_info.h"
namespace akg {
namespace ir {
namespace poly {
class SchedulePass {
public:
virtual ~SchedulePass() {}
virtual isl::schedule Run(isl::schedule sch) = 0;
std::string GetPassName() { return pass_name_; }
std::string pass_name_;
bool restart_{false}; // triggers restart during runtime
};
isl::schedule_node ReorderFilters(const isl::schedule_node &node,
const std::unordered_map<size_t, size_t> &old_to_new_map);
isl::union_map DependenceAnalysis(const isl::union_map &sources, const isl::union_map &targets,
const isl::union_map &kills, const isl::union_map &sch);
isl::union_map ComputeAllDependences(const isl::schedule &schedule, const isl::union_map &reads_um,
const isl::union_map &writes_um);
isl::schedule_node GetOuterBand(const isl::schedule_node &root);
bool IsSequenceOrSet(const isl::schedule_node &node);
/*
* Compute copyin for each filter node, by intersecting the domains of
* reads and writes of the entire scop.
*/
isl::union_map ComputeFilterCopyin(const isl::schedule_node &node, const isl::union_map &ori_reads,
const isl::union_map &ori_writes, const isl::schedule ori_schedule);
bool CompareSchTreeWithString(const std::string &compare_sch, const isl::schedule &sch);
isl::schedule_constraints MakeScheduleConstraints(const isl::schedule &schedule, PassInfo &pass_info);
isl::union_map RemoveReduceOpSelfDependence(ScopInfo &scop_info, PassInfo &pass_info);
isl::union_map RemoveSelfDependence(PassInfo &pass_info);
isl::union_map RemoveInvariantDependence(const isl::schedule &schedule, PassInfo &pass_info);
/*
* Compute copyin for each filter and return the union of such copyins.
* In particular, return an empty result when the outermost band node
* is not a sequence/set node.
*
* "result" is the union of "copyin" from each filter node, which in
* turn is computed by ComputeFilterCopyin.
*/
isl::union_map ComputeFakeCopyin(const isl::schedule &schedule, const isl::union_map &fake_copyin,
const isl::union_map &ori_reads, const isl::union_map &ori_writes);
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_PASS_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "change_marknode_position.h"
namespace akg {
namespace ir {
namespace poly {
isl::schedule ChangeMarkNodePosition::Run(isl::schedule curr_schedule) {
std::unordered_set<std::string> ids = with_stmts_ids_;
if (ids.empty()) {
return curr_schedule;
}
auto fn = [&ids](isl::schedule_node node) -> isl::schedule_node {
if (node.isa<isl::schedule_node_mark>()) {
std::string mark_id = node.as<isl::schedule_node_mark>().get_id().get_name();
if (mark_id == "realize_UB" && node.child(0).isa<isl::schedule_node_band>()) {
if (node.child(0).child(0).isa<isl::schedule_node_sequence>()) {
node = node.get_child(0).get_child(0); // sequence
bool delete_outer_mark = true;
int n = node.n_children();
for (int i = 0; i < n; i++) {
isl::schedule_node_filter filter_node = node.child(i).as<isl::schedule_node_filter>();
bool is_not_with_stmt = filter_node.get_filter().every_set(
[&ids](const isl::set &s) -> bool { return (ids.count(s.get_tuple_name()) == 0); });
if (is_not_with_stmt) {
delete_outer_mark = false;
} else {
node = node.child(i).child(0);
node = node.insert_mark(isl::id(node.ctx(), mark_id));
node = node.parent().parent();
}
}
node = node.parent().parent();
if (delete_outer_mark) {
node = node.del();
}
}
}
}
return node;
};
return curr_schedule.get_root().map_descendant_bottom_up(fn).get_schedule();
}
} // namespace poly
} // namespace ir
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_CHANGE_MARKNODE_POSITION_H_
#define POLY_CHANGE_MARKNODE_POSITION_H_
#include "poly/schedule_pass.h"
#include <unordered_set>
namespace akg {
namespace ir {
namespace poly {
/*
* "with" stmt aims to work around the irregular problem.
* By default, the "realize_UB" mark is on the outer band. However, for tensor-of-tensor,
* the intermediate tensor may be too large if realized in the outermost scope.
* To narrow down the scope, we move "realize_UB" mark to the filter node.
* If all filter nodes of the band are "with" stmts, we remove the outer "realize_UB" mark.
*/
class ChangeMarkNodePosition : public SchedulePass {
public:
ChangeMarkNodePosition(const std::unordered_set<std::string> &with_stmts_ids) : with_stmts_ids_(with_stmts_ids) {
pass_name_ = __FUNCTION__;
};
~ChangeMarkNodePosition(){};
virtual isl::schedule Run(isl::schedule sch);
private:
std::unordered_set<std::string> with_stmts_ids_;
};
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_CHANGE_MARKNODE_POSITION_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "compute_inner_band_dependency.h"
namespace akg {
namespace ir {
namespace poly {
isl::schedule ComputeInnerBandDependency::Run(isl::schedule sch) {
auto ori_reads = scop_info_.analysis_result_.GetReads();
auto ori_writes = scop_info_.analysis_result_.GetWrites();
auto ori_fake_copyin = scop_info_.analysis_result_.GetFakeCopyin();
auto inner_band_dependency =
ComputeFakeCopyin(sch, ori_fake_copyin, ori_reads, ori_writes).subtract(scop_info_.analysis_result_.GetCopyin());
scop_info_.analysis_result_.RecordInnerBandDependency(inner_band_dependency);
return sch;
}
} // namespace poly
} // namespace ir
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_COMPUTE_INNER_BAND_DEPENDENCY_H_
#define POLY_COMPUTE_INNER_BAND_DEPENDENCY_H_
#include "poly/schedule_pass.h"
#include "poly/scop_info.h"
namespace akg {
namespace ir {
namespace poly {
/*
* This class initialises the inner band dependency information used in InjectMulticoreToSchedule pass
* and record it in the scop info. No actual schedule tree transfrom is performed in this pass.
*/
class ComputeInnerBandDependency : public SchedulePass {
public:
ComputeInnerBandDependency(ScopInfo &scop_info) : scop_info_(scop_info) { pass_name_ = __FUNCTION__; }
~ComputeInnerBandDependency() {}
virtual isl::schedule Run(isl::schedule sch);
private:
ScopInfo &scop_info_;
};
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_COMPUTE_INNER_BAND_DEPENDENCY_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "compute_schedule.h"
namespace akg {
namespace ir {
namespace poly {
isl::union_map ComputeSchedule::ModDependences(const isl::union_map &dependences) {
isl::union_map umap = isl::union_map::empty(dependences.ctx());
dependences.foreach_map([&](const isl::map &m) -> void {
isl::map mm = m;
if (mm.get_tuple_id(isl_dim_in) != mm.get_tuple_id(isl_dim_out)) {
isl_map *pmap = mm.copy();
int n_in = isl_map_dim(pmap, isl_dim_in);
for (int i = 0; i < n_in; ++i) {
pmap = isl_map_plain_update_val_if_fixed(pmap, isl_dim_in, i);
}
mm = isl::manage(pmap);
}
umap = umap.unite(isl::union_map(mm));
});
return umap;
}
void ComputeSchedule::SetIslOptions() {
auto ctx = pass_info_.constraints_.ctx().get();
int status = isl_options_set_schedule_unit_max_var_coefficient_sum(ctx, 1);
CHECK(status == isl_stat_ok);
if (scop_info_.user_config_.GetComputeReschedule()) {
status = isl_options_set_schedule_whole_component(ctx, 0);
CHECK(status == isl_stat_ok);
} else {
status = isl_options_set_schedule_maximize_coincidence(ctx, 0);
CHECK(status == isl_stat_ok);
status = isl_options_set_schedule_whole_component(ctx, 1);
CHECK(status == isl_stat_ok);
}
if (scop_info_.user_config_.GetDisableScheduleShift()) {
status = isl_options_set_schedule_max_constant_term(ctx, 0);
CHECK(status == isl_stat_ok);
status = isl_options_set_schedule_nonneg_var_coefficient(ctx, 1);
CHECK(status == isl_stat_ok);
}
if (scop_info_.user_config_.GetEnableScheduleMaxConstant()) {
status = isl_options_set_schedule_max_constant_term(ctx, 0);
CHECK(status == isl_stat_ok);
}
if (scop_info_.user_config_.GetDisableLoopReversal()) {
status = isl_options_set_schedule_nonneg_var_coefficient(ctx, 1);
CHECK(status == isl_stat_ok);
}
if (scop_info_.user_config_.GetDisableLoopFusion()) {
status = isl_options_set_schedule_serialize_sccs(ctx, 1);
CHECK(status == isl_stat_ok);
}
}
isl::schedule ComputeSchedule::Run(isl::schedule sch) {
if (scop_info_.user_config_.GetModScheduleShift()) {
pass_info_.dependences_ = ModDependences(pass_info_.dependences_);
}
pass_info_.constraints_ = MakeScheduleConstraints(sch, pass_info_);
SetIslOptions();
return pass_info_.constraints_.compute_schedule();
}
} // namespace poly
} // namespace ir
} // namespace akg
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_COMPUTE_SCHEDULE_H_
#define POLY_COMPUTE_SCHEDULE_H_
#include "poly/schedule_pass.h"
namespace akg {
namespace ir {
namespace poly {
/*
* compute schedule pass, the main tasks ars as follow
* 1. modify the dependences depends on the configuration switch
* 2. Generating constraints
* 3. According to the constraints, the ISL interface is called to generate a new schedule
*/
class ComputeSchedule : public SchedulePass {
public:
ComputeSchedule(PassInfo &pass_info, ScopInfo &scop_info) : pass_info_(pass_info), scop_info_(scop_info) {
pass_name_ = __FUNCTION__;
}
~ComputeSchedule() {}
virtual isl::schedule Run(isl::schedule sch);
void SetIslOptions();
isl::union_map ModDependences(const isl::union_map &dependences);
private:
PassInfo &pass_info_;
ScopInfo &scop_info_;
};
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_COMPUTE_SCHEDULE_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "compute_transfer_copyin.h"
namespace akg {
namespace ir {
namespace poly {
isl::schedule ComputeTransferCopyin::Run(isl::schedule sch) {
// compute fake copyin
auto ori_reads = scop_info_.analysis_result_.GetReads();
auto ori_writes = scop_info_.analysis_result_.GetWrites();
auto ori_fake_copyin = scop_info_.analysis_result_.GetFakeCopyin();
isl::union_map fake_copyin = ComputeFakeCopyin(sch, ori_fake_copyin, ori_reads, ori_writes);
fake_copyin = fake_copyin.subtract(scop_info_.analysis_result_.GetCopyin());
scop_info_.analysis_result_.RecordFakeCopyin(fake_copyin);
isl::union_map raw_writes = ori_writes.domain_factor_domain();
isl::union_map raw_reads = ori_reads.domain_factor_domain();
isl::union_map raw_copyin = scop_info_.analysis_result_.GetCopyin().domain_factor_domain();
isl::union_map reads = fake_copyin.domain_factor_domain();
isl::union_map transfer_copyin = fake_copyin;
while (!reads.is_empty()) {
isl::union_map writes = raw_writes.intersect_range(reads.range());
isl::union_map dependence = DependenceAnalysis(writes, reads, writes, sch.get_map());
isl::union_set stmt = dependence.domain().universe();
scop_info_.analysis_result_.RecordTransferStmt(scop_info_.analysis_result_.GetTransferStmt().unite(stmt));
reads = raw_reads.intersect_domain(stmt);
// compute transfer copyin
isl::union_map target_acc = raw_writes.intersect_domain(stmt);
isl::union_map relation = target_acc.reverse().apply_range(reads);
transfer_copyin = transfer_copyin.apply_range(relation);
isl::union_map copyin = transfer_copyin.intersect_range(raw_copyin.range().universe());
scop_info_.analysis_result_.RecordReads(scop_info_.analysis_result_.GetReads().unite(copyin));
scop_info_.analysis_result_.RecordCopyin(scop_info_.analysis_result_.GetCopyin().unite(copyin));
transfer_copyin = transfer_copyin.subtract(copyin);
reads = reads.subtract(raw_copyin);
reads = reads.subtract(fake_copyin.domain_factor_domain());
}
return sch;
}
} // namespace poly
} // namespace ir
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_COMPUTE_TRANSFER_COPYIN_H_
#define POLY_COMPUTE_TRANSFER_COPYIN_H_
#include "poly/schedule_pass.h"
namespace akg {
namespace ir {
namespace poly {
/*
* This class initialises the transfer copyin information used in TransferStmt and MemoryPromotion pass
* and record it in the scop info. No actual schedule tree transfrom is performed in this pass.
*/
class ComputeTransferCopyin : public SchedulePass {
public:
ComputeTransferCopyin(ScopInfo &scop_info, PassInfo &pass_info) : scop_info_(scop_info), pass_info_(pass_info) {
pass_name_ = __FUNCTION__;
}
~ComputeTransferCopyin() {}
virtual isl::schedule Run(isl::schedule sch);
private:
ScopInfo &scop_info_;
PassInfo &pass_info_;
};
} // namespace poly
} // namespace ir
} // namespace akg
#endif // POLY_COMPUTE_TRANSFER_COPYIN_H_
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -13,24 +13,25 @@ ...@@ -13,24 +13,25 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef POLY_SINK_AXIS_H_ #ifndef POLY_INSERT_NODE_FOR_ALLOCC_H_
#define POLY_SINK_AXIS_H_ #define POLY_INSERT_NODE_FOR_ALLOCC_H_
#pragma once #include "poly/schedule_pass.h"
#include "poly/transform.h"
#define MAX_STRIDE 65535
namespace akg { namespace akg {
namespace ir { namespace ir {
namespace poly { namespace poly {
bool FindC0Schedule(const isl::pw_aff_list &paList); class InsertNodeForAllocC : public SchedulePass {
void ExchangeCoincident(std::vector<int> &coincident, const isl::schedule_node &node, public:
const std::unordered_map<int, bool> lastIdxSchedule, const int &n); InsertNodeForAllocC() { pass_name_ = __FUNCTION__; };
~InsertNodeForAllocC(){};
virtual isl::schedule Run(isl::schedule sched);
};
} // namespace poly } // namespace poly
} // namespace ir } // namespace ir
} // namespace akg } // namespace akg
#endif // POLY_SINK_AXIS_H_ #endif // POLY_INSERT_NODE_FOR_ALLOCC_H_
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册