Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
03db51d7
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
03db51d7
编写于
6月 30, 2017
作者:
W
willzhang4a58
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
first auto format
上级
3fbb2567
变更
200
展开全部
隐藏空白更改
内联
并排
Showing
200 changed file
with
1989 addition
and
2362 deletion
+1989
-2362
oneflow/core/actor/actor.cpp
oneflow/core/actor/actor.cpp
+14
-18
oneflow/core/actor/actor.h
oneflow/core/actor/actor.h
+20
-21
oneflow/core/actor/actor_message.cpp
oneflow/core/actor/actor_message.cpp
+4
-6
oneflow/core/actor/actor_message.h
oneflow/core/actor/actor_message.h
+9
-16
oneflow/core/actor/actor_message_bus.cpp
oneflow/core/actor/actor_message_bus.cpp
+2
-2
oneflow/core/actor/actor_message_bus.h
oneflow/core/actor/actor_message_bus.h
+2
-2
oneflow/core/actor/actor_registry.cpp
oneflow/core/actor/actor_registry.cpp
+4
-3
oneflow/core/actor/actor_registry.h
oneflow/core/actor/actor_registry.h
+3
-2
oneflow/core/actor/boxing_actor.cpp
oneflow/core/actor/boxing_actor.cpp
+18
-15
oneflow/core/actor/boxing_actor.h
oneflow/core/actor/boxing_actor.h
+1
-2
oneflow/core/actor/bp_data_comp_actor.cpp
oneflow/core/actor/bp_data_comp_actor.cpp
+37
-26
oneflow/core/actor/bp_data_comp_actor.h
oneflow/core/actor/bp_data_comp_actor.h
+3
-3
oneflow/core/actor/compute_actor.h
oneflow/core/actor/compute_actor.h
+4
-4
oneflow/core/actor/copy_comm_net_actor.cpp
oneflow/core/actor/copy_comm_net_actor.cpp
+23
-18
oneflow/core/actor/copy_comm_net_actor.h
oneflow/core/actor/copy_comm_net_actor.h
+3
-4
oneflow/core/actor/copy_hd_actor.cpp
oneflow/core/actor/copy_hd_actor.cpp
+18
-15
oneflow/core/actor/copy_hd_actor.h
oneflow/core/actor/copy_hd_actor.h
+3
-4
oneflow/core/actor/cpu_device_context.h
oneflow/core/actor/cpu_device_context.h
+4
-6
oneflow/core/actor/cuda_device_context.cpp
oneflow/core/actor/cuda_device_context.cpp
+6
-8
oneflow/core/actor/cuda_device_context.h
oneflow/core/actor/cuda_device_context.h
+2
-2
oneflow/core/actor/device_context.h
oneflow/core/actor/device_context.h
+11
-17
oneflow/core/actor/fw_data_comp_actor.cpp
oneflow/core/actor/fw_data_comp_actor.cpp
+26
-26
oneflow/core/actor/fw_data_comp_actor.h
oneflow/core/actor/fw_data_comp_actor.h
+3
-3
oneflow/core/actor/model_diff_accumulate_actor.cpp
oneflow/core/actor/model_diff_accumulate_actor.cpp
+17
-17
oneflow/core/actor/model_diff_accumulate_actor.h
oneflow/core/actor/model_diff_accumulate_actor.h
+2
-2
oneflow/core/actor/model_save_comp_actor.cpp
oneflow/core/actor/model_save_comp_actor.cpp
+12
-13
oneflow/core/actor/model_save_comp_actor.h
oneflow/core/actor/model_save_comp_actor.h
+1
-1
oneflow/core/actor/model_update_comp_actor.cpp
oneflow/core/actor/model_update_comp_actor.cpp
+22
-24
oneflow/core/actor/model_update_comp_actor.h
oneflow/core/actor/model_update_comp_actor.h
+1
-2
oneflow/core/blas/cblas.h
oneflow/core/blas/cblas.h
+203
-212
oneflow/core/blas/cblas_template.cpp
oneflow/core/blas/cblas_template.cpp
+48
-52
oneflow/core/blas/cblas_template.h
oneflow/core/blas/cblas_template.h
+26
-33
oneflow/core/blas/cublas_template.cu
oneflow/core/blas/cublas_template.cu
+49
-57
oneflow/core/blas/cublas_template.h
oneflow/core/blas/cublas_template.h
+23
-32
oneflow/core/common/balanced_splitter.cpp
oneflow/core/common/balanced_splitter.cpp
+2
-2
oneflow/core/common/balanced_splitter.h
oneflow/core/common/balanced_splitter.h
+6
-6
oneflow/core/common/balanced_splitter_test.cpp
oneflow/core/common/balanced_splitter_test.cpp
+1
-1
oneflow/core/common/channel.h
oneflow/core/common/channel.h
+8
-9
oneflow/core/common/channel_test.cpp
oneflow/core/common/channel_test.cpp
+11
-27
oneflow/core/common/cuda_stream_handle.cpp
oneflow/core/common/cuda_stream_handle.cpp
+4
-10
oneflow/core/common/cuda_stream_handle.h
oneflow/core/common/cuda_stream_handle.h
+2
-3
oneflow/core/common/cuda_util.h
oneflow/core/common/cuda_util.h
+4
-7
oneflow/core/common/process_state.h
oneflow/core/common/process_state.h
+1
-2
oneflow/core/common/protobuf.cpp
oneflow/core/common/protobuf.cpp
+12
-12
oneflow/core/common/protobuf.h
oneflow/core/common/protobuf.h
+19
-24
oneflow/core/common/range.h
oneflow/core/common/range.h
+4
-4
oneflow/core/common/shape.cpp
oneflow/core/common/shape.cpp
+8
-17
oneflow/core/common/shape.h
oneflow/core/common/shape.h
+8
-10
oneflow/core/common/util.cpp
oneflow/core/common/util.cpp
+2
-3
oneflow/core/common/util.h
oneflow/core/common/util.h
+36
-44
oneflow/core/graph/boxing_task_node.cpp
oneflow/core/graph/boxing_task_node.cpp
+25
-28
oneflow/core/graph/boxing_task_node.h
oneflow/core/graph/boxing_task_node.h
+9
-13
oneflow/core/graph/chain_graph.cpp
oneflow/core/graph/chain_graph.cpp
+37
-52
oneflow/core/graph/chain_graph.h
oneflow/core/graph/chain_graph.h
+10
-23
oneflow/core/graph/comp_task_node.cpp
oneflow/core/graph/comp_task_node.cpp
+12
-13
oneflow/core/graph/comp_task_node.h
oneflow/core/graph/comp_task_node.h
+3
-4
oneflow/core/graph/copy_task_node.cpp
oneflow/core/graph/copy_task_node.cpp
+6
-6
oneflow/core/graph/copy_task_node.h
oneflow/core/graph/copy_task_node.h
+10
-19
oneflow/core/graph/data_comp_task_node.cpp
oneflow/core/graph/data_comp_task_node.cpp
+10
-14
oneflow/core/graph/data_comp_task_node.h
oneflow/core/graph/data_comp_task_node.h
+11
-17
oneflow/core/graph/data_task_graph.cpp
oneflow/core/graph/data_task_graph.cpp
+4
-6
oneflow/core/graph/data_task_graph.h
oneflow/core/graph/data_task_graph.h
+6
-8
oneflow/core/graph/exec_graph.cpp
oneflow/core/graph/exec_graph.cpp
+7
-9
oneflow/core/graph/exec_graph.h
oneflow/core/graph/exec_graph.h
+11
-12
oneflow/core/graph/graph.h
oneflow/core/graph/graph.h
+58
-76
oneflow/core/graph/graph_test.cpp
oneflow/core/graph/graph_test.cpp
+12
-11
oneflow/core/graph/in_boxing_task_node.cpp
oneflow/core/graph/in_boxing_task_node.cpp
+4
-7
oneflow/core/graph/in_boxing_task_node.h
oneflow/core/graph/in_boxing_task_node.h
+3
-4
oneflow/core/graph/logical_graph.cpp
oneflow/core/graph/logical_graph.cpp
+2
-3
oneflow/core/graph/logical_graph.h
oneflow/core/graph/logical_graph.h
+17
-28
oneflow/core/graph/model_diff_accumulate_comp_task_node.cpp
oneflow/core/graph/model_diff_accumulate_comp_task_node.cpp
+7
-6
oneflow/core/graph/model_diff_accumulate_comp_task_node.h
oneflow/core/graph/model_diff_accumulate_comp_task_node.h
+8
-8
oneflow/core/graph/model_diff_accumulate_task_graph.cpp
oneflow/core/graph/model_diff_accumulate_task_graph.cpp
+3
-4
oneflow/core/graph/model_diff_accumulate_task_graph.h
oneflow/core/graph/model_diff_accumulate_task_graph.h
+4
-5
oneflow/core/graph/model_save_comp_task_node.cpp
oneflow/core/graph/model_save_comp_task_node.cpp
+3
-3
oneflow/core/graph/model_save_comp_task_node.h
oneflow/core/graph/model_save_comp_task_node.h
+9
-9
oneflow/core/graph/model_save_task_graph.cpp
oneflow/core/graph/model_save_task_graph.cpp
+8
-5
oneflow/core/graph/model_save_task_graph.h
oneflow/core/graph/model_save_task_graph.h
+3
-4
oneflow/core/graph/model_update_comp_task_node.cpp
oneflow/core/graph/model_update_comp_task_node.cpp
+2
-2
oneflow/core/graph/model_update_comp_task_node.h
oneflow/core/graph/model_update_comp_task_node.h
+8
-8
oneflow/core/graph/model_update_task_graph.cpp
oneflow/core/graph/model_update_task_graph.cpp
+4
-5
oneflow/core/graph/model_update_task_graph.h
oneflow/core/graph/model_update_task_graph.h
+3
-4
oneflow/core/graph/node.cpp
oneflow/core/graph/node.cpp
+1
-1
oneflow/core/graph/node.h
oneflow/core/graph/node.h
+13
-29
oneflow/core/graph/out_boxing_task_node.cpp
oneflow/core/graph/out_boxing_task_node.cpp
+4
-7
oneflow/core/graph/out_boxing_task_node.h
oneflow/core/graph/out_boxing_task_node.h
+3
-4
oneflow/core/graph/stage_graph.cpp
oneflow/core/graph/stage_graph.cpp
+2
-2
oneflow/core/graph/stage_graph.h
oneflow/core/graph/stage_graph.h
+9
-23
oneflow/core/graph/task_graph.cpp
oneflow/core/graph/task_graph.cpp
+46
-58
oneflow/core/graph/task_graph.h
oneflow/core/graph/task_graph.h
+13
-16
oneflow/core/graph/task_node.cpp
oneflow/core/graph/task_node.cpp
+18
-12
oneflow/core/graph/task_node.h
oneflow/core/graph/task_node.h
+27
-31
oneflow/core/job/compiler.cpp
oneflow/core/job/compiler.cpp
+41
-46
oneflow/core/job/id_manager.h
oneflow/core/job/id_manager.h
+10
-13
oneflow/core/job/id_manager_test.cpp
oneflow/core/job/id_manager_test.cpp
+11
-11
oneflow/core/job/job_desc.cpp
oneflow/core/job/job_desc.cpp
+1
-1
oneflow/core/job/job_desc.h
oneflow/core/job/job_desc.h
+14
-9
oneflow/core/job/keyword.cpp
oneflow/core/job/keyword.cpp
+1
-1
oneflow/core/job/keyword.h
oneflow/core/job/keyword.h
+2
-2
oneflow/core/job/parallel_desc.cpp
oneflow/core/job/parallel_desc.cpp
+7
-7
oneflow/core/job/parallel_desc.h
oneflow/core/job/parallel_desc.h
+12
-14
oneflow/core/job/runtime.cpp
oneflow/core/job/runtime.cpp
+5
-7
oneflow/core/job/runtime_context.cpp
oneflow/core/job/runtime_context.cpp
+3
-7
oneflow/core/job/runtime_context.h
oneflow/core/job/runtime_context.h
+3
-4
oneflow/core/kernel/clone_kernel.cpp
oneflow/core/kernel/clone_kernel.cpp
+10
-10
oneflow/core/kernel/clone_kernel.h
oneflow/core/kernel/clone_kernel.h
+7
-5
oneflow/core/kernel/convolution_kernel.h
oneflow/core/kernel/convolution_kernel.h
+13
-7
oneflow/core/kernel/copy_hd_kernel.cpp
oneflow/core/kernel/copy_hd_kernel.cpp
+2
-2
oneflow/core/kernel/copy_hd_kernel_test.cpp
oneflow/core/kernel/copy_hd_kernel_test.cpp
+13
-21
oneflow/core/kernel/data_loader_kernel.cpp
oneflow/core/kernel/data_loader_kernel.cpp
+6
-5
oneflow/core/kernel/data_loader_kernel.h
oneflow/core/kernel/data_loader_kernel.h
+2
-2
oneflow/core/kernel/innerproduct_kernel.cpp
oneflow/core/kernel/innerproduct_kernel.cpp
+19
-28
oneflow/core/kernel/innerproduct_kernel.h
oneflow/core/kernel/innerproduct_kernel.h
+1
-1
oneflow/core/kernel/innerproduct_kernel_test.cpp
oneflow/core/kernel/innerproduct_kernel_test.cpp
+17
-25
oneflow/core/kernel/kernel.cpp
oneflow/core/kernel/kernel.cpp
+5
-8
oneflow/core/kernel/kernel.h
oneflow/core/kernel/kernel.h
+20
-28
oneflow/core/kernel/kernel_context.h
oneflow/core/kernel/kernel_context.h
+2
-2
oneflow/core/kernel/kernel_manager.cpp
oneflow/core/kernel/kernel_manager.cpp
+14
-14
oneflow/core/kernel/kernel_manager.h
oneflow/core/kernel/kernel_manager.h
+19
-11
oneflow/core/kernel/kernel_util.cpp
oneflow/core/kernel/kernel_util.cpp
+48
-53
oneflow/core/kernel/kernel_util.cu
oneflow/core/kernel/kernel_util.cu
+41
-40
oneflow/core/kernel/kernel_util.h
oneflow/core/kernel/kernel_util.h
+36
-34
oneflow/core/kernel/model_save_kernel.cpp
oneflow/core/kernel/model_save_kernel.cpp
+7
-5
oneflow/core/kernel/model_save_kernel.h
oneflow/core/kernel/model_save_kernel.h
+7
-4
oneflow/core/memory/memory_allocator.cpp
oneflow/core/memory/memory_allocator.cpp
+5
-9
oneflow/core/memory/memory_allocator.h
oneflow/core/memory/memory_allocator.h
+6
-7
oneflow/core/operator/boxing_op.cpp
oneflow/core/operator/boxing_op.cpp
+2
-4
oneflow/core/operator/boxing_op.h
oneflow/core/operator/boxing_op.h
+4
-6
oneflow/core/operator/boxing_op_test.cpp
oneflow/core/operator/boxing_op_test.cpp
+14
-17
oneflow/core/operator/clear_op.h
oneflow/core/operator/clear_op.h
+2
-2
oneflow/core/operator/clone_op.cpp
oneflow/core/operator/clone_op.cpp
+1
-3
oneflow/core/operator/clone_op.h
oneflow/core/operator/clone_op.h
+5
-7
oneflow/core/operator/clone_op_test.cpp
oneflow/core/operator/clone_op_test.cpp
+1
-1
oneflow/core/operator/concat_op.cpp
oneflow/core/operator/concat_op.cpp
+2
-4
oneflow/core/operator/concat_op.h
oneflow/core/operator/concat_op.h
+3
-5
oneflow/core/operator/concat_op_test.cpp
oneflow/core/operator/concat_op_test.cpp
+5
-5
oneflow/core/operator/convolution_op.cpp
oneflow/core/operator/convolution_op.cpp
+10
-10
oneflow/core/operator/convolution_op.h
oneflow/core/operator/convolution_op.h
+3
-4
oneflow/core/operator/convolution_op_test.cpp
oneflow/core/operator/convolution_op_test.cpp
+11
-11
oneflow/core/operator/copy_comm_net_op.cpp
oneflow/core/operator/copy_comm_net_op.cpp
+1
-1
oneflow/core/operator/copy_comm_net_op.h
oneflow/core/operator/copy_comm_net_op.h
+3
-4
oneflow/core/operator/copy_hd_op.cpp
oneflow/core/operator/copy_hd_op.cpp
+1
-1
oneflow/core/operator/copy_hd_op.h
oneflow/core/operator/copy_hd_op.h
+3
-4
oneflow/core/operator/data_loader_op.cpp
oneflow/core/operator/data_loader_op.cpp
+4
-6
oneflow/core/operator/data_loader_op.h
oneflow/core/operator/data_loader_op.h
+6
-8
oneflow/core/operator/innerproduct_op.cpp
oneflow/core/operator/innerproduct_op.cpp
+1
-3
oneflow/core/operator/innerproduct_op.h
oneflow/core/operator/innerproduct_op.h
+3
-5
oneflow/core/operator/innerproduct_op_test.cpp
oneflow/core/operator/innerproduct_op_test.cpp
+14
-16
oneflow/core/operator/model_diff_accumulate_op.cpp
oneflow/core/operator/model_diff_accumulate_op.cpp
+1
-1
oneflow/core/operator/model_diff_accumulate_op.h
oneflow/core/operator/model_diff_accumulate_op.h
+2
-3
oneflow/core/operator/model_save_op.cpp
oneflow/core/operator/model_save_op.cpp
+1
-1
oneflow/core/operator/model_save_op.h
oneflow/core/operator/model_save_op.h
+2
-2
oneflow/core/operator/model_update_op.cpp
oneflow/core/operator/model_update_op.cpp
+2
-2
oneflow/core/operator/model_update_op.h
oneflow/core/operator/model_update_op.h
+2
-3
oneflow/core/operator/multinomial_logistic_loss_op.cpp
oneflow/core/operator/multinomial_logistic_loss_op.cpp
+3
-5
oneflow/core/operator/multinomial_logistic_loss_op.h
oneflow/core/operator/multinomial_logistic_loss_op.h
+4
-6
oneflow/core/operator/multinomial_logistic_loss_op_test.cpp
oneflow/core/operator/multinomial_logistic_loss_op_test.cpp
+15
-16
oneflow/core/operator/operator.cpp
oneflow/core/operator/operator.cpp
+5
-8
oneflow/core/operator/operator.h
oneflow/core/operator/operator.h
+31
-38
oneflow/core/operator/operator_manager.cpp
oneflow/core/operator/operator_manager.cpp
+3
-4
oneflow/core/operator/operator_manager.h
oneflow/core/operator/operator_manager.h
+5
-6
oneflow/core/operator/pooling_op.cpp
oneflow/core/operator/pooling_op.cpp
+4
-7
oneflow/core/operator/pooling_op.h
oneflow/core/operator/pooling_op.h
+4
-5
oneflow/core/operator/relu_op.cpp
oneflow/core/operator/relu_op.cpp
+1
-3
oneflow/core/operator/relu_op.h
oneflow/core/operator/relu_op.h
+2
-3
oneflow/core/operator/relu_op_test.cpp
oneflow/core/operator/relu_op_test.cpp
+2
-2
oneflow/core/operator/softmax_op.cpp
oneflow/core/operator/softmax_op.cpp
+2
-4
oneflow/core/operator/softmax_op.h
oneflow/core/operator/softmax_op.h
+4
-6
oneflow/core/operator/softmax_op_test.cpp
oneflow/core/operator/softmax_op_test.cpp
+6
-6
oneflow/core/persistence/persistent_circular_line_reader.cpp
oneflow/core/persistence/persistent_circular_line_reader.cpp
+3
-3
oneflow/core/persistence/persistent_circular_line_reader.h
oneflow/core/persistence/persistent_circular_line_reader.h
+5
-5
oneflow/core/persistence/persistent_in_stream.cpp
oneflow/core/persistence/persistent_in_stream.cpp
+4
-5
oneflow/core/persistence/persistent_in_stream.h
oneflow/core/persistence/persistent_in_stream.h
+4
-8
oneflow/core/persistence/persistent_out_stream.h
oneflow/core/persistence/persistent_out_stream.h
+8
-11
oneflow/core/persistence/snapshot.cpp
oneflow/core/persistence/snapshot.cpp
+12
-14
oneflow/core/persistence/snapshot.h
oneflow/core/persistence/snapshot.h
+14
-14
oneflow/core/persistence/snapshot_manager.cpp
oneflow/core/persistence/snapshot_manager.cpp
+5
-3
oneflow/core/persistence/snapshot_manager.h
oneflow/core/persistence/snapshot_manager.h
+3
-5
oneflow/core/persistence/snapshot_test.cpp
oneflow/core/persistence/snapshot_test.cpp
+11
-9
oneflow/core/register/blob.h
oneflow/core/register/blob.h
+3
-3
oneflow/core/register/local_register_warpper.h
oneflow/core/register/local_register_warpper.h
+6
-13
oneflow/core/register/register.cpp
oneflow/core/register/register.cpp
+2
-4
oneflow/core/register/register.h
oneflow/core/register/register.h
+5
-9
oneflow/core/register/register_desc.cpp
oneflow/core/register/register_desc.cpp
+23
-23
oneflow/core/register/register_desc.h
oneflow/core/register/register_desc.h
+5
-6
oneflow/core/register/register_manager.cpp
oneflow/core/register/register_manager.cpp
+10
-7
oneflow/core/register/register_manager.h
oneflow/core/register/register_manager.h
+5
-6
oneflow/core/register/register_warpper.h
oneflow/core/register/register_warpper.h
+2
-3
oneflow/core/register/remote_register_warpper.h
oneflow/core/register/remote_register_warpper.h
+9
-20
oneflow/core/register/runtime_register_desc.cpp
oneflow/core/register/runtime_register_desc.cpp
+5
-4
oneflow/core/register/runtime_register_desc.h
oneflow/core/register/runtime_register_desc.h
+3
-4
oneflow/core/thread/cpu_thread.cpp
oneflow/core/thread/cpu_thread.cpp
+2
-4
oneflow/core/thread/cpu_thread.h
oneflow/core/thread/cpu_thread.h
+2
-3
oneflow/core/thread/gpu_thread.cpp
oneflow/core/thread/gpu_thread.cpp
+1
-1
oneflow/core/thread/gpu_thread.h
oneflow/core/thread/gpu_thread.h
+3
-3
oneflow/core/thread/thread.cpp
oneflow/core/thread/thread.cpp
+1
-3
oneflow/core/thread/thread.h
oneflow/core/thread/thread.h
+4
-5
oneflow/core/thread/thread_context.h
oneflow/core/thread/thread_context.h
+4
-5
oneflow/core/thread/thread_manager.cpp
oneflow/core/thread/thread_manager.cpp
+5
-7
oneflow/core/thread/thread_manager.h
oneflow/core/thread/thread_manager.h
+3
-3
未找到文件。
oneflow/core/actor/actor.cpp
浏览文件 @
03db51d7
...
...
@@ -28,8 +28,8 @@ void Actor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
}
// name2regst_desc_id_
for
(
const
auto
&
pair
:
task_proto
.
produced_regst_desc
())
{
CHECK
(
name2regst_desc_id_
.
emplace
(
pair
.
first
,
pair
.
second
.
regst_desc_id
())
.
second
);
CHECK
(
name2regst_desc_id_
.
emplace
(
pair
.
first
,
pair
.
second
.
regst_desc_id
())
.
second
);
}
for
(
const
auto
&
pair
:
task_proto
.
subscribed_regst_desc_id
())
{
CHECK
(
name2regst_desc_id_
.
emplace
(
pair
.
first
,
pair
.
second
).
second
);
...
...
@@ -48,9 +48,7 @@ void Actor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
int64_t
Actor
::
RegstDescId4Name
(
const
std
::
string
&
name
)
const
{
auto
find_it
=
name2regst_desc_id_
.
find
(
name
);
if
(
find_it
!=
name2regst_desc_id_
.
end
())
{
return
find_it
->
second
;
}
if
(
find_it
!=
name2regst_desc_id_
.
end
())
{
return
find_it
->
second
;
}
return
-
1
;
}
...
...
@@ -61,7 +59,8 @@ KernelCtx Actor::GenDefaultKernelCtx() const {
}
int
Actor
::
HandleWaitUntilReadingCntEqualZero
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
if
(
total_reading_cnt_
==
0
)
{
msg_handle_
=
nullptr
;
return
1
;
...
...
@@ -92,7 +91,8 @@ void Actor::AsyncSendReadableRegstMsg() {
ActorMsgBus
::
Singleton
().
SendMsg
(
std
::
move
(
msg
));
}
});
produced_regst2reading_cnt_
.
at
(
regst
)
=
regst
->
subscribers_actor_id
().
size
();
produced_regst2reading_cnt_
.
at
(
regst
)
=
regst
->
subscribers_actor_id
().
size
();
total_reading_cnt_
+=
regst
->
subscribers_actor_id
().
size
();
if
(
!
regst
->
subscribers_actor_id
().
empty
())
{
pair
.
second
.
pop
();
}
if
(
pair
.
second
.
empty
())
{
writeable_produced_regst_desc_num_
-=
1
;
}
...
...
@@ -121,13 +121,11 @@ void Actor::AsyncDo(std::function<void()> func) {
device_ctx_
->
AddCallBack
(
func
);
}
void
Actor
::
AsyncSendRegstMsgToProducer
(
const
std
::
shared_ptr
<
RegstWarpper
>&
wp
)
{
ActorMsg
msg
=
ActorMsg
::
BuildRegstMsgToProducer
(
wp
->
producer_actor_id
(),
wp
->
regst_raw_ptr
());
AsyncDo
([
msg
]()
{
ActorMsgBus
::
Singleton
().
SendMsg
(
msg
);
});
void
Actor
::
AsyncSendRegstMsgToProducer
(
const
std
::
shared_ptr
<
RegstWarpper
>&
wp
)
{
ActorMsg
msg
=
ActorMsg
::
BuildRegstMsgToProducer
(
wp
->
producer_actor_id
(),
wp
->
regst_raw_ptr
());
AsyncDo
([
msg
]()
{
ActorMsgBus
::
Singleton
().
SendMsg
(
msg
);
});
}
int
Actor
::
TryUpdtStateAsProducedRegst
(
Regst
*
regst
)
{
...
...
@@ -139,9 +137,7 @@ int Actor::TryUpdtStateAsProducedRegst(Regst* regst) {
if
(
reading_cnt_it
->
second
!=
0
)
{
return
0
;
}
auto
writeable_it
=
writeable_produced_regst_
.
find
(
regst
->
regst_desc_id
());
if
(
writeable_it
==
writeable_produced_regst_
.
end
())
{
return
0
;
}
if
(
writeable_it
->
second
.
empty
())
{
writeable_produced_regst_desc_num_
+=
1
;
}
if
(
writeable_it
->
second
.
empty
())
{
writeable_produced_regst_desc_num_
+=
1
;
}
writeable_it
->
second
.
push
(
regst
);
return
0
;
}
...
...
@@ -166,4 +162,4 @@ bool Actor::IsWriteReady() {
return
writeable_produced_regst_desc_num_
==
writeable_produced_regst_
.
size
();
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/actor/actor.h
浏览文件 @
03db51d7
#ifndef ONEFLOW_CORE_ACTOR_ACTOR_H_
#define ONEFLOW_CORE_ACTOR_ACTOR_H_
#include "oneflow/core/common/cuda_stream_handle.h"
#include "oneflow/core/kernel/kernel_manager.h"
#include "oneflow/core/kernel/kernel_context.h"
#include "oneflow/core/actor/actor_message_bus.h"
#include "oneflow/core/actor/cpu_device_context.h"
#include "oneflow/core/actor/cuda_device_context.h"
#include "oneflow/core/common/cuda_stream_handle.h"
#include "oneflow/core/job/task.pb.h"
#include "oneflow/core/actor/actor_message_bus.h"
#include "oneflow/core/kernel/kernel_context.h"
#include "oneflow/core/kernel/kernel_manager.h"
#include "oneflow/core/persistence/snapshot_manager.h"
#include "oneflow/core/register/local_register_warpper.h"
#include "oneflow/core/register/remote_register_warpper.h"
#include "oneflow/core/register/register_manager.h"
#include "oneflow/core/register/remote_register_warpper.h"
#include "oneflow/core/thread/thread_context.h"
#include "oneflow/core/persistence/snapshot_manager.h"
namespace
oneflow
{
...
...
@@ -24,12 +24,10 @@ class Actor {
virtual
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
=
0
;
// 1: success, and actor finish
// 0: success, and actor not finish
int
ProcessMsg
(
const
ActorMsg
&
msg
)
{
return
(
this
->*
msg_handle_
)(
msg
);
}
int
ProcessMsg
(
const
ActorMsg
&
msg
)
{
return
(
this
->*
msg_handle_
)(
msg
);
}
int64_t
actor_id
()
const
{
return
actor_id_
;
}
protected:
struct
ExecKernel
{
const
Kernel
*
kernel
;
...
...
@@ -45,11 +43,11 @@ class Actor {
// Msg Handle
using
MsgHandle
=
int
(
Actor
::*
)(
const
ActorMsg
&
);
void
set_msg_handle
(
MsgHandle
val
)
{
msg_handle_
=
val
;
}
#define OF_SET_MSG_HANDLE(val)
\
do {
\
LOG(INFO) << "Actor " << actor_id() << " switch to " << #val; \
set_msg_handle(static_cast<MsgHandle>(val));
\
} while
(0)
#define OF_SET_MSG_HANDLE(val)
\
do {
\
LOG(INFO) << "Actor " << actor_id() << " switch to " << #val; \
set_msg_handle(static_cast<MsgHandle>(val));
\
} while
(0)
// Common Handles
int
HandleWaitUntilReadingCntEqualZero
(
const
ActorMsg
&
msg
);
...
...
@@ -80,22 +78,23 @@ class Actor {
int64_t
actor_id_
;
KernelWardFunc
ward_func_
;
std
::
vector
<
ExecKernel
>
exec_kernel_vec_
;
HashMap
<
int64_t
,
std
::
vector
<
std
::
unique_ptr
<
Regst
>>>
produced_regsts_
;
// <regst_desc_id, regst>
HashMap
<
int64_t
,
std
::
vector
<
std
::
unique_ptr
<
Regst
>>>
produced_regsts_
;
// <regst_desc_id, regst>
HashMap
<
std
::
string
,
int64_t
>
name2regst_desc_id_
;
std
::
unique_ptr
<
DeviceCtx
>
device_ctx_
;
MsgHandle
msg_handle_
;
// Status of Produced Registers
int64_t
expected_piece_id_
;
HashMap
<
int64_t
,
std
::
queue
<
Regst
*>>
writeable_produced_regst_
;
// <regst_desc_id, regst>
HashMap
<
int64_t
,
std
::
queue
<
Regst
*>>
writeable_produced_regst_
;
// <regst_desc_id, regst>
int64_t
writeable_produced_regst_desc_num_
;
HashMap
<
Regst
*
,
int64_t
>
produced_regst2reading_cnt_
;
int64_t
total_reading_cnt_
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_ACTOR_H_
#endif
// ONEFLOW_CORE_ACTOR_ACTOR_H_
oneflow/core/actor/actor_message.cpp
浏览文件 @
03db51d7
#include "oneflow/core/actor/actor_message.h"
#include "oneflow/core/register/remote_register_warpper.h"
#include "oneflow/core/register/local_register_warpper.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/job/runtime_context.h"
#include "oneflow/core/register/local_register_warpper.h"
#include "oneflow/core/register/remote_register_warpper.h"
namespace
oneflow
{
OF_DEFINE_ENUM_TO_OSTREAM_FUNC
(
ActorCmd
);
OF_DEFINE_ENUM_TO_OSTREAM_FUNC
(
ActorMsgType
);
ActorMsg
::
ActorMsg
()
{
dst_actor_id_
=
-
1
;
}
ActorMsg
::
ActorMsg
()
{
dst_actor_id_
=
-
1
;
}
ActorMsg
ActorMsg
::
BuildReadableRegstMsg
(
int64_t
reader_actor_id
,
Regst
*
regst_raw_ptr
)
{
...
...
@@ -36,4 +34,4 @@ ActorMsg ActorMsg::BuildRegstMsgToProducer(int64_t writer_actor_id,
return
msg
;
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/actor/actor_message.h
浏览文件 @
03db51d7
...
...
@@ -7,18 +7,15 @@
namespace
oneflow
{
enum
class
ActorCmd
{
kInitializeModel
=
0
,
// MdUpdt Actor
kSendInitialModel
,
// MdUpdt Actor
kEORD
,
// End Of Register Desc, All Actor except Source Actor
kStart
// Source Actor
kInitializeModel
=
0
,
// MdUpdt Actor
kSendInitialModel
,
// MdUpdt Actor
kEORD
,
// End Of Register Desc, All Actor except Source Actor
kStart
// Source Actor
};
OF_DECLARE_ENUM_TO_OSTREAM_FUNC
(
ActorCmd
);
enum
class
ActorMsgType
{
kRegstMsg
=
0
,
kCmdMsg
};
enum
class
ActorMsgType
{
kRegstMsg
=
0
,
kCmdMsg
};
OF_DECLARE_ENUM_TO_OSTREAM_FUNC
(
ActorMsgType
);
...
...
@@ -43,9 +40,7 @@ class ActorMsg final {
return
actor_cmd_
;
}
// Setters
void
set_dst_actor_id
(
int64_t
val
)
{
dst_actor_id_
=
val
;
}
void
set_dst_actor_id
(
int64_t
val
)
{
dst_actor_id_
=
val
;
}
void
set_regst_warpper
(
std
::
shared_ptr
<
RegstWarpper
>
val
)
{
msg_type_
=
ActorMsgType
::
kRegstMsg
;
regst_warpper_
=
val
;
...
...
@@ -54,17 +49,15 @@ class ActorMsg final {
msg_type_
=
ActorMsgType
::
kCmdMsg
;
actor_cmd_
=
val
;
}
private:
private:
int64_t
dst_actor_id_
;
ActorMsgType
msg_type_
;
std
::
shared_ptr
<
RegstWarpper
>
regst_warpper_
;
ActorCmd
actor_cmd_
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_H_
#endif
// ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_H_
oneflow/core/actor/actor_message_bus.cpp
浏览文件 @
03db51d7
...
...
@@ -7,10 +7,10 @@ namespace oneflow {
void
ActorMsgBus
::
SendMsg
(
const
ActorMsg
&
msg
)
{
int64_t
dst_machine_id
=
IDMgr
::
Singleton
().
MachineId4ActorId
(
msg
.
dst_actor_id
());
IDMgr
::
Singleton
().
MachineId4ActorId
(
msg
.
dst_actor_id
());
if
(
dst_machine_id
==
RuntimeCtx
::
Singleton
().
this_machine_id
())
{
int64_t
thrd_loc_id
=
IDMgr
::
Singleton
().
ThrdLocId4ActorId
(
msg
.
dst_actor_id
());
IDMgr
::
Singleton
().
ThrdLocId4ActorId
(
msg
.
dst_actor_id
());
ThreadMgr
::
Singleton
().
GetThrd
(
thrd_loc_id
)
->
GetMsgChannelPtr
()
->
Send
(
msg
);
}
else
{
TODO
();
...
...
oneflow/core/actor/actor_message_bus.h
浏览文件 @
03db51d7
...
...
@@ -2,8 +2,8 @@
#define ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_BUS_H_
#include <stdint.h>
#include "oneflow/core/common/util.h"
#include "oneflow/core/actor/actor_message.h"
#include "oneflow/core/common/util.h"
namespace
oneflow
{
...
...
@@ -22,4 +22,4 @@ class ActorMsgBus final {
}
// namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_BUS_H_
#endif
// ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_BUS_H_
oneflow/core/actor/actor_registry.cpp
浏览文件 @
03db51d7
...
...
@@ -5,20 +5,21 @@ namespace oneflow {
namespace
{
struct
PairHash
{
std
::
size_t
operator
()
(
const
std
::
pair
<
int
,
bool
>
&
p
)
const
{
std
::
size_t
operator
()(
const
std
::
pair
<
int
,
bool
>&
p
)
const
{
return
std
::
hash
<
int
>
{}((
p
.
first
<<
1
)
|
(
static_cast
<
int
>
(
p
.
second
)));
}
};
using
ActorTypePair
=
std
::
pair
<
TaskType
,
bool
>
;
using
ActorCreatorMap
=
HashMap
<
ActorTypePair
,
std
::
function
<
Actor
*
()
>
,
PairHash
>
;
using
ActorCreatorMap
=
HashMap
<
ActorTypePair
,
std
::
function
<
Actor
*
()
>
,
PairHash
>
;
ActorCreatorMap
&
ActorType2Creator
()
{
static
ActorCreatorMap
obj
;
return
obj
;
}
}
}
// namespace
void
AddActorCreator
(
TaskType
task_type
,
bool
is_forward
,
std
::
function
<
Actor
*
()
>
creator
)
{
...
...
oneflow/core/actor/actor_registry.h
浏览文件 @
03db51d7
...
...
@@ -20,8 +20,9 @@ struct ActorRegister {
};
#define REGISTER_ACTOR(TaskType, IsForward, ActorType) \
static ActorRegister<TaskType, IsForward, ActorType> g_##ActorType##_##IsForward##_register_var;
static ActorRegister<TaskType, IsForward, ActorType> \
g_##ActorType##_##IsForward##_register_var;
}
// namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_ACTOR_REGISTRY_H_
#endif
// ONEFLOW_CORE_ACTOR_ACTOR_REGISTRY_H_
oneflow/core/actor/boxing_actor.cpp
浏览文件 @
03db51d7
...
...
@@ -4,7 +4,8 @@
namespace
oneflow
{
void
BoxingActor
::
Init
(
const
TaskProto
&
task_proto
,
const
ThreadCtx
&
thread_ctx
)
{
void
BoxingActor
::
Init
(
const
TaskProto
&
task_proto
,
const
ThreadCtx
&
thread_ctx
)
{
Actor
::
Init
(
task_proto
,
thread_ctx
);
num_of_subscribed_regsts_
=
task_proto
.
subscribed_regst_desc_id
().
size
();
num_of_read_empty_
=
num_of_subscribed_regsts_
;
...
...
@@ -22,7 +23,8 @@ int BoxingActor::HandleBoxing(const ActorMsg& msg) {
OF_SET_MSG_HANDLE
(
&
BoxingActor
::
HandleBoxingWhenNoReadableRegstMsg
);
}
}
else
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
!=
0
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
!=
0
)
{
std
::
shared_ptr
<
RegstWarpper
>
regst_wp
=
msg
.
regst_warpper
();
num_of_read_empty_
-=
read_regst_
[
regst_wp
->
regst_desc_id
()].
empty
();
read_regst_
.
at
(
regst_wp
->
regst_desc_id
()).
push
(
regst_wp
);
...
...
@@ -35,7 +37,8 @@ int BoxingActor::HandleBoxing(const ActorMsg& msg) {
}
int
BoxingActor
::
HandleBoxingWhenNoReadableRegstMsg
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
TryWardKernelAndSendMsg
();
if
(
num_of_read_empty_
==
num_of_subscribed_regsts_
)
{
AsyncSendEORDMsgForAllProducedRegstDesc
();
...
...
@@ -49,25 +52,25 @@ int BoxingActor::HandleBoxingWhenNoReadableRegstMsg(const ActorMsg& msg) {
}
return
0
;
}
void
BoxingActor
::
TryWardKernelAndSendMsg
()
{
if
(
!
num_of_read_empty_
&&
IsWriteReady
())
{
int64_t
piece_id
=
expected_piece_id
();
for
(
const
auto
&
pair
:
read_regst_
)
{
CHECK_EQ
(
pair
.
second
.
front
()
->
piece_id
(),
piece_id
);
}
AsyncWardKernel
(
GenDefaultKernelCtx
(),
AsyncWardKernel
(
GenDefaultKernelCtx
(),
[
this
](
int64_t
regst_desc_id
)
->
std
::
shared_ptr
<
RegstWarpper
>
{
Regst
*
regst
=
GetCurWriteableRegst
(
regst_desc_id
);
if
(
regst
==
nullptr
)
{
return
read_regst_
.
at
(
regst_desc_id
).
front
();
}
else
{
return
std
::
make_shared
<
LocalRegstWarpper
>
(
regst
);
}
});
ForEachCurWriteableRegst
([
piece_id
](
Regst
*
regst
)
{
regst
->
set_piece_id
(
piece_id
);
});
Regst
*
regst
=
GetCurWriteableRegst
(
regst_desc_id
);
if
(
regst
==
nullptr
)
{
return
read_regst_
.
at
(
regst_desc_id
).
front
();
}
else
{
return
std
::
make_shared
<
LocalRegstWarpper
>
(
regst
);
}
});
ForEachCurWriteableRegst
(
[
piece_id
](
Regst
*
regst
)
{
regst
->
set_piece_id
(
piece_id
);
});
AsyncSendReadableRegstMsg
();
for
(
auto
&
pair
:
read_regst_
)
{
AsyncSendRegstMsgToProducer
(
pair
.
second
.
front
());
...
...
oneflow/core/actor/boxing_actor.h
浏览文件 @
03db51d7
...
...
@@ -24,9 +24,8 @@ class BoxingActor final : public Actor {
int
num_of_eord_
;
// <regst_desc_id, queue<regst_wp>>
HashMap
<
int64_t
,
std
::
queue
<
std
::
shared_ptr
<
RegstWarpper
>>>
read_regst_
;
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_BOXING_ACTOR_H_
#endif
// ONEFLOW_CORE_ACTOR_BOXING_ACTOR_H_
oneflow/core/actor/bp_data_comp_actor.cpp
浏览文件 @
03db51d7
...
...
@@ -4,7 +4,8 @@
namespace
oneflow
{
void
BpDataCompActor
::
Init
(
const
TaskProto
&
task_proto
,
const
ThreadCtx
&
thread_ctx
)
{
void
BpDataCompActor
::
Init
(
const
TaskProto
&
task_proto
,
const
ThreadCtx
&
thread_ctx
)
{
Actor
::
Init
(
task_proto
,
thread_ctx
);
model_regst_desc_id_
=
RegstDescId4Name
(
"model"
);
model_tmp_regst_desc_id_
=
RegstDescId4Name
(
"model_tmp"
);
...
...
@@ -24,11 +25,11 @@ void BpDataCompActor::Init(const TaskProto& task_proto, const ThreadCtx& thread_
}
bool
BpDataCompActor
::
IsReadReady
()
{
if
(
num_of_read_empty_
)
{
return
false
;
}
if
(
read_regst_
.
at
(
model_regst_desc_id_
).
front
()
->
model_version_id
()
!=
read_regst_
.
at
(
activation_regst_desc_id_
).
front
()
->
model_version_id
())
{
if
(
num_of_read_empty_
)
{
return
false
;
}
if
(
read_regst_
.
at
(
model_regst_desc_id_
).
front
()
->
model_version_id
()
!=
read_regst_
.
at
(
activation_regst_desc_id_
)
.
front
()
->
model_version_id
())
{
AsyncSendRegstMsgToProducer
(
read_regst_
.
at
(
model_regst_desc_id_
).
front
());
read_regst_
.
at
(
model_regst_desc_id_
).
pop
();
num_of_read_empty_
+=
read_regst_
.
at
(
model_regst_desc_id_
).
empty
();
...
...
@@ -44,7 +45,8 @@ int BpDataCompActor::HandleBpComp(const ActorMsg& msg) {
OF_SET_MSG_HANDLE
(
&
BpDataCompActor
::
HandleBpCompWhenNoReadableRegstMsg
);
}
}
else
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
!=
0
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
!=
0
)
{
std
::
shared_ptr
<
RegstWarpper
>
regst_wp
=
msg
.
regst_warpper
();
if
(
regst_wp
->
regst_desc_id
()
==
model_tmp_regst_desc_id_
)
{
CHECK
(
read_regst_
.
find
(
model_tmp_regst_desc_id_
)
==
read_regst_
.
end
());
...
...
@@ -62,14 +64,16 @@ int BpDataCompActor::HandleBpComp(const ActorMsg& msg) {
}
int
BpDataCompActor
::
HandleBpCompWhenNoReadableRegstMsg
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
TryWardKernelAndSendMsg
();
if
(
read_regst_
.
at
(
activation_regst_desc_id_
).
empty
())
{
while
(
!
read_regst_
.
at
(
model_regst_desc_id_
).
empty
())
{
AsyncSendRegstMsgToProducer
(
read_regst_
.
at
(
model_regst_desc_id_
).
front
());
read_regst_
.
at
(
model_regst_desc_id_
).
pop
();
}
AsyncSendRegstMsgToProducer
(
read_regst_
.
at
(
model_tmp_regst_desc_id_
).
front
());
AsyncSendRegstMsgToProducer
(
read_regst_
.
at
(
model_tmp_regst_desc_id_
).
front
());
read_regst_
.
at
(
model_tmp_regst_desc_id_
).
pop
();
AsyncSendEORDMsgForAllProducedRegstDesc
();
num_of_read_empty_
=
6
;
...
...
@@ -83,33 +87,40 @@ int BpDataCompActor::HandleBpCompWhenNoReadableRegstMsg(const ActorMsg& msg) {
}
return
0
;
}
void
BpDataCompActor
::
TryWardKernelAndSendMsg
()
{
while
(
IsReadReady
()
&&
IsWriteReady
())
{
int64_t
cur_model
=
read_regst_
.
at
(
model_regst_desc_id_
).
front
()
->
model_version_id
();
int64_t
cur_model
=
read_regst_
.
at
(
model_regst_desc_id_
).
front
()
->
model_version_id
();
int64_t
piece_id
=
expected_piece_id
();
CHECK_EQ
(
cur_model
,
read_regst_
.
at
(
activation_regst_desc_id_
).
front
()
->
model_version_id
());
CHECK_EQ
(
cur_model
,
read_regst_
.
at
(
data_tmp_regst_desc_id_
).
front
()
->
model_version_id
());
CHECK_EQ
(
cur_model
,
read_regst_
.
at
(
activation_regst_desc_id_
).
front
()
->
model_version_id
());
CHECK_EQ
(
cur_model
,
read_regst_
.
at
(
data_tmp_regst_desc_id_
).
front
()
->
model_version_id
());
for
(
const
auto
&
pair
:
read_regst_
)
{
if
(
pair
.
first
!=
model_regst_desc_id_
&&
pair
.
first
!=
model_tmp_regst_desc_id_
)
{
if
(
pair
.
first
!=
model_regst_desc_id_
&&
pair
.
first
!=
model_tmp_regst_desc_id_
)
{
CHECK_EQ
(
pair
.
second
.
front
()
->
piece_id
(),
piece_id
);
}
}
AsyncWardKernel
(
GenDefaultKernelCtx
(),
AsyncWardKernel
(
GenDefaultKernelCtx
(),
[
this
](
int64_t
regst_desc_id
)
->
std
::
shared_ptr
<
RegstWarpper
>
{
Regst
*
regst
=
GetCurWriteableRegst
(
regst_desc_id
);
if
(
regst
==
nullptr
)
{
return
read_regst_
.
at
(
regst_desc_id
).
front
();
}
else
{
return
std
::
make_shared
<
LocalRegstWarpper
>
(
regst
);
}
});
ForEachCurWriteableRegst
([
piece_id
](
Regst
*
regst
)
{
regst
->
set_piece_id
(
piece_id
);
});
Regst
*
regst
=
GetCurWriteableRegst
(
regst_desc_id
);
if
(
regst
==
nullptr
)
{
return
read_regst_
.
at
(
regst_desc_id
).
front
();
}
else
{
return
std
::
make_shared
<
LocalRegstWarpper
>
(
regst
);
}
});
ForEachCurWriteableRegst
(
[
piece_id
](
Regst
*
regst
)
{
regst
->
set_piece_id
(
piece_id
);
});
AsyncSendReadableRegstMsg
();
for
(
auto
&
pair
:
read_regst_
)
{
if
(
pair
.
first
!=
model_regst_desc_id_
&&
pair
.
first
!=
model_tmp_regst_desc_id_
)
{
if
(
pair
.
first
!=
model_regst_desc_id_
&&
pair
.
first
!=
model_tmp_regst_desc_id_
)
{
AsyncSendRegstMsgToProducer
(
pair
.
second
.
front
());
pair
.
second
.
pop
();
num_of_read_empty_
+=
pair
.
second
.
empty
();
...
...
oneflow/core/actor/bp_data_comp_actor.h
浏览文件 @
03db51d7
...
...
@@ -6,14 +6,14 @@
namespace
oneflow
{
class
BpDataCompActor
final
:
public
Actor
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
BpDataCompActor
);
BpDataCompActor
()
=
default
;
~
BpDataCompActor
()
=
default
;
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
override
;
private:
private:
int
HandleBpComp
(
const
ActorMsg
&
);
int
HandleBpCompWhenNoReadableRegstMsg
(
const
ActorMsg
&
);
...
...
@@ -34,4 +34,4 @@ private:
}
// namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_BP_DATA_COMP_ACTOR_H_
#endif
// ONEFLOW_CORE_ACTOR_BP_DATA_COMP_ACTOR_H_
oneflow/core/actor/compute_actor.h
浏览文件 @
03db51d7
...
...
@@ -13,7 +13,8 @@ class CompActor : public Actor {
protected:
CompActor
()
=
default
;
virtual
void
Init
(
const
TaskProto
&
task_proto
,
const
ThreadCtx
&
thread_ctx
)
override
{
virtual
void
Init
(
const
TaskProto
&
task_proto
,
const
ThreadCtx
&
thread_ctx
)
override
{
Actor
::
Init
(
task_proto
,
thread_ctx
);
parallel_id_
=
task_proto
.
parallel_id
();
}
...
...
@@ -26,9 +27,8 @@ class CompActor : public Actor {
ParallelPolicy
parallel_policy_
;
int64_t
parallel_id_
;
int64_t
parallel_num_
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_COMPUTE_ACTOR_H_
#endif
// ONEFLOW_CORE_ACTOR_COMPUTE_ACTOR_H_
oneflow/core/actor/copy_comm_net_actor.cpp
浏览文件 @
03db51d7
...
...
@@ -4,7 +4,8 @@
namespace
oneflow
{
void
CopyCommNetActor
::
Init
(
const
TaskProto
&
task_proto
,
const
ThreadCtx
&
thread_ctx
)
{
void
CopyCommNetActor
::
Init
(
const
TaskProto
&
task_proto
,
const
ThreadCtx
&
thread_ctx
)
{
Actor
::
Init
(
task_proto
,
thread_ctx
);
CHECK
(
thread_ctx
.
cpu_stream
);
mut_device_ctx
().
reset
(
new
CpuDeviceCtx
(
thread_ctx
.
cpu_stream
));
...
...
@@ -14,19 +15,23 @@ void CopyCommNetActor::Init(const TaskProto& task_proto, const ThreadCtx& thread
int
CopyCommNetActor
::
HandleCopyCommNet
(
const
ActorMsg
&
msg
)
{
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kCmdMsg
)
{
CHECK_EQ
(
msg
.
actor_cmd
(),
ActorCmd
::
kEORD
);
OF_SET_MSG_HANDLE
(
&
CopyCommNetActor
::
HandleCopyCommNetWhenNoReadableRegstMsg
);
OF_SET_MSG_HANDLE
(
&
CopyCommNetActor
::
HandleCopyCommNetWhenNoReadableRegstMsg
);
}
else
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
auto
regst_wp
=
msg
.
regst_warpper
();
if
(
TryUpdtStateAsProducedRegst
(
regst_wp
->
regst_raw_ptr
())
!=
0
)
{
CHECK
(
piece_id2waiting_in_regst_
.
emplace
(
regst_wp
->
piece_id
(),
regst_wp
).
second
);
CHECK
(
piece_id2waiting_in_regst_
.
emplace
(
regst_wp
->
piece_id
(),
regst_wp
)
.
second
);
}
}
TryWardKernelAndSendMsg
();
return
0
;
}
int
CopyCommNetActor
::
HandleCopyCommNetWhenNoReadableRegstMsg
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
int
CopyCommNetActor
::
HandleCopyCommNetWhenNoReadableRegstMsg
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
TryWardKernelAndSendMsg
();
if
(
piece_id2waiting_in_regst_
.
empty
())
{
AsyncSendEORDMsgForAllProducedRegstDesc
();
...
...
@@ -40,23 +45,23 @@ int CopyCommNetActor::HandleCopyCommNetWhenNoReadableRegstMsg(const ActorMsg& ms
}
return
0
;
}
void
CopyCommNetActor
::
TryWardKernelAndSendMsg
()
{
auto
next_regst_it
=
piece_id2waiting_in_regst_
.
find
(
expected_piece_id
());
if
(
next_regst_it
==
piece_id2waiting_in_regst_
.
end
())
{
return
;
}
if
(
next_regst_it
==
piece_id2waiting_in_regst_
.
end
())
{
return
;
}
if
(
IsWriteReady
())
{
std
::
shared_ptr
<
RegstWarpper
>
regst_wp
=
next_regst_it
->
second
;
AsyncWardKernel
(
GenDefaultKernelCtx
(),
[
this
,
&
regst_wp
](
uint64_t
regst_desc_id
)
->
std
::
shared_ptr
<
RegstWarpper
>
{
Regst
*
regst
=
GetCurWriteableRegst
(
regst_desc_id
);
if
(
regst
==
nullptr
)
{
return
regst_wp
;
}
else
{
return
std
::
make_shared
<
LocalRegstWarpper
>
(
regst
);
}
});
AsyncWardKernel
(
GenDefaultKernelCtx
(),
[
this
,
&
regst_wp
](
uint64_t
regst_desc_id
)
->
std
::
shared_ptr
<
RegstWarpper
>
{
Regst
*
regst
=
GetCurWriteableRegst
(
regst_desc_id
);
if
(
regst
==
nullptr
)
{
return
regst_wp
;
}
else
{
return
std
::
make_shared
<
LocalRegstWarpper
>
(
regst
);
}
});
ForEachCurWriteableRegst
([
&
regst_wp
](
Regst
*
regst
)
{
regst
->
set_piece_id
(
regst_wp
->
piece_id
());
regst
->
set_model_version_id
(
regst_wp
->
model_version_id
());
...
...
oneflow/core/actor/copy_comm_net_actor.h
浏览文件 @
03db51d7
...
...
@@ -6,22 +6,21 @@
namespace
oneflow
{
class
CopyCommNetActor
final
:
public
Actor
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
CopyCommNetActor
);
CopyCommNetActor
()
=
default
;
~
CopyCommNetActor
()
=
default
;
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
override
;
private:
private:
int
HandleCopyCommNet
(
const
ActorMsg
&
);
int
HandleCopyCommNetWhenNoReadableRegstMsg
(
const
ActorMsg
&
);
void
TryWardKernelAndSendMsg
();
HashMap
<
int64_t
,
std
::
shared_ptr
<
RegstWarpper
>>
piece_id2waiting_in_regst_
;
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_COPY_COMM_NET_ACTOR_H_
#endif
// ONEFLOW_CORE_ACTOR_COPY_COMM_NET_ACTOR_H_
oneflow/core/actor/copy_hd_actor.cpp
浏览文件 @
03db51d7
...
...
@@ -4,12 +4,12 @@
namespace
oneflow
{
void
CopyHdActor
::
Init
(
const
TaskProto
&
task_proto
,
const
ThreadCtx
&
thread_ctx
)
{
void
CopyHdActor
::
Init
(
const
TaskProto
&
task_proto
,
const
ThreadCtx
&
thread_ctx
)
{
Actor
::
Init
(
task_proto
,
thread_ctx
);
CHECK
(
thread_ctx
.
copy_hd_cuda_stream
);
mut_device_ctx
().
reset
(
new
CudaDeviceCtx
(
thread_ctx
.
copy_hd_cuda_stream
,
nullptr
,
nullptr
));
mut_device_ctx
().
reset
(
new
CudaDeviceCtx
(
thread_ctx
.
copy_hd_cuda_stream
,
nullptr
,
nullptr
));
OF_SET_MSG_HANDLE
(
&
CopyHdActor
::
HandleCopyHd
);
}
...
...
@@ -18,7 +18,8 @@ int CopyHdActor::HandleCopyHd(const ActorMsg& msg) {
CHECK_EQ
(
msg
.
actor_cmd
(),
ActorCmd
::
kEORD
);
OF_SET_MSG_HANDLE
(
&
CopyHdActor
::
HandleCopyHdWhenNoReadableRegstMsg
);
}
else
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
!=
0
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
!=
0
)
{
waiting_in_regst_
.
push
(
msg
.
regst_warpper
());
}
}
...
...
@@ -27,7 +28,8 @@ int CopyHdActor::HandleCopyHd(const ActorMsg& msg) {
}
int
CopyHdActor
::
HandleCopyHdWhenNoReadableRegstMsg
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
TryWardKernelAndSendMsg
();
if
(
waiting_in_regst_
.
empty
())
{
AsyncSendEORDMsgForAllProducedRegstDesc
();
...
...
@@ -46,16 +48,17 @@ void CopyHdActor::TryWardKernelAndSendMsg() {
if
(
!
waiting_in_regst_
.
empty
()
&&
IsWriteReady
())
{
std
::
shared_ptr
<
RegstWarpper
>
regst_wp
=
waiting_in_regst_
.
front
();
CHECK_EQ
(
regst_wp
->
piece_id
(),
expected_piece_id
());
AsyncWardKernel
(
GenDefaultKernelCtx
(),
AsyncWardKernel
(
GenDefaultKernelCtx
(),
[
this
](
uint64_t
regst_desc_id
)
->
std
::
shared_ptr
<
RegstWarpper
>
{
Regst
*
regst
=
GetCurWriteableRegst
(
regst_desc_id
);
if
(
regst
==
nullptr
)
{
CHECK_EQ
(
regst_desc_id
,
waiting_in_regst_
.
front
()
->
regst_desc_id
());
return
waiting_in_regst_
.
front
();
}
else
{
return
std
::
make_shared
<
LocalRegstWarpper
>
(
regst
);
}
});
Regst
*
regst
=
GetCurWriteableRegst
(
regst_desc_id
);
if
(
regst
==
nullptr
)
{
CHECK_EQ
(
regst_desc_id
,
waiting_in_regst_
.
front
()
->
regst_desc_id
());
return
waiting_in_regst_
.
front
();
}
else
{
return
std
::
make_shared
<
LocalRegstWarpper
>
(
regst
);
}
});
ForEachCurWriteableRegst
([
&
regst_wp
](
Regst
*
regst
)
{
regst
->
set_piece_id
(
regst_wp
->
piece_id
());
regst
->
set_model_version_id
(
regst_wp
->
model_version_id
());
...
...
oneflow/core/actor/copy_hd_actor.h
浏览文件 @
03db51d7
...
...
@@ -6,22 +6,21 @@
namespace
oneflow
{
class
CopyHdActor
final
:
public
Actor
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
CopyHdActor
);
CopyHdActor
()
=
default
;
~
CopyHdActor
()
=
default
;
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
override
;
private:
private:
int
HandleCopyHd
(
const
ActorMsg
&
);
int
HandleCopyHdWhenNoReadableRegstMsg
(
const
ActorMsg
&
);
void
TryWardKernelAndSendMsg
();
std
::
queue
<
std
::
shared_ptr
<
RegstWarpper
>>
waiting_in_regst_
;
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_COPY_HD_ACTOR_H_
#endif
// ONEFLOW_CORE_ACTOR_COPY_HD_ACTOR_H_
oneflow/core/actor/cpu_device_context.h
浏览文件 @
03db51d7
...
...
@@ -10,10 +10,8 @@ class CpuDeviceCtx final : public DeviceCtx {
// OF_DISALLOW_COPY_AND_MOVE(CpuDeviceCtx);
CpuDeviceCtx
()
=
delete
;
~
CpuDeviceCtx
()
=
default
;
CpuDeviceCtx
(
Channel
<
std
::
function
<
void
()
>>*
chan
)
{
set_cpu_stream
(
chan
);
}
CpuDeviceCtx
(
Channel
<
std
::
function
<
void
()
>>*
chan
)
{
set_cpu_stream
(
chan
);
}
void
AddCallBack
(
std
::
function
<
void
()
>
callback
)
const
override
{
cpu_stream
()
->
Send
(
callback
);
...
...
@@ -22,6 +20,6 @@ class CpuDeviceCtx final : public DeviceCtx {
private:
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_CPU_DEVICE_CONTEXT_H_
#endif
// ONEFLOW_CORE_ACTOR_CPU_DEVICE_CONTEXT_H_
oneflow/core/actor/cuda_device_context.cpp
浏览文件 @
03db51d7
...
...
@@ -4,23 +4,21 @@ namespace oneflow {
namespace
{
void
CUDART_CB
CudaCallBackHandle
(
cudaStream_t
,
cudaError_t
status
,
void
CUDART_CB
CudaCallBackHandle
(
cudaStream_t
,
cudaError_t
status
,
void
*
void_ptr
)
{
CHECK_EQ
(
status
,
cudaSuccess
);
auto
callback_ptr
=
static_cast
<
std
::
function
<
void
()
>*>
(
void_ptr
);
auto
callback_ptr
=
static_cast
<
std
::
function
<
void
()
>*>
(
void_ptr
);
(
*
callback_ptr
)();
delete
callback_ptr
;
}
}
// namespace
}
// namespace
void
CudaDeviceCtx
::
AddCallBack
(
std
::
function
<
void
()
>
callback_stack
)
const
{
auto
callback_heap
=
new
std
::
function
<
void
()
>
(
callback_stack
);
CHECK_EQ
(
cudaStreamAddCallback
(
cuda_stream
(),
&
CudaCallBackHandle
,
auto
callback_heap
=
new
std
::
function
<
void
()
>
(
callback_stack
);
CHECK_EQ
(
cudaStreamAddCallback
(
cuda_stream
(),
&
CudaCallBackHandle
,
callback_heap
,
0
),
cudaSuccess
);
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/actor/cuda_device_context.h
浏览文件 @
03db51d7
...
...
@@ -24,6 +24,6 @@ class CudaDeviceCtx final : public DeviceCtx {
private:
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_CUDA_DEVICE_CONTEXT_H_
#endif
// ONEFLOW_CORE_ACTOR_CUDA_DEVICE_CONTEXT_H_
oneflow/core/actor/device_context.h
浏览文件 @
03db51d7
#ifndef ONEFLOW_CORE_ACTOR_DEVICE_CONTEXT_H_
#define ONEFLOW_CORE_ACTOR_DEVICE_CONTEXT_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/channel.h"
#include "oneflow/core/common/util.h"
namespace
oneflow
{
...
...
@@ -19,32 +19,26 @@ class DeviceCtx {
virtual
void
AddCallBack
(
std
::
function
<
void
()
>
)
const
=
0
;
protected:
DeviceCtx
()
:
cpu_stream_
(
nullptr
),
cuda_stream_
(
nullptr
),
cublas_handle_
(
nullptr
),
cudnn_handle_
(
nullptr
)
{}
DeviceCtx
()
:
cpu_stream_
(
nullptr
),
cuda_stream_
(
nullptr
),
cublas_handle_
(
nullptr
),
cudnn_handle_
(
nullptr
)
{}
void
set_cpu_stream
(
Channel
<
std
::
function
<
void
()
>>*
val
)
{
cpu_stream_
=
val
;
}
void
set_cuda_stream
(
const
cudaStream_t
*
val
)
{
cuda_stream_
=
val
;
}
void
set_cublas_handle
(
const
cublasHandle_t
*
val
)
{
cublas_handle_
=
val
;
}
void
set_cudnn_handle
(
const
cudnnHandle_t
*
val
)
{
cudnn_handle_
=
val
;
}
void
set_cuda_stream
(
const
cudaStream_t
*
val
)
{
cuda_stream_
=
val
;
}
void
set_cublas_handle
(
const
cublasHandle_t
*
val
)
{
cublas_handle_
=
val
;
}
void
set_cudnn_handle
(
const
cudnnHandle_t
*
val
)
{
cudnn_handle_
=
val
;
}
private:
Channel
<
std
::
function
<
void
()
>>*
cpu_stream_
;
const
cudaStream_t
*
cuda_stream_
;
const
cublasHandle_t
*
cublas_handle_
;
const
cudnnHandle_t
*
cudnn_handle_
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_DEVICE_CONTEXT_H_
#endif
// ONEFLOW_CORE_ACTOR_DEVICE_CONTEXT_H_
oneflow/core/actor/fw_data_comp_actor.cpp
浏览文件 @
03db51d7
...
...
@@ -25,24 +25,24 @@ void FwDataCompActor::Init(const TaskProto& task_proto,
kernel_ctx_
.
other
=
reinterpret_cast
<
void
*>
(
parallel_id
());
OF_SET_MSG_HANDLE
(
&
FwDataCompActor
::
WaitToStart
);
}
else
{
num_of_not_eord_
=
1
+
(
model_regst_desc_id_
!=
-
1
)
+
(
model_tmp_regst_desc_id_
!=
-
1
);
num_of_not_eord_
=
1
+
(
model_regst_desc_id_
!=
-
1
)
+
(
model_tmp_regst_desc_id_
!=
-
1
);
OF_SET_MSG_HANDLE
(
&
FwDataCompActor
::
HandleFwComp
);
}
}
bool
FwDataCompActor
::
IsReadReady
()
{
if
(
in_desc_id_
==
-
1
)
{
return
true
;
}
if
(
in_desc_id_
==
-
1
)
{
return
true
;
}
if
(
in_
.
empty
()
||
(
model_regst_desc_id_
!=
-
1
&&
!
model_regst_
)
||
(
model_tmp_regst_desc_id_
!=
-
1
&&
!
model_tmp_regst_
))
{
||
(
model_tmp_regst_desc_id_
!=
-
1
&&
!
model_tmp_regst_
))
{
return
false
;
}
if
(
model_regst_desc_id_
!=
-
1
)
{
//Ho Q, Cipar J, Cui H, et al. More effective distributed ml via a stale synchronous parallel parameter server
// Ho Q, Cipar J, Cui H, et al. More effective distributed ml via a stale
// synchronous parallel parameter server
int32_t
staleness
=
JobDesc
::
Singleton
().
staleness
();
int32_t
num_of_piece_in_batch
=
JobDesc
::
Singleton
().
num_of_piece_in_batch
();
int32_t
num_of_piece_in_batch
=
JobDesc
::
Singleton
().
num_of_piece_in_batch
();
int64_t
cur_iteration
=
in_
.
front
()
->
piece_id
()
/
num_of_piece_in_batch
;
int64_t
stale_version
=
cur_iteration
-
staleness
;
return
model_regst_
->
model_version_id
()
>=
stale_version
;
...
...
@@ -65,7 +65,8 @@ int FwDataCompActor::HandleFwComp(const ActorMsg& msg) {
OF_SET_MSG_HANDLE
(
&
FwDataCompActor
::
HandleFwCompWhenNoReadableRegstMsg
);
}
}
else
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
!=
0
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
!=
0
)
{
std
::
shared_ptr
<
RegstWarpper
>
regst_wp
=
msg
.
regst_warpper
();
if
(
regst_wp
->
regst_desc_id
()
==
model_tmp_regst_desc_id_
)
{
CHECK
(
!
model_tmp_regst_
);
...
...
@@ -73,9 +74,7 @@ int FwDataCompActor::HandleFwComp(const ActorMsg& msg) {
ready_in_regst_
[
model_tmp_regst_desc_id_
]
=
regst_wp
;
}
else
if
(
regst_wp
->
regst_desc_id
()
==
model_regst_desc_id_
)
{
CHECK_EQ
(
regst_wp
->
model_version_id
(),
expected_model_version_id_
);
if
(
model_regst_
)
{
AsyncSendRegstMsgToProducer
(
model_regst_
);
}
if
(
model_regst_
)
{
AsyncSendRegstMsgToProducer
(
model_regst_
);
}
model_regst_
=
regst_wp
;
ready_in_regst_
[
model_regst_desc_id_
]
=
regst_wp
;
expected_model_version_id_
+=
1
;
...
...
@@ -89,10 +88,12 @@ int FwDataCompActor::HandleFwComp(const ActorMsg& msg) {
}
int
FwDataCompActor
::
HandleFwCompWhenNoReadableRegstMsg
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
TryWardKernelAndSendMsg
();
int
total_piece_num
=
JobDesc
::
Singleton
().
total_piece_num
();
if
((
in_desc_id_
!=-
1
&&
in_
.
empty
())
||
expected_piece_id
()
==
total_piece_num
)
{
if
((
in_desc_id_
!=
-
1
&&
in_
.
empty
())
||
expected_piece_id
()
==
total_piece_num
)
{
if
(
model_regst_desc_id_
!=
-
1
)
{
AsyncSendRegstMsgToProducer
(
model_regst_
);
model_regst_
=
nullptr
;
...
...
@@ -112,7 +113,7 @@ int FwDataCompActor::HandleFwCompWhenNoReadableRegstMsg(const ActorMsg& msg) {
}
return
0
;
}
void
FwDataCompActor
::
TryWardKernelAndSendMsg
()
{
while
(
IsReadReady
()
&&
IsWriteReady
())
{
int64_t
piece_id
=
expected_piece_id
();
...
...
@@ -121,18 +122,17 @@ void FwDataCompActor::TryWardKernelAndSendMsg() {
ready_in_regst_
[
in_
.
front
()
->
regst_desc_id
()]
=
in_
.
front
();
}
int64_t
model_version_id
=
-
1
;
if
(
model_regst_
)
{
model_version_id
=
model_regst_
->
model_version_id
();
}
AsyncWardKernel
(
kernel_ctx_
,
if
(
model_regst_
)
{
model_version_id
=
model_regst_
->
model_version_id
();
}
AsyncWardKernel
(
kernel_ctx_
,
[
this
](
int64_t
regst_desc_id
)
->
std
::
shared_ptr
<
RegstWarpper
>
{
Regst
*
regst
=
GetCurWriteableRegst
(
regst_desc_id
);
if
(
regst
==
nullptr
)
{
return
ready_in_regst_
.
at
(
regst_desc_id
);
}
else
{
return
std
::
make_shared
<
LocalRegstWarpper
>
(
regst
);
}
});
Regst
*
regst
=
GetCurWriteableRegst
(
regst_desc_id
);
if
(
regst
==
nullptr
)
{
return
ready_in_regst_
.
at
(
regst_desc_id
);
}
else
{
return
std
::
make_shared
<
LocalRegstWarpper
>
(
regst
);
}
});
ForEachCurWriteableRegst
([
piece_id
,
model_version_id
](
Regst
*
regst
)
{
regst
->
set_piece_id
(
piece_id
);
regst
->
set_model_version_id
(
model_version_id
);
...
...
oneflow/core/actor/fw_data_comp_actor.h
浏览文件 @
03db51d7
...
...
@@ -6,14 +6,14 @@
namespace
oneflow
{
class
FwDataCompActor
final
:
public
CompActor
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
FwDataCompActor
);
FwDataCompActor
()
=
default
;
~
FwDataCompActor
()
=
default
;
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
override
;
private:
private:
int
WaitToStart
(
const
ActorMsg
&
);
int
HandleFwComp
(
const
ActorMsg
&
);
int
HandleFwCompWhenNoReadableRegstMsg
(
const
ActorMsg
&
);
...
...
@@ -36,4 +36,4 @@ private:
}
// namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_FW_DATA_COMP_ACTOR_H_
#endif
// ONEFLOW_CORE_ACTOR_FW_DATA_COMP_ACTOR_H_
oneflow/core/actor/model_diff_accumulate_actor.cpp
浏览文件 @
03db51d7
...
...
@@ -15,9 +15,8 @@ void MdDiffAccActor::Init(const TaskProto& task_proto,
cuda_handle_
.
cudnn_handle
()));
}
OF_SET_MSG_HANDLE
(
&
MdDiffAccActor
::
HandleMdDiffAcc
);
ForEachCurWriteableRegst
([
this
](
Regst
*
regst
)
{
model_diff_acc_cnt_
[
regst
]
=
0
;
});
ForEachCurWriteableRegst
(
[
this
](
Regst
*
regst
)
{
model_diff_acc_cnt_
[
regst
]
=
0
;
});
}
int
MdDiffAccActor
::
HandleMdDiffAcc
(
const
ActorMsg
&
msg
)
{
...
...
@@ -25,7 +24,8 @@ int MdDiffAccActor::HandleMdDiffAcc(const ActorMsg& msg) {
CHECK_EQ
(
msg
.
actor_cmd
(),
ActorCmd
::
kEORD
);
OF_SET_MSG_HANDLE
(
&
MdDiffAccActor
::
HandleMdDiffAccWhenNoReadableRegstMsg
);
}
else
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
!=
0
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
!=
0
)
{
waiting_in_regst_
.
push
(
msg
.
regst_warpper
());
}
}
...
...
@@ -34,7 +34,8 @@ int MdDiffAccActor::HandleMdDiffAcc(const ActorMsg& msg) {
}
int
MdDiffAccActor
::
HandleMdDiffAccWhenNoReadableRegstMsg
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
TryWardKernelAndSendMsg
();
if
(
waiting_in_regst_
.
empty
())
{
AsyncSendEORDMsgForAllProducedRegstDesc
();
...
...
@@ -50,9 +51,7 @@ int MdDiffAccActor::HandleMdDiffAccWhenNoReadableRegstMsg(const ActorMsg& msg) {
}
void
MdDiffAccActor
::
TryWardKernelAndSendMsg
()
{
if
(
waiting_in_regst_
.
empty
()
||
!
IsWriteReady
())
{
return
;
}
if
(
waiting_in_regst_
.
empty
()
||
!
IsWriteReady
())
{
return
;
}
std
::
shared_ptr
<
RegstWarpper
>
regst_wp
=
waiting_in_regst_
.
front
();
CHECK_EQ
(
regst_wp
->
piece_id
(),
expected_piece_id
());
KernelCtx
ctx
=
GenDefaultKernelCtx
();
...
...
@@ -67,15 +66,16 @@ void MdDiffAccActor::TryWardKernelAndSendMsg() {
});
diff_cnt
->
second
=
0
;
});
AsyncWardKernel
(
ctx
,
[
this
](
uint64_t
regst_desc_id
)
->
std
::
shared_ptr
<
RegstWarpper
>
{
Regst
*
regst
=
GetCurWriteableRegst
(
regst_desc_id
);
if
(
regst
==
nullptr
)
{
CHECK_EQ
(
regst_desc_id
,
waiting_in_regst_
.
front
()
->
regst_desc_id
());
return
waiting_in_regst_
.
front
();
}
else
{
return
std
::
make_shared
<
LocalRegstWarpper
>
(
regst
);
}
});
AsyncWardKernel
(
ctx
,
[
this
](
uint64_t
regst_desc_id
)
->
std
::
shared_ptr
<
RegstWarpper
>
{
Regst
*
regst
=
GetCurWriteableRegst
(
regst_desc_id
);
if
(
regst
==
nullptr
)
{
CHECK_EQ
(
regst_desc_id
,
waiting_in_regst_
.
front
()
->
regst_desc_id
());
return
waiting_in_regst_
.
front
();
}
else
{
return
std
::
make_shared
<
LocalRegstWarpper
>
(
regst
);
}
});
ForEachCurWriteableRegst
([
this
,
&
regst_wp
](
Regst
*
regst
)
{
regst
->
set_piece_id
(
regst_wp
->
piece_id
());
++
model_diff_acc_cnt_
.
at
(
regst
);
...
...
oneflow/core/actor/model_diff_accumulate_actor.h
浏览文件 @
03db51d7
...
...
@@ -6,14 +6,14 @@
namespace
oneflow
{
class
MdDiffAccActor
final
:
public
CompActor
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
MdDiffAccActor
);
MdDiffAccActor
()
=
default
;
~
MdDiffAccActor
()
=
default
;
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
override
;
private:
private:
int
HandleMdDiffAcc
(
const
ActorMsg
&
);
int
HandleMdDiffAccWhenNoReadableRegstMsg
(
const
ActorMsg
&
);
...
...
oneflow/core/actor/model_save_comp_actor.cpp
浏览文件 @
03db51d7
...
...
@@ -3,7 +3,8 @@
namespace
oneflow
{
void
MdSaveCompActor
::
Init
(
const
TaskProto
&
task_proto
,
const
ThreadCtx
&
thread_ctx
)
{
void
MdSaveCompActor
::
Init
(
const
TaskProto
&
task_proto
,
const
ThreadCtx
&
thread_ctx
)
{
CompActor
::
Init
(
task_proto
,
thread_ctx
);
model_regst_desc_id_
=
RegstDescId4Name
(
"model"
);
CHECK
(
thread_ctx
.
cpu_stream
);
...
...
@@ -18,29 +19,27 @@ int MdSaveCompActor::HandleSaveModel(const ActorMsg& actor_msg) {
}
else
if
(
actor_msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
std
::
shared_ptr
<
RegstWarpper
>
regst_warpper
=
actor_msg
.
regst_warpper
();
int64_t
model_version_id
=
regst_warpper
->
model_version_id
();
int32_t
num_of_batches_in_snapshot
=
int32_t
num_of_batches_in_snapshot
=
JobDesc
::
Singleton
().
num_of_batches_in_snapshot
();
CHECK_GT
(
num_of_batches_in_snapshot
,
0
);
if
(
model_version_id
%
num_of_batches_in_snapshot
==
0
)
{
int64_t
snapshot_id
=
model_version_id
/
num_of_batches_in_snapshot
;
Snapshot
*
snapshot
=
SnapshotMgr
::
Singleton
().
GetWriteableSnapshot
(
snapshot_id
);
Snapshot
*
snapshot
=
SnapshotMgr
::
Singleton
().
GetWriteableSnapshot
(
snapshot_id
);
KernelCtx
kernel_ctx
=
GenDefaultKernelCtx
();
std
::
tuple
<
Snapshot
*
,
int64_t
>
save_ctx
=
std
::
make_tuple
(
snapshot
,
parallel_id
());
std
::
tuple
<
Snapshot
*
,
int64_t
>
save_ctx
=
std
::
make_tuple
(
snapshot
,
parallel_id
());
kernel_ctx
.
other
=
&
save_ctx
;
AsyncWardKernel
(
kernel_ctx
,
[
&
](
int64_t
regst_desc_id
)
->
std
::
shared_ptr
<
RegstWarpper
>
{
CHECK_EQ
(
regst_desc_id
,
model_regst_desc_id_
);
return
regst_warpper
;
});
CHECK_EQ
(
regst_desc_id
,
model_regst_desc_id_
);
return
regst_warpper
;
});
}
ActorMsg
msg
=
ActorMsg
::
BuildRegstMsgToProducer
(
regst_warpper
->
producer_actor_id
(),
regst_warpper
->
regst_raw_ptr
());
AsyncDo
([
msg
]()
{
ActorMsgBus
::
Singleton
().
SendMsg
(
msg
);
});
regst_warpper
->
producer_actor_id
(),
regst_warpper
->
regst_raw_ptr
());
AsyncDo
([
msg
]()
{
ActorMsgBus
::
Singleton
().
SendMsg
(
msg
);
});
}
else
{
UNEXPECTED_RUN
();
}
...
...
oneflow/core/actor/model_save_comp_actor.h
浏览文件 @
03db51d7
...
...
@@ -21,4 +21,4 @@ class MdSaveCompActor final : public CompActor {
}
// namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_MODEL_SAVE_COMP_ACTOR_H_
#endif
// ONEFLOW_CORE_ACTOR_MODEL_SAVE_COMP_ACTOR_H_
oneflow/core/actor/model_update_comp_actor.cpp
浏览文件 @
03db51d7
...
...
@@ -4,7 +4,8 @@
namespace
oneflow
{
void
MdUpdtCompActor
::
Init
(
const
TaskProto
&
task_proto
,
const
ThreadCtx
&
thread_ctx
)
{
void
MdUpdtCompActor
::
Init
(
const
TaskProto
&
task_proto
,
const
ThreadCtx
&
thread_ctx
)
{
CompActor
::
Init
(
task_proto
,
thread_ctx
);
model_regst_desc_id_
=
RegstDescId4Name
(
"model"
);
model_tmp_regst_desc_id_
=
RegstDescId4Name
(
"model_tmp"
);
...
...
@@ -31,25 +32,20 @@ int MdUpdtCompActor::HandleBeforeInitializeModel(const ActorMsg& actor_msg) {
};
model_regst
->
ForEachLbn
(
CollectKernelsFromLbn
);
model_tmp_regst
->
ForEachLbn
(
CollectKernelsFromLbn
);
for
(
const
Kernel
*
kernel
:
kernels
)
{
kernel
->
InitModelAndModelTmpBlobs
(
GenDefaultKernelCtx
(),
parallel_policy
(),
parallel_id
(),
parallel_num
(),
GenDefaultKernelCtx
(),
parallel_policy
(),
parallel_id
(),
parallel_num
(),
SnapshotMgr
::
Singleton
().
GetReadableSnapshot
(),
[
&
](
const
std
::
string
&
bn_in_op
)
{
const
std
::
string
&
lbn
=
kernel
->
Lbn4BnInOp
(
bn_in_op
);
Blob
*
ret
=
model_regst
->
GetBlobPtrFromLbn
(
lbn
);
if
(
ret
==
nullptr
)
{
ret
=
model_tmp_regst
->
GetBlobPtrFromLbn
(
lbn
);
}
CHECK
(
ret
!=
nullptr
);
return
ret
;
});
const
std
::
string
&
lbn
=
kernel
->
Lbn4BnInOp
(
bn_in_op
);
Blob
*
ret
=
model_regst
->
GetBlobPtrFromLbn
(
lbn
);
if
(
ret
==
nullptr
)
{
ret
=
model_tmp_regst
->
GetBlobPtrFromLbn
(
lbn
);
}
CHECK
(
ret
!=
nullptr
);
return
ret
;
});
}
AsyncDo
([]()
{
RuntimeCtx
::
Singleton
().
OneModelInitDone
();
});
AsyncDo
([]()
{
RuntimeCtx
::
Singleton
().
OneModelInitDone
();
});
OF_SET_MSG_HANDLE
(
&
MdUpdtCompActor
::
HandleBeforeSendInitialModel
);
return
0
;
}
...
...
@@ -86,8 +82,9 @@ int MdUpdtCompActor::HandleUpdateModel(const ActorMsg& actor_msg) {
int
MdUpdtCompActor
::
HandleUpdtModelWhenNoReadableRegstMsg
(
const
ActorMsg
&
actor_msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
actor_msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
actor_msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
TryWardKernelAndSendMsg
();
if
(
waiting_model_diff_acc_queue_
.
empty
())
{
AsyncSendEORDMsgToSubscribers
(
model_regst_desc_id_
);
...
...
@@ -109,14 +106,15 @@ void MdUpdtCompActor::TryWardKernelAndSendMsg() {
Regst
*
model_regst
=
GetCurWriteableRegst
(
model_regst_desc_id_
);
auto
model_wpr
=
std
::
make_shared
<
LocalRegstWarpper
>
(
model_regst
);
model_regst
->
set_model_version_id
(
next_model_version_id_
++
);
AsyncWardKernel
(
GenDefaultKernelCtx
(),
AsyncWardKernel
(
GenDefaultKernelCtx
(),
[
&
](
int64_t
regst_desc_id
)
->
std
::
shared_ptr
<
RegstWarpper
>
{
if
(
regst_desc_id
==
model_regst_desc_id_
)
{
return
model_wpr
;
}
else
{
return
model_diff_acc_wpr
;
}
});
if
(
regst_desc_id
==
model_regst_desc_id_
)
{
return
model_wpr
;
}
else
{
return
model_diff_acc_wpr
;
}
});
AsyncSendReadableRegstMsg
();
AsyncSendRegstMsgToProducer
(
model_diff_acc_wpr
);
}
...
...
oneflow/core/actor/model_update_comp_actor.h
浏览文件 @
03db51d7
...
...
@@ -27,9 +27,8 @@ class MdUpdtCompActor final : public CompActor {
int64_t
model_tmp_regst_desc_id_
;
std
::
queue
<
std
::
shared_ptr
<
RegstWarpper
>>
waiting_model_diff_acc_queue_
;
int64_t
next_model_version_id_
;
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_MODEL_UPDATE_COMP_ACTOR_H_
#endif
// ONEFLOW_CORE_ACTOR_MODEL_UPDATE_COMP_ACTOR_H_
oneflow/core/blas/cblas.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/blas/cblas_template.cpp
浏览文件 @
03db51d7
...
...
@@ -5,116 +5,112 @@ namespace oneflow {
// level 1 vector and vector
// dot product
template
<
>
float
cblas_dot
<
float
>
(
const
int
n
,
const
float
*
x
,
const
int
incx
,
const
float
*
y
,
const
int
incy
)
{
float
cblas_dot
<
float
>
(
const
int
n
,
const
float
*
x
,
const
int
incx
,
const
float
*
y
,
const
int
incy
)
{
return
cblas_sdot
(
n
,
x
,
incx
,
y
,
incy
);
}
template
<
>
double
cblas_dot
<
double
>
(
const
int
n
,
const
double
*
x
,
const
int
incx
,
const
double
*
y
,
const
int
incy
)
{
double
cblas_dot
<
double
>
(
const
int
n
,
const
double
*
x
,
const
int
incx
,
const
double
*
y
,
const
int
incy
)
{
return
cblas_ddot
(
n
,
x
,
incx
,
y
,
incy
);
}
// swap x and y
template
<
>
void
cblas_swap
<
float
>
(
const
int
n
,
float
*
x
,
const
int
incx
,
float
*
y
,
const
int
incy
)
{
void
cblas_swap
<
float
>
(
const
int
n
,
float
*
x
,
const
int
incx
,
float
*
y
,
const
int
incy
)
{
cblas_sswap
(
n
,
x
,
incx
,
y
,
incy
);
}
template
<
>
void
cblas_swap
<
double
>
(
const
int
n
,
double
*
x
,
const
int
incx
,
double
*
y
,
const
int
incy
)
{
void
cblas_swap
<
double
>
(
const
int
n
,
double
*
x
,
const
int
incx
,
double
*
y
,
const
int
incy
)
{
cblas_dswap
(
n
,
x
,
incx
,
y
,
incy
);
}
// copy x into y
template
<
>
void
cblas_copy
<
float
>
(
const
int
n
,
const
float
*
x
,
const
int
incx
,
float
*
y
,
const
int
incy
)
{
void
cblas_copy
<
float
>
(
const
int
n
,
const
float
*
x
,
const
int
incx
,
float
*
y
,
const
int
incy
)
{
cblas_scopy
(
n
,
x
,
incx
,
y
,
incy
);
}
template
<
>
void
cblas_copy
<
double
>
(
const
int
n
,
const
double
*
x
,
const
int
incx
,
double
*
y
,
const
int
incy
)
{
void
cblas_copy
<
double
>
(
const
int
n
,
const
double
*
x
,
const
int
incx
,
double
*
y
,
const
int
incy
)
{
cblas_dcopy
(
n
,
x
,
incx
,
y
,
incy
);
}
// y = a*x + y
template
<
>
void
cblas_axpy
<
float
>
(
const
int
n
,
const
float
alpha
,
const
float
*
x
,
const
int
incx
,
float
*
y
,
const
int
incy
)
{
void
cblas_axpy
<
float
>
(
const
int
n
,
const
float
alpha
,
const
float
*
x
,
const
int
incx
,
float
*
y
,
const
int
incy
)
{
cblas_saxpy
(
n
,
alpha
,
x
,
incx
,
y
,
incy
);
}
template
<
>
void
cblas_axpy
<
double
>
(
const
int
n
,
const
double
alpha
,
const
double
*
x
,
const
int
incx
,
double
*
y
,
const
int
incy
)
{
void
cblas_axpy
<
double
>
(
const
int
n
,
const
double
alpha
,
const
double
*
x
,
const
int
incx
,
double
*
y
,
const
int
incy
)
{
cblas_daxpy
(
n
,
alpha
,
x
,
incx
,
y
,
incy
);
}
// x = a*x
template
<
>
void
cblas_scal
<
float
>
(
const
int
n
,
const
float
alpha
,
float
*
x
,
const
int
incx
)
{
void
cblas_scal
<
float
>
(
const
int
n
,
const
float
alpha
,
float
*
x
,
const
int
incx
)
{
cblas_sscal
(
n
,
alpha
,
x
,
incx
);
}
template
<
>
void
cblas_scal
<
double
>
(
const
int
n
,
const
double
alpha
,
double
*
x
,
const
int
incx
)
{
void
cblas_scal
<
double
>
(
const
int
n
,
const
double
alpha
,
double
*
x
,
const
int
incx
)
{
cblas_dscal
(
n
,
alpha
,
x
,
incx
);
}
// level 2 matrix and vector
// matrix vector multiply
template
<
>
void
cblas_gemv
<
float
>
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
trans_a
,
const
int
m
,
const
int
n
,
const
float
alpha
,
const
float
*
a
,
const
int
lda
,
const
float
*
x
,
const
int
incx
,
const
float
beta
,
float
*
y
,
const
int
incy
)
{
void
cblas_gemv
<
float
>
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
trans_a
,
const
int
m
,
const
int
n
,
const
float
alpha
,
const
float
*
a
,
const
int
lda
,
const
float
*
x
,
const
int
incx
,
const
float
beta
,
float
*
y
,
const
int
incy
)
{
cblas_sgemv
(
order
,
trans_a
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
);
}
template
<
>
void
cblas_gemv
<
double
>
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
trans_a
,
const
int
m
,
const
int
n
,
const
double
alpha
,
const
double
*
a
,
const
int
lda
,
const
double
*
x
,
const
int
incx
,
const
double
beta
,
double
*
y
,
const
int
incy
)
{
void
cblas_gemv
<
double
>
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
trans_a
,
const
int
m
,
const
int
n
,
const
double
alpha
,
const
double
*
a
,
const
int
lda
,
const
double
*
x
,
const
int
incx
,
const
double
beta
,
double
*
y
,
const
int
incy
)
{
cblas_dgemv
(
order
,
trans_a
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
);
}
// matrix matrix multiply
template
<
>
void
cblas_gemm
<
float
>
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
trans_a
,
const
enum
CBLAS_TRANSPOSE
trans_b
,
const
int
m
,
const
int
n
,
const
int
k
,
const
float
alpha
,
const
float
*
a
,
const
int
lda
,
const
float
*
b
,
const
int
ldb
,
const
float
beta
,
float
*
c
,
const
int
ldc
)
{
cblas_sgemm
(
order
,
trans_a
,
trans_b
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
void
cblas_gemm
<
float
>
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
trans_a
,
const
enum
CBLAS_TRANSPOSE
trans_b
,
const
int
m
,
const
int
n
,
const
int
k
,
const
float
alpha
,
const
float
*
a
,
const
int
lda
,
const
float
*
b
,
const
int
ldb
,
const
float
beta
,
float
*
c
,
const
int
ldc
)
{
cblas_sgemm
(
order
,
trans_a
,
trans_b
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
}
template
<
>
void
cblas_gemm
<
double
>
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
trans_a
,
const
enum
CBLAS_TRANSPOSE
trans_b
,
const
int
m
,
const
int
n
,
const
int
k
,
const
double
alpha
,
const
double
*
a
,
const
int
lda
,
const
double
*
b
,
const
int
ldb
,
const
double
beta
,
double
*
c
,
const
int
ldc
)
{
cblas_dgemm
(
order
,
trans_a
,
trans_b
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
void
cblas_gemm
<
double
>
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
trans_a
,
const
enum
CBLAS_TRANSPOSE
trans_b
,
const
int
m
,
const
int
n
,
const
int
k
,
const
double
alpha
,
const
double
*
a
,
const
int
lda
,
const
double
*
b
,
const
int
ldb
,
const
double
beta
,
double
*
c
,
const
int
ldc
)
{
cblas_dgemm
(
order
,
trans_a
,
trans_b
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
);
}
}
// namespace oneflow
oneflow/core/blas/cblas_template.h
浏览文件 @
03db51d7
...
...
@@ -10,59 +10,52 @@ namespace oneflow {
// level 1 vector and vector
// dot product
template
<
typename
FloatingPointType
>
FloatingPointType
cblas_dot
(
const
int
n
,
const
FloatingPointType
*
x
,
const
int
incx
,
const
FloatingPointType
*
y
,
const
int
incy
);
FloatingPointType
cblas_dot
(
const
int
n
,
const
FloatingPointType
*
x
,
const
int
incx
,
const
FloatingPointType
*
y
,
const
int
incy
);
// swap x and y
template
<
typename
FloatingPointType
>
void
cblas_swap
(
const
int
n
,
FloatingPointType
*
x
,
const
int
incx
,
FloatingPointType
*
y
,
const
int
incy
);
void
cblas_swap
(
const
int
n
,
FloatingPointType
*
x
,
const
int
incx
,
FloatingPointType
*
y
,
const
int
incy
);
// copy x into y
template
<
typename
FloatingPointType
>
void
cblas_copy
(
const
int
n
,
const
FloatingPointType
*
x
,
const
int
incx
,
FloatingPointType
*
y
,
const
int
incy
);
void
cblas_copy
(
const
int
n
,
const
FloatingPointType
*
x
,
const
int
incx
,
FloatingPointType
*
y
,
const
int
incy
);
// y = a*x + y
template
<
typename
FloatingPointType
>
void
cblas_axpy
(
const
int
n
,
const
FloatingPointType
alpha
,
const
FloatingPointType
*
x
,
const
int
incx
,
FloatingPointType
*
y
,
const
int
incy
);
void
cblas_axpy
(
const
int
n
,
const
FloatingPointType
alpha
,
const
FloatingPointType
*
x
,
const
int
incx
,
FloatingPointType
*
y
,
const
int
incy
);
// x = a*x
template
<
typename
FloatingPointType
>
void
cblas_scal
(
const
int
n
,
const
FloatingPointType
alpha
,
FloatingPointType
*
x
,
const
int
incx
);
void
cblas_scal
(
const
int
n
,
const
FloatingPointType
alpha
,
FloatingPointType
*
x
,
const
int
incx
);
// level 2 matrix and vector
// matrix vector multiply
template
<
typename
FloatingPointType
>
void
cblas_gemv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
trans_a
,
const
int
m
,
const
int
n
,
const
FloatingPointType
alph
a
,
const
FloatingPointType
*
a
,
const
int
lda
,
const
FloatingPointType
*
x
,
const
int
incx
,
const
FloatingPointType
beta
,
FloatingPointType
*
y
,
const
int
incy
);
void
cblas_gemv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
trans_a
,
const
int
m
,
const
int
n
,
const
FloatingPointType
alpha
,
const
FloatingPointType
*
a
,
const
int
lda
,
const
FloatingPointType
*
x
,
const
int
incx
,
const
FloatingPointType
beta
,
FloatingPointType
*
y
,
const
int
incy
);
// level 3 matrix and matrix
// matrix matrix multiply
template
<
typename
FloatingPointType
>
void
cblas_gemm
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
trans_a
,
const
enum
CBLAS_TRANSPOSE
trans_b
,
const
int
m
,
const
int
n
,
const
int
k
,
const
FloatingPointType
alpha
,
const
FloatingPointType
*
a
,
const
int
lda
,
const
FloatingPointType
*
b
,
const
int
ldb
,
const
FloatingPointType
beta
,
FloatingPointType
*
c
,
const
int
ldc
);
void
cblas_gemm
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
trans_a
,
const
enum
CBLAS_TRANSPOSE
trans_b
,
const
int
m
,
const
int
n
,
const
int
k
,
const
FloatingPointType
alpha
,
const
FloatingPointType
*
a
,
const
int
lda
,
const
FloatingPointType
*
b
,
const
int
ldb
,
const
FloatingPointType
beta
,
FloatingPointType
*
c
,
const
int
ldc
);
}
// namespace oneflow
#endif // ONEFLOW_CORE_BLAS_CBLAS_TEMPLATE_H_
oneflow/core/blas/cublas_template.cu
浏览文件 @
03db51d7
...
...
@@ -5,125 +5,117 @@ namespace oneflow {
// level 1 vector and vector
// dot product
template
<
>
void
cublas_dot
<
float
>
(
cublasHandle_t
handle
,
int
n
,
const
float
*
x
,
int
incx
,
const
float
*
y
,
int
incy
,
float
*
result
)
{
void
cublas_dot
<
float
>
(
cublasHandle_t
handle
,
int
n
,
const
float
*
x
,
int
incx
,
const
float
*
y
,
int
incy
,
float
*
result
)
{
CHECK_EQ
(
cublasSdot
(
handle
,
n
,
x
,
incx
,
y
,
incy
,
result
),
CUBLAS_STATUS_SUCCESS
);
}
template
<
>
void
cublas_dot
<
double
>
(
cublasHandle_t
handle
,
int
n
,
const
double
*
x
,
int
incx
,
const
double
*
y
,
int
incy
,
double
*
result
)
{
void
cublas_dot
<
double
>
(
cublasHandle_t
handle
,
int
n
,
const
double
*
x
,
int
incx
,
const
double
*
y
,
int
incy
,
double
*
result
)
{
CHECK_EQ
(
cublasDdot
(
handle
,
n
,
x
,
incx
,
y
,
incy
,
result
),
CUBLAS_STATUS_SUCCESS
);
}
// swap x and y
template
<
>
void
cublas_swap
<
float
>
(
cublasHandle_t
handle
,
int
n
,
float
*
x
,
int
incx
,
float
*
y
,
int
incy
)
{
void
cublas_swap
<
float
>
(
cublasHandle_t
handle
,
int
n
,
float
*
x
,
int
incx
,
float
*
y
,
int
incy
)
{
CHECK_EQ
(
cublasSswap
(
handle
,
n
,
x
,
incx
,
y
,
incy
),
CUBLAS_STATUS_SUCCESS
);
}
template
<
>
void
cublas_swap
<
double
>
(
cublasHandle_t
handle
,
int
n
,
double
*
x
,
int
incx
,
double
*
y
,
int
incy
)
{
void
cublas_swap
<
double
>
(
cublasHandle_t
handle
,
int
n
,
double
*
x
,
int
incx
,
double
*
y
,
int
incy
)
{
CHECK_EQ
(
cublasDswap
(
handle
,
n
,
x
,
incx
,
y
,
incy
),
CUBLAS_STATUS_SUCCESS
);
}
// copy x into y
template
<
>
void
cublas_copy
<
float
>
(
cublasHandle_t
handle
,
int
n
,
const
float
*
x
,
int
incx
,
float
*
y
,
int
incy
)
{
void
cublas_copy
<
float
>
(
cublasHandle_t
handle
,
int
n
,
const
float
*
x
,
int
incx
,
float
*
y
,
int
incy
)
{
CHECK_EQ
(
cublasScopy
(
handle
,
n
,
x
,
incx
,
y
,
incy
),
CUBLAS_STATUS_SUCCESS
);
}
template
<
>
void
cublas_copy
<
double
>
(
cublasHandle_t
handle
,
int
n
,
const
double
*
x
,
int
incx
,
double
*
y
,
int
incy
)
{
void
cublas_copy
<
double
>
(
cublasHandle_t
handle
,
int
n
,
const
double
*
x
,
int
incx
,
double
*
y
,
int
incy
)
{
CHECK_EQ
(
cublasDcopy
(
handle
,
n
,
x
,
incx
,
y
,
incy
),
CUBLAS_STATUS_SUCCESS
);
}
// y = a*x + y
template
<
>
void
cublas_axpy
<
float
>
(
cublasHandle_t
handle
,
int
n
,
const
float
*
alpha
,
const
float
*
x
,
const
int
incx
,
float
*
y
,
const
int
incy
)
{
void
cublas_axpy
<
float
>
(
cublasHandle_t
handle
,
int
n
,
const
float
*
alpha
,
const
float
*
x
,
const
int
incx
,
float
*
y
,
const
int
incy
)
{
CHECK_EQ
(
cublasSaxpy
(
handle
,
n
,
alpha
,
x
,
incx
,
y
,
incy
),
CUBLAS_STATUS_SUCCESS
);
}
template
<
>
void
cublas_axpy
<
double
>
(
cublasHandle_t
handle
,
int
n
,
const
double
*
alpha
,
const
double
*
x
,
const
int
incx
,
double
*
y
,
int
incy
)
{
void
cublas_axpy
<
double
>
(
cublasHandle_t
handle
,
int
n
,
const
double
*
alpha
,
const
double
*
x
,
const
int
incx
,
double
*
y
,
int
incy
)
{
CHECK_EQ
(
cublasDaxpy
(
handle
,
n
,
alpha
,
x
,
incx
,
y
,
incy
),
CUBLAS_STATUS_SUCCESS
);
}
// x = a*x
template
<
>
void
cublas_scal
<
float
>
(
cublasHandle_t
handle
,
int
n
,
const
float
*
alpha
,
float
*
x
,
int
incx
)
{
void
cublas_scal
<
float
>
(
cublasHandle_t
handle
,
int
n
,
const
float
*
alpha
,
float
*
x
,
int
incx
)
{
CHECK_EQ
(
cublasSscal
(
handle
,
n
,
alpha
,
x
,
incx
),
CUBLAS_STATUS_SUCCESS
);
}
template
<
>
void
cublas_scal
<
double
>
(
cublasHandle_t
handle
,
int
n
,
const
double
*
alpha
,
double
*
x
,
int
incx
)
{
void
cublas_scal
<
double
>
(
cublasHandle_t
handle
,
int
n
,
const
double
*
alpha
,
double
*
x
,
int
incx
)
{
CHECK_EQ
(
cublasDscal
(
handle
,
n
,
alpha
,
x
,
incx
),
CUBLAS_STATUS_SUCCESS
);
}
// level 2 matrix and vector
// matrix vector multiply
template
<
>
void
cublas_gemv
<
float
>
(
cublasHandle_t
handle
,
cublasOperation_t
trans
,
int
m
,
int
n
,
const
float
*
alpha
,
const
float
*
a
,
int
lda
,
const
float
*
x
,
int
incx
,
const
float
*
beta
,
float
*
y
,
int
incy
)
{
CHECK_EQ
(
cublasSgemv
(
handle
,
trans
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
),
CUBLAS_STATUS_SUCCESS
);
void
cublas_gemv
<
float
>
(
cublasHandle_t
handle
,
cublasOperation_t
trans
,
int
m
,
int
n
,
const
float
*
alpha
,
const
float
*
a
,
int
lda
,
const
float
*
x
,
int
incx
,
const
float
*
beta
,
float
*
y
,
int
incy
)
{
CHECK_EQ
(
cublasSgemv
(
handle
,
trans
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
),
CUBLAS_STATUS_SUCCESS
);
}
template
<
>
void
cublas_gemv
<
double
>
(
cublasHandle_t
handle
,
cublasOperation_t
trans
,
int
m
,
int
n
,
const
double
*
alpha
,
const
double
*
a
,
int
lda
,
const
double
*
x
,
int
incx
,
const
double
*
beta
,
double
*
y
,
int
incy
)
{
CHECK_EQ
(
cublasDgemv
(
handle
,
trans
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
),
CUBLAS_STATUS_SUCCESS
);
void
cublas_gemv
<
double
>
(
cublasHandle_t
handle
,
cublasOperation_t
trans
,
int
m
,
int
n
,
const
double
*
alpha
,
const
double
*
a
,
int
lda
,
const
double
*
x
,
int
incx
,
const
double
*
beta
,
double
*
y
,
int
incy
)
{
CHECK_EQ
(
cublasDgemv
(
handle
,
trans
,
m
,
n
,
alpha
,
a
,
lda
,
x
,
incx
,
beta
,
y
,
incy
),
CUBLAS_STATUS_SUCCESS
);
}
// level 3 matrix and matrix
// matrix matrix multiply
template
<
>
void
cublas_gemm
<
float
>
(
cublasHandle_t
handle
,
cublasOperation_t
cutrans_a
,
cublasOperation_t
cutrans_b
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
a
,
int
lda
,
const
float
*
b
,
int
ldb
,
const
float
*
beta
,
float
*
c
,
int
ldc
)
{
CHECK_EQ
(
cublasSgemm
(
handle
,
cutrans_a
,
cutrans_b
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
),
void
cublas_gemm
<
float
>
(
cublasHandle_t
handle
,
cublasOperation_t
cutrans_a
,
cublasOperation_t
cutrans_b
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
a
,
int
lda
,
const
float
*
b
,
int
ldb
,
const
float
*
beta
,
float
*
c
,
int
ldc
)
{
CHECK_EQ
(
cublasSgemm
(
handle
,
cutrans_a
,
cutrans_b
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
),
CUBLAS_STATUS_SUCCESS
);
}
template
<
>
void
cublas_gemm
<
double
>
(
cublasHandle_t
handle
,
cublasOperation_t
cutrans_a
,
cublasOperation_t
cutrans_b
,
int
m
,
int
n
,
int
k
,
const
double
*
alpha
,
const
double
*
a
,
int
lda
,
const
double
*
b
,
int
ldb
,
const
double
*
beta
,
double
*
c
,
int
ldc
)
{
CHECK_EQ
(
cublasDgemm
(
handle
,
cutrans_a
,
cutrans_b
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
),
void
cublas_gemm
<
double
>
(
cublasHandle_t
handle
,
cublasOperation_t
cutrans_a
,
cublasOperation_t
cutrans_b
,
int
m
,
int
n
,
int
k
,
const
double
*
alpha
,
const
double
*
a
,
int
lda
,
const
double
*
b
,
int
ldb
,
const
double
*
beta
,
double
*
c
,
int
ldc
)
{
CHECK_EQ
(
cublasDgemm
(
handle
,
cutrans_a
,
cutrans_b
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
b
,
ldb
,
beta
,
c
,
ldc
),
CUBLAS_STATUS_SUCCESS
);
}
...
...
oneflow/core/blas/cublas_template.h
浏览文件 @
03db51d7
...
...
@@ -8,57 +8,48 @@ namespace oneflow {
// level 1 vector and vector
// dot product
template
<
typename
FloatingPointType
>
void
cublas_dot
(
cublasHandle_t
handle
,
int
n
,
const
FloatingPointType
*
x
,
int
incx
,
const
FloatingPointType
*
y
,
int
incy
,
FloatingPointType
*
result
);
void
cublas_dot
(
cublasHandle_t
handle
,
int
n
,
const
FloatingPointType
*
x
,
int
incx
,
const
FloatingPointType
*
y
,
int
incy
,
FloatingPointType
*
result
);
// swap x and y
template
<
typename
FloatingPointType
>
void
cublas_swap
(
cublasHandle_t
handle
,
int
n
,
FloatingPointType
*
x
,
int
incx
,
FloatingPointType
*
y
,
int
incy
);
void
cublas_swap
(
cublasHandle_t
handle
,
int
n
,
FloatingPointType
*
x
,
int
incx
,
FloatingPointType
*
y
,
int
incy
);
// copy x into y
template
<
typename
FloatingPointType
>
void
cublas_copy
(
cublasHandle_t
handle
,
int
n
,
const
FloatingPointType
*
x
,
int
incx
,
FloatingPointType
*
y
,
int
incy
);
void
cublas_copy
(
cublasHandle_t
handle
,
int
n
,
const
FloatingPointType
*
x
,
int
incx
,
FloatingPointType
*
y
,
int
incy
);
// y = a*x + y
template
<
typename
FloatingPointType
>
void
cublas_axpy
(
cublasHandle_t
handle
,
int
n
,
const
FloatingPointType
*
alpha
,
const
FloatingPointType
*
x
,
int
incx
,
FloatingPointType
*
y
,
int
incy
);
void
cublas_axpy
(
cublasHandle_t
handle
,
int
n
,
const
FloatingPointType
*
alpha
,
const
FloatingPointType
*
x
,
int
incx
,
FloatingPointType
*
y
,
int
incy
);
// x = a*x
template
<
typename
FloatingPointType
>
void
cublas_scal
(
cublasHandle_t
handle
,
int
n
,
const
FloatingPointType
*
alpha
,
FloatingPointType
*
x
,
int
incx
);
void
cublas_scal
(
cublasHandle_t
handle
,
int
n
,
const
FloatingPointType
*
alpha
,
FloatingPointType
*
x
,
int
incx
);
// level 2 matrix and vector
// matrix vector multiply
template
<
typename
FloatingPointType
>
void
cublas_gemv
(
cublasHandle_t
handle
,
cublasOperation_t
trans
,
int
m
,
int
n
,
const
FloatingPointType
*
alpha
,
const
FloatingPointType
*
a
,
int
lda
,
const
FloatingPointType
*
x
,
int
incx
,
const
FloatingPointType
*
beta
,
FloatingPointType
*
y
,
int
incy
);
void
cublas_gemv
(
cublasHandle_t
handle
,
cublasOperation_t
trans
,
int
m
,
int
n
,
const
FloatingPointType
*
alpha
,
const
FloatingPointType
*
a
,
int
lda
,
const
FloatingPointType
*
x
,
int
incx
,
const
FloatingPointType
*
beta
,
FloatingPointType
*
y
,
int
incy
);
// level 3 matrix and matrix
// matrix matrix multiply
template
<
typename
FloatingPointType
>
void
cublas_gemm
(
cublasHandle_t
handle
,
cublasOperation_t
cutrans_a
,
cublasOperation_t
cutrans_b
,
int
m
,
int
n
,
int
k
,
const
FloatingPointType
*
alpha
,
const
FloatingPointType
*
a
,
int
lda
,
const
FloatingPointType
*
b
,
int
ldb
,
const
FloatingPointType
*
beta
,
FloatingPointType
*
c
,
int
ldc
);
void
cublas_gemm
(
cublasHandle_t
handle
,
cublasOperation_t
cutrans_a
,
cublasOperation_t
cutrans_b
,
int
m
,
int
n
,
int
k
,
const
FloatingPointType
*
alpha
,
const
FloatingPointType
*
a
,
int
lda
,
const
FloatingPointType
*
b
,
int
ldb
,
const
FloatingPointType
*
beta
,
FloatingPointType
*
c
,
int
ldc
);
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_BLAS_CUBLAS_TEMPLATE_H_
#endif
// ONEFLOW_CORE_BLAS_CUBLAS_TEMPLATE_H_
oneflow/core/common/balanced_splitter.cpp
浏览文件 @
03db51d7
...
...
@@ -17,10 +17,10 @@ Range BalancedSplitter::At(int64_t idx) const {
upper_pound_num
=
lower_pound_num
+
(
size_per_range_
+
1
);
}
else
{
lower_pound_num
=
(
size_per_range_
+
1
)
*
change_pos_
+
size_per_range_
*
(
idx
-
change_pos_
);
+
size_per_range_
*
(
idx
-
change_pos_
);
upper_pound_num
=
lower_pound_num
+
size_per_range_
;
}
return
Range
(
lower_pound_num
,
upper_pound_num
);
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/common/balanced_splitter.h
浏览文件 @
03db51d7
...
...
@@ -2,8 +2,8 @@
#define ONEFLOW_CORE_COMMON_BALANCED_SPLITTER_H_
#include <stdint.h>
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/range.h"
#include "oneflow/core/common/util.h"
namespace
oneflow
{
...
...
@@ -27,11 +27,11 @@ class BalancedSplitter final {
Range
At
(
int64_t
idx
)
const
;
private:
int64_t
size_per_range_
;
int64_t
change_pos_
;
int64_t
split_num_
;
int64_t
size_per_range_
;
int64_t
change_pos_
;
int64_t
split_num_
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_BALANCED_SPLITTER_H_
#endif
// ONEFLOW_CORE_COMMON_BALANCED_SPLITTER_H_
oneflow/core/common/balanced_splitter_test.cpp
浏览文件 @
03db51d7
...
...
@@ -19,4 +19,4 @@ TEST(BalancedSplitter, split_2_to_3_part) {
ASSERT_TRUE
(
splitter
.
At
(
2
)
==
Range
(
2
,
2
));
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/common/channel.h
浏览文件 @
03db51d7
...
...
@@ -26,7 +26,8 @@ class Channel final {
// close the channel's send end, the thread can't send item to the channel
void
CloseSendEnd
();
// close the channel's receive end , the thread can't receive item from channel
// close the channel's receive end , the thread can't receive item from
// channel
void
CloseReceiveEnd
();
private:
...
...
@@ -40,9 +41,7 @@ class Channel final {
template
<
typename
T
>
int
Channel
<
T
>::
Send
(
const
T
&
item
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
is_send_closed_
)
{
return
-
1
;
}
if
(
is_send_closed_
)
{
return
-
1
;
}
val_
.
push
(
item
);
cond_
.
notify_one
();
return
0
;
...
...
@@ -51,10 +50,10 @@ int Channel<T>::Send(const T& item) {
template
<
typename
T
>
int
Channel
<
T
>::
Receive
(
T
*
item
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cond_
.
wait
(
lock
,
[
this
]()
{
return
!
val_
.
empty
()
||
is_receive_closed_
||
is_send_closed_
;
});
if
(
val_
.
empty
()
||
is_receive_closed_
)
{
return
-
1
;
}
cond_
.
wait
(
lock
,
[
this
]()
{
return
!
val_
.
empty
()
||
is_receive_closed_
||
is_send_closed_
;
})
;
if
(
val_
.
empty
()
||
is_receive_closed_
)
{
return
-
1
;
}
*
item
=
val_
.
front
();
val_
.
pop
();
return
0
;
...
...
@@ -76,4 +75,4 @@ void Channel<T>::CloseReceiveEnd() {
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_CHANNEL_H_
#endif
// ONEFLOW_CORE_COMMON_CHANNEL_H_
oneflow/core/common/channel_test.cpp
浏览文件 @
03db51d7
...
...
@@ -5,22 +5,16 @@ namespace oneflow {
void
CallFromSenderThread
(
Channel
<
int
>*
channel
,
Range
range
)
{
for
(
int
i
=
range
.
begin
();
i
<
range
.
end
();
++
i
)
{
if
(
channel
->
Send
(
i
)
==
-
1
)
{
break
;
}
if
(
channel
->
Send
(
i
)
==
-
1
)
{
break
;
}
}
}
void
CallFromReceiverThread
(
std
::
vector
<
int
>*
visit
,
Channel
<
int
>*
channel
)
{
void
CallFromReceiverThread
(
std
::
vector
<
int
>*
visit
,
Channel
<
int
>*
channel
)
{
int
num
=
-
1
;
int
*
num_ptr
=
&
num
;
while
(
channel
->
Receive
(
num_ptr
)
==
0
)
{
++
visit
->
at
(
*
num_ptr
);
}
while
(
channel
->
Receive
(
num_ptr
)
==
0
)
{
++
visit
->
at
(
*
num_ptr
);
}
}
TEST
(
Channel
,
30s
ender40receiver
)
{
Channel
<
int
>
channel
;
std
::
vector
<
std
::
thread
>
senders
;
...
...
@@ -31,34 +25,24 @@ TEST(Channel, 30sender40receiver) {
std
::
vector
<
std
::
vector
<
int
>>
visits
;
for
(
int
i
=
0
;
i
<
receiver_num
;
++
i
)
{
std
::
vector
<
int
>
visit_i
;
for
(
int
j
=
0
;
j
<
range_num
;
j
++
)
{
visit_i
.
push_back
(
0
);
}
for
(
int
j
=
0
;
j
<
range_num
;
j
++
)
{
visit_i
.
push_back
(
0
);
}
visits
.
push_back
(
visit_i
);
}
for
(
int
i
=
0
;
i
<
sender_num
;
++
i
)
{
senders
.
push_back
(
std
::
thread
(
CallFromSenderThread
,
&
channel
,
Range
(
0
,
range_num
)));
senders
.
push_back
(
std
::
thread
(
CallFromSenderThread
,
&
channel
,
Range
(
0
,
range_num
)));
}
for
(
int
i
=
0
;
i
<
receiver_num
;
++
i
)
{
receivers
.
push_back
(
std
::
thread
(
CallFromReceiverThread
,
&
visits
[
i
],
&
channel
));
}
for
(
std
::
thread
&
this_thread
:
senders
)
{
this_thread
.
join
();
receivers
.
push_back
(
std
::
thread
(
CallFromReceiverThread
,
&
visits
[
i
],
&
channel
));
}
for
(
std
::
thread
&
this_thread
:
senders
)
{
this_thread
.
join
();
}
channel
.
CloseSendEnd
();
for
(
std
::
thread
&
this_thread
:
receivers
)
{
this_thread
.
join
();
}
for
(
std
::
thread
&
this_thread
:
receivers
)
{
this_thread
.
join
();
}
channel
.
CloseReceiveEnd
();
for
(
int
i
=
0
;
i
<
range_num
;
++
i
)
{
int
visit_count
=
0
;
for
(
int
j
=
0
;
j
<
receiver_num
;
j
++
)
{
visit_count
+=
visits
[
j
][
i
];
}
for
(
int
j
=
0
;
j
<
receiver_num
;
j
++
)
{
visit_count
+=
visits
[
j
][
i
];
}
ASSERT_EQ
(
visit_count
,
sender_num
);
}
}
...
...
oneflow/core/common/cuda_stream_handle.cpp
浏览文件 @
03db51d7
...
...
@@ -29,15 +29,9 @@ const cudnnHandle_t* CudaStreamHandle::cudnn_handle() {
}
CudaStreamHandle
::~
CudaStreamHandle
()
{
if
(
cudnn_handle_
)
{
CHECK_EQ
(
cudnnDestroy
(
*
cudnn_handle_
),
0
);
}
if
(
cublas_handle_
)
{
CHECK_EQ
(
cublasDestroy
(
*
cublas_handle_
),
0
);
}
if
(
cuda_stream_
)
{
CHECK_EQ
(
cudaStreamDestroy
(
*
cuda_stream_
),
0
);
}
if
(
cudnn_handle_
)
{
CHECK_EQ
(
cudnnDestroy
(
*
cudnn_handle_
),
0
);
}
if
(
cublas_handle_
)
{
CHECK_EQ
(
cublasDestroy
(
*
cublas_handle_
),
0
);
}
if
(
cuda_stream_
)
{
CHECK_EQ
(
cudaStreamDestroy
(
*
cuda_stream_
),
0
);
}
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/common/cuda_stream_handle.h
浏览文件 @
03db51d7
...
...
@@ -20,9 +20,8 @@ class CudaStreamHandle final {
std
::
unique_ptr
<
cudaStream_t
>
cuda_stream_
;
std
::
unique_ptr
<
cublasHandle_t
>
cublas_handle_
;
std
::
unique_ptr
<
cudnnHandle_t
>
cudnn_handle_
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_CUDA_STREAM_HANDLE_H_
#endif
// ONEFLOW_CORE_COMMON_CUDA_STREAM_HANDLE_H_
oneflow/core/common/cuda_util.h
浏览文件 @
03db51d7
...
...
@@ -10,21 +10,18 @@ inline void CudaCheck(cudaError_t error) {
}
// CUDA: grid stride looping
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
// CUDA: check for error after kernel execution and exit loudly if there is one.
inline
void
CudaPostKernelCheck
()
{
CudaCheck
(
cudaPeekAtLastError
());
}
inline
void
CudaPostKernelCheck
()
{
CudaCheck
(
cudaPeekAtLastError
());
}
const
int32_t
kCudaThreadsNumPerBlock
=
512
;
const
int32_t
kCudaMaxBlocksNum
=
4096
;
inline
int32_t
BlocksNum4ThreadsNum
(
const
int32_t
N
)
{
return
std
::
min
((
N
+
kCudaThreadsNumPerBlock
-
1
)
/
kCudaThreadsNumPerBlock
,
return
std
::
min
((
N
+
kCudaThreadsNumPerBlock
-
1
)
/
kCudaThreadsNumPerBlock
,
kCudaMaxBlocksNum
);
}
...
...
oneflow/core/common/process_state.h
浏览文件 @
03db51d7
...
...
@@ -2,9 +2,9 @@
#define ONEFLOW_CORE_COMMON_PROCESS_STATE_H_
#if defined(_MSC_VER)
#include <WinSock2.h>
#include <direct.h>
#include <stdlib.h>
#include <WinSock2.h>
#pragma comment(lib, "Ws2_32.lib")
#else
#include <unistd.h>
...
...
@@ -32,4 +32,3 @@ std::string GetCwd() {
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_PROCESS_STATE_H_
oneflow/core/common/protobuf.cpp
浏览文件 @
03db51d7
...
...
@@ -35,8 +35,8 @@ void ParseProtoFromTextFile(const std::string& file_path, PbMessage* proto) {
}
void
PrintProtoToTextFile
(
const
PbMessage
&
proto
,
const
std
::
string
&
file_path
)
{
std
::
ofstream
out_stream
(
file_path
.
c_str
(),
std
::
ofstream
::
out
|
std
::
ofstream
::
trunc
);
std
::
ofstream
out_stream
(
file_path
.
c_str
(),
std
::
ofstream
::
out
|
std
::
ofstream
::
trunc
);
// make sure out_stream lives longer than output
{
OstreamOutputStream
output
(
&
out_stream
);
...
...
@@ -45,15 +45,15 @@ void PrintProtoToTextFile(const PbMessage& proto,
out_stream
.
close
();
}
#define DEFINE_GET_VAL_FROM_PBMESSAGE(ret_type, func_name) \
ret_type Get##func_name##FromPbMessage(const PbMessage& msg,
\
const std::string& field_name) { \
const Descriptor* d = msg.GetDescriptor();
\
const FieldDescriptor* fd = d->FindFieldByName(field_name);
\
CHECK_NOTNULL(fd);
\
const Reflection* r = msg.GetReflection();
\
return r->Get##func_name (msg, fd);
\
}
#define DEFINE_GET_VAL_FROM_PBMESSAGE(ret_type, func_name)
\
ret_type Get##func_name##FromPbMessage(const PbMessage& msg,
\
const std::string& field_name) { \
const Descriptor* d = msg.GetDescriptor();
\
const FieldDescriptor* fd = d->FindFieldByName(field_name);
\
CHECK_NOTNULL(fd);
\
const Reflection* r = msg.GetReflection();
\
return r->Get##func_name(msg, fd);
\
}
DEFINE_GET_VAL_FROM_PBMESSAGE
(
std
::
string
,
String
);
DEFINE_GET_VAL_FROM_PBMESSAGE
(
int32_t
,
Int32
);
...
...
@@ -61,4 +61,4 @@ DEFINE_GET_VAL_FROM_PBMESSAGE(uint32_t, UInt32);
DEFINE_GET_VAL_FROM_PBMESSAGE
(
int64_t
,
Int64
);
DEFINE_GET_VAL_FROM_PBMESSAGE
(
uint64_t
,
UInt64
);
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/common/protobuf.h
浏览文件 @
03db51d7
...
...
@@ -4,10 +4,10 @@
#ifdef _MSC_VER
#include <io.h>
#endif
#include "
oneflow/core/common/util
.h"
#include "
google/protobuf/descriptor
.h"
#include "google/protobuf/map.h"
#include "google/protobuf/message.h"
#include "
google/protobuf/descriptor
.h"
#include "
oneflow/core/common/util
.h"
namespace
oneflow
{
...
...
@@ -24,16 +24,14 @@ void ParseProtoFromString(const std::string& str, PbMessage* proto);
void
PrintProtoToString
(
const
PbMessage
&
proto
,
std
::
string
*
str
);
// Prototxt <-> File
void
ParseProtoFromTextFile
(
const
std
::
string
&
file_path
,
PbMessage
*
proto
);
void
PrintProtoToTextFile
(
const
PbMessage
&
proto
,
const
std
::
string
&
file_path
);
void
ParseProtoFromTextFile
(
const
std
::
string
&
file_path
,
PbMessage
*
proto
);
void
PrintProtoToTextFile
(
const
PbMessage
&
proto
,
const
std
::
string
&
file_path
);
// Get From PbMessage
#define DECLARE_GET_VAL_FROM_PBMESSAGE(ret_type, func_name) \
ret_type Get##func_name##FromPbMessage(const PbMessage& msg, \
const std::string& field_name);
#define DECLARE_GET_VAL_FROM_PBMESSAGE(ret_type, func_name)
\
ret_type Get##func_name##FromPbMessage(const PbMessage& msg, \
const std::string& field_name);
DECLARE_GET_VAL_FROM_PBMESSAGE
(
std
::
string
,
String
);
DECLARE_GET_VAL_FROM_PBMESSAGE
(
int32_t
,
Int32
);
...
...
@@ -45,8 +43,7 @@ DECLARE_GET_VAL_FROM_PBMESSAGE(uint64_t, UInt64);
// Alias PbType
#define ALIAS_PB_TYPE(type, name) \
using Pb##name = google::protobuf::type; \
#define ALIAS_PB_TYPE(type, name) using Pb##name = google::protobuf::type;
ALIAS_PB_TYPE
(
int32
,
Int32
);
ALIAS_PB_TYPE
(
int64
,
Int64
);
...
...
@@ -55,13 +52,11 @@ ALIAS_PB_TYPE(uint64, UInt64);
#undef ALIAS_PB_TYPE
// PbRpf <-> std::vector
inline
std
::
vector
<
std
::
string
>
PbVec2StdVec
(
const
PbRpf
<
std
::
string
>&
rpf
)
{
return
std
::
vector
<
std
::
string
>
(
rpf
.
begin
(),
rpf
.
end
());
// PbRpf <-> std::vector
inline
std
::
vector
<
std
::
string
>
PbVec2StdVec
(
const
PbRpf
<
std
::
string
>&
rpf
)
{
return
std
::
vector
<
std
::
string
>
(
rpf
.
begin
(),
rpf
.
end
());
}
inline
PbRpf
<
std
::
string
>
StdVec2PbVec
(
const
std
::
vector
<
std
::
string
>&
vec
)
{
inline
PbRpf
<
std
::
string
>
StdVec2PbVec
(
const
std
::
vector
<
std
::
string
>&
vec
)
{
using
RetType
=
PbRpf
<
std
::
string
>
;
return
RetType
(
vec
.
begin
(),
vec
.
end
());
}
...
...
@@ -69,7 +64,7 @@ inline PbRpf<std::string> StdVec2PbVec (
// ProtoMap <-> HashMap
template
<
typename
K
,
typename
V
>
HashMap
<
K
,
V
>
PbMap2HashMap
(
const
google
::
protobuf
::
Map
<
K
,
V
>&
pb_map
)
{
return
HashMap
<
K
,
V
>
(
pb_map
.
begin
(),
pb_map
.
end
());
return
HashMap
<
K
,
V
>
(
pb_map
.
begin
(),
pb_map
.
end
());
}
template
<
typename
K
,
typename
V
>
...
...
@@ -79,16 +74,16 @@ google::protobuf::Map<K, V> HashMap2PbMap(const HashMap<K, V>& hash_map) {
}
// operator
inline
bool
operator
==
(
const
google
::
protobuf
::
MessageLite
&
lhs
,
const
google
::
protobuf
::
MessageLite
&
rhs
)
{
inline
bool
operator
==
(
const
google
::
protobuf
::
MessageLite
&
lhs
,
const
google
::
protobuf
::
MessageLite
&
rhs
)
{
return
lhs
.
SerializeAsString
()
==
rhs
.
SerializeAsString
();
}
inline
bool
operator
!=
(
const
google
::
protobuf
::
MessageLite
&
lhs
,
const
google
::
protobuf
::
MessageLite
&
rhs
)
{
inline
bool
operator
!=
(
const
google
::
protobuf
::
MessageLite
&
lhs
,
const
google
::
protobuf
::
MessageLite
&
rhs
)
{
return
!
(
lhs
==
rhs
);
}
}
// namespace caffe
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_PROTOBUF_H_
#endif
// ONEFLOW_CORE_COMMON_PROTOBUF_H_
oneflow/core/common/range.h
浏览文件 @
03db51d7
...
...
@@ -13,13 +13,13 @@ class Range final {
Range
(
int64_t
begin
,
int64_t
end
)
:
begin_
(
begin
),
end_
(
end
)
{}
bool
operator
==
(
const
Range
&
rhs
)
const
{
bool
operator
==
(
const
Range
&
rhs
)
const
{
return
begin_
==
rhs
.
begin_
&&
end_
==
rhs
.
end_
;
}
int64_t
begin
()
const
{
return
begin_
;
}
int64_t
end
()
const
{
return
end_
;
}
int64_t
&
mut_begin
()
{
return
begin_
;
}
int64_t
&
mut_end
()
{
return
end_
;
}
...
...
@@ -30,6 +30,6 @@ class Range final {
int64_t
end_
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_RANGE_H_
#endif
// ONEFLOW_CORE_COMMON_RANGE_H_
oneflow/core/common/shape.cpp
浏览文件 @
03db51d7
...
...
@@ -9,28 +9,23 @@ Shape::Shape(const ShapeProto& shape_proto) {
}
void
Shape
::
ToProto
(
ShapeProto
*
ret
)
const
{
*
(
ret
->
mutable_dim
())
=
PbRf
<
PbInt64
>
(
dim_vec_
.
begin
(),
dim_vec_
.
end
());
*
(
ret
->
mutable_dim
())
=
PbRf
<
PbInt64
>
(
dim_vec_
.
begin
(),
dim_vec_
.
end
());
}
std
::
string
Shape
::
DebugStr
()
const
{
std
::
stringstream
ss
;
ss
<<
"{"
;
for
(
int64_t
dim
:
dim_vec_
)
{
ss
<<
dim
<<
","
;
}
for
(
int64_t
dim
:
dim_vec_
)
{
ss
<<
dim
<<
","
;
}
ss
<<
"("
<<
elem_cnt_
<<
")}"
;
return
ss
.
str
();
}
int64_t
Shape
::
Count
(
int64_t
begin_axis
,
int64_t
end_axis
)
const
{
CHECK
(
0
<=
begin_axis
&&
begin_axis
<=
end_axis
&&
end_axis
<=
NumAxes
())
<<
"[begin_axis:"
<<
begin_axis
<<
"][end_axis:"
<<
end_axis
<<
"[begin_axis:"
<<
begin_axis
<<
"][end_axis:"
<<
end_axis
<<
"][num_axes:"
<<
NumAxes
()
<<
"]"
;
int64_t
cnt
=
1
;
for
(
int64_t
i
=
begin_axis
;
i
<
end_axis
;
++
i
)
{
cnt
*=
At
(
i
);
}
for
(
int64_t
i
=
begin_axis
;
i
<
end_axis
;
++
i
)
{
cnt
*=
At
(
i
);
}
return
cnt
;
}
...
...
@@ -42,17 +37,13 @@ int64_t Shape::CanonicalAxisIndex(int64_t axis_index) const {
void
Shape
::
UpdateElemCnt
()
{
elem_cnt_
=
1
;
for
(
int64_t
s
:
dim_vec_
)
{
elem_cnt_
*=
s
;
}
if
(
dim_vec_
.
size
()
==
0
)
{
elem_cnt_
=
0
;
}
for
(
int64_t
s
:
dim_vec_
)
{
elem_cnt_
*=
s
;
}
if
(
dim_vec_
.
size
()
==
0
)
{
elem_cnt_
=
0
;
}
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
Shape
&
shape
)
{
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
Shape
&
shape
)
{
out
<<
shape
.
DebugStr
();
return
out
;
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/common/shape.h
浏览文件 @
03db51d7
#ifndef ONEFLOW_CORE_COMMON_SHAPE_H_
#define ONEFLOW_CORE_COMMON_SHAPE_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/shape.pb.h"
#include "oneflow/core/common/util.h"
namespace
oneflow
{
...
...
@@ -13,8 +13,8 @@ class Shape final {
explicit
Shape
(
const
std
::
vector
<
int64_t
>&
dim_vec
);
Shape
(
const
ShapeProto
&
shape_proto
);
~
Shape
()
=
default
;
bool
operator
==
(
const
Shape
&
rhs
)
const
;
bool
operator
==
(
const
Shape
&
rhs
)
const
;
std
::
string
DebugStr
()
const
;
void
ToProto
(
ShapeProto
*
)
const
;
...
...
@@ -34,17 +34,15 @@ class Shape final {
std
::
vector
<
int64_t
>
dim_vec_
;
int64_t
elem_cnt_
;
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
Shape
&
shape
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
Shape
&
shape
);
inline
Shape
::
Shape
(
const
std
::
vector
<
int64_t
>&
dim_vec
)
:
dim_vec_
(
dim_vec
)
{
inline
Shape
::
Shape
(
const
std
::
vector
<
int64_t
>&
dim_vec
)
:
dim_vec_
(
dim_vec
)
{
UpdateElemCnt
();
}
inline
bool
Shape
::
operator
==
(
const
Shape
&
rhs
)
const
{
inline
bool
Shape
::
operator
==
(
const
Shape
&
rhs
)
const
{
return
dim_vec_
==
rhs
.
dim_vec_
;
}
...
...
@@ -61,6 +59,6 @@ inline int64_t Shape::Count(int64_t begin_axis) const {
return
Count
(
begin_axis
,
NumAxes
());
}
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_SHAPE_H_
#endif
// ONEFLOW_CORE_COMMON_SHAPE_H_
oneflow/core/common/util.cpp
浏览文件 @
03db51d7
...
...
@@ -31,8 +31,7 @@ uint64_t oneflow_cast(const std::string& s) {
return
ret
;
}
void
Split
(
const
std
::
string
&
text
,
const
std
::
string
&
delims
,
void
Split
(
const
std
::
string
&
text
,
const
std
::
string
&
delims
,
std
::
function
<
void
(
std
::
string
&&
)
>
Func
)
{
size_t
token_start
=
0
;
if
(
text
.
empty
())
{
return
;
}
...
...
@@ -44,4 +43,4 @@ void Split(const std::string& text,
}
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/common/util.h
浏览文件 @
03db51d7
#ifndef ONEFLOW_CORE_COMMON_UTIL_H_
#define ONEFLOW_CORE_COMMON_UTIL_H_
#include <unordered_set>
#include <unordered_map>
#include <functional>
#include <algorithm>
#include <mutex>
#include <utility>
#include <memory>
#include <thread>
#include <list>
#include <condition_variable>
#include <atomic>
#include <
queu
e>
#include <
condition_variabl
e>
#include <fstream>
#include <functional>
#include <iostream>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include <list>
#include <memory>
#include <mutex>
#include <queue>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "cublas_v2.h"
#include "cuda.h"
#include "cuda_runtime.h"
#include "cublas_v2.h"
#include "cudnn.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
namespace
oneflow
{
#define OF_DISALLOW_COPY(ClassName) \
#define OF_DISALLOW_COPY(ClassName)
\
ClassName(const ClassName&) = delete; \
ClassName& operator
=
(const ClassName&) = delete;
ClassName& operator
=
(const ClassName&) = delete;
#define OF_DISALLOW_MOVE(ClassName) \
ClassName(ClassName&&) = delete; \
ClassName& operator
=
(ClassName&&) = delete;
ClassName(ClassName&&) = delete;
\
ClassName& operator
=
(ClassName&&) = delete;
#define OF_DISALLOW_COPY_AND_MOVE(ClassName) \
OF_DISALLOW_COPY(ClassName) \
OF_DISALLOW_COPY(ClassName)
\
OF_DISALLOW_MOVE(ClassName)
#define UNEXPECTED_RUN() \
LOG(FATAL) << "Unexpected Run";
#define UNEXPECTED_RUN() LOG(FATAL) << "Unexpected Run";
#define TODO() \
LOG(FATAL) << "TODO";
#define TODO() LOG(FATAL) << "TODO";
#define OF_SINGLETON(ClassName) \
#define OF_SINGLETON(ClassName)
\
static ClassName& Singleton() { \
static ClassName obj; \
return obj; \
static ClassName obj;
\
return obj;
\
}
template
<
typename
T
>
bool
operator
==
(
const
std
::
weak_ptr
<
T
>&
lhs
,
const
std
::
weak_ptr
<
T
>&
rhs
)
{
bool
operator
==
(
const
std
::
weak_ptr
<
T
>&
lhs
,
const
std
::
weak_ptr
<
T
>&
rhs
)
{
return
lhs
.
lock
().
get
()
==
rhs
.
lock
().
get
();
}
...
...
@@ -83,9 +81,7 @@ inline std::string LogDir() {
inline
void
str_replace
(
std
::
string
*
str
,
char
old_ch
,
char
new_ch
)
{
for
(
size_t
i
=
0
;
i
<
str
->
size
();
++
i
)
{
if
(
str
->
at
(
i
)
==
old_ch
)
{
str
->
at
(
i
)
=
new_ch
;
}
if
(
str
->
at
(
i
)
==
old_ch
)
{
str
->
at
(
i
)
=
new_ch
;
}
}
}
...
...
@@ -102,30 +98,26 @@ void EraseIf(HashMap<K, V>* hash_map,
}
#define OF_DECLARE_ENUM_TO_OSTREAM_FUNC(EnumType) \
std::ostream& operator <<
(std::ostream& out_stream, const EnumType&)
std::ostream& operator<<
(std::ostream& out_stream, const EnumType&)
#define OF_DEFINE_ENUM_TO_OSTREAM_FUNC(EnumType) \
std::ostream& operator <<
(std::ostream& out_stream, const EnumType& x) { \
out_stream << static_cast<int> (x);
\
return out_stream;
\
}
#define OF_DEFINE_ENUM_TO_OSTREAM_FUNC(EnumType)
\
std::ostream& operator<<
(std::ostream& out_stream, const EnumType& x) { \
out_stream << static_cast<int>(x);
\
return out_stream;
\
}
template
<
typename
OutType
,
typename
InType
>
OutType
oneflow_cast
(
const
InType
&
);
void
Split
(
const
std
::
string
&
text
,
const
std
::
string
&
delims
,
void
Split
(
const
std
::
string
&
text
,
const
std
::
string
&
delims
,
std
::
function
<
void
(
std
::
string
&&
)
>
Func
);
template
<
typename
T
>
void
SplitAndParseAs
(
const
std
::
string
&
text
,
const
std
::
string
&
delims
,
void
SplitAndParseAs
(
const
std
::
string
&
text
,
const
std
::
string
&
delims
,
std
::
function
<
void
(
T
&&
)
>
Func
)
{
Split
(
text
,
delims
,
[
&
Func
](
std
::
string
&&
s
)
{
Func
(
oneflow_cast
<
T
>
(
s
));
});
Split
(
text
,
delims
,
[
&
Func
](
std
::
string
&&
s
)
{
Func
(
oneflow_cast
<
T
>
(
s
));
});
}
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_UTIL_H_
#endif
// ONEFLOW_CORE_COMMON_UTIL_H_
oneflow/core/graph/boxing_task_node.cpp
浏览文件 @
03db51d7
#include "oneflow/core/graph/boxing_task_node.h"
#include "oneflow/core/operator/operator_manager.h"
#include "oneflow/core/operator/boxing_op.h"
#include "oneflow/core/operator/operator_manager.h"
namespace
oneflow
{
...
...
@@ -31,7 +31,7 @@ void FwCompleteBoxOpConfFakerMdUpdt(BoxingOpConf* conf) {
conf
->
mutable_clone_box
();
}
}
// namespace
}
// namespace
void
BoxingTaskNode
::
FwBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
gph
)
{
EnrollAllRegstAndBindRelatedEdge
();
...
...
@@ -68,14 +68,14 @@ void BoxingTaskNode::FwInitChain2SortedEdgesMaps(
}
for
(
auto
&
pair
:
*
chain2sorted_edges
)
{
std
::
vector
<
const
TaskEdge
*>&
edges
=
pair
.
second
;
std
::
sort
(
edges
.
begin
(),
edges
.
end
(),
[
&
edge2stage
](
const
TaskEdge
*
lhs
,
const
TaskEdge
*
rhs
)
{
const
StageNode
*
lhs_stage
=
edge2stage
.
at
(
lhs
);
const
StageNode
*
rhs_stage
=
edge2stage
.
at
(
rhs
);
CHECK
(
lhs_stage
->
chain_node
()
==
rhs_stage
->
chain_node
());
return
lhs_stage
->
parallel_range
().
begin
()
<
rhs_stage
->
parallel_range
().
begin
();
});
std
::
sort
(
edges
.
begin
(),
edges
.
end
(),
[
&
edge2stage
](
const
TaskEdge
*
lhs
,
const
TaskEdge
*
rhs
)
{
const
StageNode
*
lhs_stage
=
edge2stage
.
at
(
lhs
);
const
StageNode
*
rhs_stage
=
edge2stage
.
at
(
rhs
);
CHECK
(
lhs_stage
->
chain_node
()
==
rhs_stage
->
chain_node
());
return
lhs_stage
->
parallel_range
().
begin
()
<
rhs_stage
->
parallel_range
().
begin
();
});
}
}
...
...
@@ -91,12 +91,12 @@ void BoxingTaskNode::FwSortEdgesInnerStage(
}
return
ret
;
};
std
::
sort
(
edges_to_be_sorted
->
begin
(),
edges_to_be_sorted
->
end
(),
[
&
]
(
const
TaskEdge
*
lhs
,
const
TaskEdge
*
rhs
)
{
const
CompTaskNode
*
lhs_node
=
GetPredSuccCompTaskNode
(
lhs
);
const
CompTaskNode
*
rhs_node
=
GetPredSuccCompTaskNode
(
rhs
);
return
lhs_node
->
parallel_id
()
<
rhs_node
->
parallel_id
();
});
std
::
sort
(
edges_to_be_sorted
->
begin
(),
edges_to_be_sorted
->
end
(),
[
&
]
(
const
TaskEdge
*
lhs
,
const
TaskEdge
*
rhs
)
{
const
CompTaskNode
*
lhs_node
=
GetPredSuccCompTaskNode
(
lhs
);
const
CompTaskNode
*
rhs_node
=
GetPredSuccCompTaskNode
(
rhs
);
return
lhs_node
->
parallel_id
()
<
rhs_node
->
parallel_id
();
});
}
void
BoxingTaskNode
::
FwBuildChainSortedEdgesPair
(
...
...
@@ -141,9 +141,8 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair(
CHECK_EQ
(
lbns
.
size
(),
1
);
lbns
.
clear
();
auto
in_regst_0
=
GetRelatedRegst
(
sorted_in_edges
.
at
(
0
));
in_regst_0
->
ForEachLbn
([
&
](
const
std
::
string
&
lbn
)
{
lbns
.
push_back
(
lbn
);
});
in_regst_0
->
ForEachLbn
(
[
&
](
const
std
::
string
&
lbn
)
{
lbns
.
push_back
(
lbn
);
});
}
// Enroll Lbn
auto
middle_regst
=
GetProducedRegstDesc
(
"middle"
);
...
...
@@ -173,11 +172,9 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair(
void
BoxingTaskNode
::
FwInferShapeOfBlobsInProducedRegsts
(
TaskGraph
*
)
{
exec_gph
().
ConstForEachNode
([
this
](
const
ExecNode
*
exec_node
)
{
exec_node
->
op
()
->
InferShape4FwBlobs
(
exec_node
->
GetMutShapePtr4BnInOpFunc
(),
chain_node
()
->
parallel_desc
()
->
policy
(),
0
,
0
);
exec_node
->
op
()
->
InferShape4FwBlobs
(
exec_node
->
GetMutShapePtr4BnInOpFunc
(),
chain_node
()
->
parallel_desc
()
->
policy
(),
0
,
0
);
});
}
...
...
@@ -191,7 +188,7 @@ std::shared_ptr<RegstDesc> GetBpRegstFromFwRegst(
return
GetRelatedRegst
(
bp_edge
);
}
}
}
// namespace
void
BoxingTaskNode
::
BpBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
)
{
EnrollAllRegstAndBindRelatedEdge
();
...
...
@@ -231,7 +228,7 @@ void BoxingTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
});
mut_exec_gph
().
UpdateSourceAndSink
();
}
void
BoxingTaskNode
::
BpInferShapeOfBlobsInProducedRegsts
(
TaskGraph
*
)
{
for
(
TaskEdge
*
fw_in_edge
:
GetFwNode
()
->
in_edges
())
{
auto
in_regst
=
GetRelatedRegst
(
fw_in_edge
);
...
...
@@ -240,8 +237,8 @@ void BoxingTaskNode::BpInferShapeOfBlobsInProducedRegsts(TaskGraph*) {
}
}
auto
fw_middle_regst
=
GetFwNode
()
->
GetProducedRegstDesc
(
"middle"
);
auto
bp_middle_regst
=
GetProducedRegstDesc
(
"middle"
);
auto
bp_middle_regst
=
GetProducedRegstDesc
(
"middle"
);
bp_middle_regst
->
CopyShapeFrom
(
fw_middle_regst
.
get
());
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/boxing_task_node.h
浏览文件 @
03db51d7
...
...
@@ -10,20 +10,18 @@ class BoxingTaskNode : public TaskNode {
OF_DISALLOW_COPY_AND_MOVE
(
BoxingTaskNode
);
BoxingTaskNode
()
=
default
;
virtual
~
BoxingTaskNode
()
=
default
;
std
::
string
VisualStr
()
const
override
{
return
TaskNode
::
VisualStr
()
+
"Boxing"
;
}
void
ToProto
(
TaskProto
*
ret
)
const
override
{
TaskNode
::
ToProto
(
ret
);
};
void
ToProto
(
TaskProto
*
ret
)
const
override
{
TaskNode
::
ToProto
(
ret
);
};
protected:
virtual
void
InitWithFwNode
(
TaskNode
*
fw_node
)
override
{
TaskNode
::
InitWithFwNode
(
fw_node
);
}
using
ChainEdgesPair
=
std
::
pair
<
const
ChainNode
*
,
std
::
vector
<
const
TaskEdge
*>>
;
using
Chain2EdgesMap
=
...
...
@@ -33,10 +31,9 @@ class BoxingTaskNode : public TaskNode {
const
std
::
unordered_set
<
TaskEdge
*>&
(
TaskNode
::*
in_out_edges
)()
const
,
TaskNode
*
(
TaskEdge
::*
src_dst_node
)()
const
,
TaskEdge
*
(
TaskNode
::*
SoleEdge
)()
const
);
void
FwSortEdgesInnerStage
(
std
::
vector
<
const
TaskEdge
*>*
edges_to_be_sorted
,
TaskNode
*
(
TaskEdge
::*
src_dst_node
)()
const
,
TaskEdge
*
(
TaskNode
::*
SoleEdge
)()
const
);
void
FwSortEdgesInnerStage
(
std
::
vector
<
const
TaskEdge
*>*
edges_to_be_sorted
,
TaskNode
*
(
TaskEdge
::*
src_dst_node
)()
const
,
TaskEdge
*
(
TaskNode
::*
SoleEdge
)()
const
);
void
FwBuildChainSortedEdgesPair
(
const
ChainEdgesPair
&
chain_sorted_in_edges
,
const
ChainEdgesPair
&
chain_sorted_out_edges
);
...
...
@@ -50,12 +47,11 @@ class BoxingTaskNode : public TaskNode {
void
FwInferShapeOfBlobsInProducedRegsts
(
TaskGraph
*
);
void
BpBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
);
void
BpInferShapeOfBlobsInProducedRegsts
(
TaskGraph
*
);
void
EnrollAllRegstAndBindRelatedEdge
();
TaskType
task_type
()
const
override
{
return
kBoxingTask
;
}
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_BOXING_TASK_NODE_H_
#endif
// ONEFLOW_CORE_GRAPH_BOXING_TASK_NODE_H_
oneflow/core/graph/chain_graph.cpp
浏览文件 @
03db51d7
...
...
@@ -27,10 +27,8 @@ void SetChainNodeWithChainIt(ChainNode* chain_node, ChainIt chain_it) {
}
}
void
InitChains
(
const
LogicalGraph
&
logi_gph
,
std
::
list
<
Chain
>*
chain_list
,
Logical2ChainItMap
*
logical2chain_it
)
{
void
InitChains
(
const
LogicalGraph
&
logi_gph
,
std
::
list
<
Chain
>*
chain_list
,
Logical2ChainItMap
*
logical2chain_it
)
{
chain_list
->
clear
();
logical2chain_it
->
clear
();
logi_gph
.
ConstForEachNode
([
&
](
const
LogicalNode
*
node
)
{
...
...
@@ -82,9 +80,8 @@ void InitChains(
});
}
void
ModelMergeChains
(
std
::
list
<
Chain
>*
chain_list
,
Logical2ChainItMap
*
logical2chain_it
)
{
void
ModelMergeChains
(
std
::
list
<
Chain
>*
chain_list
,
Logical2ChainItMap
*
logical2chain_it
)
{
for
(
auto
&
pair
:
*
logical2chain_it
)
{
// Get cur_node, pred_node
const
LogicalNode
*
cur_node
=
pair
.
first
;
...
...
@@ -101,8 +98,7 @@ void ModelMergeChains(
ChainIt
pred_chain
=
logical2chain_it
->
at
(
pred_node
);
ChainIt
cur_chain
=
pair
.
second
;
// Merge
pred_chain
->
nodes
.
insert
(
pred_chain
->
nodes
.
end
(),
cur_chain
->
nodes
.
begin
(),
pred_chain
->
nodes
.
insert
(
pred_chain
->
nodes
.
end
(),
cur_chain
->
nodes
.
begin
(),
cur_chain
->
nodes
.
end
());
for
(
const
LogicalNode
*
node
:
cur_chain
->
nodes
)
{
pred_chain
->
descendants
.
erase
(
node
);
...
...
@@ -112,11 +108,10 @@ void ModelMergeChains(
}
}
bool
TryMergeWithConnect
(
const
LogicalNode
*
up_node
,
const
LogicalNode
*
bottom_node
,
std
::
list
<
Chain
>*
chain_list
,
Logical2ChainItMap
*
logical2chain_it
)
{
bool
TryMergeWithConnect
(
const
LogicalNode
*
up_node
,
const
LogicalNode
*
bottom_node
,
std
::
list
<
Chain
>*
chain_list
,
Logical2ChainItMap
*
logical2chain_it
)
{
// Get chain
ChainIt
up_chain
=
logical2chain_it
->
at
(
up_node
);
ChainIt
bottom_chain
=
logical2chain_it
->
at
(
bottom_node
);
...
...
@@ -146,11 +141,10 @@ bool TryMergeWithConnect(
return
true
;
}
bool
TryMergeWithoutConnect
(
const
LogicalNode
*
lhs_node
,
const
LogicalNode
*
rhs_node
,
std
::
list
<
Chain
>*
chain_list
,
Logical2ChainItMap
*
logical2chain_it
)
{
bool
TryMergeWithoutConnect
(
const
LogicalNode
*
lhs_node
,
const
LogicalNode
*
rhs_node
,
std
::
list
<
Chain
>*
chain_list
,
Logical2ChainItMap
*
logical2chain_it
)
{
// Get chain
ChainIt
lhs_chain
=
logical2chain_it
->
at
(
lhs_node
);
ChainIt
rhs_chain
=
logical2chain_it
->
at
(
rhs_node
);
...
...
@@ -170,11 +164,9 @@ bool TryMergeWithoutConnect(
return
true
;
}
bool
TryDataMerge
(
const
LogicalNode
*
first
,
const
LogicalNode
*
second
,
std
::
list
<
Chain
>*
chain_list
,
Logical2ChainItMap
*
logical2chain_it
)
{
bool
TryDataMerge
(
const
LogicalNode
*
first
,
const
LogicalNode
*
second
,
std
::
list
<
Chain
>*
chain_list
,
Logical2ChainItMap
*
logical2chain_it
)
{
if
(
first
->
parallel_desc
()
->
Equal
(
second
->
parallel_desc
().
get
())
==
false
)
{
return
false
;
}
...
...
@@ -186,10 +178,9 @@ bool TryDataMerge(
return
false
;
}
bool
DoOneDataMerge
(
const
std
::
vector
<
const
LogicalNode
*>&
data_parallel_node
,
std
::
list
<
Chain
>*
chain_list
,
Logical2ChainItMap
*
logical2chain_it
)
{
bool
DoOneDataMerge
(
const
std
::
vector
<
const
LogicalNode
*>&
data_parallel_node
,
std
::
list
<
Chain
>*
chain_list
,
Logical2ChainItMap
*
logical2chain_it
)
{
for
(
const
LogicalNode
*
first
:
data_parallel_node
)
{
for
(
const
LogicalNode
*
second
:
data_parallel_node
)
{
if
(
first
==
second
)
{
continue
;
}
...
...
@@ -204,10 +195,9 @@ bool DoOneDataMerge(
return
false
;
}
void
DataMergeChains
(
const
LogicalGraph
&
logical_gph
,
std
::
list
<
Chain
>*
chain_list
,
Logical2ChainItMap
*
logical2chain_it
)
{
void
DataMergeChains
(
const
LogicalGraph
&
logical_gph
,
std
::
list
<
Chain
>*
chain_list
,
Logical2ChainItMap
*
logical2chain_it
)
{
std
::
vector
<
const
LogicalNode
*>
data_parallel_node
;
for
(
const
auto
&
pair
:
*
logical2chain_it
)
{
const
LogicalNode
*
cur_logi_node
=
pair
.
first
;
...
...
@@ -215,17 +205,14 @@ void DataMergeChains(
if
(
cur_logi_node
->
IsLossNode
())
{
continue
;
}
data_parallel_node
.
push_back
(
cur_logi_node
);
}
while
(
DoOneDataMerge
(
data_parallel_node
,
chain_list
,
logical2chain_it
))
{
}
while
(
DoOneDataMerge
(
data_parallel_node
,
chain_list
,
logical2chain_it
))
{}
}
}
// namespace
}
// namespace
std
::
string
ChainNode
::
ConcatedOpsName
()
const
{
std
::
stringstream
ss
;
for
(
auto
op
:
op_vec_
)
{
ss
<<
"
\\
n"
<<
op
->
op_name
();
}
for
(
auto
op
:
op_vec_
)
{
ss
<<
"
\\
n"
<<
op
->
op_name
();
}
if
(
!
op_vec_
.
empty
())
{
return
ss
.
str
().
substr
(
2
);
}
else
{
...
...
@@ -252,19 +239,21 @@ ChainGraph::ChainGraph(const LogicalGraph* logical_gph) {
DataMergeChains
(
*
logical_gph
,
&
chain_list
,
&
logical2chain_it
);
// Init chain_nodes
auto
HashChainIt
=
[](
const
ChainIt
&
chain_it
)
{
return
std
::
hash
<
Chain
*>
()(
&
(
*
chain_it
));
return
std
::
hash
<
Chain
*>
()(
&
(
*
chain_it
));
};
HashMap
<
ChainIt
,
ChainNode
*
,
decltype
(
HashChainIt
)
>
chain_it2chain_node
(
11
,
HashChainIt
);
HashMap
<
ChainIt
,
ChainNode
*
,
decltype
(
HashChainIt
)
>
chain_it2chain_node
(
11
,
HashChainIt
);
HashMap
<
ChainNode
*
,
std
::
unordered_set
<
ChainNode
*>>
chain_node2pred
;
for
(
auto
chain_it
=
chain_list
.
begin
();
chain_it
!=
chain_list
.
end
();
++
chain_it
)
{
for
(
auto
chain_it
=
chain_list
.
begin
();
chain_it
!=
chain_list
.
end
();
++
chain_it
)
{
ChainNode
*
chain_node
=
NewNode
();
chain_it2chain_node
[
chain_it
]
=
chain_node
;
chain_node2pred
[
chain_node
]
=
{};
SetChainNodeWithChainIt
(
chain_node
,
chain_it
);
}
// Record the predecessor
for
(
auto
chain_it
=
chain_list
.
begin
();
chain_it
!=
chain_list
.
end
();
++
chain_it
)
{
for
(
auto
chain_it
=
chain_list
.
begin
();
chain_it
!=
chain_list
.
end
();
++
chain_it
)
{
ChainNode
*
chain_node
=
chain_it2chain_node
.
at
(
chain_it
);
for
(
const
LogicalNode
*
logi_node
:
chain_it
->
nodes
)
{
for
(
auto
logi_in_edge
:
logi_node
->
in_edges
())
{
...
...
@@ -324,14 +313,12 @@ void ChainGraph::SetInOutLbn4AllChainNodeInDataTaskGraph() {
});
}
std
::
vector
<
std
::
string
>
FindLbnsBetween
(
const
ChainNode
*
src_node
,
std
::
vector
<
std
::
string
>
FindLbnsBetween
(
const
ChainNode
*
src_node
,
const
ChainNode
*
dst_node
)
{
std
::
vector
<
std
::
string
>
matching_lbns
;
for
(
const
std
::
string
&
src_node_output_lbn
:
src_node
->
output_lbns
())
{
for
(
const
std
::
string
&
dst_node_input_lbn
:
dst_node
->
input_lbns
())
{
if
(
src_node_output_lbn
!=
dst_node_input_lbn
)
{
continue
;
}
for
(
const
std
::
string
&
dst_node_input_lbn
:
dst_node
->
input_lbns
())
{
if
(
src_node_output_lbn
!=
dst_node_input_lbn
)
{
continue
;
}
matching_lbns
.
push_back
(
src_node_output_lbn
);
break
;
}
...
...
@@ -343,10 +330,8 @@ std::vector<std::string> FindLbnsBetween(const ChainNode* src_node,
std
::
string
ChainEdge
::
VisualStr
()
const
{
std
::
vector
<
std
::
string
>
lbns
=
FindLbnsBetween
(
src_node
(),
dst_node
());
std
::
stringstream
ss
;
for
(
const
std
::
string
&
lbn
:
lbns
)
{
ss
<<
"
\\
n"
<<
lbn
;
}
for
(
const
std
::
string
&
lbn
:
lbns
)
{
ss
<<
"
\\
n"
<<
lbn
;
}
return
ss
.
str
().
substr
(
2
);
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/chain_graph.h
浏览文件 @
03db51d7
...
...
@@ -22,9 +22,7 @@ class ChainNode final : public Node<ChainNode, ChainEdge> {
const
std
::
vector
<
std
::
shared_ptr
<
const
Operator
>>&
op_vec
()
const
{
return
op_vec_
;
}
std
::
vector
<
std
::
shared_ptr
<
const
Operator
>>&
mut_op_vec
()
{
return
op_vec_
;
}
std
::
vector
<
std
::
shared_ptr
<
const
Operator
>>&
mut_op_vec
()
{
return
op_vec_
;
}
std
::
shared_ptr
<
const
ParallelDesc
>
parallel_desc
()
const
{
return
parallel_desc_
;
...
...
@@ -33,26 +31,18 @@ class ChainNode final : public Node<ChainNode, ChainEdge> {
return
parallel_desc_
;
}
const
std
::
vector
<
std
::
string
>&
input_lbns
()
const
{
return
input_lbns_
;
}
std
::
vector
<
std
::
string
>&
mut_input_lbns
()
{
return
input_lbns_
;
}
const
std
::
vector
<
std
::
string
>&
output_lbns
()
const
{
return
output_lbns_
;
}
std
::
vector
<
std
::
string
>&
mut_output_lbns
()
{
return
output_lbns_
;
}
const
std
::
vector
<
std
::
string
>&
input_lbns
()
const
{
return
input_lbns_
;
}
std
::
vector
<
std
::
string
>&
mut_input_lbns
()
{
return
input_lbns_
;
}
const
std
::
vector
<
std
::
string
>&
output_lbns
()
const
{
return
output_lbns_
;
}
std
::
vector
<
std
::
string
>&
mut_output_lbns
()
{
return
output_lbns_
;
}
bool
IsLossNode
()
const
{
return
op_vec_
.
size
()
==
1
&&
op_vec_
.
front
()
->
IsLossOp
();
}
std
::
string
VisualStr
()
const
{
return
ConcatedOpsName
();
}
bool
HasOpWithModelOrModelTmpBlob
()
const
;
private:
...
...
@@ -60,10 +50,8 @@ class ChainNode final : public Node<ChainNode, ChainEdge> {
std
::
shared_ptr
<
const
ParallelDesc
>
parallel_desc_
;
std
::
vector
<
std
::
string
>
input_lbns_
;
std
::
vector
<
std
::
string
>
output_lbns_
;
};
class
ChainEdge
final
:
public
Edge
<
ChainNode
,
ChainEdge
>
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
ChainEdge
);
...
...
@@ -87,11 +75,10 @@ class ChainGraph final : public Graph<ChainNode, ChainEdge> {
private:
void
SetInOutLbn4AllChainNodeInDataTaskGraph
();
};
std
::
vector
<
std
::
string
>
FindLbnsBetween
(
const
ChainNode
*
,
const
ChainNode
*
);
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_CHAIN_GRAPH_H_
#endif
// ONEFLOW_CORE_GRAPH_CHAIN_GRAPH_H_
oneflow/core/graph/comp_task_node.cpp
浏览文件 @
03db51d7
#include "oneflow/core/graph/comp_task_node.h"
#include "oneflow/core/graph/model_update_task_graph.h"
#include "oneflow/core/graph/model_save_task_graph.h"
#include "oneflow/core/
operator/operator_manager
.h"
#include "oneflow/core/
graph/model_update_task_graph
.h"
#include "oneflow/core/operator/clone_op.h"
#include "oneflow/core/operator/operator_manager.h"
namespace
oneflow
{
std
::
string
CompTaskNode
::
VisualStr
()
const
{
std
::
stringstream
ss
;
ss
<<
TaskNode
::
VisualStr
()
<<
"Compute"
<<
":"
<<
stage_node
()
->
machine_id_str
()
<<
":"
<<
thrd_loc_id_str
()
<<
"
\\
n"
ss
<<
TaskNode
::
VisualStr
()
<<
"Compute"
<<
":"
<<
stage_node
()
->
machine_id_str
()
<<
":"
<<
thrd_loc_id_str
()
<<
"
\\
n"
<<
chain_node
()
->
VisualStr
();
return
ss
.
str
();
}
std
::
string
CompTaskNode
::
device_name
()
const
{
return
IDMgr
::
Singleton
().
MachineName4MachineId
(
stage_node
()
->
machine_id
())
+
":"
+
std
::
to_string
(
IDMgr
::
Singleton
().
DevPhyId4ThrdLocId
(
thrd_loc_id
()));
+
":"
+
std
::
to_string
(
IDMgr
::
Singleton
().
DevPhyId4ThrdLocId
(
thrd_loc_id
()));
}
void
SortByParallelId
(
std
::
vector
<
CompTaskNode
*>*
comp_node_vec
)
{
std
::
sort
(
comp_node_vec
->
begin
(),
comp_node_vec
->
end
(),
[]
(
const
CompTaskNode
*
lhs
,
const
CompTaskNode
*
rhs
)
{
return
lhs
->
parallel_id
()
<
rhs
->
parallel_id
();
});
std
::
sort
(
comp_node_vec
->
begin
(),
comp_node_vec
->
end
(),
[]
(
const
CompTaskNode
*
lhs
,
const
CompTaskNode
*
rhs
)
{
return
lhs
->
parallel_id
()
<
rhs
->
parallel_id
();
});
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/comp_task_node.h
浏览文件 @
03db51d7
...
...
@@ -21,17 +21,16 @@ class CompTaskNode : public TaskNode {
protected:
virtual
void
InitWithFwNode
(
TaskNode
*
fw_node
)
override
{
TaskNode
::
InitWithFwNode
(
fw_node
);
auto
fw_comp_code
=
static_cast
<
CompTaskNode
*>
(
fw_node
);
auto
fw_comp_code
=
static_cast
<
CompTaskNode
*>
(
fw_node
);
parallel_id_
=
fw_comp_code
->
parallel_id_
;
}
private:
int64_t
parallel_id_
;
};
void
SortByParallelId
(
std
::
vector
<
CompTaskNode
*>*
comp_node_vec
);
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_COMP_TASK_NODE_H_
#endif
// ONEFLOW_CORE_GRAPH_COMP_TASK_NODE_H_
oneflow/core/graph/copy_task_node.cpp
浏览文件 @
03db51d7
#include "oneflow/core/graph/copy_task_node.h"
#include "oneflow/core/operator/copy_hd_op.h"
#include "oneflow/core/operator/copy_comm_net_op.h"
#include "oneflow/core/operator/copy_hd_op.h"
namespace
oneflow
{
void
CopyTaskNode
::
BuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
){
void
CopyTaskNode
::
BuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
)
{
auto
out_regst
=
NewProducedRegstDesc
(
"copy_out"
);
BindProducedRegstAndOutEdge
(
out_regst
,
SoleOutEdge
());
std
::
shared_ptr
<
RegstDesc
>
in_regst
=
GetRelatedRegst
(
SoleInEdge
());
SubscribeRegstDesc
(
"copy_in"
,
in_regst
);
out_regst
->
CopyLbnFrom
(
in_regst
.
get
());
ExecNode
*
node
=
mut_exec_gph
().
NewNode
();
node
->
mut_op
()
=
ConstructOp
();
if
(
IsFwNode
())
{
node
->
BindBnInOpAndRegst
(
node
->
op
()
->
SoleIbn
(),
in_regst
);
node
->
BindBnInOpAndRegst
(
node
->
op
()
->
SoleObn
(),
out_regst
);
...
...
@@ -21,7 +21,7 @@ void CopyTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph*){
node
->
BindBnInOpAndRegst
(
node
->
op
()
->
SoleOdbn
(),
in_regst
);
node
->
BindBnInOpAndRegst
(
node
->
op
()
->
SoleIdbn
(),
out_regst
);
}
mut_exec_gph
().
UpdateSourceAndSink
();
}
...
...
@@ -56,4 +56,4 @@ std::shared_ptr<const Operator> CopyCommNetTaskNode::ConstructOp() const {
return
OpMgr
::
Singleton
().
ConstructOp
(
op_conf
);
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/copy_task_node.h
浏览文件 @
03db51d7
...
...
@@ -17,7 +17,6 @@ class CopyTaskNode : public TaskNode {
private:
void
BuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
)
override
;
void
InferShapeOfBlobsInProducedRegsts
(
TaskGraph
*
)
override
;
};
class
CopyHDTaskNode
final
:
public
CopyTaskNode
{
...
...
@@ -25,26 +24,22 @@ class CopyHDTaskNode final : public CopyTaskNode {
OF_DISALLOW_COPY_AND_MOVE
(
CopyHDTaskNode
);
CopyHDTaskNode
()
=
default
;
~
CopyHDTaskNode
()
=
default
;
bool
IsH2D
()
const
{
return
((
IsFwInCopy
()
&&
IsFwNode
())
||
(
IsFwOutCopy
()
&&
IsBpNode
()));
}
bool
IsD2H
()
const
{
return
!
IsH2D
();
}
bool
IsD2H
()
const
{
return
!
IsH2D
();
}
bool
IsFwInCopy
()
const
{
return
is_fw_in_copy_
;
}
bool
IsFwOutCopy
()
const
{
return
!
is_fw_in_copy_
;
}
void
SetFwInCopy
();
void
SetFwOutCopy
();
std
::
string
VisualStr
()
const
override
{
return
TaskNode
::
VisualStr
()
+
"CopyHD"
;
}
void
ToProto
(
TaskProto
*
ret
)
const
override
{
TaskNode
::
ToProto
(
ret
);
};
void
ToProto
(
TaskProto
*
ret
)
const
override
{
TaskNode
::
ToProto
(
ret
);
};
private:
std
::
shared_ptr
<
const
Operator
>
ConstructOp
()
const
override
;
...
...
@@ -54,12 +49,11 @@ class CopyHDTaskNode final : public CopyTaskNode {
is_fw_in_copy_
=
static_cast
<
CopyHDTaskNode
*>
(
fw_node
)
->
is_fw_in_copy_
;
}
std
::
unique_ptr
<
TaskNode
>
CreateSameTypeNode
()
const
override
{
return
of_make_unique
<
CopyHDTaskNode
>
();
return
of_make_unique
<
CopyHDTaskNode
>
();
}
TaskType
task_type
()
const
override
{
return
kCopyHdTask
;
}
bool
is_fw_in_copy_
;
};
class
CopyCommNetTaskNode
final
:
public
CopyTaskNode
{
...
...
@@ -72,14 +66,12 @@ class CopyCommNetTaskNode final : public CopyTaskNode {
return
TaskNode
::
VisualStr
()
+
"CommNet"
;
}
void
ToProto
(
TaskProto
*
ret
)
const
override
{
TaskNode
::
ToProto
(
ret
);
};
void
ToProto
(
TaskProto
*
ret
)
const
override
{
TaskNode
::
ToProto
(
ret
);
};
private:
std
::
shared_ptr
<
const
Operator
>
ConstructOp
()
const
override
;
std
::
unique_ptr
<
TaskNode
>
CreateSameTypeNode
()
const
override
{
return
of_make_unique
<
CopyCommNetTaskNode
>
();
return
of_make_unique
<
CopyCommNetTaskNode
>
();
}
void
InitWithFwNode
(
TaskNode
*
fw_node
)
override
{
TaskNode
::
InitWithFwNode
(
fw_node
);
...
...
@@ -87,9 +79,8 @@ class CopyCommNetTaskNode final : public CopyTaskNode {
set_task_id
();
}
TaskType
task_type
()
const
override
{
return
kCopyCommNetTask
;
}
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_COPY_TASK_NODE_H_
#endif
// ONEFLOW_CORE_GRAPH_COPY_TASK_NODE_H_
oneflow/core/graph/data_comp_task_node.cpp
浏览文件 @
03db51d7
...
...
@@ -21,22 +21,20 @@ void DataCompTaskNode::FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
FwSetExecNodeFromInRegst
(
extern_in_lbn2consumer
);
FwEnrollLbn2OutRegst
(
lbn2producer
);
FwEnrollLbn2ActivationRegst
();
FwEnrollLbn2ModelAndTmpRegsts
();
// model model_tmp data_tmp
FwEnrollLbn2ModelAndTmpRegsts
();
// model model_tmp data_tmp
}
void
DataCompTaskNode
::
FwInferShapeOfBlobsInProducedRegsts
(
TaskGraph
*
)
{
exec_gph
().
ConstTopoForEachNode
([
this
](
const
ExecNode
*
node
)
{
node
->
op
()
->
InferShape4FwBlobs
(
node
->
GetMutShapePtr4BnInOpFunc
(),
chain_node
()
->
parallel_desc
()
->
policy
(),
parallel_id
(),
chain_node
()
->
parallel_desc
()
->
policy
(),
parallel_id
(),
chain_node
()
->
parallel_desc
()
->
parallel_num
());
});
}
void
DataCompTaskNode
::
FwBuildFromUserOps
(
Lbn2NodeBnMap
*
lbn2producer
,
Lbn2NodeBnMap
*
extern_in_lbn2consumer
)
{
Lbn2NodeBnMap
*
lbn2producer
,
Lbn2NodeBnMap
*
extern_in_lbn2consumer
)
{
for
(
std
::
shared_ptr
<
const
Operator
>
op
:
chain_node
()
->
op_vec
())
{
ExecNode
*
cur_node
=
mut_exec_gph
().
NewNode
();
cur_node
->
mut_op
()
=
op
;
...
...
@@ -56,8 +54,7 @@ void DataCompTaskNode::FwBuildFromUserOps(
edge
->
mut_dst_bn
()
=
ibn
;
Connect
(
producer_it
->
second
.
first
,
edge
,
cur_node
);
}
else
{
CHECK
(
extern_in_lbn2consumer
->
insert
({
lbn
,
{
cur_node
,
ibn
}}).
second
);
CHECK
(
extern_in_lbn2consumer
->
insert
({
lbn
,
{
cur_node
,
ibn
}}).
second
);
}
}
});
...
...
@@ -161,8 +158,7 @@ void DataCompTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
// Subscribe
SubscribeRegstDesc
(
"activation"
,
GetFwNode
()
->
GetProducedRegstDesc
(
"activation"
));
SubscribeRegstDesc
(
"data_tmp"
,
GetFwNode
()
->
GetProducedRegstDesc
(
"data_tmp"
));
SubscribeRegstDesc
(
"data_tmp"
,
GetFwNode
()
->
GetProducedRegstDesc
(
"data_tmp"
));
SubscribeRegstDesc
(
"model"
,
GetFwNode
()
->
GetSubscribedRegstDesc
(
"model"
));
SubscribeRegstDesc
(
"model_tmp"
,
GetFwNode
()
->
GetSubscribedRegstDesc
(
"model_tmp"
));
...
...
@@ -179,7 +175,8 @@ void DataCompTaskNode::BpInferShapeOfBlobsInProducedRegsts(TaskGraph*) {
in_diff_regst
->
CopyShapeFrom
(
in_regst
.
get
());
// model_diff_regst
if
(
auto
md_diff_regst
=
GetProducedRegstDesc
(
"model_diff"
))
{
md_diff_regst
->
CopyShapeFrom
(
GetFwNode
()
->
GetSubscribedRegstDesc
(
"model"
).
get
());
md_diff_regst
->
CopyShapeFrom
(
GetFwNode
()
->
GetSubscribedRegstDesc
(
"model"
).
get
());
}
// activation_diff_regst
if
(
auto
acti_diff_regst
=
GetProducedRegstDesc
(
"activation_diff"
))
{
...
...
@@ -201,8 +198,7 @@ void DataCompTaskNode::BpBuildExecGraph() {
bp_edge
->
set_lbn
(
fw_edge
->
lbn
());
bp_edge
->
mut_src_bn
()
=
GenDiffBn
(
fw_edge
->
dst_bn
());
bp_edge
->
mut_dst_bn
()
=
GenDiffBn
(
fw_edge
->
src_bn
());
Connect
(
fw_node2bp_node
.
at
(
fw_edge
->
dst_node
()),
bp_edge
,
Connect
(
fw_node2bp_node
.
at
(
fw_edge
->
dst_node
()),
bp_edge
,
fw_node2bp_node
.
at
(
fw_edge
->
src_node
()));
});
mut_exec_gph
().
UpdateSourceAndSink
();
...
...
@@ -222,7 +218,7 @@ void DataCompTaskNode::BpEnrollLbn2ActivationDiffRegst() {
exec_gph
().
ConstForEachEdge
([
&
](
const
ExecEdge
*
edge
)
{
edge
->
src_node
()
->
BindBnInOpAndRegst
(
edge
->
src_bn
(),
activation_diff_regst
);
edge
->
dst_node
()
->
BindBnInOpAndRegst
(
edge
->
dst_bn
(),
activation_diff_regst
);
edge
->
src_node
()
->
BindBnInOpAndRegst
(
GenUnDiffBn
(
edge
->
src_bn
()),
edge
->
src_node
()
->
BindBnInOpAndRegst
(
GenUnDiffBn
(
edge
->
src_bn
()),
activation_regst
);
edge
->
dst_node
()
->
BindBnInOpAndRegst
(
GenUnDiffBn
(
edge
->
dst_bn
()),
activation_regst
);
...
...
@@ -280,4 +276,4 @@ void DataCompTaskNode::BpEnrollLbn2ModelDiffRegst() {
});
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/data_comp_task_node.h
浏览文件 @
03db51d7
...
...
@@ -21,17 +21,14 @@ class DataCompTaskNode final : public CompTaskNode {
private:
OVERRIDE_IF_FW_BP_FOR_FUNC
(
BuildExecAndEnrollLbn2Regsts
);
OVERRIDE_IF_FW_BP_FOR_FUNC
(
InferShapeOfBlobsInProducedRegsts
);
using
Lbn2NodeBnMap
=
HashMap
<
std
::
string
,
std
::
pair
<
ExecNode
*
,
std
::
string
>>
;
using
Lbn2NodeBnMap
=
HashMap
<
std
::
string
,
std
::
pair
<
ExecNode
*
,
std
::
string
>>
;
void
FwBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
gph
);
void
FwInferShapeOfBlobsInProducedRegsts
(
TaskGraph
*
gph
);
void
FwBuildFromUserOps
(
Lbn2NodeBnMap
*
lbn2producer
,
Lbn2NodeBnMap
*
extern_in_lbn2consumer
);
void
FwSetExecNodeFromInRegst
(
const
Lbn2NodeBnMap
&
extern_in_lbn2consumer
);
void
FwBuildFromUserOps
(
Lbn2NodeBnMap
*
lbn2producer
,
Lbn2NodeBnMap
*
extern_in_lbn2consumer
);
void
FwSetExecNodeFromInRegst
(
const
Lbn2NodeBnMap
&
extern_in_lbn2consumer
);
void
FwEnrollLbn2OutRegst
(
const
Lbn2NodeBnMap
&
lbn2producer
);
void
FwEnrollLbn2OutRegstWhenLoss
();
void
FwEnrollLbn2OutRegstWhenNotLoss
(
const
Lbn2NodeBnMap
&
lbn2producer
);
...
...
@@ -45,16 +42,13 @@ class DataCompTaskNode final : public CompTaskNode {
void
BpSetExecNodeFromOutDiffRegst
();
void
BpEnrollLbn2InDiffRegst
();
void
BpEnrollLbn2ModelDiffRegst
();
TaskType
task_type
()
const
override
{
return
kDataCompTask
;
}
TaskType
task_type
()
const
override
{
return
kDataCompTask
;
}
std
::
unique_ptr
<
TaskNode
>
CreateSameTypeNode
()
const
override
{
return
of_make_unique
<
DataCompTaskNode
>
();
return
of_make_unique
<
DataCompTaskNode
>
();
}
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_DATA_COMP_TASK_NODE_H_
#endif
// ONEFLOW_CORE_GRAPH_DATA_COMP_TASK_NODE_H_
oneflow/core/graph/data_task_graph.cpp
浏览文件 @
03db51d7
...
...
@@ -4,11 +4,9 @@ namespace oneflow {
class
DataCompTaskNode
;
DataTaskGraph
::
DataTaskGraph
(
const
std
::
string
&
name
,
const
DLNetConf
&
dl_net_conf
,
const
Strategy
&
strategy_conf
,
bool
need_bp
)
{
DataTaskGraph
::
DataTaskGraph
(
const
std
::
string
&
name
,
const
DLNetConf
&
dl_net_conf
,
const
Strategy
&
strategy_conf
,
bool
need_bp
)
{
mut_name
()
=
name
;
LogicalGraph
logical_gph
(
dl_net_conf
,
strategy_conf
);
auto
chain_gph
=
of_make_unique
<
ChainGraph
>
(
&
logical_gph
);
...
...
@@ -16,4 +14,4 @@ DataTaskGraph::DataTaskGraph(
BuildExecAndEnrollLbn2Regsts
();
}
}
}
// namespace oneflow
oneflow/core/graph/data_task_graph.h
浏览文件 @
03db51d7
...
...
@@ -10,17 +10,15 @@ class DataTaskGraph final : public TaskGraph {
OF_DISALLOW_COPY_AND_MOVE
(
DataTaskGraph
);
DataTaskGraph
()
=
delete
;
~
DataTaskGraph
()
=
default
;
DataTaskGraph
(
const
std
::
string
&
name
,
const
DLNetConf
&
dl_net_conf
,
const
Strategy
&
strategy_conf
,
bool
need_bp
);
DataTaskGraph
(
const
std
::
string
&
name
,
const
DLNetConf
&
dl_net_conf
,
const
Strategy
&
strategy_conf
,
bool
need_bp
);
const
char
*
TypeName
()
const
override
{
return
"DataTaskGraph"
;
}
private:
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_DATA_TASK_GRAPH_H_
#endif
// ONEFLOW_CORE_GRAPH_DATA_TASK_GRAPH_H_
oneflow/core/graph/exec_graph.cpp
浏览文件 @
03db51d7
...
...
@@ -2,12 +2,10 @@
namespace
oneflow
{
void
ExecEdge
::
set_lbn
(
const
std
::
string
&
lbn
)
{
lbn_
=
lbn
;
}
void
ExecEdge
::
set_lbn
(
const
std
::
string
&
lbn
)
{
lbn_
=
lbn
;
}
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
ExecNode
::
GetMutShapePtr4BnInOpFunc
()
const
{
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
ExecNode
::
GetMutShapePtr4BnInOpFunc
()
const
{
return
[
this
](
const
std
::
string
&
bn_in_op
)
->
Shape
*
{
auto
it
=
this
->
bn_in_op2regst_
.
find
(
bn_in_op
);
if
(
it
==
this
->
bn_in_op2regst_
.
end
())
{
return
nullptr
;
}
...
...
@@ -19,11 +17,11 @@ ExecNode::GetMutShapePtr4BnInOpFunc() const {
void
ExecNode
::
ToProto
(
ExecNodeProto
*
ret
)
const
{
ret
->
set_op_name
(
op_
->
op_name
());
for
(
const
auto
&
bn_regst
:
bn_in_op2regst_
)
{
for
(
const
auto
&
bn_regst
:
bn_in_op2regst_
)
{
auto
regst
=
bn_regst
.
second
.
lock
();
if
(
regst
)
{
ret
->
mutable_bn_in_op2regst_desc_id
()
->
insert
(
{
bn_regst
.
first
,
regst
->
regst_desc_id
()});
ret
->
mutable_bn_in_op2regst_desc_id
()
->
insert
(
{
bn_regst
.
first
,
regst
->
regst_desc_id
()});
}
}
}
...
...
@@ -36,4 +34,4 @@ void ExecGraph::ToExecSequence(ExecSequence* ret) const {
});
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/exec_graph.h
浏览文件 @
03db51d7
#ifndef ONEFLOW_CORE_GRAPH_EXEC_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_EXEC_GRAPH_H_
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/graph/exec_sequence.pb.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/graph/graph.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/register/register_desc.h"
#include "oneflow/core/common/protobuf.h"
namespace
oneflow
{
...
...
@@ -32,7 +32,6 @@ class ExecEdge final : public Edge<ExecNode, ExecEdge> {
std
::
string
lbn_
;
std
::
string
src_bn_
;
std
::
string
dst_bn_
;
};
class
ExecNode
final
:
public
Node
<
ExecNode
,
ExecEdge
>
{
...
...
@@ -44,10 +43,12 @@ class ExecNode final : public Node<ExecNode, ExecEdge> {
std
::
shared_ptr
<
const
Operator
>
op
()
const
{
return
op_
;
}
std
::
shared_ptr
<
const
Operator
>&
mut_op
()
{
return
op_
;
}
void
BindBnInOpAndRegst
(
const
std
::
string
&
bn_in_op
,
std
::
weak_ptr
<
RegstDesc
>
regst
)
{
void
BindBnInOpAndRegst
(
const
std
::
string
&
bn_in_op
,
std
::
weak_ptr
<
RegstDesc
>
regst
)
{
CHECK
(
bn_in_op2regst_
.
emplace
(
bn_in_op
,
regst
).
second
);
}
std
::
shared_ptr
<
RegstDesc
>
GetRegstFromBnInOp
(
const
std
::
string
&
bn_in_op
)
const
{
std
::
shared_ptr
<
RegstDesc
>
GetRegstFromBnInOp
(
const
std
::
string
&
bn_in_op
)
const
{
return
bn_in_op2regst_
.
at
(
bn_in_op
).
lock
();
}
const
HashMap
<
std
::
string
,
std
::
weak_ptr
<
RegstDesc
>>&
bn_in_op2regst
()
const
{
...
...
@@ -55,15 +56,14 @@ class ExecNode final : public Node<ExecNode, ExecEdge> {
}
std
::
function
<
Shape
*
(
const
std
::
string
&
)
>
GetMutShapePtr4BnInOpFunc
()
const
;
std
::
string
VisualStr
()
const
{
return
op_
->
op_name
();
}
void
ToProto
(
ExecNodeProto
*
ret
)
const
;
private:
std
::
shared_ptr
<
const
Operator
>
op_
;
HashMap
<
std
::
string
,
std
::
weak_ptr
<
RegstDesc
>>
bn_in_op2regst_
;
};
class
ExecGraph
final
:
public
Graph
<
ExecNode
,
ExecEdge
>
{
...
...
@@ -71,14 +71,13 @@ class ExecGraph final : public Graph<ExecNode, ExecEdge> {
OF_DISALLOW_COPY_AND_MOVE
(
ExecGraph
);
ExecGraph
()
=
default
;
~
ExecGraph
()
=
default
;
void
ToExecSequence
(
ExecSequence
*
ret
)
const
;
const
char
*
TypeName
()
const
override
{
return
"ExecGraph"
;
}
private:
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_EXEC_GRAPH_H_
#endif
// ONEFLOW_CORE_GRAPH_EXEC_GRAPH_H_
oneflow/core/graph/graph.h
浏览文件 @
03db51d7
#ifndef ONEFLOW_CORE_GRAPH_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_GRAPH_H_
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/lib/io/path.h"
#include "gflags/gflags.h"
#include "oneflow/core/persistence/persistent_out_stream.h"
#include "oneflow/core/graph/node.h"
#include "oneflow/core/persistence/persistent_out_stream.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
namespace
oneflow
{
...
...
@@ -23,11 +23,11 @@ class Graph {
void
ConstForEachNode
(
std
::
function
<
void
(
const
NodeType
*
)
>
)
const
;
void
ConstTopoForEachNode
(
std
::
function
<
void
(
const
NodeType
*
)
>
)
const
;
void
ConstReverseTopoForEachNode
(
std
::
function
<
void
(
const
NodeType
*
)
>
)
const
;
// For Each Edge
void
ForEachEdge
(
std
::
function
<
void
(
EdgeType
*
)
>
);
void
ConstForEachEdge
(
std
::
function
<
void
(
const
EdgeType
*
)
>
)
const
;
// Getters
const
std
::
unordered_set
<
NodeType
*>&
source_nodes
()
const
;
const
std
::
unordered_set
<
NodeType
*>&
sink_nodes
()
const
;
...
...
@@ -37,7 +37,7 @@ class Graph {
size_t
node_num
()
const
{
return
nodes_
.
size
();
}
size_t
edge_num
()
const
{
return
edges_
.
size
();
}
virtual
const
char
*
TypeName
()
const
{
return
"Not Defined"
;
}
// Setters
NodeType
*
NewNode
();
EdgeType
*
NewEdge
();
...
...
@@ -57,9 +57,9 @@ class Graph {
class
TopoIterator
;
class
ReverseTopoIterator
;
TopoIterator
begin
()
{
return
source_nodes_
;
}
TopoIterator
end
()
{
return
std
::
unordered_set
<
NodeType
*>
();
}
TopoIterator
end
()
{
return
std
::
unordered_set
<
NodeType
*>
();
}
ReverseTopoIterator
rbegin
()
{
return
sink_nodes_
;
}
ReverseTopoIterator
rend
()
{
return
std
::
unordered_set
<
NodeType
*>
();
}
ReverseTopoIterator
rend
()
{
return
std
::
unordered_set
<
NodeType
*>
();
}
//
std
::
unordered_set
<
NodeType
*>
source_nodes_
;
...
...
@@ -68,25 +68,22 @@ class Graph {
std
::
vector
<
std
::
unique_ptr
<
EdgeType
>>
edges_
;
};
template
<
typename
NodeType
,
typename
EdgeType
>
class
Graph
<
NodeType
,
EdgeType
>::
TopoIterator
final
{
public:
// OF_DISALLOW_COPY_AND_MOVE(TopoIterator);
TopoIterator
()
=
default
;
~
TopoIterator
()
=
default
;
TopoIterator
(
const
std
::
unordered_set
<
NodeType
*>&
source_nodes
)
{
for
(
NodeType
*
node
:
source_nodes
)
{
bfs_queue_
.
push
(
node
);
}
for
(
NodeType
*
node
:
source_nodes
)
{
bfs_queue_
.
push
(
node
);
}
}
NodeType
&
operator
*
()
{
return
*
(
bfs_queue_
.
front
());
}
NodeType
*
operator
->
()
{
return
&
(
*
(
*
this
));
}
TopoIterator
&
operator
++
();
bool
operator
!=
(
const
TopoIterator
&
)
const
;
NodeType
&
operator
*
()
{
return
*
(
bfs_queue_
.
front
());
}
NodeType
*
operator
->
()
{
return
&
(
*
(
*
this
));
}
TopoIterator
&
operator
++
();
bool
operator
!=
(
const
TopoIterator
&
)
const
;
private:
std
::
queue
<
NodeType
*>
bfs_queue_
;
...
...
@@ -99,57 +96,48 @@ class Graph<NodeType, EdgeType>::ReverseTopoIterator final {
// OF_DISALLOW_COPY_AND_MOVE(ReverseTopoIterator);
ReverseTopoIterator
()
=
default
;
~
ReverseTopoIterator
()
=
default
;
ReverseTopoIterator
(
const
std
::
unordered_set
<
NodeType
*>&
sink_nodes
)
{
for
(
NodeType
*
node
:
sink_nodes
)
{
bfs_queue_
.
push
(
node
);
}
for
(
NodeType
*
node
:
sink_nodes
)
{
bfs_queue_
.
push
(
node
);
}
}
NodeType
&
operator
*
()
{
return
*
(
bfs_queue_
.
front
());
}
NodeType
*
operator
->
()
{
return
&
(
*
(
*
this
));
}
ReverseTopoIterator
&
operator
++
();
bool
operator
!=
(
const
ReverseTopoIterator
&
)
const
;
NodeType
&
operator
*
()
{
return
*
(
bfs_queue_
.
front
());
}
NodeType
*
operator
->
()
{
return
&
(
*
(
*
this
));
}
ReverseTopoIterator
&
operator
++
();
bool
operator
!=
(
const
ReverseTopoIterator
&
)
const
;
private:
std
::
queue
<
NodeType
*>
bfs_queue_
;
HashMap
<
NodeType
*
,
int32_t
>
visited_cnt_
;
HashMap
<
NodeType
*
,
int32_t
>
visited_cnt_
;
};
template
<
typename
NodeType
,
typename
EdgeType
>
void
Graph
<
NodeType
,
EdgeType
>::
ForEachNode
(
std
::
function
<
void
(
NodeType
*
)
>
func
)
{
for
(
auto
&
x
:
nodes_
)
{
func
(
x
.
get
());
}
for
(
auto
&
x
:
nodes_
)
{
func
(
x
.
get
());
}
}
template
<
typename
NodeType
,
typename
EdgeType
>
void
Graph
<
NodeType
,
EdgeType
>::
TopoForEachNode
(
std
::
function
<
void
(
NodeType
*
)
>
func
)
{
for
(
TopoIterator
it
=
begin
();
it
!=
end
();
++
it
)
{
func
(
&
(
*
it
));
}
for
(
TopoIterator
it
=
begin
();
it
!=
end
();
++
it
)
{
func
(
&
(
*
it
));
}
}
template
<
typename
NodeType
,
typename
EdgeType
>
void
Graph
<
NodeType
,
EdgeType
>::
ReverseTopoForEachNode
(
std
::
function
<
void
(
NodeType
*
)
>
func
)
{
for
(
ReverseTopoIterator
it
=
rbegin
();
it
!=
rend
();
++
it
)
{
func
(
&
(
*
it
));
}
for
(
ReverseTopoIterator
it
=
rbegin
();
it
!=
rend
();
++
it
)
{
func
(
&
(
*
it
));
}
}
#define OF_DEFINE_CONST_FOR_EACH_NODE(FuncName) \
template<typename NodeType, typename EdgeType> \
void Graph<NodeType, EdgeType>::Const##FuncName( \
std::function<void(const NodeType*)> func) const { \
auto cast_this = const_cast<Graph<NodeType, EdgeType>*> (this); \
cast_this->FuncName(std::bind(func, std::placeholders::_1)); \
}
#define OF_DEFINE_CONST_FOR_EACH_NODE(FuncName) \
template<typename NodeType, typename EdgeType> \
void Graph<NodeType, EdgeType>::Const##FuncName( \
std::function<void(const NodeType*)> func) const { \
auto cast_this = const_cast<Graph<NodeType, EdgeType>*>(this); \
cast_this->FuncName(std::bind(func, std::placeholders::_1)); \
}
OF_DEFINE_CONST_FOR_EACH_NODE
(
ForEachNode
);
OF_DEFINE_CONST_FOR_EACH_NODE
(
TopoForEachNode
);
...
...
@@ -160,27 +148,25 @@ OF_DEFINE_CONST_FOR_EACH_NODE(ReverseTopoForEachNode);
template
<
typename
NodeType
,
typename
EdgeType
>
void
Graph
<
NodeType
,
EdgeType
>::
ForEachEdge
(
std
::
function
<
void
(
EdgeType
*
)
>
func
)
{
for
(
auto
&
x
:
edges_
)
{
func
(
x
.
get
());
}
for
(
auto
&
x
:
edges_
)
{
func
(
x
.
get
());
}
}
template
<
typename
NodeType
,
typename
EdgeType
>
void
Graph
<
NodeType
,
EdgeType
>::
ConstForEachEdge
(
std
::
function
<
void
(
const
EdgeType
*
)
>
func
)
const
{
auto
cast_this
=
const_cast
<
Graph
<
NodeType
,
EdgeType
>*>
(
this
);
auto
cast_this
=
const_cast
<
Graph
<
NodeType
,
EdgeType
>*>
(
this
);
cast_this
->
ForEachEdge
(
std
::
bind
(
func
,
std
::
placeholders
::
_1
));
}
template
<
typename
NodeType
,
typename
EdgeType
>
const
std
::
unordered_set
<
NodeType
*>&
Graph
<
NodeType
,
EdgeType
>::
source_nodes
()
const
{
const
std
::
unordered_set
<
NodeType
*>&
Graph
<
NodeType
,
EdgeType
>::
source_nodes
()
const
{
return
source_nodes_
;
}
template
<
typename
NodeType
,
typename
EdgeType
>
const
std
::
unordered_set
<
NodeType
*>&
Graph
<
NodeType
,
EdgeType
>::
sink_nodes
()
const
{
const
std
::
unordered_set
<
NodeType
*>&
Graph
<
NodeType
,
EdgeType
>::
sink_nodes
()
const
{
return
sink_nodes_
;
}
...
...
@@ -241,12 +227,8 @@ void Graph<NodeType, EdgeType>::UpdateSourceAndSink() {
source_nodes_
.
clear
();
sink_nodes_
.
clear
();
for
(
const
std
::
unique_ptr
<
NodeType
>&
node
:
nodes_
)
{
if
(
node
->
in_edges
().
empty
())
{
source_nodes_
.
insert
(
node
.
get
());
}
if
(
node
->
out_edges
().
empty
())
{
sink_nodes_
.
insert
(
node
.
get
());
}
if
(
node
->
in_edges
().
empty
())
{
source_nodes_
.
insert
(
node
.
get
());
}
if
(
node
->
out_edges
().
empty
())
{
sink_nodes_
.
insert
(
node
.
get
());
}
}
}
...
...
@@ -259,14 +241,15 @@ void Graph<NodeType, EdgeType>::ToDotWithStream(StreamT& out_stream) const {
});
this
->
ConstForEachEdge
([
&
](
const
EdgeType
*
edge
)
{
out_stream
<<
"
\"
"
<<
edge
->
src_node
()
->
VisualStr
()
<<
"
\"
-> "
<<
"
\"
"
<<
edge
->
dst_node
()
->
VisualStr
()
<<
"
\"
"
<<
"[label=
\"
"
<<
edge
->
VisualStr
()
<<
"
\"
];
\n
"
;
<<
"
\"
"
<<
edge
->
dst_node
()
->
VisualStr
()
<<
"
\"
"
<<
"[label=
\"
"
<<
edge
->
VisualStr
()
<<
"
\"
];
\n
"
;
});
out_stream
<<
"}
\n
"
;
}
template
<
typename
NodeType
,
typename
EdgeType
>
void
Graph
<
NodeType
,
EdgeType
>::
ToDotWithFilePath
(
const
std
::
string
&
file_path
)
const
{
void
Graph
<
NodeType
,
EdgeType
>::
ToDotWithFilePath
(
const
std
::
string
&
file_path
)
const
{
std
::
string
dir_name
=
tensorflow
::
io
::
Dirname
(
file_path
).
ToString
();
tensorflow
::
Env
*
env
=
tensorflow
::
Env
::
Default
();
if
(
env
->
IsDirectory
(
dir_name
).
code
()
!=
tensorflow
::
error
::
OK
)
{
...
...
@@ -278,17 +261,15 @@ void Graph<NodeType, EdgeType>::ToDotWithFilePath(const std::string& file_path)
template
<
typename
NodeType
,
typename
EdgeType
>
void
Graph
<
NodeType
,
EdgeType
>::
ToDotWithAutoFilePath
()
const
{
std
::
string
file_path
=
LogDir
()
+
"/dot/"
+
TypeName
()
+
"/"
+
NewUniqueId
()
+
".dot"
;
std
::
string
file_path
=
LogDir
()
+
"/dot/"
+
TypeName
()
+
"/"
+
NewUniqueId
()
+
".dot"
;
ToDotWithFilePath
(
file_path
);
}
template
<
typename
NodeType
>
bool
IsNotEqual4BfsQueue
(
const
std
::
queue
<
NodeType
*>&
lhs
,
const
std
::
queue
<
NodeType
*>&
rhs
)
{
if
(
lhs
.
empty
()
!=
rhs
.
empty
())
{
return
true
;
}
if
(
lhs
.
empty
()
!=
rhs
.
empty
())
{
return
true
;
}
if
(
lhs
.
empty
()
==
false
&&
rhs
.
empty
()
==
false
)
{
return
lhs
.
front
()
!=
rhs
.
front
();
}
...
...
@@ -296,27 +277,28 @@ bool IsNotEqual4BfsQueue(const std::queue<NodeType*>& lhs,
}
template
<
typename
NodeType
,
typename
EdgeType
>
auto
Graph
<
NodeType
,
EdgeType
>::
TopoIterator
::
operator
++
()
->
TopoIterator
&
{
auto
Graph
<
NodeType
,
EdgeType
>::
TopoIterator
::
operator
++
()
->
TopoIterator
&
{
NodeType
*
cur_node
=
bfs_queue_
.
front
();
bfs_queue_
.
pop
();
for
(
EdgeType
*
out_edge
:
cur_node
->
out_edges
())
{
NodeType
*
dst_node
=
out_edge
->
dst_node
();
visited_cnt_
[
dst_node
]
+=
1
;
if
(
visited_cnt_
.
at
(
dst_node
)
==
dst_node
->
in_edges
().
size
())
{
bfs_queue_
.
push
(
dst_node
);
bfs_queue_
.
push
(
dst_node
);
}
}
return
*
this
;
}
template
<
typename
NodeType
,
typename
EdgeType
>
bool
Graph
<
NodeType
,
EdgeType
>::
TopoIterator
::
operator
!=
(
bool
Graph
<
NodeType
,
EdgeType
>::
TopoIterator
::
operator
!=
(
const
TopoIterator
&
rhs
)
const
{
return
IsNotEqual4BfsQueue
(
bfs_queue_
,
rhs
.
bfs_queue_
);
}
template
<
typename
NodeType
,
typename
EdgeType
>
auto
Graph
<
NodeType
,
EdgeType
>::
ReverseTopoIterator
::
operator
++
()
->
ReverseTopoIterator
&
{
auto
Graph
<
NodeType
,
EdgeType
>::
ReverseTopoIterator
::
operator
++
()
->
ReverseTopoIterator
&
{
NodeType
*
cur_node
=
bfs_queue_
.
front
();
bfs_queue_
.
pop
();
for
(
EdgeType
*
in_edge
:
cur_node
->
in_edges
())
{
...
...
@@ -330,11 +312,11 @@ auto Graph<NodeType, EdgeType>::ReverseTopoIterator::operator ++ () -> ReverseTo
}
template
<
typename
NodeType
,
typename
EdgeType
>
bool
Graph
<
NodeType
,
EdgeType
>::
ReverseTopoIterator
::
operator
!=
(
bool
Graph
<
NodeType
,
EdgeType
>::
ReverseTopoIterator
::
operator
!=
(
const
ReverseTopoIterator
&
rhs
)
const
{
return
IsNotEqual4BfsQueue
(
bfs_queue_
,
rhs
.
bfs_queue_
);
}
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_GRAPH_H_
#endif
// ONEFLOW_CORE_GRAPH_GRAPH_H_
oneflow/core/graph/graph_test.cpp
浏览文件 @
03db51d7
...
...
@@ -7,12 +7,11 @@ class TestEdge;
class
TestNode
final
:
public
Node
<
TestNode
,
TestEdge
>
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
TestNode
);
TestNode
(
int64_t
node_id_
)
{
test_node_id_
=
node_id_
;
}
TestNode
(
int64_t
node_id_
)
{
test_node_id_
=
node_id_
;
}
~
TestNode
()
=
default
;
int64_t
test_node_id
()
const
{
return
test_node_id_
;
}
private:
int64_t
test_node_id_
;
};
...
...
@@ -59,9 +58,11 @@ void DoOneTestGraph(const TestGraph& test_graph,
// 1. Determines whether the traversal result satisfies the topological order
HashMap
<
int64_t
,
int64_t
>
node_id2order
,
node_id2rorder
;
auto
NodePairHash
=
[](
const
NodeIdPair
&
val
)
{
return
val
.
first
^
val
.
second
;
};
std
::
unordered_set
<
NodeIdPair
,
decltype
(
NodePairHash
)
>
edges_node_pair
(
11
,
NodePairHash
);
auto
NodePairHash
=
[](
const
NodeIdPair
&
val
)
{
return
val
.
first
^
val
.
second
;
};
std
::
unordered_set
<
NodeIdPair
,
decltype
(
NodePairHash
)
>
edges_node_pair
(
11
,
NodePairHash
);
int64_t
order
=
0
;
test_graph
.
ConstTopoForEachNode
([
&
](
const
TestNode
*
node
)
{
node_id2order
.
emplace
(
node
->
test_node_id
(),
order
);
...
...
@@ -76,7 +77,7 @@ void DoOneTestGraph(const TestGraph& test_graph,
});
ASSERT_EQ
(
node_id2rorder
.
size
(),
node_num
);
// method :
// method :
// judge every directed edge <u,v>
// the node u's order is smaller than v
int64_t
edge_num
=
0
;
...
...
@@ -90,7 +91,7 @@ void DoOneTestGraph(const TestGraph& test_graph,
src_ord
=
node_id2rorder
.
at
(
src_node_id
);
dst_ord
=
node_id2rorder
.
at
(
dst_node_id
);
ASSERT_GE
(
src_ord
,
dst_ord
);
//
//
++
edge_num
;
edges_node_pair
.
insert
(
std
::
make_pair
(
src_node_id
,
dst_node_id
));
}
...
...
@@ -109,8 +110,8 @@ void DoOneTestGraph(const TestGraph& test_graph,
test_graph
.
ConstForEachEdge
([
&
](
const
TestEdge
*
cur_edge
)
{
int64_t
src_node_id
=
cur_edge
->
src_node
()
->
test_node_id
();
int64_t
dst_node_id
=
cur_edge
->
dst_node
()
->
test_node_id
();
ASSERT_TRUE
(
edges_node_pair
.
count
(
std
::
make_pair
(
src_node_id
,
dst_node_id
))
>
0
);
ASSERT_TRUE
(
edges_node_pair
.
count
(
std
::
make_pair
(
src_node_id
,
dst_node_id
))
>
0
);
});
}
...
...
@@ -129,4 +130,4 @@ TEST(TestGraph, test_graph_node_num_7) {
DoOneTestGraph
(
test_graph
,
graph_conf
);
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/in_boxing_task_node.cpp
浏览文件 @
03db51d7
...
...
@@ -4,15 +4,12 @@ namespace oneflow {
void
InBoxingTaskNode
::
FwVirtualBuild
()
{
Chain2EdgesMap
chain2sorted_in_edges
;
FwInitChain2SortedEdgesMaps
(
&
chain2sorted_in_edges
,
&
TaskNode
::
in_edges
,
&
TaskEdge
::
src_node
,
&
TaskNode
::
SoleInEdge
);
FwInitChain2SortedEdgesMaps
(
&
chain2sorted_in_edges
,
&
TaskNode
::
in_edges
,
&
TaskEdge
::
src_node
,
&
TaskNode
::
SoleInEdge
);
ChainEdgesPair
chain_sorted_out_edges
;
chain_sorted_out_edges
.
first
=
chain_node
();
chain_sorted_out_edges
.
second
.
assign
(
out_edges
().
begin
(),
out_edges
().
end
());
FwSortEdgesInnerStage
(
&
chain_sorted_out_edges
.
second
,
&
TaskEdge
::
dst_node
,
FwSortEdgesInnerStage
(
&
chain_sorted_out_edges
.
second
,
&
TaskEdge
::
dst_node
,
&
TaskNode
::
SoleOutEdge
);
for
(
const
ChainEdgesPair
&
chain_sorted_in_edges
:
chain2sorted_in_edges
)
{
FwBuildChainSortedEdgesPair
(
chain_sorted_in_edges
,
chain_sorted_out_edges
);
...
...
@@ -20,4 +17,4 @@ void InBoxingTaskNode::FwVirtualBuild() {
mut_exec_gph
().
UpdateSourceAndSink
();
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/in_boxing_task_node.h
浏览文件 @
03db51d7
...
...
@@ -13,15 +13,14 @@ class InBoxingTaskNode final : public BoxingTaskNode {
private:
std
::
unique_ptr
<
TaskNode
>
CreateSameTypeNode
()
const
override
{
return
of_make_unique
<
InBoxingTaskNode
>
();
return
of_make_unique
<
InBoxingTaskNode
>
();
}
void
InitWithFwNode
(
TaskNode
*
fw_node
)
override
{
BoxingTaskNode
::
InitWithFwNode
(
fw_node
);
}
void
FwVirtualBuild
()
override
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_IN_BOXING_TASK_NODE_H_
#endif
// ONEFLOW_CORE_GRAPH_IN_BOXING_TASK_NODE_H_
oneflow/core/graph/logical_graph.cpp
浏览文件 @
03db51d7
...
...
@@ -16,8 +16,7 @@ LogicalGraph::LogicalGraph(const DLNetConf& dl_net_conf,
}
void
LogicalGraph
::
NaiveBuildGraphStruct
(
const
DLNetConf
&
dl_net_conf
,
HashMap
<
LogicalEdge
*
,
std
::
string
>*
edge2lbn
,
const
DLNetConf
&
dl_net_conf
,
HashMap
<
LogicalEdge
*
,
std
::
string
>*
edge2lbn
,
HashMap
<
LogicalEdge
*
,
std
::
string
>*
edge2ibn
)
{
HashMap
<
std
::
string
,
LogicalNode
*>
lbn2producer
;
// Process Op
...
...
@@ -123,4 +122,4 @@ void LogicalGraph::AddOneCloneNode(
}
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/logical_graph.h
浏览文件 @
03db51d7
...
...
@@ -2,10 +2,10 @@
#define ONEFLOW_CORE_GRAPH_LOGICAL_GRAPH_H_
#include "oneflow/core/graph/graph.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/job/dlnet_conf.pb.h"
#include "oneflow/core/job/strategy.pb.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/strategy.pb.h"
#include "oneflow/core/operator/operator.h"
namespace
oneflow
{
...
...
@@ -17,12 +17,8 @@ class LogicalNode final : public Node<LogicalNode, LogicalEdge> {
LogicalNode
()
=
default
;
~
LogicalNode
()
=
default
;
std
::
shared_ptr
<
Operator
>
op
()
const
{
return
op_
;
}
std
::
shared_ptr
<
Operator
>&
mut_op
()
{
return
op_
;
}
std
::
shared_ptr
<
Operator
>
op
()
const
{
return
op_
;
}
std
::
shared_ptr
<
Operator
>&
mut_op
()
{
return
op_
;
}
std
::
shared_ptr
<
const
ParallelDesc
>
parallel_desc
()
const
{
return
parallel_desc_
;
...
...
@@ -38,7 +34,6 @@ class LogicalNode final : public Node<LogicalNode, LogicalEdge> {
private:
std
::
shared_ptr
<
Operator
>
op_
;
std
::
shared_ptr
<
const
ParallelDesc
>
parallel_desc_
;
};
class
LogicalEdge
final
:
public
Edge
<
LogicalNode
,
LogicalEdge
>
{
...
...
@@ -56,16 +51,14 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
LogicalGraph
()
=
delete
;
~
LogicalGraph
()
=
default
;
LogicalGraph
(
const
DLNetConf
&
dl_net_conf
,
const
Strategy
&
strategy_conf
);
LogicalGraph
(
const
DLNetConf
&
dl_net_conf
,
const
Strategy
&
strategy_conf
);
const
char
*
TypeName
()
const
override
{
return
"LogicalGraph"
;
}
private:
void
NaiveBuildGraphStruct
(
const
DLNetConf
&
dl_net_conf
,
HashMap
<
LogicalEdge
*
,
std
::
string
>*
edge2lbn
,
HashMap
<
LogicalEdge
*
,
std
::
string
>*
edge2ibn
);
void
NaiveBuildGraphStruct
(
const
DLNetConf
&
dl_net_conf
,
HashMap
<
LogicalEdge
*
,
std
::
string
>*
edge2lbn
,
HashMap
<
LogicalEdge
*
,
std
::
string
>*
edge2ibn
);
void
FillNodeWithParallelDesc
(
const
Strategy
&
strategy_conf
);
struct
CloneInfo
{
...
...
@@ -73,18 +66,14 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
LogicalNode
*
pred_node
;
std
::
vector
<
LogicalEdge
*>
edges
;
};
void
AddCloneNodes
(
const
HashMap
<
LogicalEdge
*
,
std
::
string
>&
edge2lbn
,
const
HashMap
<
LogicalEdge
*
,
std
::
string
>&
edge2ibn
);
void
CollectCloneInfos
(
std
::
vector
<
CloneInfo
>*
clone_infos
,
const
HashMap
<
LogicalEdge
*
,
std
::
string
>&
edge2lbn
);
void
AddOneCloneNode
(
const
CloneInfo
&
clone_info
,
const
HashMap
<
LogicalEdge
*
,
std
::
string
>&
edge2ibn
);
void
AddCloneNodes
(
const
HashMap
<
LogicalEdge
*
,
std
::
string
>&
edge2lbn
,
const
HashMap
<
LogicalEdge
*
,
std
::
string
>&
edge2ibn
);
void
CollectCloneInfos
(
std
::
vector
<
CloneInfo
>*
clone_infos
,
const
HashMap
<
LogicalEdge
*
,
std
::
string
>&
edge2lbn
);
void
AddOneCloneNode
(
const
CloneInfo
&
clone_info
,
const
HashMap
<
LogicalEdge
*
,
std
::
string
>&
edge2ibn
);
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_LOGICAL_GRAPH_H_
#endif
// ONEFLOW_CORE_GRAPH_LOGICAL_GRAPH_H_
oneflow/core/graph/model_diff_accumulate_comp_task_node.cpp
浏览文件 @
03db51d7
...
...
@@ -5,8 +5,8 @@ namespace oneflow {
void
MdDiffAccCompTaskNode
::
BuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
gph
)
{
CHECK
(
IsFwNode
());
auto
md_diff_acc_gph
=
static_cast
<
MdDiffAccTaskGraph
*>
(
gph
);
CompTaskNode
*
fw_task_
=
auto
md_diff_acc_gph
=
static_cast
<
MdDiffAccTaskGraph
*>
(
gph
);
CompTaskNode
*
fw_task_
=
md_diff_acc_gph
->
GetFwTaskFromParallelId
(
parallel_id
());
TaskNode
*
bp_task
=
fw_task_
->
GetBpNode
();
std
::
shared_ptr
<
RegstDesc
>
model_diff_regst
=
...
...
@@ -30,7 +30,7 @@ void MdDiffAccCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
exec_node
->
BindBnInOpAndRegst
(
ibn
,
GetRelatedRegst
(
SoleInEdge
()));
SubscribeRegstDesc
(
ibn
,
GetRelatedRegst
(
SoleInEdge
()));
}
exec_node
->
BindBnInOpAndRegst
(
exec_node
->
op
()
->
SoleObn
(),
exec_node
->
BindBnInOpAndRegst
(
exec_node
->
op
()
->
SoleObn
(),
model_diff_acc_regst
);
mut_exec_gph
().
UpdateSourceAndSink
();
}
...
...
@@ -38,10 +38,11 @@ void MdDiffAccCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
void
MdDiffAccCompTaskNode
::
InferShapeOfBlobsInProducedRegsts
(
TaskGraph
*
gph
)
{
CHECK
(
IsFwNode
());
if
(
!
chain_node
()
->
op_vec
().
empty
())
{
std
::
shared_ptr
<
RegstDesc
>
in_regst
=
GetSubscribedRegstDesc
(
"model_diff"
);
std
::
shared_ptr
<
RegstDesc
>
out_regst
=
GetProducedRegstDesc
(
"model_diff_acc"
);
std
::
shared_ptr
<
RegstDesc
>
in_regst
=
GetSubscribedRegstDesc
(
"model_diff"
);
std
::
shared_ptr
<
RegstDesc
>
out_regst
=
GetProducedRegstDesc
(
"model_diff_acc"
);
out_regst
->
CopyShapeFrom
(
in_regst
.
get
());
}
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/model_diff_accumulate_comp_task_node.h
浏览文件 @
03db51d7
...
...
@@ -13,23 +13,23 @@ class MdDiffAccCompTaskNode final : public CompTaskNode {
void
ToProto
(
TaskProto
*
proto
)
const
override
{
TaskNode
::
ToProto
(
proto
);
proto
->
set_parallel_policy
(
fw_task_
->
chain_node
()
->
parallel_desc
()
->
policy
());
proto
->
set_parallel_policy
(
fw_task_
->
chain_node
()
->
parallel_desc
()
->
policy
());
proto
->
set_parallel_id
(
fw_task_
->
parallel_id
());
proto
->
set_parallel_num
(
fw_task_
->
chain_node
()
->
parallel_desc
()
->
parallel_num
());
proto
->
set_parallel_num
(
fw_task_
->
chain_node
()
->
parallel_desc
()
->
parallel_num
());
}
private:
void
BuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
gph
)
override
;
void
InferShapeOfBlobsInProducedRegsts
(
TaskGraph
*
gph
)
override
;
TaskType
task_type
()
const
override
{
return
kMdDiffAccCompTask
;
}
TaskType
task_type
()
const
override
{
return
kMdDiffAccCompTask
;
}
std
::
unique_ptr
<
TaskNode
>
CreateSameTypeNode
()
const
override
{
return
of_make_unique
<
MdDiffAccCompTaskNode
>
();
return
of_make_unique
<
MdDiffAccCompTaskNode
>
();
}
CompTaskNode
*
fw_task_
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_DIFF_ACCUMULATE_COMP_TASK_NODE_H_
#endif
// ONEFLOW_CORE_GRAPH_MODEL_DIFF_ACCUMULATE_COMP_TASK_NODE_H_
oneflow/core/graph/model_diff_accumulate_task_graph.cpp
浏览文件 @
03db51d7
...
...
@@ -4,8 +4,7 @@
namespace
oneflow
{
MdDiffAccTaskGraph
::
MdDiffAccTaskGraph
(
const
std
::
string
&
name
,
const
ChainNode
*
data_chain
,
const
std
::
string
&
name
,
const
ChainNode
*
data_chain
,
const
std
::
vector
<
CompTaskNode
*>&
sorted_fw_comptasks4data_chain
)
{
mut_name
()
=
name
;
BuildTaskGraph
(
data_chain
);
...
...
@@ -22,7 +21,7 @@ void MdDiffAccTaskGraph::BuildTaskGraph(const ChainNode* data_chain) {
op_conf
.
mutable_model_diff_acc_conf
();
auto
model_diff_acc_op
=
OpMgr
::
Singleton
().
ConstructOp
(
op_conf
);
// ModelDiffAccChain
auto
chain_gph
=
of_make_unique
<
ChainGraph
>
();
auto
chain_gph
=
of_make_unique
<
ChainGraph
>
();
ChainNode
*
diff_acc_chain
=
chain_gph
->
NewNode
();
diff_acc_chain
->
mut_op_vec
()
=
{
model_diff_acc_op
};
auto
parallel_desc4diff_acc
=
...
...
@@ -46,4 +45,4 @@ void MdDiffAccTaskGraph::BuildTaskGraph(const ChainNode* data_chain) {
BuildFromChainGph
<
MdDiffAccCompTaskNode
>
(
std
::
move
(
chain_gph
),
false
);
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/model_diff_accumulate_task_graph.h
浏览文件 @
03db51d7
...
...
@@ -12,14 +12,13 @@ class MdDiffAccTaskGraph final : public TaskGraph {
~
MdDiffAccTaskGraph
()
=
default
;
MdDiffAccTaskGraph
(
const
std
::
string
&
name
,
const
ChainNode
*
data_chain
,
const
std
::
string
&
name
,
const
ChainNode
*
data_chain
,
const
std
::
vector
<
CompTaskNode
*>&
sorted_fw_comptasks4data_chain
);
CompTaskNode
*
GetFwTaskFromParallelId
(
int64_t
parallel_id
)
const
{
return
parallel_id2fw_task_
.
at
(
parallel_id
);
}
const
char
*
TypeName
()
const
override
{
return
"MdDiffAccTaskGraph"
;
}
private:
...
...
@@ -28,6 +27,6 @@ class MdDiffAccTaskGraph final : public TaskGraph {
HashMap
<
int64_t
,
CompTaskNode
*>
parallel_id2fw_task_
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_DIFF_ACCUMULATE_TASK_GRAPH_H_
#endif
// ONEFLOW_CORE_GRAPH_MODEL_DIFF_ACCUMULATE_TASK_GRAPH_H_
oneflow/core/graph/model_save_comp_task_node.cpp
浏览文件 @
03db51d7
#include "oneflow/core/graph/model_save_comp_task_node.h"
#include "oneflow/core/graph/model_update_comp_task_node.h"
#include "oneflow/core/graph/model_save_task_graph.h"
#include "oneflow/core/graph/model_update_comp_task_node.h"
namespace
oneflow
{
void
MdSaveCompTaskNode
::
BuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
gph
)
{
CHECK
(
IsFwNode
());
auto
md_save_gph
=
static_cast
<
MdSaveTaskGraph
*>
(
gph
);
auto
md_save_gph
=
static_cast
<
MdSaveTaskGraph
*>
(
gph
);
CompTaskNode
*
updt_task
=
md_save_gph
->
update_task
();
if
(
in_edges
().
empty
())
{
BindProducedRegstAndOutEdge
(
updt_task
->
GetProducedRegstDesc
(
"model"
),
...
...
@@ -36,4 +36,4 @@ void MdSaveCompTaskNode::InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) {
CHECK
(
IsFwNode
());
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/model_save_comp_task_node.h
浏览文件 @
03db51d7
...
...
@@ -10,12 +10,14 @@ class MdSaveCompTaskNode final : public CompTaskNode {
OF_DISALLOW_COPY_AND_MOVE
(
MdSaveCompTaskNode
);
MdSaveCompTaskNode
()
=
default
;
~
MdSaveCompTaskNode
()
=
default
;
void
ToProto
(
TaskProto
*
proto
)
const
override
{
TaskNode
::
ToProto
(
proto
);
proto
->
set_parallel_policy
(
fw_task_
->
chain_node
()
->
parallel_desc
()
->
policy
());
proto
->
set_parallel_policy
(
fw_task_
->
chain_node
()
->
parallel_desc
()
->
policy
());
proto
->
set_parallel_id
(
fw_task_
->
parallel_id
());
proto
->
set_parallel_num
(
fw_task_
->
chain_node
()
->
parallel_desc
()
->
parallel_num
());
proto
->
set_parallel_num
(
fw_task_
->
chain_node
()
->
parallel_desc
()
->
parallel_num
());
}
void
set_fw_task
(
CompTaskNode
*
fw_task
)
{
fw_task_
=
fw_task
;
}
...
...
@@ -28,15 +30,13 @@ class MdSaveCompTaskNode final : public CompTaskNode {
return
!
GetSubscribedRegstDesc
(
"model"
);
}
TaskType
task_type
()
const
override
{
return
kMdSaveCompTask
;
}
TaskType
task_type
()
const
override
{
return
kMdSaveCompTask
;
}
std
::
unique_ptr
<
TaskNode
>
CreateSameTypeNode
()
const
override
{
return
of_make_unique
<
MdSaveCompTaskNode
>
();
return
of_make_unique
<
MdSaveCompTaskNode
>
();
}
CompTaskNode
*
fw_task_
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_SAVE_COMP_TASK_NODE_H_
#endif
// ONEFLOW_CORE_GRAPH_MODEL_SAVE_COMP_TASK_NODE_H_
oneflow/core/graph/model_save_task_graph.cpp
浏览文件 @
03db51d7
...
...
@@ -13,12 +13,13 @@ MdSaveTaskGraph::MdSaveTaskGraph(const std::string& name,
}
void
MdSaveTaskGraph
::
BuildTaskGraph
()
{
auto
chain_gph
=
of_make_unique
<
ChainGraph
>
();
auto
chain_gph
=
of_make_unique
<
ChainGraph
>
();
// faker
ChainNode
*
faker_chain
=
chain_gph
->
NewNode
();
ParallelConf
faker_pr_conf
;
faker_pr_conf
.
set_policy
(
kDataParallel
);
faker_pr_conf
.
mutable_device_set
()
->
add_device_name
(
update_task_
->
device_name
());
faker_pr_conf
.
mutable_device_set
()
->
add_device_name
(
update_task_
->
device_name
());
faker_chain
->
mut_parallel_desc
().
reset
(
new
ParallelDesc
(
faker_pr_conf
));
faker_chain
->
mut_output_lbns
()
=
{
kBaledBlobName
};
// save
...
...
@@ -27,7 +28,8 @@ void MdSaveTaskGraph::BuildTaskGraph() {
GetMachineNameFromDeviceName
(
update_task_
->
device_name
());
ParallelConf
save_pr_conf
;
save_pr_conf
.
set_policy
(
kDataParallel
);
save_pr_conf
.
mutable_device_set
()
->
add_device_name
(
machine_name
+
":persistence"
);
save_pr_conf
.
mutable_device_set
()
->
add_device_name
(
machine_name
+
":persistence"
);
save_chain
->
mut_parallel_desc
().
reset
(
new
ParallelDesc
(
save_pr_conf
));
save_chain
->
mut_input_lbns
()
=
{
kBaledBlobName
};
//
...
...
@@ -40,9 +42,10 @@ void MdSaveTaskGraph::BuildTaskGraph() {
if
(
model_save_comp_task_node
!=
nullptr
)
{
auto
model_update_comp_task_node
=
static_cast
<
MdUpdtCompTaskNode
*>
(
update_task_
);
model_save_comp_task_node
->
set_fw_task
(
model_update_comp_task_node
->
fw_task
());
model_save_comp_task_node
->
set_fw_task
(
model_update_comp_task_node
->
fw_task
());
}
});
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/model_save_task_graph.h
浏览文件 @
03db51d7
...
...
@@ -11,8 +11,7 @@ class MdSaveTaskGraph final : public TaskGraph {
MdSaveTaskGraph
()
=
delete
;
~
MdSaveTaskGraph
()
=
default
;
MdSaveTaskGraph
(
const
std
::
string
&
name
,
CompTaskNode
*
update_task
);
MdSaveTaskGraph
(
const
std
::
string
&
name
,
CompTaskNode
*
update_task
);
CompTaskNode
*
update_task
()
const
{
return
update_task_
;
}
const
char
*
TypeName
()
const
override
{
return
"MdSaveTaskGraph"
;
}
...
...
@@ -23,6 +22,6 @@ class MdSaveTaskGraph final : public TaskGraph {
CompTaskNode
*
update_task_
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_SAVE_TASK_GRAPH_H_
#endif
// ONEFLOW_CORE_GRAPH_MODEL_SAVE_TASK_GRAPH_H_
oneflow/core/graph/model_update_comp_task_node.cpp
浏览文件 @
03db51d7
...
...
@@ -6,7 +6,7 @@ namespace oneflow {
void
MdUpdtCompTaskNode
::
BuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
gph
)
{
CHECK
(
IsFwNode
());
auto
md_updt_gph
=
static_cast
<
MdUpdtTaskGraph
*>
(
gph
);
auto
md_updt_gph
=
static_cast
<
MdUpdtTaskGraph
*>
(
gph
);
CompTaskNode
*
fw_task
=
md_updt_gph
->
fw_task
();
CompTaskNode
*
diff_acc_task
=
md_updt_gph
->
diff_acc_task
();
std
::
shared_ptr
<
RegstDesc
>
model_diff_acc_regst
;
...
...
@@ -33,4 +33,4 @@ void MdUpdtCompTaskNode::InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) {
CHECK
(
IsFwNode
());
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/model_update_comp_task_node.h
浏览文件 @
03db51d7
...
...
@@ -13,9 +13,11 @@ class MdUpdtCompTaskNode final : public CompTaskNode {
void
ToProto
(
TaskProto
*
proto
)
const
override
{
TaskNode
::
ToProto
(
proto
);
proto
->
set_parallel_policy
(
fw_task_
->
chain_node
()
->
parallel_desc
()
->
policy
());
proto
->
set_parallel_policy
(
fw_task_
->
chain_node
()
->
parallel_desc
()
->
policy
());
proto
->
set_parallel_id
(
fw_task_
->
parallel_id
());
proto
->
set_parallel_num
(
fw_task_
->
chain_node
()
->
parallel_desc
()
->
parallel_num
());
proto
->
set_parallel_num
(
fw_task_
->
chain_node
()
->
parallel_desc
()
->
parallel_num
());
}
void
set_fw_task
(
CompTaskNode
*
fw_task
)
{
fw_task_
=
fw_task
;
}
...
...
@@ -24,15 +26,13 @@ class MdUpdtCompTaskNode final : public CompTaskNode {
private:
void
BuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
gph
)
override
;
void
InferShapeOfBlobsInProducedRegsts
(
TaskGraph
*
gph
)
override
;
TaskType
task_type
()
const
override
{
return
kMdUpdtCompTask
;
}
TaskType
task_type
()
const
override
{
return
kMdUpdtCompTask
;
}
std
::
unique_ptr
<
TaskNode
>
CreateSameTypeNode
()
const
override
{
return
of_make_unique
<
MdUpdtCompTaskNode
>
();
return
of_make_unique
<
MdUpdtCompTaskNode
>
();
}
CompTaskNode
*
fw_task_
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_UPDATE_COMP_TASK_NODE_H_
#endif
// ONEFLOW_CORE_GRAPH_MODEL_UPDATE_COMP_TASK_NODE_H_
oneflow/core/graph/model_update_task_graph.cpp
浏览文件 @
03db51d7
...
...
@@ -3,8 +3,7 @@
namespace
oneflow
{
MdUpdtTaskGraph
::
MdUpdtTaskGraph
(
const
std
::
string
&
name
,
CompTaskNode
*
fw_task
,
MdUpdtTaskGraph
::
MdUpdtTaskGraph
(
const
std
::
string
&
name
,
CompTaskNode
*
fw_task
,
CompTaskNode
*
diff_acc_task
)
{
mut_name
()
=
name
;
fw_task_
=
fw_task
;
...
...
@@ -14,9 +13,9 @@ MdUpdtTaskGraph::MdUpdtTaskGraph(const std::string& name,
}
void
MdUpdtTaskGraph
::
BuildTaskGraph
()
{
auto
chain_gph
=
of_make_unique
<
ChainGraph
>
();
auto
chain_gph
=
of_make_unique
<
ChainGraph
>
();
OperatorConf
op_conf
;
op_conf
.
set_name
(
"model_update_"
+
NewUniqueId
());
op_conf
.
set_name
(
"model_update_"
+
NewUniqueId
());
op_conf
.
mutable_model_update_conf
();
auto
model_updt_op
=
OpMgr
::
Singleton
().
ConstructOp
(
op_conf
);
...
...
@@ -39,4 +38,4 @@ void MdUpdtTaskGraph::BuildTaskGraph() {
});
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/model_update_task_graph.h
浏览文件 @
03db51d7
...
...
@@ -11,8 +11,7 @@ class MdUpdtTaskGraph final : public TaskGraph {
MdUpdtTaskGraph
()
=
delete
;
~
MdUpdtTaskGraph
()
=
default
;
MdUpdtTaskGraph
(
const
std
::
string
&
name
,
CompTaskNode
*
fw_task
,
MdUpdtTaskGraph
(
const
std
::
string
&
name
,
CompTaskNode
*
fw_task
,
CompTaskNode
*
diff_acc_task
);
CompTaskNode
*
fw_task
()
const
{
return
fw_task_
;
}
...
...
@@ -26,6 +25,6 @@ class MdUpdtTaskGraph final : public TaskGraph {
CompTaskNode
*
diff_acc_task_
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_UPDATE_TASK_GRAPH_H_
#endif
// ONEFLOW_CORE_GRAPH_MODEL_UPDATE_TASK_GRAPH_H_
oneflow/core/graph/node.cpp
浏览文件 @
03db51d7
...
...
@@ -12,4 +12,4 @@ int64_t NewEdgeId() {
return
edge_id
++
;
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/node.h
浏览文件 @
03db51d7
...
...
@@ -8,9 +8,7 @@
namespace
oneflow
{
template
<
typename
NodeType
,
typename
EdgeType
>
void
Connect
(
NodeType
*
src_node
,
EdgeType
*
edge
,
NodeType
*
dst_node
)
{
void
Connect
(
NodeType
*
src_node
,
EdgeType
*
edge
,
NodeType
*
dst_node
)
{
CHECK
(
src_node
->
out_edges_
.
insert
(
edge
).
second
);
CHECK
(
dst_node
->
in_edges_
.
insert
(
edge
).
second
);
CHECK
(
edge
->
src_node_
==
nullptr
);
...
...
@@ -50,25 +48,21 @@ class Edge {
virtual
std
::
string
VisualStr
()
const
{
return
""
;
}
private:
friend
void
Connect
<
NodeType
,
EdgeType
>
(
NodeType
*
src_node
,
EdgeType
*
edge
,
friend
void
Connect
<
NodeType
,
EdgeType
>
(
NodeType
*
src_node
,
EdgeType
*
edge
,
NodeType
*
dst_node
);
friend
void
DisConnect
<
EdgeType
>
(
EdgeType
*
edge
);
int64_t
edge_id_
;
NodeType
*
src_node_
;
NodeType
*
dst_node_
;
};
template
<
typename
NodeType
,
typename
EdgeType
>
class
Node
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
Node
);
Node
()
{
node_id_
=
NewNodeId
();
}
Node
()
{
node_id_
=
NewNodeId
();
}
virtual
~
Node
()
=
default
;
int64_t
node_id
()
const
{
return
node_id_
;
}
...
...
@@ -82,36 +76,26 @@ class Node {
return
*
(
out_edges_
.
begin
());
}
const
std
::
unordered_set
<
EdgeType
*>&
in_edges
()
const
{
return
in_edges_
;
}
const
std
::
unordered_set
<
EdgeType
*>&
out_edges
()
const
{
return
out_edges_
;
}
const
std
::
unordered_set
<
EdgeType
*>&
in_edges
()
const
{
return
in_edges_
;
}
const
std
::
unordered_set
<
EdgeType
*>&
out_edges
()
const
{
return
out_edges_
;
}
void
DisconnectAllEdges
()
{
for
(
EdgeType
*
edge
:
in_edges_
)
{
DisConnect
(
edge
);
}
for
(
EdgeType
*
edge
:
out_edges_
)
{
DisConnect
(
edge
);
}
for
(
EdgeType
*
edge
:
in_edges_
)
{
DisConnect
(
edge
);
}
for
(
EdgeType
*
edge
:
out_edges_
)
{
DisConnect
(
edge
);
}
}
virtual
std
::
string
VisualStr
()
const
{
return
""
;
}
private:
friend
void
Connect
<
NodeType
,
EdgeType
>
(
NodeType
*
src_node
,
EdgeType
*
edge
,
friend
void
Connect
<
NodeType
,
EdgeType
>
(
NodeType
*
src_node
,
EdgeType
*
edge
,
NodeType
*
dst_node
);
friend
void
DisConnect
<
EdgeType
>
(
EdgeType
*
edge
);
int64_t
node_id_
;
std
::
unordered_set
<
EdgeType
*>
in_edges_
;
std
::
unordered_set
<
EdgeType
*>
out_edges_
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_NODE_H_
#endif
// ONEFLOW_CORE_GRAPH_NODE_H_
oneflow/core/graph/out_boxing_task_node.cpp
浏览文件 @
03db51d7
...
...
@@ -4,15 +4,12 @@ namespace oneflow {
void
OutBoxingTaskNode
::
FwVirtualBuild
()
{
Chain2EdgesMap
chain2sorted_out_edges
;
FwInitChain2SortedEdgesMaps
(
&
chain2sorted_out_edges
,
&
TaskNode
::
out_edges
,
&
TaskEdge
::
dst_node
,
&
TaskNode
::
SoleOutEdge
);
FwInitChain2SortedEdgesMaps
(
&
chain2sorted_out_edges
,
&
TaskNode
::
out_edges
,
&
TaskEdge
::
dst_node
,
&
TaskNode
::
SoleOutEdge
);
ChainEdgesPair
chain_sorted_in_edges
;
chain_sorted_in_edges
.
first
=
chain_node
();
chain_sorted_in_edges
.
second
.
assign
(
in_edges
().
begin
(),
in_edges
().
end
());
FwSortEdgesInnerStage
(
&
chain_sorted_in_edges
.
second
,
&
TaskEdge
::
src_node
,
FwSortEdgesInnerStage
(
&
chain_sorted_in_edges
.
second
,
&
TaskEdge
::
src_node
,
&
TaskNode
::
SoleInEdge
);
for
(
const
ChainEdgesPair
&
chain_sorted_out_edges
:
chain2sorted_out_edges
)
{
FwBuildChainSortedEdgesPair
(
chain_sorted_in_edges
,
chain_sorted_out_edges
);
...
...
@@ -20,4 +17,4 @@ void OutBoxingTaskNode::FwVirtualBuild() {
mut_exec_gph
().
UpdateSourceAndSink
();
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/out_boxing_task_node.h
浏览文件 @
03db51d7
...
...
@@ -13,15 +13,14 @@ class OutBoxingTaskNode final : public BoxingTaskNode {
private:
std
::
unique_ptr
<
TaskNode
>
CreateSameTypeNode
()
const
override
{
return
of_make_unique
<
OutBoxingTaskNode
>
();
return
of_make_unique
<
OutBoxingTaskNode
>
();
}
void
InitWithFwNode
(
TaskNode
*
fw_node
)
override
{
BoxingTaskNode
::
InitWithFwNode
(
fw_node
);
}
void
FwVirtualBuild
()
override
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_OUT_BOXING_TASK_NODE_H_
#endif
// ONEFLOW_CORE_GRAPH_OUT_BOXING_TASK_NODE_H_
oneflow/core/graph/stage_graph.cpp
浏览文件 @
03db51d7
...
...
@@ -19,7 +19,7 @@ StageGraph::StageGraph(std::unique_ptr<const ChainGraph>&& chain_gph) {
size_t
device_num
=
parallel_desc
->
sorted_device_phy_ids
(
machine_id
).
size
();
if
(
device_num
==
0
)
{
device_num
=
1
;
// persistence
device_num
=
1
;
// persistence
}
range_idx
+=
device_num
;
stage_node
->
mut_parallel_range
().
mut_end
()
=
range_idx
;
...
...
@@ -43,4 +43,4 @@ StageGraph::StageGraph(std::unique_ptr<const ChainGraph>&& chain_gph) {
ToDotWithAutoFilePath
();
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/stage_graph.h
浏览文件 @
03db51d7
...
...
@@ -14,29 +14,17 @@ class StageNode final : public Node<StageNode, StageEdge> {
StageNode
()
=
default
;
~
StageNode
()
=
default
;
std
::
string
machine_id_str
()
const
{
return
std
::
to_string
(
machine_id_
);
}
const
int64_t
&
machine_id
()
const
{
return
machine_id_
;
}
int64_t
&
mut_machine_id
()
{
return
machine_id_
;
}
std
::
string
machine_id_str
()
const
{
return
std
::
to_string
(
machine_id_
);
}
const
int64_t
&
machine_id
()
const
{
return
machine_id_
;
}
int64_t
&
mut_machine_id
()
{
return
machine_id_
;
}
const
ChainNode
*
chain_node
()
const
{
return
chain_node_
;
}
const
ChainNode
*
chain_node
()
const
{
return
chain_node_
;
}
void
set_chain_node
(
const
ChainNode
*
new_chain_node
)
{
chain_node_
=
new_chain_node
;
}
const
Range
&
parallel_range
()
const
{
return
parallel_range_
;
}
Range
&
mut_parallel_range
()
{
return
parallel_range_
;
}
const
Range
&
parallel_range
()
const
{
return
parallel_range_
;
}
Range
&
mut_parallel_range
()
{
return
parallel_range_
;
}
const
std
::
vector
<
int64_t
>&
SortedDevicePhyIds
()
const
{
return
chain_node_
->
parallel_desc
()
->
sorted_device_phy_ids
(
machine_id_
);
...
...
@@ -50,7 +38,6 @@ class StageNode final : public Node<StageNode, StageEdge> {
const
ChainNode
*
chain_node_
;
int64_t
machine_id_
;
Range
parallel_range_
;
};
class
StageEdge
final
:
public
Edge
<
StageNode
,
StageEdge
>
{
...
...
@@ -58,7 +45,7 @@ class StageEdge final : public Edge<StageNode, StageEdge> {
OF_DISALLOW_COPY_AND_MOVE
(
StageEdge
);
StageEdge
()
=
default
;
~
StageEdge
()
=
default
;
private:
};
...
...
@@ -75,9 +62,8 @@ class StageGraph final : public Graph<StageNode, StageEdge> {
private:
std
::
unique_ptr
<
const
ChainGraph
>
chain_gph_
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_STAGE_GRAPH_H_
#endif
// ONEFLOW_CORE_GRAPH_STAGE_GRAPH_H_
oneflow/core/graph/task_graph.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/graph/task_graph.h
浏览文件 @
03db51d7
#ifndef ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_
#include "oneflow/core/graph/stage_graph.h"
#include "oneflow/core/graph/boxing_task_node.h"
#include "oneflow/core/graph/copy_task_node.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/operator/operator_manager.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/graph/stage_graph.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/operator/operator_manager.h"
namespace
oneflow
{
...
...
@@ -16,22 +16,21 @@ class TaskGraph : public Graph<TaskNode, TaskEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE
(
TaskGraph
);
virtual
~
TaskGraph
()
=
default
;
// Getters
const
StageGraph
*
stage_gph
()
const
{
return
stage_gph_
.
get
();
}
const
ChainGraph
*
chain_gph
()
const
{
return
stage_gph_
->
chain_gph
();
}
std
::
vector
<
CompTaskNode
*>
CompTasksInChain
(
const
ChainNode
*
);
void
InferShapeOfBlobsInProducedRegsts
();
const
std
::
string
&
name
()
const
{
return
name_
;
}
protected:
TaskGraph
()
=
default
;
template
<
typename
CompTaskNodeType
>
void
BuildFromChainGph
(
std
::
unique_ptr
<
ChainGraph
>&&
chain_gph
,
bool
need_bp
);
void
BuildFromChainGph
(
std
::
unique_ptr
<
ChainGraph
>&&
chain_gph
,
bool
need_bp
);
void
BuildExecAndEnrollLbn2Regsts
();
std
::
string
&
mut_name
()
{
return
name_
;
}
...
...
@@ -55,9 +54,8 @@ class TaskGraph : public Graph<TaskNode, TaskEdge> {
BoxingTaskNode
*
in_boxing_task_node
;
BoxingTaskNode
*
out_boxing_task_node
;
};
using
Stage2TaskNodesMap
=
HashMap
<
const
StageNode
*
,
TaskNodesInStage
>
;
using
Stage2TaskNodesMap
=
HashMap
<
const
StageNode
*
,
TaskNodesInStage
>
;
template
<
typename
TaskNodeType
>
void
InitCompTaskNodes
(
Stage2TaskNodesMap
*
stage2task_nodes
);
...
...
@@ -73,15 +71,14 @@ class TaskGraph : public Graph<TaskNode, TaskEdge> {
void
InitOutBoxingTaskNode
(
const
StageNode
*
stage
,
TaskNodesInStage
*
task_nodes_in_stage
);
void
ConnectBoxingTaskNodes
(
const
Stage2TaskNodesMap
*
stage2task_nodes
);
void
GenerateRelatedBpNodes
(
std
::
vector
<
TaskNode
*>
*
turning_node_vec
);
void
GenerateRelatedBpNodes
(
std
::
vector
<
TaskNode
*>
*
turning_node_vec
);
void
BackwardConnect
(
const
std
::
vector
<
TaskNode
*>&
turning_node_vec
);
void
BuildBpStruct
();
std
::
unique_ptr
<
const
StageGraph
>
stage_gph_
;
std
::
string
name_
;
std
::
string
name_
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_
#endif
// ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_
oneflow/core/graph/task_node.cpp
浏览文件 @
03db51d7
...
...
@@ -2,7 +2,10 @@
namespace
oneflow
{
TaskNode
::
TaskNode
()
:
produced_regst2out_edge_
(
11
,
[](
const
std
::
weak_ptr
<
RegstDesc
>&
v
)
{
return
std
::
hash
<
void
*>
()
(
v
.
lock
().
get
());
})
{
TaskNode
::
TaskNode
()
:
produced_regst2out_edge_
(
11
,
[](
const
std
::
weak_ptr
<
RegstDesc
>&
v
)
{
return
std
::
hash
<
void
*>
()(
v
.
lock
().
get
());
})
{
stage_node_
=
nullptr
;
related_fw_or_bp_node_
=
nullptr
;
}
...
...
@@ -75,10 +78,11 @@ void TaskNode::TakeOverRegstDesc(TaskNode* rhs,
}
void
TaskNode
::
EraseProducedEmptyRegsts
()
{
EraseIf
<
std
::
string
,
std
::
shared_ptr
<
RegstDesc
>>
(
&
produced_regst_descs_
,
[]
(
HashMap
<
std
::
string
,
std
::
shared_ptr
<
RegstDesc
>>::
iterator
it
)
{
return
it
->
second
->
NumOfLbn
()
==
0
;
});
EraseIf
<
std
::
string
,
std
::
shared_ptr
<
RegstDesc
>>
(
&
produced_regst_descs_
,
[](
HashMap
<
std
::
string
,
std
::
shared_ptr
<
RegstDesc
>>::
iterator
it
)
{
return
it
->
second
->
NumOfLbn
()
==
0
;
});
}
void
TaskNode
::
EraseZeroSizeBlobInProducedRegsts
()
{
...
...
@@ -113,7 +117,7 @@ void TaskNode::BindProducedRegstAndOutEdge(std::weak_ptr<RegstDesc> regst,
std
::
shared_ptr
<
RegstDesc
>
TaskNode
::
NewProducedRegstDesc
(
const
std
::
string
&
regst_desc_name
)
{
auto
regst_desc
=
std
::
make_shared
<
RegstDesc
>
();
auto
regst_desc
=
std
::
make_shared
<
RegstDesc
>
();
regst_desc
->
SetProducer
(
this
);
regst_desc
->
set_regst_desc_id
(
IDMgr
::
Singleton
().
NewRegstDescId
());
CHECK
(
produced_regst_descs_
.
emplace
(
regst_desc_name
,
regst_desc
).
second
);
...
...
@@ -136,14 +140,16 @@ void TaskNode::ToProto(TaskProto* ret) const {
for
(
const
auto
&
pair
:
produced_regst_descs_
)
{
RegstDescProto
regst_desc_proto
;
pair
.
second
->
ToProto
(
&
regst_desc_proto
);
CHECK
(
ret
->
mutable_produced_regst_desc
()
->
insert
(
{
pair
.
first
,
regst_desc_proto
}).
second
);
CHECK
(
ret
->
mutable_produced_regst_desc
()
->
insert
({
pair
.
first
,
regst_desc_proto
})
.
second
);
}
for
(
const
auto
&
pair
:
subscribed_regst_descs_
)
{
auto
regst_desc
=
pair
.
second
.
lock
();
if
(
regst_desc
)
{
CHECK
(
ret
->
mutable_subscribed_regst_desc_id
()
->
insert
(
{
pair
.
first
,
regst_desc
->
regst_desc_id
()}).
second
);
CHECK
(
ret
->
mutable_subscribed_regst_desc_id
()
->
insert
({
pair
.
first
,
regst_desc
->
regst_desc_id
()})
.
second
);
}
}
}
...
...
@@ -159,10 +165,10 @@ std::string TaskNode::DebugStr() const {
std
::
stringstream
ss
;
ss
<<
"{"
<<
node_id_str
()
<<
"
\t
"
;
for
(
const
auto
&
pair
:
produced_regst_descs_
)
{
ss
<<
"{"
<<
pair
.
first
<<
":"
<<
pair
.
second
->
DebugStr
()
<<
"}"
;
ss
<<
"{"
<<
pair
.
first
<<
":"
<<
pair
.
second
->
DebugStr
()
<<
"}"
;
}
ss
<<
"}"
;
return
ss
.
str
();
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/graph/task_node.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/job/compiler.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/job/id_manager.h
浏览文件 @
03db51d7
...
...
@@ -19,8 +19,10 @@ class IDMgr final {
machine_num_
=
resource
.
machine_size
();
CHECK_LT
(
machine_num_
,
static_cast
<
int64_t
>
(
1
)
<<
machine_id_bit_num_
);
device_num_per_machine_
=
resource
.
device_num_per_machine
();
// reserve 3 number of device_id for persistence_, boxing_ and commnet_ ThrdLocId
CHECK_LT
(
device_num_per_machine_
,
(
static_cast
<
int64_t
>
(
1
)
<<
device_id_bit_num_
)
-
3
);
// reserve 3 number of device_id for persistence_, boxing_ and commnet_
// ThrdLocId
CHECK_LT
(
device_num_per_machine_
,
(
static_cast
<
int64_t
>
(
1
)
<<
device_id_bit_num_
)
-
3
);
for
(
int64_t
i
=
0
;
i
<
machine_num_
;
++
i
)
{
const
std
::
string
&
machine_name
=
resource
.
machine
(
i
).
name
();
CHECK
(
machine_name2machine_id_
.
emplace
(
machine_name
,
i
).
second
);
...
...
@@ -51,20 +53,15 @@ class IDMgr final {
int64_t
machine_id64bit
=
machine_id
<<
(
63
-
machine_id_bit_num_
);
int64_t
device_id64bit
=
thrd_local_id
<<
task_id_bit_num_
;
int64_t
thrd_id
=
machine_id64bit
|
device_id64bit
;
CHECK_LT
(
thread_id2num_of_tasks_
[
thrd_id
],
(
static_cast
<
int64_t
>
(
1
)
<<
task_id_bit_num_
)
-
1
);
CHECK_LT
(
thread_id2num_of_tasks_
[
thrd_id
],
(
static_cast
<
int64_t
>
(
1
)
<<
task_id_bit_num_
)
-
1
);
return
thrd_id
|
(
thread_id2num_of_tasks_
[
thrd_id
]
++
);
}
int64_t
NewRegstDescId
()
{
return
regst_desc_id_count_
++
;
}
int64_t
NewRegstDescId
()
{
return
regst_desc_id_count_
++
;
}
// Runtime
int64_t
ActorId4TaskId
(
int64_t
task_id
)
{
return
task_id
;
}
int64_t
TaskId4ActorId
(
int64_t
actor_id
)
{
return
actor_id
;
}
int64_t
ActorId4TaskId
(
int64_t
task_id
)
{
return
task_id
;
}
int64_t
TaskId4ActorId
(
int64_t
actor_id
)
{
return
actor_id
;
}
int64_t
MachineId4ActorId
(
int64_t
actor_id
)
{
return
actor_id
>>
(
63
-
machine_id_bit_num_
);
}
...
...
@@ -99,4 +96,4 @@ class IDMgr final {
}
// namespace oneflow
#endif // ONEFLOW_CORE_JOB_ID_MANAGER_H_
#endif
// ONEFLOW_CORE_JOB_ID_MANAGER_H_
oneflow/core/job/id_manager_test.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/job/job_desc.cpp
浏览文件 @
03db51d7
...
...
@@ -50,4 +50,4 @@ void JobDesc::ToProto(JobDescProto* proto) const {
proto
->
set_total_batch_num
(
total_batch_num_
);
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/job/job_desc.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/job/keyword.cpp
浏览文件 @
03db51d7
...
...
@@ -4,4 +4,4 @@ namespace oneflow {
const
char
*
kBaledBlobName
=
"_oneflow_BaledBlobName"
;
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/job/keyword.h
浏览文件 @
03db51d7
...
...
@@ -5,6 +5,6 @@ namespace oneflow {
extern
const
char
*
kBaledBlobName
;
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_JOB_KEYWORD_H_
#endif
// ONEFLOW_CORE_JOB_KEYWORD_H_
oneflow/core/job/parallel_desc.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/job/parallel_desc.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/job/runtime.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/job/runtime_context.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/job/runtime_context.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/clone_kernel.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/clone_kernel.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/convolution_kernel.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/copy_hd_kernel.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/copy_hd_kernel_test.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/data_loader_kernel.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/data_loader_kernel.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/innerproduct_kernel.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/innerproduct_kernel.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/innerproduct_kernel_test.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/kernel.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/kernel.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/kernel_context.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/kernel_manager.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/kernel_manager.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/kernel_util.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/kernel_util.cu
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/kernel_util.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/model_save_kernel.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/kernel/model_save_kernel.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/memory/memory_allocator.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/memory/memory_allocator.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/boxing_op.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/boxing_op.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/boxing_op_test.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/clear_op.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/clone_op.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/clone_op.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/clone_op_test.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/concat_op.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/concat_op.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/concat_op_test.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/convolution_op.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/convolution_op.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/convolution_op_test.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/copy_comm_net_op.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/copy_comm_net_op.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/copy_hd_op.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/copy_hd_op.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/data_loader_op.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/data_loader_op.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/innerproduct_op.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/innerproduct_op.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/innerproduct_op_test.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/model_diff_accumulate_op.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/model_diff_accumulate_op.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/model_save_op.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/model_save_op.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/model_update_op.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/model_update_op.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/multinomial_logistic_loss_op.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/multinomial_logistic_loss_op.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/multinomial_logistic_loss_op_test.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/operator.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/operator.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/operator_manager.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/operator_manager.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/pooling_op.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/pooling_op.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/relu_op.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/relu_op.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/relu_op_test.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/softmax_op.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/softmax_op.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/operator/softmax_op_test.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/persistence/persistent_circular_line_reader.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/persistence/persistent_circular_line_reader.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/persistence/persistent_in_stream.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/persistence/persistent_in_stream.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/persistence/persistent_out_stream.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/persistence/snapshot.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/persistence/snapshot.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/persistence/snapshot_manager.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/persistence/snapshot_manager.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/persistence/snapshot_test.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/register/blob.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/register/local_register_warpper.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/register/register.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/register/register.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/register/register_desc.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/register/register_desc.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/register/register_manager.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/register/register_manager.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/register/register_warpper.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/register/remote_register_warpper.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/register/runtime_register_desc.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/register/runtime_register_desc.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/thread/cpu_thread.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/thread/cpu_thread.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/thread/gpu_thread.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/thread/gpu_thread.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/thread/thread.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/thread/thread.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/thread/thread_context.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/thread/thread_manager.cpp
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
oneflow/core/thread/thread_manager.h
浏览文件 @
03db51d7
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录