cinn_cache_key.cc 4.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h"

17 18
#include <algorithm>
#include <functional>
19
#include <map>
20
#include <set>
21
#include <sstream>
22 23 24 25
#include <string>

#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
26
#include "paddle/phi/core/ddim.h"
27 28 29 30 31

namespace paddle {
namespace framework {
namespace paddle2cinn {

J
jiangcheng 已提交
32 33 34 35 36
using GraphHashStrategy = CinnCacheKey::GraphHashStrategy;

CinnCacheKey::CinnCacheKey(GraphHashStrategy graph_hash)
    : graph_hash_(graph_hash) {}

37 38
CinnCacheKey::CinnCacheKey(
    const ir::Graph& graph,
39
    const std::map<std::string, const phi::DenseTensor*>& input_tensors,
40 41
    const std::string& arch_str,
    GraphHashStrategy graph_hash)
J
jiangcheng 已提交
42
    : graph_hash_(graph_hash) {
43
  this->SetKey(graph, input_tensors, arch_str);
44 45 46
}

CinnCacheKey::CinnCacheKey(const ir::Graph& graph,
47
                           const std::map<std::string, DDim>& input_shapes,
J
jiangcheng 已提交
48 49 50
                           const std::string& arch_str,
                           GraphHashStrategy graph_hash)
    : graph_hash_(graph_hash) {
51
  this->SetKey(graph, input_shapes, arch_str);
52 53 54 55
}

void CinnCacheKey::SetKey(
    const ir::Graph& graph,
56
    const std::map<std::string, const phi::DenseTensor*>& input_tensors,
57
    const std::string& arch_str) {
J
jiangcheng 已提交
58
  graph_hash_val_ = graph_hash_(graph);
59 60
  for (const auto& name_tensor : input_tensors) {
    input_shapes_[name_tensor.first] = name_tensor.second->dims();
61
  }
62
  arch_str_ = arch_str;
63 64 65
}

void CinnCacheKey::SetKey(const ir::Graph& graph,
66 67
                          const std::map<std::string, DDim>& input_shapes,
                          const std::string& arch_str) {
J
jiangcheng 已提交
68
  graph_hash_val_ = graph_hash_(graph);
69 70
  input_shapes_ = input_shapes;
  arch_str_ = arch_str;
71 72 73 74 75 76 77
}

bool CinnCacheKey::operator!=(const CinnCacheKey& other) const {
  return !this->operator==(other);
}

bool CinnCacheKey::operator==(const CinnCacheKey& other) const {
J
jiangcheng 已提交
78
  return graph_hash_val_ == other.graph_hash_val_ &&
79
         input_shapes_ == other.input_shapes_ && arch_str_ == other.arch_str_;
80 81 82
}

size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const {
83
  std::ostringstream has_str;
84

85
  for (const auto& name_shape : key.input_shapes_) {
86 87
    has_str << name_shape.first;
    has_str << name_shape.second.to_str();
88 89
  }

90 91 92
  has_str << key.graph_hash_val_;
  has_str << key.arch_str_;
  return std::hash<std::string>()(has_str.str());
93 94
}

J
jiangcheng 已提交
95 96 97 98 99 100 101 102 103
size_t CinnCacheKeyByStructure::HashGraph(const ir::Graph& graph) {
  // sort grad node by name and id.
  auto compare = [](ir::Node* n1, ir::Node* n2) {
    return (n1->Name() == n2->Name()) ? (n1->id() < n2->id())
                                      : (n1->Name() < n2->Name());
  };

  // graph.Nodes() return unordered_set, here using set to avoid the same graph
  // may return different result
104
  std::set<ir::Node*, bool (*)(ir::Node*, ir::Node*)> node_set(compare),
J
jiangcheng 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
      output_set(compare);
  node_set.insert(graph.Nodes().begin(), graph.Nodes().end());

  std::string hash_str;
  for (ir::Node* n : node_set) {
    hash_str.append(n->Name());

    output_set.clear();
    output_set.insert(n->outputs.begin(), n->outputs.end());
    for (auto* out : output_set) {
      hash_str.append(out->Name());
    }
  }

  VLOG(1) << "The hash graph:\n" << hash_str;

  size_t hash_val = std::hash<std::string>()(hash_str);
  VLOG(4) << "The graph's hash value by graph structure is: " << hash_val;
  return hash_val;
}

size_t CinnCacheKeyByAddress::HashGraph(const ir::Graph& graph) {
  size_t hash_val = reinterpret_cast<size_t>(&graph);
  VLOG(4) << "The graph's hash value by graph address is: " << hash_val;
  return hash_val;
}

132 133 134
}  // namespace paddle2cinn
}  // namespace framework
}  // namespace paddle