transpose_function.cu.h 49.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

H
hong 已提交
17
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
18
#include "paddle/phi/backends/gpu/gpu_primitives.h"
19
#include "paddle/phi/backends/gpu/gpu_utils.h"
20
#include "paddle/phi/core/tensor_utils.h"
21
#include "paddle/phi/kernels/autotune/auto_tune_base.h"
22
#include "paddle/phi/kernels/funcs/dims_simplifier.h"
23 24
#include "paddle/phi/kernels/funcs/transpose_functor.h"
#include "paddle/phi/kernels/primitive/datamover_primitives.h"
25

26 27
namespace phi {
namespace funcs {
28

29
using Tensor = phi::DenseTensor;
30 31 32 33 34 35 36 37 38 39 40

struct EqualTo {
  constexpr bool operator()(int a, int b) const { return a == b; }
};

struct GreaterThan {
  constexpr bool operator()(int a, int b) const { return a > b; }
};

// Value can be decided in compile time.
template <typename FUN, int INT_32 = 32>
41 42 43
constexpr bool CheckProperTileSize(int tile_long,
                                   int tile_short,
                                   int size_T,
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
                                   FUN op) {
  return (size_T == 16 && ((tile_long == INT_32 && op(tile_short, 4)) ||
                           (tile_long == 2 * INT_32 && op(tile_short, 4)) ||
                           (tile_long == 4 * INT_32 && op(tile_short, 4)) ||
                           (tile_long == 8 * INT_32 && op(tile_short, 2)))) ||
         (size_T == 8 && ((tile_long == INT_32 && op(tile_short, 15)) ||
                          (tile_long == 2 * INT_32 && op(tile_short, 15)) ||
                          (tile_long == 4 * INT_32 && op(tile_short, 8)) ||
                          (tile_long == 8 * INT_32 && op(tile_short, 4)) ||
                          (tile_long == 16 * INT_32 && op(tile_short, 2)))) ||
         ((size_T == 4 || size_T == 2 || size_T == 1) &&
          ((tile_long == INT_32 && op(tile_short, 15)) ||
           (tile_long == 2 * INT_32 && op(tile_short, 15)) ||
           (tile_long == 4 * INT_32 && op(tile_short, 8)) ||
           (tile_long == 8 * INT_32 && op(tile_short, 4)) ||
           (tile_long == 16 * INT_32 && op(tile_short, 2)) ||
           (tile_long == 16 * INT_32 && op(tile_short, 2))));
}

constexpr bool CheckLongTileSize(int tile_long, int tile_short, int size_T) {
  return CheckProperTileSize(tile_long, tile_short, size_T, EqualTo());
}

constexpr bool CheckOutsideTileSize(int tile_long, int tile_short, int size_T) {
  return CheckProperTileSize(tile_long, tile_short, size_T, GreaterThan());
}

constexpr bool CheckNonLongTileSize(int tile_long, int tile_short, int size_T) {
  return !CheckOutsideTileSize(tile_long, tile_short, size_T) &&
         (CheckOutsideTileSize(tile_long * 2, tile_short, size_T) ||
          CheckOutsideTileSize(tile_long, tile_short + 1, size_T)) &&
         !CheckLongTileSize(tile_long, tile_short, size_T);
}

// Use SM to do data transfer, load a tile into SM then store out.
// All tile read and write are colascing, so can speedup memory copy
80 81 82 83
template <typename T,
          int NumThreads,
          int TileX,
          int TileY,
84
          typename IndexType = int>
85 86
__global__ void TilingSwapDim1And2(const T* __restrict__ input,
                                   Dim3 input_dims,
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
                                   T* __restrict__ output) {
  assert(blockDim.x == NumThreads);
  assert(blockDim.y == 1);
  assert(blockDim.z == 1);
  assert(gridDim.y == 1);
  assert(gridDim.z == 1);

  constexpr int BlockReadRows = NumThreads / TileY;
  constexpr int BlockWriteRows = NumThreads / TileX;

  // One extra line in the inner dimension to avoid share memory bank conflict.
  __shared__ __align__(
      alignof(T)) char share_mem_ptr[TileX * (TileY + 1) * sizeof(T)];
  typedef T(*ShareMemory)[TileY + 1];

  ShareMemory tile_sm = reinterpret_cast<ShareMemory>(share_mem_ptr);

  int x = threadIdx.x;

  Dim3 output_dims = {
107 108 109
      input_dims[0],
      input_dims[2],
      input_dims[1],
110 111 112 113
  };

  // Align dim to Tiles
  Dim3 tile_aligned_input_dim = {
114 115
      input_dims[0],
      (input_dims[1] + TileX - 1) / TileX,
116 117 118 119
      (input_dims[2] + TileY - 1) / TileY,
  };

  // Converts block idx to tile index, each block process a tile
120 121
  Index3 input_block_tile_index =
      ConvertTensorIndex<IndexType>(blockIdx.x, tile_aligned_input_dim);
122 123 124

  // Compute real index align to tile:0, 32, 64...
  Index3 block_tile_index_in_input = {
125 126
      input_block_tile_index[0],
      input_block_tile_index[1] * TileX,
127 128 129 130
      input_block_tile_index[2] * TileY,
  };

  // Compute block flat index against input dims.
131
  IndexType input_origin_block_flat_index =
132
      FlatTensorIndex<IndexType>(block_tile_index_in_input, input_dims);
133 134

  bool full_tile = true;
135
  IndexType tile_width = TileY;
136 137 138 139 140 141 142

  // Last row is not full.
  if (input_block_tile_index[2] == tile_aligned_input_dim[2] - 1) {
    tile_width = input_dims[2] - (tile_aligned_input_dim[2] - 1) * TileY;
    full_tile &= false;
  }

143
  IndexType tile_height = TileX;
144 145 146 147 148 149

  if (input_block_tile_index[1] == tile_aligned_input_dim[1] - 1) {
    tile_height = input_dims[1] - (tile_aligned_input_dim[1] - 1) * TileX;
    full_tile &= false;
  }

150
  constexpr IndexType in_effective_thread_num = NumThreads / TileY * TileY;
151 152 153 154 155

  if (x < in_effective_thread_num) {
    // Read a tile from input using block.
    int x_i = x / TileY;
    int x_j = x % TileY;
156 157 158
    IndexType input_ind =
        input_origin_block_flat_index + x_i * input_dims[2] + x_j;
    IndexType input_inc = BlockReadRows * input_dims[2];
159 160 161 162 163 164 165 166 167 168

    if (full_tile) {
#pragma unroll
      for (int ind_i = x_i; ind_i < (TileX); ind_i += BlockReadRows) {
        tile_sm[ind_i][x_j] = input[input_ind];
        input_ind += input_inc;
      }
    } else {
      if (x_j < tile_width) {
#pragma unroll
169 170
        for (IndexType ind_i = x_i; ind_i < (tile_height);
             ind_i += BlockReadRows) {
171 172 173 174 175 176 177 178 179 180 181
          tile_sm[ind_i][x_j] = input[input_ind];
          input_ind += input_inc;
        }
      }
    }
  }

  __syncthreads();

  // Store sm value back to out
  Index3 output_block_tile_index = {
182 183
      input_block_tile_index[0],
      input_block_tile_index[2],
184 185 186 187
      input_block_tile_index[1],
  };

  Index3 block_tile_index_in_output = {
188 189
      output_block_tile_index[0],
      output_block_tile_index[1] * TileY,
190 191 192
      output_block_tile_index[2] * TileX,
  };

193
  IndexType output_origin_block_flat_index =
194
      FlatTensorIndex<IndexType>(block_tile_index_in_output, output_dims);
195
  constexpr IndexType out_effective_thread_num = NumThreads / TileX * TileX;
196 197 198 199

  if (x < out_effective_thread_num) {
    int x_i = x / TileX;
    int x_j = x % TileX;
200
    IndexType output_ind =
201
        output_origin_block_flat_index + x_i * output_dims[2] + x_j;
202
    IndexType output_inc = BlockWriteRows * output_dims[2];
203 204 205 206 207 208 209 210 211 212

    if (full_tile) {
#pragma unroll
      for (int ind_i = x_i; ind_i < (TileY); ind_i += BlockWriteRows) {
        output[output_ind] = tile_sm[x_j][ind_i];
        output_ind += output_inc;
      }
    } else {
      if (x_j < tile_height) {
#pragma unroll
213 214
        for (IndexType ind_i = x_i; ind_i < (tile_width);
             ind_i += BlockWriteRows) {
215 216 217 218 219 220 221 222 223 224 225 226
          output[output_ind] = tile_sm[x_j][ind_i];
          output_ind += output_inc;
        }
      }
    }
  }
}

// This function will find combination of long_side X short_side in backups
template <int TSIZE>
bool SelectProperTileSize(std::vector<std::pair<int, int>>* tiles) {
  PADDLE_ENFORCE_LE(
227 228
      TSIZE,
      16,
229
      phi::errors::InvalidArgument(
230 231 232
          "The tile size should smaller than 16, but received is:%d.", TSIZE));

  PADDLE_ENFORCE_EQ(
233 234
      (TSIZE & (TSIZE - 1)),
      0,
235
      phi::errors::InvalidArgument(
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
          "Data types should be powers of 2, but reived size is:%d.", TSIZE));

  const int kMaxLongSideLen = 1024;
  const int kMaxShortSideLen = 15;

  for (int long_side = 32; long_side <= kMaxLongSideLen; long_side *= 2) {
    for (int short_side = 2; short_side <= kMaxShortSideLen; short_side += 1) {
      if (CheckLongTileSize(long_side, short_side, TSIZE)) {
        tiles->push_back(std::make_pair(long_side, short_side));

        if (short_side == 2) return true;

        break;
      }
    }
  }
  return false;
}

// Use system built in type
template <int ByteSize>
struct SystemElemType;
template <>
struct SystemElemType<1> {
  using type = uint8_t;
};
template <>
struct SystemElemType<2> {
  using type = uint16_t;
};
template <>
struct SystemElemType<4> {
  using type = uint32_t;
};
template <>
struct SystemElemType<8> {
  using type = uint64_t;
};
template <>
struct SystemElemType<16> {
  using type = float4;
};

279
template <typename T, int tile_long, int tile_short, typename IndexType = int>
280 281 282
void LaunchNarrowDims2TransposeKernel(const phi::GPUContext& d,
                                      int tile_size_i,
                                      int tile_size_j,
283
                                      IndexType total_tiles_count,
284 285
                                      const T* input,
                                      const Dim3& input_dims,
H
hong 已提交
286
                                      T* output) {
287 288
  constexpr int NumThreads = tile_long;
  if (tile_size_i <= tile_long && tile_size_j <= tile_short) {
289
    TilingSwapDim1And2<T, NumThreads, tile_long, tile_short, IndexType>
290 291
        <<<total_tiles_count, NumThreads, 0, d.stream()>>>(
            input, input_dims, output);
292
  } else {
293
    TilingSwapDim1And2<T, NumThreads, tile_short, tile_long, IndexType>
294 295
        <<<total_tiles_count, NumThreads, 0, d.stream()>>>(
            input, input_dims, output);
296 297 298
  }
}

299 300 301
template <typename T,
          int tile_long,
          int tile_short,
302
          typename IndexType = int,
303
          typename dummy = void>
304
struct NarrowDims2TransposeDispatch {
305 306 307
  static void DoTranspose(const phi::GPUContext& d,
                          int tile_size_i,
                          int tile_size_j,
308
                          IndexType total_tiles_count,
309 310 311
                          const T* input,
                          const Dim3& input_dims,
                          T* output) {
312
    PADDLE_ENFORCE_EQ(
313 314
        (tile_long & (tile_long - 1)),
        0,
315
        phi::errors::InvalidArgument(
316 317 318 319 320 321 322 323
            "The length of the longer side of the tile should be power of 2."
            " But received value is:%d.",
            tile_long));

    bool request_satisfied = std::max(tile_size_i, tile_size_j) <= tile_long &&
                             std::min(tile_size_i, tile_size_j) <= tile_short;

    if (request_satisfied) {
324
      LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IndexType>(
325 326 327 328 329 330
          d,
          tile_size_i,
          tile_size_j,
          total_tiles_count,
          input,
          input_dims,
331 332 333 334 335 336 337 338
          output);
      return;
    }

    const bool long_side_request_not_satisfied =
        std::max(tile_size_i, tile_size_j) > tile_long;

    if (long_side_request_not_satisfied) {
339
      NarrowDims2TransposeDispatch<T, tile_long * 2, tile_short, IndexType>::
340 341 342 343 344 345 346
          DoTranspose(d,
                      tile_size_i,
                      tile_size_j,
                      total_tiles_count,
                      input,
                      input_dims,
                      output);
347
    } else {
348
      NarrowDims2TransposeDispatch<T, tile_long, tile_short + 1, IndexType>::
349 350 351 352 353 354 355
          DoTranspose(d,
                      tile_size_i,
                      tile_size_j,
                      total_tiles_count,
                      input,
                      input_dims,
                      output);
356 357 358 359 360
    }
  }
};

// If Not long tile size, goto this function when compile.
361
template <typename T, int tile_long, int tile_short, typename IndexType>
362
struct NarrowDims2TransposeDispatch<
363 364 365
    T,
    tile_long,
    tile_short,
366
    IndexType,
367 368 369 370 371 372
    typename std::enable_if<CheckNonLongTileSize(
                                tile_long, tile_short, sizeof(T)),
                            void>::type> {
  static void DoTranspose(const phi::GPUContext& d,
                          int tile_size_i,
                          int tile_size_j,
373
                          IndexType total_tiles_count,
374 375 376
                          const T* input,
                          const Dim3& input_dims,
                          T* output) {
377
    PADDLE_ENFORCE_EQ(
378 379
        (tile_long & (tile_long - 1)),
        0,
380
        phi::errors::InvalidArgument(
381 382 383 384 385 386 387 388
            "The length of the longer side of the tile should be power of 2."
            " But received value is:%d.",
            tile_long));

    bool request_satisfied = std::max(tile_size_i, tile_size_j) <= tile_long &&
                             std::min(tile_size_i, tile_size_j) <= tile_short;

    if (request_satisfied) {
389
      LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IndexType>(
390 391 392 393 394 395
          d,
          tile_size_i,
          tile_size_j,
          total_tiles_count,
          input,
          input_dims,
396 397 398 399
          output);
      return;
    }

400
    NarrowDims2TransposeDispatch<T, tile_long, tile_short + 1, IndexType>::
401 402 403 404 405 406 407
        DoTranspose(d,
                    tile_size_i,
                    tile_size_j,
                    total_tiles_count,
                    input,
                    input_dims,
                    output);
408 409 410 411
  }
};

// If long tile size, goto this function when compile.
412
template <typename T, int tile_long, int tile_short, typename IndexType>
413
struct NarrowDims2TransposeDispatch<
414 415 416
    T,
    tile_long,
    tile_short,
417
    IndexType,
418 419
    typename std::enable_if<CheckLongTileSize(tile_long, tile_short, sizeof(T)),
                            void>::type> {
420 421 422
  static void DoTranspose(const phi::GPUContext& d,
                          int tile_size_i,
                          int tile_size_j,
423
                          IndexType total_tiles_count,
424 425 426
                          const T* input,
                          const Dim3& input_dims,
                          T* output) {
427
    PADDLE_ENFORCE_EQ(
428 429
        (tile_long & (tile_long - 1)),
        0,
430
        phi::errors::InvalidArgument(
431 432 433 434
            "The length of the longer side of the tile should be power of 2,"
            " but received is:%d.",
            tile_long));

435
    LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IndexType>(
436 437 438 439 440 441
        d,
        tile_size_i,
        tile_size_j,
        total_tiles_count,
        input,
        input_dims,
442 443 444 445
        output);
  }
};

446
template <typename T, bool conjugate = false, typename IndexType = int>
447 448 449 450
void SwapDim1And2InNarrow(const phi::GPUContext& d,
                          const T* input,
                          const Dim3& input_dims,
                          T* output,
451 452 453 454 455
                          const int kMinTileSize) {
  // First get available tile sizes for the data type requested as backups
  std::vector<std::pair<int, int>> tile_sele;
  auto ret = SelectProperTileSize<sizeof(T)>(&tile_sele);
  PADDLE_ENFORCE_EQ(
456 457
      ret,
      true,
458
      phi::errors::InvalidArgument(
459 460 461 462 463 464 465 466 467 468 469 470 471 472
          "SelectProperTileSize should return true, but return value is:%d.",
          ret));

  int tile_long_edge = 0;
  int tile_short_edge = 0;
  float lowest_cost = std::numeric_limits<float>::max();
  int input_long_edge = std::max(input_dims[1], input_dims[2]);

  // Find the tile size that best suit in  inputs.
  for (auto tile_size_pair : tile_sele) {
    int proposed_tile_long_edge = tile_size_pair.first;
    // data may not aligned to tile, so some threads wasted, we need
    // to find least wasted threads, which means we need to find tile
    // can split input properly, in another words: num_wasted_threads=0.
473
    int num_wasted_threads =
474 475 476
        input_long_edge -
        CeilOrFloor<int, false>(input_long_edge, proposed_tile_long_edge) *
            proposed_tile_long_edge;
477

478
    int num_full_tiles =
479
        CeilOrFloor<int, false>(input_long_edge, proposed_tile_long_edge);
480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512

    float cost = num_wasted_threads;

    if (cost <= lowest_cost) {
      tile_long_edge = proposed_tile_long_edge;
      tile_short_edge = tile_size_pair.second;
      lowest_cost = cost;
    }
    // break as we already find best tile size.
    if (cost == 0) break;
  }

  // The tile size we select should be match with input dim, long side to long
  // short side to short.
  // First set long side  as i if dim1 > Tile min size, then set dim2 as j.
  int select_tile_size_i =
      input_dims[1] >= kMinTileSize ? tile_long_edge : input_dims[1];
  int select_tile_size_j =
      input_dims[1] >= kMinTileSize ? input_dims[2] : tile_long_edge;

  // Check if i is long edge, if not set i as short.
  select_tile_size_i = select_tile_size_i == tile_long_edge
                           ? tile_long_edge
                           : std::min(select_tile_size_i, tile_short_edge);

  // Check if j is long edge, if not set j as short.
  select_tile_size_j = select_tile_size_j == tile_long_edge
                           ? tile_long_edge
                           : std::min(select_tile_size_j, tile_short_edge);

  // Here finally get proper long X short tile size.
  Dim3 input_dims_aligned = {
      input_dims[0],
513 514
      CeilOrFloor<int, true>(input_dims[1], select_tile_size_i),
      CeilOrFloor<int, true>(input_dims[2], select_tile_size_j),
515 516
  };

517
  IndexType total_tiles_count = input_dims_aligned[0];
518 519
  total_tiles_count *= input_dims_aligned[1];
  total_tiles_count *= input_dims_aligned[2];
520 521 522 523

  // Suppose T can be replaced by system builtin types
  using ElemType = typename SystemElemType<sizeof(T)>::type;

524
  NarrowDims2TransposeDispatch<ElemType, 32, 2, IndexType>::DoTranspose(
525 526 527 528 529 530
      d,
      select_tile_size_i,
      select_tile_size_j,
      total_tiles_count,
      reinterpret_cast<const ElemType*>(input),
      input_dims,
531 532 533 534 535
      reinterpret_cast<ElemType*>(output));
}

// This is for case that cannot do coalescing read and write.
// Or input is too small to split into tiles.
536 537
template <typename T, int pos0, int pos1, int pos2, typename IndexType = int>
__global__ void TransposeSimpleKernel(IndexType nthreads,
538 539 540
                                      const T* __restrict__ input,
                                      Dim3 input_dims,
                                      T* __restrict__ output) {
541 542 543 544 545
  Dim3 output_dims;
  output_dims[pos0] = input_dims[0];
  output_dims[pos1] = input_dims[1];
  output_dims[pos2] = input_dims[2];

546
  CUDA_KERNEL_LOOP_TYPE(output_index, nthreads, IndexType) {
547
    Index3 output_tensor_index =
548
        ConvertTensorIndex<IndexType>(output_index, output_dims);
549 550 551 552 553 554

    Index3 input_tensor_index;
    input_tensor_index[0] = output_tensor_index[pos0];
    input_tensor_index[1] = output_tensor_index[pos1];
    input_tensor_index[2] = output_tensor_index[pos2];

555
    IndexType input_index =
556
        FlatTensorIndex<IndexType>(input_tensor_index, input_dims);
557 558 559 560 561 562

    output[output_index] = input[input_index];
  }
}

// Here suppose convert all tensor to dim3, so just change dim1 and 2.
563
template <typename T, typename IndexType = int>
564 565 566 567
void SendSwapDim1And2InTranspose(const phi::GPUContext& d,
                                 const T* input,
                                 const Dim3& input_dims,
                                 T* output) {
568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583
  // Suppose tile size > 16
  static const int kMinTileSize = 16;
  static const int kMinNarrowTileSize = 96;

  bool large_tile =
      input_dims[1] >= kMinTileSize && input_dims[2] >= kMinTileSize;
  bool narrow_tile = input_dims[1] >= kMinNarrowTileSize ||
                     input_dims[2] >= kMinNarrowTileSize;
  if (large_tile) {
    // If input is large square, such as 32X32, use SM to do copy.
    // suppose 32 X 32 gives best performance, and 8 warp in block.
    constexpr int kTileSize = 32;
    constexpr int kNumThreads = 256;

    Dim3 input_dims_aligned = {
        input_dims[0],
584 585
        CeilOrFloor<int, true>(input_dims[1], kTileSize),
        CeilOrFloor<int, true>(input_dims[2], kTileSize),
586 587
    };

588
    IndexType total_tiles_count = input_dims_aligned[0];
589 590
    total_tiles_count *= input_dims_aligned[1];
    total_tiles_count *= input_dims_aligned[2];
591

592
    TilingSwapDim1And2<T, kNumThreads, kTileSize, kTileSize, IndexType>
593 594
        <<<total_tiles_count, kNumThreads, 0, d.stream()>>>(
            input, input_dims, output);
595 596 597 598 599

  } else if (narrow_tile) {
    // If input shape is like Rect, such as 2X100, use Narrow tile size.
    // It makes things complicated, because need to find a tile can coverr
    // input and also reach best coalescing.
600
    SwapDim1And2InNarrow<T, false, IndexType>(
601
        d, input, input_dims, output, kMinTileSize);
602 603
  } else {
    // If input shape is small, such as 8X8, just do simple copy
604
    IndexType total_elements = input_dims[0];
605 606
    total_elements *= input_dims[1];
    total_elements *= input_dims[2];
H
hong 已提交
607
    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_elements);
608
    TransposeSimpleKernel<T, 0, 2, 1, IndexType>
609 610
        <<<config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>(
            total_elements, input, input_dims, output);
611 612 613
  }
}

614
template <typename T, typename IndexType = int>
615
struct SwapDim1And2InTranspose {
H
hong 已提交
616
  typedef phi::GPUContext Device;
617 618 619 620
  void operator()(const Device& d,
                  const T* in,
                  const std::vector<int>& combined_dims,
                  T* out) {
621 622 623
    Dim3 input_dims = {static_cast<int>(combined_dims[0]),
                       static_cast<int>(combined_dims[1]),
                       static_cast<int>(combined_dims[2])};
624
    SendSwapDim1And2InTranspose<T, IndexType>(d, in, input_dims, out);
625 626 627
  }
};

628
template <typename T, typename IndexType = int>
629
struct SwapDim0And2InTranspose {
H
hong 已提交
630
  typedef phi::GPUContext Device;
631 632 633 634
  void operator()(const Device& d,
                  const T* in,
                  const std::vector<int>& combined_dims,
                  T* out) {
635 636 637 638
    Dim3 input_dims = {static_cast<int>(combined_dims[0]),
                       static_cast<int>(combined_dims[1]),
                       static_cast<int>(combined_dims[2])};

639
    IndexType total_size = combined_dims[0];
640 641
    total_size *= combined_dims[1];
    total_size *= combined_dims[2];
H
hong 已提交
642
    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_size);
643

644
    TransposeSimpleKernel<T, 2, 1, 0, IndexType>
645 646
        <<<config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>(
            total_size, in, input_dims, out);
647 648 649 650 651
  }
};

