Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5158fa4f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5158fa4f
编写于
11月 01, 2022
作者:
U
umiswing
提交者:
GitHub
11月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
summer-ospp 2022: 飞桨PaddlePaddle Sparse Conv开发和优化: gather-gemm-scatter fuse (#46679)
上级
60e0c506
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
941 addition
and
56 deletion
+941
-56
cmake/external/cutlass.cmake
cmake/external/cutlass.cmake
+43
-0
cmake/third_party.cmake
cmake/third_party.cmake
+10
-0
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
+145
-56
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu
+188
-0
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
+555
-0
未找到文件。
cmake/external/cutlass.cmake
0 → 100644
浏览文件 @
5158fa4f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include
(
ExternalProject
)
set
(
CUTLASS_PREFIX_DIR
${
THIRD_PARTY_PATH
}
/cutlass
)
set
(
CUTLASS_REPOSITORY https://github.com/NVIDIA/cutlass.git
)
set
(
CUTLASS_TAG v2.9.1
)
include_directories
(
"
${
THIRD_PARTY_PATH
}
/cutlass/src/extern_cutlass/"
)
include_directories
(
"
${
THIRD_PARTY_PATH
}
/cutlass/src/extern_cutlass/include/"
)
include_directories
(
"
${
THIRD_PARTY_PATH
}
/cutlass/src/extern_cutlass/tools/util/include/"
)
add_definitions
(
"-DPADDLE_WITH_CUTLASS"
)
ExternalProject_Add
(
extern_cutlass
${
EXTERNAL_PROJECT_LOG_ARGS
}
${
SHALLOW_CLONE
}
GIT_REPOSITORY
${
CUTLASS_REPOSITORY
}
GIT_TAG
"
${
CUTLASS_TAG
}
"
PREFIX
${
CUTLASS_PREFIX_DIR
}
UPDATE_COMMAND
""
CONFIGURE_COMMAND
""
BUILD_COMMAND
""
INSTALL_COMMAND
""
TEST_COMMAND
""
)
add_library
(
cutlass INTERFACE
)
add_dependencies
(
cutlass extern_cutlass
)
cmake/third_party.cmake
浏览文件 @
5158fa4f
...
...
@@ -505,4 +505,14 @@ if(WITH_CUSPARSELT)
list
(
APPEND third_party_deps extern_cusparselt
)
endif
()
if
(
WITH_GPU
AND NOT WITH_ARM
AND NOT WIN32
AND NOT APPLE
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
GREATER_EQUAL 11.0
)
include
(
external/cutlass
)
# download, build, install cusparselt
list
(
APPEND third_party_deps extern_cutlass
)
endif
()
endif
()
add_custom_target
(
third_party ALL DEPENDS
${
third_party_deps
}
)
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
浏览文件 @
5158fa4f
...
...
@@ -22,6 +22,9 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
#ifdef PADDLE_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
#endif
#include "glog/logging.h"
...
...
@@ -120,29 +123,6 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
dev_ctx
,
x
,
key
,
tmp_rulebook
,
h_counter
,
out
,
rulebook
,
counter
);
}
// 2. gather
phi
::
DenseTensor
in_features
=
phi
::
Empty
<
T
>
(
dev_ctx
,
{
rulebook_len
,
in_channels
});
phi
::
DenseTensor
out_features
=
phi
::
Empty
<
T
>
(
dev_ctx
,
{
rulebook_len
,
out_channels
});
T
*
in_features_ptr
=
in_features
.
data
<
T
>
();
T
*
out_features_ptr
=
out_features
.
data
<
T
>
();
phi
::
funcs
::
SetConstant
<
GPUContext
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
&
out_features
,
static_cast
<
T
>
(
0.0
f
));
Gather
<
T
,
IntT
>
(
dev_ctx
,
x
.
values
().
data
<
T
>
(),
rulebook_ptr
,
rulebook_len
,
in_channels
,
in_features_ptr
);
// 3. call gemm for every werght
auto
blas
=
phi
::
funcs
::
GetBlas
<
GPUContext
,
T
>
(
dev_ctx
);
auto
*
out_values
=
out
->
mutable_values
();
T
*
out_values_ptr
=
out_values
->
data
<
T
>
();
set_zero
(
dev_ctx
,
out_values
,
static_cast
<
T
>
(
0.0
f
));
if
(
subm
)
{
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
,
1
);
...
...
@@ -162,43 +142,152 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
out_index_ptr
,
unique_value_ptr
);
}
#ifdef PADDLE_WITH_CUTLASS
bool
cutlass
=
true
;
if
(
dev_ctx
.
GetComputeCapability
()
<
80
)
cutlass
=
false
;
if
(
in_channels
%
4
!=
0
||
out_channels
%
4
!=
0
)
{
if
(
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
)
cutlass
=
false
;
if
(
std
::
is_same
<
T
,
float
>::
value
)
cutlass
=
false
;
}
if
(
!
std
::
is_same
<
IntT
,
int32_t
>::
value
)
cutlass
=
false
;
if
(
cutlass
)
{
auto
*
out_values
=
out
->
mutable_non_zero_elements
();
T
*
out_values_ptr
=
out_values
->
data
<
T
>
();
phi
::
funcs
::
SetConstant
<
GPUContext
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
out_values
,
static_cast
<
T
>
(
0.0
f
));
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
h_counter_ptr
[
i
]
<=
0
)
{
continue
;
}
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
h_counter_ptr
[
i
]
<=
0
)
{
continue
;
const
int
M
=
h_counter_ptr
[
i
];
const
int
K
=
in_channels
;
const
int
N
=
out_channels
;
const
T
*
tmp_kernel_ptr
=
kernel_ptr
+
i
*
K
*
N
;
const
IntT
*
gather_indices
=
rulebook_ptr
+
h_offsets_ptr
[
i
];
const
IntT
*
scatter_indices
=
rulebook_ptr
+
rulebook_len
+
h_offsets_ptr
[
i
];
if
constexpr
(
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
&&
std
::
is_same
<
IntT
,
int32_t
>::
value
)
{
fp16_gather_gemm_scatter
gather_gemm_scatter
=
getBestFp16Kernel
(
M
,
N
,
K
);
gather_gemm_scatter
(
dev_ctx
,
reinterpret_cast
<
const
cutlass
::
half_t
*>
(
x
.
non_zero_elements
().
data
<
T
>
()),
reinterpret_cast
<
const
cutlass
::
half_t
*>
(
tmp_kernel_ptr
),
reinterpret_cast
<
cutlass
::
half_t
*>
(
out_values_ptr
),
reinterpret_cast
<
cutlass
::
half_t
*>
(
out_values_ptr
),
M
,
N
,
K
,
static_cast
<
const
int32_t
*>
(
gather_indices
),
static_cast
<
const
int32_t
*>
(
scatter_indices
),
static_cast
<
cutlass
::
half_t
>
(
1
),
static_cast
<
cutlass
::
half_t
>
(
1
));
}
if
constexpr
(
std
::
is_same
<
T
,
float
>::
value
&&
std
::
is_same
<
IntT
,
int32_t
>::
value
)
{
fp32_gather_gemm_scatter
gather_gemm_scatter
=
getBestFp32Kernel
(
M
,
N
,
K
);
gather_gemm_scatter
(
dev_ctx
,
x
.
non_zero_elements
().
data
<
T
>
(),
tmp_kernel_ptr
,
out_values_ptr
,
out_values_ptr
,
M
,
N
,
K
,
gather_indices
,
scatter_indices
,
static_cast
<
T
>
(
1
),
static_cast
<
T
>
(
1
));
}
if
constexpr
(
std
::
is_same
<
T
,
double
>::
value
&&
std
::
is_same
<
IntT
,
int32_t
>::
value
)
{
fp64_gather_gemm_scatter
gather_gemm_scatter
=
getBestFp64Kernel
(
M
,
N
,
K
);
gather_gemm_scatter
(
dev_ctx
,
x
.
non_zero_elements
().
data
<
T
>
(),
tmp_kernel_ptr
,
out_values_ptr
,
out_values_ptr
,
M
,
N
,
K
,
gather_indices
,
scatter_indices
,
static_cast
<
T
>
(
1
),
static_cast
<
T
>
(
1
));
}
}
}
else
{
#endif
// 2. gather
phi
::
DenseTensor
in_features
=
phi
::
Empty
<
T
>
(
dev_ctx
,
{
rulebook_len
,
in_channels
});
phi
::
DenseTensor
out_features
=
phi
::
Empty
<
T
>
(
dev_ctx
,
{
rulebook_len
,
out_channels
});
T
*
in_features_ptr
=
in_features
.
data
<
T
>
();
T
*
out_features_ptr
=
out_features
.
data
<
T
>
();
phi
::
funcs
::
SetConstant
<
GPUContext
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
&
out_features
,
static_cast
<
T
>
(
0.0
f
));
// call gemm: (n, in_channels) * (in_channels, out_channels)
const
int
M
=
h_counter_ptr
[
i
];
const
int
K
=
in_channels
;
const
int
N
=
out_channels
;
T
*
tmp_in_ptr
=
in_features_ptr
+
h_offsets_ptr
[
i
]
*
in_channels
;
const
T
*
tmp_kernel_ptr
=
kernel_ptr
+
i
*
K
*
N
;
T
*
tmp_out_ptr
=
out_features_ptr
+
h_offsets_ptr
[
i
]
*
out_channels
;
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
M
,
N
,
K
,
static_cast
<
T
>
(
1
),
tmp_in_ptr
,
tmp_kernel_ptr
,
static_cast
<
T
>
(
0
),
tmp_out_ptr
);
}
Gather
<
T
,
IntT
>
(
dev_ctx
,
x
.
values
().
data
<
T
>
(),
rulebook_ptr
,
rulebook_len
,
in_channels
,
in_features_ptr
);
// 3. call gemm for every werght
auto
blas
=
phi
::
funcs
::
GetBlas
<
GPUContext
,
T
>
(
dev_ctx
);
auto
*
out_values
=
out
->
mutable_values
();
T
*
out_values_ptr
=
out_values
->
data
<
T
>
();
set_zero
(
dev_ctx
,
out_values
,
static_cast
<
T
>
(
0.0
f
));
// 4. scatter
phi
::
funcs
::
sparse
::
ScatterV2
<
T
>
(
dev_ctx
,
out_features_ptr
,
out_index
.
data
<
int
>
(),
unique_value
.
data
<
int
>
(),
out
->
nnz
(),
kernel_size
,
out_channels
,
1
,
out_values_ptr
);
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
h_counter_ptr
[
i
]
<=
0
)
{
continue
;
}
// call gemm: (n, in_channels) * (in_channels, out_channels)
const
int
M
=
h_counter_ptr
[
i
];
const
int
K
=
in_channels
;
const
int
N
=
out_channels
;
T
*
tmp_in_ptr
=
in_features_ptr
+
h_offsets_ptr
[
i
]
*
in_channels
;
const
T
*
tmp_kernel_ptr
=
kernel_ptr
+
i
*
K
*
N
;
T
*
tmp_out_ptr
=
out_features_ptr
+
h_offsets_ptr
[
i
]
*
out_channels
;
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
M
,
N
,
K
,
static_cast
<
T
>
(
1
),
tmp_in_ptr
,
tmp_kernel_ptr
,
static_cast
<
T
>
(
0
),
tmp_out_ptr
);
}
// 4. scatter
phi
::
funcs
::
sparse
::
ScatterV2
<
T
>
(
dev_ctx
,
out_features_ptr
,
out_index
.
data
<
int
>
(),
unique_value
.
data
<
int
>
(),
out
->
nnz
(),
kernel_size
,
out_channels
,
1
,
out_values_ptr
);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
}
/**
...
...
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu
0 → 100644
浏览文件 @
5158fa4f
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
namespace
phi
{
namespace
sparse
{
fp16_gather_gemm_scatter
getBestFp16Kernel
(
const
int
M
,
const
int
N
,
const
int
K
)
{
if
(
K
==
4
&&
N
==
16
)
{
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4
::
Gemm
>
;
}
if
(
K
==
16
&&
N
==
16
)
{
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8
::
Gemm
>
;
}
if
(
K
==
16
&&
N
==
32
)
{
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8
::
Gemm
>
;
}
if
(
K
==
32
&&
N
==
32
)
{
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8
::
Gemm
>
;
}
if
(
K
==
32
&&
N
==
64
)
{
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8
::
Gemm
>
;
}
if
(
K
==
64
&&
N
==
64
)
{
if
(
M
>
100000
)
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8
::
Gemm
>
;
if
(
M
>
20000
)
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8
::
Gemm
>
;
if
(
M
>
15000
)
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8
::
Gemm
>
;
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8
::
Gemm
>
;
}
if
(
K
==
128
)
{
if
(
M
>=
5000
)
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8
::
Gemm
>
;
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h16816gemm_64x64_64x5_nn_align8
::
Gemm
>
;
}
if
(
N
==
128
)
{
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8
::
Gemm
>
;
}
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4
::
Gemm
>
;
}
fp32_gather_gemm_scatter
getBestFp32Kernel
(
const
int
M
,
const
int
N
,
const
int
K
)
{
if
(
K
==
4
&&
N
==
16
)
{
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4
::
Gemm
>
;
}
if
(
K
==
16
&&
N
==
16
)
{
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4
::
Gemm
>
;
}
if
(
K
==
16
&&
N
==
32
)
{
if
(
M
>=
10000
)
return
launchKernel
<
float
,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4
::
Gemm
>
;
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4
::
Gemm
>
;
}
if
(
K
==
32
&&
N
==
32
)
{
if
(
M
>=
10000
)
return
launchKernel
<
float
,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4
::
Gemm
>
;
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4
::
Gemm
>
;
}
if
(
K
==
32
&&
N
==
64
)
{
if
(
M
>=
10000
)
return
launchKernel
<
float
,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4
::
Gemm
>
;
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4
::
Gemm
>
;
}
if
(
K
==
64
&&
N
==
64
)
{
if
(
M
>=
15000
)
return
launchKernel
<
float
,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4
::
Gemm
>
;
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4
::
Gemm
>
;
}
if
(
K
==
128
)
{
if
(
M
>=
100000
)
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4
::
Gemm
>
;
if
(
M
>=
5000
)
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4
::
Gemm
>
;
return
launchKernel
<
float
,
cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4
::
Gemm
>
;
}
if
(
N
==
128
)
{
if
(
M
>=
100000
)
return
launchKernel
<
float
,
cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4
::
Gemm
>
;
if
(
M
>=
5000
)
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4
::
Gemm
>
;
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4
::
Gemm
>
;
}
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4
::
Gemm
>
;
}
fp64_gather_gemm_scatter
getBestFp64Kernel
(
const
int
M
,
const
int
N
,
const
int
K
)
{
if
(
K
==
4
&&
N
==
16
)
{
return
launchKernel
<
double
,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1
::
Gemm
>
;
}
if
(
K
==
16
&&
N
==
16
)
{
if
(
M
>=
10000
)
return
launchKernel
<
double
,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1
::
Gemm
>
;
return
launchKernel
<
double
,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1
::
Gemm
>
;
}
if
(
K
==
16
&&
N
==
32
)
{
return
launchKernel
<
double
,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1
::
Gemm
>
;
}
if
(
K
==
32
&&
N
==
32
)
{
return
launchKernel
<
double
,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1
::
Gemm
>
;
}
if
(
K
==
32
&&
N
==
64
)
{
return
launchKernel
<
double
,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1
::
Gemm
>
;
}
if
(
K
==
64
&&
N
==
64
)
{
return
launchKernel
<
double
,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1
::
Gemm
>
;
}
return
launchKernel
<
double
,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1
::
Gemm
>
;
}
}
// namespace sparse
}
// namespace phi
#endif
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
0 → 100644
浏览文件 @
5158fa4f
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_CUTLASS
#include "cutlass/arch/mma.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/util/device_memory.h"
#include "examples/common/helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace
phi
{
namespace
sparse
{
typedef
void
(
*
fp16_gather_gemm_scatter
)(
const
GPUContext
&
dev_ctx
,
const
cutlass
::
half_t
*
const
a
,
const
cutlass
::
half_t
*
const
b
,
const
cutlass
::
half_t
*
const
c
,
cutlass
::
half_t
*
const
d
,
const
int
m
,
const
int
n
,
const
int
k
,
const
int32_t
*
a_indices
,
const
int32_t
*
c_d_indices
,
cutlass
::
half_t
const
alpha
,
cutlass
::
half_t
const
beta
);
typedef
void
(
*
fp32_gather_gemm_scatter
)(
const
GPUContext
&
dev_ctx
,
const
float
*
const
a
,
const
float
*
const
b
,
const
float
*
const
c
,
float
*
const
d
,
const
int
m
,
const
int
n
,
const
int
k
,
const
int32_t
*
a_indices
,
const
int32_t
*
c_d_indices
,
float
const
alpha
,
float
const
beta
);
typedef
void
(
*
fp64_gather_gemm_scatter
)(
const
GPUContext
&
dev_ctx
,
const
double
*
const
a
,
const
double
*
const
b
,
const
double
*
const
c
,
double
*
const
d
,
const
int
m
,
const
int
n
,
const
int
k
,
const
int32_t
*
a_indices
,
const
int32_t
*
c_d_indices
,
double
const
alpha
,
double
const
beta
);
fp16_gather_gemm_scatter
getBestFp16Kernel
(
const
int
M
,
const
int
K
,
const
int
N
);
fp32_gather_gemm_scatter
getBestFp32Kernel
(
const
int
M
,
const
int
K
,
const
int
N
);
fp64_gather_gemm_scatter
getBestFp64Kernel
(
const
int
M
,
const
int
K
,
const
int
N
);
template
<
typename
T
,
typename
Gemm
>
void
launchKernel
(
const
GPUContext
&
dev_ctx
,
const
T
*
const
a
,
const
T
*
const
b
,
const
T
*
const
c
,
T
*
const
d
,
const
int
m
,
const
int
n
,
const
int
k
,
const
int32_t
*
a_indices
,
const
int32_t
*
c_d_indices
,
T
const
alpha
,
T
const
beta
)
{
cutlass
::
gemm
::
GemmCoord
problem_size_real
({
m
,
n
,
k
});
int
split_k_slices
=
1
;
typename
Gemm
::
Arguments
arguments
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
problem_size_real
,
split_k_slices
,
{
alpha
,
beta
},
a
,
b
,
c
,
d
,
cutlass
::
layout
::
RowMajor
().
capacity
(
problem_size_real
.
mk
()),
cutlass
::
layout
::
RowMajor
().
capacity
(
problem_size_real
.
kn
()),
cutlass
::
layout
::
RowMajor
().
capacity
(
problem_size_real
.
mn
()),
cutlass
::
layout
::
RowMajor
().
capacity
(
problem_size_real
.
mn
()),
problem_size_real
.
k
(),
problem_size_real
.
n
(),
problem_size_real
.
n
(),
problem_size_real
.
n
(),
a_indices
,
nullptr
,
c_d_indices
};
size_t
workspace_size
=
Gemm
::
get_workspace_size
(
arguments
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
Gemm
gemm_op
;
cutlass
::
Status
status
=
gemm_op
.
can_implement
(
arguments
);
CUTLASS_CHECK
(
status
);
status
=
gemm_op
.
initialize
(
arguments
,
workspace
.
get
());
CUTLASS_CHECK
(
status
);
gemm_op
(
dev_ctx
.
stream
());
}
struct
cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm75
,
cutlass
::
gemm
::
GemmShape
<
128
,
64
,
32
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
32
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
cutlass
::
half_t
,
8
,
cutlass
::
half_t
,
cutlass
::
half_t
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
,
2
,
8
,
8
,
cutlass
::
arch
::
OpMultiplyAdd
,
cutlass
::
ComplexTransform
::
kNone
,
cutlass
::
ComplexTransform
::
kNone
,
true
,
false
,
true
>
;
};
struct
cutlass_tensorop_h1688gemm_64x128_32x2_nn_align8
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm75
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
32
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
32
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
cutlass
::
half_t
,
8
,
cutlass
::
half_t
,
cutlass
::
half_t
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
,
2
,
8
,
8
,
cutlass
::
arch
::
OpMultiplyAdd
,
cutlass
::
ComplexTransform
::
kNone
,
cutlass
::
ComplexTransform
::
kNone
,
true
,
false
,
true
>
;
};
struct
cutlass_tensorop_h1688gemm_128x64_32x2_nn_align4
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm75
,
cutlass
::
gemm
::
GemmShape
<
128
,
64
,
32
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
32
,
32
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
cutlass
::
half_t
,
4
,
cutlass
::
half_t
,
cutlass
::
half_t
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
,
2
,
4
,
4
,
cutlass
::
arch
::
OpMultiplyAdd
,
cutlass
::
ComplexTransform
::
kNone
,
cutlass
::
ComplexTransform
::
kNone
,
true
,
false
,
true
>
;
};
struct
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm75
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
32
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
32
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
cutlass
::
half_t
,
4
,
cutlass
::
half_t
,
cutlass
::
half_t
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
,
2
,
4
,
4
,
cutlass
::
arch
::
OpMultiplyAdd
,
cutlass
::
ComplexTransform
::
kNone
,
cutlass
::
ComplexTransform
::
kNone
,
true
,
false
,
true
>
;
};
struct
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm75
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
32
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
32
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
cutlass
::
half_t
,
8
,
cutlass
::
half_t
,
cutlass
::
half_t
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
,
2
,
8
,
8
,
cutlass
::
arch
::
OpMultiplyAdd
,
cutlass
::
ComplexTransform
::
kNone
,
cutlass
::
ComplexTransform
::
kNone
,
true
,
false
,
true
>
;
};
struct
cutlass_tensorop_h16816gemm_64x64_64x5_nn_align8
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
16
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
cutlass
::
half_t
,
8
,
cutlass
::
half_t
,
cutlass
::
half_t
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
,
5
,
8
,
8
,
cutlass
::
arch
::
OpMultiplyAdd
,
cutlass
::
ComplexTransform
::
kNone
,
cutlass
::
ComplexTransform
::
kNone
,
true
,
false
,
true
>
;
};
struct
cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm75
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
32
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
32
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
cutlass
::
half_t
,
8
,
float
,
float
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
,
2
,
8
,
8
,
cutlass
::
arch
::
OpMultiplyAdd
,
cutlass
::
ComplexTransform
::
kNone
,
cutlass
::
ComplexTransform
::
kNone
,
true
,
false
,
true
>
;
};
struct
cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
half_t
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm75
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
32
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
32
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
cutlass
::
half_t
,
8
,
float
,
float
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
,
2
,
8
,
8
,
cutlass
::
arch
::
OpMultiplyAdd
,
cutlass
::
ComplexTransform
::
kNone
,
cutlass
::
ComplexTransform
::
kNone
,
true
,
false
,
true
>
;
};
struct
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
16
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
16
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
float
,
4
,
float
,
float
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
,
10
,
4
,
4
,
cutlass
::
arch
::
OpMultiplyAddFastF16
,
cutlass
::
ComplexTransform
::
kNone
,
cutlass
::
ComplexTransform
::
kNone
,
true
,
false
,
true
>
;
};
struct
cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
16
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
16
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
float
,
4
,
float
,
float
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
,
3
,
4
,
4
,
cutlass
::
arch
::
OpMultiplyAddFastF16
,
cutlass
::
ComplexTransform
::
kNone
,
cutlass
::
ComplexTransform
::
kNone
,
true
,
false
,
true
>
;
};
struct
cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
256
,
64
,
16
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
16
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
float
,
4
,
float
,
float
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
,
4
,
4
,
4
,
cutlass
::
arch
::
OpMultiplyAddFastF16
,
cutlass
::
ComplexTransform
::
kNone
,
cutlass
::
ComplexTransform
::
kNone
,
true
,
false
,
true
>
;
};
struct
cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
16
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
16
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
float
,
4
,
float
,
float
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
,
3
,
4
,
4
,
cutlass
::
arch
::
OpMultiplyAdd
,
cutlass
::
ComplexTransform
::
kNone
,
cutlass
::
ComplexTransform
::
kNone
,
true
,
false
,
true
>
;
};
struct
cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
16
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
16
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
float
,
4
,
float
,
float
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
,
6
,
4
,
4
,
cutlass
::
arch
::
OpMultiplyAddFastF16
,
cutlass
::
ComplexTransform
::
kNone
,
cutlass
::
ComplexTransform
::
kNone
,
true
,
false
,
true
>
;
};
struct
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
16
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
32
,
16
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
float
,
4
,
float
,
float
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
,
3
,
4
,
4
,
cutlass
::
arch
::
OpMultiplyAddFastF32
,
cutlass
::
ComplexTransform
::
kNone
,
cutlass
::
ComplexTransform
::
kNone
,
true
,
false
,
true
>
;
};
struct
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
double
,
cutlass
::
layout
::
RowMajor
,
double
,
cutlass
::
layout
::
RowMajor
,
double
,
cutlass
::
layout
::
RowMajor
,
double
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
16
,
32
,
16
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
16
,
16
>
,
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
4
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
double
,
1
,
double
,
double
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
,
5
,
1
,
1
,
cutlass
::
arch
::
OpMultiplyAdd
,
cutlass
::
ComplexTransform
::
kNone
,
cutlass
::
ComplexTransform
::
kNone
,
true
,
false
,
true
>
;
};
struct
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1
{
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversal
<
double
,
cutlass
::
layout
::
RowMajor
,
double
,
cutlass
::
layout
::
RowMajor
,
double
,
cutlass
::
layout
::
RowMajor
,
double
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
32
,
16
,
16
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
16
,
16
>
,
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
4
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
double
,
1
,
double
,
double
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
,
5
,
1
,
1
,
cutlass
::
arch
::
OpMultiplyAdd
,
cutlass
::
ComplexTransform
::
kNone
,
cutlass
::
ComplexTransform
::
kNone
,
true
,
false
,
true
>
;
};
}
// namespace sparse
}
// namespace phi
#endif
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录