From 942a85b9f4476e12716552078bda6a90039b42a8 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Thu, 17 Mar 2022 17:34:26 +0800 Subject: [PATCH] Modified masked_select_kernel and where_index with kernel primitive api(#40517) --- paddle/phi/kernels/funcs/select_impl.cu.h | 447 ++++++++++++++++++ .../phi/kernels/gpu/masked_select_kernel.cu | 74 +-- paddle/phi/kernels/gpu/where_index_kernel.cu | 161 ++----- .../kernels/primitive/compute_primitives.h | 3 +- 4 files changed, 500 insertions(+), 185 deletions(-) create mode 100644 paddle/phi/kernels/funcs/select_impl.cu.h diff --git a/paddle/phi/kernels/funcs/select_impl.cu.h b/paddle/phi/kernels/funcs/select_impl.cu.h new file mode 100644 index 00000000000..3a1d9b8ea7a --- /dev/null +++ b/paddle/phi/kernels/funcs/select_impl.cu.h @@ -0,0 +1,447 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// CUDA and HIP use same api +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +#include +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/primitive/kernel_primitives.h" + +namespace kps = phi::kps; + +namespace phi { +namespace funcs { +using Mode = kps::details::ReduceMode; + +/* +* Count how many of the data being processed by the current block are true +* 1. Load data from global memory and cast from bool to int64_t +* 2. Get result of this thread according to thread reduce +* 3. Get result of this block according to block reduce +* 4. first block store 0 and current result +*/ +template +struct NonZeroFunctor { + HOSTDEVICE NonZeroFunctor() {} + HOSTDEVICE inline T operator()(const T in) { + if (in) { + return static_cast(1); + } else { + return static_cast(0); + } + } +}; + +template +__device__ void GetBlockCountImpl(const InT *in, + OutT *out, + int num, + int repeat) { + InT in_data[VecSize]; + OutT temp[VecSize]; + OutT result = static_cast(0.0f); + using Add = kps::AddFunctor; + using Cast = NonZeroFunctor; + int store_fix = BLOCK_ID_X + repeat * GRID_NUM_X; + + kps::Init(&in_data[0], static_cast(0.0f)); + kps::ReadData(&in_data[0], in, num); + kps::ElementwiseUnary( + &temp[0], &in_data[0], Cast()); + kps::Reduce( + &result, &temp[0], Add(), true); + kps::Reduce( + &result, &result, Add(), true); + if (store_fix == 0) { + // first block's fix_size = 0; + OutT tmp = static_cast(0.0f); + kps::WriteData(out + store_fix, &tmp, 1); + } + + // store num of this block + kps::WriteData(out + store_fix + 1, &result, 1); +} + +// Count how many data is not zero in current block +template +__global__ void GetBlockCountKernel(const InT *in, + OutT *out, + int64_t numel, + int64_t main_offset) { + int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; + int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; + int repeat = 0; + for (; data_offset < main_offset; data_offset += stride) { + GetBlockCountImpl( + in + data_offset, out, BLOCK_NUM_X * VecSize, repeat); + repeat++; // to get the real blockIdx + } + + int num = numel - data_offset; + if (num > 0) { + GetBlockCountImpl( + in + data_offset, out, num, repeat); + } +} + +/* +* Get block num prefix us one block, VecSize must be 2 +* 1. Each thread load 2 data : threadIdx.x and threadIdx.x + blockDimx.x +* 2. Cumsum limitation is blockDim.x must be less than 512 +*/ + +template +__device__ void CumsumImpl( + const InT *in, OutT *out, OutT *pre_cumsum, int num, Functor func) { + __shared__ OutT max_thread_data; + OutT temp[VecSize]; + InT arg[VecSize]; + OutT result[VecSize]; + // init data_pr + kps::Init(&arg[0], static_cast(0.0f)); + // set pre_cumsum + kps::Init(&temp[0], *pre_cumsum); + // load data to arg + kps::ReadData( + &arg[0], in, num, 1, BLOCK_NUM_X, 1); + // block cumsum + kps::Cumsum(&result[0], &arg[0], func); + // result = cumsum_result + pre_cumsum + kps::ElementwiseBinary( + &result[0], &result[0], &temp[0], func); + // get the last prefix sum + if ((THREAD_ID_X == BLOCK_NUM_X - 1) && !IsBoundary) { + max_thread_data = result[VecSize - 1]; + } + __syncthreads(); + // update pre_cumsum + *pre_cumsum = max_thread_data; + kps::WriteData( + out, &result[0], num, 1, BLOCK_NUM_X, 1); +} + +// Compute this store_offset of this block +template +__global__ void CumsumOneBlock( + const InT *in, OutT *out, int numel, int main_offset, Functor func) { + int stride = BLOCK_NUM_X * VecSize; + int offset = 0; + OutT pre_cumsum = static_cast(0); + for (; offset < main_offset; offset += stride) { + CumsumImpl( + in + offset, out + offset, &pre_cumsum, BLOCK_NUM_X * VecSize, func); + } + + int num = numel - offset; + if (num > 0) { + CumsumImpl( + in + offset, out + offset, &pre_cumsum, num, func); + } +} + +template +struct SelectCaller { + __device__ void inline operator()(OutT *store_data, + const MT *mask_data, + const InT *in, + Functor func, + int num, + int data_offset) { + // where_index op + IdT index_reg[VecSize]; + // Set data index of global + kps::InitWithDataIndex(&index_reg[0], data_offset); + // Get store data according to mask_idt + kps::OperatorTernary( + store_data, mask_data, &index_reg[0], func, VecSize); + } +}; + +template +struct SelectCaller { // masked_select + __device__ void inline operator()(OutT *store_data, + const MT *mask_data, + const InT *in, + Functor func, + int num, + int data_offset) { + InT in_data[VecSize]; + kps::ReadData(&in_data[0], in, num); + // Get store data according to mask_idt + kps::OperatorTernary( + store_data, mask_data, &in_data[0], func, VecSize); + } +}; + +/** +* Get mask's index if mask == true +*/ +template // SelectType = 1 Mask_select else where_index +__device__ void +SelectKernelImpl(OutT *out, + const MT *mask, + const InT *in, + Functor func, + int num, + int data_offset, + int store_rank) { + const int kCVecSize = 2; + // each thread cumsum 2 data + using IdT = int64_t; + // Set index data type + using Add = kps::AddFunctor; // for cumsum + using Cast = NonZeroFunctor; // for mask + + IdT init_idx = static_cast(0.0f); + MT init_mask = static_cast(0.0f); + + IdT num_thread[kCVecSize]; + IdT cumsum_thread[kCVecSize]; + + OutT store_data[VecSize * phi::DDim::kMaxRank]; + MT mask_data[VecSize]; + IdT mask_idt[VecSize]; + // init data_pr + kps::Init(&cumsum_thread[0], init_idx); + kps::Init(&num_thread[0], init_idx); + kps::Init(&mask_data[0], init_mask); + // Load mask + kps::ReadData(&mask_data[0], mask, num); + // Cast from MT to int + kps::ElementwiseUnary( + &mask_idt[0], &mask_data[0], Cast()); + // Get the num of thread only num_thread[1] has data + kps::Reduce( + &num_thread[0], &mask_idt[0], Add(), true); + // Get cumsum_thread cumsum from 0 to num_thread cumsum_thread[0] is the + // thread_fix + kps::Cumsum(&cumsum_thread[0], &num_thread[0], Add()); + // Get store data(index) according to mask_idt + SelectCaller + compute; + compute(&store_data[0], &mask_data[0], in, func, num, data_offset); + // get thread_fix + int thread_fix = + (static_cast(cumsum_thread[0] - num_thread[0]) * store_rank); + // get how many data need to store + int store_num = static_cast(num_thread[0]) * store_rank; + // thread store num data, each thread may has different num + kps::details::WriteData(out + thread_fix, &store_data[0], store_num); +} + +template +__global__ void SelectKernel(OutT *out, + const MT *mask, + const InT *in, + CT *cumsum, + Functor func, + const int64_t numel, + int64_t main_offset, + int store_rank) { + int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; + int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; + int repeat = 0; + int size = VecSize * BLOCK_ID_X; + for (; data_offset < main_offset; data_offset += stride) { + // Cumsum index + int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X; + // niuliling todo: us ReadData API + int block_store_offset = cumsum[idx_cumsum]; + SelectKernelImpl( + out + block_store_offset * store_rank, + mask + data_offset, + in + data_offset, + func, + size, + data_offset, + store_rank); + repeat++; + } + + int num = numel - data_offset; + if (num > 0) { + // Cumsum index + int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X; + // niuliling todo: us ReadData API + int block_store_offset = static_cast(cumsum[idx_cumsum]); + SelectKernelImpl( + out + block_store_offset * store_rank, + mask + data_offset, + in + data_offset, + func, + num, + data_offset, + store_rank); + } +} + +inline int64_t Floor(int64_t in, int64_t div) { return in / div * div; } + +// SelectData = 1 then masked_select; SelectData = 0 then where_index +template +void SelectKernel(const KPDevice &dev_ctx, + const DenseTensor &condition, + const DenseTensor &in_data, + DenseTensor *out, + Functor func) { + const MT *cond_data = condition.data(); + const int64_t numel = condition.numel(); + auto dims = condition.dims(); + int rank = SelectData ? 1 : dims.size(); + const InT *in_data_ptr = SelectData ? in_data.data() : nullptr; + // calculate the inclusive prefix sum of "true_num_array" + // to get the index of "out" tensor, + // and the total number of cond_data[i]==true. + // Example: + // condition: F T T F F F T T + // before: 0 1 1 0 0 0 1 1 + // after: 0 1 2 2 2 2 3 4 + // out: 1 2 6 7 + // alloc for cpu + using CT = int64_t; // set Count_data Type + const int t_size = sizeof(CT); + + const paddle::platform::CUDAPlace &cuda_place = dev_ctx.GetPlace(); + paddle::platform::CPUPlace cpu_place = paddle::platform::CPUPlace(); + + // 1.1 get stored data num of per block + int total_true_num = 0; // init + const int kVecSize = 4; +#ifdef PADDLE_WITH_XPU_KP + int block = 64; + auto stream = dev_ctx.x_context()->xpu_stream; + const int num_per_block = kVecSize * block; + const int need_grids = (numel + num_per_block - 1) / num_per_block; + const int grid = std::min(need_grids, 8); +#else + const int block = 256; + const int num_per_block = kVecSize * block; + const int need_grids = (numel + num_per_block - 1) / num_per_block; + const int grid = std::min(need_grids, 256); + auto stream = dev_ctx.stream(); +#endif + const int64_t main_offset = Floor(numel, num_per_block); + // 1.2 alloc tmp data for CoutBlock + const int size_count_block = need_grids + 1; + std::vector dims_vec = {size_count_block * 2}; + ScalarArray dims_array(dims_vec); + DenseTensor count_mem = phi::Empty(dev_ctx, dims_array); + CT *count_data = count_mem.data(); + // 1.3 launch CountKernl + GetBlockCountKernel<<>>( + cond_data, count_data, numel, main_offset); + // 2.1 alloc cumsum data for CoutBlock prefix + DenseTensor cumsum_mem = phi::Empty(dev_ctx, dims_array); + CT *cumsum_data = cumsum_mem.data(); + // 2.2 get prefix of count_data for real out_index + const int kCumVesize = 2; + const int block_c = 256; + const int main_offset_c = Floor(size_count_block, (kCumVesize * block_c)); + using Add = kps::AddFunctor; + CumsumOneBlock<<<1, block_c, 0, stream>>>( + count_data, cumsum_data, size_count_block, main_offset_c, Add()); + // 3.1 set temp ptr for in; + // 3.1 alloc for out + // 3.1.1 get true_num for gpu place the last cumsum is the true_num + paddle::memory::Copy(cpu_place, + &total_true_num, + cuda_place, + cumsum_data + need_grids, + t_size, + dev_ctx.stream()); + + dev_ctx.Wait(); + // 3.1.2 allock for out with total_true_num + std::vector out_dim = {static_cast(total_true_num)}; + if (SelectData == 0) { // where_index + out_dim.push_back(rank); + } + out->Resize(phi::make_ddim(out_dim)); + auto out_data = out->mutable_data(cuda_place); + // 3.2 get true data's index according to cond_data and cumsum_data + if (total_true_num <= 0) return; + SelectKernel<<>>(out_data, + cond_data, + in_data_ptr, + cumsum_data, + func, + numel, + main_offset, + rank); +} + +} // namespace funcs +} // namespace phi + +#endif diff --git a/paddle/phi/kernels/gpu/masked_select_kernel.cu b/paddle/phi/kernels/gpu/masked_select_kernel.cu index fc4adca2f42..b443ae6b8fb 100644 --- a/paddle/phi/kernels/gpu/masked_select_kernel.cu +++ b/paddle/phi/kernels/gpu/masked_select_kernel.cu @@ -19,34 +19,27 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/select_impl.cu.h" #include "paddle/phi/kernels/masked_select_kernel.h" namespace phi { -__global__ void SetMaskArray(const bool* mask, int32_t* mask_array, int size) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - for (; idx < size; idx += blockDim.x * gridDim.x) { - if (mask[idx]) - mask_array[idx] = 1; - else - mask_array[idx] = 0; - } -} +template +struct MaskedSelectFunctor { + HOSTDEVICE MaskedSelectFunctor() {} -template -__global__ void SelectWithPrefixMask(const int32_t* mask_prefix_sum, - const bool* mask, - const T* input, - T* out, - int size) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - for (; idx < size; idx += blockDim.x * gridDim.x) { - if (mask[idx]) { - int index = mask_prefix_sum[idx]; - out[index] = input[idx]; + HOSTDEVICE inline void operator()(OutT* out, + const MT* mask, + const InT* value, + int num) { + int store_fix = 0; + for (int idx = 0; idx < num; idx++) { + if (mask[idx]) { + out[store_fix++] = value[idx]; + } } } -} +}; template void MaskedSelectKernel(const Context& dev_ctx, @@ -68,42 +61,9 @@ void MaskedSelectKernel(const Context& dev_ctx, "value.", input_dim, mask_dim)); - - thrust::device_ptr mask_dev_ptr = - thrust::device_pointer_cast(mask_data); - thrust::device_vector mask_vec(mask_dev_ptr, mask_dev_ptr + mask_size); - auto out_size = thrust::count(mask_vec.begin(), mask_vec.end(), true); - - DDim out_dim{out_size}; - out->Resize(out_dim); - auto out_data = out->mutable_data(dev_ctx.GetPlace()); - - DenseTensor mask_array; - DenseTensor mask_prefix_sum; - mask_array.Resize(mask_dim); - mask_prefix_sum.Resize(mask_dim); - - int32_t* mask_array_data = - mask_array.mutable_data(dev_ctx.GetPlace()); - int32_t* mask_prefix_sum_data = - mask_prefix_sum.mutable_data(dev_ctx.GetPlace()); - int threads = 512; - int grid = (mask_size + threads - 1) / threads; - auto stream = dev_ctx.stream(); - SetMaskArray<<>>( - mask_data, mask_array_data, mask_size); - - thrust::device_ptr mask_array_dev_ptr = - thrust::device_pointer_cast(mask_array_data); - thrust::device_vector mask_array_vec(mask_array_dev_ptr, - mask_array_dev_ptr + mask_size); - thrust::exclusive_scan(thrust::device, - mask_array_vec.begin(), - mask_array_vec.end(), - mask_prefix_sum_data); - - SelectWithPrefixMask<<>>( - mask_prefix_sum_data, mask_data, input_data, out_data, mask_size); + using Functor = MaskedSelectFunctor; + phi::funcs::SelectKernel( + dev_ctx, mask, x, out, Functor()); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/where_index_kernel.cu b/paddle/phi/kernels/gpu/where_index_kernel.cu index 535cb812a20..9538533f70d 100644 --- a/paddle/phi/kernels/gpu/where_index_kernel.cu +++ b/paddle/phi/kernels/gpu/where_index_kernel.cu @@ -20,150 +20,59 @@ namespace cub = hipcub; #endif -#include "paddle/phi/kernels/where_index_kernel.h" - -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/select_impl.cu.h" +#include "paddle/phi/kernels/where_index_kernel.h" namespace phi { - -template -__global__ void GetTrueNum(const T *cond_data, - const int64_t numel, - int64_t *true_num_array) { - const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (int64_t idx = tid; idx < numel; idx += gridDim.x * blockDim.x) { - true_num_array[idx] = - static_cast(static_cast(cond_data[idx])); +template +struct IndexFunctor { + T2 stride[phi::DDim::kMaxRank]; + int dims; + explicit IndexFunctor(const phi::DDim &in_dims) { + dims = in_dims.size(); + std::vector strides_in_tmp; + strides_in_tmp.resize(dims, 1); + // get strides according to in_dims + for (T2 i = 1; i < dims; i++) { + strides_in_tmp[i] = strides_in_tmp[i - 1] * in_dims[dims - i]; + } + memcpy(stride, strides_in_tmp.data(), dims * sizeof(T2)); } -} - -template -__global__ void SetTrueIndex(int64_t *out_ptr, - const T *cond_data, - const int64_t numel, - const int64_t *stride_array, - const int64_t rank, - const int64_t *true_num_array) { - const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; - for (int64_t idx = tid; idx < numel; idx += gridDim.x * blockDim.x) { - // true_num_array is calculated by cub::InclusiveSum, - // cause the first element of true_num_array is 1, - // so we need substract 1 to get true index. - const int64_t true_index = true_num_array[idx] - 1; - if (static_cast(cond_data[idx])) { - int64_t rank_index = idx; - for (int j = 0; j < rank; j++) { - const int64_t out_index = rank_index / stride_array[j]; - out_ptr[true_index * rank + j] = out_index; - rank_index -= out_index * stride_array[j]; + HOSTDEVICE inline void operator()(OutT *out, + const T1 *mask, + const T2 *index, + const int num) { + int store_fix = 0; + for (int idx = 0; idx < num; idx++) { + if (mask[idx]) { + T2 data_index = index[idx]; + // get index + for (int rank_id = dims - 1; rank_id >= 0; --rank_id) { + out[store_fix] = static_cast(data_index / stride[rank_id]); + data_index = data_index % stride[rank_id]; + store_fix++; + } } } } -} +}; template void WhereIndexKernel(const Context &dev_ctx, const DenseTensor &condition, DenseTensor *out) { - const T *cond_data = condition.data(); - const int64_t numel = condition.numel(); + DenseTensor in_data; auto dims = condition.dims(); - const int rank = dims.size(); - - auto d_array_mem = - paddle::memory::Alloc(dev_ctx, (numel + rank) * sizeof(int64_t)); - auto h_array_mem = - paddle::memory::Alloc(phi::CPUPlace(), (rank + 1) * sizeof(int64_t)); - - // "stride_array" is an array and len(stride_array)==rank, - // each element is the stride of each dimension -- the length from i to i+1. - int64_t *h_stride_array = reinterpret_cast(h_array_mem->ptr()); - int64_t *d_stride_array = reinterpret_cast(d_array_mem->ptr()); - - // "true_num_array" is an array and len(stride_array)==numel, - // at the beginning, - // "true_num_array" will set 1 if condition[i] == true else 0, - // then it will be calculated by cub::InclusiveSum, - // so that we can get the true number before i as the out index - int64_t *d_true_num_array = d_stride_array + rank; - - // the total_true_num is the total number of condition[i] == true - int64_t *h_total_true_num = h_stride_array + rank; - - // alloce cub memory - size_t cub_size = 0; - cub::DeviceScan::InclusiveSum(nullptr, - cub_size, - d_true_num_array, - d_true_num_array, - numel, - dev_ctx.stream()); - auto cub_mem = paddle::memory::Alloc(dev_ctx, cub_size * sizeof(int64_t)); - void *cub_data = cub_mem->ptr(); - - // set d_true_num_array[i]=1 if cond_data[i]==true else 0 - const int threads = std::min(numel, static_cast(128)); - const int64_t need_grids = (numel + threads - 1) / threads; - const int grids = std::min(need_grids, static_cast(256)); - GetTrueNum<<>>( - cond_data, numel, d_true_num_array); - - // calculate the inclusive prefix sum of "true_num_array" - // to get the index of "out" tensor, - // and the total number of cond_data[i]==true. - // Example: - // condition: F T T F F F T T - // before: 0 1 1 0 0 0 1 1 - // after: 0 1 2 2 2 2 3 4 - // out: 1 2 6 7 - cub::DeviceScan::InclusiveSum(cub_data, - cub_size, - d_true_num_array, - d_true_num_array, - numel, - dev_ctx.stream()); - - // calculate each dimension's stride - h_stride_array[rank - 1] = 1; - for (int i = rank - 2; i >= 0; i--) { - h_stride_array[i] = h_stride_array[i + 1] * dims[i + 1]; - } - paddle::memory::Copy(dev_ctx.GetPlace(), - d_stride_array, - phi::CPUPlace(), - h_stride_array, - rank * sizeof(int64_t), - dev_ctx.stream()); - - // get total ture number and set output size - // the last element of cub::InclusiveSum is the total number - paddle::memory::Copy(phi::CPUPlace(), - h_total_true_num, - dev_ctx.GetPlace(), - d_true_num_array + numel - 1, - sizeof(int64_t), - dev_ctx.stream()); - dev_ctx.Wait(); - - int64_t true_num = *h_total_true_num; - out->Resize(phi::make_ddim({static_cast(true_num), rank})); - auto *out_data = dev_ctx.template Alloc(out); - - if (true_num == 0) { - return; - } - - // using true_num_array and stride_array to calculate the output index - SetTrueIndex<<>>( - out_data, cond_data, numel, d_stride_array, rank, d_true_num_array); + using Functor = IndexFunctor; + Functor index_functor = Functor(dims); + phi::funcs::SelectKernel( + dev_ctx, condition, in_data, out, index_functor); } - } // namespace phi PD_REGISTER_KERNEL(where_index, diff --git a/paddle/phi/kernels/primitive/compute_primitives.h b/paddle/phi/kernels/primitive/compute_primitives.h index 632ad00f6d0..e02f4450a8b 100644 --- a/paddle/phi/kernels/primitive/compute_primitives.h +++ b/paddle/phi/kernels/primitive/compute_primitives.h @@ -22,7 +22,6 @@ #endif #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" -// #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" namespace phi { @@ -591,7 +590,7 @@ __device__ __forceinline__ void Cumsum(OutT* out, int index = (tidx + 1) * 2 * stride - 1; if (index < (blockDim.x * 2)) { temp[index + index / 32] = - compute(temp[index + index / 2], + compute(temp[index + index / 32], temp[index - stride + (index - stride) / 32]); } } -- GitLab