未验证 提交 c80bf368 编写于 作者: Z zyfncg 提交者: GitHub

[CINN] Refactor pass api of group fusion in CINN (#55090)

* new group fuse pass api

* fix header

* update

* change logic of get master node to fix bug

* revert update for ReduceFuseReduce

* modify according review

* modify by review

* refine

* update

* fix code-format
上级 9c5e4b4e
...@@ -433,6 +433,28 @@ function(download_and_uncompress INSTALL_DIR URL FILENAME) ...@@ -433,6 +433,28 @@ function(download_and_uncompress INSTALL_DIR URL FILENAME)
INSTALL_COMMAND "") INSTALL_COMMAND "")
endfunction() endfunction()
set(fusion_pass_file
${CMAKE_CURRENT_BINARY_DIR}/paddle/cinn/hlir/pass/use_general_pass.h
CACHE INTERNAL "use_general_pass.h file")
file(
WRITE ${fusion_pass_file}
"#include \"paddle/cinn/common/macros.h\" // Generated by the paddle/cinn/hlir/pass/CMakeLists.txt. DO NOT EDIT!\n\n"
)
function(find_fusion_pass_register FILENAME ADD_PATH PATTERN)
# set op_name to OUTPUT
file(READ ${FILENAME} CONTENT)
string(REGEX MATCHALL "${PATTERN}\\([a-zA-Z0-9_]*," fusion_pass_patterns
"${CONTENT}")
if(NOT fusion_pass_patterns STREQUAL "")
foreach(pass_pattern ${fusion_pass_patterns})
string(REPLACE "${PATTERN}(" "" pass_pattern "${pass_pattern}")
string(REPLACE "," "" pass_pattern "${pass_pattern}")
file(APPEND ${ADD_PATH} "USE_FUSION_PASS(${pass_pattern});\n")
endforeach()
endif()
endfunction()
function(gather_srcs SRC_GROUP) function(gather_srcs SRC_GROUP)
set(options) set(options)
set(oneValueArgs) set(oneValueArgs)
...@@ -442,6 +464,8 @@ function(gather_srcs SRC_GROUP) ...@@ -442,6 +464,8 @@ function(gather_srcs SRC_GROUP)
set(${SRC_GROUP} set(${SRC_GROUP}
"${${SRC_GROUP}};${CMAKE_CURRENT_SOURCE_DIR}/${cpp}" "${${SRC_GROUP}};${CMAKE_CURRENT_SOURCE_DIR}/${cpp}"
CACHE INTERNAL "") CACHE INTERNAL "")
find_fusion_pass_register("${CMAKE_CURRENT_SOURCE_DIR}/${cpp}"
${fusion_pass_file} "CINN_REGISTER_FUSION_PASS")
endforeach() endforeach()
endfunction() endfunction()
......
...@@ -2,6 +2,7 @@ if(WITH_TESTING) ...@@ -2,6 +2,7 @@ if(WITH_TESTING)
cinn_cc_library(cinn_gtest_main SRCS gtest_main.cc DEPS gtest gflags) cinn_cc_library(cinn_gtest_main SRCS gtest_main.cc DEPS gtest gflags)
endif() endif()
add_subdirectory(api)
add_subdirectory(auto_schedule) add_subdirectory(auto_schedule)
add_subdirectory(common) add_subdirectory(common)
add_subdirectory(utils) add_subdirectory(utils)
......
core_gather_headers()
gather_srcs(cinnapi_src SRCS op_node.cc tensor_node.cc)
message(STATUS "srcs: ${cinnapi_src}")
The classes in this directory are the interface of group fusion pass, you can use these apis to build the stragey for group fusion.
The Class and APIs are following:
`OpGroup` : A set of op nodes, which will pass to cinn backend for generating kernel code. Two groups can fuse togather according to the rule of merging written in the passes.
`OpNode` : Map the op in the program.
`TensorNode` : Map the tensor in the program.
`Shape` : The shape infomation of tensor
`FusePassCtx` : The context is the parameter for the pass, it hold the data all you need in the pass.
`FuseHelper` : We provide some util methods such as `DetectCycleIfFuse` in fuse_helper to simplify development of pass.
| Class | method | description |
| :--: | :--: | :--: |
| OpGroup | kind()| Get the Kind of group |
| | producers()| Get producer groups of current group |
| | consumers() | Get consumer groups of current group |
| | WalkOpNodes(const std::function<void(const OpNode&)>& VisitOpNode) | Visit the op_nodes in the group and execute the VisitOpNode function for each OpNode |
| | | |
| OpNode | kind() | Get the Kind of op_node |
| | inputs() | Get input tensors of op_node |
| | outputs() | Get output tensors of op_node |
| | GetAttr(const std::string& attr_name) | Get attribute of op_node by attr name |
| | | |
| TensorNode | shape() | Get shape of tensor |
| | producer() | Get the producer op_node of tensor |
| | consumers() | Get the consumer op_nodes of tensor |
| | | |
| Shape | numel() | Get total number of elements in the shape |
| | other methods are same with std::vector<int64_t> | |
| | | |
| LightwareFusePassCtx | PickOpGroup() | Get the current group in the pass context |
| | void EnableFuse(const OpGroup& first, const OpGroup& second) | Mark the two groups which can fuse togather |
| | fuse_helper() | Get the fuse_helper provided by pass context |
| | | |
| InputFusePassCtx | PickConsumersWithSameInputs() | Get all consumer groups for input tensors of graph |
| | void EnableFuse(const OpGroup& first, const OpGroup& second) | Mark the two groups which can fuse togather |
| | fuse_helper() | Get the fuse_helper provided by pass context |
| | | |
| FuseHelper | DetectCycleIfFuse(const OpGroup& first, const OpGroup& second) | Whether there is cycle in graph after fusing two groups |
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/cinn/api/op_node.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/pass/fusion_helper_base.h"
namespace cinn {
namespace api {
class OpGroup {
public:
explicit OpGroup(const std::shared_ptr<hlir::framework::Graph::Group>& group)
: group_(group) {}
OpGroup(const OpGroup& other) = default;
using Comparator = hlir::framework::Graph::Group::SharedGroupComparator;
using Hasher = hlir::framework::Graph::Group::SharedGroupHasher;
class OpGroupListIterator {
public:
OpGroupListIterator(
std::unordered_set<std::shared_ptr<hlir::framework::Graph::Group>,
Hasher,
Comparator>::const_iterator it)
: iter_(it) {}
OpGroupListIterator& operator++() {
++iter_;
return *this;
}
OpGroupListIterator operator++(int) {
OpGroupListIterator tmp = *this;
++iter_;
return tmp;
}
bool operator==(const OpGroupListIterator& other) const {
return iter_ == other.iter_;
}
bool operator!=(const OpGroupListIterator& other) const {
return !(*this == other);
}
OpGroup operator*() const { return OpGroup(*iter_); }
private:
std::unordered_set<std::shared_ptr<hlir::framework::Graph::Group>,
Hasher,
Comparator>::const_iterator iter_;
};
class ProducerOpGroupListView {
public:
ProducerOpGroupListView(
const std::weak_ptr<hlir::framework::Graph::Group>& group)
: group_(group) {}
ProducerOpGroupListView(const ProducerOpGroupListView& other) = delete;
ProducerOpGroupListView(ProducerOpGroupListView&& other) = delete;
ProducerOpGroupListView& operator=(const ProducerOpGroupListView& other) =
delete;
using const_iterator = OpGroupListIterator;
size_t size() const {
CHECK(group_.lock());
return group_.lock()->producer_groups().size();
}
const_iterator begin() const {
CHECK(group_.lock());
return const_iterator(group_.lock()->producer_groups().begin());
}
const_iterator end() const {
CHECK(group_.lock());
return const_iterator(group_.lock()->producer_groups().end());
}
private:
const std::weak_ptr<hlir::framework::Graph::Group> group_;
};
class ConsumerOpGroupListView {
public:
ConsumerOpGroupListView(
const std::weak_ptr<hlir::framework::Graph::Group>& group)
: group_(group) {}
ConsumerOpGroupListView(const ConsumerOpGroupListView& other) = delete;
ConsumerOpGroupListView(ConsumerOpGroupListView&& other) = delete;
ConsumerOpGroupListView& operator=(const ConsumerOpGroupListView& other) =
delete;
using const_iterator = OpGroupListIterator;
size_t size() const {
CHECK(group_.lock());
return group_.lock()->consumer_groups().size();
}
const_iterator begin() const {
CHECK(group_.lock());
return const_iterator(group_.lock()->consumer_groups().begin());
}
const_iterator end() const {
CHECK(group_.lock());
return const_iterator(group_.lock()->consumer_groups().end());
}
private:
const std::weak_ptr<hlir::framework::Graph::Group> group_;
};
const std::string& group_id() const { return group_.lock()->group_id; }
hlir::framework::OpPatternKind kind() const { return group_.lock()->kind(); }
// The WalkOpNodes function is used to traverse the op_nodes in the group and
// execute the VisitOpNode function for each OpNode. This function is
// equivalent to for loop for op_nodes in graph.
//
// In order to avoid unnecessary memory copies, we use WalkOpNodes function
// instead of providing a function to get all op_nodes directly.
//
// Example: Get the all Reduction op_nodes in the group.
// OpGroup group = ...;
// std::set<api::OpNode> reduce_ op_set;
// // The lambda funtion of VisitOpNode to get reduction op_nodes.
// auto get_reduce_op = [&reduce_op_set](const api::OpNode& op){
// if (op.kind() == OpPatternKind::kReduction) {
// reduce_op_set.insert(op);
// }
// };
// group.WalkOpNodes(get_reduce_op);
void WalkOpNodes(
const std::function<void(const OpNode&)>& VisitOpNode) const {
group_.lock()->WalkNodes([&](const hlir::framework::Node* node) {
VisitOpNode(OpNode(node, group_.lock()->graph_));
});
}
ProducerOpGroupListView producers() const {
return ProducerOpGroupListView(group_);
}
ConsumerOpGroupListView consumers() const {
return ConsumerOpGroupListView(group_);
}
std::shared_ptr<hlir::framework::Graph::Group> GetGroup() const {
return group_.lock();
}
bool operator==(const OpGroup& other) const {
return group_.lock().get() == other.group_.lock().get();
}
bool operator<(const OpGroup& other) const {
return group_.lock().get() < other.group_.lock().get();
}
private:
const std::weak_ptr<hlir::framework::Graph::Group> group_;
};
} // namespace api
} // namespace cinn
namespace std {
template <>
struct hash<cinn::api::OpGroup> {
size_t operator()(const cinn::api::OpGroup& obj) const {
return std::hash<size_t>()(reinterpret_cast<size_t>(obj.GetGroup().get()));
}
};
} // namespace std
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/api/op_node.h"
namespace cinn {
namespace api {
TensorNode OpNode::TensorListIterator::operator*() const {
return TensorNode(get_tensor_from_edge_(*iter_), graph_);
}
TensorNode OpNode::InputTensorListView::operator[](size_t index) const {
return TensorNode(
edges_[index]->source()->safe_as<hlir::framework::NodeData>(), graph_);
}
TensorNode OpNode::OutputTensorListView::operator[](size_t index) const {
return TensorNode(edges_[index]->sink()->safe_as<hlir::framework::NodeData>(),
graph_);
}
} // namespace api
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/cinn/api/tensor_node.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/pass/fusion_helper_base.h"
namespace cinn {
namespace api {
class OpNode {
public:
OpNode(const hlir::framework::Node* node, const hlir::framework::Graph* graph)
: node_(node),
graph_(graph),
input_tensors_(node->inlinks_in_order(), graph_),
output_tensors_(node->outlinks_in_order(), graph_) {}
OpNode(const OpNode& other)
: node_(other.node_),
graph_(other.graph_),
input_tensors_(node_->inlinks_in_order(), graph_),
output_tensors_(node_->outlinks_in_order(), graph_) {}
using OpPatternKind = cinn::hlir::framework::OpPatternKind;
OpPatternKind kind() const {
static const hlir::framework::OpValueType<OpPatternKind>& op_pattern_dict =
hlir::framework::Operator::GetAttrs<OpPatternKind>("OpPattern");
auto kind = op_pattern_dict[node_->op()];
if (kind == hlir::framework::kBroadcast) {
// As binary op was defined as broadcast, actually it should be
// element-wise.
if (node_->op()->name != "broadcast_to") {
return hlir::framework::kElementWise;
}
}
return kind;
}
class TensorListIterator {
public:
TensorListIterator(
std::vector<common::Shared<common::GraphEdge>>::const_iterator it,
const hlir::framework::Graph* graph,
std::function<hlir::framework::NodeData*(
common::Shared<common::GraphEdge>)> get_tensor_from_edge)
: iter_(it),
graph_(graph),
get_tensor_from_edge_(get_tensor_from_edge) {}
TensorListIterator& operator++() {
++iter_;
return *this;
}
TensorListIterator operator++(int) {
TensorListIterator tmp = *this;
++iter_;
return tmp;
}
bool operator==(const TensorListIterator& other) const {
return iter_ == other.iter_;
}
bool operator!=(const TensorListIterator& other) const {
return !(*this == other);
}
TensorNode operator*() const;
private:
std::vector<common::Shared<common::GraphEdge>>::const_iterator iter_;
const hlir::framework::Graph* graph_;
std::function<hlir::framework::NodeData*(common::Shared<common::GraphEdge>)>
get_tensor_from_edge_;
};
using const_iterator = TensorListIterator;
class InputTensorListView {
public:
InputTensorListView(
const std::vector<common::Shared<common::GraphEdge>>& edges,
const hlir::framework::Graph* graph)
: edges_(edges), graph_(graph) {}
InputTensorListView(const InputTensorListView& other) = delete;
InputTensorListView(InputTensorListView&& other) = delete;
InputTensorListView& operator=(const InputTensorListView& other) = delete;
size_t size() const { return edges_.size(); }
TensorNode operator[](size_t index) const;
const_iterator begin() const {
return const_iterator(
edges_.begin(), graph_, [](common::Shared<common::GraphEdge> edge) {
return edge->source()->safe_as<hlir::framework::NodeData>();
});
}
const_iterator end() const {
return const_iterator(
edges_.end(), graph_, [](common::Shared<common::GraphEdge> edge) {
return edge->source()->safe_as<hlir::framework::NodeData>();
});
}
private:
std::vector<common::Shared<common::GraphEdge>> edges_;
const hlir::framework::Graph* graph_;
};
class OutputTensorListView {
public:
OutputTensorListView(
const std::vector<common::Shared<common::GraphEdge>>& edges,
const hlir::framework::Graph* graph)
: edges_(edges), graph_(graph) {}
OutputTensorListView(const OutputTensorListView& other) = delete;
OutputTensorListView(OutputTensorListView&& other) = delete;
OutputTensorListView& operator=(const OutputTensorListView& other) = delete;
size_t size() const { return edges_.size(); }
TensorNode operator[](size_t index) const;
const_iterator begin() const {
return const_iterator(
edges_.begin(), graph_, [](common::Shared<common::GraphEdge> edge) {
return edge->sink()->safe_as<hlir::framework::NodeData>();
});
}
const_iterator end() const {
return const_iterator(
edges_.end(), graph_, [](common::Shared<common::GraphEdge> edge) {
return edge->sink()->safe_as<hlir::framework::NodeData>();
});
}
private:
std::vector<common::Shared<common::GraphEdge>> edges_;
const hlir::framework::Graph* graph_;
};
bool operator==(const OpNode& other) const { return node_ == other.node_; }
bool operator<(const OpNode& other) const { return node_ < other.node_; }
const InputTensorListView& inputs() const { return input_tensors_; }
const OutputTensorListView& outputs() const { return output_tensors_; }
template <typename T>
const T& GetAttr(const std::string& attr_name) const {
return absl::get<T>(GetAttr(attr_name));
}
private:
using Attribute = cinn::utils::Attribute;
const Attribute& GetAttr(const std::string& attr_name) const {
return node_->attrs.attr_store.at(attr_name);
}
friend struct std::hash<OpNode>;
const hlir::framework::Node* node_;
const hlir::framework::Graph* graph_;
const InputTensorListView input_tensors_;
const OutputTensorListView output_tensors_;
};
} // namespace api
} // namespace cinn
namespace std {
template <>
struct hash<cinn::api::OpNode> {
size_t operator()(const cinn::api::OpNode& obj) const {
return std::hash<size_t>()(reinterpret_cast<size_t>(obj.node_));
}
};
} // namespace std
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/pass/fusion_helper_base.h"
#include "paddle/cinn/utils/small_vector.h"
#include "paddle/cinn/utils/type_defs.h"
namespace cinn {
namespace api {
class Shape final {
public:
explicit Shape(const utils::ShapeType& shape)
: shape_(shape.begin(), shape.end()) {}
Shape(const Shape& other) = delete;
Shape(Shape&& other) = delete;
Shape& operator=(const Shape& other) = delete;
bool operator==(const Shape& other) const { return shape_ == other.shape_; }
size_t operator[](size_t index) const { return shape_[index]; }
size_t at(size_t index) const { return shape_[index]; }
size_t size() const { return shape_.size(); }
// Returns the total number of elements in the shape.
size_t numel() const {
return std::accumulate(
shape_.begin(), shape_.end(), 1, std::multiplies<int>());
}
private:
cinn::utils::SmallVector<int64_t, 12> shape_;
};
} // namespace api
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/api/tensor_node.h"
#include "paddle/cinn/api/op_node.h"
namespace cinn {
namespace api {
OpNode TensorNode::producer() const {
return OpNode(node_data_->source_node.get(), graph_);
}
OpNode TensorNode::ConsumerOpListView::Iterator::operator*() const {
return OpNode((*iter_)->sink()->safe_as<hlir::framework::Node>(), graph_);
}
} // namespace api
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/cinn/api/shape.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/pass/fusion_helper_base.h"
#include "paddle/cinn/utils/small_vector.h"
#include "paddle/cinn/utils/type_defs.h"
namespace cinn {
namespace api {
class OpNode;
class TensorNode final {
public:
TensorNode(const hlir::framework::NodeData* node_data,
const hlir::framework::Graph* graph)
: node_data_(node_data),
graph_(graph),
consumers_(node_data_->outlinks(), graph_) {
const auto& shape_dict =
graph_->GetAttrs<absl::flat_hash_map<std::string, utils::ShapeType>>(
"infershape");
CHECK(shape_dict.count(node_data_->id()))
<< "Can't find " << node_data_->id() << " 's shape!";
shape_ = std::make_shared<Shape>(shape_dict.find(node_data_->id())->second);
}
// Get the shape of tensor.
const Shape& shape() const { return *shape_; }
// Input data has no producer.
bool HasProducer() const { return node_data_->source_node.get() != nullptr; }
OpNode producer() const;
class ConsumerOpListView {
public:
ConsumerOpListView(const std::set<common::Shared<common::GraphEdge>,
common::GraphEdgeCompare>& edges,
const hlir::framework::Graph* graph)
: edges_(edges), graph_(graph) {}
ConsumerOpListView(const ConsumerOpListView& other) = delete;
ConsumerOpListView(ConsumerOpListView&& other) = delete;
ConsumerOpListView& operator=(const ConsumerOpListView& other) = delete;
class Iterator {
public:
Iterator(std::set<common::Shared<common::GraphEdge>,
common::GraphEdgeCompare>::const_iterator it,
const hlir::framework::Graph* graph)
: iter_(it), graph_(graph) {}
Iterator& operator++() {
++iter_;
return *this;
}
Iterator operator++(int) {
Iterator tmp = *this;
++iter_;
return tmp;
}
bool operator==(const Iterator& other) const {
return iter_ == other.iter_;
}
bool operator!=(const Iterator& other) const { return !(*this == other); }
OpNode operator*() const;
private:
std::set<common::Shared<common::GraphEdge>,
common::GraphEdgeCompare>::const_iterator iter_;
const hlir::framework::Graph* graph_;
};
size_t size() const { return edges_.size(); }
Iterator begin() const { return Iterator(this->edges_.begin(), graph_); }
Iterator end() const { return Iterator(this->edges_.end(), graph_); }
private:
const std::set<Shared<common::GraphEdge>, common::GraphEdgeCompare>& edges_;
const hlir::framework::Graph* graph_;
};
const ConsumerOpListView& consumers() const { return consumers_; }
private:
const hlir::framework::NodeData* node_data_;
const hlir::framework::Graph* graph_;
std::shared_ptr<Shape> shape_;
const ConsumerOpListView consumers_;
};
} // namespace api
} // namespace cinn
...@@ -23,6 +23,10 @@ gather_srcs( ...@@ -23,6 +23,10 @@ gather_srcs(
message(STATUS "srcs: ${cinnapi_src}") message(STATUS "srcs: ${cinnapi_src}")
cinn_cc_test(test_dfs_walker SRCS dfs_walker_test.cc DEPS gtest glog)
cinn_cc_test(test_is_reachable_predicator SRCS is_reachable_predicator_test.cc
DEPS gtest glog)
cinn_cc_test(test_topo_walker SRCS topo_walker_test.cc DEPS gtest glog)
cinn_cc_test(test_cinn_value SRCS cinn_value_test.cc DEPS cinncore) cinn_cc_test(test_cinn_value SRCS cinn_value_test.cc DEPS cinncore)
cinn_cc_test(test_shared SRCS shared_test.cc DEPS cinncore) cinn_cc_test(test_shared SRCS shared_test.cc DEPS cinncore)
cinn_cc_test(test_graph_utils SRCS graph_utils_test.cc DEPS cinncore) cinn_cc_test(test_graph_utils SRCS graph_utils_test.cc DEPS cinncore)
......
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <array>
#include <functional>
#include <queue>
#include <unordered_set>
namespace cinn {
namespace common {
// breadth-first search visitor
template <typename NodeType>
class BfsWalker final {
public:
BfsWalker(const BfsWalker&) = delete;
BfsWalker(BfsWalker&&) = delete;
using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;
BfsWalker(const NodesVisitorType& VisitNextNodes)
: VisitNextNodes_(VisitNextNodes) {}
void operator()(NodeType node, const NodeHandlerType& NodeHandler) const {
std::array<NodeType, 1> nodes{node};
(*this)(nodes.begin(), nodes.end(), NodeHandler);
}
template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const NodeHandlerType& NodeHandler) const {
std::queue<NodeType> node_queue;
std::unordered_set<NodeType> queued_nodes;
const auto& TryEnqueueNode = [&](NodeType node) {
if (queued_nodes.count(node) == 0) {
node_queue.push(node);
queued_nodes.insert(node);
}
};
for (NodeIt iter = begin; iter != end; ++iter) {
TryEnqueueNode(*iter);
}
while (!node_queue.empty()) {
NodeType node = node_queue.front();
node_queue.pop();
NodeHandler(node);
VisitNextNodes_(node, TryEnqueueNode);
}
}
private:
NodesVisitorType VisitNextNodes_;
};
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <array>
#include <functional>
#include <iostream>
#include <queue>
#include <stack>
#include <unordered_set>
namespace cinn {
namespace common {
// depth-first search visitor
template <typename NodeType>
class DfsWalker final {
public:
DfsWalker(const DfsWalker&) = delete;
DfsWalker(DfsWalker&&) = delete;
using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;
DfsWalker(const NodesVisitorType& VisitNextNodes)
: VisitNextNodes_(VisitNextNodes) {}
void operator()(NodeType node, const NodeHandlerType& NodeHandler) const {
std::array<NodeType, 1> nodes{node};
(*this)(nodes.begin(), nodes.end(), NodeHandler, [&](NodeType) {});
}
template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const NodeHandlerType& NodeHandler) const {
(*this)(begin, end, NodeHandler, [&](NodeType) {});
}
// https://en.wikipedia.org/wiki/Depth-first_search
template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const NodeHandlerType& NodeHandlerOnPush,
const NodeHandlerType& NodeHandlerOnPop) const {
std::unordered_set<NodeType> discovered;
struct Neighbours {
NodeType producer;
std::queue<NodeType> consumers;
};
std::stack<Neighbours> stack;
const auto& TryPush = [&](NodeType node) {
if (discovered.count(node) == 0) {
discovered.insert(node);
NodeHandlerOnPush(node);
stack.push(Neighbours{.producer = node});
VisitNextNodes_(node, [&](NodeType next_node) {
stack.top().consumers.push(next_node);
});
}
};
for (NodeIt node_iter = begin; node_iter != end; ++node_iter) {
TryPush(*node_iter);
while (!stack.empty()) {
auto* neighbours = &stack.top();
if (neighbours->consumers.empty()) {
NodeHandlerOnPop(neighbours->producer);
stack.pop();
} else {
TryPush(neighbours->consumers.front());
neighbours->consumers.pop();
}
}
}
}
private:
NodesVisitorType VisitNextNodes_;
};
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/dfs_walker.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
namespace cinn {
namespace common {
TEST(DfsWalker, simple_on_push) {
DfsWalker<int> visitor(
[](int node, const std::function<void(int)>& NodeHandler) {
if (node == 0) {
NodeHandler(3);
} else if (node == 1) {
NodeHandler(2);
NodeHandler(3);
} else if (node == 2 || node == 3) {
NodeHandler(4);
}
});
std::vector<int> sources{0, 1};
std::vector<int> outputs;
visitor(sources.begin(), sources.end(), [&](int node) {
LOG(ERROR) << node;
outputs.push_back(node);
});
std::vector<int> expected{0, 3, 4, 1, 2};
EXPECT_TRUE((outputs == expected));
}
TEST(DfsWalker, simple_on_pop) {
DfsWalker<int> visitor(
[](int node, const std::function<void(int)>& NodeHandler) {
if (node == 0) {
NodeHandler(3);
} else if (node == 1) {
NodeHandler(2);
NodeHandler(3);
} else if (node == 2 || node == 3) {
NodeHandler(4);
}
});
std::vector<int> sources{0, 1};
std::vector<int> outputs;
visitor(
sources.begin(),
sources.end(),
[](int) {},
[&](int node) {
LOG(ERROR) << node;
outputs.push_back(node);
});
std::vector<int> expected{4, 3, 0, 2, 1};
EXPECT_TRUE((outputs == expected));
}
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <array>
#include <functional>
#include "paddle/cinn/common/bfs_walker.h"
namespace cinn {
namespace common {
template <typename NodeType>
class IsReachablePredicator final {
public:
IsReachablePredicator(const IsReachablePredicator&) = delete;
IsReachablePredicator(IsReachablePredicator&&) = delete;
using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;
using NodeDepthGetterType = std::function<size_t(NodeType)>;
IsReachablePredicator(const NodeDepthGetterType& MinDepth4Node,
const NodeDepthGetterType& MaxDepth4Node,
const NodesVisitorType& VisitNextNodes)
: MinDepth4Node_(MinDepth4Node),
MaxDepth4Node_(MaxDepth4Node),
VisitNextNodes_(VisitNextNodes) {}
bool operator()(NodeType src,
NodeType dst,
const NodeHandlerType& HandleVisited) const {
const size_t dst_max_depth = MaxDepth4Node_(dst);
bool detect_reachable = false;
BfsWalker<NodeType> bfs_walker(
[&](NodeType node, const NodeHandlerType& Handler) {
VisitNextNodes_(node, [&](NodeType out_node) {
if (dst_max_depth < MinDepth4Node_(out_node)) {
// Pruned.
// Do nothing.
} else if (detect_reachable) {
// Pruned.
// Reachability is detected.
} else {
Handler(out_node);
}
});
});
std::array<NodeType, 1> starts{src};
bfs_walker(starts.begin(), starts.end(), [&](NodeType node) {
HandleVisited(node);
if (node == dst) {
detect_reachable = true;
}
});
return detect_reachable;
}
private:
NodeDepthGetterType MinDepth4Node_;
NodeDepthGetterType MaxDepth4Node_;
NodesVisitorType VisitNextNodes_;
};
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/is_reachable_predicator.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
namespace cinn {
namespace common {
TEST(IsReachablePredicator, simple) {
IsReachablePredicator<int> IsReachable(
// Get min depth
[](int x) { return std::abs(x); },
// Get max depth
[](int x) { return std::abs(x); },
// visit next node
[](int x, const std::function<void(int)>& Handler) {
Handler(x + (x / std::abs(x)));
});
EXPECT_TRUE(IsReachable(33, 99, [](int) {}));
EXPECT_FALSE(IsReachable(33, -99, [](int) {}));
}
} // namespace common
} // namespace cinn
...@@ -50,3 +50,38 @@ ...@@ -50,3 +50,38 @@
#else #else
#define CINN_NODISCARD #define CINN_NODISCARD
#endif #endif
#define DISABLE_COPY_AND_ASSIGN(classname) \
private: \
classname(const classname&) = delete; \
classname(classname&&) = delete; \
classname& operator=(const classname&) = delete; \
classname& operator=(classname&&) = delete
/**
* check if MACRO is used in GLOBAL NAMESPACE.
*/
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
#define CINN_REGISTER_FUSION_PASS(pass_name, pass_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_pass__##pass_name, \
"CINN_REGISTER_FUSION_PASS must be called in global namespace"); \
static ::cinn::hlir::pass::FusionPassRegistrar<pass_class> \
__pass_registrar_##pass_name##__(#pass_name); \
int TouchFusionPassRegistrar_##pass_name() { \
__pass_registrar_##pass_name##__.Touch(); \
return 0; \
}
#define USE_FUSION_PASS(pass_name) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_fusion_pass_##pass_name, \
"USE_OP_ITSELF must be called in global namespace"); \
extern int TouchFusionPassRegistrar_##pass_name(); \
[[maybe_unused]] static int __use_fusion_pass_##pass_name##_ = \
TouchFusionPassRegistrar_##pass_name()
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <functional>
#include <list>
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/cinn/common/dfs_walker.h"
namespace cinn {
namespace common {
// strong connnected components visitor
template <typename NodeType>
class SccWalker final {
public:
SccWalker(const SccWalker&) = delete;
SccWalker(SccWalker&&) = delete;
using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;
SccWalker(const NodesVisitorType& VisitPrevNodes,
const NodesVisitorType& VisitNextNodes)
: VisitPrevNodes_(VisitPrevNodes), VisitNextNodes_(VisitNextNodes) {}
using SccHandlerType = std::function<void(const std::vector<NodeType>&)>;
// https://en.wikipedia.org/wiki/Kosaraju%27s_algorithm
template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const SccHandlerType& SccHandler) const {
const std::list<NodeType>& dfs_ordered_nodes = [&]() {
std::list<NodeType> dfs_ordered_nodes;
DfsVisitor<NodeType> visitor(VisitNextNodes_);
visitor(
begin,
end,
/*on push*/ [](NodeType) {},
/*on pop*/
[&](NodeType node) { dfs_ordered_nodes.push_front(node); });
return dfs_ordered_nodes;
}();
std::unordered_map<NodeType, NodeType> node2root;
const auto& VisitPrevNode = [&](NodeType node,
const NodeHandlerType& NodeHandler) {
VisitPrevNodes_(node, [&](NodeType prev_node) {
if (node2root.count(prev_node) == 0) {
NodeHandler(prev_node);
}
});
};
for (NodeType root : dfs_ordered_nodes) {
if (node2root.count(root) > 0) {
continue;
}
std::vector<NodeType> scc;
// Use node2root immutablely inside dfs visitor.
DfsVisitor<NodeType> visitor(VisitPrevNode);
visitor(root, [&](NodeType node) { scc.push_back(node); });
SccHandler(scc);
// Update node2root outside dfs visitor.
for (NodeType node : scc) {
CHECK(node2root.emplace(node, root).second);
}
}
}
private:
NodesVisitorType VisitPrevNodes_;
NodesVisitorType VisitNextNodes_;
};
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/scc_walker.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
namespace cinn {
namespace common {
TEST(SccWalker, trivial) {
std::list<std::pair<int, int>> edges{{0, 3}, {1, 2}, {1, 3}, {2, 4}, {3, 4}};
SccWalker<int> visitor(
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.second == node) {
NodeHandler(pair.first);
}
}
},
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.first == node) {
NodeHandler(pair.second);
}
}
});
std::vector<int> sources{0, 1};
std::vector<std::vector<int>> outputs;
visitor(sources.begin(), sources.end(), [&](const auto& nodes) {
outputs.push_back(nodes);
});
std::vector<std::vector<int>> expected{{1}, {2}, {0}, {3}, {4}};
EXPECT_TRUE((outputs == expected));
}
TEST(SccWalker, circle) {
std::list<std::pair<int, int>> edges{
{0, 1},
{1, 2},
{2, 3},
{3, 4},
{4, 0},
};
SccWalker<int> visitor(
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.second == node) {
NodeHandler(pair.first);
}
}
},
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.first == node) {
NodeHandler(pair.second);
}
}
});
std::vector<int> sources{0};
std::vector<std::vector<int>> outputs;
visitor(sources.begin(), sources.end(), [&](const auto& nodes) {
outputs.push_back(nodes);
});
std::vector<std::vector<int>> expected{{0, 4, 3, 2, 1}};
EXPECT_TRUE((outputs == expected));
}
TEST(SccWalker, double_circle) {
std::list<std::pair<int, int>> edges{
{0, 1},
{1, 0},
{1, 2},
{2, 3},
{3, 2},
};
SccWalker<int> visitor(
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.second == node) {
NodeHandler(pair.first);
}
}
},
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.first == node) {
NodeHandler(pair.second);
}
}
});
std::vector<int> sources{0};
std::vector<std::vector<int>> outputs;
visitor(sources.begin(), sources.end(), [&](const auto& nodes) {
outputs.push_back(nodes);
});
std::vector<std::vector<int>> expected{{0, 1}, {2, 3}};
EXPECT_TRUE((outputs == expected));
}
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <array>
#include <functional>
#include <queue>
#include <unordered_set>
namespace cinn {
namespace common {
// Topological order visitor
template <typename NodeType>
class TopoWalker final {
public:
TopoWalker(const TopoWalker&) = delete;
TopoWalker(TopoWalker&&) = delete;
using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;
TopoWalker(const NodesVisitorType& VisitPrevNodes,
const NodesVisitorType& VisitNextNodes)
: VisitPrevNodes_(VisitPrevNodes), VisitNextNodes_(VisitNextNodes) {}
void operator()(NodeType node, const NodeHandlerType& NodeHandler) const {
std::array<NodeType, 1> nodes{node};
(*this)(nodes.begin(), nodes.end(), NodeHandler);
}
template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const NodeHandlerType& NodeHandler) const {
std::queue<NodeType> node_queue;
std::unordered_set<NodeType> queued_nodes;
const auto& TryEnqueueNode = [&](NodeType node) {
if (queued_nodes.count(node) == 0) {
node_queue.push(node);
queued_nodes.insert(node);
}
};
for (NodeIt iter = begin; iter != end; ++iter) {
TryEnqueueNode(*iter);
}
while (!node_queue.empty()) {
NodeType node = node_queue.front();
node_queue.pop();
NodeHandler(node);
VisitNextNodes_(node, [&](NodeType node) {
size_t num_unfinished_inputs = 0;
VisitPrevNodes_(node, [&](NodeType in_node) {
num_unfinished_inputs += (queued_nodes.count(in_node) > 0 ? 0 : 1);
});
if (num_unfinished_inputs == 0) {
TryEnqueueNode(node);
}
});
}
}
private:
NodesVisitorType VisitPrevNodes_;
NodesVisitorType VisitNextNodes_;
};
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/topo_walker.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
namespace cinn {
namespace common {
TEST(TopoWalker, simple) {
std::vector<std::pair<int, int>> edges{
{0, 3}, {1, 2}, {1, 3}, {2, 3}, {3, 4}};
TopoWalker<int> visitor(
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.second == node) {
NodeHandler(pair.first);
}
}
},
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.first == node) {
NodeHandler(pair.second);
}
}
});
std::vector<int> sources{0, 1};
std::vector<int> outputs;
visitor(sources.begin(), sources.end(), [&](int node) {
outputs.push_back(node);
});
std::vector<int> expected{0, 1, 2, 3, 4};
EXPECT_TRUE((outputs == expected));
}
} // namespace common
} // namespace cinn
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/tensor.h" #include "paddle/cinn/hlir/framework/tensor.h"
#include "paddle/cinn/hlir/op/use_ops.h" #include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_general_pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h"
namespace cinn::frontend { namespace cinn::frontend {
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/cinn/hlir/framework/graph.h" #include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/op/use_ops.h" #include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_general_pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/runtime/flags.h" #include "paddle/cinn/runtime/flags.h"
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "paddle/cinn/hlir/framework/graph.h" #include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/visualize_helper.h" #include "paddle/cinn/hlir/framework/visualize_helper.h"
#include "paddle/cinn/hlir/pass/use_general_pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/runtime/flags.h" #include "paddle/cinn/runtime/flags.h"
...@@ -37,6 +38,7 @@ DECLARE_bool(cinn_use_custom_call); ...@@ -37,6 +38,7 @@ DECLARE_bool(cinn_use_custom_call);
DECLARE_bool(use_reduce_split_pass); DECLARE_bool(use_reduce_split_pass);
DECLARE_bool(cinn_use_dense_merge_pass); DECLARE_bool(cinn_use_dense_merge_pass);
DECLARE_string(cinn_custom_call_deny_ops); DECLARE_string(cinn_custom_call_deny_ops);
DECLARE_bool(general_fusion_merge_pass);
namespace cinn { namespace cinn {
namespace frontend { namespace frontend {
...@@ -96,7 +98,11 @@ OptimizeOptions DefaultTrainingOptimizeOptions() { ...@@ -96,7 +98,11 @@ OptimizeOptions DefaultTrainingOptimizeOptions() {
if (FLAGS_cinn_use_op_fusion) { if (FLAGS_cinn_use_op_fusion) {
options.graph_passes.emplace_back("OpFusionPass"); options.graph_passes.emplace_back("OpFusionPass");
options.graph_passes.emplace_back("FusionMergePass"); if (FLAGS_general_fusion_merge_pass) {
options.graph_passes.emplace_back("GeneralFusionMergePass");
} else {
options.graph_passes.emplace_back("FusionMergePass");
}
} else { } else {
options.graph_passes.emplace_back("BuildNonFusedGroupsPass"); options.graph_passes.emplace_back("BuildNonFusedGroupsPass");
} }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "paddle/cinn/frontend/program_pass.h" #include "paddle/cinn/frontend/program_pass.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h" #include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/pass/use_general_pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h"
namespace cinn::frontend { namespace cinn::frontend {
......
...@@ -58,6 +58,13 @@ class Graph : public cinn::common::Graph { ...@@ -58,6 +58,13 @@ class Graph : public cinn::common::Graph {
std::vector<std::vector<Node*>> groups; std::vector<std::vector<Node*>> groups;
struct Group { struct Group {
Group() = default;
explicit Group(const Graph* graph) : graph_(graph) {}
// The graph that group belongs to.
const Graph* graph_ = nullptr;
// distance to last group. // distance to last group.
int depth{0}; int depth{0};
int max_depth{0}; int max_depth{0};
...@@ -81,6 +88,15 @@ class Graph : public cinn::common::Graph { ...@@ -81,6 +88,15 @@ class Graph : public cinn::common::Graph {
// master node for schedule // master node for schedule
std::unordered_set<Node*> master_nodes; std::unordered_set<Node*> master_nodes;
// fused sub-groups, used for fusion merge pass
std::vector<std::shared_ptr<Group>> fused_sub_groups;
// if as sub-group, used for belong groups.
std::unordered_set<std::shared_ptr<Group>> belong_groups;
// for op lowering.
std::vector<std::string> input_names;
std::vector<std::string> output_names;
struct SharedGroupHasher { struct SharedGroupHasher {
size_t operator()(const std::shared_ptr<Group>& group) const noexcept { size_t operator()(const std::shared_ptr<Group>& group) const noexcept {
return std::hash<uint64_t>()(reinterpret_cast<uint64_t>(group.get())); return std::hash<uint64_t>()(reinterpret_cast<uint64_t>(group.get()));
...@@ -92,27 +108,6 @@ class Graph : public cinn::common::Graph { ...@@ -92,27 +108,6 @@ class Graph : public cinn::common::Graph {
return first.get() == second.get(); return first.get() == second.get();
} }
}; };
// input groups
std::unordered_set<std::shared_ptr<Group>,
SharedGroupHasher,
SharedGroupComparator>
producer_groups;
// output grous
std::unordered_set<std::shared_ptr<Group>,
SharedGroupHasher,
SharedGroupComparator>
consumer_groups;
// fused sub-groups, used for fusion merge pass
std::vector<std::shared_ptr<Group>> fused_sub_groups;
// if as sub-group, used for belong groups.
std::unordered_set<std::shared_ptr<Group>,
SharedGroupHasher,
SharedGroupComparator>
belong_groups;
// for op lowering.
std::vector<std::string> input_names;
std::vector<std::string> output_names;
std::vector<Node*> CollectNodes() { std::vector<Node*> CollectNodes() {
if (fused_sub_groups.size()) { if (fused_sub_groups.size()) {
...@@ -127,6 +122,20 @@ class Graph : public cinn::common::Graph { ...@@ -127,6 +122,20 @@ class Graph : public cinn::common::Graph {
} }
} }
void WalkNodes(const std::function<void(const Node*)>& VisitNode) const {
if (fused_sub_groups.size()) {
for (auto& group : fused_sub_groups) {
for (const auto* node : group->nodes) {
VisitNode(node);
}
}
} else {
for (const auto* node : nodes) {
VisitNode(node);
}
}
}
std::unordered_set<Node*> NodeSet() { std::unordered_set<Node*> NodeSet() {
std::unordered_set<Node*> node_set; std::unordered_set<Node*> node_set;
for (auto node : CollectNodes()) { for (auto node : CollectNodes()) {
...@@ -139,6 +148,49 @@ class Graph : public cinn::common::Graph { ...@@ -139,6 +148,49 @@ class Graph : public cinn::common::Graph {
std::unordered_set<NodeData*> GetOutputNodeDatas(); std::unordered_set<NodeData*> GetOutputNodeDatas();
std::string GetFuncName() { return "fn_" + group_id + unique_id; } std::string GetFuncName() { return "fn_" + group_id + unique_id; }
public:
const std::unordered_set<std::shared_ptr<Group>,
SharedGroupHasher,
SharedGroupComparator>&
producer_groups() const {
return producer_groups_;
}
const std::unordered_set<std::shared_ptr<Group>,
SharedGroupHasher,
SharedGroupComparator>&
consumer_groups() const {
return consumer_groups_;
}
std::unordered_set<std::shared_ptr<Group>,
SharedGroupHasher,
SharedGroupComparator>*
mut_producer_groups() {
return &producer_groups_;
}
std::unordered_set<std::shared_ptr<Group>,
SharedGroupHasher,
SharedGroupComparator>*
mut_consumer_groups() {
return &consumer_groups_;
}
hlir::framework::OpPatternKind kind() const { return op_pattern_kind; }
private:
// input groups
std::unordered_set<std::shared_ptr<Group>,
SharedGroupHasher,
SharedGroupComparator>
producer_groups_;
// output grous
std::unordered_set<std::shared_ptr<Group>,
SharedGroupHasher,
SharedGroupComparator>
consumer_groups_;
}; };
std::vector<std::shared_ptr<Group>> fusion_groups; std::vector<std::shared_ptr<Group>> fusion_groups;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/visualize_helper.h" #include "paddle/cinn/hlir/framework/visualize_helper.h"
#include "paddle/cinn/hlir/pass/use_general_pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h"
namespace cinn { namespace cinn {
......
...@@ -9,6 +9,7 @@ gather_srcs( ...@@ -9,6 +9,7 @@ gather_srcs(
const_propagate.cc const_propagate.cc
op_fusion_pass.cc op_fusion_pass.cc
fusion_merge_pass.cc fusion_merge_pass.cc
general_fusion_merge_pass.cc
dot_merger.cc dot_merger.cc
check_fusion_accuracy_pass.cc check_fusion_accuracy_pass.cc
custom_call_pass.cc custom_call_pass.cc
......
...@@ -62,10 +62,10 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -62,10 +62,10 @@ class FusionMergePassHelper : public FusionHelperBase {
for (auto& sub_group : group->fused_sub_groups) { for (auto& sub_group : group->fused_sub_groups) {
VLOG(3) << " Fused Sub-Group -> " << sub_group->group_id; VLOG(3) << " Fused Sub-Group -> " << sub_group->group_id;
} }
for (auto& producer : group->producer_groups) { for (const auto& producer : group->producer_groups()) {
VLOG(3) << " Producer -> " << producer->group_id; VLOG(3) << " Producer -> " << producer->group_id;
} }
for (auto& consumer : group->consumer_groups) { for (const auto& consumer : group->consumer_groups()) {
VLOG(3) << " Consumer -> " << consumer->group_id; VLOG(3) << " Consumer -> " << consumer->group_id;
} }
} }
...@@ -94,7 +94,7 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -94,7 +94,7 @@ class FusionMergePassHelper : public FusionHelperBase {
continue; continue;
} }
// do horizontal fusion. // do horizontal fusion.
updated |= HorizontalFusion(producer, producer->consumer_groups); updated |= HorizontalFusion(producer, producer->consumer_groups());
} }
if (updated) { if (updated) {
...@@ -115,9 +115,10 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -115,9 +115,10 @@ class FusionMergePassHelper : public FusionHelperBase {
} }
// do horizontal fusion. // do horizontal fusion.
if (!recompute) { if (!recompute) {
updated |= HorizontalFusion(producer, producer->consumer_groups); updated |= HorizontalFusion(producer, producer->consumer_groups());
} }
updated |= VerticalFusion(producer, producer->consumer_groups, recompute); updated |=
VerticalFusion(producer, producer->consumer_groups(), recompute);
} }
// fuse input consumers // fuse input consumers
updated |= FuseInputToConsumers(); updated |= FuseInputToConsumers();
...@@ -151,7 +152,7 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -151,7 +152,7 @@ class FusionMergePassHelper : public FusionHelperBase {
} }
bool exist = false; bool exist = false;
for (auto& producer : group->producer_groups) { for (const auto& producer : group->producer_groups()) {
if (fusion_groups_set.count(producer)) { if (fusion_groups_set.count(producer)) {
VLOG(4) << group->group_id << " " << producer->group_id; VLOG(4) << group->group_id << " " << producer->group_id;
exist = true; exist = true;
...@@ -183,7 +184,7 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -183,7 +184,7 @@ class FusionMergePassHelper : public FusionHelperBase {
} }
std::unordered_set<GroupPtr, Hasher, Comparator> candidates; std::unordered_set<GroupPtr, Hasher, Comparator> candidates;
for (auto& consumer : consumers) { for (const auto& consumer : consumers) {
// relation // relation
auto& relation = fusion_relation_map_[consumer->op_pattern_kind]; auto& relation = fusion_relation_map_[consumer->op_pattern_kind];
// check horizontal relation exist // check horizontal relation exist
...@@ -324,18 +325,18 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -324,18 +325,18 @@ class FusionMergePassHelper : public FusionHelperBase {
fused_group->fused_sub_groups.push_back(consumer); fused_group->fused_sub_groups.push_back(consumer);
} }
// producer group // producer group
for (auto& producer : consumer->producer_groups) { for (auto& producer : *consumer->mut_producer_groups()) {
fused_group->producer_groups.insert(producer); fused_group->mut_producer_groups()->insert(producer);
// update producer's consumer // update producer's consumer
producer->consumer_groups.erase(consumer); producer->mut_consumer_groups()->erase(consumer);
producer->consumer_groups.insert(fused_group); producer->mut_consumer_groups()->insert(fused_group);
} }
// consumer group // consumer group
for (auto& gconsumer : consumer->consumer_groups) { for (auto& gconsumer : *consumer->mut_consumer_groups()) {
fused_group->consumer_groups.insert(gconsumer); fused_group->mut_consumer_groups()->insert(gconsumer);
// update consumer's producer // update consumer's producer
gconsumer->producer_groups.erase(consumer); gconsumer->mut_producer_groups()->erase(consumer);
gconsumer->producer_groups.insert(fused_group); gconsumer->mut_producer_groups()->insert(fused_group);
} }
// belongs group // belongs group
consumer->belong_groups.insert(fused_group); consumer->belong_groups.insert(fused_group);
...@@ -412,7 +413,7 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -412,7 +413,7 @@ class FusionMergePassHelper : public FusionHelperBase {
std::unordered_set<GroupPtr, Hasher, Comparator> fuse_consumers_unsafe; std::unordered_set<GroupPtr, Hasher, Comparator> fuse_consumers_unsafe;
std::unordered_set<GroupPtr, Hasher, Comparator> fuse_consumers; std::unordered_set<GroupPtr, Hasher, Comparator> fuse_consumers;
for (auto& consumer : consumers) { for (const auto& consumer : consumers) {
VLOG(4) << "Check consuemr " << consumer->group_id VLOG(4) << "Check consuemr " << consumer->group_id
<< " can fuse to producer " << producer->group_id; << " can fuse to producer " << producer->group_id;
// if can't fuse // if can't fuse
...@@ -458,7 +459,7 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -458,7 +459,7 @@ class FusionMergePassHelper : public FusionHelperBase {
// if can_fuse_consumers == consumers // if can_fuse_consumers == consumers
// if producer op kind == kElementwise // if producer op kind == kElementwise
// if use recompute // if use recompute
if (fuse_consumers_unsafe.size() == producer->consumer_groups.size() && if (fuse_consumers_unsafe.size() == producer->consumer_groups().size() &&
producer->op_pattern_kind == framework::kElementWise) { producer->op_pattern_kind == framework::kElementWise) {
if (!recompute) { if (!recompute) {
return false; return false;
...@@ -531,11 +532,11 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -531,11 +532,11 @@ class FusionMergePassHelper : public FusionHelperBase {
} }
// producer groups // producer groups
for (auto& group : producer->producer_groups) { for (auto& group : *producer->mut_producer_groups()) {
fused_group->producer_groups.insert(group); fused_group->mut_producer_groups()->insert(group);
// update producer's producer's consumer // update producer's producer's consumer
group->consumer_groups.erase(producer); group->mut_consumer_groups()->erase(producer);
group->consumer_groups.insert(fused_group); group->mut_consumer_groups()->insert(fused_group);
} }
// sub groups // sub groups
...@@ -581,20 +582,20 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -581,20 +582,20 @@ class FusionMergePassHelper : public FusionHelperBase {
} }
// producer nodes // producer nodes
for (auto& group : consumer->producer_groups) { for (auto& group : *consumer->mut_producer_groups()) {
if (group.get() != producer.get()) { if (group.get() != producer.get()) {
fused_group->producer_groups.insert(group); fused_group->mut_producer_groups()->insert(group);
// update consumer's producer's consumer // update consumer's producer's consumer
group->consumer_groups.erase(consumer); group->mut_consumer_groups()->erase(consumer);
group->consumer_groups.insert(fused_group); group->mut_consumer_groups()->insert(fused_group);
} }
} }
// consumer nodes // consumer nodes
for (auto& group : consumer->consumer_groups) { for (auto& group : *consumer->mut_consumer_groups()) {
fused_group->consumer_groups.insert(group); fused_group->mut_consumer_groups()->insert(group);
// update consumer's consumer's producer // update consumer's consumer's producer
group->producer_groups.erase(consumer); group->mut_producer_groups()->erase(consumer);
group->producer_groups.insert(fused_group); group->mut_producer_groups()->insert(fused_group);
} }
// sub group // sub group
...@@ -631,7 +632,7 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -631,7 +632,7 @@ class FusionMergePassHelper : public FusionHelperBase {
for (auto& node : producer->output_nodes) { for (auto& node : producer->output_nodes) {
bool be_output = true; bool be_output = true;
for (auto& consumer : producer->consumer_groups) { for (const auto& consumer : producer->consumer_groups()) {
// if consumer is in fusionable. // if consumer is in fusionable.
if (fusionable_consumers.count(consumer)) { if (fusionable_consumers.count(consumer)) {
if (consumer->input_nodes.count(node)) { if (consumer->input_nodes.count(node)) {
...@@ -658,14 +659,14 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -658,14 +659,14 @@ class FusionMergePassHelper : public FusionHelperBase {
} }
} }
// insert unfusionable consumer groups // insert unfusionable consumer groups
for (auto& consumer : producer->consumer_groups) { for (auto& consumer : *producer->mut_consumer_groups()) {
if (fusionable_consumers.count(consumer)) { if (fusionable_consumers.count(consumer)) {
continue; continue;
} }
master_fuesd_group->consumer_groups.insert(consumer); master_fuesd_group->mut_consumer_groups()->insert(consumer);
// update consumer's producer // update consumer's producer
consumer->producer_groups.erase(producer); consumer->mut_producer_groups()->erase(producer);
consumer->producer_groups.insert(master_fuesd_group); consumer->mut_producer_groups()->insert(master_fuesd_group);
} }
} }
...@@ -699,13 +700,13 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -699,13 +700,13 @@ class FusionMergePassHelper : public FusionHelperBase {
sub_group->nodes_set.insert(producer->CollectNodes()[0]); sub_group->nodes_set.insert(producer->CollectNodes()[0]);
// remove depency. // remove depency.
consumer->input_nodes.erase(producer->CollectNodes()[0]); consumer->input_nodes.erase(producer->CollectNodes()[0]);
consumer->producer_groups.erase(producer); consumer->mut_producer_groups()->erase(producer);
producer->consumer_groups.erase(consumer); producer->mut_consumer_groups()->erase(consumer);
} }
} }
CHECK_GE(producer->consumer_groups.size(), candidates.size()); CHECK_GE(producer->consumer_groups().size(), candidates.size());
if (producer->consumer_groups.size() == 0 && candidates.size() == 0 && if (producer->consumer_groups().size() == 0 && candidates.size() == 0 &&
output_nodes_set_.count(producer->CollectNodes()[0]) == 0) { output_nodes_set_.count(producer->CollectNodes()[0]) == 0) {
producer->belong_groups.insert(*fusionable_consumers->begin()); producer->belong_groups.insert(*fusionable_consumers->begin());
} }
...@@ -714,7 +715,7 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -714,7 +715,7 @@ class FusionMergePassHelper : public FusionHelperBase {
return; return;
} }
// 1 to 1 fusion. // 1 to 1 fusion.
if (producer->consumer_groups.size() == 1) { if (producer->consumer_groups().size() == 1) {
return; return;
} }
...@@ -805,7 +806,7 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -805,7 +806,7 @@ class FusionMergePassHelper : public FusionHelperBase {
while (!candidates.empty()) { while (!candidates.empty()) {
auto& candidate = candidates.front(); auto& candidate = candidates.front();
candidates.pop(); candidates.pop();
for (auto& producer : candidate->producer_groups) { for (const auto& producer : candidate->producer_groups()) {
if (producer.get() == producer_g.get()) { if (producer.get() == producer_g.get()) {
continue; continue;
} }
...@@ -833,7 +834,7 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -833,7 +834,7 @@ class FusionMergePassHelper : public FusionHelperBase {
while (!candidates.empty()) { while (!candidates.empty()) {
auto& candidate = candidates.front(); auto& candidate = candidates.front();
candidates.pop(); candidates.pop();
for (auto& producer : candidate->producer_groups) { for (auto& producer : candidate->producer_groups()) {
if (producer.get() == producer_g.get()) { if (producer.get() == producer_g.get()) {
continue; continue;
} }
...@@ -934,8 +935,8 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -934,8 +935,8 @@ class FusionMergePassHelper : public FusionHelperBase {
belong_group->output_nodes = group->output_nodes; belong_group->output_nodes = group->output_nodes;
belong_group->op_pattern_kind = group->op_pattern_kind; belong_group->op_pattern_kind = group->op_pattern_kind;
belong_group->master_nodes = group->master_nodes; belong_group->master_nodes = group->master_nodes;
belong_group->producer_groups = group->producer_groups; (*belong_group->mut_producer_groups()) = group->producer_groups();
belong_group->consumer_groups = group->consumer_groups; (*belong_group->mut_consumer_groups()) = group->consumer_groups();
belong_group->fused_sub_groups.push_back(group); belong_group->fused_sub_groups.push_back(group);
group->belong_groups.insert(belong_group); group->belong_groups.insert(belong_group);
// replace group to fused_group // replace group to fused_group
...@@ -949,18 +950,19 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -949,18 +950,19 @@ class FusionMergePassHelper : public FusionHelperBase {
std::unordered_set<GroupPtr, Hasher, Comparator> producers; std::unordered_set<GroupPtr, Hasher, Comparator> producers;
std::unordered_set<GroupPtr, Hasher, Comparator> consumers; std::unordered_set<GroupPtr, Hasher, Comparator> consumers;
for (auto& producer : group->producer_groups) { for (const auto& producer : group->producer_groups()) {
CHECK(producer->belong_groups.size()); CHECK(producer->belong_groups.size());
producers.insert(*producer->belong_groups.begin()); producers.insert(*producer->belong_groups.begin());
} }
for (auto& consumer : group->consumer_groups) {
for (auto& consumer : *group->mut_consumer_groups()) {
CHECK(consumer->belong_groups.size()); CHECK(consumer->belong_groups.size());
consumers.insert(*consumer->belong_groups.begin()); consumers.insert(*consumer->belong_groups.begin());
} }
CHECK_EQ(group->producer_groups.size(), producers.size()); CHECK_EQ(group->producer_groups().size(), producers.size());
CHECK_EQ(group->consumer_groups.size(), consumers.size()); CHECK_EQ(group->consumer_groups().size(), consumers.size());
group->producer_groups = producers; (*group->mut_producer_groups()) = producers;
group->consumer_groups = consumers; (*group->mut_consumer_groups()) = consumers;
} }
} }
......
...@@ -73,8 +73,8 @@ CONDITION_FUNC(is_same_size) { ...@@ -73,8 +73,8 @@ CONDITION_FUNC(is_same_size) {
return size_0 == size_1; return size_0 == size_1;
} }
bool is_const_group(const FusionHelperBase* helper, inline bool is_const_group(const FusionHelperBase* helper,
const std::shared_ptr<Graph::Group>& group) { const std::shared_ptr<Graph::Group>& group) {
return group->CollectNodes().size() == 1 && return group->CollectNodes().size() == 1 &&
helper->IsConstOp(group->CollectNodes()[0]); helper->IsConstOp(group->CollectNodes()[0]);
} }
......
此差异已折叠。
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/api/op_group.h"
#include "paddle/cinn/hlir/pass/fusion_merge_pass_util.h"
namespace cinn {
namespace hlir {
namespace pass {
namespace utils {
using framework::OpPatternKind;
using OpGroupPtr = api::OpGroup;
using OpGroupList = std::vector<OpGroupPtr>;
static api::OpNode GetMasterNode(const OpGroupPtr& op_group) {
std::vector<api::OpNode> master_nodes;
op_group.WalkOpNodes([&](const api::OpNode& op) {
if (op.kind() == OpPatternKind::kReduction) {
master_nodes.push_back(op);
}
});
if (!master_nodes.empty()) {
return master_nodes.front();
}
op_group.WalkOpNodes(
[&](const api::OpNode& op) { master_nodes.push_back(op); });
return master_nodes.back();
}
static bool IsSameSize(const OpGroupPtr& src, const OpGroupPtr& dst) {
api::OpNode src_master_node = GetMasterNode(src);
api::OpNode dst_master_node = GetMasterNode(dst);
auto size_0 = src_master_node.outputs()[0].shape().numel();
auto size_1 = dst_master_node.outputs()[0].shape().numel();
return size_0 == size_1;
}
static std::unordered_set<api::OpNode> GetInputOps(const OpGroupPtr& op_group) {
std::unordered_set<api::OpNode> ops_set;
op_group.WalkOpNodes(
[&ops_set](const api::OpNode& op_node) { ops_set.insert(op_node); });
std::unordered_set<api::OpNode> input_ops;
op_group.WalkOpNodes([&](const api::OpNode& op) {
const auto& input_tensors = op.inputs();
for (size_t i = 0; i < input_tensors.size(); ++i) {
if (input_tensors[i].HasProducer()) {
api::OpNode producer = input_tensors[i].producer();
if (ops_set.find(producer) == ops_set.end()) {
input_ops.insert(producer);
}
}
}
});
return input_ops;
}
static std::unordered_set<api::OpNode> GetOutputOps(
const OpGroupPtr& op_group) {
std::unordered_set<api::OpNode> ops_set;
op_group.WalkOpNodes(
[&ops_set](const api::OpNode& op_node) { ops_set.insert(op_node); });
std::unordered_set<api::OpNode> output_ops;
op_group.WalkOpNodes([&](const api::OpNode& op) {
const auto& output_tensors = op.outputs();
for (size_t i = 0; i < output_tensors.size(); ++i) {
const auto& consumers = output_tensors[i].consumers();
for (const auto& consumer : consumers) {
if (ops_set.find(consumer) == ops_set.end()) {
output_ops.insert(consumer);
break;
}
}
}
});
return output_ops;
}
// limit the group args number to less equal 512, as args stack size is 4K.
static bool limit_args(const OpGroupPtr& first, const OpGroupPtr& second) {
std::unordered_set<api::OpNode> args;
for (auto& group : {first, second}) {
for (const auto& node : GetInputOps(group)) {
args.insert(node);
}
for (const auto& node : GetOutputOps(group)) {
args.insert(node);
}
}
if (args.size() > 512) {
return false;
} else {
return true;
}
}
bool WithoutLastDimInReduce(const api::Shape& inshape,
const std::vector<int>& axes) {
// if last axis is in reduce.
if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() ||
std::find(axes.begin(), axes.end(), -1) != axes.end()) {
return false;
}
int sum_last_axes = 1;
for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) {
sum_last_axes *= inshape[idx];
}
if (sum_last_axes > 1) {
return true;
} else {
return false;
}
}
static int GetSharedSize(const api::OpNode& op_node) {
const auto& producers = op_node.inputs();
CHECK_GT(producers.size(), 0);
const auto& inshape = producers[0].shape();
const auto& axes = op_node.GetAttr<std::vector<int>>("dim");
if (WithoutLastDimInReduce(inshape, axes)) {
int lane = 1;
for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) {
lane = inshape[idx];
}
int max_num_threads = common::DefaultNVGPUTarget().max_num_threads();
if (lane > max_num_threads / 2) {
return 0;
}
int index = axes.size() - 1;
for (; index >= 0; --index) {
if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) {
break;
}
lane *= inshape[axes[index]];
if (lane > max_num_threads / 2) {
break;
}
}
// if lane > (max_num_threads / 2),the loop break from lane >
// max_num_threads / 2.
int axis = lane > (max_num_threads / 2) ? axes[index] : axes[index + 1];
if (lane <= max_num_threads) {
return lane * sizeof(float);
} else {
int prefix = inshape[axis];
int tail = lane / prefix;
for (int idx = max_num_threads / tail;
idx > ((max_num_threads / 2) / tail);
--idx) {
if (prefix % idx == 0) {
return idx * tail * sizeof(float);
}
}
int num = max_num_threads / tail;
return num * tail * sizeof(float);
}
}
return 0;
}
static bool ReduceFuseReduce(const OpGroupPtr& first,
const OpGroupPtr& second) {
if (!limit_args(first, second)) {
return false;
}
std::unique_ptr<api::OpNode> reducer_0 = nullptr;
first.WalkOpNodes([&](const api::OpNode& op) {
if (!reducer_0 && op.kind() == OpPatternKind::kReduction) {
reducer_0.reset(new api::OpNode(op));
}
});
CHECK(reducer_0) << "Can't find reduce op in group " << first.group_id();
std::unique_ptr<api::OpNode> reducer_1 = nullptr;
second.WalkOpNodes([&](const api::OpNode& op) {
if (!reducer_1 && op.kind() == OpPatternKind::kReduction) {
reducer_1.reset(new api::OpNode(op));
}
});
CHECK(reducer_1) << "Can't find reduce op in group " << second.group_id();
// check reduce has same input shape and output shape
const auto& reducer_0_input_shape = reducer_0->inputs()[0].shape();
const auto& reducer_0_output_shape = reducer_0->outputs()[0].shape();
const auto& reducer_1_input_shape = reducer_1->inputs()[0].shape();
const auto& reducer_1_output_shape = reducer_1->outputs()[0].shape();
auto reducer_0_reduce_dim = reducer_0->GetAttr<std::vector<int>>("dim");
auto reducer_1_reduce_dim = reducer_1->GetAttr<std::vector<int>>("dim");
for (auto& dim : reducer_0_reduce_dim) {
// if dim = -1, set as shape.size() - 1
if (dim == -1) {
dim = reducer_0_reduce_dim.size() - 1;
}
}
for (auto& dim : reducer_1_reduce_dim) {
// if dim = -1, set as shape.size() - 1
if (dim == -1) {
dim = reducer_1_reduce_dim.size() - 1;
}
}
// check shape is same
if (reducer_0_input_shape == reducer_1_input_shape &&
reducer_0_output_shape == reducer_1_output_shape &&
reducer_0_reduce_dim == reducer_1_reduce_dim) {
auto shared_size = 0;
for (auto& fusion_group : {first, second}) {
fusion_group.WalkOpNodes([&](const api::OpNode& op) {
if (op.kind() == OpPatternKind::kReduction) {
shared_size += GetSharedSize(op);
}
});
}
#define MAX_AVAILABLE_SHREAD 32 * 1024
if (shared_size > MAX_AVAILABLE_SHREAD) {
return false;
}
#undef MAX_AVAILABLE_SHREAD
return true;
}
if (WithoutLastDimInReduce(reducer_0_input_shape, reducer_0_reduce_dim) &&
WithoutLastDimInReduce(reducer_1_input_shape, reducer_1_reduce_dim) &&
reducer_0_output_shape == reducer_1_output_shape &&
reducer_0_reduce_dim == reducer_1_reduce_dim) {
auto shared_size = 0;
for (auto& fusion_group : {first, second}) {
fusion_group.WalkOpNodes([&](const api::OpNode& op) {
if (op.kind() == OpPatternKind::kReduction) {
shared_size += GetSharedSize(op);
}
});
}
#define MAX_AVAILABLE_SHREAD 32 * 1024
if (shared_size > MAX_AVAILABLE_SHREAD) {
return false;
}
#undef MAX_AVAILABLE_SHREAD
return true;
}
return false;
}
} // namespace utils
} // namespace pass
} // namespace hlir
} // namespace cinn
...@@ -49,7 +49,7 @@ class OpFusionPassHelper : public FusionHelperBase { ...@@ -49,7 +49,7 @@ class OpFusionPassHelper : public FusionHelperBase {
auto node = graph_node->safe_as<Node>(); auto node = graph_node->safe_as<Node>();
if (node) { if (node) {
nodes_.push_back(node); nodes_.push_back(node);
auto group = std::make_shared<Graph::Group>(); auto group = std::make_shared<Graph::Group>(graph);
// init group // init group
group->nodes.push_back(node); group->nodes.push_back(node);
group->nodes_set.insert(node); group->nodes_set.insert(node);
...@@ -101,14 +101,14 @@ class OpFusionPassHelper : public FusionHelperBase { ...@@ -101,14 +101,14 @@ class OpFusionPassHelper : public FusionHelperBase {
for (auto& consumer : fusion_groups) { for (auto& consumer : fusion_groups) {
for (auto& input_node : consumer->input_nodes) { for (auto& input_node : consumer->input_nodes) {
auto& producer = fusion_groups_[input_node.first]; auto& producer = fusion_groups_[input_node.first];
consumer->producer_groups.insert(producer); consumer->mut_producer_groups()->insert(producer);
producer->consumer_groups.insert(consumer); producer->mut_consumer_groups()->insert(consumer);
} }
} }
// init group depth. // init group depth.
for (auto& group : fusion_groups) { for (auto& group : fusion_groups) {
for (auto& consumer : group->consumer_groups) { for (const auto& consumer : group->consumer_groups()) {
// update depth. // update depth.
group->depth = std::max(group->depth, consumer->depth + 1); group->depth = std::max(group->depth, consumer->depth + 1);
} }
...@@ -376,10 +376,10 @@ void OpFusionPassInternal(Graph* graph) { ...@@ -376,10 +376,10 @@ void OpFusionPassInternal(Graph* graph) {
for (auto& group : graph->fusion_groups) { for (auto& group : graph->fusion_groups) {
VLOG(3) << "Group Id : " << group->group_id; VLOG(3) << "Group Id : " << group->group_id;
for (auto& producer : group->producer_groups) { for (const auto& producer : group->producer_groups()) {
VLOG(3) << " producer group -> " << producer->group_id; VLOG(3) << " producer group -> " << producer->group_id;
} }
for (auto& consumer : group->consumer_groups) { for (const auto& consumer : group->consumer_groups()) {
VLOG(3) << " consumer group -> " << consumer->group_id; VLOG(3) << " consumer group -> " << consumer->group_id;
} }
} }
......
...@@ -25,6 +25,7 @@ CINN_USE_REGISTER(DCE) ...@@ -25,6 +25,7 @@ CINN_USE_REGISTER(DCE)
CINN_USE_REGISTER(DotMerger) CINN_USE_REGISTER(DotMerger)
CINN_USE_REGISTER(OpFusionPass) CINN_USE_REGISTER(OpFusionPass)
CINN_USE_REGISTER(FusionMergePass) CINN_USE_REGISTER(FusionMergePass)
CINN_USE_REGISTER(GeneralFusionMergePass)
CINN_USE_REGISTER(CheckFusionAccuracyPass) CINN_USE_REGISTER(CheckFusionAccuracyPass)
CINN_USE_REGISTER(CommonSubexpressionEliminationPass) CINN_USE_REGISTER(CommonSubexpressionEliminationPass)
......
...@@ -57,6 +57,10 @@ DEFINE_bool(cinn_use_op_fusion, ...@@ -57,6 +57,10 @@ DEFINE_bool(cinn_use_op_fusion,
BoolFromEnv("FLAGS_cinn_use_op_fusion", true), BoolFromEnv("FLAGS_cinn_use_op_fusion", true),
"Whether to use op fusion pass."); "Whether to use op fusion pass.");
DEFINE_bool(general_fusion_merge_pass,
BoolFromEnv("FLAGS_general_fusion_merge_pass", true),
"Whether to use general fusion_merge pass.");
DEFINE_bool(cinn_use_common_subexpression_elimination, DEFINE_bool(cinn_use_common_subexpression_elimination,
BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination", BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination",
false), false),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册