// This function is to combine dimension. fox example:
// (0, 1, 3, 2) --> (0, 2, 1)
652
inline void CombineTransposeDim3(const DDim& shape,
653 654
                                 const std::vector<int>& perm,
                                 std::vector<int>* new_perm,
655
                                 std::vector<int>* new_dims) {
656 657
  PADDLE_ENFORCE_EQ(shape.size(),
                    perm.size(),
658
                    phi::errors::InvalidArgument(
659 660
                        " shape should have the save dim with perm, but"
                        " received shape size is:%d, perm size is:%d.",
661 662
                        shape.size(),
                        perm.size()));
663 664 665 666 667 668 669

  std::vector<int> dim_vec;
  if (shape.size() == 1) {
    // If input dimension is already 1, no need to combine dim.
    new_perm->resize(1);
    (*new_perm)[0] = perm[0];
    dim_vec.push_back(shape[0]);
670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689
  } else {
    int dim_idx = 0;
    std::vector<int> new_dim_pos(shape.size(), -1);
    std::vector<int> combined_dims(shape.size(), 0);

    int cur_head = perm[0];
    new_dim_pos[cur_head] = 0;
    combined_dims[0] = shape[cur_head];
    for (int perm_idx = 1; perm_idx < shape.size(); ++perm_idx) {
      // combine consecutive dimensions.
      if (cur_head + 1 == perm[perm_idx]) {
        cur_head = perm[perm_idx];
        combined_dims[dim_idx] *= shape[cur_head];
      } else {
        // Else start a new dimension.
        cur_head = perm[perm_idx];
        dim_idx++;
        new_dim_pos[cur_head] = dim_idx;
        combined_dims[dim_idx] = shape[cur_head];
      }
690
    }
691 692 693 694 695 696 697 698 699 700
    new_perm->resize(dim_idx + 1);

    dim_idx = 0;
    for (int i = 0; i < new_dim_pos.size(); ++i) {
      if (new_dim_pos[i] >= 0) {
        int new_perm_idx = new_dim_pos[i];
        (*new_perm)[dim_idx] = new_perm_idx;
        dim_vec.push_back(combined_dims[new_perm_idx]);
        dim_idx++;
      }
701 702
    }
  }
703
  *new_dims = dim_vec;
704 705
}

