提交 d72a21e2 编写于 作者: L Li Xinqi 提交者: GitHub

带策略的寄存器着色 (#1613)

* mem_shared_hint_id

* sharable memory block

* rm useless code

* remove useless code

* bugfix: no redundant edges

* rename: MemBlockGroup => MemBlock

* put constrcutor of SharableMemBlockNode into header file

* bugfix

* rename field: MemBlock.block_id => MemBlock.mem_block_id


Former-commit-id: 6a8fc14c2ba6bbe148a84458fa6119af16cbe672
上级 0b58c5d8
......@@ -56,8 +56,6 @@ class RegstLifetimeGraph final : public Graph<const RegstLifetimeNode, RegstLife
const std::function<void(const RegstDescProto*, HashSet<int64_t>*)>& ComputeLifetimeActorIds,
std::list<RegstLifetimeNode*>* nodes);
void InitEdges(const std::list<RegstLifetimeNode*>& nodes);
HashMap<const RegstLifetimeNode*, HashSet<const RegstLifetimeNode*>>
regst_lifetime_node2intersected_nodes_;
};
} // namespace oneflow
......
#include "oneflow/core/graph/sharable_mem_block_graph.h"
#include "oneflow/core/register/register_desc.h"
#include "oneflow/core/register/runtime_register_desc.h"
namespace oneflow {
namespace {
bool IsConsumersAndProducerInSameChain(const RegstDescProto& regst_desc,
const PlanTaskGraph& plan_task_graph) {
auto ChainId4TaskId = [&](int64_t task_id) {
return plan_task_graph.TaskProto4TaskId(task_id)->task_set_info().chain_id();
};
int64_t producer_chain_id = ChainId4TaskId(regst_desc.producer_task_id());
for (int64_t consumer_task_id : regst_desc.consumer_task_id()) {
if (ChainId4TaskId(consumer_task_id) != producer_chain_id) { return false; }
}
return true;
}
} // namespace
SharableMemBlockGraph::SharableMemBlockGraph(
const PlanTaskGraph& plan_task_gph,
const std::function<bool(const RegstDescProto&)>& IsSharable) {
auto ForEachSharableChainRegstDesc =
[&](const std::function<void(int64_t, const RegstDescProto&)>& Handler) {
for (const TaskProto& task : plan_task_gph.plan().task()) {
for (const auto& pair : task.produced_regst_desc()) {
if (IsConsumersAndProducerInSameChain(pair.second, plan_task_gph)
&& IsSharable(pair.second)) {
Handler(task.task_set_info().chain_id(), pair.second);
}
}
}
};
HashMap<std::pair<int64_t, MemBlock>, HashSet<const RegstDescProto*>>
chain_id7mem_block2regst_descs;
HashSet<int64_t> mem_block_ids_check;
ForEachSharableChainRegstDesc([&](int64_t chain_id, const RegstDescProto& regst_desc) {
int32_t idx = 0;
for (const auto& mem_block : regst_desc.mem_block_hierarchy()) {
if (idx++ == 0) { CHECK(mem_block_ids_check.emplace(mem_block.mem_block_id()).second); }
auto& regst_descs = chain_id7mem_block2regst_descs[std::make_pair(chain_id, mem_block)];
CHECK(regst_descs.emplace(&regst_desc).second);
}
});
HashMap<std::pair<int64_t, MemBlock>, SharableMemBlockNode*> chain_id7mem_block2node;
for (const auto& pair : chain_id7mem_block2regst_descs) {
auto* node =
new SharableMemBlockNode(pair.first.first, pair.first.second, pair.second, plan_task_gph);
AddAllocatedNode(node);
CHECK(chain_id7mem_block2node.emplace(pair.first, node).second);
}
HashSet<const SharableMemBlockNode*> connected_children;
ForEachSharableChainRegstDesc([&](int64_t chain_id, const RegstDescProto& regst_desc) {
SharableMemBlockNode* child = nullptr;
for (const auto& mem_block : regst_desc.mem_block_hierarchy()) {
auto* parent = chain_id7mem_block2node.at(std::make_pair(chain_id, mem_block));
if (child != nullptr && connected_children.find(child) == connected_children.end()) {
Connect(parent, NewEdge(), child);
CHECK(connected_children.emplace(child).second);
}
child = parent;
}
});
}
void SharableMemBlockGraph::ForEachSourceNodeGroup(
const std::function<int64_t(const SharableMemBlockNode*)>& GroupBy,
const std::function<void(const std::vector<const SharableMemBlockNode*>&)>& Handler) const {
HashMap<int64_t, std::vector<const SharableMemBlockNode*>> chain_id2source_nodes;
for (const SharableMemBlockNode* source : source_nodes()) {
chain_id2source_nodes[GroupBy(source)].push_back(source);
}
for (const auto& pair : chain_id2source_nodes) { Handler(pair.second); }
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_SHARABLE_MEM_BLOCK_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_SHARABLE_MEM_BLOCK_GRAPH_H_
#include "oneflow/core/graph/graph.h"
#include "oneflow/core/register/register_desc.pb.h"
#include "oneflow/core/graph/plan_task_graph.h"
namespace oneflow {
class SharableMemBlockEdge;
class SharableMemBlockNode final : public Node<SharableMemBlockNode, SharableMemBlockEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(SharableMemBlockNode);
SharableMemBlockNode(int64_t chain_id, const MemBlock& mem_block,
const HashSet<const RegstDescProto*>& regst_descs,
const PlanTaskGraph& plan_task_graph)
: chain_id_(chain_id),
mem_block_(mem_block),
regst_descs_(regst_descs.begin(), regst_descs.end()) {}
~SharableMemBlockNode() = default;
int64_t chain_id() const { return chain_id_; }
const std::vector<const RegstDescProto*>& regst_descs() const { return regst_descs_; }
const MemBlock& mem_block() const { return mem_block_; }
private:
const int64_t chain_id_;
const MemBlock mem_block_;
const std::vector<const RegstDescProto*> regst_descs_;
};
class SharableMemBlockEdge final : public Edge<SharableMemBlockNode, SharableMemBlockEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(SharableMemBlockEdge);
SharableMemBlockEdge() = default;
~SharableMemBlockEdge() = default;
};
class SharableMemBlockGraph final : public Graph<const SharableMemBlockNode, SharableMemBlockEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(SharableMemBlockGraph);
SharableMemBlockGraph(const PlanTaskGraph& plan_task_gph,
const std::function<bool(const RegstDescProto&)>& IsSharable);
~SharableMemBlockGraph() = default;
void ForEachSourceNodeGroup(
const std::function<int64_t(const SharableMemBlockNode*)>& GroupBy,
const std::function<void(const std::vector<const SharableMemBlockNode*>&)>& Handler) const;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_SHARABLE_MEM_BLOCK_GRAPH_H_
......@@ -464,6 +464,33 @@ void TaskGraph::EnableMemSharingInVariableOp() {
});
}
void TaskGraph::EnableInplaceMemSharing() {
AcyclicTopoForEachNode([&](TaskNode* node) {
if (node->exec_gph().node_num() != 1) { return; }
const Operator* op = node->exec_gph().SoleNode()->op().get();
auto* fw_task_node = dynamic_cast<NormalForwardCompTaskNode*>(node);
auto* bw_task_node = dynamic_cast<NormalBackwardCompTaskNode*>(node);
RegstDesc* input_regst = nullptr;
RegstDesc* output_regst = nullptr;
if (op->IsForwardInplace() && fw_task_node) {
input_regst = fw_task_node->GetSoleConsumedRegst("in").get();
output_regst = fw_task_node->GetProducedRegst("out").get();
} else if (op->IsBackwardInplace() && bw_task_node) {
input_regst = bw_task_node->GetSoleConsumedRegst(GenDiffBn("out")).get();
output_regst = bw_task_node->GetProducedRegst(GenDiffBn("in")).get();
} else {
// do nothing
return;
}
if (input_regst->NumOfLbi() != 1) { return; }
if (output_regst->NumOfLbi() != 1) { return; }
if (input_regst->mem_shared_inplace_block_id() == -1) {
input_regst->set_mem_shared_inplace_block_id(Global<IDMgr>::Get()->NewMemBlockId());
}
output_regst->set_mem_shared_inplace_block_id(input_regst->mem_shared_inplace_block_id());
});
}
void TaskGraph::RmUselessConsumeRelationshipBetweenFwBw() {
for (TaskNode* task_node : ordered_task_nodes_) {
auto bw_node = dynamic_cast<NormalBackwardCompTaskNode*>(task_node);
......
......@@ -27,6 +27,7 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
void EnableMemSharingInReduceStruct();
void EnableMemSharingAfterAllManualSetForMdUpdt();
void EnableMemSharingInVariableOp();
void EnableInplaceMemSharing();
void AddOrderCtrlEdgeBetweenCopyAndMdUpdt();
void RmUselessConsumeRelationshipBetweenFwBw();
......
......@@ -114,6 +114,7 @@ Plan Compiler::DoCompile() {
task_gph->EnableMemSharingInReduceStruct();
task_gph->EnableMemSharingAfterAllManualSetForMdUpdt(); // must last mem shared manual set
}
task_gph->EnableInplaceMemSharing();
if (job_desc->IsTrain()) { task_gph->AddOrderCtrlEdgeBetweenCopyAndMdUpdt(); }
if (job_desc->IsTrain()) { task_gph->RmUselessConsumeRelationshipBetweenFwBw(); }
task_gph->MdUpdtDelayedTopoForEachNode(&TaskNode::InferTimeShapeIfMeaningful);
......
......@@ -108,6 +108,7 @@ IDMgr::IDMgr() {
CHECK_LT(gpu_device_num_ + cpu_device_num_, (static_cast<int64_t>(1) << thread_id_bit_num_) - 3);
regst_desc_id_count_ = 0;
mem_shared_id_count_ = 0;
mem_block_id_count_ = 0;
}
int64_t IDMgr::GetMachineThrdId(int64_t machine_id, int64_t thrd_id) {
......
......@@ -28,6 +28,7 @@ class IDMgr final {
int64_t NewTaskId(int64_t machine_id, int64_t thrd_id, int64_t local_work_stream_id);
int64_t NewRegstDescId() { return regst_desc_id_count_++; }
int64_t NewMemSharedId() { return mem_shared_id_count_++; }
int64_t NewMemBlockId() { return mem_block_id_count_++; }
// MemZoneId
int64_t CpuMemZoneId() const { return Global<JobDesc>::Get()->GpuDeviceNum(); }
......@@ -74,6 +75,7 @@ class IDMgr final {
int64_t cpu_device_num_;
int64_t regst_desc_id_count_;
int64_t mem_shared_id_count_;
int64_t mem_block_id_count_;
HashMap<int64_t, int64_t> machine_thrd_id2num_of_tasks_;
HashMap<int64_t, int64_t> machine_thrd_id2stream_id_cnt_;
HashMap<int64_t, int64_t> stream_id2chain_cnt_;
......
......@@ -7,6 +7,7 @@
#include "oneflow/core/job/profiler.h"
#include "oneflow/core/graph/plan_task_graph.h"
#include "oneflow/core/graph/regst_lifetime_graph.h"
#include "oneflow/core/graph/sharable_mem_block_graph.h"
#include "oneflow/core/actor/act_event_logger.h"
namespace oneflow {
......@@ -27,13 +28,6 @@ bool IsConsumersAndProducerInSameChain(const RegstDescProto& regst_desc,
return true;
}
bool IsSharableRegstWithConsumer(const RegstDescProto& regst_desc,
const std::function<int64_t(int64_t)>& ChainId4TaskId) {
return regst_desc.mem_shared_id() == -1 && regst_desc.consumer_task_id_size() > 0
&& regst_desc.enable_mem_sharing() && regst_desc.register_num() == 1
&& IsConsumersAndProducerInSameChain(regst_desc, ChainId4TaskId);
}
void ForEachSharableStreamRegstDescsWithoutConsumer(
const Plan& plan, const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) {
HashMap<int64_t, std::list<const RegstDescProto*>> global_work_stream_id2regst_descs;
......@@ -50,27 +44,6 @@ void ForEachSharableStreamRegstDescsWithoutConsumer(
}
}
void ForEachSharableChainRegstDescsWithConsumer(
const Plan& plan, const std::function<int64_t(int64_t)>& ChainId4TaskId,
const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) {
HashMap<int64_t, std::list<const TaskProto*>> chain_id2task_proto;
for (const TaskProto& task : plan.task()) {
chain_id2task_proto[task.task_set_info().chain_id()].push_back(&task);
}
for (const auto& chain_tasks_pair : chain_id2task_proto) {
if (chain_tasks_pair.second.size() == 1) { continue; }
std::list<const RegstDescProto*> regst_descs;
for (const TaskProto* task : chain_tasks_pair.second) {
for (const auto& pair : task->produced_regst_desc()) {
if (IsSharableRegstWithConsumer(pair.second, ChainId4TaskId)) {
regst_descs.push_back(&pair.second);
}
}
}
if (regst_descs.size() > 1) { Handler(regst_descs); }
}
}
void ForEachSameColoredStreamRegstDescWithoutConsumer(
const Plan& plan, const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) {
auto GetProducerTaskId = [](const RegstDescProto* regst_desc, HashSet<int64_t>* ret_actor_ids) {
......@@ -86,20 +59,70 @@ void ForEachSameColoredStreamRegstDescWithoutConsumer(
void ForEachSameColoredChainRegstDescWithConsumer(
const PlanTaskGraph& plan_task_graph,
const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) {
// construct SharableMemBlockGraph
auto ChainId4TaskId = [&](int64_t task_id) {
return plan_task_graph.TaskProto4TaskId(task_id)->task_set_info().chain_id();
};
auto IsSharableRegstWithConsumer = [&](const RegstDescProto& regst_desc) {
return regst_desc.mem_shared_id() == -1 && regst_desc.consumer_task_id_size() > 0
&& regst_desc.enable_mem_sharing() && regst_desc.register_num() == 1
&& IsConsumersAndProducerInSameChain(regst_desc, ChainId4TaskId);
};
SharableMemBlockGraph sharable_mem_block_gph(plan_task_graph, IsSharableRegstWithConsumer);
sharable_mem_block_gph.ForEachNode([&](const SharableMemBlockNode* sharable_mem_block) {
CHECK_EQ(sharable_mem_block->mem_block().mem_reduce_method(), MemReduceMethod::kMemMax);
});
// group regst_descs for pre-colored regst_descs.
// example:
// given dlnet: A -> B -> C -> D -> E -> F -> H -> I, where D is a inplace op.
// Regst(C) and Regst(D) are pre-colored with same color as a group, which
// then shares memory with other regsts like A, B, E, ...
HashMap<const RegstDescProto*, std::vector<const RegstDescProto*>> header2members;
for (const SharableMemBlockNode* sharable_mem_block : sharable_mem_block_gph.source_nodes()) {
auto regst_descs = sharable_mem_block->regst_descs();
HashMap<const RegstDescProto*, size_t> regst_desc2mem_size;
for (const RegstDescProto* regst_desc : regst_descs) {
size_t size = RtRegstDesc(*regst_desc).TotalMainByteSize4AllRegst();
CHECK(regst_desc2mem_size.emplace(regst_desc, size).second);
}
std::sort(regst_descs.begin(), regst_descs.end(),
[&](const RegstDescProto* lhs, const RegstDescProto* rhs) {
return regst_desc2mem_size.at(lhs) > regst_desc2mem_size.at(rhs);
});
header2members.emplace(regst_descs.at(0), regst_descs);
}
auto GetRegstDescs = [&](const std::vector<const SharableMemBlockNode*>& sharable_mem_blocks) {
std::list<const RegstDescProto*> ret;
for (const SharableMemBlockNode* sharable_mem_block : sharable_mem_blocks) {
for (const RegstDescProto* regst_desc : sharable_mem_block->regst_descs()) {
if (header2members.find(regst_desc) != header2members.end()) {
ret.push_back(regst_desc);
break;
}
}
}
return ret;
};
auto ComputeLifetimeSameChainActorIds = [&](const RegstDescProto* regst_desc,
HashSet<int64_t>* ret_actor_ids) {
CHECK(regst_desc->enable_mem_sharing());
ret_actor_ids->clear();
plan_task_graph.ComputeLifetimeSameChainActorIds(regst_desc, ret_actor_ids);
for (const RegstDescProto* member : header2members.at(regst_desc)) {
plan_task_graph.ComputeLifetimeSameChainActorIds(member, ret_actor_ids);
}
};
auto ChainId4TaskId = [&](int64_t task_id) {
return plan_task_graph.TaskProto4TaskId(task_id)->task_set_info().chain_id();
auto AppendGroupMembers = [&](const std::list<const RegstDescProto*>& regst_descs) {
std::list<const RegstDescProto*> members;
for (const auto* header : regst_descs) {
for (const auto* member : header2members.at(header)) { members.push_back(member); }
}
Handler(members);
};
const Plan& plan = plan_task_graph.plan();
ForEachSharableChainRegstDescsWithConsumer(
plan, ChainId4TaskId, [&](const std::list<const RegstDescProto*>& regst_descs) {
RegstLifetimeGraph(regst_descs, ComputeLifetimeSameChainActorIds)
.ForEachSameColoredRegstDescs(Handler);
sharable_mem_block_gph.ForEachSourceNodeGroup(
&SharableMemBlockNode::chain_id,
[&](const std::vector<const SharableMemBlockNode*>& sharable_mem_blocks) {
RegstLifetimeGraph(GetRegstDescs(sharable_mem_blocks), ComputeLifetimeSameChainActorIds)
.ForEachSameColoredRegstDescs(AppendGroupMembers);
});
}
......
......@@ -45,6 +45,8 @@ class Operator {
virtual bool NeedOutBlobWhenBackward() const { return true; }
bool NeedInBlobWhenBackwardIf() const { return NeedInBlobWhenBackward(); }
virtual bool NeedInBlobWhenBackward() const { return true; }
virtual bool IsForwardInplace() const { return false; }
virtual bool IsBackwardInplace() const { return false; }
// bn_in_op <-> lbi
const LogicalBlobId& BnInOp2Lbi(const std::string& bn_in_op) const;
......
......@@ -16,6 +16,8 @@ class ReshapeOp final : public Operator {
bool IsElemWiseOp() const override { return true; }
bool NeedInBlobWhenBackward() const override { return false; }
bool NeedOutBlobWhenBackward() const override { return false; }
bool IsForwardInplace() const override { return true; }
bool IsBackwardInplace() const override { return true; }
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
......
......@@ -26,6 +26,7 @@ RegstDesc::RegstDesc() {
enable_mem_sharing_ = false;
mem_shared_id_ = -1;
mem_shared_offset_ = -1;
mem_shared_inplace_block_id_ = -1;
}
int64_t RegstDesc::mem_shared_offset() const {
......@@ -153,6 +154,10 @@ void RegstDesc::ToProto(RegstDescProto* ret) const {
ret->set_enable_mem_sharing(enable_mem_sharing_);
ret->set_mem_shared_id(mem_shared_id_);
ret->set_mem_shared_offset(mem_shared_offset_);
ret->add_mem_block_hierarchy()->set_mem_block_id(Global<IDMgr>::Get()->NewMemBlockId());
if (mem_shared_inplace_block_id_ != -1) {
ret->add_mem_block_hierarchy()->set_mem_block_id(mem_shared_inplace_block_id_);
}
}
bool RegstDesc::HasSameMemSize(const RegstDesc* rhs) {
......
......@@ -54,6 +54,8 @@ class RegstDesc final {
void set_enable_mem_sharing(bool enable_mem_sharing) { enable_mem_sharing_ = enable_mem_sharing; }
int64_t mem_shared_offset() const;
void set_mem_shared_offset(int64_t val) { mem_shared_offset_ = val; }
int64_t mem_shared_inplace_block_id() const { return mem_shared_inplace_block_id_; }
void set_mem_shared_inplace_block_id(int64_t val) { mem_shared_inplace_block_id_ = val; }
int32_t mem_shared_id() const { return mem_shared_id_; }
void set_mem_shared_id(int32_t val) { mem_shared_id_ = val; }
bool HasSetMemSharedId() { return mem_shared_id_ != -1; }
......@@ -95,10 +97,28 @@ class RegstDesc final {
bool enable_mem_sharing_;
int32_t mem_shared_id_;
int64_t mem_shared_offset_;
int64_t mem_shared_inplace_block_id_;
std::shared_ptr<Shape> data_regst_time_shape_;
};
inline bool operator==(const MemBlock& lhs, const MemBlock& rhs) {
bool ret = (lhs.mem_block_id() == rhs.mem_block_id());
if (ret) { CHECK_EQ(lhs.mem_reduce_method(), rhs.mem_reduce_method()); }
return ret;
}
} // namespace oneflow
namespace std {
template<>
struct hash<oneflow::MemBlock> final {
size_t operator()(const oneflow::MemBlock& mem_block) const {
return hash<int64_t>()(mem_block.mem_block_id());
}
};
} // namespace std
#endif // ONEFLOW_CORE_REGISTER_REGISTER_DESC_H_
......@@ -29,6 +29,17 @@ message RegstDescTypeProto {
}
}
enum MemReduceMethod {
kMemInvalidSharedMethod = 0;
kMemSum = 1;
kMemMax = 2;
}
message MemBlock {
required int64 mem_block_id = 1;
optional MemReduceMethod mem_reduce_method = 2 [default = kMemMax];
}
message RegstDescProto {
required int64 regst_desc_id = 1;
required int64 producer_task_id = 2;
......@@ -41,4 +52,6 @@ message RegstDescProto {
required bool enable_mem_sharing = 9;
required int32 mem_shared_id = 10;
required int64 mem_shared_offset = 11;
// from bottom to top
repeated MemBlock mem_block_hierarchy = 13;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册