datamover_primitives_xpu2.h 41.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h"

20
namespace phi {
21
namespace kps {
22 23
namespace details {

24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
enum class OptType {    // Optimize type of calc after input shape compressed
  CanNotOptimize = -1,  // can not optimize, broadcast first
  N_1,                  // just like {1} op {100} or {100} op {1}
  MN_N,                 // just like {100} op {3, 100} or {3, 100} op {100}
  MN_M,                 // just like {3} op {3, 100} or {3, 100} op {3}
  MNK_1N1,              // just like {3} op {2, 3, 100} or {2, 3, 100} op {3}
  MNK_M1K,  // just like {2, 1, 100} op {2, 3, 100} or {2, 3, 100} op {2, 1,
            // 100}
};

// Rules to determine whether dimensions can be merged
// rule 0 - xshape[idx] == yshape[idx]
// rule 1 - xshape[idx] == 1 && yshape[idx] != 1
// rule 2 - xshape[idx] != 1 && yshape[idx] == 1
static int judge_case(int a, int b) {
  if (a == b) {
    return 0;
  } else if (a == 1 && b != 1) {
    return 1;
  } else if (a != 1 && b == 1) {
    return 2;
  }
  return -1;
}

static bool case_is_same(int case_front, int case_back) {
  if (case_front == case_back) {
    return true;
  } else {
    return false;
  }
}

57 58 59 60 61 62 63 64 65 66
template <typename T, int VecSize>
struct alignas(sizeof(T) * VecSize) VectorType {
  T val[VecSize];
};

/**
 * Configuration of broadcast. Calculate the input data index according to the
 * index of the output data. if input or output shape is [dim0, dim1] then dims
 * must be [dim1, dim0].
 */
67
#pragma pack(4)
68
struct BroadcastConfig {
69 70 71
  int strides_in[phi::DDim::kMaxRank];
  int strides_out[phi::DDim::kMaxRank];
  int in_dim[phi::DDim::kMaxRank];
72
  int dim_after_cmp[phi::DDim::kMaxRank];
73
  int y_dim_after_cmp[phi::DDim::kMaxRank];
74 75 76 77 78 79 80
  int dim_size_after_cmp = 0;
  int cmp_res = 0;
  OptType cmp_type = OptType::CanNotOptimize;
  int m = 1;
  int n = 1;
  int k = 1;
  int buf_len = 0;
81
  int kDims;
82 83 84 85
  HOSTDEVICE BroadcastConfig() {}

  HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
                             const std::vector<int64_t>& in_dims,
86
                             const std::vector<int64_t>& y_in_dims,
87
                             int dim_size) {
88 89 90 91 92 93 94 95 96
    std::vector<int> strides_in_tmp;
    std::vector<int> strides_out_tmp;
    std::vector<int> dim_tmp;
    strides_in_tmp.resize(dim_size, 1);
    strides_out_tmp.resize(dim_size, 1);
    dim_tmp.resize(dim_size, 1);
    for (int i = 1; i < dim_size; i++) {
      strides_in_tmp[i] = strides_in_tmp[i - 1] * in_dims[i - 1];
      strides_out_tmp[i] = strides_out_tmp[i - 1] * out_dims[i - 1];
97 98
    }

99 100
    for (int i = 0; i < dim_size; i++) {
      dim_tmp[i] = in_dims[i];
101
    }
102
    kDims = dim_size;
103 104 105
    memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int));
    memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int));
    memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int));
106

107 108
    cmp_res = get_mnk_for_broadcast_ops(in_dims, y_in_dims);
    get_opt_type();
109 110 111 112 113 114 115 116 117 118 119 120 121
    buf_len = get_buf_len();
  }

  int get_buf_len() {
    if (cmp_type == OptType::CanNotOptimize) {
      return 256;
    }
    int max_buf_len = 512;
    int buf_len = m / 16 * 16;
    if (buf_len == 0) {
      buf_len = m;
    }
    return std::min(max_buf_len, buf_len);
122 123 124 125
  }

  __device__ inline int operator()(int index_output) const {
    int index_src = 0;
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153

    switch (cmp_type) {
      int div, mod, tmp_index;
      case OptType::MNK_M1K:
        div = index_output / (m * n);
        mod = index_output % (m * n) % m;
        index_src = div * m + mod;
        break;
      case OptType::MNK_1N1:
        // index_src = index_output / m % n;
        index_src = index_output % (m * n) / m;
        break;
      case OptType::N_1:
        index_src = 0;
        break;
      case OptType::MN_N:
        index_src = index_output / m;
        break;
      case OptType::MN_M:
        index_src = index_output % m;
        break;
      case OptType::CanNotOptimize:
        for (int i = kDims - 1; i >= 0; --i) {
          tmp_index = (index_output / strides_out[i]);
          index_output = index_output - tmp_index * strides_out[i];
          index_src += (tmp_index % in_dim[i]) * strides_in[i];
        }
        break;
154 155
    }
    return index_src;
156
  }
157

