transpose_op.cu.h 41.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/framework/gpu_utils.h"
#include "paddle/fluid/operators/transpose_op.h"
19
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
20
#include "paddle/fluid/platform/fast_divmod.h"
H
hong 已提交
21 22
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
23
#include "paddle/phi/core/tensor_utils.h"
24
#include "paddle/phi/kernels/autotune/auto_tune_base.h"
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using Dim3 = framework::Dim3;
using Index3 = framework::Index3;

struct EqualTo {
  constexpr bool operator()(int a, int b) const { return a == b; }
};

struct GreaterThan {
  constexpr bool operator()(int a, int b) const { return a > b; }
};

// Value can be decided in compile time.
template <typename FUN, int INT_32 = 32>
43 44 45
constexpr bool CheckProperTileSize(int tile_long,
                                   int tile_short,
                                   int size_T,
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
                                   FUN op) {
  return (size_T == 16 && ((tile_long == INT_32 && op(tile_short, 4)) ||
                           (tile_long == 2 * INT_32 && op(tile_short, 4)) ||
                           (tile_long == 4 * INT_32 && op(tile_short, 4)) ||
                           (tile_long == 8 * INT_32 && op(tile_short, 2)))) ||
         (size_T == 8 && ((tile_long == INT_32 && op(tile_short, 15)) ||
                          (tile_long == 2 * INT_32 && op(tile_short, 15)) ||
                          (tile_long == 4 * INT_32 && op(tile_short, 8)) ||
                          (tile_long == 8 * INT_32 && op(tile_short, 4)) ||
                          (tile_long == 16 * INT_32 && op(tile_short, 2)))) ||
         ((size_T == 4 || size_T == 2 || size_T == 1) &&
          ((tile_long == INT_32 && op(tile_short, 15)) ||
           (tile_long == 2 * INT_32 && op(tile_short, 15)) ||
           (tile_long == 4 * INT_32 && op(tile_short, 8)) ||
           (tile_long == 8 * INT_32 && op(tile_short, 4)) ||
           (tile_long == 16 * INT_32 && op(tile_short, 2)) ||
           (tile_long == 16 * INT_32 && op(tile_short, 2))));
}

constexpr bool CheckLongTileSize(int tile_long, int tile_short, int size_T) {
  return CheckProperTileSize(tile_long, tile_short, size_T, EqualTo());
}

constexpr bool CheckOutsideTileSize(int tile_long, int tile_short, int size_T) {
  return CheckProperTileSize(tile_long, tile_short, size_T, GreaterThan());
}

constexpr bool CheckNonLongTileSize(int tile_long, int tile_short, int size_T) {
  return !CheckOutsideTileSize(tile_long, tile_short, size_T) &&
         (CheckOutsideTileSize(tile_long * 2, tile_short, size_T) ||
          CheckOutsideTileSize(tile_long, tile_short + 1, size_T)) &&
         !CheckLongTileSize(tile_long, tile_short, size_T);
}

// Use SM to do data transfer, load a tile into SM then store out.
// All tile read and write are colascing, so can speedup memory copy
82 83 84 85 86
template <typename T,
          int NumThreads,
          int TileX,
          int TileY,
          typename IDX_T = int>
87 88
__global__ void TilingSwapDim1And2(const T* __restrict__ input,
                                   Dim3 input_dims,
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
                                   T* __restrict__ output) {
  assert(blockDim.x == NumThreads);
  assert(blockDim.y == 1);
  assert(blockDim.z == 1);
  assert(gridDim.y == 1);
  assert(gridDim.z == 1);

  constexpr int BlockReadRows = NumThreads / TileY;
  constexpr int BlockWriteRows = NumThreads / TileX;

  // One extra line in the inner dimension to avoid share memory bank conflict.
  __shared__ __align__(
      alignof(T)) char share_mem_ptr[TileX * (TileY + 1) * sizeof(T)];
  typedef T(*ShareMemory)[TileY + 1];

  ShareMemory tile_sm = reinterpret_cast<ShareMemory>(share_mem_ptr);

  int x = threadIdx.x;

  Dim3 output_dims = {
109 110 111
      input_dims[0],
      input_dims[2],
      input_dims[1],
112 113 114 115
  };

  // Align dim to Tiles
  Dim3 tile_aligned_input_dim = {
116 117
      input_dims[0],
      (input_dims[1] + TileX - 1) / TileX,
118 119 120 121 122
      (input_dims[2] + TileY - 1) / TileY,
  };

  // Converts block idx to tile index, each block process a tile
  Index3 input_block_tile_index =
123
      framework::ConvertTensorIndex<IDX_T>(blockIdx.x, tile_aligned_input_dim);
124 125 126

  // Compute real index align to tile:0, 32, 64...
  Index3 block_tile_index_in_input = {
127 128
      input_block_tile_index[0],
      input_block_tile_index[1] * TileX,
129 130 131 132
      input_block_tile_index[2] * TileY,
  };

  // Compute block flat index against input dims.
133 134
  IDX_T input_origin_block_flat_index =
      framework::FlatTensorIndex<IDX_T>(block_tile_index_in_input, input_dims);
135 136

  bool full_tile = true;
137
  IDX_T tile_width = TileY;
138 139 140 141 142 143 144

  // Last row is not full.
  if (input_block_tile_index[2] == tile_aligned_input_dim[2] - 1) {
    tile_width = input_dims[2] - (tile_aligned_input_dim[2] - 1) * TileY;
    full_tile &= false;
  }

145
  IDX_T tile_height = TileX;
146 147 148 149 150 151

  if (input_block_tile_index[1] == tile_aligned_input_dim[1] - 1) {
    tile_height = input_dims[1] - (tile_aligned_input_dim[1] - 1) * TileX;
    full_tile &= false;
  }

152
  constexpr IDX_T in_effective_thread_num = NumThreads / TileY * TileY;
153 154 155 156 157

  if (x < in_effective_thread_num) {
    // Read a tile from input using block.
    int x_i = x / TileY;
    int x_j = x % TileY;
158 159
    IDX_T input_ind = input_origin_block_flat_index + x_i * input_dims[2] + x_j;
    IDX_T input_inc = BlockReadRows * input_dims[2];
160 161 162 163 164 165 166 167 168 169

    if (full_tile) {
#pragma unroll
      for (int ind_i = x_i; ind_i < (TileX); ind_i += BlockReadRows) {
        tile_sm[ind_i][x_j] = input[input_ind];
        input_ind += input_inc;
      }
    } else {
      if (x_j < tile_width) {
#pragma unroll
170
        for (IDX_T ind_i = x_i; ind_i < (tile_height); ind_i += BlockReadRows) {
171 172 173 174 175 176 177 178 179 180 181
          tile_sm[ind_i][x_j] = input[input_ind];
          input_ind += input_inc;
        }
      }
    }
  }

  __syncthreads();

  // Store sm value back to out
  Index3 output_block_tile_index = {
182 183
      input_block_tile_index[0],
      input_block_tile_index[2],
184 185 186 187
      input_block_tile_index[1],
  };

  Index3 block_tile_index_in_output = {
188 189
      output_block_tile_index[0],
      output_block_tile_index[1] * TileY,
190 191 192
      output_block_tile_index[2] * TileX,
  };

193 194
  IDX_T output_origin_block_flat_index = framework::FlatTensorIndex<IDX_T>(
      block_tile_index_in_output, output_dims);
195

196
  constexpr IDX_T out_effective_thread_num = NumThreads / TileX * TileX;
197 198 199 200

  if (x < out_effective_thread_num) {
    int x_i = x / TileX;
    int x_j = x % TileX;
201
    IDX_T output_ind =
202
        output_origin_block_flat_index + x_i * output_dims[2] + x_j;
203
    IDX_T output_inc = BlockWriteRows * output_dims[2];
204 205 206 207 208 209 210 211 212 213

    if (full_tile) {
#pragma unroll
      for (int ind_i = x_i; ind_i < (TileY); ind_i += BlockWriteRows) {
        output[output_ind] = tile_sm[x_j][ind_i];
        output_ind += output_inc;
      }
    } else {
      if (x_j < tile_height) {
#pragma unroll
214
        for (IDX_T ind_i = x_i; ind_i < (tile_width); ind_i += BlockWriteRows) {
215 216 217 218 219 220 221 222 223 224 225 226
          output[output_ind] = tile_sm[x_j][ind_i];
          output_ind += output_inc;
        }
      }
    }
  }
}

