未验证 提交 41ba2722 编写于 作者: Z zhangkaihuo 提交者: GitHub

group the index in not cutlass mode (#48439)

上级 505f4100
...@@ -123,25 +123,6 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -123,25 +123,6 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
dev_ctx, x, key, tmp_rulebook, h_counter, out, rulebook, counter); dev_ctx, x, key, tmp_rulebook, h_counter, out, rulebook, counter);
} }
if (subm) {
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
unique_value.ResizeAndAllocate(
{static_cast<int>(out->nnz() * kernel_size)});
out_index.ResizeAndAllocate({static_cast<int>(rulebook_len)});
int* out_index_ptr = out_index.data<int>();
int* unique_value_ptr = unique_value.data<int>();
phi::backends::gpu::GpuMemsetAsync(
out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream());
GroupIndexs<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(rulebook_len,
kernel_size,
rulebook_ptr + rulebook_len,
out_index_ptr,
unique_value_ptr);
}
#ifdef PADDLE_WITH_CUTLASS #ifdef PADDLE_WITH_CUTLASS
bool cutlass = true; bool cutlass = true;
if (dev_ctx.GetComputeCapability() < 80) cutlass = false; if (dev_ctx.GetComputeCapability() < 80) cutlass = false;
...@@ -226,6 +207,25 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -226,6 +207,25 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
} }
} else { } else {
#endif #endif
if (subm) {
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
unique_value.ResizeAndAllocate(
{static_cast<int>(out->nnz() * kernel_size)});
out_index.ResizeAndAllocate({static_cast<int>(rulebook_len)});
int* out_index_ptr = out_index.data<int>();
int* unique_value_ptr = unique_value.data<int>();
phi::backends::gpu::GpuMemsetAsync(
out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream());
GroupIndexs<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(rulebook_len,
kernel_size,
rulebook_ptr + rulebook_len,
out_index_ptr,
unique_value_ptr);
}
// 2. gather // 2. gather
phi::DenseTensor in_features = phi::DenseTensor in_features =
phi::Empty<T>(dev_ctx, {rulebook_len, in_channels}); phi::Empty<T>(dev_ctx, {rulebook_len, in_channels});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册