task_graph.cpp 43.4 KB
Newer Older
J
jiyuan 已提交
1
#include "oneflow/core/graph/task_graph.h"
C
chengtbf 已提交
2
#include "oneflow/core/graph/normal_forward_compute_task_node.h"
J
Jinhui Yuan 已提交
3
#include "oneflow/core/graph/chain_graph.h"
W
Will Zhang 已提交
4
#include "oneflow/core/graph/boxing_task_node.h"
J
Jinhui Yuan 已提交
5
#include "oneflow/core/common/util.h"
J
Jinhui Yuan 已提交
6
#include "oneflow/core/graph/reduce_add_compute_task_node.h"
L
Li Xinqi 已提交
7
#include "oneflow/core/graph/inplace_lbi_graph.h"
8
#include "oneflow/core/register/runtime_blob_desc.h"
J
Jinhui Yuan 已提交
9
#include "oneflow/core/job/thrd_id_generator.h"
L
Li Xinqi 已提交
10 11 12
#include "oneflow/core/graph/reduce_identity_task_node.h"
#include "oneflow/core/operator/variable_op.h"
#include "oneflow/core/operator/constant_op.h"
J
Juncheng 已提交
13 14 15 16 17 18
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_context.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/chain_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/nccl_boxing_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h"
W
willzhang4a58 已提交
19 20 21

namespace oneflow {

L
Li Xinqi 已提交
22 23
namespace {

L
lixinqi 已提交
24
bool IsInterfaceTask(const TaskNode* node) {
L
Li Xinqi 已提交
25 26 27 28
  const auto* comp_task_node = dynamic_cast<const CompTaskNode*>(node);
  if (comp_task_node == nullptr) { return false; }
  if (comp_task_node->logical_node()->op_vec().size() != 1) { return false; }
  auto op_type_case = comp_task_node->logical_node()->SoleOp()->op_conf().op_type_case();
L
lixinqi 已提交
29
  return IsClassRegistered<IsInterfaceOpConf4OpTypeCase>(op_type_case);
L
Li Xinqi 已提交
30 31
}

L
Li Xinqi 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
bool IsConnectToTickOp(const TaskNode* node) {
  const auto* comp_task_node = dynamic_cast<const CompTaskNode*>(node);
  if (comp_task_node == nullptr) { return false; }
  if (comp_task_node->logical_node()->op_vec().size() != 1) { return false; }
  const Operator* op = comp_task_node->logical_node()->SoleOp().get();
  if (dynamic_cast<const VariableOp*>(op) != nullptr) { return true; }
  if (dynamic_cast<const ConstantOp*>(op) != nullptr) { return true; }
  return false;
}

void ForEachDeviceSrcUntrainableNode(const std::vector<NormalForwardCompTaskNode*>& fw_nodes,
                                     const std::function<void(CompTaskNode*)>& Handler) {
  HashSet<const TaskNode*> fw_nodes_set(fw_nodes.begin(), fw_nodes.end());
  auto IsSourceTaskNode = [&](NormalForwardCompTaskNode* node) {
    for (TaskEdge* edge : node->in_edges()) {
      if (fw_nodes_set.find(edge->src_node()) != fw_nodes_set.end()) { return false; }
    }
    return true;
  };
  auto HasBwNode = [&](NormalForwardCompTaskNode* node) {
52 53 54 55
    // TODO: update method for fw bw split
    // const auto* fw_logical_node = dynamic_cast<const ForwardLogicalNode*>(node->logical_node());
    // return fw_logical_node->bw_node() != nullptr;
    return false;
L
Li Xinqi 已提交
56 57 58 59 60 61
  };
  for (NormalForwardCompTaskNode* fw_node : fw_nodes) {
    if (IsSourceTaskNode(fw_node) && !HasBwNode(fw_node)) { Handler(fw_node); }
  }
}

L
Li Xinqi 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
std::function<TaskNode*(const std::string&)> MakeGetterTaskNode4SoleOpName(
    const HashSet<TaskNode*>& task_nodes) {
  auto op_name2task_nodes = std::make_shared<HashMap<std::string, HashSet<TaskNode*>>>();
  for (TaskNode* task_node : task_nodes) {
    if (task_node->exec_gph().node_num() == 1) {
      ExecNode* exec_node = task_node->exec_gph().SoleNode();
      CHECK((*op_name2task_nodes)[exec_node->op()->op_name()].emplace(task_node).second);
    }
  }
  return [op_name2task_nodes](const std::string& op_name) -> TaskNode* {
    const auto& iter = op_name2task_nodes->find(op_name);
    if (iter == op_name2task_nodes->end()) { return nullptr; }
    if (iter->second.size() > 1) { return nullptr; }
    return *iter->second.begin();
  };
J
Juncheng 已提交
77
}
L
Li Xinqi 已提交
78 79 80 81 82 83 84 85 86

bool IsLbiOnTaskEdge(const TaskEdge* edge, const LogicalBlobId& lbi) {
  for (const auto& regst_desc : edge->GetRegsts()) {
    if (regst_desc->HasLbi(lbi)) { return true; }
  }
  return false;
}

std::function<bool(const LogicalBlobId&, const std::string&)>
S
scxfjiang 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
MakePredicatorIsLbiAllConsumersReachable(
    std::function<const TaskNode*(const std::string&)> TaskNode4SoleOpName,
    std::function<bool(const std::string&, const std::string&)> IsOpNameDataOrCtrlReachable) {
  auto IsDataOrCtrlReachable = [IsOpNameDataOrCtrlReachable](const TaskNode* src_node,
                                                             const TaskNode* dst_node) -> bool {
    if (src_node->chain_id() == dst_node->chain_id()
        && src_node->order_in_graph() <= dst_node->order_in_graph()) {
      return true;
    }
    const CompTaskNode* comp_src_node = dynamic_cast<const CompTaskNode*>(src_node);
    if (comp_src_node == nullptr) { return false; }
    if (comp_src_node->logical_node()->op_vec().size() != 1) { return false; }
    const CompTaskNode* comp_dst_node = dynamic_cast<const CompTaskNode*>(dst_node);
    if (comp_dst_node == nullptr) { return false; }
    if (comp_dst_node->logical_node()->op_vec().size() != 1) { return false; }
    return IsOpNameDataOrCtrlReachable(comp_src_node->logical_node()->SoleOp()->op_name(),
                                       comp_dst_node->logical_node()->SoleOp()->op_name());
  };
  return [TaskNode4SoleOpName, IsDataOrCtrlReachable](const LogicalBlobId& lbi,
                                                      const std::string& op_name) -> bool {
L
Li Xinqi 已提交
107 108 109
    const TaskNode* src_task_node = TaskNode4SoleOpName(lbi.op_name());
    const TaskNode* dst_task_node = TaskNode4SoleOpName(op_name);
    size_t out_edges_size = 0;
L
Li Xinqi 已提交
110
    size_t reachable_out_edges_size = 0;
L
Li Xinqi 已提交
111 112 113
    for (TaskEdge* out_edge : src_task_node->out_edges()) {
      if (IsLbiOnTaskEdge(out_edge, lbi)) {
        out_edges_size += 1;
S
scxfjiang 已提交
114
        reachable_out_edges_size += IsDataOrCtrlReachable(out_edge->dst_node(), dst_task_node);
L
Li Xinqi 已提交
115 116
      }
    }
L
Li Xinqi 已提交
117
    return out_edges_size > 0 && out_edges_size == reachable_out_edges_size;
L
Li Xinqi 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131
  };
}

bool IsInplaceAllowed(
    TaskNode* task_node, const std::vector<std::string>& bns,
    const std::function<const TaskNode*(const std::string&)>& TaskNode4SoleOpName) {
  if (task_node->exec_gph().node_num() != 1) { return false; }
  const auto& exec_node = *task_node->exec_gph().SoleNode();
  for (const auto& bn : bns) {
    // TaskNode for bn is not nullptr if it's on the same device with `task_node`
    if (TaskNode4SoleOpName(exec_node.op()->BnInOp2Lbi(bn).op_name()) == nullptr) { return false; }
    const RegstDesc& regst_desc = *exec_node.RegstDesc4BnInOp(bn);
    if (regst_desc.NumOfLbi() != 1) { return false; }
  }
L
lixinqi 已提交
132
  const BlobDesc* first_blob = nullptr;
133
  for (const auto& bn : bns) {
L
lixinqi 已提交
134 135 136
    const BlobDesc* blob_desc = exec_node.RegstDesc4BnInOp(bn)->SoleBlobDesc();
    if (first_blob == nullptr) {
      first_blob = blob_desc;
137
    } else {
L
lixinqi 已提交
138 139 140 141
      if (!(first_blob->shape() == blob_desc->shape()
            && first_blob->data_type() == blob_desc->data_type())) {
        return false;
      }
142 143
    }
  }
L
Li Xinqi 已提交
144 145 146
  return true;
}

L
Li Xinqi 已提交
147 148
}  // namespace

149 150
TaskGraph::TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph) {
  logical_gph_ = std::move(logical_gph);
W
willzhang4a58 已提交
151
  HashMap<const LogicalNode*, std::vector<CompTaskNode*>> logical2sorted_comp_tasks;
152 153
  HashMap<const LogicalNode*, std::vector<TaskNode*>> logical2sorted_in_box;
  HashMap<const LogicalNode*, std::vector<TaskNode*>> logical2sorted_out_box;
J
Jinhui Yuan 已提交
154 155 156
  HashMap<CompTaskNode*, HashMap<int64_t, std::vector<TaskNode*>>> buf_task;
  auto MutBufTask = [&](CompTaskNode* task_node, int64_t machine_id, int32_t mem_zone_id) {
    auto& buf_vec = buf_task[task_node][machine_id];
X
Xinqi 已提交
157
    if (buf_vec.empty()) { buf_vec.assign(Global<ResourceDesc>::Get()->MemZoneNum(), nullptr); }
158 159
    return &(buf_vec.at(mem_zone_id));
  };
W
willzhang4a58 已提交
160

X
Xinqi 已提交
161
  std::vector<int64_t> cpu_device_offset(Global<ResourceDesc>::Get()->TotalMachineNum(), 0);
162
  auto AllocateCpuThrdIdEvenly = [&](const TaskNode* task_node) {
163
    CHECK(!task_node->IsIndependent());
W
willzhang4a58 已提交
164
    int64_t ret = -1;
J
Jinhui Yuan 已提交
165 166
    int64_t& offset = cpu_device_offset.at(task_node->machine_id());
    ret = Global<IDMgr>::Get()->GetCpuDeviceThrdId(offset);
X
Xinqi 已提交
167
    offset = (offset + 1) % Global<ResourceDesc>::Get()->CpuDeviceNum();
168 169
    return ret;
  };
J
Jinhui Yuan 已提交
170 171

  std::vector<std::pair<int64_t, CompTaskNode*>> machine_persistence_task_vec;
172
  logical_gph_->ForEachNode([&](const LogicalNode* logical_node) {
173
    logical_node->GenSortedCompTaskNodes(
J
Jinhui Yuan 已提交
174
        AllocateCpuThrdIdEvenly, &machine_persistence_task_vec, [&](CompTaskNode* comp_task_node) {
175 176
          AddAllocatedNode(comp_task_node);
          logical2sorted_comp_tasks[logical_node].push_back(comp_task_node);
177
          comp_task_node->set_area_id(logical_node->GetAreaId());
178
        });
179
  });
N
Niu Chong 已提交
180

181
  GenerateIndependentThrdId(machine_persistence_task_vec);
182
  logical_gph_->ForEachEdge([&](const LogicalEdge* logical_edge) {
W
willzhang4a58 已提交
183 184
    BldSubTskGphMthd method =
        GetMthdForBldSubTskGph(logical_edge->src_node(), logical_edge->dst_node());
185 186
    (this->*method)(logical_edge->src_node(), logical_edge->dst_node(),
                    logical2sorted_comp_tasks.at(logical_edge->src_node()),
W
willzhang4a58 已提交
187
                    logical2sorted_comp_tasks.at(logical_edge->dst_node()), &logical2sorted_in_box,
J
Jinhui Yuan 已提交
188
                    &logical2sorted_out_box, MutBufTask, AllocateCpuThrdIdEvenly);
J
Jinhui Yuan 已提交
189
    SetAreaIdForNewNodes(logical_edge->src_node(), logical_edge->dst_node());
190
  });
L
Li Xinqi 已提交
191 192 193 194 195 196 197
  logical_gph_->ForEachNecessaryCtrlEdge(
      [&](const LogicalNode* src, const LogicalNode* dst, int64_t ctrl_regst_num) {
        const auto& src_task_nodes = logical2sorted_comp_tasks.at(src);
        const auto& dst_task_nodes = logical2sorted_comp_tasks.at(dst);
        ConnectCtrlEdges(src_task_nodes, dst_task_nodes, ctrl_regst_num);
      });

198
  MergeChainAndSetOrderInGraphForEachNode();
J
Juncheng 已提交
199
  if (Global<ResourceDesc>::Get()->enable_debug_mode()) { ToDotWithAutoFilePath(); }
W
Will Zhang 已提交
200 201
}