706
template <typename T>
707
struct TransposeSimple {
708 709 710 711 712 713 714 715 716 717 718 719 720 721 722
  static bool Impl(const phi::GPUContext& ctx,
                   const phi::DenseTensor& in,
                   const std::vector<int32_t> perm,
                   phi::DenseTensor* out,
                   const int64_t numel) {
    if (numel >= std::numeric_limits<int32_t>::max()) {
      return Run<int64_t>(ctx, in, perm, out);
    } else {
      return Run<int32_t>(ctx, in, perm, out);
    }
  }

 private:
  template <typename IndexType = int32_t>
  static bool Run(const phi::GPUContext& ctx,
723
                  const phi::DenseTensor& in,
724
                  const std::vector<int32_t> perm,
725
                  phi::DenseTensor* out) {
726
    // First reduce the dimensions of the input tensor if possible.
727 728
    auto in_data = in.data<T>();
    auto out_data = out->data<T>();
729
    std::vector<int> new_perm;
730
    std::vector<int> new_dims;
731
    CombineTransposeDim3(in.dims(), perm, &new_perm, &new_dims);
732
    if (new_perm.size() < 2 || new_perm.size() > 3) return false;
733 734

    // In most cases, dim will not greater than 3 after combine.
735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750
    if (new_perm.size() == 2 && new_perm[1] == 0) {
      // Add the first dimension size as 1.
      new_dims.insert(new_dims.begin(), 1);
      SwapDim1And2InTranspose<T, IndexType>()(ctx, in_data, new_dims, out_data);
      return true;
    } else if (new_perm == std::vector<int>({0, 2, 1})) {
      SwapDim1And2InTranspose<T, IndexType>()(ctx, in_data, new_dims, out_data);
      return true;
    } else if (new_perm == std::vector<int>({2, 1, 0})) {
      // Maybe can optimize later, find a way to do coalescing memory copy.
      // But I think it depends on the data size. If span is not large,
      // maybe can do coalescing.
      SwapDim0And2InTranspose<T, IndexType>()(ctx, in_data, new_dims, out_data);
      return true;
    } else {
      return false;
751 752 753 754
    }
  }
};