158
  void get_opt_type() {
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
    if (dim_size_after_cmp == 1) {
      if (dim_after_cmp[0] == 1 && y_dim_after_cmp[0] != 1) {  // {1} op {n}
        n = y_dim_after_cmp[0];
        cmp_type = OptType::N_1;
      } else if (dim_after_cmp[0] != 1 &&
                 y_dim_after_cmp[0] == 1) {  // {n} op {1}
        n = dim_after_cmp[0];
        cmp_type = OptType::N_1;
      } else {
        cmp_type = OptType::CanNotOptimize;  // xshape == yshape
      }
    }
    if (dim_size_after_cmp == 2) {
      if (dim_after_cmp[0] == 1 && dim_after_cmp[1] != 1 &&
          y_dim_after_cmp[0] != 1 &&
          y_dim_after_cmp[1] != 1) {  // {n} op {m, n}
        m = y_dim_after_cmp[0];
        n = y_dim_after_cmp[1];
        cmp_type = OptType::MN_N;
      } else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] == 1 &&
                 y_dim_after_cmp[0] != 1 &&
                 y_dim_after_cmp[1] != 1) {  // {m} op {m, n}
        m = y_dim_after_cmp[0];
        n = y_dim_after_cmp[1];
        cmp_type = OptType::MN_M;
      } else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] != 1 &&
                 y_dim_after_cmp[0] == 1 &&
                 y_dim_after_cmp[1] != 1) {  // {m, n} op {n}
        m = dim_after_cmp[0];
        n = dim_after_cmp[1];
        cmp_type = OptType::MN_N;
      } else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] != 1 &&
                 y_dim_after_cmp[0] != 1 &&
                 y_dim_after_cmp[1] == 1) {  // {m, n} op {m}
        m = dim_after_cmp[0];
        n = dim_after_cmp[1];
        cmp_type = OptType::MN_M;
      } else {
        cmp_type = OptType::CanNotOptimize;
      }
    }
    if (dim_size_after_cmp == 3) {
      if (dim_after_cmp[0] == 1 && dim_after_cmp[1] != 1 &&
          dim_after_cmp[2] == 1 && y_dim_after_cmp[0] != 1 &&
          y_dim_after_cmp[1] != 1 &&
          y_dim_after_cmp[2] != 1) {  // {1, n, 1} op {m, n, k}
        m = y_dim_after_cmp[0];
        n = y_dim_after_cmp[1];
        k = y_dim_after_cmp[2];
        cmp_type = OptType::MNK_1N1;
      } else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] != 1 &&
                 dim_after_cmp[2] != 1 && y_dim_after_cmp[0] == 1 &&
                 y_dim_after_cmp[1] != 1 &&
                 y_dim_after_cmp[2] == 1) {  // {m, n, k} op {1, n, 1}
        m = dim_after_cmp[0];
        n = dim_after_cmp[1];
        k = dim_after_cmp[2];
        cmp_type = OptType::MNK_1N1;
      } else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] == 1 &&
                 dim_after_cmp[2] != 1 && y_dim_after_cmp[0] != 1 &&
                 y_dim_after_cmp[1] != 1 &&
                 y_dim_after_cmp[2] != 1) {  // {m, 1, k} op {m, n, k}
        m = y_dim_after_cmp[0];
        n = y_dim_after_cmp[1];
        k = y_dim_after_cmp[2];
        cmp_type = OptType::MNK_M1K;
      } else if (dim_after_cmp[0] != 1 && dim_after_cmp[1] != 1 &&
                 dim_after_cmp[2] != 1 && y_dim_after_cmp[0] != 1 &&
                 y_dim_after_cmp[1] == 1 &&
                 y_dim_after_cmp[2] != 1) {  // {m, n, k} op {m, 1, k}
        m = dim_after_cmp[0];
        n = dim_after_cmp[1];
        k = dim_after_cmp[2];
        cmp_type = OptType::MNK_M1K;
      } else {
        cmp_type = OptType::CanNotOptimize;
      }
    }
  }

  int get_mnk_for_broadcast_ops(const std::vector<int64_t>& xshape,
                                const std::vector<int64_t>& yshape) {
    int idx = 0;
    int cmp_x = 0;
    int cmp_y = 0;
    bool is_same = false;
245

246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
    std::vector<int64_t> xshape_after_remove_ones = xshape;
    std::vector<int64_t> yshape_after_remove_ones = yshape;
    // first step: remove excess ones
    std::vector<int64_t>::iterator x_iter = xshape_after_remove_ones.begin();
    std::vector<int64_t>::iterator y_iter = yshape_after_remove_ones.begin();
    for (; x_iter != xshape_after_remove_ones.end();) {
      if (*x_iter == 1 && *y_iter == 1) {
        x_iter = xshape_after_remove_ones.erase(x_iter);
        y_iter = yshape_after_remove_ones.erase(y_iter);
      } else {
        x_iter++;
        y_iter++;
      }
    }
    // second step: compress dims
    int after_cmp_idx = 0;
    for (int i = 0; i < 3; i++) {
      cmp_x = xshape_after_remove_ones[idx];
      cmp_y = yshape_after_remove_ones[idx];
      while ((idx + 1) < xshape_after_remove_ones.size()) {
        is_same = case_is_same(judge_case(xshape_after_remove_ones[idx],
                                          yshape_after_remove_ones[idx]),
                               judge_case(xshape_after_remove_ones[idx + 1],
                                          yshape_after_remove_ones[idx + 1]));
        if (is_same) {
          cmp_x = cmp_x * xshape_after_remove_ones[idx + 1];
          cmp_y = cmp_y * yshape_after_remove_ones[idx + 1];
          idx++;
        } else {
          break;
        }
      }
      idx = idx + 1;
      dim_after_cmp[after_cmp_idx] = cmp_x;
280
      y_dim_after_cmp[after_cmp_idx] = cmp_y;
281 282 283 284 285 286 287 288
      after_cmp_idx++;
      if (idx == xshape_after_remove_ones.size()) {
        dim_size_after_cmp = after_cmp_idx;
        return 0;
      }
    }
    return -1;  // can not compress dims
  }