// This function will find combination of long_side X short_side in backups
template <int TSIZE>
bool SelectProperTileSize(std::vector<std::pair<int, int>>* tiles) {
  PADDLE_ENFORCE_LE(
227 228
      TSIZE,
      16,
229 230 231 232
      platform::errors::InvalidArgument(
          "The tile size should smaller than 16, but received is:%d.", TSIZE));

  PADDLE_ENFORCE_EQ(
233 234
      (TSIZE & (TSIZE - 1)),
      0,
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
      platform::errors::InvalidArgument(
          "Data types should be powers of 2, but reived size is:%d.", TSIZE));

  const int kMaxLongSideLen = 1024;
  const int kMaxShortSideLen = 15;

  for (int long_side = 32; long_side <= kMaxLongSideLen; long_side *= 2) {
    for (int short_side = 2; short_side <= kMaxShortSideLen; short_side += 1) {
      if (CheckLongTileSize(long_side, short_side, TSIZE)) {
        tiles->push_back(std::make_pair(long_side, short_side));

        if (short_side == 2) return true;

        break;
      }
    }
  }
  return false;
}

// Use system built in type
template <int ByteSize>
struct SystemElemType;
template <>
struct SystemElemType<1> {
  using type = uint8_t;
};
template <>
struct SystemElemType<2> {
  using type = uint16_t;
};
template <>
struct SystemElemType<4> {
  using type = uint32_t;
};
template <>
struct SystemElemType<8> {
  using type = uint64_t;
};
template <>
struct SystemElemType<16> {
  using type = float4;
};

279
template <typename T, int tile_long, int tile_short, typename IDX_T = int>
280 281 282
void LaunchNarrowDims2TransposeKernel(const phi::GPUContext& d,
                                      int tile_size_i,
                                      int tile_size_j,
283
                                      IDX_T total_tiles_count,
284 285
                                      const T* input,
                                      const Dim3& input_dims,
H
hong 已提交
286
                                      T* output) {
287 288
  constexpr int NumThreads = tile_long;
  if (tile_size_i <= tile_long && tile_size_j <= tile_short) {
289
    TilingSwapDim1And2<T, NumThreads, tile_long, tile_short, IDX_T>
290 291
        <<<total_tiles_count, NumThreads, 0, d.stream()>>>(
            input, input_dims, output);
292
  } else {
293
    TilingSwapDim1And2<T, NumThreads, tile_short, tile_long, IDX_T>
294 295
        <<<total_tiles_count, NumThreads, 0, d.stream()>>>(
            input, input_dims, output);
296 297 298
  }
}

299 300 301 302 303
template <typename T,
          int tile_long,
          int tile_short,
          typename IDX_T = int,
          typename dummy = void>