755
template <typename IndexT, int N>
756 757 758
class IdxHelper {
 public:
  IdxHelper() {}
759
  explicit IdxHelper(const IndexT* dims) {
760 761 762 763 764
    for (int i = N - 1; i >= 0; --i) {
      stride_[i] = i < (N - 1) ? dims[i + 1] * stride_[i + 1] : 1;
    }
  }

765 766 767
  __device__ __forceinline__ IndexT GetStride(int idx) const {
    return stride_[idx];
  }
768

769 770 771
  __device__ __forceinline__ void GetIndexFromOffset(IndexT offset,
                                                     IndexT* index) const {
    IndexT remaining = offset;
772 773
#pragma unroll
    for (int i = 0; i < N - 1; ++i) {
774
      const IndexT idx = remaining / stride_[i];
775 776 777 778 779 780 781
      remaining -= idx * stride_[i];
      index[i] = idx;
    }
    index[N - 1] = remaining;
  }

 private:
782
  IndexT stride_[N];
783 784 785
};

template <int N>
786
class IdxHelper<uint32_t, N> {
787 788 789 790 791
 public:
  IdxHelper() {}
  explicit IdxHelper(const uint32_t* dims) {
    for (int i = N - 1; i >= 0; --i) {
      uint32_t value = i < (N - 1) ? dims[i + 1] * stride_[i + 1] : 1;
792
      divmoder_[i] = phi::kps::details::FastDivMod(value);
793 794 795 796
      stride_[i] = value;
    }
  }

797 798 799
  __device__ __forceinline__ uint32_t GetStride(int idx) const {
    return stride_[idx];
  }
800

801 802
  __device__ __forceinline__ void GetIndexFromOffset(uint32_t offset,
                                                     uint32_t* index) const {
803 804 805 806 807 808 809 810 811 812 813 814
    uint32_t remaining = offset;
#pragma unroll
    for (int i = 0; i < N - 1; ++i) {
      uint32_t idx = divmoder_[i].Div(remaining);
      index[i] = idx;
      remaining -= idx * stride_[i];
    }
    index[N - 1] = remaining;
  }

