Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wmsofts
Paddle
提交
12d43da9
P
Paddle
项目概览
wmsofts
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
12d43da9
编写于
3月 15, 2023
作者:
U
umiswing
提交者:
GitHub
3月 15, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Auto tune for cutlass (#50809)
上级
be9515f2
变更
10
展开全部
隐藏空白更改
内联
并排
Showing
10 changed file
with
908 addition
and
842 deletion
+908
-842
cmake/external/cutlass.cmake
cmake/external/cutlass.cmake
+5
-3
paddle/phi/kernels/autotune/auto_tune_base.h
paddle/phi/kernels/autotune/auto_tune_base.h
+79
-0
paddle/phi/kernels/autotune/cache.h
paddle/phi/kernels/autotune/cache.h
+21
-5
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
+18
-14
paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h
paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h
+103
-0
paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py
...se/gpu/cutlass_generator/gather_gemm_scatter_generator.py
+552
-0
paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_manifest.py
...rse/gpu/cutlass_generator/gather_gemm_scatter_manifest.py
+56
-3
paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_operation.py
...se/gpu/cutlass_generator/gather_gemm_scatter_operation.py
+14
-10
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu
+0
-194
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
+60
-613
未找到文件。
cmake/external/cutlass.cmake
浏览文件 @
12d43da9
...
...
@@ -39,12 +39,14 @@ ExternalProject_Add(
UPDATE_COMMAND
""
CONFIGURE_COMMAND
""
BUILD_COMMAND
rm -rf
${
CMAKE_SOURCE_DIR
}
/paddle/phi/kernels/sparse/gpu/cutlass_generator/build &&
mkdir -p
${
CMAKE_SOURCE_DIR
}
/paddle/phi/kernels/sparse/gpu/cutlass/build/generated/gemm
${
CMAKE_SOURCE_DIR
}
/paddle/phi/kernels/sparse/gpu/cutlass
_generator
/build/generated/gemm
&&
${
PYTHON_EXECUTABLE
}
-B
${
CMAKE_SOURCE_DIR
}
/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py
${
CMAKE_SOURCE_DIR
}
/paddle/phi/kernels/sparse/gpu/cutlass
_generator
/gather_gemm_scatter_generator.py
"
${
THIRD_PARTY_PATH
}
/cutlass/src/extern_cutlass/tools/library/scripts/"
"
${
CMAKE_SOURCE_DIR
}
/paddle/phi/kernels/sparse/gpu/cutlass/build"
"
${
CMAKE_SOURCE_DIR
}
/paddle/phi/kernels/sparse/gpu/cutlass
_generator
/build"
"
${
CMAKE_CUDA_COMPILER_VERSION
}
"
INSTALL_COMMAND
""
TEST_COMMAND
""
)
...
...
paddle/phi/kernels/autotune/auto_tune_base.h
浏览文件 @
12d43da9
...
...
@@ -177,6 +177,85 @@ class MatmulAutoTuner
}
};
template
<
typename
T
,
typename
ReturnType
,
typename
...
Args
>
class
GatherGemmScatterAutoTuner
:
public
AutoTuneBase
<
T
,
KernelCallback
<
T
,
ReturnType
,
T
,
T
,
Args
...
>>
{
public:
static
GatherGemmScatterAutoTuner
<
T
,
ReturnType
,
Args
...
>*
Instance
(
ReturnType
(
*
func
)(
T
,
T
,
Args
...))
{
static
std
::
once_flag
gather_gemm_scatter_init_flag
;
static
std
::
unique_ptr
<
GatherGemmScatterAutoTuner
<
T
,
ReturnType
,
Args
...
>>
instance
;
std
::
call_once
(
gather_gemm_scatter_init_flag
,
[
&
]
{
auto
obj
=
MakeCallback
<
T
>
(
func
);
instance
.
reset
(
new
GatherGemmScatterAutoTuner
<
T
,
ReturnType
,
Args
...
>
);
instance
->
AddCallBack
(
func
);
});
return
instance
.
get
();
}
void
Run
(
const
phi
::
GPUContext
&
ctx
,
const
size_t
key
,
T
const
alpha
,
T
const
beta
,
Args
...
args
)
{
this
->
is_init_
=
true
;
this
->
CheckKernelSize
();
auto
&
cache
=
AutoTuneCache
::
Instance
().
GetGatherGemmScatter
<
T
>
();
if
(
cache
.
Find
(
key
))
{
auto
best_idx
=
cache
.
Get
(
key
);
this
->
kernels_
[
best_idx
].
Run
(
alpha
,
beta
,
args
...);
}
else
{
// Set alpha to 0 and beta to 1 to avoid changing the value of d when
// picking the best kernel
auto
best_idx
=
PickBestKernel
(
ctx
,
static_cast
<
T
>
(
0
),
static_cast
<
T
>
(
1
),
args
...);
cache
.
Set
(
key
,
best_idx
);
this
->
kernels_
[
best_idx
].
Run
(
alpha
,
beta
,
args
...);
}
}
protected:
size_t
PickBestKernel
(
const
phi
::
GPUContext
&
ctx
,
const
T
&
alpha
,
const
T
&
beta
,
Args
&
...
args
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_
);
constexpr
size_t
NO_KERNEL_WORKS
=
-
1
;
size_t
best_idx
=
NO_KERNEL_WORKS
;
float
min_time
=
std
::
numeric_limits
<
float
>::
max
();
// Time cost test estabulished in default stream.
for
(
int
i
=
0
;
i
<
this
->
kernels_
.
size
();
++
i
)
{
float
time
=
0
;
// Some kernels may require more shared memory than available, skip these
// kernels.
try
{
time
=
this
->
RunAndMeasureKernel
(
ctx
,
i
,
alpha
,
beta
,
args
...);
if
(
time
<
min_time
)
{
min_time
=
time
;
best_idx
=
i
;
}
}
catch
(
const
std
::
runtime_error
&
error
)
{
VLOG
(
3
)
<<
"the kernels_["
<<
i
<<
"] get error:"
<<
error
.
what
();
}
}
if
(
best_idx
==
NO_KERNEL_WORKS
)
{
LOG
(
ERROR
)
<<
"No kernel works!
\n
"
;
exit
(
-
1
);
}
VLOG
(
3
)
<<
"best kernel idx is "
<<
best_idx
;
return
best_idx
;
}
};
template
<
typename
T
,
typename
ReturnType
,
typename
...
Args
>
static
GatherGemmScatterAutoTuner
<
T
,
ReturnType
,
Args
...
>*
MakeGatherGemmScatterTuner
(
ReturnType
(
*
func
)(
T
,
T
,
Args
...))
{
return
GatherGemmScatterAutoTuner
<
T
,
ReturnType
,
Args
...
>::
Instance
(
func
);
}
// Define the auto_tuner inital object.
#define DEFINE_AUTOTUNER_COMMON_OBJ(name) \
template <typename T, typename ReturnType, typename... Args> \
...
...
paddle/phi/kernels/autotune/cache.h
浏览文件 @
12d43da9
...
...
@@ -45,13 +45,15 @@ enum class AlgorithmType {
kConvBackwardFilter
=
3
,
kTranspose
=
4
,
kMatmul
=
5
,
kGatherGemmScatterFP16NN
=
6
,
kGatherGemmScatterFP32NN
=
7
,
#if !defined(PADDLE_WITH_CUDNN_FRONTEND)
kAlgorithmCount
=
6
kAlgorithmCount
=
8
#else
kConvForwardV8
=
6
,
kConvBackwardDataV8
=
7
,
kConvBackwardFilterV8
=
8
,
kAlgorithmCount
=
9
kConvForwardV8
=
8
,
kConvBackwardDataV8
=
9
,
kConvBackwardFilterV8
=
10
,
kAlgorithmCount
=
11
#endif
};
...
...
@@ -88,6 +90,20 @@ class AutoTuneCache {
return
conv_auto_tune_map_
[
static_cast
<
int64_t
>
(
algo_type
)];
}
template
<
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
float
>::
value
,
AlgorithmsCacheMap
&>::
type
GetGatherGemmScatter
()
{
return
Get
(
AlgorithmType
::
kGatherGemmScatterFP32NN
);
}
template
<
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
,
AlgorithmsCacheMap
&>::
type
GetGatherGemmScatter
()
{
return
Get
(
AlgorithmType
::
kGatherGemmScatterFP16NN
);
}
#ifdef PADDLE_WITH_CUDNN_FRONTEND
CudnnFrontendPlanCache
&
GetConvV8
(
const
AlgorithmType
&
algo_type
)
{
return
cudnn_v8_auto_tune_map_
[
static_cast
<
int64_t
>
(
algo_type
)];
...
...
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
浏览文件 @
12d43da9
...
...
@@ -125,12 +125,16 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
#ifdef PADDLE_WITH_CUTLASS
bool
cutlass
=
true
;
if
(
dev_ctx
.
GetComputeCapability
()
<
75
)
cutlass
=
false
;
if
(
in_channels
%
4
!=
0
||
out_channels
%
4
!=
0
)
{
if
(
dev_ctx
.
GetComputeCapability
()
<
80
)
cutlass
=
false
;
if
(
in_channels
%
8
!=
0
||
out_channels
%
8
!=
0
)
{
if
(
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
)
cutlass
=
false
;
}
if
(
in_channels
%
4
!=
0
||
out_channels
%
4
!=
0
)
{
if
(
std
::
is_same
<
T
,
float
>::
value
)
cutlass
=
false
;
}
if
(
std
::
is_same
<
T
,
double
>::
value
)
cutlass
=
false
;
if
(
!
std
::
is_same
<
IntT
,
int32_t
>::
value
)
cutlass
=
false
;
if
(
cutlass
)
{
auto
*
out_values
=
out
->
mutable_non_zero_elements
();
T
*
out_values_ptr
=
out_values
->
data
<
T
>
();
...
...
@@ -150,18 +154,18 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const
IntT
*
gather_indices
=
rulebook_ptr
+
h_offsets_ptr
[
i
];
const
IntT
*
scatter_indices
=
rulebook_ptr
+
rulebook_len
+
h_offsets_ptr
[
i
];
dispatchKernel
(
dev_ctx
,
x
.
non_zero_elements
().
data
<
T
>
(),
tmp_kernel_ptr
,
out_values_ptr
,
out_values_ptr
,
M
,
N
,
K
,
gather_indices
,
scatter_indices
,
cutlass
,
x
.
dtype
(
));
GatherGemmScatterDriver
(
dev_ctx
,
x
.
non_zero_elements
().
data
<
T
>
(),
tmp_kernel_ptr
,
out_values_ptr
,
out_values_ptr
,
M
,
N
,
K
,
gather_indices
,
scatter_indices
,
static_cast
<
T
>
(
1.0
)
,
static_cast
<
T
>
(
1.0
));
}
}
else
{
#endif
...
...
paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h
0 → 100644
浏览文件 @
12d43da9
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_CUTLASS
#include "cutlass/arch/mma.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/half.h"
#include "cutlass/util/device_memory.h"
#include "examples/common/helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace
phi
{
namespace
sparse
{
#define TYPEDEF_KERNEL_POINTER(kernel, dtype) \
typedef void (*kernel)(dtype const alpha, \
dtype const beta, \
const GPUContext& dev_ctx, \
const dtype* const a, \
const dtype* const b, \
const dtype* const c, \
dtype* const d, \
const int m, \
const int n, \
const int k, \
const int32_t* a_indices, \
const int32_t* c_d_indices);
#define GATHER_GEMM_SCATTER_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
throw std::runtime_error(cutlassGetStatusString(error)); \
} \
}
#define DEFINE_LAUNCH_KERNEL(dtype, cutlass_type) \
template <typename Gemm> \
void launchKernel(dtype const alpha, \
dtype const beta, \
const GPUContext& dev_ctx, \
const dtype* const a, \
const dtype* const b, \
const dtype* const c, \
dtype* const d, \
const int m, \
const int n, \
const int k, \
const int32_t* a_indices, \
const int32_t* c_d_indices) { \
cutlass::gemm::GemmCoord problem_size_real({m, n, k}); \
int split_k_slices = 1; \
typename Gemm::Arguments arguments{ \
cutlass::gemm::GemmUniversalMode::kGemm, \
problem_size_real, \
split_k_slices, \
{static_cast<const cutlass_type>(static_cast<const float>(alpha)), \
static_cast<const cutlass_type>(static_cast<const float>(beta))}, \
reinterpret_cast<const cutlass_type* const>(a), \
reinterpret_cast<const cutlass_type* const>(b), \
reinterpret_cast<const cutlass_type* const>(c), \
reinterpret_cast<cutlass_type* const>(d), \
cutlass::layout::RowMajor().capacity(problem_size_real.mk()), \
cutlass::layout::RowMajor().capacity(problem_size_real.kn()), \
cutlass::layout::RowMajor().capacity(problem_size_real.mn()), \
cutlass::layout::RowMajor().capacity(problem_size_real.mn()), \
problem_size_real.k(), \
problem_size_real.n(), \
problem_size_real.n(), \
problem_size_real.n(), \
a_indices, \
nullptr, \
c_d_indices}; \
size_t workspace_size = Gemm::get_workspace_size(arguments); \
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); \
Gemm gemm_op; \
cutlass::Status status = gemm_op.can_implement(arguments); \
GATHER_GEMM_SCATTER_CHECK(status); \
status = gemm_op.initialize(arguments, workspace.get()); \
GATHER_GEMM_SCATTER_CHECK(status); \
gemm_op(dev_ctx.stream()); \
}
TYPEDEF_KERNEL_POINTER
(
fp16_gather_gemm_scatter
,
phi
::
dtype
::
float16
)
TYPEDEF_KERNEL_POINTER
(
fp32_gather_gemm_scatter
,
float
)
DEFINE_LAUNCH_KERNEL
(
phi
::
dtype
::
float16
,
cutlass
::
half_t
)
DEFINE_LAUNCH_KERNEL
(
float
,
float
)
}
// namespace sparse
}
// namespace phi
#endif
paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py
→
paddle/phi/kernels/sparse/gpu/cutlass
_generator
/gather_gemm_scatter_generator.py
浏览文件 @
12d43da9
...
...
@@ -41,7 +41,6 @@ def CreateGatherGemmScatterOperator(
layouts
,
tile_descriptions
,
data_type
,
alignment_constraints
,
complex_transforms
=
None
,
epilogue_functor
=
EpilogueFunctor
.
LinearCombination
,
swizzling_functor
=
SwizzlingFunctor
.
Identity8
,
...
...
@@ -55,12 +54,15 @@ def CreateGatherGemmScatterOperator(
element_a
,
element_b
,
element_c
,
element_epilogue
=
data_type
operations
=
[]
alignment_constraints
=
[
0
]
if
'f16'
==
element_a
.
name
or
'bf16'
==
element_a
.
name
:
alignment_constraints
=
[
8
]
elif
'f32'
==
element_a
.
name
or
'tf32'
==
element_a
.
name
:
alignment_constraints
=
[
4
]
elif
'f64'
==
element_a
.
name
:
alignment_constraints
=
[
1
]
# by default, only generate the largest tile and largest alignment
# if manifest.kernel_filter == '':
# tile_descriptions = [tile_descriptions[0],]
# alignment_constraints = [alignment_constraints[0],]
operations
=
[]
for
layout
in
layouts
:
for
tile_description
in
tile_descriptions
:
...
...
@@ -95,9 +97,9 @@ def CreateGatherGemmScatterOperator(
return
operations
def
GenerateSM
70_TensorOp_884
(
manifest
,
cuda_version
):
def
GenerateSM
80_TensorOp_16816
(
manifest
,
cuda_version
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
1
0
,
1
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
1
1
,
0
):
return
layouts
=
[
...
...
@@ -106,15 +108,7 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version):
math_instructions
=
[
MathInstruction
(
[
8
,
8
,
4
],
DataType
.
f16
,
DataType
.
f16
,
DataType
.
f32
,
OpcodeClass
.
TensorOp
,
MathOperation
.
multiply_add
,
),
MathInstruction
(
[
8
,
8
,
4
],
[
16
,
8
,
16
],
DataType
.
f16
,
DataType
.
f16
,
DataType
.
f16
,
...
...
@@ -123,36 +117,78 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version):
),
]
min_cc
=
7
0
max_cc
=
75
min_cc
=
8
0
max_cc
=
1024
alignment_constraints
=
[
8
,
4
,
2
,
1
]
alignment_constraints
=
[
8
]
for
math_inst
in
math_instructions
:
tile_descriptions
=
[
TileDescription
(
[
256
,
128
,
32
],
2
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
[
256
,
128
,
32
],
3
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
256
,
32
],
3
,
[
2
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
256
,
64
,
32
],
3
,
[
4
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
256
,
64
,
32
],
4
,
[
4
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
256
,
32
],
4
,
[
1
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
32
],
3
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
32
],
4
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
32
],
5
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
64
,
32
],
6
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
128
,
32
],
6
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
64
,
32
],
10
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
256
,
32
],
2
,
[
2
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
[
256
,
128
,
64
],
3
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
32
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
[
128
,
256
,
64
],
3
,
[
2
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
256
,
64
,
32
],
2
,
[
4
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
[
256
,
64
,
64
],
4
,
[
4
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
256
,
32
],
2
,
[
1
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
[
64
,
256
,
64
],
4
,
[
1
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
128
,
32
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
[
128
,
128
,
64
],
4
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
64
,
32
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
[
256
,
64
,
64
],
3
,
[
4
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
64
,
32
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
[
64
,
256
,
64
],
3
,
[
1
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
64
],
3
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
64
,
64
],
3
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
128
,
64
],
3
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
64
,
64
],
5
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
...
...
@@ -164,11 +200,7 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version):
]
CreateGatherGemmScatterOperator
(
manifest
,
layouts
,
tile_descriptions
,
data_type
,
alignment_constraints
,
manifest
,
layouts
,
tile_descriptions
,
data_type
)
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
...
...
@@ -182,16 +214,286 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version):
]
CreateGatherGemmScatterOperator
(
manifest
,
layouts
,
tile_descriptions
,
data_type_mixed
,
alignment_constraints
,
manifest
,
layouts
,
tile_descriptions
,
data_type_mixed
)
def
GenerateSM70
(
manifest
,
cuda_version
):
GenerateSM70_TensorOp_884
(
manifest
,
cuda_version
)
def
GenerateSM80_TensorOp_1688
(
manifest
,
cuda_version
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
11
,
0
):
return
layouts
=
[
(
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
]
math_instructions
=
[
MathInstruction
(
[
16
,
8
,
8
],
DataType
.
tf32
,
DataType
.
tf32
,
DataType
.
f32
,
OpcodeClass
.
TensorOp
,
MathOperation
.
multiply_add
,
)
]
min_cc
=
80
max_cc
=
1024
for
math_inst
in
math_instructions
:
tile_descriptions
=
[
TileDescription
(
[
256
,
128
,
16
],
3
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
256
,
16
],
3
,
[
2
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
256
,
64
,
16
],
4
,
[
4
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
256
,
16
],
4
,
[
1
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
16
],
5
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
16
],
4
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
16
],
3
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
64
,
16
],
6
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
128
,
16
],
6
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
64
,
16
],
10
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
256
,
128
,
32
],
3
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
256
,
32
],
3
,
[
2
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
256
,
64
,
32
],
4
,
[
4
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
256
,
32
],
4
,
[
1
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
32
],
4
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
32
],
3
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
64
,
32
],
3
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
128
,
32
],
3
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
64
,
32
],
5
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
data_type
=
[
math_inst
.
element_a
,
math_inst
.
element_b
,
math_inst
.
element_accumulator
,
math_inst
.
element_accumulator
,
]
data_type_mixed
=
[
math_inst
.
element_a
,
math_inst
.
element_b
,
math_inst
.
element_a
,
math_inst
.
element_accumulator
,
]
CreateGatherGemmScatterOperator
(
manifest
,
layouts
,
tile_descriptions
,
data_type
)
CreateGatherGemmScatterOperator
(
manifest
,
layouts
,
tile_descriptions
,
data_type_mixed
)
def
GenerateSM80_TensorOp_1688_fast_math
(
manifest
,
cuda_version
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
11
,
0
):
return
layouts
=
[
(
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
]
math_instructions
=
[
MathInstruction
(
[
16
,
8
,
8
],
DataType
.
tf32
,
DataType
.
tf32
,
DataType
.
f32
,
OpcodeClass
.
TensorOp
,
MathOperation
.
multiply_add
,
),
]
min_cc
=
80
max_cc
=
1024
for
math_inst
in
math_instructions
:
tile_descriptions
=
[
TileDescription
(
[
256
,
128
,
16
],
3
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
256
,
16
],
3
,
[
2
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
256
,
64
,
16
],
4
,
[
4
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
256
,
16
],
4
,
[
1
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
16
],
5
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
16
],
4
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
16
],
3
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
64
,
16
],
6
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
128
,
16
],
6
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
64
,
16
],
10
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
256
,
128
,
32
],
3
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
256
,
32
],
3
,
[
2
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
256
,
64
,
32
],
4
,
[
4
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
256
,
32
],
4
,
[
1
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
32
],
4
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
32
],
3
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
64
,
32
],
3
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
128
,
32
],
3
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
64
,
32
],
5
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
data_type
=
[
DataType
.
f32
,
DataType
.
f32
,
DataType
.
f32
,
DataType
.
f32
]
CreateGatherGemmScatterOperator
(
manifest
,
layouts
,
tile_descriptions
,
data_type
)
def
GenerateSM80_TensorOp_1688_fast_fp32_math
(
manifest
,
cuda_version
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
11
,
0
):
return
layouts
=
[
(
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
]
math_instructions
=
[
MathInstruction
(
[
16
,
8
,
8
],
DataType
.
f32
,
DataType
.
f32
,
DataType
.
f32
,
OpcodeClass
.
TensorOp
,
MathOperation
.
multiply_add_fast_f32
,
),
]
min_cc
=
80
max_cc
=
1024
for
math_inst
in
math_instructions
:
tile_descriptions
=
[
TileDescription
(
[
128
,
128
,
16
],
4
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
16
],
3
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
256
,
64
,
16
],
3
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
256
,
16
],
3
,
[
2
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
64
,
16
],
4
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
128
,
16
],
4
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
64
,
16
],
3
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
32
],
3
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
256
,
64
,
32
],
3
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
256
,
32
],
3
,
[
2
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
64
,
32
],
3
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
128
,
32
],
3
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
64
,
32
],
3
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
data_type
=
[
DataType
.
f32
,
DataType
.
f32
,
DataType
.
f32
,
DataType
.
f32
]
CreateGatherGemmScatterOperator
(
manifest
,
layouts
,
tile_descriptions
,
data_type
)
def
GenerateSM80
(
manifest
,
cuda_version
):
GenerateSM80_TensorOp_16816
(
manifest
,
cuda_version
)
GenerateSM80_TensorOp_1688
(
manifest
,
cuda_version
)
GenerateSM80_TensorOp_1688_fast_math
(
manifest
,
cuda_version
)
GenerateSM80_TensorOp_1688_fast_fp32_math
(
manifest
,
cuda_version
)
class
KernelCfg
:
...
...
@@ -229,7 +531,7 @@ class KernelCfg:
if
__name__
==
"__main__"
:
args
=
KernelCfg
(
architectures
=
'
7
0'
,
architectures
=
'
8
0'
,
build_dir
=
sys
.
argv
[
2
],
cuda_version
=
sys
.
argv
[
3
],
curr_build_dir
=
sys
.
argv
[
2
],
...
...
@@ -245,6 +547,6 @@ if __name__ == "__main__":
)
manifest
=
GatherGemmScatterManifest
(
args
)
GenerateSM
7
0
(
manifest
,
args
.
cuda_version
)
GenerateSM
8
0
(
manifest
,
args
.
cuda_version
)
manifest
.
emit
(
GeneratorTarget
.
Library
)
paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_manifest.py
→
paddle/phi/kernels/sparse/gpu/cutlass
_generator
/gather_gemm_scatter_manifest.py
浏览文件 @
12d43da9
...
...
@@ -18,7 +18,7 @@ import shutil
from
gather_gemm_scatter_operation
import
(
EmitGatherGemmScatterConfigurationLibrary
,
)
from
library
import
OperationKind
,
OperationKindNames
from
library
import
OperationKind
,
OperationKindNames
,
SubstituteTemplate
from
manifest
import
EmitOperationKindLibrary
,
GeneratorTarget
,
Manifest
...
...
@@ -28,11 +28,25 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary):
self
.
emitters
=
{
OperationKind
.
Gemm
:
EmitGatherGemmScatterConfigurationLibrary
}
self
.
header_template
=
"#pragma once
\n
#ifdef PADDLE_WITH_CUTLASS
\n
"
self
.
header_template
=
"#pragma once
\n
#ifdef PADDLE_WITH_CUTLASS
\n
#include
\"
paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h
\"\n
"
self
.
entry_template
=
""
self
.
configuration_prototype_template
=
""
self
.
configuration_template
=
""
self
.
epilogue_template
=
"#endif"
self
.
namespace_template
=
"""
namespace phi {
namespace sparse {
"""
self
.
epilogue_template
=
"""
} // namespace sparse
} // namespace phi
#endif
"""
self
.
fp16_kernels_list
=
(
"static std::vector<fp16_gather_gemm_scatter> fp16_kernels = {
\n
"
)
self
.
fp32_kernels_list
=
(
"static std::vector<fp32_gather_gemm_scatter> fp32_kernels = {
\n
"
)
def
__enter__
(
self
):
self
.
operation_path
=
os
.
path
.
join
(
...
...
@@ -64,6 +78,21 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary):
self
.
source_files
.
append
(
configuration_emitter
.
configuration_path
)
self
.
configurations
.
append
(
configuration_name
)
if
'h'
==
operations
[
0
].
short_math_name
():
self
.
fp16_kernels_list
+=
(
"""
launchKernel<"""
+
configuration_name
+
"::Gemm>,"
)
if
's'
==
operations
[
0
].
short_math_name
():
self
.
fp32_kernels_list
+=
(
"""
launchKernel<"""
+
configuration_name
+
"::Gemm>,"
)
self
.
top_level_file
.
write
(
'#include "'
+
self
.
operation_path
...
...
@@ -72,6 +101,30 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary):
+
'.h"
\n
'
)
def
__exit__
(
self
,
exception_type
,
exception_value
,
traceback
):
self
.
top_level_file
.
write
(
SubstituteTemplate
(
self
.
entry_template
,
{
'operation_name'
:
OperationKindNames
[
self
.
kind
]},
)
)
for
configuration_name
in
self
.
configurations
:
self
.
top_level_file
.
write
(
SubstituteTemplate
(
self
.
configuration_template
,
{
'configuration_name'
:
configuration_name
},
)
)
self
.
fp16_kernels_list
+=
"
\n
};
\n
"
self
.
fp32_kernels_list
+=
"
\n
};
\n
"
self
.
top_level_file
.
write
(
self
.
namespace_template
)
self
.
top_level_file
.
write
(
self
.
fp16_kernels_list
)
self
.
top_level_file
.
write
(
self
.
fp32_kernels_list
)
self
.
top_level_file
.
write
(
self
.
epilogue_template
)
self
.
top_level_file
.
close
()
class
GatherGemmScatterManifest
(
Manifest
):
def
emit
(
self
,
target
=
GeneratorTarget
.
Library
):
...
...
paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_operation.py
→
paddle/phi/kernels/sparse/gpu/cutlass
_generator
/gather_gemm_scatter_operation.py
浏览文件 @
12d43da9
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
collections
import
enum
import
os.path
...
...
@@ -40,16 +41,7 @@ from library import (
class
EmitGatherGemmScatterInstance
(
EmitGemmInstance
):
def
__init__
(
self
,
operation_suffix
=
''
):
self
.
operation_suffix
=
operation_suffix
self
.
includes
=
[
"cutlass/cutlass.h"
,
"cutlass/numeric_types.h"
,
"cutlass/arch/arch.h"
,
"cutlass/arch/mma.h"
,
"cutlass/layout/matrix.h"
,
"cutlass/gemm/device/gemm.h"
,
"cutlass/gemm/device/gemm_universal_adapter.h"
,
"cutlass/gemm/kernel/default_gemm_universal.h"
,
]
self
.
includes
=
[]
self
.
builtin_epilogue_functor_template
=
"""
${epilogue_functor}<
${element_c},
...
...
@@ -247,6 +239,18 @@ namespace sparse {
#endif
"""
def
__enter__
(
self
):
self
.
configuration_file
=
open
(
self
.
configuration_path
,
"w"
)
self
.
configuration_file
.
write
(
self
.
header_template
)
self
.
configuration_file
.
write
(
self
.
separator
)
self
.
includes
=
collections
.
OrderedDict
([])
self
.
instance_definitions
=
[]
self
.
instance_wrappers
=
[]
self
.
operations
=
[]
return
self
def
__exit__
(
self
,
exception_type
,
exception_value
,
traceback
):
# Write includes
...
...
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu
已删除
100644 → 0
浏览文件 @
be9515f2
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
namespace
phi
{
namespace
sparse
{
fp16_gather_gemm_scatter
getBestFp16Kernel
(
const
int
M
,
const
int
N
,
const
int
K
)
{
if
(
K
==
4
&&
N
==
16
)
{
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4
::
Gemm
>
;
}
if
(
K
==
16
&&
N
==
16
)
{
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8
::
Gemm
>
;
}
if
(
K
==
16
&&
N
==
32
)
{
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8
::
Gemm
>
;
}
if
(
K
==
32
&&
N
==
32
)
{
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8
::
Gemm
>
;
}
if
(
K
==
32
&&
N
==
64
)
{
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8
::
Gemm
>
;
}
if
(
K
==
64
&&
N
==
64
)
{
if
(
M
>
100000
)
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8
::
Gemm
>
;
if
(
M
>
20000
)
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8
::
Gemm
>
;
if
(
M
>
15000
)
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8
::
Gemm
>
;
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8
::
Gemm
>
;
}
if
(
K
==
128
)
{
if
(
M
>=
5000
)
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8
::
Gemm
>
;
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h16816gemm_64x64_64x5_nn_align8
::
Gemm
>
;
}
if
(
N
==
128
)
{
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8
::
Gemm
>
;
}
return
launchKernel
<
cutlass
::
half_t
,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4
::
Gemm
>
;
}
fp32_gather_gemm_scatter
getBestFp32Kernel
(
const
int
M
,
const
int
N
,
const
int
K
,
const
int
SM
)
{
if
(
SM
==
75
)
{
return
launchKernel
<
float
,
cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4
::
Gemm
>
;
}
if
(
K
==
4
&&
N
==
16
)
{
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4
::
Gemm
>
;
}
if
(
K
==
16
&&
N
==
16
)
{
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4
::
Gemm
>
;
}
if
(
K
==
16
&&
N
==
32
)
{
if
(
M
>=
10000
)
return
launchKernel
<
float
,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4
::
Gemm
>
;
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4
::
Gemm
>
;
}
if
(
K
==
32
&&
N
==
32
)
{
if
(
M
>=
10000
)
return
launchKernel
<
float
,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4
::
Gemm
>
;
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4
::
Gemm
>
;
}
if
(
K
==
32
&&
N
==
64
)
{
if
(
M
>=
10000
)
return
launchKernel
<
float
,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4
::
Gemm
>
;
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4
::
Gemm
>
;
}
if
(
K
==
64
&&
N
==
64
)
{
if
(
M
>=
15000
)
return
launchKernel
<
float
,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4
::
Gemm
>
;
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4
::
Gemm
>
;
}
if
(
K
==
128
)
{
if
(
M
>=
100000
)
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4
::
Gemm
>
;
if
(
M
>=
5000
)
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4
::
Gemm
>
;
return
launchKernel
<
float
,
cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4
::
Gemm
>
;
}
if
(
N
==
128
)
{
if
(
M
>=
100000
)
return
launchKernel
<
float
,
cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4
::
Gemm
>
;
if
(
M
>=
5000
)
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4
::
Gemm
>
;
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4
::
Gemm
>
;
}
return
launchKernel
<
float
,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4
::
Gemm
>
;
}
fp64_gather_gemm_scatter
getBestFp64Kernel
(
const
int
M
,
const
int
N
,
const
int
K
)
{
if
(
K
==
4
&&
N
==
16
)
{
return
launchKernel
<
double
,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1
::
Gemm
>
;
}
if
(
K
==
16
&&
N
==
16
)
{
if
(
M
>=
10000
)
return
launchKernel
<
double
,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1
::
Gemm
>
;
return
launchKernel
<
double
,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1
::
Gemm
>
;
}
if
(
K
==
16
&&
N
==
32
)
{
return
launchKernel
<
double
,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1
::
Gemm
>
;
}
if
(
K
==
32
&&
N
==
32
)
{
return
launchKernel
<
double
,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1
::
Gemm
>
;
}
if
(
K
==
32
&&
N
==
64
)
{
return
launchKernel
<
double
,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1
::
Gemm
>
;
}
if
(
K
==
64
&&
N
==
64
)
{
return
launchKernel
<
double
,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1
::
Gemm
>
;
}
return
launchKernel
<
double
,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1
::
Gemm
>
;
}
}
// namespace sparse
}
// namespace phi
#endif
paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
浏览文件 @
12d43da9
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录