L
Li Xinqi 已提交
202 203 204 205 206
void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_nodes,
                                 const std::vector<CompTaskNode*>& dst_task_nodes,
                                 int64_t ctrl_regst_num) {
  CHECK_EQ(src_task_nodes.size(), dst_task_nodes.size());
  FOR_RANGE(int32_t, i, 0, src_task_nodes.size()) {
207 208 209
    std::string regst_desc_name;
    RegstDesc* ctrl_regst_desc =
        src_task_nodes.at(i)->BuildCtrlRegstDesc(dst_task_nodes.at(i), &regst_desc_name);
L
Li Xinqi 已提交
210 211 212 213
    ctrl_regst_desc->UpdtMinRegstNumIfNeed(ctrl_regst_num);
    ctrl_regst_desc->UpdtMaxRegstNumIfNeed(ctrl_regst_num);
    ctrl_regst_desc->mut_regst_desc_type()->mutable_ctrl_regst_desc()->set_returned_regst_num(
        ctrl_regst_num);
214 215 216 217

    TaskEdge* edge = NewEdge();
    Connect<TaskNode>(src_task_nodes.at(i), edge, dst_task_nodes.at(i));
    src_task_nodes.at(i)->BindEdgeWithProducedRegst(edge, regst_desc_name);
L
Li Xinqi 已提交
218 219 220
  }
}

221
void TaskGraph::GenerateIndependentThrdId(
J
Jinhui Yuan 已提交
222 223 224 225 226 227
    const std::vector<std::pair<int64_t, CompTaskNode*>>& persistence_nodes) {
  std::vector<std::pair<int64_t, TaskType>> machine_task_type_vec;
  for (auto pair : persistence_nodes) {
    machine_task_type_vec.emplace_back(std::make_pair(pair.first, pair.second->GetTaskType()));
  }

228
  ThrdIdGenerator generator(machine_task_type_vec, Global<IDMgr>::Get()->BaseIndependentThrdId());
J
Jinhui Yuan 已提交
229 230 231 232 233 234
  for (const auto pair : persistence_nodes) {
    int64_t thrd_id = generator.GenerateThrdId(pair.first, pair.second->GetTaskType());
    pair.second->set_thrd_id(thrd_id);
  }
}

L
Li Xinqi 已提交
235 236 237 238 239 240 241 242 243 244 245 246 247
void TaskGraph::MdUpdtDelayedTopoForEachNode(std::function<void(TaskNode* node)> Handler) const {
  HashSet<const TaskNode*> built_nodes;
  auto Build = [&](TaskNode* node) {
    CHECK(built_nodes.emplace(node).second);
    Handler(node);
  };
  AcyclicTopoForEachNode([](TaskNode* node) { return node->GetTaskType() != kNormalMdUpdt; },
                         Build);
  AcyclicTopoForEachNode([](TaskNode* node) { return node->GetTaskType() == kNormalMdUpdt; },
                         Build);
  ForEachNode([&](TaskNode* node) { CHECK(built_nodes.find(node) != built_nodes.end()); });
}