304
struct NarrowDims2TransposeDispatch {
305 306 307
  static void DoTranspose(const phi::GPUContext& d,
                          int tile_size_i,
                          int tile_size_j,
308
                          IDX_T total_tiles_count,
309 310 311
                          const T* input,
                          const Dim3& input_dims,
                          T* output) {
312
    PADDLE_ENFORCE_EQ(
313 314
        (tile_long & (tile_long - 1)),
        0,
315 316 317 318 319 320 321 322 323
        platform::errors::InvalidArgument(
            "The length of the longer side of the tile should be power of 2."
            " But received value is:%d.",
            tile_long));

    bool request_satisfied = std::max(tile_size_i, tile_size_j) <= tile_long &&
                             std::min(tile_size_i, tile_size_j) <= tile_short;

    if (request_satisfied) {
324
      LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IDX_T>(
325 326 327 328 329 330
          d,
          tile_size_i,
          tile_size_j,
          total_tiles_count,
          input,
          input_dims,
331 332 333 334 335 336 337 338
          output);
      return;
    }

    const bool long_side_request_not_satisfied =
        std::max(tile_size_i, tile_size_j) > tile_long;

    if (long_side_request_not_satisfied) {
339 340 341 342 343 344 345 346
      NarrowDims2TransposeDispatch<T, tile_long * 2, tile_short, IDX_T>::
          DoTranspose(d,
                      tile_size_i,
                      tile_size_j,
                      total_tiles_count,
                      input,
                      input_dims,
                      output);
347
    } else {
348 349 350 351 352 353 354 355
      NarrowDims2TransposeDispatch<T, tile_long, tile_short + 1, IDX_T>::
          DoTranspose(d,
                      tile_size_i,
                      tile_size_j,
                      total_tiles_count,
                      input,
                      input_dims,
                      output);
356 357 358 359 360
    }
  }
};

// If Not long tile size, goto this function when compile.
361
template <typename T, int tile_long, int tile_short, typename IDX_T>
362
struct NarrowDims2TransposeDispatch<
363 364 365
    T,
    tile_long,
    tile_short,
366
    IDX_T,
367 368 369 370 371 372
    typename std::enable_if<CheckNonLongTileSize(
                                tile_long, tile_short, sizeof(T)),
                            void>::type> {
  static void DoTranspose(const phi::GPUContext& d,
                          int tile_size_i,
                          int tile_size_j,
373
                          IDX_T total_tiles_count,
374 375 376
                          const T* input,
                          const Dim3& input_dims,
                          T* output) {
377
    PADDLE_ENFORCE_EQ(
378 379
        (tile_long & (tile_long - 1)),
        0,
380 381 382 383 384 385 386 387 388
        platform::errors::InvalidArgument(
            "The length of the longer side of the tile should be power of 2."
            " But received value is:%d.",
            tile_long));

    bool request_satisfied = std::max(tile_size_i, tile_size_j) <= tile_long &&
                             std::min(tile_size_i, tile_size_j) <= tile_short;

    if (request_satisfied) {
389
      LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IDX_T>(
390 391 392 393 394 395
          d,
          tile_size_i,
          tile_size_j,
          total_tiles_count,
          input,
          input_dims,
396 397 398 399
          output);
      return;
    }

400 401 402 403 404 405 406 407
    NarrowDims2TransposeDispatch<T, tile_long, tile_short + 1, IDX_T>::
        DoTranspose(d,
                    tile_size_i,
                    tile_size_j,
                    total_tiles_count,
                    input,
                    input_dims,
                    output);
408 409 410 411
  }
};

// If long tile size, goto this function when compile.
412
template <typename T, int tile_long, int tile_short, typename IDX_T>
413
struct NarrowDims2TransposeDispatch<
414 415 416
    T,
    tile_long,
    tile_short,
417
    IDX_T,
418 419
    typename std::enable_if<CheckLongTileSize(tile_long, tile_short, sizeof(T)),
                            void>::type> {
420 421 422
  static void DoTranspose(const phi::GPUContext& d,
                          int tile_size_i,
                          int tile_size_j,
423
                          IDX_T total_tiles_count,
424 425 426
                          const T* input,
                          const Dim3& input_dims,
                          T* output) {
427
    PADDLE_ENFORCE_EQ(
428 429
        (tile_long & (tile_long - 1)),
        0,
430 431 432 433 434
        platform::errors::InvalidArgument(
            "The length of the longer side of the tile should be power of 2,"
            " but received is:%d.",
            tile_long));

435
    LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IDX_T>(
436 437 438 439 440 441
        d,
        tile_size_i,
        tile_size_j,
        total_tiles_count,
        input,
        input_dims,
442 443 444 445
        output);
  }
};

