task_graph.cpp 31.1 KB
Newer Older
S
Shenghang Tsai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
J
jiyuan 已提交
16
#include "oneflow/core/graph/task_graph.h"
J
Jinhui Yuan 已提交
17 18
#include "oneflow/core/graph/chain_graph.h"
#include "oneflow/core/common/util.h"
L
Li Xinqi 已提交
19
#include "oneflow/core/graph/inplace_lbi_graph.h"
20
#include "oneflow/core/register/runtime_blob_desc.h"
J
Jinhui Yuan 已提交
21
#include "oneflow/core/job/thrd_id_generator.h"
22
#include "oneflow/core/job/global_for.h"
L
Li Xinqi 已提交
23
#include "oneflow/core/operator/variable_op.h"
J
Juncheng 已提交
24
#include "oneflow/core/graph/op_graph.h"
C
cheng cheng 已提交
25
#include "oneflow/core/graph/normal_forward_compute_task_node.h"
J
Juncheng 已提交
26 27 28
#include "oneflow/core/graph/boxing/sub_task_graph_builder_context.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/chain_sub_task_graph_builder.h"
J
Juncheng 已提交
29
#include "oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.h"
L
Li Xinqi 已提交
30
#include "oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h"
J
Juncheng 已提交
31
#include "oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.h"
J
Juncheng 已提交
32
#include "oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.h"
J
Juncheng 已提交
33
#include "oneflow/core/graph/boxing/b21_sub_task_graph_builder.h"
J
Juncheng 已提交
34
#include "oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h"
J
Juncheng 已提交
35
#include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h"
36
#include "oneflow/core/graph/boxing_identity_task_node.h"
W
willzhang4a58 已提交
37 38 39

