提交 a49e738f 编写于 作者: X Xinqi Li

split registers into two catagory: splited or cloned


Former-commit-id: bd8f19466c91b5d21d10657371e493fe51ff9bd9
上级 62ad745a
......@@ -283,5 +283,32 @@ DemoChainGraph::CalcChainRegstId2PathChainNodeIds(
return ret;
}
std::vector<std::vector<int64_t>> DemoChainGraph::SplitedRegstIds() const {
std::vector<std::vector<int64_t>> ret;
for (const auto& regst : regsts_) {
if (!regst->IsRegstCloned()) {
ret.push_back(std::vector<int64_t>{regst->chain_regst_id()});
}
}
return ret;
}
std::vector<std::vector<int64_t>> DemoChainGraph::ClonedRegstIds() const {
std::vector<std::vector<int64_t>> ret;
for (const auto& regst : regsts_) {
if (regst->IsRegstCloned()) {
ret.push_back(std::vector<int64_t>{regst->chain_regst_id()});
}
}
return ret;
}
bool DemoChainRegst::IsRegstCloned() const {
return producer()->task_type() == TaskType::kMdDiffAcc
|| producer()->task_type() == TaskType::kMdUpdt
|| (consumers().size() == 1
&& consumers().front()->task_type() == TaskType::kMdDiffAcc);
}
} // namespace df
} // namespace oneflow
......@@ -24,10 +24,12 @@ class DemoChainRegst final {
void HandleDiff(DemoChainRegst* regst) const { diff_handler_(regst); }
bool IsRegstCloned() const;
// Getters
int64_t chain_regst_id() const { return chain_regst_id_; }
const DemoChainNode* producer() const { return producer_; }
const std::unordered_set<const DemoChainNode*>& consumers() const {
const std::list<const DemoChainNode*>& consumers() const {
return consumers_;
}
......@@ -38,14 +40,14 @@ class DemoChainRegst final {
diff_handler_ = diff_handler;
}
void AddConsumer(const DemoChainNode* consumer) {
consumers_.insert(consumer);
consumers_.push_back(consumer);
}
private:
DemoChainNode* producer_;
int64_t chain_regst_id_;
std::function<void(DemoChainRegst*)> diff_handler_;
std::unordered_set<const DemoChainNode*> consumers_;
std::list<const DemoChainNode*> consumers_;
};
class DemoChainEdge;
......@@ -80,10 +82,9 @@ class DemoChainNode final : public Node<DemoChainNode, DemoChainEdge> {
class DemoChainNodeSubGraph final {
public:
OF_DISALLOW_COPY_AND_MOVE(DemoChainNodeSubGraph);
DemoChainNodeSubGraph(
const DemoChainNode* start_node,
const std::unordered_set<const DemoChainNode*>& end_nodes,
const IsReachablePredicator& is_reachable)
DemoChainNodeSubGraph(const DemoChainNode* start_node,
const std::list<const DemoChainNode*>& end_nodes,
const IsReachablePredicator& is_reachable)
: start_node_(start_node),
end_nodes_(end_nodes),
is_reachable_(&is_reachable) {}
......@@ -108,7 +109,7 @@ class DemoChainNodeSubGraph final {
bool IsReachableToEndNode(const DemoChainNode* node) const;
const DemoChainNode* start_node_;
std::unordered_set<const DemoChainNode*> end_nodes_;
std::list<const DemoChainNode*> end_nodes_;
const IsReachablePredicator* is_reachable_;
};
......@@ -143,6 +144,9 @@ class DemoChainGraph final : public Graph<DemoChainNode, DemoChainEdge> {
[](int64_t) -> double { return 1; });
}
std::vector<std::vector<int64_t>> SplitedRegstIds() const;
std::vector<std::vector<int64_t>> ClonedRegstIds() const;
private:
friend class DemoChainGraphBuilder;
void InitIsReachable();
......
......@@ -24,6 +24,13 @@ TEST(DemoChainGraph, simple_without_model) {
std::vector<std::vector<int64_t>> expected_path{
{0, 2, 3, 5}, {1, 2, 3, 4}, {2, 3}, {2, 3}, {3, 4}, {3, 5}};
ASSERT_TRUE(graph.CalcChainRegstId2PathChainNodeIds() == expected_path);
std::vector<std::vector<int64_t>> expected_splited_regst_ids{{0}, {1}, {2},
{3}, {4}, {5}};
ASSERT_TRUE(graph.SplitedRegstIds() == expected_splited_regst_ids);
std::vector<std::vector<int64_t>> expected_cloned_regst_ids{};
ASSERT_TRUE(graph.ClonedRegstIds() == expected_cloned_regst_ids);
}
TEST(DemoChainGraph, simple_with_model) {
......@@ -45,6 +52,14 @@ TEST(DemoChainGraph, simple_with_model) {
std::vector<std::vector<int64_t>> expected_path{
{0, 1, 2, 5}, {4, 1, 2}, {1, 2}, {1, 2}, {2, 3}, {3, 4}, {2, 5}};
ASSERT_TRUE(graph.CalcChainRegstId2PathChainNodeIds() == expected_path);
std::vector<std::vector<int64_t>> expected_splited_regst_ids{
{0}, {2}, {3}, {6}};
ASSERT_TRUE(graph.SplitedRegstIds() == expected_splited_regst_ids);
std::vector<std::vector<int64_t>> expected_cloned_regst_ids{{1}, {4}, {5}};
ASSERT_TRUE(graph.ClonedRegstIds() == expected_cloned_regst_ids);
}
} // namespace test
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册