graph_helper.cc 29.3 KB
Newer Older
X
better  
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduo 已提交
15
#include "paddle/fluid/framework/ir/graph_helper.h"
16

17
#include <queue>
Y
Yan Chunwei 已提交
18
#include <stack>
19

20
#include "paddle/fluid/framework/details/multi_devices_helper.h"
21
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
22
#include "paddle/fluid/framework/ir/pass.h"
23
#include "paddle/fluid/framework/op_proto_maker.h"
24
#include "paddle/fluid/framework/program_utils.h"
X
better  
Xin Pan 已提交
25

26 27 28 29 30
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/collective_helper.h"
#endif

31
DECLARE_bool(convert_all_blocks);
32 33
PADDLE_DEFINE_EXPORTED_string(print_sub_graph_dir,
                              "",
34 35
                              "FLAGS_print_sub_graph_dir is used "
                              "to print the nodes of sub_graphs.");
C
chengduo 已提交
36

X
better  
Xin Pan 已提交
37 38 39 40
namespace paddle {
namespace framework {
namespace ir {
namespace {
41 42

template <class NodeComparator = ir::NodeComp>
43 44
void SortHelper(const std::map<ir::Node *,
                               std::set<ir::Node *, NodeComparator>,
45
                               NodeComparator> &adj_list,
46 47
                ir::Node *node,
                std::unordered_set<ir::Node *> *visited,
48
                std::vector<ir::Node *> *ret) {
X
better  
Xin Pan 已提交
49 50 51 52
  visited->insert(node);

  for (auto adj : adj_list.at(node)) {
    if (visited->find(adj) == visited->end()) {
53
      SortHelper<NodeComparator>(adj_list, adj, visited, ret);
X
better  
Xin Pan 已提交
54 55 56
    }
  }

Y
Yan Chunwei 已提交
57
  VLOG(5) << "topology sort insert: " << node->Name() << " "
M
minqiyang 已提交
58
          << reinterpret_cast<void *>(node) << " input " << node->inputs.size();
X
better  
Xin Pan 已提交
59 60 61
  ret->push_back(node);
}

62
template <class NodeComparator = ir::NodeComp>
63 64 65 66 67 68 69
bool HasCircleHelper(ir::Node *node,
                     const std::map<ir::Node *,
                                    std::set<ir::Node *, NodeComparator>,
                                    NodeComparator> &adj_list,
                     std::unordered_set<ir::Node *> *visited,
                     std::unordered_set<ir::Node *> *in_trace,
                     std::vector<std::vector<ir::Node *>> *circles) {
X
better  
Xin Pan 已提交
70 71 72 73 74 75
  if (visited->find(node) == visited->end()) {
    visited->insert(node);
    in_trace->insert(node);

    for (ir::Node *in : adj_list.at(node)) {
      if (visited->find(in) == visited->end() &&
76 77
          HasCircleHelper<NodeComparator>(
              in, adj_list, visited, in_trace, circles)) {
X
better  
Xin Pan 已提交
78 79
        return true;
      } else if (in_trace->find(in) != in_trace->end()) {
D
dzhwinter 已提交
80 81 82 83 84 85 86 87 88 89 90 91
        if (circles != nullptr) {
          std::vector<ir::Node *> circle;
          circle.emplace_back(in);
          ir::Node *p = in;
          for (auto &adj : adj_list.at(p)) {
            if (in_trace->count(adj)) {
              circle.emplace_back(adj);
              p = adj;
            }
          }
          circles->emplace_back(circle);
        }
X
better  
Xin Pan 已提交
92 93 94 95 96 97 98 99
        return true;
      }
    }
  }
  in_trace->erase(node);
  return false;
}

100
template <class NodeComparator = ir::NodeComp>
101 102 103 104
bool HasCircleInternal(const std::map<ir::Node *,
                                      std::set<ir::Node *, NodeComparator>,
                                      NodeComparator> &adj_list,
                       std::vector<std::vector<ir::Node *>> *circles) {
X
better  
Xin Pan 已提交
105 106 107
  std::unordered_set<ir::Node *> visited;
  std::unordered_set<ir::Node *> in_trace;
  for (auto &adj : adj_list) {
108 109
    if (HasCircleHelper<NodeComparator>(
            adj.first, adj_list, &visited, &in_trace, circles)) {
X
better  
Xin Pan 已提交
110 111 112 113 114
      return true;
    }
  }
  return false;
}
X
Xin Pan 已提交
115 116 117
}  // namespace

bool HasCircle(const Graph &graph) {
D
dzhwinter 已提交
118 119 120
  return HasCircleInternal(BuildOperationAdjList(graph), nullptr);
}

121 122 123 124 125 126 127 128 129 130
bool VarDescIsConsistency(const Graph &graph) {
  std::unordered_map<std::string, std::unordered_set<ir::Node *>>
      var_name2node_set;
  for (ir::Node *node : graph.Nodes()) {
    if (node->IsVar() && node->Var()) {
      var_name2node_set[node->Var()->Name()].emplace(node);
    }
  }
  for (auto &iter : var_name2node_set) {
    auto &first_node = *iter.second.begin();
131 132
    bool is_persistable = std::any_of(iter.second.begin(),
                                      iter.second.end(),
133 134 135 136 137
                                      [&first_node](const ir::Node *node) {
                                        return node->Var()->Persistable();
                                      });
    if (is_persistable) {
      bool is_consistency =
138 139
          std::all_of(iter.second.begin(),
                      iter.second.end(),
140 141 142 143 144 145 146 147
                      [&first_node](const ir::Node *node) {
                        return *node->Var() == *first_node->Var();
                      });
      if (!is_consistency) return false;
    }
  }
  return true;
}
D
dzhwinter 已提交
148 149 150
bool FindCircleSubGraph(const Graph &graph,
                        std::vector<std::vector<ir::Node *>> *circles) {
  return HasCircleInternal(BuildOperationAdjList(graph), circles);
X
Xin Pan 已提交
151
}
X
better  
Xin Pan 已提交
152

X
Xin Pan 已提交
153
std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
154 155
  std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
      adj_list = BuildOperationAdjList(graph);
156 157
  PADDLE_ENFORCE_EQ(HasCircleInternal(adj_list, nullptr),
                    false,
158 159
                    platform::errors::InvalidArgument(
                        "Generated graph shouldn't contain cycle."));
X
better  
Xin Pan 已提交
160 161 162 163
  std::unordered_set<ir::Node *> visited;
  std::vector<ir::Node *> ret;
  for (auto adj : adj_list) {
    if (visited.find(adj.first) == visited.end()) {
164
      SortHelper<ir::NodeComp>(adj_list, adj.first, &visited, &ret);
X
better  
Xin Pan 已提交
165 166
    }
  }
167

X
better  
Xin Pan 已提交
168 169 170
  return ret;
}

Z
Zeng Jinle 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
bool IsTopologySortOperationsUnique(const Graph &graph) {
  auto nodes = TopologySortOperations(graph);
  size_t n = nodes.size();
  for (size_t i = 1; i < n; ++i) {
    auto *prev_op = nodes[i - 1];
    auto *cur_op = nodes[i];

    std::unordered_set<Node *> prev_op_outputs;
    for (auto *output : prev_op->outputs) {
      prev_op_outputs.insert(output);
    }

    bool found = false;
    for (auto *input : cur_op->inputs) {
      if (prev_op_outputs.count(input) > 0) {
        found = true;
        break;
      }
    }
    if (!found) {
      return false;
    }
  }
  return true;
}

Y
Yan Chunwei 已提交
197 198 199 200 201 202 203 204 205 206 207 208
// Build operator outlink edge table.
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationOutAdjList(
    const Graph &graph) {
  std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;

  for (auto &n : graph.Nodes()) {
    if (!n->IsOp()) continue;
    if (adj_list.find(n) == adj_list.end()) {
      adj_list[n] = std::unordered_set<ir::Node *>();
    }
    for (auto &var : n->outputs) {
      for (auto &adj_n : var->outputs) {
209 210 211 212 213 214
        PADDLE_ENFORCE_EQ(adj_n->NodeType(),
                          ir::Node::Type::kOperation,
                          platform::errors::InvalidArgument(
                              "Node(%s)'s type(%d) must be kOperation type.",
                              adj_n->Name(),
                              static_cast<int>(adj_n->NodeType())));
Y
Yan Chunwei 已提交
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
        VLOG(40) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
                 << " -> " << n->Name() << reinterpret_cast<void *>(n)
                 << "  via " << var->Name() << reinterpret_cast<void *>(var);
        adj_list[n].insert(adj_n);
      }
    }
  }
  return adj_list;
}

std::vector<ir::Node *> OpDFSSort(const Graph &graph) {
  auto edge_table = BuildOperationOutAdjList(graph);
  std::stack<Node *> stack;
  for (auto &ele : edge_table) {
    if (ele.first->inputs.empty()) {
      // find the input ops (those without input vars)
      stack.push(ele.first);
    } else {
      // find the ops with only persistable vars as inputs.
      bool all_persistable = true;
      for (auto *input : ele.first->inputs) {
        if (!(input->IsVar() && input->Var() && input->Var()->Persistable())) {
          all_persistable = false;
        }
      }
      if (all_persistable) {
        stack.push(ele.first);
      }
    }
  }

  std::vector<Node *> res;
  // start from the feed op and DFS
  std::unordered_set<Node *> unique_set;
  while (!stack.empty()) {
    // will start from the last feed by default.
    auto cur = stack.top();
    stack.pop();
    unique_set.insert(cur);
    res.push_back(cur);

    for (auto *op : edge_table[cur]) {
      if (!unique_set.count(op)) {
        stack.push(op);
      }
    }
  }
  return res;
}

std::vector<ir::Node *> TopologyDfsSortOperations(const Graph &graph) {
  std::vector<ir::Node *> nodes;
  std::unordered_map<Node *, int> in_degree;

  auto set_out_ops_ready = [&](Node *var) {
    for (auto *op : var->outputs) {
      --in_degree[op];
    }
  };
  // build in_degree
  for (auto *node : graph.Nodes()) {
    if (node->IsOp()) {
      in_degree[node] += node->inputs.size();
    } else if (node->IsVar() && node->inputs.empty()) {
      // put all the inputs of the whole graph ready.
      set_out_ops_ready(node);
    }
  }

  std::deque<Node *> op_queue;
  // first visit
  for (auto &node : OpDFSSort(graph)) {
    if (node->IsOp()) {
      op_queue.push_back(node);
    }
  }

  // traverse the graph
  int num_ops = op_queue.size();
  while (num_ops) {
    for (auto it = op_queue.begin(); it != op_queue.end(); it++) {
      auto *&cur_op = *it;
      if (!cur_op || in_degree[cur_op] > 0) continue;
      // visit this node
      // put all the output var of this op valid.
      for (auto *out_var : cur_op->outputs) {
        if (!out_var) continue;
        set_out_ops_ready(out_var);
      }
      VLOG(8) << "visit " << cur_op->Name();
      nodes.push_back(cur_op);

      cur_op = nullptr;
      num_ops--;
    }
  }

  return nodes;
}

C
chengduo 已提交
315
size_t GraphNum(const Graph &graph) {
D
dzhwinter 已提交
316
  std::unordered_set<ir::Node *> nodes(graph.Nodes());
C
chengduo 已提交
317 318 319 320 321
  std::unordered_set<ir::Node *> visited_nodes;
  visited_nodes.reserve(nodes.size());
  std::deque<ir::Node *> q_nodes;
  std::vector<std::unordered_set<ir::Node *>> graph_nodes;
  std::unordered_set<ir::Node *> g_nodes;
W
Wu Yi 已提交
322 323
  // q_set used to record records in the queue.
  std::unordered_set<ir::Node *> q_set;
C
chengduo 已提交
324 325
  size_t graph_count = 0;

326 327 328 329 330 331 332 333 334
  auto traverse_nodes =
      [&visited_nodes, &q_nodes, &q_set](const std::vector<ir::Node *> &nodes) {
        for (auto n : nodes) {
          if (visited_nodes.count(n) == 0 && q_set.count(n) == 0) {
            q_nodes.push_back(n);
            q_set.insert(n);
          }
        }
      };
C
chengduo 已提交
335 336 337 338 339

  while (visited_nodes.size() != nodes.size()) {
    if (!q_nodes.empty()) {
      auto cur_node = q_nodes.front();
      q_nodes.pop_front();
W
Wu Yi 已提交
340
      q_set.erase(cur_node);
C
chengduo 已提交
341 342 343 344 345 346 347 348 349 350 351 352 353
      visited_nodes.insert(cur_node);
      g_nodes.insert(cur_node);
      traverse_nodes(cur_node->inputs);
      traverse_nodes(cur_node->outputs);
    } else {
      ++graph_count;
      if (g_nodes.size()) {
        graph_nodes.emplace_back(g_nodes);
      }
      g_nodes.clear();
      for (auto &n : nodes) {
        if (visited_nodes.count(n) == 0) {
          q_nodes.push_back(n);
W
Wu Yi 已提交
354
          q_set.insert(n);
C
chengduo 已提交
355 356 357 358 359 360 361 362 363 364
          break;
        }
      }
    }
  }

  if (g_nodes.size()) {
    graph_nodes.emplace_back(g_nodes);
  }

C
chengduo 已提交
365 366 367 368 369 370 371 372 373
  if (FLAGS_print_sub_graph_dir.size()) {
    if (graph_nodes.size() > 1) {
      std::stringstream out;
      for (auto &g_n : graph_nodes) {
        out << "graph_nodes: " << g_n.size() << "\n";
      }
      out << "\n\n";
      for (auto &g_n : graph_nodes) {
        out << "graph_nodes: " << g_n.size();
C
chengduo 已提交
374 375 376 377 378 379 380 381 382 383 384
        for (auto &node : g_n) {
          out << "\nNode: " << node->Name() << " in [";
          for (auto &n : node->inputs) {
            out << n->Name() << ", ";
          }
          out << "], out[";
          for (auto &n : node->outputs) {
            out << n->Name() << ", ";
          }
          out << "]";
        }
C
chengduo 已提交
385
        out << "\n\n\n";
C
chengduo 已提交
386
      }
C
chengduo 已提交
387 388
      std::unique_ptr<std::ostream> fout(
          new std::ofstream(FLAGS_print_sub_graph_dir));
389 390
      PADDLE_ENFORCE_EQ(fout->good(),
                        true,
391 392 393
                        platform::errors::Unavailable(
                            "Can not open file %s for printing the graph.",
                            FLAGS_print_sub_graph_dir));
C
chengduo 已提交
394
      *fout << out.str();
C
chengduo 已提交
395 396 397 398 399 400
    }
  }

  return graph_count;
}

Y
Yan Chunwei 已提交
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
void CleanIndividualNodes(Graph *graph) {
  std::unordered_set<Node *> nodes2rm;
  for (auto *node : graph->Nodes()) {
    if (node->inputs.empty() && node->outputs.empty()) {
      nodes2rm.insert(node);
    }
  }

  for (auto *node : nodes2rm) {
    graph->RemoveNode(node);
  }
}

std::vector<Node *> TopologyVarientSort(const Graph &graph,
                                        SortKind sort_kind) {
  switch (sort_kind) {
    case SortKind::TS:
      return framework::ir::TopologySortOperations(graph);
    default:
      return framework::ir::TopologyDfsSortOperations(graph);
  }
}

424 425
class DescOrderComparator {
 public:
426 427 428 429 430 431 432 433
  bool operator()(Node *const &n1, Node *const &n2) const {
    if (n1->DescOrder() < n2->DescOrder()) {
      return true;
    } else if (n1->DescOrder() == n2->DescOrder()) {
      return n1->id() < n2->id() ||
             (n1->id() == n2->id() && n1->ToString() < n2->ToString());
    }
    return false;
434 435 436 437
  }
};

std::vector<ir::Node *> TopologySortGraphByDescOrder(const Graph &graph) {
438 439
  std::map<ir::Node *,
           std::set<ir::Node *, DescOrderComparator>,
440 441 442
           DescOrderComparator>
      adj_list = BuildOperationAdjList<DescOrderComparator>(graph);
  PADDLE_ENFORCE_EQ(HasCircleInternal<DescOrderComparator>(adj_list, nullptr),
443 444 445
                    false,
                    platform::errors::InvalidArgument(
                        "Generated graph shouldn't contain cycle."));
446 447 448 449 450
  std::unordered_set<ir::Node *> visited;
  std::vector<ir::Node *> ret;
  for (auto adj : adj_list) {
    if (visited.find(adj.first) == visited.end()) {
      SortHelper<DescOrderComparator>(adj_list, adj.first, &visited, &ret);
451 452 453
    }
  }

454
  return ret;
455 456
}

457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477
void RemoveControlDepInputAndOuput(OpDesc *op_desc) {
  auto remove_control_dep_var = [](VariableNameMap *var_name_map) {
    for (auto &pair : *var_name_map) {
      std::vector<std::string> &var_names = pair.second;
      auto it = var_names.begin();
      while (it != var_names.end()) {
        if (it->find(ir::Node::kControlDepVarName) != std::string::npos) {
          it = var_names.erase(it);
          VLOG(6) << "Remove var " << *it;
        } else {
          ++it;
        }
      }
    }
  };

  remove_control_dep_var(op_desc->MutableInputs());
  remove_control_dep_var(op_desc->MutableOutputs());
  op_desc->Flush();
}

478 479
static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) {
  desc->SetType("fill_constant");
480 481 482 483 484 485 486 487 488
  desc->SetAttr("shape", std::vector<int64_t>({1}));
  desc->SetAttr("value", 1.0f);

  if (node.IsWrappedBy<details::OpHandleBase>()) {
    details::OpHandleBase &op_hander =
        const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();
    desc->SetAttr(
        "dtype",
        dynamic_cast<details::ScaleLossGradOpHandle *>(&op_hander)->DType());
489 490 491
    desc->SetAttr(
        "value",
        dynamic_cast<details::ScaleLossGradOpHandle *>(&op_hander)->Coeff());
492 493 494
  }

  desc->SetAttr("force_cpu", false);
495 496 497
  desc->SetAttr(
      OpProtoAndCheckerMaker::OpRoleAttrName(),
      (static_cast<int>(OpRole::kBackward) | static_cast<int>(OpRole::kLoss)));
498 499
  // TODO(Ruibiao) : Set OpDeviceAttrName when needed

500 501 502 503 504 505 506 507
  std::vector<std::string> output_names;
  for (auto out : node.outputs) {
    output_names.emplace_back(out->Name());
  }
  desc->SetOutput("Out", output_names);
  return desc;
}

508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572
static void ReplaceAllReduceOp(const Node &node,
                               proto::BlockDesc *block,
                               std::vector<OpDesc> *ops) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
  ops->emplace_back();
  auto &desc1 = ops->back();
  std::string name = "fake_coalesce_" + std::to_string(ops->size());
  desc1.SetType("check_memory_continue");

