未验证 提交 8a1cdc70 编写于 作者: U umiswing 提交者: GitHub

[cutlass] gather-gemm-scatter fusion on sm 75 (#53017)

上级 c09fc385
...@@ -205,7 +205,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -205,7 +205,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
// (in_channels, n) * (n, out_channels) // (in_channels, n) * (n, out_channels)
static cutlass::device_memory::allocation<uint8_t> workspace( static cutlass::device_memory::allocation<uint8_t> workspace(
workspace_size); workspace_size);
GatherGemmScatterDriver<T, IntT, true, false>( GatherGemmScatterDriver<80, true, false>(
dev_ctx, dev_ctx,
key, key,
x.values().data<T>(), x.values().data<T>(),
...@@ -223,7 +223,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -223,7 +223,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
&workspace); &workspace);
// call gemm: d_x = out_grad * transpose(kernel) // call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels) // (n, out_channels) * (out_channels, in_channels)
GatherGemmScatterDriver<T, IntT, false, true>( GatherGemmScatterDriver<80, false, true>(
dev_ctx, dev_ctx,
key, key,
out_grad.values().data<T>(), out_grad.values().data<T>(),
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/visit_type.h" #include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h" #include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h" #include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
...@@ -31,6 +32,41 @@ limitations under the License. */ ...@@ -31,6 +32,41 @@ limitations under the License. */
namespace phi { namespace phi {
namespace sparse { namespace sparse {
#define GATHER_GEMM_SCATTER(arch, input_type, x_nnz, kernel) \
({ \
const input_type* kernel_ptr = kernel.data<input_type>(); \
const input_type* x_nnz_ptr = x_nnz.data<input_type>(); \
for (int i = 0; i < kernel_size; i++) { \
if (h_counter_ptr[i] <= 0) { \
continue; \
} \
const int M = h_counter_ptr[i]; \
const int K = in_channels; \
const int N = out_channels; \
const input_type* tmp_kernel_ptr = kernel_ptr + i * K * N; \
const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i]; \
const IntT* scatter_indices = \
rulebook_ptr + rulebook_len + h_offsets_ptr[i]; \
const size_t key = autotune::GenKey(M / features_num_range, N, K); \
GatherGemmScatterDriver<arch, false, false>( \
dev_ctx, \
key, \
x_nnz_ptr, \
tmp_kernel_ptr, \
out_values_ptr, \
out_values_ptr, \
M, \
N, \
K, \
gather_indices, \
static_cast<const IntT*>(nullptr), \
scatter_indices, \
static_cast<T>(1.0), \
static_cast<T>(1.0), \
nullptr); \
} \
})
template <typename T, typename IntT> template <typename T, typename IntT>
void Conv3dCooGPUKernel(const GPUContext& dev_ctx, void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
...@@ -124,10 +160,14 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -124,10 +160,14 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
} }
#ifdef PADDLE_WITH_CUTLASS #ifdef PADDLE_WITH_CUTLASS
bool mixed_precision = dev_ctx.GetComputeCapability() >= 75 &&
dev_ctx.GetComputeCapability() < 80 &&
std::is_same<T, float>::value;
bool cutlass = true; bool cutlass = true;
if (dev_ctx.GetComputeCapability() < 80) cutlass = false; if (dev_ctx.GetComputeCapability() < 75) cutlass = false;
if (in_channels % 8 != 0 || out_channels % 8 != 0) { if (in_channels % 8 != 0 || out_channels % 8 != 0) {
if (std::is_same<T, phi::dtype::float16>::value) cutlass = false; if (std::is_same<T, phi::dtype::float16>::value) cutlass = false;
if (mixed_precision) cutlass = false;
} }
if (in_channels % 4 != 0 || out_channels % 4 != 0) { if (in_channels % 4 != 0 || out_channels % 4 != 0) {
if (std::is_same<T, float>::value) cutlass = false; if (std::is_same<T, float>::value) cutlass = false;
...@@ -141,36 +181,17 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -141,36 +181,17 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
phi::funcs::SetConstant<GPUContext, T> set_zero; phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, out_values, static_cast<T>(0.0f)); set_zero(dev_ctx, out_values, static_cast<T>(0.0f));
const T* kernel_ptr = kernel.data<T>(); if (mixed_precision) {
for (int i = 0; i < kernel_size; i++) { DenseTensor kernel_fp16 =
if (h_counter_ptr[i] <= 0) { phi::Cast<T, GPUContext>(dev_ctx, kernel, DataType::FLOAT16);
continue; DenseTensor x_nnz_fp16 = phi::Cast<T, GPUContext>(
} dev_ctx, x.non_zero_elements(), DataType::FLOAT16);
GATHER_GEMM_SCATTER(75, phi::dtype::float16, x_nnz_fp16, kernel_fp16);
const int M = h_counter_ptr[i]; } else {
const int K = in_channels; if (dev_ctx.GetComputeCapability() < 80)
const int N = out_channels; GATHER_GEMM_SCATTER(75, T, x.non_zero_elements(), kernel);
const T* tmp_kernel_ptr = kernel_ptr + i * K * N; else
const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i]; GATHER_GEMM_SCATTER(80, T, x.non_zero_elements(), kernel);
const IntT* scatter_indices =
rulebook_ptr + rulebook_len + h_offsets_ptr[i];
const size_t key = autotune::GenKey(M / features_num_range, N, K);
GatherGemmScatterDriver<T, IntT, false, false>(
dev_ctx,
key,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
out_values_ptr,
M,
N,
K,
gather_indices,
static_cast<const IntT*>(nullptr),
scatter_indices,
static_cast<T>(1.0),
static_cast<T>(1.0),
nullptr);
} }
} else { } else {
#endif #endif
......
...@@ -36,20 +36,20 @@ size_t constexpr max_out_channels = 256; ...@@ -36,20 +36,20 @@ size_t constexpr max_out_channels = 256;
static size_t workspace_size = static size_t workspace_size =
sizeof(float) * max_splitk_slices * max_in_channels * max_out_channels; sizeof(float) * max_splitk_slices * max_in_channels * max_out_channels;
#define TYPEDEF_KERNEL_POINTER(kernel, dtype) \ #define TYPEDEF_KERNEL_POINTER(kernel, in_type, out_type) \
typedef void (*kernel)(dtype const alpha, \ typedef void (*kernel)(out_type const alpha, \
dtype const beta, \ out_type const beta, \
const GPUContext& dev_ctx, \ const GPUContext& dev_ctx, \
const dtype* const a, \ const in_type* const a, \
const dtype* const b, \ const in_type* const b, \
const dtype* const c, \ const out_type* const c, \
dtype* const d, \ out_type* const d, \
const int m, \ const int m, \
const int n, \ const int n, \
const int k, \ const int k, \
const int32_t* a_indices, \ const int32_t* a_indices, \
const int32_t* b_indices, \ const int32_t* b_indices, \
const int32_t* c_d_indices, \ const int32_t* c_d_indices, \
void* const workspace_ptr); void* const workspace_ptr);
#define GATHER_GEMM_SCATTER_CHECK(status) \ #define GATHER_GEMM_SCATTER_CHECK(status) \
{ \ { \
...@@ -58,15 +58,15 @@ static size_t workspace_size = ...@@ -58,15 +58,15 @@ static size_t workspace_size =
throw std::runtime_error(cutlassGetStatusString(error)); \ throw std::runtime_error(cutlassGetStatusString(error)); \
} \ } \
} }
#define DEFINE_LAUNCH_KERNEL(dtype, cutlass_type) \ #define DEFINE_LAUNCH_KERNEL(in_type, out_type) \
template <typename Config> \ template <typename Config> \
void launchKernel(dtype const alpha, \ void launchKernel(out_type const alpha, \
dtype const beta, \ out_type const beta, \
const GPUContext& dev_ctx, \ const GPUContext& dev_ctx, \
const dtype* const a, \ const in_type* const a, \
const dtype* const b, \ const in_type* const b, \
const dtype* const c, \ const out_type* const c, \
dtype* const d, \ out_type* const d, \
const int m, \ const int m, \
const int n, \ const int n, \
const int k, \ const int k, \
...@@ -81,12 +81,14 @@ static size_t workspace_size = ...@@ -81,12 +81,14 @@ static size_t workspace_size =
Config::Mode, \ Config::Mode, \
problem_size_real, \ problem_size_real, \
split_k_slices, \ split_k_slices, \
{static_cast<const cutlass_type>(static_cast<const float>(alpha)), \ {static_cast<const typename Gemm::Base::ElementAccumulator>( \
static_cast<const cutlass_type>(static_cast<const float>(beta))}, \ static_cast<const float>(alpha)), \
reinterpret_cast<const cutlass_type* const>(a), \ static_cast<const typename Gemm::Base::ElementAccumulator>( \
reinterpret_cast<const cutlass_type* const>(b), \ static_cast<const float>(beta))}, \
reinterpret_cast<const cutlass_type* const>(c), \ reinterpret_cast<const typename Gemm::Base::ElementA* const>(a), \
reinterpret_cast<cutlass_type* const>(d), \ reinterpret_cast<const typename Gemm::Base::ElementB* const>(b), \
reinterpret_cast<const typename Gemm::Base::ElementC* const>(c), \
reinterpret_cast<typename Gemm::Base::ElementC* const>(d), \
m * k, \ m * k, \
k * n, \ k * n, \
m * n, \ m * n, \
...@@ -172,19 +174,23 @@ static size_t workspace_size = ...@@ -172,19 +174,23 @@ static size_t workspace_size =
ref_workspace, \ ref_workspace, \
ref_d, \ ref_d, \
ref_c, \ ref_c, \
{static_cast<const cutlass_type>(static_cast<const float>(alpha)), \ {static_cast<const typename Gemm::Base::ElementAccumulator>( \
static_cast<const cutlass_type>(static_cast<const float>(beta))}); \ static_cast<const float>(alpha)), \
static_cast<const typename Gemm::Base::ElementAccumulator>( \
static_cast<const float>(beta))}); \
status = reduction_op.initialize(reduction_args); \ status = reduction_op.initialize(reduction_args); \
GATHER_GEMM_SCATTER_CHECK(status); \ GATHER_GEMM_SCATTER_CHECK(status); \
reduction_op(dev_ctx.stream()); \ reduction_op(dev_ctx.stream()); \
} \ } \
} }
TYPEDEF_KERNEL_POINTER(fp16_gather_gemm_scatter, phi::dtype::float16) TYPEDEF_KERNEL_POINTER(gather_hgemm_scatter, phi::dtype::float16, phi::float16)
TYPEDEF_KERNEL_POINTER(fp32_gather_gemm_scatter, float) TYPEDEF_KERNEL_POINTER(gather_sgemm_scatter, float, float)
TYPEDEF_KERNEL_POINTER(gather_sgemm_f16_scatter, phi::dtype::float16, float)
DEFINE_LAUNCH_KERNEL(phi::dtype::float16, cutlass::half_t) DEFINE_LAUNCH_KERNEL(phi::dtype::float16, phi::dtype::float16)
DEFINE_LAUNCH_KERNEL(float, float) DEFINE_LAUNCH_KERNEL(float, float)
DEFINE_LAUNCH_KERNEL(phi::dtype::float16, float)
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
......
...@@ -524,6 +524,91 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math( ...@@ -524,6 +524,91 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math(
) )
def GenerateSM75_TensorOp_1688(manifest, cuda_version, debug=False):
if not CudaToolkitVersionSatisfies(cuda_version, 10, 2):
return
layouts = [
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
]
math_instructions = [
MathInstruction(
[16, 8, 8],
DataType.f16,
DataType.f16,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
),
MathInstruction(
[16, 8, 8],
DataType.f16,
DataType.f16,
DataType.f16,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
),
]
min_cc = 75
max_cc = 1024
for math_inst in math_instructions:
tile_descriptions = [
TileDescription(
[256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc
),
]
if debug:
tile_descriptions = [
TileDescription(
[256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_accumulator,
math_inst.element_accumulator,
]
CreateGatherGemmScatterOperator(
manifest, layouts, tile_descriptions, data_type
)
def GenerateSM75(manifest, cuda_version, debug=False):
GenerateSM75_TensorOp_1688(manifest, cuda_version, debug)
def GenerateSM80(manifest, cuda_version, debug=False): def GenerateSM80(manifest, cuda_version, debug=False):
GenerateSM80_TensorOp_16816(manifest, cuda_version, debug) GenerateSM80_TensorOp_16816(manifest, cuda_version, debug)
GenerateSM80_TensorOp_1688(manifest, cuda_version, debug) GenerateSM80_TensorOp_1688(manifest, cuda_version, debug)
...@@ -582,6 +667,8 @@ if __name__ == "__main__": ...@@ -582,6 +667,8 @@ if __name__ == "__main__":
) )
manifest = GatherGemmScatterManifest(args) manifest = GatherGemmScatterManifest(args)
GenerateSM80(manifest, args.cuda_version) debug = False
GenerateSM75(manifest, args.cuda_version, debug)
GenerateSM80(manifest, args.cuda_version, debug)
manifest.emit(GeneratorTarget.Library) manifest.emit(GeneratorTarget.Library)
...@@ -42,10 +42,12 @@ namespace sparse { ...@@ -42,10 +42,12 @@ namespace sparse {
#endif #endif
""" """
self.kernels_lists = { self.kernels_lists = {
"hnn": "static std::vector<fp16_gather_gemm_scatter> fp16_nn_kernels = {", "hnn75": "static std::vector<gather_hgemm_scatter> sm75_fp16_nn_kernels = {",
"snn": "static std::vector<fp32_gather_gemm_scatter> fp32_nn_kernels = {", "snn75": "static std::vector<gather_sgemm_f16_scatter> sm75_fp32_nn_kernels = {",
"snt": "static std::vector<fp32_gather_gemm_scatter> fp32_nt_kernels = {", "hnn80": "static std::vector<gather_hgemm_scatter> sm80_fp16_nn_kernels = {",
"stn": "static std::vector<fp32_gather_gemm_scatter> fp32_tn_kernels = {", "snn80": "static std::vector<gather_sgemm_scatter> sm80_fp32_nn_kernels = {",
"snt80": "static std::vector<gather_sgemm_scatter> sm80_fp32_nt_kernels = {",
"stn80": "static std::vector<gather_sgemm_scatter> sm80_fp32_tn_kernels = {",
} }
def __enter__(self): def __enter__(self):
...@@ -81,7 +83,9 @@ namespace sparse { ...@@ -81,7 +83,9 @@ namespace sparse {
if operations[0].layout_name() == 'tn': if operations[0].layout_name() == 'tn':
self.kernels_lists[ self.kernels_lists[
operations[0].short_math_name() + operations[0].layout_name() operations[0].short_math_name()
+ operations[0].layout_name()
+ str(operations[0].arch)
] += ( ] += (
""" """
launchKernel<""" launchKernel<"""
...@@ -91,7 +95,9 @@ launchKernel<""" ...@@ -91,7 +95,9 @@ launchKernel<"""
) )
else: else:
self.kernels_lists[ self.kernels_lists[
operations[0].short_math_name() + operations[0].layout_name() operations[0].short_math_name()
+ operations[0].layout_name()
+ str(operations[0].arch)
] += ( ] += (
""" """
launchKernel<""" launchKernel<"""
......
...@@ -27,42 +27,56 @@ namespace sparse { ...@@ -27,42 +27,56 @@ namespace sparse {
// that shapes within this range share the same key. // that shapes within this range share the same key.
constexpr int features_num_range = 10000; constexpr int features_num_range = 10000;
template <typename T, typename IntT, bool TransposeA, bool TransposeB> template <int ComputeCapability,
bool TransposeA,
bool TransposeB,
typename Input,
typename Output,
typename IntT>
void GatherGemmScatterDriver( void GatherGemmScatterDriver(
const phi::GPUContext& ctx, const phi::GPUContext& ctx,
const size_t key, const size_t key,
const T* const a, const Input* const a,
const T* const b, const Input* const b,
const T* const c, const Output* const c,
T* const d, Output* const d,
const int& m, const int& m,
const int& n, const int& n,
const int& k, const int& k,
const IntT* a_indices, const IntT* a_indices,
const IntT* b_indices, const IntT* b_indices,
const IntT* c_d_indices, const IntT* c_d_indices,
T alpha, Output alpha,
T beta, Output beta,
cutlass::device_memory::allocation<uint8_t>* const workspace_ptr) {} cutlass::device_memory::allocation<uint8_t>* const workspace_ptr) {
PADDLE_THROW(
phi::errors::Unimplemented("gather_gemm_scatter fusion only supports "
"fp16_nn, fp32_nn, fp32_nt and fp32_tn now."));
}
#define EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER( \ #define EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER( \
T, kernels, transpose_a, transpose_b) \ compute_capability, transpose_a, transpose_b, in_type, out_type, kernels) \
template <> \ template <> \
inline void GatherGemmScatterDriver<T, int32_t, transpose_a, transpose_b>( \ inline void GatherGemmScatterDriver<compute_capability, \
transpose_a, \
transpose_b, \
in_type, \
out_type, \
int32_t>( \
const phi::GPUContext& ctx, \ const phi::GPUContext& ctx, \
const size_t key, \ const size_t key, \
const T* const a, \ const in_type* const a, \
const T* const b, \ const in_type* const b, \
const T* const c, \ const out_type* const c, \
T* const d, \ out_type* const d, \
const int& m, \ const int& m, \
const int& n, \ const int& n, \
const int& k, \ const int& k, \
const int32_t* a_indices, \ const int32_t* a_indices, \
const int32_t* b_indices, \ const int32_t* b_indices, \
const int32_t* c_d_indices, \ const int32_t* c_d_indices, \
T alpha, \ out_type alpha, \
T beta, \ out_type beta, \
cutlass::device_memory::allocation<uint8_t>* const workspace_ptr) { \ cutlass::device_memory::allocation<uint8_t>* const workspace_ptr) { \
auto* tuner = \ auto* tuner = \
autotune::MakeGatherGemmScatterTuner<transpose_a, transpose_b>( \ autotune::MakeGatherGemmScatterTuner<transpose_a, transpose_b>( \
...@@ -86,22 +100,26 @@ void GatherGemmScatterDriver( ...@@ -86,22 +100,26 @@ void GatherGemmScatterDriver(
workspace_ptr); \ workspace_ptr); \
} }
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(phi::dtype::float16, EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(75,
fp16_nn_kernels,
false, false,
false)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(float,
fp32_nn_kernels,
false, false,
false) phi::dtype::float16,
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(float, phi::dtype::float16,
fp32_nt_kernels, sm75_fp16_nn_kernels)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(
75, false, false, phi::dtype::float16, float, sm75_fp32_nn_kernels)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(80,
false, false,
true) false,
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(float, phi::dtype::float16,
fp32_tn_kernels, phi::dtype::float16,
true, sm80_fp16_nn_kernels)
false) EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(
80, false, false, float, float, sm80_fp32_nn_kernels)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(
80, false, true, float, float, sm80_fp32_nt_kernels)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(
80, true, false, float, float, sm80_fp32_tn_kernels)
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册