289
};
290
#pragma pack()
291

292
template <typename T>
293
__device__ __forceinline__ void WriteData(T _global_ptr_* dst,
294 295 296 297 298 299 300 301
                                          T* src,
                                          int num) {
  if (num > 0) {
    LM2GM(src, dst, num * sizeof(T));
  }
}
#undef INT_BITS

302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
}  // namespace details

/**
 * @brief Read 2D data from global memory to register according to Tx type, and
 * store it as Ty type into register.
 *
 * @template paraments
 * Tx: The type of data stored in the global memory.
 * Ty: The type of data that needs to be stored in registers.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For xpu,
 * core_id() is used as the index.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x core_num(), boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX * NY.
 * src: The data pointer of the current block.
 * size_nx: The maximum offset of the current block is size_nx elements in the
 * lowest dimension. The parameters are only calculated when isboundary = true.
 * size_ny: The maximum offset of the current block is size_ny elements in the
 * first dimension. The parameters are only calculated when isboundary = true.
 * stride_nx: Each read one element stride stride_nx elements in the last dim.
 * stride_ny: Each read one element stride stride_ny elements in the first dim.
 */
330 331 332 333 334
template <typename Tx,
          typename Ty,
          int NX,
          int NY,
          int BlockSize,
335
          bool IsBoundary = false>
336 337 338 339 340
__device__ __inline__ void ReadData(Ty* dst,
                                    const Tx _global_ptr_* src,
                                    int size_nx,
                                    int size_ny,
                                    int stride_nx,
341
                                    int stride_ny) {
342 343
  int thread_offset = core_id();
  int left_size_nx = size_nx - thread_offset;
344
  __local__ Tx in_temp[1];
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407
  // Each branch is added for better performance
  if (NX == 1 && NY == 1) {  // for NX == 1 and NY == 1
    if (IsBoundary) {
      if (left_size_nx > 0) {
        GM2LM(src + thread_offset, in_temp, sizeof(Tx));
        dst[0] = static_cast<Ty>(in_temp[0]);
      }
    } else {
      GM2LM(src + thread_offset, in_temp, sizeof(Tx));
      dst[0] = static_cast<Ty>(in_temp[0]);
    }
  } else if (NX == 1) {  // for NX == 1 and NY != 1
#pragma unroll
    for (int idy = 0; idy < NY; ++idy) {
      if (IsBoundary) {
        if (idy * stride_ny >= size_ny) {
          break;
        }
      }
      GM2LM(src + thread_offset + idy * stride_ny, in_temp, sizeof(Tx));
      dst[idy] = static_cast<Ty>(in_temp[0]);
    }
  } else if (NY == 1) {  // for NY == 1 and NX != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
        if (idx * stride_nx >= left_size_nx) {
          break;
        }
      }
      GM2LM(src + thread_offset + idx * stride_nx, in_temp, sizeof(Tx));
      dst[idx] = static_cast<Ty>(in_temp[0]);
    }
  } else {  // for NX != 1 and NY != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
#pragma unroll
      for (int idy = 0; idy < NY; ++idy) {
        if (IsBoundary) {
          if (idy * stride_ny >= size_ny || idx * stride_nx >= left_size_nx) {
            break;
          }
        }
        int fix = thread_offset + idx * stride_nx + idy * stride_ny;
        GM2LM(src + fix, in_temp, sizeof(Tx));
        dst[idy * NX + idx] = static_cast<Ty>(in_temp[0]);
      }
    }
  }
}

/**
 * @brief Initialize register with init_data.
 *
 * @template paraments
 * T: Data type of register.
 * NX: Number of data to initialize.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX.
 * init_data: Initial value.
 */
template <typename T, int NX>
408
__device__ __inline__ void Init(T* dst, T init_data) {
409 410 411 412 413 414
#pragma unroll
  for (int i = 0; i < NX; i++) {
    dst[i] = init_data;
  }
}

415 416 417 418 419 420 421 422
template <typename T, int NX>
__device__ __inline__ void Init(T* dst, T init_data, int read_lens) {
#pragma unroll
  for (int i = 0; i < read_lens; i++) {
    dst[i] = init_data;
  }
}

423 424 425 426 427 428 429 430 431 432 433 434
/**
 * The difference from the above function is that
 * it supports different data types of inputs.
 */
