graph_visualize_pass.cc 5.3 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/core/mir/graph_visualize_pass.h"
Y
Yan Chunwei 已提交
16
#include <map>
Y
Yan Chunwei 已提交
17 18 19
#include <memory>
#include <set>
#include <string>
Y
Yan Chunwei 已提交
20
#include <utility>
21
#include <vector>
Y
Yan Chunwei 已提交
22 23 24 25 26 27 28 29
#include "lite/core/mir/pass_registry.h"
#include "lite/utils/string.h"

namespace paddle {
namespace lite {
namespace mir {

void GraphVisualizePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
30
  VLOG(5) << "\n" << Visualize(graph.get());
Y
Yan Chunwei 已提交
31 32 33
}

std::string Visualize(mir::SSAGraph* graph) {
34
  std::ostringstream os;
35
  Dot dot;
36 37 38 39 40 41 42 43 44 45 46 47 48
  auto string_trunc = [](const std::string& str) -> std::string {
    const int max_disp_size = 100;
    if (str.length() > max_disp_size)
      return str.substr(0, max_disp_size) + "...";
    return str;
  };
  auto attr_repr = [&](const OpInfo* op_info,
                       const std::string& attr_name) -> std::string {
    std::ostringstream os;
    using AttrType = cpp::OpDesc::AttrType;
    auto attr_type = op_info->GetAttrType(attr_name);
    switch (attr_type) {
      case AttrType::INT:
49 50
        os << ":int:"
           << paddle::lite::to_string(op_info->GetAttr<int>(attr_name));
51 52
        break;
      case AttrType::FLOAT:
53 54
        os << ":float:"
           << paddle::lite::to_string(op_info->GetAttr<float>(attr_name));
55 56
        break;
      case AttrType::BOOLEAN:
57 58
        os << ":int:"
           << paddle::lite::to_string(op_info->GetAttr<bool>(attr_name));
59 60 61 62 63 64
        break;
      case AttrType::STRING:
        os << ":string: \""
           << string_trunc(op_info->GetAttr<std::string>(attr_name)) << "\"";
        break;
      case AttrType::FLOATS: {
65 66
        std::vector<float> vals =
            op_info->GetAttr<std::vector<float>>(attr_name);
67 68 69
        os << ":floats: {" + Join(vals, ",") << "}";
      } break;
      case AttrType::INTS: {
70
        std::vector<int> vals = op_info->GetAttr<std::vector<int>>(attr_name);
71 72 73
        os << ":ints: {" + Join(vals, ",") + "}";
      } break;
      case AttrType::STRINGS: {
74 75
        std::vector<std::string> vals =
            op_info->GetAttr<std::vector<std::string>>(attr_name);
76 77 78 79 80
        os << ":strings: {" + string_trunc(Join(vals, ",")) << "}";
      } break;
      default:
        os << ":Unknow type(" << static_cast<int>(attr_type) << ")";
        break;
Y
Yan Chunwei 已提交
81
    }
82 83 84 85 86 87 88 89
    return os.str();
  };
  int op_idx = 0;
  std::set<std::string> exists_var_names;
  for (auto& node : graph->StmtTopologicalOrder()) {
    if (!node->IsStmt()) continue;
    auto op_info = node->AsStmt().op_info();
    auto op_type = op_info->Type();
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
    std::string op_name;
    if (node->AsStmt().need_sync_) {
      std::ostringstream oss;
      for (size_t i = 0; i < node->AsStmt().sync_streams_.size(); ++i) {
        oss << std::to_string(node->AsStmt().sync_streams_[i]);
        if (i != node->AsStmt().sync_streams_.size() - 1) {
          oss << ",";
        }
      }
      op_name = string_format("%s%d, stream=%d, sync_streams={%s}",
                              op_type.c_str(),
                              op_idx++,
                              node->AsStmt().stream_id_,
                              oss.str().c_str());
    } else {
      op_name = string_format("%s%d", op_type.c_str(), op_idx++);
    }
107 108 109 110 111 112 113
    // Add its input&output variables as the Dot nodes
    dot.AddNode(op_name,
                {Dot::Attr("shape", "box"),
                 Dot::Attr("style", "filled"),
                 Dot::Attr("color", "black"),
                 Dot::Attr("fillcolor", "yellow")});
    for (auto& x : node->inlinks) {
114 115 116 117 118 119 120
      std::string var_name;
      if (x->AsArg().lane != -1) {
        var_name = string_format(
            "%s, lane=%d", x->AsArg().name.c_str(), x->AsArg().lane);
      } else {
        var_name = x->AsArg().name;
      }
121 122 123
      if (!exists_var_names.count(var_name)) {
        dot.AddNode(var_name, {});
        exists_var_names.insert(var_name);
Y
Yan Chunwei 已提交
124
      }
125 126 127
      dot.AddEdge(var_name, op_name, {});
    }
    for (auto& x : node->outlinks) {
128 129 130 131 132 133 134
      std::string var_name;
      if (x->AsArg().lane != -1) {
        var_name = string_format(
            "%s, lane=%d", x->AsArg().name.c_str(), x->AsArg().lane);
      } else {
        var_name = x->AsArg().name;
      }
135 136 137
      if (!exists_var_names.count(var_name)) {
        dot.AddNode(var_name, {});
        exists_var_names.insert(var_name);
Y
Yan Chunwei 已提交
138
      }
139 140 141 142 143 144 145
      dot.AddEdge(op_name, var_name, {});
    }
    // Output its all of attributes(name and values)
    os << "* " << op_name << "\n";
    const auto& attr_names = op_info->AttrNames();
    for (auto& attr_name : attr_names) {
      os << " - " << attr_name << attr_repr(op_info, attr_name) << "\n";
Y
Yan Chunwei 已提交
146 147
    }
  }
148 149
  os << dot.Build();
  return os.str();
Y
Yan Chunwei 已提交
150 151 152 153 154 155
}

}  // namespace mir
}  // namespace lite
}  // namespace paddle

156
REGISTER_MIR_PASS(graph_visualize_pass, paddle::lite::mir::GraphVisualizePass)
157
    .BindTargets({TARGET(kAny)});