提交 61a0bddf 编写于 作者: L Li Xinqi 提交者: GitHub

Dev mem sharing for variable op (#1604)

* pseudo chains of OpGraph

* ConvertPseudoChainToChain

* refine pseudo_chain

* refine register coloring algorithm

* rename op_graph log file name

* remove unused code

* EnableMemSharingInVariableOp

* no mem_sharing for out_diff & model_diff in variable_op


Former-commit-id: 224bb63a0576e7da8929ab829dfcf7c709398a97
上级 a5581ba4
......@@ -413,6 +413,33 @@ void TaskGraph::EnableMemSharingAfterAllManualSetForMdUpdt() {
});
}
void TaskGraph::EnableMemSharingInVariableOp() {
ForEachNode([&](TaskNode* node) {
if (node->exec_gph().node_num() != 1) { return; }
auto* variable_op = dynamic_cast<const VariableOp*>(node->exec_gph().SoleNode()->op().get());
if (variable_op == nullptr) { return; }
std::string model_bn = variable_op->op_conf().variable_conf().model_name();
auto* fw_task_node = dynamic_cast<NormalForwardCompTaskNode*>(node);
if (fw_task_node) {
const LogicalBlobId& lbi = variable_op->BnInOp2Lbi(model_bn);
RegstDesc* model_regst = fw_task_node->GetSoleConsumedRegst("model").get();
if (model_regst->enable_mem_sharing() == false) {
model_regst->set_enable_mem_sharing(true);
model_regst->set_mem_shared_id(Global<IDMgr>::Get()->NewMemSharedId());
model_regst->set_mem_shared_offset(0);
}
RegstDesc* out_regst = fw_task_node->GetProducedRegst("out").get();
CHECK_EQ(out_regst->NumOfLbi(), 1);
out_regst->set_enable_mem_sharing(true);
out_regst->set_mem_shared_id(model_regst->mem_shared_id());
out_regst->set_mem_shared_offset(model_regst->mem_shared_offset()
+ model_regst->ByteOffsetInPackedBlobDescBody(lbi));
} else {
// do nothing
}
});
}
void TaskGraph::RmUselessConsumeRelationshipBetweenFwBw() {
for (TaskNode* task_node : ordered_task_nodes_) {
auto bw_node = dynamic_cast<NormalBackwardCompTaskNode*>(task_node);
......
......@@ -26,6 +26,7 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
void EnableMemSharingInReduceStruct();
void EnableMemSharingAfterAllManualSetForMdUpdt();
void EnableMemSharingInVariableOp();
void AddOrderCtrlEdgeBetweenCopyAndMdUpdt();
void RmUselessConsumeRelationshipBetweenFwBw();
......
......@@ -117,6 +117,9 @@ Plan Compiler::DoCompile() {
if (job_desc->IsTrain()) { task_gph->AddOrderCtrlEdgeBetweenCopyAndMdUpdt(); }
if (job_desc->IsTrain()) { task_gph->RmUselessConsumeRelationshipBetweenFwBw(); }
task_gph->MdUpdtDelayedTopoForEachNode(&TaskNode::InferTimeShapeIfMeaningful);
if (job_desc->IsTrain() && job_desc->enable_mem_sharing()) {
task_gph->EnableMemSharingInVariableOp();
}
if (job_desc->IsTrain()) { task_gph->AddReduceNoBwForwardNodeOverlapingCtrlEdges(); }
Plan plan;
......
......@@ -9,7 +9,11 @@ void VariableKernel<device_type, T>::ForwardDataContent(
Blob* out_blob = BnInOp2Blob("out");
if ((this->op_conf().trainable() && *tick_ % Global<JobDesc>::Get()->NumOfPiecesInBatch() == 0)
|| (this->op_conf().trainable() == false && *tick_ == 0)) {
out_blob->CopyDataContentFrom(ctx.device_ctx, model_blob);
if (Global<JobDesc>::Get()->enable_mem_sharing()) {
CHECK_EQ(out_blob->dptr(), model_blob->dptr());
} else {
out_blob->CopyDataContentFrom(ctx.device_ctx, model_blob);
}
} else {
// do nothing
}
......@@ -20,8 +24,9 @@ template<DeviceType device_type, typename T>
void VariableKernel<device_type, T>::BackwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
CHECK(this->op_conf().trainable());
BnInOp2Blob(GenDiffBn(ModelName()))
->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob(GenDiffBn("out")));
const Blob* out_diff_blob = BnInOp2Blob(GenDiffBn("out"));
Blob* model_diff_blob = BnInOp2Blob(GenDiffBn(ModelName()));
model_diff_blob->CopyDataContentFrom(ctx.device_ctx, out_diff_blob);
}
template<DeviceType device_type, typename T>
......
......@@ -262,4 +262,9 @@ std::unique_ptr<BlobDesc> ComputePackedBlobDesc(
return ret;
}
bool CompareLbiBlobDescPair(const LbiBlobDescPair& lhs, const LbiBlobDescPair& rhs) {
return (lhs.blob_desc().header().blob_mem_id() < rhs.blob_desc().header().blob_mem_id())
|| (lhs.lbi() < rhs.lbi());
}
} // namespace oneflow
......@@ -7,6 +7,7 @@
#include "oneflow/core/register/blob_desc.pb.h"
#include "oneflow/core/register/pod_desc.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/register/register_desc.pb.h"
namespace oneflow {
......@@ -98,6 +99,8 @@ class BlobDesc {
std::unique_ptr<BlobDesc> ComputePackedBlobDesc(
const HashMap<LogicalBlobId, std::unique_ptr<BlobDesc>>& lbi2blob_desc);
bool CompareLbiBlobDescPair(const LbiBlobDescPair& lhs, const LbiBlobDescPair& rhs);
} // namespace oneflow
#endif // ONEFLOW_CORE_REGISTER_BLOB_DESC_H_
......@@ -81,23 +81,25 @@ class StructPodDesc final : public PodDesc {
~StructPodDesc() = default;
StructPodDesc* MutStructField(const FieldId& field_id);
StructPodDesc* MutStructField(const FieldId& field_id, int32_t default_alignment);
const PodDesc& Field(FieldKey field_key) const { return Field(NewFieldId(field_key)); }
const PodDesc& Field(const FieldId& field_id) const;
void AddField(FieldKey field_key, const PodDesc& pod_desc);
void AddField(const FieldId& field_id, const PodDesc& pod_desc);
size_t ByteSize() const override;
void InitFromProto(const StructPodProto& struct_pod);
void AddField(const FieldId& field_id, const PodDesc& pod_desc, size_t alignment);
bool HasField(FieldKey field_key) const { return HasField(NewFieldId(field_key)); }
bool HasField(const FieldId& field_id) const;
StructPodDesc& operator=(const StructPodDesc&);
std::unique_ptr<PodDesc> Clone() const override { return std::make_unique<StructPodDesc>(*this); }
void InitFromProto(const StructPodProto& struct_pod);
void ToProto(PodProto* pod_proto) const override { ToProto(pod_proto->mutable_struct_pod()); }
void ToProto(StructPodProto* pod_proto) const;
StructPodDesc* MutStructField(const FieldId& field_id, int32_t default_alignment);
void AddField(const FieldId& field_id, const PodDesc& pod_desc, size_t alignment);
bool operator==(const PodDesc& rhs) const override;
size_t ByteOffset4Field(const FieldId& field_name) const;
size_t ByteSize() const override;
StructPodDesc& operator=(const StructPodDesc&);
bool operator==(const PodDesc& rhs) const override;
private:
void Clear();
......
......@@ -3,6 +3,7 @@
#include "oneflow/core/graph/copy_task_node.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/register/runtime_blob_desc.h"
#include "oneflow/core/register/runtime_register_desc.h"
namespace oneflow {
......@@ -27,6 +28,11 @@ RegstDesc::RegstDesc() {
mem_shared_offset_ = -1;
}
int64_t RegstDesc::mem_shared_offset() const {
CHECK_GE(mem_shared_offset_, 0);
return mem_shared_offset_;
}
void RegstDesc::AddConsumer(const TaskNode* new_consumer) {
CHECK(consumers_.insert(new_consumer).second);
}
......@@ -164,6 +170,34 @@ bool RegstDesc::HasSameBlobDescs(const RegstDesc* rhs) {
return true;
}
int64_t RegstDesc::ByteOffsetInPackedBlobDescBody(const LogicalBlobId& lbi) const {
RegstDescProto regst_desc_proto;
ToProto(&regst_desc_proto);
RtRegstDesc rt_regst_desc(regst_desc_proto);
std::vector<LbiBlobDescPair> lbi_blob_desc_pairs;
for (const auto& pair : lbi2blob_desc_) {
LbiBlobDescPair lbi_blob_desc_pair;
*lbi_blob_desc_pair.mutable_lbi() = pair.first;
pair.second->ToProto(lbi_blob_desc_pair.mutable_blob_desc());
lbi_blob_desc_pairs.push_back(lbi_blob_desc_pair);
}
std::sort(lbi_blob_desc_pairs.begin(), lbi_blob_desc_pairs.end(), CompareLbiBlobDescPair);
bool found = false;
int64_t offset = 0;
rt_regst_desc.ForEachBlobDescOffsetInOnRegst(
lbi_blob_desc_pairs,
[&](const LbiBlobDescPair& lbi_blob_desc_pair, int64_t body_offset, int64_t header_offset) {
if (found) { return; }
if (lbi_blob_desc_pair.lbi() == lbi) {
offset = body_offset;
found = true;
}
});
CHECK(found);
return offset;
}
void InitCtrlRegstDesc(int64_t producer_task_id, RegstDescProto* ctrl_regst_proto) {
CHECK_NOTNULL(ctrl_regst_proto);
ctrl_regst_proto->set_regst_desc_id(Global<IDMgr>::Get()->NewRegstDescId());
......
......@@ -50,8 +50,9 @@ class RegstDesc final {
// mem
const MemoryCase& mem_case() const { return mem_case_; }
MemoryCase* mut_mem_case() { return &mem_case_; }
bool enable_mem_sharing() { return enable_mem_sharing_; }
void set_enable_mem_sharing(bool enable_mem_sharing) { enable_mem_sharing_ = enable_mem_sharing; }
int64_t mem_shared_offset() const { return mem_shared_offset_; }
int64_t mem_shared_offset() const;
void set_mem_shared_offset(int64_t val) { mem_shared_offset_ = val; }
int32_t mem_shared_id() const { return mem_shared_id_; }
void set_mem_shared_id(int32_t val) { mem_shared_id_ = val; }
......@@ -76,6 +77,7 @@ class RegstDesc final {
void EraseZeroSizeBlob();
void ToProto(RegstDescProto*) const;
bool HasSameBlobDescs(const RegstDesc*);
int64_t ByteOffsetInPackedBlobDescBody(const LogicalBlobId& lbi) const;
private:
int64_t regst_desc_id_;
......
......@@ -75,12 +75,7 @@ void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto,
for (const LbiBlobDescPair& pair : regst_desc_type.data_regst_desc().lbi2blob_desc()) {
lbi_pairs.push_back(pair);
}
std::sort(lbi_pairs.begin(), lbi_pairs.end(),
[&](const LbiBlobDescPair& lhs, const LbiBlobDescPair& rhs) {
return lhs.blob_desc().header().blob_mem_id()
< rhs.blob_desc().header().blob_mem_id()
|| lhs.lbi() < rhs.lbi();
});
std::sort(lbi_pairs.begin(), lbi_pairs.end(), &CompareLbiBlobDescPair);
CHECK(!lbi_pairs.empty());
CHECK(main_mem_ptr != nullptr);
}
......@@ -126,23 +121,14 @@ void RegstMgr::NewBlobsInOneRegst(const std::vector<LbiBlobDescPair>& lbis, Regs
cur_header_pointer = main_mem_ptr;
cur_body_pointer = main_mem_ptr + packed_blob_desc->ByteSizeOfBlobHeader();
}
int32_t last_blob_mem_id = -1;
size_t last_size = 0;
for (const LbiBlobDescPair& lbi : lbis) {
const RtBlobDesc* blob_desc = rt_regst_desc->GetRtBlobDescFromLbi(lbi.lbi());
int32_t cur_blob_mem_id = lbi.blob_desc().header().blob_mem_id();
if (cur_blob_mem_id == -1 || cur_blob_mem_id != last_blob_mem_id) {
cur_body_pointer += last_size;
}
std::unique_ptr<Blob> blob_ptr(
new Blob(regst, blob_desc, cur_header_pointer, cur_body_pointer));
InitOFRecordBlobIfNeed(blob_ptr.get());
CHECK(regst->lbi2blob_.emplace(lbi.lbi(), std::move(blob_ptr)).second);
cur_header_pointer += blob_desc->ByteSizeOfBlobHeader();
last_blob_mem_id = cur_blob_mem_id;
last_size = blob_desc->ByteSizeOfBlobBody();
}
rt_regst_desc->ForEachBlobDescOffsetInOnRegst(
lbis, [&](const LbiBlobDescPair& lbi, int64_t body_offset, int64_t header_offset) {
const RtBlobDesc* blob_desc = rt_regst_desc->GetRtBlobDescFromLbi(lbi.lbi());
std::unique_ptr<Blob> blob_ptr(new Blob(
regst, blob_desc, cur_header_pointer + header_offset, cur_body_pointer + body_offset));
InitOFRecordBlobIfNeed(blob_ptr.get());
CHECK(regst->lbi2blob_.emplace(lbi.lbi(), std::move(blob_ptr)).second);
});
}
void RegstMgr::InitOFRecordBlobIfNeed(Blob* blob_ptr) {
......
......@@ -70,4 +70,25 @@ const Shape& RtRegstDesc::data_regst_time_shape() const {
return *data_regst_time_shape_;
}
void RtRegstDesc::ForEachBlobDescOffsetInOnRegst(
const std::vector<LbiBlobDescPair>& lbis,
const std::function<void(const LbiBlobDescPair&, int64_t body_offset, int64_t header_offset)>&
Handler) const {
int32_t last_blob_mem_id = -1;
size_t last_size = 0;
int64_t cur_body_offset = 0;
int64_t cur_header_offset = 0;
for (const LbiBlobDescPair& lbi : lbis) {
const RtBlobDesc* blob_desc = GetRtBlobDescFromLbi(lbi.lbi());
int32_t cur_blob_mem_id = lbi.blob_desc().header().blob_mem_id();
if (cur_blob_mem_id == -1 || cur_blob_mem_id != last_blob_mem_id) {
cur_body_offset += last_size;
}
Handler(lbi, cur_body_offset, cur_header_offset);
cur_header_offset += blob_desc->ByteSizeOfBlobHeader();
last_blob_mem_id = cur_blob_mem_id;
last_size = blob_desc->ByteSizeOfBlobBody();
}
}
} // namespace oneflow
......@@ -31,6 +31,11 @@ class RtRegstDesc {
size_t MainByteSize4OneRegst() const;
const Shape& data_regst_time_shape() const;
void ForEachBlobDescOffsetInOnRegst(
const std::vector<LbiBlobDescPair>& lbis,
const std::function<void(const LbiBlobDescPair&, int64_t body_offset, int64_t header_offset)>&
Handler) const;
private:
int64_t regst_desc_id_;
int64_t producer_actor_id_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册