task_graph.cpp 31.7 KB
Newer Older
S
Shenghang Tsai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
J
jiyuan 已提交
16
#include "oneflow/core/graph/task_graph.h"
J
Jinhui Yuan 已提交
17 18
#include "oneflow/core/graph/chain_graph.h"
#include "oneflow/core/common/util.h"
L
Li Xinqi 已提交
19
#include "oneflow/core/graph/inplace_lbi_graph.h"
20
#include "oneflow/core/register/runtime_blob_desc.h"
J
Jinhui Yuan 已提交
21
#include "oneflow/core/job/thrd_id_generator.h"
22
#include "oneflow/core/job/global_for.h"
L
Li Xinqi 已提交
23
#include "oneflow/core/operator/variable_op.h"
24
#include "oneflow/core/operator/user_op_util.h"
J
Juncheng 已提交
25
#include "oneflow/core/graph/op_graph.h"
C
cheng cheng 已提交
26
#include "oneflow/core/graph/normal_forward_compute_task_node.h"
J
Juncheng 已提交
27 28 29
#include "oneflow/core/graph/boxing/sub_task_graph_builder_context.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/chain_sub_task_graph_builder.h"
J
Juncheng 已提交
30
#include "oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.h"
L
Li Xinqi 已提交
31
#include "oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h"
J
Juncheng 已提交
32
#include "oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.h"
J
Juncheng 已提交
33
#include "oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.h"
J
Juncheng 已提交
34
#include "oneflow/core/graph/boxing/b21_sub_task_graph_builder.h"
J
Juncheng 已提交
35
#include "oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h"
J
Juncheng 已提交
36
#include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h"
37
#include "oneflow/core/graph/boxing_identity_task_node.h"
W
willzhang4a58 已提交
38 39 40

