提交 443b3dfd 编写于 作者: L lixinqi

copyhd-free output_op task_node

上级 2fab3c60
......@@ -15,12 +15,12 @@ namespace oneflow {
namespace {
bool IsOutputInterfaceTask(const TaskNode* node) {
bool IsInterfaceTask(const TaskNode* node) {
const auto* comp_task_node = dynamic_cast<const CompTaskNode*>(node);
if (comp_task_node == nullptr) { return false; }
if (comp_task_node->logical_node()->op_vec().size() != 1) { return false; }
auto op_type_case = comp_task_node->logical_node()->SoleOp()->op_conf().op_type_case();
return IsClassRegistered<IsOutputInterfaceOpConf4OpTypeCase>(op_type_case);
return IsClassRegistered<IsInterfaceOpConf4OpTypeCase>(op_type_case);
}
bool IsConnectToTickOp(const TaskNode* node) {
......@@ -762,7 +762,7 @@ void TaskGraph::BuildTaskPath(
|| cur_node->MemZoneId121() != dst->MemZoneId121()) {
cur_node = BuildTaskStep(cur_node, dst, GetBufTask, SetBufTask, use_buf_task_node);
}
Connect<TaskNode>(cur_node, NewEdge(), dst);
if (cur_node != dst) { Connect<TaskNode>(cur_node, NewEdge(), dst); }
}
TaskNode* TaskGraph::BuildTaskStep(
......@@ -783,6 +783,7 @@ TaskNode* TaskGraph::BuildTaskStep(
next_mem_zone_id = dst->MemZoneId121();
if (!use_buf_task_node || !(next_node = GetBufTask(cur_node->machine_id(), next_mem_zone_id))) {
next_node = TryAddCopyH2DTaskTo(dst);
if (next_node == nullptr) { next_node = dst; }
Connect<TaskNode>(cur_node, NewEdge(), next_node);
}
} else if (cur_node->machine_id() != dst->machine_id()) {
......@@ -794,12 +795,14 @@ TaskNode* TaskGraph::BuildTaskStep(
} else {
UNIMPLEMENTED();
}
if (use_buf_task_node) { SetBufTask(next_node->machine_id(), next_mem_zone_id, next_node); }
if (use_buf_task_node && (next_node != dst)) {
SetBufTask(next_node->machine_id(), next_mem_zone_id, next_node);
}
return next_node;
}
TaskNode* TaskGraph::TryAddCopyH2DTaskTo(TaskNode* task) {
if (IsOutputInterfaceTask(task)) { return task; }
if (IsInterfaceTask(task)) { return nullptr; }
CHECK_EQ(task->device_type(), DeviceType::kGPU);
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::H2D, task->machine_id(), task->GpuPhyId());
......@@ -860,7 +863,8 @@ void TaskGraph::BuildInBoxing(const LogicalNode* logical,
TaskNode* task = comp_task;
if (task->device_type() == DeviceType::kGPU) {
task = TryAddCopyH2DTaskTo(comp_task);
Connect<TaskNode>(task, NewEdge(), comp_task);
if (task == nullptr) { task = comp_task; }
if (task != comp_task) { Connect<TaskNode>(task, NewEdge(), comp_task); }
}
machine_id2bound_task[task->machine_id()].push_back(task);
}
......
......@@ -32,10 +32,11 @@ void WithOpGraphAndMutJob(Job* job, const std::function<void(const OpGraph&, Job
Handler(op_graph, job);
}
void WithOpGraphAndMutJobBuilder(JobBuilder* job_builder,
void WithOpGraphAndMutJobBuilder(Job* job,
const std::function<void(const OpGraph&, JobBuilder*)>& Handler) {
OpGraph op_graph(job_builder->job());
Handler(op_graph, job_builder);
OpGraph op_graph(*job);
JobBuilder job_builder(job);
Handler(op_graph, &job_builder);
}
void GenerateFacadeImplOpConfIf(const OpNode& op_node, JobBuilder* job_builder) {
......@@ -461,36 +462,31 @@ void JobCompleter::Complete(Job* job) const {
// replace facade op
SplitDecodeOps(job);
AddRecordLoadOps(job);
auto job_builder = std::make_unique<JobBuilder>(job);
WithOpGraphAndMutJobBuilder(job_builder.get(), &ReplaceFacade);
WithOpGraphAndMutJobBuilder(job, &ReplaceFacade);
// complete variable ops
WithOpGraphAndMutJobBuilder(job_builder.get(), &AutoVar);
WithOpGraphAndMutJobBuilder(job_builder.get(), &SetDefaultVariableConf);
WithOpGraphAndMutJobBuilder(job, &AutoVar);
WithOpGraphAndMutJobBuilder(job, &SetDefaultVariableConf);
if (GlobalJobDesc().IsTrain()) {
WithOpGraphAndMutJob(job, &TieUpChainHeadersUnReachableFromAnyVariableOps);
job_builder.reset(new JobBuilder(job));
WithOpGraphAndMutJobBuilder(job_builder.get(), &EnableAutoMixedPrecision);
WithOpGraphAndMutJobBuilder(job, &EnableAutoMixedPrecision);
// complete ops for trainning
WithOpGraphAndMutJobBuilder(job_builder.get(), &GenerateOpConf4Trainning);
WithOpGraphAndMutJobBuilder(job_builder.get(), &RewriteBoxingWithAllReduce);
WithOpGraphAndMutJobBuilder(job_builder.get(), &MakeAllReduceSequence);
WithOpGraphAndMutJobBuilder(job, &GenerateOpConf4Trainning);
WithOpGraphAndMutJobBuilder(job, &RewriteBoxingWithAllReduce);
WithOpGraphAndMutJobBuilder(job, &MakeAllReduceSequence);
}
WithOpGraphAndMutJobBuilder(job_builder.get(), &DumpLogicalBlobDescAndSbpSignature);
WithOpGraphAndMutJobBuilder(job_builder.get(), &GroupBoxingByDstParallel);
WithOpGraphAndMutJobBuilder(job_builder.get(), &AddKeepHeaderOnlyOp);
WithOpGraphAndMutJobBuilder(job_builder.get(), &SetCtrlInOpName4VariableOp);
WithOpGraphAndMutJobBuilder(job, &DumpLogicalBlobDescAndSbpSignature);
WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel);
WithOpGraphAndMutJobBuilder(job, &AddKeepHeaderOnlyOp);
WithOpGraphAndMutJobBuilder(job, &SetCtrlInOpName4VariableOp);
// complete tick ops
job_builder.reset(new JobBuilder(job));
WithOpGraphAndMutJobBuilder(job_builder.get(), &AutoSourceTick);
job_builder.reset(new JobBuilder(job));
WithOpGraphAndMutJobBuilder(job_builder.get(), &AddTickForTimeShape);
job_builder.reset(new JobBuilder(job));
WithOpGraphAndMutJobBuilder(job_builder.get(), &AutoSinkTick);
AddGlobalTotalJobCriticalSection(job_builder->job());
WithOpGraphAndMutJobBuilder(job_builder.get(), &AddGlobalInputCriticalSections);
WithOpGraphAndMutJobBuilder(job_builder.get(), &AddGlobalOutputCriticalSections);
WithOpGraphAndMutJobBuilder(job_builder.get(), &DumpLogicalBlobDescAndSbpSignature);
WithOpGraphAndMutJobBuilder(job_builder.get(), &SetOpTimeShape7BatchDimLbis);
WithOpGraphAndMutJobBuilder(job, &AutoSourceTick);
WithOpGraphAndMutJobBuilder(job, &AddTickForTimeShape);
WithOpGraphAndMutJobBuilder(job, &AutoSinkTick);
AddGlobalTotalJobCriticalSection(*job);
WithOpGraphAndMutJobBuilder(job, &AddGlobalInputCriticalSections);
WithOpGraphAndMutJobBuilder(job, &AddGlobalOutputCriticalSections);
WithOpGraphAndMutJobBuilder(job, &DumpLogicalBlobDescAndSbpSignature);
WithOpGraphAndMutJobBuilder(job, &SetOpTimeShape7BatchDimLbis);
CheckOpGraph(OpGraph(*job));
}
......
......@@ -313,7 +313,6 @@ struct RuntimeMemBlockNum4OpSameOutputBlob final {
};
struct IsInterfaceOpConf4OpTypeCase final {};
struct IsOutputInterfaceOpConf4OpTypeCase final {};
#define REGISTER_OP(op_type_case, OpType) \
REGISTER_CLASS_CREATOR(op_type_case, OnlyCpuSupportPredicator, \
......@@ -338,10 +337,6 @@ struct IsOutputInterfaceOpConf4OpTypeCase final {};
REGISTER_CLASS_CREATOR(op_type_case, IsInterfaceOpConf4OpTypeCase, \
([] { return new IsInterfaceOpConf4OpTypeCase(); }))
#define REGISTER_OUTPUT_INTERFACE_OP(op_type_case) \
REGISTER_CLASS_CREATOR(op_type_case, IsOutputInterfaceOpConf4OpTypeCase, \
([] { return new IsOutputInterfaceOpConf4OpTypeCase(); }))
std::shared_ptr<Operator> ConstructOp(const OperatorConf& op_conf);
inline std::shared_ptr<Operator> ConstructOp(const OperatorConf& op_conf, DeviceType device_type) {
......
......@@ -34,6 +34,5 @@ void OutputOp::GetSbpSignatures(
REGISTER_OP(OperatorConf::kOutputConf, OutputOp);
REGISTER_OP_SAME_OUTPUT_BLOB_MEM_BLOCK_NUM(OperatorConf::kOutputConf, 1);
REGISTER_INTERFACE_OP(OperatorConf::kOutputConf);
REGISTER_OUTPUT_INTERFACE_OP(OperatorConf::kOutputConf);
} // namespace oneflow
......@@ -51,6 +51,5 @@ void SwitchOutputOp::GetSbpSignatures(
REGISTER_OP(OperatorConf::kSwitchOutputConf, SwitchOutputOp);
REGISTER_OP_SAME_OUTPUT_BLOB_MEM_BLOCK_NUM(OperatorConf::kSwitchOutputConf, 1);
REGISTER_INTERFACE_OP(OperatorConf::kSwitchOutputConf);
REGISTER_OUTPUT_INTERFACE_OP(OperatorConf::kSwitchOutputConf);
} // namespace oneflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册