 private:
  uint32_t stride_[N];
815
  phi::kps::details::FastDivMod divmoder_[N];
816 817 818
};

// Transform index between memory offset and shape coodinate.
819
template <typename IndexT, int N>
820 821 822
class IdxAndOffsetHelper {
 public:
  IdxAndOffsetHelper() {}
823 824
  explicit IdxAndOffsetHelper(const IndexT* dims) {
    index_helper = IdxHelper<IndexT, N>(dims);
825 826
  }

827 828
  __device__ __forceinline__ IndexT IndexToOffset(const IndexT* index) const {
    IndexT offset = 0;
829 830 831 832 833 834 835 836
#pragma unroll
    for (int i = 0; i < N - 1; ++i) {
      offset += index[i] * index_helper.GetStride(i);
    }
    offset += index[N - 1];
    return offset;
  }

837 838
  __device__ __forceinline__ void OffsetToIndex(IndexT offset,
                                                IndexT* index) const {
839 840 841 842
    index_helper.GetIndexFromOffset(offset, index);
  }

 private:
843
  IdxHelper<IndexT, N> index_helper;
844 845
};

846
template <typename IndexT, int Rank>
847 848 849 850 851 852
struct PermuteParams {
 public:
  IdxAndOffsetHelper<IndexT, Rank> src_index_helper;
  IdxAndOffsetHelper<IndexT, Rank> dst_index_helper;
  int perm[Rank]{};

L
limingshu 已提交
853
  explicit PermuteParams(const std::vector<int64_t>& dims,
854
                         const std::vector<int>& perm_) {
L
limingshu 已提交
855 856 857 858
    IndexT dst_dims[Rank];
    IndexT src_dims[Rank];
    for (auto i = 0; i < Rank; ++i) {
      src_dims[i] = dims[i];
859 860 861 862
      dst_dims[i] = dims[perm_[i]];
      perm[i] = perm_[i];
    }
    dst_index_helper = IdxAndOffsetHelper<IndexT, Rank>(dst_dims);
L
limingshu 已提交
863
    src_index_helper = IdxAndOffsetHelper<IndexT, Rank>(src_dims);
864 865 866 867 868
  }
};

// A special kernel for target case, both vectorized read and write supported.
template <typename T, typename IndexT, int VecSize, int Rank>
869 870
__global__ void VectorizedPermuteKernel(PermuteParams<IndexT, Rank> params,
                                        const IndexT count,
871 872 873 874 875 876
                                        const T* __restrict__ src_data,
                                        T* dst_data) {
  using VecT = phi::AlignedVector<T, VecSize>;
  IndexT src_index[Rank];
  IndexT dst_index[Rank];

877
  const VecT* __restrict__ vec_src =
878
      reinterpret_cast<const VecT* __restrict__>(src_data);
879
  VecT* vec_dst = reinterpret_cast<VecT*>(dst_data);
880 881 882 883 884 885 886 887 888 889

  IndexT tid = blockIdx.x * blockDim.x + threadIdx.x;
  for (IndexT i = tid; i < count; i += blockDim.x * gridDim.x) {
    params.dst_index_helper.OffsetToIndex(i, dst_index);

#pragma unroll
    for (int j = 0; j < Rank; ++j) {
      src_index[params.perm[j]] = dst_index[j];
    }
    IndexT src_offset = params.src_index_helper.IndexToOffset(src_index);
890
    vec_dst[i] = vec_src[src_offset];
891 892 893 894 895
  }
}

