未验证 提交 f53db251 编写于 作者: N niuliling123 提交者: GitHub

Support MaskedSelectGrad op with Kernel Primitive API (#40617)

* Support MaskedSelectGrad op with Kernel Primitive API
上级 34a256c9
......@@ -168,57 +168,82 @@ __global__ void CumsumOneBlock(
}
}
// where_index
template <typename OutT,
typename MT,
typename InT,
typename IdT,
typename Functor,
int VecSize,
int IsBoundary,
int IsMaskData>
int MaskData>
struct SelectCaller {
__device__ void inline operator()(OutT *store_data,
__device__ void inline operator()(OutT *out,
const MT *mask_data,
const InT *in,
Functor func,
int num,
int data_offset) {
// where_index op
IdT index_reg[VecSize];
// Set data index of global
kps::InitWithDataIndex<IdT, VecSize, 1, 1>(&index_reg[0], data_offset);
int data_offset,
int store_num,
int thread_fix,
int num) {
int64_t in_data[VecSize];
OutT store_data[VecSize * phi::DDim::kMaxRank];
// set index
kps::InitWithDataIndex<int64_t, VecSize, 1, 1>(&in_data[0], data_offset);
// Get store data according to mask_idt
kps::OperatorTernary<MT, IdT, OutT, Functor>(
store_data, mask_data, &index_reg[0], func, VecSize);
kps::OperatorTernary<MT, int64_t, OutT, Functor>(
store_data, mask_data, &in_data[0], func, VecSize);
kps::details::WriteData<OutT>(out + thread_fix, &store_data[0], store_num);
}
};
// masked_select
template <typename OutT,
typename MT,
typename InT,
typename IdT,
typename Functor,
int VecSize,
int IsBoundary>
struct SelectCaller<OutT,
MT,
InT,
IdT,
Functor,
VecSize,
IsBoundary,
1> { // masked_select
__device__ void inline operator()(OutT *store_data,
struct SelectCaller<OutT, MT, InT, Functor, VecSize, IsBoundary, 1> {
__device__ void inline operator()(OutT *out,
const MT *mask_data,
const InT *in,
Functor func,
int num,
int data_offset) {
int data_offset,
int store_num,
int thread_fix,
int num) {
InT in_data[VecSize];
OutT store_data[VecSize * phi::DDim::kMaxRank];
kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(&in_data[0], in, num);
// Get store data according to mask_idt
kps::OperatorTernary<MT, InT, OutT, Functor>(
store_data, mask_data, &in_data[0], func, VecSize);
kps::details::WriteData<OutT>(out + thread_fix, &store_data[0], store_num);
}
};
// masked_select_grad
template <typename OutT,
typename MT,
typename InT,
typename Functor,
int VecSize,
int IsBoundary>
struct SelectCaller<OutT, MT, InT, Functor, VecSize, IsBoundary, 2> {
__device__ void inline operator()(OutT *out,
const MT *mask_data,
const InT *in,
Functor func,
int data_offset,
int store_num,
int thread_fix,
int num) {
InT in_data[VecSize];
OutT store_data[VecSize * phi::DDim::kMaxRank];
kps::details::ReadData<InT>(&in_data[0], in + thread_fix, store_num);
kps::OperatorTernary<MT, InT, OutT, Functor>(
store_data, mask_data, &in_data[0], func, VecSize);
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out, &store_data[0], num);
}
};
......@@ -253,7 +278,6 @@ SelectKernelImpl(OutT *out,
IdT num_thread[kCVecSize];
IdT cumsum_thread[kCVecSize];
OutT store_data[VecSize * phi::DDim::kMaxRank];
MT mask_data[VecSize];
IdT mask_idt[VecSize];
// init data_pr
......@@ -271,17 +295,15 @@ SelectKernelImpl(OutT *out,
// Get cumsum_thread cumsum from 0 to num_thread cumsum_thread[0] is the
// thread_fix
kps::Cumsum<IdT, IdT, 1, Add>(&cumsum_thread[0], &num_thread[0], Add());
// Get store data(index) according to mask_idt
SelectCaller<OutT, MT, InT, IdT, Functor, VecSize, IsBoundary, MaskData>
compute;
compute(&store_data[0], &mask_data[0], in, func, num, data_offset);
// get thread_fix
int thread_fix =
(static_cast<int>(cumsum_thread[0] - num_thread[0]) * store_rank);
// get how many data need to store
int store_num = static_cast<int>(num_thread[0]) * store_rank;
// thread store num data, each thread may has different num
kps::details::WriteData<OutT>(out + thread_fix, &store_data[0], store_num);
// Get store data(index) according to mask_idt
SelectCaller<OutT, MT, InT, Functor, VecSize, IsBoundary, MaskData> select;
select(out, mask_data, in, func, data_offset, store_num, thread_fix, num);
}
template <typename MT,
......@@ -303,15 +325,17 @@ __global__ void SelectKernel(OutT *out,
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
int repeat = 0;
int size = VecSize * BLOCK_ID_X;
CT block_store_offset = 0;
for (; data_offset < main_offset; data_offset += stride) {
// Cumsum index
int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
// niuliling todo: us ReadData API
int block_store_offset = cumsum[idx_cumsum];
kps::details::ReadData<CT>(&block_store_offset, cumsum + idx_cumsum, 1);
int out_fix = MaskData < 2 ? block_store_offset * store_rank : data_offset;
int in_fix = MaskData < 2 ? data_offset : block_store_offset * store_rank;
SelectKernelImpl<InT, MT, OutT, Functor, VecSize, MaskData, false>(
out + block_store_offset * store_rank,
out + out_fix,
mask + data_offset,
in + data_offset,
in + in_fix,
func,
size,
data_offset,
......@@ -323,12 +347,13 @@ __global__ void SelectKernel(OutT *out,
if (num > 0) {
// Cumsum index
int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
// niuliling todo: us ReadData API
int block_store_offset = static_cast<int>(cumsum[idx_cumsum]);
kps::details::ReadData<CT>(&block_store_offset, cumsum + idx_cumsum, 1);
int out_fix = MaskData < 2 ? block_store_offset * store_rank : data_offset;
int in_fix = MaskData < 2 ? data_offset : block_store_offset * store_rank;
SelectKernelImpl<InT, MT, OutT, Functor, VecSize, MaskData, true>(
out + block_store_offset * store_rank,
out + out_fix,
mask + data_offset,
in + data_offset,
in + in_fix,
func,
num,
data_offset,
......@@ -402,6 +427,7 @@ void SelectKernel(const KPDevice &dev_ctx,
const int kCumVesize = 2;
const int block_c = 256;
const int main_offset_c = Floor(size_count_block, (kCumVesize * block_c));
using Add = kps::AddFunctor<CT>;
CumsumOneBlock<CT, CT, Add, kCumVesize><<<1, block_c, 0, stream>>>(
count_data, cumsum_data, size_count_block, main_offset_c, Add());
......@@ -418,10 +444,13 @@ void SelectKernel(const KPDevice &dev_ctx,
dev_ctx.Wait();
// 3.1.2 allock for out with total_true_num
std::vector<int64_t> out_dim = {static_cast<int64_t>(total_true_num)};
if (SelectData == 0) { // where_index
if (SelectData == 1) {
out->Resize(phi::make_ddim(out_dim));
} else if (SelectData == 0) { // == 0 where_index
out_dim.push_back(rank);
out->Resize(phi::make_ddim(out_dim));
}
out->Resize(phi::make_ddim(out_dim));
auto out_data = out->mutable_data<OutT>(cuda_place);
// 3.2 get true data's index according to cond_data and cumsum_data
if (total_true_num <= 0) return;
......
......@@ -17,38 +17,31 @@
#include <thrust/reverse.h>
#include <thrust/scan.h>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/select_impl.cu.h"
#include "paddle/phi/kernels/masked_select_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
__global__ void SetMaskArrayT(const bool* mask, int32_t* mask_array, int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
if (mask[idx])
mask_array[idx] = 1;
else
mask_array[idx] = 0;
}
}
template <typename MT, typename InT, typename OutT>
struct MaskedSelectGradFunctor {
HOSTDEVICE MaskedSelectGradFunctor() {}
template <typename T>
__global__ void SelectGradWithPrefixMask(const int32_t* mask_prefix_sum,
const bool* mask,
const T* input,
T* out,
int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
if (mask[idx]) {
int index = mask_prefix_sum[idx];
out[idx] = input[index];
} else {
out[idx] = 0;
HOSTDEVICE inline void operator()(OutT* out,
const MT* mask,
const InT* value,
int num) {
int read_fix = 0;
for (int idx = 0; idx < num; idx++) {
if (mask[idx]) {
out[idx] = value[read_fix++];
} else {
out[idx] = 0;
}
}
}
}
};
template <typename T, typename Context>
void MaskedSelectGradKernel(const Context& dev_ctx,
......@@ -56,42 +49,12 @@ void MaskedSelectGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& mask,
DenseTensor* x_grad) {
auto* mask_data = mask.data<bool>();
auto* input_data = out_grad.data<T>();
auto* out_data = x_grad->mutable_data<T>(dev_ctx.GetPlace());
auto input_size = out_grad.numel();
auto mask_size = mask.numel();
auto mask_dim = mask.dims();
auto out_size = mask_size;
DenseTensor mask_array;
DenseTensor mask_prefix_sum;
mask_array.Resize(mask_dim);
mask_prefix_sum.Resize(mask_dim);
int32_t* mask_array_data =
mask_array.mutable_data<int32_t>(dev_ctx.GetPlace());
int32_t* mask_prefix_sum_data =
mask_prefix_sum.mutable_data<int32_t>(dev_ctx.GetPlace());
int threads = 512;
int grid = (mask_size + threads - 1) / threads;
auto stream = dev_ctx.stream();
SetMaskArrayT<<<grid, threads, 0, stream>>>(
mask_data, mask_array_data, mask_size);
thrust::device_ptr<int32_t> mask_array_dev_ptr =
thrust::device_pointer_cast(mask_array_data);
thrust::device_vector<int32_t> mask_array_vec(mask_array_dev_ptr,
mask_array_dev_ptr + mask_size);
thrust::exclusive_scan(thrust::device,
mask_array_vec.begin(),
mask_array_vec.end(),
mask_prefix_sum_data);
SelectGradWithPrefixMask<T><<<grid, threads, 0, stream>>>(
mask_prefix_sum_data, mask_data, input_data, out_data, mask_size);
auto* out_data = x_grad->mutable_data<T>(dev_ctx.GetPlace());
if (mask_size <= 0) return;
using Functor = MaskedSelectGradFunctor<bool, T, T>;
phi::funcs::SelectKernel<bool, T, T, 2, Functor>(
dev_ctx, mask, out_grad, x_grad, Functor());
}
} // namespace phi
......
......@@ -17,11 +17,12 @@
#include <thrust/reverse.h>
#include <thrust/scan.h>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/select_impl.cu.h"
#include "paddle/phi/kernels/masked_select_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename MT, typename InT, typename OutT>
......
......@@ -20,13 +20,14 @@
namespace cub = hipcub;
#endif
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/select_impl.cu.h"
#include "paddle/phi/kernels/where_index_kernel.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T1, typename T2, typename OutT>
struct IndexFunctor {
......
......@@ -123,6 +123,15 @@ __device__ __forceinline__ void WriteData(T* dst,
dst[i] = src[i];
}
}
template <typename T>
__device__ __forceinline__ void ReadData(T* dst,
const T* __restrict__ src,
int num) {
for (int i = 0; i < num; i++) {
dst[i] = src[i];
}
}
#undef INT_BITS
} // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册