未验证 提交 942a85b9 编写于 作者: N niuliling123 提交者: GitHub

Modified masked_select_kernel and where_index with kernel primitive api(#40517)

上级 96d2f337
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
// CUDA and HIP use same api
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include <algorithm>
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
namespace kps = phi::kps;
namespace phi {
namespace funcs {
using Mode = kps::details::ReduceMode;
/*
* Count how many of the data being processed by the current block are true
* 1. Load data from global memory and cast from bool to int64_t
* 2. Get result of this thread according to thread reduce
* 3. Get result of this block according to block reduce
* 4. first block store 0 and current result
*/
template <typename T>
struct NonZeroFunctor {
HOSTDEVICE NonZeroFunctor() {}
HOSTDEVICE inline T operator()(const T in) {
if (in) {
return static_cast<T>(1);
} else {
return static_cast<T>(0);
}
}
};
template <typename InT, typename OutT, int VecSize, int IsBoundary>
__device__ void GetBlockCountImpl(const InT *in,
OutT *out,
int num,
int repeat) {
InT in_data[VecSize];
OutT temp[VecSize];
OutT result = static_cast<OutT>(0.0f);
using Add = kps::AddFunctor<OutT>;
using Cast = NonZeroFunctor<InT>;
int store_fix = BLOCK_ID_X + repeat * GRID_NUM_X;
kps::Init<InT, VecSize>(&in_data[0], static_cast<InT>(0.0f));
kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(&in_data[0], in, num);
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Cast>(
&temp[0], &in_data[0], Cast());
kps::Reduce<OutT, VecSize, 1, 1, Add, Mode::kLocalMode>(
&result, &temp[0], Add(), true);
kps::Reduce<OutT, 1, 1, 1, Add, Mode::kGlobalMode>(
&result, &result, Add(), true);
if (store_fix == 0) {
// first block's fix_size = 0;
OutT tmp = static_cast<OutT>(0.0f);
kps::WriteData<OutT, 1, 1, 1, true>(out + store_fix, &tmp, 1);
}
// store num of this block
kps::WriteData<OutT, 1, 1, 1, true>(out + store_fix + 1, &result, 1);
}
// Count how many data is not zero in current block
template <typename InT, typename OutT, int VecSize>
__global__ void GetBlockCountKernel(const InT *in,
OutT *out,
int64_t numel,
int64_t main_offset) {
int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
int repeat = 0;
for (; data_offset < main_offset; data_offset += stride) {
GetBlockCountImpl<InT, OutT, VecSize, false>(
in + data_offset, out, BLOCK_NUM_X * VecSize, repeat);
repeat++; // to get the real blockIdx
}
int num = numel - data_offset;
if (num > 0) {
GetBlockCountImpl<InT, OutT, VecSize, true>(
in + data_offset, out, num, repeat);
}
}
/*
* Get block num prefix us one block, VecSize must be 2
* 1. Each thread load 2 data : threadIdx.x and threadIdx.x + blockDimx.x
* 2. Cumsum limitation is blockDim.x must be less than 512
*/
template <typename InT,
typename OutT,
typename Functor,
int VecSize,
bool IsBoundary>
__device__ void CumsumImpl(
const InT *in, OutT *out, OutT *pre_cumsum, int num, Functor func) {
__shared__ OutT max_thread_data;
OutT temp[VecSize];
InT arg[VecSize];
OutT result[VecSize];
// init data_pr
kps::Init<InT, VecSize>(&arg[0], static_cast<InT>(0.0f));
// set pre_cumsum
kps::Init<OutT, VecSize>(&temp[0], *pre_cumsum);
// load data to arg
kps::ReadData<InT, InT, VecSize, 1, 1, IsBoundary>(
&arg[0], in, num, 1, BLOCK_NUM_X, 1);
// block cumsum
kps::Cumsum<InT, OutT, 1, Functor>(&result[0], &arg[0], func);
// result = cumsum_result + pre_cumsum
kps::ElementwiseBinary<OutT, OutT, VecSize, 1, 1, Functor>(
&result[0], &result[0], &temp[0], func);
// get the last prefix sum
if ((THREAD_ID_X == BLOCK_NUM_X - 1) && !IsBoundary) {
max_thread_data = result[VecSize - 1];
}
__syncthreads();
// update pre_cumsum
*pre_cumsum = max_thread_data;
kps::WriteData<OutT, OutT, VecSize, 1, 1, IsBoundary>(
out, &result[0], num, 1, BLOCK_NUM_X, 1);
}
// Compute this store_offset of this block
template <typename InT, typename OutT, typename Functor, int VecSize>
__global__ void CumsumOneBlock(
const InT *in, OutT *out, int numel, int main_offset, Functor func) {
int stride = BLOCK_NUM_X * VecSize;
int offset = 0;
OutT pre_cumsum = static_cast<OutT>(0);
for (; offset < main_offset; offset += stride) {
CumsumImpl<InT, OutT, Functor, VecSize, false>(
in + offset, out + offset, &pre_cumsum, BLOCK_NUM_X * VecSize, func);
}
int num = numel - offset;
if (num > 0) {
CumsumImpl<InT, OutT, Functor, VecSize, true>(
in + offset, out + offset, &pre_cumsum, num, func);
}
}
template <typename OutT,
typename MT,
typename InT,
typename IdT,
typename Functor,
int VecSize,
int IsBoundary,
int IsMaskData>
struct SelectCaller {
__device__ void inline operator()(OutT *store_data,
const MT *mask_data,
const InT *in,
Functor func,
int num,
int data_offset) {
// where_index op
IdT index_reg[VecSize];
// Set data index of global
kps::InitWithDataIndex<IdT, VecSize, 1, 1>(&index_reg[0], data_offset);
// Get store data according to mask_idt
kps::OperatorTernary<MT, IdT, OutT, Functor>(
store_data, mask_data, &index_reg[0], func, VecSize);
}
};
template <typename OutT,
typename MT,
typename InT,
typename IdT,
typename Functor,
int VecSize,
int IsBoundary>
struct SelectCaller<OutT,
MT,
InT,
IdT,
Functor,
VecSize,
IsBoundary,
1> { // masked_select
__device__ void inline operator()(OutT *store_data,
const MT *mask_data,
const InT *in,
Functor func,
int num,
int data_offset) {
InT in_data[VecSize];
kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(&in_data[0], in, num);
// Get store data according to mask_idt
kps::OperatorTernary<MT, InT, OutT, Functor>(
store_data, mask_data, &in_data[0], func, VecSize);
}
};
/**
* Get mask's index if mask == true
*/
template <typename InT,
typename MT,
typename OutT,
typename Functor,
int VecSize,
int MaskData,
int IsBoundary> // SelectType = 1 Mask_select else where_index
__device__ void
SelectKernelImpl(OutT *out,
const MT *mask,
const InT *in,
Functor func,
int num,
int data_offset,
int store_rank) {
const int kCVecSize = 2;
// each thread cumsum 2 data
using IdT = int64_t;
// Set index data type
using Add = kps::AddFunctor<IdT>; // for cumsum
using Cast = NonZeroFunctor<InT>; // for mask
IdT init_idx = static_cast<IdT>(0.0f);
MT init_mask = static_cast<MT>(0.0f);
IdT num_thread[kCVecSize];
IdT cumsum_thread[kCVecSize];
OutT store_data[VecSize * phi::DDim::kMaxRank];
MT mask_data[VecSize];
IdT mask_idt[VecSize];
// init data_pr
kps::Init<IdT, kCVecSize>(&cumsum_thread[0], init_idx);
kps::Init<IdT, kCVecSize>(&num_thread[0], init_idx);
kps::Init<MT, VecSize>(&mask_data[0], init_mask);
// Load mask
kps::ReadData<MT, VecSize, 1, 1, IsBoundary>(&mask_data[0], mask, num);
// Cast from MT to int
kps::ElementwiseUnary<MT, IdT, VecSize, 1, 1, Cast>(
&mask_idt[0], &mask_data[0], Cast());
// Get the num of thread only num_thread[1] has data
kps::Reduce<IdT, VecSize, 1, 1, Add, Mode::kLocalMode>(
&num_thread[0], &mask_idt[0], Add(), true);
// Get cumsum_thread cumsum from 0 to num_thread cumsum_thread[0] is the
// thread_fix
kps::Cumsum<IdT, IdT, 1, Add>(&cumsum_thread[0], &num_thread[0], Add());
// Get store data(index) according to mask_idt
SelectCaller<OutT, MT, InT, IdT, Functor, VecSize, IsBoundary, MaskData>
compute;
compute(&store_data[0], &mask_data[0], in, func, num, data_offset);
// get thread_fix
int thread_fix =
(static_cast<int>(cumsum_thread[0] - num_thread[0]) * store_rank);
// get how many data need to store
int store_num = static_cast<int>(num_thread[0]) * store_rank;
// thread store num data, each thread may has different num
kps::details::WriteData<OutT>(out + thread_fix, &store_data[0], store_num);
}
template <typename MT,
typename InT,
typename CT,
typename OutT,
typename Functor,
int VecSize,
int MaskData>
__global__ void SelectKernel(OutT *out,
const MT *mask,
const InT *in,
CT *cumsum,
Functor func,
const int64_t numel,
int64_t main_offset,
int store_rank) {
int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
int repeat = 0;
int size = VecSize * BLOCK_ID_X;
for (; data_offset < main_offset; data_offset += stride) {
// Cumsum index
int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
// niuliling todo: us ReadData API
int block_store_offset = cumsum[idx_cumsum];
SelectKernelImpl<InT, MT, OutT, Functor, VecSize, MaskData, false>(
out + block_store_offset * store_rank,
mask + data_offset,
in + data_offset,
func,
size,
data_offset,
store_rank);
repeat++;
}
int num = numel - data_offset;
if (num > 0) {
// Cumsum index
int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
// niuliling todo: us ReadData API
int block_store_offset = static_cast<int>(cumsum[idx_cumsum]);
SelectKernelImpl<InT, MT, OutT, Functor, VecSize, MaskData, true>(
out + block_store_offset * store_rank,
mask + data_offset,
in + data_offset,
func,
num,
data_offset,
store_rank);
}
}
inline int64_t Floor(int64_t in, int64_t div) { return in / div * div; }
// SelectData = 1 then masked_select; SelectData = 0 then where_index
template <typename MT,
typename InT,
typename OutT,
int SelectData,
typename Functor>
void SelectKernel(const KPDevice &dev_ctx,
const DenseTensor &condition,
const DenseTensor &in_data,
DenseTensor *out,
Functor func) {
const MT *cond_data = condition.data<MT>();
const int64_t numel = condition.numel();
auto dims = condition.dims();
int rank = SelectData ? 1 : dims.size();
const InT *in_data_ptr = SelectData ? in_data.data<InT>() : nullptr;
// calculate the inclusive prefix sum of "true_num_array"
// to get the index of "out" tensor,
// and the total number of cond_data[i]==true.
// Example:
// condition: F T T F F F T T
// before: 0 1 1 0 0 0 1 1
// after: 0 1 2 2 2 2 3 4
// out: 1 2 6 7
// alloc for cpu
using CT = int64_t; // set Count_data Type
const int t_size = sizeof(CT);
const paddle::platform::CUDAPlace &cuda_place = dev_ctx.GetPlace();
paddle::platform::CPUPlace cpu_place = paddle::platform::CPUPlace();
// 1.1 get stored data num of per block
int total_true_num = 0; // init
const int kVecSize = 4;
#ifdef PADDLE_WITH_XPU_KP
int block = 64;
auto stream = dev_ctx.x_context()->xpu_stream;
const int num_per_block = kVecSize * block;
const int need_grids = (numel + num_per_block - 1) / num_per_block;
const int grid = std::min(need_grids, 8);
#else
const int block = 256;
const int num_per_block = kVecSize * block;
const int need_grids = (numel + num_per_block - 1) / num_per_block;
const int grid = std::min(need_grids, 256);
auto stream = dev_ctx.stream();
#endif
const int64_t main_offset = Floor(numel, num_per_block);
// 1.2 alloc tmp data for CoutBlock
const int size_count_block = need_grids + 1;
std::vector<int> dims_vec = {size_count_block * 2};
ScalarArray dims_array(dims_vec);
DenseTensor count_mem = phi::Empty<CT, KPDevice>(dev_ctx, dims_array);
CT *count_data = count_mem.data<CT>();
// 1.3 launch CountKernl
GetBlockCountKernel<MT, CT, kVecSize><<<grid, block, 0, stream>>>(
cond_data, count_data, numel, main_offset);
// 2.1 alloc cumsum data for CoutBlock prefix
DenseTensor cumsum_mem = phi::Empty<CT, KPDevice>(dev_ctx, dims_array);
CT *cumsum_data = cumsum_mem.data<CT>();
// 2.2 get prefix of count_data for real out_index
const int kCumVesize = 2;
const int block_c = 256;
const int main_offset_c = Floor(size_count_block, (kCumVesize * block_c));
using Add = kps::AddFunctor<CT>;
CumsumOneBlock<CT, CT, Add, kCumVesize><<<1, block_c, 0, stream>>>(
count_data, cumsum_data, size_count_block, main_offset_c, Add());
// 3.1 set temp ptr for in;
// 3.1 alloc for out
// 3.1.1 get true_num for gpu place the last cumsum is the true_num
paddle::memory::Copy(cpu_place,
&total_true_num,
cuda_place,
cumsum_data + need_grids,
t_size,
dev_ctx.stream());
dev_ctx.Wait();
// 3.1.2 allock for out with total_true_num
std::vector<int64_t> out_dim = {static_cast<int64_t>(total_true_num)};
if (SelectData == 0) { // where_index
out_dim.push_back(rank);
}
out->Resize(phi::make_ddim(out_dim));
auto out_data = out->mutable_data<OutT>(cuda_place);
// 3.2 get true data's index according to cond_data and cumsum_data
if (total_true_num <= 0) return;
SelectKernel<MT,
InT,
CT,
OutT,
Functor,
kVecSize,
SelectData><<<grid, block, 0, stream>>>(out_data,
cond_data,
in_data_ptr,
cumsum_data,
func,
numel,
main_offset,
rank);
}
} // namespace funcs
} // namespace phi
#endif
......@@ -19,34 +19,27 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/select_impl.cu.h"
#include "paddle/phi/kernels/masked_select_kernel.h"
namespace phi {
__global__ void SetMaskArray(const bool* mask, int32_t* mask_array, int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
if (mask[idx])
mask_array[idx] = 1;
else
mask_array[idx] = 0;
}
}
template <typename MT, typename InT, typename OutT>
struct MaskedSelectFunctor {
HOSTDEVICE MaskedSelectFunctor() {}
template <typename T>
__global__ void SelectWithPrefixMask(const int32_t* mask_prefix_sum,
const bool* mask,
const T* input,
T* out,
int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
if (mask[idx]) {
int index = mask_prefix_sum[idx];
out[index] = input[idx];
HOSTDEVICE inline void operator()(OutT* out,
const MT* mask,
const InT* value,
int num) {
int store_fix = 0;
for (int idx = 0; idx < num; idx++) {
if (mask[idx]) {
out[store_fix++] = value[idx];
}
}
}
}
};
template <typename T, typename Context>
void MaskedSelectKernel(const Context& dev_ctx,
......@@ -68,42 +61,9 @@ void MaskedSelectKernel(const Context& dev_ctx,
"value.",
input_dim,
mask_dim));
thrust::device_ptr<const bool> mask_dev_ptr =
thrust::device_pointer_cast(mask_data);
thrust::device_vector<T> mask_vec(mask_dev_ptr, mask_dev_ptr + mask_size);
auto out_size = thrust::count(mask_vec.begin(), mask_vec.end(), true);
DDim out_dim{out_size};
out->Resize(out_dim);
auto out_data = out->mutable_data<T>(dev_ctx.GetPlace());
DenseTensor mask_array;
DenseTensor mask_prefix_sum;
mask_array.Resize(mask_dim);
mask_prefix_sum.Resize(mask_dim);
int32_t* mask_array_data =
mask_array.mutable_data<int32_t>(dev_ctx.GetPlace());
int32_t* mask_prefix_sum_data =
mask_prefix_sum.mutable_data<int32_t>(dev_ctx.GetPlace());
int threads = 512;
int grid = (mask_size + threads - 1) / threads;
auto stream = dev_ctx.stream();
SetMaskArray<<<grid, threads, 0, stream>>>(
mask_data, mask_array_data, mask_size);
thrust::device_ptr<int32_t> mask_array_dev_ptr =
thrust::device_pointer_cast(mask_array_data);
thrust::device_vector<int32_t> mask_array_vec(mask_array_dev_ptr,
mask_array_dev_ptr + mask_size);
thrust::exclusive_scan(thrust::device,
mask_array_vec.begin(),
mask_array_vec.end(),
mask_prefix_sum_data);
SelectWithPrefixMask<T><<<grid, threads, 0, stream>>>(
mask_prefix_sum_data, mask_data, input_data, out_data, mask_size);
using Functor = MaskedSelectFunctor<bool, T, T>;
phi::funcs::SelectKernel<bool, T, T, 1, Functor>(
dev_ctx, mask, x, out, Functor());
}
} // namespace phi
......
......@@ -20,150 +20,59 @@
namespace cub = hipcub;
#endif
#include "paddle/phi/kernels/where_index_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/select_impl.cu.h"
#include "paddle/phi/kernels/where_index_kernel.h"
namespace phi {
template <typename T>
__global__ void GetTrueNum(const T *cond_data,
const int64_t numel,
int64_t *true_num_array) {
const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int64_t idx = tid; idx < numel; idx += gridDim.x * blockDim.x) {
true_num_array[idx] =
static_cast<int64_t>(static_cast<bool>(cond_data[idx]));
template <typename T1, typename T2, typename OutT>
struct IndexFunctor {
T2 stride[phi::DDim::kMaxRank];
int dims;
explicit IndexFunctor(const phi::DDim &in_dims) {
dims = in_dims.size();
std::vector<T2> strides_in_tmp;
strides_in_tmp.resize(dims, 1);
// get strides according to in_dims
for (T2 i = 1; i < dims; i++) {
strides_in_tmp[i] = strides_in_tmp[i - 1] * in_dims[dims - i];
}
memcpy(stride, strides_in_tmp.data(), dims * sizeof(T2));
}
}
template <typename T>
__global__ void SetTrueIndex(int64_t *out_ptr,
const T *cond_data,
const int64_t numel,
const int64_t *stride_array,
const int64_t rank,
const int64_t *true_num_array) {
const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int64_t idx = tid; idx < numel; idx += gridDim.x * blockDim.x) {
// true_num_array is calculated by cub::InclusiveSum,
// cause the first element of true_num_array is 1,
// so we need substract 1 to get true index.
const int64_t true_index = true_num_array[idx] - 1;
if (static_cast<bool>(cond_data[idx])) {
int64_t rank_index = idx;
for (int j = 0; j < rank; j++) {
const int64_t out_index = rank_index / stride_array[j];
out_ptr[true_index * rank + j] = out_index;
rank_index -= out_index * stride_array[j];
HOSTDEVICE inline void operator()(OutT *out,
const T1 *mask,
const T2 *index,
const int num) {
int store_fix = 0;
for (int idx = 0; idx < num; idx++) {
if (mask[idx]) {
T2 data_index = index[idx];
// get index
for (int rank_id = dims - 1; rank_id >= 0; --rank_id) {
out[store_fix] = static_cast<OutT>(data_index / stride[rank_id]);
data_index = data_index % stride[rank_id];
store_fix++;
}
}
}
}
}
};
template <typename T, typename Context>
void WhereIndexKernel(const Context &dev_ctx,
const DenseTensor &condition,
DenseTensor *out) {
const T *cond_data = condition.data<T>();
const int64_t numel = condition.numel();
DenseTensor in_data;
auto dims = condition.dims();
const int rank = dims.size();
auto d_array_mem =
paddle::memory::Alloc(dev_ctx, (numel + rank) * sizeof(int64_t));
auto h_array_mem =
paddle::memory::Alloc(phi::CPUPlace(), (rank + 1) * sizeof(int64_t));
// "stride_array" is an array and len(stride_array)==rank,
// each element is the stride of each dimension -- the length from i to i+1.
int64_t *h_stride_array = reinterpret_cast<int64_t *>(h_array_mem->ptr());
int64_t *d_stride_array = reinterpret_cast<int64_t *>(d_array_mem->ptr());
// "true_num_array" is an array and len(stride_array)==numel,
// at the beginning,
// "true_num_array" will set 1 if condition[i] == true else 0,
// then it will be calculated by cub::InclusiveSum,
// so that we can get the true number before i as the out index
int64_t *d_true_num_array = d_stride_array + rank;
// the total_true_num is the total number of condition[i] == true
int64_t *h_total_true_num = h_stride_array + rank;
// alloce cub memory
size_t cub_size = 0;
cub::DeviceScan::InclusiveSum(nullptr,
cub_size,
d_true_num_array,
d_true_num_array,
numel,
dev_ctx.stream());
auto cub_mem = paddle::memory::Alloc(dev_ctx, cub_size * sizeof(int64_t));
void *cub_data = cub_mem->ptr();
// set d_true_num_array[i]=1 if cond_data[i]==true else 0
const int threads = std::min(numel, static_cast<int64_t>(128));
const int64_t need_grids = (numel + threads - 1) / threads;
const int grids = std::min(need_grids, static_cast<int64_t>(256));
GetTrueNum<T><<<grids, threads, 0, dev_ctx.stream()>>>(
cond_data, numel, d_true_num_array);
// calculate the inclusive prefix sum of "true_num_array"
// to get the index of "out" tensor,
// and the total number of cond_data[i]==true.
// Example:
// condition: F T T F F F T T
// before: 0 1 1 0 0 0 1 1
// after: 0 1 2 2 2 2 3 4
// out: 1 2 6 7
cub::DeviceScan::InclusiveSum(cub_data,
cub_size,
d_true_num_array,
d_true_num_array,
numel,
dev_ctx.stream());
// calculate each dimension's stride
h_stride_array[rank - 1] = 1;
for (int i = rank - 2; i >= 0; i--) {
h_stride_array[i] = h_stride_array[i + 1] * dims[i + 1];
}
paddle::memory::Copy(dev_ctx.GetPlace(),
d_stride_array,
phi::CPUPlace(),
h_stride_array,
rank * sizeof(int64_t),
dev_ctx.stream());
// get total ture number and set output size
// the last element of cub::InclusiveSum is the total number
paddle::memory::Copy(phi::CPUPlace(),
h_total_true_num,
dev_ctx.GetPlace(),
d_true_num_array + numel - 1,
sizeof(int64_t),
dev_ctx.stream());
dev_ctx.Wait();
int64_t true_num = *h_total_true_num;
out->Resize(phi::make_ddim({static_cast<int64_t>(true_num), rank}));
auto *out_data = dev_ctx.template Alloc<int64_t>(out);
if (true_num == 0) {
return;
}
// using true_num_array and stride_array to calculate the output index
SetTrueIndex<T><<<grids, threads, 0, dev_ctx.stream()>>>(
out_data, cond_data, numel, d_stride_array, rank, d_true_num_array);
using Functor = IndexFunctor<T, int64_t, int64_t>;
Functor index_functor = Functor(dims);
phi::funcs::SelectKernel<T, T, int64_t, 0, Functor>(
dev_ctx, condition, in_data, out, index_functor);
}
} // namespace phi
PD_REGISTER_KERNEL(where_index,
......
......@@ -22,7 +22,6 @@
#endif
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
// #include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
namespace phi {
......@@ -591,7 +590,7 @@ __device__ __forceinline__ void Cumsum(OutT* out,
int index = (tidx + 1) * 2 * stride - 1;
if (index < (blockDim.x * 2)) {
temp[index + index / 32] =
compute(temp[index + index / 2],
compute(temp[index + index / 32],
temp[index - stride + (index - stride) / 32]);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册