multi_devices_graph_print_pass.cc 2.9 KB
Newer Older
Y
yuyang18 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

X
Xin Pan 已提交
15
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
Y
yuyang18 已提交
16
#include <string>
X
clean  
Xin Pan 已提交
17
#include "paddle/fluid/framework/ir/graph.h"
X
Xin Pan 已提交
18
#include "paddle/fluid/framework/ir/graph_helper.h"
Y
yuyang18 已提交
19 20 21 22 23 24

namespace paddle {
namespace framework {
namespace details {

template <typename Callback>
X
Xin Pan 已提交
25
static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
X
Xin Pan 已提交
26
  for (auto &each : graph.Get<GraphVars>(kGraphVars)) {
Y
yuyang18 已提交
27 28 29 30 31 32 33
    for (auto &pair1 : each) {
      for (auto &pair2 : pair1.second) {
        callback(*pair2);
      }
    }
  }

X
Xin Pan 已提交
34
  for (auto &var : graph.Get<GraphDepVars>(kGraphDepVars)) {
Y
yuyang18 已提交
35 36 37 38
    callback(*var);
  }
}

X
Xin Pan 已提交
39
void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
Y
yuyang18 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
                                    std::ostream &sout) const {
  size_t var_id = 0;
  std::unordered_map<const VarHandleBase *, size_t> vars;

  sout << "digraph G {\n";

  IterAllVar(graph, [&](const VarHandleBase &var) {
    auto *var_ptr = &var;
    auto *var_handle_ptr = dynamic_cast<const VarHandle *>(var_ptr);
    auto *dummy_ptr = dynamic_cast<const DummyVarHandle *>(var_ptr);

    size_t cur_var_id = var_id++;
    vars[var_ptr] = cur_var_id;

    if (var_handle_ptr) {
      sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name_
           << "\\n"
           << var_handle_ptr->place_ << "\\n"
T
typhoonzero 已提交
58 59
           << "scope: " << var_handle_ptr->scope_idx_ << "\\n"
           << "v" << var_handle_ptr->version_ << "\"]" << std::endl;
Y
yuyang18 已提交
60 61 62 63 64 65
    } else if (dummy_ptr) {
      sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl;
    }
  });

  size_t op_id = 0;
X
Xin Pan 已提交
66
  for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(graph)) {
Y
yuyang18 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
    std::string op_name = "op_" + std::to_string(op_id++);
    sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
         << std::endl;
    for (auto in : op->Inputs()) {
      std::string var_name = "var_" + std::to_string(vars[in]);
      sout << var_name << " -> " << op_name << std::endl;
    }

    for (auto out : op->Outputs()) {
      std::string var_name = "var_" + std::to_string(vars[out]);
      sout << op_name << " -> " << var_name << std::endl;
    }
  }

  sout << "}\n";
}
}  // namespace details
}  // namespace framework
}  // namespace paddle
X
Xin Pan 已提交
86

X
Xin Pan 已提交
87
REGISTER_PASS(multi_devices_print_pass,
X
Xin Pan 已提交
88
              paddle::framework::details::SSAGraghBuilderWithPrinter);