task_graph.cpp 44.8 KB
Newer Older
J
jiyuan 已提交
1
#include "oneflow/core/graph/task_graph.h"
C
chengtbf 已提交
2
#include "oneflow/core/graph/normal_forward_compute_task_node.h"
J
Jinhui Yuan 已提交
3
#include "oneflow/core/graph/chain_graph.h"
W
Will Zhang 已提交
4
#include "oneflow/core/graph/boxing_task_node.h"
J
Jinhui Yuan 已提交
5
#include "oneflow/core/common/util.h"
J
Jinhui Yuan 已提交
6
#include "oneflow/core/graph/reduce_add_compute_task_node.h"
L
Li Xinqi 已提交
7
#include "oneflow/core/graph/inplace_lbi_graph.h"
8
#include "oneflow/core/register/runtime_blob_desc.h"
J
Jinhui Yuan 已提交
9
#include "oneflow/core/job/thrd_id_generator.h"
L
Li Xinqi 已提交
10 11 12
#include "oneflow/core/graph/reduce_identity_task_node.h"
#include "oneflow/core/operator/variable_op.h"
#include "oneflow/core/operator/constant_op.h"
13
#include "oneflow/core/operator/user_op_util.h"
J
Juncheng 已提交
14 15 16 17 18
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_context.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/chain_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/nccl_boxing_sub_task_graph_builder.h"
L
Li Xinqi 已提交
19
#include "oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h"
J
Juncheng 已提交
20
#include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h"
W
willzhang4a58 已提交
21 22 23

namespace oneflow {

L
Li Xinqi 已提交
24 25
namespace {

L
lixinqi 已提交
26
bool IsInterfaceTask(const TaskNode* node) {
L
Li Xinqi 已提交
27 28 29 30
  const auto* comp_task_node = dynamic_cast<const CompTaskNode*>(node);
  if (comp_task_node == nullptr) { return false; }
  if (comp_task_node->logical_node()->op_vec().size() != 1) { return false; }
  auto op_type_case = comp_task_node->logical_node()->SoleOp()->op_conf().op_type_case();
L
lixinqi 已提交
31
  return IsClassRegistered<IsInterfaceOpConf4OpTypeCase>(op_type_case);
L
Li Xinqi 已提交
32 33
}

L
Li Xinqi 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
bool IsConnectToTickOp(const TaskNode* node) {
  const auto* comp_task_node = dynamic_cast<const CompTaskNode*>(node);
  if (comp_task_node == nullptr) { return false; }
  if (comp_task_node->logical_node()->op_vec().size() != 1) { return false; }
  const Operator* op = comp_task_node->logical_node()->SoleOp().get();
  if (dynamic_cast<const VariableOp*>(op) != nullptr) { return true; }
  if (dynamic_cast<const ConstantOp*>(op) != nullptr) { return true; }
  return false;
}

void ForEachDeviceSrcUntrainableNode(const std::vector<NormalForwardCompTaskNode*>& fw_nodes,
                                     const std::function<void(CompTaskNode*)>& Handler) {
  HashSet<const TaskNode*> fw_nodes_set(fw_nodes.begin(), fw_nodes.end());
  auto IsSourceTaskNode = [&](NormalForwardCompTaskNode* node) {
    for (TaskEdge* edge : node->in_edges()) {
      if (fw_nodes_set.find(edge->src_node()) != fw_nodes_set.end()) { return false; }
    }
    return true;
  };
  auto HasBwNode = [&](NormalForwardCompTaskNode* node) {
54 55 56 57
    // TODO: update method for fw bw split
    // const auto* fw_logical_node = dynamic_cast<const ForwardLogicalNode*>(node->logical_node());
    // return fw_logical_node->bw_node() != nullptr;
    return false;
L
Li Xinqi 已提交
58 59 60 61 62 63
  };
  for (NormalForwardCompTaskNode* fw_node : fw_nodes) {
    if (IsSourceTaskNode(fw_node) && !HasBwNode(fw_node)) { Handler(fw_node); }
  }
}

L
Li Xinqi 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
std::function<TaskNode*(const std::string&)> MakeGetterTaskNode4SoleOpName(
    const HashSet<TaskNode*>& task_nodes) {
  auto op_name2task_nodes = std::make_shared<HashMap<std::string, HashSet<TaskNode*>>>();
  for (TaskNode* task_node : task_nodes) {
    if (task_node->exec_gph().node_num() == 1) {
      ExecNode* exec_node = task_node->exec_gph().SoleNode();
      CHECK((*op_name2task_nodes)[exec_node->op()->op_name()].emplace(task_node).second);
    }
  }
  return [op_name2task_nodes](const std::string& op_name) -> TaskNode* {
    const auto& iter = op_name2task_nodes->find(op_name);
    if (iter == op_name2task_nodes->end()) { return nullptr; }
    if (iter->second.size() > 1) { return nullptr; }
    return *iter->second.begin();
  };
J
Juncheng 已提交
79
}
L
Li Xinqi 已提交
80 81 82 83 84 85 86 87 88

bool IsLbiOnTaskEdge(const TaskEdge* edge, const LogicalBlobId& lbi) {
  for (const auto& regst_desc : edge->GetRegsts()) {
    if (regst_desc->HasLbi(lbi)) { return true; }
  }
  return false;
}

std::function<bool(const LogicalBlobId&, const std::string&)>
S
scxfjiang 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
MakePredicatorIsLbiAllConsumersReachable(
    std::function<const TaskNode*(const std::string&)> TaskNode4SoleOpName,
    std::function<bool(const std::string&, const std::string&)> IsOpNameDataOrCtrlReachable) {
  auto IsDataOrCtrlReachable = [IsOpNameDataOrCtrlReachable](const TaskNode* src_node,
                                                             const TaskNode* dst_node) -> bool {
    if (src_node->chain_id() == dst_node->chain_id()
        && src_node->order_in_graph() <= dst_node->order_in_graph()) {
      return true;
    }
    const CompTaskNode* comp_src_node = dynamic_cast<const CompTaskNode*>(src_node);
    if (comp_src_node == nullptr) { return false; }
    if (comp_src_node->logical_node()->op_vec().size() != 1) { return false; }
    const CompTaskNode* comp_dst_node = dynamic_cast<const CompTaskNode*>(dst_node);
    if (comp_dst_node == nullptr) { return false; }
    if (comp_dst_node->logical_node()->op_vec().size() != 1) { return false; }
    return IsOpNameDataOrCtrlReachable(comp_src_node->logical_node()->SoleOp()->op_name(),
                                       comp_dst_node->logical_node()->SoleOp()->op_name());
  };
  return [TaskNode4SoleOpName, IsDataOrCtrlReachable](const LogicalBlobId& lbi,
                                                      const std::string& op_name) -> bool {
L
Li Xinqi 已提交
109 110 111
    const TaskNode* src_task_node = TaskNode4SoleOpName(lbi.op_name());
    const TaskNode* dst_task_node = TaskNode4SoleOpName(op_name);
    size_t out_edges_size = 0;
L
Li Xinqi 已提交
112
    size_t reachable_out_edges_size = 0;
L
Li Xinqi 已提交
113 114 115
    for (TaskEdge* out_edge : src_task_node->out_edges()) {
      if (IsLbiOnTaskEdge(out_edge, lbi)) {
        out_edges_size += 1;
S
scxfjiang 已提交
116
        reachable_out_edges_size += IsDataOrCtrlReachable(out_edge->dst_node(), dst_task_node);
L
Li Xinqi 已提交
117 118
      }
    }
L
Li Xinqi 已提交
119
    return out_edges_size > 0 && out_edges_size == reachable_out_edges_size;
L
Li Xinqi 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132 133
  };
}

bool IsInplaceAllowed(
    TaskNode* task_node, const std::vector<std::string>& bns,
    const std::function<const TaskNode*(const std::string&)>& TaskNode4SoleOpName) {
  if (task_node->exec_gph().node_num() != 1) { return false; }
  const auto& exec_node = *task_node->exec_gph().SoleNode();
  for (const auto& bn : bns) {
    // TaskNode for bn is not nullptr if it's on the same device with `task_node`
    if (TaskNode4SoleOpName(exec_node.op()->BnInOp2Lbi(bn).op_name()) == nullptr) { return false; }
    const RegstDesc& regst_desc = *exec_node.RegstDesc4BnInOp(bn);
    if (regst_desc.NumOfLbi() != 1) { return false; }
  }
L
lixinqi 已提交
134
  const BlobDesc* first_blob = nullptr;
135
  for (const auto& bn : bns) {
L
lixinqi 已提交
136 137 138
    const BlobDesc* blob_desc = exec_node.RegstDesc4BnInOp(bn)->SoleBlobDesc();
    if (first_blob == nullptr) {
      first_blob = blob_desc;
139
    } else {
L
lixinqi 已提交
140 141 142 143
      if (!(first_blob->shape() == blob_desc->shape()
            && first_blob->data_type() == blob_desc->data_type())) {
        return false;
      }
144 145
    }
  }
L
Li Xinqi 已提交
146 147 148
  return true;
}

L
Li Xinqi 已提交
149 150
}  // namespace

151 152
TaskGraph::TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph) {
  logical_gph_ = std::move(logical_gph);
W
willzhang4a58 已提交
153
  HashMap<const LogicalNode*, std::vector<CompTaskNode*>> logical2sorted_comp_tasks;
154 155
  HashMap<const LogicalNode*, std::vector<TaskNode*>> logical2sorted_in_box;
  HashMap<const LogicalNode*, std::vector<TaskNode*>> logical2sorted_out_box;
J
Jinhui Yuan 已提交
156 157 158
  HashMap<CompTaskNode*, HashMap<int64_t, std::vector<TaskNode*>>> buf_task;
  auto MutBufTask = [&](CompTaskNode* task_node, int64_t machine_id, int32_t mem_zone_id) {
    auto& buf_vec = buf_task[task_node][machine_id];
X
Xinqi 已提交
159
    if (buf_vec.empty()) { buf_vec.assign(Global<ResourceDesc>::Get()->MemZoneNum(), nullptr); }
160 161
    return &(buf_vec.at(mem_zone_id));
  };
W
willzhang4a58 已提交
162

X
Xinqi 已提交
163
  std::vector<int64_t> cpu_device_offset(Global<ResourceDesc>::Get()->TotalMachineNum(), 0);
164
  auto AllocateCpuThrdIdEvenly = [&](const TaskNode* task_node) {
165
    CHECK(!task_node->IsIndependent());
W
willzhang4a58 已提交
166
    int64_t ret = -1;
J
Jinhui Yuan 已提交
167 168
    int64_t& offset = cpu_device_offset.at(task_node->machine_id());
    ret = Global<IDMgr>::Get()->GetCpuDeviceThrdId(offset);
X
Xinqi 已提交
169
    offset = (offset + 1) % Global<ResourceDesc>::Get()->CpuDeviceNum();
170 171
    return ret;
  };
J
Jinhui Yuan 已提交
172 173

  std::vector<std::pair<int64_t, CompTaskNode*>> machine_persistence_task_vec;
174
  logical_gph_->ForEachNode([&](const LogicalNode* logical_node) {
175
    logical_node->GenSortedCompTaskNodes(
J
Jinhui Yuan 已提交
176
        AllocateCpuThrdIdEvenly, &machine_persistence_task_vec, [&](CompTaskNode* comp_task_node) {
177 178
          AddAllocatedNode(comp_task_node);
          logical2sorted_comp_tasks[logical_node].push_back(comp_task_node);
179
          comp_task_node->set_area_id(logical_node->GetAreaId());
180
        });
181
  });
N
Niu Chong 已提交
182

183
  GenerateIndependentThrdId(machine_persistence_task_vec);
184
  logical_gph_->ForEachEdge([&](const LogicalEdge* logical_edge) {
W
willzhang4a58 已提交
185 186
    BldSubTskGphMthd method =
        GetMthdForBldSubTskGph(logical_edge->src_node(), logical_edge->dst_node());
187 188
    (this->*method)(logical_edge->src_node(), logical_edge->dst_node(),
                    logical2sorted_comp_tasks.at(logical_edge->src_node()),
W
willzhang4a58 已提交
189
                    logical2sorted_comp_tasks.at(logical_edge->dst_node()), &logical2sorted_in_box,
J
Jinhui Yuan 已提交
190
                    &logical2sorted_out_box, MutBufTask, AllocateCpuThrdIdEvenly);
J
Jinhui Yuan 已提交
191
    SetAreaIdForNewNodes(logical_edge->src_node(), logical_edge->dst_node());
192
  });
L
Li Xinqi 已提交
193 194 195 196 197 198 199
  logical_gph_->ForEachNecessaryCtrlEdge(
      [&](const LogicalNode* src, const LogicalNode* dst, int64_t ctrl_regst_num) {
        const auto& src_task_nodes = logical2sorted_comp_tasks.at(src);
        const auto& dst_task_nodes = logical2sorted_comp_tasks.at(dst);
        ConnectCtrlEdges(src_task_nodes, dst_task_nodes, ctrl_regst_num);
      });

200
  MergeChainAndSetOrderInGraphForEachNode();
J
Juncheng 已提交
201
  if (Global<ResourceDesc>::Get()->enable_debug_mode()) { ToDotWithAutoFilePath(); }
W
Will Zhang 已提交
202 203
}

