未验证 提交 8a1cdc70 编写于 作者: U umiswing 提交者: GitHub

[cutlass] gather-gemm-scatter fusion on sm 75 (#53017)

上级 c09fc385
......@@ -205,7 +205,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
// (in_channels, n) * (n, out_channels)
static cutlass::device_memory::allocation<uint8_t> workspace(
workspace_size);
GatherGemmScatterDriver<T, IntT, true, false>(
GatherGemmScatterDriver<80, true, false>(
dev_ctx,
key,
x.values().data<T>(),
......@@ -223,7 +223,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
&workspace);
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
GatherGemmScatterDriver<T, IntT, false, true>(
GatherGemmScatterDriver<80, false, true>(
dev_ctx,
key,
out_grad.values().data<T>(),
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
......@@ -31,6 +32,41 @@ limitations under the License. */
namespace phi {
namespace sparse {
#define GATHER_GEMM_SCATTER(arch, input_type, x_nnz, kernel) \
({ \
const input_type* kernel_ptr = kernel.data<input_type>(); \
const input_type* x_nnz_ptr = x_nnz.data<input_type>(); \
for (int i = 0; i < kernel_size; i++) { \
if (h_counter_ptr[i] <= 0) { \
continue; \
} \
const int M = h_counter_ptr[i]; \
const int K = in_channels; \
const int N = out_channels; \
const input_type* tmp_kernel_ptr = kernel_ptr + i * K * N; \
const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i]; \
const IntT* scatter_indices = \
rulebook_ptr + rulebook_len + h_offsets_ptr[i]; \
const size_t key = autotune::GenKey(M / features_num_range, N, K); \
GatherGemmScatterDriver<arch, false, false>( \
dev_ctx, \
key, \
x_nnz_ptr, \
tmp_kernel_ptr, \
out_values_ptr, \
out_values_ptr, \
M, \
N, \
K, \
gather_indices, \
static_cast<const IntT*>(nullptr), \
scatter_indices, \
static_cast<T>(1.0), \
static_cast<T>(1.0), \
nullptr); \
} \
})
template <typename T, typename IntT>
void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const SparseCooTensor& x,
......@@ -124,10 +160,14 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
}
#ifdef PADDLE_WITH_CUTLASS
bool mixed_precision = dev_ctx.GetComputeCapability() >= 75 &&
dev_ctx.GetComputeCapability() < 80 &&
std::is_same<T, float>::value;
bool cutlass = true;
if (dev_ctx.GetComputeCapability() < 80) cutlass = false;
if (dev_ctx.GetComputeCapability() < 75) cutlass = false;
if (in_channels % 8 != 0 || out_channels % 8 != 0) {
if (std::is_same<T, phi::dtype::float16>::value) cutlass = false;
if (mixed_precision) cutlass = false;
}
if (in_channels % 4 != 0 || out_channels % 4 != 0) {
if (std::is_same<T, float>::value) cutlass = false;
......@@ -141,36 +181,17 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, out_values, static_cast<T>(0.0f));
const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
if (h_counter_ptr[i] <= 0) {
continue;
}
const int M = h_counter_ptr[i];
const int K = in_channels;
const int N = out_channels;
const T* tmp_kernel_ptr = kernel_ptr + i * K * N;
const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i];
const IntT* scatter_indices =
rulebook_ptr + rulebook_len + h_offsets_ptr[i];
const size_t key = autotune::GenKey(M / features_num_range, N, K);
GatherGemmScatterDriver<T, IntT, false, false>(
dev_ctx,
key,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
out_values_ptr,
M,
N,
K,
gather_indices,
static_cast<const IntT*>(nullptr),
scatter_indices,
static_cast<T>(1.0),
static_cast<T>(1.0),
nullptr);
if (mixed_precision) {
DenseTensor kernel_fp16 =
phi::Cast<T, GPUContext>(dev_ctx, kernel, DataType::FLOAT16);
DenseTensor x_nnz_fp16 = phi::Cast<T, GPUContext>(
dev_ctx, x.non_zero_elements(), DataType::FLOAT16);
GATHER_GEMM_SCATTER(75, phi::dtype::float16, x_nnz_fp16, kernel_fp16);
} else {
if (dev_ctx.GetComputeCapability() < 80)
GATHER_GEMM_SCATTER(75, T, x.non_zero_elements(), kernel);
else
GATHER_GEMM_SCATTER(80, T, x.non_zero_elements(), kernel);
}
} else {
#endif
......
......@@ -36,20 +36,20 @@ size_t constexpr max_out_channels = 256;
static size_t workspace_size =
sizeof(float) * max_splitk_slices * max_in_channels * max_out_channels;
#define TYPEDEF_KERNEL_POINTER(kernel, dtype) \
typedef void (*kernel)(dtype const alpha, \
dtype const beta, \
const GPUContext& dev_ctx, \
const dtype* const a, \
const dtype* const b, \
const dtype* const c, \
dtype* const d, \
const int m, \
const int n, \
const int k, \
const int32_t* a_indices, \
const int32_t* b_indices, \
const int32_t* c_d_indices, \
#define TYPEDEF_KERNEL_POINTER(kernel, in_type, out_type) \
typedef void (*kernel)(out_type const alpha, \
out_type const beta, \
const GPUContext& dev_ctx, \
const in_type* const a, \
const in_type* const b, \
const out_type* const c, \
out_type* const d, \
const int m, \
const int n, \
const int k, \
const int32_t* a_indices, \
const int32_t* b_indices, \
const int32_t* c_d_indices, \
void* const workspace_ptr);
#define GATHER_GEMM_SCATTER_CHECK(status) \
{ \
......@@ -58,15 +58,15 @@ static size_t workspace_size =
throw std::runtime_error(cutlassGetStatusString(error)); \
} \
}
#define DEFINE_LAUNCH_KERNEL(dtype, cutlass_type) \
#define DEFINE_LAUNCH_KERNEL(in_type, out_type) \
template <typename Config> \
void launchKernel(dtype const alpha, \
dtype const beta, \
void launchKernel(out_type const alpha, \
out_type const beta, \
const GPUContext& dev_ctx, \
const dtype* const a, \
const dtype* const b, \
const dtype* const c, \
dtype* const d, \
const in_type* const a, \
const in_type* const b, \
const out_type* const c, \
out_type* const d, \
const int m, \
const int n, \
const int k, \
......@@ -81,12 +81,14 @@ static size_t workspace_size =
Config::Mode, \
problem_size_real, \
split_k_slices, \
{static_cast<const cutlass_type>(static_cast<const float>(alpha)), \
static_cast<const cutlass_type>(static_cast<const float>(beta))}, \
reinterpret_cast<const cutlass_type* const>(a), \
reinterpret_cast<const cutlass_type* const>(b), \
reinterpret_cast<const cutlass_type* const>(c), \
reinterpret_cast<cutlass_type* const>(d), \
{static_cast<const typename Gemm::Base::ElementAccumulator>( \
static_cast<const float>(alpha)), \
static_cast<const typename Gemm::Base::ElementAccumulator>( \
static_cast<const float>(beta))}, \
reinterpret_cast<const typename Gemm::Base::ElementA* const>(a), \
reinterpret_cast<const typename Gemm::Base::ElementB* const>(b), \
reinterpret_cast<const typename Gemm::Base::ElementC* const>(c), \
reinterpret_cast<typename Gemm::Base::ElementC* const>(d), \
m * k, \
k * n, \
m * n, \
......@@ -172,19 +174,23 @@ static size_t workspace_size =
ref_workspace, \
ref_d, \
ref_c, \
{static_cast<const cutlass_type>(static_cast<const float>(alpha)), \
static_cast<const cutlass_type>(static_cast<const float>(beta))}); \
{static_cast<const typename Gemm::Base::ElementAccumulator>( \
static_cast<const float>(alpha)), \
static_cast<const typename Gemm::Base::ElementAccumulator>( \
static_cast<const float>(beta))}); \
status = reduction_op.initialize(reduction_args); \
GATHER_GEMM_SCATTER_CHECK(status); \
reduction_op(dev_ctx.stream()); \
} \
}
TYPEDEF_KERNEL_POINTER(fp16_gather_gemm_scatter, phi::dtype::float16)
TYPEDEF_KERNEL_POINTER(fp32_gather_gemm_scatter, float)
TYPEDEF_KERNEL_POINTER(gather_hgemm_scatter, phi::dtype::float16, phi::float16)
TYPEDEF_KERNEL_POINTER(gather_sgemm_scatter, float, float)
TYPEDEF_KERNEL_POINTER(gather_sgemm_f16_scatter, phi::dtype::float16, float)
DEFINE_LAUNCH_KERNEL(phi::dtype::float16, cutlass::half_t)
DEFINE_LAUNCH_KERNEL(phi::dtype::float16, phi::dtype::float16)
DEFINE_LAUNCH_KERNEL(float, float)
DEFINE_LAUNCH_KERNEL(phi::dtype::float16, float)
} // namespace sparse
} // namespace phi
......
......@@ -524,6 +524,91 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math(
)
def GenerateSM75_TensorOp_1688(manifest, cuda_version, debug=False):
if not CudaToolkitVersionSatisfies(cuda_version, 10, 2):
return
layouts = [
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
]
math_instructions = [
MathInstruction(
[16, 8, 8],
DataType.f16,
DataType.f16,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
),
MathInstruction(
[16, 8, 8],
DataType.f16,
DataType.f16,
DataType.f16,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
),
]
min_cc = 75
max_cc = 1024
for math_inst in math_instructions:
tile_descriptions = [
TileDescription(
[256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc
),
]
if debug:
tile_descriptions = [
TileDescription(
[256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_accumulator,
math_inst.element_accumulator,
]
CreateGatherGemmScatterOperator(
manifest, layouts, tile_descriptions, data_type
)
def GenerateSM75(manifest, cuda_version, debug=False):
GenerateSM75_TensorOp_1688(manifest, cuda_version, debug)
def GenerateSM80(manifest, cuda_version, debug=False):
GenerateSM80_TensorOp_16816(manifest, cuda_version, debug)
GenerateSM80_TensorOp_1688(manifest, cuda_version, debug)
......@@ -582,6 +667,8 @@ if __name__ == "__main__":
)
manifest = GatherGemmScatterManifest(args)
GenerateSM80(manifest, args.cuda_version)
debug = False
GenerateSM75(manifest, args.cuda_version, debug)
GenerateSM80(manifest, args.cuda_version, debug)
manifest.emit(GeneratorTarget.Library)
......@@ -42,10 +42,12 @@ namespace sparse {
#endif
"""
self.kernels_lists = {
"hnn": "static std::vector<fp16_gather_gemm_scatter> fp16_nn_kernels = {",
"snn": "static std::vector<fp32_gather_gemm_scatter> fp32_nn_kernels = {",
"snt": "static std::vector<fp32_gather_gemm_scatter> fp32_nt_kernels = {",
"stn": "static std::vector<fp32_gather_gemm_scatter> fp32_tn_kernels = {",
"hnn75": "static std::vector<gather_hgemm_scatter> sm75_fp16_nn_kernels = {",
"snn75": "static std::vector<gather_sgemm_f16_scatter> sm75_fp32_nn_kernels = {",
"hnn80": "static std::vector<gather_hgemm_scatter> sm80_fp16_nn_kernels = {",
"snn80": "static std::vector<gather_sgemm_scatter> sm80_fp32_nn_kernels = {",
"snt80": "static std::vector<gather_sgemm_scatter> sm80_fp32_nt_kernels = {",
"stn80": "static std::vector<gather_sgemm_scatter> sm80_fp32_tn_kernels = {",
}
def __enter__(self):
......@@ -81,7 +83,9 @@ namespace sparse {
if operations[0].layout_name() == 'tn':
self.kernels_lists[
operations[0].short_math_name() + operations[0].layout_name()
operations[0].short_math_name()
+ operations[0].layout_name()
+ str(operations[0].arch)
] += (
"""
launchKernel<"""
......@@ -91,7 +95,9 @@ launchKernel<"""
)
else:
self.kernels_lists[
operations[0].short_math_name() + operations[0].layout_name()
operations[0].short_math_name()
+ operations[0].layout_name()
+ str(operations[0].arch)
] += (
"""
launchKernel<"""
......
......@@ -27,42 +27,56 @@ namespace sparse {
// that shapes within this range share the same key.
constexpr int features_num_range = 10000;
template <typename T, typename IntT, bool TransposeA, bool TransposeB>
template <int ComputeCapability,
bool TransposeA,
bool TransposeB,
typename Input,
typename Output,
typename IntT>
void GatherGemmScatterDriver(
const phi::GPUContext& ctx,
const size_t key,
const T* const a,
const T* const b,
const T* const c,
T* const d,
const Input* const a,
const Input* const b,
const Output* const c,
Output* const d,
const int& m,
const int& n,
const int& k,
const IntT* a_indices,
const IntT* b_indices,
const IntT* c_d_indices,
T alpha,
T beta,
cutlass::device_memory::allocation<uint8_t>* const workspace_ptr) {}
Output alpha,
Output beta,
cutlass::device_memory::allocation<uint8_t>* const workspace_ptr) {
PADDLE_THROW(
phi::errors::Unimplemented("gather_gemm_scatter fusion only supports "
"fp16_nn, fp32_nn, fp32_nt and fp32_tn now."));
}
#define EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER( \
T, kernels, transpose_a, transpose_b) \
compute_capability, transpose_a, transpose_b, in_type, out_type, kernels) \
template <> \
inline void GatherGemmScatterDriver<T, int32_t, transpose_a, transpose_b>( \
inline void GatherGemmScatterDriver<compute_capability, \
transpose_a, \
transpose_b, \
in_type, \
out_type, \
int32_t>( \
const phi::GPUContext& ctx, \
const size_t key, \
const T* const a, \
const T* const b, \
const T* const c, \
T* const d, \
const in_type* const a, \
const in_type* const b, \
const out_type* const c, \
out_type* const d, \
const int& m, \
const int& n, \
const int& k, \
const int32_t* a_indices, \
const int32_t* b_indices, \
const int32_t* c_d_indices, \
T alpha, \
T beta, \
out_type alpha, \
out_type beta, \
cutlass::device_memory::allocation<uint8_t>* const workspace_ptr) { \
auto* tuner = \
autotune::MakeGatherGemmScatterTuner<transpose_a, transpose_b>( \
......@@ -86,22 +100,26 @@ void GatherGemmScatterDriver(
workspace_ptr); \
}
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(phi::dtype::float16,
fp16_nn_kernels,
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(75,
false,
false)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(float,
fp32_nn_kernels,
false,
false)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(float,
fp32_nt_kernels,
phi::dtype::float16,
phi::dtype::float16,
sm75_fp16_nn_kernels)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(
75, false, false, phi::dtype::float16, float, sm75_fp32_nn_kernels)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(80,
false,
true)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(float,
fp32_tn_kernels,
true,
false)
false,
phi::dtype::float16,
phi::dtype::float16,
sm80_fp16_nn_kernels)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(
80, false, false, float, float, sm80_fp32_nn_kernels)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(
80, false, true, float, float, sm80_fp32_nt_kernels)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(
80, true, false, float, float, sm80_fp32_tn_kernels)
} // namespace sparse
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册