  ops->emplace_back();
  auto &desc2 = ops->back();
  desc2.SetType("c_allreduce_sum");

  if (node.IsWrappedBy<details::OpHandleBase>()) {
    details::OpHandleBase &op_hander =
        const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();

    // set inputs
    auto in_var_handles = op_hander.Inputs();
    std::vector<std::string> in_names;
    for (const auto &in : in_var_handles) {
      if (dynamic_cast<details::DummyVarHandle *>(in) != nullptr) {
        continue;
      }
      in_names.emplace_back(in->Name());
    }
    desc1.SetInput("X", in_names);

    proto::VarDesc var_desc;
    var_desc.set_name(name);
    var_desc.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
    block->mutable_vars()->Add()->CopyFrom(var_desc);
    desc1.SetOutput("Out", {name});
    desc1.SetOutput("XOut", in_names);
    VLOG(4) << "add variable for check_memory_continue: " << name;

    desc2.SetInput("X", {name});
    // set outputs
    auto out_var_handles = op_hander.Outputs();
    std::vector<std::string> out_names;
    for (const auto &out : out_var_handles) {
      if (dynamic_cast<details::DummyVarHandle *>(out) != nullptr) {
        continue;
      }
      out_names.emplace_back(out->Name());
    }
    desc2.SetOutput("Out", {name});

    int ring_id = platform::NCCLCommContext::Instance().GetRingId(
        dynamic_cast<details::NCCLOpHandleBase *>(&op_hander)->GetComm());
    desc2.SetAttr("ring_id", ring_id);
    desc2.SetAttr("use_calc_stream", true);
  }

