From 903cf3f8ee100ec32965a1fbf9b570fad4600e03 Mon Sep 17 00:00:00 2001 From: willzhang4a58 Date: Sun, 28 May 2017 13:56:46 +0800 Subject: [PATCH] refine interface for register --- oneflow/memory/memory_allocator.cpp | 8 ++- oneflow/memory/memory_case.proto | 16 +++-- oneflow/register/blob.h | 6 +- oneflow/register/register.h | 7 ++- oneflow/register/register_manager.cpp | 75 ++---------------------- oneflow/register/register_manager.h | 14 +---- oneflow/register/runtime_register_desc.h | 32 ++++++++++ oneflow/runtime/elf_runner.cpp | 1 - 8 files changed, 60 insertions(+), 99 deletions(-) create mode 100644 oneflow/register/runtime_register_desc.h diff --git a/oneflow/memory/memory_allocator.cpp b/oneflow/memory/memory_allocator.cpp index 74f45bb983..08c8fef242 100644 --- a/oneflow/memory/memory_allocator.cpp +++ b/oneflow/memory/memory_allocator.cpp @@ -4,6 +4,8 @@ namespace oneflow { std::pair> MemoryAllocator::Allocate( MemoryCase mem_case,std::size_t size) { + TODO(); + /* char* dptr = nullptr; if (mem_case.has_host_pageable_mem()) { dptr = (char*) malloc (size); @@ -20,9 +22,12 @@ std::pair> MemoryAllocator::Allocate( CHECK_EQ(cudaSetDevice(current_device_id), 0); } return {dptr, std::bind(&MemoryAllocator::Deallocate, this, dptr, mem_case)}; + */ } void MemoryAllocator::Deallocate(char* dptr, MemoryCase mem_case) { + TODO(); + /* if (mem_case.has_host_pageable_mem()) { free(dptr); } else if (mem_case.has_cuda_pinned_mem()) { @@ -35,7 +40,8 @@ void MemoryAllocator::Deallocate(char* dptr, MemoryCase mem_case) { CHECK_EQ(cudaSetDevice(mem_case.gpu_mem().device_id()), 0); CHECK_EQ(cudaFree(&dptr), 0); CHECK_EQ(cudaSetDevice(current_device_id), 0); - } + } + */ } } // namespace oneflow diff --git a/oneflow/memory/memory_case.proto b/oneflow/memory/memory_case.proto index a5e74d6720..c33055f546 100644 --- a/oneflow/memory/memory_case.proto +++ b/oneflow/memory/memory_case.proto @@ -4,21 +4,19 @@ package oneflow; message HostPageableMemory { } -message CudaPinnedMemory { +message HostPinnedMemory { + bool need_cuda = 1; + bool need_rdma = 2; } -message RdmaPinnedMemory { -} - -message GpuMemory { - int32 device_id = 1; +message DeviceCudaMemory { + uint64 device_id = 1; } message MemoryCase { oneof case { HostPageableMemory host_pageable_mem = 1; - CudaPinnedMemory cuda_pinned_mem = 2; - RdmaPinnedMemory rdma_pinned_mem = 3; - GpuMemory gpu_mem = 4; + HostPinnedMemory host_pinned_mem = 2; + DeviceCudaMemory device_cuda_mem = 3; } } diff --git a/oneflow/register/blob.h b/oneflow/register/blob.h index 7f7223dcda..fdeba1f762 100644 --- a/oneflow/register/blob.h +++ b/oneflow/register/blob.h @@ -9,17 +9,17 @@ namespace oneflow { class Blob { public: OF_DISALLOW_COPY_AND_MOVE(Blob); - Blob(char* dptr, const Shape& shape) : dptr_(dptr), shape_(shape) {} + Blob(char* dptr, const Shape* shape) : dptr_(dptr), shape_(shape) {} ~Blob() {} const char* dptr() const { return dptr_; } - const Shape& shape() const { return shape_; } + const Shape& shape() const { return *shape_; } char* mut_dptr() { return dptr_; } private: char* dptr_ ; - Shape shape_; + const Shape* shape_; }; } // namespace oneflow diff --git a/oneflow/register/register.h b/oneflow/register/register.h index eb6bf5a7c1..fabebd8870 100644 --- a/oneflow/register/register.h +++ b/oneflow/register/register.h @@ -5,6 +5,7 @@ #include "actor/actor_message.pb.h" #include "actor/actor_msg_bus.h" #include "common/util.h" +#include "register/runtime_register_desc.h" namespace oneflow { @@ -20,9 +21,9 @@ class Regst final { private: friend class RegstMgr; Regst() = default; - uint64_t id_; - uint64_t producer_id_; - std::vector consumer_ids_; + + std::shared_ptr regst_desc_; + uint64_t regst_id_; std::function deleter_; HashMap> lbn2blob_; }; diff --git a/oneflow/register/register_manager.cpp b/oneflow/register/register_manager.cpp index f7e2f1ea33..8eedf94989 100644 --- a/oneflow/register/register_manager.cpp +++ b/oneflow/register/register_manager.cpp @@ -2,76 +2,11 @@ namespace oneflow { -void RegstMgr::NewRegstFromRegstDesc( - uint64_t producer_id, - const RegstDescProto& regstdesc, - std::size_t sizeof_floating, - HashMap>& actor_id2produced_regst_desc_id, - HashMap>& regst_desc_id2regst_ids) { - uint64_t regst_desc_id = regstdesc.regst_desc_id(); - for (int64_t i = 0; i < regstdesc.register_num(); ++i) { - std::unique_ptr regst(new Regst()); - regst->id_ = IDMgr::Singleton().NewRegstId(regst_desc_id); - regst->producer_id_ = producer_id; - std::size_t regst_size = 0; - for (const auto& mpair : regstdesc.lbn2shape()) { - Shape shape(mpair.second); - regst_size += shape.elem_cnt() * sizeof_floating; - } - auto mem_info = MemoryAllocator::Singleton().Allocate(MemoryCase(), - regst_size); - regst->deleter_ = mem_info.second; - char* dptr = mem_info.first; - for (const auto& mpair : regstdesc.lbn2shape()) { - Shape shape(mpair.second); - regst->lbn2blob_.emplace(mpair.first, of_make_unique(dptr, shape)); - dptr += shape.elem_cnt() * sizeof_floating; - } - regst_id2regst_.emplace(regst->id_, std::move(regst)); - actor_id2produced_regst_desc_id[producer_id].insert(regst_desc_id); - regst_desc_id2regst_ids[regst_desc_id].push_back(regst->id_); - } -} - -void RegstMgr::InitFromProto(const OfElf& ofelf) { - /* - //Init all regst for id, cnt, producer_id, lbn2blob - HashMap> actor_id2produced_regst_desc_id; - HashMap> regst_desc_id2regst_ids; - std::size_t sizeof_floating; - if (ofelf.job_desc().floating_point_type() == kFloat) { - sizeof_floating = sizeof(float); - } else { - sizeof_floating = sizeof(double); - } - for (const TaskProto& taskproto : ofelf.task()) { - if (taskproto.machine_id() != RuntimeInfo::Singleton().this_machine_id()) { continue; } - uint64_t actor_id = IDMgr::Singleton().GetActorIdFromTaskId(taskproto.id()); - for (const RegstDescProto& regstdesc : taskproto.produced_regst_desc()) { - NewRegstFromRegstDesc(actor_id, - regstdesc, - sizeof_floating, - actor_id2produced_regst_desc_id, - regst_desc_id2regst_ids); - } - } - //for consumer_ids, lbn2blob - for (const TaskProto& taskproto : ofelf.task()) { - if (taskproto.machine_id() != RuntimeInfo::Singleton().this_machine_id()) { continue; } - uint64_t actor_id = IDMgr::Singleton().GetActorIdFromTaskId(taskproto.id()); - HashSet processed_consumer; - for (const ExecNodeProto& execnode: taskproto.exec_sequence().exec_node()) { - for (const auto& mpair : execnode.bn_in_op2regst_desc_id()) { - if (actor_id2produced_regst_desc_id.at(actor_id).find(mpair.second) != - actor_id2produced_regst_desc_id.at(actor_id).end()) { continue; } - if (processed_consumer.find(mpair.second) != processed_consumer.end()) { continue; } - for (uint64_t regst_id : regst_desc_id2regst_ids[mpair.second]) { - GetRegstFromRegstID(regst_id)->consumer_ids_.push_back(actor_id); - } - } - } - }*/ - TODO(); +void NewRegsts(const RegstDescProto& regst_desc_proto, + std::function OneRegstDone) { + // One RegstDesc means Multi Regst + // All Regst has a shared_ptr point to the same RtRegstDesc obj + // Call OneRegstDone for each regst } } diff --git a/oneflow/register/register_manager.h b/oneflow/register/register_manager.h index 9e34c3d3a9..42a14614ee 100644 --- a/oneflow/register/register_manager.h +++ b/oneflow/register/register_manager.h @@ -20,22 +20,12 @@ class RegstMgr final { return obj; } - Regst* GetRegstFromRegstID(uint64_t regst_id) { - return regst_id2regst_.at(regst_id).get(); - } - - void InitFromProto(const OfElf& ofelf); + void NewRegsts(const RegstDescProto& regst_desc_proto, + std::function OneRegstDone); private: RegstMgr() = default; - void NewRegstFromRegstDesc( - uint64_t producer_id, - const RegstDescProto& regstdesc, - std::size_t sizeof_floating, - HashMap>& actor_id2produced_regst_desc_id, - HashMap>& regst_desc_id2regst_ids); - HashMap> regst_id2regst_; }; } // namespace oneflow diff --git a/oneflow/register/runtime_register_desc.h b/oneflow/register/runtime_register_desc.h new file mode 100644 index 0000000000..0a29e4cd48 --- /dev/null +++ b/oneflow/register/runtime_register_desc.h @@ -0,0 +1,32 @@ +#ifndef ONEFLOW_REGISTER_RUNTIME_REGISTER_DESC_H_ +#define ONEFLOW_REGISTER_RUNTIME_REGISTER_DESC_H_ + +#include "common/util.h" +#include "memory/memory_case.pb.h" +#include "register/register_desc.pb.h" + +namespace oneflow { + +class RtRegstDesc { + public: + OF_DISALLOW_COPY_AND_MOVE(RtRegstDesc); + RtRegstDesc() = delete; + ~RtRegstDesc() = default; + + RtRegstDesc(const RegstDescProto&) { TODO(); } + + // TODO: Add Getter + + private: + uint64_t regst_desc_id_; + uint64_t producer_task_id_; + std::vector subscribers_task_id_; + std::unordered_map> lbn2shape_; + int64_t register_num_; + MemoryCase mem_case_; + +}; + +} // namespace oneflow + +#endif // ONEFLOW_REGISTER_RUNTIME_REGISTER_DESC_H_ diff --git a/oneflow/runtime/elf_runner.cpp b/oneflow/runtime/elf_runner.cpp index ebbfb4554f..91d50581ac 100644 --- a/oneflow/runtime/elf_runner.cpp +++ b/oneflow/runtime/elf_runner.cpp @@ -23,7 +23,6 @@ class ElfRunner final { JobDesc::Singleton().InitFromProto(elf.job_desc()); IDMgr::Singleton().InitFromResource(JobDesc::Singleton().resource()); RuntimeInfo::Singleton().set_this_machine_name(this_machine_name); - RegstMgr::Singleton().InitFromProto(elf); TODO(); } -- GitLab