schedule.h 11.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <absl/container/flat_hash_map.h>

#include <string>
#include <vector>

#include "paddle/cinn/common/target.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/pe/schedule_param.pb.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/poly/stage.h"

namespace cinn {
namespace hlir {
namespace pe {
class ScheduleParam {
 public:
  ~ScheduleParam();
  ScheduleParam(const ScheduleParam &) = delete;
  ScheduleParam &operator=(const ScheduleParam &) = delete;
  static ScheduleParam &get_cuda_instance() {
    static ScheduleParam instance{common::Target::Arch::NVGPU};
    return instance;
  }
  static ScheduleParam &get_x86_instance() {
    static ScheduleParam instance{common::Target::Arch::X86};
    return instance;
  }
45 46 47
  absl::flat_hash_map<std::string,
                      absl::flat_hash_map<std::string, std::vector<int>>>
      &GetParam() {
48 49
    return param_data;
  }
50 51 52 53
  absl::flat_hash_map<std::string, std::vector<int>> &operator[](
      const std::string &key) {
    return param_data[key];
  }
54 55 56 57
  int Count(const std::string &key) { return param_data.count(key); }

 private:
  ScheduleParam(common::Target::Arch arch);
58 59 60
  absl::flat_hash_map<std::string,
                      absl::flat_hash_map<std::string, std::vector<int>>>
      param_data;
61 62 63 64 65 66 67 68 69 70 71 72
};

int GetInnerSplitter(int origin, int other_axis);

int GetVectorizeFactor(int shape, int split_factor);

int SplitEven(int origin);

int GetBasicFactor(const Type &type, const common::Target &target);

int GetBetterSplitFactor(int shape, int split_factor);

73 74 75
int GetArrayPackingFactor(int shape,
                          const Type &type,
                          const common::Target &target);
76 77 78 79 80 81 82 83 84 85 86

void ScheduleInjectiveCPU(poly::Stage *stage,
                          const std::vector<int> &output_shape,
                          const common::Target &target,
                          bool vectorizable = true);
// to deprecate
void ScheduleInjectiveCPU1(poly::Stage *stage,
                           const std::vector<int> &output_shape,
                           const common::Target &target,
                           bool vectorizable = true);

87 88 89
void MatmulScheduleCUDA(poly::StageMap stages,
                        const ir::Tensor &output,
                        const common::Target &target);
90 91 92 93 94 95 96 97 98 99 100

void MatmulScheduleCPU(poly::StageMap stage,
                       const ir::Tensor &output,
                       const ir::Tensor &packedB,
                       const common::Target &target);

void MulScheduleCPU(poly::StageMap stage,
                    const ir::Tensor &output,
                    const ir::Tensor &input_tensor,
                    const common::Target &target);

101 102 103 104
void SoftmaxScheduleCPU(poly::StageMap stage,
                        const ir::Tensor &output,
                        const ir::Tensor &temp,
                        int axis = -1);
105 106 107 108 109 110 111 112 113 114

void GetConv2dFactors(absl::flat_hash_map<std::string, int> *factors,
                      int oc,
                      int ic,
                      int fc,
                      int oh,
                      int ow,
                      const Type &type,
                      const common::Target &target,
                      const std::string &key = "",
115
                      bool import_params = true);
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133

void GetConv2d1x1Factors(absl::flat_hash_map<std::string, int> *factors,
                         int oc,
                         int ic,
                         int oh,
                         int ow,
                         const Type &type,
                         const common::Target &target);

void Conv2d_NCHWc_Schedule_CPU(poly::StageMap stages,
                               const ir::Tensor &res,
                               ir::Tensor &packed_out,
                               const ir::Tensor &input_pad,
                               const ir::Tensor &weights_dilation,
                               const ir::Tensor &data,
                               const common::Target &target,
                               const std::string &key,
                               bool do_padding);
134 135 136 137 138 139 140 141 142
void GlobalPoolScheduleGPU(poly::StageMap stages,
                           const std::vector<ir::Tensor> &output,
                           const common::Target &target);
void PoolScheduleCPU(poly::StageMap stages,
                     const ir::Tensor &output,
                     const common::Target &target);
void PoolScheduleGPU(poly::StageMap stages,
                     ir::Tensor &output,
                     const common::Target &target);
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169

void Conv2d_NCHWc_Schedule_CPU_Nofuse(poly::StageMap stages,
                                      const ir::Tensor &res,
                                      ir::Tensor &packed_out,
                                      const ir::Tensor &input_pad,
                                      const ir::Tensor &weights_dilation,
                                      const ir::Tensor &data,
                                      const common::Target &target);

void Conv2d_NCHWc_1X1_Schedule_CPU(poly::StageMap stages,
                                   const ir::Tensor &res,
                                   ir::Tensor &packed_out,
                                   const ir::Tensor &input_pad,
                                   const ir::Tensor &weights_dilation,
                                   const ir::Tensor &data,
                                   const common::Target &target,
                                   const std::string &key,
                                   bool do_padding);

void Conv2d_NCHWc_1X1_Schedule_CPU_Nofuse(poly::StageMap stages,
                                          const ir::Tensor &res,
                                          ir::Tensor &packed_out,
                                          const ir::Tensor &input_pad,
                                          const ir::Tensor &weights_dilation,
                                          const ir::Tensor &data,
                                          const common::Target &target);

170 171 172 173 174 175 176 177 178
void Depthwise_Conv2d_NCHWc_Schedule_CPU_Nofuse(
    poly::StageMap stages,
    const ir::Tensor &res,
    ir::Tensor &packed_out,
    const ir::Tensor &input_pad,
    const ir::Tensor &weights_dilation,
    const ir::Tensor &data,
    const common::Target &target,
    bool do_padding);
179 180 181 182 183 184 185

void CudaScheduleMul(poly::StageMap stages,
                     ir::Tensor output,
                     const std::vector<int> &output_shape,
                     const common::Target &target);

// reduce shedules.
186 187 188 189
void CudaReduceSchedule(poly::StageMap stages,
                        ir::Tensor output,
                        int last_dimension_num,
                        const common::Target &target);
190

191 192 193 194
void CudaWarpReduceSchedule(poly::StageMap stages,
                            ir::Tensor tmp_out,
                            ir::Tensor out,
                            const common::Target &target);
195 196 197 198 199 200

void CudaBlockReduceInternalSchedule(poly::StageMap stages,
                                     ir::Tensor tmp_out,
                                     ir::Tensor out,
                                     const common::Target &target);

201 202 203 204 205
void CudaBlockReduceSchedule(poly::StageMap stages,
                             ir::Tensor reduce_tmp_out,
                             ir::Tensor tmp_out,
                             ir::Tensor out,
                             const common::Target &target);
206 207 208 209 210 211 212 213 214 215 216 217 218 219

void CudaBlockShuffleReduceSchedule(poly::StageMap stages,
                                    ir::Tensor reduce_reshape,
                                    ir::Tensor reduce_internal,
                                    ir::Tensor reduce_out,
                                    const common::Target &target);

void CudaTwoStepReduceSchedule(poly::StageMap stages,
                               ir::Tensor reshape,
                               ir::Tensor internal,
                               ir::Tensor tmp_out,
                               ir::Tensor out,
                               const common::Target &target);

220 221 222
void CudaScheduleDepthwiseConv(poly::StageMap stages,
                               ir::Tensor &output,
                               const common::Target &target);
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240

void CudaScheduleConv(poly::StageMap stages,
                      ir::Tensor &input_pad,
                      ir::Tensor &weights,
                      ir::Tensor &output,
                      const common::Target &target);

void CudaScheduleWinogradConv(poly::StageMap wino_stages,
                              std::vector<ir::Tensor> &all_tensors,
                              const common::Target &target);

void CudaScheduleConv2(poly::StageMap stages,
                       ir::Tensor &input_pad,
                       ir::Tensor &weights,
                       ir::Tensor &output,
                       const common::Target &target,
                       const std::string &key);

241 242 243
void CudaScheduleInjective(poly::Stage *stage,
                           const std::vector<int> &output_shape,
                           const common::Target &target);
244 245 246 247 248 249 250 251 252 253 254 255 256

void CudaSplitSchedule(common::CINNValuePack *arg_pack,
                       const std::vector<std::vector<int>> &output_shapes,
                       int axis,
                       const common::Target &target);

void CreateCudaSerialData(const std::string &file_name = "default_serial.log");

std::string GenerateX86ConvKey(const std::vector<Expr> &input_shape,
                               const std::vector<Expr> &weight_shape,
                               const std::vector<int> &strides,
                               const std::vector<int> &paddings,
                               const std::vector<int> &dilations,
257
                               const int &index = 0,
258 259 260 261 262 263 264
                               const std::string &model_name = "");

std::string GenerateX86ConvKey(const std::vector<int> &input_shape,
                               const std::vector<int> &weight_shape,
                               const std::vector<int> &strides,
                               const std::vector<int> &paddings,
                               const std::vector<int> &dilations,
265
                               const int &index = 0,
266 267 268
                               const std::string &model_name = "");
void CreateX86SerialData(const std::string &file_name = "default_serial.log");

269 270 271 272 273
void LoadSerialData(
    absl::flat_hash_map<std::string,
                        absl::flat_hash_map<std::string, std::vector<int>>>
        *params,
    const std::string &file_name = "default_serial.log");
274 275

void SaveSerialData(
276 277 278
    const absl::flat_hash_map<
        std::string,
        absl::flat_hash_map<std::string, std::vector<int>>> &model_data,
279 280 281 282
    const std::string &file_name = "default_serial.log");

int GetMaxSplitter(int a, int b);

283 284 285
absl::flat_hash_map<std::string,
                    absl::flat_hash_map<std::string, std::vector<int>>>
CreateCudaParams();
286 287 288 289

}  // namespace pe
}  // namespace hlir
}  // namespace cinn