Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e8de9dfd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
e8de9dfd
编写于
8月 08, 2022
作者:
Z
zhangkaihuo
提交者:
GitHub
8月 08, 2022
1
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
SparseConv support duplicate coordinates (#44976)
* sparse conv support duplicate coordinates
上级
090caa0e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
168 addition
and
53 deletion
+168
-53
paddle/phi/kernels/funcs/sparse/scatter.cu.h
paddle/phi/kernels/funcs/sparse/scatter.cu.h
+7
-2
paddle/phi/kernels/sparse/gpu/conv.cu.h
paddle/phi/kernels/sparse/gpu/conv.cu.h
+120
-27
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
+38
-2
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
+3
-22
未找到文件。
paddle/phi/kernels/funcs/sparse/scatter.cu.h
浏览文件 @
e8de9dfd
...
@@ -79,6 +79,7 @@ __global__ void ScatterKernelV2(const T* input,
...
@@ -79,6 +79,7 @@ __global__ void ScatterKernelV2(const T* input,
const
int
*
index_groups
,
const
int
*
index_groups
,
const
int
non_zero_num
,
const
int
non_zero_num
,
const
int
kernel_size
,
const
int
kernel_size
,
const
int
max_voxel
,
const
int
channels
,
const
int
channels
,
const
int
buffer_counts
,
const
int
buffer_counts
,
T
*
out
)
{
T
*
out
)
{
...
@@ -96,10 +97,11 @@ __global__ void ScatterKernelV2(const T* input,
...
@@ -96,10 +97,11 @@ __global__ void ScatterKernelV2(const T* input,
&
sums
);
&
sums
);
for
(
int
it
=
0
;
it
<
buffer_counts
;
it
++
)
{
for
(
int
it
=
0
;
it
<
buffer_counts
;
it
++
)
{
int
len
=
index_counts
[
indices_i
+
it
*
non_zero_num
];
int
len
=
index_counts
[
indices_i
+
it
*
non_zero_num
];
const
int
group_offset
=
it
*
kernel_size
*
non_zero_num
;
const
int
group_offset
=
it
*
max_voxel
*
kernel_size
*
non_zero_num
;
for
(
int
j
=
0
;
j
<
len
;
j
++
)
{
for
(
int
j
=
0
;
j
<
len
;
j
++
)
{
const
int
out_feature_i
=
const
int
out_feature_i
=
index_groups
[
indices_i
*
kernel_size
+
j
+
group_offset
];
index_groups
[
indices_i
*
max_voxel
*
kernel_size
+
j
+
group_offset
];
LoadT
vec_in
;
LoadT
vec_in
;
phi
::
Load
<
T
,
VecSize
>
(
phi
::
Load
<
T
,
VecSize
>
(
input
+
out_feature_i
*
channels
+
channels_i
*
VecSize
,
&
vec_in
);
input
+
out_feature_i
*
channels
+
channels_i
*
VecSize
,
&
vec_in
);
...
@@ -121,6 +123,7 @@ void ScatterV2(const GPUContext& dev_ctx,
...
@@ -121,6 +123,7 @@ void ScatterV2(const GPUContext& dev_ctx,
const
int
*
index_groups
,
const
int
*
index_groups
,
const
int
non_zero_num
,
const
int
non_zero_num
,
const
int
kernel_size
,
const
int
kernel_size
,
const
int
max_voxel
,
const
int
channels
,
const
int
channels
,
const
int
buffer_counts
,
const
int
buffer_counts
,
T
*
output
)
{
T
*
output
)
{
...
@@ -136,6 +139,7 @@ void ScatterV2(const GPUContext& dev_ctx,
...
@@ -136,6 +139,7 @@ void ScatterV2(const GPUContext& dev_ctx,
index_groups
,
index_groups
,
non_zero_num
,
non_zero_num
,
kernel_size
,
kernel_size
,
max_voxel
,
channels
,
channels
,
buffer_counts
,
buffer_counts
,
output
);
output
);
...
@@ -150,6 +154,7 @@ void ScatterV2(const GPUContext& dev_ctx,
...
@@ -150,6 +154,7 @@ void ScatterV2(const GPUContext& dev_ctx,
index_groups
,
index_groups
,
non_zero_num
,
non_zero_num
,
kernel_size
,
kernel_size
,
max_voxel
,
channels
,
channels
,
buffer_counts
,
buffer_counts
,
output
);
output
);
...
...
paddle/phi/kernels/sparse/gpu/conv.cu.h
浏览文件 @
e8de9dfd
...
@@ -65,6 +65,7 @@ __global__ void GatherKernelV2(const T* inputs,
...
@@ -65,6 +65,7 @@ __global__ void GatherKernelV2(const T* inputs,
const
int
*
index_groups
,
const
int
*
index_groups
,
const
int
non_zero_num
,
const
int
non_zero_num
,
const
int
kernel_size
,
const
int
kernel_size
,
const
int
max_voxel
,
const
int
channels
,
const
int
channels
,
const
int
buffer_count
,
const
int
buffer_count
,
T
*
output
)
{
T
*
output
)
{
...
@@ -82,10 +83,11 @@ __global__ void GatherKernelV2(const T* inputs,
...
@@ -82,10 +83,11 @@ __global__ void GatherKernelV2(const T* inputs,
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
buffer_count
;
it
++
)
{
for
(
int
it
=
0
;
it
<
buffer_count
;
it
++
)
{
int
len
=
index_counts
[
indices_i
+
it
*
non_zero_num
];
int
len
=
index_counts
[
indices_i
+
it
*
non_zero_num
];
const
int
group_offset
=
it
*
kernel_size
*
non_zero_num
;
const
int
group_offset
=
it
*
kernel_size
*
max_voxel
*
non_zero_num
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
len
;
j
++
)
{
for
(
int
j
=
0
;
j
<
len
;
j
++
)
{
int
out_i
=
index_groups
[
indices_i
*
kernel_size
+
j
+
group_offset
];
int
out_i
=
index_groups
[
indices_i
*
kernel_size
*
max_voxel
+
j
+
group_offset
];
phi
::
Store
<
T
,
VecSize
>
(
phi
::
Store
<
T
,
VecSize
>
(
in_vec
,
output
+
out_i
*
channels
+
channels_i
*
VecSize
);
in_vec
,
output
+
out_i
*
channels
+
channels_i
*
VecSize
);
}
}
...
@@ -127,6 +129,7 @@ inline void GatherV2(const GPUContext& dev_ctx,
...
@@ -127,6 +129,7 @@ inline void GatherV2(const GPUContext& dev_ctx,
const
int
*
index_groups
,
const
int
*
index_groups
,
const
int
non_zero_num
,
const
int
non_zero_num
,
const
int
kernel_size
,
const
int
kernel_size
,
const
int
max_voxel
,
const
int
channels
,
const
int
channels
,
const
int
buffer_count
,
const
int
buffer_count
,
T
*
output
)
{
T
*
output
)
{
...
@@ -142,6 +145,7 @@ inline void GatherV2(const GPUContext& dev_ctx,
...
@@ -142,6 +145,7 @@ inline void GatherV2(const GPUContext& dev_ctx,
index_groups
,
index_groups
,
non_zero_num
,
non_zero_num
,
kernel_size
,
kernel_size
,
max_voxel
,
channels
,
channels
,
buffer_count
,
buffer_count
,
output
);
output
);
...
@@ -156,6 +160,7 @@ inline void GatherV2(const GPUContext& dev_ctx,
...
@@ -156,6 +160,7 @@ inline void GatherV2(const GPUContext& dev_ctx,
index_groups
,
index_groups
,
non_zero_num
,
non_zero_num
,
kernel_size
,
kernel_size
,
max_voxel
,
channels
,
channels
,
buffer_count
,
buffer_count
,
output
);
output
);
...
@@ -202,7 +207,7 @@ __global__ void UniqueKernel(const IntT* in_indexs,
...
@@ -202,7 +207,7 @@ __global__ void UniqueKernel(const IntT* in_indexs,
template
<
typename
IntT
>
template
<
typename
IntT
>
__global__
void
GroupIndexs
(
const
int
*
out_index_table
,
__global__
void
GroupIndexs
(
const
int
*
out_index_table
,
const
int
n
,
const
int
n
,
const
int
kernel_size
,
const
int
offset
,
IntT
*
out_indexs
,
IntT
*
out_indexs
,
int
*
out_index_counts
,
int
*
out_index_counts
,
int
*
out_index_groups
)
{
int
*
out_index_groups
)
{
...
@@ -214,7 +219,7 @@ __global__ void GroupIndexs(const int* out_index_table,
...
@@ -214,7 +219,7 @@ __global__ void GroupIndexs(const int* out_index_table,
// kernel_size at most
// kernel_size at most
int
j
=
atomicAdd
(
out_index_counts
+
real_index
,
1
);
int
j
=
atomicAdd
(
out_index_counts
+
real_index
,
1
);
// nnz * kernel_size
// nnz * kernel_size
out_index_groups
[
real_index
*
kernel_size
+
j
]
=
i
;
out_index_groups
[
real_index
*
offset
+
j
]
=
i
;
}
}
}
}
...
@@ -298,18 +303,36 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
...
@@ -298,18 +303,36 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
}
}
}
}
template
<
typename
IntT
>
template
<
typename
IntT
,
bool
save_out_index
=
true
>
__global__
void
GetOutIndexTable
(
const
IntT
*
indices
,
__global__
void
GetOutIndexTable
(
const
IntT
*
indices
,
const
IntT
non_zero_num
,
const
IntT
non_zero_num
,
const
Dims4D
dims
,
const
Dims4D
dims
,
int
*
out_index_table
)
{
int
*
out_index_table
,
int
*
out_index_table2
,
int
*
max_voxel
)
{
__shared__
int
cache_max
;
if
(
threadIdx
.
x
==
0
)
{
cache_max
=
0
;
}
__syncthreads
();
CUDA_KERNEL_LOOP_TYPE
(
i
,
non_zero_num
,
int64_t
)
{
CUDA_KERNEL_LOOP_TYPE
(
i
,
non_zero_num
,
int64_t
)
{
IntT
batch
=
indices
[
i
];
IntT
batch
=
indices
[
i
];
IntT
in_z
=
indices
[
i
+
non_zero_num
];
IntT
in_z
=
indices
[
i
+
non_zero_num
];
IntT
in_y
=
indices
[
i
+
2
*
non_zero_num
];
IntT
in_y
=
indices
[
i
+
2
*
non_zero_num
];
IntT
in_x
=
indices
[
i
+
3
*
non_zero_num
];
IntT
in_x
=
indices
[
i
+
3
*
non_zero_num
];
IntT
index
=
PointToIndex
(
batch
,
in_x
,
in_y
,
in_z
,
dims
);
IntT
index
=
PointToIndex
(
batch
,
in_x
,
in_y
,
in_z
,
dims
);
out_index_table
[
index
]
=
i
==
0
?
-
1
:
i
;
if
(
save_out_index
)
{
out_index_table
[
index
]
=
i
==
0
?
-
1
:
i
;
}
int
count
=
atomicAdd
(
out_index_table2
+
index
,
1
);
atomicMax
(
&
cache_max
,
count
);
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
atomicMax
(
max_voxel
,
cache_max
+
1
);
}
}
}
}
...
@@ -318,10 +341,22 @@ __global__ void GetOutIndexTable(int* indexs,
...
@@ -318,10 +341,22 @@ __global__ void GetOutIndexTable(int* indexs,
const
int
non_zero_num
,
const
int
non_zero_num
,
const
Dims4D
out_dims
,
const
Dims4D
out_dims
,
int
*
out_index_table
,
int
*
out_index_table
,
int
*
out_index_table2
,
int
*
max_voxel
,
IntT
*
out_indices
)
{
IntT
*
out_indices
)
{
__shared__
int
cache_max
;
if
(
threadIdx
.
x
==
0
)
{
cache_max
=
0
;
}
__syncthreads
();
CUDA_KERNEL_LOOP_TYPE
(
i
,
non_zero_num
,
int64_t
)
{
CUDA_KERNEL_LOOP_TYPE
(
i
,
non_zero_num
,
int64_t
)
{
IntT
index
=
static_cast
<
IntT
>
(
indexs
[
i
]);
IntT
index
=
static_cast
<
IntT
>
(
indexs
[
i
]);
out_index_table
[
index
]
=
i
;
out_index_table
[
index
]
=
i
;
int
count
=
atomicAdd
(
out_index_table2
+
index
,
1
);
atomicMax
(
&
cache_max
,
count
);
IntT
batch
,
x
,
y
,
z
;
IntT
batch
,
x
,
y
,
z
;
phi
::
funcs
::
sparse
::
IndexToPoint
<
Dims4D
>
(
phi
::
funcs
::
sparse
::
IndexToPoint
<
Dims4D
>
(
index
,
out_dims
,
&
batch
,
&
x
,
&
y
,
&
z
);
index
,
out_dims
,
&
batch
,
&
x
,
&
y
,
&
z
);
...
@@ -332,6 +367,11 @@ __global__ void GetOutIndexTable(int* indexs,
...
@@ -332,6 +367,11 @@ __global__ void GetOutIndexTable(int* indexs,
out_indices
[
i
+
non_zero_num
*
3
]
=
x
;
out_indices
[
i
+
non_zero_num
*
3
]
=
x
;
indexs
[
i
]
=
0
;
indexs
[
i
]
=
0
;
}
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
atomicMax
(
max_voxel
,
cache_max
+
1
);
}
}
}
template
<
typename
IntT
>
template
<
typename
IntT
>
...
@@ -451,7 +491,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
...
@@ -451,7 +491,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
template
<
typename
IntT
>
template
<
typename
IntT
>
__global__
void
GroupIndexs
(
const
int
n
,
__global__
void
GroupIndexs
(
const
int
n
,
const
int
kernel_size
,
const
int
offset
,
const
IntT
*
indexs
,
const
IntT
*
indexs
,
int
*
index_counts
,
int
*
index_counts
,
int
*
index_groups
)
{
int
*
index_groups
)
{
...
@@ -460,7 +500,7 @@ __global__ void GroupIndexs(const int n,
...
@@ -460,7 +500,7 @@ __global__ void GroupIndexs(const int n,
// kernel_size at most
// kernel_size at most
int
j
=
atomicAdd
(
index_counts
+
index
,
1
);
int
j
=
atomicAdd
(
index_counts
+
index
,
1
);
// nnz * kernel_size
// nnz * kernel_size
index_groups
[
index
*
kernel_size
+
j
]
=
i
;
index_groups
[
index
*
offset
+
j
]
=
i
;
}
}
}
}
...
@@ -468,7 +508,7 @@ __global__ void GroupIndexs(const int n,
...
@@ -468,7 +508,7 @@ __global__ void GroupIndexs(const int n,
template
<
typename
IntT
>
template
<
typename
IntT
>
__global__
void
GroupIndexsV2
(
const
int
rulebook_len
,
__global__
void
GroupIndexsV2
(
const
int
rulebook_len
,
const
int
non_zero_num
,
const
int
non_zero_num
,
const
int
kernel_size
,
const
int
offset
,
const
int
half_kernel_offset
,
const
int
half_kernel_offset
,
const
IntT
*
indexs
,
const
IntT
*
indexs
,
int
*
index_counts
,
int
*
index_counts
,
...
@@ -479,11 +519,11 @@ __global__ void GroupIndexsV2(const int rulebook_len,
...
@@ -479,11 +519,11 @@ __global__ void GroupIndexsV2(const int rulebook_len,
i
<
half_kernel_offset
?
index_counts
:
index_counts
+
non_zero_num
;
i
<
half_kernel_offset
?
index_counts
:
index_counts
+
non_zero_num
;
int
*
groups_ptr
=
i
<
half_kernel_offset
int
*
groups_ptr
=
i
<
half_kernel_offset
?
index_groups
?
index_groups
:
index_groups
+
non_zero_num
*
kernel_size
;
:
index_groups
+
non_zero_num
*
offset
;
// conflict kernel_size times at most
// conflict kernel_size times at most
int
j
=
atomicAdd
(
counts_ptr
+
index
,
1
);
int
j
=
atomicAdd
(
counts_ptr
+
index
,
1
);
// nnz * kernel_size
// nnz * kernel_size
groups_ptr
[
index
*
kernel_size
+
j
]
=
i
;
groups_ptr
[
index
*
offset
+
j
]
=
i
;
}
}
}
}
...
@@ -582,6 +622,10 @@ int ProductRuleBook(const Context& dev_ctx,
...
@@ -582,6 +622,10 @@ int ProductRuleBook(const Context& dev_ctx,
DenseTensor
out_index_table
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
table_size
});
DenseTensor
out_index_table
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
table_size
});
int
*
out_index_table_ptr
=
out_index_table
.
data
<
int
>
();
int
*
out_index_table_ptr
=
out_index_table
.
data
<
int
>
();
DenseTensor
out_index_table2
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
table_size
+
1
});
int
*
out_index_table2_ptr
=
out_index_table2
.
data
<
int
>
();
int
*
h_max_voxel
=
h_counter
+
kernel_size
;
if
(
subm
)
{
if
(
subm
)
{
DenseTensor
tmp_rulebook
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
rulebook_meta
));
DenseTensor
tmp_rulebook
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
rulebook_meta
));
IntT
*
rulebook_ptr
=
tmp_rulebook
.
data
<
IntT
>
();
IntT
*
rulebook_ptr
=
tmp_rulebook
.
data
<
IntT
>
();
...
@@ -594,14 +638,29 @@ int ProductRuleBook(const Context& dev_ctx,
...
@@ -594,14 +638,29 @@ int ProductRuleBook(const Context& dev_ctx,
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_index_table_ptr
,
0
,
sizeof
(
int
)
*
table_size
,
dev_ctx
.
stream
());
out_index_table_ptr
,
0
,
sizeof
(
int
)
*
table_size
,
dev_ctx
.
stream
());
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_index_table2_ptr
,
0
,
sizeof
(
int
)
*
(
table_size
+
1
),
dev_ctx
.
stream
());
auto
config
=
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
non_zero_num
,
1
);
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
non_zero_num
,
1
);
GetOutIndexTable
<
IntT
><<<
config
.
block_per_grid
,
GetOutIndexTable
<
IntT
>
config
.
thread_per_block
,
<<<
config
.
block_per_grid
,
0
,
config
.
thread_per_block
,
dev_ctx
.
stream
()
>>>
(
0
,
out_indices
.
data
<
IntT
>
(),
non_zero_num
,
d_x_dims
,
out_index_table_ptr
);
dev_ctx
.
stream
()
>>>
(
out_indices
.
data
<
IntT
>
(),
non_zero_num
,
d_x_dims
,
out_index_table_ptr
,
out_index_table2_ptr
,
out_index_table2_ptr
+
table_size
);
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
h_max_voxel
,
out_index_table2_ptr
+
table_size
,
sizeof
(
int
),
gpuMemcpyDeviceToHost
,
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
size_t
cache_size
=
kernel_size
*
2
+
kernel_size
*
size_t
cache_size
=
kernel_size
*
2
+
kernel_size
*
config
.
thread_per_block
.
x
*
2
*
config
.
thread_per_block
.
x
*
2
*
...
@@ -655,6 +714,22 @@ int ProductRuleBook(const Context& dev_ctx,
...
@@ -655,6 +714,22 @@ int ProductRuleBook(const Context& dev_ctx,
out_rulebook_ptr
);
out_rulebook_ptr
);
*
rulebook
=
out_rulebook
;
*
rulebook
=
out_rulebook
;
unique_value
->
ResizeAndAllocate
(
{
static_cast
<
int
>
(
non_zero_num
*
h_max_voxel
[
0
]
*
kernel_size
)});
int
*
unique_value_ptr
=
unique_value
->
data
<
int
>
();
out_index
->
ResizeAndAllocate
({
static_cast
<
int
>
(
rulebook_len
)});
int
*
out_index_ptr
=
out_index
->
data
<
int
>
();
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_index_ptr
,
0
,
sizeof
(
int
)
*
rulebook_len
,
dev_ctx
.
stream
());
GroupIndexs
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
rulebook_len
,
kernel_size
*
h_max_voxel
[
0
],
out_rulebook_ptr
+
rulebook_len
,
out_index_ptr
,
unique_value_ptr
);
return
rulebook_len
;
return
rulebook_len
;
}
else
{
}
else
{
...
@@ -729,17 +804,35 @@ int ProductRuleBook(const Context& dev_ctx,
...
@@ -729,17 +804,35 @@ int ProductRuleBook(const Context& dev_ctx,
IntT
*
out_indices_ptr
=
out_indices
.
data
<
IntT
>
();
IntT
*
out_indices_ptr
=
out_indices
.
data
<
IntT
>
();
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_index_table_ptr
,
0
,
sizeof
(
int
)
*
table_size
,
dev_ctx
.
stream
());
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_index_table2_ptr
,
0
,
sizeof
(
int
)
*
(
table_size
+
1
),
dev_ctx
.
stream
());
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
out_nnz
,
1
);
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
out_nnz
,
1
);
GetOutIndexTable
<
IntT
><<<
config
.
block_per_grid
,
GetOutIndexTable
<
IntT
>
config
.
thread_per_block
,
<<<
config
.
block_per_grid
,
0
,
config
.
thread_per_block
,
dev_ctx
.
stream
()
>>>
(
out_index_ptr
,
0
,
out_nnz
,
dev_ctx
.
stream
()
>>>
(
out_index_ptr
,
d_out_dims
,
out_nnz
,
out_index_table_ptr
,
d_out_dims
,
out_indices_ptr
);
out_index_table_ptr
,
out_index_table2_ptr
,
out_index_table2_ptr
+
table_size
,
out_indices_ptr
);
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
h_max_voxel
,
out_index_table2_ptr
+
table_size
,
sizeof
(
int
),
gpuMemcpyDeviceToHost
,
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
,
1
);
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
,
1
);
unique_value
->
ResizeAndAllocate
({
static_cast
<
int
>
(
out_nnz
*
kernel_size
)});
unique_value
->
ResizeAndAllocate
(
{
static_cast
<
int
>
(
out_nnz
*
h_max_voxel
[
0
]
*
kernel_size
)});
int
*
unique_value_ptr
=
unique_value
->
data
<
int
>
();
int
*
unique_value_ptr
=
unique_value
->
data
<
int
>
();
GroupIndexs
<<<
config
.
block_per_grid
,
GroupIndexs
<<<
config
.
block_per_grid
,
...
@@ -747,7 +840,7 @@ int ProductRuleBook(const Context& dev_ctx,
...
@@ -747,7 +840,7 @@ int ProductRuleBook(const Context& dev_ctx,
0
,
0
,
dev_ctx
.
stream
()
>>>
(
out_index_table_ptr
,
dev_ctx
.
stream
()
>>>
(
out_index_table_ptr
,
rulebook_len
,
rulebook_len
,
kernel_size
,
kernel_size
*
h_max_voxel
[
0
]
,
rulebook_ptr
+
rulebook_len
,
rulebook_ptr
+
rulebook_len
,
out_index_ptr
,
out_index_ptr
,
unique_value_ptr
);
unique_value_ptr
);
...
...
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
浏览文件 @
e8de9dfd
...
@@ -124,10 +124,44 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
...
@@ -124,10 +124,44 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
}
}
}
}
int
max_voxel
=
counter_ptr
[
kernel_size
];
if
(
!
subm
)
{
const
auto
&
x_dims
=
x
.
dims
();
Dims4D
d_x_dims
(
x_dims
[
0
],
x_dims
[
3
],
x_dims
[
2
],
x_dims
[
1
]);
int64_t
table_size
=
1
;
for
(
int
i
=
0
;
i
<
x_dims
.
size
()
-
1
;
i
++
)
{
table_size
*=
x_dims
[
i
];
}
DenseTensor
in_index_table
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
table_size
+
1
});
int
*
in_index_table_ptr
=
in_index_table
.
data
<
int
>
();
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
in_index_table_ptr
,
0
,
sizeof
(
int
)
*
(
table_size
+
1
),
dev_ctx
.
stream
());
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
x
.
nnz
(),
1
);
GetOutIndexTable
<
IntT
,
false
>
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
x
.
non_zero_indices
().
data
<
IntT
>
(),
x
.
nnz
(),
d_x_dims
,
nullptr
,
in_index_table_ptr
,
in_index_table_ptr
+
table_size
);
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
max_voxel
,
in_index_table_ptr
+
table_size
,
sizeof
(
int
),
gpuMemcpyDeviceToHost
,
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
}
auto
config
=
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
,
1
);
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
,
1
);
DenseTensor
unique_value
=
phi
::
Empty
<
int
>
(
DenseTensor
unique_value
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
static_cast
<
int
>
(
x_grad
->
nnz
()
*
kernel_size
*
2
)});
dev_ctx
,
{
static_cast
<
int
>
(
x_grad
->
nnz
()
*
max_voxel
*
kernel_size
*
2
)});
DenseTensor
out_index
=
DenseTensor
out_index
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
static_cast
<
int
>
(
x
.
nnz
()
*
2
)});
phi
::
Empty
<
int
>
(
dev_ctx
,
{
static_cast
<
int
>
(
x
.
nnz
()
*
2
)});
int
*
out_index_ptr
=
out_index
.
data
<
int
>
();
int
*
out_index_ptr
=
out_index
.
data
<
int
>
();
...
@@ -140,7 +174,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
...
@@ -140,7 +174,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
0
,
0
,
dev_ctx
.
stream
()
>>>
(
rulebook_len
,
dev_ctx
.
stream
()
>>>
(
rulebook_len
,
x
.
nnz
(),
x
.
nnz
(),
kernel_size
,
kernel_size
*
max_voxel
,
offsets
[
kernel_size
/
2
],
offsets
[
kernel_size
/
2
],
rulebook_ptr
,
rulebook_ptr
,
out_index_ptr
,
out_index_ptr
,
...
@@ -152,6 +186,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
...
@@ -152,6 +186,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
unique_value_ptr
,
unique_value_ptr
,
x
.
nnz
(),
x
.
nnz
(),
kernel_size
,
kernel_size
,
max_voxel
,
in_channels
,
in_channels
,
2
,
2
,
in_features_ptr
);
in_features_ptr
);
...
@@ -212,6 +247,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
...
@@ -212,6 +247,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
unique_value
.
data
<
int
>
(),
unique_value
.
data
<
int
>
(),
x_grad
->
nnz
(),
x_grad
->
nnz
(),
kernel_size
,
kernel_size
,
max_voxel
,
in_channels
,
in_channels
,
2
,
2
,
x_grad_values_ptr
);
x_grad_values_ptr
);
...
...
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
浏览文件 @
e8de9dfd
...
@@ -66,7 +66,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
...
@@ -66,7 +66,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const
int
in_channels
=
kernel_dims
[
3
];
const
int
in_channels
=
kernel_dims
[
3
];
const
int
out_channels
=
kernel_dims
[
4
];
const
int
out_channels
=
kernel_dims
[
4
];
DenseTensor
h_counter
,
h_offsets
;
DenseTensor
h_counter
,
h_offsets
;
h_counter
.
Resize
({
kernel_size
});
h_counter
.
Resize
({
kernel_size
+
1
});
h_offsets
.
Resize
({
kernel_size
+
1
});
h_offsets
.
Resize
({
kernel_size
+
1
});
int
*
h_counter_ptr
=
dev_ctx
.
template
HostAlloc
<
int
>(
&
h_counter
);
int
*
h_counter_ptr
=
dev_ctx
.
template
HostAlloc
<
int
>(
&
h_counter
);
int
*
h_offsets_ptr
=
dev_ctx
.
template
HostAlloc
<
int
>(
&
h_offsets
);
int
*
h_offsets_ptr
=
dev_ctx
.
template
HostAlloc
<
int
>(
&
h_offsets
);
...
@@ -74,7 +74,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
...
@@ -74,7 +74,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
// Second algorithm:
// Second algorithm:
// https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf
// https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf
// 1. product rulebook
// 1. product rulebook
DenseTensor
counter_per_kernel
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
kernel_size
});
DenseTensor
counter_per_kernel
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
kernel_size
+
1
});
DenseTensor
offsets_per_kernel
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
kernel_size
});
DenseTensor
offsets_per_kernel
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
kernel_size
});
DenseTensor
out_index
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
1
});
DenseTensor
out_index
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
1
});
DenseTensor
unique_value
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
1
});
DenseTensor
unique_value
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
1
});
...
@@ -143,26 +143,6 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
...
@@ -143,26 +143,6 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
T
*
out_values_ptr
=
out_values
->
data
<
T
>
();
T
*
out_values_ptr
=
out_values
->
data
<
T
>
();
set_zero
(
dev_ctx
,
out_values
,
static_cast
<
T
>
(
0.0
f
));
set_zero
(
dev_ctx
,
out_values
,
static_cast
<
T
>
(
0.0
f
));
if
(
subm
)
{
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
,
1
);
unique_value
.
ResizeAndAllocate
(
{
static_cast
<
int
>
(
out
->
nnz
()
*
kernel_size
)});
out_index
.
ResizeAndAllocate
({
static_cast
<
int
>
(
rulebook_len
)});
int
*
out_index_ptr
=
out_index
.
data
<
int
>
();
int
*
unique_value_ptr
=
unique_value
.
data
<
int
>
();
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_index_ptr
,
0
,
sizeof
(
int
)
*
rulebook_len
,
dev_ctx
.
stream
());
GroupIndexs
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
rulebook_len
,
kernel_size
,
rulebook_ptr
+
rulebook_len
,
out_index_ptr
,
unique_value_ptr
);
}
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
h_counter_ptr
[
i
]
<=
0
)
{
if
(
h_counter_ptr
[
i
]
<=
0
)
{
...
@@ -196,6 +176,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
...
@@ -196,6 +176,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
unique_value
.
data
<
int
>
(),
unique_value
.
data
<
int
>
(),
out
->
nnz
(),
out
->
nnz
(),
kernel_size
,
kernel_size
,
h_counter_ptr
[
kernel_size
],
out_channels
,
out_channels
,
1
,
1
,
out_values_ptr
);
out_values_ptr
);
...
...
saxon_zh
@saxon_zh
mentioned in commit
8fbe97e4
·
9月 22, 2022
mentioned in commit
8fbe97e4
mentioned in commit 8fbe97e446b2cdb4c06a200af123712fb667e238
开关提交列表
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录