transpose_function.cu.h 52.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

H
hong 已提交
17
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
18
#include "paddle/phi/backends/gpu/gpu_primitives.h"
19
#include "paddle/phi/backends/gpu/gpu_utils.h"
20
#include "paddle/phi/core/tensor_utils.h"
21
#include "paddle/phi/kernels/autotune/auto_tune_base.h"
22
#include "paddle/phi/kernels/funcs/aligned_vector.h"
23
#include "paddle/phi/kernels/funcs/dims_simplifier.h"
24
#include "paddle/phi/kernels/funcs/math_function.h"
25
#include "paddle/phi/kernels/primitive/datamover_primitives.h"
26

27 28
namespace phi {
namespace funcs {
29

30
using Tensor = phi::DenseTensor;
31 32 33 34 35 36 37 38 39 40 41

struct EqualTo {
  constexpr bool operator()(int a, int b) const { return a == b; }
};

struct GreaterThan {
  constexpr bool operator()(int a, int b) const { return a > b; }
};

// Value can be decided in compile time.
template <typename FUN, int INT_32 = 32>
42 43 44
constexpr bool CheckProperTileSize(int tile_long,
                                   int tile_short,
                                   int size_T,
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
                                   FUN op) {
  return (size_T == 16 && ((tile_long == INT_32 && op(tile_short, 4)) ||
                           (tile_long == 2 * INT_32 && op(tile_short, 4)) ||
                           (tile_long == 4 * INT_32 && op(tile_short, 4)) ||
                           (tile_long == 8 * INT_32 && op(tile_short, 2)))) ||
         (size_T == 8 && ((tile_long == INT_32 && op(tile_short, 15)) ||
                          (tile_long == 2 * INT_32 && op(tile_short, 15)) ||
                          (tile_long == 4 * INT_32 && op(tile_short, 8)) ||
                          (tile_long == 8 * INT_32 && op(tile_short, 4)) ||
                          (tile_long == 16 * INT_32 && op(tile_short, 2)))) ||
         ((size_T == 4 || size_T == 2 || size_T == 1) &&
          ((tile_long == INT_32 && op(tile_short, 15)) ||
           (tile_long == 2 * INT_32 && op(tile_short, 15)) ||
           (tile_long == 4 * INT_32 && op(tile_short, 8)) ||
           (tile_long == 8 * INT_32 && op(tile_short, 4)) ||
           (tile_long == 16 * INT_32 && op(tile_short, 2)) ||
           (tile_long == 16 * INT_32 && op(tile_short, 2))));
}

constexpr bool CheckLongTileSize(int tile_long, int tile_short, int size_T) {
  return CheckProperTileSize(tile_long, tile_short, size_T, EqualTo());
}

constexpr bool CheckOutsideTileSize(int tile_long, int tile_short, int size_T) {
  return CheckProperTileSize(tile_long, tile_short, size_T, GreaterThan());
}

constexpr bool CheckNonLongTileSize(int tile_long, int tile_short, int size_T) {
  return !CheckOutsideTileSize(tile_long, tile_short, size_T) &&
         (CheckOutsideTileSize(tile_long * 2, tile_short, size_T) ||
          CheckOutsideTileSize(tile_long, tile_short + 1, size_T)) &&
         !CheckLongTileSize(tile_long, tile_short, size_T);
}

// Use SM to do data transfer, load a tile into SM then store out.
// All tile read and write are colascing, so can speedup memory copy
81 82 83 84
template <typename T,
          int NumThreads,
          int TileX,
          int TileY,
85
          typename IndexType = int>
86 87
__global__ void TilingSwapDim1And2(const T* __restrict__ input,
                                   Dim3 input_dims,
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
                                   T* __restrict__ output) {
  assert(blockDim.x == NumThreads);
  assert(blockDim.y == 1);
  assert(blockDim.z == 1);
  assert(gridDim.y == 1);
  assert(gridDim.z == 1);

  constexpr int BlockReadRows = NumThreads / TileY;
  constexpr int BlockWriteRows = NumThreads / TileX;

  // One extra line in the inner dimension to avoid share memory bank conflict.
  __shared__ __align__(
      alignof(T)) char share_mem_ptr[TileX * (TileY + 1) * sizeof(T)];
  typedef T(*ShareMemory)[TileY + 1];

  ShareMemory tile_sm = reinterpret_cast<ShareMemory>(share_mem_ptr);

  int x = threadIdx.x;

  Dim3 output_dims = {
108 109 110
      input_dims[0],
      input_dims[2],
      input_dims[1],
111 112 113 114
  };

  // Align dim to Tiles
  Dim3 tile_aligned_input_dim = {
115 116
      input_dims[0],
      (input_dims[1] + TileX - 1) / TileX,
117 118 119 120
      (input_dims[2] + TileY - 1) / TileY,
  };

  // Converts block idx to tile index, each block process a tile
121 122
  Index3 input_block_tile_index =
      ConvertTensorIndex<IndexType>(blockIdx.x, tile_aligned_input_dim);
123 124 125

  // Compute real index align to tile:0, 32, 64...
  Index3 block_tile_index_in_input = {
126 127
      input_block_tile_index[0],
      input_block_tile_index[1] * TileX,
128 129 130 131
      input_block_tile_index[2] * TileY,
  };

  // Compute block flat index against input dims.
132
  IndexType input_origin_block_flat_index =
133
      FlatTensorIndex<IndexType>(block_tile_index_in_input, input_dims);
134 135

  bool full_tile = true;
136
  IndexType tile_width = TileY;
137 138 139 140 141 142 143

  // Last row is not full.
  if (input_block_tile_index[2] == tile_aligned_input_dim[2] - 1) {
    tile_width = input_dims[2] - (tile_aligned_input_dim[2] - 1) * TileY;
    full_tile &= false;
  }

144
  IndexType tile_height = TileX;
145 146 147 148 149 150

  if (input_block_tile_index[1] == tile_aligned_input_dim[1] - 1) {
    tile_height = input_dims[1] - (tile_aligned_input_dim[1] - 1) * TileX;
    full_tile &= false;
  }

151
  constexpr IndexType in_effective_thread_num = NumThreads / TileY * TileY;
152 153 154 155 156

  if (x < in_effective_thread_num) {
    // Read a tile from input using block.
    int x_i = x / TileY;
    int x_j = x % TileY;
157 158 159
    IndexType input_ind =
        input_origin_block_flat_index + x_i * input_dims[2] + x_j;
    IndexType input_inc = BlockReadRows * input_dims[2];
160 161 162 163 164 165 166 167 168 169

    if (full_tile) {
#pragma unroll
      for (int ind_i = x_i; ind_i < (TileX); ind_i += BlockReadRows) {
        tile_sm[ind_i][x_j] = input[input_ind];
        input_ind += input_inc;
      }
    } else {
      if (x_j < tile_width) {
#pragma unroll
170 171
        for (IndexType ind_i = x_i; ind_i < (tile_height);
             ind_i += BlockReadRows) {
172 173 174 175 176 177 178 179 180 181 182
          tile_sm[ind_i][x_j] = input[input_ind];
          input_ind += input_inc;
        }
      }
    }
  }

  __syncthreads();

  // Store sm value back to out
  Index3 output_block_tile_index = {
183 184
      input_block_tile_index[0],
      input_block_tile_index[2],
185 186 187 188
      input_block_tile_index[1],
  };

  Index3 block_tile_index_in_output = {
189 190
      output_block_tile_index[0],
      output_block_tile_index[1] * TileY,
191 192 193
      output_block_tile_index[2] * TileX,
  };

194
  IndexType output_origin_block_flat_index =
195
      FlatTensorIndex<IndexType>(block_tile_index_in_output, output_dims);
196
  constexpr IndexType out_effective_thread_num = NumThreads / TileX * TileX;
197 198 199 200

  if (x < out_effective_thread_num) {
    int x_i = x / TileX;
    int x_j = x % TileX;
201
    IndexType output_ind =
202
        output_origin_block_flat_index + x_i * output_dims[2] + x_j;
203
    IndexType output_inc = BlockWriteRows * output_dims[2];
204 205 206 207 208 209 210 211 212 213

    if (full_tile) {
#pragma unroll
      for (int ind_i = x_i; ind_i < (TileY); ind_i += BlockWriteRows) {
        output[output_ind] = tile_sm[x_j][ind_i];
        output_ind += output_inc;
      }
    } else {
      if (x_j < tile_height) {
#pragma unroll
214 215
        for (IndexType ind_i = x_i; ind_i < (tile_width);
             ind_i += BlockWriteRows) {
216 217 218 219 220 221 222 223 224 225 226 227
          output[output_ind] = tile_sm[x_j][ind_i];
          output_ind += output_inc;
        }
      }
    }
  }
}

// This function will find combination of long_side X short_side in backups
template <int TSIZE>
bool SelectProperTileSize(std::vector<std::pair<int, int>>* tiles) {
  PADDLE_ENFORCE_LE(
228 229
      TSIZE,
      16,
230
      phi::errors::InvalidArgument(
231 232 233
          "The tile size should smaller than 16, but received is:%d.", TSIZE));

  PADDLE_ENFORCE_EQ(
234 235
      (TSIZE & (TSIZE - 1)),
      0,
236
      phi::errors::InvalidArgument(
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
          "Data types should be powers of 2, but reived size is:%d.", TSIZE));

  const int kMaxLongSideLen = 1024;
  const int kMaxShortSideLen = 15;

  for (int long_side = 32; long_side <= kMaxLongSideLen; long_side *= 2) {
    for (int short_side = 2; short_side <= kMaxShortSideLen; short_side += 1) {
      if (CheckLongTileSize(long_side, short_side, TSIZE)) {
        tiles->push_back(std::make_pair(long_side, short_side));

        if (short_side == 2) return true;

        break;
      }
    }
  }
  return false;
}

// Use system built in type
template <int ByteSize>
struct SystemElemType;
template <>
struct SystemElemType<1> {
  using type = uint8_t;
};
template <>
struct SystemElemType<2> {
  using type = uint16_t;
};
template <>
struct SystemElemType<4> {
  using type = uint32_t;
};
template <>
struct SystemElemType<8> {
  using type = uint64_t;
};
template <>
struct SystemElemType<16> {
  using type = float4;
};

280
template <typename T, int tile_long, int tile_short, typename IndexType = int>
281 282 283
void LaunchNarrowDims2TransposeKernel(const phi::GPUContext& d,
                                      int tile_size_i,
                                      int tile_size_j,
284
                                      IndexType total_tiles_count,
285 286
                                      const T* input,
                                      const Dim3& input_dims,
H
hong 已提交
287
                                      T* output) {
288 289
  constexpr int NumThreads = tile_long;
  if (tile_size_i <= tile_long && tile_size_j <= tile_short) {
290
    TilingSwapDim1And2<T, NumThreads, tile_long, tile_short, IndexType>
291 292
        <<<total_tiles_count, NumThreads, 0, d.stream()>>>(
            input, input_dims, output);
293
  } else {
294
    TilingSwapDim1And2<T, NumThreads, tile_short, tile_long, IndexType>
295 296
        <<<total_tiles_count, NumThreads, 0, d.stream()>>>(
            input, input_dims, output);
297 298 299
  }
}

300 301 302
template <typename T,
          int tile_long,
          int tile_short,
303
          typename IndexType = int,
304
          typename dummy = void>
305
struct NarrowDims2TransposeDispatch {
306 307 308
  static void DoTranspose(const phi::GPUContext& d,
                          int tile_size_i,
                          int tile_size_j,
309
                          IndexType total_tiles_count,
310 311 312
                          const T* input,
                          const Dim3& input_dims,
                          T* output) {
313
    PADDLE_ENFORCE_EQ(
314 315
        (tile_long & (tile_long - 1)),
        0,
316
        phi::errors::InvalidArgument(
317 318 319 320 321 322 323 324
            "The length of the longer side of the tile should be power of 2."
            " But received value is:%d.",
            tile_long));

    bool request_satisfied = std::max(tile_size_i, tile_size_j) <= tile_long &&
                             std::min(tile_size_i, tile_size_j) <= tile_short;

    if (request_satisfied) {
325
      LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IndexType>(
326 327 328 329 330 331
          d,
          tile_size_i,
          tile_size_j,
          total_tiles_count,
          input,
          input_dims,
332 333 334 335 336 337 338 339
          output);
      return;
    }

    const bool long_side_request_not_satisfied =
        std::max(tile_size_i, tile_size_j) > tile_long;

    if (long_side_request_not_satisfied) {
340
      NarrowDims2TransposeDispatch<T, tile_long * 2, tile_short, IndexType>::
341 342 343 344 345 346 347
          DoTranspose(d,
                      tile_size_i,
                      tile_size_j,
                      total_tiles_count,
                      input,
                      input_dims,
                      output);
348
    } else {
349
      NarrowDims2TransposeDispatch<T, tile_long, tile_short + 1, IndexType>::
350 351 352 353 354 355 356
          DoTranspose(d,
                      tile_size_i,
                      tile_size_j,
                      total_tiles_count,
                      input,
                      input_dims,
                      output);
357 358 359 360 361
    }
  }
};

// If Not long tile size, goto this function when compile.
362
template <typename T, int tile_long, int tile_short, typename IndexType>
363
struct NarrowDims2TransposeDispatch<
364 365 366
    T,
    tile_long,
    tile_short,
367
    IndexType,
368 369 370 371 372 373
    typename std::enable_if<CheckNonLongTileSize(
                                tile_long, tile_short, sizeof(T)),
                            void>::type> {
  static void DoTranspose(const phi::GPUContext& d,
                          int tile_size_i,
                          int tile_size_j,
374
                          IndexType total_tiles_count,
375 376 377
                          const T* input,
                          const Dim3& input_dims,
                          T* output) {
378
    PADDLE_ENFORCE_EQ(
379 380
        (tile_long & (tile_long - 1)),
        0,
381
        phi::errors::InvalidArgument(
382 383 384 385 386 387 388 389
            "The length of the longer side of the tile should be power of 2."
            " But received value is:%d.",
            tile_long));

    bool request_satisfied = std::max(tile_size_i, tile_size_j) <= tile_long &&
                             std::min(tile_size_i, tile_size_j) <= tile_short;

    if (request_satisfied) {
390
      LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IndexType>(
391 392 393 394 395 396
          d,
          tile_size_i,
          tile_size_j,
          total_tiles_count,
          input,
          input_dims,
397 398 399 400
          output);
      return;
    }

401
    NarrowDims2TransposeDispatch<T, tile_long, tile_short + 1, IndexType>::
402 403 404 405 406 407 408
        DoTranspose(d,
                    tile_size_i,
                    tile_size_j,
                    total_tiles_count,
                    input,
                    input_dims,
                    output);
409 410 411 412
  }
};

// If long tile size, goto this function when compile.
413
template <typename T, int tile_long, int tile_short, typename IndexType>
414
struct NarrowDims2TransposeDispatch<
415 416 417
    T,
    tile_long,
    tile_short,
418
    IndexType,
419 420
    typename std::enable_if<CheckLongTileSize(tile_long, tile_short, sizeof(T)),
                            void>::type> {
421 422 423
  static void DoTranspose(const phi::GPUContext& d,
                          int tile_size_i,
                          int tile_size_j,
424
                          IndexType total_tiles_count,
425 426 427
                          const T* input,
                          const Dim3& input_dims,
                          T* output) {
428
    PADDLE_ENFORCE_EQ(
429 430
        (tile_long & (tile_long - 1)),
        0,
431
        phi::errors::InvalidArgument(
432 433 434 435
            "The length of the longer side of the tile should be power of 2,"
            " but received is:%d.",
            tile_long));

436
    LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IndexType>(
437 438 439 440 441 442
        d,
        tile_size_i,
        tile_size_j,
        total_tiles_count,
        input,
        input_dims,
443 444 445 446
        output);
  }
};

447
template <typename T, bool conjugate = false, typename IndexType = int>
448 449 450 451
void SwapDim1And2InNarrow(const phi::GPUContext& d,
                          const T* input,
                          const Dim3& input_dims,
                          T* output,
452 453 454 455 456
                          const int kMinTileSize) {
  // First get available tile sizes for the data type requested as backups
  std::vector<std::pair<int, int>> tile_sele;
  auto ret = SelectProperTileSize<sizeof(T)>(&tile_sele);
  PADDLE_ENFORCE_EQ(
457 458
      ret,
      true,
459
      phi::errors::InvalidArgument(
460 461 462 463 464 465 466 467 468 469 470 471 472 473
          "SelectProperTileSize should return true, but return value is:%d.",
          ret));

  int tile_long_edge = 0;
  int tile_short_edge = 0;
  float lowest_cost = std::numeric_limits<float>::max();
  int input_long_edge = std::max(input_dims[1], input_dims[2]);

  // Find the tile size that best suit in  inputs.
  for (auto tile_size_pair : tile_sele) {
    int proposed_tile_long_edge = tile_size_pair.first;
    // data may not aligned to tile, so some threads wasted, we need
    // to find least wasted threads, which means we need to find tile
    // can split input properly, in another words: num_wasted_threads=0.
474
    int num_wasted_threads =
475 476 477
        input_long_edge -
        CeilOrFloor<int, false>(input_long_edge, proposed_tile_long_edge) *
            proposed_tile_long_edge;
478

479
    int num_full_tiles =
480
        CeilOrFloor<int, false>(input_long_edge, proposed_tile_long_edge);
481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513

    float cost = num_wasted_threads;

    if (cost <= lowest_cost) {
      tile_long_edge = proposed_tile_long_edge;
      tile_short_edge = tile_size_pair.second;
      lowest_cost = cost;
    }
    // break as we already find best tile size.
    if (cost == 0) break;
  }

  // The tile size we select should be match with input dim, long side to long
  // short side to short.
  // First set long side  as i if dim1 > Tile min size, then set dim2 as j.
  int select_tile_size_i =
      input_dims[1] >= kMinTileSize ? tile_long_edge : input_dims[1];
  int select_tile_size_j =
      input_dims[1] >= kMinTileSize ? input_dims[2] : tile_long_edge;

  // Check if i is long edge, if not set i as short.
  select_tile_size_i = select_tile_size_i == tile_long_edge
                           ? tile_long_edge
                           : std::min(select_tile_size_i, tile_short_edge);

  // Check if j is long edge, if not set j as short.
  select_tile_size_j = select_tile_size_j == tile_long_edge
                           ? tile_long_edge
                           : std::min(select_tile_size_j, tile_short_edge);

  // Here finally get proper long X short tile size.
  Dim3 input_dims_aligned = {
      input_dims[0],
514 515
      CeilOrFloor<int, true>(input_dims[1], select_tile_size_i),
      CeilOrFloor<int, true>(input_dims[2], select_tile_size_j),
516 517
  };

518
  IndexType total_tiles_count = input_dims_aligned[0];
519 520
  total_tiles_count *= input_dims_aligned[1];
  total_tiles_count *= input_dims_aligned[2];
521 522 523 524

  // Suppose T can be replaced by system builtin types
  using ElemType = typename SystemElemType<sizeof(T)>::type;

525
  NarrowDims2TransposeDispatch<ElemType, 32, 2, IndexType>::DoTranspose(
526 527 528 529 530 531
      d,
      select_tile_size_i,
      select_tile_size_j,
      total_tiles_count,
      reinterpret_cast<const ElemType*>(input),
      input_dims,
532 533 534 535 536
      reinterpret_cast<ElemType*>(output));
}

// This is for case that cannot do coalescing read and write.
// Or input is too small to split into tiles.
537 538
template <typename T, int pos0, int pos1, int pos2, typename IndexType = int>
__global__ void TransposeSimpleKernel(IndexType nthreads,
539 540 541
                                      const T* __restrict__ input,
                                      Dim3 input_dims,
                                      T* __restrict__ output) {
542 543 544 545 546
  Dim3 output_dims;
  output_dims[pos0] = input_dims[0];
  output_dims[pos1] = input_dims[1];
  output_dims[pos2] = input_dims[2];

547
  CUDA_KERNEL_LOOP_TYPE(output_index, nthreads, IndexType) {
548
    Index3 output_tensor_index =
549
        ConvertTensorIndex<IndexType>(output_index, output_dims);
550 551 552 553 554 555

    Index3 input_tensor_index;
    input_tensor_index[0] = output_tensor_index[pos0];
    input_tensor_index[1] = output_tensor_index[pos1];
    input_tensor_index[2] = output_tensor_index[pos2];

556
    IndexType input_index =
557
        FlatTensorIndex<IndexType>(input_tensor_index, input_dims);
558 559 560 561 562 563

    output[output_index] = input[input_index];
  }
}

// Here suppose convert all tensor to dim3, so just change dim1 and 2.
564
template <typename T, typename IndexType = int>
565 566 567 568
void SendSwapDim1And2InTranspose(const phi::GPUContext& d,
                                 const T* input,
                                 const Dim3& input_dims,
                                 T* output) {
569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
  // Suppose tile size > 16
  static const int kMinTileSize = 16;
  static const int kMinNarrowTileSize = 96;

  bool large_tile =
      input_dims[1] >= kMinTileSize && input_dims[2] >= kMinTileSize;
  bool narrow_tile = input_dims[1] >= kMinNarrowTileSize ||
                     input_dims[2] >= kMinNarrowTileSize;
  if (large_tile) {
    // If input is large square, such as 32X32, use SM to do copy.
    // suppose 32 X 32 gives best performance, and 8 warp in block.
    constexpr int kTileSize = 32;
    constexpr int kNumThreads = 256;

    Dim3 input_dims_aligned = {
        input_dims[0],
585 586
        CeilOrFloor<int, true>(input_dims[1], kTileSize),
        CeilOrFloor<int, true>(input_dims[2], kTileSize),
587 588
    };

589
    IndexType total_tiles_count = input_dims_aligned[0];
590 591
    total_tiles_count *= input_dims_aligned[1];
    total_tiles_count *= input_dims_aligned[2];
592

593
    TilingSwapDim1And2<T, kNumThreads, kTileSize, kTileSize, IndexType>
594 595
        <<<total_tiles_count, kNumThreads, 0, d.stream()>>>(
            input, input_dims, output);
596 597 598 599 600

  } else if (narrow_tile) {
    // If input shape is like Rect, such as 2X100, use Narrow tile size.
    // It makes things complicated, because need to find a tile can coverr
    // input and also reach best coalescing.
601
    SwapDim1And2InNarrow<T, false, IndexType>(
602
        d, input, input_dims, output, kMinTileSize);
603 604
  } else {
    // If input shape is small, such as 8X8, just do simple copy
605
    IndexType total_elements = input_dims[0];
606 607
    total_elements *= input_dims[1];
    total_elements *= input_dims[2];
H
hong 已提交
608
    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_elements);
609
    TransposeSimpleKernel<T, 0, 2, 1, IndexType>
610 611
        <<<config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>(
            total_elements, input, input_dims, output);
612 613 614
  }
}

615
template <typename T, typename IndexType = int>
616
struct SwapDim1And2InTranspose {
H
hong 已提交
617
  typedef phi::GPUContext Device;
618 619 620 621
  void operator()(const Device& d,
                  const T* in,
                  const std::vector<int>& combined_dims,
                  T* out) {
622 623 624
    Dim3 input_dims = {static_cast<int>(combined_dims[0]),
                       static_cast<int>(combined_dims[1]),
                       static_cast<int>(combined_dims[2])};
625
    SendSwapDim1And2InTranspose<T, IndexType>(d, in, input_dims, out);
626 627 628
  }
};

629
template <typename T, typename IndexType = int>
630
struct SwapDim0And2InTranspose {
H
hong 已提交
631
  typedef phi::GPUContext Device;
632 633 634 635
  void operator()(const Device& d,
                  const T* in,
                  const std::vector<int>& combined_dims,
                  T* out) {
636 637 638 639
    Dim3 input_dims = {static_cast<int>(combined_dims[0]),
                       static_cast<int>(combined_dims[1]),
                       static_cast<int>(combined_dims[2])};

640
    IndexType total_size = combined_dims[0];
641 642
    total_size *= combined_dims[1];
    total_size *= combined_dims[2];
H
hong 已提交
643
    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_size);
644

645
    TransposeSimpleKernel<T, 2, 1, 0, IndexType>
646 647
        <<<config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>(
            total_size, in, input_dims, out);
648 649 650 651 652
  }
};

