提交 b09fb51c 编写于 作者: W Wind5

rename

上级 53b5fbd3
......@@ -8,14 +8,20 @@ namespace oneflow {
namespace {
void SetHostPinnedMemory4CopyTask(MemoryCase& mem_case,
const HashSet<const TaskNode*>& subs) {
void SetDeviceCudaMemoryAccordingToThrdLocId(MemoryCase& mem_case,
uint64_t thrd_loc_id) {
uint64_t device_id = IDMgr::Singleton().DevPhyId4ThrdLocId(thrd_loc_id);
mem_case.mutable_device_cuda_mem()->set_device_id(device_id);
}
void SetHostPinnedMemoryAccordingToSubscribers(MemoryCase& mem_case,
const HashSet<const TaskNode*>& subs) {
for (const TaskNode* sub : subs) {
if (sub->task_type() == kCopyCommNetTask) {
mem_case.mutable_host_pinned_mem()->set_need_rdma(true);
}
if (const CopyHDTaskNode* cp_hd = dynamic_cast<const CopyHDTaskNode*>(sub)) {
if (cp_hd->IsH2D()) {
if (auto cp_hd_sub = dynamic_cast<const CopyHDTaskNode*>(sub)) {
if (cp_hd_sub->IsH2D()) {
mem_case.mutable_host_pinned_mem()->set_need_cuda(true);
}
}
......@@ -113,24 +119,22 @@ void RegstDesc::ToProto(RegstDescProto* ret) const {
MemoryCase RegstDesc::InferMemCase() const {
MemoryCase mem_case;
uint64_t device_id = IDMgr::Singleton().DevPhyId4ThrdLocId(producer_->thrd_loc_id());
DeviceType device_type = producer_->chain_node()->parallel_desc()->device_type();
if (const CopyHDTaskNode* cp_hd = dynamic_cast<const CopyHDTaskNode*>(producer_)) {
if (cp_hd->IsH2D()) {
mem_case.mutable_device_cuda_mem()->set_device_id(device_id);
if (auto cp_hd_producer = dynamic_cast<const CopyHDTaskNode*>(producer_)) {
if (cp_hd_producer->IsH2D()) {
SetDeviceCudaMemoryAccordingToThrdLocId(mem_case, producer_->thrd_loc_id());
} else {
mem_case.mutable_host_pinned_mem()->set_need_cuda(true);
SetHostPinnedMemory4CopyTask(mem_case, subscribers_);
SetHostPinnedMemoryAccordingToSubscribers(mem_case, subscribers_);
}
} else if (producer_->task_type() == kCopyCommNetTask) {
mem_case.mutable_host_pinned_mem()->set_need_rdma(true);
SetHostPinnedMemory4CopyTask(mem_case, subscribers_);
SetHostPinnedMemoryAccordingToSubscribers(mem_case, subscribers_);
} else {
if (device_type == kGPU && producer_->task_type() != kBoxingTask) {
mem_case.mutable_device_cuda_mem()->set_device_id(device_id);
SetDeviceCudaMemoryAccordingToThrdLocId(mem_case, producer_->thrd_loc_id());
} else {
mem_case.mutable_host_pageable_mem();
SetHostPinnedMemory4CopyTask(mem_case, subscribers_);
SetHostPinnedMemoryAccordingToSubscribers(mem_case, subscribers_);
}
}
return mem_case;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册