graph_helper.cc 26.5 KB
Newer Older
X
better  
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduo 已提交
15
#include "paddle/fluid/framework/ir/graph_helper.h"
16

17
#include <queue>
Y
Yan Chunwei 已提交
18
#include <stack>
19

20
#include "paddle/fluid/framework/details/multi_devices_helper.h"
21
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
22
#include "paddle/fluid/framework/ir/pass.h"
23
#include "paddle/fluid/framework/op_proto_maker.h"
24
#include "paddle/fluid/framework/program_utils.h"
X
better  
Xin Pan 已提交
25

26
DECLARE_bool(convert_all_blocks);
27 28
PADDLE_DEFINE_EXPORTED_string(print_sub_graph_dir,
                              "",
29 30
                              "FLAGS_print_sub_graph_dir is used "
                              "to print the nodes of sub_graphs.");
C
chengduo 已提交
31

X
better  
Xin Pan 已提交
32 33 34 35
namespace paddle {
namespace framework {
namespace ir {
namespace {
36 37

template <class NodeComparator = ir::NodeComp>
38 39
void SortHelper(const std::map<ir::Node *,
                               std::set<ir::Node *, NodeComparator>,
40
                               NodeComparator> &adj_list,
41 42
                ir::Node *node,
                std::unordered_set<ir::Node *> *visited,
43
                std::vector<ir::Node *> *ret) {
X
better  
Xin Pan 已提交
44 45 46 47
  visited->insert(node);

  for (auto adj : adj_list.at(node)) {
    if (visited->find(adj) == visited->end()) {
48
      SortHelper<NodeComparator>(adj_list, adj, visited, ret);
X
better  
Xin Pan 已提交
49 50 51
    }
  }

Y
Yan Chunwei 已提交
52
  VLOG(5) << "topology sort insert: " << node->Name() << " "
M
minqiyang 已提交
53
          << reinterpret_cast<void *>(node) << " input " << node->inputs.size();
X
better  
Xin Pan 已提交
54 55 56
  ret->push_back(node);
}

57
template <class NodeComparator = ir::NodeComp>
58 59 60 61 62 63 64
bool HasCircleHelper(ir::Node *node,
                     const std::map<ir::Node *,
                                    std::set<ir::Node *, NodeComparator>,
                                    NodeComparator> &adj_list,
                     std::unordered_set<ir::Node *> *visited,
                     std::unordered_set<ir::Node *> *in_trace,
                     std::vector<std::vector<ir::Node *>> *circles) {
X
better  
Xin Pan 已提交
65 66 67 68 69 70
  if (visited->find(node) == visited->end()) {
    visited->insert(node);
    in_trace->insert(node);

    for (ir::Node *in : adj_list.at(node)) {
      if (visited->find(in) == visited->end() &&
71 72
          HasCircleHelper<NodeComparator>(
              in, adj_list, visited, in_trace, circles)) {
X
better  
Xin Pan 已提交
73 74
        return true;
      } else if (in_trace->find(in) != in_trace->end()) {
D
dzhwinter 已提交
75 76 77 78 79 80 81 82 83 84 85 86
        if (circles != nullptr) {
          std::vector<ir::Node *> circle;
          circle.emplace_back(in);
          ir::Node *p = in;
          for (auto &adj : adj_list.at(p)) {
            if (in_trace->count(adj)) {
              circle.emplace_back(adj);
              p = adj;
            }
          }
          circles->emplace_back(circle);
        }
X
better  
Xin Pan 已提交
87 88 89 90 91 92 93 94
        return true;
      }
    }
  }
  in_trace->erase(node);
  return false;
}

95
template <class NodeComparator = ir::NodeComp>
96 97 98 99
bool HasCircleInternal(const std::map<ir::Node *,
                                      std::set<ir::Node *, NodeComparator>,
                                      NodeComparator> &adj_list,
                       std::vector<std::vector<ir::Node *>> *circles) {
X
better  
Xin Pan 已提交
100 101 102
  std::unordered_set<ir::Node *> visited;
  std::unordered_set<ir::Node *> in_trace;
  for (auto &adj : adj_list) {
103 104
    if (HasCircleHelper<NodeComparator>(
            adj.first, adj_list, &visited, &in_trace, circles)) {
X
better  
Xin Pan 已提交
105 106 107 108 109
      return true;
    }
  }
  return false;
}
X
Xin Pan 已提交
110 111 112
}  // namespace

bool HasCircle(const Graph &graph) {
D
dzhwinter 已提交
113 114 115
  return HasCircleInternal(BuildOperationAdjList(graph), nullptr);
}

116 117 118 119 120 121 122 123 124 125
bool VarDescIsConsistency(const Graph &graph) {
  std::unordered_map<std::string, std::unordered_set<ir::Node *>>
      var_name2node_set;
  for (ir::Node *node : graph.Nodes()) {
    if (node->IsVar() && node->Var()) {
      var_name2node_set[node->Var()->Name()].emplace(node);
    }
  }
  for (auto &iter : var_name2node_set) {
    auto &first_node = *iter.second.begin();
126 127
    bool is_persistable = std::any_of(iter.second.begin(),
                                      iter.second.end(),
128 129 130 131 132
                                      [&first_node](const ir::Node *node) {
                                        return node->Var()->Persistable();
                                      });
    if (is_persistable) {
      bool is_consistency =
133 134
          std::all_of(iter.second.begin(),
                      iter.second.end(),
135 136 137 138 139 140 141 142
                      [&first_node](const ir::Node *node) {
                        return *node->Var() == *first_node->Var();
                      });
      if (!is_consistency) return false;
    }
  }
  return true;
}
D
dzhwinter 已提交
143 144 145
bool FindCircleSubGraph(const Graph &graph,
                        std::vector<std::vector<ir::Node *>> *circles) {
  return HasCircleInternal(BuildOperationAdjList(graph), circles);
X
Xin Pan 已提交
146
}
X
better  
Xin Pan 已提交
147

X
Xin Pan 已提交
148
std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
149 150
  std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
      adj_list = BuildOperationAdjList(graph);
151 152
  PADDLE_ENFORCE_EQ(HasCircleInternal(adj_list, nullptr),
                    false,
153 154
                    platform::errors::InvalidArgument(
                        "Generated graph shouldn't contain cycle."));
X
better  
Xin Pan 已提交
155 156 157 158
  std::unordered_set<ir::Node *> visited;
  std::vector<ir::Node *> ret;
  for (auto adj : adj_list) {
    if (visited.find(adj.first) == visited.end()) {
159
      SortHelper<ir::NodeComp>(adj_list, adj.first, &visited, &ret);
X
better  
Xin Pan 已提交
160 161
    }
  }
162

X
better  
Xin Pan 已提交
163 164 165
  return ret;
}

Z
Zeng Jinle 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
bool IsTopologySortOperationsUnique(const Graph &graph) {
  auto nodes = TopologySortOperations(graph);
  size_t n = nodes.size();
  for (size_t i = 1; i < n; ++i) {
    auto *prev_op = nodes[i - 1];
    auto *cur_op = nodes[i];

    std::unordered_set<Node *> prev_op_outputs;
    for (auto *output : prev_op->outputs) {
      prev_op_outputs.insert(output);
    }

    bool found = false;
    for (auto *input : cur_op->inputs) {
      if (prev_op_outputs.count(input) > 0) {
        found = true;
        break;
      }
    }
    if (!found) {
      return false;
    }
  }
  return true;
}

Y
Yan Chunwei 已提交
192 193 194 195 196 197 198 199 200 201 202 203
// Build operator outlink edge table.
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationOutAdjList(
    const Graph &graph) {
  std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;

  for (auto &n : graph.Nodes()) {
    if (!n->IsOp()) continue;
    if (adj_list.find(n) == adj_list.end()) {
      adj_list[n] = std::unordered_set<ir::Node *>();
    }
    for (auto &var : n->outputs) {
      for (auto &adj_n : var->outputs) {
204 205 206 207 208 209
        PADDLE_ENFORCE_EQ(adj_n->NodeType(),
                          ir::Node::Type::kOperation,
                          platform::errors::InvalidArgument(
                              "Node(%s)'s type(%d) must be kOperation type.",
                              adj_n->Name(),
                              static_cast<int>(adj_n->NodeType())));
Y
Yan Chunwei 已提交
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
        VLOG(40) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
                 << " -> " << n->Name() << reinterpret_cast<void *>(n)
                 << "  via " << var->Name() << reinterpret_cast<void *>(var);
        adj_list[n].insert(adj_n);
      }
    }
  }
  return adj_list;
}

std::vector<ir::Node *> OpDFSSort(const Graph &graph) {
  auto edge_table = BuildOperationOutAdjList(graph);
  std::stack<Node *> stack;
  for (auto &ele : edge_table) {
    if (ele.first->inputs.empty()) {
      // find the input ops (those without input vars)
      stack.push(ele.first);
    } else {
      // find the ops with only persistable vars as inputs.
      bool all_persistable = true;
      for (auto *input : ele.first->inputs) {
        if (!(input->IsVar() && input->Var() && input->Var()->Persistable())) {
          all_persistable = false;
        }
      }
      if (all_persistable) {
        stack.push(ele.first);
      }
    }
  }

  std::vector<Node *> res;
  // start from the feed op and DFS
  std::unordered_set<Node *> unique_set;
  while (!stack.empty()) {
    // will start from the last feed by default.
    auto cur = stack.top();
    stack.pop();
    unique_set.insert(cur);
    res.push_back(cur);

    for (auto *op : edge_table[cur]) {
      if (!unique_set.count(op)) {
        stack.push(op);
      }
    }
  }
  return res;
}

std::vector<ir::Node *> TopologyDfsSortOperations(const Graph &graph) {
  std::vector<ir::Node *> nodes;
  std::unordered_map<Node *, int> in_degree;

  auto set_out_ops_ready = [&](Node *var) {
    for (auto *op : var->outputs) {
      --in_degree[op];
    }
  };
  // build in_degree
  for (auto *node : graph.Nodes()) {
    if (node->IsOp()) {
      in_degree[node] += node->inputs.size();
    } else if (node->IsVar() && node->inputs.empty()) {
      // put all the inputs of the whole graph ready.
      set_out_ops_ready(node);
    }
  }

  std::deque<Node *> op_queue;
  // first visit
  for (auto &node : OpDFSSort(graph)) {
    if (node->IsOp()) {
      op_queue.push_back(node);
    }
  }

  // traverse the graph
  int num_ops = op_queue.size();
  while (num_ops) {
    for (auto it = op_queue.begin(); it != op_queue.end(); it++) {
      auto *&cur_op = *it;
      if (!cur_op || in_degree[cur_op] > 0) continue;
      // visit this node
      // put all the output var of this op valid.
      for (auto *out_var : cur_op->outputs) {
        if (!out_var) continue;
        set_out_ops_ready(out_var);
      }
      VLOG(8) << "visit " << cur_op->Name();
      nodes.push_back(cur_op);

      cur_op = nullptr;
      num_ops--;
    }
  }

  return nodes;
}

C
chengduo 已提交
310
size_t GraphNum(const Graph &graph) {
D
dzhwinter 已提交
311
  std::unordered_set<ir::Node *> nodes(graph.Nodes());
C
chengduo 已提交
312 313 314 315 316
  std::unordered_set<ir::Node *> visited_nodes;
  visited_nodes.reserve(nodes.size());
  std::deque<ir::Node *> q_nodes;
  std::vector<std::unordered_set<ir::Node *>> graph_nodes;
  std::unordered_set<ir::Node *> g_nodes;
W
Wu Yi 已提交
317 318
  // q_set used to record records in the queue.
  std::unordered_set<ir::Node *> q_set;
C
chengduo 已提交
319 320
  size_t graph_count = 0;

321 322 323 324 325 326 327 328 329
  auto traverse_nodes =
      [&visited_nodes, &q_nodes, &q_set](const std::vector<ir::Node *> &nodes) {
        for (auto n : nodes) {
          if (visited_nodes.count(n) == 0 && q_set.count(n) == 0) {
            q_nodes.push_back(n);
            q_set.insert(n);
          }
        }
      };
C
chengduo 已提交
330 331 332 333 334

  while (visited_nodes.size() != nodes.size()) {
    if (!q_nodes.empty()) {
      auto cur_node = q_nodes.front();
      q_nodes.pop_front();
W
Wu Yi 已提交
335
      q_set.erase(cur_node);
C
chengduo 已提交
336 337 338 339 340 341 342 343 344 345 346 347 348
      visited_nodes.insert(cur_node);
      g_nodes.insert(cur_node);
      traverse_nodes(cur_node->inputs);
      traverse_nodes(cur_node->outputs);
    } else {
      ++graph_count;
      if (g_nodes.size()) {
        graph_nodes.emplace_back(g_nodes);
      }
      g_nodes.clear();
      for (auto &n : nodes) {
        if (visited_nodes.count(n) == 0) {
          q_nodes.push_back(n);
W
Wu Yi 已提交
349
          q_set.insert(n);
C
chengduo 已提交
350 351 352 353 354 355 356 357 358 359
          break;
        }
      }
    }
  }

  if (g_nodes.size()) {
    graph_nodes.emplace_back(g_nodes);
  }

C
chengduo 已提交
360 361 362 363 364 365 366 367 368
  if (FLAGS_print_sub_graph_dir.size()) {
    if (graph_nodes.size() > 1) {
      std::stringstream out;
      for (auto &g_n : graph_nodes) {
        out << "graph_nodes: " << g_n.size() << "\n";
      }
      out << "\n\n";
      for (auto &g_n : graph_nodes) {
        out << "graph_nodes: " << g_n.size();
C
chengduo 已提交
369 370 371 372 373 374 375 376 377 378 379
        for (auto &node : g_n) {
          out << "\nNode: " << node->Name() << " in [";
          for (auto &n : node->inputs) {
            out << n->Name() << ", ";
          }
          out << "], out[";
          for (auto &n : node->outputs) {
            out << n->Name() << ", ";
          }
          out << "]";
        }
C
chengduo 已提交
380
        out << "\n\n\n";
C
chengduo 已提交
381
      }
C
chengduo 已提交
382 383
      std::unique_ptr<std::ostream> fout(
          new std::ofstream(FLAGS_print_sub_graph_dir));
384 385
      PADDLE_ENFORCE_EQ(fout->good(),
                        true,
386 387 388
                        platform::errors::Unavailable(
                            "Can not open file %s for printing the graph.",
                            FLAGS_print_sub_graph_dir));
C
chengduo 已提交
389
      *fout << out.str();
C
chengduo 已提交
390 391 392 393 394 395
    }
  }

  return graph_count;
}

Y
Yan Chunwei 已提交
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
void CleanIndividualNodes(Graph *graph) {
  std::unordered_set<Node *> nodes2rm;
  for (auto *node : graph->Nodes()) {
    if (node->inputs.empty() && node->outputs.empty()) {
      nodes2rm.insert(node);
    }
  }

  for (auto *node : nodes2rm) {
    graph->RemoveNode(node);
  }
}

std::vector<Node *> TopologyVarientSort(const Graph &graph,
                                        SortKind sort_kind) {
  switch (sort_kind) {
    case SortKind::TS:
      return framework::ir::TopologySortOperations(graph);
    default:
      return framework::ir::TopologyDfsSortOperations(graph);
  }
}

419 420
class DescOrderComparator {
 public:
421 422 423 424 425 426 427 428
  bool operator()(Node *const &n1, Node *const &n2) const {
    if (n1->DescOrder() < n2->DescOrder()) {
      return true;
    } else if (n1->DescOrder() == n2->DescOrder()) {
      return n1->id() < n2->id() ||
             (n1->id() == n2->id() && n1->ToString() < n2->ToString());
    }
    return false;
429 430 431 432
  }
};

std::vector<ir::Node *> TopologySortGraphByDescOrder(const Graph &graph) {
433 434
  std::map<ir::Node *,
           std::set<ir::Node *, DescOrderComparator>,
435 436 437
           DescOrderComparator>
      adj_list = BuildOperationAdjList<DescOrderComparator>(graph);
  PADDLE_ENFORCE_EQ(HasCircleInternal<DescOrderComparator>(adj_list, nullptr),
438 439 440
                    false,
                    platform::errors::InvalidArgument(
                        "Generated graph shouldn't contain cycle."));
441 442 443 444 445
  std::unordered_set<ir::Node *> visited;
  std::vector<ir::Node *> ret;
  for (auto adj : adj_list) {
    if (visited.find(adj.first) == visited.end()) {
      SortHelper<DescOrderComparator>(adj_list, adj.first, &visited, &ret);
446 447 448
    }
  }

449
  return ret;
450 451
}

452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472
void RemoveControlDepInputAndOuput(OpDesc *op_desc) {
  auto remove_control_dep_var = [](VariableNameMap *var_name_map) {
    for (auto &pair : *var_name_map) {
      std::vector<std::string> &var_names = pair.second;
      auto it = var_names.begin();
      while (it != var_names.end()) {
        if (it->find(ir::Node::kControlDepVarName) != std::string::npos) {
          it = var_names.erase(it);
          VLOG(6) << "Remove var " << *it;
        } else {
          ++it;
        }
      }
    }
  };

  remove_control_dep_var(op_desc->MutableInputs());
  remove_control_dep_var(op_desc->MutableOutputs());
  op_desc->Flush();
}

473 474
static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) {
  desc->SetType("fill_constant");
475 476 477 478 479 480 481 482 483 484 485 486
  desc->SetAttr("shape", std::vector<int64_t>({1}));
  desc->SetAttr("value", 1.0f);

  if (node.IsWrappedBy<details::OpHandleBase>()) {
    details::OpHandleBase &op_hander =
        const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();
    desc->SetAttr(
        "dtype",
        dynamic_cast<details::ScaleLossGradOpHandle *>(&op_hander)->DType());
  }

  desc->SetAttr("force_cpu", false);
487 488 489
  desc->SetAttr(
      OpProtoAndCheckerMaker::OpRoleAttrName(),
      (static_cast<int>(OpRole::kBackward) | static_cast<int>(OpRole::kLoss)));
490 491
  // TODO(Ruibiao) : Set OpDeviceAttrName when needed

492 493 494 495 496 497 498 499
  std::vector<std::string> output_names;
  for (auto out : node.outputs) {
    output_names.emplace_back(out->Name());
  }
  desc->SetOutput("Out", output_names);
  return desc;
}

500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527
void UpdateControlOpSkipEagerDeletionVars(const Node &node,
                                          const Graph &graph,
                                          const size_t graph_idx,
                                          const std::string &control_type) {
  // Node(zhangbo): SkipEagerDeletionVars pass policy for control flow class op:
  // 1) if op is in main_block: SkipEagerDeletionVars information will be
  // writted into Graph OpNode which wrapped by OpHandleBase; 2) if op is in
  // sub_block: SkipEagerDeletionVars information will be writted into graph's
  // OriginProgram OpDesc. Please refer to
  // FindAllConditionalBlockAndConditionalBlockGradOp in
  // "paddle/fluid/operators/controlflow/conditional_block_op_helper.cc"
  if (graph_idx != 0) {
    auto origin_program = graph.OriginProgram();
    auto &block = origin_program.Block(graph_idx);
    for (size_t j = 0; j < block.OpSize(); ++j) {
      auto *op = block.Op(j);
      if (op->Type() == control_type &&
          op->HasAttr("skip_eager_deletion_vars")) {
        if (op->InputArgumentNames() == node.Op()->InputArgumentNames() &&
            op->OutputArgumentNames() == node.Op()->OutputArgumentNames()) {
          node.Op()->SetAttr("skip_eager_deletion_vars",
                             op->GetAttr("skip_eager_deletion_vars"));
        }
      }
    }
  }
}

528
static void GetGraphOpDesc(const std::vector<Node *> &nodes,
529 530 531
                           std::vector<OpDesc> *ops,
                           const Graph &graph,
                           const size_t graph_idx) {
532 533 534 535 536 537 538 539 540 541 542 543 544
  auto is_fused_opt = [](Node *n) -> bool {
    auto op_type = n->Op()->Type();
    auto is_opt =
        (op_type == "adam" || op_type == "momentum" || op_type == "sgd");
    auto input_names = n->Op()->InputArgumentNames();
    auto contains_fused_var = std::any_of(
        input_names.begin(), input_names.end(), [](std::string name) {
          return name.find(details::kFusedVarNamePrefix) != std::string::npos;
        });
    VLOG(4) << is_opt << " " << contains_fused_var;
    return is_opt && contains_fused_var;
  };

545 546 547 548 549 550
  for (Node *n : nodes) {
    // if node is not Op, skip
    if (!n->IsOp()) continue;

    // create fill_constant op
    if (n->Name() == "scale_loss_grad") {
551
      VLOG(4) << "convert op node scale_loss_grad to desc fill_constant";
552 553 554 555
      ops->emplace_back();
      auto &desc = ops->back();
      ReplaceScaleLossGradOp(*n, &desc);
    } else if (n->Op()) {
556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574
      VLOG(4) << "convert op node to desc " << n->Op()->Type();
      if (is_fused_opt(n)) {
        OpDesc depend_desc(n->Op()->Block());

        std::vector<std::string> deps;
        for (auto in : n->inputs) {
          if (in->IsVar() && !in->IsCtrlVar()) {
            deps.push_back(in->Name());
          }
        }
        depend_desc.SetType("depend");
        depend_desc.SetInput("X",
                             n->Op()->Inputs().at(n->Op()->InputNames()[0]));
        depend_desc.SetInput("Dep", deps);
        depend_desc.SetOutput("Out",
                              n->Op()->Inputs().at(n->Op()->InputNames()[0]));
        ops->emplace_back(depend_desc);
        VLOG(4) << "add depend op";
      }
575 576 577 578 579 580 581
      if (n->Name() == "while" || n->Name() == "while_grad" ||
          n->Name() == "conditional_block" ||
          n->Name() == "conditional_block_grad" || n->Name() == "recurrent" ||
          n->Name() == "recurrent_grad") {
        VLOG(1) << "Update control op attr: skip_eager_deletion_vars";
        UpdateControlOpSkipEagerDeletionVars(*n, graph, graph_idx, n->Name());
      }
582
      ops->emplace_back(*n->Op());
583
      VLOG(4) << n->ToString();
584 585 586 587 588
    }
    // delete no OpDesc op
  }
}

589 590 591 592 593 594 595 596 597 598 599 600
template <class T = Node *>
static void GetGraphVarDesc(const Graph &graph,
                            const std::unordered_set<T> &nodes,
                            std::vector<proto::VarDesc> *vars) {
  for (T node : nodes) {
    if (node->IsVar() && node->Var() &&
        node->GetVarNodeBlockId() == graph.GetBlockId()) {
      vars->emplace_back(*node->Var()->Proto());
    }
  }
}

601 602
static void GraphToBlock(const Graph &graph,
                         proto::BlockDesc *block,
603 604
                         const SortKind *sort_kind,
                         const size_t graph_idx) {
605 606 607 608 609 610 611 612 613
  // Remove the unneeded variables after memory optimization.
  std::unordered_set<std::string> vars2remove;
  if (graph.Has(kGraphToProgramVarsToRemove)) {
    vars2remove =
        graph.Get<std::unordered_set<std::string>>(kGraphToProgramVarsToRemove);
    VLOG(2) << "graph (id: " << block->idx() << ") to program remove "
            << vars2remove.size() << " nodes";
  }

614
  std::vector<proto::VarDesc> vars_in_graph;
615 616 617 618 619
  GetGraphVarDesc<Node *>(graph, graph.Nodes(), &vars_in_graph);
  if (graph.Has(details::kRemovedVars)) {
    auto &removed_vars = graph.Get<details::RemovedVars>(details::kRemovedVars);
    GetGraphVarDesc<std::shared_ptr<ir::Node>>(
        graph, removed_vars, &vars_in_graph);
620 621 622
  }

  // add vars_in_graph to blcok
623 624
  block->clear_vars();
  std::unordered_set<std::string> visited_vars;
625 626 627 628 629 630
  for (proto::VarDesc &var : vars_in_graph) {
    const std::string &var_name = var.name();
    if (visited_vars.find(var_name) == visited_vars.end() &&
        vars2remove.find(var_name) == vars2remove.end()) {
      block->add_vars()->MergeFrom(var);
      visited_vars.insert(var_name);
631 632 633
    }
  }

634
  block->clear_ops();
635 636 637 638 639 640 641 642 643 644 645 646 647
  std::vector<Node *> nodes;
  if (sort_kind != nullptr) {
    // Inference Memory Optimize relays on this branch.
    nodes = TopologyVarientSort(graph, *sort_kind);
  } else {
    if (FLAGS_convert_all_blocks) {
      nodes = TopologySortGraphByDescOrder(graph);
    } else {
      nodes = TopologySortOperations(graph);
    }
  }

  std::vector<OpDesc> ops;
648
  GetGraphOpDesc(nodes, &ops, graph, graph_idx);
649

650
  for (auto &op : ops) {
651
    RemoveControlDepInputAndOuput(&op);
652 653 654 655
    block->add_ops()->MergeFrom(*op.Proto());
  }
}

656 657
void GraphToProgram(const Graph &graph,
                    ProgramDesc *program,
658
                    const SortKind *sort_kind) {
659 660
  PADDLE_ENFORCE_EQ(graph.IsMainGraph(),
                    true,
661 662 663 664 665 666 667 668 669 670 671 672 673
                    platform::errors::InvalidArgument(
                        "This graph is a sub_graph, "
                        "and can't convert to program individually"));
  PADDLE_ENFORCE_NOT_NULL(
      program,
      platform::errors::InvalidArgument(
          "program must not be nullptr when converting graph to program"));

  proto::ProgramDesc program_pb(*(program->Proto()));
  auto block = program_pb.mutable_blocks(kRootBlockIndex);
  block->set_idx(kRootBlockIndex);

  if (FLAGS_convert_all_blocks) {
674 675 676 677
    GraphToBlock(*graph.GetSubGraph(kRootBlockIndex),
                 block,
                 sort_kind,
                 graph.GetSubGraph(kRootBlockIndex)->GetBlockId());
678 679 680 681 682 683 684 685 686

    VLOG(3) << "Graph to program need convert " << graph.SubGraphsSize()
            << " sub graph";
    for (size_t idx = 0; idx < graph.SubGraphsSize(); ++idx) {
      // avoid kRootBlockIndex not 0
      if (idx == kRootBlockIndex) continue;

      block = program_pb.add_blocks();
      block->set_idx(idx);
687
      block->set_parent_idx(kRootBlockIndex);
688 689 690 691
      GraphToBlock(*graph.GetSubGraph(idx),
                   block,
                   sort_kind,
                   graph.GetSubGraph(idx)->GetBlockId());
692 693
    }
  } else {
694
    GraphToBlock(graph, block, sort_kind, graph.GetBlockId());
695 696 697
  }

  program->CopyFrom(program_pb);
698 699 700 701 702 703 704

  if (graph.Has(details::kProgramDescs)) {
    details::ProgramDescs program_descs =
        graph.Get<details::ProgramDescs>(details::kProgramDescs);
    VLOG(8) << "Merge main programs";
    MergePrograms(program, program_descs, /*append=*/false);
  }
705 706
}

707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764
static std::vector<std::vector<ir::Node::Dep>> GetOpDependencies(
    const BlockDesc &block, const std::unordered_set<ir::Node *> &nodes) {
  auto block_ops = block.AllOps();
  size_t op_num = block_ops.size();
  std::unordered_map<const ir::Node *, std::unordered_set<const ir::Node *>>
      preceding_ops(op_num);
  std::unordered_map<const ir::Node *, size_t> preceding_deps(op_num);
  std::unordered_map<const ir::Node *, std::unordered_set<const ir::Node *>>
      pending_ops(op_num);

  std::queue<const ir::Node *> ready_ops;
  for (const auto *node : nodes) {
    if (!node->IsOp()) continue;

    auto &tmp_preceding_ops = preceding_ops[node];
    for (const auto *in_var : node->inputs) {
      for (const auto *in_op : in_var->inputs) {
        tmp_preceding_ops.insert(in_op);
      }
    }
    if (tmp_preceding_ops.empty()) {
      ready_ops.push(node);
    }
    preceding_deps[node] = tmp_preceding_ops.size();

    auto &tmp_pending_ops = pending_ops[node];
    for (const auto *out_var : node->outputs) {
      for (const auto *out_op : out_var->outputs) {
        tmp_pending_ops.insert(out_op);
      }
    }
  }

  std::unordered_map<const ir::Node *, std::unordered_set<const ir::Node *>>
      all_preceding_ops;
  while (!ready_ops.empty()) {
    const auto *cur_op = ready_ops.front();
    ready_ops.pop();

    auto &all_preceding_ops_of_cur_op = all_preceding_ops[cur_op];
    for (const auto *preceding_op : preceding_ops.at(cur_op)) {
      all_preceding_ops_of_cur_op.insert(preceding_op);
      auto &prev_preceding_ops = all_preceding_ops[preceding_op];
      all_preceding_ops_of_cur_op.insert(prev_preceding_ops.begin(),
                                         prev_preceding_ops.end());
    }

    for (const auto *pending_op : pending_ops.at(cur_op)) {
      if (--preceding_deps.at(pending_op) == 0) {
        ready_ops.push(pending_op);
      }
    }
  }

  std::unordered_map<uint64_t, size_t> op_id_to_idx(op_num);
  for (const auto *op_desc : block_ops) {
    size_t op_idx = op_id_to_idx.size();
    PADDLE_ENFORCE_EQ(
765 766
        op_id_to_idx.emplace(op_desc->OriginalId(), op_idx).second,
        true,
767
        platform::errors::InvalidArgument(
S
sneaxiy 已提交
768
            "There should not be duplicate op id: %d", op_desc->OriginalId()));
769 770 771 772 773 774 775 776 777 778
  }

  std::vector<std::vector<ir::Node::Dep>> dep_matrix(op_num);
  for (size_t i = 0; i < op_num; ++i) {
    dep_matrix[i].resize(op_num, ir::Node::Dep::kNoDep);
    dep_matrix[i][i] = ir::Node::Dep::kSame;
  }

  auto get_op_idx_by_id = [&op_id_to_idx](uint64_t op_id) {
    auto iter = op_id_to_idx.find(op_id);
779 780
    PADDLE_ENFORCE_NE(iter,
                      op_id_to_idx.end(),
781 782 783 784 785 786 787
                      platform::errors::InvalidArgument(
                          "Cannot find OpDesc with id %d", op_id));
    return iter->second;
  };

  for (const auto &pair : all_preceding_ops) {
    const auto *cur_op_node = pair.first;
S
sneaxiy 已提交
788
    size_t op_idx_1 = get_op_idx_by_id(cur_op_node->Op()->OriginalId());
789
    for (const auto *preceding_op_node : pair.second) {
S
sneaxiy 已提交
790
      size_t op_idx_2 = get_op_idx_by_id(preceding_op_node->Op()->OriginalId());
791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810
      dep_matrix[op_idx_1][op_idx_2] = ir::Node::Dep::kAfter;
      dep_matrix[op_idx_2][op_idx_1] = ir::Node::Dep::kBefore;
    }
  }
  return dep_matrix;
}

std::vector<std::vector<std::vector<ir::Node::Dep>>> GetOpDependencies(
    const ProgramDesc &program) {
  ir::Graph graph(program);
  size_t block_num = program.Size();
  std::vector<std::vector<std::vector<ir::Node::Dep>>> deps;
  deps.reserve(block_num);
  for (size_t i = 0; i < block_num; ++i) {
    deps.emplace_back(
        GetOpDependencies(program.Block(i), graph.GetSubGraph(i)->Nodes()));
  }
  return deps;
}

X
better  
Xin Pan 已提交
811 812 813
}  // namespace ir
}  // namespace framework
}  // namespace paddle