// This function is to combine dimension. fox example:
// (0, 1, 3, 2) --> (0, 2, 1)
653
inline void CombineTransposeDim3(const DDim& shape,
654 655
                                 const std::vector<int>& perm,
                                 std::vector<int>* new_perm,
656
                                 std::vector<int>* new_dims) {
657 658
  PADDLE_ENFORCE_EQ(shape.size(),
                    perm.size(),
659
                    phi::errors::InvalidArgument(
660 661
                        " shape should have the save dim with perm, but"
                        " received shape size is:%d, perm size is:%d.",
662 663
                        shape.size(),
                        perm.size()));
664 665 666 667 668 669 670

  std::vector<int> dim_vec;
  if (shape.size() == 1) {
    // If input dimension is already 1, no need to combine dim.
    new_perm->resize(1);
    (*new_perm)[0] = perm[0];
    dim_vec.push_back(shape[0]);
671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690
  } else {
    int dim_idx = 0;
    std::vector<int> new_dim_pos(shape.size(), -1);
    std::vector<int> combined_dims(shape.size(), 0);

    int cur_head = perm[0];
    new_dim_pos[cur_head] = 0;
    combined_dims[0] = shape[cur_head];
    for (int perm_idx = 1; perm_idx < shape.size(); ++perm_idx) {
      // combine consecutive dimensions.
      if (cur_head + 1 == perm[perm_idx]) {
        cur_head = perm[perm_idx];
        combined_dims[dim_idx] *= shape[cur_head];
      } else {
        // Else start a new dimension.
        cur_head = perm[perm_idx];
        dim_idx++;
        new_dim_pos[cur_head] = dim_idx;
        combined_dims[dim_idx] = shape[cur_head];
      }
691
    }
692 693 694 695 696 697 698 699 700 701
    new_perm->resize(dim_idx + 1);

    dim_idx = 0;
    for (int i = 0; i < new_dim_pos.size(); ++i) {
      if (new_dim_pos[i] >= 0) {
        int new_perm_idx = new_dim_pos[i];
        (*new_perm)[dim_idx] = new_perm_idx;
        dim_vec.push_back(combined_dims[new_perm_idx]);
        dim_idx++;
      }
702 703
    }
  }
704
  *new_dims = dim_vec;
705 706
}

