database.h 3.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <unordered_map>

#include "paddle/cinn/auto_schedule/auto_schedule.pb.h"
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/ir/schedule_desc.pb.h"

namespace cinn {
namespace auto_schedule {

// Record related data about tuning process of a measure candidate
struct TuningRecord {
  // the unique key to identify a task
  std::string task_key;
  // the predicted cost of CostModel
  float predicted_cost;  // unit: us
  // the ScheduleDesc of this tuning process
  ir::proto::ScheduleDesc trace;
  // the cost time of the candidate executed during measure
  double execution_cost;  // unit: us

  TuningRecord() = default;
37
  explicit TuningRecord(const proto::TuningRecord& record)
38 39 40 41
      : task_key(record.task_key()),
        predicted_cost(record.predicted_cost()),
        trace(record.trace()),
        execution_cost(record.execution_cost()) {}
42 43 44
  TuningRecord(const std::string& task_key,
               const SearchState& state,
               double execution_cost)
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
      : task_key(task_key),
        predicted_cost(state->predicted_cost),
        trace(state->ir_schedule.GetTraceDesc().ToProto()),
        execution_cost(execution_cost) {}

  // convert to proto object
  proto::TuningRecord ToProto() const;

  // a binary compare function that denotes when the left
  // will be sorted in the front of the right
  struct Compare {
    bool operator()(const TuningRecord& lhs, const TuningRecord& rhs) const;
  };
};

enum class DatabaseType : int { kMemory, kJSONFile };

struct DatabaseConfig {
63 64
  DatabaseType type = DatabaseType::kMemory;
  int capacity_per_task = 2;
65 66 67
  std::string record_file_path = "/tmp/tuning_record.json";
};

68 69 70 71
// A database supports insert or lookup historial tuning result with specified
// traits. It can be implemented with a concrete storage to save/load underlying
// data, such as memory, file, database server and so on, this base class can be
// regarded as one using memory as its underlying storage medium.
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
class Database {
 public:
  explicit Database(int capacity_per_task);
  ~Database() = default;

  // Create a Database with the specific config
  static std::unique_ptr<Database> Make(const DatabaseConfig& config);

  // add a record into the database
  bool AddRecord(const TuningRecord& record);
  // return all records whose task_keys are equal to the specified key
  std::vector<TuningRecord> LookUp(const std::string& task_key);
  // return the states of the top k in sorted candidates
  std::vector<TuningRecord> GetTopK(const std::string& task_key, int k);
  // return the total number of stored candidates
  size_t Size();
  // return the number of stored candidates with specified key
  size_t Count(const std::string& task_key);

 protected:
  // commit the newly added record into underlying storage
  virtual bool Commit(const TuningRecord& record) { return true; }
  // insert a newly added record into memory storage
  void Insert(const TuningRecord& record);

  // map task_key to its records
98 99 100
  std::unordered_map<std::string,
                     std::multiset<TuningRecord, TuningRecord::Compare>>
      key2record_;
101 102 103 104 105 106
  // the max number of candidates stored
  const int capacity_per_task_;
};

}  // namespace auto_schedule
}  // namespace cinn