未验证 提交 1aa64d13 编写于 作者: Z zhangkaihuo 提交者: GitHub

[Sparse]optimize sparse convolution and fix MaskHelper bug (#47703)

上级 5c7fce47
...@@ -26,6 +26,19 @@ __global__ void DistanceKernel(const T* start, const T* end, T* distance) { ...@@ -26,6 +26,19 @@ __global__ void DistanceKernel(const T* start, const T* end, T* distance) {
} }
} }
inline __device__ bool SetBits(const int value, int* ptr) {
const int index = value >> 5;
const int mask = 1 << (value & 31);
const int old = atomicOr(ptr + index, mask);
return (mask & old) != 0;
}
inline __device__ bool TestBits(const int value, const int* ptr) {
const int index = value >> 5;
const int mask = 1 << (value & 31);
return (mask & ptr[index]) != 0;
}
} // namespace sparse } // namespace sparse
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -167,7 +167,7 @@ inline void GatherV2(const GPUContext& dev_ctx, ...@@ -167,7 +167,7 @@ inline void GatherV2(const GPUContext& dev_ctx,
template <typename IntT> template <typename IntT>
__global__ void UniqueKernel(const IntT* in_indexs, __global__ void UniqueKernel(const IntT* in_indexs,
const int rulebook_len, const int rulebook_len,
int* out_index_table, int* index_flags,
int* out_indexs, int* out_indexs,
int* nnz) { int* nnz) {
extern __shared__ int cache[]; extern __shared__ int cache[];
...@@ -182,8 +182,8 @@ __global__ void UniqueKernel(const IntT* in_indexs, ...@@ -182,8 +182,8 @@ __global__ void UniqueKernel(const IntT* in_indexs,
if (i < rulebook_len) { if (i < rulebook_len) {
// atomicOr only support int // atomicOr only support int
int index = static_cast<int>(in_indexs[i]); int index = static_cast<int>(in_indexs[i]);
int flag = atomicOr(out_index_table + index, 1); const bool flag = phi::funcs::sparse::SetBits(index, index_flags);
if (flag == 0) { if (!flag) {
int j = atomicAdd(&count, 1); int j = atomicAdd(&count, 1);
cache[j] = index; cache[j] = index;
} }
...@@ -284,7 +284,6 @@ __global__ void ProductRuleBookKernel(const T* x_indices, ...@@ -284,7 +284,6 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
atomicAdd(&counter_buf[kernel_index], 1); atomicAdd(&counter_buf[kernel_index], 1);
kernel_i = kernel_index; kernel_i = kernel_index;
} }
// rulebook[kernel_index * non_zero_num + i] = kernel_i;
rulebook[kernel_index * non_zero_num + i] = in_i; rulebook[kernel_index * non_zero_num + i] = in_i;
rulebook[kernel_index * non_zero_num + offset + i] = out_index; rulebook[kernel_index * non_zero_num + offset + i] = out_index;
++kernel_index; ++kernel_index;
...@@ -299,17 +298,19 @@ __global__ void ProductRuleBookKernel(const T* x_indices, ...@@ -299,17 +298,19 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
} }
template <typename IntT> template <typename IntT>
__global__ void GetOutIndexTable(const IntT* indices, __global__ void GetOutIndexTable1(const IntT* indices,
const IntT non_zero_num, const IntT non_zero_num,
const Dims4D dims, const Dims4D dims,
int* out_index_table) { int* index_flags,
int* out_index_table) {
CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) { CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) {
IntT batch = indices[i]; IntT batch = indices[i];
IntT in_z = indices[i + non_zero_num]; IntT in_z = indices[i + non_zero_num];
IntT in_y = indices[i + 2 * non_zero_num]; IntT in_y = indices[i + 2 * non_zero_num];
IntT in_x = indices[i + 3 * non_zero_num]; IntT in_x = indices[i + 3 * non_zero_num];
IntT index = PointToIndex(batch, in_x, in_y, in_z, dims); IntT index = PointToIndex(batch, in_x, in_y, in_z, dims);
out_index_table[index] = i == 0 ? -1 : i; phi::funcs::sparse::SetBits(index, index_flags);
out_index_table[index] = i;
} }
} }
...@@ -375,6 +376,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices, ...@@ -375,6 +376,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
const Dims4D paddings, const Dims4D paddings,
const Dims4D dilations, const Dims4D dilations,
const Dims4D strides, const Dims4D strides,
const int* index_flags,
const int* out_index_table, const int* out_index_table,
T* rulebook, T* rulebook,
int* counter) { int* counter) {
...@@ -417,9 +419,10 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices, ...@@ -417,9 +419,10 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
T out_x = (in_x + paddings[3] - kx * dilations[3]) / strides[3]; T out_x = (in_x + paddings[3] - kx * dilations[3]) / strides[3];
out_index = phi::funcs::sparse::PointToIndex<Dims4D>( out_index = phi::funcs::sparse::PointToIndex<Dims4D>(
batch, out_x, out_y, out_z, out_dims); batch, out_x, out_y, out_z, out_dims);
int real_out_index = out_index_table[out_index]; const bool flag =
if (real_out_index != 0) { phi::funcs::sparse::TestBits(out_index, index_flags);
real_out_index = real_out_index == -1 ? 0 : real_out_index; if (flag) {
int real_out_index = out_index_table[out_index];
in_i = i; in_i = i;
int buf_i = atomicAdd(&counter_buf[kernel_index], 1); int buf_i = atomicAdd(&counter_buf[kernel_index], 1);
kernel_i = kernel_index; kernel_i = kernel_index;
...@@ -440,7 +443,6 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices, ...@@ -440,7 +443,6 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
__syncthreads(); __syncthreads();
for (int i = 0; i < kernel_size; i++) { for (int i = 0; i < kernel_size; i++) {
if (threadIdx.x < counter_buf[i]) { if (threadIdx.x < counter_buf[i]) {
// rulebook[i * non_zero_num + counter_buf2[i] + threadIdx.x] = i;
rulebook[i * non_zero_num + counter_buf2[i] + threadIdx.x] = rulebook[i * non_zero_num + counter_buf2[i] + threadIdx.x] =
rulebook_buf[i * blockDim.x + threadIdx.x]; rulebook_buf[i * blockDim.x + threadIdx.x];
rulebook[i * non_zero_num + offset + counter_buf2[i] + threadIdx.x] = rulebook[i * non_zero_num + offset + counter_buf2[i] + threadIdx.x] =
...@@ -575,12 +577,18 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -575,12 +577,18 @@ int ProductRuleBook(const Context& dev_ctx,
DenseTensorMeta rulebook_meta( DenseTensorMeta rulebook_meta(
indices_dtype, {rulebook_rows, rulebook_cols}, DataLayout::NCHW); indices_dtype, {rulebook_rows, rulebook_cols}, DataLayout::NCHW);
int64_t table_size = 1; int table_size = 1;
for (int i = 0; i < out_dims.size() - 1; i++) { for (int i = 0; i < out_dims.size() - 1; i++) {
table_size *= out_dims[i]; table_size *= out_dims[i];
} }
DenseTensor out_index_table = phi::Empty<int>(dev_ctx, {table_size}); DenseTensor out_index_table = phi::Empty<int>(dev_ctx, {table_size});
int* out_index_table_ptr = out_index_table.data<int>(); int* out_index_table_ptr = out_index_table.data<int>();
// index_flags: flag the indices exist or not
int index_flags_size = (table_size + 31) / 32;
DenseTensor index_flags = phi::Empty<int>(dev_ctx, {index_flags_size});
int* index_flags_ptr = index_flags.data<int>();
phi::backends::gpu::GpuMemsetAsync(
index_flags_ptr, 0, sizeof(int) * index_flags.numel(), dev_ctx.stream());
if (subm) { if (subm) {
DenseTensor tmp_rulebook = phi::Empty(dev_ctx, std::move(rulebook_meta)); DenseTensor tmp_rulebook = phi::Empty(dev_ctx, std::move(rulebook_meta));
...@@ -590,16 +598,16 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -590,16 +598,16 @@ int ProductRuleBook(const Context& dev_ctx,
phi::Copy(dev_ctx, x.indices(), dev_ctx.GetPlace(), false, &out_indices); phi::Copy(dev_ctx, x.indices(), dev_ctx.GetPlace(), false, &out_indices);
phi::backends::gpu::GpuMemsetAsync(
out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream());
auto config = auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);
GetOutIndexTable<IntT><<<config.block_per_grid, GetOutIndexTable1<IntT><<<config.block_per_grid,
config.thread_per_block, config.thread_per_block,
0, 0,
dev_ctx.stream()>>>( dev_ctx.stream()>>>(out_indices.data<IntT>(),
out_indices.data<IntT>(), non_zero_num, d_x_dims, out_index_table_ptr); non_zero_num,
d_x_dims,
index_flags_ptr,
out_index_table_ptr);
size_t cache_size = size_t cache_size =
kernel_size * 2 * sizeof(int) + kernel_size * 2 * sizeof(int) +
...@@ -625,6 +633,7 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -625,6 +633,7 @@ int ProductRuleBook(const Context& dev_ctx,
d_paddings, d_paddings,
d_dilations, d_dilations,
d_strides, d_strides,
index_flags_ptr,
out_index_table_ptr, out_index_table_ptr,
rulebook_ptr, rulebook_ptr,
counter_ptr); counter_ptr);
...@@ -695,9 +704,6 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -695,9 +704,6 @@ int ProductRuleBook(const Context& dev_ctx,
int* out_index_ptr = out_index->data<int>(); int* out_index_ptr = out_index->data<int>();
int* unique_key_ptr = unique_key.data<int>(); int* unique_key_ptr = unique_key.data<int>();
phi::backends::gpu::GpuMemsetAsync(
out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream());
phi::backends::gpu::GpuMemsetAsync( phi::backends::gpu::GpuMemsetAsync(
unique_key_ptr, 0, sizeof(int), dev_ctx.stream()); unique_key_ptr, 0, sizeof(int), dev_ctx.stream());
...@@ -708,7 +714,7 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -708,7 +714,7 @@ int ProductRuleBook(const Context& dev_ctx,
cache_size, cache_size,
dev_ctx.stream()>>>(rulebook_ptr + rulebook_len, dev_ctx.stream()>>>(rulebook_ptr + rulebook_len,
rulebook_len, rulebook_len,
out_index_table_ptr, index_flags_ptr,
out_index_ptr, out_index_ptr,
unique_key_ptr); unique_key_ptr);
......
...@@ -25,6 +25,7 @@ limitations under the License. */ ...@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h" #include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h"
#include "paddle/phi/kernels/funcs/sparse/utils.cu.h"
namespace phi { namespace phi {
namespace sparse { namespace sparse {
...@@ -118,15 +119,20 @@ void SparseMaskKernel(const Context& dev_ctx, ...@@ -118,15 +119,20 @@ void SparseMaskKernel(const Context& dev_ctx,
} }
template <typename IntT> template <typename IntT>
__global__ void MaskTable(const IntT* x_indexs, const int n, int* table) { __global__ void MaskTable(const IntT* x_indexs,
const int n,
int* index_flags,
int* table) {
CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) {
int index = x_indexs[i]; int index = x_indexs[i];
table[index] = i == 0 ? -1 : i; phi::funcs::sparse::SetBits(index, index_flags);
table[index] = i;
} }
} }
template <typename T, typename IntT, int VecSize> template <typename T, typename IntT, int VecSize>
__global__ void MaskCopy(const IntT* mask_indexs, __global__ void MaskCopy(const IntT* mask_indexs,
const int* index_flags,
const int* table, const int* table,
const int n, const int n,
const int stride, const int stride,
...@@ -135,9 +141,10 @@ __global__ void MaskCopy(const IntT* mask_indexs, ...@@ -135,9 +141,10 @@ __global__ void MaskCopy(const IntT* mask_indexs,
using LoadT = phi::AlignedVector<T, VecSize>; using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>; using StoreT = phi::AlignedVector<T, VecSize>;
CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) {
int j = table[mask_indexs[i]]; const int mask_index = mask_indexs[i];
if (j != 0) { const bool flag = phi::funcs::sparse::TestBits(mask_index, index_flags);
if (j == -1) j = 0; if (flag) {
int j = table[mask_index];
for (int k = 0; k < stride; k += VecSize) { for (int k = 0; k < stride; k += VecSize) {
LoadT vec_x; LoadT vec_x;
phi::Load<T, VecSize>(x_values + j * stride + k, &vec_x); phi::Load<T, VecSize>(x_values + j * stride + k, &vec_x);
...@@ -217,12 +224,15 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, ...@@ -217,12 +224,15 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
int table_size = 1; int table_size = 1;
auto x_dims = x.dims(); auto x_dims = x.dims();
for (int i = 0; i < x_dims.size() - 1; i++) { for (int i = 0; i < sparse_dim; i++) {
table_size *= x_dims[i]; table_size *= x_dims[i];
} }
DenseTensor table = phi::Empty<int>(dev_ctx, {table_size}); DenseTensor table = phi::Empty<int>(dev_ctx, {table_size});
phi::backends::gpu::GpuMemsetAsync( DenseTensor index_flags = phi::Empty<int>(dev_ctx, {(table_size + 31) / 32});
table.data<int>(), 0, table_size * sizeof(int), dev_ctx.stream()); phi::backends::gpu::GpuMemsetAsync(index_flags.data<int>(),
0,
index_flags.numel() * sizeof(int),
dev_ctx.stream());
const int64_t stride = const int64_t stride =
x.dims().size() == sparse_dim ? 1 : x.values().dims()[1]; x.dims().size() == sparse_dim ? 1 : x.values().dims()[1];
*out = phi::EmptyLike<T>(dev_ctx, x.values()); *out = phi::EmptyLike<T>(dev_ctx, x.values());
...@@ -234,8 +244,10 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, ...@@ -234,8 +244,10 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
MaskTable<<<config.block_per_grid, MaskTable<<<config.block_per_grid,
config.thread_per_block, config.thread_per_block,
0, 0,
dev_ctx.stream()>>>( dev_ctx.stream()>>>(x_indexs_ptr,
x_indexs_ptr, x_indexs.numel(), table.data<int>()); x_indexs.numel(),
index_flags.data<int>(),
table.data<int>());
config = config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1); phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1);
...@@ -246,6 +258,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, ...@@ -246,6 +258,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
config.thread_per_block, config.thread_per_block,
0, 0,
dev_ctx.stream()>>>(mask_indexs_ptr, dev_ctx.stream()>>>(mask_indexs_ptr,
index_flags.data<int>(),
table.data<int>(), table.data<int>(),
mask_indexs.numel(), mask_indexs.numel(),
stride, stride,
...@@ -256,6 +269,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, ...@@ -256,6 +269,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
config.thread_per_block, config.thread_per_block,
0, 0,
dev_ctx.stream()>>>(mask_indexs_ptr, dev_ctx.stream()>>>(mask_indexs_ptr,
index_flags.data<int>(),
table.data<int>(), table.data<int>(),
mask_indexs.numel(), mask_indexs.numel(),
stride, stride,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册