未验证 提交 da72707f 编写于 作者: B BiynXu 提交者: GitHub

[CINN] Add ScheduleBlock graph (#56122)

Added a graph data structure in units of ScheduleBlock and some necessary operations, such as finding upstream and downstream nodes, and performing operations in the DFS topological order.
上级 2951521a
......@@ -24,6 +24,7 @@ gather_srcs(
message(STATUS "srcs: ${cinnapi_src}")
cinn_cc_test(test_dfs_walker SRCS dfs_walker_test.cc DEPS gtest glog)
cinn_cc_test(test_dfs_topo_walker SRCS dfs_topo_walker_test.cc DEPS gtest glog)
cinn_cc_test(test_is_reachable_predicator SRCS is_reachable_predicator_test.cc
DEPS gtest glog)
cinn_cc_test(test_topo_walker SRCS topo_walker_test.cc DEPS gtest glog)
......
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <array>
#include <functional>
#include <stack>
#include <unordered_set>
namespace cinn {
namespace common {
// DFS Topological order walker.
// Try to walk in a depth first manner while ensuring topological order.
// For example:
// Graph:
// 0 -> 1
// 2 -> 3
// 0 -> 3
// 1 -> 3
// 3 -> 4
// Start nodes: 0, 2
// Walking order: 0 -> 1 -> 2 -> 3 -> 4
template <typename NodeType,
typename NodeHash = std::hash<NodeType>,
typename NodeEqual = std::equal_to<NodeType>>
class DfsTopoWalker final {
public:
DfsTopoWalker(const DfsTopoWalker&) = delete;
DfsTopoWalker(DfsTopoWalker&&) = delete;
using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;
DfsTopoWalker(const NodesVisitorType& VisitPreNodes,
const NodesVisitorType& VisitNextNodes)
: VisitPreNodes_(VisitPreNodes), VisitNextNodes_(VisitNextNodes) {}
// Start walking from 1 node and make every effort to access all nodes that
// meet the walking rules.
// If there are more than 1 nodes with a degree of 0 in a graph,
// only one part will be accessed.
// If you want to access the entire graph,
// you need to provide all starting nodes.
void operator()(NodeType node, const NodeHandlerType& NodeHandler) const {
std::array<NodeType, 1> nodes{node};
(*this)(nodes.begin(), nodes.end(), NodeHandler);
}
// Start walking from a collection of node and make every effort to access all
// nodes that meet the walking rules.
// If there are other start nodes in a graph,
// some nodes on the graph will not be accessed.
// If you want to access the entire graph,
// you need to provide all starting nodes.
template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const NodeHandlerType& NodeHandler) const {
std::stack<NodeType> node_stack;
std::unordered_set<NodeType, NodeHash, NodeEqual> visited;
std::unordered_map<NodeType, int, NodeHash, NodeEqual> in_degree;
const auto& InitInDegree = [&](NodeType node) {
if (in_degree.count(node) == 0) {
in_degree[node] = 0;
VisitPreNodes_(node, [&](NodeType in_node) { ++in_degree[node]; });
}
};
const auto& UpdateInDegree = [&](NodeType node) {
InitInDegree(node);
--in_degree[node];
};
const auto& TryPush = [&](NodeType node) {
InitInDegree(node);
if (visited.count(node) == 0 && in_degree[node] == 0) {
node_stack.push(node);
visited.insert(node);
}
};
for (NodeIt iter = begin; iter != end; ++iter) {
TryPush(*iter);
while (!node_stack.empty()) {
NodeType cur = node_stack.top();
node_stack.pop();
NodeHandler(cur);
VisitNextNodes_(cur, UpdateInDegree);
VisitNextNodes_(cur, TryPush);
}
}
}
private:
NodesVisitorType VisitNextNodes_;
NodesVisitorType VisitPreNodes_;
};
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/cinn/common/dfs_topo_walker.h"
namespace cinn {
namespace common {
TEST(DfsTopoWalker, simple) {
std::vector<std::pair<int, int>> edges{
{0, 1}, {2, 3}, {1, 3}, {0, 3}, {3, 4}};
DfsTopoWalker<int> walker(
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.second == node) {
NodeHandler(pair.first);
}
}
},
[&](int node, const std::function<void(int)>& NodeHandler) {
for (const auto& pair : edges) {
if (pair.first == node) {
NodeHandler(pair.second);
}
}
});
std::vector<int> sources{0, 2};
std::vector<int> outputs;
walker(sources.begin(), sources.end(), [&](int node) {
outputs.push_back(node);
});
for (auto output : outputs) {
LOG(INFO) << output;
}
std::vector<int> expected{0, 1, 2, 3, 4};
EXPECT_TRUE((outputs == expected));
}
} // namespace common
} // namespace cinn
......@@ -14,7 +14,8 @@ gather_srcs(
module.cc
lowered_func.cc
intrinsic_ops.cc
layout.cc)
layout.cc
schedule_block_graph.cc)
add_subdirectory(op)
add_subdirectory(test)
......
......@@ -860,11 +860,18 @@ std::vector<Expr> GetProducers(const Expr& block, const Expr& root) {
auto compute_body = block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->body;
std::string block_name = block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
ir::CollectIRNodesWithoutTensor(
compute_body, [&producer_tensor_names](const Expr* x) {
compute_body, [&producer_tensor_names, &block_name](const Expr* x) {
auto* load = x->As<ir::Load>();
if (load) {
producer_tensor_names.insert(load->tensor.as_tensor()->name);
if (load->tensor.as_tensor()->name == block_name) {
producer_tensor_names.insert(
GenReduceInitTensorNameOf(load->tensor.as_tensor()->name));
}
return true;
}
return false;
......@@ -896,6 +903,18 @@ std::vector<Expr> GetConsumers(const Expr& block, const Expr& root) {
CHECK(root.As<ir::ScheduleBlockRealize>());
std::vector<Expr> consumers;
std::string block_tensor = GetTensor(block)->name;
if (IsReduceInitTensorName(block_tensor)) {
std::string consumer_name = GetOriginalReduceTensorName(block_tensor);
auto consumer = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::ScheduleBlockRealize>() &&
x->As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name == consumer_name;
});
CHECK_EQ(consumer.size(), 1);
return {*consumer.begin()};
}
auto find_block = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::ScheduleBlockRealize>() && *x != block && *x != root;
});
......@@ -997,10 +1016,12 @@ std::vector<IterRange> CalculateRequiredRegions(
// deduce accessed regions of the provided tensor in block by itering each
// required block
for (const Expr& pro_node : provided_nodes) {
const std::string& provided_tensor_name =
std::string provided_tensor_name =
is_store_provided ? pro_node.As<ir::Store>()->tensor.as_tensor()->name
: pro_node.As<ir::Load>()->tensor.as_tensor()->name;
if (IsReduceInitTensorName(provided_tensor_name)) {
provided_tensor_name = GetOriginalReduceTensorName(provided_tensor_name);
}
for (const Expr& req_block : required_blocks) {
CHECK(req_block.As<ir::ScheduleBlockRealize>());
Expr block_body =
......
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ir/schedule_block_graph.h"
#include "paddle/cinn/common/dfs_topo_walker.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn {
namespace ir {
ScheduleBlockNode::ScheduleBlockNode(Expr block, const IRSchedule& ir_sch)
: ir_sch_(ir_sch) {
CHECK(block.As<ScheduleBlockRealize>())
<< "Expr is not a ScheduleBlockRealize: " << block;
id_ = block.As<ScheduleBlockRealize>()
->schedule_block.As<ScheduleBlock>()
->name;
VLOG(5) << "create schedule_block node: " << id_;
}
Expr ScheduleBlockNode::Block() const { return ir_sch_.GetBlock(id_); }
std::vector<Expr> ScheduleBlockNode::ControlStmts() const {
return ir_sch_.GetLoops(id_);
}
bool EdgeCompare(const common::Shared<common::GraphEdge>& a,
const common::Shared<common::GraphEdge>& b) {
CHECK_NOTNULL(a.get());
CHECK_NOTNULL(b.get());
return a->index() < b->index();
}
std::vector<common::Shared<common::GraphEdge>>
ScheduleBlockNode::OrderedInLinks() const {
std::vector<common::Shared<common::GraphEdge>> ordered_links;
for (auto& in_edge : this->inlinks()) {
ordered_links.push_back(in_edge);
CHECK_GE(in_edge->index(), 0)
<< "The index of a node's inlinks should be >= 0! Now index is: "
<< in_edge->index() << ". Please check.";
}
std::sort(ordered_links.begin(), ordered_links.end(), EdgeCompare);
return ordered_links;
}
std::vector<common::Shared<common::GraphEdge>>
ScheduleBlockNode::OrderedOutLinks() const {
std::vector<common::Shared<common::GraphEdge>> ordered_links;
for (auto& out_edge : this->outlinks()) {
ordered_links.push_back(out_edge);
CHECK_GE(out_edge->index(), 0)
<< "The index of a node's outlinks should be >= 0! Now index is: "
<< out_edge->index() << ". Please check.";
}
std::sort(ordered_links.begin(), ordered_links.end(), EdgeCompare);
return ordered_links;
}
std::vector<ScheduleBlockNode*> ScheduleBlockNode::Producers() const {
std::vector<ScheduleBlockNode*> producers;
for (const auto& link : this->OrderedInLinks()) {
producers.push_back(dynamic_cast<ScheduleBlockNode*>(link->source()));
}
return producers;
}
std::vector<ScheduleBlockNode*> ScheduleBlockNode::Consumers() const {
std::vector<ScheduleBlockNode*> consumers;
for (const auto& link : this->OrderedOutLinks()) {
consumers.push_back(dynamic_cast<ScheduleBlockNode*>(link->sink()));
}
return consumers;
}
ScheduleBlockGraph::ScheduleBlockGraph(const IRSchedule& ir_sch) {
Update(ir_sch);
}
void ScheduleBlockGraph::Update(const IRSchedule& ir_sch) {
nodes_.clear();
registry_.clear();
std::vector<Expr> all_blocks = ir_sch.GetAllBlocks();
Expr root_block = ir_sch.GetRootBlock(all_blocks[0]);
for (Expr block : all_blocks) {
CHECK(block.As<ScheduleBlockRealize>())
<< "Expr is not a ScheduleBlockRealize: " << block;
std::string id = block.As<ScheduleBlockRealize>()
->schedule_block.As<ScheduleBlock>()
->name;
if (id == "root") {
continue;
}
ScheduleBlockNode* node = new ScheduleBlockNode(block, ir_sch);
RegisterNode(id, node);
VLOG(5) << "register schedule_block node: " << id;
block_ids_in_order_.push_back(id);
std::vector<Expr> producers = GetProducers(block, root_block);
for (Expr producer : producers) {
CHECK(producer.As<ScheduleBlockRealize>())
<< "Expr is not a ScheduleBlockRealize: " << producer;
std::string producer_id = producer.As<ScheduleBlockRealize>()
->schedule_block.As<ScheduleBlock>()
->name;
ScheduleBlockNode* producer_node = RetrieveNode(producer_id);
CHECK(producer_node) << "producer node: " << producer_id
<< " does not exist in the graph";
producer_node->Controls(node);
for (const std::string& upstream_node_id :
producer_node->UpstreamNodes()) {
node->AddUpstreamNode(upstream_node_id);
}
node->AddUpstreamNode(producer_id);
}
for (const std::string& upstream_node_id : node->UpstreamNodes()) {
RetrieveNode(upstream_node_id)->AddDownstreamNode(id);
}
}
}
std::vector<ScheduleBlockNode*> ScheduleBlockGraph::StartPoints() {
std::vector<ScheduleBlockNode*> res;
for (common::GraphNode* node : nodes()) {
if (node->inlinks().empty()) {
res.push_back(dynamic_cast<ScheduleBlockNode*>(node));
}
}
return res;
}
std::vector<ScheduleBlockNode*> ScheduleBlockGraph::EndPoints() {
std::vector<ScheduleBlockNode*> res;
for (common::GraphNode* node : nodes()) {
if (node->outlinks().empty()) {
res.push_back(dynamic_cast<ScheduleBlockNode*>(node));
}
}
return res;
}
void ScheduleBlockGraph::NodesWalk(const NodeHandlerType& NodeHandler) {
for (common::GraphNode* node : nodes()) {
ScheduleBlockNode* cur_node = dynamic_cast<ScheduleBlockNode*>(node);
NodeHandler(cur_node);
}
}
void ScheduleBlockGraph::DFSTopoWalk(const NodeHandlerType& NodeHandler,
bool is_reverse) {
auto VisitPreNodes = [&](const ScheduleBlockNode* node,
const NodeHandlerType& PreNodeHandler) {
std::vector<ScheduleBlockNode*> pre_nodes =
is_reverse ? node->Consumers() : node->Producers();
for (ScheduleBlockNode* pre_node : pre_nodes) {
PreNodeHandler(pre_node);
}
};
auto VisitNextNodes = [&](const ScheduleBlockNode* node,
const NodeHandlerType& NextNodeHandler) {
std::vector<ScheduleBlockNode*> next_nodes =
is_reverse ? node->Producers() : node->Consumers();
for (ScheduleBlockNode* next_node : next_nodes) {
NextNodeHandler(next_node);
}
};
common::DfsTopoWalker<ScheduleBlockNode*> walker(VisitPreNodes,
VisitNextNodes);
std::vector<ScheduleBlockNode*> starts =
is_reverse ? EndPoints() : StartPoints();
walker(starts.begin(), starts.end(), NodeHandler);
}
} // namespace ir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <stack>
#include "paddle/cinn/common/graph_utils.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
using Group = cinn::hlir::framework::Graph::Group;
namespace cinn {
namespace ir {
// Node in units of ScheduleBlock.
class ScheduleBlockNode : public common::GraphNode {
public:
ScheduleBlockNode(Expr block, const IRSchedule& ir_sch);
// Get the id of this node, which is same as the name of ScheduleBlock.
std::string id() const { return id_; }
// Get the ScheduleBlockRealize expr
Expr Block() const;
// Get all control stmts containing the schedule_block, now only the For node
// is being considered.
std::vector<Expr> ControlStmts() const;
// Get all the upstream nodes that this node depends on.
std::unordered_set<std::string> UpstreamNodes() const {
return upstream_nodes_;
}
// Get all downstream nodes that depend on this node.
std::unordered_set<std::string> DownstreamNodes() const {
return downstream_nodes_;
}
// Get the producer node that this node directly depends on
std::vector<ScheduleBlockNode*> Producers() const;
// Get consumer nodes that directly depend on this node.
std::vector<ScheduleBlockNode*> Consumers() const;
void AddUpstreamNode(const std::string& node_id) {
upstream_nodes_.insert(node_id);
}
void AddDownstreamNode(const std::string& node_id) {
downstream_nodes_.insert(node_id);
}
private:
std::vector<common::Shared<common::GraphEdge>> OrderedInLinks() const;
std::vector<common::Shared<common::GraphEdge>> OrderedOutLinks() const;
private:
std::string id_;
std::unordered_set<std::string> upstream_nodes_;
std::unordered_set<std::string> downstream_nodes_;
const IRSchedule& ir_sch_;
};
// Graph in units of ScheduleBlockNode, each node corresponds to a ScheduleBlock
// in IR.
class ScheduleBlockGraph : public common::Graph {
public:
explicit ScheduleBlockGraph(const IRSchedule& ir_sch);
// Update graph information according to the new IRSchedule.
void Update(const IRSchedule& ir_sch);
// Retrieve a node in the graph by id, the id is same as the name of
// ScheduleBlock.
ScheduleBlockNode* RetrieveNode(const std::string& id) {
return dynamic_cast<ScheduleBlockNode*>(common::Graph::RetrieveNode(id));
}
// Get all block name in order,
// this sequence may become invalid after some schedule operations,
// and an Update() operation is required.
std::list<std::string> BlockIdsInOrder() const { return block_ids_in_order_; }
// Get all nodes without input node.
std::vector<ScheduleBlockNode*> StartPoints();
// Get all nodes without output node.
std::vector<ScheduleBlockNode*> EndPoints();
// Function used to define the operations to be performed on each node.
using NodeHandlerType = std::function<void(ScheduleBlockNode*)>;
// Walk through each node
// and perform some operations defined by NodeHandler on it.
void NodesWalk(const NodeHandlerType& NodeHandler);
// Walk through each node topological dfs topo order
// and perform some operations defined by NodeHandler on it.
void DFSTopoWalk(const NodeHandlerType& NodeHandler, bool is_reverse = true);
private:
std::list<std::string> block_ids_in_order_;
};
/**
* The mutator used to construct the order of blocks and their control
* statements
*
* Example:
* for0:
* for1:
* block0
* block1
* block2
* for2:
* block3
* block4
*
* the result is:
* [0]: for0
* [0, 0]: for1
* [0, 0, 0]: block0
* [0, 0, 1]: block1
* [0, 1]: block2
* [0, 2]: for2
* [0, 2, 0]: block3
* [0, 2, 1]: block4
*/
struct BlockOrderConstructor : public IRMutator<Expr*> {
std::map<std::vector<int>, Expr> operator()(ir::Expr* expr) {
IRMutator::Visit(expr, expr);
return block_order_with_ctrl_structure_;
}
private:
void Visit(const For* x, Expr* op) {
if (global_idx_.empty() ||
block_order_with_ctrl_structure_.rbegin()->first.size() ==
global_idx_.size()) {
cur_idx_ = -1;
}
global_idx_.push_back(++cur_idx_);
block_order_with_ctrl_structure_.insert(std::make_pair(global_idx_, *op));
IRMutator<Expr*>::Visit(x, op);
cur_idx_ = global_idx_.back();
global_idx_.pop_back();
}
void Visit(const ScheduleBlockRealize* x, Expr* op) {
if (global_idx_.empty() ||
block_order_with_ctrl_structure_.rbegin()->first.size() ==
global_idx_.size()) {
cur_idx_ = -1;
}
global_idx_.push_back(++cur_idx_);
block_order_with_ctrl_structure_.insert(std::make_pair(global_idx_, *op));
if (x->schedule_block.As<ScheduleBlock>()->name.substr(0, 4) == "root") {
IRMutator<Expr*>::Visit(x, op);
}
global_idx_.pop_back();
}
void Visit(const IfThenElse* x, Expr* op) {
if (global_idx_.empty() ||
block_order_with_ctrl_structure_.rbegin()->first.size() ==
global_idx_.size()) {
cur_idx_ = -1;
}
global_idx_.push_back(++cur_idx_);
block_order_with_ctrl_structure_.insert(std::make_pair(global_idx_, *op));
IRMutator<Expr*>::Visit(x, op);
cur_idx_ = global_idx_.back();
global_idx_.pop_back();
}
private:
int cur_idx_;
std::vector<int> global_idx_;
std::map<std::vector<int>, Expr> block_order_with_ctrl_structure_;
};
} // namespace ir
} // namespace cinn
......@@ -599,8 +599,26 @@ Shared<poly::Stage> CreateStage(Tensor tensor) {
return poly::Stage::New(isl_domain, tensor->body(), tensor.self());
}
static constexpr char kReduceInitSuffix[] = "__reduce_init";
std::string GenReduceInitTensorNameOf(const std::string &tensor_name) {
return tensor_name + "__reduce_init";
return tensor_name + kReduceInitSuffix;
}
bool IsReduceInitTensorName(const std::string &tensor_name) {
std::string reduce_init_suffix(kReduceInitSuffix);
return tensor_name.length() > reduce_init_suffix.size() &&
tensor_name.substr(tensor_name.length() - reduce_init_suffix.size(),
reduce_init_suffix.size()) == reduce_init_suffix;
}
std::string GetOriginalReduceTensorName(const std::string &tensor_name) {
std::string reduce_init_suffix(kReduceInitSuffix);
if (IsReduceInitTensorName(tensor_name)) {
return tensor_name.substr(0,
tensor_name.length() - reduce_init_suffix.size());
}
return tensor_name;
}
bool _Tensor_::is_reduce_sum() const {
......
......@@ -116,6 +116,10 @@ class Tensor : public ir::IrNodeRef {
*/
std::string GenReduceInitTensorNameOf(const std::string& tensor_name);
bool IsReduceInitTensorName(const std::string& tensor_name);
std::string GetOriginalReduceTensorName(const std::string& tensor_name);
class ComputeOp;
class PlaceholderOp;
struct ReadCacheRelation;
......
......@@ -17,3 +17,5 @@ cinn_cc_test(test_ir_verify SRCS ir_verify_test.cc DEPS cinncore)
cinn_cc_test(test_schedule_desc SRCS schedule_desc_test.cc DEPS cinncore)
cinn_cc_test(test_ir_compare SRCS ir_compare_test.cc DEPS cinncore)
cinn_cc_test(test_ir_copy SRCS ir_copy_test.cc DEPS cinncore)
cinn_cc_test(test_schedule_block_graph SRCS schedule_block_graph_test.cc DEPS
cinncore)
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ir/schedule_block_graph.h"
#include <gtest/gtest.h>
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/op_lowering.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace ir {
IRSchedule MakeIRSchedule(frontend::Program* program) {
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
std::unordered_set<std::string> fetch_ids;
auto graph = frontend::Optimize(program, fetch_ids, target);
LOG_IF(WARNING, graph->fusion_groups.size() > 1)
<< "Test Graph has more than 1 group";
auto& dtype_dict =
graph->GetMutableAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
auto& shape_dict = graph->GetMutableAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target);
std::vector<LoweredFunc> lowered_funcs =
op_lowerer.Lower(graph->fusion_groups.front(), false, false);
CHECK(!lowered_funcs.empty()) << "lowered_funcs_ is empty";
std::vector<Expr> bodys;
for (auto&& func : lowered_funcs) {
bodys.emplace_back(func->body);
}
return IRSchedule(ModuleExpr({std::move(bodys)}), 1);
}
std::string GetIR(const ir::IRSchedule& schedule) {
const auto& exprs = schedule.GetModule().GetExprs();
std::stringstream module_stream;
for (auto i = 0; i < exprs.size(); ++i) {
module_stream << "Expr " << i << " {\n"
<< exprs.at(i) << "\n} // end Expr " << i << "\n";
}
return module_stream.str();
}
frontend::Program CreateElementwiseProgram() {
constexpr int M = 32;
constexpr int N = 24;
frontend::NetBuilder builder("net_builder");
auto a = builder.CreateInput(Float(32), {M, N}, "A");
auto b = builder.CreateInput(Float(32), {M, N}, "B");
auto c = builder.Add(a, b);
auto d = builder.Add(a, c);
auto e = builder.Relu(c);
auto f = builder.Relu(d);
auto program = builder.Build();
return program;
}
frontend::Program CreateReduceProgram() {
constexpr int M = 64;
constexpr int N = 128;
frontend::NetBuilder builder("net_builder");
auto a = builder.CreateInput(Float(32), {M, N}, "A");
auto b = builder.CreateInput(Float(32), {M, N}, "B");
auto c = builder.Add(a, b);
auto d = builder.ReduceSum(c, {0});
auto e = builder.BroadcastTo(d, {M, N});
auto f = builder.Add(e, a);
auto program = builder.Build();
return program;
}
TEST(ScheduleBlockGraph, elementwise) {
frontend::Program program = CreateElementwiseProgram();
IRSchedule ir_sch = MakeIRSchedule(&program);
ScheduleBlockGraph sbg(ir_sch);
LOG(INFO) << GetIR(ir_sch);
LOG(INFO) << sbg.Visualize();
CHECK_EQ(sbg.BlockIdsInOrder().size(), 6);
CHECK_EQ(sbg.nodes().size(), 6);
ScheduleBlockNode* v2 = sbg.RetrieveNode("var_2");
CHECK(v2);
CHECK_EQ(v2->UpstreamNodes().size(), 1);
CHECK_EQ(v2->DownstreamNodes().size(), 1);
ScheduleBlockNode* v4 = sbg.RetrieveNode("var_4");
CHECK(v4);
CHECK_EQ(v4->UpstreamNodes().size(), 3);
CHECK_EQ(v4->DownstreamNodes().size(), 0);
std::vector<std::string> reverse_dfs_topo_order_ids;
sbg.DFSTopoWalk([&reverse_dfs_topo_order_ids](const ScheduleBlockNode* node) {
reverse_dfs_topo_order_ids.push_back(node->id());
});
for (const std::string& id : reverse_dfs_topo_order_ids) {
LOG(INFO) << id;
}
CHECK_EQ(reverse_dfs_topo_order_ids.size(), 6);
std::vector<std::string> dfs_topo_order_ids;
sbg.DFSTopoWalk(
[&dfs_topo_order_ids](const ScheduleBlockNode* node) {
dfs_topo_order_ids.push_back(node->id());
},
false);
for (const std::string& id : dfs_topo_order_ids) {
LOG(INFO) << id;
}
CHECK_EQ(dfs_topo_order_ids.size(), 6);
}
#ifdef CINN_WITH_CUDA
TEST(ScheduleBlockGraph, reduce) {
frontend::Program program = CreateReduceProgram();
IRSchedule ir_sch = MakeIRSchedule(&program);
ScheduleBlockGraph sbg(ir_sch);
LOG(INFO) << GetIR(ir_sch);
LOG(INFO) << sbg.Visualize();
CHECK_EQ(sbg.BlockIdsInOrder().size(), 8);
CHECK_EQ(sbg.nodes().size(), 8);
ScheduleBlockNode* v_reduce_init = sbg.RetrieveNode("var_48__reduce_init");
CHECK(v_reduce_init);
CHECK_EQ(v_reduce_init->UpstreamNodes().size(), 0);
CHECK_EQ(v_reduce_init->DownstreamNodes().size(), 3);
ScheduleBlockNode* v = sbg.RetrieveNode("var_48");
CHECK(v);
CHECK_EQ(v->UpstreamNodes().size(), 5);
CHECK_EQ(v->DownstreamNodes().size(), 2);
std::vector<std::string> reverse_dfs_topo_order_ids;
sbg.DFSTopoWalk([&reverse_dfs_topo_order_ids](const ScheduleBlockNode* node) {
reverse_dfs_topo_order_ids.push_back(node->id());
});
for (const std::string& id : reverse_dfs_topo_order_ids) {
LOG(INFO) << id;
}
CHECK_EQ(reverse_dfs_topo_order_ids.size(), 8);
std::vector<std::string> dfs_topo_order_ids;
sbg.DFSTopoWalk(
[&dfs_topo_order_ids](const ScheduleBlockNode* node) {
dfs_topo_order_ids.push_back(node->id());
},
false);
for (const std::string& id : dfs_topo_order_ids) {
LOG(INFO) << id;
}
CHECK_EQ(dfs_topo_order_ids.size(), 8);
}
#endif
} // namespace ir
} // namespace cinn
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册