namespace oneflow {

L
Li Xinqi 已提交
40 41
namespace {

L
lixinqi 已提交
42
bool IsInterfaceTask(const TaskNode* node) {
L
Li Xinqi 已提交
43 44 45 46
  const auto* comp_task_node = dynamic_cast<const CompTaskNode*>(node);
  if (comp_task_node == nullptr) { return false; }
  if (comp_task_node->logical_node()->op_vec().size() != 1) { return false; }
  auto op_type_case = comp_task_node->logical_node()->SoleOp()->op_conf().op_type_case();
47
  return IsClassRegistered<int32_t, IsInterfaceOpConf4OpTypeCase>(op_type_case);
L
Li Xinqi 已提交
48 49
}

L
Li Xinqi 已提交
50 51 52 53 54 55 56 57 58
bool IsConnectToTickOp(const TaskNode* node) {
  const auto* comp_task_node = dynamic_cast<const CompTaskNode*>(node);
  if (comp_task_node == nullptr) { return false; }
  if (comp_task_node->logical_node()->op_vec().size() != 1) { return false; }
  const Operator* op = comp_task_node->logical_node()->SoleOp().get();
  if (dynamic_cast<const VariableOp*>(op) != nullptr) { return true; }
  return false;
}

C
cheng cheng 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
bool IsSpecialOpNotConsiderMergeInChain(const Operator* op) {
  const OperatorConf& op_conf = op->op_conf();
  if (op_conf.has_variable_conf() || op_conf.has_keep_header_only_conf() || op_conf.has_tick_conf()
      || op_conf.has_device_tick_conf() || op_conf.has_partial_tick_conf()) {
    return true;
  }
  return false;
}

bool IsTaskNodeProducedResgtHasMultiRegstNum(const TaskNode* node) {
  for (const auto& pair : node->produced_regsts()) {
    if (pair.second->min_register_num() > 1) { return true; }
  }
  return false;
}

bool CanBeMergedInChain(const TaskNode* node) {
  // ONLY the node which is NormalForward and in GPU and NOT variable can be merged.
  if (IsTaskNodeProducedResgtHasMultiRegstNum(node)) { return false; }
  const auto* fw_comp_node = dynamic_cast<const NormalForwardCompTaskNode*>(node);
  if (fw_comp_node == nullptr) { return false; }
  if (fw_comp_node->logical_node()->op_vec().size() != 1) { return false; }
  if (fw_comp_node->device_type() != DeviceType::kGPU) { return false; }
  const Operator* op = fw_comp_node->logical_node()->SoleOp().get();
  if (IsSpecialOpNotConsiderMergeInChain(op)) { return false; }
  return true;
}

void TraverseConnectedSubGraphMergeInThisChain(TaskNode* this_node, const int64_t this_chain_id) {
  CHECK_NE(this_chain_id, -1);
  CHECK_EQ(this_node->chain_id(), -1);
  // bfs search all node can be merged in this chain
  HashSet<TaskNode*> visited_nodes;
  std::queue<TaskNode*> queued_nodes;
  queued_nodes.push(this_node);
  visited_nodes.insert(this_node);
  while (!queued_nodes.empty()) {
    TaskNode* cur_node = queued_nodes.front();
    queued_nodes.pop();

    CHECK_EQ(cur_node->chain_id(), -1);
    cur_node->set_chain_id(this_chain_id);

    cur_node->ForEachNodeOnInOutEdge([&](TaskNode* next_node) {
      // NOTE(chengcheng): use area_id to not merge optimizer ops with fw/bw ops
      if (visited_nodes.find(next_node) == visited_nodes.end() && CanBeMergedInChain(next_node)
          && this_node->GlobalWorkStreamId() == next_node->GlobalWorkStreamId()
          && this_node->area_id() == next_node->area_id()) {
        if (next_node->chain_id() == -1) {
          queued_nodes.push(next_node);
          visited_nodes.insert(next_node);
        } else {
          CHECK_EQ(next_node->chain_id(), this_chain_id);
        }
      }
    });
  }
}

L
Li Xinqi 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
std::function<TaskNode*(const std::string&)> MakeGetterTaskNode4SoleOpName(
    const HashSet<TaskNode*>& task_nodes) {
  auto op_name2task_nodes = std::make_shared<HashMap<std::string, HashSet<TaskNode*>>>();
  for (TaskNode* task_node : task_nodes) {
    if (task_node->exec_gph().node_num() == 1) {
      ExecNode* exec_node = task_node->exec_gph().SoleNode();
      CHECK((*op_name2task_nodes)[exec_node->op()->op_name()].emplace(task_node).second);
    }
  }
  return [op_name2task_nodes](const std::string& op_name) -> TaskNode* {
    const auto& iter = op_name2task_nodes->find(op_name);
    if (iter == op_name2task_nodes->end()) { return nullptr; }
    if (iter->second.size() > 1) { return nullptr; }
    return *iter->second.begin();
  };
J
Juncheng 已提交
133
}
L
Li Xinqi 已提交
134 135 136 137 138 139 140 141 142

bool IsLbiOnTaskEdge(const TaskEdge* edge, const LogicalBlobId& lbi) {
  for (const auto& regst_desc : edge->GetRegsts()) {
    if (regst_desc->HasLbi(lbi)) { return true; }
  }
  return false;
}

std::function<bool(const LogicalBlobId&, const std::string&)>
S
scxfjiang 已提交
143
MakePredicatorIsLbiAllConsumersReachable(
J
Juncheng 已提交
144 145 146
    const std::function<const TaskNode*(const std::string&)>& TaskNode4SoleOpName,
    const std::function<bool(const std::string&, const std::string&)>&
        IsOpNameDataOrCtrlReachable) {
S
scxfjiang 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
  auto IsDataOrCtrlReachable = [IsOpNameDataOrCtrlReachable](const TaskNode* src_node,
                                                             const TaskNode* dst_node) -> bool {
    if (src_node->chain_id() == dst_node->chain_id()
        && src_node->order_in_graph() <= dst_node->order_in_graph()) {
      return true;
    }
    const CompTaskNode* comp_src_node = dynamic_cast<const CompTaskNode*>(src_node);
    if (comp_src_node == nullptr) { return false; }
    if (comp_src_node->logical_node()->op_vec().size() != 1) { return false; }
    const CompTaskNode* comp_dst_node = dynamic_cast<const CompTaskNode*>(dst_node);
    if (comp_dst_node == nullptr) { return false; }
    if (comp_dst_node->logical_node()->op_vec().size() != 1) { return false; }
    return IsOpNameDataOrCtrlReachable(comp_src_node->logical_node()->SoleOp()->op_name(),
                                       comp_dst_node->logical_node()->SoleOp()->op_name());
  };
  return [TaskNode4SoleOpName, IsDataOrCtrlReachable](const LogicalBlobId& lbi,
                                                      const std::string& op_name) -> bool {
L
Li Xinqi 已提交
164 165 166
    const TaskNode* src_task_node = TaskNode4SoleOpName(lbi.op_name());
    const TaskNode* dst_task_node = TaskNode4SoleOpName(op_name);
    size_t out_edges_size = 0;
L
Li Xinqi 已提交
167
    size_t reachable_out_edges_size = 0;
L
Li Xinqi 已提交
168 169 170
    for (TaskEdge* out_edge : src_task_node->out_edges()) {
      if (IsLbiOnTaskEdge(out_edge, lbi)) {
        out_edges_size += 1;
S
scxfjiang 已提交
171
        reachable_out_edges_size += IsDataOrCtrlReachable(out_edge->dst_node(), dst_task_node);
L
Li Xinqi 已提交
172 173
      }
    }
L
Li Xinqi 已提交
174
    return out_edges_size > 0 && out_edges_size == reachable_out_edges_size;
L
Li Xinqi 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188
  };
}

bool IsInplaceAllowed(
    TaskNode* task_node, const std::vector<std::string>& bns,
    const std::function<const TaskNode*(const std::string&)>& TaskNode4SoleOpName) {
  if (task_node->exec_gph().node_num() != 1) { return false; }
  const auto& exec_node = *task_node->exec_gph().SoleNode();
  for (const auto& bn : bns) {
    // TaskNode for bn is not nullptr if it's on the same device with `task_node`
    if (TaskNode4SoleOpName(exec_node.op()->BnInOp2Lbi(bn).op_name()) == nullptr) { return false; }
    const RegstDesc& regst_desc = *exec_node.RegstDesc4BnInOp(bn);
    if (regst_desc.NumOfLbi() != 1) { return false; }
  }
L
lixinqi 已提交
189
  const BlobDesc* first_blob = nullptr;
190
  for (const auto& bn : bns) {
L
lixinqi 已提交
191 192 193
    const BlobDesc* blob_desc = exec_node.RegstDesc4BnInOp(bn)->SoleBlobDesc();
    if (first_blob == nullptr) {
      first_blob = blob_desc;
194
    } else {
J
Juncheng 已提交
195
      if (!(first_blob->shape().elem_cnt() == blob_desc->shape().elem_cnt()
L
lixinqi 已提交
196 197 198
            && first_blob->data_type() == blob_desc->data_type())) {
        return false;
      }
199 200
    }
  }
L
Li Xinqi 已提交
201 202 203
  return true;
}

qq_22305325's avatar
qq_22305325 已提交
204 205 206 207 208 209 210 211 212
std::unique_ptr<BoxingLogger> CreateBoxingLogger() {
  if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
    return std::unique_ptr<BoxingLogger>(
        new CsvBoxingLogger(StrCat("boxing/log/", GlobalJobDesc().job_id()) + ".csv"));
  } else {
    return std::unique_ptr<BoxingLogger>(new NullBoxingLogger());
  }
}

L
Li Xinqi 已提交
213 214
}  // namespace