// A general kernel for normal case, only support vectorized write.
template <typename T, typename IndexT, int VecSize, int Rank>
896 897 898 899
__global__ void GeneralPermuteKernel(PermuteParams<IndexT, Rank> params,
                                     const IndexT main_cnt,
                                     const IndexT tail_cnt,
                                     const IndexT offset,
900
                                     const T* __restrict__ src,
901
                                     T* dst) {
902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918
  using VecT = phi::AlignedVector<T, VecSize>;
  VecT* vec_dst = reinterpret_cast<VecT*>(dst);
  IndexT src_index[VecSize][Rank];
  IndexT dst_index[VecSize][Rank];

  // Vectorized load data.
  IndexT tid = blockIdx.x * blockDim.x + threadIdx.x;
  for (IndexT idx = tid; idx < main_cnt; idx += blockDim.x * gridDim.x) {
    VecT vec_data;
    IndexT vec_idx = idx * VecSize;

#pragma unroll
    for (int i = 0; i < VecSize; ++i) {
      params.dst_index_helper.OffsetToIndex(vec_idx + i, dst_index[i]);

#pragma unroll
      for (int j = 0; j < Rank; ++j) {
919
        src_index[i][params.perm[j]] = dst_index[i][j];
920 921 922 923 924 925 926 927 928 929 930 931 932 933
      }
      IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[i]);
      vec_data[i] = src[src_offset];
    }
    vec_dst[idx] = vec_data;
  }

  // Singularized load data.
  if (tid < tail_cnt) {
    IndexT idx = tid + offset;
    params.dst_index_helper.OffsetToIndex(idx, dst_index[0]);

#pragma unroll
    for (int j = 0; j < Rank; ++j) {
934
      src_index[0][params.perm[j]] = dst_index[0][j];
935 936 937 938 939 940
    }
    IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[0]);
    dst[idx] = src[src_offset];
  }
}

941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981
template <typename T, typename IndexT, int ReadSize, int WriteSize>
struct TransposeDataWriter {
  __device__ __forceinline__ void operator()(T* dst_data,
                                             const T* s_data,
                                             const IndexT rows,
                                             const IndexT cols,
                                             const IndexT chs_stride,
                                             const IndexT round_tile_cols,
                                             const IndexT col_stride = 1) {
    using OutVecT = phi::AlignedVector<T, WriteSize>;
    OutVecT* vec_dst = reinterpret_cast<OutVecT*>(dst_data);

    constexpr int kColTile = kTileSize * ReadSize;
    constexpr int kColStride = kShareCol * ReadSize;

    const IndexT vec_rows = rows / WriteSize;
    const IndexT col_in_mat = blockIdx.y * kTileSize + threadIdx.x;

    if (col_in_mat < /*dst_cols=*/vec_rows) {
      const int cols_range = (blockIdx.x < round_tile_cols)
                                 ? kTileSize
                                 : (cols - round_tile_cols * kTileSize);
      const int share_tile = threadIdx.x * (WriteSize * kColStride);
      const IndexT write_offset = blockIdx.z * chs_stride + col_in_mat;
#pragma unroll
      for (int tile_y = threadIdx.y; tile_y < cols_range;
           tile_y += kBlockRows) {
        OutVecT tmp_data[ReadSize];
#pragma unroll
        for (int i = 0; i < ReadSize; ++i) {
          int tile_tail = tile_y * ReadSize + i;
          int major_share_idx = share_tile + tile_tail;
          IndexT row_in_mat = (blockIdx.x * kColTile + tile_tail) * col_stride;
#pragma unroll
          for (int j = 0; j < WriteSize; ++j) {
            tmp_data[i].val[j] = s_data[j * kColStride + major_share_idx];
          }
          vec_dst[write_offset + row_in_mat * vec_rows] = tmp_data[i];
        }
      }
    }
982
  }
983
};
984

985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013
template <typename T, typename IndexT, int ReadSize>
struct TransposeDataWriter<T, IndexT, ReadSize, 1> {
  __device__ __forceinline__ void operator()(T* dst_data,
                                             const T* s_data,
                                             const IndexT rows,
                                             const IndexT cols,
                                             const IndexT chs_stride,
                                             const IndexT round_tile_cols,
                                             const IndexT col_stride = 1) {
    const IndexT col_in_mat = blockIdx.y * kTileSize + threadIdx.x;
    if (col_in_mat < /*dst_cols=*/rows) {
      const int cols_range = (blockIdx.x < round_tile_cols)
                                 ? kTileSize
                                 : (cols - round_tile_cols * kTileSize);
      const IndexT row_tile = blockIdx.x * kTileSize * ReadSize;
      const IndexT write_offset = blockIdx.z * chs_stride + col_in_mat;
      const int shared_tile = threadIdx.x * kShareCol * ReadSize;
#pragma unroll
      for (int tile_y = threadIdx.y; tile_y < cols_range;
           tile_y += kBlockRows) {
        const int shared_major = shared_tile + tile_y * ReadSize;
        const IndexT row_major = (row_tile + tile_y * ReadSize) * col_stride;
#pragma unroll
        for (int i = 0; i < ReadSize; ++i) {
          const IndexT row_in_mat = row_major + i * col_stride;
          dst_data[write_offset + row_in_mat * rows] = s_data[shared_major + i];
        }
      }
    }
1014
  }
1015
};
1016

1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044
template <typename T, typename IndexT, int VecSize, IndexT RowTile>
struct TransposeDataReader {
  __device__ __forceinline__ void operator()(const T* __restrict__ src,
                                             T* s_shared,
                                             const IndexT cols,
                                             const IndexT rows,
                                             const IndexT chs_stride,
                                             const IndexT cols_thresh,
                                             const IndexT round_tile_rows) {
    using VecT = phi::AlignedVector<T, VecSize>;
    const VecT* __restrict__ v_src =
        reinterpret_cast<const VecT* __restrict__>(src);
    VecT* v_shared = reinterpret_cast<VecT*>(s_shared);

    const IndexT col_in_mat = blockIdx.x * kTileSize + threadIdx.x;
    if (col_in_mat < cols_thresh) {
      const int row_range = (blockIdx.y < round_tile_rows)
                                ? RowTile
                                : (rows - RowTile * round_tile_rows);
      const IndexT src_idx_major = blockIdx.z * chs_stride + col_in_mat;
#pragma unroll
      for (int tile_y = threadIdx.y; tile_y < row_range; tile_y += kBlockRows) {
        const IndexT row_in_mat = blockIdx.y * RowTile + tile_y;
        v_shared[tile_y * kShareCol + threadIdx.x] =
            v_src[row_in_mat * cols + src_idx_major];
      }
    }
    __syncthreads();
1045
  }
1046
};
1047

L
limingshu 已提交
1048
// Aim at transposing the last 2 dimensions. Reference from
1049
// https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/
1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080
template <typename T,
          typename IndexT,
          bool IsVecWrite,
          int ReadSize,
          int WriteSize = (IsVecWrite && (sizeof(T) < sizeof(float)))
                              ? sizeof(float) / sizeof(T)
                              : 1>