namespace oneflow {

L
Li Xinqi 已提交
41 42
namespace {

L
lixinqi 已提交
43
bool IsInterfaceTask(const TaskNode* node) {
L
Li Xinqi 已提交
44 45 46 47
  const auto* comp_task_node = dynamic_cast<const CompTaskNode*>(node);
  if (comp_task_node == nullptr) { return false; }
  if (comp_task_node->logical_node()->op_vec().size() != 1) { return false; }
  auto op_type_case = comp_task_node->logical_node()->SoleOp()->op_conf().op_type_case();
48
  return IsClassRegistered<int32_t, IsInterfaceOpConf4OpTypeCase>(op_type_case);
L
Li Xinqi 已提交
49 50
}

L
Li Xinqi 已提交
51 52 53 54 55 56 57 58 59
bool IsConnectToTickOp(const TaskNode* node) {
  const auto* comp_task_node = dynamic_cast<const CompTaskNode*>(node);
  if (comp_task_node == nullptr) { return false; }
  if (comp_task_node->logical_node()->op_vec().size() != 1) { return false; }
  const Operator* op = comp_task_node->logical_node()->SoleOp().get();
  if (dynamic_cast<const VariableOp*>(op) != nullptr) { return true; }
  return false;
}

C
cheng cheng 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
bool IsSpecialOpNotConsiderMergeInChain(const Operator* op) {
  const OperatorConf& op_conf = op->op_conf();
  if (op_conf.has_variable_conf() || op_conf.has_keep_header_only_conf() || op_conf.has_tick_conf()
      || op_conf.has_device_tick_conf() || op_conf.has_partial_tick_conf()) {
    return true;
  }
  return false;
}

bool IsTaskNodeProducedResgtHasMultiRegstNum(const TaskNode* node) {
  for (const auto& pair : node->produced_regsts()) {
    if (pair.second->min_register_num() > 1) { return true; }
  }
  return false;
}

bool CanBeMergedInChain(const TaskNode* node) {
  // ONLY the node which is NormalForward and in GPU and NOT variable can be merged.
  if (IsTaskNodeProducedResgtHasMultiRegstNum(node)) { return false; }
  const auto* fw_comp_node = dynamic_cast<const NormalForwardCompTaskNode*>(node);
  if (fw_comp_node == nullptr) { return false; }
  if (fw_comp_node->logical_node()->op_vec().size() != 1) { return false; }
  if (fw_comp_node->device_type() != DeviceType::kGPU) { return false; }
  const Operator* op = fw_comp_node->logical_node()->SoleOp().get();
  if (IsSpecialOpNotConsiderMergeInChain(op)) { return false; }
  return true;
}

void TraverseConnectedSubGraphMergeInThisChain(TaskNode* this_node, const int64_t this_chain_id) {
  CHECK_NE(this_chain_id, -1);
  CHECK_EQ(this_node->chain_id(), -1);
  // bfs search all node can be merged in this chain
  HashSet<TaskNode*> visited_nodes;
  std::queue<TaskNode*> queued_nodes;
  queued_nodes.push(this_node);
  visited_nodes.insert(this_node);
  while (!queued_nodes.empty()) {
    TaskNode* cur_node = queued_nodes.front();
    queued_nodes.pop();

    CHECK_EQ(cur_node->chain_id(), -1);
    cur_node->set_chain_id(this_chain_id);

    cur_node->ForEachNodeOnInOutEdge([&](TaskNode* next_node) {
      // NOTE(chengcheng): use area_id to not merge optimizer ops with fw/bw ops
      if (visited_nodes.find(next_node) == visited_nodes.end() && CanBeMergedInChain(next_node)
          && this_node->GlobalWorkStreamId() == next_node->GlobalWorkStreamId()
          && this_node->area_id() == next_node->area_id()) {
        if (next_node->chain_id() == -1) {
          queued_nodes.push(next_node);
          visited_nodes.insert(next_node);
        } else {
          CHECK_EQ(next_node->chain_id(), this_chain_id);
        }
      }
    });
  }
}

L
Li Xinqi 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
std::function<TaskNode*(const std::string&)> MakeGetterTaskNode4SoleOpName(
    const HashSet<TaskNode*>& task_nodes) {
  auto op_name2task_nodes = std::make_shared<HashMap<std::string, HashSet<TaskNode*>>>();
  for (TaskNode* task_node : task_nodes) {
    if (task_node->exec_gph().node_num() == 1) {
      ExecNode* exec_node = task_node->exec_gph().SoleNode();
      CHECK((*op_name2task_nodes)[exec_node->op()->op_name()].emplace(task_node).second);
    }
  }
  return [op_name2task_nodes](const std::string& op_name) -> TaskNode* {
    const auto& iter = op_name2task_nodes->find(op_name);
    if (iter == op_name2task_nodes->end()) { return nullptr; }
    if (iter->second.size() > 1) { return nullptr; }
    return *iter->second.begin();
  };
J
Juncheng 已提交
134
}
L
Li Xinqi 已提交
135 136 137 138 139 140 141 142 143

bool IsLbiOnTaskEdge(const TaskEdge* edge, const LogicalBlobId& lbi) {
  for (const auto& regst_desc : edge->GetRegsts()) {
    if (regst_desc->HasLbi(lbi)) { return true; }
  }
  return false;
}

std::function<bool(const LogicalBlobId&, const std::string&)>
S
scxfjiang 已提交
144
MakePredicatorIsLbiAllConsumersReachable(
J
Juncheng 已提交
145 146 147
    const std::function<const TaskNode*(const std::string&)>& TaskNode4SoleOpName,
    const std::function<bool(const std::string&, const std::string&)>&
        IsOpNameDataOrCtrlReachable) {
S
scxfjiang 已提交
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
  auto IsDataOrCtrlReachable = [IsOpNameDataOrCtrlReachable](const TaskNode* src_node,
                                                             const TaskNode* dst_node) -> bool {
    if (src_node->chain_id() == dst_node->chain_id()
        && src_node->order_in_graph() <= dst_node->order_in_graph()) {
      return true;
    }
    const CompTaskNode* comp_src_node = dynamic_cast<const CompTaskNode*>(src_node);
    if (comp_src_node == nullptr) { return false; }
    if (comp_src_node->logical_node()->op_vec().size() != 1) { return false; }
    const CompTaskNode* comp_dst_node = dynamic_cast<const CompTaskNode*>(dst_node);
    if (comp_dst_node == nullptr) { return false; }
    if (comp_dst_node->logical_node()->op_vec().size() != 1) { return false; }
    return IsOpNameDataOrCtrlReachable(comp_src_node->logical_node()->SoleOp()->op_name(),
                                       comp_dst_node->logical_node()->SoleOp()->op_name());
  };
  return [TaskNode4SoleOpName, IsDataOrCtrlReachable](const LogicalBlobId& lbi,
                                                      const std::string& op_name) -> bool {
L
Li Xinqi 已提交
165 166 167
    const TaskNode* src_task_node = TaskNode4SoleOpName(lbi.op_name());
    const TaskNode* dst_task_node = TaskNode4SoleOpName(op_name);
    size_t out_edges_size = 0;
L
Li Xinqi 已提交
168
    size_t reachable_out_edges_size = 0;
L
Li Xinqi 已提交
169 170 171
    for (TaskEdge* out_edge : src_task_node->out_edges()) {
      if (IsLbiOnTaskEdge(out_edge, lbi)) {
        out_edges_size += 1;
S
scxfjiang 已提交
172
        reachable_out_edges_size += IsDataOrCtrlReachable(out_edge->dst_node(), dst_task_node);
L
Li Xinqi 已提交
173 174
      }
    }
L
Li Xinqi 已提交
175
    return out_edges_size > 0 && out_edges_size == reachable_out_edges_size;
L
Li Xinqi 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189
  };
}

bool IsInplaceAllowed(
    TaskNode* task_node, const std::vector<std::string>& bns,
    const std::function<const TaskNode*(const std::string&)>& TaskNode4SoleOpName) {
  if (task_node->exec_gph().node_num() != 1) { return false; }
  const auto& exec_node = *task_node->exec_gph().SoleNode();
  for (const auto& bn : bns) {
    // TaskNode for bn is not nullptr if it's on the same device with `task_node`
    if (TaskNode4SoleOpName(exec_node.op()->BnInOp2Lbi(bn).op_name()) == nullptr) { return false; }
    const RegstDesc& regst_desc = *exec_node.RegstDesc4BnInOp(bn);
    if (regst_desc.NumOfLbi() != 1) { return false; }
  }
L
lixinqi 已提交
190
  const BlobDesc* first_blob = nullptr;
191
  for (const auto& bn : bns) {
L
lixinqi 已提交
192 193 194
    const BlobDesc* blob_desc = exec_node.RegstDesc4BnInOp(bn)->SoleBlobDesc();
    if (first_blob == nullptr) {
      first_blob = blob_desc;
195
    } else {
J
Juncheng 已提交
196
      if (!(first_blob->shape().elem_cnt() == blob_desc->shape().elem_cnt()
L
lixinqi 已提交
197 198 199
            && first_blob->data_type() == blob_desc->data_type())) {
        return false;
      }
200 201
    }
  }
L
Li Xinqi 已提交
202 203 204
  return true;
}

qq_22305325's avatar
qq_22305325 已提交
205 206 207 208 209 210 211 212 213
std::unique_ptr<BoxingLogger> CreateBoxingLogger() {
  if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
    return std::unique_ptr<BoxingLogger>(
        new CsvBoxingLogger(StrCat("boxing/log/", GlobalJobDesc().job_id()) + ".csv"));
  } else {
    return std::unique_ptr<BoxingLogger>(new NullBoxingLogger());
  }
}

L
Li Xinqi 已提交
214 215
}  // namespace

