sort.cc 17.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/op/contrib/sort.h"

#include <gflags/gflags.h>

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
#include "paddle/cinn/hlir/op/op_util.h"
#include "paddle/cinn/hlir/pe/elementwise.h"
#include "paddle/cinn/hlir/pe/ir_schedule_pe.h"
#include "paddle/cinn/hlir/pe/transform.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"

DECLARE_bool(cinn_ir_schedule);

namespace cinn {
namespace hlir {
namespace op {

using common::CINNValue;
using common::CINNValuePack;

std::vector<ir::Tensor> ArgSort(const ir::Tensor &A,
                                const common::Target &target,
                                poly::StageMap stages,
                                const int &axis,
                                const bool &is_ascend,
                                const std::string &name) {
  std::string find_func_name;
  std::string index_func_name;
  if (target.arch == common::Target::Arch::NVGPU) {
    find_func_name.assign("cinn_nvgpu_next_smallest_int32");
  } else if (target.arch == common::Target::Arch::X86) {
    find_func_name.assign("cinn_host_next_smallest_int32");
  } else {
    LOG(FATAL) << "ArgSort only supports X86 and NVGPU ! Please Check.\n";
  }
  if (is_ascend) {
66 67
    index_func_name =
        cinn::hlir::GetExternFuncName(target, A->type(), "lt_num");
68
  } else {
69 70
    index_func_name =
        cinn::hlir::GetExternFuncName(target, A->type(), "gt_num");
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
  }
  int pos_axis = axis;
  if (pos_axis < 0) {
    pos_axis += A->shape.size();
  }
  auto positions = Compute(
      A->shape,
      [=](const std::vector<Expr> &indices) {
        Expr offset(0);
        Expr stride(1);
        for (int i = 0; i < indices.size(); i++) {
          if (i < pos_axis) {
            offset = offset * A->shape[i] + indices[i];
          } else if (i == pos_axis) {
            offset = offset * A->shape[i];
          } else {
            offset = offset * A->shape[i] + indices[i];
            stride = stride * A->shape[i];
          }
        }
91 92
        offset = common::AutoSimplify(offset);
        stride = common::AutoSimplify(stride);
93
        auto A_shape_axis = A->shape[pos_axis];
94 95
        return lang::CallExtern(index_func_name,
                                {A, A_shape_axis, A(indices), offset, stride});
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
      },
      name + "_temp");
  auto res = Compute(
      A->shape,
      [=](const std::vector<Expr> &indices) {
        Expr offset(0);
        Expr stride(1);
        for (int i = 0; i < indices.size(); i++) {
          if (i < pos_axis) {
            offset = offset * A->shape[i] + indices[i];
          } else if (i == pos_axis) {
            offset = offset * A->shape[i];
          } else {
            offset = offset * A->shape[i] + indices[i];
            stride = stride * A->shape[i];
          }
        }
        offset = common::AutoSimplify(offset);
        stride = common::AutoSimplify(stride);

        auto A_shape_axis = A->shape[pos_axis];
117 118 119
        auto idx = lang::CallExtern(
            find_func_name,
            {positions, A_shape_axis, indices[pos_axis], offset, stride});
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
        return idx;
      },
      name);
  stages->InsertLazily(positions);
  return {res, positions};
}

std::vector<ir::Tensor> Sort(const ir::Tensor &A,
                             const common::Target &target,
                             poly::StageMap stages,
                             const int &axis,
                             const bool &is_ascend,
                             const std::string &name) {
  int pos_axis = axis;
  if (pos_axis < 0) {
    pos_axis += A->shape.size();
  }
137 138 139
  auto sort_index =
      ArgSort(A, target, stages, pos_axis, is_ascend, name + "_index");
  auto res = Compute(
140 141 142 143 144 145 146 147 148 149 150
      A->shape,
      [=](const std::vector<Expr> &indices) {
        std::vector<Expr> A_indices(indices);
        A_indices[pos_axis] = sort_index.at(0)(indices);
        return A(A_indices);
      },
      name);
  stages->InsertLazily(sort_index.at(0));
  return {res, sort_index.at(0), sort_index.at(1)};
}

151 152 153 154 155 156
std::shared_ptr<framework::OpStrategy> StrategyForSort(
    const framework::NodeAttr &attrs,
    const std::vector<ir::Tensor> &inputs,
    const std::vector<Type> &out_type,
    const std::vector<std::vector<int>> &output_shapes,
    const Target &target) {
157 158 159 160
  auto attr_store = attrs.attr_store;
  std::string op_name("sort");

  CHECK(attr_store.count("axis")) << "find no attr of axis";
161
  int axis = absl::get<int>(attr_store.at("axis"));
162 163 164 165 166
  bool is_ascend = true;
  if (attr_store.count("is_ascend")) {
    is_ascend = absl::get<bool>(attr_store.at("is_ascend"));
  }

167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
  framework::CINNCompute sort_compute(
      [=](lang::Args args, lang::RetValue *ret) {
        CHECK(!args.empty())
            << "The input arguments of Sort compute is empty! Please check.\n";
        CINNValuePack pack_args = args[0];
        CHECK_GE(pack_args.size(), 1U)
            << "At least 1 input tensors for Sort compute\n";
        Expr A = pack_args[0];
        CHECK(A.as_tensor());
        CHECK(!output_shapes.empty());
        auto tensor_A = A.as_tensor_ref();
        auto stages = CreateStages({tensor_A});
        VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
                << ", output_shapes: " << utils::Join(output_shapes[0], ", ");
        auto tensor_name = UniqName("Sort_out");
        if (FLAGS_cinn_ir_schedule) {
          CHECK_EQ(pack_args.size(), 2U);
          CHECK(pack_args[1].is_string());
          tensor_name = pack_args[1].operator std::string();
        }
        std::vector<ir::Tensor> out =
            Sort(tensor_A, target, stages, axis, is_ascend, tensor_name);
        stages->InsertLazily(out[0]);
        std::vector<CINNValue> res{
            CINNValue(out[0]), CINNValue(out[1]), CINNValue(out[2])};
        CHECK(!out_type.empty())
            << "Output type of Sort is empty! Please check.\n";
        res.push_back(CINNValue(stages));
        *ret = CINNValuePack{res};
      });

  framework::CINNSchedule sort_schedule([=](lang::Args args,
                                            lang::RetValue *ret) {
200
    if (FLAGS_cinn_ir_schedule) {
201 202
      CHECK(!args.empty())
          << "The input argument of sort_schedule is empty! Please check.\n";
203 204 205 206 207 208 209 210 211 212 213 214 215
      common::CINNValuePack arg_pack = args[0];
      std::vector<Expr> vec_ast;
      for (int i = 0; i < arg_pack.size(); i++) {
        if (arg_pack[i].is_expr()) {
          Expr temp = arg_pack[i];
          vec_ast.emplace_back(temp);
        }
      }
      CHECK(!vec_ast.empty());
      ir::ModuleExpr mod_expr(vec_ast);
      ir::IRSchedule ir_sch(mod_expr);
      ir_sch.MergeExprs();
      auto blocks = ir_sch.GetAllBlocks();
216 217
      // TODO(Shixiaowei02): remove external calls, do not use local variables,
      // because the size will exceed the limit.
218 219 220
      ir_sch.SetBuffer(blocks[0], "local");
      ir_sch.SetBuffer(blocks[1], "local");

221 222 223 224
      int64_t prod_size = std::accumulate(output_shapes[0].begin(),
                                          output_shapes[0].end(),
                                          1,
                                          std::multiplies<int>());
225 226 227
      if (prod_size > 1 && target.arch == Target::Arch::X86) {
        pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true);
      }
228 229
      std::vector<common::CINNValue> res{
          common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
230 231
      *ret = common::CINNValuePack{res};
    } else {
232 233
      CHECK(!args.empty())
          << "The input argument of sort_schedule is empty! Please check.\n";
234
      CINNValuePack arg_pack = args[0];
235
      Expr out = arg_pack[0];
236 237 238 239 240 241 242 243 244 245
      CHECK(out.as_tensor());
      *ret = arg_pack;
    }
  });

  auto strategy = std::make_shared<framework::OpStrategy>();
  strategy->AddImpl(sort_compute, sort_schedule, "strategy.sort", 1);
  return strategy;
}

246 247 248 249 250 251
std::shared_ptr<framework::OpStrategy> StrategyForArgSort(
    const framework::NodeAttr &attrs,
    const std::vector<ir::Tensor> &inputs,
    const std::vector<Type> &out_type,
    const std::vector<std::vector<int>> &output_shapes,
    const Target &target) {
252 253
  auto attr_store = attrs.attr_store;
  CHECK(attr_store.count("axis")) << "find no attr of axis";
254
  int axis = absl::get<int>(attr_store.at("axis"));
255 256 257 258 259
  bool is_ascend = true;
  if (attr_store.count("is_ascend")) {
    is_ascend = absl::get<bool>(attr_store.at("is_ascend"));
  }

260 261 262 263
  framework::CINNCompute argsort_compute([=](lang::Args args,
                                             lang::RetValue *ret) {
    CHECK(!args.empty())
        << "The input arguments of ArgSort compute is empty! Please check.\n";
264
    CINNValuePack pack_args = args[0];
265 266
    CHECK_GE(pack_args.size(), 1U)
        << "At least 1 input tensors for ArgSort compute\n";
267 268 269 270
    Expr A = pack_args[0];
    CHECK(A.as_tensor());
    CHECK(!output_shapes.empty());
    auto tensor_A = A.as_tensor_ref();
271
    auto stages = CreateStages({tensor_A});
272 273 274 275 276 277 278 279 280 281 282 283 284 285
    VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
            << ", output_shapes: " << utils::Join(output_shapes[0], ", ");
    auto tensor_name = UniqName("ArgSort_out");
    if (FLAGS_cinn_ir_schedule) {
      CHECK_EQ(pack_args.size(), 3U);
      CHECK(pack_args[1].is_string());
      tensor_name = pack_args[1].operator std::string();
    }
    auto out = ArgSort(tensor_A, target, stages, axis, is_ascend, tensor_name);
    std::vector<CINNValue> res;
    stages->InsertLazily(out.at(0));
    stages->InsertLazily(out.at(1));
    res.push_back(CINNValue(out.at(0)));
    res.push_back(CINNValue(out.at(1)));
286 287
    CHECK(!out_type.empty())
        << "Output type of ArgSort is empty! Please check.\n";
288 289 290 291
    res.push_back(CINNValue(stages));
    *ret = CINNValuePack{res};
  });

292 293
  framework::CINNSchedule argsort_schedule([=](lang::Args args,
                                               lang::RetValue *ret) {
294
    if (FLAGS_cinn_ir_schedule) {
295 296
      CHECK(!args.empty())
          << "The input argument of argsort_schedule is empty! Please check.\n";
297 298 299 300 301 302 303 304 305 306 307 308 309
      common::CINNValuePack arg_pack = args[0];
      std::vector<Expr> vec_ast;
      for (int i = 0; i < arg_pack.size(); i++) {
        if (arg_pack[i].is_expr()) {
          Expr temp = arg_pack[i];
          vec_ast.emplace_back(temp);
        }
      }
      CHECK(!vec_ast.empty());
      ir::ModuleExpr mod_expr(vec_ast);
      ir::IRSchedule ir_sch(mod_expr);
      ir_sch.MergeExprs();
      auto blocks = ir_sch.GetAllBlocks();
310 311 312 313 314
      // TODO(Shixiaowei02): remove external calls, do not use local variables,
      // because the size will exceed the limit.
      // TODO(lanxianghit): There is a bug, setting buffer to "local" here will
      // cause the var declared twice at CodeGen. ir_sch.SetBuffer(blocks[0],
      // "local");
315 316 317 318
      int64_t prod_size = std::accumulate(output_shapes[0].begin(),
                                          output_shapes[0].end(),
                                          1,
                                          std::multiplies<int>());
319 320 321
      if (prod_size > 1 && target.arch == Target::Arch::X86) {
        pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true);
      }
322 323
      std::vector<common::CINNValue> res{
          common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
324 325
      *ret = common::CINNValuePack{res};
    } else {
326 327
      CHECK(!args.empty())
          << "The input argument of argsort_schedule is empty! Please check.\n";
328
      CINNValuePack arg_pack = args[0];
329
      Expr out = arg_pack[0];
330 331 332 333 334 335 336 337 338 339
      CHECK(out.as_tensor());
      *ret = arg_pack;
    }
  });

  auto strategy = std::make_shared<framework::OpStrategy>();
  strategy->AddImpl(argsort_compute, argsort_schedule, "strategy.argsort", 1);
  return strategy;
}

340 341 342 343 344
std::vector<std::vector<int>> InferShapeForSort(
    const std::vector<std::vector<int>> &inputs_shape,
    const framework::AttrMapType &attrs) {
  CHECK_EQ(inputs_shape.size(), 1UL)
      << "The input's shape size should be 1! Please check again.";
345 346 347 348 349 350 351
  int axis = 0;
  for (auto &iter : attrs) {
    if (iter.first == "axis") {
      axis = absl::get<int>(iter.second);
      break;
    }
  }
352 353
  CHECK_GT(inputs_shape[0].size(), axis)
      << "The input's dim should be greater than axis! ";
354 355 356 357
  std::vector<std::vector<int>> res{inputs_shape[0]};
  return res;
}

358 359 360 361
std::vector<Type> InferDtypeForSort(const std::vector<Type> &inputs_type,
                                    const framework::AttrMapType &attrs) {
  CHECK_EQ(inputs_type.size(), 1UL)
      << "The input's type size should be 1! Please check again.";
362 363 364 365
  std::vector<Type> res{inputs_type[0]};
  return res;
}

366 367 368 369 370
std::vector<std::vector<int>> InferShapeForArgSort(
    const std::vector<std::vector<int>> &inputs_shape,
    const framework::AttrMapType &attrs) {
  CHECK_EQ(inputs_shape.size(), 1UL)
      << "The input's shape size should be 1! Please check again.";
371 372 373 374 375 376 377 378 379 380
  int axis = 0;
  for (auto &iter : attrs) {
    if (iter.first == "axis") {
      axis = absl::get<int>(iter.second);
      break;
    }
  }
  if (axis < 0) {
    axis += inputs_shape[0].size();
  }
381 382
  CHECK_GT(inputs_shape[0].size(), axis)
      << "The input's dim should be greater than axis! ";
383 384 385 386 387
  std::vector<std::vector<int>> res{inputs_shape[0], inputs_shape[0]};

  return res;
}

388 389 390 391
std::vector<Type> InferDtypeForArgSort(const std::vector<Type> &inputs_type,
                                       const framework::AttrMapType &attrs) {
  CHECK_EQ(inputs_type.size(), 1UL)
      << "The input's type size should be 1! Please check again.";
392 393 394
  return {Int(32), Int(32)};
}

395 396 397 398 399 400
std::vector<std::vector<int>> InferShapeForTopK(
    const std::vector<std::vector<int>> &inputs_shape,
    const framework::AttrMapType &attrs) {
  CHECK_EQ(inputs_shape.size(), 1UL)
      << "The input's shape size should be 1! Please check again.";
  auto res = inputs_shape;
401 402
  auto k_it = attrs.find("k");
  CHECK(k_it != attrs.end()) << "The attr k of topk does not exist.";
403
  int k = absl::get<int>(k_it->second);
404 405 406 407 408 409 410 411 412 413 414 415
  auto axis_it = attrs.find("axis");
  CHECK(axis_it != attrs.end()) << "The attr axis of topk does not exist.";
  int axis = absl::get<int>(axis_it->second);
  if (axis < 0) {
    axis += res[0].size();
  }
  CHECK_GE(axis, 0);
  CHECK_LT(axis, res[0].size());
  res[0][axis] = std::min(res[0][axis], k);
  return {res[0], res[0]};
}

416 417 418 419
std::vector<Type> InferDtypeForTopK(const std::vector<Type> &inputs_type,
                                    const framework::AttrMapType &attrs) {
  CHECK_EQ(inputs_type.size(), 1UL)
      << "The input's type size should be 1! Please check again.";
420 421 422 423 424 425 426 427 428 429
  std::vector<Type> res{inputs_type[0], Int(64)};
  return res;
}

}  // namespace op
}  // namespace hlir
}  // namespace cinn