template <typename T, typename ArgsT, int Index, int NX>
__device__ __forceinline__ void Init(ArgsT* dst, T init_data) {
#pragma unroll
  for (int i = 0; i < NX; i++) {
    std::get<Index>(dst[i]) = init_data;
  }
}

435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
/**
 * @brief Read 1D data from global memory to register. When IsBoundary = true
 * and (NX % 4 == 0 or Nx % 2 == 0), vectorized load data will be used to
 * improve memory access efficiency.
 *
 * @template paraments
 * T: The type of data.
 * NX: Each thread load NX data from global memory continuously.
 * NY: Each thread need to load NY rows, only NY = 1 was supported.
 * BlockSize: Identifies the current device thread index method. For xpu,
 * core_id() is used as the index.
 * IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
 * When the number of data processed by this block is less than
 * NX x NY x core_num(), boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX * NY.
 * src: The data pointer of the current block.
 * size: The current block needs to load size data continuously.
 */
456
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
457 458
__device__ __inline__ void ReadData(T* dst,
                                    const T _global_ptr_* src,
459
                                    int num) {
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
  int thread_offset = core_id() * NX;
  __local__ T in_temp[1];
  if (IsBoundary) {  // core_num() * NX > num
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (idx + thread_offset < num) {
        GM2LM(src + thread_offset + idx, in_temp, sizeof(T));
        dst[idx] = in_temp[0];
      }
    }
  } else {  // core_num() * NX < num
    GM2LM(src + thread_offset, dst, NX * sizeof(T));
  }
}

475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
__device__ __inline__ void ReadData(T* dst,
                                    const T _global_ptr_* src,
                                    int num,
                                    int read_lens) {
  int thread_offset = core_id() * read_lens;
  __local__ T in_temp[1];
  if (IsBoundary) {  // core_num() * read_lens > num
#pragma unroll
    for (int idx = 0; idx < read_lens; ++idx) {
      if (idx + thread_offset < num) {
        GM2LM(src + thread_offset + idx, in_temp, sizeof(T));
        dst[idx] = in_temp[0];
      }
    }
  } else {  // core_num() * read_lens < num
    GM2LM(src + thread_offset, dst, read_lens * sizeof(T));
  }
}

495 496 497
/**
 * @brief Read 1D data from global memory to register. The difference
 * from the above function is that it supports different data types of inputs.
498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515
 *
 * @template paraments
 * T: The type of data.
 * NX: Each thread load NX data from global memory continuously.
 * NY: Each thread need to load NY rows, only NY = 1 was supported.
 * ArgsT: The Type if dst, ArgsT can be std::tuple<T> or std::tuple<Args>
 * Index: The index of data stored in dst.
 * BlockSize: Identifies the current device thread index method. For xpu,
 * core_id() is used as the index.
 * IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
 * When the number of data processed by this block is less than
 * NX x NY x blockDim.x, boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX * NY.
 * src: The data pointer of the current block.
 * size: The current block needs to load size data continuously.
516 517 518 519 520 521 522
 */
template <typename T,
          int NX,
          int NY,
          int BlockSize,
          typename ArgsT,
          int Index,
523
          bool IsBoundary>
524
__device__ __forceinline__ void ReadData(ArgsT* dst,
525
                                         const T _global_ptr_* src,
526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546
                                         int num) {
  int thread_offset = core_id() * NX;
  __local__ T in_temp[1];
  __local__ T in_vec[NX];
  if (IsBoundary) {  // core_num() * NX > num
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (idx + thread_offset < num) {
        GM2LM(src + thread_offset + idx, in_temp, sizeof(T));
        std::get<Index>(dst[idx]) = in_temp[0];
      }
    }
  } else {  // core_num() * NX < num
    GM2LM(src + thread_offset, in_vec, NX * sizeof(T));
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      std::get<Index>(dst[idx]) = in_vec[idx];
    }
  }
}

547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570
/**
 * @brief Read 2D data from global memory to registers with broadcast form.
 *
 * @template paraments
 * T: The type of data stored in the global memory.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For xpu,
 * core_id() is used as the index.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x core_num(), boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX * NY.
 * src: Raw input data pointer of kernel.
 * block_offset: Data offset of this block, core_num() *  cluster_id() * NX;
 * config: Calculation configuration of broadcast. It is used to calculate the
 * coordinate mapping relationship between output data and input data.
 * total_num_output: Total number of original output.
 * stride_nx: Each read one element stride stride_nx elements in the last dim.
 * stride_ny: Each read one element stride stride_ny elements in the first dim.
 */
571
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
572 573
__device__ __inline__ void ReadDataBc(T* dst,
                                      const T _global_ptr_* src,
574
                                      uint32_t block_offset,
575
                                      const details::BroadcastConfig& config,
576 577
                                      int total_num_output,
                                      int stride_nx,
578
                                      int stride_ny) {
579 580 581 582 583 584 585 586 587 588 589
  uint32_t thread_offset = block_offset + core_id();
  uint32_t index_src = 0;
  __local__ T in_temp[1];

#pragma unroll
  for (int ny = 0; ny < NY; ++ny) {
#pragma unroll
    for (uint32_t nx = 0; nx < NX; ++nx) {
      uint32_t index_output = thread_offset + ny * stride_ny + nx * stride_nx;
      index_src = 0;
      if (IsBoundary) {
590
        if (index_output >= (uint32_t)total_num_output) {
591 592 593
          break;
        }
      }
594
      index_src = config(index_output);
595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631
      GM2LM(src + index_src, in_temp, sizeof(T));
      dst[nx + ny * NX] = in_temp[0];
    }
  }
}