215 216
TaskGraph::TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph) {
  logical_gph_ = std::move(logical_gph);
J
Juncheng 已提交
217
  sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this));
qq_22305325's avatar
qq_22305325 已提交
218
  boxing_logger_ = CreateBoxingLogger();
J
Juncheng 已提交
219 220
  std::vector<std::shared_ptr<SubTskGphBuilder>> builders;
  builders.emplace_back(new OneToOneSubTskGphBuilder());
J
Juncheng 已提交
221
  builders.emplace_back(new B21SubTskGphBuilder());
J
Juncheng 已提交
222 223 224
  builders.emplace_back(new CollectiveBoxingSubTskGphBuilder());
  builders.emplace_back(new SliceBoxingSubTskGphBuilder());
  builders.emplace_back(new NaiveB2BSubTskGphBuilder());
J
Juncheng 已提交
225
  builders.emplace_back(new NaiveB2PSubTskGphBuilder());
J
Juncheng 已提交
226
  sub_tsk_gph_builder_.reset(new ChainSubTskGphBuilder(builders));
W
willzhang4a58 已提交
227
  HashMap<const LogicalNode*, std::vector<CompTaskNode*>> logical2sorted_comp_tasks;
J
Jinhui Yuan 已提交
228 229 230
  HashMap<CompTaskNode*, HashMap<int64_t, std::vector<TaskNode*>>> buf_task;
  auto MutBufTask = [&](CompTaskNode* task_node, int64_t machine_id, int32_t mem_zone_id) {
    auto& buf_vec = buf_task[task_node][machine_id];
231 232 233
    if (buf_vec.empty()) {
      buf_vec.assign(Global<ResourceDesc, ForSession>::Get()->MemZoneNum(), nullptr);
    }
234 235
    return &(buf_vec.at(mem_zone_id));
  };
W
willzhang4a58 已提交
236

237 238
  std::vector<int64_t> cpu_device_offset(Global<ResourceDesc, ForSession>::Get()->TotalMachineNum(),
                                         0);
239
  auto AllocateCpuThrdIdEvenly = [&](const TaskNode* task_node) {
240
    CHECK(!task_node->IsIndependent());
J
Jinhui Yuan 已提交
241
    int64_t& offset = cpu_device_offset.at(task_node->machine_id());
J
Juncheng 已提交
242
    int64_t ret = Global<IDMgr>::Get()->GetCpuDeviceThrdId(offset);
243
    offset = (offset + 1) % Global<ResourceDesc, ForSession>::Get()->CpuDeviceNum();
244 245
    return ret;
  };
J
Jinhui Yuan 已提交
246 247

  std::vector<std::pair<int64_t, CompTaskNode*>> machine_persistence_task_vec;
248
  logical_gph_->ForEachNode([&](const LogicalNode* logical_node) {
249
    logical_node->GenSortedCompTaskNodes(
J
Jinhui Yuan 已提交
250
        AllocateCpuThrdIdEvenly, &machine_persistence_task_vec, [&](CompTaskNode* comp_task_node) {
251 252
          AddAllocatedNode(comp_task_node);
          logical2sorted_comp_tasks[logical_node].push_back(comp_task_node);
253
          comp_task_node->set_area_id(logical_node->GetAreaId());
254
        });
255
  });
N
Niu Chong 已提交
256

257
  GenerateIndependentThrdId(machine_persistence_task_vec);
258
  logical_gph_->ForEachEdge([&](const LogicalEdge* logical_edge) {
W
willzhang4a58 已提交
259 260
    BldSubTskGphMthd method =
        GetMthdForBldSubTskGph(logical_edge->src_node(), logical_edge->dst_node());
261 262
    (this->*method)(logical_edge->src_node(), logical_edge->dst_node(),
                    logical2sorted_comp_tasks.at(logical_edge->src_node()),
263 264
                    logical2sorted_comp_tasks.at(logical_edge->dst_node()), MutBufTask,
                    AllocateCpuThrdIdEvenly);
J
Jinhui Yuan 已提交
265
    SetAreaIdForNewNodes(logical_edge->src_node(), logical_edge->dst_node());
266
  });
L
Li Xinqi 已提交
267 268 269 270 271 272 273
  logical_gph_->ForEachNecessaryCtrlEdge(
      [&](const LogicalNode* src, const LogicalNode* dst, int64_t ctrl_regst_num) {
        const auto& src_task_nodes = logical2sorted_comp_tasks.at(src);
        const auto& dst_task_nodes = logical2sorted_comp_tasks.at(dst);
        ConnectCtrlEdges(src_task_nodes, dst_task_nodes, ctrl_regst_num);
      });

C
cheng cheng 已提交
274
  SetOrderInGraphForEachNode();
275
  if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) { ToDotWithAutoFilePath(); }
W
Will Zhang 已提交
276 277
}

L
Li Xinqi 已提交
278 279 280 281 282
void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_nodes,
                                 const std::vector<CompTaskNode*>& dst_task_nodes,
                                 int64_t ctrl_regst_num) {
  CHECK_EQ(src_task_nodes.size(), dst_task_nodes.size());
  FOR_RANGE(int32_t, i, 0, src_task_nodes.size()) {
283 284 285
    std::string regst_desc_name;
    RegstDesc* ctrl_regst_desc =
        src_task_nodes.at(i)->BuildCtrlRegstDesc(dst_task_nodes.at(i), &regst_desc_name);
L
Li Xinqi 已提交
286 287 288 289
    ctrl_regst_desc->UpdtMinRegstNumIfNeed(ctrl_regst_num);
    ctrl_regst_desc->UpdtMaxRegstNumIfNeed(ctrl_regst_num);
    ctrl_regst_desc->mut_regst_desc_type()->mutable_ctrl_regst_desc()->set_returned_regst_num(
        ctrl_regst_num);
290 291 292 293

    TaskEdge* edge = NewEdge();
    Connect<TaskNode>(src_task_nodes.at(i), edge, dst_task_nodes.at(i));
    src_task_nodes.at(i)->BindEdgeWithProducedRegst(edge, regst_desc_name);
L
Li Xinqi 已提交
294 295 296
  }
}