  desc1.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
                (static_cast<int>(OpRole::kBackward)));
  desc2.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
                (static_cast<int>(OpRole::kBackward)));
#else
  PADDLE_THROW(
      platform::errors::Unimplemented("ReplaceAllReduceOp is only implemented "
                                      "for paddle compiled with NCCL/RCCL."));
#endif
}

573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600
void UpdateControlOpSkipEagerDeletionVars(const Node &node,
                                          const Graph &graph,
                                          const size_t graph_idx,
                                          const std::string &control_type) {
  // Node(zhangbo): SkipEagerDeletionVars pass policy for control flow class op:
  // 1) if op is in main_block: SkipEagerDeletionVars information will be
  // writted into Graph OpNode which wrapped by OpHandleBase; 2) if op is in
  // sub_block: SkipEagerDeletionVars information will be writted into graph's
  // OriginProgram OpDesc. Please refer to
  // FindAllConditionalBlockAndConditionalBlockGradOp in
  // "paddle/fluid/operators/controlflow/conditional_block_op_helper.cc"
  if (graph_idx != 0) {
    auto origin_program = graph.OriginProgram();
    auto &block = origin_program.Block(graph_idx);
    for (size_t j = 0; j < block.OpSize(); ++j) {
      auto *op = block.Op(j);
      if (op->Type() == control_type &&
          op->HasAttr("skip_eager_deletion_vars")) {
        if (op->InputArgumentNames() == node.Op()->InputArgumentNames() &&
            op->OutputArgumentNames() == node.Op()->OutputArgumentNames()) {
          node.Op()->SetAttr("skip_eager_deletion_vars",
                             op->GetAttr("skip_eager_deletion_vars"));
        }
      }
    }
  }
}

