evolutionary_search_test.cc 7.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/auto_schedule/search_strategy/evolutionary_search.h"

#include <gtest/gtest.h>

#include <memory>
#include <utility>

#include "paddle/cinn/auto_schedule/cost_model/expr_cost_model.h"
#include "paddle/cinn/auto_schedule/database/database.h"
#include "paddle/cinn/auto_schedule/search_space/search_space.h"
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/auto_schedule/task/task_creator.h"
#include "paddle/cinn/auto_schedule/task/task_registry.h"
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/auto_schedule/tuning.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_schedule.h"
#include "test/cpp/cinn/program_builder.h"

namespace cinn {
namespace auto_schedule {

37 38
std::vector<TuneTask> CreateTasks(const frontend::Program& program,
                                  const Target& target) {
39 40
  auto graph = std::make_shared<hlir::framework::Graph>(program, target);
  TaskCreator task_creator;
41 42 43 44 45 46 47 48
  auto tasks = task_creator.CreateTuneTaskOpLevel(graph.get());
  const auto& dtype_dict =
      graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
          "inferdtype");
  const auto& shape_dict = graph->GetAttrs<
      absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
  auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(
      dtype_dict, shape_dict, target);
49 50 51
  InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
  for (auto i = 0; i < tasks.size(); ++i) {
    tasks[i].Initialize(shape_dict, dtype_dict, op_lowerer.get());
52 53
    task_registry->Regist(tasks[i].serialized_key,
                          ir::ModuleExpr(tasks[i].GetLoweredFuncBodyExprs()));
54 55 56 57 58 59 60 61 62 63 64 65 66
  }
  return tasks;
}

/**
 * A mock search space is only used for test. It creates integer ir::Expr from
 * 0, -1, -2, ... and set the cost value same as the integer value.
 *
 * So evolutionary search should be able to find the minimal ModuleExpr with
 * smallest ir::Expr. This file tests it.
 */
class MockSearchSpace : public SearchSpace {
 public:
67 68
  explicit MockSearchSpace(const TuneTask& tune_task)
      : SearchSpace(tune_task) {}
69 70 71 72 73

  int GetMinExprValue() const { return min_expr_value_; }

  int GetModuleExprSize() const { return module_expr_size_; }

74 75
  std::vector<SearchState> GenerateSketches(
      int num, const std::string& strategy) override {
76 77 78 79 80 81 82 83 84 85 86 87 88 89
    std::vector<SearchState> ret;
    for (int i = 0; i < num; ++i) {
      std::vector<ir::Expr> exprs;
      for (int j = 0; j < module_expr_size_; ++j) {
        exprs.push_back(ir::Expr(-i));
      }
      min_expr_value_ = -i;
      ret.push_back(SearchState(ir::IRSchedule(ir::ModuleExpr(exprs))));
    }
    return ret;
  }