216 217
TaskGraph::TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph) {
  logical_gph_ = std::move(logical_gph);
J
Juncheng 已提交
218
  sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this));
qq_22305325's avatar
qq_22305325 已提交
219
  boxing_logger_ = CreateBoxingLogger();
J
Juncheng 已提交
220 221
  std::vector<std::shared_ptr<SubTskGphBuilder>> builders;
  builders.emplace_back(new OneToOneSubTskGphBuilder());
J
Juncheng 已提交
222
  builders.emplace_back(new B21SubTskGphBuilder());
J
Juncheng 已提交
223 224 225
  builders.emplace_back(new CollectiveBoxingSubTskGphBuilder());
  builders.emplace_back(new SliceBoxingSubTskGphBuilder());
  builders.emplace_back(new NaiveB2BSubTskGphBuilder());
J
Juncheng 已提交
226
  builders.emplace_back(new NaiveB2PSubTskGphBuilder());
J
Juncheng 已提交
227
  sub_tsk_gph_builder_.reset(new ChainSubTskGphBuilder(builders));
W
willzhang4a58 已提交
228
  HashMap<const LogicalNode*, std::vector<CompTaskNode*>> logical2sorted_comp_tasks;
J
Jinhui Yuan 已提交
229 230 231
  HashMap<CompTaskNode*, HashMap<int64_t, std::vector<TaskNode*>>> buf_task;
  auto MutBufTask = [&](CompTaskNode* task_node, int64_t machine_id, int32_t mem_zone_id) {
    auto& buf_vec = buf_task[task_node][machine_id];
232 233 234
    if (buf_vec.empty()) {
      buf_vec.assign(Global<ResourceDesc, ForSession>::Get()->MemZoneNum(), nullptr);
    }
235 236
    return &(buf_vec.at(mem_zone_id));
  };
W
willzhang4a58 已提交
237

238 239
  std::vector<int64_t> cpu_device_offset(Global<ResourceDesc, ForSession>::Get()->TotalMachineNum(),
                                         0);
240
  auto AllocateCpuThrdIdEvenly = [&](const TaskNode* task_node) {
241
    CHECK(!task_node->IsIndependent());
J
Jinhui Yuan 已提交
242
    int64_t& offset = cpu_device_offset.at(task_node->machine_id());
J
Juncheng 已提交
243
    int64_t ret = Global<IDMgr>::Get()->GetCpuDeviceThrdId(offset);
244
    offset = (offset + 1) % Global<ResourceDesc, ForSession>::Get()->CpuDeviceNum();
245 246
    return ret;
  };
J
Jinhui Yuan 已提交
247 248

  std::vector<std::pair<int64_t, CompTaskNode*>> machine_persistence_task_vec;
249
  logical_gph_->ForEachNode([&](const LogicalNode* logical_node) {
250
    logical_node->GenSortedCompTaskNodes(
J
Jinhui Yuan 已提交
251
        AllocateCpuThrdIdEvenly, &machine_persistence_task_vec, [&](CompTaskNode* comp_task_node) {
252 253
          AddAllocatedNode(comp_task_node);
          logical2sorted_comp_tasks[logical_node].push_back(comp_task_node);
254
          comp_task_node->set_area_id(logical_node->GetAreaId());
255
        });
256
  });
N
Niu Chong 已提交
257

258
  GenerateIndependentThrdId(machine_persistence_task_vec);
259
  logical_gph_->ForEachEdge([&](const LogicalEdge* logical_edge) {
W
willzhang4a58 已提交
260 261
    BldSubTskGphMthd method =
        GetMthdForBldSubTskGph(logical_edge->src_node(), logical_edge->dst_node());
262 263
    (this->*method)(logical_edge->src_node(), logical_edge->dst_node(),
                    logical2sorted_comp_tasks.at(logical_edge->src_node()),
264 265
                    logical2sorted_comp_tasks.at(logical_edge->dst_node()), MutBufTask,
                    AllocateCpuThrdIdEvenly);
J
Jinhui Yuan 已提交
266
    SetAreaIdForNewNodes(logical_edge->src_node(), logical_edge->dst_node());
267
  });
L
Li Xinqi 已提交
268 269 270 271 272 273 274
  logical_gph_->ForEachNecessaryCtrlEdge(
      [&](const LogicalNode* src, const LogicalNode* dst, int64_t ctrl_regst_num) {
        const auto& src_task_nodes = logical2sorted_comp_tasks.at(src);
        const auto& dst_task_nodes = logical2sorted_comp_tasks.at(dst);
        ConnectCtrlEdges(src_task_nodes, dst_task_nodes, ctrl_regst_num);
      });

C
cheng cheng 已提交
275
  SetOrderInGraphForEachNode();
276
  if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) { ToDotWithAutoFilePath(); }
W
Will Zhang 已提交
277 278
}

L
Li Xinqi 已提交
279 280 281 282 283
void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_nodes,
                                 const std::vector<CompTaskNode*>& dst_task_nodes,
                                 int64_t ctrl_regst_num) {
  CHECK_EQ(src_task_nodes.size(), dst_task_nodes.size());
  FOR_RANGE(int32_t, i, 0, src_task_nodes.size()) {
284 285 286
    std::string regst_desc_name;
    RegstDesc* ctrl_regst_desc =
        src_task_nodes.at(i)->BuildCtrlRegstDesc(dst_task_nodes.at(i), &regst_desc_name);
L
Li Xinqi 已提交
287 288 289 290
    ctrl_regst_desc->UpdtMinRegstNumIfNeed(ctrl_regst_num);
    ctrl_regst_desc->UpdtMaxRegstNumIfNeed(ctrl_regst_num);
    ctrl_regst_desc->mut_regst_desc_type()->mutable_ctrl_regst_desc()->set_returned_regst_num(
        ctrl_regst_num);
291 292 293 294

    TaskEdge* edge = NewEdge();
    Connect<TaskNode>(src_task_nodes.at(i), edge, dst_task_nodes.at(i));
    src_task_nodes.at(i)->BindEdgeWithProducedRegst(edge, regst_desc_name);
L
Li Xinqi 已提交
295 296 297
  }
}

