未验证 提交 9c101490 编写于 作者: H HongyuJia 提交者: GitHub

[CINN Frontend] Optimize user interface, polish fuse_pass classes (#55705)

* [CINN Frontend] Optimize user interface, polish fuse_pass classes"

* Fix some compile error
上级 09a60477
......@@ -433,28 +433,6 @@ function(download_and_uncompress INSTALL_DIR URL FILENAME)
INSTALL_COMMAND "")
endfunction()
set(fusion_pass_file
${CMAKE_CURRENT_BINARY_DIR}/paddle/cinn/hlir/pass/use_general_pass.h
CACHE INTERNAL "use_general_pass.h file")
file(
WRITE ${fusion_pass_file}
"#include \"paddle/cinn/common/macros.h\" // Generated by the paddle/cinn/hlir/pass/CMakeLists.txt. DO NOT EDIT!\n\n"
)
function(find_fusion_pass_register FILENAME ADD_PATH PATTERN)
# set op_name to OUTPUT
file(READ ${FILENAME} CONTENT)
string(REGEX MATCHALL "${PATTERN}\\([a-zA-Z0-9_]*," fusion_pass_patterns
"${CONTENT}")
if(NOT fusion_pass_patterns STREQUAL "")
foreach(pass_pattern ${fusion_pass_patterns})
string(REPLACE "${PATTERN}(" "" pass_pattern "${pass_pattern}")
string(REPLACE "," "" pass_pattern "${pass_pattern}")
file(APPEND ${ADD_PATH} "USE_FUSION_PASS(${pass_pattern});\n")
endforeach()
endif()
endfunction()
function(gather_srcs SRC_GROUP)
set(options)
set(oneValueArgs)
......@@ -464,8 +442,6 @@ function(gather_srcs SRC_GROUP)
set(${SRC_GROUP}
"${${SRC_GROUP}};${CMAKE_CURRENT_SOURCE_DIR}/${cpp}"
CACHE INTERNAL "")
find_fusion_pass_register("${CMAKE_CURRENT_SOURCE_DIR}/${cpp}"
${fusion_pass_file} "CINN_REGISTER_FUSION_PASS")
endforeach()
endfunction()
......
......@@ -67,17 +67,6 @@
__test_global_namespace_##uniq_name##__>::value, \
msg)
#define CINN_REGISTER_FUSION_PASS(pass_name, pass_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_pass__##pass_name, \
"CINN_REGISTER_FUSION_PASS must be called in global namespace"); \
static ::cinn::hlir::pass::FusionPassRegistrar<pass_class> \
__pass_registrar_##pass_name##__(#pass_name); \
int TouchFusionPassRegistrar_##pass_name() { \
__pass_registrar_##pass_name##__.Touch(); \
return 0; \
}
#define USE_FUSION_PASS(pass_name) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_fusion_pass_##pass_name, \
......
......@@ -30,7 +30,6 @@
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/tensor.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_general_pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
namespace cinn::frontend {
......
......@@ -21,7 +21,6 @@
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_general_pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/runtime/flags.h"
......
......@@ -26,7 +26,6 @@
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/visualize_helper.h"
#include "paddle/cinn/hlir/pass/use_general_pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/runtime/flags.h"
......
......@@ -24,7 +24,6 @@
#include "paddle/cinn/frontend/program_pass.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/pass/use_general_pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
namespace cinn::frontend {
......
......@@ -15,7 +15,6 @@
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/visualize_helper.h"
#include "paddle/cinn/hlir/pass/use_general_pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
namespace cinn {
......
add_subdirectory(general_fusion_merge_pass)
core_gather_headers()
gather_srcs(
......
......@@ -15,9 +15,15 @@
#include <map>
#include <unordered_map>
#include "glog/logging.h"
#include "paddle/cinn/api/op_group.h"
#include "paddle/cinn/common/is_reachable_predicator.h"
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/fusion_pass_map.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/graph_group_input_fuse_pass_ctx.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/graph_group_lightware_fuse_pass_ctx.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/input_fuse_pass.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/lightware_fuse_pass.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/lightware_fuse_pass_ctx.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass_utils.h"
DECLARE_bool(enhance_vertical_fusion_with_recompute);
......@@ -47,1015 +53,6 @@ using OpGroupList = std::vector<OpGroupPtr>;
using ConditionFunction = std::function<bool(
const FusionHelperBase*, const GroupPtr&, const GroupPtr&)>;
class FuseHelper {
public:
virtual ~FuseHelper() = default;
virtual bool AllOutputsSameSize(const OpGroupPtr& first,
const OpGroupPtr& second) const = 0;
virtual bool HorizontalElementwiseFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool ElementwiseFuseBroadcast(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool HorizontalWithInjective(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool ElementwiseFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool BroadcastFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool InjectiveHorizontalWithReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool ReduceFuseElementwise(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool ReduceFuseBroadcast(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool ReduceFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool IsReachable(const OpGroupPtr& lhs,
const OpGroupPtr& rhs) const = 0;
virtual bool DetectCycleIfFuse(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool IsConsumerSetsReachable(
const OpGroupPtr& group,
const std::unordered_set<OpGroupPtr>& consumers) const = 0;
protected:
FuseHelper() = default;
};
template <typename FusePassCtxT>
class GraphGroupFuseHelper final : public FuseHelper {
public:
explicit GraphGroupFuseHelper(const FusePassCtxT* ctx) : ctx_(ctx) {}
bool AllOutputsSameSize(const OpGroupPtr& first,
const OpGroupPtr& second) const override;
bool HorizontalElementwiseFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool ElementwiseFuseBroadcast(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool HorizontalWithInjective(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool ElementwiseFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool BroadcastFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool InjectiveHorizontalWithReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool ReduceFuseElementwise(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool ReduceFuseBroadcast(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool ReduceFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool IsReachable(const OpGroupPtr& lhs,
const OpGroupPtr& rhs) const override {
return IsReachableInDag(lhs, rhs) || IsReachableInDag(rhs, lhs);
}
bool DetectCycleIfFuse(const OpGroupPtr& lhs,
const OpGroupPtr& rhs) const override {
return ReachableIfDirectEdgeIgnored(lhs, rhs) ||
ReachableIfDirectEdgeIgnored(rhs, lhs);
}
bool IsConsumerSetsReachable(
const OpGroupPtr& group,
const std::unordered_set<OpGroupPtr>& consumers) const override {
for (const auto& consumer : consumers) {
if (group == consumer) {
continue;
}
if (IsReachableInDag(consumer, group)) {
return true;
}
}
return false;
}
private:
bool IsReachableInDag(const OpGroupPtr& producer,
const OpGroupPtr& consumer) const {
const auto& MinDepth4Node = [&](const OpGroupPtr& node) {
return node.GetGroup()->min_depth;
};
const auto& MaxDepth4Node = [&](const OpGroupPtr& node) {
return node.GetGroup()->max_depth;
};
const auto& VisitNextNodes =
[&](const OpGroupPtr& node,
const std::function<void(OpGroupPtr)>& Visit) {
for (const auto& node_producer : node.producers()) {
Visit(node_producer);
}
};
common::IsReachablePredicator<OpGroupPtr> is_reachable(
MinDepth4Node, MaxDepth4Node, VisitNextNodes);
return is_reachable(consumer, producer, [](OpGroupPtr) {});
}
bool ReachableIfDirectEdgeIgnored(const OpGroupPtr& producer,
const OpGroupPtr& consumer) const {
const auto& MinDepth4Node = [&](const OpGroupPtr& node) {
return node.GetGroup()->min_depth;
};
const auto& MaxDepth4Node = [&](const OpGroupPtr& node) {
return node.GetGroup()->max_depth;
};
const auto& VisitNextNodes =
[&](const OpGroupPtr& node,
const std::function<void(OpGroupPtr)>& Visit) {
for (const auto& node_producer : node.producers()) {
if (node == consumer && node_producer == producer) {
continue;
}
Visit(node_producer);
}
};
common::IsReachablePredicator<OpGroupPtr> is_reachable(
MinDepth4Node, MaxDepth4Node, VisitNextNodes);
return is_reachable(consumer, producer, [](OpGroupPtr) {});
}
const FusePassCtxT* ctx_;
};
class FusePassCtx {
public:
virtual ~FusePassCtx() {}
virtual const FuseHelper& fuse_helper() const = 0;
virtual void MarkFusible(const OpGroupPtr& first,
const OpGroupPtr& second) = 0;
protected:
FusePassCtx() = default;
};
class LightwareFusePassCtx : public FusePassCtx {
public:
virtual ~LightwareFusePassCtx() {}
virtual const OpGroupPtr& PickOpGroup() const = 0;
virtual const FuseHelper& fuse_helper() const = 0;
virtual void MarkFusible(const OpGroupPtr& first,
const OpGroupPtr& second) = 0;
virtual void MarkFusible(const OpGroupList& candidates) = 0;
protected:
LightwareFusePassCtx() = default;
};
class GraphGroupLightwareFusePassCtx final : public LightwareFusePassCtx {
public:
GraphGroupLightwareFusePassCtx(
const FusionHelperBase* graph_group_fusion_helper,
const OpGroupPtr& group,
const std::function<void(const OpGroupPtr& first,
const OpGroupPtr& second)>& MarkFusible)
: graph_group_fusion_helper_(graph_group_fusion_helper),
group_(group),
MarkFusible_(MarkFusible),
fuse_helper_(
new GraphGroupFuseHelper<GraphGroupLightwareFusePassCtx>(this)) {}
GraphGroupLightwareFusePassCtx(
const FusionHelperBase* graph_group_fusion_helper,
const OpGroupPtr& group,
const std::function<void(const OpGroupList& candidates)>&
MarkGroupListFusible)
: graph_group_fusion_helper_(graph_group_fusion_helper),
group_(group),
MarkGroupListFusible_(MarkGroupListFusible),
fuse_helper_(
new GraphGroupFuseHelper<GraphGroupLightwareFusePassCtx>(this)) {}
const OpGroupPtr& PickOpGroup() const override { return group_; }
const FuseHelper& fuse_helper() const override { return *fuse_helper_; }
void MarkFusible(const OpGroupPtr& first, const OpGroupPtr& second) override {
MarkFusible_(first, second);
}
void MarkFusible(const OpGroupList& candidates) override {
MarkGroupListFusible_(candidates);
}
const FusionHelperBase& graph_group_fusion_helper() const {
return *graph_group_fusion_helper_;
}
private:
const FusionHelperBase* graph_group_fusion_helper_;
const OpGroupPtr& group_;
const std::function<void(const OpGroupPtr& first, const OpGroupPtr& second)>
MarkFusible_;
const std::function<void(const OpGroupList& candidates)>
MarkGroupListFusible_;
const std::unique_ptr<const FuseHelper> fuse_helper_;
};
class InputFusePassCtx : public FusePassCtx {
public:
virtual ~InputFusePassCtx() {}
virtual const OpGroupList& PickConsumersWithSameInputs() const = 0;
virtual const FuseHelper& fuse_helper() const = 0;
virtual void MarkFusible(const OpGroupPtr& first,
const OpGroupPtr& second) = 0;
virtual void MarkFusible(const OpGroupList& candidates) = 0;
protected:
InputFusePassCtx() = default;
};
class GraphGroupInputFusePassCtx final : public InputFusePassCtx {
public:
GraphGroupInputFusePassCtx(
const FusionHelperBase* graph_group_fusion_helper,
const OpGroupList& groups,
const std::function<void(const OpGroupPtr& first,
const OpGroupPtr& second)>& MarkFusible)
: graph_group_fusion_helper_(graph_group_fusion_helper),
groups_(groups),
MarkFusible_(MarkFusible),
fuse_helper_(
new GraphGroupFuseHelper<GraphGroupInputFusePassCtx>(this)) {}
GraphGroupInputFusePassCtx(
const FusionHelperBase* graph_group_fusion_helper,
const OpGroupList& groups,
const std::function<void(const OpGroupList& candidates)>&
MarkGroupListFusible)
: graph_group_fusion_helper_(graph_group_fusion_helper),
groups_(groups),
MarkGroupListFusible_(MarkGroupListFusible),
fuse_helper_(
new GraphGroupFuseHelper<GraphGroupInputFusePassCtx>(this)) {}
const OpGroupList& PickConsumersWithSameInputs() const override {
return groups_;
}
const FuseHelper& fuse_helper() const override { return *fuse_helper_; }
void MarkFusible(const OpGroupPtr& first, const OpGroupPtr& second) override {
MarkFusible_(first, second);
}
void MarkFusible(const OpGroupList& candidates) override {
MarkGroupListFusible_(candidates);
}
const FusionHelperBase& graph_group_fusion_helper() const {
return *graph_group_fusion_helper_;
}
private:
const FusionHelperBase* graph_group_fusion_helper_;
const OpGroupList& groups_;
const std::function<void(const OpGroupPtr& first, const OpGroupPtr& second)>
MarkFusible_;
const std::function<void(const OpGroupList& candidates)>
MarkGroupListFusible_;
const std::unique_ptr<const FuseHelper> fuse_helper_;
};
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::AllOutputsSameSize(
const OpGroupPtr& first, const OpGroupPtr& second) const {
return is_same_size(
&ctx_->graph_group_fusion_helper(), first.GetGroup(), second.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::HorizontalElementwiseFuseReduce(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return honrizontal_elementwise_fuse_reduce(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::ElementwiseFuseBroadcast(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return elementwise_fuse_broadcast(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::HorizontalWithInjective(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return horizontal_with_injective(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::ElementwiseFuseReduce(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return elementwise_fuse_reduce(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::BroadcastFuseReduce(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return broadcast_fuse_reduce(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::InjectiveHorizontalWithReduce(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return injective_horizontal_with_reduce(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::ReduceFuseElementwise(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return reduce_fuse_elementwise(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::ReduceFuseBroadcast(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return reduce_fuse_broadcast(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::ReduceFuseReduce(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return reduce_fuse_reduce(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
template <typename FusePassCtxT>
struct HorizontalFuseUtil {
using KindKeyT = std::pair<OpPatternKind, OpPatternKind>;
static bool DetectFusabilityByKind(FusePassCtxT* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
const KindKeyT kind_pair(src.kind(), dst.kind());
const auto& map = GetConditionMap();
const auto& iter = map.find(kind_pair);
if (iter == map.end()) {
return false;
}
return iter->second(ctx, src, dst);
}
typedef bool (*ConditionT)(FusePassCtxT* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst);
static const std::map<KindKeyT, ConditionT>& GetConditionMap() {
thread_local static std::map<KindKeyT, ConditionT> map(RawConditionMap());
return map;
}
static std::map<KindKeyT, ConditionT> RawConditionMap() {
return std::map<KindKeyT, ConditionT>{
{{OpPatternKind::kElementWise, framework::kElementWise}, &IsSameSize},
{{OpPatternKind::kElementWise, framework::kBroadcast}, &IsSameSize},
{{OpPatternKind::kElementWise, framework::kInjective}, &IsSameSize},
{{OpPatternKind::kElementWise, framework::kReduction},
&HorizontalElementwiseFuseReduce},
{{OpPatternKind::kBroadcast, framework::kElementWise}, &IsSameSize},
{{OpPatternKind::kBroadcast, framework::kBroadcast}, &IsSameSize},
{{OpPatternKind::kBroadcast, framework::kInjective}, &IsSameSize},
{{OpPatternKind::kBroadcast, framework::kReduction}, &IsSameSize},
{{OpPatternKind::kInjective, framework::kElementWise}, &IsSameSize},
{{OpPatternKind::kInjective, framework::kBroadcast}, &IsSameSize},
{{OpPatternKind::kInjective, framework::kInjective}, &IsSameSize},
{{OpPatternKind::kInjective, framework::kReduction}, &IsSameSize},
{{OpPatternKind::kReduction, framework::kElementWise},
&HorizontalElementwiseFuseReduce},
{{OpPatternKind::kReduction, framework::kBroadcast}, &IsSameSize},
{{OpPatternKind::kReduction, framework::kInjective}, &IsSameSize},
{{OpPatternKind::kReduction, framework::kReduction}, &ReduceFuseReduce},
};
}
static bool IsSameSize(FusePassCtxT* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return utils::IsSameSize(src, dst);
}
static bool HorizontalElementwiseFuseReduce(FusePassCtxT* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
// if same shape with horizontal relation
if (IsSameSize(ctx, src, dst)) {
return true;
}
const OpGroupPtr* ele_group = nullptr;
const OpGroupPtr* reduce_group = nullptr;
if (src.kind() == framework::kReduction) {
ele_group = &dst;
reduce_group = &src;
} else {
ele_group = &src;
reduce_group = &dst;
}
size_t size_ele =
utils::GetMasterNode(*ele_group).outputs()[0].shape().numel();
bool can_fuse = false;
reduce_group->WalkOpNodes([&](const api::OpNode& op) {
if (op.kind() == OpPatternKind::kReduction) {
size_t size_master = op.outputs()[0].shape().numel();
if (size_ele == size_master) {
can_fuse = true;
}
}
});
return can_fuse;
}
static bool ReduceFuseReduce(FusePassCtxT* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().ReduceFuseReduce(src, dst);
}
};
class FusePass {
public:
virtual ~FusePass() = default;
virtual const std::string FuseMode() const = 0;
virtual int Benefit() const = 0;
protected:
FusePass() = default;
};
class InputFusePass : public FusePass {
public:
virtual ~InputFusePass() = default;
virtual void operator()(InputFusePassCtx* ctx) const = 0;
const std::string FuseMode() const final { return "InputFuse"; }
virtual int Benefit() const = 0;
protected:
InputFusePass() = default;
};
class DefaultInputFusePass final : public InputFusePass {
public:
DefaultInputFusePass() : InputFusePass() {}
int Benefit() const override { return 100; }
void operator()(InputFusePassCtx* ctx) const override {
const auto& consumer_set = ctx->PickConsumersWithSameInputs();
const std::unordered_set<OpGroupPtr> consumer_candidates =
[&]() -> std::unordered_set<OpGroupPtr> {
std::unordered_set<OpGroupPtr> consumers;
for (const auto& consumer : consumer_set) {
if (consumer.kind() == framework::kElementWise ||
consumer.kind() == framework::kBroadcast ||
consumer.kind() == framework::kInjective ||
consumer.kind() == framework::kReduction) {
consumers.insert(consumer);
}
}
return consumers;
}();
if (consumer_candidates.size() <= 1) {
return;
}
std::vector<OpGroupList> fusionable_consumers;
for (auto& candidate : consumer_candidates) {
if (ctx->fuse_helper().IsConsumerSetsReachable(candidate,
consumer_candidates)) {
continue;
}
if (fusionable_consumers.empty()) {
fusionable_consumers.push_back({candidate});
continue;
}
// check each fusionable groups
bool fusionable = false;
for (auto& groups : fusionable_consumers) {
auto& last = groups.back();
if (!HorizontalFuseUtil<InputFusePassCtx>::DetectFusabilityByKind(
ctx, candidate, last)) {
continue;
}
groups.push_back(candidate);
fusionable = true;
break;
}
// if can't fuse to othors Groups, new Groups.
if (!fusionable) {
fusionable_consumers.push_back({candidate});
}
}
for (const auto& groups : fusionable_consumers) {
if (groups.size() > 1) {
ctx->MarkFusible(groups);
}
}
VLOG(1) << "DefaultInputFusePass Finish";
}
};
class LightwareFusePass : public FusePass {
public:
virtual ~LightwareFusePass() = default;
virtual void operator()(LightwareFusePassCtx* ctx) const = 0;
virtual const std::string FuseMode() const = 0;
virtual int Benefit() const = 0;
protected:
LightwareFusePass() = default;
};
class HorizontalFusePass : public LightwareFusePass {
public:
virtual ~HorizontalFusePass() = default;
virtual void operator()(LightwareFusePassCtx* ctx) const = 0;
const std::string FuseMode() const final { return "HorizontalFuse"; }
virtual int Benefit() const = 0;
protected:
HorizontalFusePass() = default;
};
class DefaultHorizontalFusePass final : public HorizontalFusePass {
public:
DefaultHorizontalFusePass() : HorizontalFusePass() {}
int Benefit() const override { return 100; }
void operator()(LightwareFusePassCtx* ctx) const override {
const auto& producer = ctx->PickOpGroup();
const std::unordered_set<OpGroupPtr> consumer_candidates =
[&]() -> std::unordered_set<OpGroupPtr> {
std::unordered_set<OpGroupPtr> consumers;
for (const auto& consumer : producer.consumers()) {
if (consumer.kind() == framework::kElementWise ||
consumer.kind() == framework::kBroadcast ||
consumer.kind() == framework::kInjective ||
consumer.kind() == framework::kReduction) {
consumers.insert(consumer);
}
}
return consumers;
}();
if (consumer_candidates.size() <= 1) {
return;
}
std::vector<OpGroupList> fusionable_consumers;
for (auto& candidate : consumer_candidates) {
if (ctx->fuse_helper().IsConsumerSetsReachable(candidate,
consumer_candidates)) {
continue;
}
if (fusionable_consumers.empty()) {
fusionable_consumers.push_back({candidate});
continue;
}
// check each fusionable groups
bool fusionable = false;
for (auto& groups : fusionable_consumers) {
auto& last = groups.back();
if (!HorizontalFuseUtil<LightwareFusePassCtx>::DetectFusabilityByKind(
ctx, candidate, last)) {
continue;
}
groups.push_back(candidate);
fusionable = true;
break;
}
// if can't fuse to othors Groups, new Groups.
if (!fusionable) {
fusionable_consumers.push_back({candidate});
}
}
for (const auto& groups : fusionable_consumers) {
if (groups.size() > 1) {
// Trick for BERT, maybe not required, wait for substitution from
// unordered_set to set
if (groups.size() == 2) {
OpGroupList fuse_group;
if (groups[1].group_id().substr(0, 4) == "cast" &&
groups[0].group_id() == "reshape_split") {
fuse_group.push_back(groups[1]);
fuse_group.push_back(groups[0]);
ctx->MarkFusible(fuse_group);
continue;
}
}
ctx->MarkFusible(groups);
}
}
}
};
class VerticalFusePass : public LightwareFusePass {
public:
virtual ~VerticalFusePass() = default;
virtual void operator()(LightwareFusePassCtx* ctx) const = 0;
const std::string FuseMode() const final { return "VerticalFuse"; }
virtual int Benefit() const = 0;
protected:
VerticalFusePass() = default;
};
class DefaultVerticalFusePass final : public VerticalFusePass {
public:
DefaultVerticalFusePass() : VerticalFusePass() {}
int Benefit() const override { return 100; }
void operator()(LightwareFusePassCtx* ctx) const override {
const auto& producer = ctx->PickOpGroup();
const OpGroupList consumers = [&]() {
OpGroupList consumers;
for (const auto& consumer : producer.consumers()) {
consumers.push_back(consumer);
}
return consumers;
}();
if (consumers.size() == 0) {
return;
}
std::vector<OpGroupPtr> candidates;
for (int i = 0; i < consumers.size(); ++i) {
const auto& consumer = consumers.at(i);
if (!DetectFusabilityByKind(ctx, producer, consumer)) {
break;
}
candidates.push_back(consumer);
}
if (candidates.size() == consumers.size() &&
producer.kind() == framework::kElementWise) {
return;
}
for (int i = 0; i < consumers.size(); ++i) {
const auto& consumer = consumers.at(i);
if (!DetectFusabilityByKind(ctx, producer, consumer)) {
continue;
}
if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) {
VLOG(4) << "Can't fuse because detect cycle";
continue;
}
ctx->MarkFusible(producer, consumer);
}
}
using KindKeyT = std::pair<OpPatternKind, OpPatternKind>;
bool DetectFusabilityByKind(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) const {
const KindKeyT kind_pair(src.kind(), dst.kind());
const auto& map = GetConditionMap();
const auto& iter = map.find(kind_pair);
if (iter == map.end()) {
return false;
}
return iter->second(ctx, src, dst);
}
typedef bool (*ConditionT)(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst);
static const std::map<KindKeyT, ConditionT>& GetConditionMap() {
thread_local static std::map<KindKeyT, ConditionT> map(RawConditionMap());
return map;
}
static std::map<KindKeyT, ConditionT> RawConditionMap() {
return std::map<KindKeyT, ConditionT>{
{{OpPatternKind::kElementWise, framework::kElementWise},
&DefaultVerticalFusePass::IsSameSize},
{{OpPatternKind::kElementWise, framework::kBroadcast},
&DefaultVerticalFusePass::ElementwiseFuseBroadcast},
{{OpPatternKind::kElementWise, framework::kInjective},
&DefaultVerticalFusePass::HorizontalWithInjective},
{{OpPatternKind::kElementWise, framework::kReduction},
&DefaultVerticalFusePass::ElementwiseFuseReduce},
{{OpPatternKind::kBroadcast, framework::kElementWise},
&DefaultVerticalFusePass::IsSameSize},
{{OpPatternKind::kBroadcast, framework::kBroadcast},
&DefaultVerticalFusePass::IsSameSize},
{{OpPatternKind::kBroadcast, framework::kInjective},
&DefaultVerticalFusePass::HorizontalWithInjective},
{{OpPatternKind::kBroadcast, framework::kReduction},
&DefaultVerticalFusePass::BroadcastFuseReduce},
{{OpPatternKind::kInjective, framework::kElementWise},
&DefaultVerticalFusePass::IsSameSize},
{{OpPatternKind::kInjective, framework::kBroadcast},
&DefaultVerticalFusePass::IsSameSize},
{{OpPatternKind::kInjective, framework::kInjective},
&DefaultVerticalFusePass::HorizontalWithInjective},
{{OpPatternKind::kInjective, framework::kReduction},
&DefaultVerticalFusePass::InjectiveHorizontalWithReduce},
{{OpPatternKind::kReduction, framework::kElementWise},
&DefaultVerticalFusePass::ReduceFuseElementwise},
{{OpPatternKind::kReduction, framework::kBroadcast},
&DefaultVerticalFusePass::ReduceFuseBroadcast},
{{OpPatternKind::kReduction, framework::kInjective},
&DefaultVerticalFusePass::HorizontalWithInjective},
{{OpPatternKind::kReduction, framework::kReduction},
&DefaultVerticalFusePass::ReduceFuseReduce},
};
}
static bool IsSameSize(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return utils::IsSameSize(src, dst);
}
static bool ElementwiseFuseBroadcast(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().ElementwiseFuseBroadcast(src, dst);
}
static bool HorizontalWithInjective(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().HorizontalWithInjective(src, dst);
}
static bool ElementwiseFuseReduce(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().ElementwiseFuseReduce(src, dst);
}
static bool BroadcastFuseReduce(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().BroadcastFuseReduce(src, dst);
}
static bool InjectiveHorizontalWithReduce(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().InjectiveHorizontalWithReduce(src, dst);
}
static bool ReduceFuseElementwise(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().ReduceFuseElementwise(src, dst);
}
static bool ReduceFuseBroadcast(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().ReduceFuseBroadcast(src, dst);
}
static bool ReduceFuseReduce(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().ReduceFuseReduce(src, dst);
}
};
class RecomputeFusePass : public LightwareFusePass {
public:
virtual ~RecomputeFusePass() = default;
virtual void operator()(LightwareFusePassCtx* ctx) const = 0;
const std::string FuseMode() const final { return "RecomputeFuse"; }
virtual int Benefit() const = 0;
protected:
RecomputeFusePass() = default;
};
class DefaultRecomputeFusePass final : public RecomputeFusePass {
public:
DefaultRecomputeFusePass() : RecomputeFusePass() {}
int Benefit() const override { return 100; }
void operator()(LightwareFusePassCtx* ctx) const override {
const auto& producer = ctx->PickOpGroup();
const OpGroupList consumers = [&]() {
OpGroupList consumers;
for (const auto& consumer : producer.consumers()) {
consumers.push_back(consumer);
}
return consumers;
}();
// Borrows unsafe_candidates and candidates concept from origin
// fusion_merge_pass
std::vector<OpGroupPtr> unsafe_candidates;
std::vector<OpGroupPtr> candidates;
for (int i = 0; i < consumers.size(); ++i) {
const auto& consumer = consumers.at(i);
if (!DetectFusabilityByKind(ctx, producer, consumer)) {
continue;
}
unsafe_candidates.push_back(consumer);
if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) {
continue;
}
candidates.push_back(consumer);
}
if (!candidates.empty() && unsafe_candidates.size() == consumers.size() &&
producer.kind() == framework::kElementWise) {
for (const auto& consumer : consumers) {
ctx->MarkFusible(producer, consumer);
}
}
}
using KindKeyT = std::pair<OpPatternKind, OpPatternKind>;
bool DetectFusabilityByKind(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) const {
const KindKeyT kind_pair(src.kind(), dst.kind());
const auto& map = DefaultVerticalFusePass::GetConditionMap();
const auto& iter = map.find(kind_pair);
if (iter == map.end()) {
return false;
}
return iter->second(ctx, src, dst);
}
};
struct LightwareFusePassComparator {
bool operator()(const std::shared_ptr<LightwareFusePass>& lhs,
const std::shared_ptr<LightwareFusePass>& rhs) const {
return lhs->Benefit() > rhs->Benefit();
}
};
struct InputFusePassComparator {
bool operator()(const std::shared_ptr<InputFusePass>& lhs,
const std::shared_ptr<InputFusePass>& rhs) const {
return lhs->Benefit() > rhs->Benefit();
}
};
class FusionPassMap {
public:
static FusionPassMap& Instance() {
static FusionPassMap global_fusion_pass_map;
return global_fusion_pass_map;
}
bool Has(const std::string& pass_name) const {
return map_.find(pass_name) != map_.end();
}
void Insert(const std::string& pass_name,
const std::shared_ptr<FusePass>& pass) {
CHECK(!Has(pass_name)) << "FusePass " << pass_name
<< " has already been registered.";
map_.insert({pass_name, pass});
}
std::shared_ptr<FusePass> Get(const std::string& pass_name) const {
auto it = map_.find(pass_name);
CHECK(it != map_.end())
<< "FusePass " << pass_name << " has not been registered.";
return it->second;
}
// fuse_mode: HorizontalFuse, VerticalFuse, RecomputeFuse
std::vector<std::shared_ptr<LightwareFusePass>> GetLightwareFusePassesByMode(
const std::string& fuse_mode) const {
CHECK(fuse_mode == "HorizontalFuse" || fuse_mode == "VerticalFuse" ||
fuse_mode == "RecomputeFuse")
<< "fuse_mode only supports HorizontalFuse, VerticalFuse and "
"RecomputeFuse. Please check your input modes = "
<< fuse_mode;
std::set<std::shared_ptr<LightwareFusePass>, LightwareFusePassComparator>
candidate_passes;
for (const auto iter : map_) {
if (fuse_mode == iter.second->FuseMode()) {
candidate_passes.insert(
std::dynamic_pointer_cast<LightwareFusePass>(iter.second));
}
}
return std::vector<std::shared_ptr<LightwareFusePass>>(
candidate_passes.begin(), candidate_passes.end());
}
std::vector<std::shared_ptr<InputFusePass>> GetInputFusePasses() const {
std::set<std::shared_ptr<InputFusePass>, InputFusePassComparator>
candidate_passes;
for (const auto iter : map_) {
if (iter.second->FuseMode() == "InputFuse") {
candidate_passes.insert(
std::dynamic_pointer_cast<InputFusePass>(iter.second));
}
}
return std::vector<std::shared_ptr<InputFusePass>>(candidate_passes.begin(),
candidate_passes.end());
}
private:
FusionPassMap() = default;
std::unordered_map<std::string, std::shared_ptr<FusePass>> map_;
DISABLE_COPY_AND_ASSIGN(FusionPassMap);
};
class Registrar {
public:
// In our design, various kinds of classes, e.g., operators and kernels,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which
// are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_OP macros to
// call this method. So, as long as the callee code calls USE_OP, the global
// registrar variable won't be removed by the linker.
void Touch() {}
};
template <typename PassClassT>
class FusionPassRegistrar final : public Registrar {
public:
explicit FusionPassRegistrar(const std::string& pass_name) {
FusionPassMap::Instance().Insert(
pass_name, std::shared_ptr<PassClassT>(new PassClassT()));
}
};
// Op Fusion Pass which performs Ops fusion, Ops are fused
// "vertically", meaning producing Ops are fused into their consumers
// with the intent that the loops which compute their values will be fused in
......@@ -2078,12 +1075,3 @@ CINN_REGISTER_HELPER(GeneralFusionMergePass) {
return true;
}
CINN_REGISTER_FUSION_PASS(DefaultHorizontalFusePass,
cinn::hlir::pass::DefaultHorizontalFusePass);
CINN_REGISTER_FUSION_PASS(DefaultVerticalFusePass,
cinn::hlir::pass::DefaultVerticalFusePass);
CINN_REGISTER_FUSION_PASS(DefaultRecomputeFusePass,
cinn::hlir::pass::DefaultRecomputeFusePass);
CINN_REGISTER_FUSION_PASS(DefaultInputFusePass,
cinn::hlir::pass::DefaultInputFusePass);
set(fusion_pass_file
${PADDLE_BINARY_DIR}/paddle/cinn/hlir/pass/use_general_pass.h)
file(
WRITE ${fusion_pass_file}
"#include \"paddle/cinn/common/macros.h\" // Generated by the paddle/cinn/hlir/pass/CMakeLists.txt. DO NOT EDIT!\n\n#pragma once\n\n"
)
function(find_fusion_pass_register FILENAME ADD_PATH PATTERN)
# set op_name to OUTPUT
file(READ ${FILENAME} CONTENT)
string(REGEX MATCHALL "${PATTERN}\\([a-zA-Z0-9_]*," fusion_pass_patterns
"${CONTENT}")
if(NOT fusion_pass_patterns STREQUAL "")
foreach(pass_pattern ${fusion_pass_patterns})
string(REPLACE "${PATTERN}(" "" pass_pattern "${pass_pattern}")
string(REPLACE "," "" pass_pattern "${pass_pattern}")
file(APPEND ${ADD_PATH} "USE_FUSION_PASS(${pass_pattern});\n")
endforeach()
endif()
endfunction()
file(GLOB_RECURSE FUSION_PASS_HEADER ./*.cc)
foreach(file ${FUSION_PASS_HEADER})
find_fusion_pass_register("${file}" ${fusion_pass_file}
"CINN_REGISTER_FUSION_PASS")
endforeach()
core_gather_headers()
gather_srcs(
cinnapi_src SRCS default_horizontal_fuse_pass.cc default_input_fuse_pass.cc
default_vertical_fuse_pass.cc default_recompute_fuse_pass.cc)
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/fusion_pass_registrar.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/horizontal_fuse_pass.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/horizontal_fuse_util.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/lightware_fuse_pass_ctx.h"
namespace cinn {
namespace hlir {
namespace pass {
class DefaultHorizontalFusePass final : public HorizontalFusePass {
public:
DefaultHorizontalFusePass() : HorizontalFusePass() {}
int Benefit() const override { return 100; }
void operator()(LightwareFusePassCtx* ctx) const override {
const auto& producer = ctx->PickOpGroup();
const std::unordered_set<OpGroupPtr> consumer_candidates =
[&]() -> std::unordered_set<OpGroupPtr> {
std::unordered_set<OpGroupPtr> consumers;
for (const auto& consumer : producer.consumers()) {
if (consumer.kind() == framework::kElementWise ||
consumer.kind() == framework::kBroadcast ||
consumer.kind() == framework::kInjective ||
consumer.kind() == framework::kReduction) {
consumers.insert(consumer);
}
}
return consumers;
}();
if (consumer_candidates.size() <= 1) {
return;
}
std::vector<OpGroupList> fusionable_consumers;
for (auto& candidate : consumer_candidates) {
if (ctx->fuse_helper().IsConsumerSetsReachable(candidate,
consumer_candidates)) {
continue;
}
if (fusionable_consumers.empty()) {
fusionable_consumers.push_back({candidate});
continue;
}
// check each fusionable groups
bool fusionable = false;
for (auto& groups : fusionable_consumers) {
auto& last = groups.back();
if (!HorizontalFuseUtil<LightwareFusePassCtx>::DetectFusabilityByKind(
ctx, candidate, last)) {
continue;
}
groups.push_back(candidate);
fusionable = true;
break;
}
// if can't fuse to othors Groups, new Groups.
if (!fusionable) {
fusionable_consumers.push_back({candidate});
}
}
for (const auto& groups : fusionable_consumers) {
if (groups.size() > 1) {
// Trick for BERT, maybe not required, wait for substitution from
// unordered_set to set
if (groups.size() == 2) {
OpGroupList fuse_group;
if (groups[1].group_id().substr(0, 4) == "cast" &&
groups[0].group_id() == "reshape_split") {
fuse_group.push_back(groups[1]);
fuse_group.push_back(groups[0]);
ctx->MarkFusible(fuse_group);
continue;
}
}
ctx->MarkFusible(groups);
}
}
}
};
} // namespace pass
} // namespace hlir
} // namespace cinn
CINN_REGISTER_FUSION_PASS(DefaultHorizontalFusePass,
cinn::hlir::pass::DefaultHorizontalFusePass);
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/fusion_pass_registrar.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/horizontal_fuse_util.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/input_fuse_pass.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/input_fuse_pass_ctx.h"
namespace cinn {
namespace hlir {
namespace pass {
class DefaultInputFusePass final : public InputFusePass {
public:
DefaultInputFusePass() : InputFusePass() {}
int Benefit() const override { return 100; }
void operator()(InputFusePassCtx* ctx) const override {
const auto& consumer_set = ctx->PickConsumersWithSameInputs();
const std::unordered_set<OpGroupPtr> consumer_candidates =
[&]() -> std::unordered_set<OpGroupPtr> {
std::unordered_set<OpGroupPtr> consumers;
for (const auto& consumer : consumer_set) {
if (consumer.kind() == framework::kElementWise ||
consumer.kind() == framework::kBroadcast ||
consumer.kind() == framework::kInjective ||
consumer.kind() == framework::kReduction) {
consumers.insert(consumer);
}
}
return consumers;
}();
if (consumer_candidates.size() <= 1) {
return;
}
std::vector<OpGroupList> fusionable_consumers;
for (auto& candidate : consumer_candidates) {
if (ctx->fuse_helper().IsConsumerSetsReachable(candidate,
consumer_candidates)) {
continue;
}
if (fusionable_consumers.empty()) {
fusionable_consumers.push_back({candidate});
continue;
}
// check each fusionable groups
bool fusionable = false;
for (auto& groups : fusionable_consumers) {
auto& last = groups.back();
if (!HorizontalFuseUtil<InputFusePassCtx>::DetectFusabilityByKind(
ctx, candidate, last)) {
continue;
}
groups.push_back(candidate);
fusionable = true;
break;
}
// if can't fuse to othors Groups, new Groups.
if (!fusionable) {
fusionable_consumers.push_back({candidate});
}
}
for (const auto& groups : fusionable_consumers) {
if (groups.size() > 1) {
ctx->MarkFusible(groups);
}
}
VLOG(1) << "DefaultInputFusePass Finish";
}
};
} // namespace pass
} // namespace hlir
} // namespace cinn
CINN_REGISTER_FUSION_PASS(DefaultInputFusePass,
cinn::hlir::pass::DefaultInputFusePass);
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/fusion_pass_registrar.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/lightware_fuse_pass_ctx.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/recompute_fuse_pass.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/vertical_fuse_util.h"
namespace cinn {
namespace hlir {
namespace pass {
class DefaultRecomputeFusePass final : public RecomputeFusePass {
public:
DefaultRecomputeFusePass() : RecomputeFusePass() {}
int Benefit() const override { return 100; }
void operator()(LightwareFusePassCtx* ctx) const override {
const auto& producer = ctx->PickOpGroup();
const OpGroupList consumers = [&]() {
OpGroupList consumers;
for (const auto& consumer : producer.consumers()) {
consumers.push_back(consumer);
}
return consumers;
}();
// Borrows unsafe_candidates and candidates concept from origin
// fusion_merge_pass
std::vector<OpGroupPtr> unsafe_candidates;
std::vector<OpGroupPtr> candidates;
for (int i = 0; i < consumers.size(); ++i) {
const auto& consumer = consumers.at(i);
if (!VerticalFuseUtil::DetectFusabilityByKind(ctx, producer, consumer)) {
continue;
}
unsafe_candidates.push_back(consumer);
if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) {
continue;
}
candidates.push_back(consumer);
}
if (!candidates.empty() && unsafe_candidates.size() == consumers.size() &&
producer.kind() == framework::kElementWise) {
for (const auto& consumer : consumers) {
ctx->MarkFusible(producer, consumer);
}
}
}
};
} // namespace pass
} // namespace hlir
} // namespace cinn
CINN_REGISTER_FUSION_PASS(DefaultRecomputeFusePass,
cinn::hlir::pass::DefaultRecomputeFusePass);
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/fusion_pass_registrar.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/lightware_fuse_pass_ctx.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/vertical_fuse_pass.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/vertical_fuse_util.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass_utils.h"
namespace cinn {
namespace hlir {
namespace pass {
class DefaultVerticalFusePass final : public VerticalFusePass {
public:
DefaultVerticalFusePass() : VerticalFusePass() {}
int Benefit() const override { return 100; }
void operator()(LightwareFusePassCtx* ctx) const override {
const auto& producer = ctx->PickOpGroup();
const OpGroupList consumers = [&]() {
OpGroupList consumers;
for (const auto& consumer : producer.consumers()) {
consumers.push_back(consumer);
}
return consumers;
}();
if (consumers.size() == 0) {
return;
}
std::vector<OpGroupPtr> candidates;
for (int i = 0; i < consumers.size(); ++i) {
const auto& consumer = consumers.at(i);
if (!VerticalFuseUtil::DetectFusabilityByKind(ctx, producer, consumer)) {
break;
}
candidates.push_back(consumer);
}
if (candidates.size() == consumers.size() &&
producer.kind() == framework::kElementWise) {
return;
}
for (int i = 0; i < consumers.size(); ++i) {
const auto& consumer = consumers.at(i);
if (!VerticalFuseUtil::DetectFusabilityByKind(ctx, producer, consumer)) {
continue;
}
if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) {
VLOG(4) << "Can't fuse because detect cycle";
continue;
}
ctx->MarkFusible(producer, consumer);
}
}
};
} // namespace pass
} // namespace hlir
} // namespace cinn
CINN_REGISTER_FUSION_PASS(DefaultVerticalFusePass,
cinn::hlir::pass::DefaultVerticalFusePass);
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/api/op_group.h"
namespace cinn {
namespace hlir {
namespace pass {
using OpGroupPtr = api::OpGroup;
class FuseHelper {
public:
virtual ~FuseHelper() = default;
virtual bool AllOutputsSameSize(const OpGroupPtr& first,
const OpGroupPtr& second) const = 0;
virtual bool HorizontalElementwiseFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool ElementwiseFuseBroadcast(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool HorizontalWithInjective(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool ElementwiseFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool BroadcastFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool InjectiveHorizontalWithReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool ReduceFuseElementwise(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool ReduceFuseBroadcast(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool ReduceFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool IsReachable(const OpGroupPtr& lhs,
const OpGroupPtr& rhs) const = 0;
virtual bool DetectCycleIfFuse(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;
virtual bool IsConsumerSetsReachable(
const OpGroupPtr& group,
const std::unordered_set<OpGroupPtr>& consumers) const = 0;
protected:
FuseHelper() = default;
};
} // namespace pass
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/cinn/hlir/pass/use_general_pass.h"
namespace cinn {
namespace hlir {
namespace pass {
class FusePass {
public:
virtual ~FusePass() = default;
virtual const std::string FuseMode() const = 0;
virtual int Benefit() const = 0;
protected:
FusePass() = default;
};
} // namespace pass
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/api/op_group.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/fuse_helper.h"
namespace cinn {
namespace hlir {
namespace pass {
using OpGroupPtr = api::OpGroup;
class FusePassCtx {
public:
virtual ~FusePassCtx() {}
virtual const FuseHelper& fuse_helper() const = 0;
virtual void MarkFusible(const OpGroupPtr& first,
const OpGroupPtr& second) = 0;
protected:
FusePassCtx() = default;
};
} // namespace pass
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <set>
#include <unordered_map>
#include "glog/logging.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/fuse_pass.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/input_fuse_pass.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/lightware_fuse_pass.h"
namespace cinn {
namespace hlir {
namespace pass {
struct LightwareFusePassComparator {
bool operator()(const std::shared_ptr<LightwareFusePass>& lhs,
const std::shared_ptr<LightwareFusePass>& rhs) const {
return lhs->Benefit() > rhs->Benefit();
}
};
struct InputFusePassComparator {
bool operator()(const std::shared_ptr<InputFusePass>& lhs,
const std::shared_ptr<InputFusePass>& rhs) const {
return lhs->Benefit() > rhs->Benefit();
}
};
class FusionPassMap {
public:
static FusionPassMap& Instance() {
static FusionPassMap global_fusion_pass_map;
return global_fusion_pass_map;
}
bool Has(const std::string& pass_name) const {
return map_.find(pass_name) != map_.end();
}
void Insert(const std::string& pass_name,
const std::shared_ptr<FusePass>& pass) {
CHECK(!Has(pass_name)) << "FusePass " << pass_name
<< " has already been registered.";
map_.insert({pass_name, pass});
}
std::shared_ptr<FusePass> Get(const std::string& pass_name) const {
auto it = map_.find(pass_name);
CHECK(it != map_.end())
<< "FusePass " << pass_name << " has not been registered.";
return it->second;
}
// fuse_mode: HorizontalFuse, VerticalFuse, RecomputeFuse
std::vector<std::shared_ptr<LightwareFusePass>> GetLightwareFusePassesByMode(
const std::string& fuse_mode) const {
CHECK(fuse_mode == "HorizontalFuse" || fuse_mode == "VerticalFuse" ||
fuse_mode == "RecomputeFuse")
<< "fuse_mode only supports HorizontalFuse, VerticalFuse and "
"RecomputeFuse. Please check your input modes = "
<< fuse_mode;
std::set<std::shared_ptr<LightwareFusePass>, LightwareFusePassComparator>
candidate_passes;
for (const auto iter : map_) {
if (fuse_mode == iter.second->FuseMode()) {
candidate_passes.insert(
std::dynamic_pointer_cast<LightwareFusePass>(iter.second));
}
}
return std::vector<std::shared_ptr<LightwareFusePass>>(
candidate_passes.begin(), candidate_passes.end());
}
std::vector<std::shared_ptr<InputFusePass>> GetInputFusePasses() const {
std::set<std::shared_ptr<InputFusePass>, InputFusePassComparator>
candidate_passes;
for (const auto iter : map_) {
if (iter.second->FuseMode() == "InputFuse") {
candidate_passes.insert(
std::dynamic_pointer_cast<InputFusePass>(iter.second));
}
}
return std::vector<std::shared_ptr<InputFusePass>>(candidate_passes.begin(),
candidate_passes.end());
}
private:
FusionPassMap() = default;
std::unordered_map<std::string, std::shared_ptr<FusePass>> map_;
DISABLE_COPY_AND_ASSIGN(FusionPassMap);
};
} // namespace pass
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/fusion_pass_map.h"
namespace cinn {
namespace hlir {
namespace pass {
class Registrar {
public:
// In our design, various kinds of classes, e.g., operators and kernels,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which
// are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_OP macros to
// call this method. So, as long as the callee code calls USE_OP, the global
// registrar variable won't be removed by the linker.
void Touch() {}
};
template <typename PassClassT>
class FusionPassRegistrar final : public Registrar {
public:
explicit FusionPassRegistrar(const std::string& pass_name) {
FusionPassMap::Instance().Insert(
pass_name, std::shared_ptr<PassClassT>(new PassClassT()));
}
};
} // namespace pass
} // namespace hlir
} // namespace cinn
#define CINN_REGISTER_FUSION_PASS(pass_name, pass_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_pass__##pass_name, \
"CINN_REGISTER_FUSION_PASS must be called in global namespace"); \
static ::cinn::hlir::pass::FusionPassRegistrar<pass_class> \
__pass_registrar_##pass_name##__(#pass_name); \
int TouchFusionPassRegistrar_##pass_name() { \
__pass_registrar_##pass_name##__.Touch(); \
return 0; \
}
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/common/is_reachable_predicator.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/fuse_helper.h"
namespace cinn {
namespace hlir {
namespace pass {
template <typename FusePassCtxT>
class GraphGroupFuseHelper final : public FuseHelper {
public:
explicit GraphGroupFuseHelper(const FusePassCtxT* ctx) : ctx_(ctx) {}
bool AllOutputsSameSize(const OpGroupPtr& first,
const OpGroupPtr& second) const override;
bool HorizontalElementwiseFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool ElementwiseFuseBroadcast(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool HorizontalWithInjective(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool ElementwiseFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool BroadcastFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool InjectiveHorizontalWithReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool ReduceFuseElementwise(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool ReduceFuseBroadcast(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool ReduceFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;
bool IsReachable(const OpGroupPtr& lhs,
const OpGroupPtr& rhs) const override {
return IsReachableInDag(lhs, rhs) || IsReachableInDag(rhs, lhs);
}
bool DetectCycleIfFuse(const OpGroupPtr& lhs,
const OpGroupPtr& rhs) const override {
return ReachableIfDirectEdgeIgnored(lhs, rhs) ||
ReachableIfDirectEdgeIgnored(rhs, lhs);
}
bool IsConsumerSetsReachable(
const OpGroupPtr& group,
const std::unordered_set<OpGroupPtr>& consumers) const override {
for (const auto& consumer : consumers) {
if (group == consumer) {
continue;
}
if (IsReachableInDag(consumer, group)) {
return true;
}
}
return false;
}
private:
bool IsReachableInDag(const OpGroupPtr& producer,
const OpGroupPtr& consumer) const {
const auto& MinDepth4Node = [&](const OpGroupPtr& node) {
return node.GetGroup()->min_depth;
};
const auto& MaxDepth4Node = [&](const OpGroupPtr& node) {
return node.GetGroup()->max_depth;
};
const auto& VisitNextNodes =
[&](const OpGroupPtr& node,
const std::function<void(OpGroupPtr)>& Visit) {
for (const auto& node_producer : node.producers()) {
Visit(node_producer);
}
};
common::IsReachablePredicator<OpGroupPtr> is_reachable(
MinDepth4Node, MaxDepth4Node, VisitNextNodes);
return is_reachable(consumer, producer, [](OpGroupPtr) {});
}
bool ReachableIfDirectEdgeIgnored(const OpGroupPtr& producer,
const OpGroupPtr& consumer) const {
const auto& MinDepth4Node = [&](const OpGroupPtr& node) {
return node.GetGroup()->min_depth;
};
const auto& MaxDepth4Node = [&](const OpGroupPtr& node) {
return node.GetGroup()->max_depth;
};
const auto& VisitNextNodes =
[&](const OpGroupPtr& node,
const std::function<void(OpGroupPtr)>& Visit) {
for (const auto& node_producer : node.producers()) {
if (node == consumer && node_producer == producer) {
continue;
}
Visit(node_producer);
}
};
common::IsReachablePredicator<OpGroupPtr> is_reachable(
MinDepth4Node, MaxDepth4Node, VisitNextNodes);
return is_reachable(consumer, producer, [](OpGroupPtr) {});
}
const FusePassCtxT* ctx_;
};
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::AllOutputsSameSize(
const OpGroupPtr& first, const OpGroupPtr& second) const {
return is_same_size(
&ctx_->graph_group_fusion_helper(), first.GetGroup(), second.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::HorizontalElementwiseFuseReduce(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return honrizontal_elementwise_fuse_reduce(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::ElementwiseFuseBroadcast(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return elementwise_fuse_broadcast(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::HorizontalWithInjective(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return horizontal_with_injective(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::ElementwiseFuseReduce(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return elementwise_fuse_reduce(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::BroadcastFuseReduce(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return broadcast_fuse_reduce(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::InjectiveHorizontalWithReduce(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return injective_horizontal_with_reduce(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::ReduceFuseElementwise(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return reduce_fuse_elementwise(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::ReduceFuseBroadcast(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return reduce_fuse_broadcast(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::ReduceFuseReduce(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return reduce_fuse_reduce(
&ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup());
}
} // namespace pass
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/graph_group_fuse_helper.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/input_fuse_pass_ctx.h"
namespace cinn {
namespace hlir {
namespace pass {
class GraphGroupInputFusePassCtx;
class GraphGroupInputFusePassCtx final : public InputFusePassCtx {
public:
GraphGroupInputFusePassCtx(
const FusionHelperBase* graph_group_fusion_helper,
const OpGroupList& groups,
const std::function<void(const OpGroupPtr& first,
const OpGroupPtr& second)>& MarkFusible)
: graph_group_fusion_helper_(graph_group_fusion_helper),
groups_(groups),
MarkFusible_(MarkFusible),
fuse_helper_(
new GraphGroupFuseHelper<GraphGroupInputFusePassCtx>(this)) {}
GraphGroupInputFusePassCtx(
const FusionHelperBase* graph_group_fusion_helper,
const OpGroupList& groups,
const std::function<void(const OpGroupList& candidates)>&
MarkGroupListFusible)
: graph_group_fusion_helper_(graph_group_fusion_helper),
groups_(groups),
MarkGroupListFusible_(MarkGroupListFusible),
fuse_helper_(
new GraphGroupFuseHelper<GraphGroupInputFusePassCtx>(this)) {}
const OpGroupList& PickConsumersWithSameInputs() const override {
return groups_;
}
const FuseHelper& fuse_helper() const override { return *fuse_helper_; }
void MarkFusible(const OpGroupPtr& first, const OpGroupPtr& second) override {
MarkFusible_(first, second);
}
void MarkFusible(const OpGroupList& candidates) override {
MarkGroupListFusible_(candidates);
}
const FusionHelperBase& graph_group_fusion_helper() const {
return *graph_group_fusion_helper_;
}
private:
const FusionHelperBase* graph_group_fusion_helper_;
const OpGroupList& groups_;
const std::function<void(const OpGroupPtr& first, const OpGroupPtr& second)>
MarkFusible_;
const std::function<void(const OpGroupList& candidates)>
MarkGroupListFusible_;
const std::unique_ptr<const FuseHelper> fuse_helper_;
};
} // namespace pass
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/graph_group_fuse_helper.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/lightware_fuse_pass_ctx.h"
namespace cinn {
namespace hlir {
namespace pass {
class GraphGroupLightwareFusePassCtx;
class GraphGroupLightwareFusePassCtx final : public LightwareFusePassCtx {
public:
GraphGroupLightwareFusePassCtx(
const FusionHelperBase* graph_group_fusion_helper,
const OpGroupPtr& group,
const std::function<void(const OpGroupPtr& first,
const OpGroupPtr& second)>& MarkFusible)
: graph_group_fusion_helper_(graph_group_fusion_helper),
group_(group),
MarkFusible_(MarkFusible),
fuse_helper_(
new GraphGroupFuseHelper<GraphGroupLightwareFusePassCtx>(this)) {}
GraphGroupLightwareFusePassCtx(
const FusionHelperBase* graph_group_fusion_helper,
const OpGroupPtr& group,
const std::function<void(const OpGroupList& candidates)>&
MarkGroupListFusible)
: graph_group_fusion_helper_(graph_group_fusion_helper),
group_(group),
MarkGroupListFusible_(MarkGroupListFusible),
fuse_helper_(
new GraphGroupFuseHelper<GraphGroupLightwareFusePassCtx>(this)) {}
const OpGroupPtr& PickOpGroup() const override { return group_; }
const FuseHelper& fuse_helper() const override { return *fuse_helper_; }
void MarkFusible(const OpGroupPtr& first, const OpGroupPtr& second) override {
MarkFusible_(first, second);
}
void MarkFusible(const OpGroupList& candidates) override {
MarkGroupListFusible_(candidates);
}
const FusionHelperBase& graph_group_fusion_helper() const {
return *graph_group_fusion_helper_;
}
private:
const FusionHelperBase* graph_group_fusion_helper_;
const OpGroupPtr& group_;
const std::function<void(const OpGroupPtr& first, const OpGroupPtr& second)>
MarkFusible_;
const std::function<void(const OpGroupList& candidates)>
MarkGroupListFusible_;
const std::unique_ptr<const FuseHelper> fuse_helper_;
};
} // namespace pass
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/lightware_fuse_pass.h"
namespace cinn {
namespace hlir {
namespace pass {
class HorizontalFusePass : public LightwareFusePass {
public:
virtual ~HorizontalFusePass() = default;
virtual void operator()(LightwareFusePassCtx* ctx) const = 0;
const std::string FuseMode() const final { return "HorizontalFuse"; }
virtual int Benefit() const = 0;
protected:
HorizontalFusePass() = default;
};
} // namespace pass
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/api/op_group.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass_utils.h"
namespace cinn {
namespace hlir {
namespace pass {
using OpGroupPtr = api::OpGroup;
using framework::OpPatternKind;
template <typename FusePassCtxT>
struct HorizontalFuseUtil {
using KindKeyT = std::pair<OpPatternKind, OpPatternKind>;
static bool DetectFusabilityByKind(FusePassCtxT* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
const KindKeyT kind_pair(src.kind(), dst.kind());
const auto& map = GetConditionMap();
const auto& iter = map.find(kind_pair);
if (iter == map.end()) {
return false;
}
return iter->second(ctx, src, dst);
}
typedef bool (*ConditionT)(FusePassCtxT* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst);
static const std::map<KindKeyT, ConditionT>& GetConditionMap() {
thread_local static std::map<KindKeyT, ConditionT> map(RawConditionMap());
return map;
}
static std::map<KindKeyT, ConditionT> RawConditionMap() {
return std::map<KindKeyT, ConditionT>{
{{OpPatternKind::kElementWise, framework::kElementWise}, &IsSameSize},
{{OpPatternKind::kElementWise, framework::kBroadcast}, &IsSameSize},
{{OpPatternKind::kElementWise, framework::kInjective}, &IsSameSize},
{{OpPatternKind::kElementWise, framework::kReduction},
&HorizontalElementwiseFuseReduce},
{{OpPatternKind::kBroadcast, framework::kElementWise}, &IsSameSize},
{{OpPatternKind::kBroadcast, framework::kBroadcast}, &IsSameSize},
{{OpPatternKind::kBroadcast, framework::kInjective}, &IsSameSize},
{{OpPatternKind::kBroadcast, framework::kReduction}, &IsSameSize},
{{OpPatternKind::kInjective, framework::kElementWise}, &IsSameSize},
{{OpPatternKind::kInjective, framework::kBroadcast}, &IsSameSize},
{{OpPatternKind::kInjective, framework::kInjective}, &IsSameSize},
{{OpPatternKind::kInjective, framework::kReduction}, &IsSameSize},
{{OpPatternKind::kReduction, framework::kElementWise},
&HorizontalElementwiseFuseReduce},
{{OpPatternKind::kReduction, framework::kBroadcast}, &IsSameSize},
{{OpPatternKind::kReduction, framework::kInjective}, &IsSameSize},
{{OpPatternKind::kReduction, framework::kReduction}, &ReduceFuseReduce},
};
}
static bool IsSameSize(FusePassCtxT* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return utils::IsSameSize(src, dst);
}
static bool HorizontalElementwiseFuseReduce(FusePassCtxT* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
// if same shape with horizontal relation
if (IsSameSize(ctx, src, dst)) {
return true;
}
const OpGroupPtr* ele_group = nullptr;
const OpGroupPtr* reduce_group = nullptr;
if (src.kind() == framework::kReduction) {
ele_group = &dst;
reduce_group = &src;
} else {
ele_group = &src;
reduce_group = &dst;
}
size_t size_ele =
utils::GetMasterNode(*ele_group).outputs()[0].shape().numel();
bool can_fuse = false;
reduce_group->WalkOpNodes([&](const api::OpNode& op) {
if (op.kind() == OpPatternKind::kReduction) {
size_t size_master = op.outputs()[0].shape().numel();
if (size_ele == size_master) {
can_fuse = true;
}
}
});
return can_fuse;
}
static bool ReduceFuseReduce(FusePassCtxT* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().ReduceFuseReduce(src, dst);
}
};
} // namespace pass
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/fuse_pass.h"
namespace cinn {
namespace hlir {
namespace pass {
class InputFusePassCtx;
class InputFusePass : public FusePass {
public:
virtual ~InputFusePass() = default;
virtual void operator()(InputFusePassCtx* ctx) const = 0;
const std::string FuseMode() const final { return "InputFuse"; }
virtual int Benefit() const = 0;
protected:
InputFusePass() = default;
};
} // namespace pass
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/fuse_pass_ctx.h"
namespace cinn {
namespace hlir {
namespace pass {
using OpGroupList = std::vector<OpGroupPtr>;
class InputFusePassCtx : public FusePassCtx {
public:
virtual ~InputFusePassCtx() {}
virtual const OpGroupList& PickConsumersWithSameInputs() const = 0;
virtual const FuseHelper& fuse_helper() const = 0;
virtual void MarkFusible(const OpGroupPtr& first,
const OpGroupPtr& second) = 0;
virtual void MarkFusible(const OpGroupList& candidates) = 0;
protected:
InputFusePassCtx() = default;
};
} // namespace pass
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/fuse_pass.h"
namespace cinn {
namespace hlir {
namespace pass {
class LightwareFusePassCtx;
class LightwareFusePass : public FusePass {
public:
virtual ~LightwareFusePass() = default;
virtual void operator()(LightwareFusePassCtx* ctx) const = 0;
virtual const std::string FuseMode() const = 0;
virtual int Benefit() const = 0;
protected:
LightwareFusePass() = default;
};
} // namespace pass
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/fuse_pass_ctx.h"
namespace cinn {
namespace hlir {
namespace pass {
using OpGroupList = std::vector<OpGroupPtr>;
class LightwareFusePassCtx : public FusePassCtx {
public:
virtual ~LightwareFusePassCtx() {}
virtual const OpGroupPtr& PickOpGroup() const = 0;
virtual const FuseHelper& fuse_helper() const = 0;
virtual void MarkFusible(const OpGroupPtr& first,
const OpGroupPtr& second) = 0;
virtual void MarkFusible(const OpGroupList& candidates) = 0;
protected:
LightwareFusePassCtx() = default;
};
} // namespace pass
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/lightware_fuse_pass.h"
namespace cinn {
namespace hlir {
namespace pass {
class RecomputeFusePass : public LightwareFusePass {
public:
virtual ~RecomputeFusePass() = default;
virtual void operator()(LightwareFusePassCtx* ctx) const = 0;
const std::string FuseMode() const final { return "RecomputeFuse"; }
virtual int Benefit() const = 0;
protected:
RecomputeFusePass() = default;
};
} // namespace pass
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/lightware_fuse_pass.h"
namespace cinn {
namespace hlir {
namespace pass {
class VerticalFusePass : public LightwareFusePass {
public:
virtual ~VerticalFusePass() = default;
virtual void operator()(LightwareFusePassCtx* ctx) const = 0;
const std::string FuseMode() const final { return "VerticalFuse"; }
virtual int Benefit() const = 0;
protected:
VerticalFusePass() = default;
};
} // namespace pass
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/api/op_group.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/lightware_fuse_pass_ctx.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass_utils.h"
namespace cinn {
namespace hlir {
namespace pass {
using OpGroupPtr = api::OpGroup;
using framework::OpPatternKind;
struct VerticalFuseUtil {
using KindKeyT = std::pair<OpPatternKind, OpPatternKind>;
static bool DetectFusabilityByKind(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
const KindKeyT kind_pair(src.kind(), dst.kind());
const auto& map = GetConditionMap();
const auto& iter = map.find(kind_pair);
if (iter == map.end()) {
return false;
}
return iter->second(ctx, src, dst);
}
typedef bool (*ConditionT)(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst);
static const std::map<KindKeyT, ConditionT>& GetConditionMap() {
thread_local static std::map<KindKeyT, ConditionT> map(RawConditionMap());
return map;
}
static std::map<KindKeyT, ConditionT> RawConditionMap() {
return std::map<KindKeyT, ConditionT>{
{{OpPatternKind::kElementWise, framework::kElementWise}, &IsSameSize},
{{OpPatternKind::kElementWise, framework::kBroadcast},
&ElementwiseFuseBroadcast},
{{OpPatternKind::kElementWise, framework::kInjective},
&HorizontalWithInjective},
{{OpPatternKind::kElementWise, framework::kReduction},
&ElementwiseFuseReduce},
{{OpPatternKind::kBroadcast, framework::kElementWise}, &IsSameSize},
{{OpPatternKind::kBroadcast, framework::kBroadcast}, &IsSameSize},
{{OpPatternKind::kBroadcast, framework::kInjective},
&HorizontalWithInjective},
{{OpPatternKind::kBroadcast, framework::kReduction},
&BroadcastFuseReduce},
{{OpPatternKind::kInjective, framework::kElementWise}, &IsSameSize},
{{OpPatternKind::kInjective, framework::kBroadcast}, &IsSameSize},
{{OpPatternKind::kInjective, framework::kInjective},
&HorizontalWithInjective},
{{OpPatternKind::kInjective, framework::kReduction},
&InjectiveHorizontalWithReduce},
{{OpPatternKind::kReduction, framework::kElementWise},
&ReduceFuseElementwise},
{{OpPatternKind::kReduction, framework::kBroadcast},
&ReduceFuseBroadcast},
{{OpPatternKind::kReduction, framework::kInjective},
&HorizontalWithInjective},
{{OpPatternKind::kReduction, framework::kReduction}, &ReduceFuseReduce},
};
}
static bool IsSameSize(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return utils::IsSameSize(src, dst);
}
static bool ElementwiseFuseBroadcast(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().ElementwiseFuseBroadcast(src, dst);
}
static bool HorizontalWithInjective(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().HorizontalWithInjective(src, dst);
}
static bool ElementwiseFuseReduce(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().ElementwiseFuseReduce(src, dst);
}
static bool BroadcastFuseReduce(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().BroadcastFuseReduce(src, dst);
}
static bool InjectiveHorizontalWithReduce(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().InjectiveHorizontalWithReduce(src, dst);
}
static bool ReduceFuseElementwise(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().ReduceFuseElementwise(src, dst);
}
static bool ReduceFuseBroadcast(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().ReduceFuseBroadcast(src, dst);
}
static bool ReduceFuseReduce(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().ReduceFuseReduce(src, dst);
}
};
} // namespace pass
} // namespace hlir
} // namespace cinn
......@@ -113,8 +113,8 @@ static bool limit_args(const OpGroupPtr& first, const OpGroupPtr& second) {
}
}
bool WithoutLastDimInReduce(const api::Shape& inshape,
const std::vector<int>& axes) {
inline bool WithoutLastDimInReduce(const api::Shape& inshape,
const std::vector<int>& axes) {
// if last axis is in reduce.
if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() ||
std::find(axes.begin(), axes.end(), -1) != axes.end()) {
......
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/hlir/pass/use_general_pass.h"
CINN_USE_REGISTER(InferShape)
CINN_USE_REGISTER(OpFusion)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册