/**
 * @brief Read 2D data from global memory to register with reduce form.
 *
 * @template paraments
 * T: The type of data.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For xpu,
 * core_id() is used as the index.
 * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x core_num(), boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX * NY.
 * src: The input data pointer of this block.
 * block_offset: The data offset of this block, blockDim.x * cluster_id() * NX.
 * index_cal: Calculation configuration of Reduce. It is used to calculate the
 * coordinate mapping relationship between output data and input data.
 * size_nx: The current block needs to load size_nx columns of data, this
 * parameter will participate in the calculation when isboundary = true.
 * size_ny: The current block needs to load size_ny rows of data, this parameter
 * will participate in the calculation when isboundary = true.
 * will be used when IsBoundary = true.
 * stride_nx: Each read one element stride stride_nx columns.
 * stride_ny: Each read one element stride stride_ny raws.
 * reduce_last_dim: Used to indicate whether the dimension of reduce contains
 * the lowest dimension.
 */
632 633
template <typename Tx,
          typename Ty,
634 635 636 637 638
          int NX,
          int NY,
          int BlockSize,
          int Rank,
          typename IndexCal,
639
          typename Functor,
640
          bool IsBoundary = false>
641 642 643 644 645 646 647 648 649 650 651
__device__ __forceinline__ void ReadDataReduce(
    Ty* dst,
    const Tx _global_ptr_* __restrict__ src,
    int block_offset,
    const IndexCal& index_cal,
    int size_nx,
    int size_ny,
    int stride_nx,
    int stride_ny,
    Functor func,
    bool reduce_last_dim) {
652
  __local__ Tx in_temp[1];
653
  int thread_offset = 0;
654
  int left_idx = 0;
655
  if (reduce_last_dim) {
656 657
    thread_offset = core_id();
    left_idx = 0;
658
  } else {
659 660
    thread_offset = 0;
    left_idx = 0;
661 662 663 664 665 666
  }

  if (NX == 1) {
#pragma unroll
    for (int ny = 0; ny < NY; ++ny) {
      if (IsBoundary) {
667
        if (thread_offset >= size_ny) {
668 669 670
          break;
        }
      }
671 672 673
      uint32_t index_src = index_cal(thread_offset + block_offset);
      GM2LM(src + index_src, in_temp, sizeof(Tx));
      dst[ny] = static_cast<Ty>(func(in_temp[0]));
674 675 676 677 678 679 680 681
      thread_offset += stride_ny;
    }
  } else {
#pragma unroll
    for (int nx = 0; nx < NX; ++nx) {
#pragma unroll
      for (int ny = 0; ny < NY; ++ny) {
        if (IsBoundary) {
682 683
          if ((thread_offset >= size_ny) ||
              (left_idx + nx * stride_nx >= size_nx)) {
684 685 686
            break;
          }
        }
687 688 689
        uint32_t index_src = index_cal(thread_offset + block_offset);
        GM2LM(src + index_src, in_temp, sizeof(Tx));
        dst[nx + ny * NX] = static_cast<Ty>(func(in_temp[0]));
690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716
        thread_offset += stride_ny;
      }
    }
  }
}
/**
 * @brief Write 1D data from registers to global memory. When IsBoundary = true
 * and (NX % 4 == 0 or Nx % 2 == 0), the data will be vectorized to improve the
 * data loading efficiency
 *
 * @template paraments
 * T: The type of data.
 * NX: The number of data continuously writed by each thread.
 * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
 * BlockSize: Identifies the current device thread index method. For xpu,
 * core_id() is used as the index.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x core_num(), boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
 * @param:
 * dst: The data pointer of the current block.
 * src: The register pointer, the size is NX * NY.
 * size: The current block needs to load size elements continuously.
 */

717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
__device__ void WriteData(T _global_ptr_* dst,
                          const T* src,
                          int num,
                          int read_lens) {
  int thread_offset = core_id() * read_lens;
  __local__ T in_temp[1];

  if (IsBoundary) {  // core_num() * read_lens > num
#pragma unroll
    for (int idx = 0; idx < read_lens; ++idx) {
      if (idx + thread_offset < num) {
        in_temp[0] = src[idx];
        LM2GM(in_temp, dst + idx + thread_offset, sizeof(T));
      }
    }
  } else {  // core_num() * read_lens < num
    LM2GM(src, dst + thread_offset, read_lens * sizeof(T));
  }
}

738 739 740 741
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
__device__ void WriteData(T _global_ptr_* dst, const T* src, int num) {
  int thread_offset = core_id() * NX;
  __local__ T in_temp[1];
742

743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781
  if (IsBoundary) {  // core_num() * NX > num
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (idx + thread_offset < num) {
        in_temp[0] = src[idx];
        LM2GM(in_temp, dst + idx + thread_offset, sizeof(T));
      }
    }
  } else {  // core_num() * NX < num
    LM2GM(src, dst + thread_offset, NX * sizeof(T));
  }
}

