cinn_cache_key.cc 6.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h"

17 18
#include <algorithm>
#include <functional>
19
#include <map>
20
#include <set>
21
#include <sstream>
22 23 24 25
#include <string>

#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
26
#include "paddle/fluid/framework/paddle2cinn/transform_type.h"
27
#include "paddle/phi/core/ddim.h"
28 29 30 31 32

namespace paddle {
namespace framework {
namespace paddle2cinn {

J
jiangcheng 已提交
33 34 35 36 37
using GraphHashStrategy = CinnCacheKey::GraphHashStrategy;

CinnCacheKey::CinnCacheKey(GraphHashStrategy graph_hash)
    : graph_hash_(graph_hash) {}

38 39
CinnCacheKey::CinnCacheKey(
    const ir::Graph& graph,
40
    const std::map<std::string, const phi::DenseTensor*>& input_tensors,
41 42
    const std::string& arch_str,
    GraphHashStrategy graph_hash)
J
jiangcheng 已提交
43
    : graph_hash_(graph_hash) {
44
  this->SetKey(graph, input_tensors, arch_str);
45 46 47
}

CinnCacheKey::CinnCacheKey(const ir::Graph& graph,
48
                           const std::map<std::string, DDim>& input_shapes,
49
                           const std::map<std::string, DataType>& input_dtypes,
J
jiangcheng 已提交
50 51 52
                           const std::string& arch_str,
                           GraphHashStrategy graph_hash)
    : graph_hash_(graph_hash) {
53
  this->SetKey(graph, input_shapes, input_dtypes, arch_str);
54 55 56 57
}

void CinnCacheKey::SetKey(
    const ir::Graph& graph,
58
    const std::map<std::string, const phi::DenseTensor*>& input_tensors,
59
    const std::string& arch_str) {
J
jiangcheng 已提交
60
  graph_hash_val_ = graph_hash_(graph);
61 62
  for (const auto& name_tensor : input_tensors) {
    input_shapes_[name_tensor.first] = name_tensor.second->dims();
63
    input_dtypes_[name_tensor.first] = name_tensor.second->dtype();
64
  }
65
  arch_str_ = arch_str;
66 67 68
}

void CinnCacheKey::SetKey(const ir::Graph& graph,
69
                          const std::map<std::string, DDim>& input_shapes,
70
                          const std::map<std::string, DataType>& input_dtypes,
71
                          const std::string& arch_str) {
72 73 74 75 76 77
  PADDLE_ENFORCE_EQ(
      input_shapes.size(),
      input_dtypes.size(),
      platform::errors::PreconditionNotMet(
          "Required input_shapes has same length with input_dtypes."));

J
jiangcheng 已提交
78
  graph_hash_val_ = graph_hash_(graph);
79
  input_shapes_ = input_shapes;
80
  input_dtypes_ = input_dtypes;
81
  arch_str_ = arch_str;
82 83 84 85 86 87 88
}

bool CinnCacheKey::operator!=(const CinnCacheKey& other) const {
  return !this->operator==(other);
}

bool CinnCacheKey::operator==(const CinnCacheKey& other) const {
J
jiangcheng 已提交
89
  return graph_hash_val_ == other.graph_hash_val_ &&
90 91
         input_shapes_ == other.input_shapes_ &&
         input_dtypes_ == other.input_dtypes_ && arch_str_ == other.arch_str_;
92 93 94
}

size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const {
95
  std::ostringstream has_str;
96

97
  for (const auto& name_shape : key.input_shapes_) {
98 99 100 101 102 103 104
    has_str << name_shape.first << ",";
    has_str << "[" << name_shape.second << "],";
    PADDLE_ENFORCE_NE(key.input_dtypes_.find(name_shape.first),
                      key.input_dtypes_.end(),
                      platform::errors::PreconditionNotMet(
                          "%s is not in key.input_dtypes_.", name_shape.first));
    has_str << key.input_dtypes_.at(name_shape.first) << ";";
105 106
  }

107
  has_str << key.arch_str_ << ",";
108
  has_str << key.graph_hash_val_;
109
  VLOG(1) << "CinnCacheKey : " << has_str.str();
110
  return std::hash<std::string>()(has_str.str());
111 112
}

J
jiangcheng 已提交
113 114 115 116 117 118 119 120 121
size_t CinnCacheKeyByStructure::HashGraph(const ir::Graph& graph) {
  // sort grad node by name and id.
  auto compare = [](ir::Node* n1, ir::Node* n2) {
    return (n1->Name() == n2->Name()) ? (n1->id() < n2->id())
                                      : (n1->Name() < n2->Name());
  };

  // graph.Nodes() return unordered_set, here using set to avoid the same graph
  // may return different result
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
  std::set<ir::Node*, bool (*)(ir::Node*, ir::Node*)> node_set(compare);
  for (ir::Node* node : graph.Nodes()) {
    if (node->IsOp()) {
      // only need cache graph with same op
      node_set.insert(node);
    }
  }

  static std::unordered_set<std::string> ignore_attr = {"op_callstack",
                                                        "op_device",
                                                        "op_namescope",
                                                        "op_role",
                                                        "op_role_var",
                                                        "with_quant_attr"};

  std::ostringstream hash_str;
  for (ir::Node* op : node_set) {
    hash_str << op->Name() << ":";
    hash_str << "input_num=" << op->inputs.size() << ",";
    hash_str << "output_num=" << op->outputs.size() << ",";

    const auto& attrs_unordered_map = op->Op()->GetAttrMap();
    std::map<std::string, Attribute> attrs_map(attrs_unordered_map.begin(),
                                               attrs_unordered_map.end());
    for (const auto& attr : attrs_map) {
      if (ignore_attr.count(attr.first)) {
        continue;
      }
      const auto& attr_str = PaddleAttributeToString(attr.second);
      if (!attr_str.empty()) {
        hash_str << attr.first << "=" << attr_str << ",";
      }
J
jiangcheng 已提交
154
    }
155
    hash_str << ";";
J
jiangcheng 已提交
156 157
  }

158
  VLOG(1) << "The hash graph:\n" << hash_str.str();
J
jiangcheng 已提交
159

160
  size_t hash_val = std::hash<std::string>()(hash_str.str());
J
jiangcheng 已提交
161 162 163 164 165 166 167 168 169 170
  VLOG(4) << "The graph's hash value by graph structure is: " << hash_val;
  return hash_val;
}

size_t CinnCacheKeyByAddress::HashGraph(const ir::Graph& graph) {
  size_t hash_val = reinterpret_cast<size_t>(&graph);
  VLOG(4) << "The graph's hash value by graph address is: " << hash_val;
  return hash_val;
}

171 172 173
}  // namespace paddle2cinn
}  // namespace framework
}  // namespace paddle