601
static void GetGraphOpDesc(const std::vector<Node *> &nodes,
602
                           proto::BlockDesc *block,
603 604 605
                           std::vector<OpDesc> *ops,
                           const Graph &graph,
                           const size_t graph_idx) {
606 607 608 609 610 611 612 613 614 615 616 617 618
  auto is_fused_opt = [](Node *n) -> bool {
    auto op_type = n->Op()->Type();
    auto is_opt =
        (op_type == "adam" || op_type == "momentum" || op_type == "sgd");
    auto input_names = n->Op()->InputArgumentNames();
    auto contains_fused_var = std::any_of(
        input_names.begin(), input_names.end(), [](std::string name) {
          return name.find(details::kFusedVarNamePrefix) != std::string::npos;
        });
    VLOG(4) << is_opt << " " << contains_fused_var;
    return is_opt && contains_fused_var;
  };

619 620 621 622 623 624
  for (Node *n : nodes) {
    // if node is not Op, skip
    if (!n->IsOp()) continue;

    // create fill_constant op
    if (n->Name() == "scale_loss_grad") {
625
      VLOG(4) << "convert op node scale_loss_grad to desc fill_constant";
626 627 628
      ops->emplace_back();
      auto &desc = ops->back();
      ReplaceScaleLossGradOp(*n, &desc);
629 630 631 632
    } else if (n->Name() == "fused_all_reduce") {
      VLOG(4) << "convert op node fused_all_reduce to desc c_allreduce_sum";
      ReplaceAllReduceOp(*n, block, ops);
      VLOG(4) << n->ToString();
633
    } else if (n->Op()) {
634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652
      VLOG(4) << "convert op node to desc " << n->Op()->Type();
      if (is_fused_opt(n)) {
        OpDesc depend_desc(n->Op()->Block());

        std::vector<std::string> deps;
        for (auto in : n->inputs) {
          if (in->IsVar() && !in->IsCtrlVar()) {
            deps.push_back(in->Name());
          }
        }
        depend_desc.SetType("depend");
        depend_desc.SetInput("X",
                             n->Op()->Inputs().at(n->Op()->InputNames()[0]));
        depend_desc.SetInput("Dep", deps);
        depend_desc.SetOutput("Out",
                              n->Op()->Inputs().at(n->Op()->InputNames()[0]));
        ops->emplace_back(depend_desc);
        VLOG(4) << "add depend op";
      }
653 654 655 656 657 658 659
      if (n->Name() == "while" || n->Name() == "while_grad" ||
          n->Name() == "conditional_block" ||
          n->Name() == "conditional_block_grad" || n->Name() == "recurrent" ||
          n->Name() == "recurrent_grad") {
        VLOG(1) << "Update control op attr: skip_eager_deletion_vars";
        UpdateControlOpSkipEagerDeletionVars(*n, graph, graph_idx, n->Name());
      }
660
      ops->emplace_back(*n->Op());
661
      VLOG(4) << n->ToString();
662 663 664 665 666
    }
    // delete no OpDesc op
  }
}

