未验证 提交 0e32ee69 编写于 作者: L leaves-zwx 提交者: GitHub

Remove get local work stream id api (#4227)

* rm LocalWorkStreamId

* rm AllocateLocalWorkStreamId in TaskNode

* rm local work stream id in task node and commnet task node

* rm local_work_stream_id param in NewTaskId

* fix test
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 aebb9f5b
......@@ -259,14 +259,12 @@ int64_t Actor::GetPieceId4NaiveOrInplaceCurReadableDataRegst() const {
void Actor::InitDeviceCtx(const ThreadCtx& thread_ctx) {
switch (GetDeviceType()) {
case DeviceType::kCPU: {
CHECK_EQ(GetLocalWorkStreamId(), 0);
device_ctx_.reset(new CpuDeviceCtx());
break;
}
#ifdef WITH_CUDA
case DeviceType::kGPU: {
CudaStreamHandle* cuda_handle = nullptr;
CHECK_EQ(GetLocalWorkStreamId(), 0);
cuda_handle = thread_ctx.g_cuda_stream.get();
device_ctx_.reset(new CudaDeviceCtx(cuda_handle));
break;
......@@ -713,10 +711,6 @@ int64_t Actor::GetGlobalWorkStreamId() const {
return Global<IDMgr>::Get()->GlobalWorkStreamId4ActorId(actor_id_);
}
int64_t Actor::GetLocalWorkStreamId() const {
return Global<IDMgr>::Get()->LocalWorkStreamId4ActorId(actor_id_);
}
Regst* Actor::GetNaiveOrInplaceCurReadable(int64_t regst_desc_id) const {
Regst* regst = naive_consumed_rs_.Front(regst_desc_id);
if (regst == nullptr) { regst = inplace_consumed_rs_.Front(regst_desc_id); }
......
......@@ -148,7 +148,6 @@ class Actor {
protected:
int64_t GetGlobalWorkStreamId() const;
int64_t GetLocalWorkStreamId() const;
virtual bool NeedCollectActEvent() const {
return Global<RuntimeCtx>::Get()->NeedCollectActEvent();
}
......
......@@ -57,7 +57,7 @@ TaskNode* SubTskGphBuilderCtx::GetProxyNode(TaskNode* src_node, int64_t src_mem_
GetProxyNode(src_node, src_mem_zone_id, src_node->machine_id(),
Global<IDMgr>::Get()->CpuMemZoneId());
CopyCommNetTaskNode* copy_comm_net_task = task_graph()->NewNode<CopyCommNetTaskNode>();
copy_comm_net_task->Init(dst_machine_id, proxy_on_src_host->machine_id());
copy_comm_net_task->Init(dst_machine_id);
Connect<TaskNode>(proxy_on_src_host, task_graph()->NewEdge(), copy_comm_net_task);
node2proxies_[src_node][key] = copy_comm_net_task;
return copy_comm_net_task;
......
......@@ -90,44 +90,9 @@ OperatorConf CopyHdTaskNode::NewCopyOpConf() {
return conf;
}
void CopyCommNetTaskNode::Init(int64_t machine_id, int64_t src_machine_id) {
void CopyCommNetTaskNode::Init(int64_t machine_id) {
set_machine_id(machine_id);
set_thrd_id(Global<IDMgr>::Get()->CommNetThrdId());
peer_machine_id_ = src_machine_id;
}
namespace {
HashMap<int64_t, HashMap<int64_t, int64_t>>* GetConnection2LocalStreamIdMap() {
// this_machine_id -> {peer_machine_id, local_work_stream_id}
static HashMap<int64_t, HashMap<int64_t, int64_t>> connection2stream_id;
return &connection2stream_id;
}
int64_t GetLocalStreamId4Connection(int64_t this_machine_id, int64_t peer_machine_id) {
auto& dict = *GetConnection2LocalStreamIdMap();
auto this_machine_it = dict.find(this_machine_id);
if (this_machine_it == dict.end()) { return -1; }
auto peer_machine_it = this_machine_it->second.find(peer_machine_id);
if (peer_machine_it == this_machine_it->second.end()) { return -1; }
return peer_machine_it->second;
}
void InsertLocalStreamId4Connection(int64_t this_machine_id, int64_t peer_machine_id) {
auto& dict = *GetConnection2LocalStreamIdMap();
dict[this_machine_id][peer_machine_id] = dict[this_machine_id].size();
}
} // namespace
int64_t CopyCommNetTaskNode::AllocateLocalWorkStreamId() {
int64_t this_machine_id = machine_id();
int64_t local_work_stream_id = GetLocalStreamId4Connection(this_machine_id, peer_machine_id_);
if (local_work_stream_id == -1) {
InsertLocalStreamId4Connection(this_machine_id, peer_machine_id_);
local_work_stream_id = GetLocalStreamId4Connection(this_machine_id, peer_machine_id_);
}
return local_work_stream_id;
}
void CopyCommNetTaskNode::InitProducedRegstMemCase(MemoryCase* mem_case) {
......
......@@ -74,15 +74,12 @@ class CopyCommNetTaskNode final : public CopyTaskNode {
TaskType GetTaskType() const override { return TaskType::kCopyCommNet; }
void Init(int64_t machine_id, int64_t src_machine_id);
int64_t AllocateLocalWorkStreamId() override;
int64_t peer_machine_id() const { return peer_machine_id_; }
void Init(int64_t machine_id);
private:
void InitProducedRegstMemCase(MemoryCase*) override;
void PinConsumedRegstMemCase(MemoryCase*) override;
OperatorConf NewCopyOpConf() override;
int64_t peer_machine_id_;
};
} // namespace oneflow
......
......@@ -752,7 +752,7 @@ TaskNode* TaskGraph::AddCopyD2HTaskFrom(TaskNode* task) {
TaskNode* TaskGraph::AddCopyCommNetTaskBetween(TaskNode* src, TaskNode* dst) {
CHECK_NE(src->machine_id(), dst->machine_id());
CopyCommNetTaskNode* copy_comm_net_task = NewNode<CopyCommNetTaskNode>();
copy_comm_net_task->Init(dst->machine_id(), src->machine_id());
copy_comm_net_task->Init(dst->machine_id());
return copy_comm_net_task;
}
......
......@@ -382,21 +382,10 @@ void TaskNode::FixRegisterNumRange() {
}
}
int64_t TaskNode::AllocateLocalWorkStreamId() {
CHECK_NE(machine_id_, -1);
CHECK_NE(thrd_id_, -1);
return 0;
}
void TaskNode::UpdateTaskId() {
CHECK_NE(machine_id_, -1);
CHECK_NE(thrd_id_, -1);
task_id_ = Global<IDMgr>::Get()->NewTaskId(machine_id_, thrd_id_, AllocateLocalWorkStreamId());
}
int64_t TaskNode::LocalWorkStreamId() const {
CHECK_NE(task_id_, -1);
return Global<IDMgr>::Get()->LocalWorkStreamId4TaskId(task_id_);
task_id_ = Global<IDMgr>::Get()->NewTaskId(machine_id_, thrd_id_);
}
int64_t TaskNode::GlobalWorkStreamId() const {
......
......@@ -56,7 +56,6 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
}
DeviceType device_type() const;
virtual const ParallelContext* parallel_ctx() const { return nullptr; }
int64_t LocalWorkStreamId() const;
int64_t GlobalWorkStreamId() const;
int64_t GpuPhyId() const { return Global<IDMgr>::Get()->GetGpuPhyIdFromThrdId(thrd_id_); }
virtual int64_t AreaId4ChainMerge() const { return area_id(); }
......@@ -132,8 +131,6 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
virtual void LockRegsts();
void FixRegisterNumRange();
virtual int64_t AllocateLocalWorkStreamId();
virtual void InferProducedDataRegstTimeShape() = 0;
void NaiveInferProducedDataRegstTimeShape();
......
......@@ -43,13 +43,11 @@ void IDMgr::UpdateBaseIndependentThrdId(int64_t val) {
if (val >= base_independent_thrd_id_) { base_independent_thrd_id_ = val + 1; }
}
int64_t IDMgr::NewTaskId(int64_t machine_id, int64_t thrd_id, int64_t local_work_stream_id) {
int64_t IDMgr::NewTaskId(int64_t machine_id, int64_t thrd_id) {
int64_t machine_thrd_id = GetMachineThrdId(machine_id, thrd_id);
CHECK_LT(machine_thrd_id2num_of_tasks_[machine_thrd_id],
(static_cast<int64_t>(1) << task_id_bit_num_) - 1);
CHECK_LT(local_work_stream_id, static_cast<int64_t>(1) << local_work_stream_id_bit_num_);
return machine_thrd_id | (local_work_stream_id << task_id_bit_num_)
| (machine_thrd_id2num_of_tasks_[machine_thrd_id]++);
return machine_thrd_id | (machine_thrd_id2num_of_tasks_[machine_thrd_id]++);
}
DeviceType IDMgr::GetDeviceTypeFromThrdId(int64_t thrd_id) const {
......@@ -80,10 +78,6 @@ int64_t IDMgr::ThrdId4ActorId(int64_t actor_id) const {
return tmp >> (63 - thread_id_bit_num_);
}
int64_t IDMgr::AllocateLocalWorkStreamId(int64_t machine_id, int64_t thrd_id) {
return 100 + (machine_thrd_id2stream_id_cnt_[GetMachineThrdId(machine_id, thrd_id)]++);
}
int64_t IDMgr::GlobalWorkStreamId4TaskId(int64_t task_id) const {
return (task_id >> task_id_bit_num_) << task_id_bit_num_;
}
......@@ -97,16 +91,6 @@ int64_t IDMgr::GlobalThrdId4TaskId(int64_t task_id) const {
return (task_id >> shift) << shift;
}
int64_t IDMgr::LocalWorkStreamId4TaskId(int64_t task_id) const {
int64_t tmp = (task_id << (machine_id_bit_num_ + thread_id_bit_num_));
tmp &= ~(static_cast<int64_t>(1) << 63);
return tmp >> (63 - local_work_stream_id_bit_num_);
}
int64_t IDMgr::LocalWorkStreamId4ActorId(int64_t actor_id) const {
return LocalWorkStreamId4TaskId(actor_id);
}
int64_t IDMgr::AllocateChainId(int64_t global_work_stream_id) {
CHECK_LT(stream_id2chain_cnt_[global_work_stream_id],
(static_cast<int64_t>(1) << task_id_bit_num_) - 1);
......
......@@ -41,7 +41,7 @@ class IDMgr final {
int64_t BaseIndependentThrdId() const;
void UpdateBaseIndependentThrdId(int64_t val);
int64_t NewTaskId(int64_t machine_id, int64_t thrd_id, int64_t local_work_stream_id);
int64_t NewTaskId(int64_t machine_id, int64_t thrd_id);
int64_t NewRegstDescId() { return regst_desc_id_count_++; }
int64_t NewMemBlockId() { return mem_block_id_count_++; }
int64_t NewChunkId() { return chunk_id_count_++; }
......@@ -65,14 +65,6 @@ class IDMgr final {
int64_t MachineId4ActorId(int64_t actor_id) const;
int64_t ThrdId4ActorId(int64_t actor_id) const;
// local_work_stream_id
// for cpu:
// 0: the actor thread
// for gpu:
// 0: the global cuda stream
int64_t AllocateLocalWorkStreamId(int64_t machine_id, int64_t thrd_id);
int64_t LocalWorkStreamId4TaskId(int64_t task_id) const;
int64_t LocalWorkStreamId4ActorId(int64_t actor_id) const;
// global_thread_id
// sign | machine_id | thrd_id | 0 | 0
// 1 | 10 | 11 | 21 | 21
......
......@@ -64,14 +64,12 @@ TEST(IDMgr, compile_task_id) {
(static_cast<int64_t>(1) << machine_id_shl) + (static_cast<int64_t>(2) << thread_id_shl);
int64_t machine3thrd4 =
(static_cast<int64_t>(3) << machine_id_shl) + (static_cast<int64_t>(4) << thread_id_shl);
int64_t local_work_stream1 = (static_cast<int64_t>(1) << local_work_stream_shl);
int64_t local_work_stream3 = (static_cast<int64_t>(3) << local_work_stream_shl);
ASSERT_EQ(Global<IDMgr>::Get()->NewTaskId(1, 2, 0), machine1thrd2 | 0 | 0);
ASSERT_EQ(Global<IDMgr>::Get()->NewTaskId(1, 2, 1), machine1thrd2 | local_work_stream1 | 1);
ASSERT_EQ(Global<IDMgr>::Get()->NewTaskId(1, 2, 1), machine1thrd2 | local_work_stream1 | 2);
ASSERT_EQ(Global<IDMgr>::Get()->NewTaskId(3, 4, 1), machine3thrd4 | local_work_stream1 | 0);
ASSERT_EQ(Global<IDMgr>::Get()->NewTaskId(3, 4, 1), machine3thrd4 | local_work_stream1 | 1);
ASSERT_EQ(Global<IDMgr>::Get()->NewTaskId(3, 4, 3), machine3thrd4 | local_work_stream3 | 2);
ASSERT_EQ(Global<IDMgr>::Get()->NewTaskId(1, 2), machine1thrd2 | 0);
ASSERT_EQ(Global<IDMgr>::Get()->NewTaskId(1, 2), machine1thrd2 | 1);
ASSERT_EQ(Global<IDMgr>::Get()->NewTaskId(1, 2), machine1thrd2 | 2);
ASSERT_EQ(Global<IDMgr>::Get()->NewTaskId(3, 4), machine3thrd4 | 0);
ASSERT_EQ(Global<IDMgr>::Get()->NewTaskId(3, 4), machine3thrd4 | 1);
ASSERT_EQ(Global<IDMgr>::Get()->NewTaskId(3, 4), machine3thrd4 | 2);
Delete();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册