248 249 250
void TaskGraph::AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAllowedStartNode,
                                       std::function<void(TaskNode* node)> Handler) const {
  auto ForEachInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
251
    node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) {
L
Li Xinqi 已提交
252
      if (IsBackEdge(node_on_in_edge, node)) { return; }
253
      Handler(const_cast<TaskNode*>(node_on_in_edge));
254 255
    });
  };
256
  auto ForEachOutNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
257
    node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) {
L
Li Xinqi 已提交
258
      if (IsBackEdge(node, node_on_out_edge)) { return; }
259
      Handler(const_cast<TaskNode*>(node_on_out_edge));
260 261
    });
  };
L
Li Xinqi 已提交
262 263 264 265 266 267 268 269 270
  auto IsSourceNode = [&](TaskNode* node) {
    int32_t in_node_num = 0;
    ForEachInNode(node, [&](TaskNode* in_node) { ++in_node_num; });
    return in_node_num == 0;
  };
  std::list<TaskNode*> starts;
  ForEachNode([&](TaskNode* node) {
    if (IsSourceNode(node) && IsAllowedStartNode(node)) { starts.push_back(node); }
  });
271
  // DfsTopo will cause inappropriate chain graph
272 273 274 275 276
  TopoForEachNode(starts, ForEachInNode, ForEachOutNode, Handler);
}

void TaskGraph::AcyclicTopoForEachNode(std::function<void(TaskNode* node)> Handler) const {
  return AcyclicTopoForEachNode([](TaskNode*) { return true; }, Handler);
277 278
}

J
Jinhui Yuan 已提交
279 280 281 282
void TaskGraph::RemoveEmptyRegsts() {
  ForEachNode([&](TaskNode* node) { node->EraseZeroSizeProducedBlob(); });
  ForEachNode([&](TaskNode* node) { node->EraseZeroSizeConsumedRegst(); });
  ForEachNode([&](TaskNode* node) { node->EraseZeroSizeProducedRegst(); });
283
  ForEachNode([&](TaskNode* node) { node->UnbindBnWithEmptyRegst(); });
J
Jinhui Yuan 已提交
284 285
}

286
void TaskGraph::AddOrderingCtrlEdgeInSameChain() { BuildCtrlRegstDescInSameChain(); }
J
Jinhui Yuan 已提交
287

288 289 290
void TaskGraph::MergeChainAndSetOrderInGraphForEachNode() {
  ChainGraph chain_graph(*this);
  const auto& ordered_chain_nodes = chain_graph.OrderdedChainNodes();
J
Jinhui Yuan 已提交
291 292
  int64_t order_in_graph = 0;
  for (auto& chain_node : ordered_chain_nodes) {
293
    auto& ordered_in_chain = chain_node->TaskNodes();
J
Jinhui Yuan 已提交
294 295 296 297 298 299 300 301 302 303
    int64_t chain_id = chain_node->chain_id();
    for (auto& task_node : ordered_in_chain) {
      task_node->set_chain_id(chain_id);
      task_node->set_order_in_graph(order_in_graph);
      ordered_task_nodes_.emplace_back(task_node);
      ++order_in_graph;
    }
  }
}

304
void TaskGraph::BuildCtrlRegstDescInSameChain() {
J
Jinhui Yuan 已提交
305
  HashMap<int64_t, TaskNode*> chain_id2node;
L
Li Xinqi 已提交
306 307
  for (auto* node : ordered_task_nodes_) {
    if (IsConnectToTickOp(node)) { continue; }
J
Jinhui Yuan 已提交
308 309 310 311 312 313 314 315 316 317 318
    int64_t chain_id = node->chain_id();
    auto iter = chain_id2node.find(chain_id);
    if (iter == chain_id2node.end()) {
      CHECK(chain_id2node.emplace(chain_id, node).second);
    } else {
      iter->second->BuildCtrlRegstDescIfNeed(node);
      iter->second = node;
    }
  }
}

L
Li Xinqi 已提交
319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
void TaskGraph::AddReduceNoBwForwardNodeOverlapingCtrlEdges() {
  HashMap<int64_t, std::vector<ReduceIdentityCompTaskNode*>> global_thrd_id2identity_nodes;
  HashMap<std::pair<int64_t, int64_t>, std::vector<NormalForwardCompTaskNode*>>
      global_dev_phy_id2fw_nodes;
  const auto* id_mgr = Global<IDMgr>::Get();
  for (auto* node : ordered_task_nodes_) {
    if (id_mgr->GetDeviceTypeFromThrdId(node->thrd_id()) == DeviceType::kCPU) { continue; }
    int64_t global_thrd_id = id_mgr->GlobalThrdId4TaskId(node->task_id());
    auto* identity_node = dynamic_cast<ReduceIdentityCompTaskNode*>(node);
    auto* fw_node = dynamic_cast<NormalForwardCompTaskNode*>(node);
    if (identity_node != nullptr) {
      global_thrd_id2identity_nodes[global_thrd_id].push_back(identity_node);
    } else if (fw_node != nullptr) {
      int64_t dev_phy_id = id_mgr->GetGpuPhyIdFromThrdId(node->thrd_id());
      global_dev_phy_id2fw_nodes[std::make_pair(node->machine_id(), dev_phy_id)].push_back(fw_node);
    } else {
      // do nothing
    }
  }
  auto GetIdentityNodeOrder = [&](const ReduceIdentityCompTaskNode* id_node) {
    const auto* id_logical_node =
        dynamic_cast<const ReduceIdentityLogicalNode*>(id_node->logical_node());
    return id_logical_node->order_in_logical_graph();
  };
  for (auto& pair : global_thrd_id2identity_nodes) {
    auto& identity_nodes = pair.second;
    std::sort(identity_nodes.begin(), identity_nodes.end(),
              [&](ReduceIdentityCompTaskNode* lhs, ReduceIdentityCompTaskNode* rhs) {
                return GetIdentityNodeOrder(lhs) < GetIdentityNodeOrder(rhs);
              });
    auto* first_identity_node = identity_nodes.at(0);
    int64_t machine_id = first_identity_node->machine_id();
    int64_t dev_phy_id = id_mgr->GetGpuPhyIdFromThrdId(first_identity_node->thrd_id());
    const auto& fw_nodes = global_dev_phy_id2fw_nodes.at(std::make_pair(machine_id, dev_phy_id));
    const Shape& identity_time_shape =
        *first_identity_node->GetProducedRegst("out")->data_regst_time_shape();
    ForEachDeviceSrcUntrainableNode(fw_nodes, [&](CompTaskNode* node) {
      std::shared_ptr<RegstDesc> regst_desc = node->GetProducedRegst("out");
      if (!regst_desc) { return; }
      const Shape& time_shape = *regst_desc->data_regst_time_shape();
L
Li Xinqi 已提交
359
      if (!time_shape.Containing(identity_time_shape)) { return; }
L
Li Xinqi 已提交
360 361 362 363 364 365 366 367 368 369 370
      CHECK_EQ(time_shape.elem_cnt() % identity_time_shape.elem_cnt(), 0);
      int regst_desc_num = time_shape.elem_cnt() / identity_time_shape.elem_cnt();
      RegstDesc* ctrl_regst_desc = node->BuildCtrlRegstDesc(first_identity_node);
      ctrl_regst_desc->UpdtMinRegstNumIfNeed(regst_desc_num);
      ctrl_regst_desc->UpdtMaxRegstNumIfNeed(regst_desc_num);
      ctrl_regst_desc->mut_regst_desc_type()->mutable_ctrl_regst_desc()->set_returned_regst_num(
          regst_desc_num);
    });
  }
}