297
void TaskGraph::GenerateIndependentThrdId(
J
Jinhui Yuan 已提交
298 299 300 301 302 303
    const std::vector<std::pair<int64_t, CompTaskNode*>>& persistence_nodes) {
  std::vector<std::pair<int64_t, TaskType>> machine_task_type_vec;
  for (auto pair : persistence_nodes) {
    machine_task_type_vec.emplace_back(std::make_pair(pair.first, pair.second->GetTaskType()));
  }

304
  ThrdIdGenerator generator(machine_task_type_vec, Global<IDMgr>::Get()->BaseIndependentThrdId());
J
Juncheng 已提交
305
  for (const auto& pair : persistence_nodes) {
J
Jinhui Yuan 已提交
306 307 308 309 310
    int64_t thrd_id = generator.GenerateThrdId(pair.first, pair.second->GetTaskType());
    pair.second->set_thrd_id(thrd_id);
  }
}

311
void TaskGraph::AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAllowedStartNode,
J
Juncheng 已提交
312
                                       const std::function<void(TaskNode* node)>& Handler) const {
313
  auto ForEachInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
314
    node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) {
L
Li Xinqi 已提交
315
      if (IsBackEdge(node_on_in_edge, node)) { return; }
316
      Handler(const_cast<TaskNode*>(node_on_in_edge));
317 318
    });
  };
319
  auto ForEachOutNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
320
    node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) {
L
Li Xinqi 已提交
321
      if (IsBackEdge(node, node_on_out_edge)) { return; }
322
      Handler(const_cast<TaskNode*>(node_on_out_edge));
323 324
    });
  };
L
Li Xinqi 已提交
325 326 327 328 329 330 331 332 333
  auto IsSourceNode = [&](TaskNode* node) {
    int32_t in_node_num = 0;
    ForEachInNode(node, [&](TaskNode* in_node) { ++in_node_num; });
    return in_node_num == 0;
  };
  std::list<TaskNode*> starts;
  ForEachNode([&](TaskNode* node) {
    if (IsSourceNode(node) && IsAllowedStartNode(node)) { starts.push_back(node); }
  });
334
  // DfsTopo will cause inappropriate chain graph
335 336 337
  TopoForEachNode(starts, ForEachInNode, ForEachOutNode, Handler);
}

J
Juncheng 已提交
338
void TaskGraph::AcyclicTopoForEachNode(const std::function<void(TaskNode* node)>& Handler) const {
339
  return AcyclicTopoForEachNode([](TaskNode*) { return true; }, Handler);
340 341
}

J
Jinhui Yuan 已提交
342 343 344 345
void TaskGraph::RemoveEmptyRegsts() {
  ForEachNode([&](TaskNode* node) { node->EraseZeroSizeProducedBlob(); });
  ForEachNode([&](TaskNode* node) { node->EraseZeroSizeConsumedRegst(); });
  ForEachNode([&](TaskNode* node) { node->EraseZeroSizeProducedRegst(); });
346
  ForEachNode([&](TaskNode* node) { node->UnbindBnWithEmptyRegst(); });
J
Jinhui Yuan 已提交
347 348
}

C
cheng cheng 已提交
349 350 351 352
void TaskGraph::MergeChainAndAddOrderingCtrlEdgeInSameChain() {
  MergeChain();
  BuildCtrlRegstDescInSameChain();
}
J
Jinhui Yuan 已提交
353

C
cheng cheng 已提交
354
void TaskGraph::SetOrderInGraphForEachNode() {
J
Jinhui Yuan 已提交
355
  int64_t order_in_graph = 0;
C
cheng cheng 已提交
356
  auto SetOrderInGraph = [&](TaskNode* task_node) {
357 358 359
    task_node->set_order_in_graph(order_in_graph);
    ordered_task_nodes_.emplace_back(task_node);
    ++order_in_graph;
C
cheng cheng 已提交
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
  };
  AcyclicTopoForEachNode(SetOrderInGraph);
}

void TaskGraph::MergeChain() {
  int64_t chain_id = 0;
  for (auto* this_node : ordered_task_nodes_) {
    // skip if this node has been set in a chain.
    if (this_node->chain_id() != -1) { continue; }

    CHECK_EQ(this_node->chain_id(), -1);
    if (CanBeMergedInChain(this_node)) {
      TraverseConnectedSubGraphMergeInThisChain(this_node, chain_id);
    } else {
      this_node->set_chain_id(chain_id);
    }

    ++chain_id;
  }
  for (auto* node : ordered_task_nodes_) { CHECK_NE(node->chain_id(), -1); }
J
Jinhui Yuan 已提交
380 381
}

382
void TaskGraph::BuildCtrlRegstDescInSameChain() {
J
Jinhui Yuan 已提交
383
  HashMap<int64_t, TaskNode*> chain_id2node;
L
Li Xinqi 已提交
384 385
  for (auto* node : ordered_task_nodes_) {
    if (IsConnectToTickOp(node)) { continue; }
J
Jinhui Yuan 已提交
386 387 388 389 390 391 392 393 394 395 396
    int64_t chain_id = node->chain_id();
    auto iter = chain_id2node.find(chain_id);
    if (iter == chain_id2node.end()) {
      CHECK(chain_id2node.emplace(chain_id, node).second);
    } else {
      iter->second->BuildCtrlRegstDescIfNeed(node);
      iter->second = node;
    }
  }
}