298
void TaskGraph::GenerateIndependentThrdId(
J
Jinhui Yuan 已提交
299 300 301 302 303 304
    const std::vector<std::pair<int64_t, CompTaskNode*>>& persistence_nodes) {
  std::vector<std::pair<int64_t, TaskType>> machine_task_type_vec;
  for (auto pair : persistence_nodes) {
    machine_task_type_vec.emplace_back(std::make_pair(pair.first, pair.second->GetTaskType()));
  }

305
  ThrdIdGenerator generator(machine_task_type_vec, Global<IDMgr>::Get()->BaseIndependentThrdId());
J
Juncheng 已提交
306
  for (const auto& pair : persistence_nodes) {
J
Jinhui Yuan 已提交
307 308 309 310 311
    int64_t thrd_id = generator.GenerateThrdId(pair.first, pair.second->GetTaskType());
    pair.second->set_thrd_id(thrd_id);
  }
}

312
void TaskGraph::AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAllowedStartNode,
J
Juncheng 已提交
313
                                       const std::function<void(TaskNode* node)>& Handler) const {
314
  auto ForEachInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
315
    node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) {
L
Li Xinqi 已提交
316
      if (IsBackEdge(node_on_in_edge, node)) { return; }
317
      Handler(const_cast<TaskNode*>(node_on_in_edge));
318 319
    });
  };
320
  auto ForEachOutNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
321
    node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) {
L
Li Xinqi 已提交
322
      if (IsBackEdge(node, node_on_out_edge)) { return; }
323
      Handler(const_cast<TaskNode*>(node_on_out_edge));
324 325
    });
  };
L
Li Xinqi 已提交
326 327 328 329 330 331 332 333 334
  auto IsSourceNode = [&](TaskNode* node) {
    int32_t in_node_num = 0;
    ForEachInNode(node, [&](TaskNode* in_node) { ++in_node_num; });
    return in_node_num == 0;
  };
  std::list<TaskNode*> starts;
  ForEachNode([&](TaskNode* node) {
    if (IsSourceNode(node) && IsAllowedStartNode(node)) { starts.push_back(node); }
  });
335
  // DfsTopo will cause inappropriate chain graph
336 337 338
  TopoForEachNode(starts, ForEachInNode, ForEachOutNode, Handler);
}

J
Juncheng 已提交
339
void TaskGraph::AcyclicTopoForEachNode(const std::function<void(TaskNode* node)>& Handler) const {
340
  return AcyclicTopoForEachNode([](TaskNode*) { return true; }, Handler);
341 342
}

J
Jinhui Yuan 已提交
343 344 345 346
void TaskGraph::RemoveEmptyRegsts() {
  ForEachNode([&](TaskNode* node) { node->EraseZeroSizeProducedBlob(); });
  ForEachNode([&](TaskNode* node) { node->EraseZeroSizeConsumedRegst(); });
  ForEachNode([&](TaskNode* node) { node->EraseZeroSizeProducedRegst(); });
347
  ForEachNode([&](TaskNode* node) { node->UnbindBnWithEmptyRegst(); });
J
Jinhui Yuan 已提交
348 349
}

C
cheng cheng 已提交
350 351 352 353
void TaskGraph::MergeChainAndAddOrderingCtrlEdgeInSameChain() {
  MergeChain();
  BuildCtrlRegstDescInSameChain();
}
J
Jinhui Yuan 已提交
354

C
cheng cheng 已提交
355
void TaskGraph::SetOrderInGraphForEachNode() {
J
Jinhui Yuan 已提交
356
  int64_t order_in_graph = 0;
C
cheng cheng 已提交
357
  auto SetOrderInGraph = [&](TaskNode* task_node) {
358 359 360
    task_node->set_order_in_graph(order_in_graph);
    ordered_task_nodes_.emplace_back(task_node);
    ++order_in_graph;
C
cheng cheng 已提交
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
  };
  AcyclicTopoForEachNode(SetOrderInGraph);
}

void TaskGraph::MergeChain() {
  int64_t chain_id = 0;
  for (auto* this_node : ordered_task_nodes_) {
    // skip if this node has been set in a chain.
    if (this_node->chain_id() != -1) { continue; }

    CHECK_EQ(this_node->chain_id(), -1);
    if (CanBeMergedInChain(this_node)) {
      TraverseConnectedSubGraphMergeInThisChain(this_node, chain_id);
    } else {
      this_node->set_chain_id(chain_id);
    }

    ++chain_id;
  }
  for (auto* node : ordered_task_nodes_) { CHECK_NE(node->chain_id(), -1); }
J
Jinhui Yuan 已提交
381 382
}

383
void TaskGraph::BuildCtrlRegstDescInSameChain() {
J
Jinhui Yuan 已提交
384
  HashMap<int64_t, TaskNode*> chain_id2node;
L
Li Xinqi 已提交
385 386
  for (auto* node : ordered_task_nodes_) {
    if (IsConnectToTickOp(node)) { continue; }
J
Jinhui Yuan 已提交
387 388 389 390 391 392 393 394 395 396 397
    int64_t chain_id = node->chain_id();
    auto iter = chain_id2node.find(chain_id);
    if (iter == chain_id2node.end()) {
      CHECK(chain_id2node.emplace(chain_id, node).second);
    } else {
      iter->second->BuildCtrlRegstDescIfNeed(node);
      iter->second = node;
    }
  }
}

