未验证 提交 e8de9dfd 编写于 作者: Z zhangkaihuo 提交者: GitHub

SparseConv support duplicate coordinates (#44976)

* sparse conv support duplicate coordinates
上级 090caa0e
...@@ -79,6 +79,7 @@ __global__ void ScatterKernelV2(const T* input, ...@@ -79,6 +79,7 @@ __global__ void ScatterKernelV2(const T* input,
const int* index_groups, const int* index_groups,
const int non_zero_num, const int non_zero_num,
const int kernel_size, const int kernel_size,
const int max_voxel,
const int channels, const int channels,
const int buffer_counts, const int buffer_counts,
T* out) { T* out) {
...@@ -96,10 +97,11 @@ __global__ void ScatterKernelV2(const T* input, ...@@ -96,10 +97,11 @@ __global__ void ScatterKernelV2(const T* input,
&sums); &sums);
for (int it = 0; it < buffer_counts; it++) { for (int it = 0; it < buffer_counts; it++) {
int len = index_counts[indices_i + it * non_zero_num]; int len = index_counts[indices_i + it * non_zero_num];
const int group_offset = it * kernel_size * non_zero_num; const int group_offset = it * max_voxel * kernel_size * non_zero_num;
for (int j = 0; j < len; j++) { for (int j = 0; j < len; j++) {
const int out_feature_i = const int out_feature_i =
index_groups[indices_i * kernel_size + j + group_offset]; index_groups[indices_i * max_voxel * kernel_size + j +
group_offset];
LoadT vec_in; LoadT vec_in;
phi::Load<T, VecSize>( phi::Load<T, VecSize>(
input + out_feature_i * channels + channels_i * VecSize, &vec_in); input + out_feature_i * channels + channels_i * VecSize, &vec_in);
...@@ -121,6 +123,7 @@ void ScatterV2(const GPUContext& dev_ctx, ...@@ -121,6 +123,7 @@ void ScatterV2(const GPUContext& dev_ctx,
const int* index_groups, const int* index_groups,
const int non_zero_num, const int non_zero_num,
const int kernel_size, const int kernel_size,
const int max_voxel,
const int channels, const int channels,
const int buffer_counts, const int buffer_counts,
T* output) { T* output) {
...@@ -136,6 +139,7 @@ void ScatterV2(const GPUContext& dev_ctx, ...@@ -136,6 +139,7 @@ void ScatterV2(const GPUContext& dev_ctx,
index_groups, index_groups,
non_zero_num, non_zero_num,
kernel_size, kernel_size,
max_voxel,
channels, channels,
buffer_counts, buffer_counts,
output); output);
...@@ -150,6 +154,7 @@ void ScatterV2(const GPUContext& dev_ctx, ...@@ -150,6 +154,7 @@ void ScatterV2(const GPUContext& dev_ctx,
index_groups, index_groups,
non_zero_num, non_zero_num,
kernel_size, kernel_size,
max_voxel,
channels, channels,
buffer_counts, buffer_counts,
output); output);
......
...@@ -65,6 +65,7 @@ __global__ void GatherKernelV2(const T* inputs, ...@@ -65,6 +65,7 @@ __global__ void GatherKernelV2(const T* inputs,
const int* index_groups, const int* index_groups,
const int non_zero_num, const int non_zero_num,
const int kernel_size, const int kernel_size,
const int max_voxel,
const int channels, const int channels,
const int buffer_count, const int buffer_count,
T* output) { T* output) {
...@@ -82,10 +83,11 @@ __global__ void GatherKernelV2(const T* inputs, ...@@ -82,10 +83,11 @@ __global__ void GatherKernelV2(const T* inputs,
#pragma unroll #pragma unroll
for (int it = 0; it < buffer_count; it++) { for (int it = 0; it < buffer_count; it++) {
int len = index_counts[indices_i + it * non_zero_num]; int len = index_counts[indices_i + it * non_zero_num];
const int group_offset = it * kernel_size * non_zero_num; const int group_offset = it * kernel_size * max_voxel * non_zero_num;
#pragma unroll #pragma unroll
for (int j = 0; j < len; j++) { for (int j = 0; j < len; j++) {
int out_i = index_groups[indices_i * kernel_size + j + group_offset]; int out_i = index_groups[indices_i * kernel_size * max_voxel + j +
group_offset];
phi::Store<T, VecSize>( phi::Store<T, VecSize>(
in_vec, output + out_i * channels + channels_i * VecSize); in_vec, output + out_i * channels + channels_i * VecSize);
} }
...@@ -127,6 +129,7 @@ inline void GatherV2(const GPUContext& dev_ctx, ...@@ -127,6 +129,7 @@ inline void GatherV2(const GPUContext& dev_ctx,
const int* index_groups, const int* index_groups,
const int non_zero_num, const int non_zero_num,
const int kernel_size, const int kernel_size,
const int max_voxel,
const int channels, const int channels,
const int buffer_count, const int buffer_count,
T* output) { T* output) {
...@@ -142,6 +145,7 @@ inline void GatherV2(const GPUContext& dev_ctx, ...@@ -142,6 +145,7 @@ inline void GatherV2(const GPUContext& dev_ctx,
index_groups, index_groups,
non_zero_num, non_zero_num,
kernel_size, kernel_size,
max_voxel,
channels, channels,
buffer_count, buffer_count,
output); output);
...@@ -156,6 +160,7 @@ inline void GatherV2(const GPUContext& dev_ctx, ...@@ -156,6 +160,7 @@ inline void GatherV2(const GPUContext& dev_ctx,
index_groups, index_groups,
non_zero_num, non_zero_num,
kernel_size, kernel_size,
max_voxel,
channels, channels,
buffer_count, buffer_count,
output); output);
...@@ -202,7 +207,7 @@ __global__ void UniqueKernel(const IntT* in_indexs, ...@@ -202,7 +207,7 @@ __global__ void UniqueKernel(const IntT* in_indexs,
template <typename IntT> template <typename IntT>
__global__ void GroupIndexs(const int* out_index_table, __global__ void GroupIndexs(const int* out_index_table,
const int n, const int n,
const int kernel_size, const int offset,
IntT* out_indexs, IntT* out_indexs,
int* out_index_counts, int* out_index_counts,
int* out_index_groups) { int* out_index_groups) {
...@@ -214,7 +219,7 @@ __global__ void GroupIndexs(const int* out_index_table, ...@@ -214,7 +219,7 @@ __global__ void GroupIndexs(const int* out_index_table,
// kernel_size at most // kernel_size at most
int j = atomicAdd(out_index_counts + real_index, 1); int j = atomicAdd(out_index_counts + real_index, 1);
// nnz * kernel_size // nnz * kernel_size
out_index_groups[real_index * kernel_size + j] = i; out_index_groups[real_index * offset + j] = i;
} }
} }
...@@ -298,18 +303,36 @@ __global__ void ProductRuleBookKernel(const T* x_indices, ...@@ -298,18 +303,36 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
} }
} }
template <typename IntT> template <typename IntT, bool save_out_index = true>
__global__ void GetOutIndexTable(const IntT* indices, __global__ void GetOutIndexTable(const IntT* indices,
const IntT non_zero_num, const IntT non_zero_num,
const Dims4D dims, const Dims4D dims,
int* out_index_table) { int* out_index_table,
int* out_index_table2,
int* max_voxel) {
__shared__ int cache_max;
if (threadIdx.x == 0) {
cache_max = 0;
}
__syncthreads();
CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) { CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) {
IntT batch = indices[i]; IntT batch = indices[i];
IntT in_z = indices[i + non_zero_num]; IntT in_z = indices[i + non_zero_num];
IntT in_y = indices[i + 2 * non_zero_num]; IntT in_y = indices[i + 2 * non_zero_num];
IntT in_x = indices[i + 3 * non_zero_num]; IntT in_x = indices[i + 3 * non_zero_num];
IntT index = PointToIndex(batch, in_x, in_y, in_z, dims); IntT index = PointToIndex(batch, in_x, in_y, in_z, dims);
out_index_table[index] = i == 0 ? -1 : i; if (save_out_index) {
out_index_table[index] = i == 0 ? -1 : i;
}
int count = atomicAdd(out_index_table2 + index, 1);
atomicMax(&cache_max, count);
}
__syncthreads();
if (threadIdx.x == 0) {
atomicMax(max_voxel, cache_max + 1);
} }
} }
...@@ -318,10 +341,22 @@ __global__ void GetOutIndexTable(int* indexs, ...@@ -318,10 +341,22 @@ __global__ void GetOutIndexTable(int* indexs,
const int non_zero_num, const int non_zero_num,
const Dims4D out_dims, const Dims4D out_dims,
int* out_index_table, int* out_index_table,
int* out_index_table2,
int* max_voxel,
IntT* out_indices) { IntT* out_indices) {
__shared__ int cache_max;
if (threadIdx.x == 0) {
cache_max = 0;
}
__syncthreads();
CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) { CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) {
IntT index = static_cast<IntT>(indexs[i]); IntT index = static_cast<IntT>(indexs[i]);
out_index_table[index] = i; out_index_table[index] = i;
int count = atomicAdd(out_index_table2 + index, 1);
atomicMax(&cache_max, count);
IntT batch, x, y, z; IntT batch, x, y, z;
phi::funcs::sparse::IndexToPoint<Dims4D>( phi::funcs::sparse::IndexToPoint<Dims4D>(
index, out_dims, &batch, &x, &y, &z); index, out_dims, &batch, &x, &y, &z);
...@@ -332,6 +367,11 @@ __global__ void GetOutIndexTable(int* indexs, ...@@ -332,6 +367,11 @@ __global__ void GetOutIndexTable(int* indexs,
out_indices[i + non_zero_num * 3] = x; out_indices[i + non_zero_num * 3] = x;
indexs[i] = 0; indexs[i] = 0;
} }
__syncthreads();
if (threadIdx.x == 0) {
atomicMax(max_voxel, cache_max + 1);
}
} }
template <typename IntT> template <typename IntT>
...@@ -451,7 +491,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices, ...@@ -451,7 +491,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
template <typename IntT> template <typename IntT>
__global__ void GroupIndexs(const int n, __global__ void GroupIndexs(const int n,
const int kernel_size, const int offset,
const IntT* indexs, const IntT* indexs,
int* index_counts, int* index_counts,
int* index_groups) { int* index_groups) {
...@@ -460,7 +500,7 @@ __global__ void GroupIndexs(const int n, ...@@ -460,7 +500,7 @@ __global__ void GroupIndexs(const int n,
// kernel_size at most // kernel_size at most
int j = atomicAdd(index_counts + index, 1); int j = atomicAdd(index_counts + index, 1);
// nnz * kernel_size // nnz * kernel_size
index_groups[index * kernel_size + j] = i; index_groups[index * offset + j] = i;
} }
} }
...@@ -468,7 +508,7 @@ __global__ void GroupIndexs(const int n, ...@@ -468,7 +508,7 @@ __global__ void GroupIndexs(const int n,
template <typename IntT> template <typename IntT>
__global__ void GroupIndexsV2(const int rulebook_len, __global__ void GroupIndexsV2(const int rulebook_len,
const int non_zero_num, const int non_zero_num,
const int kernel_size, const int offset,
const int half_kernel_offset, const int half_kernel_offset,
const IntT* indexs, const IntT* indexs,
int* index_counts, int* index_counts,
...@@ -479,11 +519,11 @@ __global__ void GroupIndexsV2(const int rulebook_len, ...@@ -479,11 +519,11 @@ __global__ void GroupIndexsV2(const int rulebook_len,
i < half_kernel_offset ? index_counts : index_counts + non_zero_num; i < half_kernel_offset ? index_counts : index_counts + non_zero_num;
int* groups_ptr = i < half_kernel_offset int* groups_ptr = i < half_kernel_offset
? index_groups ? index_groups
: index_groups + non_zero_num * kernel_size; : index_groups + non_zero_num * offset;
// conflict kernel_size times at most // conflict kernel_size times at most
int j = atomicAdd(counts_ptr + index, 1); int j = atomicAdd(counts_ptr + index, 1);
// nnz * kernel_size // nnz * kernel_size
groups_ptr[index * kernel_size + j] = i; groups_ptr[index * offset + j] = i;
} }
} }
...@@ -582,6 +622,10 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -582,6 +622,10 @@ int ProductRuleBook(const Context& dev_ctx,
DenseTensor out_index_table = phi::Empty<int>(dev_ctx, {table_size}); DenseTensor out_index_table = phi::Empty<int>(dev_ctx, {table_size});
int* out_index_table_ptr = out_index_table.data<int>(); int* out_index_table_ptr = out_index_table.data<int>();
DenseTensor out_index_table2 = phi::Empty<int>(dev_ctx, {table_size + 1});
int* out_index_table2_ptr = out_index_table2.data<int>();
int* h_max_voxel = h_counter + kernel_size;
if (subm) { if (subm) {
DenseTensor tmp_rulebook = phi::Empty(dev_ctx, std::move(rulebook_meta)); DenseTensor tmp_rulebook = phi::Empty(dev_ctx, std::move(rulebook_meta));
IntT* rulebook_ptr = tmp_rulebook.data<IntT>(); IntT* rulebook_ptr = tmp_rulebook.data<IntT>();
...@@ -594,14 +638,29 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -594,14 +638,29 @@ int ProductRuleBook(const Context& dev_ctx,
phi::backends::gpu::GpuMemsetAsync( phi::backends::gpu::GpuMemsetAsync(
out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream()); out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream());
phi::backends::gpu::GpuMemsetAsync(out_index_table2_ptr,
0,
sizeof(int) * (table_size + 1),
dev_ctx.stream());
auto config = auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);
GetOutIndexTable<IntT><<<config.block_per_grid, GetOutIndexTable<IntT>
config.thread_per_block, <<<config.block_per_grid,
0, config.thread_per_block,
dev_ctx.stream()>>>( 0,
out_indices.data<IntT>(), non_zero_num, d_x_dims, out_index_table_ptr); dev_ctx.stream()>>>(out_indices.data<IntT>(),
non_zero_num,
d_x_dims,
out_index_table_ptr,
out_index_table2_ptr,
out_index_table2_ptr + table_size);
phi::backends::gpu::GpuMemcpyAsync(h_max_voxel,
out_index_table2_ptr + table_size,
sizeof(int),
gpuMemcpyDeviceToHost,
dev_ctx.stream());
dev_ctx.Wait();
size_t cache_size = kernel_size * 2 + kernel_size * size_t cache_size = kernel_size * 2 + kernel_size *
config.thread_per_block.x * 2 * config.thread_per_block.x * 2 *
...@@ -655,6 +714,22 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -655,6 +714,22 @@ int ProductRuleBook(const Context& dev_ctx,
out_rulebook_ptr); out_rulebook_ptr);
*rulebook = out_rulebook; *rulebook = out_rulebook;
unique_value->ResizeAndAllocate(
{static_cast<int>(non_zero_num * h_max_voxel[0] * kernel_size)});
int* unique_value_ptr = unique_value->data<int>();
out_index->ResizeAndAllocate({static_cast<int>(rulebook_len)});
int* out_index_ptr = out_index->data<int>();
phi::backends::gpu::GpuMemsetAsync(
out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream());
GroupIndexs<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(rulebook_len,
kernel_size * h_max_voxel[0],
out_rulebook_ptr + rulebook_len,
out_index_ptr,
unique_value_ptr);
return rulebook_len; return rulebook_len;
} else { } else {
...@@ -729,17 +804,35 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -729,17 +804,35 @@ int ProductRuleBook(const Context& dev_ctx,
IntT* out_indices_ptr = out_indices.data<IntT>(); IntT* out_indices_ptr = out_indices.data<IntT>();
phi::backends::gpu::GpuMemsetAsync(
out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream());
phi::backends::gpu::GpuMemsetAsync(out_index_table2_ptr,
0,
sizeof(int) * (table_size + 1),
dev_ctx.stream());
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_nnz, 1); config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_nnz, 1);
GetOutIndexTable<IntT><<<config.block_per_grid, GetOutIndexTable<IntT>
config.thread_per_block, <<<config.block_per_grid,
0, config.thread_per_block,
dev_ctx.stream()>>>(out_index_ptr, 0,
out_nnz, dev_ctx.stream()>>>(out_index_ptr,
d_out_dims, out_nnz,
out_index_table_ptr, d_out_dims,
out_indices_ptr); out_index_table_ptr,
out_index_table2_ptr,
out_index_table2_ptr + table_size,
out_indices_ptr);
phi::backends::gpu::GpuMemcpyAsync(h_max_voxel,
out_index_table2_ptr + table_size,
sizeof(int),
gpuMemcpyDeviceToHost,
dev_ctx.stream());
dev_ctx.Wait();
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
unique_value->ResizeAndAllocate({static_cast<int>(out_nnz * kernel_size)}); unique_value->ResizeAndAllocate(
{static_cast<int>(out_nnz * h_max_voxel[0] * kernel_size)});
int* unique_value_ptr = unique_value->data<int>(); int* unique_value_ptr = unique_value->data<int>();
GroupIndexs<<<config.block_per_grid, GroupIndexs<<<config.block_per_grid,
...@@ -747,7 +840,7 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -747,7 +840,7 @@ int ProductRuleBook(const Context& dev_ctx,
0, 0,
dev_ctx.stream()>>>(out_index_table_ptr, dev_ctx.stream()>>>(out_index_table_ptr,
rulebook_len, rulebook_len,
kernel_size, kernel_size * h_max_voxel[0],
rulebook_ptr + rulebook_len, rulebook_ptr + rulebook_len,
out_index_ptr, out_index_ptr,
unique_value_ptr); unique_value_ptr);
......
...@@ -124,10 +124,44 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -124,10 +124,44 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
} }
} }
int max_voxel = counter_ptr[kernel_size];
if (!subm) {
const auto& x_dims = x.dims();
Dims4D d_x_dims(x_dims[0], x_dims[3], x_dims[2], x_dims[1]);
int64_t table_size = 1;
for (int i = 0; i < x_dims.size() - 1; i++) {
table_size *= x_dims[i];
}
DenseTensor in_index_table = phi::Empty<int>(dev_ctx, {table_size + 1});
int* in_index_table_ptr = in_index_table.data<int>();
phi::backends::gpu::GpuMemsetAsync(in_index_table_ptr,
0,
sizeof(int) * (table_size + 1),
dev_ctx.stream());
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x.nnz(), 1);
GetOutIndexTable<IntT, false>
<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(x.non_zero_indices().data<IntT>(),
x.nnz(),
d_x_dims,
nullptr,
in_index_table_ptr,
in_index_table_ptr + table_size);
phi::backends::gpu::GpuMemcpyAsync(&max_voxel,
in_index_table_ptr + table_size,
sizeof(int),
gpuMemcpyDeviceToHost,
dev_ctx.stream());
dev_ctx.Wait();
}
auto config = auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
DenseTensor unique_value = phi::Empty<int>( DenseTensor unique_value = phi::Empty<int>(
dev_ctx, {static_cast<int>(x_grad->nnz() * kernel_size * 2)}); dev_ctx, {static_cast<int>(x_grad->nnz() * max_voxel * kernel_size * 2)});
DenseTensor out_index = DenseTensor out_index =
phi::Empty<int>(dev_ctx, {static_cast<int>(x.nnz() * 2)}); phi::Empty<int>(dev_ctx, {static_cast<int>(x.nnz() * 2)});
int* out_index_ptr = out_index.data<int>(); int* out_index_ptr = out_index.data<int>();
...@@ -140,7 +174,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -140,7 +174,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
0, 0,
dev_ctx.stream()>>>(rulebook_len, dev_ctx.stream()>>>(rulebook_len,
x.nnz(), x.nnz(),
kernel_size, kernel_size * max_voxel,
offsets[kernel_size / 2], offsets[kernel_size / 2],
rulebook_ptr, rulebook_ptr,
out_index_ptr, out_index_ptr,
...@@ -152,6 +186,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -152,6 +186,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
unique_value_ptr, unique_value_ptr,
x.nnz(), x.nnz(),
kernel_size, kernel_size,
max_voxel,
in_channels, in_channels,
2, 2,
in_features_ptr); in_features_ptr);
...@@ -212,6 +247,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -212,6 +247,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
unique_value.data<int>(), unique_value.data<int>(),
x_grad->nnz(), x_grad->nnz(),
kernel_size, kernel_size,
max_voxel,
in_channels, in_channels,
2, 2,
x_grad_values_ptr); x_grad_values_ptr);
......
...@@ -66,7 +66,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -66,7 +66,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const int in_channels = kernel_dims[3]; const int in_channels = kernel_dims[3];
const int out_channels = kernel_dims[4]; const int out_channels = kernel_dims[4];
DenseTensor h_counter, h_offsets; DenseTensor h_counter, h_offsets;
h_counter.Resize({kernel_size}); h_counter.Resize({kernel_size + 1});
h_offsets.Resize({kernel_size + 1}); h_offsets.Resize({kernel_size + 1});
int* h_counter_ptr = dev_ctx.template HostAlloc<int>(&h_counter); int* h_counter_ptr = dev_ctx.template HostAlloc<int>(&h_counter);
int* h_offsets_ptr = dev_ctx.template HostAlloc<int>(&h_offsets); int* h_offsets_ptr = dev_ctx.template HostAlloc<int>(&h_offsets);
...@@ -74,7 +74,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -74,7 +74,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
// Second algorithm: // Second algorithm:
// https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf // https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf
// 1. product rulebook // 1. product rulebook
DenseTensor counter_per_kernel = phi::Empty<int>(dev_ctx, {kernel_size}); DenseTensor counter_per_kernel = phi::Empty<int>(dev_ctx, {kernel_size + 1});
DenseTensor offsets_per_kernel = phi::Empty<int>(dev_ctx, {kernel_size}); DenseTensor offsets_per_kernel = phi::Empty<int>(dev_ctx, {kernel_size});
DenseTensor out_index = phi::Empty<int>(dev_ctx, {1}); DenseTensor out_index = phi::Empty<int>(dev_ctx, {1});
DenseTensor unique_value = phi::Empty<int>(dev_ctx, {1}); DenseTensor unique_value = phi::Empty<int>(dev_ctx, {1});
...@@ -143,26 +143,6 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -143,26 +143,6 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
T* out_values_ptr = out_values->data<T>(); T* out_values_ptr = out_values->data<T>();
set_zero(dev_ctx, out_values, static_cast<T>(0.0f)); set_zero(dev_ctx, out_values, static_cast<T>(0.0f));
if (subm) {
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
unique_value.ResizeAndAllocate(
{static_cast<int>(out->nnz() * kernel_size)});
out_index.ResizeAndAllocate({static_cast<int>(rulebook_len)});
int* out_index_ptr = out_index.data<int>();
int* unique_value_ptr = unique_value.data<int>();
phi::backends::gpu::GpuMemsetAsync(
out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream());
GroupIndexs<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(rulebook_len,
kernel_size,
rulebook_ptr + rulebook_len,
out_index_ptr,
unique_value_ptr);
}
const T* kernel_ptr = kernel.data<T>(); const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) { for (int i = 0; i < kernel_size; i++) {
if (h_counter_ptr[i] <= 0) { if (h_counter_ptr[i] <= 0) {
...@@ -196,6 +176,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -196,6 +176,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
unique_value.data<int>(), unique_value.data<int>(),
out->nnz(), out->nnz(),
kernel_size, kernel_size,
h_counter_ptr[kernel_size],
out_channels, out_channels,
1, 1,
out_values_ptr); out_values_ptr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册