Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wmsofts
Paddle
提交
507af1c8
P
Paddle
项目概览
wmsofts
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
507af1c8
编写于
2月 20, 2023
作者:
U
umiswing
提交者:
GitHub
2月 20, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add generator scripts for cutlass (#50364)
上级
c92b1c54
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
679 addition
and
1 deletion
+679
-1
cmake/external/cutlass.cmake
cmake/external/cutlass.cmake
+8
-1
paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py
...rnels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py
+250
-0
paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_manifest.py
...ernels/sparse/gpu/cutlass/gather_gemm_scatter_manifest.py
+101
-0
paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_operation.py
...rnels/sparse/gpu/cutlass/gather_gemm_scatter_operation.py
+320
-0
未找到文件。
cmake/external/cutlass.cmake
浏览文件 @
507af1c8
...
...
@@ -34,7 +34,14 @@ ExternalProject_Add(
PREFIX
${
CUTLASS_PREFIX_DIR
}
UPDATE_COMMAND
""
CONFIGURE_COMMAND
""
BUILD_COMMAND
""
BUILD_COMMAND
mkdir -p
${
CMAKE_SOURCE_DIR
}
/paddle/phi/kernels/sparse/gpu/cutlass/build/generated/gemm
&&
${
PYTHON_EXECUTABLE
}
-B
${
CMAKE_SOURCE_DIR
}
/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py
"
${
THIRD_PARTY_PATH
}
/cutlass/src/extern_cutlass/tools/library/scripts/"
"
${
CMAKE_SOURCE_DIR
}
/paddle/phi/kernels/sparse/gpu/cutlass/build"
"
${
CMAKE_CUDA_COMPILER_VERSION
}
"
INSTALL_COMMAND
""
TEST_COMMAND
""
)
...
...
paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py
0 → 100644
浏览文件 @
507af1c8
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
sys
.
path
.
append
(
sys
.
argv
[
1
])
from
gather_gemm_scatter_manifest
import
GatherGemmScatterManifest
from
gather_gemm_scatter_operation
import
GatherGemmScatterOperation
from
generator
import
(
ComplexTransform
,
CudaToolkitVersionSatisfies
,
EpilogueFunctor
,
GemmKind
,
SwizzlingFunctor
,
TensorDescription
,
)
from
library
import
(
DataType
,
LayoutType
,
MathInstruction
,
MathOperation
,
OpcodeClass
,
TileDescription
,
)
from
manifest
import
GeneratorTarget
def
CreateGatherGemmScatterOperator
(
manifest
,
layouts
,
tile_descriptions
,
data_type
,
alignment_constraints
,
complex_transforms
=
None
,
epilogue_functor
=
EpilogueFunctor
.
LinearCombination
,
swizzling_functor
=
SwizzlingFunctor
.
Identity8
,
):
# To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK`
if
complex_transforms
is
None
:
complex_transforms
=
[
(
ComplexTransform
.
none
,
ComplexTransform
.
none
),
]
element_a
,
element_b
,
element_c
,
element_epilogue
=
data_type
operations
=
[]
# by default, only generate the largest tile and largest alignment
# if manifest.kernel_filter == '':
# tile_descriptions = [tile_descriptions[0],]
# alignment_constraints = [alignment_constraints[0],]
for
layout
in
layouts
:
for
tile_description
in
tile_descriptions
:
for
alignment
in
alignment_constraints
:
for
complex_transform
in
complex_transforms
:
alignment_c
=
min
(
8
,
alignment
)
A
=
TensorDescription
(
element_a
,
layout
[
0
],
alignment
,
complex_transform
[
0
]
)
B
=
TensorDescription
(
element_b
,
layout
[
1
],
alignment
,
complex_transform
[
1
]
)
C
=
TensorDescription
(
element_c
,
layout
[
2
],
alignment_c
)
new_operation
=
GatherGemmScatterOperation
(
GemmKind
.
Universal
,
tile_description
.
minimum_compute_capability
,
tile_description
,
A
,
B
,
C
,
element_epilogue
,
epilogue_functor
,
swizzling_functor
,
)
manifest
.
append
(
new_operation
)
operations
.
append
(
new_operation
)
return
operations
def
GenerateSM70_TensorOp_884
(
manifest
,
cuda_version
):
if
not
CudaToolkitVersionSatisfies
(
cuda_version
,
10
,
1
):
return
layouts
=
[
(
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
,
LayoutType
.
RowMajor
),
]
math_instructions
=
[
MathInstruction
(
[
8
,
8
,
4
],
DataType
.
f16
,
DataType
.
f16
,
DataType
.
f32
,
OpcodeClass
.
TensorOp
,
MathOperation
.
multiply_add
,
),
MathInstruction
(
[
8
,
8
,
4
],
DataType
.
f16
,
DataType
.
f16
,
DataType
.
f16
,
OpcodeClass
.
TensorOp
,
MathOperation
.
multiply_add
,
),
]
min_cc
=
70
max_cc
=
75
alignment_constraints
=
[
8
,
4
,
2
,
1
]
for
math_inst
in
math_instructions
:
tile_descriptions
=
[
TileDescription
(
[
256
,
128
,
32
],
2
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
256
,
32
],
2
,
[
2
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
128
,
32
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
256
,
64
,
32
],
2
,
[
4
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
256
,
32
],
2
,
[
1
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
128
,
32
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
128
,
64
,
32
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
(
[
64
,
64
,
32
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
data_type
=
[
math_inst
.
element_a
,
math_inst
.
element_b
,
math_inst
.
element_accumulator
,
math_inst
.
element_accumulator
,
]
CreateGatherGemmScatterOperator
(
manifest
,
layouts
,
tile_descriptions
,
data_type
,
alignment_constraints
,
)
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
if
math_inst
.
element_a
!=
math_inst
.
element_accumulator
:
data_type_mixed
=
[
math_inst
.
element_a
,
math_inst
.
element_b
,
math_inst
.
element_a
,
math_inst
.
element_accumulator
,
]
CreateGatherGemmScatterOperator
(
manifest
,
layouts
,
tile_descriptions
,
data_type_mixed
,
alignment_constraints
,
)
def
GenerateSM70
(
manifest
,
cuda_version
):
GenerateSM70_TensorOp_884
(
manifest
,
cuda_version
)
class
KernelCfg
:
def
__init__
(
self
,
architectures
,
build_dir
,
cuda_version
,
curr_build_dir
,
disable_full_archs_compilation
,
filter_by_cc
,
generator_target
,
ignore_kernels
,
interface_dir
,
kernel_filter_file
,
kernels
,
operations
,
selected_kernel_list
,
):
self
.
architectures
=
architectures
self
.
build_dir
=
build_dir
self
.
cuda_version
=
cuda_version
self
.
curr_build_dir
=
curr_build_dir
self
.
disable_full_archs_compilation
=
disable_full_archs_compilation
self
.
filter_by_cc
=
filter_by_cc
self
.
generator_target
=
generator_target
self
.
ignore_kernels
=
ignore_kernels
self
.
interface_dir
=
interface_dir
self
.
kernel_filter_file
=
kernel_filter_file
self
.
kernels
=
kernels
self
.
operations
=
operations
self
.
selected_kernel_list
=
selected_kernel_list
if
__name__
==
"__main__"
:
args
=
KernelCfg
(
architectures
=
'70'
,
build_dir
=
sys
.
argv
[
2
],
cuda_version
=
sys
.
argv
[
3
],
curr_build_dir
=
sys
.
argv
[
2
],
disable_full_archs_compilation
=
False
,
filter_by_cc
=
'True'
,
generator_target
=
'library'
,
ignore_kernels
=
''
,
interface_dir
=
None
,
kernel_filter_file
=
None
,
kernels
=
''
,
operations
=
'all'
,
selected_kernel_list
=
None
,
)
manifest
=
GatherGemmScatterManifest
(
args
)
GenerateSM70
(
manifest
,
args
.
cuda_version
)
manifest
.
emit
(
GeneratorTarget
.
Library
)
paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_manifest.py
0 → 100644
浏览文件 @
507af1c8
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
shutil
from
gather_gemm_scatter_operation
import
(
EmitGatherGemmScatterConfigurationLibrary
,
)
from
library
import
OperationKind
,
OperationKindNames
from
manifest
import
EmitOperationKindLibrary
,
GeneratorTarget
,
Manifest
class
GatherGemmScatterEmitOperationKindLibrary
(
EmitOperationKindLibrary
):
def
__init__
(
self
,
generated_path
,
kind
,
args
):
super
().
__init__
(
generated_path
,
kind
,
args
)
self
.
emitters
=
{
OperationKind
.
Gemm
:
EmitGatherGemmScatterConfigurationLibrary
}
self
.
header_template
=
"#pragma once
\n
#ifdef PADDLE_WITH_CUTLASS
\n
"
self
.
entry_template
=
""
self
.
configuration_prototype_template
=
""
self
.
configuration_template
=
""
self
.
epilogue_template
=
"#endif"
def
__enter__
(
self
):
self
.
operation_path
=
os
.
path
.
join
(
self
.
generated_path
,
OperationKindNames
[
self
.
kind
]
)
os
.
mkdir
(
self
.
operation_path
)
self
.
top_level_path
=
os
.
path
.
join
(
self
.
operation_path
,
"all_%s_operations.h"
%
OperationKindNames
[
self
.
kind
],
)
self
.
top_level_file
=
open
(
self
.
top_level_path
,
"w"
)
self
.
top_level_file
.
write
(
self
.
header_template
)
self
.
source_files
=
[
self
.
top_level_path
,
]
return
self
def
emit
(
self
,
configuration_name
,
operations
):
with
self
.
emitters
[
self
.
kind
](
self
.
operation_path
,
configuration_name
)
as
configuration_emitter
:
for
operation
in
operations
:
configuration_emitter
.
emit
(
operation
)
self
.
source_files
.
append
(
configuration_emitter
.
configuration_path
)
self
.
configurations
.
append
(
configuration_name
)
self
.
top_level_file
.
write
(
'#include "'
+
self
.
operation_path
+
'/'
+
configuration_name
+
'.h"
\n
'
)
class
GatherGemmScatterManifest
(
Manifest
):
def
emit
(
self
,
target
=
GeneratorTarget
.
Library
):
operation_emitters
=
{
GeneratorTarget
.
Library
:
GatherGemmScatterEmitOperationKindLibrary
}
generated_path
=
os
.
path
.
join
(
self
.
curr_build_dir
,
'generated'
)
# create generated/
if
os
.
path
.
exists
(
generated_path
):
shutil
.
rmtree
(
generated_path
)
os
.
mkdir
(
generated_path
)
source_files
=
[]
# for each operation kind, emit initializer for all configurations
for
operation_kind
,
configurations
in
self
.
operations
.
items
():
with
operation_emitters
[
target
](
generated_path
,
operation_kind
,
self
.
args
)
as
operation_kind_emitter
:
for
configuration_name
,
operations
in
configurations
.
items
():
operation_kind_emitter
.
emit
(
configuration_name
,
operations
)
source_files
+=
operation_kind_emitter
.
source_files
paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_operation.py
0 → 100644
浏览文件 @
507af1c8
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
enum
import
os.path
from
gemm_operation
import
(
EmitGemmConfigurationLibrary
,
EmitGemmInstance
,
EpilogueFunctor
,
GemmOperation
,
SwizzlingFunctor
,
)
from
library
import
(
ComplexTransformTag
,
DataTypeSize
,
DataTypeTag
,
EpilogueFunctorTag
,
GemmKind
,
LayoutTag
,
LayoutType
,
MathOperationTag
,
OpcodeClassTag
,
SubstituteTemplate
,
SwizzlingFunctorTag
,
)
class
EmitGatherGemmScatterInstance
(
EmitGemmInstance
):
def
__init__
(
self
,
operation_suffix
=
''
):
self
.
operation_suffix
=
operation_suffix
self
.
includes
=
[
"cutlass/cutlass.h"
,
"cutlass/numeric_types.h"
,
"cutlass/arch/arch.h"
,
"cutlass/arch/mma.h"
,
"cutlass/layout/matrix.h"
,
"cutlass/gemm/device/gemm.h"
,
"cutlass/gemm/device/gemm_universal_adapter.h"
,
"cutlass/gemm/kernel/default_gemm_universal.h"
,
]
self
.
builtin_epilogue_functor_template
=
"""
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>
"""
self
.
gemm_template
=
"""
// Gemm operator ${operation_name}
struct ${operation_name} {
using Gemm =
cutlass::gemm::device::GemmUniversal<
${element_a},
${layout_a},
${element_b},
${layout_b},
${element_c},
${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor},
${swizzling_functor},
${stages},
${align_a},
${align_b},
${math_operation},
${transform_a},
${transform_b},
true, // gather a
false, // gather b
true // scatter d
>;
};
"""
def
instance_template
(
self
):
return
""
def
emit
(
self
,
operation
):
threadblock_shape
=
operation
.
tile_description
.
threadblock_shape
warp_count
=
operation
.
tile_description
.
warp_count
warp_shape
=
[
threadblock_shape
[
idx
]
//
warp_count
[
idx
]
for
idx
in
range
(
3
)
]
transpose_layouts
=
{
LayoutType
.
ColumnMajor
:
LayoutType
.
ColumnMajor
,
LayoutType
.
RowMajor
:
LayoutType
.
RowMajor
,
}
if
(
operation
.
A
.
layout
in
transpose_layouts
.
keys
()
and
operation
.
B
.
layout
in
transpose_layouts
.
keys
()
and
operation
.
C
.
layout
in
transpose_layouts
.
keys
()
):
instance_layout_A
=
transpose_layouts
[
operation
.
A
.
layout
]
instance_layout_B
=
transpose_layouts
[
operation
.
B
.
layout
]
instance_layout_C
=
transpose_layouts
[
operation
.
C
.
layout
]
gemm_template
=
self
.
gemm_template
else
:
instance_layout_A
,
instance_layout_B
,
instance_layout_C
=
(
operation
.
A
.
layout
,
operation
.
B
.
layout
,
operation
.
C
.
layout
,
)
gemm_template
=
self
.
gemm_template_interleaved
# Support built-in epilogue functors or user-defined functions
if
isinstance
(
operation
.
epilogue_functor
,
enum
.
Enum
):
epilogue_vector_length
=
(
min
(
operation
.
C
.
alignment
*
DataTypeSize
[
operation
.
C
.
element
],
128
,
)
//
DataTypeSize
[
operation
.
C
.
element
]
)
values
=
{
'epilogue_vector_length'
:
str
(
epilogue_vector_length
),
'element_epilogue'
:
str
(
DataTypeTag
[
operation
.
element_epilogue
]
),
'epilogue_functor'
:
EpilogueFunctorTag
[
operation
.
epilogue_functor
],
}
epilogue_functor
=
SubstituteTemplate
(
self
.
builtin_epilogue_functor_template
,
values
)
else
:
epilogue_functor
=
self
.
epilogue_functor
.
emit_declaration
()
values
=
{
'operation_name'
:
operation
.
procedural_name
(),
'operation_suffix'
:
self
.
operation_suffix
,
'element_a'
:
DataTypeTag
[
operation
.
A
.
element
],
'layout_a'
:
LayoutTag
[
instance_layout_A
],
'element_b'
:
DataTypeTag
[
operation
.
B
.
element
],
'layout_b'
:
LayoutTag
[
instance_layout_B
],
'element_c'
:
DataTypeTag
[
operation
.
C
.
element
],
'layout_c'
:
LayoutTag
[
instance_layout_C
],
'element_accumulator'
:
DataTypeTag
[
operation
.
accumulator_type
()],
'opcode_class'
:
OpcodeClassTag
[
operation
.
tile_description
.
math_instruction
.
opcode_class
],
'arch'
:
"cutlass::arch::Sm%d"
%
operation
.
arch
,
'threadblock_shape_m'
:
str
(
operation
.
tile_description
.
threadblock_shape
[
0
]
),
'threadblock_shape_n'
:
str
(
operation
.
tile_description
.
threadblock_shape
[
1
]
),
'threadblock_shape_k'
:
str
(
operation
.
tile_description
.
threadblock_shape
[
2
]
),
'warp_shape_m'
:
str
(
warp_shape
[
0
]),
'warp_shape_n'
:
str
(
warp_shape
[
1
]),
'warp_shape_k'
:
str
(
warp_shape
[
2
]),
'instruction_shape_m'
:
str
(
operation
.
tile_description
.
math_instruction
.
instruction_shape
[
0
]
),
'instruction_shape_n'
:
str
(
operation
.
tile_description
.
math_instruction
.
instruction_shape
[
1
]
),
'instruction_shape_k'
:
str
(
operation
.
tile_description
.
math_instruction
.
instruction_shape
[
2
]
),
'epilogue_functor'
:
epilogue_functor
,
'swizzling_functor'
:
SwizzlingFunctorTag
[
operation
.
swizzling_functor
],
'stages'
:
str
(
operation
.
tile_description
.
stages
),
'align_a'
:
str
(
operation
.
A
.
alignment
),
'align_b'
:
str
(
operation
.
B
.
alignment
),
'transform_a'
:
ComplexTransformTag
[
operation
.
A
.
complex_transform
],
'transform_b'
:
ComplexTransformTag
[
operation
.
B
.
complex_transform
],
'math_operation'
:
MathOperationTag
[
operation
.
tile_description
.
math_instruction
.
math_operation
],
}
return
SubstituteTemplate
(
gemm_template
,
values
)
class
EmitGatherGemmScatterConfigurationLibrary
(
EmitGemmConfigurationLibrary
):
def
__init__
(
self
,
operation_path
,
configuration_name
):
self
.
configuration_name
=
configuration_name
self
.
configuration_path
=
os
.
path
.
join
(
operation_path
,
"%s.h"
%
configuration_name
).
replace
(
'
\\
'
,
'/'
)
self
.
instance_emitter
=
{
GemmKind
.
Universal
:
EmitGatherGemmScatterInstance
,
}
self
.
gemm_kind_wrappers
=
{
GemmKind
.
Universal
:
'GemmUniversalOperation'
,
}
self
.
wmma_guard_start
=
(
"#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
)
self
.
separator
=
"""
///////////////////////////////////////////////////////////////////////////////////////////////////
"""
self
.
header_template
=
"""
/*
Generated by gemm_operation.py - Do not edit.
*/
#pragma once
#ifdef PADDLE_WITH_CUTLASS
"""
self
.
namespace_template
=
"""
namespace phi {
namespace sparse {
"""
self
.
epilogue_template
=
"""
} // namespace sparse
} // namespace phi
#endif
"""
def
__exit__
(
self
,
exception_type
,
exception_value
,
traceback
):
# Write includes
for
incl
,
_
in
self
.
includes
.
items
():
include_statement
=
"#include
\"
%s
\"\n
"
%
incl
self
.
configuration_file
.
write
(
include_statement
)
self
.
configuration_file
.
write
(
self
.
separator
)
self
.
configuration_file
.
write
(
self
.
namespace_template
)
# Write instance definitions in top-level namespace
for
instance_definition
in
self
.
instance_definitions
:
self
.
configuration_file
.
write
(
instance_definition
)
for
instance_wrapper
in
self
.
instance_wrappers
:
self
.
configuration_file
.
write
(
instance_wrapper
)
self
.
configuration_file
.
write
(
self
.
epilogue_template
)
self
.
configuration_file
.
close
()
class
GatherGemmScatterOperation
(
GemmOperation
):
# cutlass transpose A and B in the library.py, so we transpose it back here
def
__init__
(
self
,
gemm_kind
,
arch
,
tile_description
,
A
,
B
,
C
,
element_epilogue
,
epilogue_functor
=
EpilogueFunctor
.
LinearCombination
,
swizzling_functor
=
SwizzlingFunctor
.
Identity8
,
):
super
().
__init__
(
gemm_kind
,
arch
,
tile_description
,
A
,
B
,
C
,
element_epilogue
,
epilogue_functor
=
EpilogueFunctor
.
LinearCombination
,
swizzling_functor
=
SwizzlingFunctor
.
Identity8
,
)
self
.
ShortLayoutTypeNames
=
{
LayoutType
.
ColumnMajor
:
't'
,
LayoutType
.
ColumnMajorInterleaved2
:
't2'
,
LayoutType
.
ColumnMajorInterleaved32
:
't32'
,
LayoutType
.
ColumnMajorInterleaved64
:
't64'
,
LayoutType
.
RowMajor
:
'n'
,
LayoutType
.
RowMajorInterleaved2
:
'n2'
,
LayoutType
.
RowMajorInterleaved32
:
'n32'
,
LayoutType
.
RowMajorInterleaved64
:
'n64'
,
LayoutType
.
TensorNHWC
:
'nhwc'
,
LayoutType
.
TensorNDHWC
:
'ndhwc'
,
LayoutType
.
TensorNCHW
:
'nchw'
,
LayoutType
.
TensorNGHWC
:
'nghwc'
,
LayoutType
.
TensorNC32HW32
:
'nc32hw32'
,
LayoutType
.
TensorNC64HW64
:
'nc64hw64'
,
LayoutType
.
TensorC32RSK32
:
'c32rsk32'
,
LayoutType
.
TensorC64RSK64
:
'c64rsk64'
,
}
def
layout_name
(
self
):
return
"%s%s"
%
(
self
.
ShortLayoutTypeNames
[
self
.
A
.
layout
],
self
.
ShortLayoutTypeNames
[
self
.
B
.
layout
],
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录