L
Li Xinqi 已提交
397
void TaskGraph::GetInplaceOpBlobArgList(
398
    InplaceObasInfo* obas_info, const HashSet<TaskNode*>& dev_nodes,
L
Li Xinqi 已提交
399
    const std::function<const TaskNode*(const std::string&)>& TaskNode4OpName) const {
400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416
  auto AddMutableInplaceArgPair = [&](TaskNode* node, const std::string& ibn,
                                      const std::string& obn, const std::string& op_name) {
    if (IsInplaceAllowed(node, {ibn, obn}, TaskNode4OpName)) {
      auto* pair = obas_info->mut_inplace_oba_pairs.mutable_pair()->Add();
      *pair->mutable_first() = GenOpBlobArg(op_name, ibn);
      *pair->mutable_second() = GenOpBlobArg(op_name, obn);
    }
  };
  auto AddConstInplaceArgPair = [&](TaskNode* node, const std::string& ibn, const std::string& obn,
                                    const std::string& op_name) {
    if (IsInplaceAllowed(node, {ibn, obn}, TaskNode4OpName)) {
      auto* pair = obas_info->con_inplace_oba_pairs.mutable_pair()->Add();
      *pair->mutable_first() = GenOpBlobArg(op_name, ibn);
      *pair->mutable_second() = GenOpBlobArg(op_name, obn);
    }
  };

L
Li Xinqi 已提交
417 418 419 420 421 422
  for (TaskNode* task_node : dev_nodes) {
    if (task_node->exec_gph().node_num() != 1) { continue; }
    const auto& op = *task_node->exec_gph().SoleNode()->op();
    for (const std::string& ibn : op.input_bns()) {
      if (op.InputBlobModifier4Ibn(ibn).is_mutable()) {
        CHECK(IsInplaceAllowed(task_node, {ibn}, TaskNode4OpName));
423
        *obas_info->mut_in_obas.mutable_oba()->Add() = GenOpBlobArg(op.op_name(), ibn);
L
Li Xinqi 已提交
424 425
      }
    }
426 427
    for (const auto& pair : task_node->exec_gph().SoleNode()->mut_inplace_obn2ibn()) {
      AddMutableInplaceArgPair(task_node, pair.second, pair.first, op.op_name());
428
    }
429 430
    for (const auto& pair : task_node->exec_gph().SoleNode()->con_inplace_obn2ibn()) {
      AddConstInplaceArgPair(task_node, pair.second, pair.first, op.op_name());
L
Li Xinqi 已提交
431
    }
L
Li Xinqi 已提交
432 433 434 435
  }
}

void TaskGraph::GetSafeInplaceOpBlobArgList(
436
    InplaceObasInfo* safe_obas_info, const HashSet<TaskNode*>& dev_nodes,
J
Juncheng 已提交
437 438
    const std::function<bool(const std::string&, const std::string&)>& IsOpNameDataOrCtrlReachable)
    const {
L
Li Xinqi 已提交
439
  auto TaskNode4SoleOpName = MakeGetterTaskNode4SoleOpName(dev_nodes);
440 441
  InplaceObasInfo obas_info;
  GetInplaceOpBlobArgList(&obas_info, dev_nodes, TaskNode4SoleOpName);
L
Li Xinqi 已提交
442 443 444
  auto Op4OpName = [&](const std::string& op_name) -> const Operator* {
    return TaskNode4SoleOpName(op_name)->exec_gph().SoleNode()->op().get();
  };
S
scxfjiang 已提交
445 446
  auto IsLbiAllConsumersReachable =
      MakePredicatorIsLbiAllConsumersReachable(TaskNode4SoleOpName, IsOpNameDataOrCtrlReachable);
447 448 449
  InplaceLbiGraph origin_graph(obas_info, Op4OpName);
  InplaceLbiGraph safe_graph(*safe_obas_info, Op4OpName);
  origin_graph.ComputeSafeInplaceObns(safe_obas_info, IsLbiAllConsumersReachable);
450
  if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
J
Juncheng 已提交
451 452 453 454 455
    origin_graph.ToDotWithFilePath(
        JoinPath("dot", "InplaceLbiGraph", GlobalJobDesc().job_name() + "_origin.dot"));
    safe_graph.ToDotWithFilePath(
        JoinPath("dot", "InplaceLbiGraph", GlobalJobDesc().job_name() + "_safe.dot"));
  }
L
Li Xinqi 已提交
456 457
}

458
void TaskGraph::SetTaskRegstInplaceInfo(const InplaceObasInfo& obas_info,
L
Li Xinqi 已提交
459 460 461 462 463
                                        const HashSet<TaskNode*>& dev_nodes) const {
  auto TaskNode4SoleOpName = MakeGetterTaskNode4SoleOpName(dev_nodes);
  auto Op4OpName = [&](const std::string& op_name) -> const Operator* {
    return TaskNode4SoleOpName(op_name)->exec_gph().SoleNode()->op().get();
  };
464
  InplaceLbiGraph inplace_gph(obas_info, Op4OpName);
J
Juncheng 已提交
465
  inplace_gph.ForEachConnectedComponent([&](const HashSet<const InplaceLbiNode*>& inplace_nodes) {
L
Li Xinqi 已提交
466 467 468 469 470 471 472
    for (const auto* inplace_node : inplace_nodes) {
      if (inplace_node->in_edges().empty()) { continue; }
      const auto* inplace_edge = inplace_node->SoleInEdge();
      auto* exec_node = TaskNode4SoleOpName(inplace_edge->op().op_name())->exec_gph().SoleNode();
      RegstDesc* in_regst = exec_node->RegstDesc4BnInOp(inplace_edge->ibn());
      RegstDesc* out_regst = exec_node->RegstDesc4BnInOp(inplace_edge->obn());
      out_regst->set_hint_inplace_consumed_regst_desc_id(in_regst->regst_desc_id());
L
Li Xinqi 已提交
473
    }
L
Li Xinqi 已提交
474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
  });
}