 private:
  int module_expr_size_ = 10;
90
  int min_expr_value_ = 0;
91 92 93
};

class MockCostModel : public ExprCostModel {
94 95 96
  float Predict(const ir::ModuleExpr& sample,
                const common::Target& target) const override {
    float cost = 0.0f;
97 98 99 100 101 102 103 104 105 106 107 108
    std::vector<ir::Expr> exprs = sample.GetExprs();
    for (const ir::Expr& expr : exprs) {
      if (expr.as_int32()) {
        cost += static_cast<float>((expr.as_int32()));
      }
    }
    return cost;
  }
};

TEST(EvolutionarySearch, GetOneBest) {
  TuneTask mock_tune_task;
109 110
  mock_tune_task.serialized_key = "mock_task";
  mock_tune_task.target = common::DefaultTarget();
111
  InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
112 113
  task_registry->Regist(mock_tune_task.serialized_key,
                        ir::ModuleExpr({ir::Expr(0)}));
114 115 116 117 118 119 120 121
  MockCostModel cost_model;
  TuningOptions options;
  Database db(2);
  EvolutionarySearch evolutionary_search(mock_tune_task, cost_model, &db);

  MockSearchSpace* mock_search_space = new MockSearchSpace(mock_tune_task);
  // Ownership is transferred so don't delete mock_search_space
  evolutionary_search.SetSearchSpace(mock_search_space);
122
  SearchState best_state = evolutionary_search.SearchModuleExpr(options);
123 124 125 126 127 128 129 130 131
  std::vector<ir::Expr> exprs = best_state->ir_schedule.GetModule().GetExprs();
  EXPECT_GE(exprs.size(), 1UL);
  for (const ir::Expr& e : exprs) {
    EXPECT_EQ(e.as_int32(), mock_search_space->GetMinExprValue());
  }
}

TEST(EvolutionarySearch, GetEpsGreedy) {
  TuneTask mock_tune_task;
132 133
  mock_tune_task.serialized_key = "mock_task";
  mock_tune_task.target = common::DefaultTarget();
134
  InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
135 136
  task_registry->Regist(mock_tune_task.serialized_key,
                        ir::ModuleExpr({ir::Expr(0)}));
137 138 139 140 141 142 143 144
  ExprCostModel cost_model;
  TuningOptions options;
  Database db(2);
  EvolutionarySearch evolutionary_search(mock_tune_task, cost_model, &db);

  MockSearchSpace* mock_search_space = new MockSearchSpace(mock_tune_task);
  // Ownership is transferred so don't delete mock_search_space
  evolutionary_search.SetSearchSpace(mock_search_space);
145 146
  std::vector<SearchState> search_states =
      evolutionary_search.SearchModuleExprEpsGreedy(options);
147 148

  EXPECT_GE(search_states.size(), 1UL);
149 150
  size_t expr_size =
      static_cast<size_t>(mock_search_space->GetModuleExprSize());
151 152 153 154 155 156 157
  for (const SearchState& state : search_states) {
    EXPECT_EQ(state->ir_schedule.GetModule().GetExprs().size(), expr_size);
  }
}

TEST(EvolutionarySearch, Evolve) {
  auto target = common::DefaultNVGPUTarget();
158 159 160
  auto tasks = CreateTasks(
      tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}}),
      target);
161 162 163 164 165 166 167
  CHECK_EQ(tasks.size(), 1);
  ExprCostModel cost_model;
  std::vector<const ir::ModuleExpr*> cost_model_samples(1);
  std::vector<float> cost_model_labels(1);
  for (size_t i = 0; i < 2; ++i) {
    ir::ModuleExpr me({ir::Expr(tasks[0].lowered_funcs[0])});
    cost_model_samples[0] = &me;
168
    cost_model_labels[0] = i + 10;
169 170 171 172 173 174 175 176 177
    cost_model.Update(cost_model_samples, cost_model_labels, target);
  }

  Database db(2);
  TuningOptions options;
  options.evolution_pick_database_topk = 0;

  EvolutionarySearch evolutionary_search(tasks[0], cost_model, &db);

178 179 180
  int num_population = 10;
  std::vector<SearchState> init_sketch =
      evolutionary_search.TestInitSketch(num_population, "rule_prune");
181 182 183
  for (int i = 0; i < num_population; ++i) {
    ir::ModuleExpr me(init_sketch[i]->ir_schedule.GetModule());
    cost_model_samples[0] = &me;
184
    cost_model_labels[0] = i;
185 186 187 188 189 190
    cost_model.Update(cost_model_samples, cost_model_labels, target);
  }
  VLOG(6) << "init sketch costs:";
  for (auto s : init_sketch) {
    VLOG(6) << "cost = " << s->predicted_cost;
  }
191 192
  std::vector<SearchState>*population_pre_ptr = &init_sketch,
  *population_next_ptr;
193 194
  std::vector<SearchState> population;
  for (int i = 0; i < 10; ++i) {
195 196
    population = evolutionary_search.TestEvolve(
        *population_pre_ptr, /*cross_over_num*/ 0, /*ret_num*/ 10);
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
    population_next_ptr = &population;
    VLOG(6) << "population[" << i + 1 << "] costs:";
    double total_cost_pre = 0.0, total_cost_next = 0.0;
    for (auto s : *population_pre_ptr) {
      total_cost_pre += s->predicted_cost;
    }
    for (auto s : *population_next_ptr) {
      total_cost_next += s->predicted_cost;
      VLOG(6) << "cost = " << s->predicted_cost;
    }
    VLOG(6) << "total_cost_next = " << total_cost_next;
    CHECK_LE(total_cost_next, total_cost_pre);
    std::swap(population_pre_ptr, population_next_ptr);
  }
}

}  // namespace auto_schedule
}  // namespace cinn