371
void TaskGraph::EnableInplaceMemSharingInReduceStruct() {
J
Juncheng 已提交
372
  auto GetSuccReduceTaskNode = [](TaskNode* pred) {
J
Jinhui Yuan 已提交
373
    std::vector<TaskNode*> nodes;
J
Juncheng 已提交
374
    pred->ForEachNodeOnOutDataEdge([&](TaskNode* succ) {
J
Juncheng 已提交
375
      if (dynamic_cast<ReduceCompTaskNodeIf*>(succ) != nullptr) { nodes.push_back(succ); }
J
Jinhui Yuan 已提交
376 377
    });
    return nodes;
378 379
  };

J
Jinhui Yuan 已提交
380 381 382 383
  HashSet<TaskNode*> has_enabled_nodes;

  auto CollectReduceTaskNode = [&](TaskNode* from) {
    std::list<TaskNode*> nodes;
J
Juncheng 已提交
384 385
    nodes.push_back(from);
    TaskNode* pred = from;
J
Jinhui Yuan 已提交
386
    while (true) {
J
Juncheng 已提交
387 388 389 390 391 392
      std::vector<TaskNode*> succ_reduce_nodes = GetSuccReduceTaskNode(pred);
      if (succ_reduce_nodes.size() != 1) { break; }
      TaskNode* succ_reduce_node = succ_reduce_nodes.front();
      if (has_enabled_nodes.find(succ_reduce_node) != has_enabled_nodes.end()) { break; }
      nodes.push_back(succ_reduce_node);
      pred = succ_reduce_node;
393
    }
J
Jinhui Yuan 已提交
394
    return nodes;
395 396
  };

J
Juncheng 已提交
397 398
  auto CalcModelSize = [](ReduceIdentityCompTaskNode* node) {
    return InferRegstSize(*node->produced_regsts().at("out").get());
J
Jinhui Yuan 已提交
399
  };
400

J
Jinhui Yuan 已提交
401
  ForEachNode([&](TaskNode* node) {
J
Juncheng 已提交
402 403 404 405 406 407
    ReduceIdentityCompTaskNode* identity_node = dynamic_cast<ReduceIdentityCompTaskNode*>(node);
    if (!identity_node) { return; }
    if (identity_node->device_type() != DeviceType::kGPU) { return; }
    if (identity_node->parallel_ctx()->parallel_num() < 2) { return; }
    std::list<TaskNode*> reduce_task_nodes = CollectReduceTaskNode(identity_node);

408
    const int64_t mem_block_id = Global<IDMgr>::Get()->NewMemBlockId();
J
Juncheng 已提交
409
    const int64_t mem_size = CalcModelSize(identity_node);
410
    ReduceMemSharingCtx ctx(mem_size, mem_block_id);
J
Jinhui Yuan 已提交
411 412
    for (TaskNode* reduce_node : reduce_task_nodes) {
      auto reduce_task_node_if = dynamic_cast<ReduceCompTaskNodeIf*>(reduce_node);
J
Juncheng 已提交
413
      CHECK_NOTNULL(reduce_task_node_if);
J
Jinhui Yuan 已提交
414 415
      reduce_task_node_if->EnableMemSharingInReduce(ctx);
      has_enabled_nodes.insert(reduce_node);
416 417 418 419
    }
  });
}

L
Li Xinqi 已提交
420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
void TaskGraph::GetInplaceOpBlobArgList(
    OpBlobArgList* inplace_obas, const HashSet<TaskNode*>& dev_nodes,
    const std::function<const TaskNode*(const std::string&)>& TaskNode4OpName) const {
  for (TaskNode* task_node : dev_nodes) {
    if (task_node->exec_gph().node_num() != 1) { continue; }
    const auto& op = *task_node->exec_gph().SoleNode()->op();
    for (const std::string& ibn : op.input_bns()) {
      if (op.InputBlobModifier4Ibn(ibn).is_mutable()) {
        CHECK(IsInplaceAllowed(task_node, {ibn}, TaskNode4OpName));
        *inplace_obas->mutable_oba()->Add() = GenOpBlobArg(op.op_name(), ibn);
      }
    }
    for (const std::string& obn : op.output_bns()) {
      std::string ibn = "";
      {
        const auto& obn_modifier = op.OutputBlobModifier4Obn(obn);
        if (obn_modifier.has_const_inplace_ibn()) {
          ibn = obn_modifier.const_inplace_ibn();
        } else if (obn_modifier.has_mutable_inplace_ibn()) {
          ibn = obn_modifier.mutable_inplace_ibn();
        } else {
          // do nothing
        }
      }
      if (ibn != "" && IsInplaceAllowed(task_node, {ibn, obn}, TaskNode4OpName)) {
        *inplace_obas->mutable_oba()->Add() = GenOpBlobArg(op.op_name(), obn);
      }
L
Li Xinqi 已提交
447
    }
L
Li Xinqi 已提交
448 449 450 451 452
  }
}

void TaskGraph::GetSafeInplaceOpBlobArgList(
    OpBlobArgList* safe_obas, const HashSet<TaskNode*>& dev_nodes,
S
scxfjiang 已提交
453
    std::function<bool(const std::string&, const std::string&)> IsOpNameDataOrCtrlReachable) const {
L
Li Xinqi 已提交
454 455 456 457 458 459
  auto TaskNode4SoleOpName = MakeGetterTaskNode4SoleOpName(dev_nodes);
  OpBlobArgList inplace_obas;
  GetInplaceOpBlobArgList(&inplace_obas, dev_nodes, TaskNode4SoleOpName);
  auto Op4OpName = [&](const std::string& op_name) -> const Operator* {
    return TaskNode4SoleOpName(op_name)->exec_gph().SoleNode()->op().get();
  };
S
scxfjiang 已提交
460 461
  auto IsLbiAllConsumersReachable =
      MakePredicatorIsLbiAllConsumersReachable(TaskNode4SoleOpName, IsOpNameDataOrCtrlReachable);
462
  InplaceLbiGraph origin_graph(inplace_obas, Op4OpName);
S
scxfjiang 已提交
463
  origin_graph.ComputeSafeInplaceObns(safe_obas, IsLbiAllConsumersReachable);
J
Juncheng 已提交
464 465 466 467 468 469 470
  InplaceLbiGraph safe_graph(*safe_obas, Op4OpName);
  if (Global<ResourceDesc>::Get()->enable_debug_mode()) {
    origin_graph.ToDotWithFilePath(
        JoinPath("dot", "InplaceLbiGraph", GlobalJobDesc().job_name() + "_origin.dot"));
    safe_graph.ToDotWithFilePath(
        JoinPath("dot", "InplaceLbiGraph", GlobalJobDesc().job_name() + "_safe.dot"));
  }
L
Li Xinqi 已提交
471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487
}

void TaskGraph::SetTaskRegstInplaceInfo(const OpBlobArgList& obas,
                                        const HashSet<TaskNode*>& dev_nodes) const {
  auto TaskNode4SoleOpName = MakeGetterTaskNode4SoleOpName(dev_nodes);
  auto Op4OpName = [&](const std::string& op_name) -> const Operator* {
    return TaskNode4SoleOpName(op_name)->exec_gph().SoleNode()->op().get();
  };
  InplaceLbiGraph inplace_gph(obas, Op4OpName);
  inplace_gph.ForEachConnectedComponent([&](const HashSet<const InplaceLbiNode*> inplace_nodes) {
    for (const auto* inplace_node : inplace_nodes) {
      if (inplace_node->in_edges().empty()) { continue; }
      const auto* inplace_edge = inplace_node->SoleInEdge();
      auto* exec_node = TaskNode4SoleOpName(inplace_edge->op().op_name())->exec_gph().SoleNode();
      RegstDesc* in_regst = exec_node->RegstDesc4BnInOp(inplace_edge->ibn());
      RegstDesc* out_regst = exec_node->RegstDesc4BnInOp(inplace_edge->obn());
      out_regst->set_hint_inplace_consumed_regst_desc_id(in_regst->regst_desc_id());
L
Li Xinqi 已提交
488
    }
L
Li Xinqi 已提交
489 490 491 492 493 494 495 496 497 498 499 500 501 502 503
  });
}

