search_state.cc 5.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/auto_schedule/search_space/search_state.h"

#include <memory>
#include <sstream>
#include <utility>
#include <vector>

#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_schedule.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/utils/functional.h"
#include "paddle/cinn/utils/string.h"

namespace cinn {
namespace auto_schedule {

32 33 34
SearchState::SearchState(ir::IRSchedule ir_sch,
                         float cost,
                         const std::vector<AutoGenRule*>& rules)
35
    : common::Shared<_SearchState_>(common::make_shared<_SearchState_>()) {
36 37
  auto* state = get();
  state->ir_schedule = std::move(ir_sch);
38
  state->applicable_rules = rules;
39
  state->predicted_cost = cost;
40 41
}

42 43 44
SearchState SearchState::Copy() const {
  return SearchState((*this)->ir_schedule, (*this)->predicted_cost, {});
}
45 46 47 48 49

std::string _SearchState_::DebugString() const {
  const auto& exprs = ir_schedule.GetModule().GetExprs();
  std::stringstream module_stream;
  for (auto i = 0; i < exprs.size(); ++i) {
50 51
    module_stream << "Expr " << i << " {\n"
                  << exprs.at(i) << "\n}  // end Expr";
52 53 54 55 56 57 58 59 60 61 62
  }

  const char* fmt_str = R"ROC(
ModuleExpr {
%s
} // end ModuleExpr
ScheduleDesc {
%s
} // end ScheduleDesc
predicted_cost: %f)ROC";

63 64 65 66
  return utils::StringFormat(fmt_str,
                             module_stream.str().c_str(),
                             ir_schedule.GetTraceDesc().DebugString().c_str(),
                             predicted_cost);
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
}

bool operator<(const SearchState& left, const SearchState& right) {
  return left->predicted_cost < right->predicted_cost;
}

// Visit every node by expanding all of their fields in dfs order
class DfsWithExprsFields : public ir::IRVisitor {
 protected:
#define __m(t__)                          \
  void Visit(const ir::t__* x) override { \
    for (auto* n : x->expr_fields()) {    \
      if (n->defined()) {                 \
        Visit(n);                         \
      }                                   \
    }                                     \
  }

  NODETY_FORALL(__m)
#undef __m

  void Visit(const Expr* expr) override { IRVisitor::Visit(expr); }
};

// Generate a reduce hash of a AST tree by combining hash of each AST node
class IrNodesStructuralHash : public DfsWithExprsFields {
 public:
  IrNodesStructuralHash(size_t init_key) : hash_key_(init_key) {}
  size_t operator()(const Expr* expr) {
    Visit(expr);
    return hash_key_;
  }

  void Visit(const Expr* expr) override {
    static decltype(ir::kIrNodeTyReprs) Node2Name = ir::kIrNodeTyReprs;
    if (!expr->defined()) return;
    auto type_code = static_cast<IrNodeTyUnderlyingType>(expr->node_type());
104
    hash_key_ = utils::HashCombine(hash_key_, type_code);
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
    DfsWithExprsFields::Visit(expr);
  }

 private:
  void Visit(const ir::_Tensor_* x) override {
    for (auto& e : x->shape) {
      Visit(&e);
    }
    DfsWithExprsFields::Visit(x->buffer.As<ir::_Buffer_>());
  }

  using IrNodeTyUnderlyingType = std::underlying_type<ir::IrNodeTy>::type;
  size_t hash_key_;
};

size_t SearchStateHash::operator()(const SearchState& s) const {
121
  size_t hash_key = 0;
122 123 124 125 126 127 128
  const auto& exprs = s->ir_schedule.GetModule().GetExprs();
  for (auto&& expr : exprs) {
    hash_key = IrNodesStructuralHash(hash_key)(&expr);
  }
  return hash_key;
}

129 130
bool SearchStateEqual::operator()(const SearchState& lhs,
                                  const SearchState& rhs) const {
131 132 133 134 135 136 137
  const auto& lhs_exprs = lhs->ir_schedule.GetModule().GetExprs();
  const auto& rhs_exprs = rhs->ir_schedule.GetModule().GetExprs();
  // compare exprs size firstly
  if (lhs_exprs.size() != rhs_exprs.size()) return false;

  // compare every expr one by one with ir::IrEqualVisitor
  for (int i = 0; i < lhs_exprs.size(); ++i) {
138 139
    ir::IrEqualVisitor compartor(
        /*allow_name_suffix_diff=*/true);  // ignore suffix difference in name
140 141 142 143 144
    if (!compartor.Compare(lhs_exprs[i], rhs_exprs[i])) return false;
  }
  return true;
}

145 146 147
std::string JoinStatesDebugString(const std::string& title,
                                  const std::vector<SearchState>& states,
                                  bool verbose) {
148 149 150 151 152 153
  std::stringstream ss;
  ss << title << " states size:" << states.size() << "\n";
  SearchStateHash state_hasher;
  for (size_t i = 0; i < states.size(); ++i) {
    uint64_t hash_key = state_hasher(states[i]);
    if (verbose) {
154 155
      ss << "\tState-" << i << " hash:" << hash_key << "\t content:------>"
         << states[i]->DebugString() << "\n<------";
156 157 158 159 160 161 162 163 164
    } else {
      ss << "\tState-" << i << " hash:" << hash_key << "\n";
    }
  }
  return std::move(*ss.rdbuf()).str();
}

}  // namespace auto_schedule
}  // namespace cinn