Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
9ec6871d
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
9ec6871d
编写于
11月 02, 2021
作者:
L
leaves-zwx
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify by review
上级
d338b271
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
84 addition
and
72 deletion
+84
-72
oneflow/core/device/device_id.h
oneflow/core/device/device_id.h
+21
-16
oneflow/core/device/stream_index.h
oneflow/core/device/stream_index.h
+1
-1
oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.cpp
...graph/boxing/collective_boxing_sub_task_graph_builder.cpp
+9
-9
oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.cpp
...ow/core/graph/boxing/naive_b2p_sub_task_graph_builder.cpp
+4
-3
oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp
...core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp
+2
-2
oneflow/core/graph/copy_task_node.cpp
oneflow/core/graph/copy_task_node.cpp
+2
-2
oneflow/core/graph/stream_index_getter_registry_manager.cpp
oneflow/core/graph/stream_index_getter_registry_manager.cpp
+1
-1
oneflow/core/graph/stream_index_getter_registry_manager.h
oneflow/core/graph/stream_index_getter_registry_manager.h
+1
-1
oneflow/core/graph/task_graph.cpp
oneflow/core/graph/task_graph.cpp
+11
-9
oneflow/core/graph/task_id.cpp
oneflow/core/graph/task_id.cpp
+5
-4
oneflow/core/graph/task_id.h
oneflow/core/graph/task_id.h
+7
-6
oneflow/core/graph/task_id_generator.h
oneflow/core/graph/task_id_generator.h
+1
-1
oneflow/core/memory/memory_zone.cpp
oneflow/core/memory/memory_zone.cpp
+3
-3
oneflow/core/memory/memory_zone.h
oneflow/core/memory/memory_zone.h
+1
-1
oneflow/core/stream/stream_id.cpp
oneflow/core/stream/stream_id.cpp
+4
-3
oneflow/core/stream/stream_id.h
oneflow/core/stream/stream_id.h
+11
-10
未找到文件。
oneflow/core/device/device_id.h
浏览文件 @
9ec6871d
...
...
@@ -29,27 +29,32 @@ namespace oneflow {
class
DeviceId
{
public:
using
index_t
=
uint32_t
;
using
node_index_t
=
uint32_t
;
using
device_type_t
=
uint32_t
;
using
device_index_t
=
uint32_t
;
constexpr
static
size_t
kNodeIndexBits
=
19
;
constexpr
static
size_t
kDeviceTypeBits
=
5
;
constexpr
static
size_t
kDeviceIndexBits
=
7
;
constexpr
static
index_t
kMaxNodeIndex
=
(
index_t
{
1
}
<<
kNodeIndexBits
)
-
index_t
{
1
};
constexpr
static
index_t
kMaxDeviceTypeVal
=
(
index_t
{
1
}
<<
kDeviceTypeBits
)
-
index_t
{
1
};
constexpr
static
index_t
kMaxDeviceIndex
=
(
index_t
{
1
}
<<
kDeviceIndexBits
)
-
index_t
{
1
};
DeviceId
(
index_t
node_index
,
DeviceType
device_type
,
index_t
device_index
)
constexpr
static
node_index_t
kMaxNodeIndex
=
(
node_index_t
{
1
}
<<
kNodeIndexBits
)
-
node_index_t
{
1
};
constexpr
static
device_type_t
kMaxDeviceTypeVal
=
(
device_type_t
{
1
}
<<
kDeviceTypeBits
)
-
device_type_t
{
1
};
constexpr
static
device_index_t
kMaxDeviceIndex
=
(
device_index_t
{
1
}
<<
kDeviceIndexBits
)
-
device_index_t
{
1
};
DeviceId
(
node_index_t
node_index
,
DeviceType
device_type
,
device_index_t
device_index
)
:
node_index_
(
node_index
),
device_type_
(
static_cast
<
index
_t
>
(
device_type
)),
device_type_
(
static_cast
<
device_type
_t
>
(
device_type
)),
device_index_
(
device_index
)
{
CHECK_LE
(
node_index_
,
kMaxNodeIndex
);
CHECK_LE
(
device_type_
,
kMaxDeviceTypeVal
);
CHECK_LE
(
device_index
,
kMaxDeviceIndex
);
CHECK_LE
(
device_index
_
,
kMaxDeviceIndex
);
}
index_t
node_index
()
const
{
return
node_index_
;
}
node_
index_t
node_index
()
const
{
return
node_index_
;
}
DeviceType
device_type
()
const
{
return
static_cast
<
DeviceType
>
(
device_type_
);
}
index_t
device_index
()
const
{
return
device_index_
;
}
device_
index_t
device_index
()
const
{
return
device_index_
;
}
bool
operator
==
(
const
DeviceId
&
rhs
)
const
{
return
node_index_
==
rhs
.
node_index_
&&
device_type_
==
rhs
.
device_type_
...
...
@@ -59,16 +64,16 @@ class DeviceId {
bool
operator
!=
(
const
DeviceId
&
rhs
)
const
{
return
!
(
*
this
==
rhs
);
}
size_t
hash
()
const
{
size_t
hash
=
std
::
hash
<
index_t
>
{}(
node_index_
);
HashCombine
(
&
hash
,
std
::
hash
<
index
_t
>
{}(
device_type_
));
HashCombine
(
&
hash
,
std
::
hash
<
index_t
>
{}(
device_index_
));
size_t
hash
=
std
::
hash
<
node_
index_t
>
{}(
node_index_
);
HashCombine
(
&
hash
,
std
::
hash
<
device_type
_t
>
{}(
device_type_
));
HashCombine
(
&
hash
,
std
::
hash
<
device_
index_t
>
{}(
device_index_
));
return
hash
;
}
private:
index_t
node_index_
;
index
_t
device_type_
;
index_t
device_index_
;
node_
index_t
node_index_
;
device_type
_t
device_type_
;
device_
index_t
device_index_
;
};
}
// namespace oneflow
...
...
oneflow/core/device/stream_index.h
浏览文件 @
9ec6871d
...
...
@@ -25,7 +25,7 @@ namespace oneflow {
class
StreamIndexGenerator
{
public:
virtual
~
StreamIndexGenerator
()
{}
using
index_t
=
StreamId
::
index_t
;
using
index_t
=
StreamId
::
stream_
index_t
;
virtual
index_t
GenerateComputeStreamIndex
()
=
0
;
virtual
index_t
GenerateH2DStreamIndex
()
=
0
;
...
...
oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.cpp
浏览文件 @
9ec6871d
...
...
@@ -65,8 +65,8 @@ void NcclInitCollectiveNode(CollectiveBoxingGenericTaskNode* node,
const
int64_t
machine_id
=
CHECK_JUST
(
parallel_desc
.
MachineId4ParallelId
(
parallel_id
));
const
int64_t
device_index
=
CHECK_JUST
(
parallel_desc
.
DeviceId4ParallelId
(
parallel_id
));
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
index_t
>
(
device_index
)};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_
index_t
>
(
machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
device_
index_t
>
(
device_index
)};
auto
*
stream_index_generator
=
dynamic_cast
<
CudaStreamIndexGenerator
*>
(
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
));
CHECK_NOTNULL
(
stream_index_generator
);
...
...
@@ -191,8 +191,8 @@ class NcclCollectiveBoxingP2SNoncontinuousSubTskGphBuilder final : public SubTsk
FOR_RANGE
(
int64_t
,
i
,
0
,
in_parallel_desc
.
parallel_num
())
{
const
int64_t
machine_id
=
CHECK_JUST
(
in_parallel_desc
.
MachineId4ParallelId
(
i
));
const
int64_t
device_index
=
CHECK_JUST
(
in_parallel_desc
.
DeviceId4ParallelId
(
i
));
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
index_t
>
(
device_index
)};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_
index_t
>
(
machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
device_
index_t
>
(
device_index
)};
auto
*
stream_index_generator
=
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
);
auto
stream_index
=
stream_index_generator
->
GenerateComputeStreamIndex
();
...
...
@@ -293,8 +293,8 @@ class NcclCollectiveBoxingS2BNoncontinuousSubTskGphBuilder final : public SubTsk
FOR_RANGE
(
int64_t
,
i
,
0
,
in_parallel_desc
.
parallel_num
())
{
const
int64_t
machine_id
=
CHECK_JUST
(
out_parallel_desc
.
MachineId4ParallelId
(
i
));
const
int64_t
device_index
=
CHECK_JUST
(
out_parallel_desc
.
DeviceId4ParallelId
(
i
));
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
index_t
>
(
device_index
)};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_
index_t
>
(
machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
device_
index_t
>
(
device_index
)};
auto
*
stream_index_generator
=
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
);
auto
stream_index
=
stream_index_generator
->
GenerateComputeStreamIndex
();
...
...
@@ -406,7 +406,7 @@ class CollectiveBoxingScatterThenNcclAllGatherSubTskGphBuilder final : public Su
SliceBoxingTaskNode
*
slice_node
=
ctx
->
task_graph
()
->
NewNode
<
SliceBoxingTaskNode
>
();
// slice on cpu
const
auto
in_machine_id
=
CHECK_JUST
(
in_parallel_desc
.
MachineId4ParallelId
(
0
));
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
in_machine_id
),
DeviceType
::
kCPU
,
0
};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_
index_t
>
(
in_machine_id
),
DeviceType
::
kCPU
,
0
};
auto
*
stream_index_generator
=
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
);
auto
stream_index
=
stream_index_generator
->
GenerateComputeStreamIndex
();
...
...
@@ -522,8 +522,8 @@ class NcclCollectiveBoxingAll2AllSubTskGphBuilder final : public SubTskGphBuilde
FOR_RANGE
(
int64_t
,
i
,
0
,
in_parallel_desc
.
parallel_num
())
{
const
int64_t
machine_id
=
CHECK_JUST
(
in_parallel_desc
.
MachineId4ParallelId
(
i
));
const
int64_t
device_index
=
CHECK_JUST
(
in_parallel_desc
.
DeviceId4ParallelId
(
i
));
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
index_t
>
(
device_index
)};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_
index_t
>
(
machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
device_
index_t
>
(
device_index
)};
auto
*
stream_index_generator
=
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
);
auto
stream_index
=
stream_index_generator
->
GenerateComputeStreamIndex
();
...
...
oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.cpp
浏览文件 @
9ec6871d
...
...
@@ -58,8 +58,8 @@ Maybe<SubTskGphBuilderStatus> NaiveB2PSubTskGphBuilder::Build(
int64_t
thrd_id
=
-
1
;
if
(
out_parallel_desc
.
device_type
()
==
DeviceType
::
kGPU
)
{
#ifdef WITH_CUDA
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
out_machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
index_t
>
(
out_dev_phy_id
)};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_
index_t
>
(
out_machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
device_
index_t
>
(
out_dev_phy_id
)};
auto
*
stream_index_generator
=
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
);
auto
stream_index
=
stream_index_generator
->
GenerateComputeStreamIndex
();
...
...
@@ -68,7 +68,8 @@ Maybe<SubTskGphBuilderStatus> NaiveB2PSubTskGphBuilder::Build(
UNIMPLEMENTED
();
#endif
}
else
if
(
out_parallel_desc
.
device_type
()
==
DeviceType
::
kCPU
)
{
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
out_machine_id
),
DeviceType
::
kCPU
,
0
};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_index_t
>
(
out_machine_id
),
DeviceType
::
kCPU
,
0
};
auto
*
stream_index_generator
=
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
);
auto
stream_index
=
stream_index_generator
->
GenerateComputeStreamIndex
();
...
...
oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp
浏览文件 @
9ec6871d
...
...
@@ -61,8 +61,8 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
}
else
{
dev_id
=
CHECK_JUST
(
pd
.
DeviceId4ParallelId
(
parallel_id
));
}
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
machine_id
),
pd
.
device_type
(),
static_cast
<
DeviceId
::
index_t
>
(
dev_id
)};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_
index_t
>
(
machine_id
),
pd
.
device_type
(),
static_cast
<
DeviceId
::
device_
index_t
>
(
dev_id
)};
auto
*
stream_index_generator
=
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
);
auto
stream_index
=
stream_index_generator
->
GenerateComputeStreamIndex
();
...
...
oneflow/core/graph/copy_task_node.cpp
浏览文件 @
9ec6871d
...
...
@@ -45,7 +45,7 @@ void CopyHdTaskNode::Init(CopyHdOpConf::Type copy_type, const DeviceId& device_i
set_machine_id
(
device_id
.
node_index
());
auto
*
stream_index_generator
=
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
);
StreamId
::
index_t
stream_index
=
0
;
StreamId
::
stream_
index_t
stream_index
=
0
;
if
(
copy_type
==
CopyHdOpConf
::
H2D
)
{
stream_index
=
stream_index_generator
->
GenerateH2DStreamIndex
();
}
else
if
(
copy_type
==
CopyHdOpConf
::
D2H
)
{
...
...
@@ -84,7 +84,7 @@ OperatorConf CopyHdTaskNode::NewCopyOpConf() {
void
CopyCommNetTaskNode
::
Init
(
int64_t
machine_id
,
const
LogicalBlobId
&
lbi
)
{
set_machine_id
(
machine_id
);
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
machine_id
),
DeviceType
::
kCPU
,
0
};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_
index_t
>
(
machine_id
),
DeviceType
::
kCPU
,
0
};
auto
*
generator
=
dynamic_cast
<
CPUStreamIndexGenerator
*>
(
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
));
CHECK_NOTNULL
(
generator
);
...
...
oneflow/core/graph/stream_index_getter_registry_manager.cpp
浏览文件 @
9ec6871d
...
...
@@ -22,7 +22,7 @@ StreamIndexGetterRegistryManager& StreamIndexGetterRegistryManager::Get() {
return
mgr
;
}
StreamId
::
index_t
StreamIndexGetterRegistryManager
::
StreamIndex4DeviceIdAndTaskType
(
StreamId
::
stream_
index_t
StreamIndexGetterRegistryManager
::
StreamIndex4DeviceIdAndTaskType
(
DeviceId
device_id
,
TaskType
task_type
)
{
auto
index_getter_fn
=
StreamIndexGetterRegistryManager
::
GetStreamIndexGetterFunc
(
device_id
.
device_type
(),
task_type
);
...
...
oneflow/core/graph/stream_index_getter_registry_manager.h
浏览文件 @
9ec6871d
...
...
@@ -47,7 +47,7 @@ class StreamIndexGetterRegistryManager final {
StreamIndexKeyMap
<
StreamIndexGetterFn
>&
StreamIndexGetterFuncs
();
StreamId
::
index_t
StreamIndex4DeviceIdAndTaskType
(
DeviceId
device_id
,
TaskType
task_type
);
StreamId
::
stream_
index_t
StreamIndex4DeviceIdAndTaskType
(
DeviceId
device_id
,
TaskType
task_type
);
private:
StreamIndexGetterFn
GetStreamIndexGetterFunc
(
DeviceType
dev_type
,
TaskType
task_type
);
...
...
oneflow/core/graph/task_graph.cpp
浏览文件 @
9ec6871d
...
...
@@ -284,16 +284,17 @@ void GenSortedCompTaskNodes(const OpNode* op_node, std::vector<CompTaskNode*>* s
comp_task_node
->
mut_parallel_ctx
()
->
set_parallel_id
(
parallel_idx
++
);
comp_task_node
->
mut_parallel_ctx
()
->
set_parallel_num
(
parallel_num
);
DeviceId
::
index_t
device_index
=
parallel_desc
.
device_type
()
==
DeviceType
::
kCPU
?
0
:
static_cast
<
DeviceId
::
index_t
>
(
dev_phy_id
);
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
machine_id
),
parallel_desc
.
device_type
(),
device_index
};
StreamId
::
index_t
stream_index
{};
DeviceId
::
device_index_t
device_index
=
parallel_desc
.
device_type
()
==
DeviceType
::
kCPU
?
0
:
static_cast
<
DeviceId
::
device_index_t
>
(
dev_phy_id
);
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_index_t
>
(
machine_id
),
parallel_desc
.
device_type
(),
device_index
};
StreamId
::
stream_index_t
stream_index
{};
if
(
op_node
->
op
().
op_conf
().
has_stream_index_hint
())
{
int32_t
stream_index_hint
=
op_node
->
op
().
op_conf
().
stream_index_hint
();
LOG
(
INFO
)
<<
"set op: "
<<
op_node
->
op
().
op_name
()
<<
" to stream: "
<<
stream_index_hint
;
stream_index
=
static_cast
<
StreamId
::
index_t
>
(
stream_index_hint
);
stream_index
=
static_cast
<
StreamId
::
stream_
index_t
>
(
stream_index_hint
);
}
else
{
stream_index
=
StreamIndexGetterRegistryManager
::
Get
().
StreamIndex4DeviceIdAndTaskType
(
device_id
,
comp_task_node
->
GetTaskType
());
...
...
@@ -522,8 +523,9 @@ TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
const
int64_t
dev_id
=
CHECK_JUST
(
dst_parallel_desc
.
DeviceId4ParallelId
(
dst_parallel_id
));
DeviceType
device_type
=
dst_parallel_desc
.
device_type
();
auto
device_index
=
(
device_type
==
DeviceType
::
kCPU
?
0
:
static_cast
<
DeviceId
::
index_t
>
(
dev_id
));
MemZoneId
mem_zone_id
{
static_cast
<
MemZoneId
::
index_t
>
(
dst_machine_id
),
device_type
,
device_index
};
(
device_type
==
DeviceType
::
kCPU
?
0
:
static_cast
<
DeviceId
::
node_index_t
>
(
dev_id
));
MemZoneId
mem_zone_id
{
static_cast
<
MemZoneId
::
node_index_t
>
(
dst_machine_id
),
device_type
,
device_index
};
return
GetProxyNode
(
src_node
,
lbi
,
mem_zone_id
);
}
...
...
oneflow/core/graph/task_id.cpp
浏览文件 @
9ec6871d
...
...
@@ -65,10 +65,11 @@ TaskId DecodeTaskIdFromInt64(int64_t task_id_val) {
int64_t
device_index
=
(
task_id_val
&
kDeviceIndexInt64Mask
)
>>
kDeviceIndexShift
;
int64_t
stream_index
=
(
task_id_val
&
kStreamIndexInt64Mask
)
>>
kStreamIndexShift
;
int64_t
task_index
=
task_id_val
&
kTaskIndexInt64Mask
;
StreamId
stream_id
{
static_cast
<
DeviceId
::
index_t
>
(
node_index
),
static_cast
<
DeviceType
>
(
device_type
),
static_cast
<
DeviceId
::
index_t
>
(
device_index
),
static_cast
<
StreamId
::
index_t
>
(
stream_index
)};
return
TaskId
{
stream_id
,
static_cast
<
TaskId
::
index_t
>
(
task_index
)};
StreamId
stream_id
{
static_cast
<
DeviceId
::
node_index_t
>
(
node_index
),
static_cast
<
DeviceType
>
(
device_type
),
static_cast
<
DeviceId
::
device_index_t
>
(
device_index
),
static_cast
<
StreamId
::
stream_index_t
>
(
stream_index
)};
return
TaskId
{
stream_id
,
static_cast
<
TaskId
::
task_index_t
>
(
task_index
)};
}
int64_t
MachineId4ActorId
(
int64_t
actor_id
)
{
...
...
oneflow/core/graph/task_id.h
浏览文件 @
9ec6871d
...
...
@@ -22,18 +22,19 @@ namespace oneflow {
class
TaskId
{
public:
using
index_t
=
uint32_t
;
using
task_
index_t
=
uint32_t
;
const
static
size_t
kTaskIndexBits
=
21
;
constexpr
static
index_t
kMaxTaskIndex
=
(
index_t
{
1
}
<<
kTaskIndexBits
)
-
index_t
{
1
};
constexpr
static
task_index_t
kMaxTaskIndex
=
(
task_index_t
{
1
}
<<
kTaskIndexBits
)
-
task_index_t
{
1
};
TaskId
(
const
StreamId
&
stream_id
,
index_t
task_index
)
TaskId
(
const
StreamId
&
stream_id
,
task_
index_t
task_index
)
:
stream_id_
(
stream_id
),
task_index_
(
task_index
)
{
CHECK_LE
(
task_index_
,
kMaxTaskIndex
);
}
const
StreamId
&
stream_id
()
const
{
return
stream_id_
;
}
index_t
task_index
()
const
{
return
task_index_
;
}
task_
index_t
task_index
()
const
{
return
task_index_
;
}
bool
operator
==
(
const
TaskId
&
rhs
)
const
{
return
stream_id_
==
rhs
.
stream_id_
&&
task_index_
==
rhs
.
task_index_
;
...
...
@@ -42,13 +43,13 @@ class TaskId {
size_t
hash
()
const
{
size_t
hash
=
stream_id_
.
hash
();
HashCombine
(
&
hash
,
std
::
hash
<
index_t
>
{}(
task_index_
));
HashCombine
(
&
hash
,
std
::
hash
<
task_
index_t
>
{}(
task_index_
));
return
hash
;
}
private:
StreamId
stream_id_
;
index_t
task_index_
;
task_
index_t
task_index_
;
};
int64_t
EncodeTaskIdToInt64
(
const
TaskId
&
);
...
...
oneflow/core/graph/task_id_generator.h
浏览文件 @
9ec6871d
...
...
@@ -22,7 +22,7 @@ namespace oneflow {
class
TaskIdGenerator
final
{
public:
using
task_index_t
=
TaskId
::
index_t
;
using
task_index_t
=
TaskId
::
task_
index_t
;
TaskIdGenerator
()
=
default
;
OF_DISALLOW_COPY_AND_MOVE
(
TaskIdGenerator
);
...
...
oneflow/core/memory/memory_zone.cpp
浏览文件 @
9ec6871d
...
...
@@ -32,7 +32,7 @@ constexpr int64_t kMemZoneIdDeviceIndexInt64Mask = (int64_t{1} << MemZoneId::kDe
const
MemZoneId
kInvalidMemZoneId
=
MemZoneId
{
0
,
DeviceType
::
kInvalidDevice
,
0
};
MemZoneId
GetNodeCPUMemZoneId
(
MemZoneId
::
index_t
node_index
)
{
MemZoneId
GetNodeCPUMemZoneId
(
MemZoneId
::
node_
index_t
node_index
)
{
return
MemZoneId
{
node_index
,
DeviceType
::
kCPU
,
0
};
}
...
...
@@ -47,9 +47,9 @@ MemZoneId DecodeMemZoneIdFromInt64(int64_t mem_zone_id) {
int64_t
node_index
=
(
mem_zone_id
&
kMemZoneIdNodeIndexInt64Mask
)
>>
kMemZoneIdNodeIndexShift
;
int64_t
device_type
=
(
mem_zone_id
&
kMemZoneIdDeviceTypeInt64Mask
)
>>
kMemZoneIdDeviceTypeShift
;
int64_t
device_index
=
mem_zone_id
&
kMemZoneIdDeviceIndexInt64Mask
;
return
MemZoneId
(
static_cast
<
MemZoneId
::
index_t
>
(
node_index
),
return
MemZoneId
(
static_cast
<
MemZoneId
::
node_
index_t
>
(
node_index
),
static_cast
<
DeviceType
>
(
device_type
),
static_cast
<
MemZoneId
::
index_t
>
(
device_index
));
static_cast
<
MemZoneId
::
device_
index_t
>
(
device_index
));
}
}
// namespace oneflow
oneflow/core/memory/memory_zone.h
浏览文件 @
9ec6871d
...
...
@@ -25,7 +25,7 @@ using MemZoneId = DeviceId;
int64_t
EncodeMemZoneIdToInt64
(
const
MemZoneId
&
);
MemZoneId
DecodeMemZoneIdFromInt64
(
int64_t
);
MemZoneId
GetNodeCPUMemZoneId
(
MemZoneId
::
index_t
node_index
);
MemZoneId
GetNodeCPUMemZoneId
(
MemZoneId
::
node_
index_t
node_index
);
extern
const
MemZoneId
kInvalidMemZoneId
;
...
...
oneflow/core/stream/stream_id.cpp
浏览文件 @
9ec6871d
...
...
@@ -59,9 +59,10 @@ StreamId DecodeStreamIdFromInt64(int64_t stream_id_val) {
int64_t
device_type
=
(
stream_id_val
&
kDeviceTypeInt64Mask
)
>>
kDeviceTypeShift
;
int64_t
device_index
=
(
stream_id_val
&
kDeviceIndexInt64Mask
)
>>
kDeviceIndexShift
;
int64_t
stream_index
=
(
stream_id_val
&
kStreamIndexInt64Mask
);
return
StreamId
{
static_cast
<
DeviceId
::
index_t
>
(
node_index
),
static_cast
<
DeviceType
>
(
device_type
),
static_cast
<
DeviceId
::
index_t
>
(
device_index
),
static_cast
<
StreamId
::
index_t
>
(
stream_index
)};
return
StreamId
{
static_cast
<
DeviceId
::
node_index_t
>
(
node_index
),
static_cast
<
DeviceType
>
(
device_type
),
static_cast
<
DeviceId
::
device_index_t
>
(
device_index
),
static_cast
<
StreamId
::
stream_index_t
>
(
stream_index
)};
}
}
// namespace oneflow
oneflow/core/stream/stream_id.h
浏览文件 @
9ec6871d
...
...
@@ -22,26 +22,27 @@ namespace oneflow {
class
StreamId
{
public:
using
index_t
=
uint32_t
;
using
stream_
index_t
=
uint32_t
;
constexpr
static
size_t
kStreamIndexBits
=
12
;
constexpr
static
index_t
kMaxStreamIndex
=
(
index_t
{
1
}
<<
kStreamIndexBits
)
-
index_t
{
1
};
constexpr
static
stream_index_t
kMaxStreamIndex
=
(
stream_index_t
{
1
}
<<
kStreamIndexBits
)
-
stream_index_t
{
1
};
StreamId
(
const
DeviceId
&
device_id
,
index_t
stream_index
)
StreamId
(
const
DeviceId
&
device_id
,
stream_
index_t
stream_index
)
:
device_id_
(
device_id
),
stream_index_
(
stream_index
)
{
CHECK_LE
(
stream_index
,
kMaxStreamIndex
);
}
StreamId
(
DeviceId
::
index_t
node_index
,
DeviceType
device_type
,
DeviceId
::
index_t
device_index
,
index_t
stream_index
)
StreamId
(
DeviceId
::
node_index_t
node_index
,
DeviceType
device_type
,
DeviceId
::
node_index_t
device_index
,
stream_
index_t
stream_index
)
:
device_id_
(
node_index
,
device_type
,
device_index
),
stream_index_
(
stream_index
)
{
CHECK_LE
(
stream_index
,
kMaxStreamIndex
);
}
const
DeviceId
&
device_id
()
const
{
return
device_id_
;
}
DeviceId
::
index_t
node_index
()
const
{
return
device_id_
.
node_index
();
}
DeviceId
::
node_
index_t
node_index
()
const
{
return
device_id_
.
node_index
();
}
DeviceType
device_type
()
const
{
return
device_id_
.
device_type
();
}
DeviceId
::
index_t
device_index
()
const
{
return
device_id_
.
device_index
();
}
index_t
stream_index
()
const
{
return
stream_index_
;
}
DeviceId
::
node_
index_t
device_index
()
const
{
return
device_id_
.
device_index
();
}
stream_
index_t
stream_index
()
const
{
return
stream_index_
;
}
bool
operator
==
(
const
StreamId
&
rhs
)
const
{
return
device_id_
==
rhs
.
device_id_
&&
stream_index_
==
rhs
.
stream_index_
;
...
...
@@ -51,13 +52,13 @@ class StreamId {
size_t
hash
()
const
{
size_t
hash
=
device_id_
.
hash
();
HashCombine
(
&
hash
,
std
::
hash
<
index_t
>
{}(
stream_index_
));
HashCombine
(
&
hash
,
std
::
hash
<
stream_
index_t
>
{}(
stream_index_
));
return
hash
;
}
private:
DeviceId
device_id_
;
index_t
stream_index_
;
stream_
index_t
stream_index_
;
};
int64_t
EncodeStreamIdToInt64
(
const
StreamId
&
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录