Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
913f40ee
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
913f40ee
编写于
2月 10, 2023
作者:
Z
zhangkaihuo
提交者:
GitHub
2月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cherry-pick] remove if constexpr(), which is not supported on gcc54 (#50421)
att, cherry-pick #48563
上级
eb610740
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
73 addition
and
54 deletion
+73
-54
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
+12
-54
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
+61
-0
未找到文件。
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
浏览文件 @
913f40ee
...
@@ -150,60 +150,18 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
...
@@ -150,60 +150,18 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const
IntT
*
gather_indices
=
rulebook_ptr
+
h_offsets_ptr
[
i
];
const
IntT
*
gather_indices
=
rulebook_ptr
+
h_offsets_ptr
[
i
];
const
IntT
*
scatter_indices
=
const
IntT
*
scatter_indices
=
rulebook_ptr
+
rulebook_len
+
h_offsets_ptr
[
i
];
rulebook_ptr
+
rulebook_len
+
h_offsets_ptr
[
i
];
dispatchKernel
(
dev_ctx
,
if
constexpr
(
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
&&
x
.
non_zero_elements
().
data
<
T
>
(),
std
::
is_same
<
IntT
,
int32_t
>::
value
)
{
tmp_kernel_ptr
,
fp16_gather_gemm_scatter
gather_gemm_scatter
=
out_values_ptr
,
getBestFp16Kernel
(
M
,
N
,
K
);
out_values_ptr
,
gather_gemm_scatter
(
M
,
dev_ctx
,
N
,
reinterpret_cast
<
const
cutlass
::
half_t
*>
(
K
,
x
.
non_zero_elements
().
data
<
T
>
()),
gather_indices
,
reinterpret_cast
<
const
cutlass
::
half_t
*>
(
tmp_kernel_ptr
),
scatter_indices
,
reinterpret_cast
<
cutlass
::
half_t
*>
(
out_values_ptr
),
cutlass
,
reinterpret_cast
<
cutlass
::
half_t
*>
(
out_values_ptr
),
x
.
dtype
());
M
,
N
,
K
,
static_cast
<
const
int32_t
*>
(
gather_indices
),
static_cast
<
const
int32_t
*>
(
scatter_indices
),
static_cast
<
cutlass
::
half_t
>
(
1
),
static_cast
<
cutlass
::
half_t
>
(
1
));
}
if
constexpr
(
std
::
is_same
<
T
,
float
>::
value
&&
std
::
is_same
<
IntT
,
int32_t
>::
value
)
{
fp32_gather_gemm_scatter
gather_gemm_scatter
=
getBestFp32Kernel
(
M
,
N
,
K
,
dev_ctx
.
GetComputeCapability
());
gather_gemm_scatter
(
dev_ctx
,
x
.
non_zero_elements
().
data
<
T
>
(),
tmp_kernel_ptr
,
out_values_ptr
,
out_values_ptr
,
M
,
N
,
K
,
gather_indices
,
scatter_indices
,
static_cast
<
T
>
(
1
),
static_cast
<
T
>
(
1
));
}
if
constexpr
(
std
::
is_same
<
T
,
double
>::
value
&&
std
::
is_same
<
IntT
,
int32_t
>::
value
)
{
fp64_gather_gemm_scatter
gather_gemm_scatter
=
getBestFp64Kernel
(
M
,
N
,
K
);
gather_gemm_scatter
(
dev_ctx
,
x
.
non_zero_elements
().
data
<
T
>
(),
tmp_kernel_ptr
,
out_values_ptr
,
out_values_ptr
,
M
,
N
,
K
,
gather_indices
,
scatter_indices
,
static_cast
<
T
>
(
1
),
static_cast
<
T
>
(
1
));
}
}
}
}
else
{
}
else
{
#endif
#endif
...
...
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
浏览文件 @
913f40ee
...
@@ -23,6 +23,7 @@
...
@@ -23,6 +23,7 @@
#include "cutlass/util/device_memory.h"
#include "cutlass/util/device_memory.h"
#include "examples/common/helper.h"
#include "examples/common/helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
typedef
void
(
*
fp16_gather_gemm_scatter
)(
const
GPUContext
&
dev_ctx
,
typedef
void
(
*
fp16_gather_gemm_scatter
)(
const
GPUContext
&
dev_ctx
,
...
@@ -115,6 +116,66 @@ void launchKernel(const GPUContext& dev_ctx,
...
@@ -115,6 +116,66 @@ void launchKernel(const GPUContext& dev_ctx,
CUTLASS_CHECK
(
status
);
CUTLASS_CHECK
(
status
);
gemm_op
(
dev_ctx
.
stream
());
gemm_op
(
dev_ctx
.
stream
());
}
}
static
void
dispatchKernel
(
const
GPUContext
&
dev_ctx
,
const
void
*
const
a
,
const
void
*
const
b
,
const
void
*
const
c
,
void
*
const
d
,
const
int
m
,
const
int
n
,
const
int
k
,
const
void
*
a_indices
,
const
void
*
c_d_indices
,
const
bool
cutlass
,
const
phi
::
DataType
type
)
{
if
(
!
cutlass
)
return
;
if
(
type
==
phi
::
DataType
::
FLOAT16
)
{
fp16_gather_gemm_scatter
gather_gemm_scatter
=
getBestFp16Kernel
(
m
,
n
,
k
);
gather_gemm_scatter
(
dev_ctx
,
static_cast
<
const
cutlass
::
half_t
*>
(
a
),
static_cast
<
const
cutlass
::
half_t
*>
(
b
),
static_cast
<
const
cutlass
::
half_t
*>
(
c
),
static_cast
<
cutlass
::
half_t
*>
(
d
),
m
,
n
,
k
,
static_cast
<
const
int32_t
*>
(
a_indices
),
static_cast
<
const
int32_t
*>
(
c_d_indices
),
static_cast
<
cutlass
::
half_t
>
(
1
),
static_cast
<
cutlass
::
half_t
>
(
1
));
}
else
if
(
type
==
phi
::
DataType
::
FLOAT32
)
{
fp32_gather_gemm_scatter
gather_gemm_scatter
=
getBestFp32Kernel
(
m
,
n
,
k
,
dev_ctx
.
GetComputeCapability
());
gather_gemm_scatter
(
dev_ctx
,
static_cast
<
const
float
*>
(
a
),
static_cast
<
const
float
*>
(
b
),
static_cast
<
const
float
*>
(
c
),
static_cast
<
float
*>
(
d
),
m
,
n
,
k
,
static_cast
<
const
int32_t
*>
(
a_indices
),
static_cast
<
const
int32_t
*>
(
c_d_indices
),
static_cast
<
float
>
(
1
),
static_cast
<
float
>
(
1
));
}
else
if
(
type
==
phi
::
DataType
::
FLOAT64
)
{
fp64_gather_gemm_scatter
gather_gemm_scatter
=
getBestFp64Kernel
(
m
,
n
,
k
);
gather_gemm_scatter
(
dev_ctx
,
static_cast
<
const
double
*>
(
a
),
static_cast
<
const
double
*>
(
b
),
static_cast
<
const
double
*>
(
c
),
static_cast
<
double
*>
(
d
),
m
,
n
,
k
,
static_cast
<
const
int32_t
*>
(
a_indices
),
static_cast
<
const
int32_t
*>
(
c_d_indices
),
static_cast
<
double
>
(
1
),
static_cast
<
double
>
(
1
));
}
}
struct
cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8
{
struct
cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
cutlass
::
half_t
,
cutlass
::
half_t
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录