未验证 提交 788bd1a5 编写于 作者: C cheng cheng 提交者: GitHub

Fix bug of Multi-Client src tick output order (#6221)

* Fix bug of Multi-Client src tick output order

* Add input/output ctrl edge to DstSubTick for order io and callback_notify

* add test scripts

* remove note

* auto format by CI

* add note of sleep

* auto format by CI
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>
上级 86f77141
......@@ -199,6 +199,7 @@ Maybe<void> NNGraph::CompileAndInitRuntime() {
<< " , compile time: " << (GetCurTime() - start) / 1000000000.0 << " seconds.\n";
if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
TeePersistentLogStream::Create("job_" + name_ + "_plan")->Write(plan_);
PlanUtil::ToDotFile(plan_, "job_" + name_ + "_plan.dot");
}
// TODO(chengcheng): test collective boxing for multi-job.
PlanUtil::GenCollectiveBoxingPlan(&job_, &plan_);
......
......@@ -26,6 +26,7 @@ limitations under the License.
#include "oneflow/core/job/scope.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/job_rewriter/calculation_pass.h"
#include "oneflow/core/job/env_desc.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h"
#include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h"
#include "oneflow/core/graph/stream_index_getter_registry_manager.h"
......@@ -542,6 +543,48 @@ void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_node
}
}
void TaskGraph::AddCtrlEdgeBetweenSrcDstTickAndInputOutputInSameRank() {
if (!CHECK_JUST(GlobalMultiClientEnv())) { return; }
HashMap<int64_t, TaskNode*> rank_id2src_tick;
HashMap<int64_t, TaskNode*> rank_id2dst_tick;
HashMap<int64_t, HashSet<TaskNode*>> rank_id2input_output_nodes;
ForEachNode([&](TaskNode* node) {
if (node->GetTaskType() == TaskType::kSrcSubsetTick) {
CHECK(rank_id2src_tick.emplace(node->machine_id(), node).second);
} else if (node->GetTaskType() == TaskType::kDstSubsetTick) {
CHECK(rank_id2dst_tick.emplace(node->machine_id(), node).second);
} else if (node->GetTaskType() == TaskType::kNormalForward) {
auto* forward_node = reinterpret_cast<NormalForwardCompTaskNode*>(node);
CHECK(forward_node);
if (forward_node->op()->op_conf().has_input_conf()
|| forward_node->op()->op_conf().has_output_conf()) {
CHECK(rank_id2input_output_nodes[node->machine_id()].insert(node).second);
}
}
});
auto AddCtrlEdge = [&](TaskNode* src, TaskNode* dst) {
std::string ctrl_regst_name;
src->BuildCtrlRegstDesc(dst, &ctrl_regst_name);
TaskEdge* edge = NewEdge();
Connect<TaskNode>(src, edge, dst);
src->BindEdgeWithProducedRegst(edge, ctrl_regst_name);
};
for (auto& pair : rank_id2src_tick) {
int64_t rank_id = pair.first;
TaskNode* src = pair.second;
for (TaskNode* io_task : rank_id2input_output_nodes[rank_id]) { AddCtrlEdge(src, io_task); }
}
for (auto& pair : rank_id2dst_tick) {
int64_t rank_id = pair.first;
TaskNode* dst = pair.second;
for (TaskNode* io_task : rank_id2input_output_nodes[rank_id]) { AddCtrlEdge(io_task, dst); }
}
}
void TaskGraph::RemoveEmptyRegsts() {
ForEachNode([&](TaskNode* node) { node->EraseUninitializedShapeProducedBlob(); });
ForEachNode([&](TaskNode* node) { node->EraseZeroSizeConsumedRegst(); });
......
......@@ -47,6 +47,7 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
const char* TypeName() const override { return "TaskGraph"; }
void RemoveEmptyRegsts();
void AddCtrlEdgeBetweenSrcDstTickAndInputOutputInSameRank();
void MergeChainAndAddOrderingCtrlEdgeInSameChain();
void EnableInplaceMemSharing(const std::function<bool(const std::string&, const std::string&)>&
......
......@@ -68,6 +68,10 @@ void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const {
task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1));
task_gph->TopoForEachNode(&TaskNode::Build);
task_gph->RemoveEmptyRegsts();
// NOTE(chengcheng):
// In Multi-Client, each rank has its own src_tick/dst_tick and input/output with callback,
// which need to be forced sequenced.
task_gph->AddCtrlEdgeBetweenSrcDstTickAndInputOutputInSameRank();
task_gph->MergeChainAndAddOrderingCtrlEdgeInSameChain();
auto IsReachable = Global<OpGraph>::Get()->MakePredicatorIsOpNameDataOrCtrlReachable();
if (job_desc.enable_inplace()) { task_gph->EnableInplaceMemSharing(IsReachable); }
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import time
import unittest
import numpy as np
import oneflow as flow
import oneflow.unittest
def _test_graph_pipeline_delay_output(test_case):
class StageLayerModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = flow.nn.Linear(10, 8, False)
self.linear2 = flow.nn.Linear(8, 10)
flow.nn.init.constant_(self.linear1.weight, 0.023)
flow.nn.init.constant_(self.linear2.weight, 1.23)
def forward(self, x):
out0 = self.linear1(x)
out0 = out0 + 1.0
out0 = out0 * 2.0
out1 = self.linear2(out0)
return out1
P0 = flow.placement("cuda", {0: [0]})
P1 = flow.placement("cuda", {0: [1]})
B = flow.sbp.broadcast
class PipelineModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.layer_0 = StageLayerModule()
self.layer_1 = StageLayerModule()
self.layer_0.to_consistent(P0, B)
self.layer_1.to_consistent(P1, B)
def forward(self, x):
# stage 0
in0 = x.to_consistent(P0, B)
out0 = self.layer_0(in0)
# stage 1
in1 = out0.to_consistent(P1, B)
out1 = self.layer_1(in1)
return out1
pp_m = PipelineModule()
pp_m.train()
of_sgd = flow.optim.SGD(pp_m.parameters(), lr=0.001)
class PipelineGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.pp_m = pp_m
self.pp_m.layer_0.config.stage_id = 0
self.pp_m.layer_1.config.stage_id = 1
self.config.set_gradient_accumulation_steps(4)
self.add_optimizer(of_sgd)
def build(self, x, y):
pp_out = self.pp_m(x)
loss = pp_out.mean()
loss.backward()
y = x + y
free_out = y.to_consistent(P1, B)
return loss, free_out
pp_g = PipelineGraph()
rank = flow.env.get_rank()
for i in range(3):
x = flow.randn(16, 10)
y = flow.randn(16, 10)
x = x.to_consistent(P0, B)
y = y.to_consistent(P0, B)
if rank == 1:
time.sleep(2)
loss_pack_4, free_out = pp_g(x, y)
if rank == 1:
# NOTE(chengcheng): Before Oneflow-Inc/oneflow#6221 fix src/dst tick order with input/output,
# this case use sleep in rank 1 will expose this BUG:
# free_out is output only on rank 1, but NOT control in rank 1 src/dst tick, so if manual sleep
# on rank 1, free out pull callback must exec before rank 1 src tick exec, so will meet BUG of
# output_kernel buffer status empty.
# After this PR fix, this test case ensure that src/dst tick and input/output cb exec order on
# each rank is as expected.
time.sleep(2)
print(
"rank: ",
rank,
"packed loss with 4 micro-batch = ",
loss_pack_4.to_local(),
)
print(
"rank: ",
rank,
"packed image with 4 micro-batch = ",
free_out.to_local(),
)
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n2d()
class TestGraphPipelineDelayOutput(oneflow.unittest.TestCase):
def test_graph_pipeline_delay_output(test_case):
_test_graph_pipeline_delay_output(test_case)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册