提交 903cf3f8 编写于 作者: W willzhang4a58

refine interface for register

上级 f95e9857
......@@ -4,6 +4,8 @@ namespace oneflow {
std::pair<char*, std::function<void()>> MemoryAllocator::Allocate(
MemoryCase mem_case,std::size_t size) {
TODO();
/*
char* dptr = nullptr;
if (mem_case.has_host_pageable_mem()) {
dptr = (char*) malloc (size);
......@@ -20,9 +22,12 @@ std::pair<char*, std::function<void()>> MemoryAllocator::Allocate(
CHECK_EQ(cudaSetDevice(current_device_id), 0);
}
return {dptr, std::bind(&MemoryAllocator::Deallocate, this, dptr, mem_case)};
*/
}
void MemoryAllocator::Deallocate(char* dptr, MemoryCase mem_case) {
TODO();
/*
if (mem_case.has_host_pageable_mem()) {
free(dptr);
} else if (mem_case.has_cuda_pinned_mem()) {
......@@ -35,7 +40,8 @@ void MemoryAllocator::Deallocate(char* dptr, MemoryCase mem_case) {
CHECK_EQ(cudaSetDevice(mem_case.gpu_mem().device_id()), 0);
CHECK_EQ(cudaFree(&dptr), 0);
CHECK_EQ(cudaSetDevice(current_device_id), 0);
}
}
*/
}
} // namespace oneflow
......@@ -4,21 +4,19 @@ package oneflow;
message HostPageableMemory {
}
message CudaPinnedMemory {
message HostPinnedMemory {
bool need_cuda = 1;
bool need_rdma = 2;
}
message RdmaPinnedMemory {
}
message GpuMemory {
int32 device_id = 1;
message DeviceCudaMemory {
uint64 device_id = 1;
}
message MemoryCase {
oneof case {
HostPageableMemory host_pageable_mem = 1;
CudaPinnedMemory cuda_pinned_mem = 2;
RdmaPinnedMemory rdma_pinned_mem = 3;
GpuMemory gpu_mem = 4;
HostPinnedMemory host_pinned_mem = 2;
DeviceCudaMemory device_cuda_mem = 3;
}
}
......@@ -9,17 +9,17 @@ namespace oneflow {
class Blob {
public:
OF_DISALLOW_COPY_AND_MOVE(Blob);
Blob(char* dptr, const Shape& shape) : dptr_(dptr), shape_(shape) {}
Blob(char* dptr, const Shape* shape) : dptr_(dptr), shape_(shape) {}
~Blob() {}
const char* dptr() const { return dptr_; }
const Shape& shape() const { return shape_; }
const Shape& shape() const { return *shape_; }
char* mut_dptr() { return dptr_; }
private:
char* dptr_ ;
Shape shape_;
const Shape* shape_;
};
} // namespace oneflow
......
......@@ -5,6 +5,7 @@
#include "actor/actor_message.pb.h"
#include "actor/actor_msg_bus.h"
#include "common/util.h"
#include "register/runtime_register_desc.h"
namespace oneflow {
......@@ -20,9 +21,9 @@ class Regst final {
private:
friend class RegstMgr;
Regst() = default;
uint64_t id_;
uint64_t producer_id_;
std::vector<uint64_t> consumer_ids_;
std::shared_ptr<const RtRegstDesc> regst_desc_;
uint64_t regst_id_;
std::function<void()> deleter_;
HashMap<std::string, std::unique_ptr<Blob>> lbn2blob_;
};
......
......@@ -2,76 +2,11 @@
namespace oneflow {
void RegstMgr::NewRegstFromRegstDesc(
uint64_t producer_id,
const RegstDescProto& regstdesc,
std::size_t sizeof_floating,
HashMap<uint64_t, HashSet<uint64_t>>& actor_id2produced_regst_desc_id,
HashMap<uint64_t, std::vector<uint64_t>>& regst_desc_id2regst_ids) {
uint64_t regst_desc_id = regstdesc.regst_desc_id();
for (int64_t i = 0; i < regstdesc.register_num(); ++i) {
std::unique_ptr<Regst> regst(new Regst());
regst->id_ = IDMgr::Singleton().NewRegstId(regst_desc_id);
regst->producer_id_ = producer_id;
std::size_t regst_size = 0;
for (const auto& mpair : regstdesc.lbn2shape()) {
Shape shape(mpair.second);
regst_size += shape.elem_cnt() * sizeof_floating;
}
auto mem_info = MemoryAllocator::Singleton().Allocate(MemoryCase(),
regst_size);
regst->deleter_ = mem_info.second;
char* dptr = mem_info.first;
for (const auto& mpair : regstdesc.lbn2shape()) {
Shape shape(mpair.second);
regst->lbn2blob_.emplace(mpair.first, of_make_unique<Blob>(dptr, shape));
dptr += shape.elem_cnt() * sizeof_floating;
}
regst_id2regst_.emplace(regst->id_, std::move(regst));
actor_id2produced_regst_desc_id[producer_id].insert(regst_desc_id);
regst_desc_id2regst_ids[regst_desc_id].push_back(regst->id_);
}
}
void RegstMgr::InitFromProto(const OfElf& ofelf) {
/*
//Init all regst for id, cnt, producer_id, lbn2blob
HashMap<uint64_t, HashSet<uint64_t>> actor_id2produced_regst_desc_id;
HashMap<uint64_t, std::vector<uint64_t>> regst_desc_id2regst_ids;
std::size_t sizeof_floating;
if (ofelf.job_desc().floating_point_type() == kFloat) {
sizeof_floating = sizeof(float);
} else {
sizeof_floating = sizeof(double);
}
for (const TaskProto& taskproto : ofelf.task()) {
if (taskproto.machine_id() != RuntimeInfo::Singleton().this_machine_id()) { continue; }
uint64_t actor_id = IDMgr::Singleton().GetActorIdFromTaskId(taskproto.id());
for (const RegstDescProto& regstdesc : taskproto.produced_regst_desc()) {
NewRegstFromRegstDesc(actor_id,
regstdesc,
sizeof_floating,
actor_id2produced_regst_desc_id,
regst_desc_id2regst_ids);
}
}
//for consumer_ids, lbn2blob
for (const TaskProto& taskproto : ofelf.task()) {
if (taskproto.machine_id() != RuntimeInfo::Singleton().this_machine_id()) { continue; }
uint64_t actor_id = IDMgr::Singleton().GetActorIdFromTaskId(taskproto.id());
HashSet<uint64_t> processed_consumer;
for (const ExecNodeProto& execnode: taskproto.exec_sequence().exec_node()) {
for (const auto& mpair : execnode.bn_in_op2regst_desc_id()) {
if (actor_id2produced_regst_desc_id.at(actor_id).find(mpair.second) !=
actor_id2produced_regst_desc_id.at(actor_id).end()) { continue; }
if (processed_consumer.find(mpair.second) != processed_consumer.end()) { continue; }
for (uint64_t regst_id : regst_desc_id2regst_ids[mpair.second]) {
GetRegstFromRegstID(regst_id)->consumer_ids_.push_back(actor_id);
}
}
}
}*/
TODO();
void NewRegsts(const RegstDescProto& regst_desc_proto,
std::function<void(Regst*)> OneRegstDone) {
// One RegstDesc means Multi Regst
// All Regst has a shared_ptr point to the same RtRegstDesc obj
// Call OneRegstDone for each regst
}
}
......@@ -20,22 +20,12 @@ class RegstMgr final {
return obj;
}
Regst* GetRegstFromRegstID(uint64_t regst_id) {
return regst_id2regst_.at(regst_id).get();
}
void InitFromProto(const OfElf& ofelf);
void NewRegsts(const RegstDescProto& regst_desc_proto,
std::function<void(Regst*)> OneRegstDone);
private:
RegstMgr() = default;
void NewRegstFromRegstDesc(
uint64_t producer_id,
const RegstDescProto& regstdesc,
std::size_t sizeof_floating,
HashMap<uint64_t, HashSet<uint64_t>>& actor_id2produced_regst_desc_id,
HashMap<uint64_t, std::vector<uint64_t>>& regst_desc_id2regst_ids);
HashMap<uint64_t, std::unique_ptr<Regst>> regst_id2regst_;
};
} // namespace oneflow
......
#ifndef ONEFLOW_REGISTER_RUNTIME_REGISTER_DESC_H_
#define ONEFLOW_REGISTER_RUNTIME_REGISTER_DESC_H_
#include "common/util.h"
#include "memory/memory_case.pb.h"
#include "register/register_desc.pb.h"
namespace oneflow {
class RtRegstDesc {
public:
OF_DISALLOW_COPY_AND_MOVE(RtRegstDesc);
RtRegstDesc() = delete;
~RtRegstDesc() = default;
RtRegstDesc(const RegstDescProto&) { TODO(); }
// TODO: Add Getter
private:
uint64_t regst_desc_id_;
uint64_t producer_task_id_;
std::vector<uint64_t> subscribers_task_id_;
std::unordered_map<std::string, std::unique_ptr<Shape>> lbn2shape_;
int64_t register_num_;
MemoryCase mem_case_;
};
} // namespace oneflow
#endif // ONEFLOW_REGISTER_RUNTIME_REGISTER_DESC_H_
......@@ -23,7 +23,6 @@ class ElfRunner final {
JobDesc::Singleton().InitFromProto(elf.job_desc());
IDMgr::Singleton().InitFromResource(JobDesc::Singleton().resource());
RuntimeInfo::Singleton().set_this_machine_name(this_machine_name);
RegstMgr::Singleton().InitFromProto(elf);
TODO();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册