446
template <typename T, bool conjugate = false, typename IDX_T = int>
447 448 449 450
void SwapDim1And2InNarrow(const phi::GPUContext& d,
                          const T* input,
                          const Dim3& input_dims,
                          T* output,
451 452 453 454 455
                          const int kMinTileSize) {
  // First get available tile sizes for the data type requested as backups
  std::vector<std::pair<int, int>> tile_sele;
  auto ret = SelectProperTileSize<sizeof(T)>(&tile_sele);
  PADDLE_ENFORCE_EQ(
456 457
      ret,
      true,
458 459 460 461 462 463 464 465 466 467 468 469 470 471 472
      platform::errors::InvalidArgument(
          "SelectProperTileSize should return true, but return value is:%d.",
          ret));

  int tile_long_edge = 0;
  int tile_short_edge = 0;
  float lowest_cost = std::numeric_limits<float>::max();
  int input_long_edge = std::max(input_dims[1], input_dims[2]);

  // Find the tile size that best suit in  inputs.
  for (auto tile_size_pair : tile_sele) {
    int proposed_tile_long_edge = tile_size_pair.first;
    // data may not aligned to tile, so some threads wasted, we need
    // to find least wasted threads, which means we need to find tile
    // can split input properly, in another words: num_wasted_threads=0.
473 474 475 476
    int num_wasted_threads =
        input_long_edge - framework::CeilOrFloor<int, false>(
                              input_long_edge, proposed_tile_long_edge) *
                              proposed_tile_long_edge;
477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516

    int num_full_tiles = framework::CeilOrFloor<int, false>(
        input_long_edge, proposed_tile_long_edge);

    float cost = num_wasted_threads;

    if (cost <= lowest_cost) {
      tile_long_edge = proposed_tile_long_edge;
      tile_short_edge = tile_size_pair.second;
      lowest_cost = cost;
    }
    // break as we already find best tile size.
    if (cost == 0) break;
  }

  // The tile size we select should be match with input dim, long side to long
  // short side to short.
  // First set long side  as i if dim1 > Tile min size, then set dim2 as j.
  int select_tile_size_i =
      input_dims[1] >= kMinTileSize ? tile_long_edge : input_dims[1];
  int select_tile_size_j =
      input_dims[1] >= kMinTileSize ? input_dims[2] : tile_long_edge;

  // Check if i is long edge, if not set i as short.
  select_tile_size_i = select_tile_size_i == tile_long_edge
                           ? tile_long_edge
                           : std::min(select_tile_size_i, tile_short_edge);

  // Check if j is long edge, if not set j as short.
  select_tile_size_j = select_tile_size_j == tile_long_edge
                           ? tile_long_edge
                           : std::min(select_tile_size_j, tile_short_edge);

  // Here finally get proper long X short tile size.
  Dim3 input_dims_aligned = {
      input_dims[0],
      framework::CeilOrFloor<int, true>(input_dims[1], select_tile_size_i),
      framework::CeilOrFloor<int, true>(input_dims[2], select_tile_size_j),
  };

517 518 519
  IDX_T total_tiles_count = input_dims_aligned[0];
  total_tiles_count *= input_dims_aligned[1];
  total_tiles_count *= input_dims_aligned[2];
520 521 522 523

  // Suppose T can be replaced by system builtin types
  using ElemType = typename SystemElemType<sizeof(T)>::type;

524
  NarrowDims2TransposeDispatch<ElemType, 32, 2, IDX_T>::DoTranspose(
525 526 527 528 529 530
      d,
      select_tile_size_i,
      select_tile_size_j,
      total_tiles_count,
      reinterpret_cast<const ElemType*>(input),
      input_dims,
531 532 533 534 535
      reinterpret_cast<ElemType*>(output));
}

// This is for case that cannot do coalescing read and write.
// Or input is too small to split into tiles.
536 537
template <typename T, int pos0, int pos1, int pos2, typename IDX_T = int>
__global__ void TransposeSimpleKernel(IDX_T nthreads,
538 539 540
                                      const T* __restrict__ input,
                                      Dim3 input_dims,
                                      T* __restrict__ output) {
541 542 543 544 545
  Dim3 output_dims;
  output_dims[pos0] = input_dims[0];
  output_dims[pos1] = input_dims[1];
  output_dims[pos2] = input_dims[2];

546 547 548
  CUDA_KERNEL_LOOP_TYPE(output_index, nthreads, IDX_T) {
    Index3 output_tensor_index =
        framework::ConvertTensorIndex<IDX_T>(output_index, output_dims);
549 550 551 552 553 554

    Index3 input_tensor_index;
    input_tensor_index[0] = output_tensor_index[pos0];
    input_tensor_index[1] = output_tensor_index[pos1];
    input_tensor_index[2] = output_tensor_index[pos2];

555 556
    IDX_T input_index =
        framework::FlatTensorIndex<IDX_T>(input_tensor_index, input_dims);
557 558 559 560 561 562

    output[output_index] = input[input_index];
  }
}

// Here suppose convert all tensor to dim3, so just change dim1 and 2.
563
template <typename T, typename IDX_T = int>
564 565 566 567
void SendSwapDim1And2InTranspose(const phi::GPUContext& d,
                                 const T* input,
                                 const Dim3& input_dims,
                                 T* output) {
568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587
  // Suppose tile size > 16
  static const int kMinTileSize = 16;
  static const int kMinNarrowTileSize = 96;

  bool large_tile =
      input_dims[1] >= kMinTileSize && input_dims[2] >= kMinTileSize;
  bool narrow_tile = input_dims[1] >= kMinNarrowTileSize ||
                     input_dims[2] >= kMinNarrowTileSize;
  if (large_tile) {
    // If input is large square, such as 32X32, use SM to do copy.
    // suppose 32 X 32 gives best performance, and 8 warp in block.
    constexpr int kTileSize = 32;
    constexpr int kNumThreads = 256;

    Dim3 input_dims_aligned = {
        input_dims[0],
        framework::CeilOrFloor<int, true>(input_dims[1], kTileSize),
        framework::CeilOrFloor<int, true>(input_dims[2], kTileSize),
    };

588 589 590
    IDX_T total_tiles_count = input_dims_aligned[0];
    total_tiles_count *= input_dims_aligned[1];
    total_tiles_count *= input_dims_aligned[2];
591

592
    TilingSwapDim1And2<T, kNumThreads, kTileSize, kTileSize, IDX_T>
593 594
        <<<total_tiles_count, kNumThreads, 0, d.stream()>>>(
            input, input_dims, output);
595 596 597 598 599

  } else if (narrow_tile) {
    // If input shape is like Rect, such as 2X100, use Narrow tile size.
    // It makes things complicated, because need to find a tile can coverr
    // input and also reach best coalescing.
600 601
    SwapDim1And2InNarrow<T, false, IDX_T>(
        d, input, input_dims, output, kMinTileSize);
602 603
  } else {
    // If input shape is small, such as 8X8, just do simple copy
604 605 606
    IDX_T total_elements = input_dims[0];
    total_elements *= input_dims[1];
    total_elements *= input_dims[2];
H
hong 已提交
607
    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_elements);
608
    TransposeSimpleKernel<T, 0, 2, 1, IDX_T>
609 610
        <<<config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>(
            total_elements, input, input_dims, output);
611 612 613
  }
}