void TaskGraph::ForEachGpuDeviceNodes(
    const std::function<void(const HashSet<TaskNode*>& dev_nodes)>& Handler) const {
  HashMap<std::pair<int64_t, int64_t>, HashSet<TaskNode*>> global_dev_phy_id2nodes;
  ForEachNode([&](TaskNode* task_node) {
    if (task_node->device_type() != DeviceType::kGPU) { return; }
    int64_t dev_phy_id = Global<IDMgr>::Get()->GetGpuPhyIdFromThrdId(task_node->thrd_id());
    global_dev_phy_id2nodes[{task_node->machine_id(), dev_phy_id}].emplace(task_node);
  });
  for (const auto& pair : global_dev_phy_id2nodes) { Handler(pair.second); }
}

void TaskGraph::EnableInplaceMemSharing(
S
scxfjiang 已提交
489 490
    const std::function<bool(const std::string&, const std::string&)>&
        IsOpNameDataOrCtrlReachable) {
L
Li Xinqi 已提交
491
  ForEachGpuDeviceNodes([&](const HashSet<TaskNode*>& dev_nodes) {
492 493 494
    InplaceObasInfo safe_inplace_obas_info;
    GetSafeInplaceOpBlobArgList(&safe_inplace_obas_info, dev_nodes, IsOpNameDataOrCtrlReachable);
    SetTaskRegstInplaceInfo(safe_inplace_obas_info, dev_nodes);
L
Li Xinqi 已提交
495 496 497
  });
}

J
Jinhui Yuan 已提交
498 499 500 501 502
void TaskGraph::SetAreaIdForNewNodes(const LogicalNode* src_logical,
                                     const LogicalNode* dst_logical) {
  CHECK(src_logical != nullptr && dst_logical != nullptr);
  ForEachNode([&](TaskNode* node) {
    if (node->area_id() != static_cast<int64_t>(kInvalidArea)) return;
503 504
    if (src_logical->GetAreaId() == dst_logical->GetAreaId()) {
      node->set_area_id(src_logical->GetAreaId());
J
Jinhui Yuan 已提交
505 506 507 508 509 510
    } else {
      node->set_area_id(static_cast<int64_t>(kBoundaryArea));
    }
  });
}

511 512
#define DEFINE_BLD_SUB_TASK_GRAPH_METHOD(method_name) \
  void TaskGraph::method_name BLD_SUB_TSK_GPH_MTHD_ARGS()
513 514

DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) {
J
Juncheng 已提交
515
  const std::vector<LogicalBlobId> lbis = src_logical->GetLbisTo(dst_logical);
J
Juncheng 已提交
516
  for (const LogicalBlobId& lbi : lbis) {
517
    std::vector<TaskNode*> in_nodes;
J
Juncheng 已提交
518
    if (lbis.size() == 1) {
519
      in_nodes.assign(sorted_src_comp_tasks.begin(), sorted_src_comp_tasks.end());
J
Juncheng 已提交
520 521
    } else {
      for (CompTaskNode* src_node : sorted_src_comp_tasks) {
522 523
        auto* identity_node = NewNode<BoxingIdentityTaskNode>();
        identity_node->Init(src_node->machine_id(), src_node->thrd_id(), src_node->area_id(), lbi);
J
Juncheng 已提交
524
        Connect<TaskNode>(src_node, NewEdge(), identity_node);
525
        in_nodes.push_back(identity_node);
J
Juncheng 已提交
526 527
      }
    }
528 529 530
    std::vector<TaskNode*> out_nodes;
    out_nodes.reserve(sorted_dst_comp_tasks.size());
    std::vector<std::vector<TaskNode*>> sorted_ctrl_tasks;
J
Juncheng 已提交
531 532 533 534 535 536 537
    const SbpParallel& src_sbp_parallel =
        Global<OpGraph>::Get()->GetSbpParallel(src_logical->SoleOp()->op_name(), lbi);
    const SbpParallel& dst_sbp_parallel =
        Global<OpGraph>::Get()->GetSbpParallel(dst_logical->SoleOp()->op_name(), lbi);
    const std::shared_ptr<const ParallelDesc>& src_parallel_desc = src_logical->parallel_desc();
    const std::shared_ptr<const ParallelDesc>& dst_parallel_desc = dst_logical->parallel_desc();
    const BlobDesc& blob_desc = Global<OpGraph>::Get()->GetLogicalBlobDesc(lbi);
qq_22305325's avatar
qq_22305325 已提交
538
    auto status = CHECK_JUST(sub_tsk_gph_builder_->Build(
539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554
        sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, &sorted_ctrl_tasks,
        *src_parallel_desc, *dst_parallel_desc, lbi, blob_desc, src_sbp_parallel, dst_sbp_parallel,
        *src_logical->out_blob_time_shape()));
    boxing_logger_->Log(*status, src_logical->SoleOp()->op_name(), dst_logical->SoleOp()->op_name(),
                        *src_parallel_desc, *dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel,
                        lbi, blob_desc);
    sub_tsk_gph_builder_ctx_->ConnectAll121(out_nodes, sorted_dst_comp_tasks);
    if (!sorted_ctrl_tasks.empty()) {
      CHECK_EQ(sorted_ctrl_tasks.size(), sorted_dst_comp_tasks.size());
      FOR_RANGE(size_t, i, 0, sorted_dst_comp_tasks.size()) {
        for (TaskNode* ctrl_node : sorted_ctrl_tasks.at(i)) {
          Connect<TaskNode>(ctrl_node, NewEdge(), sorted_dst_comp_tasks.at(i));
          ctrl_node->BuildCtrlRegstDesc(sorted_dst_comp_tasks.at(i));
        }
      }
    }
J
Juncheng 已提交
555 556 557
  }
}