L
Li Xinqi 已提交
204 205 206 207 208
void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_nodes,
                                 const std::vector<CompTaskNode*>& dst_task_nodes,
                                 int64_t ctrl_regst_num) {
  CHECK_EQ(src_task_nodes.size(), dst_task_nodes.size());
  FOR_RANGE(int32_t, i, 0, src_task_nodes.size()) {
209 210 211
    std::string regst_desc_name;
    RegstDesc* ctrl_regst_desc =
        src_task_nodes.at(i)->BuildCtrlRegstDesc(dst_task_nodes.at(i), &regst_desc_name);
L
Li Xinqi 已提交
212 213 214 215
    ctrl_regst_desc->UpdtMinRegstNumIfNeed(ctrl_regst_num);
    ctrl_regst_desc->UpdtMaxRegstNumIfNeed(ctrl_regst_num);
    ctrl_regst_desc->mut_regst_desc_type()->mutable_ctrl_regst_desc()->set_returned_regst_num(
        ctrl_regst_num);
216 217 218 219

    TaskEdge* edge = NewEdge();
    Connect<TaskNode>(src_task_nodes.at(i), edge, dst_task_nodes.at(i));
    src_task_nodes.at(i)->BindEdgeWithProducedRegst(edge, regst_desc_name);
L
Li Xinqi 已提交
220 221 222
  }
}

223
void TaskGraph::GenerateIndependentThrdId(
J
Jinhui Yuan 已提交
224 225 226 227 228 229
    const std::vector<std::pair<int64_t, CompTaskNode*>>& persistence_nodes) {
  std::vector<std::pair<int64_t, TaskType>> machine_task_type_vec;
  for (auto pair : persistence_nodes) {
    machine_task_type_vec.emplace_back(std::make_pair(pair.first, pair.second->GetTaskType()));
  }

230
  ThrdIdGenerator generator(machine_task_type_vec, Global<IDMgr>::Get()->BaseIndependentThrdId());
J
Jinhui Yuan 已提交
231 232 233 234 235 236
  for (const auto pair : persistence_nodes) {
    int64_t thrd_id = generator.GenerateThrdId(pair.first, pair.second->GetTaskType());
    pair.second->set_thrd_id(thrd_id);
  }
}

L
Li Xinqi 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249
void TaskGraph::MdUpdtDelayedTopoForEachNode(std::function<void(TaskNode* node)> Handler) const {
  HashSet<const TaskNode*> built_nodes;
  auto Build = [&](TaskNode* node) {
    CHECK(built_nodes.emplace(node).second);
    Handler(node);
  };
  AcyclicTopoForEachNode([](TaskNode* node) { return node->GetTaskType() != kNormalMdUpdt; },
                         Build);
  AcyclicTopoForEachNode([](TaskNode* node) { return node->GetTaskType() == kNormalMdUpdt; },
                         Build);
  ForEachNode([&](TaskNode* node) { CHECK(built_nodes.find(node) != built_nodes.end()); });
}

250 251 252
void TaskGraph::AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAllowedStartNode,
                                       std::function<void(TaskNode* node)> Handler) const {
  auto ForEachInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
253
    node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) {
L
Li Xinqi 已提交
254
      if (IsBackEdge(node_on_in_edge, node)) { return; }
255
      Handler(const_cast<TaskNode*>(node_on_in_edge));
256 257
    });
  };