/**
 * @brief Write 2D data from register to global memory according to Tx type, and
 * store it as Ty type.
 *
 * @template paraments
 * Tx: The type of data that needs to be stored in registers.
 * Ty: The type of data stored in the global memory.
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For xpu,
 * core_id() is used as the index.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x core_num(), boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
 * @param:
 * dst: Data pointer of the current block.
 * src: The register pointer of the thread, the size is NX * NY.
 * size_nx: The current block needs to load size_nx columns of data, this
 * parameter will be used when IsBoundary = true.
 * size_ny: The current block needs to load size_ny rows of data. This parameter
 * will be used when IsBoundary = true.
 * stride_nx: Each read one element stride stride_nx elements in the last dim.
 * stride_ny: Each read one element stride stride_ny elements in the first dim.
 */
782 783 784 785 786
template <typename Tx,
          typename Ty,
          int NX,
          int NY,
          int BlockSize,
787
          bool IsBoundary = false>
788 789 790 791 792
__device__ __inline__ void WriteData(Ty _global_ptr_* dst,
                                     const Tx* src,
                                     int size_nx,
                                     int size_ny,
                                     int stride_nx,
793
                                     int stride_ny) {
794 795 796 797 798 799 800 801 802
  int thread_offset = core_id();
  int left_size_nx = size_nx - thread_offset;
  __local__ Ty in_temp[1];

  // Each branch is added for better performance
  if (NX == 1 && NY == 1) {
    if (IsBoundary) {
      if (left_size_nx > 0) {
        in_temp[0] = static_cast<Ty>(src[0]);
803
        LM2GM(in_temp, dst + thread_offset, sizeof(Ty));
804 805 806
      }
    } else {
      in_temp[0] = static_cast<Ty>(src[0]);
807
      LM2GM(in_temp, dst + thread_offset, sizeof(Ty));
808 809 810 811 812 813 814 815 816 817 818
    }
  } else if (NX == 1) {
#pragma unroll
    for (int idy = 0; idy < NY; ++idy) {
      if (IsBoundary) {
        if (idy * stride_ny >= size_ny) {
          break;
        }
      }

      in_temp[0] = static_cast<Ty>(src[idy]);
819
      LM2GM(in_temp, dst + thread_offset + idy * stride_ny, sizeof(Ty));
820 821 822 823 824 825 826 827 828 829 830
    }
  } else if (NY == 1) {  // for NY == 1 and NX != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
        if (idx * stride_nx >= left_size_nx) {
          break;
        }
      }

      in_temp[0] = static_cast<Ty>(src[idx]);
831
      LM2GM(in_temp, dst + thread_offset + idx * stride_nx, sizeof(Ty));
832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848
    }
  } else {  // for NX != 1 and NY != 1
#pragma unroll
    for (int idx = 0; idx < NX; ++idx) {
      if (IsBoundary) {
        if (idx * stride_nx >= left_size_nx) {
          break;
        }
      }
#pragma unroll
      for (int idy = 0; idy < NY; ++idy) {
        if (IsBoundary) {
          if (idy * stride_ny >= size_ny) {
            break;
          }
        }
        in_temp[0] = static_cast<Ty>(src[idx + idy * NX]);
849 850
        LM2GM(in_temp,
              dst + thread_offset + idx * stride_nx + idy * stride_ny,
851
              sizeof(Ty));
852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868
      }
    }
  }
}

/**
 * @brief Initialize register with init_data.
 *
 * @template paraments
 * T: Data type of register.
 * NX: Number of data to initialize.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX.
 * init_data: The register pointer of init data, the size is NX.
 */
template <typename T, int NX, bool IsBoundary = false>
869
__device__ __inline__ void Init(T* dst, T* init_data, int num) {
870 871 872 873 874 875 876 877 878 879 880
#pragma unroll
  for (int i = 0; i < NX; i++) {
    if (IsBoundary) {
      if (i >= num) {
        break;
      }
    }
    dst[i] = init_data[i];
  }
}

881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896
/**
 * @brief Read data from global memory to local memory with broadcast
 * {m, 1, k}-> {m, n, k} form.
 *
 * @template paraments
 * T: Data type of register.
 * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX.
 * src: The original input data pointer of kernel.
 * thread_offset: The data offset of this thread.
 * config: Calculation configuration of broadcast. It is used to calculate the
 * coordinate mapping relationship between output data and input data.
 * read_lens: The number of data continuously loaded by each thread.
 */
897
template <typename T>
898 899 900 901
__device__ __inline__ void ReadDataBcM1kMnk(
    T* dst,
    const T _global_ptr_* src,
    int thread_offset,
902
    const details::BroadcastConfig& config,
903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943
    int read_lens) {
  int index_output = thread_offset;
  int index_base = config(index_output);
  int m = config.m;
  int n = config.n;

  int m_pos = index_base % m;
  if ((m - m_pos) < read_lens) {
    int last_col = m - m_pos;
    GM2LM(src + index_base, dst, last_col * sizeof(T));
    int n_pos = index_output % (m * n) / m;
    int next_part_index = 0;
    if (n_pos != config.n - 1) {
      next_part_index = index_base / m * m;
    } else {
      next_part_index = (index_base / m + 1) * m;
    }
    GM2LM(src + next_part_index,
          dst + last_col,
          (read_lens - last_col) * sizeof(T));
  } else {
    GM2LM(src + index_base, dst, read_lens * sizeof(T));
  }
}

