From 443b3dfd968a7506afcbfb89abf5ed45d8903971 Mon Sep 17 00:00:00 2001 From: lixinqi Date: Wed, 14 Aug 2019 00:02:22 +0800 Subject: [PATCH] copyhd-free output_op task_node --- oneflow/core/graph/task_graph.cpp | 16 ++++--- oneflow/core/job_completer/job_completer.cpp | 50 +++++++++----------- oneflow/core/operator/operator.h | 5 -- oneflow/core/operator/output_op.cpp | 1 - oneflow/core/operator/switch_output_op.cpp | 1 - 5 files changed, 33 insertions(+), 40 deletions(-) diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index d2f8e604c4..59cd4ba801 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -15,12 +15,12 @@ namespace oneflow { namespace { -bool IsOutputInterfaceTask(const TaskNode* node) { +bool IsInterfaceTask(const TaskNode* node) { const auto* comp_task_node = dynamic_cast(node); if (comp_task_node == nullptr) { return false; } if (comp_task_node->logical_node()->op_vec().size() != 1) { return false; } auto op_type_case = comp_task_node->logical_node()->SoleOp()->op_conf().op_type_case(); - return IsClassRegistered(op_type_case); + return IsClassRegistered(op_type_case); } bool IsConnectToTickOp(const TaskNode* node) { @@ -762,7 +762,7 @@ void TaskGraph::BuildTaskPath( || cur_node->MemZoneId121() != dst->MemZoneId121()) { cur_node = BuildTaskStep(cur_node, dst, GetBufTask, SetBufTask, use_buf_task_node); } - Connect(cur_node, NewEdge(), dst); + if (cur_node != dst) { Connect(cur_node, NewEdge(), dst); } } TaskNode* TaskGraph::BuildTaskStep( @@ -783,6 +783,7 @@ TaskNode* TaskGraph::BuildTaskStep( next_mem_zone_id = dst->MemZoneId121(); if (!use_buf_task_node || !(next_node = GetBufTask(cur_node->machine_id(), next_mem_zone_id))) { next_node = TryAddCopyH2DTaskTo(dst); + if (next_node == nullptr) { next_node = dst; } Connect(cur_node, NewEdge(), next_node); } } else if (cur_node->machine_id() != dst->machine_id()) { @@ -794,12 +795,14 @@ TaskNode* TaskGraph::BuildTaskStep( } else { UNIMPLEMENTED(); } - if (use_buf_task_node) { SetBufTask(next_node->machine_id(), next_mem_zone_id, next_node); } + if (use_buf_task_node && (next_node != dst)) { + SetBufTask(next_node->machine_id(), next_mem_zone_id, next_node); + } return next_node; } TaskNode* TaskGraph::TryAddCopyH2DTaskTo(TaskNode* task) { - if (IsOutputInterfaceTask(task)) { return task; } + if (IsInterfaceTask(task)) { return nullptr; } CHECK_EQ(task->device_type(), DeviceType::kGPU); CopyHdTaskNode* copy_task = NewNode(); copy_task->Init(CopyHdOpConf::H2D, task->machine_id(), task->GpuPhyId()); @@ -860,7 +863,8 @@ void TaskGraph::BuildInBoxing(const LogicalNode* logical, TaskNode* task = comp_task; if (task->device_type() == DeviceType::kGPU) { task = TryAddCopyH2DTaskTo(comp_task); - Connect(task, NewEdge(), comp_task); + if (task == nullptr) { task = comp_task; } + if (task != comp_task) { Connect(task, NewEdge(), comp_task); } } machine_id2bound_task[task->machine_id()].push_back(task); } diff --git a/oneflow/core/job_completer/job_completer.cpp b/oneflow/core/job_completer/job_completer.cpp index 0cd9429136..7141aa7caf 100644 --- a/oneflow/core/job_completer/job_completer.cpp +++ b/oneflow/core/job_completer/job_completer.cpp @@ -32,10 +32,11 @@ void WithOpGraphAndMutJob(Job* job, const std::function& Handler) { - OpGraph op_graph(job_builder->job()); - Handler(op_graph, job_builder); + OpGraph op_graph(*job); + JobBuilder job_builder(job); + Handler(op_graph, &job_builder); } void GenerateFacadeImplOpConfIf(const OpNode& op_node, JobBuilder* job_builder) { @@ -461,36 +462,31 @@ void JobCompleter::Complete(Job* job) const { // replace facade op SplitDecodeOps(job); AddRecordLoadOps(job); - auto job_builder = std::make_unique(job); - WithOpGraphAndMutJobBuilder(job_builder.get(), &ReplaceFacade); + WithOpGraphAndMutJobBuilder(job, &ReplaceFacade); // complete variable ops - WithOpGraphAndMutJobBuilder(job_builder.get(), &AutoVar); - WithOpGraphAndMutJobBuilder(job_builder.get(), &SetDefaultVariableConf); + WithOpGraphAndMutJobBuilder(job, &AutoVar); + WithOpGraphAndMutJobBuilder(job, &SetDefaultVariableConf); if (GlobalJobDesc().IsTrain()) { WithOpGraphAndMutJob(job, &TieUpChainHeadersUnReachableFromAnyVariableOps); - job_builder.reset(new JobBuilder(job)); - WithOpGraphAndMutJobBuilder(job_builder.get(), &EnableAutoMixedPrecision); + WithOpGraphAndMutJobBuilder(job, &EnableAutoMixedPrecision); // complete ops for trainning - WithOpGraphAndMutJobBuilder(job_builder.get(), &GenerateOpConf4Trainning); - WithOpGraphAndMutJobBuilder(job_builder.get(), &RewriteBoxingWithAllReduce); - WithOpGraphAndMutJobBuilder(job_builder.get(), &MakeAllReduceSequence); + WithOpGraphAndMutJobBuilder(job, &GenerateOpConf4Trainning); + WithOpGraphAndMutJobBuilder(job, &RewriteBoxingWithAllReduce); + WithOpGraphAndMutJobBuilder(job, &MakeAllReduceSequence); } - WithOpGraphAndMutJobBuilder(job_builder.get(), &DumpLogicalBlobDescAndSbpSignature); - WithOpGraphAndMutJobBuilder(job_builder.get(), &GroupBoxingByDstParallel); - WithOpGraphAndMutJobBuilder(job_builder.get(), &AddKeepHeaderOnlyOp); - WithOpGraphAndMutJobBuilder(job_builder.get(), &SetCtrlInOpName4VariableOp); + WithOpGraphAndMutJobBuilder(job, &DumpLogicalBlobDescAndSbpSignature); + WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel); + WithOpGraphAndMutJobBuilder(job, &AddKeepHeaderOnlyOp); + WithOpGraphAndMutJobBuilder(job, &SetCtrlInOpName4VariableOp); // complete tick ops - job_builder.reset(new JobBuilder(job)); - WithOpGraphAndMutJobBuilder(job_builder.get(), &AutoSourceTick); - job_builder.reset(new JobBuilder(job)); - WithOpGraphAndMutJobBuilder(job_builder.get(), &AddTickForTimeShape); - job_builder.reset(new JobBuilder(job)); - WithOpGraphAndMutJobBuilder(job_builder.get(), &AutoSinkTick); - AddGlobalTotalJobCriticalSection(job_builder->job()); - WithOpGraphAndMutJobBuilder(job_builder.get(), &AddGlobalInputCriticalSections); - WithOpGraphAndMutJobBuilder(job_builder.get(), &AddGlobalOutputCriticalSections); - WithOpGraphAndMutJobBuilder(job_builder.get(), &DumpLogicalBlobDescAndSbpSignature); - WithOpGraphAndMutJobBuilder(job_builder.get(), &SetOpTimeShape7BatchDimLbis); + WithOpGraphAndMutJobBuilder(job, &AutoSourceTick); + WithOpGraphAndMutJobBuilder(job, &AddTickForTimeShape); + WithOpGraphAndMutJobBuilder(job, &AutoSinkTick); + AddGlobalTotalJobCriticalSection(*job); + WithOpGraphAndMutJobBuilder(job, &AddGlobalInputCriticalSections); + WithOpGraphAndMutJobBuilder(job, &AddGlobalOutputCriticalSections); + WithOpGraphAndMutJobBuilder(job, &DumpLogicalBlobDescAndSbpSignature); + WithOpGraphAndMutJobBuilder(job, &SetOpTimeShape7BatchDimLbis); CheckOpGraph(OpGraph(*job)); } diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h index 8d6630f5e7..adbafa988e 100644 --- a/oneflow/core/operator/operator.h +++ b/oneflow/core/operator/operator.h @@ -313,7 +313,6 @@ struct RuntimeMemBlockNum4OpSameOutputBlob final { }; struct IsInterfaceOpConf4OpTypeCase final {}; -struct IsOutputInterfaceOpConf4OpTypeCase final {}; #define REGISTER_OP(op_type_case, OpType) \ REGISTER_CLASS_CREATOR(op_type_case, OnlyCpuSupportPredicator, \ @@ -338,10 +337,6 @@ struct IsOutputInterfaceOpConf4OpTypeCase final {}; REGISTER_CLASS_CREATOR(op_type_case, IsInterfaceOpConf4OpTypeCase, \ ([] { return new IsInterfaceOpConf4OpTypeCase(); })) -#define REGISTER_OUTPUT_INTERFACE_OP(op_type_case) \ - REGISTER_CLASS_CREATOR(op_type_case, IsOutputInterfaceOpConf4OpTypeCase, \ - ([] { return new IsOutputInterfaceOpConf4OpTypeCase(); })) - std::shared_ptr ConstructOp(const OperatorConf& op_conf); inline std::shared_ptr ConstructOp(const OperatorConf& op_conf, DeviceType device_type) { diff --git a/oneflow/core/operator/output_op.cpp b/oneflow/core/operator/output_op.cpp index 7ff370d6e4..c8a960cbb7 100644 --- a/oneflow/core/operator/output_op.cpp +++ b/oneflow/core/operator/output_op.cpp @@ -34,6 +34,5 @@ void OutputOp::GetSbpSignatures( REGISTER_OP(OperatorConf::kOutputConf, OutputOp); REGISTER_OP_SAME_OUTPUT_BLOB_MEM_BLOCK_NUM(OperatorConf::kOutputConf, 1); REGISTER_INTERFACE_OP(OperatorConf::kOutputConf); -REGISTER_OUTPUT_INTERFACE_OP(OperatorConf::kOutputConf); } // namespace oneflow diff --git a/oneflow/core/operator/switch_output_op.cpp b/oneflow/core/operator/switch_output_op.cpp index 795ad9a1cc..0392155a4a 100644 --- a/oneflow/core/operator/switch_output_op.cpp +++ b/oneflow/core/operator/switch_output_op.cpp @@ -51,6 +51,5 @@ void SwitchOutputOp::GetSbpSignatures( REGISTER_OP(OperatorConf::kSwitchOutputConf, SwitchOutputOp); REGISTER_OP_SAME_OUTPUT_BLOB_MEM_BLOCK_NUM(OperatorConf::kSwitchOutputConf, 1); REGISTER_INTERFACE_OP(OperatorConf::kSwitchOutputConf); -REGISTER_OUTPUT_INTERFACE_OP(OperatorConf::kSwitchOutputConf); } // namespace oneflow -- GitLab