667 668 669 670 671 672 673 674 675 676 677 678
template <class T = Node *>
static void GetGraphVarDesc(const Graph &graph,
                            const std::unordered_set<T> &nodes,
                            std::vector<proto::VarDesc> *vars) {
  for (T node : nodes) {
    if (node->IsVar() && node->Var() &&
        node->GetVarNodeBlockId() == graph.GetBlockId()) {
      vars->emplace_back(*node->Var()->Proto());
    }
  }
}

679 680
static void GraphToBlock(const Graph &graph,
                         proto::BlockDesc *block,
681 682
                         const SortKind *sort_kind,
                         const size_t graph_idx) {
683 684 685 686 687 688 689 690 691
  // Remove the unneeded variables after memory optimization.
  std::unordered_set<std::string> vars2remove;
  if (graph.Has(kGraphToProgramVarsToRemove)) {
    vars2remove =
        graph.Get<std::unordered_set<std::string>>(kGraphToProgramVarsToRemove);
    VLOG(2) << "graph (id: " << block->idx() << ") to program remove "
            << vars2remove.size() << " nodes";
  }

692
  std::vector<proto::VarDesc> vars_in_graph;
693 694 695 696 697
  GetGraphVarDesc<Node *>(graph, graph.Nodes(), &vars_in_graph);
  if (graph.Has(details::kRemovedVars)) {
    auto &removed_vars = graph.Get<details::RemovedVars>(details::kRemovedVars);
    GetGraphVarDesc<std::shared_ptr<ir::Node>>(
        graph, removed_vars, &vars_in_graph);
698 699 700
  }

  // add vars_in_graph to blcok
701 702
  block->clear_vars();
  std::unordered_set<std::string> visited_vars;
703 704 705 706 707 708
  for (proto::VarDesc &var : vars_in_graph) {
    const std::string &var_name = var.name();
    if (visited_vars.find(var_name) == visited_vars.end() &&
        vars2remove.find(var_name) == vars2remove.end()) {
      block->add_vars()->MergeFrom(var);
      visited_vars.insert(var_name);
709 710 711
    }
  }

712
  block->clear_ops();
713 714 715 716 717 718 719 720 721 722 723 724 725
  std::vector<Node *> nodes;
  if (sort_kind != nullptr) {
    // Inference Memory Optimize relays on this branch.
    nodes = TopologyVarientSort(graph, *sort_kind);
  } else {
    if (FLAGS_convert_all_blocks) {
      nodes = TopologySortGraphByDescOrder(graph);
    } else {
      nodes = TopologySortOperations(graph);
    }
  }

  std::vector<OpDesc> ops;
726
  GetGraphOpDesc(nodes, block, &ops, graph, graph_idx);
727

728
  for (auto &op : ops) {
729
    RemoveControlDepInputAndOuput(&op);
730 731 732 733
    block->add_ops()->MergeFrom(*op.Proto());
  }
}

