Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0b98d1aa
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0b98d1aa
编写于
4月 13, 2023
作者:
U
umiswing
提交者:
GitHub
4月 13, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cutlass] Sparse conv3d backward fusion (#52361)
上级
1acb845a
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
507 addition
and
226 deletion
+507
-226
paddle/phi/kernels/autotune/auto_tune_base.h
paddle/phi/kernels/autotune/auto_tune_base.h
+38
-9
paddle/phi/kernels/autotune/cache.h
paddle/phi/kernels/autotune/cache.h
+34
-19
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
+135
-58
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
+17
-12
paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h
paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h
+146
-58
paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py
...se/gpu/cutlass_generator/gather_gemm_scatter_generator.py
+44
-9
paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_manifest.py
...rse/gpu/cutlass_generator/gather_gemm_scatter_manifest.py
+22
-16
paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_operation.py
...se/gpu/cutlass_generator/gather_gemm_scatter_operation.py
+11
-5
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
+60
-40
未找到文件。
paddle/phi/kernels/autotune/auto_tune_base.h
浏览文件 @
0b98d1aa
...
...
@@ -177,18 +177,34 @@ class MatmulAutoTuner
}
};
template
<
typename
T
,
typename
ReturnType
,
typename
...
Args
>
template
<
bool
TransposeA
,
bool
TransposeB
,
typename
T
,
typename
ReturnType
,
typename
...
Args
>
class
GatherGemmScatterAutoTuner
:
public
AutoTuneBase
<
T
,
KernelCallback
<
T
,
ReturnType
,
T
,
T
,
Args
...
>>
{
public:
static
GatherGemmScatterAutoTuner
<
T
,
ReturnType
,
Args
...
>*
Instance
(
ReturnType
(
*
func
)(
T
,
T
,
Args
...))
{
static
GatherGemmScatterAutoTuner
<
TransposeA
,
TransposeB
,
T
,
ReturnType
,
Args
...
>*
Instance
(
ReturnType
(
*
func
)(
T
,
T
,
Args
...))
{
static
std
::
once_flag
gather_gemm_scatter_init_flag
;
static
std
::
unique_ptr
<
GatherGemmScatterAutoTuner
<
T
,
ReturnType
,
Args
...
>>
static
std
::
unique_ptr
<
GatherGemmScatterAutoTuner
<
TransposeA
,
TransposeB
,
T
,
ReturnType
,
Args
...
>>
instance
;
std
::
call_once
(
gather_gemm_scatter_init_flag
,
[
&
]
{
auto
obj
=
MakeCallback
<
T
>
(
func
);
instance
.
reset
(
new
GatherGemmScatterAutoTuner
<
T
,
ReturnType
,
Args
...
>
);
instance
.
reset
(
new
GatherGemmScatterAutoTuner
<
TransposeA
,
TransposeB
,
T
,
ReturnType
,
Args
...
>
);
instance
->
AddCallBack
(
func
);
});
return
instance
.
get
();
...
...
@@ -201,7 +217,8 @@ class GatherGemmScatterAutoTuner
Args
...
args
)
{
this
->
is_init_
=
true
;
this
->
CheckKernelSize
();
auto
&
cache
=
AutoTuneCache
::
Instance
().
GetGatherGemmScatter
<
T
>
();
auto
&
cache
=
AutoTuneCache
::
Instance
()
.
GetGatherGemmScatter
<
T
,
TransposeA
,
TransposeB
>
();
if
(
cache
.
Find
(
key
))
{
auto
best_idx
=
cache
.
Get
(
key
);
...
...
@@ -250,10 +267,22 @@ class GatherGemmScatterAutoTuner
return
best_idx
;
}
};
template
<
typename
T
,
typename
ReturnType
,
typename
...
Args
>
static
GatherGemmScatterAutoTuner
<
T
,
ReturnType
,
Args
...
>*
template
<
bool
TransposeA
,
bool
TransposeB
,
typename
T
,
typename
ReturnType
,
typename
...
Args
>
static
GatherGemmScatterAutoTuner
<
TransposeA
,
TransposeB
,
T
,
ReturnType
,
Args
...
>*
MakeGatherGemmScatterTuner
(
ReturnType
(
*
func
)(
T
,
T
,
Args
...))
{
return
GatherGemmScatterAutoTuner
<
T
,
ReturnType
,
Args
...
>::
Instance
(
func
);
return
GatherGemmScatterAutoTuner
<
TransposeA
,
TransposeB
,
T
,
ReturnType
,
Args
...
>::
Instance
(
func
);
}
// Define the auto_tuner inital object.
...
...
paddle/phi/kernels/autotune/cache.h
浏览文件 @
0b98d1aa
...
...
@@ -47,13 +47,15 @@ enum class AlgorithmType {
kMatmul
=
5
,
kGatherGemmScatterFP16NN
=
6
,
kGatherGemmScatterFP32NN
=
7
,
kGatherGemmScatterFP32TN
=
8
,
kGatherGemmScatterFP32NT
=
9
,
#if !defined(PADDLE_WITH_CUDNN_FRONTEND)
kAlgorithmCount
=
8
kAlgorithmCount
=
10
#else
kConvForwardV8
=
8
,
kConvBackwardDataV8
=
9
,
kConvBackwardFilterV8
=
1
0
,
kAlgorithmCount
=
1
1
kConvForwardV8
=
10
,
kConvBackwardDataV8
=
11
,
kConvBackwardFilterV8
=
1
2
,
kAlgorithmCount
=
1
3
#endif
};
...
...
@@ -73,6 +75,17 @@ using CudnnV8AlgorithmsTypeMap =
std
::
unordered_map
<
int64_t
,
CudnnFrontendPlanCache
>
;
#endif
#define DEFINE_GET_GATHER_GEMM_SCATTER( \
dtype, transpose_a, transpose_b, algo_type) \
template <typename T, bool TransposeA, bool TransposeB> \
typename std::enable_if<std::is_same<T, dtype>::value && \
TransposeA == transpose_a && \
TransposeB == transpose_b, \
AlgorithmsCacheMap&>::type \
GetGatherGemmScatter() { \
return Get(algo_type); \
}
class
AutoTuneCache
{
public:
static
AutoTuneCache
&
Instance
()
{
...
...
@@ -89,20 +102,22 @@ class AutoTuneCache {
ConvAlgorithmsCacheMap
&
GetConv
(
const
AlgorithmType
&
algo_type
)
{
return
conv_auto_tune_map_
[
static_cast
<
int64_t
>
(
algo_type
)];
}
template
<
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
float
>::
value
,
AlgorithmsCacheMap
&>::
type
GetGatherGemmScatter
()
{
return
Get
(
AlgorithmType
::
kGatherGemmScatterFP32NN
);
}
template
<
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
,
AlgorithmsCacheMap
&>::
type
GetGatherGemmScatter
()
{
return
Get
(
AlgorithmType
::
kGatherGemmScatterFP16NN
);
}
DEFINE_GET_GATHER_GEMM_SCATTER
(
phi
::
dtype
::
float16
,
false
,
false
,
AlgorithmType
::
kGatherGemmScatterFP16NN
);
DEFINE_GET_GATHER_GEMM_SCATTER
(
float
,
false
,
false
,
AlgorithmType
::
kGatherGemmScatterFP32NN
);
DEFINE_GET_GATHER_GEMM_SCATTER
(
float
,
true
,
false
,
AlgorithmType
::
kGatherGemmScatterFP32TN
);
DEFINE_GET_GATHER_GEMM_SCATTER
(
float
,
false
,
true
,
AlgorithmType
::
kGatherGemmScatterFP32NT
);
#ifdef PADDLE_WITH_CUDNN_FRONTEND
CudnnFrontendPlanCache
&
GetConvV8
(
const
AlgorithmType
&
algo_type
)
{
...
...
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
浏览文件 @
0b98d1aa
...
...
@@ -24,9 +24,13 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
#ifdef PADDLE_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
#endif
namespace
phi
{
namespace
sparse
{
extern
size_t
workspace_size
;
// rulebook[3, rulebook_len]:
//[
...
...
@@ -130,34 +134,52 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_index_ptr
,
0
,
sizeof
(
int
)
*
x
.
nnz
()
*
2
,
dev_ctx
.
stream
());
GroupIndexsV2
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
rulebook_len
,
x
.
nnz
(),
kernel_size
,
offsets
[
kernel_size
/
2
],
rulebook_ptr
,
out_index_ptr
,
unique_value_ptr
);
#ifdef PADDLE_WITH_CUTLASS
bool
cutlass
=
true
;
if
(
dev_ctx
.
GetComputeCapability
()
<
80
)
cutlass
=
false
;
GatherV2
<
T
,
IntT
>
(
dev_ctx
,
x
.
values
().
data
<
T
>
(),
out_index_ptr
,
unique_value_ptr
,
x
.
nnz
(),
kernel_size
,
in_channels
,
2
,
in_features_ptr
);
if
(
in_channels
%
4
!=
0
||
out_channels
%
4
!=
0
)
cutlass
=
false
;
Gather
<
T
,
IntT
>
(
dev_ctx
,
out_grad
.
values
().
data
<
T
>
(),
rulebook_ptr
+
rulebook_len
,
rulebook_len
,
out_channels
,
out_grad_features_ptr
);
if
(
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
||
std
::
is_same
<
T
,
double
>::
value
)
cutlass
=
false
;
if
(
!
std
::
is_same
<
IntT
,
int32_t
>::
value
)
cutlass
=
false
;
if
(
!
cutlass
)
{
#endif
GroupIndexsV2
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
rulebook_len
,
x
.
nnz
(),
kernel_size
,
offsets
[
kernel_size
/
2
],
rulebook_ptr
,
out_index_ptr
,
unique_value_ptr
);
GatherV2
<
T
,
IntT
>
(
dev_ctx
,
x
.
values
().
data
<
T
>
(),
out_index_ptr
,
unique_value_ptr
,
x
.
nnz
(),
kernel_size
,
in_channels
,
2
,
in_features_ptr
);
Gather
<
T
,
IntT
>
(
dev_ctx
,
out_grad
.
values
().
data
<
T
>
(),
rulebook_ptr
+
rulebook_len
,
rulebook_len
,
out_channels
,
out_grad_features_ptr
);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
counter_ptr
[
i
]
<=
0
||
(
subm
&&
i
==
half_kernel_size
))
{
...
...
@@ -173,43 +195,98 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
T
*
tmp_d_x_ptr
=
d_x_features_ptr
+
offsets
[
i
]
*
in_channels
;
T
*
tmp_d_kernel_ptr
=
d_kernel_ptr
+
i
*
in_channels
*
out_channels
;
// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
blas
.
GEMM
(
CblasTrans
,
CblasNoTrans
,
K
,
N
,
M
,
static_cast
<
T
>
(
1
),
tmp_in_ptr
,
tmp_out_grad_ptr
,
static_cast
<
T
>
(
0
),
tmp_d_kernel_ptr
);
#ifdef PADDLE_WITH_CUTLASS
if
(
cutlass
)
{
const
IntT
*
gather_x_indices
=
rulebook_ptr
+
offsets
[
i
];
const
IntT
*
scatter_x_indices
=
rulebook_ptr
+
offsets
[
i
];
const
IntT
*
gather_out_indices
=
rulebook_ptr
+
rulebook_len
+
offsets
[
i
];
const
size_t
key
=
autotune
::
GenKey
(
M
/
features_num_range
,
N
,
K
);
// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
static
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
GatherGemmScatterDriver
<
T
,
IntT
,
true
,
false
>
(
dev_ctx
,
key
,
x
.
values
().
data
<
T
>
(),
out_grad
.
values
().
data
<
T
>
(),
tmp_d_kernel_ptr
,
tmp_d_kernel_ptr
,
in_channels
,
out_channels
,
counter_ptr
[
i
],
gather_x_indices
,
gather_out_indices
,
static_cast
<
const
IntT
*>
(
nullptr
),
static_cast
<
const
T
>
(
1.0
),
static_cast
<
const
T
>
(
0.0
),
&
workspace
);
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
GatherGemmScatterDriver
<
T
,
IntT
,
false
,
true
>
(
dev_ctx
,
key
,
out_grad
.
values
().
data
<
T
>
(),
tmp_kernel_ptr
,
x_grad_values_ptr
,
x_grad_values_ptr
,
counter_ptr
[
i
],
in_channels
,
out_channels
,
gather_out_indices
,
static_cast
<
const
IntT
*>
(
nullptr
),
scatter_x_indices
,
static_cast
<
const
T
>
(
1.0
),
static_cast
<
const
T
>
(
1.0
),
nullptr
);
}
else
{
#endif
// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
blas
.
GEMM
(
CblasTrans
,
CblasNoTrans
,
K
,
N
,
M
,
static_cast
<
T
>
(
1
),
tmp_in_ptr
,
tmp_out_grad_ptr
,
static_cast
<
T
>
(
0
),
tmp_d_kernel_ptr
);
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
blas
.
GEMM
(
CblasNoTrans
,
CblasTrans
,
M
,
K
,
N
,
static_cast
<
T
>
(
1
),
tmp_out_grad_ptr
,
tmp_kernel_ptr
,
static_cast
<
T
>
(
0
),
tmp_d_x_ptr
);
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
blas
.
GEMM
(
CblasNoTrans
,
CblasTrans
,
M
,
K
,
N
,
static_cast
<
T
>
(
1
),
tmp_out_grad_ptr
,
tmp_kernel_ptr
,
static_cast
<
T
>
(
0
),
tmp_d_x_ptr
);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
}
// 4. scatter
phi
::
funcs
::
sparse
::
ScatterV2
<
T
>
(
dev_ctx
,
d_x_features_ptr
,
out_index
.
data
<
int
>
(),
unique_value
.
data
<
int
>
(),
x_grad
->
nnz
(),
kernel_size
,
in_channels
,
2
,
x_grad_values_ptr
);
#ifdef PADDLE_WITH_CUTLASS
if
(
!
cutlass
)
{
#endif
phi
::
funcs
::
sparse
::
ScatterV2
<
T
>
(
dev_ctx
,
d_x_features_ptr
,
out_index
.
data
<
int
>
(),
unique_value
.
data
<
int
>
(),
x_grad
->
nnz
(),
kernel_size
,
in_channels
,
2
,
x_grad_values_ptr
);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
}
template
<
typename
T
,
typename
Context
>
...
...
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
浏览文件 @
0b98d1aa
...
...
@@ -154,18 +154,23 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const
IntT
*
gather_indices
=
rulebook_ptr
+
h_offsets_ptr
[
i
];
const
IntT
*
scatter_indices
=
rulebook_ptr
+
rulebook_len
+
h_offsets_ptr
[
i
];
GatherGemmScatterDriver
(
dev_ctx
,
x
.
non_zero_elements
().
data
<
T
>
(),
tmp_kernel_ptr
,
out_values_ptr
,
out_values_ptr
,
M
,
N
,
K
,
gather_indices
,
scatter_indices
,
static_cast
<
T
>
(
1.0
),
static_cast
<
T
>
(
1.0
));
const
size_t
key
=
autotune
::
GenKey
(
M
/
features_num_range
,
N
,
K
);
GatherGemmScatterDriver
<
T
,
IntT
,
false
,
false
>
(
dev_ctx
,
key
,
x
.
non_zero_elements
().
data
<
T
>
(),
tmp_kernel_ptr
,
out_values_ptr
,
out_values_ptr
,
M
,
N
,
K
,
gather_indices
,
static_cast
<
const
IntT
*>
(
nullptr
),
scatter_indices
,
static_cast
<
T
>
(
1.0
),
static_cast
<
T
>
(
1.0
),
nullptr
);
}
}
else
{
#endif
...
...
paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h
浏览文件 @
0b98d1aa
...
...
@@ -16,28 +16,41 @@
#ifdef PADDLE_WITH_CUTLASS
#include "cutlass/arch/mma.h"
#include "cutlass/device_kernel.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/half.h"
#include "cutlass/reduction/device/reduce_split_k.h"
#include "cutlass/reduction/thread/reduction_operators.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/device_memory.h"
#include "examples/common/helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace
phi
{
namespace
sparse
{
#define TYPEDEF_KERNEL_POINTER(kernel, dtype) \
typedef void (*kernel)(dtype const alpha, \
dtype const beta, \
const GPUContext& dev_ctx, \
const dtype* const a, \
const dtype* const b, \
const dtype* const c, \
dtype* const d, \
const int m, \
const int n, \
const int k, \
const int32_t* a_indices, \
const int32_t* c_d_indices);
size_t
constexpr
max_splitk_slices
=
256
;
size_t
constexpr
max_in_channels
=
256
;
size_t
constexpr
max_out_channels
=
256
;
static
size_t
workspace_size
=
sizeof
(
float
)
*
max_splitk_slices
*
max_in_channels
*
max_out_channels
;
#define TYPEDEF_KERNEL_POINTER(kernel, dtype) \
typedef void (*kernel)(dtype const alpha, \
dtype const beta, \
const GPUContext& dev_ctx, \
const dtype* const a, \
const dtype* const b, \
const dtype* const c, \
dtype* const d, \
const int m, \
const int n, \
const int k, \
const int32_t* a_indices, \
const int32_t* b_indices, \
const int32_t* c_d_indices, \
void* const workspace_ptr);
#define GATHER_GEMM_SCATTER_CHECK(status) \
{ \
cutlass::Status error = status; \
...
...
@@ -45,51 +58,126 @@ namespace sparse {
throw std::runtime_error(cutlassGetStatusString(error)); \
} \
}
#define DEFINE_LAUNCH_KERNEL(dtype, cutlass_type) \
template <typename Gemm> \
void launchKernel(dtype const alpha, \
dtype const beta, \
const GPUContext& dev_ctx, \
const dtype* const a, \
const dtype* const b, \
const dtype* const c, \
dtype* const d, \
const int m, \
const int n, \
const int k, \
const int32_t* a_indices, \
const int32_t* c_d_indices) { \
cutlass::gemm::GemmCoord problem_size_real({m, n, k}); \
int split_k_slices = 1; \
typename Gemm::Arguments arguments{ \
cutlass::gemm::GemmUniversalMode::kGemm, \
problem_size_real, \
split_k_slices, \
{static_cast<const cutlass_type>(static_cast<const float>(alpha)), \
static_cast<const cutlass_type>(static_cast<const float>(beta))}, \
reinterpret_cast<const cutlass_type* const>(a), \
reinterpret_cast<const cutlass_type* const>(b), \
reinterpret_cast<const cutlass_type* const>(c), \
reinterpret_cast<cutlass_type* const>(d), \
cutlass::layout::RowMajor().capacity(problem_size_real.mk()), \
cutlass::layout::RowMajor().capacity(problem_size_real.kn()), \
cutlass::layout::RowMajor().capacity(problem_size_real.mn()), \
cutlass::layout::RowMajor().capacity(problem_size_real.mn()), \
problem_size_real.k(), \
problem_size_real.n(), \
problem_size_real.n(), \
problem_size_real.n(), \
a_indices, \
nullptr, \
c_d_indices}; \
size_t workspace_size = Gemm::get_workspace_size(arguments); \
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); \
Gemm gemm_op; \
cutlass::Status status = gemm_op.can_implement(arguments); \
GATHER_GEMM_SCATTER_CHECK(status); \
status = gemm_op.initialize(arguments, workspace.get()); \
GATHER_GEMM_SCATTER_CHECK(status); \
gemm_op(dev_ctx.stream()); \
#define DEFINE_LAUNCH_KERNEL(dtype, cutlass_type) \
template <typename Config> \
void launchKernel(dtype const alpha, \
dtype const beta, \
const GPUContext& dev_ctx, \
const dtype* const a, \
const dtype* const b, \
const dtype* const c, \
dtype* const d, \
const int m, \
const int n, \
const int k, \
const int32_t* a_indices, \
const int32_t* b_indices, \
const int32_t* c_d_indices, \
void* const workspace_ptr) { \
cutlass::gemm::GemmCoord problem_size_real({m, n, k}); \
using Gemm = typename Config::Gemm; \
int split_k_slices = std::max(std::min(64, k / 128), 1); \
typename Gemm::Arguments arguments{ \
Config::Mode, \
problem_size_real, \
split_k_slices, \
{static_cast<const cutlass_type>(static_cast<const float>(alpha)), \
static_cast<const cutlass_type>(static_cast<const float>(beta))}, \
reinterpret_cast<const cutlass_type* const>(a), \
reinterpret_cast<const cutlass_type* const>(b), \
reinterpret_cast<const cutlass_type* const>(c), \
reinterpret_cast<cutlass_type* const>(d), \
m * k, \
k * n, \
m * n, \
m * n, \
std::is_same<typename Gemm::Base::LayoutA, \
cutlass::layout::RowMajor>::value \
? problem_size_real.k() \
: problem_size_real.m(), \
std::is_same<typename Gemm::Base::LayoutB, \
cutlass::layout::RowMajor>::value \
? problem_size_real.n() \
: problem_size_real.k(), \
std::is_same<typename Gemm::Base::LayoutC, \
cutlass::layout::RowMajor>::value \
? problem_size_real.n() \
: problem_size_real.m(), \
std::is_same<typename Gemm::Base::LayoutC, \
cutlass::layout::RowMajor>::value \
? problem_size_real.n() \
: problem_size_real.m(), \
a_indices, \
b_indices, \
c_d_indices}; \
cutlass::device_memory::allocation<uint8_t>* const real_workspace_ptr = \
static_cast<cutlass::device_memory::allocation<uint8_t>* const>( \
workspace_ptr); \
if (Config::Mode == \
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { \
size_t current_workspace_size = Gemm::get_workspace_size(arguments); \
if (current_workspace_size > workspace_size) { \
workspace_size = current_workspace_size; \
real_workspace_ptr->reset(workspace_size); \
} \
\
arguments.ptr_D = real_workspace_ptr->get(); \
} \
Gemm gemm_op; \
cutlass::Status status = gemm_op.can_implement(arguments); \
GATHER_GEMM_SCATTER_CHECK(status); \
if (Config::Mode == \
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { \
status = gemm_op.initialize(arguments, real_workspace_ptr->get()); \
} else { \
cutlass::device_memory::allocation<uint8_t> empty_workspace(0); \
status = gemm_op.initialize(arguments, empty_workspace.get()); \
} \
GATHER_GEMM_SCATTER_CHECK(status); \
gemm_op(dev_ctx.stream()); \
if (Config::Mode == \
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { \
using ReductionOp = cutlass::reduction::thread::ReduceAdd< \
typename Gemm::ElementAccumulator, \
typename Gemm::EpilogueOutputOp::ElementAccumulator, \
Gemm::EpilogueOutputOp::kCount>; \
\
using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< \
cutlass::MatrixShape<4, 32 * Gemm::EpilogueOutputOp::kCount>, \
typename Gemm::EpilogueOutputOp, \
ReductionOp>; \
using ReductionDevice = \
typename cutlass::reduction::device::ReduceSplitK<ReductionKernel>; \
ReductionDevice reduction_op; \
int splitk_gemm_stride = n; \
cutlass::layout::RowMajor splitk_gemm_layout(splitk_gemm_stride); \
void* workspace_gemm_ptr = real_workspace_ptr->get(); \
cutlass::TensorRef<typename Gemm::ElementAccumulator, \
cutlass::layout::RowMajor> \
ref_workspace(reinterpret_cast<typename Gemm::ElementAccumulator*>( \
workspace_gemm_ptr), \
splitk_gemm_layout); \
cutlass::TensorRef<typename Gemm::Base::ElementC, \
typename Gemm::Base::LayoutC> \
ref_c(reinterpret_cast<typename Gemm::Base::ElementC* const>(d), \
splitk_gemm_layout); \
cutlass::TensorRef<typename Gemm::Base::ElementC, \
typename Gemm::Base::LayoutC> \
ref_d(reinterpret_cast<typename Gemm::Base::ElementC* const>(d), \
splitk_gemm_layout); \
typename ReductionDevice::Arguments reduction_args( \
problem_size_real.mn(), \
split_k_slices, \
static_cast<size_t>(problem_size_real.m() * problem_size_real.n()), \
ref_workspace, \
ref_d, \
ref_c, \
{static_cast<const cutlass_type>(static_cast<const float>(alpha)), \
static_cast<const cutlass_type>(static_cast<const float>(beta))}); \
status = reduction_op.initialize(reduction_args); \
GATHER_GEMM_SCATTER_CHECK(status); \
reduction_op(dev_ctx.stream()); \
} \
}
TYPEDEF_KERNEL_POINTER
(
fp16_gather_gemm_scatter
,
phi
::
dtype
::
float16
)
...
...
paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py
浏览文件 @
0b98d1aa
...
...
@@ -97,7 +97,7 @@ def CreateGatherGemmScatterOperator(
return
operations
def
GenerateSM80_TensorOp_16816
(
manifest
,
cuda_version
):
def
GenerateSM80_TensorOp_16816
(
manifest
,
cuda_version
,
debug
=
False
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
11
,
0
):
return
...
...
@@ -191,6 +191,12 @@ def GenerateSM80_TensorOp_16816(manifest, cuda_version):
[
64
,
64
,
64
],
5
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
if
debug
:
tile_descriptions
=
[
TileDescription
(
[
256
,
128
,
32
],
3
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
data_type
=
[
math_inst
.
element_a
,
...
...
@@ -218,13 +224,15 @@ def GenerateSM80_TensorOp_16816(manifest, cuda_version):
)
def
GenerateSM80_TensorOp_1688
(
manifest
,
cuda_version
):
def
GenerateSM80_TensorOp_1688
(
manifest
,
cuda_version
,
debug
=
False
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
11
,
0
):
return
layouts
=
[
(
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
(
LayoutType
.
RowMajor
,
LayoutType
.
ColumnMajor
,
LayoutType
.
RowMajor
),
(
LayoutType
.
ColumnMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
]
math_instructions
=
[
...
...
@@ -302,6 +310,13 @@ def GenerateSM80_TensorOp_1688(manifest, cuda_version):
),
]
if
debug
:
tile_descriptions
=
[
TileDescription
(
[
256
,
128
,
16
],
3
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
data_type
=
[
math_inst
.
element_a
,
math_inst
.
element_b
,
...
...
@@ -325,13 +340,15 @@ def GenerateSM80_TensorOp_1688(manifest, cuda_version):
)
def
GenerateSM80_TensorOp_1688_fast_math
(
manifest
,
cuda_version
):
def
GenerateSM80_TensorOp_1688_fast_math
(
manifest
,
cuda_version
,
debug
=
False
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
11
,
0
):
return
layouts
=
[
(
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
(
LayoutType
.
RowMajor
,
LayoutType
.
ColumnMajor
,
LayoutType
.
RowMajor
),
(
LayoutType
.
ColumnMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
]
math_instructions
=
[
...
...
@@ -409,6 +426,13 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version):
),
]
if
debug
:
tile_descriptions
=
[
TileDescription
(
[
256
,
128
,
16
],
3
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
data_type
=
[
DataType
.
f32
,
DataType
.
f32
,
DataType
.
f32
,
DataType
.
f32
]
CreateGatherGemmScatterOperator
(
...
...
@@ -416,13 +440,17 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version):
)
def
GenerateSM80_TensorOp_1688_fast_fp32_math
(
manifest
,
cuda_version
):
def
GenerateSM80_TensorOp_1688_fast_fp32_math
(
manifest
,
cuda_version
,
debug
=
False
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
11
,
0
):
return
layouts
=
[
(
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
(
LayoutType
.
RowMajor
,
LayoutType
.
ColumnMajor
,
LayoutType
.
RowMajor
),
(
LayoutType
.
ColumnMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
]
math_instructions
=
[
...
...
@@ -482,6 +510,13 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version):
),
]
if
debug
:
tile_descriptions
=
[
TileDescription
(
[
128
,
128
,
16
],
4
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
data_type
=
[
DataType
.
f32
,
DataType
.
f32
,
DataType
.
f32
,
DataType
.
f32
]
CreateGatherGemmScatterOperator
(
...
...
@@ -489,11 +524,11 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version):
)
def
GenerateSM80
(
manifest
,
cuda_version
):
GenerateSM80_TensorOp_16816
(
manifest
,
cuda_version
)
GenerateSM80_TensorOp_1688
(
manifest
,
cuda_version
)
GenerateSM80_TensorOp_1688_fast_math
(
manifest
,
cuda_version
)
GenerateSM80_TensorOp_1688_fast_fp32_math
(
manifest
,
cuda_version
)
def
GenerateSM80
(
manifest
,
cuda_version
,
debug
=
False
):
GenerateSM80_TensorOp_16816
(
manifest
,
cuda_version
,
debug
)
GenerateSM80_TensorOp_1688
(
manifest
,
cuda_version
,
debug
)
GenerateSM80_TensorOp_1688_fast_math
(
manifest
,
cuda_version
,
debug
)
GenerateSM80_TensorOp_1688_fast_fp32_math
(
manifest
,
cuda_version
,
debug
)
class
KernelCfg
:
...
...
paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_manifest.py
浏览文件 @
0b98d1aa
...
...
@@ -41,12 +41,12 @@ namespace sparse {
} // namespace phi
#endif
"""
self
.
fp16_kernels_list
=
(
"
static std::vector<fp16_gather_gemm_scatter> fp16_kernels = {
\n
"
)
self
.
fp32_kernels_list
=
(
"st
atic std::vector<fp32_gather_gemm_scatter> fp32_kernels = {
\n
"
)
self
.
kernels_lists
=
{
"
hnn"
:
"static std::vector<fp16_gather_gemm_scatter> fp16_nn_kernels = {"
,
"snn"
:
"static std::vector<fp32_gather_gemm_scatter> fp32_nn_kernels = {"
,
"snt"
:
"static std::vector<fp32_gather_gemm_scatter> fp32_nt_kernels = {"
,
"st
n"
:
"static std::vector<fp32_gather_gemm_scatter> fp32_tn_kernels = {"
,
}
def
__enter__
(
self
):
self
.
operation_path
=
os
.
path
.
join
(
...
...
@@ -78,19 +78,25 @@ namespace sparse {
self
.
source_files
.
append
(
configuration_emitter
.
configuration_path
)
self
.
configurations
.
append
(
configuration_name
)
if
'h'
==
operations
[
0
].
short_math_name
():
self
.
fp16_kernels_list
+=
(
if
operations
[
0
].
layout_name
()
==
'tn'
:
self
.
kernels_lists
[
operations
[
0
].
short_math_name
()
+
operations
[
0
].
layout_name
()
]
+=
(
"""
launchKernel<"""
+
configuration_name
+
"::Gemm>,"
+
"<cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel"
+
">>,"
)
if
's'
==
operations
[
0
].
short_math_name
():
self
.
fp32_kernels_list
+=
(
else
:
self
.
kernels_lists
[
operations
[
0
].
short_math_name
()
+
operations
[
0
].
layout_name
()
]
+=
(
"""
launchKernel<"""
+
configuration_name
+
"
::Gemm
>,"
+
"
<>
>,"
)
self
.
top_level_file
.
write
(
...
...
@@ -117,11 +123,11 @@ launchKernel<"""
)
)
self
.
fp16_kernels_list
+=
"
\n
};
\n
"
self
.
fp32_kernels_list
+=
"
\n
};
\n
"
for
k
,
v
in
self
.
kernels_lists
.
items
():
self
.
kernels_lists
[
k
]
+=
"
\n
};
\n
"
self
.
top_level_file
.
write
(
self
.
namespace_template
)
self
.
top_level_file
.
write
(
self
.
fp16_kernels_list
)
self
.
top_level_file
.
write
(
self
.
fp32_kernels_list
)
for
k
,
v
in
self
.
kernels_lists
.
items
():
self
.
top_level_file
.
write
(
v
)
self
.
top_level_file
.
write
(
self
.
epilogue_template
)
self
.
top_level_file
.
close
()
...
...
paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_operation.py
浏览文件 @
0b98d1aa
...
...
@@ -52,6 +52,8 @@ class EmitGatherGemmScatterInstance(EmitGemmInstance):
"""
self
.
gemm_template
=
"""
// Gemm operator ${operation_name}
template<cutlass::gemm::GemmUniversalMode Mode_ =
cutlass::gemm::GemmUniversalMode::kGemm>
struct ${operation_name} {
using Gemm =
cutlass::gemm::device::GemmUniversal<
...
...
@@ -75,10 +77,11 @@ struct ${operation_name} {
${math_operation},
${transform_a},
${transform_b},
true
, // gather a
false
, // gather b
true
// scatter d
${gather_a}
, // gather a
${gather_b}
, // gather b
${scatter_d}
// scatter d
>;
static const cutlass::gemm::GemmUniversalMode Mode = Mode_;
};
"""
...
...
@@ -192,6 +195,9 @@ struct ${operation_name} {
'math_operation'
:
MathOperationTag
[
operation
.
tile_description
.
math_instruction
.
math_operation
],
'gather_a'
:
'true'
,
'gather_b'
:
str
(
operation
.
layout_name
()
==
'tn'
).
lower
(),
'scatter_d'
:
str
(
operation
.
layout_name
()
!=
'tn'
).
lower
(),
}
return
SubstituteTemplate
(
gemm_template
,
values
)
...
...
@@ -295,8 +301,8 @@ class GatherGemmScatterOperation(GemmOperation):
B
,
C
,
element_epilogue
,
epilogue_functor
=
EpilogueFunctor
.
LinearCombination
,
swizzling_functor
=
SwizzlingFunctor
.
Identity8
,
epilogue_functor
,
swizzling_functor
,
)
self
.
ShortLayoutTypeNames
=
{
LayoutType
.
ColumnMajor
:
't'
,
...
...
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
浏览文件 @
0b98d1aa
...
...
@@ -24,29 +24,50 @@ namespace phi {
namespace
sparse
{
// To reduce tuning time, map shape (m,n,k) to (m/features_num_range,n,k) so
// that shapes in this range share the same key.
// that shapes
with
in this range share the same key.
constexpr
int
features_num_range
=
10000
;
#define DEFINE_GATHER_GEMM_SCATTER_DRIVER(dtype, kernels) \
template <typename T, typename IntT> \
typename std::enable_if<std::is_same<T, dtype>::value && \
std::is_same<IntT, int32_t>::value, \
void>::type \
GatherGemmScatterDriver(const phi::GPUContext& ctx, \
const T* const a, \
const T* const b, \
const T* const c, \
T* const d, \
const int& m, \
const int& n, \
const int& k, \
const IntT* a_indices, \
const IntT* c_d_indices, \
T alpha, \
T beta) { \
auto* tuner = autotune::MakeGatherGemmScatterTuner(kernels[0]); \
template
<
typename
T
,
typename
IntT
,
bool
TransposeA
,
bool
TransposeB
>
void
GatherGemmScatterDriver
(
const
phi
::
GPUContext
&
ctx
,
const
size_t
key
,
const
T
*
const
a
,
const
T
*
const
b
,
const
T
*
const
c
,
T
*
const
d
,
const
int
&
m
,
const
int
&
n
,
const
int
&
k
,
const
IntT
*
a_indices
,
const
IntT
*
b_indices
,
const
IntT
*
c_d_indices
,
T
alpha
,
T
beta
,
cutlass
::
device_memory
::
allocation
<
uint8_t
>*
const
workspace_ptr
)
{}
#define EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER( \
T, kernels, transpose_a, transpose_b) \
template <> \
inline void GatherGemmScatterDriver<T, int32_t, transpose_a, transpose_b>( \
const phi::GPUContext& ctx, \
const size_t key, \
const T* const a, \
const T* const b, \
const T* const c, \
T* const d, \
const int& m, \
const int& n, \
const int& k, \
const int32_t* a_indices, \
const int32_t* b_indices, \
const int32_t* c_d_indices, \
T alpha, \
T beta, \
cutlass::device_memory::allocation<uint8_t>* const workspace_ptr) { \
auto* tuner = \
autotune::MakeGatherGemmScatterTuner<transpose_a, transpose_b>( \
kernels[0]); \
for (auto i = 1; i < kernels.size(); i++) tuner->AddCallBack(kernels[i]); \
size_t key = autotune::GenKey(m / features_num_range, n, k); \
tuner->Run(ctx, \
key, \
alpha, \
...
...
@@ -60,28 +81,27 @@ constexpr int features_num_range = 10000;
n, \
k, \
a_indices, \
c_d_indices); \
b_indices, \
c_d_indices, \
workspace_ptr); \
}
template
<
typename
T
,
typename
IntT
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
double
>::
value
||
!
std
::
is_same
<
IntT
,
int32_t
>::
value
,
void
>::
type
GatherGemmScatterDriver
(
const
phi
::
GPUContext
&
ctx
,
const
T
*
const
a
,
const
T
*
const
b
,
const
T
*
const
c
,
T
*
const
d
,
const
int
&
m
,
const
int
&
n
,
const
int
&
k
,
const
IntT
*
a_indices
,
const
IntT
*
c_d_indices
,
T
alpha
,
T
beta
)
{}
DEFINE_GATHER_GEMM_SCATTER_DRIVER
(
phi
::
dtype
::
float16
,
fp16_kernels
)
DEFINE_GATHER_GEMM_SCATTER_DRIVER
(
float
,
fp32_kernels
)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER
(
phi
::
dtype
::
float16
,
fp16_nn_kernels
,
false
,
false
)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER
(
float
,
fp32_nn_kernels
,
false
,
false
)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER
(
float
,
fp32_nt_kernels
,
false
,
true
)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER
(
float
,
fp32_tn_kernels
,
true
,
false
)
}
// namespace sparse
}
// namespace phi
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录