未验证 提交 0b98d1aa 编写于 作者: U umiswing 提交者: GitHub

[cutlass] Sparse conv3d backward fusion (#52361)

上级 1acb845a
...@@ -177,18 +177,34 @@ class MatmulAutoTuner ...@@ -177,18 +177,34 @@ class MatmulAutoTuner
} }
}; };
template <typename T, typename ReturnType, typename... Args> template <bool TransposeA,
bool TransposeB,
typename T,
typename ReturnType,
typename... Args>
class GatherGemmScatterAutoTuner class GatherGemmScatterAutoTuner
: public AutoTuneBase<T, KernelCallback<T, ReturnType, T, T, Args...>> { : public AutoTuneBase<T, KernelCallback<T, ReturnType, T, T, Args...>> {
public: public:
static GatherGemmScatterAutoTuner<T, ReturnType, Args...>* Instance( static GatherGemmScatterAutoTuner<TransposeA,
ReturnType (*func)(T, T, Args...)) { TransposeB,
T,
ReturnType,
Args...>*
Instance(ReturnType (*func)(T, T, Args...)) {
static std::once_flag gather_gemm_scatter_init_flag; static std::once_flag gather_gemm_scatter_init_flag;
static std::unique_ptr<GatherGemmScatterAutoTuner<T, ReturnType, Args...>> static std::unique_ptr<GatherGemmScatterAutoTuner<TransposeA,
TransposeB,
T,
ReturnType,
Args...>>
instance; instance;
std::call_once(gather_gemm_scatter_init_flag, [&] { std::call_once(gather_gemm_scatter_init_flag, [&] {
auto obj = MakeCallback<T>(func); auto obj = MakeCallback<T>(func);
instance.reset(new GatherGemmScatterAutoTuner<T, ReturnType, Args...>); instance.reset(new GatherGemmScatterAutoTuner<TransposeA,
TransposeB,
T,
ReturnType,
Args...>);
instance->AddCallBack(func); instance->AddCallBack(func);
}); });
return instance.get(); return instance.get();
...@@ -201,7 +217,8 @@ class GatherGemmScatterAutoTuner ...@@ -201,7 +217,8 @@ class GatherGemmScatterAutoTuner
Args... args) { Args... args) {
this->is_init_ = true; this->is_init_ = true;
this->CheckKernelSize(); this->CheckKernelSize();
auto& cache = AutoTuneCache::Instance().GetGatherGemmScatter<T>(); auto& cache = AutoTuneCache::Instance()
.GetGatherGemmScatter<T, TransposeA, TransposeB>();
if (cache.Find(key)) { if (cache.Find(key)) {
auto best_idx = cache.Get(key); auto best_idx = cache.Get(key);
...@@ -250,10 +267,22 @@ class GatherGemmScatterAutoTuner ...@@ -250,10 +267,22 @@ class GatherGemmScatterAutoTuner
return best_idx; return best_idx;
} }
}; };
template <typename T, typename ReturnType, typename... Args> template <bool TransposeA,
static GatherGemmScatterAutoTuner<T, ReturnType, Args...>* bool TransposeB,
typename T,
typename ReturnType,
typename... Args>
static GatherGemmScatterAutoTuner<TransposeA,
TransposeB,
T,
ReturnType,
Args...>*
MakeGatherGemmScatterTuner(ReturnType (*func)(T, T, Args...)) { MakeGatherGemmScatterTuner(ReturnType (*func)(T, T, Args...)) {
return GatherGemmScatterAutoTuner<T, ReturnType, Args...>::Instance(func); return GatherGemmScatterAutoTuner<TransposeA,
TransposeB,
T,
ReturnType,
Args...>::Instance(func);
} }
// Define the auto_tuner inital object. // Define the auto_tuner inital object.
......
...@@ -47,13 +47,15 @@ enum class AlgorithmType { ...@@ -47,13 +47,15 @@ enum class AlgorithmType {
kMatmul = 5, kMatmul = 5,
kGatherGemmScatterFP16NN = 6, kGatherGemmScatterFP16NN = 6,
kGatherGemmScatterFP32NN = 7, kGatherGemmScatterFP32NN = 7,
kGatherGemmScatterFP32TN = 8,
kGatherGemmScatterFP32NT = 9,
#if !defined(PADDLE_WITH_CUDNN_FRONTEND) #if !defined(PADDLE_WITH_CUDNN_FRONTEND)
kAlgorithmCount = 8 kAlgorithmCount = 10
#else #else
kConvForwardV8 = 8, kConvForwardV8 = 10,
kConvBackwardDataV8 = 9, kConvBackwardDataV8 = 11,
kConvBackwardFilterV8 = 10, kConvBackwardFilterV8 = 12,
kAlgorithmCount = 11 kAlgorithmCount = 13
#endif #endif
}; };
...@@ -73,6 +75,17 @@ using CudnnV8AlgorithmsTypeMap = ...@@ -73,6 +75,17 @@ using CudnnV8AlgorithmsTypeMap =
std::unordered_map<int64_t, CudnnFrontendPlanCache>; std::unordered_map<int64_t, CudnnFrontendPlanCache>;
#endif #endif
#define DEFINE_GET_GATHER_GEMM_SCATTER( \
dtype, transpose_a, transpose_b, algo_type) \
template <typename T, bool TransposeA, bool TransposeB> \
typename std::enable_if<std::is_same<T, dtype>::value && \
TransposeA == transpose_a && \
TransposeB == transpose_b, \
AlgorithmsCacheMap&>::type \
GetGatherGemmScatter() { \
return Get(algo_type); \
}
class AutoTuneCache { class AutoTuneCache {
public: public:
static AutoTuneCache& Instance() { static AutoTuneCache& Instance() {
...@@ -89,20 +102,22 @@ class AutoTuneCache { ...@@ -89,20 +102,22 @@ class AutoTuneCache {
ConvAlgorithmsCacheMap& GetConv(const AlgorithmType& algo_type) { ConvAlgorithmsCacheMap& GetConv(const AlgorithmType& algo_type) {
return conv_auto_tune_map_[static_cast<int64_t>(algo_type)]; return conv_auto_tune_map_[static_cast<int64_t>(algo_type)];
} }
DEFINE_GET_GATHER_GEMM_SCATTER(phi::dtype::float16,
template <typename T> false,
typename std::enable_if<std::is_same<T, float>::value, false,
AlgorithmsCacheMap&>::type AlgorithmType::kGatherGemmScatterFP16NN);
GetGatherGemmScatter() { DEFINE_GET_GATHER_GEMM_SCATTER(float,
return Get(AlgorithmType::kGatherGemmScatterFP32NN); false,
} false,
AlgorithmType::kGatherGemmScatterFP32NN);
template <typename T> DEFINE_GET_GATHER_GEMM_SCATTER(float,
typename std::enable_if<std::is_same<T, phi::dtype::float16>::value, true,
AlgorithmsCacheMap&>::type false,
GetGatherGemmScatter() { AlgorithmType::kGatherGemmScatterFP32TN);
return Get(AlgorithmType::kGatherGemmScatterFP16NN); DEFINE_GET_GATHER_GEMM_SCATTER(float,
} false,
true,
AlgorithmType::kGatherGemmScatterFP32NT);
#ifdef PADDLE_WITH_CUDNN_FRONTEND #ifdef PADDLE_WITH_CUDNN_FRONTEND
CudnnFrontendPlanCache& GetConvV8(const AlgorithmType& algo_type) { CudnnFrontendPlanCache& GetConvV8(const AlgorithmType& algo_type) {
......
...@@ -24,9 +24,13 @@ limitations under the License. */ ...@@ -24,9 +24,13 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h" #include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
#ifdef PADDLE_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
#endif
namespace phi { namespace phi {
namespace sparse { namespace sparse {
extern size_t workspace_size;
// rulebook[3, rulebook_len]: // rulebook[3, rulebook_len]:
//[ //[
...@@ -130,34 +134,52 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -130,34 +134,52 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
phi::backends::gpu::GpuMemsetAsync( phi::backends::gpu::GpuMemsetAsync(
out_index_ptr, 0, sizeof(int) * x.nnz() * 2, dev_ctx.stream()); out_index_ptr, 0, sizeof(int) * x.nnz() * 2, dev_ctx.stream());
GroupIndexsV2<<<config.block_per_grid, #ifdef PADDLE_WITH_CUTLASS
config.thread_per_block, bool cutlass = true;
0, if (dev_ctx.GetComputeCapability() < 80) cutlass = false;
dev_ctx.stream()>>>(rulebook_len,
x.nnz(),
kernel_size,
offsets[kernel_size / 2],
rulebook_ptr,
out_index_ptr,
unique_value_ptr);
GatherV2<T, IntT>(dev_ctx, if (in_channels % 4 != 0 || out_channels % 4 != 0) cutlass = false;
x.values().data<T>(),
out_index_ptr,
unique_value_ptr,
x.nnz(),
kernel_size,
in_channels,
2,
in_features_ptr);
Gather<T, IntT>(dev_ctx, if (std::is_same<T, phi::dtype::float16>::value ||
out_grad.values().data<T>(), std::is_same<T, double>::value)
rulebook_ptr + rulebook_len, cutlass = false;
rulebook_len,
out_channels,
out_grad_features_ptr);
if (!std::is_same<IntT, int32_t>::value) cutlass = false;
if (!cutlass) {
#endif
GroupIndexsV2<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(rulebook_len,
x.nnz(),
kernel_size,
offsets[kernel_size / 2],
rulebook_ptr,
out_index_ptr,
unique_value_ptr);
GatherV2<T, IntT>(dev_ctx,
x.values().data<T>(),
out_index_ptr,
unique_value_ptr,
x.nnz(),
kernel_size,
in_channels,
2,
in_features_ptr);
Gather<T, IntT>(dev_ctx,
out_grad.values().data<T>(),
rulebook_ptr + rulebook_len,
rulebook_len,
out_channels,
out_grad_features_ptr);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
const T* kernel_ptr = kernel.data<T>(); const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) { for (int i = 0; i < kernel_size; i++) {
if (counter_ptr[i] <= 0 || (subm && i == half_kernel_size)) { if (counter_ptr[i] <= 0 || (subm && i == half_kernel_size)) {
...@@ -173,43 +195,98 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -173,43 +195,98 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * in_channels; T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * in_channels;
T* tmp_d_kernel_ptr = d_kernel_ptr + i * in_channels * out_channels; T* tmp_d_kernel_ptr = d_kernel_ptr + i * in_channels * out_channels;
// call gemm: d_kernel = transpose(x) * out_grad #ifdef PADDLE_WITH_CUTLASS
// (in_channels, n) * (n, out_channels) if (cutlass) {
blas.GEMM(CblasTrans, const IntT* gather_x_indices = rulebook_ptr + offsets[i];
CblasNoTrans, const IntT* scatter_x_indices = rulebook_ptr + offsets[i];
K, const IntT* gather_out_indices = rulebook_ptr + rulebook_len + offsets[i];
N, const size_t key = autotune::GenKey(M / features_num_range, N, K);
M, // call gemm: d_kernel = transpose(x) * out_grad
static_cast<T>(1), // (in_channels, n) * (n, out_channels)
tmp_in_ptr, static cutlass::device_memory::allocation<uint8_t> workspace(
tmp_out_grad_ptr, workspace_size);
static_cast<T>(0), GatherGemmScatterDriver<T, IntT, true, false>(
tmp_d_kernel_ptr); dev_ctx,
key,
x.values().data<T>(),
out_grad.values().data<T>(),
tmp_d_kernel_ptr,
tmp_d_kernel_ptr,
in_channels,
out_channels,
counter_ptr[i],
gather_x_indices,
gather_out_indices,
static_cast<const IntT*>(nullptr),
static_cast<const T>(1.0),
static_cast<const T>(0.0),
&workspace);
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
GatherGemmScatterDriver<T, IntT, false, true>(
dev_ctx,
key,
out_grad.values().data<T>(),
tmp_kernel_ptr,
x_grad_values_ptr,
x_grad_values_ptr,
counter_ptr[i],
in_channels,
out_channels,
gather_out_indices,
static_cast<const IntT*>(nullptr),
scatter_x_indices,
static_cast<const T>(1.0),
static_cast<const T>(1.0),
nullptr);
} else {
#endif
// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
blas.GEMM(CblasTrans,
CblasNoTrans,
K,
N,
M,
static_cast<T>(1),
tmp_in_ptr,
tmp_out_grad_ptr,
static_cast<T>(0),
tmp_d_kernel_ptr);
// call gemm: d_x = out_grad * transpose(kernel) // call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels) // (n, out_channels) * (out_channels, in_channels)
blas.GEMM(CblasNoTrans, blas.GEMM(CblasNoTrans,
CblasTrans, CblasTrans,
M, M,
K, K,
N, N,
static_cast<T>(1), static_cast<T>(1),
tmp_out_grad_ptr, tmp_out_grad_ptr,
tmp_kernel_ptr, tmp_kernel_ptr,
static_cast<T>(0), static_cast<T>(0),
tmp_d_x_ptr); tmp_d_x_ptr);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
} }
// 4. scatter // 4. scatter
phi::funcs::sparse::ScatterV2<T>(dev_ctx, #ifdef PADDLE_WITH_CUTLASS
d_x_features_ptr, if (!cutlass) {
out_index.data<int>(), #endif
unique_value.data<int>(), phi::funcs::sparse::ScatterV2<T>(dev_ctx,
x_grad->nnz(), d_x_features_ptr,
kernel_size, out_index.data<int>(),
in_channels, unique_value.data<int>(),
2, x_grad->nnz(),
x_grad_values_ptr); kernel_size,
in_channels,
2,
x_grad_values_ptr);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
} }
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -154,18 +154,23 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -154,18 +154,23 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i]; const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i];
const IntT* scatter_indices = const IntT* scatter_indices =
rulebook_ptr + rulebook_len + h_offsets_ptr[i]; rulebook_ptr + rulebook_len + h_offsets_ptr[i];
GatherGemmScatterDriver(dev_ctx, const size_t key = autotune::GenKey(M / features_num_range, N, K);
x.non_zero_elements().data<T>(), GatherGemmScatterDriver<T, IntT, false, false>(
tmp_kernel_ptr, dev_ctx,
out_values_ptr, key,
out_values_ptr, x.non_zero_elements().data<T>(),
M, tmp_kernel_ptr,
N, out_values_ptr,
K, out_values_ptr,
gather_indices, M,
scatter_indices, N,
static_cast<T>(1.0), K,
static_cast<T>(1.0)); gather_indices,
static_cast<const IntT*>(nullptr),
scatter_indices,
static_cast<T>(1.0),
static_cast<T>(1.0),
nullptr);
} }
} else { } else {
#endif #endif
......
...@@ -16,28 +16,41 @@ ...@@ -16,28 +16,41 @@
#ifdef PADDLE_WITH_CUTLASS #ifdef PADDLE_WITH_CUTLASS
#include "cutlass/arch/mma.h" #include "cutlass/arch/mma.h"
#include "cutlass/device_kernel.h"
#include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/gemm.h" #include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/half.h" #include "cutlass/half.h"
#include "cutlass/reduction/device/reduce_split_k.h"
#include "cutlass/reduction/thread/reduction_operators.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/device_memory.h" #include "cutlass/util/device_memory.h"
#include "examples/common/helper.h" #include "examples/common/helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
namespace phi { namespace phi {
namespace sparse { namespace sparse {
#define TYPEDEF_KERNEL_POINTER(kernel, dtype) \ size_t constexpr max_splitk_slices = 256;
typedef void (*kernel)(dtype const alpha, \ size_t constexpr max_in_channels = 256;
dtype const beta, \ size_t constexpr max_out_channels = 256;
const GPUContext& dev_ctx, \ static size_t workspace_size =
const dtype* const a, \ sizeof(float) * max_splitk_slices * max_in_channels * max_out_channels;
const dtype* const b, \
const dtype* const c, \ #define TYPEDEF_KERNEL_POINTER(kernel, dtype) \
dtype* const d, \ typedef void (*kernel)(dtype const alpha, \
const int m, \ dtype const beta, \
const int n, \ const GPUContext& dev_ctx, \
const int k, \ const dtype* const a, \
const int32_t* a_indices, \ const dtype* const b, \
const int32_t* c_d_indices); const dtype* const c, \
dtype* const d, \
const int m, \
const int n, \
const int k, \
const int32_t* a_indices, \
const int32_t* b_indices, \
const int32_t* c_d_indices, \
void* const workspace_ptr);
#define GATHER_GEMM_SCATTER_CHECK(status) \ #define GATHER_GEMM_SCATTER_CHECK(status) \
{ \ { \
cutlass::Status error = status; \ cutlass::Status error = status; \
...@@ -45,51 +58,126 @@ namespace sparse { ...@@ -45,51 +58,126 @@ namespace sparse {
throw std::runtime_error(cutlassGetStatusString(error)); \ throw std::runtime_error(cutlassGetStatusString(error)); \
} \ } \
} }
#define DEFINE_LAUNCH_KERNEL(dtype, cutlass_type) \ #define DEFINE_LAUNCH_KERNEL(dtype, cutlass_type) \
template <typename Gemm> \ template <typename Config> \
void launchKernel(dtype const alpha, \ void launchKernel(dtype const alpha, \
dtype const beta, \ dtype const beta, \
const GPUContext& dev_ctx, \ const GPUContext& dev_ctx, \
const dtype* const a, \ const dtype* const a, \
const dtype* const b, \ const dtype* const b, \
const dtype* const c, \ const dtype* const c, \
dtype* const d, \ dtype* const d, \
const int m, \ const int m, \
const int n, \ const int n, \
const int k, \ const int k, \
const int32_t* a_indices, \ const int32_t* a_indices, \
const int32_t* c_d_indices) { \ const int32_t* b_indices, \
cutlass::gemm::GemmCoord problem_size_real({m, n, k}); \ const int32_t* c_d_indices, \
int split_k_slices = 1; \ void* const workspace_ptr) { \
typename Gemm::Arguments arguments{ \ cutlass::gemm::GemmCoord problem_size_real({m, n, k}); \
cutlass::gemm::GemmUniversalMode::kGemm, \ using Gemm = typename Config::Gemm; \
problem_size_real, \ int split_k_slices = std::max(std::min(64, k / 128), 1); \
split_k_slices, \ typename Gemm::Arguments arguments{ \
{static_cast<const cutlass_type>(static_cast<const float>(alpha)), \ Config::Mode, \
static_cast<const cutlass_type>(static_cast<const float>(beta))}, \ problem_size_real, \
reinterpret_cast<const cutlass_type* const>(a), \ split_k_slices, \
reinterpret_cast<const cutlass_type* const>(b), \ {static_cast<const cutlass_type>(static_cast<const float>(alpha)), \
reinterpret_cast<const cutlass_type* const>(c), \ static_cast<const cutlass_type>(static_cast<const float>(beta))}, \
reinterpret_cast<cutlass_type* const>(d), \ reinterpret_cast<const cutlass_type* const>(a), \
cutlass::layout::RowMajor().capacity(problem_size_real.mk()), \ reinterpret_cast<const cutlass_type* const>(b), \
cutlass::layout::RowMajor().capacity(problem_size_real.kn()), \ reinterpret_cast<const cutlass_type* const>(c), \
cutlass::layout::RowMajor().capacity(problem_size_real.mn()), \ reinterpret_cast<cutlass_type* const>(d), \
cutlass::layout::RowMajor().capacity(problem_size_real.mn()), \ m * k, \
problem_size_real.k(), \ k * n, \
problem_size_real.n(), \ m * n, \
problem_size_real.n(), \ m * n, \
problem_size_real.n(), \ std::is_same<typename Gemm::Base::LayoutA, \
a_indices, \ cutlass::layout::RowMajor>::value \
nullptr, \ ? problem_size_real.k() \
c_d_indices}; \ : problem_size_real.m(), \
size_t workspace_size = Gemm::get_workspace_size(arguments); \ std::is_same<typename Gemm::Base::LayoutB, \
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); \ cutlass::layout::RowMajor>::value \
Gemm gemm_op; \ ? problem_size_real.n() \
cutlass::Status status = gemm_op.can_implement(arguments); \ : problem_size_real.k(), \
GATHER_GEMM_SCATTER_CHECK(status); \ std::is_same<typename Gemm::Base::LayoutC, \
status = gemm_op.initialize(arguments, workspace.get()); \ cutlass::layout::RowMajor>::value \
GATHER_GEMM_SCATTER_CHECK(status); \ ? problem_size_real.n() \
gemm_op(dev_ctx.stream()); \ : problem_size_real.m(), \
std::is_same<typename Gemm::Base::LayoutC, \
cutlass::layout::RowMajor>::value \
? problem_size_real.n() \
: problem_size_real.m(), \
a_indices, \
b_indices, \
c_d_indices}; \
cutlass::device_memory::allocation<uint8_t>* const real_workspace_ptr = \
static_cast<cutlass::device_memory::allocation<uint8_t>* const>( \
workspace_ptr); \
if (Config::Mode == \
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { \
size_t current_workspace_size = Gemm::get_workspace_size(arguments); \
if (current_workspace_size > workspace_size) { \
workspace_size = current_workspace_size; \
real_workspace_ptr->reset(workspace_size); \
} \
\
arguments.ptr_D = real_workspace_ptr->get(); \
} \
Gemm gemm_op; \
cutlass::Status status = gemm_op.can_implement(arguments); \
GATHER_GEMM_SCATTER_CHECK(status); \
if (Config::Mode == \
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { \
status = gemm_op.initialize(arguments, real_workspace_ptr->get()); \
} else { \
cutlass::device_memory::allocation<uint8_t> empty_workspace(0); \
status = gemm_op.initialize(arguments, empty_workspace.get()); \
} \
GATHER_GEMM_SCATTER_CHECK(status); \
gemm_op(dev_ctx.stream()); \
if (Config::Mode == \
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { \
using ReductionOp = cutlass::reduction::thread::ReduceAdd< \
typename Gemm::ElementAccumulator, \
typename Gemm::EpilogueOutputOp::ElementAccumulator, \
Gemm::EpilogueOutputOp::kCount>; \
\
using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< \
cutlass::MatrixShape<4, 32 * Gemm::EpilogueOutputOp::kCount>, \
typename Gemm::EpilogueOutputOp, \
ReductionOp>; \
using ReductionDevice = \
typename cutlass::reduction::device::ReduceSplitK<ReductionKernel>; \
ReductionDevice reduction_op; \
int splitk_gemm_stride = n; \
cutlass::layout::RowMajor splitk_gemm_layout(splitk_gemm_stride); \
void* workspace_gemm_ptr = real_workspace_ptr->get(); \
cutlass::TensorRef<typename Gemm::ElementAccumulator, \
cutlass::layout::RowMajor> \
ref_workspace(reinterpret_cast<typename Gemm::ElementAccumulator*>( \
workspace_gemm_ptr), \
splitk_gemm_layout); \
cutlass::TensorRef<typename Gemm::Base::ElementC, \
typename Gemm::Base::LayoutC> \
ref_c(reinterpret_cast<typename Gemm::Base::ElementC* const>(d), \
splitk_gemm_layout); \
cutlass::TensorRef<typename Gemm::Base::ElementC, \
typename Gemm::Base::LayoutC> \
ref_d(reinterpret_cast<typename Gemm::Base::ElementC* const>(d), \
splitk_gemm_layout); \
typename ReductionDevice::Arguments reduction_args( \
problem_size_real.mn(), \
split_k_slices, \
static_cast<size_t>(problem_size_real.m() * problem_size_real.n()), \
ref_workspace, \
ref_d, \
ref_c, \
{static_cast<const cutlass_type>(static_cast<const float>(alpha)), \
static_cast<const cutlass_type>(static_cast<const float>(beta))}); \
status = reduction_op.initialize(reduction_args); \
GATHER_GEMM_SCATTER_CHECK(status); \
reduction_op(dev_ctx.stream()); \
} \
} }
TYPEDEF_KERNEL_POINTER(fp16_gather_gemm_scatter, phi::dtype::float16) TYPEDEF_KERNEL_POINTER(fp16_gather_gemm_scatter, phi::dtype::float16)
......
...@@ -97,7 +97,7 @@ def CreateGatherGemmScatterOperator( ...@@ -97,7 +97,7 @@ def CreateGatherGemmScatterOperator(
return operations return operations
def GenerateSM80_TensorOp_16816(manifest, cuda_version): def GenerateSM80_TensorOp_16816(manifest, cuda_version, debug=False):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return return
...@@ -191,6 +191,12 @@ def GenerateSM80_TensorOp_16816(manifest, cuda_version): ...@@ -191,6 +191,12 @@ def GenerateSM80_TensorOp_16816(manifest, cuda_version):
[64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc [64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc
), ),
] ]
if debug:
tile_descriptions = [
TileDescription(
[256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [ data_type = [
math_inst.element_a, math_inst.element_a,
...@@ -218,13 +224,15 @@ def GenerateSM80_TensorOp_16816(manifest, cuda_version): ...@@ -218,13 +224,15 @@ def GenerateSM80_TensorOp_16816(manifest, cuda_version):
) )
def GenerateSM80_TensorOp_1688(manifest, cuda_version): def GenerateSM80_TensorOp_1688(manifest, cuda_version, debug=False):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return return
layouts = [ layouts = [
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor),
] ]
math_instructions = [ math_instructions = [
...@@ -302,6 +310,13 @@ def GenerateSM80_TensorOp_1688(manifest, cuda_version): ...@@ -302,6 +310,13 @@ def GenerateSM80_TensorOp_1688(manifest, cuda_version):
), ),
] ]
if debug:
tile_descriptions = [
TileDescription(
[256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [ data_type = [
math_inst.element_a, math_inst.element_a,
math_inst.element_b, math_inst.element_b,
...@@ -325,13 +340,15 @@ def GenerateSM80_TensorOp_1688(manifest, cuda_version): ...@@ -325,13 +340,15 @@ def GenerateSM80_TensorOp_1688(manifest, cuda_version):
) )
def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version): def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version, debug=False):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return return
layouts = [ layouts = [
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor),
] ]
math_instructions = [ math_instructions = [
...@@ -409,6 +426,13 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version): ...@@ -409,6 +426,13 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version):
), ),
] ]
if debug:
tile_descriptions = [
TileDescription(
[256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32]
CreateGatherGemmScatterOperator( CreateGatherGemmScatterOperator(
...@@ -416,13 +440,17 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version): ...@@ -416,13 +440,17 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version):
) )
def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version): def GenerateSM80_TensorOp_1688_fast_fp32_math(
manifest, cuda_version, debug=False
):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return return
layouts = [ layouts = [
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor),
] ]
math_instructions = [ math_instructions = [
...@@ -482,6 +510,13 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version): ...@@ -482,6 +510,13 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version):
), ),
] ]
if debug:
tile_descriptions = [
TileDescription(
[128, 128, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32]
CreateGatherGemmScatterOperator( CreateGatherGemmScatterOperator(
...@@ -489,11 +524,11 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version): ...@@ -489,11 +524,11 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version):
) )
def GenerateSM80(manifest, cuda_version): def GenerateSM80(manifest, cuda_version, debug=False):
GenerateSM80_TensorOp_16816(manifest, cuda_version) GenerateSM80_TensorOp_16816(manifest, cuda_version, debug)
GenerateSM80_TensorOp_1688(manifest, cuda_version) GenerateSM80_TensorOp_1688(manifest, cuda_version, debug)
GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version) GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version, debug)
GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version) GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version, debug)
class KernelCfg: class KernelCfg:
......
...@@ -41,12 +41,12 @@ namespace sparse { ...@@ -41,12 +41,12 @@ namespace sparse {
} // namespace phi } // namespace phi
#endif #endif
""" """
self.fp16_kernels_list = ( self.kernels_lists = {
"static std::vector<fp16_gather_gemm_scatter> fp16_kernels = {\n" "hnn": "static std::vector<fp16_gather_gemm_scatter> fp16_nn_kernels = {",
) "snn": "static std::vector<fp32_gather_gemm_scatter> fp32_nn_kernels = {",
self.fp32_kernels_list = ( "snt": "static std::vector<fp32_gather_gemm_scatter> fp32_nt_kernels = {",
"static std::vector<fp32_gather_gemm_scatter> fp32_kernels = {\n" "stn": "static std::vector<fp32_gather_gemm_scatter> fp32_tn_kernels = {",
) }
def __enter__(self): def __enter__(self):
self.operation_path = os.path.join( self.operation_path = os.path.join(
...@@ -78,19 +78,25 @@ namespace sparse { ...@@ -78,19 +78,25 @@ namespace sparse {
self.source_files.append(configuration_emitter.configuration_path) self.source_files.append(configuration_emitter.configuration_path)
self.configurations.append(configuration_name) self.configurations.append(configuration_name)
if 'h' == operations[0].short_math_name():
self.fp16_kernels_list += ( if operations[0].layout_name() == 'tn':
self.kernels_lists[
operations[0].short_math_name() + operations[0].layout_name()
] += (
""" """
launchKernel<""" launchKernel<"""
+ configuration_name + configuration_name
+ "::Gemm>," + "<cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel"
+ ">>,"
) )
if 's' == operations[0].short_math_name(): else:
self.fp32_kernels_list += ( self.kernels_lists[
operations[0].short_math_name() + operations[0].layout_name()
] += (
""" """
launchKernel<""" launchKernel<"""
+ configuration_name + configuration_name
+ "::Gemm>," + "<>>,"
) )
self.top_level_file.write( self.top_level_file.write(
...@@ -117,11 +123,11 @@ launchKernel<""" ...@@ -117,11 +123,11 @@ launchKernel<"""
) )
) )
self.fp16_kernels_list += "\n};\n" for k, v in self.kernels_lists.items():
self.fp32_kernels_list += "\n};\n" self.kernels_lists[k] += "\n};\n"
self.top_level_file.write(self.namespace_template) self.top_level_file.write(self.namespace_template)
self.top_level_file.write(self.fp16_kernels_list) for k, v in self.kernels_lists.items():
self.top_level_file.write(self.fp32_kernels_list) self.top_level_file.write(v)
self.top_level_file.write(self.epilogue_template) self.top_level_file.write(self.epilogue_template)
self.top_level_file.close() self.top_level_file.close()
......
...@@ -52,6 +52,8 @@ class EmitGatherGemmScatterInstance(EmitGemmInstance): ...@@ -52,6 +52,8 @@ class EmitGatherGemmScatterInstance(EmitGemmInstance):
""" """
self.gemm_template = """ self.gemm_template = """
// Gemm operator ${operation_name} // Gemm operator ${operation_name}
template<cutlass::gemm::GemmUniversalMode Mode_ =
cutlass::gemm::GemmUniversalMode::kGemm>
struct ${operation_name} { struct ${operation_name} {
using Gemm = using Gemm =
cutlass::gemm::device::GemmUniversal< cutlass::gemm::device::GemmUniversal<
...@@ -75,10 +77,11 @@ struct ${operation_name} { ...@@ -75,10 +77,11 @@ struct ${operation_name} {
${math_operation}, ${math_operation},
${transform_a}, ${transform_a},
${transform_b}, ${transform_b},
true, // gather a ${gather_a}, // gather a
false, // gather b ${gather_b}, // gather b
true // scatter d ${scatter_d} // scatter d
>; >;
static const cutlass::gemm::GemmUniversalMode Mode = Mode_;
}; };
""" """
...@@ -192,6 +195,9 @@ struct ${operation_name} { ...@@ -192,6 +195,9 @@ struct ${operation_name} {
'math_operation': MathOperationTag[ 'math_operation': MathOperationTag[
operation.tile_description.math_instruction.math_operation operation.tile_description.math_instruction.math_operation
], ],
'gather_a': 'true',
'gather_b': str(operation.layout_name() == 'tn').lower(),
'scatter_d': str(operation.layout_name() != 'tn').lower(),
} }
return SubstituteTemplate(gemm_template, values) return SubstituteTemplate(gemm_template, values)
...@@ -295,8 +301,8 @@ class GatherGemmScatterOperation(GemmOperation): ...@@ -295,8 +301,8 @@ class GatherGemmScatterOperation(GemmOperation):
B, B,
C, C,
element_epilogue, element_epilogue,
epilogue_functor=EpilogueFunctor.LinearCombination, epilogue_functor,
swizzling_functor=SwizzlingFunctor.Identity8, swizzling_functor,
) )
self.ShortLayoutTypeNames = { self.ShortLayoutTypeNames = {
LayoutType.ColumnMajor: 't', LayoutType.ColumnMajor: 't',
......
...@@ -24,29 +24,50 @@ namespace phi { ...@@ -24,29 +24,50 @@ namespace phi {
namespace sparse { namespace sparse {
// To reduce tuning time, map shape (m,n,k) to (m/features_num_range,n,k) so // To reduce tuning time, map shape (m,n,k) to (m/features_num_range,n,k) so
// that shapes in this range share the same key. // that shapes within this range share the same key.
constexpr int features_num_range = 10000; constexpr int features_num_range = 10000;
#define DEFINE_GATHER_GEMM_SCATTER_DRIVER(dtype, kernels) \ template <typename T, typename IntT, bool TransposeA, bool TransposeB>
template <typename T, typename IntT> \ void GatherGemmScatterDriver(
typename std::enable_if<std::is_same<T, dtype>::value && \ const phi::GPUContext& ctx,
std::is_same<IntT, int32_t>::value, \ const size_t key,
void>::type \ const T* const a,
GatherGemmScatterDriver(const phi::GPUContext& ctx, \ const T* const b,
const T* const a, \ const T* const c,
const T* const b, \ T* const d,
const T* const c, \ const int& m,
T* const d, \ const int& n,
const int& m, \ const int& k,
const int& n, \ const IntT* a_indices,
const int& k, \ const IntT* b_indices,
const IntT* a_indices, \ const IntT* c_d_indices,
const IntT* c_d_indices, \ T alpha,
T alpha, \ T beta,
T beta) { \ cutlass::device_memory::allocation<uint8_t>* const workspace_ptr) {}
auto* tuner = autotune::MakeGatherGemmScatterTuner(kernels[0]); \
#define EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER( \
T, kernels, transpose_a, transpose_b) \
template <> \
inline void GatherGemmScatterDriver<T, int32_t, transpose_a, transpose_b>( \
const phi::GPUContext& ctx, \
const size_t key, \
const T* const a, \
const T* const b, \
const T* const c, \
T* const d, \
const int& m, \
const int& n, \
const int& k, \
const int32_t* a_indices, \
const int32_t* b_indices, \
const int32_t* c_d_indices, \
T alpha, \
T beta, \
cutlass::device_memory::allocation<uint8_t>* const workspace_ptr) { \
auto* tuner = \
autotune::MakeGatherGemmScatterTuner<transpose_a, transpose_b>( \
kernels[0]); \
for (auto i = 1; i < kernels.size(); i++) tuner->AddCallBack(kernels[i]); \ for (auto i = 1; i < kernels.size(); i++) tuner->AddCallBack(kernels[i]); \
size_t key = autotune::GenKey(m / features_num_range, n, k); \
tuner->Run(ctx, \ tuner->Run(ctx, \
key, \ key, \
alpha, \ alpha, \
...@@ -60,28 +81,27 @@ constexpr int features_num_range = 10000; ...@@ -60,28 +81,27 @@ constexpr int features_num_range = 10000;
n, \ n, \
k, \ k, \
a_indices, \ a_indices, \
c_d_indices); \ b_indices, \
c_d_indices, \
workspace_ptr); \
} }
template <typename T, typename IntT> EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(phi::dtype::float16,
typename std::enable_if<std::is_same<T, double>::value || fp16_nn_kernels,
!std::is_same<IntT, int32_t>::value, false,
void>::type false)
GatherGemmScatterDriver(const phi::GPUContext& ctx, EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(float,
const T* const a, fp32_nn_kernels,
const T* const b, false,
const T* const c, false)
T* const d, EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(float,
const int& m, fp32_nt_kernels,
const int& n, false,
const int& k, true)
const IntT* a_indices, EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(float,
const IntT* c_d_indices, fp32_tn_kernels,
T alpha, true,
T beta) {} false)
DEFINE_GATHER_GEMM_SCATTER_DRIVER(phi::dtype::float16, fp16_kernels)
DEFINE_GATHER_GEMM_SCATTER_DRIVER(float, fp32_kernels)
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册