707
template <typename T>
708
struct TransposeSimple {
709 710 711 712 713
  static bool Run(const phi::GPUContext& ctx,
                  const phi::DenseTensor& in,
                  const std::vector<int32_t>& perm,
                  phi::DenseTensor* out,
                  const int64_t numel) {
714
    if (numel >= std::numeric_limits<int32_t>::max()) {
715
      return RunImpl<int64_t>(ctx, in, perm, out);
716
    } else {
717
      return RunImpl<int32_t>(ctx, in, perm, out);
718 719 720 721 722
    }
  }

 private:
  template <typename IndexType = int32_t>
723 724 725 726
  static bool RunImpl(const phi::GPUContext& ctx,
                      const phi::DenseTensor& in,
                      const std::vector<int32_t>& perm,
                      phi::DenseTensor* out) {
727
    // First reduce the dimensions of the input tensor if possible.
728 729
    auto in_data = in.data<T>();
    auto out_data = out->data<T>();
730
    std::vector<int> new_perm;
731
    std::vector<int> new_dims;
732
    CombineTransposeDim3(in.dims(), perm, &new_perm, &new_dims);
733
    if (new_perm.size() < 2 || new_perm.size() > 3) return false;
734 735

    // In most cases, dim will not greater than 3 after combine.
736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751
    if (new_perm.size() == 2 && new_perm[1] == 0) {
      // Add the first dimension size as 1.
      new_dims.insert(new_dims.begin(), 1);
      SwapDim1And2InTranspose<T, IndexType>()(ctx, in_data, new_dims, out_data);
      return true;
    } else if (new_perm == std::vector<int>({0, 2, 1})) {
      SwapDim1And2InTranspose<T, IndexType>()(ctx, in_data, new_dims, out_data);
      return true;
    } else if (new_perm == std::vector<int>({2, 1, 0})) {
      // Maybe can optimize later, find a way to do coalescing memory copy.
      // But I think it depends on the data size. If span is not large,
      // maybe can do coalescing.
      SwapDim0And2InTranspose<T, IndexType>()(ctx, in_data, new_dims, out_data);
      return true;
    } else {
      return false;
752 753 754 755
    }
  }
};