614
template <typename T, typename IDX_T = int>
615
struct SwapDim1And2InTranspose {
H
hong 已提交
616
  typedef phi::GPUContext Device;
617 618 619 620
  void operator()(const Device& d,
                  const T* in,
                  const std::vector<int>& combined_dims,
                  T* out) {
621 622 623
    Dim3 input_dims = {static_cast<int>(combined_dims[0]),
                       static_cast<int>(combined_dims[1]),
                       static_cast<int>(combined_dims[2])};
624
    SendSwapDim1And2InTranspose<T, IDX_T>(d, in, input_dims, out);
625 626 627
  }
};

628
template <typename T, typename IDX_T = int>
629
struct SwapDim0And2InTranspose {
H
hong 已提交
630
  typedef phi::GPUContext Device;
631 632 633 634
  void operator()(const Device& d,
                  const T* in,
                  const std::vector<int>& combined_dims,
                  T* out) {
635 636 637 638
    Dim3 input_dims = {static_cast<int>(combined_dims[0]),
                       static_cast<int>(combined_dims[1]),
                       static_cast<int>(combined_dims[2])};

639 640 641
    IDX_T total_size = combined_dims[0];
    total_size *= combined_dims[1];
    total_size *= combined_dims[2];
H
hong 已提交
642
    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_size);
643

644
    TransposeSimpleKernel<T, 2, 1, 0, IDX_T>
645 646
        <<<config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>(
            total_size, in, input_dims, out);
647 648 649 650 651 652 653 654 655
  }
};

// This function is to combine dimension. fox example:
// (0, 1, 3, 2) --> (0, 2, 1)
inline void CombineTransposeDim3(const framework::DDim& shape,
                                 const std::vector<int>& perm,
                                 std::vector<int>* new_perm,
                                 framework::DDim* new_dims) {
656 657
  PADDLE_ENFORCE_EQ(shape.size(),
                    perm.size(),
658 659 660
                    platform::errors::InvalidArgument(
                        " shape should have the save dim with perm, but"
                        " received shape size is:%d, perm size is:%d.",
661 662
                        shape.size(),
                        perm.size()));
663 664 665 666 667 668 669

  std::vector<int> dim_vec;
  if (shape.size() == 1) {
    // If input dimension is already 1, no need to combine dim.
    new_perm->resize(1);
    (*new_perm)[0] = perm[0];
    dim_vec.push_back(shape[0]);
670
    *new_dims = phi::make_ddim(dim_vec);
671 672 673
    return;
  }
  std::vector<int> new_dim_pos(shape.size(), -1);
674
  std::vector<int64_t> combined_dims(shape.size(), 0);
675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704
  int cur_head = perm[0];
  new_dim_pos[cur_head] = 0;
  combined_dims[0] = shape[cur_head];
  int dim_idx = 0;
  for (int perm_idx = 1; perm_idx < shape.size(); ++perm_idx) {
    // combine consecutive dimensions.
    if (cur_head + 1 == perm[perm_idx]) {
      cur_head = perm[perm_idx];
      combined_dims[dim_idx] *= shape[cur_head];
    } else {
      // Else start a new dimension.
      cur_head = perm[perm_idx];
      dim_idx++;
      new_dim_pos[cur_head] = dim_idx;
      combined_dims[dim_idx] = shape[cur_head];
    }
  }

  new_perm->resize(dim_idx + 1);

  dim_idx = 0;
  for (int i = 0; i < new_dim_pos.size(); ++i) {
    if (new_dim_pos[i] >= 0) {
      int new_perm_idx = new_dim_pos[i];
      (*new_perm)[dim_idx] = new_perm_idx;
      dim_vec.push_back(combined_dims[new_perm_idx]);
      dim_idx++;
    }
  }

705
  *new_dims = phi::make_ddim(dim_vec);
706 707
}

708
template <typename T, typename IDX_T = int>
709
struct TransposeSimple {
710 711 712 713
  static bool run(const phi::GPUContext& ctx,
                  const Tensor& in,
                  const std::vector<int32_t> perm,
                  Tensor* out) {
714 715 716 717 718 719 720
    // First reduce the dimensions of the input tensor if possible.
    std::vector<int> new_perm;
    framework::DDim new_dims;
    CombineTransposeDim3(in.dims(), perm, &new_perm, &new_dims);

    // Only use tile copy GPU kernel when dimension is 2 or 3.
    int dims = new_dims.size();
721
    std::vector<int> new_dim_vec = phi::vectorize<int>(new_dims);
722 723 724 725 726 727 728 729 730
    if (dims < 2 || dims > 3) return false;
    auto in_data = in.data<T>();
    auto out_data = out->data<T>();
    // In most cases, dim will not greater than 3 after combine.
    switch (dims) {
      case 2:
        if (new_perm[0] == 1 && new_perm[1] == 0) {
          // Add the first dimension size as 1.
          new_dim_vec.insert(new_dim_vec.begin(), 1);
731 732
          SwapDim1And2InTranspose<T, IDX_T>()(
              ctx, in_data, new_dim_vec, out_data);
733 734 735 736 737 738
          return true;
        }
        break;
      case 3:
        // In this case, suppose we can do coalescing read and write in tile.
        if (new_perm == std::vector<int>({0, 2, 1})) {
739 740
          SwapDim1And2InTranspose<T, IDX_T>()(
              ctx, in_data, new_dim_vec, out_data);
741 742 743 744 745 746
          return true;
        } else if (new_perm == std::vector<int>({2, 1, 0})) {
          // Maybe can optimize later, find a way to do coalescing memory copy.
          // But I think it depends on the data size. If span is not large,
          // maybe
          // can do coalescing.
747 748
          SwapDim0And2InTranspose<T, IDX_T>()(
              ctx, in_data, new_dim_vec, out_data);
749 750 751 752 753 754 755 756 757 758 759 760
          return true;
        } else {
          return false;
        }
        break;
      default:
        return false;
    }
    return false;
  }
};