L
Li Xinqi 已提交
398
void TaskGraph::GetInplaceOpBlobArgList(
399
    InplaceObasInfo* obas_info, const HashSet<TaskNode*>& dev_nodes,
L
Li Xinqi 已提交
400
    const std::function<const TaskNode*(const std::string&)>& TaskNode4OpName) const {
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
  auto AddMutableInplaceArgPair = [&](TaskNode* node, const std::string& ibn,
                                      const std::string& obn, const std::string& op_name) {
    if (IsInplaceAllowed(node, {ibn, obn}, TaskNode4OpName)) {
      auto* pair = obas_info->mut_inplace_oba_pairs.mutable_pair()->Add();
      *pair->mutable_first() = GenOpBlobArg(op_name, ibn);
      *pair->mutable_second() = GenOpBlobArg(op_name, obn);
    }
  };
  auto AddConstInplaceArgPair = [&](TaskNode* node, const std::string& ibn, const std::string& obn,
                                    const std::string& op_name) {
    if (IsInplaceAllowed(node, {ibn, obn}, TaskNode4OpName)) {
      auto* pair = obas_info->con_inplace_oba_pairs.mutable_pair()->Add();
      *pair->mutable_first() = GenOpBlobArg(op_name, ibn);
      *pair->mutable_second() = GenOpBlobArg(op_name, obn);
    }
  };

L
Li Xinqi 已提交
418 419 420 421 422 423
  for (TaskNode* task_node : dev_nodes) {
    if (task_node->exec_gph().node_num() != 1) { continue; }
    const auto& op = *task_node->exec_gph().SoleNode()->op();
    for (const std::string& ibn : op.input_bns()) {
      if (op.InputBlobModifier4Ibn(ibn).is_mutable()) {
        CHECK(IsInplaceAllowed(task_node, {ibn}, TaskNode4OpName));
424
        *obas_info->mut_in_obas.mutable_oba()->Add() = GenOpBlobArg(op.op_name(), ibn);
L
Li Xinqi 已提交
425 426 427
      }
    }
    for (const std::string& obn : op.output_bns()) {
428 429 430 431 432 433 434 435 436 437
      const auto& obn_modifier = op.OutputBlobModifier4Obn(obn);
      if (obn_modifier.has_mutable_inplace_ibn()) {
        AddMutableInplaceArgPair(task_node, obn_modifier.mutable_inplace_ibn(), obn, op.op_name());
      } else if (obn_modifier.has_const_inplace_ibn()) {
        AddConstInplaceArgPair(task_node, obn_modifier.const_inplace_ibn(), obn, op.op_name());
      }
    }

    if (op.op_conf().has_user_conf()) {
      const OpContext* op_ctx = task_node->exec_gph().SoleNode()->op_context();
J
Juncheng 已提交
438 439
      const UserOpCtx* user_op_ctx = dynamic_cast<const UserOpCtx*>(op_ctx);
      CHECK_NOTNULL(user_op_ctx);
440 441
      for (const auto& pair : user_op_ctx->mut_inplace_obn2ibn) {
        AddMutableInplaceArgPair(task_node, pair.second, pair.first, op.op_name());
L
Li Xinqi 已提交
442
      }
443 444
      for (const auto& pair : user_op_ctx->con_inplace_obn2ibn) {
        AddConstInplaceArgPair(task_node, pair.second, pair.first, op.op_name());
L
Li Xinqi 已提交
445
      }
L
Li Xinqi 已提交
446
    }
L
Li Xinqi 已提交
447 448 449 450
  }
}

void TaskGraph::GetSafeInplaceOpBlobArgList(
451
    InplaceObasInfo* safe_obas_info, const HashSet<TaskNode*>& dev_nodes,
J
Juncheng 已提交
452 453
    const std::function<bool(const std::string&, const std::string&)>& IsOpNameDataOrCtrlReachable)
    const {
L
Li Xinqi 已提交
454
  auto TaskNode4SoleOpName = MakeGetterTaskNode4SoleOpName(dev_nodes);
455 456
  InplaceObasInfo obas_info;
  GetInplaceOpBlobArgList(&obas_info, dev_nodes, TaskNode4SoleOpName);
L
Li Xinqi 已提交
457 458 459
  auto Op4OpName = [&](const std::string& op_name) -> const Operator* {
    return TaskNode4SoleOpName(op_name)->exec_gph().SoleNode()->op().get();
  };
S
scxfjiang 已提交
460 461
  auto IsLbiAllConsumersReachable =
      MakePredicatorIsLbiAllConsumersReachable(TaskNode4SoleOpName, IsOpNameDataOrCtrlReachable);
462 463 464
  InplaceLbiGraph origin_graph(obas_info, Op4OpName);
  InplaceLbiGraph safe_graph(*safe_obas_info, Op4OpName);
  origin_graph.ComputeSafeInplaceObns(safe_obas_info, IsLbiAllConsumersReachable);
465
  if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
J
Juncheng 已提交
466 467 468 469 470
    origin_graph.ToDotWithFilePath(
        JoinPath("dot", "InplaceLbiGraph", GlobalJobDesc().job_name() + "_origin.dot"));
    safe_graph.ToDotWithFilePath(
        JoinPath("dot", "InplaceLbiGraph", GlobalJobDesc().job_name() + "_safe.dot"));
  }
L
Li Xinqi 已提交
471 472
}

