feature.h 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <cmath>
#include <vector>

#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/ir_schedule.h"

namespace cinn {
namespace auto_schedule {

/* Loop feature enums */
27 28 29 30 31 32 33
enum class ForOptimizeFeatureEnum : int {
  kNone,
  kGpuBind,
  kParallel,
  kUnroll,
  kVectorize
};
34 35

/* function to scale feature numbers */
36 37 38
inline float slog(float x) {
  return x < 0 ? std::log2(-x + 1) : std::log2(x + 1);
}
39 40 41 42 43 44 45 46

class LoopBlockFeature {
 public:
  // TODO(zhhsplendid): distinguish more types such as float16, float32,
  // float64, etc. However speed the gap between float and int are larger than
  // different bits, so we just distinguished int and float here
  /* Arithmetic features */
  int float_add_or_sub = 0;
47
  int float_mul = 0;
48
  int float_div_or_mod = 0;
49 50
  int float_cmp = 0;
  int float_math_func = 0;
51 52 53
  int float_other_call = 0;  // like simple assign, cast, etc.

  int int_add_or_sub = 0;
54
  int int_mul = 0;
55
  int int_div_or_mod = 0;
56 57
  int int_cmp = 0;
  int int_math_func = 0;
58 59
  int int_other_call = 0;  // like simple assign, cast, etc.

60
  int bool_op = 0;
61 62 63 64 65 66 67 68 69 70 71
  int select_op = 0;

  static constexpr int kArithSize = 6 * 2 + 2;

  /**
   * Buffer memory features, which is the number of memory operations.
   * Note that different size of memory operation can have various speed,
   * however the speed difference would be small in OS. A meticulous TODO
   * may be collect operand sizes (like alloc size, write size, or so)
   */
  int mem_alloc = 0;
72 73
  int mem_free = 0;
  int mem_read = 0;
74 75 76 77 78 79 80 81
  int mem_write = 0;

  static constexpr int kMemSize = 4;

  /**
   * Reduce and Broadcast features
   */
  int float_reduce_sum_or_sub = 0;
82 83
  int float_reduce_mul = 0;
  int float_reduce_div = 0;
84
  int float_reduce_max_or_min = 0;
85
  int float_broadcast = 0;
86 87

  int int_reduce_sum_or_sub = 0;
88 89
  int int_reduce_mul = 0;
  int int_reduce_div = 0;
90
  int int_reduce_max_or_min = 0;
91
  int int_broadcast = 0;
92 93 94 95 96 97 98 99 100 101 102 103 104 105

  static constexpr int kReduceBroadcastSize = 10;

  /* Loop type features */

  // A TODO maybe add loop position (Inner, Outer, Middle) feature

  ForOptimizeFeatureEnum loop_opt_type = ForOptimizeFeatureEnum::kNone;

  static constexpr int kOptApplySize = 5;

  /* Thread features if loop is optimized by GPU or CPU parallelism.
   * Useless in other cases.
   */
106 107 108 109 110 111 112
  int len_blockIdx_x = 0;
  int len_blockIdx_y = 0;
  int len_blockIdx_z = 0;
  int len_threadIdx_x = 0;
  int len_threadIdx_y = 0;
  int len_threadIdx_z = 0;
  int len_vthread = 0;  // length of virtual thread
113 114 115 116
  int vectorize_factor = 0;

  static constexpr int kThreadFeatureSize = 8;

117 118 119
  static constexpr int kTotalSize = kArithSize + kMemSize +
                                    kReduceBroadcastSize + kOptApplySize +
                                    kThreadFeatureSize;
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136

  /* Non-feature attributes, used to maintain during feature_extractor */

  // Number to indicate the loop block inside current one
  int num_sub_loops = 0;

  // Number of repeats of this loop, -1 represents unknown
  int loop_length = 1;
};

/**
 * Feature of Expr. It is used in CostModel
 */
class Feature {
 public:
  Feature();

137
  explicit Feature(const common::Target& target);
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170

  // Convert the various-length loop block features to fixed-size vector
  std::vector<float> ToFixedSizeVector();

  // Call when visit into a loop block to collect LoopBlockFeature
  void IntoLoopBlock();
  // Call when exit a loop block to collect LoopBlockFeature
  void ExitLoopBlock();
  // The current loop block which we should collect feature on
  LoopBlockFeature& CurrentLoopBlock();
  // The current loop block which we should collect feature on
  const LoopBlockFeature& CurrentLoopBlock() const;

 private:
  // We treat a computation feature to be encoded as variable-length vector.
  // The root compute block is not a loop, but we treat it as a size-1 loop.
  // Blocks are encoded like a stack. Each LoopBlockFeature contains a
  // num_sub_loops to indicate the next level sub-loop-block it contains.
  //
  // For example, code like:
  //
  // some_compute_0
  // loop1 {
  //   some_compute_1
  //   loop2 {
  //     some_compute_2
  //   }
  // }
  //
  // loop3 {
  //   some_compute_3
  // }
  //
171 172 173 174 175
  // We go through the code and push loops into stack, then the features are
  // encoded as [loop_block_feature_0, loop_block_feature_1,
  // loop_block_feature_2, loop_block_feature_3] where loop_block_feature_i
  // stores the features of some_compute_i (such as number of arithmetic
  // operations)
176 177 178 179 180 181 182 183 184 185 186 187 188 189
  //
  // loop_block_feature_0.num_sub_loops = 2
  // loop_block_feature_1.num_sub_loops = 1
  // loop_block_feature_2.num_sub_loops = 0
  // loop_block_feature_3.num_sub_loops = 0
  std::vector<LoopBlockFeature> stack_encoded_feature_;
  int current_loop_block_index_;
  std::vector<int> parent_indices_;

  common::Target target_;
};

}  // namespace auto_schedule
}  // namespace cinn