734 735
void GraphToProgram(const Graph &graph,
                    ProgramDesc *program,
736
                    const SortKind *sort_kind) {
737 738
  PADDLE_ENFORCE_EQ(graph.IsMainGraph(),
                    true,
739 740 741 742 743 744 745 746 747 748 749 750 751
                    platform::errors::InvalidArgument(
                        "This graph is a sub_graph, "
                        "and can't convert to program individually"));
  PADDLE_ENFORCE_NOT_NULL(
      program,
      platform::errors::InvalidArgument(
          "program must not be nullptr when converting graph to program"));

  proto::ProgramDesc program_pb(*(program->Proto()));
  auto block = program_pb.mutable_blocks(kRootBlockIndex);
  block->set_idx(kRootBlockIndex);

  if (FLAGS_convert_all_blocks) {
752 753 754 755
    GraphToBlock(*graph.GetSubGraph(kRootBlockIndex),
                 block,
                 sort_kind,
                 graph.GetSubGraph(kRootBlockIndex)->GetBlockId());
756 757 758 759 760 761 762 763 764

    VLOG(3) << "Graph to program need convert " << graph.SubGraphsSize()
            << " sub graph";
    for (size_t idx = 0; idx < graph.SubGraphsSize(); ++idx) {
      // avoid kRootBlockIndex not 0
      if (idx == kRootBlockIndex) continue;

      block = program_pb.add_blocks();
      block->set_idx(idx);
765
      block->set_parent_idx(kRootBlockIndex);
766 767 768 769
      GraphToBlock(*graph.GetSubGraph(idx),
                   block,
                   sort_kind,
                   graph.GetSubGraph(idx)->GetBlockId());
770 771
    }
  } else {
772
    GraphToBlock(graph, block, sort_kind, graph.GetBlockId());
773 774 775
  }

  program->CopyFrom(program_pb);
776 777 778 779 780 781 782

  if (graph.Has(details::kProgramDescs)) {
    details::ProgramDescs program_descs =
        graph.Get<details::ProgramDescs>(details::kProgramDescs);
    VLOG(8) << "Merge main programs";
    MergePrograms(program, program_descs, /*append=*/false);
  }
783 784
}