756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871
enum PermuteType {
  kCopy = 1,
  kSwapTranspose = 2,
  kGeneralTranspose = 3,
  kVecPermute = 4,
  kGeneralPermute = 5
};

constexpr int kBlockRows = 16;
constexpr int kTileSize = 32;
constexpr int kShareCol = (kTileSize + 1);

#define GET_TILE_SIZE(LEN_, ALIGN_) \
  ((LEN_ + (ALIGN_ - 1)) & ~(ALIGN_ - 1)) / ALIGN_

template <typename T>
struct PermTypeClassifier {
 public:
  PermTypeClassifier(const int sm_count,
                     const int rank,
                     const std::vector<int32_t>& perm,
                     const std::vector<int64_t>& dims,
                     const T* src,
                     T* dst) {
    if (rank == 1) {
      type_ = PermuteType::kCopy;
    } else {
      // Limitation of the setting in one dimension of cuda grid.
      constexpr int64_t dim_limitation = 65536;
      int dst_vec_size = phi::GetVectorizedSize<T>(dst);

      // While the last dim is fixed, there is chance for vectorized IO.
      const int last_idx = rank - 1;
      if (perm[last_idx] == last_idx) {
        type_ = PermuteType::kVecPermute;
        vec_size_ = GetDimVecSize(dst_vec_size, dims[last_idx], src, false);
        return;
      }

      // Permute at last 2 dims, namely transpose.
      if ((rank == 2 && perm[1] == 0) ||
          (rank == 3 && perm[2] == 1 && perm[1] == 2)) {
        int64_t channel = rank == 2 ? 1 : dims[0];
        // Currently, transpose kernel cannot cover the case that channel
        // dimension is more than 65536 which is the limitation of dim3 setting.
        // This special case will be covered by extended transpose kernel later.
        if (channel < dim_limitation) {
          type_ = PermuteType::kGeneralTranspose;
          num_rows_tile_ = GET_TILE_SIZE(dims[rank - 2], kTileSize);
          int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src);
          int tile_size = channel * num_rows_tile_ *
                          GET_TILE_SIZE(dims[last_idx], kTileSize);
          vec_size_ = tile_size < sm_count ? 1 : dim_vec_size;
        } else {
          type_ = PermuteType::kGeneralPermute;
        }
        return;
      }

      // Permute at first dim and third dim.
      if (rank == 3 && perm[2] == 0 && perm[1] == 1) {
        // Currently, transpose kernel cannot cover the case that channel
        // dimension is more than 65536 which is the limitation of dim3 setting.
        // This special case will be covered by extended transpose kernel later.
        if (dims[1] < dim_limitation) {
          type_ = PermuteType::kSwapTranspose;
          num_rows_tile_ = GET_TILE_SIZE(dims[0], kTileSize);

          int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src);
          int tile_size =
              dims[1] * num_rows_tile_ * GET_TILE_SIZE(dims[2], kTileSize);
          vec_size_ = tile_size < sm_count ? 1 : dim_vec_size;
        } else {
          type_ = PermuteType::kGeneralPermute;
        }
        return;
      }
      vec_size_ = dst_vec_size;
    }
  }

  ~PermTypeClassifier() = default;

  int GetVecSize() const { return vec_size_; }
  int GetRowsTile() const { return num_rows_tile_; }
  PermuteType GetPermType() const { return type_; }

 private:
  int vec_size_{1};
  int64_t num_rows_tile_{0};
  PermuteType type_{kGeneralPermute};

  // To find if highest common divisor and make it as vec_size.
  int GetDimVecSize(const int dst_vec_size,
                    const int64_t target_dim,
                    const T* src,
                    bool use_share_mem = true) {
    int vec_size = std::min(dst_vec_size, phi::GetVectorizedSize<T>(src));
    int dim_vec_size = 1;
    for (int size = vec_size; size > 0; size /= 2) {
      if (target_dim % size == 0) {
        dim_vec_size = size;
        break;
      }
    }

    if (use_share_mem) {
      // By bytes limitation of shared_memory.
      return (sizeof(T) > sizeof(float) ? 1 : dim_vec_size);
    } else {
      return dim_vec_size;
    }
  }
};

