multi_devices_graph_check_pass.cc 3.0 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

X
clean  
Xin Pan 已提交
15
#include <string>
16
#include "paddle/fluid/framework/details/multi_devices_helper.h"
X
clean  
Xin Pan 已提交
17
#include "paddle/fluid/framework/ir/graph.h"
X
Xin Pan 已提交
18
#include "paddle/fluid/framework/ir/graph_helper.h"
C
chengduoZH 已提交
19 20 21 22 23

namespace paddle {
namespace framework {
namespace details {

24 25 26 27 28 29 30
class SSAGraghBuilderWithChecker : public ir::Pass {
 protected:
  std::unique_ptr<ir::Graph> ApplyImpl(
      std::unique_ptr<ir::Graph> graph) const override {
    PADDLE_ENFORCE(IsValidGraph(graph.get()));
    return graph;
  }
C
chengduoZH 已提交
31

32 33 34 35 36
  bool IsValidGraph(const ir::Graph *graph) const {
    std::unordered_map<OpHandleBase *, size_t> pending_ops;
    std::unordered_set<VarHandleBase *> pending_vars;
    std::unordered_set<VarHandleBase *> ready_vars;
    std::unordered_set<OpHandleBase *> ready_ops;
C
chengduoZH 已提交
37

38 39 40 41
    auto insert_pending_var = [&](VarHandleBase *var) {
      pending_vars.insert(var);
      if (var->GeneratedOp() == nullptr) {
        ready_vars.emplace(var);
C
chengduoZH 已提交
42
      }
43
    };
C
chengduoZH 已提交
44

45 46 47 48 49 50 51
    for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
      for (auto &name_pair : var_map) {
        for (auto &version_pair : name_pair.second) {
          insert_pending_var(version_pair);
        }
      }
    }
C
chengduoZH 已提交
52

53 54
    for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
      insert_pending_var(var);
C
chengduoZH 已提交
55 56
    }

57 58 59 60 61
    for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph)) {
      if (op->Inputs().empty()) {
        ready_ops.insert(op);
      } else {
        pending_ops.insert({op, op->NoDupInputSize()});
C
chengduoZH 已提交
62 63 64
      }
    }

65 66 67 68 69 70 71 72
    auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
      for (auto *op : set) {
        for (auto out : op->Outputs()) {
          ready_vars.emplace(out);
        }
      }
      set.clear();
    };
C
chengduoZH 已提交
73

74 75
    while (!pending_vars.empty()) {
      run_all_ops(ready_ops);
C
chengduoZH 已提交
76

77 78 79 80 81 82 83 84 85 86 87
      if (ready_vars.empty()) {
        return false;
      }

      for (auto ready_var : ready_vars) {
        pending_vars.erase(ready_var);
        for (auto *op : ready_var->PendingOps()) {
          auto &deps = --pending_ops[op];
          if (deps == 0) {
            ready_ops.insert(op);
          }
C
chengduoZH 已提交
88 89
        }
      }
90
      ready_vars.clear();
C
chengduoZH 已提交
91
    }
92
    return true;
C
chengduoZH 已提交
93
  }
94 95
};

C
chengduoZH 已提交
96 97 98
}  // namespace details
}  // namespace framework
}  // namespace paddle
X
Xin Pan 已提交
99

X
Xin Pan 已提交
100
REGISTER_PASS(multi_devices_check_pass,
X
Xin Pan 已提交
101 102
              paddle::framework::details::SSAGraghBuilderWithChecker)
    .RequireGraphAttr(paddle::framework::details::kGraphVars)
X
Xin Pan 已提交
103
    .RequireGraphAttr(paddle::framework::details::kGraphDepVars);