761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905
template <int N, typename T>
class IdxHelper {
 public:
  IdxHelper() {}
  explicit IdxHelper(const T* dims) {
    for (int i = N - 1; i >= 0; --i) {
      stride_[i] = i < (N - 1) ? dims[i + 1] * stride_[i + 1] : 1;
    }
  }

  __device__ inline T GetStride(int idx) const { return stride_[idx]; }

  __device__ inline void GetIndexFromOffset(T offset, T* index) const {
    T remaining = offset;
#pragma unroll
    for (int i = 0; i < N - 1; ++i) {
      const T idx = remaining / stride_[i];
      remaining -= idx * stride_[i];
      index[i] = idx;
    }
    index[N - 1] = remaining;
  }

 private:
  T stride_[N];
};

template <int N>
class IdxHelper<N, uint32_t> {
 public:
  IdxHelper() {}
  explicit IdxHelper(const uint32_t* dims) {
    for (int i = N - 1; i >= 0; --i) {
      uint32_t value = i < (N - 1) ? dims[i + 1] * stride_[i + 1] : 1;
      divmoder_[i] = paddle::platform::FastDivMod(value);
      stride_[i] = value;
    }
  }

  __device__ inline uint32_t GetStride(int idx) const { return stride_[idx]; }

  __device__ inline void GetIndexFromOffset(uint32_t offset,
                                            uint32_t* index) const {
    uint32_t remaining = offset;
#pragma unroll
    for (int i = 0; i < N - 1; ++i) {
      uint32_t idx = divmoder_[i].Div(remaining);
      index[i] = idx;
      remaining -= idx * stride_[i];
    }
    index[N - 1] = remaining;
  }

 private:
  uint32_t stride_[N];
  paddle::platform::FastDivMod divmoder_[N];
};

// Transform index between memory offset and shape coodinate.
template <typename T, int N>
class IdxAndOffsetHelper {
 public:
  IdxAndOffsetHelper() {}
  ~IdxAndOffsetHelper() = default;

  explicit IdxAndOffsetHelper(const T* dims) {
    index_helper = IdxHelper<N, T>(dims);
  }

  template <typename U>
  explicit IdxAndOffsetHelper(const U* dims) {
    T temp_dims[N];
    for (int i = 0; i < N; ++i) {
      temp_dims[i] = static_cast<T>(dims[i]);
    }
    index_helper = IdxHelper<N, T>(temp_dims);
  }

  __device__ inline T IndexToOffset(const T* index) const {
    T offset = 0;
#pragma unroll
    for (int i = 0; i < N - 1; ++i) {
      offset += index[i] * index_helper.GetStride(i);
    }
    offset += index[N - 1];
    return offset;
  }

  __device__ inline void OffsetToIndex(T offset, T* index) const {
    index_helper.GetIndexFromOffset(offset, index);
  }

 private:
  IdxHelper<N, T> index_helper;
};

template <size_t Rank, typename IndexT>
struct PermuteParams {
 public:
  IdxAndOffsetHelper<IndexT, Rank> src_index_helper;
  IdxAndOffsetHelper<IndexT, Rank> dst_index_helper;
  int perm[Rank]{};

  explicit PermuteParams(const std::vector<size_t>& dims,
                         const std::vector<int>& perm_) {
    size_t dst_dims[Rank];
    for (size_t i = 0; i < Rank; ++i) {
      dst_dims[i] = dims[perm_[i]];
      perm[i] = perm_[i];
    }
    dst_index_helper = IdxAndOffsetHelper<IndexT, Rank>(dst_dims);
    src_index_helper = IdxAndOffsetHelper<IndexT, Rank>(dims.data());
  }
};

// A special kernel for target case, both vectorized read and write supported.
template <typename T, typename IndexT, int VecSize, int Rank>
__global__ void VectorizedPermuteKernel(PermuteParams<Rank, IndexT> params,
                                        const size_t count,
                                        const T* __restrict__ src_data,
                                        T* dst_data) {
  using VecT = phi::AlignedVector<T, VecSize>;
  IndexT src_index[Rank];
  IndexT dst_index[Rank];

  const VecT* __restrict__ src =
      reinterpret_cast<const VecT* __restrict__>(src_data);
  VecT* dst = reinterpret_cast<VecT*>(dst_data);

  IndexT tid = blockIdx.x * blockDim.x + threadIdx.x;
  for (IndexT i = tid; i < count; i += blockDim.x * gridDim.x) {
    params.dst_index_helper.OffsetToIndex(i, dst_index);

#pragma unroll
    for (int j = 0; j < Rank; ++j) {
      src_index[params.perm[j]] = dst_index[j];
    }
    IndexT src_offset = params.src_index_helper.IndexToOffset(src_index);
    dst[i] = src[src_offset];
  }
}

// A general kernel for normal case, only support vectorized write.
template <typename T, typename IndexT, int VecSize, int Rank>
__global__ void GeneralPermuteKernel(PermuteParams<Rank, IndexT> params,
906 907
                                     const T* __restrict__ src,
                                     T* dst,
908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960
                                     const size_t main_cnt,
                                     const size_t tail_cnt,
                                     const size_t offset) {
  using VecT = phi::AlignedVector<T, VecSize>;
  VecT* vec_dst = reinterpret_cast<VecT*>(dst);

  IndexT src_index[VecSize][Rank];
  IndexT dst_index[VecSize][Rank];

  // Avoid read perm data both in 2 load process.
  __shared__ int perm[Rank];
  if (threadIdx.x < Rank) {
    perm[threadIdx.x] = params.perm[threadIdx.x];
  }
  __syncthreads();

  // Vectorized load data.
  IndexT tid = blockIdx.x * blockDim.x + threadIdx.x;
  for (IndexT idx = tid; idx < main_cnt; idx += blockDim.x * gridDim.x) {
    VecT vec_data;
    IndexT vec_idx = idx * VecSize;

#pragma unroll
    for (int i = 0; i < VecSize; ++i) {
      params.dst_index_helper.OffsetToIndex(vec_idx + i, dst_index[i]);

#pragma unroll
      for (int j = 0; j < Rank; ++j) {
        src_index[i][perm[j]] = dst_index[i][j];
      }
      IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[i]);
      vec_data[i] = src[src_offset];
    }
    vec_dst[idx] = vec_data;
  }

  // Singularized load data.
  if (tid < tail_cnt) {
    IndexT idx = tid + offset;
    params.dst_index_helper.OffsetToIndex(idx, dst_index[0]);

#pragma unroll
    for (int j = 0; j < Rank; ++j) {
      src_index[0][perm[j]] = dst_index[0][j];
    }
    IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[0]);
    dst[idx] = src[src_offset];
  }
}

