未验证 提交 12d43da9 编写于 作者: U umiswing 提交者: GitHub

Auto tune for cutlass (#50809)

上级 be9515f2
......@@ -39,12 +39,14 @@ ExternalProject_Add(
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND
rm -rf
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build &&
mkdir -p
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/build/generated/gemm
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build/generated/gemm
&& ${PYTHON_EXECUTABLE} -B
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py
"${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/library/scripts/"
"${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/build"
"${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build"
"${CMAKE_CUDA_COMPILER_VERSION}"
INSTALL_COMMAND ""
TEST_COMMAND "")
......
......@@ -177,6 +177,85 @@ class MatmulAutoTuner
}
};
template <typename T, typename ReturnType, typename... Args>
class GatherGemmScatterAutoTuner
: public AutoTuneBase<T, KernelCallback<T, ReturnType, T, T, Args...>> {
public:
static GatherGemmScatterAutoTuner<T, ReturnType, Args...>* Instance(
ReturnType (*func)(T, T, Args...)) {
static std::once_flag gather_gemm_scatter_init_flag;
static std::unique_ptr<GatherGemmScatterAutoTuner<T, ReturnType, Args...>>
instance;
std::call_once(gather_gemm_scatter_init_flag, [&] {
auto obj = MakeCallback<T>(func);
instance.reset(new GatherGemmScatterAutoTuner<T, ReturnType, Args...>);
instance->AddCallBack(func);
});
return instance.get();
}
void Run(const phi::GPUContext& ctx,
const size_t key,
T const alpha,
T const beta,
Args... args) {
this->is_init_ = true;
this->CheckKernelSize();
auto& cache = AutoTuneCache::Instance().GetGatherGemmScatter<T>();
if (cache.Find(key)) {
auto best_idx = cache.Get(key);
this->kernels_[best_idx].Run(alpha, beta, args...);
} else {
// Set alpha to 0 and beta to 1 to avoid changing the value of d when
// picking the best kernel
auto best_idx =
PickBestKernel(ctx, static_cast<T>(0), static_cast<T>(1), args...);
cache.Set(key, best_idx);
this->kernels_[best_idx].Run(alpha, beta, args...);
}
}
protected:
size_t PickBestKernel(const phi::GPUContext& ctx,
const T& alpha,
const T& beta,
Args&... args) {
std::lock_guard<std::mutex> lock(this->mutex_);
constexpr size_t NO_KERNEL_WORKS = -1;
size_t best_idx = NO_KERNEL_WORKS;
float min_time = std::numeric_limits<float>::max();
// Time cost test estabulished in default stream.
for (int i = 0; i < this->kernels_.size(); ++i) {
float time = 0;
// Some kernels may require more shared memory than available, skip these
// kernels.
try {
time = this->RunAndMeasureKernel(ctx, i, alpha, beta, args...);
if (time < min_time) {
min_time = time;
best_idx = i;
}
} catch (const std::runtime_error& error) {
VLOG(3) << "the kernels_[" << i << "] get error:" << error.what();
}
}
if (best_idx == NO_KERNEL_WORKS) {
LOG(ERROR) << "No kernel works!\n";
exit(-1);
}
VLOG(3) << "best kernel idx is " << best_idx;
return best_idx;
}
};
template <typename T, typename ReturnType, typename... Args>
static GatherGemmScatterAutoTuner<T, ReturnType, Args...>*
MakeGatherGemmScatterTuner(ReturnType (*func)(T, T, Args...)) {
return GatherGemmScatterAutoTuner<T, ReturnType, Args...>::Instance(func);
}
// Define the auto_tuner inital object.
#define DEFINE_AUTOTUNER_COMMON_OBJ(name) \
template <typename T, typename ReturnType, typename... Args> \
......
......@@ -45,13 +45,15 @@ enum class AlgorithmType {
kConvBackwardFilter = 3,
kTranspose = 4,
kMatmul = 5,
kGatherGemmScatterFP16NN = 6,
kGatherGemmScatterFP32NN = 7,
#if !defined(PADDLE_WITH_CUDNN_FRONTEND)
kAlgorithmCount = 6
kAlgorithmCount = 8
#else
kConvForwardV8 = 6,
kConvBackwardDataV8 = 7,
kConvBackwardFilterV8 = 8,
kAlgorithmCount = 9
kConvForwardV8 = 8,
kConvBackwardDataV8 = 9,
kConvBackwardFilterV8 = 10,
kAlgorithmCount = 11
#endif
};
......@@ -88,6 +90,20 @@ class AutoTuneCache {
return conv_auto_tune_map_[static_cast<int64_t>(algo_type)];
}
template <typename T>
typename std::enable_if<std::is_same<T, float>::value,
AlgorithmsCacheMap&>::type
GetGatherGemmScatter() {
return Get(AlgorithmType::kGatherGemmScatterFP32NN);
}
template <typename T>
typename std::enable_if<std::is_same<T, phi::dtype::float16>::value,
AlgorithmsCacheMap&>::type
GetGatherGemmScatter() {
return Get(AlgorithmType::kGatherGemmScatterFP16NN);
}
#ifdef PADDLE_WITH_CUDNN_FRONTEND
CudnnFrontendPlanCache& GetConvV8(const AlgorithmType& algo_type) {
return cudnn_v8_auto_tune_map_[static_cast<int64_t>(algo_type)];
......
......@@ -125,12 +125,16 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
#ifdef PADDLE_WITH_CUTLASS
bool cutlass = true;
if (dev_ctx.GetComputeCapability() < 75) cutlass = false;
if (in_channels % 4 != 0 || out_channels % 4 != 0) {
if (dev_ctx.GetComputeCapability() < 80) cutlass = false;
if (in_channels % 8 != 0 || out_channels % 8 != 0) {
if (std::is_same<T, phi::dtype::float16>::value) cutlass = false;
}
if (in_channels % 4 != 0 || out_channels % 4 != 0) {
if (std::is_same<T, float>::value) cutlass = false;
}
if (std::is_same<T, double>::value) cutlass = false;
if (!std::is_same<IntT, int32_t>::value) cutlass = false;
if (cutlass) {
auto* out_values = out->mutable_non_zero_elements();
T* out_values_ptr = out_values->data<T>();
......@@ -150,18 +154,18 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i];
const IntT* scatter_indices =
rulebook_ptr + rulebook_len + h_offsets_ptr[i];
dispatchKernel(dev_ctx,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
out_values_ptr,
M,
N,
K,
gather_indices,
scatter_indices,
cutlass,
x.dtype());
GatherGemmScatterDriver(dev_ctx,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
out_values_ptr,
M,
N,
K,
gather_indices,
scatter_indices,
static_cast<T>(1.0),
static_cast<T>(1.0));
}
} else {
#endif
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_CUTLASS
#include "cutlass/arch/mma.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/half.h"
#include "cutlass/util/device_memory.h"
#include "examples/common/helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace phi {
namespace sparse {
#define TYPEDEF_KERNEL_POINTER(kernel, dtype) \
typedef void (*kernel)(dtype const alpha, \
dtype const beta, \
const GPUContext& dev_ctx, \
const dtype* const a, \
const dtype* const b, \
const dtype* const c, \
dtype* const d, \
const int m, \
const int n, \
const int k, \
const int32_t* a_indices, \
const int32_t* c_d_indices);
#define GATHER_GEMM_SCATTER_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
throw std::runtime_error(cutlassGetStatusString(error)); \
} \
}
#define DEFINE_LAUNCH_KERNEL(dtype, cutlass_type) \
template <typename Gemm> \
void launchKernel(dtype const alpha, \
dtype const beta, \
const GPUContext& dev_ctx, \
const dtype* const a, \
const dtype* const b, \
const dtype* const c, \
dtype* const d, \
const int m, \
const int n, \
const int k, \
const int32_t* a_indices, \
const int32_t* c_d_indices) { \
cutlass::gemm::GemmCoord problem_size_real({m, n, k}); \
int split_k_slices = 1; \
typename Gemm::Arguments arguments{ \
cutlass::gemm::GemmUniversalMode::kGemm, \
problem_size_real, \
split_k_slices, \
{static_cast<const cutlass_type>(static_cast<const float>(alpha)), \
static_cast<const cutlass_type>(static_cast<const float>(beta))}, \
reinterpret_cast<const cutlass_type* const>(a), \
reinterpret_cast<const cutlass_type* const>(b), \
reinterpret_cast<const cutlass_type* const>(c), \
reinterpret_cast<cutlass_type* const>(d), \
cutlass::layout::RowMajor().capacity(problem_size_real.mk()), \
cutlass::layout::RowMajor().capacity(problem_size_real.kn()), \
cutlass::layout::RowMajor().capacity(problem_size_real.mn()), \
cutlass::layout::RowMajor().capacity(problem_size_real.mn()), \
problem_size_real.k(), \
problem_size_real.n(), \
problem_size_real.n(), \
problem_size_real.n(), \
a_indices, \
nullptr, \
c_d_indices}; \
size_t workspace_size = Gemm::get_workspace_size(arguments); \
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); \
Gemm gemm_op; \
cutlass::Status status = gemm_op.can_implement(arguments); \
GATHER_GEMM_SCATTER_CHECK(status); \
status = gemm_op.initialize(arguments, workspace.get()); \
GATHER_GEMM_SCATTER_CHECK(status); \
gemm_op(dev_ctx.stream()); \
}
TYPEDEF_KERNEL_POINTER(fp16_gather_gemm_scatter, phi::dtype::float16)
TYPEDEF_KERNEL_POINTER(fp32_gather_gemm_scatter, float)
DEFINE_LAUNCH_KERNEL(phi::dtype::float16, cutlass::half_t)
DEFINE_LAUNCH_KERNEL(float, float)
} // namespace sparse
} // namespace phi
#endif
......@@ -41,7 +41,6 @@ def CreateGatherGemmScatterOperator(
layouts,
tile_descriptions,
data_type,
alignment_constraints,
complex_transforms=None,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity8,
......@@ -55,12 +54,15 @@ def CreateGatherGemmScatterOperator(
element_a, element_b, element_c, element_epilogue = data_type
operations = []
alignment_constraints = [0]
if 'f16' == element_a.name or 'bf16' == element_a.name:
alignment_constraints = [8]
elif 'f32' == element_a.name or 'tf32' == element_a.name:
alignment_constraints = [4]
elif 'f64' == element_a.name:
alignment_constraints = [1]
# by default, only generate the largest tile and largest alignment
# if manifest.kernel_filter == '':
# tile_descriptions = [tile_descriptions[0],]
# alignment_constraints = [alignment_constraints[0],]
operations = []
for layout in layouts:
for tile_description in tile_descriptions:
......@@ -95,9 +97,9 @@ def CreateGatherGemmScatterOperator(
return operations
def GenerateSM70_TensorOp_884(manifest, cuda_version):
def GenerateSM80_TensorOp_16816(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 10, 1):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return
layouts = [
......@@ -106,15 +108,7 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version):
math_instructions = [
MathInstruction(
[8, 8, 4],
DataType.f16,
DataType.f16,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
),
MathInstruction(
[8, 8, 4],
[16, 8, 16],
DataType.f16,
DataType.f16,
DataType.f16,
......@@ -123,36 +117,78 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version):
),
]
min_cc = 70
max_cc = 75
min_cc = 80
max_cc = 1024
alignment_constraints = [8, 4, 2, 1]
alignment_constraints = [8]
for math_inst in math_instructions:
tile_descriptions = [
TileDescription(
[256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
[256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 32], 3, [4, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
[256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
[128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc
[256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc
[64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
[128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
[256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
[64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc
),
]
......@@ -164,11 +200,7 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version):
]
CreateGatherGemmScatterOperator(
manifest,
layouts,
tile_descriptions,
data_type,
alignment_constraints,
manifest, layouts, tile_descriptions, data_type
)
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
......@@ -182,16 +214,286 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version):
]
CreateGatherGemmScatterOperator(
manifest,
layouts,
tile_descriptions,
data_type_mixed,
alignment_constraints,
manifest, layouts, tile_descriptions, data_type_mixed
)
def GenerateSM70(manifest, cuda_version):
GenerateSM70_TensorOp_884(manifest, cuda_version)
def GenerateSM80_TensorOp_1688(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return
layouts = [
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
]
math_instructions = [
MathInstruction(
[16, 8, 8],
DataType.tf32,
DataType.tf32,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
)
]
min_cc = 80
max_cc = 1024
for math_inst in math_instructions:
tile_descriptions = [
TileDescription(
[256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_accumulator,
math_inst.element_accumulator,
]
data_type_mixed = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_a,
math_inst.element_accumulator,
]
CreateGatherGemmScatterOperator(
manifest, layouts, tile_descriptions, data_type
)
CreateGatherGemmScatterOperator(
manifest, layouts, tile_descriptions, data_type_mixed
)
def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return
layouts = [
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
]
math_instructions = [
MathInstruction(
[16, 8, 8],
DataType.tf32,
DataType.tf32,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
),
]
min_cc = 80
max_cc = 1024
for math_inst in math_instructions:
tile_descriptions = [
TileDescription(
[256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32]
CreateGatherGemmScatterOperator(
manifest, layouts, tile_descriptions, data_type
)
def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return
layouts = [
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
]
math_instructions = [
MathInstruction(
[16, 8, 8],
DataType.f32,
DataType.f32,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add_fast_f32,
),
]
min_cc = 80
max_cc = 1024
for math_inst in math_instructions:
tile_descriptions = [
TileDescription(
[128, 128, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32]
CreateGatherGemmScatterOperator(
manifest, layouts, tile_descriptions, data_type
)
def GenerateSM80(manifest, cuda_version):
GenerateSM80_TensorOp_16816(manifest, cuda_version)
GenerateSM80_TensorOp_1688(manifest, cuda_version)
GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version)
GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version)
class KernelCfg:
......@@ -229,7 +531,7 @@ class KernelCfg:
if __name__ == "__main__":
args = KernelCfg(
architectures='70',
architectures='80',
build_dir=sys.argv[2],
cuda_version=sys.argv[3],
curr_build_dir=sys.argv[2],
......@@ -245,6 +547,6 @@ if __name__ == "__main__":
)
manifest = GatherGemmScatterManifest(args)
GenerateSM70(manifest, args.cuda_version)
GenerateSM80(manifest, args.cuda_version)
manifest.emit(GeneratorTarget.Library)
......@@ -18,7 +18,7 @@ import shutil
from gather_gemm_scatter_operation import (
EmitGatherGemmScatterConfigurationLibrary,
)
from library import OperationKind, OperationKindNames
from library import OperationKind, OperationKindNames, SubstituteTemplate
from manifest import EmitOperationKindLibrary, GeneratorTarget, Manifest
......@@ -28,11 +28,25 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary):
self.emitters = {
OperationKind.Gemm: EmitGatherGemmScatterConfigurationLibrary
}
self.header_template = "#pragma once\n#ifdef PADDLE_WITH_CUTLASS\n"
self.header_template = "#pragma once\n#ifdef PADDLE_WITH_CUTLASS\n#include \"paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h\"\n"
self.entry_template = ""
self.configuration_prototype_template = ""
self.configuration_template = ""
self.epilogue_template = "#endif"
self.namespace_template = """
namespace phi {
namespace sparse {
"""
self.epilogue_template = """
} // namespace sparse
} // namespace phi
#endif
"""
self.fp16_kernels_list = (
"static std::vector<fp16_gather_gemm_scatter> fp16_kernels = {\n"
)
self.fp32_kernels_list = (
"static std::vector<fp32_gather_gemm_scatter> fp32_kernels = {\n"
)
def __enter__(self):
self.operation_path = os.path.join(
......@@ -64,6 +78,21 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary):
self.source_files.append(configuration_emitter.configuration_path)
self.configurations.append(configuration_name)
if 'h' == operations[0].short_math_name():
self.fp16_kernels_list += (
"""
launchKernel<"""
+ configuration_name
+ "::Gemm>,"
)
if 's' == operations[0].short_math_name():
self.fp32_kernels_list += (
"""
launchKernel<"""
+ configuration_name
+ "::Gemm>,"
)
self.top_level_file.write(
'#include "'
+ self.operation_path
......@@ -72,6 +101,30 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary):
+ '.h"\n'
)
def __exit__(self, exception_type, exception_value, traceback):
self.top_level_file.write(
SubstituteTemplate(
self.entry_template,
{'operation_name': OperationKindNames[self.kind]},
)
)
for configuration_name in self.configurations:
self.top_level_file.write(
SubstituteTemplate(
self.configuration_template,
{'configuration_name': configuration_name},
)
)
self.fp16_kernels_list += "\n};\n"
self.fp32_kernels_list += "\n};\n"
self.top_level_file.write(self.namespace_template)
self.top_level_file.write(self.fp16_kernels_list)
self.top_level_file.write(self.fp32_kernels_list)
self.top_level_file.write(self.epilogue_template)
self.top_level_file.close()
class GatherGemmScatterManifest(Manifest):
def emit(self, target=GeneratorTarget.Library):
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import enum
import os.path
......@@ -40,16 +41,7 @@ from library import (
class EmitGatherGemmScatterInstance(EmitGemmInstance):
def __init__(self, operation_suffix=''):
self.operation_suffix = operation_suffix
self.includes = [
"cutlass/cutlass.h",
"cutlass/numeric_types.h",
"cutlass/arch/arch.h",
"cutlass/arch/mma.h",
"cutlass/layout/matrix.h",
"cutlass/gemm/device/gemm.h",
"cutlass/gemm/device/gemm_universal_adapter.h",
"cutlass/gemm/kernel/default_gemm_universal.h",
]
self.includes = []
self.builtin_epilogue_functor_template = """
${epilogue_functor}<
${element_c},
......@@ -247,6 +239,18 @@ namespace sparse {
#endif
"""
def __enter__(self):
self.configuration_file = open(self.configuration_path, "w")
self.configuration_file.write(self.header_template)
self.configuration_file.write(self.separator)
self.includes = collections.OrderedDict([])
self.instance_definitions = []
self.instance_wrappers = []
self.operations = []
return self
def __exit__(self, exception_type, exception_value, traceback):
# Write includes
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
namespace phi {
namespace sparse {
fp16_gather_gemm_scatter getBestFp16Kernel(const int M,
const int N,
const int K) {
if (K == 4 && N == 16) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4::Gemm>;
}
if (K == 16 && N == 16) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 16 && N == 32) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 32 && N == 32) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 32 && N == 64) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 64 && N == 64) {
if (M > 100000)
launchKernel<
cutlass::half_t,
cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8::Gemm>;
if (M > 20000)
launchKernel<
cutlass::half_t,
cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8::Gemm>;
if (M > 15000)
return launchKernel<
cutlass::half_t,
cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8::Gemm>;
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 128) {
if (M >= 5000)
return launchKernel<
cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
return launchKernel<cutlass::half_t,
cutlass_tensorop_h16816gemm_64x64_64x5_nn_align8::Gemm>;
}
if (N == 128) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4::Gemm>;
}
fp32_gather_gemm_scatter getBestFp32Kernel(const int M,
const int N,
const int K,
const int SM) {
if (SM == 75) {
return launchKernel<
float,
cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4::Gemm>;
}
if (K == 4 && N == 16) {
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 16 && N == 16) {
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 16 && N == 32) {
if (M >= 10000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 32 && N == 32) {
if (M >= 10000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 32 && N == 64) {
if (M >= 10000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 64 && N == 64) {
if (M >= 15000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 128) {
if (M >= 100000)
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4::Gemm>;
if (M >= 5000)
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4::Gemm>;
}
if (N == 128) {
if (M >= 100000)
return launchKernel<
float,
cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4::Gemm>;
if (M >= 5000)
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4::Gemm>;
}
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
fp64_gather_gemm_scatter getBestFp64Kernel(const int M,
const int N,
const int K) {
if (K == 4 && N == 16) {
return launchKernel<double,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1::Gemm>;
}
if (K == 16 && N == 16) {
if (M >= 10000)
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
return launchKernel<double,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1::Gemm>;
}
if (K == 16 && N == 32) {
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
if (K == 32 && N == 32) {
return launchKernel<double,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1::Gemm>;
}
if (K == 32 && N == 64) {
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
if (K == 64 && N == 64) {
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
} // namespace sparse
} // namespace phi
#endif
......@@ -13,628 +13,75 @@
// limitations under the License.
#pragma once
#include <type_traits>
#ifdef PADDLE_WITH_CUTLASS
#include "cutlass/arch/mma.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/util/device_memory.h"
#include "examples/common/helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/autotune/auto_tune_base.h"
#include "paddle/phi/kernels/sparse/gpu/cutlass_generator/build/generated/gemm/all_gemm_operations.h"
namespace phi {
namespace sparse {
typedef void (*fp16_gather_gemm_scatter)(const GPUContext& dev_ctx,
const cutlass::half_t* const a,
const cutlass::half_t* const b,
const cutlass::half_t* const c,
cutlass::half_t* const d,
const int m,
const int n,
const int k,
const int32_t* a_indices,
const int32_t* c_d_indices,
cutlass::half_t const alpha,
cutlass::half_t const beta);
typedef void (*fp32_gather_gemm_scatter)(const GPUContext& dev_ctx,
const float* const a,
const float* const b,
const float* const c,
float* const d,
const int m,
const int n,
const int k,
const int32_t* a_indices,
const int32_t* c_d_indices,
float const alpha,
float const beta);
typedef void (*fp64_gather_gemm_scatter)(const GPUContext& dev_ctx,
const double* const a,
const double* const b,
const double* const c,
double* const d,
const int m,
const int n,
const int k,
const int32_t* a_indices,
const int32_t* c_d_indices,
double const alpha,
double const beta);
fp16_gather_gemm_scatter getBestFp16Kernel(const int M,
const int K,
const int N);
fp32_gather_gemm_scatter getBestFp32Kernel(const int M,
const int K,
const int N,
const int SM);
fp64_gather_gemm_scatter getBestFp64Kernel(const int M,
const int K,
const int N);
template <typename T, typename Gemm>
void launchKernel(const GPUContext& dev_ctx,
const T* const a,
const T* const b,
const T* const c,
T* const d,
const int m,
const int n,
const int k,
const int32_t* a_indices,
const int32_t* c_d_indices,
T const alpha,
T const beta) {
cutlass::gemm::GemmCoord problem_size_real({m, n, k});
int split_k_slices = 1;
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size_real,
split_k_slices,
{alpha, beta},
a,
b,
c,
d,
cutlass::layout::RowMajor().capacity(problem_size_real.mk()),
cutlass::layout::RowMajor().capacity(problem_size_real.kn()),
cutlass::layout::RowMajor().capacity(problem_size_real.mn()),
cutlass::layout::RowMajor().capacity(problem_size_real.mn()),
problem_size_real.k(),
problem_size_real.n(),
problem_size_real.n(),
problem_size_real.n(),
a_indices,
nullptr,
c_d_indices};
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
Gemm gemm_op;
cutlass::Status status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(status);
gemm_op(dev_ctx.stream());
}
static void dispatchKernel(const GPUContext& dev_ctx,
const void* const a,
const void* const b,
const void* const c,
void* const d,
const int m,
const int n,
const int k,
const void* a_indices,
const void* c_d_indices,
const bool cutlass,
const phi::DataType type) {
if (!cutlass) return;
if (type == phi::DataType::FLOAT16) {
fp16_gather_gemm_scatter gather_gemm_scatter = getBestFp16Kernel(m, n, k);
gather_gemm_scatter(dev_ctx,
static_cast<const cutlass::half_t*>(a),
static_cast<const cutlass::half_t*>(b),
static_cast<const cutlass::half_t*>(c),
static_cast<cutlass::half_t*>(d),
m,
n,
k,
static_cast<const int32_t*>(a_indices),
static_cast<const int32_t*>(c_d_indices),
static_cast<cutlass::half_t>(1),
static_cast<cutlass::half_t>(1));
} else if (type == phi::DataType::FLOAT32) {
fp32_gather_gemm_scatter gather_gemm_scatter =
getBestFp32Kernel(m, n, k, dev_ctx.GetComputeCapability());
gather_gemm_scatter(dev_ctx,
static_cast<const float*>(a),
static_cast<const float*>(b),
static_cast<const float*>(c),
static_cast<float*>(d),
m,
n,
k,
static_cast<const int32_t*>(a_indices),
static_cast<const int32_t*>(c_d_indices),
static_cast<float>(1),
static_cast<float>(1));
} else if (type == phi::DataType::FLOAT64) {
fp64_gather_gemm_scatter gather_gemm_scatter = getBestFp64Kernel(m, n, k);
gather_gemm_scatter(dev_ctx,
static_cast<const double*>(a),
static_cast<const double*>(b),
static_cast<const double*>(c),
static_cast<double*>(d),
m,
n,
k,
static_cast<const int32_t*>(a_indices),
static_cast<const int32_t*>(c_d_indices),
static_cast<double>(1),
static_cast<double>(1));
// To reduce tuning time, map shape (m,n,k) to (m/features_num_range,n,k) so
// that shapes in this range share the same key.
constexpr int features_num_range = 10000;
#define DEFINE_GATHER_GEMM_SCATTER_DRIVER(dtype, kernels) \
template <typename T, typename IntT> \
typename std::enable_if<std::is_same<T, dtype>::value && \
std::is_same<IntT, int32_t>::value, \
void>::type \
GatherGemmScatterDriver(const phi::GPUContext& ctx, \
const T* const a, \
const T* const b, \
const T* const c, \
T* const d, \
const int& m, \
const int& n, \
const int& k, \
const IntT* a_indices, \
const IntT* c_d_indices, \
T alpha, \
T beta) { \
auto* tuner = autotune::MakeGatherGemmScatterTuner(kernels[0]); \
for (auto i = 1; i < kernels.size(); i++) tuner->AddCallBack(kernels[i]); \
size_t key = autotune::GenKey(m / features_num_range, n, k); \
tuner->Run(ctx, \
key, \
alpha, \
beta, \
ctx, \
a, \
b, \
c, \
d, \
m, \
n, \
k, \
a_indices, \
c_d_indices); \
}
}
struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
8,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h1688gemm_64x128_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
8,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
4,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
4,
4,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
4,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
4,
4,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
8,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h16816gemm_64x64_64x5_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
8,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
5,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::
LinearCombination<cutlass::half_t, 8, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::
LinearCombination<cutlass::half_t, 8, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
10,
4,
4,
cutlass::arch::OpMultiplyAddFastF16,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 16>,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
3,
4,
4,
cutlass::arch::OpMultiplyAddFastF16,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 16>,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
4,
4,
4,
cutlass::arch::OpMultiplyAddFastF16,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 16>,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
3,
4,
4,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 16>,
cutlass::gemm::GemmShape<32, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
6,
4,
4,
cutlass::arch::OpMultiplyAddFastF16,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
3,
4,
4,
cutlass::arch::OpMultiplyAddFastF32,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_d884gemm_16x32_16x5_nn_align1 {
using Gemm = cutlass::gemm::device::GemmUniversal<
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::RowMajor,
double,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<16, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<double, 1, double, double>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
5,
1,
1,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_d884gemm_32x16_16x5_nn_align1 {
using Gemm = cutlass::gemm::device::GemmUniversal<
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::RowMajor,
double,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 16, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<double, 1, double, double>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
5,
1,
1,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
template <typename T, typename IntT>
typename std::enable_if<std::is_same<T, double>::value ||
!std::is_same<IntT, int32_t>::value,
void>::type
GatherGemmScatterDriver(const phi::GPUContext& ctx,
const T* const a,
const T* const b,
const T* const c,
T* const d,
const int& m,
const int& n,
const int& k,
const IntT* a_indices,
const IntT* c_d_indices,
T alpha,
T beta) {}
// sm75
struct cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd>;
};
DEFINE_GATHER_GEMM_SCATTER_DRIVER(phi::dtype::float16, fp16_kernels)
DEFINE_GATHER_GEMM_SCATTER_DRIVER(float, fp32_kernels)
} // namespace sparse
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册