558
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByOneToOne) {
W
Will Zhang 已提交
559 560
  CHECK_EQ(sorted_src_comp_tasks.size(), sorted_dst_comp_tasks.size());
  FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) {
561 562
    CompTaskNode* src = sorted_src_comp_tasks.at(i);
    CompTaskNode* dst = sorted_dst_comp_tasks.at(i);
563
    BuildTaskPath(src, dst, MutBufTask, true);
W
Will Zhang 已提交
564
  }
W
willzhang4a58 已提交
565 566
}

567
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBroadcastToBroadcast) {
L
Li Xinqi 已提交
568
  for (CompTaskNode* dst_node : sorted_dst_comp_tasks) {
J
Juncheng 已提交
569 570 571 572
    CompTaskNode* nearest_src_node =
        SubTskGphBuilderUtil::FindNearestNode(sorted_src_comp_tasks, dst_node);
    CHECK_NOTNULL(nearest_src_node);
    BuildTaskPath(nearest_src_node, dst_node, MutBufTask, true);
L
Li Xinqi 已提交
573 574 575
  }
}

L
Li Xinqi 已提交
576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialInLbiConnect) {
  HashSet<LogicalBlobId> lbis;
  for (const auto& obn : src_logical->SoleOp()->output_bns()) {
    lbis.insert(src_logical->SoleOp()->BnInOp2Lbi(obn));
  }
  CHECK_EQ(sorted_src_comp_tasks.size(), 1);
  CHECK_EQ(dst_logical->SoleOp()->input_bns().size(), sorted_dst_comp_tasks.size());
  FOR_RANGE(int, i, 0, sorted_dst_comp_tasks.size()) {
    const auto& lbi = dst_logical->SoleOp()->BnInOp2Lbi(dst_logical->SoleOp()->input_bns().Get(i));
    if (lbis.find(lbi) != lbis.end()) {
      BuildTaskPath(sorted_src_comp_tasks.at(0), sorted_dst_comp_tasks.at(i), MutBufTask, true);
    }
  }
}

DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialOutLbiConnect) {
  HashSet<LogicalBlobId> lbis;
  for (const auto& ibn : dst_logical->SoleOp()->input_bns()) {
    lbis.insert(dst_logical->SoleOp()->BnInOp2Lbi(ibn));
  }
  CHECK_EQ(sorted_dst_comp_tasks.size(), 1);
  CHECK_EQ(src_logical->SoleOp()->output_bns().size(), sorted_src_comp_tasks.size());
  FOR_RANGE(int, i, 0, sorted_src_comp_tasks.size()) {
    const auto& lbi = src_logical->SoleOp()->BnInOp2Lbi(src_logical->SoleOp()->output_bns().Get(i));
    if (lbis.find(lbi) != lbis.end()) {
      BuildTaskPath(sorted_src_comp_tasks.at(i), sorted_dst_comp_tasks.at(0), MutBufTask, true);
    }
  }
}

606 607 608 609 610 611 612 613 614
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphNormalForwardToDecodeH2D) {
  CHECK_EQ(sorted_src_comp_tasks.size(), sorted_dst_comp_tasks.size());
  FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) {
    CompTaskNode* src = sorted_src_comp_tasks.at(i);
    CompTaskNode* dst = sorted_dst_comp_tasks.at(i);
    Connect<TaskNode>(src, NewEdge(), dst);
  }
}

J
Jinhui Yuan 已提交
615
void TaskGraph::BuildTaskPath(
J
Jinhui Yuan 已提交
616 617
    CompTaskNode* src, CompTaskNode* dst,
    std::function<TaskNode**(CompTaskNode* src, int64_t machine_id, int32_t mem_zone_id)>
J
Jinhui Yuan 已提交
618 619
        MutBufTask,
    bool use_buf_task_node) {
J
Jinhui Yuan 已提交
620
  CHECK_NE(src, dst);
J
Jinhui Yuan 已提交
621 622
  auto GetBufTask = [&](int64_t machine_id, int32_t mem_zone_id) {
    return *MutBufTask(src, machine_id, mem_zone_id);
J
Jinhui Yuan 已提交
623
  };
J
Jinhui Yuan 已提交
624 625
  auto SetBufTask = [&](int64_t machine_id, int32_t mem_zone_id, TaskNode* new_val) {
    TaskNode** cur_val = MutBufTask(src, machine_id, mem_zone_id);
J
Jinhui Yuan 已提交
626 627 628 629 630 631 632 633
    if (*cur_val == nullptr) {
      *cur_val = new_val;
    } else {
      CHECK_EQ(*cur_val, new_val);
    }
    return new_val;
  };

J
Jinhui Yuan 已提交
634 635 636
  TaskNode* cur_node = src;
  while (cur_node->machine_id() != dst->machine_id()
         || cur_node->MemZoneId121() != dst->MemZoneId121()) {
J
Jinhui Yuan 已提交
637
    cur_node = BuildTaskStep(cur_node, dst, GetBufTask, SetBufTask, use_buf_task_node);
638
  }
L
lixinqi 已提交
639
  if (cur_node != dst) { Connect<TaskNode>(cur_node, NewEdge(), dst); }
J
Jinhui Yuan 已提交
640
}
641