// A Gerneral permute method that drectly find the dst data
// coordinate in the source data.
template <typename T, typename IndexT, int VecSize, int Rank>
961 962
inline void LaunchPermuteKernel(const phi::GPUContext& ctx,
                                const IndexT count,
963 964
                                const PermuteType perm_type,
                                const std::vector<size_t>& dims,
965 966
                                const std::vector<int>& perm,
                                const T* src,
967 968 969 970 971 972 973 974
                                T* dst) {
  size_t main_count = count / VecSize;
  auto params = PermuteParams<Rank, IndexT>(dims, perm);
  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, main_count);

  if (perm_type == PermuteType::kNormalPermute) {
    size_t tail_count = count - main_count * VecSize;
    size_t offset = count - tail_count;
975 976 977
    GeneralPermuteKernel<T, IndexT, VecSize, Rank>
        <<<config.GetGridSize(), config.GetBlockSize(), 0, ctx.stream()>>>(
            params, src, dst, main_count, tail_count, offset);
978
  } else {
979 980 981
    VectorizedPermuteKernel<T, IndexT, VecSize, Rank>
        <<<config.GetGridSize(), config.GetBlockSize(), 0, ctx.stream()>>>(
            params, main_count, src, dst);
982 983 984 985 986 987 988 989 990
  }
}

template <typename T, typename IndexT, int VecSize>
inline void LaunchPermuteRankDispatch(const phi::GPUContext& ctx,
                                      const IndexT count,
                                      const PermuteType perm_type,
                                      const std::vector<size_t>& dims,
                                      const std::vector<int>& perm,
991 992 993 994 995 996 997
                                      const T* src,
                                      T* dst) {
#define CALL_DISPATCH_RANK(rank)                      \
  case rank: {                                        \
    LaunchPermuteKernel<T, IndexT, VecSize, rank>(    \
        ctx, count, perm_type, dims, perm, src, dst); \
    break;                                            \
998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017
  }

  switch (dims.size()) {
    CALL_DISPATCH_RANK(1);
    CALL_DISPATCH_RANK(2);
    CALL_DISPATCH_RANK(3);
    CALL_DISPATCH_RANK(4);
    CALL_DISPATCH_RANK(5);
    CALL_DISPATCH_RANK(6);
    CALL_DISPATCH_RANK(7);
    CALL_DISPATCH_RANK(8);
    CALL_DISPATCH_RANK(9);
  }
#undef CALL_DISPATCH_RANK
}

// Aim at transposing the last 2 dimensions. Refer from
// https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/
template <typename T, typename IndexT, int VecSize>
__global__ void BatchTransposeKernel(const T* __restrict__ src_data,
1018 1019 1020
                                     T* dst_data,
                                     IndexT rows,
                                     IndexT cols) {
1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065
  using VecT = phi::AlignedVector<T, VecSize>;

  __shared__ VecT tile[kTileSize][kShareCol];
  T* single_tile = reinterpret_cast<T*>(tile);

  IndexT col_in_matrix = blockIdx.x * kTileSize + threadIdx.x;
  IndexT offset = blockIdx.z * rows * cols;

  // Vectorized load data from src into shared memory. [rows, cols]
  const VecT* __restrict__ src =
      reinterpret_cast<const VecT* __restrict__>(src_data);

  for (IndexT tile_y = threadIdx.y; tile_y < kTileSize; tile_y += kBlockRows) {
    IndexT row_in_matrix = tile_y + blockIdx.y * kTileSize;

    if (col_in_matrix < cols && row_in_matrix < rows) {
      tile[tile_y][threadIdx.x] =
          src[offset + row_in_matrix * cols + col_in_matrix];
    }
  }

  // Singularized load data from shared memory into dst.
  // and dst_cols = rows, dst_rows = cols, [cols * Vecsize, rows]
  col_in_matrix = blockIdx.y * kTileSize + threadIdx.x;
  offset = offset * VecSize + col_in_matrix;
  IndexT tile_x_idx = threadIdx.x * (kShareCol * VecSize);

  __syncthreads();

  for (IndexT tile_y = threadIdx.y; tile_y < kTileSize; tile_y += kBlockRows) {
    IndexT row_in_matrix = tile_y + blockIdx.x * kTileSize;
    IndexT dst_idx = offset + row_in_matrix * VecSize * rows;
    IndexT tile_idx = tile_x_idx + tile_y * VecSize;
    if (col_in_matrix < /*dst_cols=*/rows &&
        row_in_matrix < /*dst_rows=*/cols) {
#pragma unroll
      for (auto i = 0; i < VecSize; ++i) {
        dst_data[dst_idx + i * rows] = single_tile[tile_idx + i];
      }
    }
  }
}

// With the byte limitation of shared_memory, the VecSize shall be restricted
// for the type whose byte-size is less than 8.
1066 1067 1068
template <typename T,
          typename IndexT,
          int Size,
1069 1070
          int VecSize = (sizeof(T) > 8 ? 1 : Size)>