258
  auto ForEachOutNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
259
    node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) {
L
Li Xinqi 已提交
260
      if (IsBackEdge(node, node_on_out_edge)) { return; }
261
      Handler(const_cast<TaskNode*>(node_on_out_edge));
262 263
    });
  };
L
Li Xinqi 已提交
264 265 266 267 268 269 270 271 272
  auto IsSourceNode = [&](TaskNode* node) {
    int32_t in_node_num = 0;
    ForEachInNode(node, [&](TaskNode* in_node) { ++in_node_num; });
    return in_node_num == 0;
  };
  std::list<TaskNode*> starts;
  ForEachNode([&](TaskNode* node) {
    if (IsSourceNode(node) && IsAllowedStartNode(node)) { starts.push_back(node); }
  });
273
  // DfsTopo will cause inappropriate chain graph
274 275 276 277 278
  TopoForEachNode(starts, ForEachInNode, ForEachOutNode, Handler);
}

void TaskGraph::AcyclicTopoForEachNode(std::function<void(TaskNode* node)> Handler) const {
  return AcyclicTopoForEachNode([](TaskNode*) { return true; }, Handler);
279 280
}

J
Jinhui Yuan 已提交
281 282 283 284
void TaskGraph::RemoveEmptyRegsts() {
  ForEachNode([&](TaskNode* node) { node->EraseZeroSizeProducedBlob(); });
  ForEachNode([&](TaskNode* node) { node->EraseZeroSizeConsumedRegst(); });
  ForEachNode([&](TaskNode* node) { node->EraseZeroSizeProducedRegst(); });
285
  ForEachNode([&](TaskNode* node) { node->UnbindBnWithEmptyRegst(); });
J
Jinhui Yuan 已提交
286 287
}

288
void TaskGraph::AddOrderingCtrlEdgeInSameChain() { BuildCtrlRegstDescInSameChain(); }
J
Jinhui Yuan 已提交
289

290 291 292
void TaskGraph::MergeChainAndSetOrderInGraphForEachNode() {
  ChainGraph chain_graph(*this);
  const auto& ordered_chain_nodes = chain_graph.OrderdedChainNodes();
J
Jinhui Yuan 已提交
293 294
  int64_t order_in_graph = 0;
  for (auto& chain_node : ordered_chain_nodes) {
295
    auto& ordered_in_chain = chain_node->TaskNodes();
J
Jinhui Yuan 已提交
296 297 298 299 300 301 302 303 304 305
    int64_t chain_id = chain_node->chain_id();
    for (auto& task_node : ordered_in_chain) {
      task_node->set_chain_id(chain_id);
      task_node->set_order_in_graph(order_in_graph);
      ordered_task_nodes_.emplace_back(task_node);
      ++order_in_graph;
    }
  }
}

306
void TaskGraph::BuildCtrlRegstDescInSameChain() {
J
Jinhui Yuan 已提交
307
  HashMap<int64_t, TaskNode*> chain_id2node;
L
Li Xinqi 已提交
308 309
  for (auto* node : ordered_task_nodes_) {
    if (IsConnectToTickOp(node)) { continue; }
J
Jinhui Yuan 已提交
310 311 312 313 314 315 316 317 318 319 320
    int64_t chain_id = node->chain_id();
    auto iter = chain_id2node.find(chain_id);
    if (iter == chain_id2node.end()) {
      CHECK(chain_id2node.emplace(chain_id, node).second);
    } else {
      iter->second->BuildCtrlRegstDescIfNeed(node);
      iter->second = node;
    }
  }
}

L
Li Xinqi 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
void TaskGraph::AddReduceNoBwForwardNodeOverlapingCtrlEdges() {
  HashMap<int64_t, std::vector<ReduceIdentityCompTaskNode*>> global_thrd_id2identity_nodes;
  HashMap<std::pair<int64_t, int64_t>, std::vector<NormalForwardCompTaskNode*>>
      global_dev_phy_id2fw_nodes;
  const auto* id_mgr = Global<IDMgr>::Get();
  for (auto* node : ordered_task_nodes_) {
    if (id_mgr->GetDeviceTypeFromThrdId(node->thrd_id()) == DeviceType::kCPU) { continue; }
    int64_t global_thrd_id = id_mgr->GlobalThrdId4TaskId(node->task_id());
    auto* identity_node = dynamic_cast<ReduceIdentityCompTaskNode*>(node);
    auto* fw_node = dynamic_cast<NormalForwardCompTaskNode*>(node);
    if (identity_node != nullptr) {
      global_thrd_id2identity_nodes[global_thrd_id].push_back(identity_node);
    } else if (fw_node != nullptr) {
      int64_t dev_phy_id = id_mgr->GetGpuPhyIdFromThrdId(node->thrd_id());
      global_dev_phy_id2fw_nodes[std::make_pair(node->machine_id(), dev_phy_id)].push_back(fw_node);
    } else {
      // do nothing
    }
  }
  auto GetIdentityNodeOrder = [&](const ReduceIdentityCompTaskNode* id_node) {
    const auto* id_logical_node =
        dynamic_cast<const ReduceIdentityLogicalNode*>(id_node->logical_node());
    return id_logical_node->order_in_logical_graph();
  };
  for (auto& pair : global_thrd_id2identity_nodes) {
    auto& identity_nodes = pair.second;
    std::sort(identity_nodes.begin(), identity_nodes.end(),
              [&](ReduceIdentityCompTaskNode* lhs, ReduceIdentityCompTaskNode* rhs) {
                return GetIdentityNodeOrder(lhs) < GetIdentityNodeOrder(rhs);
              });
    auto* first_identity_node = identity_nodes.at(0);
    int64_t machine_id = first_identity_node->machine_id();
    int64_t dev_phy_id = id_mgr->GetGpuPhyIdFromThrdId(first_identity_node->thrd_id());
    const auto& fw_nodes = global_dev_phy_id2fw_nodes.at(std::make_pair(machine_id, dev_phy_id));
    const Shape& identity_time_shape =
        *first_identity_node->GetProducedRegst("out")->data_regst_time_shape();
    ForEachDeviceSrcUntrainableNode(fw_nodes, [&](CompTaskNode* node) {
      std::shared_ptr<RegstDesc> regst_desc = node->GetProducedRegst("out");
      if (!regst_desc) { return; }
      const Shape& time_shape = *regst_desc->data_regst_time_shape();
L
Li Xinqi 已提交
361
      if (!time_shape.Containing(identity_time_shape)) { return; }
L
Li Xinqi 已提交
362 363 364 365 366 367 368 369 370 371 372
      CHECK_EQ(time_shape.elem_cnt() % identity_time_shape.elem_cnt(), 0);
      int regst_desc_num = time_shape.elem_cnt() / identity_time_shape.elem_cnt();
      RegstDesc* ctrl_regst_desc = node->BuildCtrlRegstDesc(first_identity_node);
      ctrl_regst_desc->UpdtMinRegstNumIfNeed(regst_desc_num);
      ctrl_regst_desc->UpdtMaxRegstNumIfNeed(regst_desc_num);
      ctrl_regst_desc->mut_regst_desc_type()->mutable_ctrl_regst_desc()->set_returned_regst_num(
          regst_desc_num);
    });
  }
}