void TaskGraph::ForEachGpuDeviceNodes(
    const std::function<void(const HashSet<TaskNode*>& dev_nodes)>& Handler) const {
  HashMap<std::pair<int64_t, int64_t>, HashSet<TaskNode*>> global_dev_phy_id2nodes;
  ForEachNode([&](TaskNode* task_node) {
    if (task_node->device_type() != DeviceType::kGPU) { return; }
    int64_t dev_phy_id = Global<IDMgr>::Get()->GetGpuPhyIdFromThrdId(task_node->thrd_id());
    global_dev_phy_id2nodes[{task_node->machine_id(), dev_phy_id}].emplace(task_node);
  });
  for (const auto& pair : global_dev_phy_id2nodes) { Handler(pair.second); }
}

void TaskGraph::EnableInplaceMemSharing(
S
scxfjiang 已提交
504 505
    const std::function<bool(const std::string&, const std::string&)>&
        IsOpNameDataOrCtrlReachable) {
L
Li Xinqi 已提交
506 507
  ForEachGpuDeviceNodes([&](const HashSet<TaskNode*>& dev_nodes) {
    OpBlobArgList safe_inplace_obas;
S
scxfjiang 已提交
508
    GetSafeInplaceOpBlobArgList(&safe_inplace_obas, dev_nodes, IsOpNameDataOrCtrlReachable);
L
Li Xinqi 已提交
509
    SetTaskRegstInplaceInfo(safe_inplace_obas, dev_nodes);
L
Li Xinqi 已提交
510 511 512
  });
}

C
chengtbf 已提交
513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531
void TaskGraph::AddOrderCtrlEdgeBetweenCopyAndMdUpdt() {
  for (TaskNode* task_node : ordered_task_nodes_) {
    auto copy_hd_task_node = dynamic_cast<CopyHdTaskNode*>(task_node);
    if (copy_hd_task_node == nullptr) { continue; }
    if (copy_hd_task_node->copy_type() != CopyHdOpConf::H2D) { continue; }
    if (copy_hd_task_node->area_id() != static_cast<int64_t>(kDataForwardArea)
        && copy_hd_task_node->area_id() != static_cast<int64_t>(kBoundaryArea)) {
      continue;
    }
    std::vector<TaskNode*> candidate_nodes;
    auto ForEachNextNode = [&](TaskNode* node,
                               const std::function<void(TaskNode*)>& TryPushNodeToQueue) {
      node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) {
        if (IsForwardTaskType(node_on_out_edge->GetTaskType())) {
          TryPushNodeToQueue(node_on_out_edge);
        }
      });
    };
    auto HandlerAddCandidate = [&](TaskNode* node) {
532 533
      TODO();  // refactor the following code
      /*
C
chengtbf 已提交
534 535 536 537 538 539
      auto fw_task_node = dynamic_cast<NormalForwardCompTaskNode*>(node);
      if (fw_task_node != nullptr && fw_task_node->logical_node()->HasOpWithModelBlob()
          && fw_task_node->parallel_ctx()->parallel_num() > 1
          && fw_task_node->parallel_ctx()->policy() == kDataParallel) {
        candidate_nodes.push_back(node);
      }
540
      */
C
chengtbf 已提交
541 542 543 544 545 546 547 548 549 550 551 552 553 554
    };
    BfsForEachNode({task_node}, ForEachNextNode, HandlerAddCandidate);
    std::sort(candidate_nodes.begin(), candidate_nodes.end(),
              [](const TaskNode* a, const TaskNode* b) {
                return a->order_in_graph() < b->order_in_graph();
              });
    int64_t last_chain_id = -1;
    for (TaskNode* candidate_node : candidate_nodes) {
      if (candidate_node->chain_id() != last_chain_id) {
        last_chain_id = candidate_node->chain_id();
        candidate_node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) {
          if (IsMdUpdtTaskType(node_on_in_edge->GetTaskType())) {
            RegstDesc* ctrl_regst = task_node->BuildCtrlRegstDesc(node_on_in_edge);
            RegstDesc* copy_out_regst = copy_hd_task_node->GetProducedRegst("copy_out").get();
X
Xinqi 已提交
555
            int64_t piece_num_in_batch = GlobalJobDesc().NumOfPiecesInBatch();
C
chengtbf 已提交
556 557 558 559 560 561 562 563 564 565 566 567
            ctrl_regst->UpdtMinRegstNumIfNeed(copy_out_regst->min_register_num()
                                              + piece_num_in_batch - 1);
            CtrlRegstDesc* ctrl_regst_desc =
                ctrl_regst->mut_regst_desc_type()->mutable_ctrl_regst_desc();
            ctrl_regst_desc->set_reliant_regst_desc_id(copy_out_regst->regst_desc_id());
            ctrl_regst_desc->set_returned_regst_num(piece_num_in_batch);
          }
        });
      }
    }
  }
}
J
Jinhui Yuan 已提交
568 569 570 571 572 573

void TaskGraph::SetAreaIdForNewNodes(const LogicalNode* src_logical,
                                     const LogicalNode* dst_logical) {
  CHECK(src_logical != nullptr && dst_logical != nullptr);
  ForEachNode([&](TaskNode* node) {
    if (node->area_id() != static_cast<int64_t>(kInvalidArea)) return;
574 575
    if (src_logical->GetAreaId() == dst_logical->GetAreaId()) {
      node->set_area_id(src_logical->GetAreaId());
J
Jinhui Yuan 已提交
576 577 578 579 580 581
    } else {
      node->set_area_id(static_cast<int64_t>(kBoundaryArea));
    }
  });
}

582 583
#define DEFINE_BLD_SUB_TASK_GRAPH_METHOD(method_name) \
  void TaskGraph::method_name BLD_SUB_TSK_GPH_MTHD_ARGS()
584 585

DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) {
J
Juncheng 已提交
586 587 588 589 590 591 592 593 594 595 596 597
  if (GlobalJobDesc().use_boxing_v2()) {
    BldSubTskGphByBoxingV2(src_logical, dst_logical, sorted_src_comp_tasks, sorted_dst_comp_tasks,
                           logical2sorted_in_box, logical2sorted_out_box, std::move(MutBufTask),
                           std::move(AllocateCpuThrdIdEvenly));
  } else {
    BldSubTskGphByBoxingV1(src_logical, dst_logical, sorted_src_comp_tasks, sorted_dst_comp_tasks,
                           logical2sorted_in_box, logical2sorted_out_box, std::move(MutBufTask),
                           std::move(AllocateCpuThrdIdEvenly));
  }
}

DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxingV1) {
598
  std::vector<TaskNode*>* sorted_out_box = nullptr;
W
willzhang4a58 已提交
599 600
  if (logical2sorted_out_box->find(src_logical) == logical2sorted_out_box->end()) {
    BuildOutBoxing(src_logical, sorted_src_comp_tasks, &((*logical2sorted_out_box)[src_logical]),
601
                   MutBufTask, AllocateCpuThrdIdEvenly);
602
  }
603
  sorted_out_box = &(logical2sorted_out_box->at(src_logical));
604

W
willzhang4a58 已提交
605
  std::vector<TaskNode*>* sorted_in_box = nullptr;
W
willzhang4a58 已提交
606 607
  if (logical2sorted_in_box->find(dst_logical) == logical2sorted_in_box->end()) {
    BuildInBoxing(dst_logical, sorted_dst_comp_tasks, &((*logical2sorted_in_box)[dst_logical]),
608
                  AllocateCpuThrdIdEvenly);
W
willzhang4a58 已提交
609
  }
610
  sorted_in_box = &(logical2sorted_in_box->at(dst_logical));
611 612

  for (TaskNode* src_box : *sorted_out_box) {
J
Jinhui Yuan 已提交
613
    for (TaskNode* dst_box : *sorted_in_box) { ConnectWithCopyCommNetIfNeed(src_box, dst_box); }
W
willzhang4a58 已提交
614 615 616
  }
}