473
void TaskGraph::SetTaskRegstInplaceInfo(const InplaceObasInfo& obas_info,
L
Li Xinqi 已提交
474 475 476 477 478
                                        const HashSet<TaskNode*>& dev_nodes) const {
  auto TaskNode4SoleOpName = MakeGetterTaskNode4SoleOpName(dev_nodes);
  auto Op4OpName = [&](const std::string& op_name) -> const Operator* {
    return TaskNode4SoleOpName(op_name)->exec_gph().SoleNode()->op().get();
  };
479
  InplaceLbiGraph inplace_gph(obas_info, Op4OpName);
J
Juncheng 已提交
480
  inplace_gph.ForEachConnectedComponent([&](const HashSet<const InplaceLbiNode*>& inplace_nodes) {
L
Li Xinqi 已提交
481 482 483 484 485 486 487
    for (const auto* inplace_node : inplace_nodes) {
      if (inplace_node->in_edges().empty()) { continue; }
      const auto* inplace_edge = inplace_node->SoleInEdge();
      auto* exec_node = TaskNode4SoleOpName(inplace_edge->op().op_name())->exec_gph().SoleNode();
      RegstDesc* in_regst = exec_node->RegstDesc4BnInOp(inplace_edge->ibn());
      RegstDesc* out_regst = exec_node->RegstDesc4BnInOp(inplace_edge->obn());
      out_regst->set_hint_inplace_consumed_regst_desc_id(in_regst->regst_desc_id());
L
Li Xinqi 已提交
488
    }
L
Li Xinqi 已提交
489 490 491 492 493 494 495 496 497 498 499 500 501 502 503
  });
}

void TaskGraph::ForEachGpuDeviceNodes(
    const std::function<void(const HashSet<TaskNode*>& dev_nodes)>& Handler) const {
  HashMap<std::pair<int64_t, int64_t>, HashSet<TaskNode*>> global_dev_phy_id2nodes;
  ForEachNode([&](TaskNode* task_node) {
    if (task_node->device_type() != DeviceType::kGPU) { return; }
    int64_t dev_phy_id = Global<IDMgr>::Get()->GetGpuPhyIdFromThrdId(task_node->thrd_id());
    global_dev_phy_id2nodes[{task_node->machine_id(), dev_phy_id}].emplace(task_node);
  });
  for (const auto& pair : global_dev_phy_id2nodes) { Handler(pair.second); }
}

void TaskGraph::EnableInplaceMemSharing(
S
scxfjiang 已提交
504 505
    const std::function<bool(const std::string&, const std::string&)>&
        IsOpNameDataOrCtrlReachable) {
L
Li Xinqi 已提交
506
  ForEachGpuDeviceNodes([&](const HashSet<TaskNode*>& dev_nodes) {
507 508 509
    InplaceObasInfo safe_inplace_obas_info;
    GetSafeInplaceOpBlobArgList(&safe_inplace_obas_info, dev_nodes, IsOpNameDataOrCtrlReachable);
    SetTaskRegstInplaceInfo(safe_inplace_obas_info, dev_nodes);
L
Li Xinqi 已提交
510 511 512
  });
}

J
Jinhui Yuan 已提交
513 514 515 516 517
void TaskGraph::SetAreaIdForNewNodes(const LogicalNode* src_logical,
                                     const LogicalNode* dst_logical) {
  CHECK(src_logical != nullptr && dst_logical != nullptr);
  ForEachNode([&](TaskNode* node) {
    if (node->area_id() != static_cast<int64_t>(kInvalidArea)) return;
518 519
    if (src_logical->GetAreaId() == dst_logical->GetAreaId()) {
      node->set_area_id(src_logical->GetAreaId());
J
Jinhui Yuan 已提交
520 521 522 523 524 525
    } else {
      node->set_area_id(static_cast<int64_t>(kBoundaryArea));
    }
  });
}

526 527
#define DEFINE_BLD_SUB_TASK_GRAPH_METHOD(method_name) \
  void TaskGraph::method_name BLD_SUB_TSK_GPH_MTHD_ARGS()
528 529

DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) {
J
Juncheng 已提交
530
  const std::vector<LogicalBlobId> lbis = src_logical->GetLbisTo(dst_logical);
J
Juncheng 已提交
531
  for (const LogicalBlobId& lbi : lbis) {
532
    std::vector<TaskNode*> in_nodes;
J
Juncheng 已提交
533
    if (lbis.size() == 1) {
534
      in_nodes.assign(sorted_src_comp_tasks.begin(), sorted_src_comp_tasks.end());
J
Juncheng 已提交
535 536
    } else {
      for (CompTaskNode* src_node : sorted_src_comp_tasks) {
537 538
        auto* identity_node = NewNode<BoxingIdentityTaskNode>();
        identity_node->Init(src_node->machine_id(), src_node->thrd_id(), src_node->area_id(), lbi);
J
Juncheng 已提交
539
        Connect<TaskNode>(src_node, NewEdge(), identity_node);
540
        in_nodes.push_back(identity_node);
J
Juncheng 已提交
541 542
      }
    }
543 544 545
    std::vector<TaskNode*> out_nodes;
    out_nodes.reserve(sorted_dst_comp_tasks.size());
    std::vector<std::vector<TaskNode*>> sorted_ctrl_tasks;
J
Juncheng 已提交
546 547 548 549 550 551 552
    const SbpParallel& src_sbp_parallel =
        Global<OpGraph>::Get()->GetSbpParallel(src_logical->SoleOp()->op_name(), lbi);
    const SbpParallel& dst_sbp_parallel =
        Global<OpGraph>::Get()->GetSbpParallel(dst_logical->SoleOp()->op_name(), lbi);
    const std::shared_ptr<const ParallelDesc>& src_parallel_desc = src_logical->parallel_desc();
    const std::shared_ptr<const ParallelDesc>& dst_parallel_desc = dst_logical->parallel_desc();
    const BlobDesc& blob_desc = Global<OpGraph>::Get()->GetLogicalBlobDesc(lbi);
qq_22305325's avatar
qq_22305325 已提交
553
    auto status = CHECK_JUST(sub_tsk_gph_builder_->Build(
554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569
        sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, &sorted_ctrl_tasks,
        *src_parallel_desc, *dst_parallel_desc, lbi, blob_desc, src_sbp_parallel, dst_sbp_parallel,
        *src_logical->out_blob_time_shape()));
    boxing_logger_->Log(*status, src_logical->SoleOp()->op_name(), dst_logical->SoleOp()->op_name(),
                        *src_parallel_desc, *dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel,
                        lbi, blob_desc);
    sub_tsk_gph_builder_ctx_->ConnectAll121(out_nodes, sorted_dst_comp_tasks);
    if (!sorted_ctrl_tasks.empty()) {
      CHECK_EQ(sorted_ctrl_tasks.size(), sorted_dst_comp_tasks.size());
      FOR_RANGE(size_t, i, 0, sorted_dst_comp_tasks.size()) {
        for (TaskNode* ctrl_node : sorted_ctrl_tasks.at(i)) {
          Connect<TaskNode>(ctrl_node, NewEdge(), sorted_dst_comp_tasks.at(i));
          ctrl_node->BuildCtrlRegstDesc(sorted_dst_comp_tasks.at(i));
        }
      }
    }
J
Juncheng 已提交
570 571 572
  }
}

