提交 903cf3f8 编写于 作者: W willzhang4a58

refine interface for register

上级 f95e9857
...@@ -4,6 +4,8 @@ namespace oneflow { ...@@ -4,6 +4,8 @@ namespace oneflow {
std::pair<char*, std::function<void()>> MemoryAllocator::Allocate( std::pair<char*, std::function<void()>> MemoryAllocator::Allocate(
MemoryCase mem_case,std::size_t size) { MemoryCase mem_case,std::size_t size) {
TODO();
/*
char* dptr = nullptr; char* dptr = nullptr;
if (mem_case.has_host_pageable_mem()) { if (mem_case.has_host_pageable_mem()) {
dptr = (char*) malloc (size); dptr = (char*) malloc (size);
...@@ -20,9 +22,12 @@ std::pair<char*, std::function<void()>> MemoryAllocator::Allocate( ...@@ -20,9 +22,12 @@ std::pair<char*, std::function<void()>> MemoryAllocator::Allocate(
CHECK_EQ(cudaSetDevice(current_device_id), 0); CHECK_EQ(cudaSetDevice(current_device_id), 0);
} }
return {dptr, std::bind(&MemoryAllocator::Deallocate, this, dptr, mem_case)}; return {dptr, std::bind(&MemoryAllocator::Deallocate, this, dptr, mem_case)};
*/
} }
void MemoryAllocator::Deallocate(char* dptr, MemoryCase mem_case) { void MemoryAllocator::Deallocate(char* dptr, MemoryCase mem_case) {
TODO();
/*
if (mem_case.has_host_pageable_mem()) { if (mem_case.has_host_pageable_mem()) {
free(dptr); free(dptr);
} else if (mem_case.has_cuda_pinned_mem()) { } else if (mem_case.has_cuda_pinned_mem()) {
...@@ -35,7 +40,8 @@ void MemoryAllocator::Deallocate(char* dptr, MemoryCase mem_case) { ...@@ -35,7 +40,8 @@ void MemoryAllocator::Deallocate(char* dptr, MemoryCase mem_case) {
CHECK_EQ(cudaSetDevice(mem_case.gpu_mem().device_id()), 0); CHECK_EQ(cudaSetDevice(mem_case.gpu_mem().device_id()), 0);
CHECK_EQ(cudaFree(&dptr), 0); CHECK_EQ(cudaFree(&dptr), 0);
CHECK_EQ(cudaSetDevice(current_device_id), 0); CHECK_EQ(cudaSetDevice(current_device_id), 0);
} }
*/
} }
} // namespace oneflow } // namespace oneflow
...@@ -4,21 +4,19 @@ package oneflow; ...@@ -4,21 +4,19 @@ package oneflow;
message HostPageableMemory { message HostPageableMemory {
} }
message CudaPinnedMemory { message HostPinnedMemory {
bool need_cuda = 1;
bool need_rdma = 2;
} }
message RdmaPinnedMemory { message DeviceCudaMemory {
} uint64 device_id = 1;
message GpuMemory {
int32 device_id = 1;
} }
message MemoryCase { message MemoryCase {
oneof case { oneof case {
HostPageableMemory host_pageable_mem = 1; HostPageableMemory host_pageable_mem = 1;
CudaPinnedMemory cuda_pinned_mem = 2; HostPinnedMemory host_pinned_mem = 2;
RdmaPinnedMemory rdma_pinned_mem = 3; DeviceCudaMemory device_cuda_mem = 3;
GpuMemory gpu_mem = 4;
} }
} }
...@@ -9,17 +9,17 @@ namespace oneflow { ...@@ -9,17 +9,17 @@ namespace oneflow {
class Blob { class Blob {
public: public:
OF_DISALLOW_COPY_AND_MOVE(Blob); OF_DISALLOW_COPY_AND_MOVE(Blob);
Blob(char* dptr, const Shape& shape) : dptr_(dptr), shape_(shape) {} Blob(char* dptr, const Shape* shape) : dptr_(dptr), shape_(shape) {}
~Blob() {} ~Blob() {}
const char* dptr() const { return dptr_; } const char* dptr() const { return dptr_; }
const Shape& shape() const { return shape_; } const Shape& shape() const { return *shape_; }
char* mut_dptr() { return dptr_; } char* mut_dptr() { return dptr_; }
private: private:
char* dptr_ ; char* dptr_ ;
Shape shape_; const Shape* shape_;
}; };
} // namespace oneflow } // namespace oneflow
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "actor/actor_message.pb.h" #include "actor/actor_message.pb.h"
#include "actor/actor_msg_bus.h" #include "actor/actor_msg_bus.h"
#include "common/util.h" #include "common/util.h"
#include "register/runtime_register_desc.h"
namespace oneflow { namespace oneflow {
...@@ -20,9 +21,9 @@ class Regst final { ...@@ -20,9 +21,9 @@ class Regst final {
private: private:
friend class RegstMgr; friend class RegstMgr;
Regst() = default; Regst() = default;
uint64_t id_;
uint64_t producer_id_; std::shared_ptr<const RtRegstDesc> regst_desc_;
std::vector<uint64_t> consumer_ids_; uint64_t regst_id_;
std::function<void()> deleter_; std::function<void()> deleter_;
HashMap<std::string, std::unique_ptr<Blob>> lbn2blob_; HashMap<std::string, std::unique_ptr<Blob>> lbn2blob_;
}; };
......
...@@ -2,76 +2,11 @@ ...@@ -2,76 +2,11 @@
namespace oneflow { namespace oneflow {
void RegstMgr::NewRegstFromRegstDesc( void NewRegsts(const RegstDescProto& regst_desc_proto,
uint64_t producer_id, std::function<void(Regst*)> OneRegstDone) {
const RegstDescProto& regstdesc, // One RegstDesc means Multi Regst
std::size_t sizeof_floating, // All Regst has a shared_ptr point to the same RtRegstDesc obj
HashMap<uint64_t, HashSet<uint64_t>>& actor_id2produced_regst_desc_id, // Call OneRegstDone for each regst
HashMap<uint64_t, std::vector<uint64_t>>& regst_desc_id2regst_ids) {
uint64_t regst_desc_id = regstdesc.regst_desc_id();
for (int64_t i = 0; i < regstdesc.register_num(); ++i) {
std::unique_ptr<Regst> regst(new Regst());
regst->id_ = IDMgr::Singleton().NewRegstId(regst_desc_id);
regst->producer_id_ = producer_id;
std::size_t regst_size = 0;
for (const auto& mpair : regstdesc.lbn2shape()) {
Shape shape(mpair.second);
regst_size += shape.elem_cnt() * sizeof_floating;
}
auto mem_info = MemoryAllocator::Singleton().Allocate(MemoryCase(),
regst_size);
regst->deleter_ = mem_info.second;
char* dptr = mem_info.first;
for (const auto& mpair : regstdesc.lbn2shape()) {
Shape shape(mpair.second);
regst->lbn2blob_.emplace(mpair.first, of_make_unique<Blob>(dptr, shape));
dptr += shape.elem_cnt() * sizeof_floating;
}
regst_id2regst_.emplace(regst->id_, std::move(regst));
actor_id2produced_regst_desc_id[producer_id].insert(regst_desc_id);
regst_desc_id2regst_ids[regst_desc_id].push_back(regst->id_);
}
}
void RegstMgr::InitFromProto(const OfElf& ofelf) {
/*
//Init all regst for id, cnt, producer_id, lbn2blob
HashMap<uint64_t, HashSet<uint64_t>> actor_id2produced_regst_desc_id;
HashMap<uint64_t, std::vector<uint64_t>> regst_desc_id2regst_ids;
std::size_t sizeof_floating;
if (ofelf.job_desc().floating_point_type() == kFloat) {
sizeof_floating = sizeof(float);
} else {
sizeof_floating = sizeof(double);
}
for (const TaskProto& taskproto : ofelf.task()) {
if (taskproto.machine_id() != RuntimeInfo::Singleton().this_machine_id()) { continue; }
uint64_t actor_id = IDMgr::Singleton().GetActorIdFromTaskId(taskproto.id());
for (const RegstDescProto& regstdesc : taskproto.produced_regst_desc()) {
NewRegstFromRegstDesc(actor_id,
regstdesc,
sizeof_floating,
actor_id2produced_regst_desc_id,
regst_desc_id2regst_ids);
}
}
//for consumer_ids, lbn2blob
for (const TaskProto& taskproto : ofelf.task()) {
if (taskproto.machine_id() != RuntimeInfo::Singleton().this_machine_id()) { continue; }
uint64_t actor_id = IDMgr::Singleton().GetActorIdFromTaskId(taskproto.id());
HashSet<uint64_t> processed_consumer;
for (const ExecNodeProto& execnode: taskproto.exec_sequence().exec_node()) {
for (const auto& mpair : execnode.bn_in_op2regst_desc_id()) {
if (actor_id2produced_regst_desc_id.at(actor_id).find(mpair.second) !=
actor_id2produced_regst_desc_id.at(actor_id).end()) { continue; }
if (processed_consumer.find(mpair.second) != processed_consumer.end()) { continue; }
for (uint64_t regst_id : regst_desc_id2regst_ids[mpair.second]) {
GetRegstFromRegstID(regst_id)->consumer_ids_.push_back(actor_id);
}
}
}
}*/
TODO();
} }
} }
...@@ -20,22 +20,12 @@ class RegstMgr final { ...@@ -20,22 +20,12 @@ class RegstMgr final {
return obj; return obj;
} }
Regst* GetRegstFromRegstID(uint64_t regst_id) { void NewRegsts(const RegstDescProto& regst_desc_proto,
return regst_id2regst_.at(regst_id).get(); std::function<void(Regst*)> OneRegstDone);
}
void InitFromProto(const OfElf& ofelf);
private: private:
RegstMgr() = default; RegstMgr() = default;
void NewRegstFromRegstDesc(
uint64_t producer_id,
const RegstDescProto& regstdesc,
std::size_t sizeof_floating,
HashMap<uint64_t, HashSet<uint64_t>>& actor_id2produced_regst_desc_id,
HashMap<uint64_t, std::vector<uint64_t>>& regst_desc_id2regst_ids);
HashMap<uint64_t, std::unique_ptr<Regst>> regst_id2regst_;
}; };
} // namespace oneflow } // namespace oneflow
......
#ifndef ONEFLOW_REGISTER_RUNTIME_REGISTER_DESC_H_
#define ONEFLOW_REGISTER_RUNTIME_REGISTER_DESC_H_
#include "common/util.h"
#include "memory/memory_case.pb.h"
#include "register/register_desc.pb.h"
namespace oneflow {
class RtRegstDesc {
public:
OF_DISALLOW_COPY_AND_MOVE(RtRegstDesc);
RtRegstDesc() = delete;
~RtRegstDesc() = default;
RtRegstDesc(const RegstDescProto&) { TODO(); }
// TODO: Add Getter
private:
uint64_t regst_desc_id_;
uint64_t producer_task_id_;
std::vector<uint64_t> subscribers_task_id_;
std::unordered_map<std::string, std::unique_ptr<Shape>> lbn2shape_;
int64_t register_num_;
MemoryCase mem_case_;
};
} // namespace oneflow
#endif // ONEFLOW_REGISTER_RUNTIME_REGISTER_DESC_H_
...@@ -23,7 +23,6 @@ class ElfRunner final { ...@@ -23,7 +23,6 @@ class ElfRunner final {
JobDesc::Singleton().InitFromProto(elf.job_desc()); JobDesc::Singleton().InitFromProto(elf.job_desc());
IDMgr::Singleton().InitFromResource(JobDesc::Singleton().resource()); IDMgr::Singleton().InitFromResource(JobDesc::Singleton().resource());
RuntimeInfo::Singleton().set_this_machine_name(this_machine_name); RuntimeInfo::Singleton().set_this_machine_name(this_machine_name);
RegstMgr::Singleton().InitFromProto(elf);
TODO(); TODO();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册