__global__ void SwapTransposeKernel(const T* __restrict__ src_data,
                                    T* dst_data,
                                    const IndexT round_tile_rows,
                                    const IndexT round_tile_cols,
                                    const IndexT cols,
                                    const IndexT rows,
                                    const IndexT chs /*=channel*/) {
  constexpr int kRowTile = kTileSize * WriteSize;
  __shared__ T s_data[kRowTile * kShareCol * ReadSize];

  const IndexT chs_stride = chs * cols;
  TransposeDataReader<T, IndexT, ReadSize, kRowTile>()(
      src_data, s_data, chs_stride, rows, cols, cols, round_tile_rows);
  TransposeDataWriter<T, IndexT, ReadSize, WriteSize>()(
      dst_data, s_data, rows, cols, rows / WriteSize, round_tile_cols, chs);
}

template <typename T,
          typename IndexT,
          bool IsVecWrite,
          int ReadSize,
          int WriteSize = (IsVecWrite && (sizeof(T) < sizeof(float)))
                              ? sizeof(float) / sizeof(T)
                              : 1>
1081
__global__ void BatchTransposeKernel(const T* __restrict__ src_data,
1082
                                     T* dst_data,
1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100
                                     const IndexT round_tile_rows,
                                     const IndexT round_tile_cols,
                                     const IndexT cols,
                                     const IndexT rows) {
  constexpr int kRowTile = kTileSize * WriteSize;
  __shared__ T s_data[kRowTile * kShareCol * ReadSize];

  const IndexT chs_stride = rows * cols;
  TransposeDataReader<T, IndexT, ReadSize, kRowTile>()(
      src_data, s_data, cols, rows, chs_stride, cols, round_tile_rows);
  TransposeDataWriter<T, IndexT, ReadSize, WriteSize>()(
      dst_data,
      s_data,
      rows,
      cols,
      chs_stride * ReadSize / WriteSize,
      round_tile_cols);
}
1101

1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119
template <typename T, typename IndexT, int VecSize>
struct PermuteLauncher {
 public:
  void operator()(const phi::GPUContext& ctx,
                  const int& rank,
                  const IndexT& count,
                  const PermuteType& perm_type,
                  const std::vector<int64_t>& dims,
                  const std::vector<int32_t>& perm,
                  const T* src,
                  T* dst) {
    dims_ = dims;
    main_cnt_ = count / VecSize;
#define CALL_PERMUTE_DISPATCH_RANK(rank_)              \
  case rank_: {                                        \
    Run<rank_>(ctx, perm, perm_type, count, src, dst); \
    break;                                             \
  }
1120

1121 1122 1123 1124 1125 1126 1127 1128
    switch (rank) {
      CALL_PERMUTE_DISPATCH_RANK(3);
      CALL_PERMUTE_DISPATCH_RANK(4);
      CALL_PERMUTE_DISPATCH_RANK(5);
      CALL_PERMUTE_DISPATCH_RANK(6);
      CALL_PERMUTE_DISPATCH_RANK(7);
      CALL_PERMUTE_DISPATCH_RANK(8);
      CALL_PERMUTE_DISPATCH_RANK(9);
1129
    }
1130
#undef CALL_PERMUTE_DISPATCH_RANK
1131 1132
  }

1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155
 private:
  IndexT main_cnt_{0};
  std::vector<int64_t> dims_;

  template <int Rank>
  void Run(const phi::GPUContext& ctx,
           const std::vector<int32_t>& perm,
           const PermuteType& perm_type,
           const IndexT& count,
           const T* src,
           T* dst) {
    auto cfg = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, main_cnt_);
    if (perm_type == PermuteType::kVecPermute) {
      dims_[Rank - 1] /= VecSize;
      const auto params = PermuteParams<IndexT, Rank>(dims_, perm);

      VectorizedPermuteKernel<T, IndexT, VecSize, Rank>
          <<<cfg.block_per_grid, cfg.thread_per_block, 0, ctx.stream()>>>(
              params, main_cnt_, src, dst);
    } else {
      IndexT tail_cnt = count - main_cnt_ * VecSize;
      IndexT main_offset = count - tail_cnt;
      const auto params = PermuteParams<IndexT, Rank>(dims_, perm);
1156

1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210
      GeneralPermuteKernel<T, IndexT, VecSize, Rank>
          <<<cfg.block_per_grid, cfg.thread_per_block, 0, ctx.stream()>>>(
              params, main_cnt_, tail_cnt, main_offset, src, dst);
    }
  }
};

template <typename T, typename IndexT, int VecSize>
struct TransposeLauncher {
 public:
  void operator()(const phi::GPUContext& ctx,
                  const int& rank,
                  const PermuteType& perm_type,
                  const std::vector<int64_t>& dims,
                  const IndexT& num_rows_tile,
                  const T* src,
                  T* dst) {
    constexpr int ReadSize = sizeof(T) > sizeof(float) ? 1 : VecSize;
    const IndexT cols = dims[rank - 1] / VecSize;
    const IndexT n_cols_tile = GETTILESIZE(cols, kTileSize);

    if (perm_type == PermuteType::kGeneralTranspose) {
      IndexT chs = (rank == 2) ? 1 : dims[0];
      IndexT rows = dims[rank - 2];
      IndexT n_rows_tile =
          FindRowTiles(chs, rows, num_rows_tile, n_cols_tile, ctx.GetSMCount());
      dim3 blocks(n_cols_tile, n_rows_tile, chs);
      dim3 threads(kTileSize, kBlockRows, 1);

      if (is_vec_write) {
        BatchTransposeKernel<T, IndexT, true, ReadSize>
            <<<blocks, threads, 0, ctx.stream()>>>(
                src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows);
      } else {
        BatchTransposeKernel<T, IndexT, false, ReadSize>
            <<<blocks, threads, 0, ctx.stream()>>>(
                src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows);
      }
    } else {
      IndexT rows = dims[0];
      IndexT chs = dims[rank - 2];
      IndexT n_rows_tile =
          FindRowTiles(chs, rows, num_rows_tile, n_cols_tile, ctx.GetSMCount());
      dim3 blocks(n_cols_tile, n_rows_tile, chs);
      dim3 threads(kTileSize, kBlockRows, 1);

      if (is_vec_write) {
        SwapTransposeKernel<T, IndexT, true, ReadSize>
            <<<blocks, threads, 0, ctx.stream()>>>(
                src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows, chs);
      } else {
        SwapTransposeKernel<T, IndexT, false, ReadSize>
            <<<blocks, threads, 0, ctx.stream()>>>(
                src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows, chs);
1211 1212 1213 1214
      }
    }
  }

1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236
 private:
  bool is_vec_write{false};
  inline IndexT FindRowTiles(const IndexT& chs,
                             const IndexT& rows,
                             const IndexT& num_rows_tile,
                             const IndexT& num_cols_tile,
                             const int& sm_count) {
    constexpr int kVecRow = sizeof(float) / sizeof(T);
    is_vec_write =
        (sizeof(T) < sizeof(float)) ? ((rows % kVecRow) ? false : true) : false;

    int vec_write = 1;
    if (is_vec_write) {
      is_vec_write = (chs * num_cols_tile * num_rows_tile) > sm_count;
      vec_write = is_vec_write ? kVecRow : 1;
    }
    IndexT n_rows_tile = is_vec_write
                             ? GETTILESIZE(rows, (kTileSize * vec_write))
                             : num_rows_tile;
    return n_rows_tile;
  }
};
1237 1238