373
void TaskGraph::EnableInplaceMemSharingInReduceStruct() {
J
Juncheng 已提交
374
  auto GetSuccReduceTaskNode = [](TaskNode* pred) {
J
Jinhui Yuan 已提交
375
    std::vector<TaskNode*> nodes;
J
Juncheng 已提交
376
    pred->ForEachNodeOnOutDataEdge([&](TaskNode* succ) {
J
Juncheng 已提交
377
      if (dynamic_cast<ReduceCompTaskNodeIf*>(succ) != nullptr) { nodes.push_back(succ); }
J
Jinhui Yuan 已提交
378 379
    });
    return nodes;
380 381
  };

J
Jinhui Yuan 已提交
382 383 384 385
  HashSet<TaskNode*> has_enabled_nodes;

  auto CollectReduceTaskNode = [&](TaskNode* from) {
    std::list<TaskNode*> nodes;
J
Juncheng 已提交
386 387
    nodes.push_back(from);
    TaskNode* pred = from;
J
Jinhui Yuan 已提交
388
    while (true) {
J
Juncheng 已提交
389 390 391 392 393 394
      std::vector<TaskNode*> succ_reduce_nodes = GetSuccReduceTaskNode(pred);
      if (succ_reduce_nodes.size() != 1) { break; }
      TaskNode* succ_reduce_node = succ_reduce_nodes.front();
      if (has_enabled_nodes.find(succ_reduce_node) != has_enabled_nodes.end()) { break; }
      nodes.push_back(succ_reduce_node);
      pred = succ_reduce_node;
395
    }
J
Jinhui Yuan 已提交
396
    return nodes;
397 398
  };

J
Juncheng 已提交
399 400
  auto CalcModelSize = [](ReduceIdentityCompTaskNode* node) {
    return InferRegstSize(*node->produced_regsts().at("out").get());
J
Jinhui Yuan 已提交
401
  };
402

J
Jinhui Yuan 已提交
403
  ForEachNode([&](TaskNode* node) {
J
Juncheng 已提交
404 405 406 407 408 409
    ReduceIdentityCompTaskNode* identity_node = dynamic_cast<ReduceIdentityCompTaskNode*>(node);
    if (!identity_node) { return; }
    if (identity_node->device_type() != DeviceType::kGPU) { return; }
    if (identity_node->parallel_ctx()->parallel_num() < 2) { return; }
    std::list<TaskNode*> reduce_task_nodes = CollectReduceTaskNode(identity_node);

410
    const int64_t mem_block_id = Global<IDMgr>::Get()->NewMemBlockId();
J
Juncheng 已提交
411
    const int64_t mem_size = CalcModelSize(identity_node);
412
    ReduceMemSharingCtx ctx(mem_size, mem_block_id);
J
Jinhui Yuan 已提交
413 414
    for (TaskNode* reduce_node : reduce_task_nodes) {
      auto reduce_task_node_if = dynamic_cast<ReduceCompTaskNodeIf*>(reduce_node);
J
Juncheng 已提交
415
      CHECK_NOTNULL(reduce_task_node_if);
J
Jinhui Yuan 已提交
416 417
      reduce_task_node_if->EnableMemSharingInReduce(ctx);
      has_enabled_nodes.insert(reduce_node);
418 419 420 421
    }
  });
}

L
Li Xinqi 已提交
422
void TaskGraph::GetInplaceOpBlobArgList(
423
    InplaceObasInfo* obas_info, const HashSet<TaskNode*>& dev_nodes,
L
Li Xinqi 已提交
424
    const std::function<const TaskNode*(const std::string&)>& TaskNode4OpName) const {
425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441
  auto AddMutableInplaceArgPair = [&](TaskNode* node, const std::string& ibn,
                                      const std::string& obn, const std::string& op_name) {
    if (IsInplaceAllowed(node, {ibn, obn}, TaskNode4OpName)) {
      auto* pair = obas_info->mut_inplace_oba_pairs.mutable_pair()->Add();
      *pair->mutable_first() = GenOpBlobArg(op_name, ibn);
      *pair->mutable_second() = GenOpBlobArg(op_name, obn);
    }
  };
  auto AddConstInplaceArgPair = [&](TaskNode* node, const std::string& ibn, const std::string& obn,
                                    const std::string& op_name) {
    if (IsInplaceAllowed(node, {ibn, obn}, TaskNode4OpName)) {
      auto* pair = obas_info->con_inplace_oba_pairs.mutable_pair()->Add();
      *pair->mutable_first() = GenOpBlobArg(op_name, ibn);
      *pair->mutable_second() = GenOpBlobArg(op_name, obn);
    }
  };

L
Li Xinqi 已提交
442 443 444 445 446 447
  for (TaskNode* task_node : dev_nodes) {
    if (task_node->exec_gph().node_num() != 1) { continue; }
    const auto& op = *task_node->exec_gph().SoleNode()->op();
    for (const std::string& ibn : op.input_bns()) {
      if (op.InputBlobModifier4Ibn(ibn).is_mutable()) {
        CHECK(IsInplaceAllowed(task_node, {ibn}, TaskNode4OpName));
448
        *obas_info->mut_in_obas.mutable_oba()->Add() = GenOpBlobArg(op.op_name(), ibn);
L
Li Xinqi 已提交
449 450 451
      }
    }
    for (const std::string& obn : op.output_bns()) {
452 453 454 455 456 457 458 459 460 461 462 463 464
      const auto& obn_modifier = op.OutputBlobModifier4Obn(obn);
      if (obn_modifier.has_mutable_inplace_ibn()) {
        AddMutableInplaceArgPair(task_node, obn_modifier.mutable_inplace_ibn(), obn, op.op_name());
      } else if (obn_modifier.has_const_inplace_ibn()) {
        AddConstInplaceArgPair(task_node, obn_modifier.const_inplace_ibn(), obn, op.op_name());
      }
    }

    if (op.op_conf().has_user_conf()) {
      const OpContext* op_ctx = task_node->exec_gph().SoleNode()->op_context();
      const UserOpCtx* user_op_ctx = static_cast<const UserOpCtx*>(op_ctx);
      for (const auto& pair : user_op_ctx->mut_inplace_obn2ibn) {
        AddMutableInplaceArgPair(task_node, pair.second, pair.first, op.op_name());
L
Li Xinqi 已提交
465
      }
466 467
      for (const auto& pair : user_op_ctx->con_inplace_obn2ibn) {
        AddConstInplaceArgPair(task_node, pair.second, pair.first, op.op_name());
L
Li Xinqi 已提交
468
      }
L
Li Xinqi 已提交
469
    }
L
Li Xinqi 已提交
470 471 472 473
  }
}

void TaskGraph::GetSafeInplaceOpBlobArgList(
474
    InplaceObasInfo* safe_obas_info, const HashSet<TaskNode*>& dev_nodes,
S
scxfjiang 已提交
475
    std::function<bool(const std::string&, const std::string&)> IsOpNameDataOrCtrlReachable) const {
L
Li Xinqi 已提交
476
  auto TaskNode4SoleOpName = MakeGetterTaskNode4SoleOpName(dev_nodes);
477 478
  InplaceObasInfo obas_info;
  GetInplaceOpBlobArgList(&obas_info, dev_nodes, TaskNode4SoleOpName);
L
Li Xinqi 已提交
479 480 481
  auto Op4OpName = [&](const std::string& op_name) -> const Operator* {
    return TaskNode4SoleOpName(op_name)->exec_gph().SoleNode()->op().get();
  };
S
scxfjiang 已提交
482 483
  auto IsLbiAllConsumersReachable =
      MakePredicatorIsLbiAllConsumersReachable(TaskNode4SoleOpName, IsOpNameDataOrCtrlReachable);
484 485 486
  InplaceLbiGraph origin_graph(obas_info, Op4OpName);
  InplaceLbiGraph safe_graph(*safe_obas_info, Op4OpName);
  origin_graph.ComputeSafeInplaceObns(safe_obas_info, IsLbiAllConsumersReachable);
J
Juncheng 已提交
487 488 489 490 491 492
  if (Global<ResourceDesc>::Get()->enable_debug_mode()) {
    origin_graph.ToDotWithFilePath(
        JoinPath("dot", "InplaceLbiGraph", GlobalJobDesc().job_name() + "_origin.dot"));
    safe_graph.ToDotWithFilePath(
        JoinPath("dot", "InplaceLbiGraph", GlobalJobDesc().job_name() + "_safe.dot"));
  }
L
Li Xinqi 已提交
493 494
}