J
Juncheng 已提交
617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxingV2) {
  const std::vector<LogicalBlobId> lbis = src_logical->GetLbisTo(dst_logical);
  const auto Fallback = [&]() {
    BldSubTskGphByBoxingV1(src_logical, dst_logical, sorted_src_comp_tasks, sorted_dst_comp_tasks,
                           logical2sorted_in_box, logical2sorted_out_box, std::move(MutBufTask),
                           std::move(AllocateCpuThrdIdEvenly));
  };
  if (lbis.size() > 1) {
    Fallback();
  } else {
    CHECK_EQ(lbis.size(), 1);
    const LogicalBlobId& lbi = lbis.front();
    const SbpParallel& src_sbp_parallel =
        Global<OpGraph>::Get()->GetSbpParallel(src_logical->SoleOp()->op_name(), lbi);
    const SbpParallel& dst_sbp_parallel =
        Global<OpGraph>::Get()->GetSbpParallel(dst_logical->SoleOp()->op_name(), lbi);
    const std::shared_ptr<const ParallelDesc>& src_parallel_desc = src_logical->parallel_desc();
    const std::shared_ptr<const ParallelDesc>& dst_parallel_desc = dst_logical->parallel_desc();
    const BlobDesc& blob_desc = Global<OpGraph>::Get()->GetLogicalBlobDesc(lbi);
    SubTskGphBuilderCtx ctx(this);
    std::vector<std::shared_ptr<SubTskGphBuilder>> builders;
    builders.emplace_back(new NcclBoxingSubTskGphBuilder());
    Maybe<void> status = TRY(ChainSubTskGphBuilder(builders).Build(
        &ctx, sorted_src_comp_tasks, sorted_dst_comp_tasks, *src_parallel_desc, *dst_parallel_desc,
        lbi, blob_desc, src_sbp_parallel, dst_sbp_parallel));
    if (!status.IsOk()) {
      if (SubTskGphBuilderUtil::IsErrorBoxingNotSupported(*status.error())) {
        Fallback();
      } else {
        UNIMPLEMENTED();
      }
    }
  }
}

652
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByOneToOne) {
W
Will Zhang 已提交
653 654
  CHECK_EQ(sorted_src_comp_tasks.size(), sorted_dst_comp_tasks.size());
  FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) {
655 656
    CompTaskNode* src = sorted_src_comp_tasks.at(i);
    CompTaskNode* dst = sorted_dst_comp_tasks.at(i);
657
    BuildTaskPath(src, dst, MutBufTask, true);
W
Will Zhang 已提交
658
  }
W
willzhang4a58 已提交
659 660
}

661
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBroadcastToBroadcast) {
662 663 664 665 666 667 668 669 670
  CHECK_EQ(sorted_dst_comp_tasks.size() % sorted_src_comp_tasks.size(), 0);
  if (sorted_src_comp_tasks.size() == sorted_dst_comp_tasks.size()) {
    FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) {
      CompTaskNode* src = sorted_src_comp_tasks.at(i);
      CompTaskNode* dst = sorted_dst_comp_tasks.at(i);
      BuildTaskPath(src, dst, MutBufTask, true);
    }
    return;
  }
671 672 673 674 675 676 677 678
  HashMap<size_t, CompTaskNode*> machine_id2last_src_task;
  HashMap<std::pair<int64_t, int64_t>, CompTaskNode*> global_thrd_id2src_task;
  auto GlobalThrdId4TaskNode = [](TaskNode* task_node) -> std::pair<int64_t, int64_t> {
    return std::make_pair(task_node->machine_id(), task_node->thrd_id());
  };
  for (CompTaskNode* src_node : sorted_src_comp_tasks) {
    machine_id2last_src_task[src_node->machine_id()] = src_node;
    global_thrd_id2src_task[GlobalThrdId4TaskNode(src_node)] = src_node;
L
Li Xinqi 已提交
679
  }
680
  HashMap<std::pair<int64_t, int64_t>, CompTaskNode*> global_thrd_id2dst_task;
L
Li Xinqi 已提交
681
  for (CompTaskNode* dst_node : sorted_dst_comp_tasks) {
682
    global_thrd_id2dst_task[GlobalThrdId4TaskNode(dst_node)] = dst_node;
L
Li Xinqi 已提交
683
  }
684 685 686 687 688 689 690 691 692
  auto GetSrcNode = [&](const std::pair<int64_t, int64_t>& global_thrd_id) -> CompTaskNode* {
    const auto& src_task_it = global_thrd_id2src_task.find(global_thrd_id);
    if (src_task_it != global_thrd_id2src_task.end()) { return src_task_it->second; }
    const auto& m_src_task_it = machine_id2last_src_task.find(global_thrd_id.first);
    if (m_src_task_it != machine_id2last_src_task.end()) { return m_src_task_it->second; }
    return machine_id2last_src_task.begin()->second;
  };
  for (const auto& pair : global_thrd_id2dst_task) {
    BuildTaskPath(GetSrcNode(pair.first), pair.second, MutBufTask, true);
L
Li Xinqi 已提交
693 694 695
  }
}

696
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphBySelectOneSourceToSoleSink) {
W
Will Zhang 已提交
697 698 699 700 701 702 703 704 705 706 707 708
  CHECK_EQ(sorted_dst_comp_tasks.size(), 1);
  CompTaskNode* sole_dst_comp_task = sorted_dst_comp_tasks.front();
  CompTaskNode* selected_src_comp_task = nullptr;
  bool is_same_machine = false;
  auto UpdateSelected = [&](CompTaskNode* node) {
    selected_src_comp_task = node;
    is_same_machine = (node->machine_id() == sole_dst_comp_task->machine_id());
  };
  for (CompTaskNode* src_comp_task : sorted_src_comp_tasks) {
    if (selected_src_comp_task == nullptr) {
      UpdateSelected(src_comp_task);
      continue;
W
willzhang4a58 已提交
709
    }
W
Will Zhang 已提交
710 711 712 713 714
    if (src_comp_task->machine_id() == sole_dst_comp_task->machine_id()) {
      if (is_same_machine == false) {
        UpdateSelected(src_comp_task);
        continue;
      }
W
willzhang4a58 已提交
715
      if (src_comp_task->thrd_id() == sole_dst_comp_task->thrd_id()) {
W
Will Zhang 已提交
716 717 718
        UpdateSelected(src_comp_task);
        break;
      }
W
willzhang4a58 已提交
719 720
    }
  }
W
Will Zhang 已提交
721
  CHECK_NOTNULL(selected_src_comp_task);
W
willzhang4a58 已提交
722
  BldSubTskGphByOneToOne(nullptr, nullptr, {selected_src_comp_task}, sorted_dst_comp_tasks, nullptr,
J
Jinhui Yuan 已提交
723
                         nullptr, MutBufTask, AllocateCpuThrdIdEvenly);
W
willzhang4a58 已提交
724 725
}

J
Jinhui Yuan 已提交
726 727 728 729 730 731 732 733 734 735 736
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByReduceScatter2ReduceAdd) {
  const LogicalNode* src_logical_node = sorted_src_comp_tasks.front()->logical_node();
  const auto& pd = src_logical_node->parallel_desc();
  bool has_local_reduce =
      pd->sorted_machine_ids().size() > 1 && pd->device_num_of_each_machine() > 1;
  const LogicalNode* pred_src_logical_node = src_logical_node->SoleInEdge()->src_node();
  bool is_local_reduce =
      has_local_reduce
          ? !(dynamic_cast<const ReduceAddLogicalNode*>(pred_src_logical_node)
              || dynamic_cast<const NcclReduceScatterLogicalNode*>(pred_src_logical_node))
          : false;
737 738
  for (CompTaskNode* src_comp_task : sorted_src_comp_tasks) {
    for (CompTaskNode* dst_comp_task : sorted_dst_comp_tasks) {
J
Jinhui Yuan 已提交
739 740 741 742 743 744 745 746 747 748 749 750
      if (has_local_reduce) {
        if (is_local_reduce) {
          if (src_comp_task->machine_id() == dst_comp_task->machine_id()) {
            BuildTaskPath(src_comp_task, dst_comp_task, MutBufTask, false);
          }
        } else {
          if (src_comp_task->parallel_id() % pd->device_num_of_each_machine()
              == dst_comp_task->parallel_id() % pd->device_num_of_each_machine()) {
            BuildTaskPath(src_comp_task, dst_comp_task, MutBufTask, false);
          }
        }
      } else {
751
        BuildTaskPath(src_comp_task, dst_comp_task, MutBufTask, false);
W
willzhang4a58 已提交
752
      }
J
Jinhui Yuan 已提交
753 754 755 756
    }
  }
}

