Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
227a5112
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2310
Star
20933
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
227a5112
编写于
12月 14, 2022
作者:
Z
zhangkaihuo
提交者:
GitHub
12月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Sparse]Optimize performance of sparse conv on T4 (#49009)
上级
032cbfc2
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
140 addition
and
9 deletion
+140
-9
paddle/phi/kernels/sparse/gpu/conv.cu.h
paddle/phi/kernels/sparse/gpu/conv.cu.h
+105
-5
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
+2
-2
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu
+7
-1
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
+26
-1
未找到文件。
paddle/phi/kernels/sparse/gpu/conv.cu.h
浏览文件 @
227a5112
...
...
@@ -15,8 +15,14 @@ limitations under the License. */
#pragma once
#include <thrust/remove.h>
#include <thrust/sort.h>
#include <thrust/unique.h>
#ifdef __NVCC__
#include <cub/block/block_scan.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/phi/kernels/sparse/conv_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
...
...
@@ -199,6 +205,88 @@ __global__ void UniqueKernel(const IntT* in_indexs,
}
}
inline
__device__
uint32_t
BitCount
(
const
uint32_t
data
)
{
uint32_t
count
=
data
;
count
=
(
count
&
0x55555555
)
+
((
count
>>
1
)
&
0x55555555
);
count
=
(
count
&
0x33333333
)
+
((
count
>>
2
)
&
0x33333333
);
count
=
(
count
&
0x0f0f0f0f
)
+
((
count
>>
4
)
&
0x0f0f0f0f
);
count
=
(
count
&
0x00ff00ff
)
+
((
count
>>
8
)
&
0x00ff00ff
);
count
=
(
count
&
0x0000ffff
)
+
((
count
>>
16
)
&
0x0000ffff
);
return
count
;
}
static
__global__
void
GetOutIndexsCounter
(
const
int
*
flags
,
const
int
n
,
int
*
out
)
{
int
tid
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
__shared__
int
block_count
;
if
(
threadIdx
.
x
==
0
)
{
block_count
=
0
;
}
__syncthreads
();
if
(
tid
<
n
)
{
// get the count of 1 in flags[tid]
uint32_t
count
=
BitCount
(
static_cast
<
uint32_t
>
(
flags
[
tid
]));
// add to block_count
// TODO(zhangkaihuo): replace with block reduce_sum
atomicAdd
(
&
block_count
,
static_cast
<
int
>
(
count
));
}
__syncthreads
();
// write to out
if
(
threadIdx
.
x
==
0
)
{
out
[
blockIdx
.
x
]
=
block_count
;
}
}
template
<
int
BS
>
__global__
void
GetOutIndexs
(
const
int
*
flags
,
const
int
n
,
const
int
*
offsets
,
const
int
out_nnz
,
int
*
out
)
{
int
tid
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
__shared__
int
block_counts
[
BS
];
__shared__
int
block_outs
[
BS
*
32
];
int
count
=
0
;
if
(
tid
<
n
)
{
// get the count of 1 in flags[tid]
int
flag
=
flags
[
tid
];
count
=
BitCount
(
static_cast
<
uint32_t
>
(
flag
));
}
// call block prefix_sum
// using namespace cub;
typedef
cub
::
BlockScan
<
int
,
BS
>
BlockScan
;
__shared__
typename
BlockScan
::
TempStorage
temp_storage
;
BlockScan
(
temp_storage
).
ExclusiveSum
(
count
,
count
);
__syncthreads
();
// write index to out
if
(
tid
<
n
)
{
// get the count of 1 in flags[tid]
int
flag
=
flags
[
tid
];
// int j = block_counts[threadIdx.x];
int
j
=
count
;
// TODO(zhangkaihuo): opt the loop
for
(
int
i
=
0
;
i
<
32
;
++
i
)
{
if
((
1
&
(
flag
>>
i
))
==
1
)
{
block_outs
[
j
++
]
=
(
tid
<<
5
)
+
i
;
}
}
}
__syncthreads
();
// write to block_outs
int
start
=
offsets
[
blockIdx
.
x
];
int
end
=
blockIdx
.
x
==
gridDim
.
x
-
1
?
out_nnz
:
offsets
[
blockIdx
.
x
+
1
];
for
(
int
i
=
threadIdx
.
x
;
i
<
end
-
start
;
i
+=
blockDim
.
x
)
{
out
[
start
+
i
]
=
block_outs
[
i
];
}
}
template
<
typename
IntT
>
__global__
void
GroupIndexs
(
const
int
*
out_index_table
,
const
int
n
,
...
...
@@ -725,13 +813,25 @@ int ProductRuleBook(const Context& dev_ctx,
gpuMemcpyDeviceToHost
,
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
const
int
threads
=
256
;
const
int
blocks
=
(
index_flags
.
numel
()
+
threads
-
1
)
/
threads
;
GetOutIndexsCounter
<<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
index_flags_ptr
,
index_flags
.
numel
(),
out_index_table_ptr
);
#ifdef PADDLE_WITH_HIP
thrust
::
sort
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
thrust
::
exclusive_scan
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
thrust
::
sort
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
thrust
::
exclusive_scan
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
out_index_ptr
,
out_index_ptr
+
out_nnz
);
out_index_table_ptr
,
out_index_table_ptr
+
blocks
,
out_index_table_ptr
);
GetOutIndexs
<
threads
>
<<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
index_flags_ptr
,
index_flags
.
numel
(),
out_index_table_ptr
,
out_nnz
,
out_index_ptr
);
const
int64_t
sparse_dim
=
4
;
phi
::
DenseTensor
out_indices
=
...
...
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
浏览文件 @
227a5112
...
...
@@ -125,7 +125,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
#ifdef PADDLE_WITH_CUTLASS
bool
cutlass
=
true
;
if
(
dev_ctx
.
GetComputeCapability
()
<
80
)
cutlass
=
false
;
if
(
dev_ctx
.
GetComputeCapability
()
<
75
)
cutlass
=
false
;
if
(
in_channels
%
4
!=
0
||
out_channels
%
4
!=
0
)
{
if
(
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
)
cutlass
=
false
;
if
(
std
::
is_same
<
T
,
float
>::
value
)
cutlass
=
false
;
...
...
@@ -173,7 +173,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
if
constexpr
(
std
::
is_same
<
T
,
float
>::
value
&&
std
::
is_same
<
IntT
,
int32_t
>::
value
)
{
fp32_gather_gemm_scatter
gather_gemm_scatter
=
getBestFp32Kernel
(
M
,
N
,
K
);
getBestFp32Kernel
(
M
,
N
,
K
,
dev_ctx
.
GetComputeCapability
()
);
gather_gemm_scatter
(
dev_ctx
,
x
.
non_zero_elements
().
data
<
T
>
(),
tmp_kernel_ptr
,
...
...
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu
浏览文件 @
227a5112
...
...
@@ -72,7 +72,13 @@ fp16_gather_gemm_scatter getBestFp16Kernel(const int M,
}
fp32_gather_gemm_scatter
getBestFp32Kernel
(
const
int
M
,
const
int
N
,
const
int
K
)
{
const
int
K
,
const
int
SM
)
{
if
(
SM
==
75
)
{
return
launchKernel
<
float
,
cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4
::
Gemm
>
;
}
if
(
K
==
4
&&
N
==
16
)
{
return
launchKernel
<
float
,
...
...
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
浏览文件 @
227a5112
...
...
@@ -66,7 +66,8 @@ fp16_gather_gemm_scatter getBestFp16Kernel(const int M,
const
int
N
);
fp32_gather_gemm_scatter
getBestFp32Kernel
(
const
int
M
,
const
int
K
,
const
int
N
);
const
int
N
,
const
int
SM
);
fp64_gather_gemm_scatter
getBestFp64Kernel
(
const
int
M
,
const
int
K
,
const
int
N
);
...
...
@@ -550,6 +551,30 @@ struct cutlass_tensorop_d884gemm_32x16_16x5_nn_align1 {
false
,
true
>
;
};
// sm75
struct
cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm75
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
32
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
32
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
float
,
4
,
float
,
float
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
,
2
,
8
,
8
,
cutlass
::
arch
::
OpMultiplyAdd
>
;
};
}
// namespace sparse
}
// namespace phi
#endif
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录