CINN_REGISTER_HELPER(sort_ops) {
  CINN_REGISTER_OP(sort)
430 431
      .describe(
          "Sort a variable x along the given axis and return sorted Variable.")
432 433
      .set_num_inputs(1)
      .set_num_outputs(1)
434 435
      .set_attr<cinn::hlir::framework::StrategyFunction>(
          "CINNStrategy", cinn::hlir::op::StrategyForSort)
436 437
      .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForSort))
      .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForSort))
438 439
      .set_attr<cinn::hlir::framework::OpPatternKind>(
          "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible)
440 441 442 443 444 445
      .set_support_level(4);

  CINN_REGISTER_OP(argsort)
      .describe("Sort a variable x along the given axis and return indices.")
      .set_num_inputs(1)
      .set_num_outputs(2)
446 447 448 449 450 451 452 453
      .set_attr<cinn::hlir::framework::StrategyFunction>(
          "CINNStrategy", cinn::hlir::op::StrategyForArgSort)
      .set_attr("infershape",
                MakeOpFunction(cinn::hlir::op::InferShapeForArgSort))
      .set_attr("inferdtype",
                MakeOpFunction(cinn::hlir::op::InferDtypeForArgSort))
      .set_attr<cinn::hlir::framework::OpPatternKind>(
          "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible)
454 455 456
      .set_support_level(4);

  CINN_REGISTER_OP(top_k)
457 458 459
      .describe(
          "Find values and indices of the k largest entries for the last "
          "dimension..")
460 461 462 463
      .set_num_inputs(1)
      .set_num_outputs(2)
      .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForTopK))
      .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForTopK))
464 465
      .set_attr<cinn::hlir::framework::OpPatternKind>(
          "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible)
466 467 468 469
      .set_support_level(4);

  return true;
}