495
void TaskGraph::SetTaskRegstInplaceInfo(const InplaceObasInfo& obas_info,
L
Li Xinqi 已提交
496 497 498 499 500
                                        const HashSet<TaskNode*>& dev_nodes) const {
  auto TaskNode4SoleOpName = MakeGetterTaskNode4SoleOpName(dev_nodes);
  auto Op4OpName = [&](const std::string& op_name) -> const Operator* {
    return TaskNode4SoleOpName(op_name)->exec_gph().SoleNode()->op().get();
  };
501
  InplaceLbiGraph inplace_gph(obas_info, Op4OpName);
L
Li Xinqi 已提交
502 503 504 505 506 507 508 509
  inplace_gph.ForEachConnectedComponent([&](const HashSet<const InplaceLbiNode*> inplace_nodes) {
    for (const auto* inplace_node : inplace_nodes) {
      if (inplace_node->in_edges().empty()) { continue; }
      const auto* inplace_edge = inplace_node->SoleInEdge();
      auto* exec_node = TaskNode4SoleOpName(inplace_edge->op().op_name())->exec_gph().SoleNode();
      RegstDesc* in_regst = exec_node->RegstDesc4BnInOp(inplace_edge->ibn());
      RegstDesc* out_regst = exec_node->RegstDesc4BnInOp(inplace_edge->obn());
      out_regst->set_hint_inplace_consumed_regst_desc_id(in_regst->regst_desc_id());
L
Li Xinqi 已提交
510
    }
L
Li Xinqi 已提交
511 512 513 514 515 516 517 518 519 520 521 522 523 524 525
  });
}

void TaskGraph::ForEachGpuDeviceNodes(
    const std::function<void(const HashSet<TaskNode*>& dev_nodes)>& Handler) const {
  HashMap<std::pair<int64_t, int64_t>, HashSet<TaskNode*>> global_dev_phy_id2nodes;
  ForEachNode([&](TaskNode* task_node) {
    if (task_node->device_type() != DeviceType::kGPU) { return; }
    int64_t dev_phy_id = Global<IDMgr>::Get()->GetGpuPhyIdFromThrdId(task_node->thrd_id());
    global_dev_phy_id2nodes[{task_node->machine_id(), dev_phy_id}].emplace(task_node);
  });
  for (const auto& pair : global_dev_phy_id2nodes) { Handler(pair.second); }
}

void TaskGraph::EnableInplaceMemSharing(
S
scxfjiang 已提交
526 527
    const std::function<bool(const std::string&, const std::string&)>&
        IsOpNameDataOrCtrlReachable) {
L
Li Xinqi 已提交
528
  ForEachGpuDeviceNodes([&](const HashSet<TaskNode*>& dev_nodes) {
529 530 531
    InplaceObasInfo safe_inplace_obas_info;
    GetSafeInplaceOpBlobArgList(&safe_inplace_obas_info, dev_nodes, IsOpNameDataOrCtrlReachable);
    SetTaskRegstInplaceInfo(safe_inplace_obas_info, dev_nodes);
L
Li Xinqi 已提交
532 533 534
  });
}

C
chengtbf 已提交
535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553
void TaskGraph::AddOrderCtrlEdgeBetweenCopyAndMdUpdt() {
  for (TaskNode* task_node : ordered_task_nodes_) {
    auto copy_hd_task_node = dynamic_cast<CopyHdTaskNode*>(task_node);
    if (copy_hd_task_node == nullptr) { continue; }
    if (copy_hd_task_node->copy_type() != CopyHdOpConf::H2D) { continue; }
    if (copy_hd_task_node->area_id() != static_cast<int64_t>(kDataForwardArea)
        && copy_hd_task_node->area_id() != static_cast<int64_t>(kBoundaryArea)) {
      continue;
    }
    std::vector<TaskNode*> candidate_nodes;
    auto ForEachNextNode = [&](TaskNode* node,
                               const std::function<void(TaskNode*)>& TryPushNodeToQueue) {
      node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) {
        if (IsForwardTaskType(node_on_out_edge->GetTaskType())) {
          TryPushNodeToQueue(node_on_out_edge);
        }
      });
    };
    auto HandlerAddCandidate = [&](TaskNode* node) {
554 555
      TODO();  // refactor the following code
      /*
C
chengtbf 已提交
556 557 558 559 560 561
      auto fw_task_node = dynamic_cast<NormalForwardCompTaskNode*>(node);
      if (fw_task_node != nullptr && fw_task_node->logical_node()->HasOpWithModelBlob()
          && fw_task_node->parallel_ctx()->parallel_num() > 1
          && fw_task_node->parallel_ctx()->policy() == kDataParallel) {
        candidate_nodes.push_back(node);
      }
562
      */
C
chengtbf 已提交
563 564 565 566 567 568 569 570 571 572 573 574 575 576
    };
    BfsForEachNode({task_node}, ForEachNextNode, HandlerAddCandidate);
    std::sort(candidate_nodes.begin(), candidate_nodes.end(),
              [](const TaskNode* a, const TaskNode* b) {
                return a->order_in_graph() < b->order_in_graph();
              });
    int64_t last_chain_id = -1;
    for (TaskNode* candidate_node : candidate_nodes) {
      if (candidate_node->chain_id() != last_chain_id) {
        last_chain_id = candidate_node->chain_id();
        candidate_node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) {
          if (IsMdUpdtTaskType(node_on_in_edge->GetTaskType())) {
            RegstDesc* ctrl_regst = task_node->BuildCtrlRegstDesc(node_on_in_edge);
            RegstDesc* copy_out_regst = copy_hd_task_node->GetProducedRegst("copy_out").get();
X
Xinqi 已提交
577
            int64_t piece_num_in_batch = GlobalJobDesc().NumOfPiecesInBatch();
C
chengtbf 已提交
578 579 580 581 582 583 584 585 586 587 588 589
            ctrl_regst->UpdtMinRegstNumIfNeed(copy_out_regst->min_register_num()
                                              + piece_num_in_batch - 1);
            CtrlRegstDesc* ctrl_regst_desc =
                ctrl_regst->mut_regst_desc_type()->mutable_ctrl_regst_desc();
            ctrl_regst_desc->set_reliant_regst_desc_id(copy_out_regst->regst_desc_id());
            ctrl_regst_desc->set_returned_regst_num(piece_num_in_batch);
          }
        });
      }
    }
  }
}
J
Jinhui Yuan 已提交
590 591 592 593 594 595

void TaskGraph::SetAreaIdForNewNodes(const LogicalNode* src_logical,
                                     const LogicalNode* dst_logical) {
  CHECK(src_logical != nullptr && dst_logical != nullptr);
  ForEachNode([&](TaskNode* node) {
    if (node->area_id() != static_cast<int64_t>(kInvalidArea)) return;
596 597
    if (src_logical->GetAreaId() == dst_logical->GetAreaId()) {
      node->set_area_id(src_logical->GetAreaId());
J
Jinhui Yuan 已提交
598 599 600 601 602 603
    } else {
      node->set_area_id(static_cast<int64_t>(kBoundaryArea));
    }
  });
}

604 605
#define DEFINE_BLD_SUB_TASK_GRAPH_METHOD(method_name) \
  void TaskGraph::method_name BLD_SUB_TSK_GPH_MTHD_ARGS()
606 607

DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) {
J
Juncheng 已提交
608 609 610 611 612 613 614 615 616 617 618 619
  if (GlobalJobDesc().use_boxing_v2()) {
    BldSubTskGphByBoxingV2(src_logical, dst_logical, sorted_src_comp_tasks, sorted_dst_comp_tasks,
                           logical2sorted_in_box, logical2sorted_out_box, std::move(MutBufTask),
                           std::move(AllocateCpuThrdIdEvenly));
  } else {
    BldSubTskGphByBoxingV1(src_logical, dst_logical, sorted_src_comp_tasks, sorted_dst_comp_tasks,
                           logical2sorted_in_box, logical2sorted_out_box, std::move(MutBufTask),
                           std::move(AllocateCpuThrdIdEvenly));
  }
}

DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxingV1) {
620
  std::vector<TaskNode*>* sorted_out_box = nullptr;
W
willzhang4a58 已提交
621 622
  if (logical2sorted_out_box->find(src_logical) == logical2sorted_out_box->end()) {
    BuildOutBoxing(src_logical, sorted_src_comp_tasks, &((*logical2sorted_out_box)[src_logical]),
623
                   MutBufTask, AllocateCpuThrdIdEvenly);
624
  }
625
  sorted_out_box = &(logical2sorted_out_box->at(src_logical));
626

W
willzhang4a58 已提交
627
  std::vector<TaskNode*>* sorted_in_box = nullptr;
W
willzhang4a58 已提交
628 629
  if (logical2sorted_in_box->find(dst_logical) == logical2sorted_in_box->end()) {
    BuildInBoxing(dst_logical, sorted_dst_comp_tasks, &((*logical2sorted_in_box)[dst_logical]),
630
                  AllocateCpuThrdIdEvenly);
W
willzhang4a58 已提交
631
  }
632
  sorted_in_box = &(logical2sorted_in_box->at(dst_logical));
633 634

  for (TaskNode* src_box : *sorted_out_box) {
J
Jinhui Yuan 已提交
635
    for (TaskNode* dst_box : *sorted_in_box) { ConnectWithCopyCommNetIfNeed(src_box, dst_box); }
W
willzhang4a58 已提交
636 637 638
  }
}