template <typename IndexT, int Rank>
872 873 874
class IdxHelper {
 public:
  IdxHelper() {}
875
  explicit IdxHelper(const IndexT* dims) {
876 877
    for (int i = Rank - 1; i >= 0; --i) {
      stride_[i] = i < (Rank - 1) ? dims[i + 1] * stride_[i + 1] : 1;
878 879 880
    }
  }

881 882 883
  __device__ __forceinline__ IndexT GetStride(int idx) const {
    return stride_[idx];
  }
884

885 886 887
  __device__ __forceinline__ void GetIndexFromOffset(IndexT offset,
                                                     IndexT* index) const {
    IndexT remaining = offset;
888
#pragma unroll
889
    for (int i = 0; i < Rank - 1; ++i) {
890
      const IndexT idx = remaining / stride_[i];
891 892 893
      remaining -= idx * stride_[i];
      index[i] = idx;
    }
894
    index[Rank - 1] = remaining;
895 896 897
  }

 private:
898
  IndexT stride_[Rank];
899 900
};

901 902
template <int Rank>
class IdxHelper<uint32_t, Rank> {
903 904 905
 public:
  IdxHelper() {}
  explicit IdxHelper(const uint32_t* dims) {
906 907
    for (int i = Rank - 1; i >= 0; --i) {
      uint32_t value = i < (Rank - 1) ? dims[i + 1] * stride_[i + 1] : 1;
908
      divmoder_[i] = phi::kps::details::FastDivMod(value);
909 910 911 912
      stride_[i] = value;
    }
  }

913 914 915
  __device__ __forceinline__ uint32_t GetStride(int idx) const {
    return stride_[idx];
  }
916

917 918
  __device__ __forceinline__ void GetIndexFromOffset(uint32_t offset,
                                                     uint32_t* index) const {
919 920
    uint32_t remaining = offset;
#pragma unroll
921
    for (int i = 0; i < Rank - 1; ++i) {
922 923 924 925
      uint32_t idx = divmoder_[i].Div(remaining);
      index[i] = idx;
      remaining -= idx * stride_[i];
    }
926
    index[Rank - 1] = remaining;
927 928 929
  }

 private:
930 931
  uint32_t stride_[Rank];
  phi::kps::details::FastDivMod divmoder_[Rank];
932 933 934
};

// Transform index between memory offset and shape coodinate.
935
template <typename IndexT, int Rank>
936 937 938
class IdxAndOffsetHelper {
 public:
  IdxAndOffsetHelper() {}
939
  explicit IdxAndOffsetHelper(const IndexT* dims) {
940
    index_helper = IdxHelper<IndexT, Rank>(dims);
941 942
  }

943 944
  __device__ __forceinline__ IndexT IndexToOffset(const IndexT* index) const {
    IndexT offset = 0;
945
#pragma unroll
946
    for (int i = 0; i < Rank - 1; ++i) {
947 948
      offset += index[i] * index_helper.GetStride(i);
    }
949
    offset += index[Rank - 1];
950 951 952
    return offset;
  }

953 954
  __device__ __forceinline__ void OffsetToIndex(IndexT offset,
                                                IndexT* index) const {
955 956 957 958
    index_helper.GetIndexFromOffset(offset, index);
  }

 private:
959
  IdxHelper<IndexT, Rank> index_helper;
960 961
};

962
template <typename IndexT, int Rank>
963 964 965 966 967 968
struct PermuteParams {
 public:
  IdxAndOffsetHelper<IndexT, Rank> src_index_helper;
  IdxAndOffsetHelper<IndexT, Rank> dst_index_helper;
  int perm[Rank]{};

L
limingshu 已提交
969
  explicit PermuteParams(const std::vector<int64_t>& dims,
970
                         const std::vector<int>& perm_) {
L
limingshu 已提交
971 972 973 974
    IndexT dst_dims[Rank];
    IndexT src_dims[Rank];
    for (auto i = 0; i < Rank; ++i) {
      src_dims[i] = dims[i];
975 976 977 978
      dst_dims[i] = dims[perm_[i]];
      perm[i] = perm_[i];
    }
    dst_index_helper = IdxAndOffsetHelper<IndexT, Rank>(dst_dims);
L
limingshu 已提交
979
    src_index_helper = IdxAndOffsetHelper<IndexT, Rank>(src_dims);
980 981 982 983 984
  }
};

