visualize_helper.h 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <absl/container/flat_hash_map.h>
#include <sys/stat.h>
#include <unistd.h>

#include <fstream>
#include <iomanip>
#include <iostream>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/utils/dot_lang.h"

namespace cinn {
namespace hlir {
namespace framework {

class PassPrinter {
 public:
  static PassPrinter* GetInstance() {
    static PassPrinter printer;
    return &printer;
  }

  bool Begin(const std::unordered_set<std::string>& fetch_ids = {});
44 45
  bool PassBegin(const std::string& pass_name,
                 const frontend::Program& program);
46 47 48 49 50 51 52 53 54 55 56 57
  bool PassEnd(const std::string& pass_name, const frontend::Program& program);
  bool PassBegin(const std::string& pass_name, Graph* g);
  bool PassEnd(const std::string& pass_name, Graph* g);
  bool End();

 private:
  std::unordered_set<std::string> fetch_ids_;
  std::string save_path_;
  int64_t graph_id_{0};
  int64_t pass_id_{0};
};

58 59
inline void WriteToFile(const std::string& filepath,
                        const std::string& content) {
60 61 62 63 64 65 66 67
  VLOG(4) << "Write to " << filepath;
  std::ofstream of(filepath);
  CHECK(of.is_open()) << "Failed to open " << filepath;
  of << content;
  of.close();
}

inline std::string GenClusterId(const std::vector<Node*>& group, int group_id) {
68 69
  return "group_" + std::to_string(group_id) +
         "(size=" + std::to_string(group.size()) + ")";
70 71
}

72 73 74
inline std::string GenNodeId(const Node* node,
                             bool is_recomputed,
                             int recompute_id) {
75 76 77 78 79 80 81
  if (is_recomputed) {
    return node->id() + "/" + std::to_string(recompute_id);
  } else {
    return node->id();
  }
}

82 83 84
inline std::string GenNodeDataId(const NodeData* data,
                                 bool is_recomputed,
                                 int recompute_id) {
85 86 87 88 89 90 91 92 93
  if (is_recomputed) {
    return data->id() + "/" + std::to_string(recompute_id);
  } else {
    return data->id();
  }
}

inline std::vector<utils::DotAttr> GetGroupOpAttrs(bool is_recomputed = false) {
  std::string color = is_recomputed ? "#836FFF" : "#8EABFF";
94 95 96
  return std::vector<utils::DotAttr>{utils::DotAttr("shape", "Mrecord"),
                                     utils::DotAttr("color", color),
                                     utils::DotAttr("style", "filled")};
97 98 99
}

inline std::vector<utils::DotAttr> GetOutlinkOpAttrs() {
100 101 102
  return std::vector<utils::DotAttr>{utils::DotAttr("shape", "Mrecord"),
                                     utils::DotAttr("color", "#ff7f00"),
                                     utils::DotAttr("style", "filled")};
103 104 105 106
}

inline std::vector<utils::DotAttr> GetGroupVarAttrs(bool is_fetched = false) {
  if (is_fetched) {
107 108 109
    return std::vector<utils::DotAttr>{utils::DotAttr("peripheries", "2"),
                                       utils::DotAttr("color", "#43CD80"),
                                       utils::DotAttr("style", "filled")};
110
  } else {
111 112
    return std::vector<utils::DotAttr>{utils::DotAttr("color", "#FFDC85"),
                                       utils::DotAttr("style", "filled")};
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
  }
}

inline std::vector<utils::DotAttr> GetGroupAttrs(size_t group_size) {
  std::string fillcolor;
  if (group_size == 1) {
    fillcolor = "#E8E8E8";
  } else if (group_size <= 3) {
    fillcolor = "#FFFFF0";
  } else if (group_size <= 10) {
    fillcolor = "#F0FFFF";
  } else {
    // group_size > 10
    fillcolor = "#EEE5DE";
  }
128 129 130
  std::vector<utils::DotAttr> attrs = {utils::DotAttr("color", "grey"),
                                       utils::DotAttr("style", "filled"),
                                       utils::DotAttr("fillcolor", fillcolor)};
131 132 133 134 135 136 137 138 139
  return attrs;
}

bool MakeDirectory(const std::string& dirname, mode_t mode);

std::string GetFilePathForGroup(const std::vector<std::vector<Node*>>& groups,
                                const int group_id,
                                const std::string& viz_path);

140 141 142 143 144
std::string GenNodeDataLabel(
    const NodeData* node,
    const absl::flat_hash_map<std::string, shape_t>& shape_dict,
    const absl::flat_hash_map<std::string, common::Type>& dtype_dict,
    const std::string dot_nodedata_id);
145

146 147
void Summary(const std::vector<std::vector<Node*>>& groups,
             const std::string& viz_path);
148 149 150 151 152 153

std::string DebugString(const Node* node);

void FindRecomputeNodes(const std::vector<std::vector<Node*>>& groups,
                        std::unordered_map<std::string, int>* recompute_nodes);

154 155 156 157 158 159 160 161 162 163
void AddGroupNode(
    const Node* node,
    const std::string& dot_cluster_id,
    const std::unordered_set<std::string>& fetch_var_ids,
    const absl::flat_hash_map<std::string, shape_t>& shape_dict,
    const absl::flat_hash_map<std::string, common::Type>& dtype_dict,
    std::unordered_map<std::string, int>* recompute_nodes,
    std::unordered_map<std::string, std::string>* outnode2dot_id,
    std::unordered_set<std::string>* nodedatas_set,
    utils::DotLang* dot);
164 165 166 167 168 169 170 171

// used for CheckFusionAccuracyPass
std::string GenerateAccCheckNodeId(const std::string& node_id);

bool IsAccCheckOp(const Node* op);
bool IsAccCheckVar(const NodeData* var);
bool IsAccCheckGroup(const std::vector<Node*>& group);

172 173
std::vector<std::vector<Node*>> RemoveAccCheckGroups(
    const std::vector<std::vector<Node*>>& groups);
174 175 176 177

}  // namespace framework
}  // namespace hlir
}  // namespace cinn