J
Jinhui Yuan 已提交
757
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByReduceAdd2ReduceGather) {
J
Jinhui Yuan 已提交
758 759 760
  const auto& pd = sorted_src_comp_tasks.front()->logical_node()->parallel_desc();
  bool has_local_reduce =
      pd->sorted_machine_ids().size() > 1 && pd->device_num_of_each_machine() > 1;
J
Jinhui Yuan 已提交
761 762
  for (CompTaskNode* src_comp_task : sorted_src_comp_tasks) {
    for (CompTaskNode* dst_comp_task : sorted_dst_comp_tasks) {
J
Jinhui Yuan 已提交
763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783
      if (has_local_reduce) {
        if (src_comp_task->parallel_id() % pd->device_num_of_each_machine()
            == dst_comp_task->parallel_id() % pd->device_num_of_each_machine()) {
          BuildTaskPath(src_comp_task, dst_comp_task, MutBufTask, true);
        }
      } else {
        BuildTaskPath(src_comp_task, dst_comp_task, MutBufTask, true);
      }
    }
  }
}

DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByReduceGather2ReduceGather) {
  const auto& pd = sorted_src_comp_tasks.front()->logical_node()->parallel_desc();
  CHECK_GT(pd->device_num_of_each_machine(), 1);
  CHECK_GT(pd->sorted_machine_ids().size(), 1);
  for (CompTaskNode* src_comp_task : sorted_src_comp_tasks) {
    for (CompTaskNode* dst_comp_task : sorted_dst_comp_tasks) {
      if (src_comp_task->machine_id() == dst_comp_task->machine_id()) {
        BuildTaskPath(src_comp_task, dst_comp_task, MutBufTask, true);
      }
W
willzhang4a58 已提交
784 785 786
    }
  }
}
787

788 789 790 791 792 793 794 795 796 797
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByConnectNodeOnSameGpuDevice) {
  for (CompTaskNode* src : sorted_src_comp_tasks) {
    for (CompTaskNode* dst : sorted_dst_comp_tasks) {
      if (src->machine_id() == dst->machine_id() && src->GpuPhyId() == dst->GpuPhyId()) {
        Connect<TaskNode>(src, NewEdge(), dst);
      }
    }
  }
}

L
Li Xinqi 已提交
798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialInLbiConnect) {
  HashSet<LogicalBlobId> lbis;
  for (const auto& obn : src_logical->SoleOp()->output_bns()) {
    lbis.insert(src_logical->SoleOp()->BnInOp2Lbi(obn));
  }
  CHECK_EQ(sorted_src_comp_tasks.size(), 1);
  CHECK_EQ(dst_logical->SoleOp()->input_bns().size(), sorted_dst_comp_tasks.size());
  FOR_RANGE(int, i, 0, sorted_dst_comp_tasks.size()) {
    const auto& lbi = dst_logical->SoleOp()->BnInOp2Lbi(dst_logical->SoleOp()->input_bns().Get(i));
    if (lbis.find(lbi) != lbis.end()) {
      BuildTaskPath(sorted_src_comp_tasks.at(0), sorted_dst_comp_tasks.at(i), MutBufTask, true);
    }
  }
}

DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialOutLbiConnect) {
  HashSet<LogicalBlobId> lbis;
  for (const auto& ibn : dst_logical->SoleOp()->input_bns()) {
    lbis.insert(dst_logical->SoleOp()->BnInOp2Lbi(ibn));
  }
  CHECK_EQ(sorted_dst_comp_tasks.size(), 1);
  CHECK_EQ(src_logical->SoleOp()->output_bns().size(), sorted_src_comp_tasks.size());
  FOR_RANGE(int, i, 0, sorted_src_comp_tasks.size()) {
    const auto& lbi = src_logical->SoleOp()->BnInOp2Lbi(src_logical->SoleOp()->output_bns().Get(i));
    if (lbis.find(lbi) != lbis.end()) {
      BuildTaskPath(sorted_src_comp_tasks.at(i), sorted_dst_comp_tasks.at(0), MutBufTask, true);
    }
  }
}

J
Jinhui Yuan 已提交
828
void TaskGraph::BuildTaskPath(
J
Jinhui Yuan 已提交
829 830
    CompTaskNode* src, CompTaskNode* dst,
    std::function<TaskNode**(CompTaskNode* src, int64_t machine_id, int32_t mem_zone_id)>
J
Jinhui Yuan 已提交
831 832
        MutBufTask,
    bool use_buf_task_node) {
J
Jinhui Yuan 已提交
833
  CHECK_NE(src, dst);
J
Jinhui Yuan 已提交
834 835
  auto GetBufTask = [&](int64_t machine_id, int32_t mem_zone_id) {
    return *MutBufTask(src, machine_id, mem_zone_id);
J
Jinhui Yuan 已提交
836
  };
J
Jinhui Yuan 已提交
837 838
  auto SetBufTask = [&](int64_t machine_id, int32_t mem_zone_id, TaskNode* new_val) {
    TaskNode** cur_val = MutBufTask(src, machine_id, mem_zone_id);
J
Jinhui Yuan 已提交
839 840 841 842 843 844 845 846
    if (*cur_val == nullptr) {
      *cur_val = new_val;
    } else {
      CHECK_EQ(*cur_val, new_val);
    }
    return new_val;
  };

J
Jinhui Yuan 已提交
847 848 849
  TaskNode* cur_node = src;
  while (cur_node->machine_id() != dst->machine_id()
         || cur_node->MemZoneId121() != dst->MemZoneId121()) {
J
Jinhui Yuan 已提交
850
    cur_node = BuildTaskStep(cur_node, dst, GetBufTask, SetBufTask, use_buf_task_node);
851
  }
L
lixinqi 已提交
852
  if (cur_node != dst) { Connect<TaskNode>(cur_node, NewEdge(), dst); }
J
Jinhui Yuan 已提交
853
}
854

J
Jinhui Yuan 已提交
855
TaskNode* TaskGraph::BuildTaskStep(
J
Jinhui Yuan 已提交
856
    TaskNode* cur_node, TaskNode* dst,
J
Jinhui Yuan 已提交
857 858 859
    std::function<TaskNode*(int64_t machine_id, int32_t mem_zone_id)> GetBufTask,
    std::function<TaskNode*(int64_t machine_id, int32_t mem_zone_id, TaskNode*)> SetBufTask,
    bool use_buf_task_node) {
J
Jinhui Yuan 已提交
860 861 862 863 864
  int32_t cpu_mem_zone_id = Global<IDMgr>::Get()->CpuMemZoneId();
  int32_t next_mem_zone_id = -1;
  TaskNode* next_node = nullptr;
  if (cur_node->MemZoneId121() != cpu_mem_zone_id) {
    next_mem_zone_id = cpu_mem_zone_id;
J
Jinhui Yuan 已提交
865
    if (!use_buf_task_node || !(next_node = GetBufTask(cur_node->machine_id(), next_mem_zone_id))) {
J
Jinhui Yuan 已提交
866 867
      next_node = AddCopyD2HTaskFrom(cur_node);
      Connect<TaskNode>(cur_node, NewEdge(), next_node);
868
    }
J
Jinhui Yuan 已提交
869 870
  } else if (cur_node->machine_id() == dst->machine_id()) {
    next_mem_zone_id = dst->MemZoneId121();
J
Jinhui Yuan 已提交
871
    if (!use_buf_task_node || !(next_node = GetBufTask(cur_node->machine_id(), next_mem_zone_id))) {
L
Li Xinqi 已提交
872
      next_node = TryAddCopyH2DTaskTo(dst);
L
lixinqi 已提交
873
      if (next_node == nullptr) { next_node = dst; }
J
Jinhui Yuan 已提交
874
      Connect<TaskNode>(cur_node, NewEdge(), next_node);
875
    }
J
Jinhui Yuan 已提交
876 877
  } else if (cur_node->machine_id() != dst->machine_id()) {
    next_mem_zone_id = cpu_mem_zone_id;
J
Jinhui Yuan 已提交
878
    if (!use_buf_task_node || !(next_node = GetBufTask(dst->machine_id(), next_mem_zone_id))) {
J
Jinhui Yuan 已提交
879 880 881 882 883
      next_node = AddCopyCommNetTaskBetween(cur_node, dst);
      Connect<TaskNode>(cur_node, NewEdge(), next_node);
    }
  } else {
    UNIMPLEMENTED();
884
  }
L
lixinqi 已提交
885 886 887
  if (use_buf_task_node && (next_node != dst)) {
    SetBufTask(next_node->machine_id(), next_mem_zone_id, next_node);
  }
J
Jinhui Yuan 已提交
888
  return next_node;
889 890
}