// A special kernel for target case, both vectorized read and write supported.
template <typename T, typename IndexT, int VecSize, int Rank>
985 986
__global__ void VectorizedPermuteKernel(PermuteParams<IndexT, Rank> params,
                                        const IndexT count,
987 988 989 990 991 992
                                        const T* __restrict__ src_data,
                                        T* dst_data) {
  using VecT = phi::AlignedVector<T, VecSize>;
  IndexT src_index[Rank];
  IndexT dst_index[Rank];

993
  const VecT* __restrict__ vec_src =
994
      reinterpret_cast<const VecT* __restrict__>(src_data);
995
  VecT* vec_dst = reinterpret_cast<VecT*>(dst_data);
996 997 998 999 1000 1001 1002 1003 1004 1005

  IndexT tid = blockIdx.x * blockDim.x + threadIdx.x;
  for (IndexT i = tid; i < count; i += blockDim.x * gridDim.x) {
    params.dst_index_helper.OffsetToIndex(i, dst_index);

#pragma unroll
    for (int j = 0; j < Rank; ++j) {
      src_index[params.perm[j]] = dst_index[j];
    }
    IndexT src_offset = params.src_index_helper.IndexToOffset(src_index);
1006
    vec_dst[i] = vec_src[src_offset];
1007 1008 1009 1010 1011
  }
}

// A general kernel for normal case, only support vectorized write.
template <typename T, typename IndexT, int VecSize, int Rank>
1012 1013 1014 1015
__global__ void GeneralPermuteKernel(PermuteParams<IndexT, Rank> params,
                                     const IndexT main_cnt,
                                     const IndexT tail_cnt,
                                     const IndexT offset,
1016
                                     const T* __restrict__ src,
1017
                                     T* dst) {
1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034
  using VecT = phi::AlignedVector<T, VecSize>;
  VecT* vec_dst = reinterpret_cast<VecT*>(dst);
  IndexT src_index[VecSize][Rank];
  IndexT dst_index[VecSize][Rank];

  // Vectorized load data.
  IndexT tid = blockIdx.x * blockDim.x + threadIdx.x;
  for (IndexT idx = tid; idx < main_cnt; idx += blockDim.x * gridDim.x) {
    VecT vec_data;
    IndexT vec_idx = idx * VecSize;

#pragma unroll
    for (int i = 0; i < VecSize; ++i) {
      params.dst_index_helper.OffsetToIndex(vec_idx + i, dst_index[i]);

#pragma unroll
      for (int j = 0; j < Rank; ++j) {
1035
        src_index[i][params.perm[j]] = dst_index[i][j];
1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049
      }
      IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[i]);
      vec_data[i] = src[src_offset];
    }
    vec_dst[idx] = vec_data;
  }

  // Singularized load data.
  if (tid < tail_cnt) {
    IndexT idx = tid + offset;
    params.dst_index_helper.OffsetToIndex(idx, dst_index[0]);

#pragma unroll
    for (int j = 0; j < Rank; ++j) {
1050
      src_index[0][params.perm[j]] = dst_index[0][j];
1051 1052 1053 1054 1055 1056
    }
    IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[0]);
    dst[idx] = src[src_offset];
  }
}

1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097
template <typename T, typename IndexT, int ReadSize, int WriteSize>
struct TransposeDataWriter {
  __device__ __forceinline__ void operator()(T* dst_data,
                                             const T* s_data,
                                             const IndexT rows,
                                             const IndexT cols,
                                             const IndexT chs_stride,
                                             const IndexT round_tile_cols,
                                             const IndexT col_stride = 1) {
    using OutVecT = phi::AlignedVector<T, WriteSize>;
    OutVecT* vec_dst = reinterpret_cast<OutVecT*>(dst_data);

    constexpr int kColTile = kTileSize * ReadSize;
    constexpr int kColStride = kShareCol * ReadSize;

    const IndexT vec_rows = rows / WriteSize;
    const IndexT col_in_mat = blockIdx.y * kTileSize + threadIdx.x;

    if (col_in_mat < /*dst_cols=*/vec_rows) {
      const int cols_range = (blockIdx.x < round_tile_cols)
                                 ? kTileSize
                                 : (cols - round_tile_cols * kTileSize);
      const int share_tile = threadIdx.x * (WriteSize * kColStride);
      const IndexT write_offset = blockIdx.z * chs_stride + col_in_mat;
#pragma unroll
      for (int tile_y = threadIdx.y; tile_y < cols_range;
           tile_y += kBlockRows) {
        OutVecT tmp_data[ReadSize];
#pragma unroll
        for (int i = 0; i < ReadSize; ++i) {
          int tile_tail = tile_y * ReadSize + i;
          int major_share_idx = share_tile + tile_tail;
          IndexT row_in_mat = (blockIdx.x * kColTile + tile_tail) * col_stride;
#pragma unroll
          for (int j = 0; j < WriteSize; ++j) {
            tmp_data[i].val[j] = s_data[j * kColStride + major_share_idx];
          }
          vec_dst[write_offset + row_in_mat * vec_rows] = tmp_data[i];
        }
      }
    }
1098
  }
1099
};
1100

1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129
template <typename T, typename IndexT, int ReadSize>
struct TransposeDataWriter<T, IndexT, ReadSize, 1> {
  __device__ __forceinline__ void operator()(T* dst_data,
                                             const T* s_data,
                                             const IndexT rows,
                                             const IndexT cols,
                                             const IndexT chs_stride,
                                             const IndexT round_tile_cols,
                                             const IndexT col_stride = 1) {
    const IndexT col_in_mat = blockIdx.y * kTileSize + threadIdx.x;
    if (col_in_mat < /*dst_cols=*/rows) {
      const int cols_range = (blockIdx.x < round_tile_cols)
                                 ? kTileSize
                                 : (cols - round_tile_cols * kTileSize);
      const IndexT row_tile = blockIdx.x * kTileSize * ReadSize;
      const IndexT write_offset = blockIdx.z * chs_stride + col_in_mat;
      const int shared_tile = threadIdx.x * kShareCol * ReadSize;
#pragma unroll
      for (int tile_y = threadIdx.y; tile_y < cols_range;
           tile_y += kBlockRows) {
        const int shared_major = shared_tile + tile_y * ReadSize;
        const IndexT row_major = (row_tile + tile_y * ReadSize) * col_stride;
#pragma unroll
        for (int i = 0; i < ReadSize; ++i) {
          const IndexT row_in_mat = row_major + i * col_stride;
          dst_data[write_offset + row_in_mat * rows] = s_data[shared_major + i];
        }
      }
    }
1130
  }
1131
};
1132

1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160
template <typename T, typename IndexT, int VecSize, IndexT RowTile>
struct TransposeDataReader {
  __device__ __forceinline__ void operator()(const T* __restrict__ src,
                                             T* s_shared,
                                             const IndexT cols,
                                             const IndexT rows,
                                             const IndexT chs_stride,
                                             const IndexT cols_thresh,
                                             const IndexT round_tile_rows) {
    using VecT = phi::AlignedVector<T, VecSize>;
    const VecT* __restrict__ v_src =
        reinterpret_cast<const VecT* __restrict__>(src);
    VecT* v_shared = reinterpret_cast<VecT*>(s_shared);

    const IndexT col_in_mat = blockIdx.x * kTileSize + threadIdx.x;
    if (col_in_mat < cols_thresh) {
      const int row_range = (blockIdx.y < round_tile_rows)
                                ? RowTile
                                : (rows - RowTile * round_tile_rows);
      const IndexT src_idx_major = blockIdx.z * chs_stride + col_in_mat;
#pragma unroll
      for (int tile_y = threadIdx.y; tile_y < row_range; tile_y += kBlockRows) {
        const IndexT row_in_mat = blockIdx.y * RowTile + tile_y;
        v_shared[tile_y * kShareCol + threadIdx.x] =
            v_src[row_in_mat * cols + src_idx_major];
      }
    }
    __syncthreads();
1161
  }
1162
};
1163