template <typename T, typename IndexT>
1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251
struct PermuteDispatch {
 public:
  PermuteDispatch(const phi::GPUContext& ctx,
                  PermTypeClassifier<T>* cls_ptr,
                  const std::vector<int64_t>& dims,
                  const std::vector<int32_t>& perm,
                  const IndexT count,
                  const T* src,
                  T* dst)
      : dims_(dims), cls_(cls_ptr) {
    rank_ = dims_.size();
    type_ = cls_->GetPermType();
    KernelTypeDispatch(ctx, count, perm, src, dst);
1252
  }
1253
  ~PermuteDispatch() {}
1254

1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295
 private:
  int rank_{0};
  std::vector<int64_t> dims_;
  PermTypeClassifier<T>* cls_;
  PermuteType type_{kGeneralPermute};

  void KernelTypeDispatch(const phi::GPUContext& ctx,
                          const IndexT& count,
                          const std::vector<int32_t>& perm,
                          const T* src,
                          T* dst) {
#define TRANSPOSE_DISPATCH_VEC_SIZE(size)                         \
  case size: {                                                    \
    TransposeLauncher<T, IndexT, size>()(                         \
        ctx, rank_, type_, dims_, cls_->GetRowsTile(), src, dst); \
    break;                                                        \
  }

#define PERMUTE_DISPATCH_VEC_SIZE(size)                   \
  case size: {                                            \
    PermuteLauncher<T, IndexT, size>()(                   \
        ctx, rank_, count, type_, dims_, perm, src, dst); \
    break;                                                \
  }

    switch (type_) {
      case kSwapTranspose:
      case kGeneralTranspose:
        switch (cls_->GetVecSize()) {
          TRANSPOSE_DISPATCH_VEC_SIZE(1);
          TRANSPOSE_DISPATCH_VEC_SIZE(2);
          TRANSPOSE_DISPATCH_VEC_SIZE(4);
        }
        break;
      default:
        switch (cls_->GetVecSize()) {
          PERMUTE_DISPATCH_VEC_SIZE(1);
          PERMUTE_DISPATCH_VEC_SIZE(2);
          PERMUTE_DISPATCH_VEC_SIZE(4);
        }
        break;
1296
    }
1297 1298
#define TRANSPOSE_DISPATCH_VEC_SIZE
#define PERMUTE_DISPATCH_VEC_SIZE
1299
  }
1300
};
1301

1302 1303 1304
template <typename T>
inline void PermuteAndTranspose(const phi::GPUContext& ctx,
                                const int& rank,
L
limingshu 已提交
1305 1306
                                const phi::DenseTensor& in,
                                phi::DenseTensor* out,
1307 1308 1309 1310 1311 1312 1313 1314 1315 1316
                                const DimsSimplifier& simplifier) {
  T* dst_data = out->data<T>();
  const T* src_data = in.data<T>();
  const auto count = simplifier.GetCount();
  auto classifier = PermTypeClassifier<T>(ctx.GetSMCount(),
                                          simplifier.GetRank(),
                                          simplifier.GetPerm(),
                                          simplifier.GetSrcDims(),
                                          src_data,
                                          dst_data);
L
limingshu 已提交
1317
  if (classifier.GetPermType() == PermuteType::kCopy) {
1318
    // If perm is [0,1,2,3], then just operate a DtoD copy.
1319 1320 1321
    phi::backends::gpu::GpuMemcpyAsync(dst_data,
                                       src_data,
                                       count * sizeof(T),
L
limingshu 已提交
1322 1323
                                       phi::gpuMemcpyDeviceToDevice,
                                       ctx.stream());
1324
  } else {
1325 1326 1327 1328 1329 1330 1331 1332
    if (count < std::numeric_limits<uint32_t>::max()) {
      PermuteDispatch<T, uint32_t>(ctx,
                                   &classifier,
                                   simplifier.GetSrcDims(),
                                   simplifier.GetPerm(),
                                   static_cast<uint32_t>(count),
                                   src_data,
                                   dst_data);
L
limingshu 已提交
1333
    } else {
1334 1335 1336 1337 1338 1339 1340
      PermuteDispatch<T, int64_t>(ctx,
                                  &classifier,
                                  simplifier.GetSrcDims(),
                                  simplifier.GetPerm(),
                                  static_cast<int64_t>(count),
                                  src_data,
                                  dst_data);
L
limingshu 已提交
1341
    }
1342 1343 1344
  }
}

1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368
template <typename T>
inline void PermuteWithEigen(const phi::GPUContext& ctx,
                             const int& rank,
                             const phi::DenseTensor& in,
                             phi::DenseTensor* out,
                             const DimsSimplifier& simplifier) {
  const bool not_same_dims = simplifier.GetRank() != rank;
  if (not_same_dims) {
    phi::DDim dst_dims = out->dims();
    phi::DenseTensor temp_in;

    temp_in.ShareBufferWith(in);
    temp_in.Resize(phi::make_ddim(simplifier.GetSrcDims()));
    out->Resize(phi::make_ddim(simplifier.GetDstDims()));

    TransCompute<phi::GPUContext, T>(
        simplifier.GetRank(), ctx, temp_in, out, simplifier.GetPerm());
    out->Resize(dst_dims);
  } else {
    TransCompute<phi::GPUContext, T>(
        simplifier.GetRank(), ctx, in, out, simplifier.GetPerm());
  }
}

1369
template <typename T>
1370
void TransposeGPUKernelDriver(const phi::GPUContext& ctx,
1371
                              const phi::DenseTensor& in,
1372
                              const std::vector<int32_t>& perm,
1373
                              phi::DenseTensor* out) {
1374
  const int rank = perm.size();
1375
  int64_t numel = in.numel();
1376
  bool ret = TransposeSimple<T>::Impl(ctx, in, perm, out, numel);
1377
  if (!ret) {
1378 1379 1380 1381
    auto simplifier =
        DimsSimplifier(rank, numel, perm, phi::vectorize<int64_t>(in.dims()));
    auto* tuner = phi::autotune::MakeTransposeTuner<T>(PermuteWithEigen<T>);
    tuner->AddCallBack(PermuteAndTranspose<T>);
1382 1383

    size_t key = phi::autotune::TransposeKey(
1384 1385
        simplifier.GetSrcDims(),
        simplifier.GetPerm(),
1386 1387 1388 1389 1390 1391
        paddle::experimental::CppTypeToDataType<T>::Type());

    tuner->Run(ctx,
               phi::autotune::AlgorithmType::kTranspose,
               key,
               ctx,
1392
               rank,
1393 1394
               in,
               out,
1395
               simplifier);
1396 1397 1398
  }
}

1399 1400
}  // namespace funcs
}  // namespace phi