未验证 提交 23e06680 编写于 作者: Z zhangkaihuo 提交者: GitHub

[Sparse]Conv sort out (#46216)

* sort out index
上级 7c4efa5a
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <thrust/remove.h> #include <thrust/remove.h>
#include <thrust/sort.h>
#include <thrust/unique.h> #include <thrust/unique.h>
#include "paddle/phi/kernels/sparse/conv_kernel.h" #include "paddle/phi/kernels/sparse/conv_kernel.h"
...@@ -186,8 +187,7 @@ __global__ void UniqueKernel(const IntT* in_indexs, ...@@ -186,8 +187,7 @@ __global__ void UniqueKernel(const IntT* in_indexs,
if (i < rulebook_len) { if (i < rulebook_len) {
// atomicOr only support int // atomicOr only support int
int index = static_cast<int>(in_indexs[i]); int index = static_cast<int>(in_indexs[i]);
int change_index = index == 0 ? -1 : index; int flag = atomicOr(out_index_table + index, 1);
int flag = atomicOr(out_index_table + index, change_index);
if (flag == 0) { if (flag == 0) {
int j = atomicAdd(&count, 1); int j = atomicAdd(&count, 1);
cache[j] = index; cache[j] = index;
...@@ -772,6 +772,7 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -772,6 +772,7 @@ int ProductRuleBook(const Context& dev_ctx,
phi::backends::gpu::GpuMemsetAsync( phi::backends::gpu::GpuMemsetAsync(
out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream()); out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream());
phi::backends::gpu::GpuMemsetAsync( phi::backends::gpu::GpuMemsetAsync(
unique_key_ptr, 0, sizeof(int), dev_ctx.stream()); unique_key_ptr, 0, sizeof(int), dev_ctx.stream());
...@@ -785,6 +786,7 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -785,6 +786,7 @@ int ProductRuleBook(const Context& dev_ctx,
out_index_table_ptr, out_index_table_ptr,
out_index_ptr, out_index_ptr,
unique_key_ptr); unique_key_ptr);
int out_nnz = 0; int out_nnz = 0;
phi::backends::gpu::GpuMemcpyAsync(&out_nnz, phi::backends::gpu::GpuMemcpyAsync(&out_nnz,
unique_key_ptr, unique_key_ptr,
...@@ -792,6 +794,13 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -792,6 +794,13 @@ int ProductRuleBook(const Context& dev_ctx,
gpuMemcpyDeviceToHost, gpuMemcpyDeviceToHost,
dev_ctx.stream()); dev_ctx.stream());
dev_ctx.Wait(); dev_ctx.Wait();
#ifdef PADDLE_WITH_HIP
thrust::sort(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::sort(thrust::cuda::par.on(dev_ctx.stream()),
#endif
out_index_ptr,
out_index_ptr + out_nnz);
const int64_t sparse_dim = 4; const int64_t sparse_dim = 4;
phi::DenseTensor out_indices = phi::DenseTensor out_indices =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册