L
limingshu 已提交
1164
// Aim at transposing the last 2 dimensions. Reference from
1165
// https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/
1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196
template <typename T,
          typename IndexT,
          bool IsVecWrite,
          int ReadSize,
          int WriteSize = (IsVecWrite && (sizeof(T) < sizeof(float)))
                              ? sizeof(float) / sizeof(T)
                              : 1>
__global__ void SwapTransposeKernel(const T* __restrict__ src_data,
                                    T* dst_data,
                                    const IndexT round_tile_rows,
                                    const IndexT round_tile_cols,
                                    const IndexT cols,
                                    const IndexT rows,
                                    const IndexT chs /*=channel*/) {
  constexpr int kRowTile = kTileSize * WriteSize;
  __shared__ T s_data[kRowTile * kShareCol * ReadSize];

  const IndexT chs_stride = chs * cols;
  TransposeDataReader<T, IndexT, ReadSize, kRowTile>()(
      src_data, s_data, chs_stride, rows, cols, cols, round_tile_rows);
  TransposeDataWriter<T, IndexT, ReadSize, WriteSize>()(
      dst_data, s_data, rows, cols, rows / WriteSize, round_tile_cols, chs);
}

template <typename T,
          typename IndexT,
          bool IsVecWrite,
          int ReadSize,
          int WriteSize = (IsVecWrite && (sizeof(T) < sizeof(float)))
                              ? sizeof(float) / sizeof(T)
                              : 1>
1197
__global__ void BatchTransposeKernel(const T* __restrict__ src_data,
1198
                                     T* dst_data,
1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216
                                     const IndexT round_tile_rows,
                                     const IndexT round_tile_cols,
                                     const IndexT cols,
                                     const IndexT rows) {
  constexpr int kRowTile = kTileSize * WriteSize;
  __shared__ T s_data[kRowTile * kShareCol * ReadSize];

  const IndexT chs_stride = rows * cols;
  TransposeDataReader<T, IndexT, ReadSize, kRowTile>()(
      src_data, s_data, cols, rows, chs_stride, cols, round_tile_rows);
  TransposeDataWriter<T, IndexT, ReadSize, WriteSize>()(
      dst_data,
      s_data,
      rows,
      cols,
      chs_stride * ReadSize / WriteSize,
      round_tile_cols);
}
1217

1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235
template <typename T, typename IndexT, int VecSize>
struct PermuteLauncher {
 public:
  void operator()(const phi::GPUContext& ctx,
                  const int& rank,
                  const IndexT& count,
                  const PermuteType& perm_type,
                  const std::vector<int64_t>& dims,
                  const std::vector<int32_t>& perm,
                  const T* src,
                  T* dst) {
    dims_ = dims;
    main_cnt_ = count / VecSize;
#define CALL_PERMUTE_DISPATCH_RANK(rank_)              \
  case rank_: {                                        \
    Run<rank_>(ctx, perm, perm_type, count, src, dst); \
    break;                                             \
  }
1236

1237 1238 1239 1240 1241 1242 1243 1244
    switch (rank) {
      CALL_PERMUTE_DISPATCH_RANK(3);
      CALL_PERMUTE_DISPATCH_RANK(4);
      CALL_PERMUTE_DISPATCH_RANK(5);
      CALL_PERMUTE_DISPATCH_RANK(6);
      CALL_PERMUTE_DISPATCH_RANK(7);
      CALL_PERMUTE_DISPATCH_RANK(8);
      CALL_PERMUTE_DISPATCH_RANK(9);
1245
    }
1246
#undef CALL_PERMUTE_DISPATCH_RANK
1247 1248
  }

1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271
 private:
  IndexT main_cnt_{0};
  std::vector<int64_t> dims_;

  template <int Rank>
  void Run(const phi::GPUContext& ctx,
           const std::vector<int32_t>& perm,
           const PermuteType& perm_type,
           const IndexT& count,
           const T* src,
           T* dst) {
    auto cfg = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, main_cnt_);
    if (perm_type == PermuteType::kVecPermute) {
      dims_[Rank - 1] /= VecSize;
      const auto params = PermuteParams<IndexT, Rank>(dims_, perm);

      VectorizedPermuteKernel<T, IndexT, VecSize, Rank>
          <<<cfg.block_per_grid, cfg.thread_per_block, 0, ctx.stream()>>>(
              params, main_cnt_, src, dst);
    } else {
      IndexT tail_cnt = count - main_cnt_ * VecSize;
      IndexT main_offset = count - tail_cnt;
      const auto params = PermuteParams<IndexT, Rank>(dims_, perm);
1272

1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291
      GeneralPermuteKernel<T, IndexT, VecSize, Rank>
          <<<cfg.block_per_grid, cfg.thread_per_block, 0, ctx.stream()>>>(
              params, main_cnt_, tail_cnt, main_offset, src, dst);
    }
  }
};

template <typename T, typename IndexT, int VecSize>
struct TransposeLauncher {
 public:
  void operator()(const phi::GPUContext& ctx,
                  const int& rank,
                  const PermuteType& perm_type,
                  const std::vector<int64_t>& dims,
                  const IndexT& num_rows_tile,
                  const T* src,
                  T* dst) {
    constexpr int ReadSize = sizeof(T) > sizeof(float) ? 1 : VecSize;
    const IndexT cols = dims[rank - 1] / VecSize;
1292
    const IndexT n_cols_tile = GET_TILE_SIZE(cols, kTileSize);
1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326

    if (perm_type == PermuteType::kGeneralTranspose) {
      IndexT chs = (rank == 2) ? 1 : dims[0];
      IndexT rows = dims[rank - 2];
      IndexT n_rows_tile =
          FindRowTiles(chs, rows, num_rows_tile, n_cols_tile, ctx.GetSMCount());
      dim3 blocks(n_cols_tile, n_rows_tile, chs);
      dim3 threads(kTileSize, kBlockRows, 1);

      if (is_vec_write) {
        BatchTransposeKernel<T, IndexT, true, ReadSize>
            <<<blocks, threads, 0, ctx.stream()>>>(
                src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows);
      } else {
        BatchTransposeKernel<T, IndexT, false, ReadSize>
            <<<blocks, threads, 0, ctx.stream()>>>(
                src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows);
      }
    } else {
      IndexT rows = dims[0];
      IndexT chs = dims[rank - 2];
      IndexT n_rows_tile =
          FindRowTiles(chs, rows, num_rows_tile, n_cols_tile, ctx.GetSMCount());
      dim3 blocks(n_cols_tile, n_rows_tile, chs);
      dim3 threads(kTileSize, kBlockRows, 1);

      if (is_vec_write) {
        SwapTransposeKernel<T, IndexT, true, ReadSize>
            <<<blocks, threads, 0, ctx.stream()>>>(
                src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows, chs);
      } else {
        SwapTransposeKernel<T, IndexT, false, ReadSize>
            <<<blocks, threads, 0, ctx.stream()>>>(
                src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows, chs);
1327 1328 1329 1330
      }
    }
  }

1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347
 private:
  bool is_vec_write{false};
  inline IndexT FindRowTiles(const IndexT& chs,
                             const IndexT& rows,
                             const IndexT& num_rows_tile,
                             const IndexT& num_cols_tile,
                             const int& sm_count) {
    constexpr int kVecRow = sizeof(float) / sizeof(T);
    is_vec_write =
        (sizeof(T) < sizeof(float)) ? ((rows % kVecRow) ? false : true) : false;

    int vec_write = 1;
    if (is_vec_write) {
      is_vec_write = (chs * num_cols_tile * num_rows_tile) > sm_count;
      vec_write = is_vec_write ? kVecRow : 1;
    }
    IndexT n_rows_tile = is_vec_write
1348
                             ? GET_TILE_SIZE(rows, (kTileSize * vec_write))
1349 1350 1351 1352
                             : num_rows_tile;
    return n_rows_tile;
  }
};
1353 1354

