Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
913f40ee
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
913f40ee
编写于
2月 10, 2023
作者:
Z
zhangkaihuo
提交者:
GitHub
2月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cherry-pick] remove if constexpr(), which is not supported on gcc54 (#50421)
att, cherry-pick #48563
上级
eb610740
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
73 addition
and
54 deletion
+73
-54
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
+12
-54
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
+61
-0
未找到文件。
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
浏览文件 @
913f40ee
...
...
@@ -150,60 +150,18 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const
IntT
*
gather_indices
=
rulebook_ptr
+
h_offsets_ptr
[
i
];
const
IntT
*
scatter_indices
=
rulebook_ptr
+
rulebook_len
+
h_offsets_ptr
[
i
];
if
constexpr
(
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
&&
std
::
is_same
<
IntT
,
int32_t
>::
value
)
{
fp16_gather_gemm_scatter
gather_gemm_scatter
=
getBestFp16Kernel
(
M
,
N
,
K
);
gather_gemm_scatter
(
dev_ctx
,
reinterpret_cast
<
const
cutlass
::
half_t
*>
(
x
.
non_zero_elements
().
data
<
T
>
()),
reinterpret_cast
<
const
cutlass
::
half_t
*>
(
tmp_kernel_ptr
),
reinterpret_cast
<
cutlass
::
half_t
*>
(
out_values_ptr
),
reinterpret_cast
<
cutlass
::
half_t
*>
(
out_values_ptr
),
M
,
N
,
K
,
static_cast
<
const
int32_t
*>
(
gather_indices
),
static_cast
<
const
int32_t
*>
(
scatter_indices
),
static_cast
<
cutlass
::
half_t
>
(
1
),
static_cast
<
cutlass
::
half_t
>
(
1
));
}
if
constexpr
(
std
::
is_same
<
T
,
float
>::
value
&&
std
::
is_same
<
IntT
,
int32_t
>::
value
)
{
fp32_gather_gemm_scatter
gather_gemm_scatter
=
getBestFp32Kernel
(
M
,
N
,
K
,
dev_ctx
.
GetComputeCapability
());
gather_gemm_scatter
(
dev_ctx
,
x
.
non_zero_elements
().
data
<
T
>
(),
tmp_kernel_ptr
,
out_values_ptr
,
out_values_ptr
,
M
,
N
,
K
,
gather_indices
,
scatter_indices
,
static_cast
<
T
>
(
1
),
static_cast
<
T
>
(
1
));
}
if
constexpr
(
std
::
is_same
<
T
,
double
>::
value
&&
std
::
is_same
<
IntT
,
int32_t
>::
value
)
{
fp64_gather_gemm_scatter
gather_gemm_scatter
=
getBestFp64Kernel
(
M
,
N
,
K
);
gather_gemm_scatter
(
dev_ctx
,
x
.
non_zero_elements
().
data
<
T
>
(),
tmp_kernel_ptr
,
out_values_ptr
,
out_values_ptr
,
M
,
N
,
K
,
gather_indices
,
scatter_indices
,
static_cast
<
T
>
(
1
),
static_cast
<
T
>
(
1
));
}
dispatchKernel
(
dev_ctx
,
x
.
non_zero_elements
().
data
<
T
>
(),
tmp_kernel_ptr
,
out_values_ptr
,
out_values_ptr
,
M
,
N
,
K
,
gather_indices
,
scatter_indices
,
cutlass
,
x
.
dtype
());
}
}
else
{
#endif
...
...
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
浏览文件 @
913f40ee
...
...
@@ -23,6 +23,7 @@
#include "cutlass/util/device_memory.h"
#include "examples/common/helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
namespace
phi
{
namespace
sparse
{
typedef
void
(
*
fp16_gather_gemm_scatter
)(
const
GPUContext
&
dev_ctx
,
...
...
@@ -115,6 +116,66 @@ void launchKernel(const GPUContext& dev_ctx,
CUTLASS_CHECK
(
status
);
gemm_op
(
dev_ctx
.
stream
());
}
static
void
dispatchKernel
(
const
GPUContext
&
dev_ctx
,
const
void
*
const
a
,
const
void
*
const
b
,
const
void
*
const
c
,
void
*
const
d
,
const
int
m
,
const
int
n
,
const
int
k
,
const
void
*
a_indices
,
const
void
*
c_d_indices
,
const
bool
cutlass
,
const
phi
::
DataType
type
)
{
if
(
!
cutlass
)
return
;
if
(
type
==
phi
::
DataType
::
FLOAT16
)
{
fp16_gather_gemm_scatter
gather_gemm_scatter
=
getBestFp16Kernel
(
m
,
n
,
k
);
gather_gemm_scatter
(
dev_ctx
,
static_cast
<
const
cutlass
::
half_t
*>
(
a
),
static_cast
<
const
cutlass
::
half_t
*>
(
b
),
static_cast
<
const
cutlass
::
half_t
*>
(
c
),
static_cast
<
cutlass
::
half_t
*>
(
d
),
m
,
n
,
k
,
static_cast
<
const
int32_t
*>
(
a_indices
),
static_cast
<
const
int32_t
*>
(
c_d_indices
),
static_cast
<
cutlass
::
half_t
>
(
1
),
static_cast
<
cutlass
::
half_t
>
(
1
));
}
else
if
(
type
==
phi
::
DataType
::
FLOAT32
)
{
fp32_gather_gemm_scatter
gather_gemm_scatter
=
getBestFp32Kernel
(
m
,
n
,
k
,
dev_ctx
.
GetComputeCapability
());
gather_gemm_scatter
(
dev_ctx
,
static_cast
<
const
float
*>
(
a
),
static_cast
<
const
float
*>
(
b
),
static_cast
<
const
float
*>
(
c
),
static_cast
<
float
*>
(
d
),
m
,
n
,
k
,
static_cast
<
const
int32_t
*>
(
a_indices
),
static_cast
<
const
int32_t
*>
(
c_d_indices
),
static_cast
<
float
>
(
1
),
static_cast
<
float
>
(
1
));
}
else
if
(
type
==
phi
::
DataType
::
FLOAT64
)
{
fp64_gather_gemm_scatter
gather_gemm_scatter
=
getBestFp64Kernel
(
m
,
n
,
k
);
gather_gemm_scatter
(
dev_ctx
,
static_cast
<
const
double
*>
(
a
),
static_cast
<
const
double
*>
(
b
),
static_cast
<
const
double
*>
(
c
),
static_cast
<
double
*>
(
d
),
m
,
n
,
k
,
static_cast
<
const
int32_t
*>
(
a_indices
),
static_cast
<
const
int32_t
*>
(
c_d_indices
),
static_cast
<
double
>
(
1
),
static_cast
<
double
>
(
1
));
}
}
struct
cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
cutlass
::
half_t
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录