diff --git a/cmake/cinn/core.cmake b/cmake/cinn/core.cmake index c54aea3cc29dc8dc7453cb64b38c73b71a443e84..b6ad32bb746f4b6fc812c6c33288bda7be4d3491 100644 --- a/cmake/cinn/core.cmake +++ b/cmake/cinn/core.cmake @@ -433,6 +433,28 @@ function(download_and_uncompress INSTALL_DIR URL FILENAME) INSTALL_COMMAND "") endfunction() +set(fusion_pass_file + ${CMAKE_CURRENT_BINARY_DIR}/paddle/cinn/hlir/pass/use_general_pass.h + CACHE INTERNAL "use_general_pass.h file") +file( + WRITE ${fusion_pass_file} + "#include \"paddle/cinn/common/macros.h\" // Generated by the paddle/cinn/hlir/pass/CMakeLists.txt. DO NOT EDIT!\n\n" +) + +function(find_fusion_pass_register FILENAME ADD_PATH PATTERN) + # set op_name to OUTPUT + file(READ ${FILENAME} CONTENT) + string(REGEX MATCHALL "${PATTERN}\\([a-zA-Z0-9_]*," fusion_pass_patterns + "${CONTENT}") + if(NOT fusion_pass_patterns STREQUAL "") + foreach(pass_pattern ${fusion_pass_patterns}) + string(REPLACE "${PATTERN}(" "" pass_pattern "${pass_pattern}") + string(REPLACE "," "" pass_pattern "${pass_pattern}") + file(APPEND ${ADD_PATH} "USE_FUSION_PASS(${pass_pattern});\n") + endforeach() + endif() +endfunction() + function(gather_srcs SRC_GROUP) set(options) set(oneValueArgs) @@ -442,6 +464,8 @@ function(gather_srcs SRC_GROUP) set(${SRC_GROUP} "${${SRC_GROUP}};${CMAKE_CURRENT_SOURCE_DIR}/${cpp}" CACHE INTERNAL "") + find_fusion_pass_register("${CMAKE_CURRENT_SOURCE_DIR}/${cpp}" + ${fusion_pass_file} "CINN_REGISTER_FUSION_PASS") endforeach() endfunction() diff --git a/paddle/cinn/CMakeLists.txt b/paddle/cinn/CMakeLists.txt index 84912f18cbc509a4537a0d6b40c9ca498a291c96..4166e6453b4202ac951c3c706cc8e122f67e1c85 100644 --- a/paddle/cinn/CMakeLists.txt +++ b/paddle/cinn/CMakeLists.txt @@ -2,6 +2,7 @@ if(WITH_TESTING) cinn_cc_library(cinn_gtest_main SRCS gtest_main.cc DEPS gtest gflags) endif() +add_subdirectory(api) add_subdirectory(auto_schedule) add_subdirectory(common) add_subdirectory(utils) diff --git a/paddle/cinn/api/CMakeLists.txt b/paddle/cinn/api/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..2806b33154c8b45081366b503478a3529d940808 --- /dev/null +++ b/paddle/cinn/api/CMakeLists.txt @@ -0,0 +1,5 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS op_node.cc tensor_node.cc) + +message(STATUS "srcs: ${cinnapi_src}") diff --git a/paddle/cinn/api/README.md b/paddle/cinn/api/README.md new file mode 100644 index 0000000000000000000000000000000000000000..999f27c41cdbe71504155109b5411822a98aae2e --- /dev/null +++ b/paddle/cinn/api/README.md @@ -0,0 +1,45 @@ +The classes in this directory are the interface of group fusion pass, you can use these apis to build the stragey for group fusion. + + +The Class and APIs are following: + +`OpGroup` : A set of op nodes, which will pass to cinn backend for generating kernel code. Two groups can fuse togather according to the rule of merging written in the passes. + +`OpNode` : Map the op in the program. + +`TensorNode` : Map the tensor in the program. + +`Shape` : The shape infomation of tensor + +`FusePassCtx` : The context is the parameter for the pass, it hold the data all you need in the pass. + +`FuseHelper` : We provide some util methods such as `DetectCycleIfFuse` in fuse_helper to simplify development of pass. + +| Class | method | description | +| :--: | :--: | :--: | +| OpGroup | kind()| Get the Kind of group | +| | producers()| Get producer groups of current group | +| | consumers() | Get consumer groups of current group | +| | WalkOpNodes(const std::function& VisitOpNode) | Visit the op_nodes in the group and execute the VisitOpNode function for each OpNode | +| | | | +| OpNode | kind() | Get the Kind of op_node | +| | inputs() | Get input tensors of op_node | +| | outputs() | Get output tensors of op_node | +| | GetAttr(const std::string& attr_name) | Get attribute of op_node by attr name | +| | | | +| TensorNode | shape() | Get shape of tensor | +| | producer() | Get the producer op_node of tensor | +| | consumers() | Get the consumer op_nodes of tensor | +| | | | +| Shape | numel() | Get total number of elements in the shape | +| | other methods are same with std::vector | | +| | | | +| LightwareFusePassCtx | PickOpGroup() | Get the current group in the pass context | +| | void EnableFuse(const OpGroup& first, const OpGroup& second) | Mark the two groups which can fuse togather | +| | fuse_helper() | Get the fuse_helper provided by pass context | +| | | | +| InputFusePassCtx | PickConsumersWithSameInputs() | Get all consumer groups for input tensors of graph | +| | void EnableFuse(const OpGroup& first, const OpGroup& second) | Mark the two groups which can fuse togather | +| | fuse_helper() | Get the fuse_helper provided by pass context | +| | | | +| FuseHelper | DetectCycleIfFuse(const OpGroup& first, const OpGroup& second) | Whether there is cycle in graph after fusing two groups | diff --git a/paddle/cinn/api/op_group.h b/paddle/cinn/api/op_group.h new file mode 100644 index 0000000000000000000000000000000000000000..7e9a2581a9535b4d9d1c0daf6d9c1faabce053f7 --- /dev/null +++ b/paddle/cinn/api/op_group.h @@ -0,0 +1,202 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/cinn/api/op_node.h" + +#include "paddle/cinn/hlir/framework/graph.h" +#include "paddle/cinn/hlir/pass/fusion_helper_base.h" + +namespace cinn { +namespace api { + +class OpGroup { + public: + explicit OpGroup(const std::shared_ptr& group) + : group_(group) {} + + OpGroup(const OpGroup& other) = default; + + using Comparator = hlir::framework::Graph::Group::SharedGroupComparator; + using Hasher = hlir::framework::Graph::Group::SharedGroupHasher; + + class OpGroupListIterator { + public: + OpGroupListIterator( + std::unordered_set, + Hasher, + Comparator>::const_iterator it) + : iter_(it) {} + + OpGroupListIterator& operator++() { + ++iter_; + return *this; + } + + OpGroupListIterator operator++(int) { + OpGroupListIterator tmp = *this; + ++iter_; + return tmp; + } + + bool operator==(const OpGroupListIterator& other) const { + return iter_ == other.iter_; + } + + bool operator!=(const OpGroupListIterator& other) const { + return !(*this == other); + } + + OpGroup operator*() const { return OpGroup(*iter_); } + + private: + std::unordered_set, + Hasher, + Comparator>::const_iterator iter_; + }; + + class ProducerOpGroupListView { + public: + ProducerOpGroupListView( + const std::weak_ptr& group) + : group_(group) {} + + ProducerOpGroupListView(const ProducerOpGroupListView& other) = delete; + ProducerOpGroupListView(ProducerOpGroupListView&& other) = delete; + + ProducerOpGroupListView& operator=(const ProducerOpGroupListView& other) = + delete; + + using const_iterator = OpGroupListIterator; + + size_t size() const { + CHECK(group_.lock()); + return group_.lock()->producer_groups().size(); + } + + const_iterator begin() const { + CHECK(group_.lock()); + return const_iterator(group_.lock()->producer_groups().begin()); + } + + const_iterator end() const { + CHECK(group_.lock()); + return const_iterator(group_.lock()->producer_groups().end()); + } + + private: + const std::weak_ptr group_; + }; + + class ConsumerOpGroupListView { + public: + ConsumerOpGroupListView( + const std::weak_ptr& group) + : group_(group) {} + + ConsumerOpGroupListView(const ConsumerOpGroupListView& other) = delete; + ConsumerOpGroupListView(ConsumerOpGroupListView&& other) = delete; + + ConsumerOpGroupListView& operator=(const ConsumerOpGroupListView& other) = + delete; + + using const_iterator = OpGroupListIterator; + + size_t size() const { + CHECK(group_.lock()); + return group_.lock()->consumer_groups().size(); + } + + const_iterator begin() const { + CHECK(group_.lock()); + return const_iterator(group_.lock()->consumer_groups().begin()); + } + + const_iterator end() const { + CHECK(group_.lock()); + return const_iterator(group_.lock()->consumer_groups().end()); + } + + private: + const std::weak_ptr group_; + }; + + const std::string& group_id() const { return group_.lock()->group_id; } + + hlir::framework::OpPatternKind kind() const { return group_.lock()->kind(); } + + // The WalkOpNodes function is used to traverse the op_nodes in the group and + // execute the VisitOpNode function for each OpNode. This function is + // equivalent to for loop for op_nodes in graph. + // + // In order to avoid unnecessary memory copies, we use WalkOpNodes function + // instead of providing a function to get all op_nodes directly. + // + // Example: Get the all Reduction op_nodes in the group. + // OpGroup group = ...; + // std::set reduce_ op_set; + // // The lambda funtion of VisitOpNode to get reduction op_nodes. + // auto get_reduce_op = [&reduce_op_set](const api::OpNode& op){ + // if (op.kind() == OpPatternKind::kReduction) { + // reduce_op_set.insert(op); + // } + // }; + // group.WalkOpNodes(get_reduce_op); + void WalkOpNodes( + const std::function& VisitOpNode) const { + group_.lock()->WalkNodes([&](const hlir::framework::Node* node) { + VisitOpNode(OpNode(node, group_.lock()->graph_)); + }); + } + + ProducerOpGroupListView producers() const { + return ProducerOpGroupListView(group_); + } + + ConsumerOpGroupListView consumers() const { + return ConsumerOpGroupListView(group_); + } + + std::shared_ptr GetGroup() const { + return group_.lock(); + } + + bool operator==(const OpGroup& other) const { + return group_.lock().get() == other.group_.lock().get(); + } + + bool operator<(const OpGroup& other) const { + return group_.lock().get() < other.group_.lock().get(); + } + + private: + const std::weak_ptr group_; +}; + +} // namespace api +} // namespace cinn + +namespace std { + +template <> +struct hash { + size_t operator()(const cinn::api::OpGroup& obj) const { + return std::hash()(reinterpret_cast(obj.GetGroup().get())); + } +}; + +} // namespace std diff --git a/paddle/cinn/api/op_node.cc b/paddle/cinn/api/op_node.cc new file mode 100644 index 0000000000000000000000000000000000000000..7b0562052e0fc1baa6fd65c850ae5efc41f62994 --- /dev/null +++ b/paddle/cinn/api/op_node.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/api/op_node.h" + +namespace cinn { +namespace api { + +TensorNode OpNode::TensorListIterator::operator*() const { + return TensorNode(get_tensor_from_edge_(*iter_), graph_); +} + +TensorNode OpNode::InputTensorListView::operator[](size_t index) const { + return TensorNode( + edges_[index]->source()->safe_as(), graph_); +} + +TensorNode OpNode::OutputTensorListView::operator[](size_t index) const { + return TensorNode(edges_[index]->sink()->safe_as(), + graph_); +} + +} // namespace api +} // namespace cinn diff --git a/paddle/cinn/api/op_node.h b/paddle/cinn/api/op_node.h new file mode 100644 index 0000000000000000000000000000000000000000..5bb9a79f1be88ec19e0a0e3c78be4c43c31a7bd4 --- /dev/null +++ b/paddle/cinn/api/op_node.h @@ -0,0 +1,208 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/cinn/api/tensor_node.h" +#include "paddle/cinn/hlir/framework/graph.h" +#include "paddle/cinn/hlir/framework/op.h" +#include "paddle/cinn/hlir/pass/fusion_helper_base.h" + +namespace cinn { +namespace api { + +class OpNode { + public: + OpNode(const hlir::framework::Node* node, const hlir::framework::Graph* graph) + : node_(node), + graph_(graph), + input_tensors_(node->inlinks_in_order(), graph_), + output_tensors_(node->outlinks_in_order(), graph_) {} + + OpNode(const OpNode& other) + : node_(other.node_), + graph_(other.graph_), + input_tensors_(node_->inlinks_in_order(), graph_), + output_tensors_(node_->outlinks_in_order(), graph_) {} + + using OpPatternKind = cinn::hlir::framework::OpPatternKind; + + OpPatternKind kind() const { + static const hlir::framework::OpValueType& op_pattern_dict = + hlir::framework::Operator::GetAttrs("OpPattern"); + auto kind = op_pattern_dict[node_->op()]; + + if (kind == hlir::framework::kBroadcast) { + // As binary op was defined as broadcast, actually it should be + // element-wise. + if (node_->op()->name != "broadcast_to") { + return hlir::framework::kElementWise; + } + } + return kind; + } + + class TensorListIterator { + public: + TensorListIterator( + std::vector>::const_iterator it, + const hlir::framework::Graph* graph, + std::function)> get_tensor_from_edge) + : iter_(it), + graph_(graph), + get_tensor_from_edge_(get_tensor_from_edge) {} + + TensorListIterator& operator++() { + ++iter_; + return *this; + } + + TensorListIterator operator++(int) { + TensorListIterator tmp = *this; + ++iter_; + return tmp; + } + + bool operator==(const TensorListIterator& other) const { + return iter_ == other.iter_; + } + + bool operator!=(const TensorListIterator& other) const { + return !(*this == other); + } + + TensorNode operator*() const; + + private: + std::vector>::const_iterator iter_; + const hlir::framework::Graph* graph_; + std::function)> + get_tensor_from_edge_; + }; + + using const_iterator = TensorListIterator; + + class InputTensorListView { + public: + InputTensorListView( + const std::vector>& edges, + const hlir::framework::Graph* graph) + : edges_(edges), graph_(graph) {} + + InputTensorListView(const InputTensorListView& other) = delete; + InputTensorListView(InputTensorListView&& other) = delete; + + InputTensorListView& operator=(const InputTensorListView& other) = delete; + + size_t size() const { return edges_.size(); } + + TensorNode operator[](size_t index) const; + + const_iterator begin() const { + return const_iterator( + edges_.begin(), graph_, [](common::Shared edge) { + return edge->source()->safe_as(); + }); + } + + const_iterator end() const { + return const_iterator( + edges_.end(), graph_, [](common::Shared edge) { + return edge->source()->safe_as(); + }); + } + + private: + std::vector> edges_; + const hlir::framework::Graph* graph_; + }; + + class OutputTensorListView { + public: + OutputTensorListView( + const std::vector>& edges, + const hlir::framework::Graph* graph) + : edges_(edges), graph_(graph) {} + + OutputTensorListView(const OutputTensorListView& other) = delete; + OutputTensorListView(OutputTensorListView&& other) = delete; + + OutputTensorListView& operator=(const OutputTensorListView& other) = delete; + + size_t size() const { return edges_.size(); } + + TensorNode operator[](size_t index) const; + + const_iterator begin() const { + return const_iterator( + edges_.begin(), graph_, [](common::Shared edge) { + return edge->sink()->safe_as(); + }); + } + + const_iterator end() const { + return const_iterator( + edges_.end(), graph_, [](common::Shared edge) { + return edge->sink()->safe_as(); + }); + } + + private: + std::vector> edges_; + const hlir::framework::Graph* graph_; + }; + + bool operator==(const OpNode& other) const { return node_ == other.node_; } + + bool operator<(const OpNode& other) const { return node_ < other.node_; } + + const InputTensorListView& inputs() const { return input_tensors_; } + + const OutputTensorListView& outputs() const { return output_tensors_; } + + template + const T& GetAttr(const std::string& attr_name) const { + return absl::get(GetAttr(attr_name)); + } + + private: + using Attribute = cinn::utils::Attribute; + const Attribute& GetAttr(const std::string& attr_name) const { + return node_->attrs.attr_store.at(attr_name); + } + + friend struct std::hash; + + const hlir::framework::Node* node_; + const hlir::framework::Graph* graph_; + + const InputTensorListView input_tensors_; + const OutputTensorListView output_tensors_; +}; + +} // namespace api +} // namespace cinn + +namespace std { + +template <> +struct hash { + size_t operator()(const cinn::api::OpNode& obj) const { + return std::hash()(reinterpret_cast(obj.node_)); + } +}; + +} // namespace std diff --git a/paddle/cinn/api/shape.h b/paddle/cinn/api/shape.h new file mode 100644 index 0000000000000000000000000000000000000000..f5ddf8b8daa91e27e21a18989c39d81dfed460d8 --- /dev/null +++ b/paddle/cinn/api/shape.h @@ -0,0 +1,56 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/cinn/hlir/framework/graph.h" +#include "paddle/cinn/hlir/pass/fusion_helper_base.h" +#include "paddle/cinn/utils/small_vector.h" +#include "paddle/cinn/utils/type_defs.h" + +namespace cinn { +namespace api { + +class Shape final { + public: + explicit Shape(const utils::ShapeType& shape) + : shape_(shape.begin(), shape.end()) {} + + Shape(const Shape& other) = delete; + Shape(Shape&& other) = delete; + + Shape& operator=(const Shape& other) = delete; + + bool operator==(const Shape& other) const { return shape_ == other.shape_; } + + size_t operator[](size_t index) const { return shape_[index]; } + + size_t at(size_t index) const { return shape_[index]; } + + size_t size() const { return shape_.size(); } + + // Returns the total number of elements in the shape. + size_t numel() const { + return std::accumulate( + shape_.begin(), shape_.end(), 1, std::multiplies()); + } + + private: + cinn::utils::SmallVector shape_; +}; + +} // namespace api +} // namespace cinn diff --git a/paddle/cinn/api/tensor_node.cc b/paddle/cinn/api/tensor_node.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff744316bfcc7299a4ea1f1674c03624f2eb5f80 --- /dev/null +++ b/paddle/cinn/api/tensor_node.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/api/tensor_node.h" + +#include "paddle/cinn/api/op_node.h" + +namespace cinn { +namespace api { + +OpNode TensorNode::producer() const { + return OpNode(node_data_->source_node.get(), graph_); +} + +OpNode TensorNode::ConsumerOpListView::Iterator::operator*() const { + return OpNode((*iter_)->sink()->safe_as(), graph_); +} + +} // namespace api +} // namespace cinn diff --git a/paddle/cinn/api/tensor_node.h b/paddle/cinn/api/tensor_node.h new file mode 100644 index 0000000000000000000000000000000000000000..fca0a844108bc8380afa4f32c93835c956b1ce6e --- /dev/null +++ b/paddle/cinn/api/tensor_node.h @@ -0,0 +1,119 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/cinn/api/shape.h" +#include "paddle/cinn/hlir/framework/graph.h" +#include "paddle/cinn/hlir/pass/fusion_helper_base.h" +#include "paddle/cinn/utils/small_vector.h" +#include "paddle/cinn/utils/type_defs.h" + +namespace cinn { +namespace api { + +class OpNode; + +class TensorNode final { + public: + TensorNode(const hlir::framework::NodeData* node_data, + const hlir::framework::Graph* graph) + : node_data_(node_data), + graph_(graph), + consumers_(node_data_->outlinks(), graph_) { + const auto& shape_dict = + graph_->GetAttrs>( + "infershape"); + CHECK(shape_dict.count(node_data_->id())) + << "Can't find " << node_data_->id() << " 's shape!"; + shape_ = std::make_shared(shape_dict.find(node_data_->id())->second); + } + + // Get the shape of tensor. + const Shape& shape() const { return *shape_; } + + // Input data has no producer. + bool HasProducer() const { return node_data_->source_node.get() != nullptr; } + + OpNode producer() const; + + class ConsumerOpListView { + public: + ConsumerOpListView(const std::set, + common::GraphEdgeCompare>& edges, + const hlir::framework::Graph* graph) + : edges_(edges), graph_(graph) {} + + ConsumerOpListView(const ConsumerOpListView& other) = delete; + ConsumerOpListView(ConsumerOpListView&& other) = delete; + + ConsumerOpListView& operator=(const ConsumerOpListView& other) = delete; + + class Iterator { + public: + Iterator(std::set, + common::GraphEdgeCompare>::const_iterator it, + const hlir::framework::Graph* graph) + : iter_(it), graph_(graph) {} + + Iterator& operator++() { + ++iter_; + return *this; + } + + Iterator operator++(int) { + Iterator tmp = *this; + ++iter_; + return tmp; + } + + bool operator==(const Iterator& other) const { + return iter_ == other.iter_; + } + + bool operator!=(const Iterator& other) const { return !(*this == other); } + + OpNode operator*() const; + + private: + std::set, + common::GraphEdgeCompare>::const_iterator iter_; + const hlir::framework::Graph* graph_; + }; + + size_t size() const { return edges_.size(); } + + Iterator begin() const { return Iterator(this->edges_.begin(), graph_); } + + Iterator end() const { return Iterator(this->edges_.end(), graph_); } + + private: + const std::set, common::GraphEdgeCompare>& edges_; + const hlir::framework::Graph* graph_; + }; + + const ConsumerOpListView& consumers() const { return consumers_; } + + private: + const hlir::framework::NodeData* node_data_; + const hlir::framework::Graph* graph_; + + std::shared_ptr shape_; + const ConsumerOpListView consumers_; +}; + +} // namespace api +} // namespace cinn diff --git a/paddle/cinn/common/CMakeLists.txt b/paddle/cinn/common/CMakeLists.txt index c9a3267681ab343d53a7a887d2954481aa217524..03acaa40320e56a82119225f14cd56f05099af80 100644 --- a/paddle/cinn/common/CMakeLists.txt +++ b/paddle/cinn/common/CMakeLists.txt @@ -23,6 +23,10 @@ gather_srcs( message(STATUS "srcs: ${cinnapi_src}") +cinn_cc_test(test_dfs_walker SRCS dfs_walker_test.cc DEPS gtest glog) +cinn_cc_test(test_is_reachable_predicator SRCS is_reachable_predicator_test.cc + DEPS gtest glog) +cinn_cc_test(test_topo_walker SRCS topo_walker_test.cc DEPS gtest glog) cinn_cc_test(test_cinn_value SRCS cinn_value_test.cc DEPS cinncore) cinn_cc_test(test_shared SRCS shared_test.cc DEPS cinncore) cinn_cc_test(test_graph_utils SRCS graph_utils_test.cc DEPS cinncore) diff --git a/paddle/cinn/common/bfs_walker.h b/paddle/cinn/common/bfs_walker.h new file mode 100644 index 0000000000000000000000000000000000000000..33530f3add43d94d1efb4db2b8aeb53e74f37da2 --- /dev/null +++ b/paddle/cinn/common/bfs_walker.h @@ -0,0 +1,72 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +namespace cinn { +namespace common { + +// breadth-first search visitor +template +class BfsWalker final { + public: + BfsWalker(const BfsWalker&) = delete; + BfsWalker(BfsWalker&&) = delete; + + using NodeHandlerType = std::function; + using NodesVisitorType = + std::function; + + BfsWalker(const NodesVisitorType& VisitNextNodes) + : VisitNextNodes_(VisitNextNodes) {} + + void operator()(NodeType node, const NodeHandlerType& NodeHandler) const { + std::array nodes{node}; + (*this)(nodes.begin(), nodes.end(), NodeHandler); + } + + template + void operator()(NodeIt begin, + NodeIt end, + const NodeHandlerType& NodeHandler) const { + std::queue node_queue; + std::unordered_set queued_nodes; + const auto& TryEnqueueNode = [&](NodeType node) { + if (queued_nodes.count(node) == 0) { + node_queue.push(node); + queued_nodes.insert(node); + } + }; + for (NodeIt iter = begin; iter != end; ++iter) { + TryEnqueueNode(*iter); + } + while (!node_queue.empty()) { + NodeType node = node_queue.front(); + node_queue.pop(); + NodeHandler(node); + VisitNextNodes_(node, TryEnqueueNode); + } + } + + private: + NodesVisitorType VisitNextNodes_; +}; + +} // namespace common +} // namespace cinn diff --git a/paddle/cinn/common/dfs_walker.h b/paddle/cinn/common/dfs_walker.h new file mode 100644 index 0000000000000000000000000000000000000000..840ea53edc4f8f384590f4e6877bff44cf945e5a --- /dev/null +++ b/paddle/cinn/common/dfs_walker.h @@ -0,0 +1,95 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace cinn { +namespace common { + +// depth-first search visitor +template +class DfsWalker final { + public: + DfsWalker(const DfsWalker&) = delete; + DfsWalker(DfsWalker&&) = delete; + + using NodeHandlerType = std::function; + using NodesVisitorType = + std::function; + + DfsWalker(const NodesVisitorType& VisitNextNodes) + : VisitNextNodes_(VisitNextNodes) {} + + void operator()(NodeType node, const NodeHandlerType& NodeHandler) const { + std::array nodes{node}; + (*this)(nodes.begin(), nodes.end(), NodeHandler, [&](NodeType) {}); + } + + template + void operator()(NodeIt begin, + NodeIt end, + const NodeHandlerType& NodeHandler) const { + (*this)(begin, end, NodeHandler, [&](NodeType) {}); + } + + // https://en.wikipedia.org/wiki/Depth-first_search + template + void operator()(NodeIt begin, + NodeIt end, + const NodeHandlerType& NodeHandlerOnPush, + const NodeHandlerType& NodeHandlerOnPop) const { + std::unordered_set discovered; + struct Neighbours { + NodeType producer; + std::queue consumers; + }; + std::stack stack; + const auto& TryPush = [&](NodeType node) { + if (discovered.count(node) == 0) { + discovered.insert(node); + NodeHandlerOnPush(node); + stack.push(Neighbours{.producer = node}); + VisitNextNodes_(node, [&](NodeType next_node) { + stack.top().consumers.push(next_node); + }); + } + }; + for (NodeIt node_iter = begin; node_iter != end; ++node_iter) { + TryPush(*node_iter); + while (!stack.empty()) { + auto* neighbours = &stack.top(); + if (neighbours->consumers.empty()) { + NodeHandlerOnPop(neighbours->producer); + stack.pop(); + } else { + TryPush(neighbours->consumers.front()); + neighbours->consumers.pop(); + } + } + } + } + + private: + NodesVisitorType VisitNextNodes_; +}; + +} // namespace common +} // namespace cinn diff --git a/paddle/cinn/common/dfs_walker_test.cc b/paddle/cinn/common/dfs_walker_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a2a875df67e1d23fb300f854ddab7072b1bf8ddf --- /dev/null +++ b/paddle/cinn/common/dfs_walker_test.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/common/dfs_walker.h" + +#include +#include + +namespace cinn { +namespace common { + +TEST(DfsWalker, simple_on_push) { + DfsWalker visitor( + [](int node, const std::function& NodeHandler) { + if (node == 0) { + NodeHandler(3); + } else if (node == 1) { + NodeHandler(2); + NodeHandler(3); + } else if (node == 2 || node == 3) { + NodeHandler(4); + } + }); + std::vector sources{0, 1}; + std::vector outputs; + visitor(sources.begin(), sources.end(), [&](int node) { + LOG(ERROR) << node; + outputs.push_back(node); + }); + std::vector expected{0, 3, 4, 1, 2}; + EXPECT_TRUE((outputs == expected)); +} + +TEST(DfsWalker, simple_on_pop) { + DfsWalker visitor( + [](int node, const std::function& NodeHandler) { + if (node == 0) { + NodeHandler(3); + } else if (node == 1) { + NodeHandler(2); + NodeHandler(3); + } else if (node == 2 || node == 3) { + NodeHandler(4); + } + }); + std::vector sources{0, 1}; + std::vector outputs; + visitor( + sources.begin(), + sources.end(), + [](int) {}, + [&](int node) { + LOG(ERROR) << node; + outputs.push_back(node); + }); + std::vector expected{4, 3, 0, 2, 1}; + EXPECT_TRUE((outputs == expected)); +} + +} // namespace common +} // namespace cinn diff --git a/paddle/cinn/common/is_reachable_predicator.h b/paddle/cinn/common/is_reachable_predicator.h new file mode 100644 index 0000000000000000000000000000000000000000..4d2b38cfb3ddf43658e59f3bc0f8766dec31c5bf --- /dev/null +++ b/paddle/cinn/common/is_reachable_predicator.h @@ -0,0 +1,79 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/cinn/common/bfs_walker.h" + +namespace cinn { +namespace common { + +template +class IsReachablePredicator final { + public: + IsReachablePredicator(const IsReachablePredicator&) = delete; + IsReachablePredicator(IsReachablePredicator&&) = delete; + + using NodeHandlerType = std::function; + using NodesVisitorType = + std::function; + using NodeDepthGetterType = std::function; + + IsReachablePredicator(const NodeDepthGetterType& MinDepth4Node, + const NodeDepthGetterType& MaxDepth4Node, + const NodesVisitorType& VisitNextNodes) + : MinDepth4Node_(MinDepth4Node), + MaxDepth4Node_(MaxDepth4Node), + VisitNextNodes_(VisitNextNodes) {} + + bool operator()(NodeType src, + NodeType dst, + const NodeHandlerType& HandleVisited) const { + const size_t dst_max_depth = MaxDepth4Node_(dst); + bool detect_reachable = false; + BfsWalker bfs_walker( + [&](NodeType node, const NodeHandlerType& Handler) { + VisitNextNodes_(node, [&](NodeType out_node) { + if (dst_max_depth < MinDepth4Node_(out_node)) { + // Pruned. + // Do nothing. + } else if (detect_reachable) { + // Pruned. + // Reachability is detected. + } else { + Handler(out_node); + } + }); + }); + std::array starts{src}; + bfs_walker(starts.begin(), starts.end(), [&](NodeType node) { + HandleVisited(node); + if (node == dst) { + detect_reachable = true; + } + }); + return detect_reachable; + } + + private: + NodeDepthGetterType MinDepth4Node_; + NodeDepthGetterType MaxDepth4Node_; + NodesVisitorType VisitNextNodes_; +}; + +} // namespace common +} // namespace cinn diff --git a/paddle/cinn/common/is_reachable_predicator_test.cc b/paddle/cinn/common/is_reachable_predicator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc12bc37bee96e95a0ed873b37c9a0a9273829bf --- /dev/null +++ b/paddle/cinn/common/is_reachable_predicator_test.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/common/is_reachable_predicator.h" + +#include +#include + +namespace cinn { +namespace common { + +TEST(IsReachablePredicator, simple) { + IsReachablePredicator IsReachable( + // Get min depth + [](int x) { return std::abs(x); }, + // Get max depth + [](int x) { return std::abs(x); }, + // visit next node + [](int x, const std::function& Handler) { + Handler(x + (x / std::abs(x))); + }); + EXPECT_TRUE(IsReachable(33, 99, [](int) {})); + EXPECT_FALSE(IsReachable(33, -99, [](int) {})); +} + +} // namespace common +} // namespace cinn diff --git a/paddle/cinn/common/macros.h b/paddle/cinn/common/macros.h index 2b9b75064bc0749b99a96b5cf974b782f64616a8..3494d3af0bf3b544d73e3455d54041eecb79537f 100644 --- a/paddle/cinn/common/macros.h +++ b/paddle/cinn/common/macros.h @@ -50,3 +50,38 @@ #else #define CINN_NODISCARD #endif + +#define DISABLE_COPY_AND_ASSIGN(classname) \ + private: \ + classname(const classname&) = delete; \ + classname(classname&&) = delete; \ + classname& operator=(const classname&) = delete; \ + classname& operator=(classname&&) = delete + +/** + * check if MACRO is used in GLOBAL NAMESPACE. + */ +#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \ + struct __test_global_namespace_##uniq_name##__ {}; \ + static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ + __test_global_namespace_##uniq_name##__>::value, \ + msg) + +#define CINN_REGISTER_FUSION_PASS(pass_name, pass_class) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_pass__##pass_name, \ + "CINN_REGISTER_FUSION_PASS must be called in global namespace"); \ + static ::cinn::hlir::pass::FusionPassRegistrar \ + __pass_registrar_##pass_name##__(#pass_name); \ + int TouchFusionPassRegistrar_##pass_name() { \ + __pass_registrar_##pass_name##__.Touch(); \ + return 0; \ + } + +#define USE_FUSION_PASS(pass_name) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __use_fusion_pass_##pass_name, \ + "USE_OP_ITSELF must be called in global namespace"); \ + extern int TouchFusionPassRegistrar_##pass_name(); \ + [[maybe_unused]] static int __use_fusion_pass_##pass_name##_ = \ + TouchFusionPassRegistrar_##pass_name() diff --git a/paddle/cinn/common/scc_walker.h b/paddle/cinn/common/scc_walker.h new file mode 100644 index 0000000000000000000000000000000000000000..36247763f336460eec52aff1ebbb45ab10901b24 --- /dev/null +++ b/paddle/cinn/common/scc_walker.h @@ -0,0 +1,95 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#include "paddle/cinn/common/dfs_walker.h" + +namespace cinn { +namespace common { + +// strong connnected components visitor +template +class SccWalker final { + public: + SccWalker(const SccWalker&) = delete; + SccWalker(SccWalker&&) = delete; + + using NodeHandlerType = std::function; + using NodesVisitorType = + std::function; + + SccWalker(const NodesVisitorType& VisitPrevNodes, + const NodesVisitorType& VisitNextNodes) + : VisitPrevNodes_(VisitPrevNodes), VisitNextNodes_(VisitNextNodes) {} + + using SccHandlerType = std::function&)>; + + // https://en.wikipedia.org/wiki/Kosaraju%27s_algorithm + template + void operator()(NodeIt begin, + NodeIt end, + const SccHandlerType& SccHandler) const { + const std::list& dfs_ordered_nodes = [&]() { + std::list dfs_ordered_nodes; + DfsVisitor visitor(VisitNextNodes_); + visitor( + begin, + end, + /*on push*/ [](NodeType) {}, + /*on pop*/ + [&](NodeType node) { dfs_ordered_nodes.push_front(node); }); + return dfs_ordered_nodes; + }(); + std::unordered_map node2root; + const auto& VisitPrevNode = [&](NodeType node, + const NodeHandlerType& NodeHandler) { + VisitPrevNodes_(node, [&](NodeType prev_node) { + if (node2root.count(prev_node) == 0) { + NodeHandler(prev_node); + } + }); + }; + for (NodeType root : dfs_ordered_nodes) { + if (node2root.count(root) > 0) { + continue; + } + std::vector scc; + // Use node2root immutablely inside dfs visitor. + DfsVisitor visitor(VisitPrevNode); + visitor(root, [&](NodeType node) { scc.push_back(node); }); + SccHandler(scc); + // Update node2root outside dfs visitor. + for (NodeType node : scc) { + CHECK(node2root.emplace(node, root).second); + } + } + } + + private: + NodesVisitorType VisitPrevNodes_; + NodesVisitorType VisitNextNodes_; +}; + +} // namespace common +} // namespace cinn diff --git a/paddle/cinn/common/scc_walker_test.cc b/paddle/cinn/common/scc_walker_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..362243d9ee8806f963f715f538745d058b94016f --- /dev/null +++ b/paddle/cinn/common/scc_walker_test.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/common/scc_walker.h" + +#include +#include + +namespace cinn { +namespace common { + +TEST(SccWalker, trivial) { + std::list> edges{{0, 3}, {1, 2}, {1, 3}, {2, 4}, {3, 4}}; + + SccWalker visitor( + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.second == node) { + NodeHandler(pair.first); + } + } + }, + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.first == node) { + NodeHandler(pair.second); + } + } + }); + std::vector sources{0, 1}; + std::vector> outputs; + visitor(sources.begin(), sources.end(), [&](const auto& nodes) { + outputs.push_back(nodes); + }); + std::vector> expected{{1}, {2}, {0}, {3}, {4}}; + EXPECT_TRUE((outputs == expected)); +} + +TEST(SccWalker, circle) { + std::list> edges{ + {0, 1}, + {1, 2}, + {2, 3}, + {3, 4}, + {4, 0}, + }; + + SccWalker visitor( + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.second == node) { + NodeHandler(pair.first); + } + } + }, + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.first == node) { + NodeHandler(pair.second); + } + } + }); + std::vector sources{0}; + std::vector> outputs; + visitor(sources.begin(), sources.end(), [&](const auto& nodes) { + outputs.push_back(nodes); + }); + std::vector> expected{{0, 4, 3, 2, 1}}; + EXPECT_TRUE((outputs == expected)); +} + +TEST(SccWalker, double_circle) { + std::list> edges{ + {0, 1}, + {1, 0}, + {1, 2}, + {2, 3}, + {3, 2}, + }; + + SccWalker visitor( + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.second == node) { + NodeHandler(pair.first); + } + } + }, + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.first == node) { + NodeHandler(pair.second); + } + } + }); + std::vector sources{0}; + std::vector> outputs; + visitor(sources.begin(), sources.end(), [&](const auto& nodes) { + outputs.push_back(nodes); + }); + std::vector> expected{{0, 1}, {2, 3}}; + EXPECT_TRUE((outputs == expected)); +} + +} // namespace common +} // namespace cinn diff --git a/paddle/cinn/common/topo_walker.h b/paddle/cinn/common/topo_walker.h new file mode 100644 index 0000000000000000000000000000000000000000..1361cd52c0e912feab56ed6d2b559fe62e265880 --- /dev/null +++ b/paddle/cinn/common/topo_walker.h @@ -0,0 +1,82 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +namespace cinn { +namespace common { + +// Topological order visitor +template +class TopoWalker final { + public: + TopoWalker(const TopoWalker&) = delete; + TopoWalker(TopoWalker&&) = delete; + + using NodeHandlerType = std::function; + using NodesVisitorType = + std::function; + + TopoWalker(const NodesVisitorType& VisitPrevNodes, + const NodesVisitorType& VisitNextNodes) + : VisitPrevNodes_(VisitPrevNodes), VisitNextNodes_(VisitNextNodes) {} + + void operator()(NodeType node, const NodeHandlerType& NodeHandler) const { + std::array nodes{node}; + (*this)(nodes.begin(), nodes.end(), NodeHandler); + } + + template + void operator()(NodeIt begin, + NodeIt end, + const NodeHandlerType& NodeHandler) const { + std::queue node_queue; + std::unordered_set queued_nodes; + const auto& TryEnqueueNode = [&](NodeType node) { + if (queued_nodes.count(node) == 0) { + node_queue.push(node); + queued_nodes.insert(node); + } + }; + for (NodeIt iter = begin; iter != end; ++iter) { + TryEnqueueNode(*iter); + } + while (!node_queue.empty()) { + NodeType node = node_queue.front(); + node_queue.pop(); + NodeHandler(node); + VisitNextNodes_(node, [&](NodeType node) { + size_t num_unfinished_inputs = 0; + VisitPrevNodes_(node, [&](NodeType in_node) { + num_unfinished_inputs += (queued_nodes.count(in_node) > 0 ? 0 : 1); + }); + if (num_unfinished_inputs == 0) { + TryEnqueueNode(node); + } + }); + } + } + + private: + NodesVisitorType VisitPrevNodes_; + NodesVisitorType VisitNextNodes_; +}; + +} // namespace common +} // namespace cinn diff --git a/paddle/cinn/common/topo_walker_test.cc b/paddle/cinn/common/topo_walker_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3a0907e7af70de433e8f17d748db4280ac9067a9 --- /dev/null +++ b/paddle/cinn/common/topo_walker_test.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/common/topo_walker.h" + +#include +#include + +namespace cinn { +namespace common { + +TEST(TopoWalker, simple) { + std::vector> edges{ + {0, 3}, {1, 2}, {1, 3}, {2, 3}, {3, 4}}; + TopoWalker visitor( + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.second == node) { + NodeHandler(pair.first); + } + } + }, + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.first == node) { + NodeHandler(pair.second); + } + } + }); + std::vector sources{0, 1}; + std::vector outputs; + visitor(sources.begin(), sources.end(), [&](int node) { + outputs.push_back(node); + }); + std::vector expected{0, 1, 2, 3, 4}; + EXPECT_TRUE((outputs == expected)); +} + +} // namespace common +} // namespace cinn diff --git a/paddle/cinn/frontend/decomposer/test_helper.h b/paddle/cinn/frontend/decomposer/test_helper.h index 9188b4ee48a70c73d2c12b17dbaffe1f7d41bd4b..c55329f87e7a9b4a0325b91bccfeccdbed1f127f 100644 --- a/paddle/cinn/frontend/decomposer/test_helper.h +++ b/paddle/cinn/frontend/decomposer/test_helper.h @@ -30,6 +30,7 @@ #include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/framework/tensor.h" #include "paddle/cinn/hlir/op/use_ops.h" +#include "paddle/cinn/hlir/pass/use_general_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h" namespace cinn::frontend { diff --git a/paddle/cinn/frontend/interpreter.cc b/paddle/cinn/frontend/interpreter.cc old mode 100755 new mode 100644 index 6a432d4f58414d42ff96f2196af1b30a2c59a232..744aac1f2672d006b1d5f19be3c5965fc643eca9 --- a/paddle/cinn/frontend/interpreter.cc +++ b/paddle/cinn/frontend/interpreter.cc @@ -21,6 +21,7 @@ #include "paddle/cinn/hlir/framework/graph.h" #include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/op/use_ops.h" +#include "paddle/cinn/hlir/pass/use_general_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h" #include "paddle/cinn/runtime/flags.h" diff --git a/paddle/cinn/frontend/optimize.cc b/paddle/cinn/frontend/optimize.cc index 3387d32612c7bc04ec82deb75a8fd60e91410376..caff9b19dc95308ed4ac1dc0169af0c1c498a28f 100644 --- a/paddle/cinn/frontend/optimize.cc +++ b/paddle/cinn/frontend/optimize.cc @@ -26,6 +26,7 @@ #include "paddle/cinn/hlir/framework/graph.h" #include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/framework/visualize_helper.h" +#include "paddle/cinn/hlir/pass/use_general_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h" #include "paddle/cinn/runtime/flags.h" @@ -37,6 +38,7 @@ DECLARE_bool(cinn_use_custom_call); DECLARE_bool(use_reduce_split_pass); DECLARE_bool(cinn_use_dense_merge_pass); DECLARE_string(cinn_custom_call_deny_ops); +DECLARE_bool(general_fusion_merge_pass); namespace cinn { namespace frontend { @@ -96,7 +98,11 @@ OptimizeOptions DefaultTrainingOptimizeOptions() { if (FLAGS_cinn_use_op_fusion) { options.graph_passes.emplace_back("OpFusionPass"); - options.graph_passes.emplace_back("FusionMergePass"); + if (FLAGS_general_fusion_merge_pass) { + options.graph_passes.emplace_back("GeneralFusionMergePass"); + } else { + options.graph_passes.emplace_back("FusionMergePass"); + } } else { options.graph_passes.emplace_back("BuildNonFusedGroupsPass"); } diff --git a/paddle/cinn/frontend/pass/test_helper.h b/paddle/cinn/frontend/pass/test_helper.h index 468f15a164d9c12fa19bdf7425ff2eb2b581757b..87725d7b906171bad8f3e54e7e09a1f68e3e1153 100644 --- a/paddle/cinn/frontend/pass/test_helper.h +++ b/paddle/cinn/frontend/pass/test_helper.h @@ -24,6 +24,7 @@ #include "paddle/cinn/frontend/program_pass.h" #include "paddle/cinn/hlir/framework/graph_compiler.h" #include "paddle/cinn/hlir/framework/pass.h" +#include "paddle/cinn/hlir/pass/use_general_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h" namespace cinn::frontend { diff --git a/paddle/cinn/hlir/framework/graph.h b/paddle/cinn/hlir/framework/graph.h index 39f9ae4a12ce5eab7f67f481b523fefbb1220da1..5f4d2e4d9791fbd48345800896d081aacfe598f2 100644 --- a/paddle/cinn/hlir/framework/graph.h +++ b/paddle/cinn/hlir/framework/graph.h @@ -58,6 +58,13 @@ class Graph : public cinn::common::Graph { std::vector> groups; struct Group { + Group() = default; + + explicit Group(const Graph* graph) : graph_(graph) {} + + // The graph that group belongs to. + const Graph* graph_ = nullptr; + // distance to last group. int depth{0}; int max_depth{0}; @@ -81,6 +88,15 @@ class Graph : public cinn::common::Graph { // master node for schedule std::unordered_set master_nodes; + // fused sub-groups, used for fusion merge pass + std::vector> fused_sub_groups; + // if as sub-group, used for belong groups. + std::unordered_set> belong_groups; + + // for op lowering. + std::vector input_names; + std::vector output_names; + struct SharedGroupHasher { size_t operator()(const std::shared_ptr& group) const noexcept { return std::hash()(reinterpret_cast(group.get())); @@ -92,27 +108,6 @@ class Graph : public cinn::common::Graph { return first.get() == second.get(); } }; - // input groups - std::unordered_set, - SharedGroupHasher, - SharedGroupComparator> - producer_groups; - // output grous - std::unordered_set, - SharedGroupHasher, - SharedGroupComparator> - consumer_groups; - // fused sub-groups, used for fusion merge pass - std::vector> fused_sub_groups; - // if as sub-group, used for belong groups. - std::unordered_set, - SharedGroupHasher, - SharedGroupComparator> - belong_groups; - - // for op lowering. - std::vector input_names; - std::vector output_names; std::vector CollectNodes() { if (fused_sub_groups.size()) { @@ -127,6 +122,20 @@ class Graph : public cinn::common::Graph { } } + void WalkNodes(const std::function& VisitNode) const { + if (fused_sub_groups.size()) { + for (auto& group : fused_sub_groups) { + for (const auto* node : group->nodes) { + VisitNode(node); + } + } + } else { + for (const auto* node : nodes) { + VisitNode(node); + } + } + } + std::unordered_set NodeSet() { std::unordered_set node_set; for (auto node : CollectNodes()) { @@ -139,6 +148,49 @@ class Graph : public cinn::common::Graph { std::unordered_set GetOutputNodeDatas(); std::string GetFuncName() { return "fn_" + group_id + unique_id; } + + public: + const std::unordered_set, + SharedGroupHasher, + SharedGroupComparator>& + producer_groups() const { + return producer_groups_; + } + + const std::unordered_set, + SharedGroupHasher, + SharedGroupComparator>& + consumer_groups() const { + return consumer_groups_; + } + + std::unordered_set, + SharedGroupHasher, + SharedGroupComparator>* + mut_producer_groups() { + return &producer_groups_; + } + + std::unordered_set, + SharedGroupHasher, + SharedGroupComparator>* + mut_consumer_groups() { + return &consumer_groups_; + } + + hlir::framework::OpPatternKind kind() const { return op_pattern_kind; } + + private: + // input groups + std::unordered_set, + SharedGroupHasher, + SharedGroupComparator> + producer_groups_; + // output grous + std::unordered_set, + SharedGroupHasher, + SharedGroupComparator> + consumer_groups_; }; std::vector> fusion_groups; diff --git a/paddle/cinn/hlir/framework/pass.cc b/paddle/cinn/hlir/framework/pass.cc index d10490e607ed11003b78c76cd476d3d95e304cfc..5b329ab748a6e4defe62118c141c2274c24a62e0 100644 --- a/paddle/cinn/hlir/framework/pass.cc +++ b/paddle/cinn/hlir/framework/pass.cc @@ -15,6 +15,7 @@ #include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/framework/visualize_helper.h" +#include "paddle/cinn/hlir/pass/use_general_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h" namespace cinn { diff --git a/paddle/cinn/hlir/pass/CMakeLists.txt b/paddle/cinn/hlir/pass/CMakeLists.txt index ce09c9eddc562c62c5252d259c322ffa06a739d7..79ee0e6fa442a3aa6afd45110ca635197272f4c0 100644 --- a/paddle/cinn/hlir/pass/CMakeLists.txt +++ b/paddle/cinn/hlir/pass/CMakeLists.txt @@ -9,6 +9,7 @@ gather_srcs( const_propagate.cc op_fusion_pass.cc fusion_merge_pass.cc + general_fusion_merge_pass.cc dot_merger.cc check_fusion_accuracy_pass.cc custom_call_pass.cc diff --git a/paddle/cinn/hlir/pass/fusion_merge_pass.cc b/paddle/cinn/hlir/pass/fusion_merge_pass.cc index fc0b372b3ede0fb59971745186888f76e3558f99..da2cfe86a4c215b720789ce263820878175dd900 100644 --- a/paddle/cinn/hlir/pass/fusion_merge_pass.cc +++ b/paddle/cinn/hlir/pass/fusion_merge_pass.cc @@ -62,10 +62,10 @@ class FusionMergePassHelper : public FusionHelperBase { for (auto& sub_group : group->fused_sub_groups) { VLOG(3) << " Fused Sub-Group -> " << sub_group->group_id; } - for (auto& producer : group->producer_groups) { + for (const auto& producer : group->producer_groups()) { VLOG(3) << " Producer -> " << producer->group_id; } - for (auto& consumer : group->consumer_groups) { + for (const auto& consumer : group->consumer_groups()) { VLOG(3) << " Consumer -> " << consumer->group_id; } } @@ -94,7 +94,7 @@ class FusionMergePassHelper : public FusionHelperBase { continue; } // do horizontal fusion. - updated |= HorizontalFusion(producer, producer->consumer_groups); + updated |= HorizontalFusion(producer, producer->consumer_groups()); } if (updated) { @@ -115,9 +115,10 @@ class FusionMergePassHelper : public FusionHelperBase { } // do horizontal fusion. if (!recompute) { - updated |= HorizontalFusion(producer, producer->consumer_groups); + updated |= HorizontalFusion(producer, producer->consumer_groups()); } - updated |= VerticalFusion(producer, producer->consumer_groups, recompute); + updated |= + VerticalFusion(producer, producer->consumer_groups(), recompute); } // fuse input consumers updated |= FuseInputToConsumers(); @@ -151,7 +152,7 @@ class FusionMergePassHelper : public FusionHelperBase { } bool exist = false; - for (auto& producer : group->producer_groups) { + for (const auto& producer : group->producer_groups()) { if (fusion_groups_set.count(producer)) { VLOG(4) << group->group_id << " " << producer->group_id; exist = true; @@ -183,7 +184,7 @@ class FusionMergePassHelper : public FusionHelperBase { } std::unordered_set candidates; - for (auto& consumer : consumers) { + for (const auto& consumer : consumers) { // relation auto& relation = fusion_relation_map_[consumer->op_pattern_kind]; // check horizontal relation exist @@ -324,18 +325,18 @@ class FusionMergePassHelper : public FusionHelperBase { fused_group->fused_sub_groups.push_back(consumer); } // producer group - for (auto& producer : consumer->producer_groups) { - fused_group->producer_groups.insert(producer); + for (auto& producer : *consumer->mut_producer_groups()) { + fused_group->mut_producer_groups()->insert(producer); // update producer's consumer - producer->consumer_groups.erase(consumer); - producer->consumer_groups.insert(fused_group); + producer->mut_consumer_groups()->erase(consumer); + producer->mut_consumer_groups()->insert(fused_group); } // consumer group - for (auto& gconsumer : consumer->consumer_groups) { - fused_group->consumer_groups.insert(gconsumer); + for (auto& gconsumer : *consumer->mut_consumer_groups()) { + fused_group->mut_consumer_groups()->insert(gconsumer); // update consumer's producer - gconsumer->producer_groups.erase(consumer); - gconsumer->producer_groups.insert(fused_group); + gconsumer->mut_producer_groups()->erase(consumer); + gconsumer->mut_producer_groups()->insert(fused_group); } // belongs group consumer->belong_groups.insert(fused_group); @@ -412,7 +413,7 @@ class FusionMergePassHelper : public FusionHelperBase { std::unordered_set fuse_consumers_unsafe; std::unordered_set fuse_consumers; - for (auto& consumer : consumers) { + for (const auto& consumer : consumers) { VLOG(4) << "Check consuemr " << consumer->group_id << " can fuse to producer " << producer->group_id; // if can't fuse @@ -458,7 +459,7 @@ class FusionMergePassHelper : public FusionHelperBase { // if can_fuse_consumers == consumers // if producer op kind == kElementwise // if use recompute - if (fuse_consumers_unsafe.size() == producer->consumer_groups.size() && + if (fuse_consumers_unsafe.size() == producer->consumer_groups().size() && producer->op_pattern_kind == framework::kElementWise) { if (!recompute) { return false; @@ -531,11 +532,11 @@ class FusionMergePassHelper : public FusionHelperBase { } // producer groups - for (auto& group : producer->producer_groups) { - fused_group->producer_groups.insert(group); + for (auto& group : *producer->mut_producer_groups()) { + fused_group->mut_producer_groups()->insert(group); // update producer's producer's consumer - group->consumer_groups.erase(producer); - group->consumer_groups.insert(fused_group); + group->mut_consumer_groups()->erase(producer); + group->mut_consumer_groups()->insert(fused_group); } // sub groups @@ -581,20 +582,20 @@ class FusionMergePassHelper : public FusionHelperBase { } // producer nodes - for (auto& group : consumer->producer_groups) { + for (auto& group : *consumer->mut_producer_groups()) { if (group.get() != producer.get()) { - fused_group->producer_groups.insert(group); + fused_group->mut_producer_groups()->insert(group); // update consumer's producer's consumer - group->consumer_groups.erase(consumer); - group->consumer_groups.insert(fused_group); + group->mut_consumer_groups()->erase(consumer); + group->mut_consumer_groups()->insert(fused_group); } } // consumer nodes - for (auto& group : consumer->consumer_groups) { - fused_group->consumer_groups.insert(group); + for (auto& group : *consumer->mut_consumer_groups()) { + fused_group->mut_consumer_groups()->insert(group); // update consumer's consumer's producer - group->producer_groups.erase(consumer); - group->producer_groups.insert(fused_group); + group->mut_producer_groups()->erase(consumer); + group->mut_producer_groups()->insert(fused_group); } // sub group @@ -631,7 +632,7 @@ class FusionMergePassHelper : public FusionHelperBase { for (auto& node : producer->output_nodes) { bool be_output = true; - for (auto& consumer : producer->consumer_groups) { + for (const auto& consumer : producer->consumer_groups()) { // if consumer is in fusionable. if (fusionable_consumers.count(consumer)) { if (consumer->input_nodes.count(node)) { @@ -658,14 +659,14 @@ class FusionMergePassHelper : public FusionHelperBase { } } // insert unfusionable consumer groups - for (auto& consumer : producer->consumer_groups) { + for (auto& consumer : *producer->mut_consumer_groups()) { if (fusionable_consumers.count(consumer)) { continue; } - master_fuesd_group->consumer_groups.insert(consumer); + master_fuesd_group->mut_consumer_groups()->insert(consumer); // update consumer's producer - consumer->producer_groups.erase(producer); - consumer->producer_groups.insert(master_fuesd_group); + consumer->mut_producer_groups()->erase(producer); + consumer->mut_producer_groups()->insert(master_fuesd_group); } } @@ -699,13 +700,13 @@ class FusionMergePassHelper : public FusionHelperBase { sub_group->nodes_set.insert(producer->CollectNodes()[0]); // remove depency. consumer->input_nodes.erase(producer->CollectNodes()[0]); - consumer->producer_groups.erase(producer); - producer->consumer_groups.erase(consumer); + consumer->mut_producer_groups()->erase(producer); + producer->mut_consumer_groups()->erase(consumer); } } - CHECK_GE(producer->consumer_groups.size(), candidates.size()); - if (producer->consumer_groups.size() == 0 && candidates.size() == 0 && + CHECK_GE(producer->consumer_groups().size(), candidates.size()); + if (producer->consumer_groups().size() == 0 && candidates.size() == 0 && output_nodes_set_.count(producer->CollectNodes()[0]) == 0) { producer->belong_groups.insert(*fusionable_consumers->begin()); } @@ -714,7 +715,7 @@ class FusionMergePassHelper : public FusionHelperBase { return; } // 1 to 1 fusion. - if (producer->consumer_groups.size() == 1) { + if (producer->consumer_groups().size() == 1) { return; } @@ -805,7 +806,7 @@ class FusionMergePassHelper : public FusionHelperBase { while (!candidates.empty()) { auto& candidate = candidates.front(); candidates.pop(); - for (auto& producer : candidate->producer_groups) { + for (const auto& producer : candidate->producer_groups()) { if (producer.get() == producer_g.get()) { continue; } @@ -833,7 +834,7 @@ class FusionMergePassHelper : public FusionHelperBase { while (!candidates.empty()) { auto& candidate = candidates.front(); candidates.pop(); - for (auto& producer : candidate->producer_groups) { + for (auto& producer : candidate->producer_groups()) { if (producer.get() == producer_g.get()) { continue; } @@ -934,8 +935,8 @@ class FusionMergePassHelper : public FusionHelperBase { belong_group->output_nodes = group->output_nodes; belong_group->op_pattern_kind = group->op_pattern_kind; belong_group->master_nodes = group->master_nodes; - belong_group->producer_groups = group->producer_groups; - belong_group->consumer_groups = group->consumer_groups; + (*belong_group->mut_producer_groups()) = group->producer_groups(); + (*belong_group->mut_consumer_groups()) = group->consumer_groups(); belong_group->fused_sub_groups.push_back(group); group->belong_groups.insert(belong_group); // replace group to fused_group @@ -949,18 +950,19 @@ class FusionMergePassHelper : public FusionHelperBase { std::unordered_set producers; std::unordered_set consumers; - for (auto& producer : group->producer_groups) { + for (const auto& producer : group->producer_groups()) { CHECK(producer->belong_groups.size()); producers.insert(*producer->belong_groups.begin()); } - for (auto& consumer : group->consumer_groups) { + + for (auto& consumer : *group->mut_consumer_groups()) { CHECK(consumer->belong_groups.size()); consumers.insert(*consumer->belong_groups.begin()); } - CHECK_EQ(group->producer_groups.size(), producers.size()); - CHECK_EQ(group->consumer_groups.size(), consumers.size()); - group->producer_groups = producers; - group->consumer_groups = consumers; + CHECK_EQ(group->producer_groups().size(), producers.size()); + CHECK_EQ(group->consumer_groups().size(), consumers.size()); + (*group->mut_producer_groups()) = producers; + (*group->mut_consumer_groups()) = consumers; } } diff --git a/paddle/cinn/hlir/pass/fusion_merge_pass_util.h b/paddle/cinn/hlir/pass/fusion_merge_pass_util.h index 73e83ee31f39ef4b5a42bfca5975da76a8bfa559..6b6f786cab4a09985ecbc5a1b4869b632de5bf6c 100644 --- a/paddle/cinn/hlir/pass/fusion_merge_pass_util.h +++ b/paddle/cinn/hlir/pass/fusion_merge_pass_util.h @@ -73,8 +73,8 @@ CONDITION_FUNC(is_same_size) { return size_0 == size_1; } -bool is_const_group(const FusionHelperBase* helper, - const std::shared_ptr& group) { +inline bool is_const_group(const FusionHelperBase* helper, + const std::shared_ptr& group) { return group->CollectNodes().size() == 1 && helper->IsConstOp(group->CollectNodes()[0]); } diff --git a/paddle/cinn/hlir/pass/general_fusion_merge_pass.cc b/paddle/cinn/hlir/pass/general_fusion_merge_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4b4a64cef16daf4e4709ee55274f5712f569096 --- /dev/null +++ b/paddle/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -0,0 +1,2151 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/cinn/api/op_group.h" +#include "paddle/cinn/common/is_reachable_predicator.h" +#include "paddle/cinn/common/macros.h" +#include "paddle/cinn/hlir/pass/general_fusion_merge_pass_utils.h" + +DECLARE_bool(enhance_vertical_fusion_with_recompute); + +namespace cinn { +namespace hlir { +namespace pass { + +using framework::Graph; +using framework::Node; +using framework::NodeData; +using framework::OpPatternKind; +using framework::shape_t; + +using common::GraphEdge; +using common::GraphNode; + +using GroupPtr = std::shared_ptr; +using GroupList = std::vector; + +using Comparator = Graph::Group::SharedGroupComparator; +using Hasher = Graph::Group::SharedGroupHasher; + +using OpGroupPtr = api::OpGroup; +using OpGroupList = std::vector; + +using ConditionFunction = std::function; + +class FuseHelper { + public: + virtual ~FuseHelper() = default; + + virtual bool AllOutputsSameSize(const OpGroupPtr& first, + const OpGroupPtr& second) const = 0; + + virtual bool HorizontalElementwiseFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool ElementwiseFuseBroadcast(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool HorizontalWithInjective(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool ElementwiseFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool BroadcastFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool InjectiveHorizontalWithReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool ReduceFuseElementwise(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool ReduceFuseBroadcast(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool ReduceFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool IsReachable(const OpGroupPtr& lhs, + const OpGroupPtr& rhs) const = 0; + + virtual bool DetectCycleIfFuse(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool IsConsumerSetsReachable( + const OpGroupPtr& group, + const std::unordered_set& consumers) const = 0; + + protected: + FuseHelper() = default; +}; + +template +class GraphGroupFuseHelper final : public FuseHelper { + public: + explicit GraphGroupFuseHelper(const FusePassCtxT* ctx) : ctx_(ctx) {} + + bool AllOutputsSameSize(const OpGroupPtr& first, + const OpGroupPtr& second) const override; + + bool HorizontalElementwiseFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool ElementwiseFuseBroadcast(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool HorizontalWithInjective(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool ElementwiseFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool BroadcastFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool InjectiveHorizontalWithReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool ReduceFuseElementwise(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool ReduceFuseBroadcast(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool ReduceFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool IsReachable(const OpGroupPtr& lhs, + const OpGroupPtr& rhs) const override { + return IsReachableInDag(lhs, rhs) || IsReachableInDag(rhs, lhs); + } + + bool DetectCycleIfFuse(const OpGroupPtr& lhs, + const OpGroupPtr& rhs) const override { + return ReachableIfDirectEdgeIgnored(lhs, rhs) || + ReachableIfDirectEdgeIgnored(rhs, lhs); + } + + bool IsConsumerSetsReachable( + const OpGroupPtr& group, + const std::unordered_set& consumers) const override { + for (const auto& consumer : consumers) { + if (group == consumer) { + continue; + } + if (IsReachableInDag(consumer, group)) { + return true; + } + } + return false; + } + + private: + bool IsReachableInDag(const OpGroupPtr& producer, + const OpGroupPtr& consumer) const { + const auto& MinDepth4Node = [&](const OpGroupPtr& node) { + return node.GetGroup()->min_depth; + }; + const auto& MaxDepth4Node = [&](const OpGroupPtr& node) { + return node.GetGroup()->max_depth; + }; + const auto& VisitNextNodes = + [&](const OpGroupPtr& node, + const std::function& Visit) { + for (const auto& node_producer : node.producers()) { + Visit(node_producer); + } + }; + common::IsReachablePredicator is_reachable( + MinDepth4Node, MaxDepth4Node, VisitNextNodes); + return is_reachable(consumer, producer, [](OpGroupPtr) {}); + } + + bool ReachableIfDirectEdgeIgnored(const OpGroupPtr& producer, + const OpGroupPtr& consumer) const { + const auto& MinDepth4Node = [&](const OpGroupPtr& node) { + return node.GetGroup()->min_depth; + }; + const auto& MaxDepth4Node = [&](const OpGroupPtr& node) { + return node.GetGroup()->max_depth; + }; + const auto& VisitNextNodes = + [&](const OpGroupPtr& node, + const std::function& Visit) { + for (const auto& node_producer : node.producers()) { + if (node == consumer && node_producer == producer) { + continue; + } + Visit(node_producer); + } + }; + common::IsReachablePredicator is_reachable( + MinDepth4Node, MaxDepth4Node, VisitNextNodes); + return is_reachable(consumer, producer, [](OpGroupPtr) {}); + } + + const FusePassCtxT* ctx_; +}; + +class FusePassCtx { + public: + virtual ~FusePassCtx() {} + + virtual const FuseHelper& fuse_helper() const = 0; + + virtual void MarkFusible(const OpGroupPtr& first, + const OpGroupPtr& second) = 0; + + protected: + FusePassCtx() = default; +}; + +class LightwareFusePassCtx : public FusePassCtx { + public: + virtual ~LightwareFusePassCtx() {} + + virtual const OpGroupPtr& PickOpGroup() const = 0; + + virtual const FuseHelper& fuse_helper() const = 0; + + virtual void MarkFusible(const OpGroupPtr& first, + const OpGroupPtr& second) = 0; + + virtual void MarkFusible(const OpGroupList& candidates) = 0; + + protected: + LightwareFusePassCtx() = default; +}; + +class GraphGroupLightwareFusePassCtx final : public LightwareFusePassCtx { + public: + GraphGroupLightwareFusePassCtx( + const FusionHelperBase* graph_group_fusion_helper, + const OpGroupPtr& group, + const std::function& MarkFusible) + : graph_group_fusion_helper_(graph_group_fusion_helper), + group_(group), + MarkFusible_(MarkFusible), + fuse_helper_( + new GraphGroupFuseHelper(this)) {} + + GraphGroupLightwareFusePassCtx( + const FusionHelperBase* graph_group_fusion_helper, + const OpGroupPtr& group, + const std::function& + MarkGroupListFusible) + : graph_group_fusion_helper_(graph_group_fusion_helper), + group_(group), + MarkGroupListFusible_(MarkGroupListFusible), + fuse_helper_( + new GraphGroupFuseHelper(this)) {} + + const OpGroupPtr& PickOpGroup() const override { return group_; } + + const FuseHelper& fuse_helper() const override { return *fuse_helper_; } + + void MarkFusible(const OpGroupPtr& first, const OpGroupPtr& second) override { + MarkFusible_(first, second); + } + + void MarkFusible(const OpGroupList& candidates) override { + MarkGroupListFusible_(candidates); + } + + const FusionHelperBase& graph_group_fusion_helper() const { + return *graph_group_fusion_helper_; + } + + private: + const FusionHelperBase* graph_group_fusion_helper_; + const OpGroupPtr& group_; + const std::function + MarkFusible_; + const std::function + MarkGroupListFusible_; + const std::unique_ptr fuse_helper_; +}; + +class InputFusePassCtx : public FusePassCtx { + public: + virtual ~InputFusePassCtx() {} + + virtual const OpGroupList& PickConsumersWithSameInputs() const = 0; + + virtual const FuseHelper& fuse_helper() const = 0; + + virtual void MarkFusible(const OpGroupPtr& first, + const OpGroupPtr& second) = 0; + + virtual void MarkFusible(const OpGroupList& candidates) = 0; + + protected: + InputFusePassCtx() = default; +}; + +class GraphGroupInputFusePassCtx final : public InputFusePassCtx { + public: + GraphGroupInputFusePassCtx( + const FusionHelperBase* graph_group_fusion_helper, + const OpGroupList& groups, + const std::function& MarkFusible) + : graph_group_fusion_helper_(graph_group_fusion_helper), + groups_(groups), + MarkFusible_(MarkFusible), + fuse_helper_( + new GraphGroupFuseHelper(this)) {} + + GraphGroupInputFusePassCtx( + const FusionHelperBase* graph_group_fusion_helper, + const OpGroupList& groups, + const std::function& + MarkGroupListFusible) + : graph_group_fusion_helper_(graph_group_fusion_helper), + groups_(groups), + MarkGroupListFusible_(MarkGroupListFusible), + fuse_helper_( + new GraphGroupFuseHelper(this)) {} + + const OpGroupList& PickConsumersWithSameInputs() const override { + return groups_; + } + + const FuseHelper& fuse_helper() const override { return *fuse_helper_; } + + void MarkFusible(const OpGroupPtr& first, const OpGroupPtr& second) override { + MarkFusible_(first, second); + } + + void MarkFusible(const OpGroupList& candidates) override { + MarkGroupListFusible_(candidates); + } + + const FusionHelperBase& graph_group_fusion_helper() const { + return *graph_group_fusion_helper_; + } + + private: + const FusionHelperBase* graph_group_fusion_helper_; + const OpGroupList& groups_; + const std::function + MarkFusible_; + const std::function + MarkGroupListFusible_; + const std::unique_ptr fuse_helper_; +}; + +template +bool GraphGroupFuseHelper::AllOutputsSameSize( + const OpGroupPtr& first, const OpGroupPtr& second) const { + return is_same_size( + &ctx_->graph_group_fusion_helper(), first.GetGroup(), second.GetGroup()); +} + +template +bool GraphGroupFuseHelper::HorizontalElementwiseFuseReduce( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return honrizontal_elementwise_fuse_reduce( + &ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); +} + +template +bool GraphGroupFuseHelper::ElementwiseFuseBroadcast( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return elementwise_fuse_broadcast( + &ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); +} + +template +bool GraphGroupFuseHelper::HorizontalWithInjective( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return horizontal_with_injective( + &ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); +} + +template +bool GraphGroupFuseHelper::ElementwiseFuseReduce( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return elementwise_fuse_reduce( + &ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); +} + +template +bool GraphGroupFuseHelper::BroadcastFuseReduce( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return broadcast_fuse_reduce( + &ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); +} + +template +bool GraphGroupFuseHelper::InjectiveHorizontalWithReduce( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return injective_horizontal_with_reduce( + &ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); +} + +template +bool GraphGroupFuseHelper::ReduceFuseElementwise( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return reduce_fuse_elementwise( + &ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); +} + +template +bool GraphGroupFuseHelper::ReduceFuseBroadcast( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return reduce_fuse_broadcast( + &ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); +} + +template +bool GraphGroupFuseHelper::ReduceFuseReduce( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return reduce_fuse_reduce( + &ctx_->graph_group_fusion_helper(), src.GetGroup(), dst.GetGroup()); +} + +template +struct HorizontalFuseUtil { + using KindKeyT = std::pair; + + static bool DetectFusabilityByKind(FusePassCtxT* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + const KindKeyT kind_pair(src.kind(), dst.kind()); + const auto& map = GetConditionMap(); + const auto& iter = map.find(kind_pair); + if (iter == map.end()) { + return false; + } + return iter->second(ctx, src, dst); + } + + typedef bool (*ConditionT)(FusePassCtxT* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst); + + static const std::map& GetConditionMap() { + thread_local static std::map map(RawConditionMap()); + return map; + } + + static std::map RawConditionMap() { + return std::map{ + {{OpPatternKind::kElementWise, framework::kElementWise}, &IsSameSize}, + {{OpPatternKind::kElementWise, framework::kBroadcast}, &IsSameSize}, + {{OpPatternKind::kElementWise, framework::kInjective}, &IsSameSize}, + {{OpPatternKind::kElementWise, framework::kReduction}, + &HorizontalElementwiseFuseReduce}, + + {{OpPatternKind::kBroadcast, framework::kElementWise}, &IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kBroadcast}, &IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kInjective}, &IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kReduction}, &IsSameSize}, + + {{OpPatternKind::kInjective, framework::kElementWise}, &IsSameSize}, + {{OpPatternKind::kInjective, framework::kBroadcast}, &IsSameSize}, + {{OpPatternKind::kInjective, framework::kInjective}, &IsSameSize}, + {{OpPatternKind::kInjective, framework::kReduction}, &IsSameSize}, + + {{OpPatternKind::kReduction, framework::kElementWise}, + &HorizontalElementwiseFuseReduce}, + {{OpPatternKind::kReduction, framework::kBroadcast}, &IsSameSize}, + {{OpPatternKind::kReduction, framework::kInjective}, &IsSameSize}, + {{OpPatternKind::kReduction, framework::kReduction}, &ReduceFuseReduce}, + }; + } + + static bool IsSameSize(FusePassCtxT* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return utils::IsSameSize(src, dst); + } + + static bool HorizontalElementwiseFuseReduce(FusePassCtxT* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + // if same shape with horizontal relation + if (IsSameSize(ctx, src, dst)) { + return true; + } + + const OpGroupPtr* ele_group = nullptr; + const OpGroupPtr* reduce_group = nullptr; + + if (src.kind() == framework::kReduction) { + ele_group = &dst; + reduce_group = &src; + } else { + ele_group = &src; + reduce_group = &dst; + } + + size_t size_ele = + utils::GetMasterNode(*ele_group).outputs()[0].shape().numel(); + + bool can_fuse = false; + reduce_group->WalkOpNodes([&](const api::OpNode& op) { + if (op.kind() == OpPatternKind::kReduction) { + size_t size_master = op.outputs()[0].shape().numel(); + if (size_ele == size_master) { + can_fuse = true; + } + } + }); + + return can_fuse; + } + + static bool ReduceFuseReduce(FusePassCtxT* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().ReduceFuseReduce(src, dst); + } +}; + +class FusePass { + public: + virtual ~FusePass() = default; + + virtual const std::string FuseMode() const = 0; + + virtual int Benefit() const = 0; + + protected: + FusePass() = default; +}; + +class InputFusePass : public FusePass { + public: + virtual ~InputFusePass() = default; + + virtual void operator()(InputFusePassCtx* ctx) const = 0; + + const std::string FuseMode() const final { return "InputFuse"; } + + virtual int Benefit() const = 0; + + protected: + InputFusePass() = default; +}; + +class DefaultInputFusePass final : public InputFusePass { + public: + DefaultInputFusePass() : InputFusePass() {} + + int Benefit() const override { return 100; } + + void operator()(InputFusePassCtx* ctx) const override { + const auto& consumer_set = ctx->PickConsumersWithSameInputs(); + + const std::unordered_set consumer_candidates = + [&]() -> std::unordered_set { + std::unordered_set consumers; + for (const auto& consumer : consumer_set) { + if (consumer.kind() == framework::kElementWise || + consumer.kind() == framework::kBroadcast || + consumer.kind() == framework::kInjective || + consumer.kind() == framework::kReduction) { + consumers.insert(consumer); + } + } + return consumers; + }(); + if (consumer_candidates.size() <= 1) { + return; + } + + std::vector fusionable_consumers; + for (auto& candidate : consumer_candidates) { + if (ctx->fuse_helper().IsConsumerSetsReachable(candidate, + consumer_candidates)) { + continue; + } + if (fusionable_consumers.empty()) { + fusionable_consumers.push_back({candidate}); + continue; + } + // check each fusionable groups + bool fusionable = false; + for (auto& groups : fusionable_consumers) { + auto& last = groups.back(); + if (!HorizontalFuseUtil::DetectFusabilityByKind( + ctx, candidate, last)) { + continue; + } + groups.push_back(candidate); + fusionable = true; + break; + } + + // if can't fuse to othors Groups, new Groups. + if (!fusionable) { + fusionable_consumers.push_back({candidate}); + } + } + + for (const auto& groups : fusionable_consumers) { + if (groups.size() > 1) { + ctx->MarkFusible(groups); + } + } + VLOG(1) << "DefaultInputFusePass Finish"; + } +}; + +class LightwareFusePass : public FusePass { + public: + virtual ~LightwareFusePass() = default; + + virtual void operator()(LightwareFusePassCtx* ctx) const = 0; + + virtual const std::string FuseMode() const = 0; + + virtual int Benefit() const = 0; + + protected: + LightwareFusePass() = default; +}; + +class HorizontalFusePass : public LightwareFusePass { + public: + virtual ~HorizontalFusePass() = default; + + virtual void operator()(LightwareFusePassCtx* ctx) const = 0; + + const std::string FuseMode() const final { return "HorizontalFuse"; } + + virtual int Benefit() const = 0; + + protected: + HorizontalFusePass() = default; +}; + +class DefaultHorizontalFusePass final : public HorizontalFusePass { + public: + DefaultHorizontalFusePass() : HorizontalFusePass() {} + + int Benefit() const override { return 100; } + + void operator()(LightwareFusePassCtx* ctx) const override { + const auto& producer = ctx->PickOpGroup(); + const std::unordered_set consumer_candidates = + [&]() -> std::unordered_set { + std::unordered_set consumers; + for (const auto& consumer : producer.consumers()) { + if (consumer.kind() == framework::kElementWise || + consumer.kind() == framework::kBroadcast || + consumer.kind() == framework::kInjective || + consumer.kind() == framework::kReduction) { + consumers.insert(consumer); + } + } + return consumers; + }(); + if (consumer_candidates.size() <= 1) { + return; + } + + std::vector fusionable_consumers; + for (auto& candidate : consumer_candidates) { + if (ctx->fuse_helper().IsConsumerSetsReachable(candidate, + consumer_candidates)) { + continue; + } + if (fusionable_consumers.empty()) { + fusionable_consumers.push_back({candidate}); + continue; + } + // check each fusionable groups + bool fusionable = false; + for (auto& groups : fusionable_consumers) { + auto& last = groups.back(); + if (!HorizontalFuseUtil::DetectFusabilityByKind( + ctx, candidate, last)) { + continue; + } + groups.push_back(candidate); + fusionable = true; + break; + } + + // if can't fuse to othors Groups, new Groups. + if (!fusionable) { + fusionable_consumers.push_back({candidate}); + } + } + + for (const auto& groups : fusionable_consumers) { + if (groups.size() > 1) { + // Trick for BERT, maybe not required, wait for substitution from + // unordered_set to set + if (groups.size() == 2) { + OpGroupList fuse_group; + if (groups[1].group_id().substr(0, 4) == "cast" && + groups[0].group_id() == "reshape_split") { + fuse_group.push_back(groups[1]); + fuse_group.push_back(groups[0]); + ctx->MarkFusible(fuse_group); + continue; + } + } + ctx->MarkFusible(groups); + } + } + } +}; + +class VerticalFusePass : public LightwareFusePass { + public: + virtual ~VerticalFusePass() = default; + + virtual void operator()(LightwareFusePassCtx* ctx) const = 0; + + const std::string FuseMode() const final { return "VerticalFuse"; } + + virtual int Benefit() const = 0; + + protected: + VerticalFusePass() = default; +}; + +class DefaultVerticalFusePass final : public VerticalFusePass { + public: + DefaultVerticalFusePass() : VerticalFusePass() {} + + int Benefit() const override { return 100; } + + void operator()(LightwareFusePassCtx* ctx) const override { + const auto& producer = ctx->PickOpGroup(); + const OpGroupList consumers = [&]() { + OpGroupList consumers; + for (const auto& consumer : producer.consumers()) { + consumers.push_back(consumer); + } + return consumers; + }(); + if (consumers.size() == 0) { + return; + } + + std::vector candidates; + for (int i = 0; i < consumers.size(); ++i) { + const auto& consumer = consumers.at(i); + if (!DetectFusabilityByKind(ctx, producer, consumer)) { + break; + } + candidates.push_back(consumer); + } + if (candidates.size() == consumers.size() && + producer.kind() == framework::kElementWise) { + return; + } + + for (int i = 0; i < consumers.size(); ++i) { + const auto& consumer = consumers.at(i); + if (!DetectFusabilityByKind(ctx, producer, consumer)) { + continue; + } + if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) { + VLOG(4) << "Can't fuse because detect cycle"; + continue; + } + ctx->MarkFusible(producer, consumer); + } + } + + using KindKeyT = std::pair; + bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) const { + const KindKeyT kind_pair(src.kind(), dst.kind()); + const auto& map = GetConditionMap(); + const auto& iter = map.find(kind_pair); + if (iter == map.end()) { + return false; + } + return iter->second(ctx, src, dst); + } + + typedef bool (*ConditionT)(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst); + + static const std::map& GetConditionMap() { + thread_local static std::map map(RawConditionMap()); + return map; + } + + static std::map RawConditionMap() { + return std::map{ + {{OpPatternKind::kElementWise, framework::kElementWise}, + &DefaultVerticalFusePass::IsSameSize}, + {{OpPatternKind::kElementWise, framework::kBroadcast}, + &DefaultVerticalFusePass::ElementwiseFuseBroadcast}, + {{OpPatternKind::kElementWise, framework::kInjective}, + &DefaultVerticalFusePass::HorizontalWithInjective}, + {{OpPatternKind::kElementWise, framework::kReduction}, + &DefaultVerticalFusePass::ElementwiseFuseReduce}, + + {{OpPatternKind::kBroadcast, framework::kElementWise}, + &DefaultVerticalFusePass::IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kBroadcast}, + &DefaultVerticalFusePass::IsSameSize}, + {{OpPatternKind::kBroadcast, framework::kInjective}, + &DefaultVerticalFusePass::HorizontalWithInjective}, + {{OpPatternKind::kBroadcast, framework::kReduction}, + &DefaultVerticalFusePass::BroadcastFuseReduce}, + + {{OpPatternKind::kInjective, framework::kElementWise}, + &DefaultVerticalFusePass::IsSameSize}, + {{OpPatternKind::kInjective, framework::kBroadcast}, + &DefaultVerticalFusePass::IsSameSize}, + {{OpPatternKind::kInjective, framework::kInjective}, + &DefaultVerticalFusePass::HorizontalWithInjective}, + {{OpPatternKind::kInjective, framework::kReduction}, + &DefaultVerticalFusePass::InjectiveHorizontalWithReduce}, + + {{OpPatternKind::kReduction, framework::kElementWise}, + &DefaultVerticalFusePass::ReduceFuseElementwise}, + {{OpPatternKind::kReduction, framework::kBroadcast}, + &DefaultVerticalFusePass::ReduceFuseBroadcast}, + {{OpPatternKind::kReduction, framework::kInjective}, + &DefaultVerticalFusePass::HorizontalWithInjective}, + {{OpPatternKind::kReduction, framework::kReduction}, + &DefaultVerticalFusePass::ReduceFuseReduce}, + }; + } + + static bool IsSameSize(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return utils::IsSameSize(src, dst); + } + + static bool ElementwiseFuseBroadcast(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().ElementwiseFuseBroadcast(src, dst); + } + + static bool HorizontalWithInjective(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().HorizontalWithInjective(src, dst); + } + + static bool ElementwiseFuseReduce(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().ElementwiseFuseReduce(src, dst); + } + + static bool BroadcastFuseReduce(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().BroadcastFuseReduce(src, dst); + } + + static bool InjectiveHorizontalWithReduce(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().InjectiveHorizontalWithReduce(src, dst); + } + + static bool ReduceFuseElementwise(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().ReduceFuseElementwise(src, dst); + } + + static bool ReduceFuseBroadcast(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().ReduceFuseBroadcast(src, dst); + } + + static bool ReduceFuseReduce(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().ReduceFuseReduce(src, dst); + } +}; + +class RecomputeFusePass : public LightwareFusePass { + public: + virtual ~RecomputeFusePass() = default; + + virtual void operator()(LightwareFusePassCtx* ctx) const = 0; + + const std::string FuseMode() const final { return "RecomputeFuse"; } + + virtual int Benefit() const = 0; + + protected: + RecomputeFusePass() = default; +}; + +class DefaultRecomputeFusePass final : public RecomputeFusePass { + public: + DefaultRecomputeFusePass() : RecomputeFusePass() {} + + int Benefit() const override { return 100; } + + void operator()(LightwareFusePassCtx* ctx) const override { + const auto& producer = ctx->PickOpGroup(); + const OpGroupList consumers = [&]() { + OpGroupList consumers; + for (const auto& consumer : producer.consumers()) { + consumers.push_back(consumer); + } + return consumers; + }(); + // Borrows unsafe_candidates and candidates concept from origin + // fusion_merge_pass + std::vector unsafe_candidates; + std::vector candidates; + for (int i = 0; i < consumers.size(); ++i) { + const auto& consumer = consumers.at(i); + if (!DetectFusabilityByKind(ctx, producer, consumer)) { + continue; + } + unsafe_candidates.push_back(consumer); + if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) { + continue; + } + candidates.push_back(consumer); + } + + if (!candidates.empty() && unsafe_candidates.size() == consumers.size() && + producer.kind() == framework::kElementWise) { + for (const auto& consumer : consumers) { + ctx->MarkFusible(producer, consumer); + } + } + } + + using KindKeyT = std::pair; + bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) const { + const KindKeyT kind_pair(src.kind(), dst.kind()); + const auto& map = DefaultVerticalFusePass::GetConditionMap(); + const auto& iter = map.find(kind_pair); + if (iter == map.end()) { + return false; + } + return iter->second(ctx, src, dst); + } +}; + +struct LightwareFusePassComparator { + bool operator()(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) const { + return lhs->Benefit() > rhs->Benefit(); + } +}; + +struct InputFusePassComparator { + bool operator()(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) const { + return lhs->Benefit() > rhs->Benefit(); + } +}; + +class FusionPassMap { + public: + static FusionPassMap& Instance() { + static FusionPassMap global_fusion_pass_map; + return global_fusion_pass_map; + } + + bool Has(const std::string& pass_name) const { + return map_.find(pass_name) != map_.end(); + } + + void Insert(const std::string& pass_name, + const std::shared_ptr& pass) { + CHECK(!Has(pass_name)) << "FusePass " << pass_name + << " has already been registered."; + map_.insert({pass_name, pass}); + } + + std::shared_ptr Get(const std::string& pass_name) const { + auto it = map_.find(pass_name); + CHECK(it != map_.end()) + << "FusePass " << pass_name << " has not been registered."; + return it->second; + } + + // fuse_mode: HorizontalFuse, VerticalFuse, RecomputeFuse + std::vector> GetLightwareFusePassesByMode( + const std::string& fuse_mode) const { + CHECK(fuse_mode == "HorizontalFuse" || fuse_mode == "VerticalFuse" || + fuse_mode == "RecomputeFuse") + << "fuse_mode only supports HorizontalFuse, VerticalFuse and " + "RecomputeFuse. Please check your input modes = " + << fuse_mode; + std::set, LightwareFusePassComparator> + candidate_passes; + for (const auto iter : map_) { + if (fuse_mode == iter.second->FuseMode()) { + candidate_passes.insert( + std::dynamic_pointer_cast(iter.second)); + } + } + return std::vector>( + candidate_passes.begin(), candidate_passes.end()); + } + + std::vector> GetInputFusePasses() const { + std::set, InputFusePassComparator> + candidate_passes; + for (const auto iter : map_) { + if (iter.second->FuseMode() == "InputFuse") { + candidate_passes.insert( + std::dynamic_pointer_cast(iter.second)); + } + } + return std::vector>(candidate_passes.begin(), + candidate_passes.end()); + } + + private: + FusionPassMap() = default; + std::unordered_map> map_; + + DISABLE_COPY_AND_ASSIGN(FusionPassMap); +}; + +class Registrar { + public: + // In our design, various kinds of classes, e.g., operators and kernels, + // have their corresponding registry and registrar. The action of + // registration is in the constructor of a global registrar variable, which + // are not used in the code that calls package framework, and would + // be removed from the generated binary file by the linker. To avoid such + // removal, we add Touch to all registrar classes and make USE_OP macros to + // call this method. So, as long as the callee code calls USE_OP, the global + // registrar variable won't be removed by the linker. + void Touch() {} +}; + +template +class FusionPassRegistrar final : public Registrar { + public: + explicit FusionPassRegistrar(const std::string& pass_name) { + FusionPassMap::Instance().Insert( + pass_name, std::shared_ptr(new PassClassT())); + } +}; + +// Op Fusion Pass which performs Ops fusion, Ops are fused +// "vertically", meaning producing Ops are fused into their consumers +// with the intent that the loops which compute their values will be fused in +// code generation. +class GeneralFusionMergePassHelper : public FusionHelperBase { + public: + explicit GeneralFusionMergePassHelper(const Graph* graph) + : FusionHelperBase(graph), graph_(graph) { + fusion_groups_ = graph->fusion_groups; + // init input to consumers. + InitInputToConsumers(); + // init fusion group index. + InitFusionGroupsAndIndex(); + } + + GroupList operator()() { + // run fusion merge untill no update. + DoFusionMerge(); + for (auto& group : fusion_groups_) { + VLOG(3) << "Fusion Group -> " << group->group_id; + for (auto& sub_group : group->fused_sub_groups) { + VLOG(3) << " Fused Sub-Group -> " << sub_group->group_id; + } + for (const auto& producer : group->producer_groups()) { + VLOG(3) << " Producer -> " << producer->group_id; + } + for (const auto& consumer : group->consumer_groups()) { + VLOG(3) << " Consumer -> " << consumer->group_id; + } + } + return fusion_groups_; + } + + private: + void DoFusionMerge() { + VLOG(3) << "DoFusionMerge...!"; + while (DoGeneralHorizontalFusion()) { + } + while (DoGeneralVerticalFusion()) { + } + while (DoGeneralRecomputeAndVerticalFusion()) { + } + } + + bool DoGeneralHorizontalFusion() { + VLOG(3) << "DoGeneralHorizontalFusion...!"; + bool updated = false; + for (int idx = 0; idx < fusion_groups_.size(); ++idx) { + auto producer = fusion_groups_[idx]; + VLOG(3) << "Fusion Producer idx " << idx << " Group -> " + << producer->group_id; + // if producer is sub group. + if (producer->belong_groups.size()) { + continue; + } + // do horizontal fusion. + updated |= GeneralHorizontalFuse(producer); + } + + if (updated) { + UpdateFusionGroup(); + } + return updated; + } + + bool DoGeneralVerticalFusion() { + VLOG(3) << "DoGeneralVerticalFusion...!"; + bool updated = false; + for (int idx = 0; idx < fusion_groups_.size(); ++idx) { + auto producer = fusion_groups_[idx]; + VLOG(3) << "Fusion Producer idx " << idx << " Group -> " + << producer->group_id; + // if producer is sub group. + if (producer->belong_groups.size()) { + continue; + } + // do horizontal fusion. + updated |= GeneralHorizontalFuse(producer); + updated |= GeneralVerticalFuse(producer); + } + + // fuse input consumers + updated |= GeneralInputFuse(); + + if (updated) { + UpdateFusionGroup(); + } + return updated; + } + + bool DoGeneralRecomputeAndVerticalFusion() { + VLOG(3) << "DoGeneralRecomputeAndVerticalFusion...!"; + bool updated = false; + for (int idx = 0; idx < fusion_groups_.size(); ++idx) { + auto producer = fusion_groups_[idx]; + VLOG(3) << "Fusion Producer idx " << idx << " Group -> " + << producer->group_id; + // if producer is sub group. + if (producer->belong_groups.size()) { + continue; + } + // do horizontal fusion. + bool recompute_success = GeneralRecomputeFuse(producer); + updated |= recompute_success; + if (!recompute_success) { + updated |= GeneralVerticalFuse(producer); + } + } + + // fuse input consumers + updated |= GeneralInputFuse(); + + if (updated) { + UpdateFusionGroup(); + } + return updated; + } + + void UpdateFusionGroup() { + VLOG(3) << "UpdateFusionGroup..."; + GroupList fusion_groups; + std::unordered_set fusion_groups_set; + // update fusion_groups_ + for (auto& group : fusion_groups_) { + if (!group->belong_groups.size()) { + fusion_groups.push_back(group); + fusion_groups_set.insert(group); + } + } + // keep group in order + fusion_groups_.clear(); + fusion_groups_index_.clear(); + while (!fusion_groups_set.empty()) { + bool is_ring = true; + for (int idx = 0; idx < fusion_groups.size(); ++idx) { + auto& group = fusion_groups[idx]; + if (!group.get()) { + continue; + } + + bool exist = false; + for (const auto& producer : group->producer_groups()) { + if (fusion_groups_set.count(producer)) { + VLOG(4) << group->group_id << " " << producer->group_id; + exist = true; + break; + } + } + + if (!exist) { + fusion_groups_index_[group] = fusion_groups_.size(); + fusion_groups_.push_back(group); + fusion_groups_set.erase(group); + group.reset(); + is_ring = false; + continue; + } + } + if (is_ring) { + LOG(FATAL) << "Exists Ring, Please Check!"; + } + } + } + + std::vector> RawHorizontalFusePasses() + const { + return FusionPassMap::Instance().GetLightwareFusePassesByMode( + "HorizontalFuse"); + } + + const std::vector>& + GetHorizontalFusePasses() const { + thread_local static std::vector> + fuse_passes = RawHorizontalFusePasses(); + return fuse_passes; + } + + void EnableFusedHorizontalGroups(LightwareFusePassCtx* ctx) const { + const auto& producer = ctx->PickOpGroup(); + if (producer.consumers().size() <= 1) { + return; + } + const auto& fuse_passes = GetHorizontalFusePasses(); + for (const auto& fuse_pass : fuse_passes) { + (*fuse_pass)(ctx); + } + } + + bool GeneralHorizontalFuse(const GroupPtr& producer) { + VLOG(3) << "GeneralHorizontalFuse handling producer : " + << producer->group_id; + const auto& GetFusableConsumerGroupLists = + [&]() -> std::vector { + std::vector tagged_lists; + const auto& MarkFusible = [&](const OpGroupList& candidates) { + tagged_lists.push_back(candidates); + }; + GraphGroupLightwareFusePassCtx fuse_ctx( + this, api::OpGroup(producer), MarkFusible); + EnableFusedHorizontalGroups(&fuse_ctx); + return tagged_lists; + }; + const auto& GetFusableConsumerGroupList = [&]() -> std::vector { + const auto& group_lists = GetFusableConsumerGroupLists(); + if (group_lists.empty()) { + return std::vector{}; + } + std::vector ret; + for (const auto& group_list : group_lists) { + GroupList tmp; + for (const auto& group : group_list) { + tmp.push_back(group.GetGroup()); + } + ret.push_back(tmp); + } + return ret; + }; + + const auto& group_lists = GetFusableConsumerGroupList(); + if (group_lists.empty()) { + return false; + } + for (const auto& group_list : group_lists) { + HorizontalFuse(group_list); + } + + return true; + } + + std::vector> RawInputFusePasses() const { + return FusionPassMap::Instance().GetInputFusePasses(); + } + + const std::vector>& GetInputFusePasses() + const { + thread_local static std::vector> + fuse_passes = RawInputFusePasses(); + return fuse_passes; + } + + void EnableFusedInputGroups(InputFusePassCtx* ctx) const { + const auto& fuse_passes = GetInputFusePasses(); + for (const auto& fuse_pass : fuse_passes) { + (*fuse_pass)(ctx); + } + } + + bool CallGeneralInputFusePass( + const std::unordered_set& consumers) { + VLOG(3) << "CallGeneralInputFusePass...!"; + const auto& GetFusableConsumerGroupLists = + [&]() -> std::vector { + std::vector tagged_lists; + const auto& MarkFusible = [&](const OpGroupList& candidates) { + tagged_lists.push_back(candidates); + }; + OpGroupList consumer_groups; + consumer_groups.reserve(consumers.size()); + for (auto& consumer : consumers) { + consumer_groups.push_back(api::OpGroup(consumer)); + } + GraphGroupInputFusePassCtx fuse_ctx(this, consumer_groups, MarkFusible); + EnableFusedInputGroups(&fuse_ctx); + return tagged_lists; + }; + const auto& GetFusableConsumerGroupList = [&]() -> std::vector { + const auto& group_lists = GetFusableConsumerGroupLists(); + if (group_lists.empty()) { + return std::vector{}; + } + std::vector ret; + for (const auto& group_list : group_lists) { + GroupList tmp; + for (const auto& group : group_list) { + tmp.push_back(group.GetGroup()); + } + ret.push_back(tmp); + } + return ret; + }; + + const auto& group_lists = GetFusableConsumerGroupList(); + if (group_lists.empty()) { + return false; + } + for (const auto& group_list : group_lists) { + HorizontalFuse(group_list); + } + + return true; + } + + void HorizontalFuse(const GroupList& consumers) { + VLOG(3) << "HorizontalFuse Groups..."; + // create fusion group + auto fused_group = std::make_shared(graph_); + // As recompute exist which may case sub-group used by more than one time. + std::vector repeat_sub_groups; + std::unordered_set sub_group_set; + // find the first consumer. + GroupPtr first_consumer(nullptr); + // fuse all group into fusion group. + for (const auto& consumer : consumers) { + VLOG(3) << "fuse consumer " << consumer->group_id << " into fused_group!"; + // update depth + fused_group->max_depth = + std::max(fused_group->max_depth, consumer->max_depth); + fused_group->min_depth = + std::min(fused_group->min_depth, consumer->min_depth); + // update group id + if (fused_group->group_id.size()) { + fused_group->group_id += "_" + consumer->group_id; + } else { + fused_group->group_id = consumer->group_id; + } + // set op pattern kind + fused_group->op_pattern_kind = + static_cast(fused_group->op_pattern_kind) >= + static_cast(consumer->op_pattern_kind) + ? fused_group->op_pattern_kind + : consumer->op_pattern_kind; + // input nodes + for (auto& node : consumer->input_nodes) { + if (fused_group->input_nodes.count(node.first)) { + fused_group->input_nodes[node.first] += node.second; + } else { + fused_group->input_nodes.insert(node); + } + } + // output node + for (auto& node : consumer->output_nodes) { + fused_group->output_nodes.insert(node); + } + // internal node + if (consumer->fused_sub_groups.size()) { + for (auto& node : consumer->internal_nodes) { + fused_group->internal_nodes.insert(node); + } + } + // master node + for (auto& node : consumer->master_nodes) { + if (GetOpKind(node) == framework::kReduction) { + fused_group->master_nodes.insert(node); + } + } + // insert sub group + if (consumer->fused_sub_groups.size()) { + for (auto& sub_group : consumer->fused_sub_groups) { + // check sub group is repeat. + if (sub_group_set.count(sub_group)) { + VLOG(3) << sub_group->group_id << " is repeated!"; + repeat_sub_groups.push_back(sub_group); + continue; + } + // record sub group + sub_group_set.insert(sub_group); + + // insert to fused sub group. + fused_group->fused_sub_groups.push_back(sub_group); + // update belongs group + sub_group->belong_groups.erase(consumer); + sub_group->belong_groups.insert(fused_group); + } + } else { + fused_group->fused_sub_groups.push_back(consumer); + } + // producer group + for (auto& producer : *consumer->mut_producer_groups()) { + fused_group->mut_producer_groups()->insert(producer); + // update producer's consumer + producer->mut_consumer_groups()->erase(consumer); + producer->mut_consumer_groups()->insert(fused_group); + } + // consumer group + for (auto& gconsumer : *consumer->mut_consumer_groups()) { + fused_group->mut_consumer_groups()->insert(gconsumer); + // update consumer's producer + gconsumer->mut_producer_groups()->erase(consumer); + gconsumer->mut_producer_groups()->insert(fused_group); + } + // belongs group + consumer->belong_groups.insert(fused_group); + + // find the first consumer. + CHECK(fusion_groups_index_.count(consumer)) + << "Can't find consumer " << consumer->group_id + << " index in fusion_groups_index_!"; + if (first_consumer.get()) { + if (fusion_groups_index_[consumer] < + fusion_groups_index_[first_consumer]) { + first_consumer = consumer; + } + } else { + first_consumer = consumer; + } + } + + // if node is output nodes of sub_group, check it can't be internal node. + for (auto& sub_group : repeat_sub_groups) { + // check each output node in sub_group. + for (auto& node : sub_group->output_nodes) { + // if node is not output node of fused_group. + if (!fused_group->output_nodes.count(node)) { + fused_group->internal_nodes.insert(node); + } + } + } + + if (static_cast(framework::kReduction) > + static_cast((consumers.back())->op_pattern_kind)) { + auto consumer = consumers.back(); + + for (auto& node : consumer->master_nodes) { + fused_group->master_nodes.insert(node); + } + } else { + for (auto consumer = consumers.rbegin(); consumer != consumers.rend(); + ++consumer) { + Node* master_node = nullptr; + for (auto& node : (*consumer)->master_nodes) { + if (GetOpKind(node) != framework::kReduction) { + master_node = node; + break; + } + } + if (master_node) { + VLOG(3) << "Insert Master node : " << master_node->id() + << " into group : " << fused_group->group_id; + fused_group->master_nodes.insert(master_node); + break; + } + } + } + + auto postion = fusion_groups_index_[first_consumer]; + fusion_groups_[postion] = fused_group; + fusion_groups_index_[fused_group] = postion; + + CHECK(fused_group->output_nodes.size()) + << "No output node is found, " << fused_group->group_id; + } + + std::vector> RawVerticalFusePasses() + const { + return FusionPassMap::Instance().GetLightwareFusePassesByMode( + "VerticalFuse"); + } + + const std::vector>& GetVerticalFusePasses() + const { + thread_local static std::vector> + fuse_passes = RawVerticalFusePasses(); + return fuse_passes; + } + + void TagVerticalGroups(LightwareFusePassCtx* ctx) const { + const auto& producer = ctx->PickOpGroup(); + if (producer.consumers().size() == 0) { + return; + } + const auto& fuse_passes = GetVerticalFusePasses(); + for (const auto& fuse_pass : fuse_passes) { + (*fuse_pass)(ctx); + } + } + + bool GeneralVerticalFuse(const GroupPtr& producer) { + VLOG(3) << "GeneralVerticalFuse...!"; + using GroupSets = std::vector>; + const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { + GroupSets tagged_sets; + const auto& MarkFusible = [&](const OpGroupPtr& first, + const OpGroupPtr& second) { + tagged_sets.push_back(std::make_pair(first, second)); + }; + GraphGroupLightwareFusePassCtx fuse_ctx( + this, api::OpGroup(producer), MarkFusible); + TagVerticalGroups(&fuse_ctx); + return tagged_sets; + }; + + auto GetFusableConsumerGroupSet = + [&]() -> std::unordered_set { + const auto& group_sets = GetFusableConsumerOpGroupSets(); + if (group_sets.empty()) { + return {}; + } + std::unordered_set ret; + for (const auto& group_pair : group_sets) { + ret.insert(group_pair.second.GetGroup()); + } + return ret; + }; + + bool update = false; + auto consumer_groups = GetFusableConsumerGroupSet(); + if (consumer_groups.size()) { + SelectConsumerToFuse(producer, &consumer_groups); + } + if (consumer_groups.size() > 0) { + VerticalFuse(producer, consumer_groups); + update = true; + } + return update; + } + + void VerticalFuse(const GroupPtr& producer, + const std::unordered_set& + fusionable_consumers) { + VLOG(3) << "VerticalFuse...!"; + GroupList fused_groups; + GroupPtr master_fuesd_group(nullptr); + for (auto& consumer : fusionable_consumers) { + auto fused_group = std::make_shared(graph_); + // update depth using consumer depth. + fused_group->max_depth = + std::max(producer->max_depth, consumer->max_depth); + fused_group->min_depth = + std::min(producer->min_depth, consumer->min_depth); + // update group id + fused_group->group_id = producer->group_id + "_" + consumer->group_id; + VLOG(3) << "fuse producer " << producer->group_id << " into consumer " + << consumer->group_id; + // fuse producer into fusion group + fused_group->op_pattern_kind = + static_cast(producer->op_pattern_kind) >= + static_cast(consumer->op_pattern_kind) + ? producer->op_pattern_kind + : consumer->op_pattern_kind; + // input nodes + fused_group->input_nodes = producer->input_nodes; + + // internal nodes + if (producer->fused_sub_groups.size()) { + for (auto& node : producer->internal_nodes) { + fused_group->internal_nodes.insert(node); + } + } + // convert producer's output node to internal. + for (auto node : producer->output_nodes) { + // if node is used more than 1 time. + if (consumer->input_nodes.count(node)) { + if (consumer->input_nodes[node] > 1 && node->inlinks().size() > 0) { + fused_group->internal_nodes.insert(node); + } + } + } + // master nodes + for (auto& node : producer->master_nodes) { + if (GetOpKind(node) == framework::kReduction) { + fused_group->master_nodes.insert(node); + } + } + + // producer groups + for (auto& group : *producer->mut_producer_groups()) { + fused_group->mut_producer_groups()->insert(group); + // update producer's producer's consumer + group->mut_consumer_groups()->erase(producer); + group->mut_consumer_groups()->insert(fused_group); + } + + // sub groups + if (producer->fused_sub_groups.size()) { + for (auto& group : producer->fused_sub_groups) { + fused_group->fused_sub_groups.push_back(group); + // update belong group + group->belong_groups.erase(producer); + group->belong_groups.insert(fused_group); + } + } else { + fused_group->fused_sub_groups.push_back(producer); + } + producer->belong_groups.insert(fused_group); + + // input nodes + for (auto& input_node : consumer->input_nodes) { + // if input node not in producer output. + if (!producer->output_nodes.count(input_node.first)) { + if (fused_group->input_nodes.count(input_node.first)) { + fused_group->input_nodes[input_node.first] += input_node.second; + } else { + fused_group->input_nodes.insert(input_node); + } + } + } + + // output nodes + for (auto& node : consumer->output_nodes) { + fused_group->output_nodes.insert(node); + } + + // internal nodes + if (consumer->fused_sub_groups.size()) { + for (auto& node : consumer->internal_nodes) { + fused_group->internal_nodes.insert(node); + } + } + + // master nodes + for (auto& node : consumer->master_nodes) { + fused_group->master_nodes.insert(node); + } + + // producer nodes + for (auto& group : *consumer->mut_producer_groups()) { + if (group.get() != producer.get()) { + fused_group->mut_producer_groups()->insert(group); + // update consumer's producer's consumer + group->mut_consumer_groups()->erase(consumer); + group->mut_consumer_groups()->insert(fused_group); + } + } + + // consumer nodes + for (auto& group : *consumer->mut_consumer_groups()) { + fused_group->mut_consumer_groups()->insert(group); + // update consumer's consumer's producer + group->mut_producer_groups()->erase(consumer); + group->mut_producer_groups()->insert(fused_group); + } + + // sub group + if (consumer->fused_sub_groups.size()) { + for (auto& sub_group : consumer->fused_sub_groups) { + if (std::find(fused_group->fused_sub_groups.begin(), + fused_group->fused_sub_groups.end(), + sub_group) == fused_group->fused_sub_groups.end()) { + fused_group->fused_sub_groups.push_back(sub_group); + } + // update belong group + sub_group->belong_groups.erase(consumer); + sub_group->belong_groups.insert(fused_group); + } + } else { + fused_group->fused_sub_groups.push_back(consumer); + } + consumer->belong_groups.insert(fused_group); + + fused_groups.push_back(fused_group); + CHECK(fusion_groups_index_.count(consumer)) + << "Can't find consumer " << consumer->group_id + << " index in fusion_groups_index_!"; + auto postion = fusion_groups_index_[consumer]; + fusion_groups_[postion] = fused_group; + fusion_groups_index_[fused_group] = postion; + + if (!master_fuesd_group.get()) { + master_fuesd_group = fused_group; + } + CHECK(fused_group->output_nodes.size()) + << "No output node is found, " << fused_group->group_id; + } + + for (auto& node : producer->output_nodes) { + bool be_output = true; + for (const auto& consumer : producer->consumer_groups()) { + // if consumer is in fusionable. + if (fusionable_consumers.count(consumer)) { + if (consumer->input_nodes.count(node)) { + be_output = false; + } + continue; + } + // if consumer is not in fusionable. + if (consumer->input_nodes.count(node)) { + be_output = true; + break; + } + // others node is as graph output. + } + + if (output_nodes_set_.count(node)) { + be_output = true; + } + + if (be_output) { + VLOG(4) << "Insert Id " << node->id() << " Into Group " + << master_fuesd_group->group_id; + master_fuesd_group->output_nodes.insert(node); + } + } + // insert unfusionable consumer groups + for (auto& consumer : *producer->mut_consumer_groups()) { + if (fusionable_consumers.count(consumer)) { + continue; + } + master_fuesd_group->mut_consumer_groups()->insert(consumer); + // update consumer's producer + consumer->mut_producer_groups()->erase(producer); + consumer->mut_producer_groups()->insert(master_fuesd_group); + } + } + + std::vector> RawRecomputeFusePasses() + const { + return FusionPassMap::Instance().GetLightwareFusePassesByMode( + "RecomputeFuse"); + } + + const std::vector>& + GetRecomputeFusePasses() const { + thread_local static std::vector> + fuse_passes = RawRecomputeFusePasses(); + return fuse_passes; + } + + void TagRecomputeGroups(LightwareFusePassCtx* ctx) const { + const auto& fuse_passes = GetRecomputeFusePasses(); + for (const auto& fuse_pass : fuse_passes) { + (*fuse_pass)(ctx); + } + } + + bool GeneralRecomputeFuse(const GroupPtr& producer) { + VLOG(3) << "GeneralRecomputeFuse handling producer : " + << producer->group_id; + using GroupSets = std::set>; + const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { + GroupSets tagged_sets; + const auto& MarkFusible = [&](const OpGroupPtr& first, + const OpGroupPtr& second) { + tagged_sets.insert(std::make_pair(first, second)); + }; + GraphGroupLightwareFusePassCtx fuse_ctx( + this, api::OpGroup(producer), MarkFusible); + TagRecomputeGroups(&fuse_ctx); + return tagged_sets; + }; + + auto GetFusableConsumerGroupSet = + [&]() -> std::unordered_set { + const auto& group_sets = GetFusableConsumerOpGroupSets(); + if (group_sets.empty()) { + return {}; + } + std::unordered_set ret; + for (const auto& group_pair : group_sets) { + ret.insert(group_pair.second.GetGroup()); + } + return ret; + }; + + bool update = false; + auto consumer_groups = GetFusableConsumerGroupSet(); + if (consumer_groups.size() > 0) { + CHECK(consumer_groups.size() == producer->mut_consumer_groups()->size()) + << "Recompute requires fuse all consumers!"; + RecomputeFuse(producer, consumer_groups); + update = true; + } + return update; + } + + void RecomputeFuse(const GroupPtr& producer, + const std::unordered_set& + fusionable_consumers) { + VerticalFuse(producer, fusionable_consumers); + } + + void SelectConsumerToFuse( + const GroupPtr& producer, + std::unordered_set* fusionable_consumers) { + // if is const op + if (is_const_group(this, producer)) { + std::unordered_set candidates; + for (auto& consumer : *fusionable_consumers) { + // if can be output node. + if (is_same_shape(this, producer, consumer)) { + candidates.insert(consumer); + } else { + VLOG(4) << "Fuse Producer : " << producer->group_id + << " into Consumer : " << consumer->group_id; + consumer->group_id = producer->group_id + "_" + consumer->group_id; + // just merge the node into group. + auto& sub_group = consumer->fused_sub_groups.front(); + sub_group->group_id = producer->group_id + "_" + sub_group->group_id; + sub_group->nodes.insert(sub_group->nodes.begin(), + producer->CollectNodes()[0]); + sub_group->nodes_set.insert(producer->CollectNodes()[0]); + // remove depency. + consumer->input_nodes.erase(producer->CollectNodes()[0]); + consumer->mut_producer_groups()->erase(producer); + producer->mut_consumer_groups()->erase(consumer); + } + } + + CHECK_GE(producer->consumer_groups().size(), candidates.size()); + if (producer->consumer_groups().size() == 0 && candidates.size() == 0 && + output_nodes_set_.count(producer->CollectNodes()[0]) == 0) { + producer->belong_groups.insert(*fusionable_consumers->begin()); + } + + *fusionable_consumers = candidates; + return; + } + // 1 to 1 fusion. + if (producer->consumer_groups().size() == 1) { + return; + } + + if (FLAGS_enhance_vertical_fusion_with_recompute) { + std::vector candidates; + for (auto& consumer : *fusionable_consumers) { + if (consumer->op_pattern_kind == framework::kElementWise) { + candidates.push_back(consumer); + continue; + } + + auto producer_output_shape = + this->GetNodeDataShape(*producer->output_nodes.begin()); + auto consumer_output_shape = + this->GetNodeDataShape(*consumer->output_nodes.begin()); + auto consumer_master_input_shape = + this->GetNodeInputShape(*(consumer->master_nodes.begin())); + int producer_output_numel = + std::accumulate(producer_output_shape.begin(), + producer_output_shape.end(), + 1, + std::multiplies()); + int consumer_output_numel = + std::accumulate(consumer_output_shape.begin(), + consumer_output_shape.end(), + 1, + std::multiplies()); + int consumer_master_input_numel = + std::accumulate(consumer_master_input_shape.begin(), + consumer_master_input_shape.end(), + 1, + std::multiplies()); + if (producer_output_numel == consumer_output_numel) { + candidates.push_back(consumer); + continue; + } + + if (producer->op_pattern_kind != framework::kInjective && + consumer->op_pattern_kind == framework::kReduction && + producer_output_numel == consumer_master_input_numel) { + candidates.push_back(consumer); + } + } + sort(candidates.begin(), + candidates.end(), + [](const auto& lhs, const auto& rhs) { + return lhs->op_pattern_kind < rhs->op_pattern_kind; + }); + + fusionable_consumers->clear(); + if (candidates.size()) { + fusionable_consumers->insert(*candidates.begin()); + } + } else { + std::vector candidates; + for (auto& consumer : *fusionable_consumers) { + if (consumer->op_pattern_kind == framework::kElementWise) { + candidates.push_back(consumer); + continue; + } + + auto shape0 = this->GetNodeDataShape(*producer->output_nodes.begin()); + auto shape1 = this->GetNodeDataShape(*consumer->output_nodes.begin()); + + if (std::accumulate( + shape0.begin(), shape0.end(), 1, std::multiplies()) == + std::accumulate( + shape1.begin(), shape1.end(), 1, std::multiplies())) { + candidates.push_back(consumer); + } + } + + fusionable_consumers->clear(); + if (candidates.size()) { + fusionable_consumers->insert(candidates.front()); + } + } + } + + bool IsDependency( + const GroupPtr& producer_g, + const GroupPtr& consumer, + const std::unordered_set& consumers) { + std::queue candidates; + candidates.push(consumer); + + std::unordered_set visited_set; + while (!candidates.empty()) { + auto& candidate = candidates.front(); + candidates.pop(); + for (const auto& producer_and_list : candidate->producer_groups()) { + if (producer_and_list.get() == producer_g.get()) { + continue; + } + const auto& producer = + std::dynamic_pointer_cast(producer_and_list); + if (consumers.count(producer)) { + return true; + } + if (!visited_set.count(producer)) { + visited_set.insert(producer); + candidates.push(producer); + } + } + } + return false; + } + + bool IsDependencySimplify( + const GroupPtr& producer_g, + const GroupPtr& consumer, + const std::unordered_set& consumers) { + std::queue candidates; + candidates.push(consumer); + // check upper. + int check_upper_depth = producer_g.get() ? producer_g->max_depth : INT_MAX; + std::unordered_set visited_set; + while (!candidates.empty()) { + auto& candidate = candidates.front(); + candidates.pop(); + for (auto& producer_and_list : candidate->producer_groups()) { + if (producer_and_list.get() == producer_g.get()) { + continue; + } + const auto& producer = + std::dynamic_pointer_cast(producer_and_list); + if (producer->min_depth > check_upper_depth) { + continue; + } + if (consumers.count(producer)) { + return true; + } + if (!visited_set.count(producer)) { + visited_set.insert(producer); + candidates.push(producer); + } + } + } + return false; + } + + bool GeneralInputFuse() { + VLOG(3) << "GeneralInputFuse...!"; + auto updated = false; + UpdateInputToConsumers(); + for (auto& input_consumers : input_to_consumers_) { + // if group set size == 1. + if (input_consumers.second.size() == 1) { + continue; + } + // do input fusion. + auto st = CallGeneralInputFusePass(input_consumers.second); + if (st) { + // fused consumers, update + UpdateInputToConsumers(); + } + updated |= st; + } + + return updated; + } + + void UpdateInputToConsumers() { + for (auto& input_consumers : input_to_consumers_) { + auto& consumers = input_consumers.second; + std::unordered_set updated_consumers; + for (auto& consumer : consumers) { + std::queue fused_groups; + fused_groups.push(consumer); + while (!fused_groups.empty()) { + auto& cur = fused_groups.front(); + fused_groups.pop(); + // if group is sub group + if (cur->belong_groups.empty()) { + updated_consumers.insert(cur); + } else { + for (auto& belong_group : cur->belong_groups) { + if (belong_group->group_id == cur->group_id) { + updated_consumers.insert(belong_group); + } else { + fused_groups.push(belong_group); + } + } + } + } + } + consumers = updated_consumers; + } + } + + void InitInputToConsumers() { + VLOG(3) << "InitInputToConsumers...!"; + // init input data node -> fusion group map. + for (auto& group : fusion_groups_) { + for (auto& node : group->nodes_set) { + // collect producer node data. + auto producer_node_datas = GetProducerNodeData(node); + for (auto& node_data : producer_node_datas) { + // node data's source node is null. + if (!node_data->source_node.get()) { + // insert group to set. + input_to_consumers_[node_data].insert(group); + } + } + } + } + } + + void InitFusionGroupsAndIndex() { + VLOG(3) << "InitFusionGroupsAndIndex...!"; + // init the postion of groups in fusion groups. + for (int idx = 0; idx < fusion_groups_.size(); ++idx) { + auto group = fusion_groups_[idx]; + auto belong_group = std::make_shared(graph_); + // copy from group. + belong_group->max_depth = group->depth; + belong_group->min_depth = group->depth; + belong_group->group_id = group->group_id; + belong_group->input_nodes = group->input_nodes; + belong_group->output_nodes = group->output_nodes; + belong_group->op_pattern_kind = group->op_pattern_kind; + belong_group->master_nodes = group->master_nodes; + (*belong_group->mut_producer_groups()) = group->producer_groups(); + (*belong_group->mut_consumer_groups()) = group->consumer_groups(); + belong_group->fused_sub_groups.push_back(group); + group->belong_groups.insert(belong_group); + // replace group to fused_group + fusion_groups_[idx] = belong_group; + // record idx + fusion_groups_index_[belong_group] = idx; + } + + // update producer and consumer. + for (auto& group : fusion_groups_) { + std::unordered_set producers; + std::unordered_set consumers; + + for (const auto& producer : group->producer_groups()) { + CHECK(producer->belong_groups.size()); + producers.insert(*producer->belong_groups.begin()); + } + + for (auto& consumer : *group->mut_consumer_groups()) { + CHECK(consumer->belong_groups.size()); + consumers.insert(*consumer->belong_groups.begin()); + } + CHECK_EQ(group->producer_groups().size(), producers.size()); + CHECK_EQ(group->consumer_groups().size(), consumers.size()); + (*group->mut_producer_groups()) = producers; + (*group->mut_consumer_groups()) = consumers; + } + } + + const Graph* graph_; + GroupList fusion_groups_; + std::unordered_map fusion_groups_index_; + std::unordered_map> + input_to_consumers_; +}; + +void GeneralFusionMergePassInternal(Graph* graph) { + if (graph->fusion_groups.size() <= 1) { + VLOG(3) << "Don't do Fusoin Merge Pass...!"; + return; + } + + GeneralFusionMergePassHelper fusion_merge_pass_helper(graph); + graph->fusion_groups = fusion_merge_pass_helper(); +} + +} // namespace pass +} // namespace hlir +} // namespace cinn + +CINN_REGISTER_HELPER(GeneralFusionMergePass) { + CINN_REGISTER_PASS(GeneralFusionMergePass) + .describe( + "Fusion Merge Pass which performs Fusion-Ops fusion, Producer " + "Fusion-Ops are fused into Consumer Fusion-Ops " + "with certain conditions.") + .set_change_structure(false) + .set_body(cinn::hlir::pass::GeneralFusionMergePassInternal); + + return true; +} + +CINN_REGISTER_FUSION_PASS(DefaultHorizontalFusePass, + cinn::hlir::pass::DefaultHorizontalFusePass); +CINN_REGISTER_FUSION_PASS(DefaultVerticalFusePass, + cinn::hlir::pass::DefaultVerticalFusePass); +CINN_REGISTER_FUSION_PASS(DefaultRecomputeFusePass, + cinn::hlir::pass::DefaultRecomputeFusePass); +CINN_REGISTER_FUSION_PASS(DefaultInputFusePass, + cinn::hlir::pass::DefaultInputFusePass); diff --git a/paddle/cinn/hlir/pass/general_fusion_merge_pass_utils.h b/paddle/cinn/hlir/pass/general_fusion_merge_pass_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..8873f3f1f4f49c3b9b9bca053335dd79fe4dbdcd --- /dev/null +++ b/paddle/cinn/hlir/pass/general_fusion_merge_pass_utils.h @@ -0,0 +1,276 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/api/op_group.h" +#include "paddle/cinn/hlir/pass/fusion_merge_pass_util.h" + +namespace cinn { +namespace hlir { +namespace pass { +namespace utils { + +using framework::OpPatternKind; + +using OpGroupPtr = api::OpGroup; +using OpGroupList = std::vector; + +static api::OpNode GetMasterNode(const OpGroupPtr& op_group) { + std::vector master_nodes; + op_group.WalkOpNodes([&](const api::OpNode& op) { + if (op.kind() == OpPatternKind::kReduction) { + master_nodes.push_back(op); + } + }); + if (!master_nodes.empty()) { + return master_nodes.front(); + } + + op_group.WalkOpNodes( + [&](const api::OpNode& op) { master_nodes.push_back(op); }); + return master_nodes.back(); +} + +static bool IsSameSize(const OpGroupPtr& src, const OpGroupPtr& dst) { + api::OpNode src_master_node = GetMasterNode(src); + api::OpNode dst_master_node = GetMasterNode(dst); + + auto size_0 = src_master_node.outputs()[0].shape().numel(); + auto size_1 = dst_master_node.outputs()[0].shape().numel(); + + return size_0 == size_1; +} + +static std::unordered_set GetInputOps(const OpGroupPtr& op_group) { + std::unordered_set ops_set; + op_group.WalkOpNodes( + [&ops_set](const api::OpNode& op_node) { ops_set.insert(op_node); }); + + std::unordered_set input_ops; + op_group.WalkOpNodes([&](const api::OpNode& op) { + const auto& input_tensors = op.inputs(); + for (size_t i = 0; i < input_tensors.size(); ++i) { + if (input_tensors[i].HasProducer()) { + api::OpNode producer = input_tensors[i].producer(); + if (ops_set.find(producer) == ops_set.end()) { + input_ops.insert(producer); + } + } + } + }); + return input_ops; +} + +static std::unordered_set GetOutputOps( + const OpGroupPtr& op_group) { + std::unordered_set ops_set; + op_group.WalkOpNodes( + [&ops_set](const api::OpNode& op_node) { ops_set.insert(op_node); }); + std::unordered_set output_ops; + op_group.WalkOpNodes([&](const api::OpNode& op) { + const auto& output_tensors = op.outputs(); + for (size_t i = 0; i < output_tensors.size(); ++i) { + const auto& consumers = output_tensors[i].consumers(); + for (const auto& consumer : consumers) { + if (ops_set.find(consumer) == ops_set.end()) { + output_ops.insert(consumer); + break; + } + } + } + }); + return output_ops; +} + +// limit the group args number to less equal 512, as args stack size is 4K. +static bool limit_args(const OpGroupPtr& first, const OpGroupPtr& second) { + std::unordered_set args; + for (auto& group : {first, second}) { + for (const auto& node : GetInputOps(group)) { + args.insert(node); + } + for (const auto& node : GetOutputOps(group)) { + args.insert(node); + } + } + + if (args.size() > 512) { + return false; + } else { + return true; + } +} + +bool WithoutLastDimInReduce(const api::Shape& inshape, + const std::vector& axes) { + // if last axis is in reduce. + if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() || + std::find(axes.begin(), axes.end(), -1) != axes.end()) { + return false; + } + + int sum_last_axes = 1; + for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { + sum_last_axes *= inshape[idx]; + } + + if (sum_last_axes > 1) { + return true; + } else { + return false; + } +} + +static int GetSharedSize(const api::OpNode& op_node) { + const auto& producers = op_node.inputs(); + CHECK_GT(producers.size(), 0); + const auto& inshape = producers[0].shape(); + const auto& axes = op_node.GetAttr>("dim"); + if (WithoutLastDimInReduce(inshape, axes)) { + int lane = 1; + for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { + lane = inshape[idx]; + } + int max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); + if (lane > max_num_threads / 2) { + return 0; + } + int index = axes.size() - 1; + for (; index >= 0; --index) { + if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) { + break; + } + lane *= inshape[axes[index]]; + if (lane > max_num_threads / 2) { + break; + } + } + // if lane > (max_num_threads / 2),the loop break from lane > + // max_num_threads / 2. + int axis = lane > (max_num_threads / 2) ? axes[index] : axes[index + 1]; + if (lane <= max_num_threads) { + return lane * sizeof(float); + } else { + int prefix = inshape[axis]; + int tail = lane / prefix; + for (int idx = max_num_threads / tail; + idx > ((max_num_threads / 2) / tail); + --idx) { + if (prefix % idx == 0) { + return idx * tail * sizeof(float); + } + } + int num = max_num_threads / tail; + return num * tail * sizeof(float); + } + } + return 0; +} + +static bool ReduceFuseReduce(const OpGroupPtr& first, + const OpGroupPtr& second) { + if (!limit_args(first, second)) { + return false; + } + std::unique_ptr reducer_0 = nullptr; + first.WalkOpNodes([&](const api::OpNode& op) { + if (!reducer_0 && op.kind() == OpPatternKind::kReduction) { + reducer_0.reset(new api::OpNode(op)); + } + }); + CHECK(reducer_0) << "Can't find reduce op in group " << first.group_id(); + + std::unique_ptr reducer_1 = nullptr; + second.WalkOpNodes([&](const api::OpNode& op) { + if (!reducer_1 && op.kind() == OpPatternKind::kReduction) { + reducer_1.reset(new api::OpNode(op)); + } + }); + + CHECK(reducer_1) << "Can't find reduce op in group " << second.group_id(); + + // check reduce has same input shape and output shape + const auto& reducer_0_input_shape = reducer_0->inputs()[0].shape(); + const auto& reducer_0_output_shape = reducer_0->outputs()[0].shape(); + + const auto& reducer_1_input_shape = reducer_1->inputs()[0].shape(); + const auto& reducer_1_output_shape = reducer_1->outputs()[0].shape(); + + auto reducer_0_reduce_dim = reducer_0->GetAttr>("dim"); + auto reducer_1_reduce_dim = reducer_1->GetAttr>("dim"); + + for (auto& dim : reducer_0_reduce_dim) { + // if dim = -1, set as shape.size() - 1 + if (dim == -1) { + dim = reducer_0_reduce_dim.size() - 1; + } + } + + for (auto& dim : reducer_1_reduce_dim) { + // if dim = -1, set as shape.size() - 1 + if (dim == -1) { + dim = reducer_1_reduce_dim.size() - 1; + } + } + + // check shape is same + if (reducer_0_input_shape == reducer_1_input_shape && + reducer_0_output_shape == reducer_1_output_shape && + reducer_0_reduce_dim == reducer_1_reduce_dim) { + auto shared_size = 0; + for (auto& fusion_group : {first, second}) { + fusion_group.WalkOpNodes([&](const api::OpNode& op) { + if (op.kind() == OpPatternKind::kReduction) { + shared_size += GetSharedSize(op); + } + }); + } + +#define MAX_AVAILABLE_SHREAD 32 * 1024 + if (shared_size > MAX_AVAILABLE_SHREAD) { + return false; + } +#undef MAX_AVAILABLE_SHREAD + return true; + } + + if (WithoutLastDimInReduce(reducer_0_input_shape, reducer_0_reduce_dim) && + WithoutLastDimInReduce(reducer_1_input_shape, reducer_1_reduce_dim) && + reducer_0_output_shape == reducer_1_output_shape && + reducer_0_reduce_dim == reducer_1_reduce_dim) { + auto shared_size = 0; + for (auto& fusion_group : {first, second}) { + fusion_group.WalkOpNodes([&](const api::OpNode& op) { + if (op.kind() == OpPatternKind::kReduction) { + shared_size += GetSharedSize(op); + } + }); + } + +#define MAX_AVAILABLE_SHREAD 32 * 1024 + if (shared_size > MAX_AVAILABLE_SHREAD) { + return false; + } +#undef MAX_AVAILABLE_SHREAD + return true; + } + + return false; +} + +} // namespace utils +} // namespace pass +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/pass/op_fusion_pass.cc b/paddle/cinn/hlir/pass/op_fusion_pass.cc index 302c9b71b5a9d9912a61f42ad437cdc3f20fa184..84a95dfe277dddeef1851e709566d1443f17d2ff 100644 --- a/paddle/cinn/hlir/pass/op_fusion_pass.cc +++ b/paddle/cinn/hlir/pass/op_fusion_pass.cc @@ -49,7 +49,7 @@ class OpFusionPassHelper : public FusionHelperBase { auto node = graph_node->safe_as(); if (node) { nodes_.push_back(node); - auto group = std::make_shared(); + auto group = std::make_shared(graph); // init group group->nodes.push_back(node); group->nodes_set.insert(node); @@ -101,14 +101,14 @@ class OpFusionPassHelper : public FusionHelperBase { for (auto& consumer : fusion_groups) { for (auto& input_node : consumer->input_nodes) { auto& producer = fusion_groups_[input_node.first]; - consumer->producer_groups.insert(producer); - producer->consumer_groups.insert(consumer); + consumer->mut_producer_groups()->insert(producer); + producer->mut_consumer_groups()->insert(consumer); } } // init group depth. for (auto& group : fusion_groups) { - for (auto& consumer : group->consumer_groups) { + for (const auto& consumer : group->consumer_groups()) { // update depth. group->depth = std::max(group->depth, consumer->depth + 1); } @@ -376,10 +376,10 @@ void OpFusionPassInternal(Graph* graph) { for (auto& group : graph->fusion_groups) { VLOG(3) << "Group Id : " << group->group_id; - for (auto& producer : group->producer_groups) { + for (const auto& producer : group->producer_groups()) { VLOG(3) << " producer group -> " << producer->group_id; } - for (auto& consumer : group->consumer_groups) { + for (const auto& consumer : group->consumer_groups()) { VLOG(3) << " consumer group -> " << consumer->group_id; } } diff --git a/paddle/cinn/hlir/pass/use_pass.h b/paddle/cinn/hlir/pass/use_pass.h index a9397cc8d281f5085f2e3ad7acf451d3424cce9a..8e95f198447222f493367a4756af5cf6e60c75f0 100644 --- a/paddle/cinn/hlir/pass/use_pass.h +++ b/paddle/cinn/hlir/pass/use_pass.h @@ -25,6 +25,7 @@ CINN_USE_REGISTER(DCE) CINN_USE_REGISTER(DotMerger) CINN_USE_REGISTER(OpFusionPass) CINN_USE_REGISTER(FusionMergePass) +CINN_USE_REGISTER(GeneralFusionMergePass) CINN_USE_REGISTER(CheckFusionAccuracyPass) CINN_USE_REGISTER(CommonSubexpressionEliminationPass) diff --git a/paddle/cinn/runtime/flags.cc b/paddle/cinn/runtime/flags.cc index 3baca1958d182be7081c1e2e5aa56a1f09b95287..6a6f13a741bfc22e97dd382c18755b0dfcccd978 100644 --- a/paddle/cinn/runtime/flags.cc +++ b/paddle/cinn/runtime/flags.cc @@ -57,6 +57,10 @@ DEFINE_bool(cinn_use_op_fusion, BoolFromEnv("FLAGS_cinn_use_op_fusion", true), "Whether to use op fusion pass."); +DEFINE_bool(general_fusion_merge_pass, + BoolFromEnv("FLAGS_general_fusion_merge_pass", true), + "Whether to use general fusion_merge pass."); + DEFINE_bool(cinn_use_common_subexpression_elimination, BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination", false),