J
Juncheng 已提交
639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxingV2) {
  const std::vector<LogicalBlobId> lbis = src_logical->GetLbisTo(dst_logical);
  const auto Fallback = [&]() {
    BldSubTskGphByBoxingV1(src_logical, dst_logical, sorted_src_comp_tasks, sorted_dst_comp_tasks,
                           logical2sorted_in_box, logical2sorted_out_box, std::move(MutBufTask),
                           std::move(AllocateCpuThrdIdEvenly));
  };
  if (lbis.size() > 1) {
    Fallback();
  } else {
    CHECK_EQ(lbis.size(), 1);
    const LogicalBlobId& lbi = lbis.front();
    const SbpParallel& src_sbp_parallel =
        Global<OpGraph>::Get()->GetSbpParallel(src_logical->SoleOp()->op_name(), lbi);
    const SbpParallel& dst_sbp_parallel =
        Global<OpGraph>::Get()->GetSbpParallel(dst_logical->SoleOp()->op_name(), lbi);
    const std::shared_ptr<const ParallelDesc>& src_parallel_desc = src_logical->parallel_desc();
    const std::shared_ptr<const ParallelDesc>& dst_parallel_desc = dst_logical->parallel_desc();
    const BlobDesc& blob_desc = Global<OpGraph>::Get()->GetLogicalBlobDesc(lbi);
    SubTskGphBuilderCtx ctx(this);
    std::vector<std::shared_ptr<SubTskGphBuilder>> builders;
    builders.emplace_back(new NcclBoxingSubTskGphBuilder());
L
Li Xinqi 已提交
661
    builders.emplace_back(new SliceBoxingSubTskGphBuilder());
J
Juncheng 已提交
662 663 664 665 666 667 668 669 670 671 672 673 674
    Maybe<void> status = TRY(ChainSubTskGphBuilder(builders).Build(
        &ctx, sorted_src_comp_tasks, sorted_dst_comp_tasks, *src_parallel_desc, *dst_parallel_desc,
        lbi, blob_desc, src_sbp_parallel, dst_sbp_parallel));
    if (!status.IsOk()) {
      if (SubTskGphBuilderUtil::IsErrorBoxingNotSupported(*status.error())) {
        Fallback();
      } else {
        UNIMPLEMENTED();
      }
    }
  }
}

675
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByOneToOne) {
W
Will Zhang 已提交
676 677
  CHECK_EQ(sorted_src_comp_tasks.size(), sorted_dst_comp_tasks.size());
  FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) {
678 679
    CompTaskNode* src = sorted_src_comp_tasks.at(i);
    CompTaskNode* dst = sorted_dst_comp_tasks.at(i);
680
    BuildTaskPath(src, dst, MutBufTask, true);
W
Will Zhang 已提交
681
  }
W
willzhang4a58 已提交
682 683
}

684
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBroadcastToBroadcast) {
685 686 687 688 689 690 691 692 693
  CHECK_EQ(sorted_dst_comp_tasks.size() % sorted_src_comp_tasks.size(), 0);
  if (sorted_src_comp_tasks.size() == sorted_dst_comp_tasks.size()) {
    FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) {
      CompTaskNode* src = sorted_src_comp_tasks.at(i);
      CompTaskNode* dst = sorted_dst_comp_tasks.at(i);
      BuildTaskPath(src, dst, MutBufTask, true);
    }
    return;
  }
694 695 696 697 698 699 700 701
  HashMap<size_t, CompTaskNode*> machine_id2last_src_task;
  HashMap<std::pair<int64_t, int64_t>, CompTaskNode*> global_thrd_id2src_task;
  auto GlobalThrdId4TaskNode = [](TaskNode* task_node) -> std::pair<int64_t, int64_t> {
    return std::make_pair(task_node->machine_id(), task_node->thrd_id());
  };
  for (CompTaskNode* src_node : sorted_src_comp_tasks) {
    machine_id2last_src_task[src_node->machine_id()] = src_node;
    global_thrd_id2src_task[GlobalThrdId4TaskNode(src_node)] = src_node;
L
Li Xinqi 已提交
702
  }
703
  HashMap<std::pair<int64_t, int64_t>, CompTaskNode*> global_thrd_id2dst_task;
L
Li Xinqi 已提交
704
  for (CompTaskNode* dst_node : sorted_dst_comp_tasks) {
705
    global_thrd_id2dst_task[GlobalThrdId4TaskNode(dst_node)] = dst_node;
L
Li Xinqi 已提交
706
  }
707 708 709 710 711 712 713 714 715
  auto GetSrcNode = [&](const std::pair<int64_t, int64_t>& global_thrd_id) -> CompTaskNode* {
    const auto& src_task_it = global_thrd_id2src_task.find(global_thrd_id);
    if (src_task_it != global_thrd_id2src_task.end()) { return src_task_it->second; }
    const auto& m_src_task_it = machine_id2last_src_task.find(global_thrd_id.first);
    if (m_src_task_it != machine_id2last_src_task.end()) { return m_src_task_it->second; }
    return machine_id2last_src_task.begin()->second;
  };
  for (const auto& pair : global_thrd_id2dst_task) {
    BuildTaskPath(GetSrcNode(pair.first), pair.second, MutBufTask, true);
L
Li Xinqi 已提交
716 717 718
  }
}

719
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphBySelectOneSourceToSoleSink) {
W
Will Zhang 已提交
720 721 722 723 724 725 726 727 728 729 730 731
  CHECK_EQ(sorted_dst_comp_tasks.size(), 1);
  CompTaskNode* sole_dst_comp_task = sorted_dst_comp_tasks.front();
  CompTaskNode* selected_src_comp_task = nullptr;
  bool is_same_machine = false;
  auto UpdateSelected = [&](CompTaskNode* node) {
    selected_src_comp_task = node;
    is_same_machine = (node->machine_id() == sole_dst_comp_task->machine_id());
  };
  for (CompTaskNode* src_comp_task : sorted_src_comp_tasks) {
    if (selected_src_comp_task == nullptr) {
      UpdateSelected(src_comp_task);
      continue;
W
willzhang4a58 已提交
732
    }
W
Will Zhang 已提交
733 734 735 736 737
    if (src_comp_task->machine_id() == sole_dst_comp_task->machine_id()) {
      if (is_same_machine == false) {
        UpdateSelected(src_comp_task);
        continue;
      }
W
willzhang4a58 已提交
738
      if (src_comp_task->thrd_id() == sole_dst_comp_task->thrd_id()) {
W
Will Zhang 已提交
739 740 741
        UpdateSelected(src_comp_task);
        break;
      }
W
willzhang4a58 已提交
742 743
    }
  }
W
Will Zhang 已提交
744
  CHECK_NOTNULL(selected_src_comp_task);
W
willzhang4a58 已提交
745
  BldSubTskGphByOneToOne(nullptr, nullptr, {selected_src_comp_task}, sorted_dst_comp_tasks, nullptr,
J
Jinhui Yuan 已提交
746
                         nullptr, MutBufTask, AllocateCpuThrdIdEvenly);
W
willzhang4a58 已提交
747 748
}

