graph_visualize_pass.cc 4.3 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/core/mir/graph_visualize_pass.h"
Y
Yan Chunwei 已提交
16
#include <map>
Y
Yan Chunwei 已提交
17 18 19
#include <memory>
#include <set>
#include <string>
Y
Yan Chunwei 已提交
20
#include <utility>
21
#include <vector>
Y
Yan Chunwei 已提交
22 23 24 25 26 27 28 29
#include "lite/core/mir/pass_registry.h"
#include "lite/utils/string.h"

namespace paddle {
namespace lite {
namespace mir {

void GraphVisualizePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
30
  VLOG(5) << "\n" << Visualize(graph.get());
Y
Yan Chunwei 已提交
31 32 33
}

std::string Visualize(mir::SSAGraph* graph) {
34
  std::ostringstream os;
35
  Dot dot;
36 37 38 39 40 41 42 43 44 45 46 47 48
  auto string_trunc = [](const std::string& str) -> std::string {
    const int max_disp_size = 100;
    if (str.length() > max_disp_size)
      return str.substr(0, max_disp_size) + "...";
    return str;
  };
  auto attr_repr = [&](const OpInfo* op_info,
                       const std::string& attr_name) -> std::string {
    std::ostringstream os;
    using AttrType = cpp::OpDesc::AttrType;
    auto attr_type = op_info->GetAttrType(attr_name);
    switch (attr_type) {
      case AttrType::INT:
49 50
        os << ":int:"
           << paddle::lite::to_string(op_info->GetAttr<int>(attr_name));
51 52
        break;
      case AttrType::FLOAT:
53 54
        os << ":float:"
           << paddle::lite::to_string(op_info->GetAttr<float>(attr_name));
55 56
        break;
      case AttrType::BOOLEAN:
57 58
        os << ":int:"
           << paddle::lite::to_string(op_info->GetAttr<bool>(attr_name));
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
        break;
      case AttrType::STRING:
        os << ":string: \""
           << string_trunc(op_info->GetAttr<std::string>(attr_name)) << "\"";
        break;
      case AttrType::FLOATS: {
        auto vals = op_info->GetAttr<std::vector<float>>(attr_name);
        os << ":floats: {" + Join(vals, ",") << "}";
      } break;
      case AttrType::INTS: {
        auto vals = op_info->GetAttr<std::vector<int>>(attr_name);
        os << ":ints: {" + Join(vals, ",") + "}";
      } break;
      case AttrType::STRINGS: {
        auto vals = op_info->GetAttr<std::vector<std::string>>(attr_name);
        os << ":strings: {" + string_trunc(Join(vals, ",")) << "}";
      } break;
      default:
        os << ":Unknow type(" << static_cast<int>(attr_type) << ")";
        break;
Y
Yan Chunwei 已提交
79
    }
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
    return os.str();
  };
  int op_idx = 0;
  std::set<std::string> exists_var_names;
  for (auto& node : graph->StmtTopologicalOrder()) {
    if (!node->IsStmt()) continue;
    auto op_info = node->AsStmt().op_info();
    auto op_type = op_info->Type();
    std::string op_name = string_format("%s%d", op_type.c_str(), op_idx++);
    // Add its input&output variables as the Dot nodes
    dot.AddNode(op_name,
                {Dot::Attr("shape", "box"),
                 Dot::Attr("style", "filled"),
                 Dot::Attr("color", "black"),
                 Dot::Attr("fillcolor", "yellow")});
    for (auto& x : node->inlinks) {
      auto var_name = x->AsArg().name;
      if (!exists_var_names.count(var_name)) {
        dot.AddNode(var_name, {});
        exists_var_names.insert(var_name);
Y
Yan Chunwei 已提交
100
      }
101 102 103 104 105 106 107
      dot.AddEdge(var_name, op_name, {});
    }
    for (auto& x : node->outlinks) {
      auto var_name = x->AsArg().name;
      if (!exists_var_names.count(var_name)) {
        dot.AddNode(var_name, {});
        exists_var_names.insert(var_name);
Y
Yan Chunwei 已提交
108
      }
109 110 111 112 113 114 115
      dot.AddEdge(op_name, var_name, {});
    }
    // Output its all of attributes(name and values)
    os << "* " << op_name << "\n";
    const auto& attr_names = op_info->AttrNames();
    for (auto& attr_name : attr_names) {
      os << " - " << attr_name << attr_repr(op_info, attr_name) << "\n";
Y
Yan Chunwei 已提交
116 117
    }
  }
118 119
  os << dot.Build();
  return os.str();
Y
Yan Chunwei 已提交
120 121 122 123 124 125
}

}  // namespace mir
}  // namespace lite
}  // namespace paddle

126
REGISTER_MIR_PASS(graph_visualize_pass, paddle::lite::mir::GraphVisualizePass)
127
    .BindTargets({TARGET(kAny)});