template <typename T, typename IndexT>
1355 1356 1357 1358 1359 1360 1361 1362 1363
inline void PermuteDispatch(const phi::GPUContext& ctx,
                            const IndexT& count,
                            PermTypeClassifier<T>* cls_ptr,
                            const std::vector<int64_t>& dims,
                            const std::vector<int32_t>& perm,
                            const T* src,
                            T* dst) {
  int rank = dims.size();
  PermuteType type = cls_ptr->GetPermType();
1364 1365 1366 1367

#define TRANSPOSE_DISPATCH_VEC_SIZE(size)                         \
  case size: {                                                    \
    TransposeLauncher<T, IndexT, size>()(                         \
1368
        ctx, rank, type, dims, cls_ptr->GetRowsTile(), src, dst); \
1369 1370 1371
    break;                                                        \
  }

1372 1373 1374 1375 1376
#define PERMUTE_DISPATCH_VEC_SIZE(size)                \
  case size: {                                         \
    PermuteLauncher<T, IndexT, size>()(                \
        ctx, rank, count, type, dims, perm, src, dst); \
    break;                                             \
1377 1378
  }

1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395
  switch (type) {
    case kSwapTranspose:
    case kGeneralTranspose:
      switch (cls_ptr->GetVecSize()) {
        TRANSPOSE_DISPATCH_VEC_SIZE(1);
        TRANSPOSE_DISPATCH_VEC_SIZE(2);
        TRANSPOSE_DISPATCH_VEC_SIZE(4);
      }
      break;
    default:
      switch (cls_ptr->GetVecSize()) {
        PERMUTE_DISPATCH_VEC_SIZE(1);
        PERMUTE_DISPATCH_VEC_SIZE(2);
        PERMUTE_DISPATCH_VEC_SIZE(4);
      }
      break;
  }
1396 1397
#define TRANSPOSE_DISPATCH_VEC_SIZE
#define PERMUTE_DISPATCH_VEC_SIZE
1398
}
1399

1400
template <typename T>
1401 1402 1403 1404 1405 1406
inline void PermuteAndTranspose(
    const phi::GPUContext& ctx,
    const int& rank,
    const phi::DenseTensor& in,
    phi::DenseTensor* out,
    const phi::funcs::PermuteDimsSimplifier& simplifier) {
1407 1408 1409 1410 1411 1412 1413 1414 1415
  T* dst_data = out->data<T>();
  const T* src_data = in.data<T>();
  const auto count = simplifier.GetCount();
  auto classifier = PermTypeClassifier<T>(ctx.GetSMCount(),
                                          simplifier.GetRank(),
                                          simplifier.GetPerm(),
                                          simplifier.GetSrcDims(),
                                          src_data,
                                          dst_data);
L
limingshu 已提交
1416
  if (classifier.GetPermType() == PermuteType::kCopy) {
1417
    // If perm is [0,1,2,3], then just operate a DtoD copy.
1418 1419 1420
    phi::backends::gpu::GpuMemcpyAsync(dst_data,
                                       src_data,
                                       count * sizeof(T),
L
limingshu 已提交
1421 1422
                                       phi::gpuMemcpyDeviceToDevice,
                                       ctx.stream());
1423
  } else {
1424 1425
    if (count < std::numeric_limits<uint32_t>::max()) {
      PermuteDispatch<T, uint32_t>(ctx,
1426
                                   static_cast<uint32_t>(count),
1427 1428 1429 1430 1431
                                   &classifier,
                                   simplifier.GetSrcDims(),
                                   simplifier.GetPerm(),
                                   src_data,
                                   dst_data);
L
limingshu 已提交
1432
    } else {
1433
      PermuteDispatch<T, int64_t>(ctx,
1434
                                  static_cast<int64_t>(count),
1435 1436 1437 1438 1439
                                  &classifier,
                                  simplifier.GetSrcDims(),
                                  simplifier.GetPerm(),
                                  src_data,
                                  dst_data);
L
limingshu 已提交
1440
    }
1441 1442 1443
  }
}

1444
template <typename T>
1445 1446 1447 1448 1449 1450 1451
inline void PermuteWithEigen(
    const phi::GPUContext& ctx,
    const int& rank,
    const phi::DenseTensor& in,
    phi::DenseTensor* out,
    const phi::funcs::PermuteDimsSimplifier& simplifier) {
  bool not_same_dims = simplifier.GetRank() != rank;
1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468
  if (not_same_dims) {
    phi::DDim dst_dims = out->dims();
    phi::DenseTensor temp_in;

    temp_in.ShareBufferWith(in);
    temp_in.Resize(phi::make_ddim(simplifier.GetSrcDims()));
    out->Resize(phi::make_ddim(simplifier.GetDstDims()));

    TransCompute<phi::GPUContext, T>(
        simplifier.GetRank(), ctx, temp_in, out, simplifier.GetPerm());
    out->Resize(dst_dims);
  } else {
    TransCompute<phi::GPUContext, T>(
        simplifier.GetRank(), ctx, in, out, simplifier.GetPerm());
  }
}

1469
template <typename T>
1470
void TransposeGPUKernelDriver(const phi::GPUContext& ctx,
1471
                              const phi::DenseTensor& in,
1472
                              const std::vector<int32_t>& perm,
1473
                              phi::DenseTensor* out) {
1474
  const int rank = perm.size();
1475
  int64_t numel = in.numel();
1476
  bool ret = TransposeSimple<T>::Run(ctx, in, perm, out, numel);
1477
  if (!ret) {
1478 1479
    auto simplifier = phi::funcs::PermuteDimsSimplifier(
        rank, numel, perm, phi::vectorize<int64_t>(in.dims()));
1480 1481
    auto* tuner = phi::autotune::MakeTransposeTuner<T>(PermuteWithEigen<T>);
    tuner->AddCallBack(PermuteAndTranspose<T>);
1482 1483

    size_t key = phi::autotune::TransposeKey(
1484 1485
        simplifier.GetSrcDims(),
        simplifier.GetPerm(),
1486 1487 1488 1489 1490 1491
        paddle::experimental::CppTypeToDataType<T>::Type());

    tuner->Run(ctx,
               phi::autotune::AlgorithmType::kTranspose,
               key,
               ctx,
1492
               rank,
1493 1494
               in,
               out,
1495
               simplifier);
1496 1497 1498
  }
}

1499 1500
}  // namespace funcs
}  // namespace phi