diff --git a/paddle/cinn/common/CMakeLists.txt b/paddle/cinn/common/CMakeLists.txt index 03acaa40320e56a82119225f14cd56f05099af80..eed3899ec35a177bd6399167147020b303a9ecfa 100644 --- a/paddle/cinn/common/CMakeLists.txt +++ b/paddle/cinn/common/CMakeLists.txt @@ -24,6 +24,7 @@ gather_srcs( message(STATUS "srcs: ${cinnapi_src}") cinn_cc_test(test_dfs_walker SRCS dfs_walker_test.cc DEPS gtest glog) +cinn_cc_test(test_dfs_topo_walker SRCS dfs_topo_walker_test.cc DEPS gtest glog) cinn_cc_test(test_is_reachable_predicator SRCS is_reachable_predicator_test.cc DEPS gtest glog) cinn_cc_test(test_topo_walker SRCS topo_walker_test.cc DEPS gtest glog) diff --git a/paddle/cinn/common/dfs_topo_walker.h b/paddle/cinn/common/dfs_topo_walker.h new file mode 100644 index 0000000000000000000000000000000000000000..b476ec76db8b4dfdfa1350161e7844b1d25955a5 --- /dev/null +++ b/paddle/cinn/common/dfs_topo_walker.h @@ -0,0 +1,112 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +namespace cinn { +namespace common { + +// DFS Topological order walker. +// Try to walk in a depth first manner while ensuring topological order. +// For example: +// Graph: +// 0 -> 1 +// 2 -> 3 +// 0 -> 3 +// 1 -> 3 +// 3 -> 4 +// Start nodes: 0, 2 +// Walking order: 0 -> 1 -> 2 -> 3 -> 4 +template , + typename NodeEqual = std::equal_to> +class DfsTopoWalker final { + public: + DfsTopoWalker(const DfsTopoWalker&) = delete; + DfsTopoWalker(DfsTopoWalker&&) = delete; + + using NodeHandlerType = std::function; + using NodesVisitorType = + std::function; + + DfsTopoWalker(const NodesVisitorType& VisitPreNodes, + const NodesVisitorType& VisitNextNodes) + : VisitPreNodes_(VisitPreNodes), VisitNextNodes_(VisitNextNodes) {} + + // Start walking from 1 node and make every effort to access all nodes that + // meet the walking rules. + // If there are more than 1 nodes with a degree of 0 in a graph, + // only one part will be accessed. + // If you want to access the entire graph, + // you need to provide all starting nodes. + void operator()(NodeType node, const NodeHandlerType& NodeHandler) const { + std::array nodes{node}; + (*this)(nodes.begin(), nodes.end(), NodeHandler); + } + + // Start walking from a collection of node and make every effort to access all + // nodes that meet the walking rules. + // If there are other start nodes in a graph, + // some nodes on the graph will not be accessed. + // If you want to access the entire graph, + // you need to provide all starting nodes. + template + void operator()(NodeIt begin, + NodeIt end, + const NodeHandlerType& NodeHandler) const { + std::stack node_stack; + std::unordered_set visited; + std::unordered_map in_degree; + const auto& InitInDegree = [&](NodeType node) { + if (in_degree.count(node) == 0) { + in_degree[node] = 0; + VisitPreNodes_(node, [&](NodeType in_node) { ++in_degree[node]; }); + } + }; + const auto& UpdateInDegree = [&](NodeType node) { + InitInDegree(node); + --in_degree[node]; + }; + const auto& TryPush = [&](NodeType node) { + InitInDegree(node); + if (visited.count(node) == 0 && in_degree[node] == 0) { + node_stack.push(node); + visited.insert(node); + } + }; + + for (NodeIt iter = begin; iter != end; ++iter) { + TryPush(*iter); + while (!node_stack.empty()) { + NodeType cur = node_stack.top(); + node_stack.pop(); + NodeHandler(cur); + VisitNextNodes_(cur, UpdateInDegree); + VisitNextNodes_(cur, TryPush); + } + } + } + + private: + NodesVisitorType VisitNextNodes_; + NodesVisitorType VisitPreNodes_; +}; + +} // namespace common +} // namespace cinn diff --git a/paddle/cinn/common/dfs_topo_walker_test.cc b/paddle/cinn/common/dfs_topo_walker_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..97f3cab02081fb167afdf0483cf2f8d0cfad8436 --- /dev/null +++ b/paddle/cinn/common/dfs_topo_walker_test.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/cinn/common/dfs_topo_walker.h" + +namespace cinn { +namespace common { + +TEST(DfsTopoWalker, simple) { + std::vector> edges{ + {0, 1}, {2, 3}, {1, 3}, {0, 3}, {3, 4}}; + DfsTopoWalker walker( + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.second == node) { + NodeHandler(pair.first); + } + } + }, + [&](int node, const std::function& NodeHandler) { + for (const auto& pair : edges) { + if (pair.first == node) { + NodeHandler(pair.second); + } + } + }); + std::vector sources{0, 2}; + std::vector outputs; + walker(sources.begin(), sources.end(), [&](int node) { + outputs.push_back(node); + }); + for (auto output : outputs) { + LOG(INFO) << output; + } + std::vector expected{0, 1, 2, 3, 4}; + EXPECT_TRUE((outputs == expected)); +} + +} // namespace common +} // namespace cinn diff --git a/paddle/cinn/ir/CMakeLists.txt b/paddle/cinn/ir/CMakeLists.txt index fdadb140078d2634e3286aa4526adb27a2440dd9..6435f4e9ee54b9cbc3f1751dede84976a6e050bd 100644 --- a/paddle/cinn/ir/CMakeLists.txt +++ b/paddle/cinn/ir/CMakeLists.txt @@ -14,7 +14,8 @@ gather_srcs( module.cc lowered_func.cc intrinsic_ops.cc - layout.cc) + layout.cc + schedule_block_graph.cc) add_subdirectory(op) add_subdirectory(test) diff --git a/paddle/cinn/ir/schedule/ir_schedule_util.cc b/paddle/cinn/ir/schedule/ir_schedule_util.cc index 87b7147d97803820b9a25b903a6a2432961fe27f..b4000ff212cadbdfdb648e43d59c0065197e089c 100644 --- a/paddle/cinn/ir/schedule/ir_schedule_util.cc +++ b/paddle/cinn/ir/schedule/ir_schedule_util.cc @@ -860,11 +860,18 @@ std::vector GetProducers(const Expr& block, const Expr& root) { auto compute_body = block.As() ->schedule_block.As() ->body; + std::string block_name = block.As() + ->schedule_block.As() + ->name; ir::CollectIRNodesWithoutTensor( - compute_body, [&producer_tensor_names](const Expr* x) { + compute_body, [&producer_tensor_names, &block_name](const Expr* x) { auto* load = x->As(); if (load) { producer_tensor_names.insert(load->tensor.as_tensor()->name); + if (load->tensor.as_tensor()->name == block_name) { + producer_tensor_names.insert( + GenReduceInitTensorNameOf(load->tensor.as_tensor()->name)); + } return true; } return false; @@ -896,6 +903,18 @@ std::vector GetConsumers(const Expr& block, const Expr& root) { CHECK(root.As()); std::vector consumers; std::string block_tensor = GetTensor(block)->name; + if (IsReduceInitTensorName(block_tensor)) { + std::string consumer_name = GetOriginalReduceTensorName(block_tensor); + auto consumer = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { + return x->As() && + x->As() + ->schedule_block.As() + ->name == consumer_name; + }); + CHECK_EQ(consumer.size(), 1); + return {*consumer.begin()}; + } + auto find_block = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { return x->As() && *x != block && *x != root; }); @@ -997,10 +1016,12 @@ std::vector CalculateRequiredRegions( // deduce accessed regions of the provided tensor in block by itering each // required block for (const Expr& pro_node : provided_nodes) { - const std::string& provided_tensor_name = + std::string provided_tensor_name = is_store_provided ? pro_node.As()->tensor.as_tensor()->name : pro_node.As()->tensor.as_tensor()->name; - + if (IsReduceInitTensorName(provided_tensor_name)) { + provided_tensor_name = GetOriginalReduceTensorName(provided_tensor_name); + } for (const Expr& req_block : required_blocks) { CHECK(req_block.As()); Expr block_body = diff --git a/paddle/cinn/ir/schedule_block_graph.cc b/paddle/cinn/ir/schedule_block_graph.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5bd14f00fbeb2247fd6bac06b7b812fa7232185 --- /dev/null +++ b/paddle/cinn/ir/schedule_block_graph.cc @@ -0,0 +1,186 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/ir/schedule_block_graph.h" +#include "paddle/cinn/common/dfs_topo_walker.h" +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" +#include "paddle/cinn/ir/utils/ir_printer.h" + +namespace cinn { +namespace ir { + +ScheduleBlockNode::ScheduleBlockNode(Expr block, const IRSchedule& ir_sch) + : ir_sch_(ir_sch) { + CHECK(block.As()) + << "Expr is not a ScheduleBlockRealize: " << block; + id_ = block.As() + ->schedule_block.As() + ->name; + VLOG(5) << "create schedule_block node: " << id_; +} + +Expr ScheduleBlockNode::Block() const { return ir_sch_.GetBlock(id_); } + +std::vector ScheduleBlockNode::ControlStmts() const { + return ir_sch_.GetLoops(id_); +} + +bool EdgeCompare(const common::Shared& a, + const common::Shared& b) { + CHECK_NOTNULL(a.get()); + CHECK_NOTNULL(b.get()); + return a->index() < b->index(); +} +std::vector> +ScheduleBlockNode::OrderedInLinks() const { + std::vector> ordered_links; + for (auto& in_edge : this->inlinks()) { + ordered_links.push_back(in_edge); + CHECK_GE(in_edge->index(), 0) + << "The index of a node's inlinks should be >= 0! Now index is: " + << in_edge->index() << ". Please check."; + } + std::sort(ordered_links.begin(), ordered_links.end(), EdgeCompare); + return ordered_links; +} + +std::vector> +ScheduleBlockNode::OrderedOutLinks() const { + std::vector> ordered_links; + for (auto& out_edge : this->outlinks()) { + ordered_links.push_back(out_edge); + CHECK_GE(out_edge->index(), 0) + << "The index of a node's outlinks should be >= 0! Now index is: " + << out_edge->index() << ". Please check."; + } + std::sort(ordered_links.begin(), ordered_links.end(), EdgeCompare); + return ordered_links; +} + +std::vector ScheduleBlockNode::Producers() const { + std::vector producers; + for (const auto& link : this->OrderedInLinks()) { + producers.push_back(dynamic_cast(link->source())); + } + return producers; +} +std::vector ScheduleBlockNode::Consumers() const { + std::vector consumers; + for (const auto& link : this->OrderedOutLinks()) { + consumers.push_back(dynamic_cast(link->sink())); + } + return consumers; +} + +ScheduleBlockGraph::ScheduleBlockGraph(const IRSchedule& ir_sch) { + Update(ir_sch); +} + +void ScheduleBlockGraph::Update(const IRSchedule& ir_sch) { + nodes_.clear(); + registry_.clear(); + std::vector all_blocks = ir_sch.GetAllBlocks(); + Expr root_block = ir_sch.GetRootBlock(all_blocks[0]); + for (Expr block : all_blocks) { + CHECK(block.As()) + << "Expr is not a ScheduleBlockRealize: " << block; + std::string id = block.As() + ->schedule_block.As() + ->name; + if (id == "root") { + continue; + } + ScheduleBlockNode* node = new ScheduleBlockNode(block, ir_sch); + RegisterNode(id, node); + VLOG(5) << "register schedule_block node: " << id; + block_ids_in_order_.push_back(id); + + std::vector producers = GetProducers(block, root_block); + for (Expr producer : producers) { + CHECK(producer.As()) + << "Expr is not a ScheduleBlockRealize: " << producer; + std::string producer_id = producer.As() + ->schedule_block.As() + ->name; + ScheduleBlockNode* producer_node = RetrieveNode(producer_id); + CHECK(producer_node) << "producer node: " << producer_id + << " does not exist in the graph"; + producer_node->Controls(node); + for (const std::string& upstream_node_id : + producer_node->UpstreamNodes()) { + node->AddUpstreamNode(upstream_node_id); + } + node->AddUpstreamNode(producer_id); + } + + for (const std::string& upstream_node_id : node->UpstreamNodes()) { + RetrieveNode(upstream_node_id)->AddDownstreamNode(id); + } + } +} + +std::vector ScheduleBlockGraph::StartPoints() { + std::vector res; + for (common::GraphNode* node : nodes()) { + if (node->inlinks().empty()) { + res.push_back(dynamic_cast(node)); + } + } + return res; +} + +std::vector ScheduleBlockGraph::EndPoints() { + std::vector res; + for (common::GraphNode* node : nodes()) { + if (node->outlinks().empty()) { + res.push_back(dynamic_cast(node)); + } + } + return res; +} + +void ScheduleBlockGraph::NodesWalk(const NodeHandlerType& NodeHandler) { + for (common::GraphNode* node : nodes()) { + ScheduleBlockNode* cur_node = dynamic_cast(node); + NodeHandler(cur_node); + } +} + +void ScheduleBlockGraph::DFSTopoWalk(const NodeHandlerType& NodeHandler, + bool is_reverse) { + auto VisitPreNodes = [&](const ScheduleBlockNode* node, + const NodeHandlerType& PreNodeHandler) { + std::vector pre_nodes = + is_reverse ? node->Consumers() : node->Producers(); + for (ScheduleBlockNode* pre_node : pre_nodes) { + PreNodeHandler(pre_node); + } + }; + auto VisitNextNodes = [&](const ScheduleBlockNode* node, + const NodeHandlerType& NextNodeHandler) { + std::vector next_nodes = + is_reverse ? node->Producers() : node->Consumers(); + for (ScheduleBlockNode* next_node : next_nodes) { + NextNodeHandler(next_node); + } + }; + common::DfsTopoWalker walker(VisitPreNodes, + VisitNextNodes); + std::vector starts = + is_reverse ? EndPoints() : StartPoints(); + walker(starts.begin(), starts.end(), NodeHandler); +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/schedule_block_graph.h b/paddle/cinn/ir/schedule_block_graph.h new file mode 100644 index 0000000000000000000000000000000000000000..c2ef0788b26436701a8bda47b69c9d194ec46a2f --- /dev/null +++ b/paddle/cinn/ir/schedule_block_graph.h @@ -0,0 +1,200 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/cinn/common/graph_utils.h" +#include "paddle/cinn/hlir/framework/graph.h" +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/utils/ir_mutator.h" +#include "paddle/cinn/ir/utils/ir_printer.h" + +using Group = cinn::hlir::framework::Graph::Group; + +namespace cinn { +namespace ir { + +// Node in units of ScheduleBlock. +class ScheduleBlockNode : public common::GraphNode { + public: + ScheduleBlockNode(Expr block, const IRSchedule& ir_sch); + + // Get the id of this node, which is same as the name of ScheduleBlock. + std::string id() const { return id_; } + + // Get the ScheduleBlockRealize expr + Expr Block() const; + + // Get all control stmts containing the schedule_block, now only the For node + // is being considered. + std::vector ControlStmts() const; + + // Get all the upstream nodes that this node depends on. + std::unordered_set UpstreamNodes() const { + return upstream_nodes_; + } + + // Get all downstream nodes that depend on this node. + std::unordered_set DownstreamNodes() const { + return downstream_nodes_; + } + + // Get the producer node that this node directly depends on + std::vector Producers() const; + + // Get consumer nodes that directly depend on this node. + std::vector Consumers() const; + + void AddUpstreamNode(const std::string& node_id) { + upstream_nodes_.insert(node_id); + } + void AddDownstreamNode(const std::string& node_id) { + downstream_nodes_.insert(node_id); + } + + private: + std::vector> OrderedInLinks() const; + std::vector> OrderedOutLinks() const; + + private: + std::string id_; + std::unordered_set upstream_nodes_; + std::unordered_set downstream_nodes_; + const IRSchedule& ir_sch_; +}; + +// Graph in units of ScheduleBlockNode, each node corresponds to a ScheduleBlock +// in IR. +class ScheduleBlockGraph : public common::Graph { + public: + explicit ScheduleBlockGraph(const IRSchedule& ir_sch); + + // Update graph information according to the new IRSchedule. + void Update(const IRSchedule& ir_sch); + + // Retrieve a node in the graph by id, the id is same as the name of + // ScheduleBlock. + ScheduleBlockNode* RetrieveNode(const std::string& id) { + return dynamic_cast(common::Graph::RetrieveNode(id)); + } + + // Get all block name in order, + // this sequence may become invalid after some schedule operations, + // and an Update() operation is required. + std::list BlockIdsInOrder() const { return block_ids_in_order_; } + + // Get all nodes without input node. + std::vector StartPoints(); + + // Get all nodes without output node. + std::vector EndPoints(); + + // Function used to define the operations to be performed on each node. + using NodeHandlerType = std::function; + + // Walk through each node + // and perform some operations defined by NodeHandler on it. + void NodesWalk(const NodeHandlerType& NodeHandler); + + // Walk through each node topological dfs topo order + // and perform some operations defined by NodeHandler on it. + void DFSTopoWalk(const NodeHandlerType& NodeHandler, bool is_reverse = true); + + private: + std::list block_ids_in_order_; +}; + +/** + * The mutator used to construct the order of blocks and their control + * statements + * + * Example: + * for0: + * for1: + * block0 + * block1 + * block2 + * for2: + * block3 + * block4 + * + * the result is: + * [0]: for0 + * [0, 0]: for1 + * [0, 0, 0]: block0 + * [0, 0, 1]: block1 + * [0, 1]: block2 + * [0, 2]: for2 + * [0, 2, 0]: block3 + * [0, 2, 1]: block4 + */ +struct BlockOrderConstructor : public IRMutator { + std::map, Expr> operator()(ir::Expr* expr) { + IRMutator::Visit(expr, expr); + return block_order_with_ctrl_structure_; + } + + private: + void Visit(const For* x, Expr* op) { + if (global_idx_.empty() || + block_order_with_ctrl_structure_.rbegin()->first.size() == + global_idx_.size()) { + cur_idx_ = -1; + } + global_idx_.push_back(++cur_idx_); + block_order_with_ctrl_structure_.insert(std::make_pair(global_idx_, *op)); + IRMutator::Visit(x, op); + cur_idx_ = global_idx_.back(); + global_idx_.pop_back(); + } + + void Visit(const ScheduleBlockRealize* x, Expr* op) { + if (global_idx_.empty() || + block_order_with_ctrl_structure_.rbegin()->first.size() == + global_idx_.size()) { + cur_idx_ = -1; + } + global_idx_.push_back(++cur_idx_); + block_order_with_ctrl_structure_.insert(std::make_pair(global_idx_, *op)); + if (x->schedule_block.As()->name.substr(0, 4) == "root") { + IRMutator::Visit(x, op); + } + global_idx_.pop_back(); + } + + void Visit(const IfThenElse* x, Expr* op) { + if (global_idx_.empty() || + block_order_with_ctrl_structure_.rbegin()->first.size() == + global_idx_.size()) { + cur_idx_ = -1; + } + global_idx_.push_back(++cur_idx_); + block_order_with_ctrl_structure_.insert(std::make_pair(global_idx_, *op)); + IRMutator::Visit(x, op); + cur_idx_ = global_idx_.back(); + global_idx_.pop_back(); + } + + private: + int cur_idx_; + std::vector global_idx_; + std::map, Expr> block_order_with_ctrl_structure_; +}; + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/tensor.cc b/paddle/cinn/ir/tensor.cc index dbd056df69541503c65af4a6c72027c3ccefa1b4..2bfa6ee7737efa8057572975c81b9a99ce00e7a3 100644 --- a/paddle/cinn/ir/tensor.cc +++ b/paddle/cinn/ir/tensor.cc @@ -599,8 +599,26 @@ Shared CreateStage(Tensor tensor) { return poly::Stage::New(isl_domain, tensor->body(), tensor.self()); } +static constexpr char kReduceInitSuffix[] = "__reduce_init"; + std::string GenReduceInitTensorNameOf(const std::string &tensor_name) { - return tensor_name + "__reduce_init"; + return tensor_name + kReduceInitSuffix; +} + +bool IsReduceInitTensorName(const std::string &tensor_name) { + std::string reduce_init_suffix(kReduceInitSuffix); + return tensor_name.length() > reduce_init_suffix.size() && + tensor_name.substr(tensor_name.length() - reduce_init_suffix.size(), + reduce_init_suffix.size()) == reduce_init_suffix; +} + +std::string GetOriginalReduceTensorName(const std::string &tensor_name) { + std::string reduce_init_suffix(kReduceInitSuffix); + if (IsReduceInitTensorName(tensor_name)) { + return tensor_name.substr(0, + tensor_name.length() - reduce_init_suffix.size()); + } + return tensor_name; } bool _Tensor_::is_reduce_sum() const { diff --git a/paddle/cinn/ir/tensor.h b/paddle/cinn/ir/tensor.h index 437e0f2c5e6054ce7248f08305bed4a0152a9112..8879e35afa98df67de28daf5b3b47a4ed1c96ab9 100644 --- a/paddle/cinn/ir/tensor.h +++ b/paddle/cinn/ir/tensor.h @@ -116,6 +116,10 @@ class Tensor : public ir::IrNodeRef { */ std::string GenReduceInitTensorNameOf(const std::string& tensor_name); +bool IsReduceInitTensorName(const std::string& tensor_name); + +std::string GetOriginalReduceTensorName(const std::string& tensor_name); + class ComputeOp; class PlaceholderOp; struct ReadCacheRelation; diff --git a/paddle/cinn/ir/test/CMakeLists.txt b/paddle/cinn/ir/test/CMakeLists.txt index bef31ed067e3b82e00d7621fc14503e00df1e28d..e503f5ebfd9648a3c1db56d75d4dadcce1e5dadc 100644 --- a/paddle/cinn/ir/test/CMakeLists.txt +++ b/paddle/cinn/ir/test/CMakeLists.txt @@ -17,3 +17,5 @@ cinn_cc_test(test_ir_verify SRCS ir_verify_test.cc DEPS cinncore) cinn_cc_test(test_schedule_desc SRCS schedule_desc_test.cc DEPS cinncore) cinn_cc_test(test_ir_compare SRCS ir_compare_test.cc DEPS cinncore) cinn_cc_test(test_ir_copy SRCS ir_copy_test.cc DEPS cinncore) +cinn_cc_test(test_schedule_block_graph SRCS schedule_block_graph_test.cc DEPS + cinncore) diff --git a/paddle/cinn/ir/test/schedule_block_graph_test.cc b/paddle/cinn/ir/test/schedule_block_graph_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..52dd018ca39afb020da7f3bfa6f5f55ae4cbd3b7 --- /dev/null +++ b/paddle/cinn/ir/test/schedule_block_graph_test.cc @@ -0,0 +1,179 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/ir/schedule_block_graph.h" +#include +#include "paddle/cinn/frontend/net_builder.h" +#include "paddle/cinn/frontend/optimize.h" +#include "paddle/cinn/frontend/syntax.h" +#include "paddle/cinn/hlir/framework/op_lowering.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" + +namespace cinn { +namespace ir { + +IRSchedule MakeIRSchedule(frontend::Program* program) { +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + std::unordered_set fetch_ids; + auto graph = frontend::Optimize(program, fetch_ids, target); + LOG_IF(WARNING, graph->fusion_groups.size() > 1) + << "Test Graph has more than 1 group"; + auto& dtype_dict = + graph->GetMutableAttrs>( + "inferdtype"); + auto& shape_dict = graph->GetMutableAttrs< + absl::flat_hash_map>("infershape"); + hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target); + + std::vector lowered_funcs = + op_lowerer.Lower(graph->fusion_groups.front(), false, false); + CHECK(!lowered_funcs.empty()) << "lowered_funcs_ is empty"; + + std::vector bodys; + for (auto&& func : lowered_funcs) { + bodys.emplace_back(func->body); + } + return IRSchedule(ModuleExpr({std::move(bodys)}), 1); +} + +std::string GetIR(const ir::IRSchedule& schedule) { + const auto& exprs = schedule.GetModule().GetExprs(); + std::stringstream module_stream; + for (auto i = 0; i < exprs.size(); ++i) { + module_stream << "Expr " << i << " {\n" + << exprs.at(i) << "\n} // end Expr " << i << "\n"; + } + return module_stream.str(); +} + +frontend::Program CreateElementwiseProgram() { + constexpr int M = 32; + constexpr int N = 24; + + frontend::NetBuilder builder("net_builder"); + auto a = builder.CreateInput(Float(32), {M, N}, "A"); + auto b = builder.CreateInput(Float(32), {M, N}, "B"); + auto c = builder.Add(a, b); + auto d = builder.Add(a, c); + auto e = builder.Relu(c); + auto f = builder.Relu(d); + auto program = builder.Build(); + + return program; +} + +frontend::Program CreateReduceProgram() { + constexpr int M = 64; + constexpr int N = 128; + + frontend::NetBuilder builder("net_builder"); + auto a = builder.CreateInput(Float(32), {M, N}, "A"); + auto b = builder.CreateInput(Float(32), {M, N}, "B"); + auto c = builder.Add(a, b); + auto d = builder.ReduceSum(c, {0}); + auto e = builder.BroadcastTo(d, {M, N}); + auto f = builder.Add(e, a); + auto program = builder.Build(); + + return program; +} + +TEST(ScheduleBlockGraph, elementwise) { + frontend::Program program = CreateElementwiseProgram(); + IRSchedule ir_sch = MakeIRSchedule(&program); + ScheduleBlockGraph sbg(ir_sch); + LOG(INFO) << GetIR(ir_sch); + LOG(INFO) << sbg.Visualize(); + CHECK_EQ(sbg.BlockIdsInOrder().size(), 6); + CHECK_EQ(sbg.nodes().size(), 6); + + ScheduleBlockNode* v2 = sbg.RetrieveNode("var_2"); + CHECK(v2); + CHECK_EQ(v2->UpstreamNodes().size(), 1); + CHECK_EQ(v2->DownstreamNodes().size(), 1); + + ScheduleBlockNode* v4 = sbg.RetrieveNode("var_4"); + CHECK(v4); + CHECK_EQ(v4->UpstreamNodes().size(), 3); + CHECK_EQ(v4->DownstreamNodes().size(), 0); + + std::vector reverse_dfs_topo_order_ids; + sbg.DFSTopoWalk([&reverse_dfs_topo_order_ids](const ScheduleBlockNode* node) { + reverse_dfs_topo_order_ids.push_back(node->id()); + }); + for (const std::string& id : reverse_dfs_topo_order_ids) { + LOG(INFO) << id; + } + CHECK_EQ(reverse_dfs_topo_order_ids.size(), 6); + + std::vector dfs_topo_order_ids; + sbg.DFSTopoWalk( + [&dfs_topo_order_ids](const ScheduleBlockNode* node) { + dfs_topo_order_ids.push_back(node->id()); + }, + false); + for (const std::string& id : dfs_topo_order_ids) { + LOG(INFO) << id; + } + CHECK_EQ(dfs_topo_order_ids.size(), 6); +} + +#ifdef CINN_WITH_CUDA +TEST(ScheduleBlockGraph, reduce) { + frontend::Program program = CreateReduceProgram(); + IRSchedule ir_sch = MakeIRSchedule(&program); + ScheduleBlockGraph sbg(ir_sch); + LOG(INFO) << GetIR(ir_sch); + LOG(INFO) << sbg.Visualize(); + CHECK_EQ(sbg.BlockIdsInOrder().size(), 8); + CHECK_EQ(sbg.nodes().size(), 8); + + ScheduleBlockNode* v_reduce_init = sbg.RetrieveNode("var_48__reduce_init"); + CHECK(v_reduce_init); + CHECK_EQ(v_reduce_init->UpstreamNodes().size(), 0); + CHECK_EQ(v_reduce_init->DownstreamNodes().size(), 3); + + ScheduleBlockNode* v = sbg.RetrieveNode("var_48"); + CHECK(v); + CHECK_EQ(v->UpstreamNodes().size(), 5); + CHECK_EQ(v->DownstreamNodes().size(), 2); + + std::vector reverse_dfs_topo_order_ids; + sbg.DFSTopoWalk([&reverse_dfs_topo_order_ids](const ScheduleBlockNode* node) { + reverse_dfs_topo_order_ids.push_back(node->id()); + }); + for (const std::string& id : reverse_dfs_topo_order_ids) { + LOG(INFO) << id; + } + CHECK_EQ(reverse_dfs_topo_order_ids.size(), 8); + + std::vector dfs_topo_order_ids; + sbg.DFSTopoWalk( + [&dfs_topo_order_ids](const ScheduleBlockNode* node) { + dfs_topo_order_ids.push_back(node->id()); + }, + false); + for (const std::string& id : dfs_topo_order_ids) { + LOG(INFO) << id; + } + CHECK_EQ(dfs_topo_order_ids.size(), 8); +} +#endif + +} // namespace ir +} // namespace cinn