inline void LaunchTransposeKernel(const phi::GPUContext& ctx,
1071 1072
                                  const std::vector<size_t>& dims,
                                  const T* src,
1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083
                                  T* dst) {
  auto rank = dims.size();
  IndexT num_batches = (rank == 2) ? 1 : dims[0];
  IndexT rows = dims[rank - 2];
  IndexT cols = dims[rank - 1];
  IndexT num_tile_rows = (rows + kTileSize - 1) / kTileSize;
  IndexT num_tile_cols = (cols + kTileSize - 1) / kTileSize;

  dim3 blocks(num_tile_cols, num_tile_rows, num_batches);
  dim3 threads(kTileSize, kBlockRows, 1);

1084 1085
  BatchTransposeKernel<T, IndexT, VecSize>
      <<<blocks, threads, 0, ctx.stream()>>>(src, dst, rows, cols);
1086 1087 1088 1089 1090 1091 1092 1093
}

template <typename T, typename IndexT>
inline void LaunchWithDispatchVecSize(const phi::GPUContext& ctx,
                                      const int vec_size,
                                      const PermuteType perm_type,
                                      const std::vector<size_t>& dims,
                                      const std::vector<int>& perm,
1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105
                                      const T* src,
                                      T* dst,
                                      IndexT count) {
#define CALL_DISPATCH_VEC_SIZE(vec_size)                               \
  case vec_size: {                                                     \
    if (perm_type == PermuteType::kTranspose) {                        \
      LaunchTransposeKernel<T, IndexT, vec_size>(ctx, dims, src, dst); \
    } else {                                                           \
      LaunchPermuteRankDispatch<T, IndexT, vec_size>(                  \
          ctx, count, perm_type, dims, perm, src, dst);                \
    }                                                                  \
    break;                                                             \
1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122
  }

  switch (vec_size) {
    CALL_DISPATCH_VEC_SIZE(1);
    CALL_DISPATCH_VEC_SIZE(2);
    CALL_DISPATCH_VEC_SIZE(4);
    default: {
      PADDLE_THROW(phi::errors::Unimplemented(
          "Unsupported vectorized size: %d !", vec_size));
      break;
    }
  }
#undef CALL_DISPATCH_VEC_SIZE
}

template <typename T>
inline void LaunchWithDispatchIndex(const phi::GPUContext& ctx,
1123 1124
                                    const size_t count,
                                    const int vec_size,
1125 1126
                                    const PermuteType perm_type,
                                    const std::vector<size_t>& dims,
1127 1128
                                    const std::vector<int>& perm,
                                    const T* src,
1129 1130
                                    T* dst) {
  if (count < std::numeric_limits<uint32_t>::max()) {
1131 1132 1133 1134 1135 1136 1137
    LaunchWithDispatchVecSize<T, uint32_t>(ctx,
                                           vec_size,
                                           perm_type,
                                           dims,
                                           perm,
                                           src,
                                           dst,
1138 1139 1140
                                           static_cast<uint32_t>(count));
  } else {
    int64_t cnt = static_cast<int64_t>(count);
1141 1142 1143 1144 1145 1146 1147
    LaunchWithDispatchVecSize<T, int64_t>(ctx,
                                          vec_size,
                                          perm_type,
                                          dims,
                                          perm,
                                          src,
                                          dst,
1148 1149 1150 1151 1152
                                          static_cast<int64_t>(count));
  }
}

template <typename DeviceContext, typename T>
1153 1154 1155 1156
inline void SimplifyThenLaunch(const int rank,
                               const DeviceContext& ctx,
                               const Tensor& in,
                               Tensor* out,
1157 1158 1159
                               const std::vector<int32_t>& perm) {
  int sm_count = ctx.GetSMCount();
  auto src_dims = phi::vectorize<size_t>(in.dims());
1160 1161
  auto simplifier = DimsSimplifier<T>(
      sm_count, rank, perm, src_dims, in.data<T>(), out->data<T>());
1162 1163 1164 1165 1166

  if (simplifier.GetPermType() == PermuteType::kCopy) {
    // If perm is [0,1,2,3], then just operate a DtoD copy.
    phi::Copy(ctx, in, ctx.GetPlace(), false, out);
  } else {
1167 1168 1169 1170 1171 1172 1173 1174
    LaunchWithDispatchIndex<T>(ctx,
                               simplifier.GetCount(),
                               simplifier.GetVecSize(),
                               simplifier.GetPermType(),
                               simplifier.GetDims(),
                               simplifier.GetPerm(),
                               in.data<T>(),
                               out->data<T>());
1175 1176 1177 1178
  }
}

template <typename T>
1179
void TransposeGPUKernelDriver(const phi::GPUContext& ctx,
H
hong 已提交
1180
                              const Tensor& in,
1181 1182
                              const std::vector<int32_t>& perm,
                              Tensor* out) {
1183
  const int rank = perm.size();
1184 1185 1186 1187 1188 1189 1190
  int64_t numel = in.numel();
  bool ret{false};
  if (numel >= INT32_MAX) {
    ret = TransposeSimple<T, int64_t>::run(ctx, in, perm, out);
  } else {
    ret = TransposeSimple<T>::run(ctx, in, perm, out);
  }
1191
  if (!ret) {
1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209
    auto* tuner =
        phi::autotune::MakeTransposeTuner<T>(TransCompute<phi::GPUContext, T>);
    tuner->AddCallBack(
        phi::autotune::MakeCallback<T>(SimplifyThenLaunch<phi::GPUContext, T>));

    size_t key = phi::autotune::TransposeKey(
        phi::vectorize(in.dims()),
        perm,
        paddle::experimental::CppTypeToDataType<T>::Type());

    tuner->Run(ctx,
               phi::autotune::AlgorithmType::kTranspose,
               key,
               rank,
               ctx,
               in,
               out,
               perm);
1210 1211 1212 1213 1214
  }
}

}  // namespace operators
}  // namespace paddle