785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842
static std::vector<std::vector<ir::Node::Dep>> GetOpDependencies(
    const BlockDesc &block, const std::unordered_set<ir::Node *> &nodes) {
  auto block_ops = block.AllOps();
  size_t op_num = block_ops.size();
  std::unordered_map<const ir::Node *, std::unordered_set<const ir::Node *>>
      preceding_ops(op_num);
  std::unordered_map<const ir::Node *, size_t> preceding_deps(op_num);
  std::unordered_map<const ir::Node *, std::unordered_set<const ir::Node *>>
      pending_ops(op_num);

  std::queue<const ir::Node *> ready_ops;
  for (const auto *node : nodes) {
    if (!node->IsOp()) continue;

    auto &tmp_preceding_ops = preceding_ops[node];
    for (const auto *in_var : node->inputs) {
      for (const auto *in_op : in_var->inputs) {
        tmp_preceding_ops.insert(in_op);
      }
    }
    if (tmp_preceding_ops.empty()) {
      ready_ops.push(node);
    }
    preceding_deps[node] = tmp_preceding_ops.size();

    auto &tmp_pending_ops = pending_ops[node];
    for (const auto *out_var : node->outputs) {
      for (const auto *out_op : out_var->outputs) {
        tmp_pending_ops.insert(out_op);
      }
    }
  }

  std::unordered_map<const ir::Node *, std::unordered_set<const ir::Node *>>
      all_preceding_ops;
  while (!ready_ops.empty()) {
    const auto *cur_op = ready_ops.front();
    ready_ops.pop();

    auto &all_preceding_ops_of_cur_op = all_preceding_ops[cur_op];
    for (const auto *preceding_op : preceding_ops.at(cur_op)) {
      all_preceding_ops_of_cur_op.insert(preceding_op);
      auto &prev_preceding_ops = all_preceding_ops[preceding_op];
      all_preceding_ops_of_cur_op.insert(prev_preceding_ops.begin(),
                                         prev_preceding_ops.end());
    }

    for (const auto *pending_op : pending_ops.at(cur_op)) {
      if (--preceding_deps.at(pending_op) == 0) {
        ready_ops.push(pending_op);
      }
    }
  }

  std::unordered_map<uint64_t, size_t> op_id_to_idx(op_num);
  for (const auto *op_desc : block_ops) {
    size_t op_idx = op_id_to_idx.size();
    PADDLE_ENFORCE_EQ(
843 844
        op_id_to_idx.emplace(op_desc->OriginalId(), op_idx).second,
        true,
845
        platform::errors::InvalidArgument(
S
sneaxiy 已提交
846
            "There should not be duplicate op id: %d", op_desc->OriginalId()));
847 848 849 850 851 852 853 854 855 856
  }

  std::vector<std::vector<ir::Node::Dep>> dep_matrix(op_num);
  for (size_t i = 0; i < op_num; ++i) {
    dep_matrix[i].resize(op_num, ir::Node::Dep::kNoDep);
    dep_matrix[i][i] = ir::Node::Dep::kSame;
  }

  auto get_op_idx_by_id = [&op_id_to_idx](uint64_t op_id) {
    auto iter = op_id_to_idx.find(op_id);
857 858
    PADDLE_ENFORCE_NE(iter,
                      op_id_to_idx.end(),
859 860 861 862 863 864 865
                      platform::errors::InvalidArgument(
                          "Cannot find OpDesc with id %d", op_id));
    return iter->second;
  };

  for (const auto &pair : all_preceding_ops) {
    const auto *cur_op_node = pair.first;
S
sneaxiy 已提交
866
    size_t op_idx_1 = get_op_idx_by_id(cur_op_node->Op()->OriginalId());
867
    for (const auto *preceding_op_node : pair.second) {
S
sneaxiy 已提交
868
      size_t op_idx_2 = get_op_idx_by_id(preceding_op_node->Op()->OriginalId());
869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888
      dep_matrix[op_idx_1][op_idx_2] = ir::Node::Dep::kAfter;
      dep_matrix[op_idx_2][op_idx_1] = ir::Node::Dep::kBefore;
    }
  }
  return dep_matrix;
}

std::vector<std::vector<std::vector<ir::Node::Dep>>> GetOpDependencies(
    const ProgramDesc &program) {
  ir::Graph graph(program);
  size_t block_num = program.Size();
  std::vector<std::vector<std::vector<ir::Node::Dep>>> deps;
  deps.reserve(block_num);
  for (size_t i = 0; i < block_num; ++i) {
    deps.emplace_back(
        GetOpDependencies(program.Block(i), graph.GetSubGraph(i)->Nodes()));
  }
  return deps;
}

X
better  
Xin Pan 已提交
889 890 891
}  // namespace ir
}  // namespace framework
}  // namespace paddle