J
Jinhui Yuan 已提交
749 750 751 752 753 754 755 756 757 758 759
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByReduceScatter2ReduceAdd) {
  const LogicalNode* src_logical_node = sorted_src_comp_tasks.front()->logical_node();
  const auto& pd = src_logical_node->parallel_desc();
  bool has_local_reduce =
      pd->sorted_machine_ids().size() > 1 && pd->device_num_of_each_machine() > 1;
  const LogicalNode* pred_src_logical_node = src_logical_node->SoleInEdge()->src_node();
  bool is_local_reduce =
      has_local_reduce
          ? !(dynamic_cast<const ReduceAddLogicalNode*>(pred_src_logical_node)
              || dynamic_cast<const NcclReduceScatterLogicalNode*>(pred_src_logical_node))
          : false;
760 761
  for (CompTaskNode* src_comp_task : sorted_src_comp_tasks) {
    for (CompTaskNode* dst_comp_task : sorted_dst_comp_tasks) {
J
Jinhui Yuan 已提交
762 763 764 765 766 767 768 769 770 771 772 773
      if (has_local_reduce) {
        if (is_local_reduce) {
          if (src_comp_task->machine_id() == dst_comp_task->machine_id()) {
            BuildTaskPath(src_comp_task, dst_comp_task, MutBufTask, false);
          }
        } else {
          if (src_comp_task->parallel_id() % pd->device_num_of_each_machine()
              == dst_comp_task->parallel_id() % pd->device_num_of_each_machine()) {
            BuildTaskPath(src_comp_task, dst_comp_task, MutBufTask, false);
          }
        }
      } else {
774
        BuildTaskPath(src_comp_task, dst_comp_task, MutBufTask, false);
W
willzhang4a58 已提交
775
      }
J
Jinhui Yuan 已提交
776 777 778 779
    }
  }
}

J
Jinhui Yuan 已提交
780
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByReduceAdd2ReduceGather) {
J
Jinhui Yuan 已提交
781 782 783
  const auto& pd = sorted_src_comp_tasks.front()->logical_node()->parallel_desc();
  bool has_local_reduce =
      pd->sorted_machine_ids().size() > 1 && pd->device_num_of_each_machine() > 1;
J
Jinhui Yuan 已提交
784 785
  for (CompTaskNode* src_comp_task : sorted_src_comp_tasks) {
    for (CompTaskNode* dst_comp_task : sorted_dst_comp_tasks) {
J
Jinhui Yuan 已提交
786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806
      if (has_local_reduce) {
        if (src_comp_task->parallel_id() % pd->device_num_of_each_machine()
            == dst_comp_task->parallel_id() % pd->device_num_of_each_machine()) {
          BuildTaskPath(src_comp_task, dst_comp_task, MutBufTask, true);
        }
      } else {
        BuildTaskPath(src_comp_task, dst_comp_task, MutBufTask, true);
      }
    }
  }
}

DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByReduceGather2ReduceGather) {
  const auto& pd = sorted_src_comp_tasks.front()->logical_node()->parallel_desc();
  CHECK_GT(pd->device_num_of_each_machine(), 1);
  CHECK_GT(pd->sorted_machine_ids().size(), 1);
  for (CompTaskNode* src_comp_task : sorted_src_comp_tasks) {
    for (CompTaskNode* dst_comp_task : sorted_dst_comp_tasks) {
      if (src_comp_task->machine_id() == dst_comp_task->machine_id()) {
        BuildTaskPath(src_comp_task, dst_comp_task, MutBufTask, true);
      }
W
willzhang4a58 已提交
807 808 809
    }
  }
}
810

811 812 813 814 815 816 817 818 819 820
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByConnectNodeOnSameGpuDevice) {
  for (CompTaskNode* src : sorted_src_comp_tasks) {
    for (CompTaskNode* dst : sorted_dst_comp_tasks) {
      if (src->machine_id() == dst->machine_id() && src->GpuPhyId() == dst->GpuPhyId()) {
        Connect<TaskNode>(src, NewEdge(), dst);
      }
    }
  }
}

L
Li Xinqi 已提交
821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialInLbiConnect) {
  HashSet<LogicalBlobId> lbis;
  for (const auto& obn : src_logical->SoleOp()->output_bns()) {
    lbis.insert(src_logical->SoleOp()->BnInOp2Lbi(obn));
  }
  CHECK_EQ(sorted_src_comp_tasks.size(), 1);
  CHECK_EQ(dst_logical->SoleOp()->input_bns().size(), sorted_dst_comp_tasks.size());
  FOR_RANGE(int, i, 0, sorted_dst_comp_tasks.size()) {
    const auto& lbi = dst_logical->SoleOp()->BnInOp2Lbi(dst_logical->SoleOp()->input_bns().Get(i));
    if (lbis.find(lbi) != lbis.end()) {
      BuildTaskPath(sorted_src_comp_tasks.at(0), sorted_dst_comp_tasks.at(i), MutBufTask, true);
    }
  }
}

DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialOutLbiConnect) {
  HashSet<LogicalBlobId> lbis;
  for (const auto& ibn : dst_logical->SoleOp()->input_bns()) {
    lbis.insert(dst_logical->SoleOp()->BnInOp2Lbi(ibn));
  }
  CHECK_EQ(sorted_dst_comp_tasks.size(), 1);
  CHECK_EQ(src_logical->SoleOp()->output_bns().size(), sorted_src_comp_tasks.size());
  FOR_RANGE(int, i, 0, sorted_src_comp_tasks.size()) {
    const auto& lbi = src_logical->SoleOp()->BnInOp2Lbi(src_logical->SoleOp()->output_bns().Get(i));
    if (lbis.find(lbi) != lbis.end()) {
      BuildTaskPath(sorted_src_comp_tasks.at(i), sorted_dst_comp_tasks.at(0), MutBufTask, true);
    }
  }
}

J
Jinhui Yuan 已提交
851
void TaskGraph::BuildTaskPath(
J
Jinhui Yuan 已提交
852 853
    CompTaskNode* src, CompTaskNode* dst,
    std::function<TaskNode**(CompTaskNode* src, int64_t machine_id, int32_t mem_zone_id)>
J
Jinhui Yuan 已提交
854 855
        MutBufTask,
    bool use_buf_task_node) {
J
Jinhui Yuan 已提交
856
  CHECK_NE(src, dst);
J
Jinhui Yuan 已提交
857 858
  auto GetBufTask = [&](int64_t machine_id, int32_t mem_zone_id) {
    return *MutBufTask(src, machine_id, mem_zone_id);
J
Jinhui Yuan 已提交
859
  };
J
Jinhui Yuan 已提交
860 861
  auto SetBufTask = [&](int64_t machine_id, int32_t mem_zone_id, TaskNode* new_val) {
    TaskNode** cur_val = MutBufTask(src, machine_id, mem_zone_id);
J
Jinhui Yuan 已提交
862 863 864 865 866 867 868 869
    if (*cur_val == nullptr) {
      *cur_val = new_val;
    } else {
      CHECK_EQ(*cur_val, new_val);
    }
    return new_val;
  };

J
Jinhui Yuan 已提交
870 871 872
  TaskNode* cur_node = src;
  while (cur_node->machine_id() != dst->machine_id()
         || cur_node->MemZoneId121() != dst->MemZoneId121()) {
J
Jinhui Yuan 已提交
873
    cur_node = BuildTaskStep(cur_node, dst, GetBufTask, SetBufTask, use_buf_task_node);
874
  }
L
lixinqi 已提交
875
  if (cur_node != dst) { Connect<TaskNode>(cur_node, NewEdge(), dst); }
J
Jinhui Yuan 已提交
876
}
877