573
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByOneToOne) {
W
Will Zhang 已提交
574 575
  CHECK_EQ(sorted_src_comp_tasks.size(), sorted_dst_comp_tasks.size());
  FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) {
576 577
    CompTaskNode* src = sorted_src_comp_tasks.at(i);
    CompTaskNode* dst = sorted_dst_comp_tasks.at(i);
578
    BuildTaskPath(src, dst, MutBufTask, true);
W
Will Zhang 已提交
579
  }
W
willzhang4a58 已提交
580 581
}

582
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBroadcastToBroadcast) {
L
Li Xinqi 已提交
583
  for (CompTaskNode* dst_node : sorted_dst_comp_tasks) {
J
Juncheng 已提交
584 585 586 587
    CompTaskNode* nearest_src_node =
        SubTskGphBuilderUtil::FindNearestNode(sorted_src_comp_tasks, dst_node);
    CHECK_NOTNULL(nearest_src_node);
    BuildTaskPath(nearest_src_node, dst_node, MutBufTask, true);
L
Li Xinqi 已提交
588 589 590
  }
}

L
Li Xinqi 已提交
591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialInLbiConnect) {
  HashSet<LogicalBlobId> lbis;
  for (const auto& obn : src_logical->SoleOp()->output_bns()) {
    lbis.insert(src_logical->SoleOp()->BnInOp2Lbi(obn));
  }
  CHECK_EQ(sorted_src_comp_tasks.size(), 1);
  CHECK_EQ(dst_logical->SoleOp()->input_bns().size(), sorted_dst_comp_tasks.size());
  FOR_RANGE(int, i, 0, sorted_dst_comp_tasks.size()) {
    const auto& lbi = dst_logical->SoleOp()->BnInOp2Lbi(dst_logical->SoleOp()->input_bns().Get(i));
    if (lbis.find(lbi) != lbis.end()) {
      BuildTaskPath(sorted_src_comp_tasks.at(0), sorted_dst_comp_tasks.at(i), MutBufTask, true);
    }
  }
}

DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialOutLbiConnect) {
  HashSet<LogicalBlobId> lbis;
  for (const auto& ibn : dst_logical->SoleOp()->input_bns()) {
    lbis.insert(dst_logical->SoleOp()->BnInOp2Lbi(ibn));
  }
  CHECK_EQ(sorted_dst_comp_tasks.size(), 1);
  CHECK_EQ(src_logical->SoleOp()->output_bns().size(), sorted_src_comp_tasks.size());
  FOR_RANGE(int, i, 0, sorted_src_comp_tasks.size()) {
    const auto& lbi = src_logical->SoleOp()->BnInOp2Lbi(src_logical->SoleOp()->output_bns().Get(i));
    if (lbis.find(lbi) != lbis.end()) {
      BuildTaskPath(sorted_src_comp_tasks.at(i), sorted_dst_comp_tasks.at(0), MutBufTask, true);
    }
  }
}

621 622 623 624 625 626 627 628 629
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphNormalForwardToDecodeH2D) {
  CHECK_EQ(sorted_src_comp_tasks.size(), sorted_dst_comp_tasks.size());
  FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) {
    CompTaskNode* src = sorted_src_comp_tasks.at(i);
    CompTaskNode* dst = sorted_dst_comp_tasks.at(i);
    Connect<TaskNode>(src, NewEdge(), dst);
  }
}

J
Jinhui Yuan 已提交
630
void TaskGraph::BuildTaskPath(
J
Jinhui Yuan 已提交
631 632
    CompTaskNode* src, CompTaskNode* dst,
    std::function<TaskNode**(CompTaskNode* src, int64_t machine_id, int32_t mem_zone_id)>
J
Jinhui Yuan 已提交
633 634
        MutBufTask,
    bool use_buf_task_node) {
J
Jinhui Yuan 已提交
635
  CHECK_NE(src, dst);
J
Jinhui Yuan 已提交
636 637
  auto GetBufTask = [&](int64_t machine_id, int32_t mem_zone_id) {
    return *MutBufTask(src, machine_id, mem_zone_id);
J
Jinhui Yuan 已提交
638
  };
J
Jinhui Yuan 已提交
639 640
  auto SetBufTask = [&](int64_t machine_id, int32_t mem_zone_id, TaskNode* new_val) {
    TaskNode** cur_val = MutBufTask(src, machine_id, mem_zone_id);
J
Jinhui Yuan 已提交
641 642 643 644 645 646 647 648
    if (*cur_val == nullptr) {
      *cur_val = new_val;
    } else {
      CHECK_EQ(*cur_val, new_val);
    }
    return new_val;
  };

J
Jinhui Yuan 已提交
649 650 651
  TaskNode* cur_node = src;
  while (cur_node->machine_id() != dst->machine_id()
         || cur_node->MemZoneId121() != dst->MemZoneId121()) {
J
Jinhui Yuan 已提交
652
    cur_node = BuildTaskStep(cur_node, dst, GetBufTask, SetBufTask, use_buf_task_node);
653
  }
L
lixinqi 已提交
654
  if (cur_node != dst) { Connect<TaskNode>(cur_node, NewEdge(), dst); }
J
Jinhui Yuan 已提交
655
}
656