/**
 * @brief Read data from global memory to local memory with broadcast
 * {m, 1}-> {m, n} form.
 *
 * @template paraments
 * T: Data type of register.
 * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX.
 * src: The original input data pointer of kernel.
 * thread_offset: The data offset of this thread.
 * config: Calculation configuration of broadcast. It is used to calculate the
 * coordinate mapping relationship between output data and input data.
 * read_lens: The number of data continuously loaded by each thread.
 */
944
template <typename T>
945 946 947 948
__device__ __inline__ void ReadDataBcM1Mn(
    T* dst,
    const T _global_ptr_* src,
    int thread_offset,
949
    const details::BroadcastConfig& config,
950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980
    int read_lens) {
  int index_output = thread_offset;
  int index_base = config(index_output);
  int m = config.m;
  int n = config.n;

  int m_pos = index_base % m;
  if ((m - m_pos) < read_lens) {
    int last_col = m - m_pos;
    GM2LM(src + index_base, dst, last_col * sizeof(T));
    GM2LM(src, dst + last_col, (read_lens - last_col) * sizeof(T));
  } else {
    GM2LM(src + index_base, dst, read_lens * sizeof(T));
  }
}

/**
 * @brief Read data from global memory to local memory with broadcast
 * {1, n}-> {m, n} form.
 *
 * @template paraments
 * T: Data type of register.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX.
 * src: The original input data pointer of kernel.
 * thread_offset: The data offset of this thread.
 * config: Calculation configuration of broadcast. It is used to calculate the
 * coordinate mapping relationship between output data and input data.
 * read_lens: The number of data continuously loaded by each thread.
 */
981
template <typename T>
982 983 984 985
__device__ __inline__ void ReadDataBc1NMn(
    T* dst,
    const T _global_ptr_* src,
    int thread_offset,
986
    const details::BroadcastConfig& config,
987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027
    int read_lens) {
  int index_output = thread_offset;
  int index_base = config(index_output);
  int m = config.m;
  int n = config.n;
  T in_temp;

  int m_pos = index_output % m;
  if ((m - m_pos) < read_lens) {
    int last_col = m - m_pos;
    GM2LM(src + index_base, &in_temp, sizeof(T));
    for (int i = 0; i < last_col; i++) {
      dst[i] = in_temp;
    }
    GM2LM(src + index_base + 1, &in_temp, sizeof(T));
    for (int i = 0; i < read_lens - last_col; i++) {
      dst[last_col + i] = in_temp;
    }
  } else {
    GM2LM(src + index_base, &in_temp, sizeof(T));
    for (int i = 0; i < read_lens; i++) {
      dst[i] = in_temp;
    }
  }
}

/**
 * @brief Read data from global memory to local memory with broadcast
 * {1, n, 1}-> {m, n, k} form.
 *
 * @template paraments
 * T: Data type of register.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX.
 * src: The original input data pointer of kernel.
 * thread_offset: The data offset of this thread.
 * config: Calculation configuration of broadcast. It is used to calculate the
 * coordinate mapping relationship between output data and input data.
 * read_lens: The number of data continuously loaded by each thread.
 */
1028
template <typename T>
1029 1030 1031 1032
__device__ __inline__ void ReadDataBc1N1Mnk(
    T* dst,
    const T _global_ptr_* src,
    int thread_offset,
1033
    const details::BroadcastConfig& config,
1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081
    int read_lens) {
  int index_output = thread_offset;
  int index_base = config(index_output);
  int m = config.m;
  int n = config.n;
  T in_temp;

  int m_pos = index_output % m;
  if ((m - m_pos) < read_lens) {
    int last_col = m - m_pos;
    GM2LM(src + index_base, &in_temp, sizeof(T));
    for (int i = 0; i < last_col; i++) {
      dst[i] = in_temp;
    }
    int n_pos = index_output % (m * n) / m;
    int next_part_index = 0;
    if (n_pos != n - 1) {
      next_part_index = n_pos + 1;
    } else {
      next_part_index = 0;
    }
    GM2LM(src + next_part_index, &in_temp, sizeof(T));
    for (int i = 0; i < read_lens - last_col; i++) {
      dst[last_col + i] = in_temp;
    }
  } else {
    GM2LM(src + index_base, &in_temp, sizeof(T));
    for (int i = 0; i < read_lens; i++) {
      dst[i] = in_temp;
    }
  }
}

/**
 * @brief Read data from global memory to local memory with broadcast
 * {1}-> {n} form.
 *
 * @template paraments
 * T: Data type of register.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX.
 * src: The original input data pointer of kernel.
 * thread_offset: The data offset of this thread.
 * config: Calculation configuration of broadcast. It is used to calculate the
 * coordinate mapping relationship between output data and input data.
 * read_lens: The number of data continuously loaded by each thread.
 */