J
Jinhui Yuan 已提交
642
TaskNode* TaskGraph::BuildTaskStep(
J
Jinhui Yuan 已提交
643
    TaskNode* cur_node, TaskNode* dst,
J
Juncheng 已提交
644 645
    const std::function<TaskNode*(int64_t machine_id, int32_t mem_zone_id)>& GetBufTask,
    const std::function<TaskNode*(int64_t machine_id, int32_t mem_zone_id, TaskNode*)>& SetBufTask,
J
Jinhui Yuan 已提交
646
    bool use_buf_task_node) {
J
Jinhui Yuan 已提交
647 648 649 650 651
  int32_t cpu_mem_zone_id = Global<IDMgr>::Get()->CpuMemZoneId();
  int32_t next_mem_zone_id = -1;
  TaskNode* next_node = nullptr;
  if (cur_node->MemZoneId121() != cpu_mem_zone_id) {
    next_mem_zone_id = cpu_mem_zone_id;
J
Jinhui Yuan 已提交
652
    if (!use_buf_task_node || !(next_node = GetBufTask(cur_node->machine_id(), next_mem_zone_id))) {
J
Jinhui Yuan 已提交
653 654
      next_node = AddCopyD2HTaskFrom(cur_node);
      Connect<TaskNode>(cur_node, NewEdge(), next_node);
655
    }
J
Jinhui Yuan 已提交
656 657
  } else if (cur_node->machine_id() == dst->machine_id()) {
    next_mem_zone_id = dst->MemZoneId121();
J
Jinhui Yuan 已提交
658
    if (!use_buf_task_node || !(next_node = GetBufTask(cur_node->machine_id(), next_mem_zone_id))) {
L
Li Xinqi 已提交
659
      next_node = TryAddCopyH2DTaskTo(dst);
L
lixinqi 已提交
660
      if (next_node == nullptr) { next_node = dst; }
J
Jinhui Yuan 已提交
661
      Connect<TaskNode>(cur_node, NewEdge(), next_node);
662
    }
J
Jinhui Yuan 已提交
663 664
  } else if (cur_node->machine_id() != dst->machine_id()) {
    next_mem_zone_id = cpu_mem_zone_id;
J
Jinhui Yuan 已提交
665
    if (!use_buf_task_node || !(next_node = GetBufTask(dst->machine_id(), next_mem_zone_id))) {
J
Jinhui Yuan 已提交
666 667 668 669 670
      next_node = AddCopyCommNetTaskBetween(cur_node, dst);
      Connect<TaskNode>(cur_node, NewEdge(), next_node);
    }
  } else {
    UNIMPLEMENTED();
671
  }
L
lixinqi 已提交
672 673 674
  if (use_buf_task_node && (next_node != dst)) {
    SetBufTask(next_node->machine_id(), next_mem_zone_id, next_node);
  }
J
Jinhui Yuan 已提交
675
  return next_node;
676 677
}

L
Li Xinqi 已提交
678
TaskNode* TaskGraph::TryAddCopyH2DTaskTo(TaskNode* task) {
L
lixinqi 已提交
679
  if (IsInterfaceTask(task)) { return nullptr; }
680
  if (IsClassRegistered<int32_t, TickTockTaskType>(task->GetTaskType())) { return nullptr; }
J
Jinhui Yuan 已提交
681
  CHECK_EQ(task->device_type(), DeviceType::kGPU);
W
Will Zhang 已提交
682
  CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
W
willzhang4a58 已提交
683
  copy_task->Init(CopyHdOpConf::H2D, task->machine_id(), task->GpuPhyId());
W
Will Zhang 已提交
684
  return copy_task;
W
willzhang4a58 已提交
685 686
}

J
Jinhui Yuan 已提交
687 688
TaskNode* TaskGraph::AddCopyD2HTaskFrom(TaskNode* task) {
  CHECK_EQ(task->device_type(), DeviceType::kGPU);
W
Will Zhang 已提交
689
  CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
W
willzhang4a58 已提交
690
  copy_task->Init(CopyHdOpConf::D2H, task->machine_id(), task->GpuPhyId());
W
Will Zhang 已提交
691 692
  return copy_task;
}
693

J
Jinhui Yuan 已提交
694
TaskNode* TaskGraph::AddCopyCommNetTaskBetween(TaskNode* src, TaskNode* dst) {
W
Will Zhang 已提交
695 696
  CHECK_NE(src->machine_id(), dst->machine_id());
  CopyCommNetTaskNode* copy_comm_net_task = NewNode<CopyCommNetTaskNode>();
697
  copy_comm_net_task->Init(dst->machine_id(), src->machine_id());
J
Jinhui Yuan 已提交
698
  return copy_comm_net_task;
W
willzhang4a58 已提交
699 700
}

W
willzhang4a58 已提交
701
void TaskGraph::ConnectWithCopyCommNetIfNeed(TaskNode* src, TaskNode* dst) {
W
willzhang4a58 已提交
702 703 704
  if (src->machine_id() == dst->machine_id()) {
    Connect(src, NewEdge(), dst);
  } else {
J
Jinhui Yuan 已提交
705 706 707
    TaskNode* copy_comm_net_task = AddCopyCommNetTaskBetween(src, dst);
    Connect<TaskNode>(src, NewEdge(), copy_comm_net_task);
    Connect<TaskNode>(copy_comm_net_task, NewEdge(), dst);
W
willzhang4a58 已提交
708 709 710
  }
}

L
Li Xinqi 已提交
711
bool IsBackEdge(TaskNode* src, TaskNode* dst) { return false; }
J
Jinhui Yuan 已提交
712

W
willzhang4a58 已提交
713
}  // namespace oneflow