J
Jinhui Yuan 已提交
878
TaskNode* TaskGraph::BuildTaskStep(
J
Jinhui Yuan 已提交
879
    TaskNode* cur_node, TaskNode* dst,
J
Jinhui Yuan 已提交
880 881 882
    std::function<TaskNode*(int64_t machine_id, int32_t mem_zone_id)> GetBufTask,
    std::function<TaskNode*(int64_t machine_id, int32_t mem_zone_id, TaskNode*)> SetBufTask,
    bool use_buf_task_node) {
J
Jinhui Yuan 已提交
883 884 885 886 887
  int32_t cpu_mem_zone_id = Global<IDMgr>::Get()->CpuMemZoneId();
  int32_t next_mem_zone_id = -1;
  TaskNode* next_node = nullptr;
  if (cur_node->MemZoneId121() != cpu_mem_zone_id) {
    next_mem_zone_id = cpu_mem_zone_id;
J
Jinhui Yuan 已提交
888
    if (!use_buf_task_node || !(next_node = GetBufTask(cur_node->machine_id(), next_mem_zone_id))) {
J
Jinhui Yuan 已提交
889 890
      next_node = AddCopyD2HTaskFrom(cur_node);
      Connect<TaskNode>(cur_node, NewEdge(), next_node);
891
    }
J
Jinhui Yuan 已提交
892 893
  } else if (cur_node->machine_id() == dst->machine_id()) {
    next_mem_zone_id = dst->MemZoneId121();
J
Jinhui Yuan 已提交
894
    if (!use_buf_task_node || !(next_node = GetBufTask(cur_node->machine_id(), next_mem_zone_id))) {
L
Li Xinqi 已提交
895
      next_node = TryAddCopyH2DTaskTo(dst);
L
lixinqi 已提交
896
      if (next_node == nullptr) { next_node = dst; }
J
Jinhui Yuan 已提交
897
      Connect<TaskNode>(cur_node, NewEdge(), next_node);
898
    }
J
Jinhui Yuan 已提交
899 900
  } else if (cur_node->machine_id() != dst->machine_id()) {
    next_mem_zone_id = cpu_mem_zone_id;
J
Jinhui Yuan 已提交
901
    if (!use_buf_task_node || !(next_node = GetBufTask(dst->machine_id(), next_mem_zone_id))) {
J
Jinhui Yuan 已提交
902 903 904 905 906
      next_node = AddCopyCommNetTaskBetween(cur_node, dst);
      Connect<TaskNode>(cur_node, NewEdge(), next_node);
    }
  } else {
    UNIMPLEMENTED();
907
  }
L
lixinqi 已提交
908 909 910
  if (use_buf_task_node && (next_node != dst)) {
    SetBufTask(next_node->machine_id(), next_mem_zone_id, next_node);
  }
J
Jinhui Yuan 已提交
911
  return next_node;
912 913
}

L
Li Xinqi 已提交
914
TaskNode* TaskGraph::TryAddCopyH2DTaskTo(TaskNode* task) {
L
lixinqi 已提交
915
  if (IsInterfaceTask(task)) { return nullptr; }
L
lixinqi 已提交
916
  if (IsClassRegistered<TickTockTaskType>(task->GetTaskType())) { return nullptr; }
J
Jinhui Yuan 已提交
917
  CHECK_EQ(task->device_type(), DeviceType::kGPU);
W
Will Zhang 已提交
918
  CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
W
willzhang4a58 已提交
919
  copy_task->Init(CopyHdOpConf::H2D, task->machine_id(), task->GpuPhyId());
W
Will Zhang 已提交
920
  return copy_task;
W
willzhang4a58 已提交
921 922
}

J
Jinhui Yuan 已提交
923 924
TaskNode* TaskGraph::AddCopyD2HTaskFrom(TaskNode* task) {
  CHECK_EQ(task->device_type(), DeviceType::kGPU);
W
Will Zhang 已提交
925
  CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
W
willzhang4a58 已提交
926
  copy_task->Init(CopyHdOpConf::D2H, task->machine_id(), task->GpuPhyId());
W
Will Zhang 已提交
927 928
  return copy_task;
}
929

J
Jinhui Yuan 已提交
930
TaskNode* TaskGraph::AddCopyCommNetTaskBetween(TaskNode* src, TaskNode* dst) {
W
Will Zhang 已提交
931 932
  CHECK_NE(src->machine_id(), dst->machine_id());
  CopyCommNetTaskNode* copy_comm_net_task = NewNode<CopyCommNetTaskNode>();
933
  copy_comm_net_task->Init(dst->machine_id(), src->machine_id());
J
Jinhui Yuan 已提交
934
  return copy_comm_net_task;
W
willzhang4a58 已提交
935 936
}

937 938 939 940 941 942
void TaskGraph::BuildOutBoxing(
    const LogicalNode* logical, const std::vector<CompTaskNode*>& sorted_comp_tasks,
    std::vector<TaskNode*>* sorted_out_box,
    std::function<TaskNode**(CompTaskNode* src, int64_t machine_id, int32_t mem_zone_id)>
        MutBufTask,
    std::function<int64_t(const TaskNode*)> AllocateCpuThrdIdEvenly) {
W
Will Zhang 已提交
943 944
  std::map<int64_t, std::vector<TaskNode*>> machine_id2bound_task;
  for (CompTaskNode* comp_task : sorted_comp_tasks) {
J
Jinhui Yuan 已提交
945 946
    TaskNode* task = comp_task;
    if (task->device_type() == DeviceType::kGPU) {
947 948 949 950 951 952 953 954 955
      TaskNode** buf_task =
          MutBufTask(comp_task, comp_task->machine_id(), Global<IDMgr>::Get()->CpuMemZoneId());
      if ((*buf_task) == nullptr) {
        task = AddCopyD2HTaskFrom(comp_task);
        Connect<TaskNode>(comp_task, NewEdge(), task);
        *buf_task = task;
      } else {
        task = *buf_task;
      }
J
Jinhui Yuan 已提交
956
    }
W
Will Zhang 已提交
957 958 959 960
    machine_id2bound_task[task->machine_id()].push_back(task);
  }
  for (const auto& pair : machine_id2bound_task) {
    OutBoxingTaskNode* boxing_task = NewNode<OutBoxingTaskNode>();
W
willzhang4a58 已提交
961
    boxing_task->set_machine_id(pair.second.front()->machine_id());
962
    boxing_task->set_thrd_id(AllocateCpuThrdIdEvenly(boxing_task));
W
willzhang4a58 已提交
963
    for (TaskNode* task : pair.second) { Connect<TaskNode>(task, NewEdge(), boxing_task); }
964
    sorted_out_box->push_back(boxing_task);
W
willzhang4a58 已提交
965 966 967
  }
}

W
willzhang4a58 已提交
968 969 970
void TaskGraph::BuildInBoxing(const LogicalNode* logical,
                              const std::vector<CompTaskNode*>& sorted_comp_tasks,
                              std::vector<TaskNode*>* sorted_in_box,
971
                              std::function<int64_t(const TaskNode*)> AllocateCpuThrdIdEvenly) {
W
Will Zhang 已提交
972 973
  std::map<int64_t, std::vector<TaskNode*>> machine_id2bound_task;
  for (CompTaskNode* comp_task : sorted_comp_tasks) {
J
Jinhui Yuan 已提交
974 975
    TaskNode* task = comp_task;
    if (task->device_type() == DeviceType::kGPU) {
L
Li Xinqi 已提交
976
      task = TryAddCopyH2DTaskTo(comp_task);
L
lixinqi 已提交
977 978
      if (task == nullptr) { task = comp_task; }
      if (task != comp_task) { Connect<TaskNode>(task, NewEdge(), comp_task); }
J
Jinhui Yuan 已提交
979
    }
W
Will Zhang 已提交
980
    machine_id2bound_task[task->machine_id()].push_back(task);
W
willzhang4a58 已提交
981
  }
W
Will Zhang 已提交
982 983
  for (const auto& pair : machine_id2bound_task) {
    InBoxingTaskNode* boxing_task = NewNode<InBoxingTaskNode>();
W
willzhang4a58 已提交
984
    boxing_task->set_machine_id(pair.second.front()->machine_id());
985
    boxing_task->set_thrd_id(AllocateCpuThrdIdEvenly(boxing_task));
W
willzhang4a58 已提交
986
    for (TaskNode* task : pair.second) { Connect<TaskNode>(boxing_task, NewEdge(), task); }
W
willzhang4a58 已提交
987
    sorted_in_box->push_back(boxing_task);
W
Will Zhang 已提交
988
  }
W
willzhang4a58 已提交
989 990
}

W
willzhang4a58 已提交
991
void TaskGraph::ConnectWithCopyCommNetIfNeed(TaskNode* src, TaskNode* dst) {
W
willzhang4a58 已提交
992 993 994
  if (src->machine_id() == dst->machine_id()) {
    Connect(src, NewEdge(), dst);
  } else {
J
Jinhui Yuan 已提交
995 996 997
    TaskNode* copy_comm_net_task = AddCopyCommNetTaskBetween(src, dst);
    Connect<TaskNode>(src, NewEdge(), copy_comm_net_task);
    Connect<TaskNode>(copy_comm_net_task, NewEdge(), dst);
W
willzhang4a58 已提交
998 999 1000
  }
}

L
Li Xinqi 已提交
1001
bool IsBackEdge(TaskNode* src, TaskNode* dst) { return false; }
J
Jinhui Yuan 已提交
1002

W
willzhang4a58 已提交
1003
}  // namespace oneflow