未验证 提交 23e06680 编写于 作者: Z zhangkaihuo 提交者: GitHub

[Sparse]Conv sort out (#46216)

* sort out index
上级 7c4efa5a
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <thrust/remove.h>
#include <thrust/sort.h>
#include <thrust/unique.h>
#include "paddle/phi/kernels/sparse/conv_kernel.h"
......@@ -186,8 +187,7 @@ __global__ void UniqueKernel(const IntT* in_indexs,
if (i < rulebook_len) {
// atomicOr only support int
int index = static_cast<int>(in_indexs[i]);
int change_index = index == 0 ? -1 : index;
int flag = atomicOr(out_index_table + index, change_index);
int flag = atomicOr(out_index_table + index, 1);
if (flag == 0) {
int j = atomicAdd(&count, 1);
cache[j] = index;
......@@ -772,6 +772,7 @@ int ProductRuleBook(const Context& dev_ctx,
phi::backends::gpu::GpuMemsetAsync(
out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream());
phi::backends::gpu::GpuMemsetAsync(
unique_key_ptr, 0, sizeof(int), dev_ctx.stream());
......@@ -785,6 +786,7 @@ int ProductRuleBook(const Context& dev_ctx,
out_index_table_ptr,
out_index_ptr,
unique_key_ptr);
int out_nnz = 0;
phi::backends::gpu::GpuMemcpyAsync(&out_nnz,
unique_key_ptr,
......@@ -792,6 +794,13 @@ int ProductRuleBook(const Context& dev_ctx,
gpuMemcpyDeviceToHost,
dev_ctx.stream());
dev_ctx.Wait();
#ifdef PADDLE_WITH_HIP
thrust::sort(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::sort(thrust::cuda::par.on(dev_ctx.stream()),
#endif
out_index_ptr,
out_index_ptr + out_nnz);
const int64_t sparse_dim = 4;
phi::DenseTensor out_indices =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册