Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0b98d1aa
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0b98d1aa
编写于
4月 13, 2023
作者:
U
umiswing
提交者:
GitHub
4月 13, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cutlass] Sparse conv3d backward fusion (#52361)
上级
1acb845a
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
507 addition
and
226 deletion
+507
-226
paddle/phi/kernels/autotune/auto_tune_base.h
paddle/phi/kernels/autotune/auto_tune_base.h
+38
-9
paddle/phi/kernels/autotune/cache.h
paddle/phi/kernels/autotune/cache.h
+34
-19
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
+135
-58
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
+17
-12
paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h
paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h
+146
-58
paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py
...se/gpu/cutlass_generator/gather_gemm_scatter_generator.py
+44
-9
paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_manifest.py
...rse/gpu/cutlass_generator/gather_gemm_scatter_manifest.py
+22
-16
paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_operation.py
...se/gpu/cutlass_generator/gather_gemm_scatter_operation.py
+11
-5
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
+60
-40
未找到文件。
paddle/phi/kernels/autotune/auto_tune_base.h
浏览文件 @
0b98d1aa
...
@@ -177,18 +177,34 @@ class MatmulAutoTuner
...
@@ -177,18 +177,34 @@ class MatmulAutoTuner
}
}
};
};
template
<
typename
T
,
typename
ReturnType
,
typename
...
Args
>
template
<
bool
TransposeA
,
bool
TransposeB
,
typename
T
,
typename
ReturnType
,
typename
...
Args
>
class
GatherGemmScatterAutoTuner
class
GatherGemmScatterAutoTuner
:
public
AutoTuneBase
<
T
,
KernelCallback
<
T
,
ReturnType
,
T
,
T
,
Args
...
>>
{
:
public
AutoTuneBase
<
T
,
KernelCallback
<
T
,
ReturnType
,
T
,
T
,
Args
...
>>
{
public:
public:
static
GatherGemmScatterAutoTuner
<
T
,
ReturnType
,
Args
...
>*
Instance
(
static
GatherGemmScatterAutoTuner
<
TransposeA
,
ReturnType
(
*
func
)(
T
,
T
,
Args
...))
{
TransposeB
,
T
,
ReturnType
,
Args
...
>*
Instance
(
ReturnType
(
*
func
)(
T
,
T
,
Args
...))
{
static
std
::
once_flag
gather_gemm_scatter_init_flag
;
static
std
::
once_flag
gather_gemm_scatter_init_flag
;
static
std
::
unique_ptr
<
GatherGemmScatterAutoTuner
<
T
,
ReturnType
,
Args
...
>>
static
std
::
unique_ptr
<
GatherGemmScatterAutoTuner
<
TransposeA
,
TransposeB
,
T
,
ReturnType
,
Args
...
>>
instance
;
instance
;
std
::
call_once
(
gather_gemm_scatter_init_flag
,
[
&
]
{
std
::
call_once
(
gather_gemm_scatter_init_flag
,
[
&
]
{
auto
obj
=
MakeCallback
<
T
>
(
func
);
auto
obj
=
MakeCallback
<
T
>
(
func
);
instance
.
reset
(
new
GatherGemmScatterAutoTuner
<
T
,
ReturnType
,
Args
...
>
);
instance
.
reset
(
new
GatherGemmScatterAutoTuner
<
TransposeA
,
TransposeB
,
T
,
ReturnType
,
Args
...
>
);
instance
->
AddCallBack
(
func
);
instance
->
AddCallBack
(
func
);
});
});
return
instance
.
get
();
return
instance
.
get
();
...
@@ -201,7 +217,8 @@ class GatherGemmScatterAutoTuner
...
@@ -201,7 +217,8 @@ class GatherGemmScatterAutoTuner
Args
...
args
)
{
Args
...
args
)
{
this
->
is_init_
=
true
;
this
->
is_init_
=
true
;
this
->
CheckKernelSize
();
this
->
CheckKernelSize
();
auto
&
cache
=
AutoTuneCache
::
Instance
().
GetGatherGemmScatter
<
T
>
();
auto
&
cache
=
AutoTuneCache
::
Instance
()
.
GetGatherGemmScatter
<
T
,
TransposeA
,
TransposeB
>
();
if
(
cache
.
Find
(
key
))
{
if
(
cache
.
Find
(
key
))
{
auto
best_idx
=
cache
.
Get
(
key
);
auto
best_idx
=
cache
.
Get
(
key
);
...
@@ -250,10 +267,22 @@ class GatherGemmScatterAutoTuner
...
@@ -250,10 +267,22 @@ class GatherGemmScatterAutoTuner
return
best_idx
;
return
best_idx
;
}
}
};
};
template
<
typename
T
,
typename
ReturnType
,
typename
...
Args
>
template
<
bool
TransposeA
,
static
GatherGemmScatterAutoTuner
<
T
,
ReturnType
,
Args
...
>*
bool
TransposeB
,
typename
T
,
typename
ReturnType
,
typename
...
Args
>
static
GatherGemmScatterAutoTuner
<
TransposeA
,
TransposeB
,
T
,
ReturnType
,
Args
...
>*
MakeGatherGemmScatterTuner
(
ReturnType
(
*
func
)(
T
,
T
,
Args
...))
{
MakeGatherGemmScatterTuner
(
ReturnType
(
*
func
)(
T
,
T
,
Args
...))
{
return
GatherGemmScatterAutoTuner
<
T
,
ReturnType
,
Args
...
>::
Instance
(
func
);
return
GatherGemmScatterAutoTuner
<
TransposeA
,
TransposeB
,
T
,
ReturnType
,
Args
...
>::
Instance
(
func
);
}
}
// Define the auto_tuner inital object.
// Define the auto_tuner inital object.
...
...
paddle/phi/kernels/autotune/cache.h
浏览文件 @
0b98d1aa
...
@@ -47,13 +47,15 @@ enum class AlgorithmType {
...
@@ -47,13 +47,15 @@ enum class AlgorithmType {
kMatmul
=
5
,
kMatmul
=
5
,
kGatherGemmScatterFP16NN
=
6
,
kGatherGemmScatterFP16NN
=
6
,
kGatherGemmScatterFP32NN
=
7
,
kGatherGemmScatterFP32NN
=
7
,
kGatherGemmScatterFP32TN
=
8
,
kGatherGemmScatterFP32NT
=
9
,
#if !defined(PADDLE_WITH_CUDNN_FRONTEND)
#if !defined(PADDLE_WITH_CUDNN_FRONTEND)
kAlgorithmCount
=
8
kAlgorithmCount
=
10
#else
#else
kConvForwardV8
=
8
,
kConvForwardV8
=
10
,
kConvBackwardDataV8
=
9
,
kConvBackwardDataV8
=
11
,
kConvBackwardFilterV8
=
1
0
,
kConvBackwardFilterV8
=
1
2
,
kAlgorithmCount
=
1
1
kAlgorithmCount
=
1
3
#endif
#endif
};
};
...
@@ -73,6 +75,17 @@ using CudnnV8AlgorithmsTypeMap =
...
@@ -73,6 +75,17 @@ using CudnnV8AlgorithmsTypeMap =
std
::
unordered_map
<
int64_t
,
CudnnFrontendPlanCache
>
;
std
::
unordered_map
<
int64_t
,
CudnnFrontendPlanCache
>
;
#endif
#endif
#define DEFINE_GET_GATHER_GEMM_SCATTER( \
dtype, transpose_a, transpose_b, algo_type) \
template <typename T, bool TransposeA, bool TransposeB> \
typename std::enable_if<std::is_same<T, dtype>::value && \
TransposeA == transpose_a && \
TransposeB == transpose_b, \
AlgorithmsCacheMap&>::type \
GetGatherGemmScatter() { \
return Get(algo_type); \
}
class
AutoTuneCache
{
class
AutoTuneCache
{
public:
public:
static
AutoTuneCache
&
Instance
()
{
static
AutoTuneCache
&
Instance
()
{
...
@@ -89,20 +102,22 @@ class AutoTuneCache {
...
@@ -89,20 +102,22 @@ class AutoTuneCache {
ConvAlgorithmsCacheMap
&
GetConv
(
const
AlgorithmType
&
algo_type
)
{
ConvAlgorithmsCacheMap
&
GetConv
(
const
AlgorithmType
&
algo_type
)
{
return
conv_auto_tune_map_
[
static_cast
<
int64_t
>
(
algo_type
)];
return
conv_auto_tune_map_
[
static_cast
<
int64_t
>
(
algo_type
)];
}
}
DEFINE_GET_GATHER_GEMM_SCATTER
(
phi
::
dtype
::
float16
,
template
<
typename
T
>
false
,
typename
std
::
enable_if
<
std
::
is_same
<
T
,
float
>::
value
,
false
,
AlgorithmsCacheMap
&>::
type
AlgorithmType
::
kGatherGemmScatterFP16NN
);
GetGatherGemmScatter
()
{
DEFINE_GET_GATHER_GEMM_SCATTER
(
float
,
return
Get
(
AlgorithmType
::
kGatherGemmScatterFP32NN
);
false
,
}
false
,
AlgorithmType
::
kGatherGemmScatterFP32NN
);
template
<
typename
T
>
DEFINE_GET_GATHER_GEMM_SCATTER
(
float
,
typename
std
::
enable_if
<
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
,
true
,
AlgorithmsCacheMap
&>::
type
false
,
GetGatherGemmScatter
()
{
AlgorithmType
::
kGatherGemmScatterFP32TN
);
return
Get
(
AlgorithmType
::
kGatherGemmScatterFP16NN
);
DEFINE_GET_GATHER_GEMM_SCATTER
(
float
,
}
false
,
true
,
AlgorithmType
::
kGatherGemmScatterFP32NT
);
#ifdef PADDLE_WITH_CUDNN_FRONTEND
#ifdef PADDLE_WITH_CUDNN_FRONTEND
CudnnFrontendPlanCache
&
GetConvV8
(
const
AlgorithmType
&
algo_type
)
{
CudnnFrontendPlanCache
&
GetConvV8
(
const
AlgorithmType
&
algo_type
)
{
...
...
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
浏览文件 @
0b98d1aa
...
@@ -24,9 +24,13 @@ limitations under the License. */
...
@@ -24,9 +24,13 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
#ifdef PADDLE_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
#endif
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
extern
size_t
workspace_size
;
// rulebook[3, rulebook_len]:
// rulebook[3, rulebook_len]:
//[
//[
...
@@ -130,34 +134,52 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
...
@@ -130,34 +134,52 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_index_ptr
,
0
,
sizeof
(
int
)
*
x
.
nnz
()
*
2
,
dev_ctx
.
stream
());
out_index_ptr
,
0
,
sizeof
(
int
)
*
x
.
nnz
()
*
2
,
dev_ctx
.
stream
());
GroupIndexsV2
<<<
config
.
block_per_grid
,
#ifdef PADDLE_WITH_CUTLASS
config
.
thread_per_block
,
bool
cutlass
=
true
;
0
,
if
(
dev_ctx
.
GetComputeCapability
()
<
80
)
cutlass
=
false
;
dev_ctx
.
stream
()
>>>
(
rulebook_len
,
x
.
nnz
(),
kernel_size
,
offsets
[
kernel_size
/
2
],
rulebook_ptr
,
out_index_ptr
,
unique_value_ptr
);
GatherV2
<
T
,
IntT
>
(
dev_ctx
,
if
(
in_channels
%
4
!=
0
||
out_channels
%
4
!=
0
)
cutlass
=
false
;
x
.
values
().
data
<
T
>
(),
out_index_ptr
,
unique_value_ptr
,
x
.
nnz
(),
kernel_size
,
in_channels
,
2
,
in_features_ptr
);
Gather
<
T
,
IntT
>
(
dev_ctx
,
if
(
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
||
out_grad
.
values
().
data
<
T
>
(),
std
::
is_same
<
T
,
double
>::
value
)
rulebook_ptr
+
rulebook_len
,
cutlass
=
false
;
rulebook_len
,
out_channels
,
out_grad_features_ptr
);
if
(
!
std
::
is_same
<
IntT
,
int32_t
>::
value
)
cutlass
=
false
;
if
(
!
cutlass
)
{
#endif
GroupIndexsV2
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
rulebook_len
,
x
.
nnz
(),
kernel_size
,
offsets
[
kernel_size
/
2
],
rulebook_ptr
,
out_index_ptr
,
unique_value_ptr
);
GatherV2
<
T
,
IntT
>
(
dev_ctx
,
x
.
values
().
data
<
T
>
(),
out_index_ptr
,
unique_value_ptr
,
x
.
nnz
(),
kernel_size
,
in_channels
,
2
,
in_features_ptr
);
Gather
<
T
,
IntT
>
(
dev_ctx
,
out_grad
.
values
().
data
<
T
>
(),
rulebook_ptr
+
rulebook_len
,
rulebook_len
,
out_channels
,
out_grad_features_ptr
);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
counter_ptr
[
i
]
<=
0
||
(
subm
&&
i
==
half_kernel_size
))
{
if
(
counter_ptr
[
i
]
<=
0
||
(
subm
&&
i
==
half_kernel_size
))
{
...
@@ -173,43 +195,98 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
...
@@ -173,43 +195,98 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
T
*
tmp_d_x_ptr
=
d_x_features_ptr
+
offsets
[
i
]
*
in_channels
;
T
*
tmp_d_x_ptr
=
d_x_features_ptr
+
offsets
[
i
]
*
in_channels
;
T
*
tmp_d_kernel_ptr
=
d_kernel_ptr
+
i
*
in_channels
*
out_channels
;
T
*
tmp_d_kernel_ptr
=
d_kernel_ptr
+
i
*
in_channels
*
out_channels
;
// call gemm: d_kernel = transpose(x) * out_grad
#ifdef PADDLE_WITH_CUTLASS
// (in_channels, n) * (n, out_channels)
if
(
cutlass
)
{
blas
.
GEMM
(
CblasTrans
,
const
IntT
*
gather_x_indices
=
rulebook_ptr
+
offsets
[
i
];
CblasNoTrans
,
const
IntT
*
scatter_x_indices
=
rulebook_ptr
+
offsets
[
i
];
K
,
const
IntT
*
gather_out_indices
=
rulebook_ptr
+
rulebook_len
+
offsets
[
i
];
N
,
const
size_t
key
=
autotune
::
GenKey
(
M
/
features_num_range
,
N
,
K
);
M
,
// call gemm: d_kernel = transpose(x) * out_grad
static_cast
<
T
>
(
1
),
// (in_channels, n) * (n, out_channels)
tmp_in_ptr
,
static
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
tmp_out_grad_ptr
,
workspace_size
);
static_cast
<
T
>
(
0
),
GatherGemmScatterDriver
<
T
,
IntT
,
true
,
false
>
(
tmp_d_kernel_ptr
);
dev_ctx
,
key
,
x
.
values
().
data
<
T
>
(),
out_grad
.
values
().
data
<
T
>
(),
tmp_d_kernel_ptr
,
tmp_d_kernel_ptr
,
in_channels
,
out_channels
,
counter_ptr
[
i
],
gather_x_indices
,
gather_out_indices
,
static_cast
<
const
IntT
*>
(
nullptr
),
static_cast
<
const
T
>
(
1.0
),
static_cast
<
const
T
>
(
0.0
),
&
workspace
);
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
GatherGemmScatterDriver
<
T
,
IntT
,
false
,
true
>
(
dev_ctx
,
key
,
out_grad
.
values
().
data
<
T
>
(),
tmp_kernel_ptr
,
x_grad_values_ptr
,
x_grad_values_ptr
,
counter_ptr
[
i
],
in_channels
,
out_channels
,
gather_out_indices
,
static_cast
<
const
IntT
*>
(
nullptr
),
scatter_x_indices
,
static_cast
<
const
T
>
(
1.0
),
static_cast
<
const
T
>
(
1.0
),
nullptr
);
}
else
{
#endif
// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
blas
.
GEMM
(
CblasTrans
,
CblasNoTrans
,
K
,
N
,
M
,
static_cast
<
T
>
(
1
),
tmp_in_ptr
,
tmp_out_grad_ptr
,
static_cast
<
T
>
(
0
),
tmp_d_kernel_ptr
);
// call gemm: d_x = out_grad * transpose(kernel)
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
// (n, out_channels) * (out_channels, in_channels)
blas
.
GEMM
(
CblasNoTrans
,
blas
.
GEMM
(
CblasNoTrans
,
CblasTrans
,
CblasTrans
,
M
,
M
,
K
,
K
,
N
,
N
,
static_cast
<
T
>
(
1
),
static_cast
<
T
>
(
1
),
tmp_out_grad_ptr
,
tmp_out_grad_ptr
,
tmp_kernel_ptr
,
tmp_kernel_ptr
,
static_cast
<
T
>
(
0
),
static_cast
<
T
>
(
0
),
tmp_d_x_ptr
);
tmp_d_x_ptr
);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
}
}
// 4. scatter
// 4. scatter
phi
::
funcs
::
sparse
::
ScatterV2
<
T
>
(
dev_ctx
,
#ifdef PADDLE_WITH_CUTLASS
d_x_features_ptr
,
if
(
!
cutlass
)
{
out_index
.
data
<
int
>
(),
#endif
unique_value
.
data
<
int
>
(),
phi
::
funcs
::
sparse
::
ScatterV2
<
T
>
(
dev_ctx
,
x_grad
->
nnz
(),
d_x_features_ptr
,
kernel_size
,
out_index
.
data
<
int
>
(),
in_channels
,
unique_value
.
data
<
int
>
(),
2
,
x_grad
->
nnz
(),
x_grad_values_ptr
);
kernel_size
,
in_channels
,
2
,
x_grad_values_ptr
);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
...
...
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
浏览文件 @
0b98d1aa
...
@@ -154,18 +154,23 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
...
@@ -154,18 +154,23 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const
IntT
*
gather_indices
=
rulebook_ptr
+
h_offsets_ptr
[
i
];
const
IntT
*
gather_indices
=
rulebook_ptr
+
h_offsets_ptr
[
i
];
const
IntT
*
scatter_indices
=
const
IntT
*
scatter_indices
=
rulebook_ptr
+
rulebook_len
+
h_offsets_ptr
[
i
];
rulebook_ptr
+
rulebook_len
+
h_offsets_ptr
[
i
];
GatherGemmScatterDriver
(
dev_ctx
,
const
size_t
key
=
autotune
::
GenKey
(
M
/
features_num_range
,
N
,
K
);
x
.
non_zero_elements
().
data
<
T
>
(),
GatherGemmScatterDriver
<
T
,
IntT
,
false
,
false
>
(
tmp_kernel_ptr
,
dev_ctx
,
out_values_ptr
,
key
,
out_values_ptr
,
x
.
non_zero_elements
().
data
<
T
>
(),
M
,
tmp_kernel_ptr
,
N
,
out_values_ptr
,
K
,
out_values_ptr
,
gather_indices
,
M
,
scatter_indices
,
N
,
static_cast
<
T
>
(
1.0
),
K
,
static_cast
<
T
>
(
1.0
));
gather_indices
,
static_cast
<
const
IntT
*>
(
nullptr
),
scatter_indices
,
static_cast
<
T
>
(
1.0
),
static_cast
<
T
>
(
1.0
),
nullptr
);
}
}
}
else
{
}
else
{
#endif
#endif
...
...
paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h
浏览文件 @
0b98d1aa
...
@@ -16,28 +16,41 @@
...
@@ -16,28 +16,41 @@
#ifdef PADDLE_WITH_CUTLASS
#ifdef PADDLE_WITH_CUTLASS
#include "cutlass/arch/mma.h"
#include "cutlass/arch/mma.h"
#include "cutlass/device_kernel.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/half.h"
#include "cutlass/half.h"
#include "cutlass/reduction/device/reduce_split_k.h"
#include "cutlass/reduction/thread/reduction_operators.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/device_memory.h"
#include "examples/common/helper.h"
#include "examples/common/helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
#define TYPEDEF_KERNEL_POINTER(kernel, dtype) \
size_t
constexpr
max_splitk_slices
=
256
;
typedef void (*kernel)(dtype const alpha, \
size_t
constexpr
max_in_channels
=
256
;
dtype const beta, \
size_t
constexpr
max_out_channels
=
256
;
const GPUContext& dev_ctx, \
static
size_t
workspace_size
=
const dtype* const a, \
sizeof
(
float
)
*
max_splitk_slices
*
max_in_channels
*
max_out_channels
;
const dtype* const b, \
const dtype* const c, \
#define TYPEDEF_KERNEL_POINTER(kernel, dtype) \
dtype* const d, \
typedef void (*kernel)(dtype const alpha, \
const int m, \
dtype const beta, \
const int n, \
const GPUContext& dev_ctx, \
const int k, \
const dtype* const a, \
const int32_t* a_indices, \
const dtype* const b, \
const int32_t* c_d_indices);
const dtype* const c, \
dtype* const d, \
const int m, \
const int n, \
const int k, \
const int32_t* a_indices, \
const int32_t* b_indices, \
const int32_t* c_d_indices, \
void* const workspace_ptr);
#define GATHER_GEMM_SCATTER_CHECK(status) \
#define GATHER_GEMM_SCATTER_CHECK(status) \
{ \
{ \
cutlass::Status error = status; \
cutlass::Status error = status; \
...
@@ -45,51 +58,126 @@ namespace sparse {
...
@@ -45,51 +58,126 @@ namespace sparse {
throw std::runtime_error(cutlassGetStatusString(error)); \
throw std::runtime_error(cutlassGetStatusString(error)); \
} \
} \
}
}
#define DEFINE_LAUNCH_KERNEL(dtype, cutlass_type) \
#define DEFINE_LAUNCH_KERNEL(dtype, cutlass_type) \
template <typename Gemm> \
template <typename Config> \
void launchKernel(dtype const alpha, \
void launchKernel(dtype const alpha, \
dtype const beta, \
dtype const beta, \
const GPUContext& dev_ctx, \
const GPUContext& dev_ctx, \
const dtype* const a, \
const dtype* const a, \
const dtype* const b, \
const dtype* const b, \
const dtype* const c, \
const dtype* const c, \
dtype* const d, \
dtype* const d, \
const int m, \
const int m, \
const int n, \
const int n, \
const int k, \
const int k, \
const int32_t* a_indices, \
const int32_t* a_indices, \
const int32_t* c_d_indices) { \
const int32_t* b_indices, \
cutlass::gemm::GemmCoord problem_size_real({m, n, k}); \
const int32_t* c_d_indices, \
int split_k_slices = 1; \
void* const workspace_ptr) { \
typename Gemm::Arguments arguments{ \
cutlass::gemm::GemmCoord problem_size_real({m, n, k}); \
cutlass::gemm::GemmUniversalMode::kGemm, \
using Gemm = typename Config::Gemm; \
problem_size_real, \
int split_k_slices = std::max(std::min(64, k / 128), 1); \
split_k_slices, \
typename Gemm::Arguments arguments{ \
{static_cast<const cutlass_type>(static_cast<const float>(alpha)), \
Config::Mode, \
static_cast<const cutlass_type>(static_cast<const float>(beta))}, \
problem_size_real, \
reinterpret_cast<const cutlass_type* const>(a), \
split_k_slices, \
reinterpret_cast<const cutlass_type* const>(b), \
{static_cast<const cutlass_type>(static_cast<const float>(alpha)), \
reinterpret_cast<const cutlass_type* const>(c), \
static_cast<const cutlass_type>(static_cast<const float>(beta))}, \
reinterpret_cast<cutlass_type* const>(d), \
reinterpret_cast<const cutlass_type* const>(a), \
cutlass::layout::RowMajor().capacity(problem_size_real.mk()), \
reinterpret_cast<const cutlass_type* const>(b), \
cutlass::layout::RowMajor().capacity(problem_size_real.kn()), \
reinterpret_cast<const cutlass_type* const>(c), \
cutlass::layout::RowMajor().capacity(problem_size_real.mn()), \
reinterpret_cast<cutlass_type* const>(d), \
cutlass::layout::RowMajor().capacity(problem_size_real.mn()), \
m * k, \
problem_size_real.k(), \
k * n, \
problem_size_real.n(), \
m * n, \
problem_size_real.n(), \
m * n, \
problem_size_real.n(), \
std::is_same<typename Gemm::Base::LayoutA, \
a_indices, \
cutlass::layout::RowMajor>::value \
nullptr, \
? problem_size_real.k() \
c_d_indices}; \
: problem_size_real.m(), \
size_t workspace_size = Gemm::get_workspace_size(arguments); \
std::is_same<typename Gemm::Base::LayoutB, \
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); \
cutlass::layout::RowMajor>::value \
Gemm gemm_op; \
? problem_size_real.n() \
cutlass::Status status = gemm_op.can_implement(arguments); \
: problem_size_real.k(), \
GATHER_GEMM_SCATTER_CHECK(status); \
std::is_same<typename Gemm::Base::LayoutC, \
status = gemm_op.initialize(arguments, workspace.get()); \
cutlass::layout::RowMajor>::value \
GATHER_GEMM_SCATTER_CHECK(status); \
? problem_size_real.n() \
gemm_op(dev_ctx.stream()); \
: problem_size_real.m(), \
std::is_same<typename Gemm::Base::LayoutC, \
cutlass::layout::RowMajor>::value \
? problem_size_real.n() \
: problem_size_real.m(), \
a_indices, \
b_indices, \
c_d_indices}; \
cutlass::device_memory::allocation<uint8_t>* const real_workspace_ptr = \
static_cast<cutlass::device_memory::allocation<uint8_t>* const>( \
workspace_ptr); \
if (Config::Mode == \
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { \
size_t current_workspace_size = Gemm::get_workspace_size(arguments); \
if (current_workspace_size > workspace_size) { \
workspace_size = current_workspace_size; \
real_workspace_ptr->reset(workspace_size); \
} \
\
arguments.ptr_D = real_workspace_ptr->get(); \
} \
Gemm gemm_op; \
cutlass::Status status = gemm_op.can_implement(arguments); \
GATHER_GEMM_SCATTER_CHECK(status); \
if (Config::Mode == \
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { \
status = gemm_op.initialize(arguments, real_workspace_ptr->get()); \
} else { \
cutlass::device_memory::allocation<uint8_t> empty_workspace(0); \
status = gemm_op.initialize(arguments, empty_workspace.get()); \
} \
GATHER_GEMM_SCATTER_CHECK(status); \
gemm_op(dev_ctx.stream()); \
if (Config::Mode == \
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { \
using ReductionOp = cutlass::reduction::thread::ReduceAdd< \
typename Gemm::ElementAccumulator, \
typename Gemm::EpilogueOutputOp::ElementAccumulator, \
Gemm::EpilogueOutputOp::kCount>; \
\
using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< \
cutlass::MatrixShape<4, 32 * Gemm::EpilogueOutputOp::kCount>, \
typename Gemm::EpilogueOutputOp, \
ReductionOp>; \
using ReductionDevice = \
typename cutlass::reduction::device::ReduceSplitK<ReductionKernel>; \
ReductionDevice reduction_op; \
int splitk_gemm_stride = n; \
cutlass::layout::RowMajor splitk_gemm_layout(splitk_gemm_stride); \
void* workspace_gemm_ptr = real_workspace_ptr->get(); \
cutlass::TensorRef<typename Gemm::ElementAccumulator, \
cutlass::layout::RowMajor> \
ref_workspace(reinterpret_cast<typename Gemm::ElementAccumulator*>( \
workspace_gemm_ptr), \
splitk_gemm_layout); \
cutlass::TensorRef<typename Gemm::Base::ElementC, \
typename Gemm::Base::LayoutC> \
ref_c(reinterpret_cast<typename Gemm::Base::ElementC* const>(d), \
splitk_gemm_layout); \
cutlass::TensorRef<typename Gemm::Base::ElementC, \
typename Gemm::Base::LayoutC> \
ref_d(reinterpret_cast<typename Gemm::Base::ElementC* const>(d), \
splitk_gemm_layout); \
typename ReductionDevice::Arguments reduction_args( \
problem_size_real.mn(), \
split_k_slices, \
static_cast<size_t>(problem_size_real.m() * problem_size_real.n()), \
ref_workspace, \
ref_d, \
ref_c, \
{static_cast<const cutlass_type>(static_cast<const float>(alpha)), \
static_cast<const cutlass_type>(static_cast<const float>(beta))}); \
status = reduction_op.initialize(reduction_args); \
GATHER_GEMM_SCATTER_CHECK(status); \
reduction_op(dev_ctx.stream()); \
} \
}
}
TYPEDEF_KERNEL_POINTER
(
fp16_gather_gemm_scatter
,
phi
::
dtype
::
float16
)
TYPEDEF_KERNEL_POINTER
(
fp16_gather_gemm_scatter
,
phi
::
dtype
::
float16
)
...
...
paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py
浏览文件 @
0b98d1aa
...
@@ -97,7 +97,7 @@ def CreateGatherGemmScatterOperator(
...
@@ -97,7 +97,7 @@ def CreateGatherGemmScatterOperator(
return
operations
return
operations
def
GenerateSM80_TensorOp_16816
(
manifest
,
cuda_version
):
def
GenerateSM80_TensorOp_16816
(
manifest
,
cuda_version
,
debug
=
False
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
11
,
0
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
11
,
0
):
return
return
...
@@ -191,6 +191,12 @@ def GenerateSM80_TensorOp_16816(manifest, cuda_version):
...
@@ -191,6 +191,12 @@ def GenerateSM80_TensorOp_16816(manifest, cuda_version):
[
64
,
64
,
64
],
5
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
[
64
,
64
,
64
],
5
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
),
]
]
if
debug
:
tile_descriptions
=
[
TileDescription
(
[
256
,
128
,
32
],
3
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
data_type
=
[
data_type
=
[
math_inst
.
element_a
,
math_inst
.
element_a
,
...
@@ -218,13 +224,15 @@ def GenerateSM80_TensorOp_16816(manifest, cuda_version):
...
@@ -218,13 +224,15 @@ def GenerateSM80_TensorOp_16816(manifest, cuda_version):
)
)
def
GenerateSM80_TensorOp_1688
(
manifest
,
cuda_version
):
def
GenerateSM80_TensorOp_1688
(
manifest
,
cuda_version
,
debug
=
False
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
11
,
0
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
11
,
0
):
return
return
layouts
=
[
layouts
=
[
(
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
(
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
(
LayoutType
.
RowMajor
,
LayoutType
.
ColumnMajor
,
LayoutType
.
RowMajor
),
(
LayoutType
.
ColumnMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
]
]
math_instructions
=
[
math_instructions
=
[
...
@@ -302,6 +310,13 @@ def GenerateSM80_TensorOp_1688(manifest, cuda_version):
...
@@ -302,6 +310,13 @@ def GenerateSM80_TensorOp_1688(manifest, cuda_version):
),
),
]
]
if
debug
:
tile_descriptions
=
[
TileDescription
(
[
256
,
128
,
16
],
3
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
data_type
=
[
data_type
=
[
math_inst
.
element_a
,
math_inst
.
element_a
,
math_inst
.
element_b
,
math_inst
.
element_b
,
...
@@ -325,13 +340,15 @@ def GenerateSM80_TensorOp_1688(manifest, cuda_version):
...
@@ -325,13 +340,15 @@ def GenerateSM80_TensorOp_1688(manifest, cuda_version):
)
)
def
GenerateSM80_TensorOp_1688_fast_math
(
manifest
,
cuda_version
):
def
GenerateSM80_TensorOp_1688_fast_math
(
manifest
,
cuda_version
,
debug
=
False
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
11
,
0
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
11
,
0
):
return
return
layouts
=
[
layouts
=
[
(
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
(
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
(
LayoutType
.
RowMajor
,
LayoutType
.
ColumnMajor
,
LayoutType
.
RowMajor
),
(
LayoutType
.
ColumnMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
]
]
math_instructions
=
[
math_instructions
=
[
...
@@ -409,6 +426,13 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version):
...
@@ -409,6 +426,13 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version):
),
),
]
]
if
debug
:
tile_descriptions
=
[
TileDescription
(
[
256
,
128
,
16
],
3
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
data_type
=
[
DataType
.
f32
,
DataType
.
f32
,
DataType
.
f32
,
DataType
.
f32
]
data_type
=
[
DataType
.
f32
,
DataType
.
f32
,
DataType
.
f32
,
DataType
.
f32
]
CreateGatherGemmScatterOperator
(
CreateGatherGemmScatterOperator
(
...
@@ -416,13 +440,17 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version):
...
@@ -416,13 +440,17 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version):
)
)
def
GenerateSM80_TensorOp_1688_fast_fp32_math
(
manifest
,
cuda_version
):
def
GenerateSM80_TensorOp_1688_fast_fp32_math
(
manifest
,
cuda_version
,
debug
=
False
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
11
,
0
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
11
,
0
):
return
return
layouts
=
[
layouts
=
[
(
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
(
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
(
LayoutType
.
RowMajor
,
LayoutType
.
ColumnMajor
,
LayoutType
.
RowMajor
),
(
LayoutType
.
ColumnMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
]
]
math_instructions
=
[
math_instructions
=
[
...
@@ -482,6 +510,13 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version):
...
@@ -482,6 +510,13 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version):
),
),
]
]
if
debug
:
tile_descriptions
=
[
TileDescription
(
[
128
,
128
,
16
],
4
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
data_type
=
[
DataType
.
f32
,
DataType
.
f32
,
DataType
.
f32
,
DataType
.
f32
]
data_type
=
[
DataType
.
f32
,
DataType
.
f32
,
DataType
.
f32
,
DataType
.
f32
]
CreateGatherGemmScatterOperator
(
CreateGatherGemmScatterOperator
(
...
@@ -489,11 +524,11 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version):
...
@@ -489,11 +524,11 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version):
)
)
def
GenerateSM80
(
manifest
,
cuda_version
):
def
GenerateSM80
(
manifest
,
cuda_version
,
debug
=
False
):
GenerateSM80_TensorOp_16816
(
manifest
,
cuda_version
)
GenerateSM80_TensorOp_16816
(
manifest
,
cuda_version
,
debug
)
GenerateSM80_TensorOp_1688
(
manifest
,
cuda_version
)
GenerateSM80_TensorOp_1688
(
manifest
,
cuda_version
,
debug
)
GenerateSM80_TensorOp_1688_fast_math
(
manifest
,
cuda_version
)
GenerateSM80_TensorOp_1688_fast_math
(
manifest
,
cuda_version
,
debug
)
GenerateSM80_TensorOp_1688_fast_fp32_math
(
manifest
,
cuda_version
)
GenerateSM80_TensorOp_1688_fast_fp32_math
(
manifest
,
cuda_version
,
debug
)
class
KernelCfg
:
class
KernelCfg
:
...
...
paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_manifest.py
浏览文件 @
0b98d1aa
...
@@ -41,12 +41,12 @@ namespace sparse {
...
@@ -41,12 +41,12 @@ namespace sparse {
} // namespace phi
} // namespace phi
#endif
#endif
"""
"""
self
.
fp16_kernels_list
=
(
self
.
kernels_lists
=
{
"
static std::vector<fp16_gather_gemm_scatter> fp16_kernels = {
\n
"
"
hnn"
:
"static std::vector<fp16_gather_gemm_scatter> fp16_nn_kernels = {"
,
)
"snn"
:
"static std::vector<fp32_gather_gemm_scatter> fp32_nn_kernels = {"
,
self
.
fp32_kernels_list
=
(
"snt"
:
"static std::vector<fp32_gather_gemm_scatter> fp32_nt_kernels = {"
,
"st
atic std::vector<fp32_gather_gemm_scatter> fp32_kernels = {
\n
"
"st
n"
:
"static std::vector<fp32_gather_gemm_scatter> fp32_tn_kernels = {"
,
)
}
def
__enter__
(
self
):
def
__enter__
(
self
):
self
.
operation_path
=
os
.
path
.
join
(
self
.
operation_path
=
os
.
path
.
join
(
...
@@ -78,19 +78,25 @@ namespace sparse {
...
@@ -78,19 +78,25 @@ namespace sparse {
self
.
source_files
.
append
(
configuration_emitter
.
configuration_path
)
self
.
source_files
.
append
(
configuration_emitter
.
configuration_path
)
self
.
configurations
.
append
(
configuration_name
)
self
.
configurations
.
append
(
configuration_name
)
if
'h'
==
operations
[
0
].
short_math_name
():
self
.
fp16_kernels_list
+=
(
if
operations
[
0
].
layout_name
()
==
'tn'
:
self
.
kernels_lists
[
operations
[
0
].
short_math_name
()
+
operations
[
0
].
layout_name
()
]
+=
(
"""
"""
launchKernel<"""
launchKernel<"""
+
configuration_name
+
configuration_name
+
"::Gemm>,"
+
"<cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel"
+
">>,"
)
)
if
's'
==
operations
[
0
].
short_math_name
():
else
:
self
.
fp32_kernels_list
+=
(
self
.
kernels_lists
[
operations
[
0
].
short_math_name
()
+
operations
[
0
].
layout_name
()
]
+=
(
"""
"""
launchKernel<"""
launchKernel<"""
+
configuration_name
+
configuration_name
+
"
::Gemm
>,"
+
"
<>
>,"
)
)
self
.
top_level_file
.
write
(
self
.
top_level_file
.
write
(
...
@@ -117,11 +123,11 @@ launchKernel<"""
...
@@ -117,11 +123,11 @@ launchKernel<"""
)
)
)
)
self
.
fp16_kernels_list
+=
"
\n
};
\n
"
for
k
,
v
in
self
.
kernels_lists
.
items
():
self
.
fp32_kernels_list
+=
"
\n
};
\n
"
self
.
kernels_lists
[
k
]
+=
"
\n
};
\n
"
self
.
top_level_file
.
write
(
self
.
namespace_template
)
self
.
top_level_file
.
write
(
self
.
namespace_template
)
self
.
top_level_file
.
write
(
self
.
fp16_kernels_list
)
for
k
,
v
in
self
.
kernels_lists
.
items
():
self
.
top_level_file
.
write
(
self
.
fp32_kernels_list
)
self
.
top_level_file
.
write
(
v
)
self
.
top_level_file
.
write
(
self
.
epilogue_template
)
self
.
top_level_file
.
write
(
self
.
epilogue_template
)
self
.
top_level_file
.
close
()
self
.
top_level_file
.
close
()
...
...
paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_operation.py
浏览文件 @
0b98d1aa
...
@@ -52,6 +52,8 @@ class EmitGatherGemmScatterInstance(EmitGemmInstance):
...
@@ -52,6 +52,8 @@ class EmitGatherGemmScatterInstance(EmitGemmInstance):
"""
"""
self
.
gemm_template
=
"""
self
.
gemm_template
=
"""
// Gemm operator ${operation_name}
// Gemm operator ${operation_name}
template<cutlass::gemm::GemmUniversalMode Mode_ =
cutlass::gemm::GemmUniversalMode::kGemm>
struct ${operation_name} {
struct ${operation_name} {
using Gemm =
using Gemm =
cutlass::gemm::device::GemmUniversal<
cutlass::gemm::device::GemmUniversal<
...
@@ -75,10 +77,11 @@ struct ${operation_name} {
...
@@ -75,10 +77,11 @@ struct ${operation_name} {
${math_operation},
${math_operation},
${transform_a},
${transform_a},
${transform_b},
${transform_b},
true
, // gather a
${gather_a}
, // gather a
false
, // gather b
${gather_b}
, // gather b
true
// scatter d
${scatter_d}
// scatter d
>;
>;
static const cutlass::gemm::GemmUniversalMode Mode = Mode_;
};
};
"""
"""
...
@@ -192,6 +195,9 @@ struct ${operation_name} {
...
@@ -192,6 +195,9 @@ struct ${operation_name} {
'math_operation'
:
MathOperationTag
[
'math_operation'
:
MathOperationTag
[
operation
.
tile_description
.
math_instruction
.
math_operation
operation
.
tile_description
.
math_instruction
.
math_operation
],
],
'gather_a'
:
'true'
,
'gather_b'
:
str
(
operation
.
layout_name
()
==
'tn'
).
lower
(),
'scatter_d'
:
str
(
operation
.
layout_name
()
!=
'tn'
).
lower
(),
}
}
return
SubstituteTemplate
(
gemm_template
,
values
)
return
SubstituteTemplate
(
gemm_template
,
values
)
...
@@ -295,8 +301,8 @@ class GatherGemmScatterOperation(GemmOperation):
...
@@ -295,8 +301,8 @@ class GatherGemmScatterOperation(GemmOperation):
B
,
B
,
C
,
C
,
element_epilogue
,
element_epilogue
,
epilogue_functor
=
EpilogueFunctor
.
LinearCombination
,
epilogue_functor
,
swizzling_functor
=
SwizzlingFunctor
.
Identity8
,
swizzling_functor
,
)
)
self
.
ShortLayoutTypeNames
=
{
self
.
ShortLayoutTypeNames
=
{
LayoutType
.
ColumnMajor
:
't'
,
LayoutType
.
ColumnMajor
:
't'
,
...
...
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
浏览文件 @
0b98d1aa
...
@@ -24,29 +24,50 @@ namespace phi {
...
@@ -24,29 +24,50 @@ namespace phi {
namespace
sparse
{
namespace
sparse
{
// To reduce tuning time, map shape (m,n,k) to (m/features_num_range,n,k) so
// To reduce tuning time, map shape (m,n,k) to (m/features_num_range,n,k) so
// that shapes in this range share the same key.
// that shapes
with
in this range share the same key.
constexpr
int
features_num_range
=
10000
;
constexpr
int
features_num_range
=
10000
;
#define DEFINE_GATHER_GEMM_SCATTER_DRIVER(dtype, kernels) \
template
<
typename
T
,
typename
IntT
,
bool
TransposeA
,
bool
TransposeB
>
template <typename T, typename IntT> \
void
GatherGemmScatterDriver
(
typename std::enable_if<std::is_same<T, dtype>::value && \
const
phi
::
GPUContext
&
ctx
,
std::is_same<IntT, int32_t>::value, \
const
size_t
key
,
void>::type \
const
T
*
const
a
,
GatherGemmScatterDriver(const phi::GPUContext& ctx, \
const
T
*
const
b
,
const T* const a, \
const
T
*
const
c
,
const T* const b, \
T
*
const
d
,
const T* const c, \
const
int
&
m
,
T* const d, \
const
int
&
n
,
const int& m, \
const
int
&
k
,
const int& n, \
const
IntT
*
a_indices
,
const int& k, \
const
IntT
*
b_indices
,
const IntT* a_indices, \
const
IntT
*
c_d_indices
,
const IntT* c_d_indices, \
T
alpha
,
T alpha, \
T
beta
,
T beta) { \
cutlass
::
device_memory
::
allocation
<
uint8_t
>*
const
workspace_ptr
)
{}
auto* tuner = autotune::MakeGatherGemmScatterTuner(kernels[0]); \
#define EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER( \
T, kernels, transpose_a, transpose_b) \
template <> \
inline void GatherGemmScatterDriver<T, int32_t, transpose_a, transpose_b>( \
const phi::GPUContext& ctx, \
const size_t key, \
const T* const a, \
const T* const b, \
const T* const c, \
T* const d, \
const int& m, \
const int& n, \
const int& k, \
const int32_t* a_indices, \
const int32_t* b_indices, \
const int32_t* c_d_indices, \
T alpha, \
T beta, \
cutlass::device_memory::allocation<uint8_t>* const workspace_ptr) { \
auto* tuner = \
autotune::MakeGatherGemmScatterTuner<transpose_a, transpose_b>( \
kernels[0]); \
for (auto i = 1; i < kernels.size(); i++) tuner->AddCallBack(kernels[i]); \
for (auto i = 1; i < kernels.size(); i++) tuner->AddCallBack(kernels[i]); \
size_t key = autotune::GenKey(m / features_num_range, n, k); \
tuner->Run(ctx, \
tuner->Run(ctx, \
key, \
key, \
alpha, \
alpha, \
...
@@ -60,28 +81,27 @@ constexpr int features_num_range = 10000;
...
@@ -60,28 +81,27 @@ constexpr int features_num_range = 10000;
n, \
n, \
k, \
k, \
a_indices, \
a_indices, \
c_d_indices); \
b_indices, \
c_d_indices, \
workspace_ptr); \
}
}
template
<
typename
T
,
typename
IntT
>
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER
(
phi
::
dtype
::
float16
,
typename
std
::
enable_if
<
std
::
is_same
<
T
,
double
>::
value
||
fp16_nn_kernels
,
!
std
::
is_same
<
IntT
,
int32_t
>::
value
,
false
,
void
>::
type
false
)
GatherGemmScatterDriver
(
const
phi
::
GPUContext
&
ctx
,
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER
(
float
,
const
T
*
const
a
,
fp32_nn_kernels
,
const
T
*
const
b
,
false
,
const
T
*
const
c
,
false
)
T
*
const
d
,
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER
(
float
,
const
int
&
m
,
fp32_nt_kernels
,
const
int
&
n
,
false
,
const
int
&
k
,
true
)
const
IntT
*
a_indices
,
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER
(
float
,
const
IntT
*
c_d_indices
,
fp32_tn_kernels
,
T
alpha
,
true
,
T
beta
)
{}
false
)
DEFINE_GATHER_GEMM_SCATTER_DRIVER
(
phi
::
dtype
::
float16
,
fp16_kernels
)
DEFINE_GATHER_GEMM_SCATTER_DRIVER
(
float
,
fp32_kernels
)
}
// namespace sparse
}
// namespace sparse
}
// namespace phi
}
// namespace phi
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录