未验证 提交 12d43da9 编写于 作者: U umiswing 提交者: GitHub

Auto tune for cutlass (#50809)

上级 be9515f2
...@@ -39,12 +39,14 @@ ExternalProject_Add( ...@@ -39,12 +39,14 @@ ExternalProject_Add(
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND BUILD_COMMAND
rm -rf
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build &&
mkdir -p mkdir -p
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/build/generated/gemm ${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build/generated/gemm
&& ${PYTHON_EXECUTABLE} -B && ${PYTHON_EXECUTABLE} -B
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py ${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py
"${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/library/scripts/" "${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/library/scripts/"
"${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/build" "${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build"
"${CMAKE_CUDA_COMPILER_VERSION}" "${CMAKE_CUDA_COMPILER_VERSION}"
INSTALL_COMMAND "" INSTALL_COMMAND ""
TEST_COMMAND "") TEST_COMMAND "")
......
...@@ -177,6 +177,85 @@ class MatmulAutoTuner ...@@ -177,6 +177,85 @@ class MatmulAutoTuner
} }
}; };
template <typename T, typename ReturnType, typename... Args>
class GatherGemmScatterAutoTuner
: public AutoTuneBase<T, KernelCallback<T, ReturnType, T, T, Args...>> {
public:
static GatherGemmScatterAutoTuner<T, ReturnType, Args...>* Instance(
ReturnType (*func)(T, T, Args...)) {
static std::once_flag gather_gemm_scatter_init_flag;
static std::unique_ptr<GatherGemmScatterAutoTuner<T, ReturnType, Args...>>
instance;
std::call_once(gather_gemm_scatter_init_flag, [&] {
auto obj = MakeCallback<T>(func);
instance.reset(new GatherGemmScatterAutoTuner<T, ReturnType, Args...>);
instance->AddCallBack(func);
});
return instance.get();
}
void Run(const phi::GPUContext& ctx,
const size_t key,
T const alpha,
T const beta,
Args... args) {
this->is_init_ = true;
this->CheckKernelSize();
auto& cache = AutoTuneCache::Instance().GetGatherGemmScatter<T>();
if (cache.Find(key)) {
auto best_idx = cache.Get(key);
this->kernels_[best_idx].Run(alpha, beta, args...);
} else {
// Set alpha to 0 and beta to 1 to avoid changing the value of d when
// picking the best kernel
auto best_idx =
PickBestKernel(ctx, static_cast<T>(0), static_cast<T>(1), args...);
cache.Set(key, best_idx);
this->kernels_[best_idx].Run(alpha, beta, args...);
}
}
protected:
size_t PickBestKernel(const phi::GPUContext& ctx,
const T& alpha,
const T& beta,
Args&... args) {
std::lock_guard<std::mutex> lock(this->mutex_);
constexpr size_t NO_KERNEL_WORKS = -1;
size_t best_idx = NO_KERNEL_WORKS;
float min_time = std::numeric_limits<float>::max();
// Time cost test estabulished in default stream.
for (int i = 0; i < this->kernels_.size(); ++i) {
float time = 0;
// Some kernels may require more shared memory than available, skip these
// kernels.
try {
time = this->RunAndMeasureKernel(ctx, i, alpha, beta, args...);
if (time < min_time) {
min_time = time;
best_idx = i;
}
} catch (const std::runtime_error& error) {
VLOG(3) << "the kernels_[" << i << "] get error:" << error.what();
}
}
if (best_idx == NO_KERNEL_WORKS) {
LOG(ERROR) << "No kernel works!\n";
exit(-1);
}
VLOG(3) << "best kernel idx is " << best_idx;
return best_idx;
}
};
template <typename T, typename ReturnType, typename... Args>
static GatherGemmScatterAutoTuner<T, ReturnType, Args...>*
MakeGatherGemmScatterTuner(ReturnType (*func)(T, T, Args...)) {
return GatherGemmScatterAutoTuner<T, ReturnType, Args...>::Instance(func);
}
// Define the auto_tuner inital object. // Define the auto_tuner inital object.
#define DEFINE_AUTOTUNER_COMMON_OBJ(name) \ #define DEFINE_AUTOTUNER_COMMON_OBJ(name) \
template <typename T, typename ReturnType, typename... Args> \ template <typename T, typename ReturnType, typename... Args> \
......
...@@ -45,13 +45,15 @@ enum class AlgorithmType { ...@@ -45,13 +45,15 @@ enum class AlgorithmType {
kConvBackwardFilter = 3, kConvBackwardFilter = 3,
kTranspose = 4, kTranspose = 4,
kMatmul = 5, kMatmul = 5,
kGatherGemmScatterFP16NN = 6,
kGatherGemmScatterFP32NN = 7,
#if !defined(PADDLE_WITH_CUDNN_FRONTEND) #if !defined(PADDLE_WITH_CUDNN_FRONTEND)
kAlgorithmCount = 6 kAlgorithmCount = 8
#else #else
kConvForwardV8 = 6, kConvForwardV8 = 8,
kConvBackwardDataV8 = 7, kConvBackwardDataV8 = 9,
kConvBackwardFilterV8 = 8, kConvBackwardFilterV8 = 10,
kAlgorithmCount = 9 kAlgorithmCount = 11
#endif #endif
}; };
...@@ -88,6 +90,20 @@ class AutoTuneCache { ...@@ -88,6 +90,20 @@ class AutoTuneCache {
return conv_auto_tune_map_[static_cast<int64_t>(algo_type)]; return conv_auto_tune_map_[static_cast<int64_t>(algo_type)];
} }
template <typename T>
typename std::enable_if<std::is_same<T, float>::value,
AlgorithmsCacheMap&>::type
GetGatherGemmScatter() {
return Get(AlgorithmType::kGatherGemmScatterFP32NN);
}
template <typename T>
typename std::enable_if<std::is_same<T, phi::dtype::float16>::value,
AlgorithmsCacheMap&>::type
GetGatherGemmScatter() {
return Get(AlgorithmType::kGatherGemmScatterFP16NN);
}
#ifdef PADDLE_WITH_CUDNN_FRONTEND #ifdef PADDLE_WITH_CUDNN_FRONTEND
CudnnFrontendPlanCache& GetConvV8(const AlgorithmType& algo_type) { CudnnFrontendPlanCache& GetConvV8(const AlgorithmType& algo_type) {
return cudnn_v8_auto_tune_map_[static_cast<int64_t>(algo_type)]; return cudnn_v8_auto_tune_map_[static_cast<int64_t>(algo_type)];
......
...@@ -125,12 +125,16 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -125,12 +125,16 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
#ifdef PADDLE_WITH_CUTLASS #ifdef PADDLE_WITH_CUTLASS
bool cutlass = true; bool cutlass = true;
if (dev_ctx.GetComputeCapability() < 75) cutlass = false; if (dev_ctx.GetComputeCapability() < 80) cutlass = false;
if (in_channels % 4 != 0 || out_channels % 4 != 0) { if (in_channels % 8 != 0 || out_channels % 8 != 0) {
if (std::is_same<T, phi::dtype::float16>::value) cutlass = false; if (std::is_same<T, phi::dtype::float16>::value) cutlass = false;
}
if (in_channels % 4 != 0 || out_channels % 4 != 0) {
if (std::is_same<T, float>::value) cutlass = false; if (std::is_same<T, float>::value) cutlass = false;
} }
if (std::is_same<T, double>::value) cutlass = false;
if (!std::is_same<IntT, int32_t>::value) cutlass = false; if (!std::is_same<IntT, int32_t>::value) cutlass = false;
if (cutlass) { if (cutlass) {
auto* out_values = out->mutable_non_zero_elements(); auto* out_values = out->mutable_non_zero_elements();
T* out_values_ptr = out_values->data<T>(); T* out_values_ptr = out_values->data<T>();
...@@ -150,18 +154,18 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -150,18 +154,18 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i]; const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i];
const IntT* scatter_indices = const IntT* scatter_indices =
rulebook_ptr + rulebook_len + h_offsets_ptr[i]; rulebook_ptr + rulebook_len + h_offsets_ptr[i];
dispatchKernel(dev_ctx, GatherGemmScatterDriver(dev_ctx,
x.non_zero_elements().data<T>(), x.non_zero_elements().data<T>(),
tmp_kernel_ptr, tmp_kernel_ptr,
out_values_ptr, out_values_ptr,
out_values_ptr, out_values_ptr,
M, M,
N, N,
K, K,
gather_indices, gather_indices,
scatter_indices, scatter_indices,
cutlass, static_cast<T>(1.0),
x.dtype()); static_cast<T>(1.0));
} }
} else { } else {
#endif #endif
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_CUTLASS
#include "cutlass/arch/mma.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/half.h"
#include "cutlass/util/device_memory.h"
#include "examples/common/helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace phi {
namespace sparse {
#define TYPEDEF_KERNEL_POINTER(kernel, dtype) \
typedef void (*kernel)(dtype const alpha, \
dtype const beta, \
const GPUContext& dev_ctx, \
const dtype* const a, \
const dtype* const b, \
const dtype* const c, \
dtype* const d, \
const int m, \
const int n, \
const int k, \
const int32_t* a_indices, \
const int32_t* c_d_indices);
#define GATHER_GEMM_SCATTER_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
throw std::runtime_error(cutlassGetStatusString(error)); \
} \
}
#define DEFINE_LAUNCH_KERNEL(dtype, cutlass_type) \
template <typename Gemm> \
void launchKernel(dtype const alpha, \
dtype const beta, \
const GPUContext& dev_ctx, \
const dtype* const a, \
const dtype* const b, \
const dtype* const c, \
dtype* const d, \
const int m, \
const int n, \
const int k, \
const int32_t* a_indices, \
const int32_t* c_d_indices) { \
cutlass::gemm::GemmCoord problem_size_real({m, n, k}); \
int split_k_slices = 1; \
typename Gemm::Arguments arguments{ \
cutlass::gemm::GemmUniversalMode::kGemm, \
problem_size_real, \
split_k_slices, \
{static_cast<const cutlass_type>(static_cast<const float>(alpha)), \
static_cast<const cutlass_type>(static_cast<const float>(beta))}, \
reinterpret_cast<const cutlass_type* const>(a), \
reinterpret_cast<const cutlass_type* const>(b), \
reinterpret_cast<const cutlass_type* const>(c), \
reinterpret_cast<cutlass_type* const>(d), \
cutlass::layout::RowMajor().capacity(problem_size_real.mk()), \
cutlass::layout::RowMajor().capacity(problem_size_real.kn()), \
cutlass::layout::RowMajor().capacity(problem_size_real.mn()), \
cutlass::layout::RowMajor().capacity(problem_size_real.mn()), \
problem_size_real.k(), \
problem_size_real.n(), \
problem_size_real.n(), \
problem_size_real.n(), \
a_indices, \
nullptr, \
c_d_indices}; \
size_t workspace_size = Gemm::get_workspace_size(arguments); \
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); \
Gemm gemm_op; \
cutlass::Status status = gemm_op.can_implement(arguments); \
GATHER_GEMM_SCATTER_CHECK(status); \
status = gemm_op.initialize(arguments, workspace.get()); \
GATHER_GEMM_SCATTER_CHECK(status); \
gemm_op(dev_ctx.stream()); \
}
TYPEDEF_KERNEL_POINTER(fp16_gather_gemm_scatter, phi::dtype::float16)
TYPEDEF_KERNEL_POINTER(fp32_gather_gemm_scatter, float)
DEFINE_LAUNCH_KERNEL(phi::dtype::float16, cutlass::half_t)
DEFINE_LAUNCH_KERNEL(float, float)
} // namespace sparse
} // namespace phi
#endif
...@@ -41,7 +41,6 @@ def CreateGatherGemmScatterOperator( ...@@ -41,7 +41,6 @@ def CreateGatherGemmScatterOperator(
layouts, layouts,
tile_descriptions, tile_descriptions,
data_type, data_type,
alignment_constraints,
complex_transforms=None, complex_transforms=None,
epilogue_functor=EpilogueFunctor.LinearCombination, epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity8, swizzling_functor=SwizzlingFunctor.Identity8,
...@@ -55,12 +54,15 @@ def CreateGatherGemmScatterOperator( ...@@ -55,12 +54,15 @@ def CreateGatherGemmScatterOperator(
element_a, element_b, element_c, element_epilogue = data_type element_a, element_b, element_c, element_epilogue = data_type
operations = [] alignment_constraints = [0]
if 'f16' == element_a.name or 'bf16' == element_a.name:
alignment_constraints = [8]
elif 'f32' == element_a.name or 'tf32' == element_a.name:
alignment_constraints = [4]
elif 'f64' == element_a.name:
alignment_constraints = [1]
# by default, only generate the largest tile and largest alignment operations = []
# if manifest.kernel_filter == '':
# tile_descriptions = [tile_descriptions[0],]
# alignment_constraints = [alignment_constraints[0],]
for layout in layouts: for layout in layouts:
for tile_description in tile_descriptions: for tile_description in tile_descriptions:
...@@ -95,9 +97,9 @@ def CreateGatherGemmScatterOperator( ...@@ -95,9 +97,9 @@ def CreateGatherGemmScatterOperator(
return operations return operations
def GenerateSM70_TensorOp_884(manifest, cuda_version): def GenerateSM80_TensorOp_16816(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 10, 1): if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return return
layouts = [ layouts = [
...@@ -106,15 +108,7 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version): ...@@ -106,15 +108,7 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version):
math_instructions = [ math_instructions = [
MathInstruction( MathInstruction(
[8, 8, 4], [16, 8, 16],
DataType.f16,
DataType.f16,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
),
MathInstruction(
[8, 8, 4],
DataType.f16, DataType.f16,
DataType.f16, DataType.f16,
DataType.f16, DataType.f16,
...@@ -123,36 +117,78 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version): ...@@ -123,36 +117,78 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version):
), ),
] ]
min_cc = 70 min_cc = 80
max_cc = 75 max_cc = 1024
alignment_constraints = [8, 4, 2, 1] alignment_constraints = [8]
for math_inst in math_instructions: for math_inst in math_instructions:
tile_descriptions = [ tile_descriptions = [
TileDescription( TileDescription(
[256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc [256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 32], 3, [4, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc
), ),
TileDescription( TileDescription(
[128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc [256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc
), ),
TileDescription( TileDescription(
[128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc [128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc
), ),
TileDescription( TileDescription(
[256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc [256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc
), ),
TileDescription( TileDescription(
[64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc [64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc
), ),
TileDescription( TileDescription(
[64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc [128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc
), ),
TileDescription( TileDescription(
[128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc [256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc
), ),
TileDescription( TileDescription(
[64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc [64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc
), ),
] ]
...@@ -164,11 +200,7 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version): ...@@ -164,11 +200,7 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version):
] ]
CreateGatherGemmScatterOperator( CreateGatherGemmScatterOperator(
manifest, manifest, layouts, tile_descriptions, data_type
layouts,
tile_descriptions,
data_type,
alignment_constraints,
) )
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
...@@ -182,16 +214,286 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version): ...@@ -182,16 +214,286 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version):
] ]
CreateGatherGemmScatterOperator( CreateGatherGemmScatterOperator(
manifest, manifest, layouts, tile_descriptions, data_type_mixed
layouts,
tile_descriptions,
data_type_mixed,
alignment_constraints,
) )
def GenerateSM70(manifest, cuda_version): def GenerateSM80_TensorOp_1688(manifest, cuda_version):
GenerateSM70_TensorOp_884(manifest, cuda_version)
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return
layouts = [
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
]
math_instructions = [
MathInstruction(
[16, 8, 8],
DataType.tf32,
DataType.tf32,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
)
]
min_cc = 80
max_cc = 1024
for math_inst in math_instructions:
tile_descriptions = [
TileDescription(
[256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_accumulator,
math_inst.element_accumulator,
]
data_type_mixed = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_a,
math_inst.element_accumulator,
]
CreateGatherGemmScatterOperator(
manifest, layouts, tile_descriptions, data_type
)
CreateGatherGemmScatterOperator(
manifest, layouts, tile_descriptions, data_type_mixed
)
def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return
layouts = [
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
]
math_instructions = [
MathInstruction(
[16, 8, 8],
DataType.tf32,
DataType.tf32,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
),
]
min_cc = 80
max_cc = 1024
for math_inst in math_instructions:
tile_descriptions = [
TileDescription(
[256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32]
CreateGatherGemmScatterOperator(
manifest, layouts, tile_descriptions, data_type
)
def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return
layouts = [
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
]
math_instructions = [
MathInstruction(
[16, 8, 8],
DataType.f32,
DataType.f32,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add_fast_f32,
),
]
min_cc = 80
max_cc = 1024
for math_inst in math_instructions:
tile_descriptions = [
TileDescription(
[128, 128, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32]
CreateGatherGemmScatterOperator(
manifest, layouts, tile_descriptions, data_type
)
def GenerateSM80(manifest, cuda_version):
GenerateSM80_TensorOp_16816(manifest, cuda_version)
GenerateSM80_TensorOp_1688(manifest, cuda_version)
GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version)
GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version)
class KernelCfg: class KernelCfg:
...@@ -229,7 +531,7 @@ class KernelCfg: ...@@ -229,7 +531,7 @@ class KernelCfg:
if __name__ == "__main__": if __name__ == "__main__":
args = KernelCfg( args = KernelCfg(
architectures='70', architectures='80',
build_dir=sys.argv[2], build_dir=sys.argv[2],
cuda_version=sys.argv[3], cuda_version=sys.argv[3],
curr_build_dir=sys.argv[2], curr_build_dir=sys.argv[2],
...@@ -245,6 +547,6 @@ if __name__ == "__main__": ...@@ -245,6 +547,6 @@ if __name__ == "__main__":
) )
manifest = GatherGemmScatterManifest(args) manifest = GatherGemmScatterManifest(args)
GenerateSM70(manifest, args.cuda_version) GenerateSM80(manifest, args.cuda_version)
manifest.emit(GeneratorTarget.Library) manifest.emit(GeneratorTarget.Library)
...@@ -18,7 +18,7 @@ import shutil ...@@ -18,7 +18,7 @@ import shutil
from gather_gemm_scatter_operation import ( from gather_gemm_scatter_operation import (
EmitGatherGemmScatterConfigurationLibrary, EmitGatherGemmScatterConfigurationLibrary,
) )
from library import OperationKind, OperationKindNames from library import OperationKind, OperationKindNames, SubstituteTemplate
from manifest import EmitOperationKindLibrary, GeneratorTarget, Manifest from manifest import EmitOperationKindLibrary, GeneratorTarget, Manifest
...@@ -28,11 +28,25 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary): ...@@ -28,11 +28,25 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary):
self.emitters = { self.emitters = {
OperationKind.Gemm: EmitGatherGemmScatterConfigurationLibrary OperationKind.Gemm: EmitGatherGemmScatterConfigurationLibrary
} }
self.header_template = "#pragma once\n#ifdef PADDLE_WITH_CUTLASS\n" self.header_template = "#pragma once\n#ifdef PADDLE_WITH_CUTLASS\n#include \"paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h\"\n"
self.entry_template = "" self.entry_template = ""
self.configuration_prototype_template = "" self.configuration_prototype_template = ""
self.configuration_template = "" self.configuration_template = ""
self.epilogue_template = "#endif" self.namespace_template = """
namespace phi {
namespace sparse {
"""
self.epilogue_template = """
} // namespace sparse
} // namespace phi
#endif
"""
self.fp16_kernels_list = (
"static std::vector<fp16_gather_gemm_scatter> fp16_kernels = {\n"
)
self.fp32_kernels_list = (
"static std::vector<fp32_gather_gemm_scatter> fp32_kernels = {\n"
)
def __enter__(self): def __enter__(self):
self.operation_path = os.path.join( self.operation_path = os.path.join(
...@@ -64,6 +78,21 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary): ...@@ -64,6 +78,21 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary):
self.source_files.append(configuration_emitter.configuration_path) self.source_files.append(configuration_emitter.configuration_path)
self.configurations.append(configuration_name) self.configurations.append(configuration_name)
if 'h' == operations[0].short_math_name():
self.fp16_kernels_list += (
"""
launchKernel<"""
+ configuration_name
+ "::Gemm>,"
)
if 's' == operations[0].short_math_name():
self.fp32_kernels_list += (
"""
launchKernel<"""
+ configuration_name
+ "::Gemm>,"
)
self.top_level_file.write( self.top_level_file.write(
'#include "' '#include "'
+ self.operation_path + self.operation_path
...@@ -72,6 +101,30 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary): ...@@ -72,6 +101,30 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary):
+ '.h"\n' + '.h"\n'
) )
def __exit__(self, exception_type, exception_value, traceback):
self.top_level_file.write(
SubstituteTemplate(
self.entry_template,
{'operation_name': OperationKindNames[self.kind]},
)
)
for configuration_name in self.configurations:
self.top_level_file.write(
SubstituteTemplate(
self.configuration_template,
{'configuration_name': configuration_name},
)
)
self.fp16_kernels_list += "\n};\n"
self.fp32_kernels_list += "\n};\n"
self.top_level_file.write(self.namespace_template)
self.top_level_file.write(self.fp16_kernels_list)
self.top_level_file.write(self.fp32_kernels_list)
self.top_level_file.write(self.epilogue_template)
self.top_level_file.close()
class GatherGemmScatterManifest(Manifest): class GatherGemmScatterManifest(Manifest):
def emit(self, target=GeneratorTarget.Library): def emit(self, target=GeneratorTarget.Library):
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import enum import enum
import os.path import os.path
...@@ -40,16 +41,7 @@ from library import ( ...@@ -40,16 +41,7 @@ from library import (
class EmitGatherGemmScatterInstance(EmitGemmInstance): class EmitGatherGemmScatterInstance(EmitGemmInstance):
def __init__(self, operation_suffix=''): def __init__(self, operation_suffix=''):
self.operation_suffix = operation_suffix self.operation_suffix = operation_suffix
self.includes = [ self.includes = []
"cutlass/cutlass.h",
"cutlass/numeric_types.h",
"cutlass/arch/arch.h",
"cutlass/arch/mma.h",
"cutlass/layout/matrix.h",
"cutlass/gemm/device/gemm.h",
"cutlass/gemm/device/gemm_universal_adapter.h",
"cutlass/gemm/kernel/default_gemm_universal.h",
]
self.builtin_epilogue_functor_template = """ self.builtin_epilogue_functor_template = """
${epilogue_functor}< ${epilogue_functor}<
${element_c}, ${element_c},
...@@ -247,6 +239,18 @@ namespace sparse { ...@@ -247,6 +239,18 @@ namespace sparse {
#endif #endif
""" """
def __enter__(self):
self.configuration_file = open(self.configuration_path, "w")
self.configuration_file.write(self.header_template)
self.configuration_file.write(self.separator)
self.includes = collections.OrderedDict([])
self.instance_definitions = []
self.instance_wrappers = []
self.operations = []
return self
def __exit__(self, exception_type, exception_value, traceback): def __exit__(self, exception_type, exception_value, traceback):
# Write includes # Write includes
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
namespace phi {
namespace sparse {
fp16_gather_gemm_scatter getBestFp16Kernel(const int M,
const int N,
const int K) {
if (K == 4 && N == 16) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4::Gemm>;
}
if (K == 16 && N == 16) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 16 && N == 32) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 32 && N == 32) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 32 && N == 64) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 64 && N == 64) {
if (M > 100000)
launchKernel<
cutlass::half_t,
cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8::Gemm>;
if (M > 20000)
launchKernel<
cutlass::half_t,
cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8::Gemm>;
if (M > 15000)
return launchKernel<
cutlass::half_t,
cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8::Gemm>;
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 128) {
if (M >= 5000)
return launchKernel<
cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
return launchKernel<cutlass::half_t,
cutlass_tensorop_h16816gemm_64x64_64x5_nn_align8::Gemm>;
}
if (N == 128) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4::Gemm>;
}
fp32_gather_gemm_scatter getBestFp32Kernel(const int M,
const int N,
const int K,
const int SM) {
if (SM == 75) {
return launchKernel<
float,
cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4::Gemm>;
}
if (K == 4 && N == 16) {
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 16 && N == 16) {
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 16 && N == 32) {
if (M >= 10000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 32 && N == 32) {
if (M >= 10000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 32 && N == 64) {
if (M >= 10000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 64 && N == 64) {
if (M >= 15000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 128) {
if (M >= 100000)
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4::Gemm>;
if (M >= 5000)
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4::Gemm>;
}
if (N == 128) {
if (M >= 100000)
return launchKernel<
float,
cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4::Gemm>;
if (M >= 5000)
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4::Gemm>;
}
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
fp64_gather_gemm_scatter getBestFp64Kernel(const int M,
const int N,
const int K) {
if (K == 4 && N == 16) {
return launchKernel<double,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1::Gemm>;
}
if (K == 16 && N == 16) {
if (M >= 10000)
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
return launchKernel<double,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1::Gemm>;
}
if (K == 16 && N == 32) {
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
if (K == 32 && N == 32) {
return launchKernel<double,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1::Gemm>;
}
if (K == 32 && N == 64) {
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
if (K == 64 && N == 64) {
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
} // namespace sparse
} // namespace phi
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册