未验证 提交 f53db251 编写于 作者: N niuliling123 提交者: GitHub

Support MaskedSelectGrad op with Kernel Primitive API (#40617)

* Support MaskedSelectGrad op with Kernel Primitive API
上级 34a256c9
...@@ -168,57 +168,82 @@ __global__ void CumsumOneBlock( ...@@ -168,57 +168,82 @@ __global__ void CumsumOneBlock(
} }
} }
// where_index
template <typename OutT, template <typename OutT,
typename MT, typename MT,
typename InT, typename InT,
typename IdT,
typename Functor, typename Functor,
int VecSize, int VecSize,
int IsBoundary, int IsBoundary,
int IsMaskData> int MaskData>
struct SelectCaller { struct SelectCaller {
__device__ void inline operator()(OutT *store_data, __device__ void inline operator()(OutT *out,
const MT *mask_data, const MT *mask_data,
const InT *in, const InT *in,
Functor func, Functor func,
int num, int data_offset,
int data_offset) { int store_num,
// where_index op int thread_fix,
IdT index_reg[VecSize]; int num) {
// Set data index of global int64_t in_data[VecSize];
kps::InitWithDataIndex<IdT, VecSize, 1, 1>(&index_reg[0], data_offset); OutT store_data[VecSize * phi::DDim::kMaxRank];
// set index
kps::InitWithDataIndex<int64_t, VecSize, 1, 1>(&in_data[0], data_offset);
// Get store data according to mask_idt // Get store data according to mask_idt
kps::OperatorTernary<MT, IdT, OutT, Functor>( kps::OperatorTernary<MT, int64_t, OutT, Functor>(
store_data, mask_data, &index_reg[0], func, VecSize); store_data, mask_data, &in_data[0], func, VecSize);
kps::details::WriteData<OutT>(out + thread_fix, &store_data[0], store_num);
} }
}; };
// masked_select
template <typename OutT, template <typename OutT,
typename MT, typename MT,
typename InT, typename InT,
typename IdT,
typename Functor, typename Functor,
int VecSize, int VecSize,
int IsBoundary> int IsBoundary>
struct SelectCaller<OutT, struct SelectCaller<OutT, MT, InT, Functor, VecSize, IsBoundary, 1> {
MT, __device__ void inline operator()(OutT *out,
InT,
IdT,
Functor,
VecSize,
IsBoundary,
1> { // masked_select
__device__ void inline operator()(OutT *store_data,
const MT *mask_data, const MT *mask_data,
const InT *in, const InT *in,
Functor func, Functor func,
int num, int data_offset,
int data_offset) { int store_num,
int thread_fix,
int num) {
InT in_data[VecSize]; InT in_data[VecSize];
OutT store_data[VecSize * phi::DDim::kMaxRank];
kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(&in_data[0], in, num); kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(&in_data[0], in, num);
// Get store data according to mask_idt // Get store data according to mask_idt
kps::OperatorTernary<MT, InT, OutT, Functor>( kps::OperatorTernary<MT, InT, OutT, Functor>(
store_data, mask_data, &in_data[0], func, VecSize); store_data, mask_data, &in_data[0], func, VecSize);
kps::details::WriteData<OutT>(out + thread_fix, &store_data[0], store_num);
}
};
// masked_select_grad
template <typename OutT,
typename MT,
typename InT,
typename Functor,
int VecSize,
int IsBoundary>
struct SelectCaller<OutT, MT, InT, Functor, VecSize, IsBoundary, 2> {
__device__ void inline operator()(OutT *out,
const MT *mask_data,
const InT *in,
Functor func,
int data_offset,
int store_num,
int thread_fix,
int num) {
InT in_data[VecSize];
OutT store_data[VecSize * phi::DDim::kMaxRank];
kps::details::ReadData<InT>(&in_data[0], in + thread_fix, store_num);
kps::OperatorTernary<MT, InT, OutT, Functor>(
store_data, mask_data, &in_data[0], func, VecSize);
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out, &store_data[0], num);
} }
}; };
...@@ -253,7 +278,6 @@ SelectKernelImpl(OutT *out, ...@@ -253,7 +278,6 @@ SelectKernelImpl(OutT *out,
IdT num_thread[kCVecSize]; IdT num_thread[kCVecSize];
IdT cumsum_thread[kCVecSize]; IdT cumsum_thread[kCVecSize];
OutT store_data[VecSize * phi::DDim::kMaxRank];
MT mask_data[VecSize]; MT mask_data[VecSize];
IdT mask_idt[VecSize]; IdT mask_idt[VecSize];
// init data_pr // init data_pr
...@@ -271,17 +295,15 @@ SelectKernelImpl(OutT *out, ...@@ -271,17 +295,15 @@ SelectKernelImpl(OutT *out,
// Get cumsum_thread cumsum from 0 to num_thread cumsum_thread[0] is the // Get cumsum_thread cumsum from 0 to num_thread cumsum_thread[0] is the
// thread_fix // thread_fix
kps::Cumsum<IdT, IdT, 1, Add>(&cumsum_thread[0], &num_thread[0], Add()); kps::Cumsum<IdT, IdT, 1, Add>(&cumsum_thread[0], &num_thread[0], Add());
// Get store data(index) according to mask_idt
SelectCaller<OutT, MT, InT, IdT, Functor, VecSize, IsBoundary, MaskData>
compute;
compute(&store_data[0], &mask_data[0], in, func, num, data_offset);
// get thread_fix // get thread_fix
int thread_fix = int thread_fix =
(static_cast<int>(cumsum_thread[0] - num_thread[0]) * store_rank); (static_cast<int>(cumsum_thread[0] - num_thread[0]) * store_rank);
// get how many data need to store // get how many data need to store
int store_num = static_cast<int>(num_thread[0]) * store_rank; int store_num = static_cast<int>(num_thread[0]) * store_rank;
// thread store num data, each thread may has different num // thread store num data, each thread may has different num
kps::details::WriteData<OutT>(out + thread_fix, &store_data[0], store_num); // Get store data(index) according to mask_idt
SelectCaller<OutT, MT, InT, Functor, VecSize, IsBoundary, MaskData> select;
select(out, mask_data, in, func, data_offset, store_num, thread_fix, num);
} }
template <typename MT, template <typename MT,
...@@ -303,15 +325,17 @@ __global__ void SelectKernel(OutT *out, ...@@ -303,15 +325,17 @@ __global__ void SelectKernel(OutT *out,
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
int repeat = 0; int repeat = 0;
int size = VecSize * BLOCK_ID_X; int size = VecSize * BLOCK_ID_X;
CT block_store_offset = 0;
for (; data_offset < main_offset; data_offset += stride) { for (; data_offset < main_offset; data_offset += stride) {
// Cumsum index // Cumsum index
int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X; int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
// niuliling todo: us ReadData API kps::details::ReadData<CT>(&block_store_offset, cumsum + idx_cumsum, 1);
int block_store_offset = cumsum[idx_cumsum]; int out_fix = MaskData < 2 ? block_store_offset * store_rank : data_offset;
int in_fix = MaskData < 2 ? data_offset : block_store_offset * store_rank;
SelectKernelImpl<InT, MT, OutT, Functor, VecSize, MaskData, false>( SelectKernelImpl<InT, MT, OutT, Functor, VecSize, MaskData, false>(
out + block_store_offset * store_rank, out + out_fix,
mask + data_offset, mask + data_offset,
in + data_offset, in + in_fix,
func, func,
size, size,
data_offset, data_offset,
...@@ -323,12 +347,13 @@ __global__ void SelectKernel(OutT *out, ...@@ -323,12 +347,13 @@ __global__ void SelectKernel(OutT *out,
if (num > 0) { if (num > 0) {
// Cumsum index // Cumsum index
int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X; int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
// niuliling todo: us ReadData API kps::details::ReadData<CT>(&block_store_offset, cumsum + idx_cumsum, 1);
int block_store_offset = static_cast<int>(cumsum[idx_cumsum]); int out_fix = MaskData < 2 ? block_store_offset * store_rank : data_offset;
int in_fix = MaskData < 2 ? data_offset : block_store_offset * store_rank;
SelectKernelImpl<InT, MT, OutT, Functor, VecSize, MaskData, true>( SelectKernelImpl<InT, MT, OutT, Functor, VecSize, MaskData, true>(
out + block_store_offset * store_rank, out + out_fix,
mask + data_offset, mask + data_offset,
in + data_offset, in + in_fix,
func, func,
num, num,
data_offset, data_offset,
...@@ -402,6 +427,7 @@ void SelectKernel(const KPDevice &dev_ctx, ...@@ -402,6 +427,7 @@ void SelectKernel(const KPDevice &dev_ctx,
const int kCumVesize = 2; const int kCumVesize = 2;
const int block_c = 256; const int block_c = 256;
const int main_offset_c = Floor(size_count_block, (kCumVesize * block_c)); const int main_offset_c = Floor(size_count_block, (kCumVesize * block_c));
using Add = kps::AddFunctor<CT>; using Add = kps::AddFunctor<CT>;
CumsumOneBlock<CT, CT, Add, kCumVesize><<<1, block_c, 0, stream>>>( CumsumOneBlock<CT, CT, Add, kCumVesize><<<1, block_c, 0, stream>>>(
count_data, cumsum_data, size_count_block, main_offset_c, Add()); count_data, cumsum_data, size_count_block, main_offset_c, Add());
...@@ -418,10 +444,13 @@ void SelectKernel(const KPDevice &dev_ctx, ...@@ -418,10 +444,13 @@ void SelectKernel(const KPDevice &dev_ctx,
dev_ctx.Wait(); dev_ctx.Wait();
// 3.1.2 allock for out with total_true_num // 3.1.2 allock for out with total_true_num
std::vector<int64_t> out_dim = {static_cast<int64_t>(total_true_num)}; std::vector<int64_t> out_dim = {static_cast<int64_t>(total_true_num)};
if (SelectData == 0) { // where_index
if (SelectData == 1) {
out->Resize(phi::make_ddim(out_dim));
} else if (SelectData == 0) { // == 0 where_index
out_dim.push_back(rank); out_dim.push_back(rank);
}
out->Resize(phi::make_ddim(out_dim)); out->Resize(phi::make_ddim(out_dim));
}
auto out_data = out->mutable_data<OutT>(cuda_place); auto out_data = out->mutable_data<OutT>(cuda_place);
// 3.2 get true data's index according to cond_data and cumsum_data // 3.2 get true data's index according to cond_data and cumsum_data
if (total_true_num <= 0) return; if (total_true_num <= 0) return;
......
...@@ -17,38 +17,31 @@ ...@@ -17,38 +17,31 @@
#include <thrust/reverse.h> #include <thrust/reverse.h>
#include <thrust/scan.h> #include <thrust/scan.h>
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/kernels/funcs/select_impl.cu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/masked_select_grad_kernel.h" #include "paddle/phi/kernels/masked_select_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
__global__ void SetMaskArrayT(const bool* mask, int32_t* mask_array, int size) { template <typename MT, typename InT, typename OutT>
int idx = blockDim.x * blockIdx.x + threadIdx.x; struct MaskedSelectGradFunctor {
for (; idx < size; idx += blockDim.x * gridDim.x) { HOSTDEVICE MaskedSelectGradFunctor() {}
if (mask[idx])
mask_array[idx] = 1;
else
mask_array[idx] = 0;
}
}
template <typename T> HOSTDEVICE inline void operator()(OutT* out,
__global__ void SelectGradWithPrefixMask(const int32_t* mask_prefix_sum, const MT* mask,
const bool* mask, const InT* value,
const T* input, int num) {
T* out, int read_fix = 0;
int size) { for (int idx = 0; idx < num; idx++) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
if (mask[idx]) { if (mask[idx]) {
int index = mask_prefix_sum[idx]; out[idx] = value[read_fix++];
out[idx] = input[index];
} else { } else {
out[idx] = 0; out[idx] = 0;
} }
} }
} }
};
template <typename T, typename Context> template <typename T, typename Context>
void MaskedSelectGradKernel(const Context& dev_ctx, void MaskedSelectGradKernel(const Context& dev_ctx,
...@@ -56,42 +49,12 @@ void MaskedSelectGradKernel(const Context& dev_ctx, ...@@ -56,42 +49,12 @@ void MaskedSelectGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& mask, const DenseTensor& mask,
DenseTensor* x_grad) { DenseTensor* x_grad) {
auto* mask_data = mask.data<bool>();
auto* input_data = out_grad.data<T>();
auto* out_data = x_grad->mutable_data<T>(dev_ctx.GetPlace());
auto input_size = out_grad.numel();
auto mask_size = mask.numel(); auto mask_size = mask.numel();
auto mask_dim = mask.dims(); auto* out_data = x_grad->mutable_data<T>(dev_ctx.GetPlace());
if (mask_size <= 0) return;
auto out_size = mask_size; using Functor = MaskedSelectGradFunctor<bool, T, T>;
phi::funcs::SelectKernel<bool, T, T, 2, Functor>(
DenseTensor mask_array; dev_ctx, mask, out_grad, x_grad, Functor());
DenseTensor mask_prefix_sum;
mask_array.Resize(mask_dim);
mask_prefix_sum.Resize(mask_dim);
int32_t* mask_array_data =
mask_array.mutable_data<int32_t>(dev_ctx.GetPlace());
int32_t* mask_prefix_sum_data =
mask_prefix_sum.mutable_data<int32_t>(dev_ctx.GetPlace());
int threads = 512;
int grid = (mask_size + threads - 1) / threads;
auto stream = dev_ctx.stream();
SetMaskArrayT<<<grid, threads, 0, stream>>>(
mask_data, mask_array_data, mask_size);
thrust::device_ptr<int32_t> mask_array_dev_ptr =
thrust::device_pointer_cast(mask_array_data);
thrust::device_vector<int32_t> mask_array_vec(mask_array_dev_ptr,
mask_array_dev_ptr + mask_size);
thrust::exclusive_scan(thrust::device,
mask_array_vec.begin(),
mask_array_vec.end(),
mask_prefix_sum_data);
SelectGradWithPrefixMask<T><<<grid, threads, 0, stream>>>(
mask_prefix_sum_data, mask_data, input_data, out_data, mask_size);
} }
} // namespace phi } // namespace phi
......
...@@ -17,11 +17,12 @@ ...@@ -17,11 +17,12 @@
#include <thrust/reverse.h> #include <thrust/reverse.h>
#include <thrust/scan.h> #include <thrust/scan.h>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/select_impl.cu.h" #include "paddle/phi/kernels/funcs/select_impl.cu.h"
#include "paddle/phi/kernels/masked_select_kernel.h" #include "paddle/phi/kernels/masked_select_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
template <typename MT, typename InT, typename OutT> template <typename MT, typename InT, typename OutT>
......
...@@ -20,13 +20,14 @@ ...@@ -20,13 +20,14 @@
namespace cub = hipcub; namespace cub = hipcub;
#endif #endif
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/select_impl.cu.h" #include "paddle/phi/kernels/funcs/select_impl.cu.h"
#include "paddle/phi/kernels/where_index_kernel.h" #include "paddle/phi/kernels/where_index_kernel.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
template <typename T1, typename T2, typename OutT> template <typename T1, typename T2, typename OutT>
struct IndexFunctor { struct IndexFunctor {
......
...@@ -123,6 +123,15 @@ __device__ __forceinline__ void WriteData(T* dst, ...@@ -123,6 +123,15 @@ __device__ __forceinline__ void WriteData(T* dst,
dst[i] = src[i]; dst[i] = src[i];
} }
} }
template <typename T>
__device__ __forceinline__ void ReadData(T* dst,
const T* __restrict__ src,
int num) {
for (int i = 0; i < num; i++) {
dst[i] = src[i];
}
}
#undef INT_BITS #undef INT_BITS
} // namespace details } // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册