J
Jinhui Yuan 已提交
657
TaskNode* TaskGraph::BuildTaskStep(
J
Jinhui Yuan 已提交
658
    TaskNode* cur_node, TaskNode* dst,
J
Juncheng 已提交
659 660
    const std::function<TaskNode*(int64_t machine_id, int32_t mem_zone_id)>& GetBufTask,
    const std::function<TaskNode*(int64_t machine_id, int32_t mem_zone_id, TaskNode*)>& SetBufTask,
J
Jinhui Yuan 已提交
661
    bool use_buf_task_node) {
J
Jinhui Yuan 已提交
662 663 664 665 666
  int32_t cpu_mem_zone_id = Global<IDMgr>::Get()->CpuMemZoneId();
  int32_t next_mem_zone_id = -1;
  TaskNode* next_node = nullptr;
  if (cur_node->MemZoneId121() != cpu_mem_zone_id) {
    next_mem_zone_id = cpu_mem_zone_id;
J
Jinhui Yuan 已提交
667
    if (!use_buf_task_node || !(next_node = GetBufTask(cur_node->machine_id(), next_mem_zone_id))) {
J
Jinhui Yuan 已提交
668 669
      next_node = AddCopyD2HTaskFrom(cur_node);
      Connect<TaskNode>(cur_node, NewEdge(), next_node);
670
    }
J
Jinhui Yuan 已提交
671 672
  } else if (cur_node->machine_id() == dst->machine_id()) {
    next_mem_zone_id = dst->MemZoneId121();
J
Jinhui Yuan 已提交
673
    if (!use_buf_task_node || !(next_node = GetBufTask(cur_node->machine_id(), next_mem_zone_id))) {
L
Li Xinqi 已提交
674
      next_node = TryAddCopyH2DTaskTo(dst);
L
lixinqi 已提交
675
      if (next_node == nullptr) { next_node = dst; }
J
Jinhui Yuan 已提交
676
      Connect<TaskNode>(cur_node, NewEdge(), next_node);
677
    }
J
Jinhui Yuan 已提交
678 679
  } else if (cur_node->machine_id() != dst->machine_id()) {
    next_mem_zone_id = cpu_mem_zone_id;
J
Jinhui Yuan 已提交
680
    if (!use_buf_task_node || !(next_node = GetBufTask(dst->machine_id(), next_mem_zone_id))) {
J
Jinhui Yuan 已提交
681 682 683 684 685
      next_node = AddCopyCommNetTaskBetween(cur_node, dst);
      Connect<TaskNode>(cur_node, NewEdge(), next_node);
    }
  } else {
    UNIMPLEMENTED();
686
  }
L
lixinqi 已提交
687 688 689
  if (use_buf_task_node && (next_node != dst)) {
    SetBufTask(next_node->machine_id(), next_mem_zone_id, next_node);
  }
J
Jinhui Yuan 已提交
690
  return next_node;
691 692
}

L
Li Xinqi 已提交
693
TaskNode* TaskGraph::TryAddCopyH2DTaskTo(TaskNode* task) {
L
lixinqi 已提交
694
  if (IsInterfaceTask(task)) { return nullptr; }
695
  if (IsClassRegistered<int32_t, TickTockTaskType>(task->GetTaskType())) { return nullptr; }
J
Jinhui Yuan 已提交
696
  CHECK_EQ(task->device_type(), DeviceType::kGPU);
W
Will Zhang 已提交
697
  CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
W
willzhang4a58 已提交
698
  copy_task->Init(CopyHdOpConf::H2D, task->machine_id(), task->GpuPhyId());
W
Will Zhang 已提交
699
  return copy_task;
W
willzhang4a58 已提交
700 701
}

J
Jinhui Yuan 已提交
702 703
TaskNode* TaskGraph::AddCopyD2HTaskFrom(TaskNode* task) {
  CHECK_EQ(task->device_type(), DeviceType::kGPU);
W
Will Zhang 已提交
704
  CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
W
willzhang4a58 已提交
705
  copy_task->Init(CopyHdOpConf::D2H, task->machine_id(), task->GpuPhyId());
W
Will Zhang 已提交
706 707
  return copy_task;
}
708

J
Jinhui Yuan 已提交
709
TaskNode* TaskGraph::AddCopyCommNetTaskBetween(TaskNode* src, TaskNode* dst) {
W
Will Zhang 已提交
710 711
  CHECK_NE(src->machine_id(), dst->machine_id());
  CopyCommNetTaskNode* copy_comm_net_task = NewNode<CopyCommNetTaskNode>();
712
  copy_comm_net_task->Init(dst->machine_id(), src->machine_id());
J
Jinhui Yuan 已提交
713
  return copy_comm_net_task;
W
willzhang4a58 已提交
714 715
}

W
willzhang4a58 已提交
716
void TaskGraph::ConnectWithCopyCommNetIfNeed(TaskNode* src, TaskNode* dst) {
W
willzhang4a58 已提交
717 718 719
  if (src->machine_id() == dst->machine_id()) {
    Connect(src, NewEdge(), dst);
  } else {
J
Jinhui Yuan 已提交
720 721 722
    TaskNode* copy_comm_net_task = AddCopyCommNetTaskBetween(src, dst);
    Connect<TaskNode>(src, NewEdge(), copy_comm_net_task);
    Connect<TaskNode>(copy_comm_net_task, NewEdge(), dst);
W
willzhang4a58 已提交
723 724 725
  }
}

L
Li Xinqi 已提交
726
bool IsBackEdge(TaskNode* src, TaskNode* dst) { return false; }
J
Jinhui Yuan 已提交
727

W
willzhang4a58 已提交
728
}  // namespace oneflow