1082 1083 1084 1085 1086 1087
template <typename T>
__device__ __inline__ void ReadDataBc1N(T* dst,
                                        const T _global_ptr_* src,
                                        int thread_offset,
                                        const details::BroadcastConfig& config,
                                        int read_lens) {
1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114
  int index_output = thread_offset;
  int index_base = config(index_output);
  T in_temp;

  GM2LM(src + index_base, &in_temp, sizeof(T));
  for (int i = 0; i < read_lens; i++) {
    dst[i] = in_temp;
  }
}

/**
 * @brief Read data from global memory to local memory with broadcast
 * form which can not compress.
 *
 * @template paraments
 * T: Data type of register.
 * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX.
 * src: The original input data pointer of kernel.
 * thread_offset: The data offset of this thread.
 * config: Calculation configuration of broadcast. It is used to calculate the
 * coordinate mapping relationship between output data and input data.
 * total_num_output: Total number of original output.
 * read_lens: The number of data continuously loaded by each thread.
 */
1115
template <typename T, bool IsBoundary = false>
1116 1117 1118 1119
__device__ __inline__ void ReadDataBcCanNotCmp(
    T* dst,
    const T _global_ptr_* src,
    int thread_offset,
1120
    const details::BroadcastConfig& config,
1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169
    int total_num_output,
    int read_lens) {
  int index_output = thread_offset;
  int index_base = config(index_output);
  T in_temp;
  int cache_size = 256;
  __local__ T src_temp[cache_size];
  GM2LM(src + index_base, src_temp, cache_size * sizeof(T));

  for (int nx = 0; nx < read_lens; ++nx) {
    index_output = thread_offset + nx;
    if (IsBoundary) {
      if (index_output >= total_num_output) {
        break;
      }
    }
    int index_src = config(index_output);
    if (index_src >= index_base && index_src < index_base + cache_size) {
      in_temp = src_temp[index_src - index_base];
    } else {
      GM2LM(src + index_src, &in_temp, sizeof(T));
    }
    dst[nx] = in_temp;
  }
}

/**
 * @brief Read 1D data from global memory to register with broadcast form.
 *
 * @template paraments
 * T: The type of data stored in the global memory.
 * NX: The number of data continuously loaded by each thread.
 * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
 * BlockSize: Identifies the current device thread index method. For xpu,
 * core_id() is used as the index.
 * IsBoundary: Indicates whether to perform block access storage out-of-bounds
 * judgment. When the number of data processed by the block is less than
 * NX x NY x core_num(), boundary judgment is required to avoid memory access
 * crossing the boundary.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX * NY.
 * src: The original input data pointer of kernel.
 * block_offset: The data offset of this block, core_num() * blockIdx.x * NX;
 * config: Calculation configuration of broadcast. It is used to calculate the
 * coordinate mapping relationship between output data and input data.
 * read_lens: The number of data continuously loaded by each thread.
 * total_num_output: Total number of original output.
 */
1170 1171 1172 1173 1174 1175 1176
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __inline__ void ReadDataBc(T* dst,
                                      const T _global_ptr_* src,
                                      uint32_t block_offset,
                                      const details::BroadcastConfig& config,
                                      int total_num_output,
                                      int read_lens) {
1177 1178 1179
  int thread_offset = block_offset + core_id() * read_lens;

  if (config.cmp_type == details::OptType::MNK_M1K) {
1180
    ReadDataBcM1kMnk<T>(dst, src, thread_offset, config, read_lens);
1181
  } else if (config.cmp_type == details::OptType::N_1) {
1182
    ReadDataBc1N<T>(dst, src, thread_offset, config, read_lens);
1183
  } else if (config.cmp_type == details::OptType::MN_M) {
1184
    ReadDataBcM1Mn<T>(dst, src, thread_offset, config, read_lens);
1185
  } else if (config.cmp_type == details::OptType::MN_N) {
1186
    ReadDataBc1NMn<T>(dst, src, thread_offset, config, read_lens);
1187
  } else if (config.cmp_type == details::OptType::MNK_1N1) {
1188
    ReadDataBc1N1Mnk<T>(dst, src, thread_offset, config, read_lens);
1189
  } else {
1190
    ReadDataBcCanNotCmp<T, IsBoundary>(
1191 1192 1193 1194
        dst, src, thread_offset, config, total_num_output, read_lens);
  }
}

1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217
/**
 * @brief Initialize register with data index.
 *
 * @template paraments
 * T: Data type of register.
 * NX: Number of data to initialize.
 * NY: Number of data to initialize, NY only can be 1.
 * BlockSize: Identifies the current device thread index method. For xpu,
 * core_id() is used as the index.
 *
 * @param:
 * dst: The register pointer of the thread, the size is NX.
 * init_data: The register pointer of init data, the size is NX.
 */
template <typename T, int NX, int NY, int BlockSize>
__device__ __forceinline__ void InitWithDataIndex(T* dst, int block_offset) {
  int thread_offset = block_offset + core_id() * NX;
#pragma unroll
  for (int nx = 0; nx < NX; ++nx) {
    dst[nx] = static_cast<T>(thread_offset + nx);
  }
}

1218
}  // namespace kps
1219
}  // namespace phi