L
Li Xinqi 已提交
891
TaskNode* TaskGraph::TryAddCopyH2DTaskTo(TaskNode* task) {
L
lixinqi 已提交
892
  if (IsInterfaceTask(task)) { return nullptr; }
L
lixinqi 已提交
893
  if (IsClassRegistered<TickTockTaskType>(task->GetTaskType())) { return nullptr; }
J
Jinhui Yuan 已提交
894
  CHECK_EQ(task->device_type(), DeviceType::kGPU);
W
Will Zhang 已提交
895
  CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
W
willzhang4a58 已提交
896
  copy_task->Init(CopyHdOpConf::H2D, task->machine_id(), task->GpuPhyId());
W
Will Zhang 已提交
897
  return copy_task;
W
willzhang4a58 已提交
898 899
}

J
Jinhui Yuan 已提交
900 901
TaskNode* TaskGraph::AddCopyD2HTaskFrom(TaskNode* task) {
  CHECK_EQ(task->device_type(), DeviceType::kGPU);
W
Will Zhang 已提交
902
  CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
W
willzhang4a58 已提交
903
  copy_task->Init(CopyHdOpConf::D2H, task->machine_id(), task->GpuPhyId());
W
Will Zhang 已提交
904 905
  return copy_task;
}
906

J
Jinhui Yuan 已提交
907
TaskNode* TaskGraph::AddCopyCommNetTaskBetween(TaskNode* src, TaskNode* dst) {
W
Will Zhang 已提交
908 909
  CHECK_NE(src->machine_id(), dst->machine_id());
  CopyCommNetTaskNode* copy_comm_net_task = NewNode<CopyCommNetTaskNode>();
910
  copy_comm_net_task->Init(dst->machine_id(), src->machine_id());
J
Jinhui Yuan 已提交
911
  return copy_comm_net_task;
W
willzhang4a58 已提交
912 913
}

914 915 916 917 918 919
void TaskGraph::BuildOutBoxing(
    const LogicalNode* logical, const std::vector<CompTaskNode*>& sorted_comp_tasks,
    std::vector<TaskNode*>* sorted_out_box,
    std::function<TaskNode**(CompTaskNode* src, int64_t machine_id, int32_t mem_zone_id)>
        MutBufTask,
    std::function<int64_t(const TaskNode*)> AllocateCpuThrdIdEvenly) {
W
Will Zhang 已提交
920 921
  std::map<int64_t, std::vector<TaskNode*>> machine_id2bound_task;
  for (CompTaskNode* comp_task : sorted_comp_tasks) {
J
Jinhui Yuan 已提交
922 923
    TaskNode* task = comp_task;
    if (task->device_type() == DeviceType::kGPU) {
924 925 926 927 928 929 930 931 932
      TaskNode** buf_task =
          MutBufTask(comp_task, comp_task->machine_id(), Global<IDMgr>::Get()->CpuMemZoneId());
      if ((*buf_task) == nullptr) {
        task = AddCopyD2HTaskFrom(comp_task);
        Connect<TaskNode>(comp_task, NewEdge(), task);
        *buf_task = task;
      } else {
        task = *buf_task;
      }
J
Jinhui Yuan 已提交
933
    }
W
Will Zhang 已提交
934 935 936 937
    machine_id2bound_task[task->machine_id()].push_back(task);
  }
  for (const auto& pair : machine_id2bound_task) {
    OutBoxingTaskNode* boxing_task = NewNode<OutBoxingTaskNode>();
W
willzhang4a58 已提交
938
    boxing_task->set_machine_id(pair.second.front()->machine_id());
939
    boxing_task->set_thrd_id(AllocateCpuThrdIdEvenly(boxing_task));
W
willzhang4a58 已提交
940
    for (TaskNode* task : pair.second) { Connect<TaskNode>(task, NewEdge(), boxing_task); }
941
    sorted_out_box->push_back(boxing_task);
W
willzhang4a58 已提交
942 943 944
  }
}

W
willzhang4a58 已提交
945 946 947
void TaskGraph::BuildInBoxing(const LogicalNode* logical,
                              const std::vector<CompTaskNode*>& sorted_comp_tasks,
                              std::vector<TaskNode*>* sorted_in_box,
948
                              std::function<int64_t(const TaskNode*)> AllocateCpuThrdIdEvenly) {
W
Will Zhang 已提交
949 950
  std::map<int64_t, std::vector<TaskNode*>> machine_id2bound_task;
  for (CompTaskNode* comp_task : sorted_comp_tasks) {
J
Jinhui Yuan 已提交
951 952
    TaskNode* task = comp_task;
    if (task->device_type() == DeviceType::kGPU) {
L
Li Xinqi 已提交
953
      task = TryAddCopyH2DTaskTo(comp_task);
L
lixinqi 已提交
954 955
      if (task == nullptr) { task = comp_task; }
      if (task != comp_task) { Connect<TaskNode>(task, NewEdge(), comp_task); }
J
Jinhui Yuan 已提交
956
    }
W
Will Zhang 已提交
957
    machine_id2bound_task[task->machine_id()].push_back(task);
W
willzhang4a58 已提交
958
  }
W
Will Zhang 已提交
959 960
  for (const auto& pair : machine_id2bound_task) {
    InBoxingTaskNode* boxing_task = NewNode<InBoxingTaskNode>();
W
willzhang4a58 已提交
961
    boxing_task->set_machine_id(pair.second.front()->machine_id());
962
    boxing_task->set_thrd_id(AllocateCpuThrdIdEvenly(boxing_task));
W
willzhang4a58 已提交
963
    for (TaskNode* task : pair.second) { Connect<TaskNode>(boxing_task, NewEdge(), task); }
W
willzhang4a58 已提交
964
    sorted_in_box->push_back(boxing_task);
W
Will Zhang 已提交
965
  }
W
willzhang4a58 已提交
966 967
}

W
willzhang4a58 已提交
968
void TaskGraph::ConnectWithCopyCommNetIfNeed(TaskNode* src, TaskNode* dst) {
W
willzhang4a58 已提交
969 970 971
  if (src->machine_id() == dst->machine_id()) {
    Connect(src, NewEdge(), dst);
  } else {
J
Jinhui Yuan 已提交
972 973 974
    TaskNode* copy_comm_net_task = AddCopyCommNetTaskBetween(src, dst);
    Connect<TaskNode>(src, NewEdge(), copy_comm_net_task);
    Connect<TaskNode>(copy_comm_net_task, NewEdge(), dst);
W
willzhang4a58 已提交
975 976 977
  }
}

L
Li Xinqi 已提交
978
bool IsBackEdge(TaskNode* src, TaskNode* dst) { return false; }
J
Jinhui Yuan 已提交
979

W
willzhang4a58 已提交
980
}  // namespace oneflow