未验证 提交 0b98d1aa 编写于 作者: U umiswing 提交者: GitHub

[cutlass] Sparse conv3d backward fusion (#52361)

上级 1acb845a
......@@ -177,18 +177,34 @@ class MatmulAutoTuner
}
};
template <typename T, typename ReturnType, typename... Args>
template <bool TransposeA,
bool TransposeB,
typename T,
typename ReturnType,
typename... Args>
class GatherGemmScatterAutoTuner
: public AutoTuneBase<T, KernelCallback<T, ReturnType, T, T, Args...>> {
public:
static GatherGemmScatterAutoTuner<T, ReturnType, Args...>* Instance(
ReturnType (*func)(T, T, Args...)) {
static GatherGemmScatterAutoTuner<TransposeA,
TransposeB,
T,
ReturnType,
Args...>*
Instance(ReturnType (*func)(T, T, Args...)) {
static std::once_flag gather_gemm_scatter_init_flag;
static std::unique_ptr<GatherGemmScatterAutoTuner<T, ReturnType, Args...>>
static std::unique_ptr<GatherGemmScatterAutoTuner<TransposeA,
TransposeB,
T,
ReturnType,
Args...>>
instance;
std::call_once(gather_gemm_scatter_init_flag, [&] {
auto obj = MakeCallback<T>(func);
instance.reset(new GatherGemmScatterAutoTuner<T, ReturnType, Args...>);
instance.reset(new GatherGemmScatterAutoTuner<TransposeA,
TransposeB,
T,
ReturnType,
Args...>);
instance->AddCallBack(func);
});
return instance.get();
......@@ -201,7 +217,8 @@ class GatherGemmScatterAutoTuner
Args... args) {
this->is_init_ = true;
this->CheckKernelSize();
auto& cache = AutoTuneCache::Instance().GetGatherGemmScatter<T>();
auto& cache = AutoTuneCache::Instance()
.GetGatherGemmScatter<T, TransposeA, TransposeB>();
if (cache.Find(key)) {
auto best_idx = cache.Get(key);
......@@ -250,10 +267,22 @@ class GatherGemmScatterAutoTuner
return best_idx;
}
};
template <typename T, typename ReturnType, typename... Args>
static GatherGemmScatterAutoTuner<T, ReturnType, Args...>*
template <bool TransposeA,
bool TransposeB,
typename T,
typename ReturnType,
typename... Args>
static GatherGemmScatterAutoTuner<TransposeA,
TransposeB,
T,
ReturnType,
Args...>*
MakeGatherGemmScatterTuner(ReturnType (*func)(T, T, Args...)) {
return GatherGemmScatterAutoTuner<T, ReturnType, Args...>::Instance(func);
return GatherGemmScatterAutoTuner<TransposeA,
TransposeB,
T,
ReturnType,
Args...>::Instance(func);
}
// Define the auto_tuner inital object.
......
......@@ -47,13 +47,15 @@ enum class AlgorithmType {
kMatmul = 5,
kGatherGemmScatterFP16NN = 6,
kGatherGemmScatterFP32NN = 7,
kGatherGemmScatterFP32TN = 8,
kGatherGemmScatterFP32NT = 9,
#if !defined(PADDLE_WITH_CUDNN_FRONTEND)
kAlgorithmCount = 8
kAlgorithmCount = 10
#else
kConvForwardV8 = 8,
kConvBackwardDataV8 = 9,
kConvBackwardFilterV8 = 10,
kAlgorithmCount = 11
kConvForwardV8 = 10,
kConvBackwardDataV8 = 11,
kConvBackwardFilterV8 = 12,
kAlgorithmCount = 13
#endif
};
......@@ -73,6 +75,17 @@ using CudnnV8AlgorithmsTypeMap =
std::unordered_map<int64_t, CudnnFrontendPlanCache>;
#endif
#define DEFINE_GET_GATHER_GEMM_SCATTER( \
dtype, transpose_a, transpose_b, algo_type) \
template <typename T, bool TransposeA, bool TransposeB> \
typename std::enable_if<std::is_same<T, dtype>::value && \
TransposeA == transpose_a && \
TransposeB == transpose_b, \
AlgorithmsCacheMap&>::type \
GetGatherGemmScatter() { \
return Get(algo_type); \
}
class AutoTuneCache {
public:
static AutoTuneCache& Instance() {
......@@ -89,20 +102,22 @@ class AutoTuneCache {
ConvAlgorithmsCacheMap& GetConv(const AlgorithmType& algo_type) {
return conv_auto_tune_map_[static_cast<int64_t>(algo_type)];
}
template <typename T>
typename std::enable_if<std::is_same<T, float>::value,
AlgorithmsCacheMap&>::type
GetGatherGemmScatter() {
return Get(AlgorithmType::kGatherGemmScatterFP32NN);
}
template <typename T>
typename std::enable_if<std::is_same<T, phi::dtype::float16>::value,
AlgorithmsCacheMap&>::type
GetGatherGemmScatter() {
return Get(AlgorithmType::kGatherGemmScatterFP16NN);
}
DEFINE_GET_GATHER_GEMM_SCATTER(phi::dtype::float16,
false,
false,
AlgorithmType::kGatherGemmScatterFP16NN);
DEFINE_GET_GATHER_GEMM_SCATTER(float,
false,
false,
AlgorithmType::kGatherGemmScatterFP32NN);
DEFINE_GET_GATHER_GEMM_SCATTER(float,
true,
false,
AlgorithmType::kGatherGemmScatterFP32TN);
DEFINE_GET_GATHER_GEMM_SCATTER(float,
false,
true,
AlgorithmType::kGatherGemmScatterFP32NT);
#ifdef PADDLE_WITH_CUDNN_FRONTEND
CudnnFrontendPlanCache& GetConvV8(const AlgorithmType& algo_type) {
......
......@@ -24,9 +24,13 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
#ifdef PADDLE_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
#endif
namespace phi {
namespace sparse {
extern size_t workspace_size;
// rulebook[3, rulebook_len]:
//[
......@@ -130,6 +134,21 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
phi::backends::gpu::GpuMemsetAsync(
out_index_ptr, 0, sizeof(int) * x.nnz() * 2, dev_ctx.stream());
#ifdef PADDLE_WITH_CUTLASS
bool cutlass = true;
if (dev_ctx.GetComputeCapability() < 80) cutlass = false;
if (in_channels % 4 != 0 || out_channels % 4 != 0) cutlass = false;
if (std::is_same<T, phi::dtype::float16>::value ||
std::is_same<T, double>::value)
cutlass = false;
if (!std::is_same<IntT, int32_t>::value) cutlass = false;
if (!cutlass) {
#endif
GroupIndexsV2<<<config.block_per_grid,
config.thread_per_block,
0,
......@@ -158,6 +177,9 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
out_channels,
out_grad_features_ptr);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
if (counter_ptr[i] <= 0 || (subm && i == half_kernel_size)) {
......@@ -173,6 +195,52 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * in_channels;
T* tmp_d_kernel_ptr = d_kernel_ptr + i * in_channels * out_channels;
#ifdef PADDLE_WITH_CUTLASS
if (cutlass) {
const IntT* gather_x_indices = rulebook_ptr + offsets[i];
const IntT* scatter_x_indices = rulebook_ptr + offsets[i];
const IntT* gather_out_indices = rulebook_ptr + rulebook_len + offsets[i];
const size_t key = autotune::GenKey(M / features_num_range, N, K);
// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
static cutlass::device_memory::allocation<uint8_t> workspace(
workspace_size);
GatherGemmScatterDriver<T, IntT, true, false>(
dev_ctx,
key,
x.values().data<T>(),
out_grad.values().data<T>(),
tmp_d_kernel_ptr,
tmp_d_kernel_ptr,
in_channels,
out_channels,
counter_ptr[i],
gather_x_indices,
gather_out_indices,
static_cast<const IntT*>(nullptr),
static_cast<const T>(1.0),
static_cast<const T>(0.0),
&workspace);
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
GatherGemmScatterDriver<T, IntT, false, true>(
dev_ctx,
key,
out_grad.values().data<T>(),
tmp_kernel_ptr,
x_grad_values_ptr,
x_grad_values_ptr,
counter_ptr[i],
in_channels,
out_channels,
gather_out_indices,
static_cast<const IntT*>(nullptr),
scatter_x_indices,
static_cast<const T>(1.0),
static_cast<const T>(1.0),
nullptr);
} else {
#endif
// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
blas.GEMM(CblasTrans,
......@@ -198,9 +266,15 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
tmp_kernel_ptr,
static_cast<T>(0),
tmp_d_x_ptr);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
}
// 4. scatter
#ifdef PADDLE_WITH_CUTLASS
if (!cutlass) {
#endif
phi::funcs::sparse::ScatterV2<T>(dev_ctx,
d_x_features_ptr,
out_index.data<int>(),
......@@ -210,6 +284,9 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
in_channels,
2,
x_grad_values_ptr);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
}
template <typename T, typename Context>
......
......@@ -154,7 +154,10 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i];
const IntT* scatter_indices =
rulebook_ptr + rulebook_len + h_offsets_ptr[i];
GatherGemmScatterDriver(dev_ctx,
const size_t key = autotune::GenKey(M / features_num_range, N, K);
GatherGemmScatterDriver<T, IntT, false, false>(
dev_ctx,
key,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
......@@ -163,9 +166,11 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
N,
K,
gather_indices,
static_cast<const IntT*>(nullptr),
scatter_indices,
static_cast<T>(1.0),
static_cast<T>(1.0));
static_cast<T>(1.0),
nullptr);
}
} else {
#endif
......
......@@ -16,15 +16,26 @@
#ifdef PADDLE_WITH_CUTLASS
#include "cutlass/arch/mma.h"
#include "cutlass/device_kernel.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/half.h"
#include "cutlass/reduction/device/reduce_split_k.h"
#include "cutlass/reduction/thread/reduction_operators.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/device_memory.h"
#include "examples/common/helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace phi {
namespace sparse {
size_t constexpr max_splitk_slices = 256;
size_t constexpr max_in_channels = 256;
size_t constexpr max_out_channels = 256;
static size_t workspace_size =
sizeof(float) * max_splitk_slices * max_in_channels * max_out_channels;
#define TYPEDEF_KERNEL_POINTER(kernel, dtype) \
typedef void (*kernel)(dtype const alpha, \
dtype const beta, \
......@@ -37,7 +48,9 @@ namespace sparse {
const int n, \
const int k, \
const int32_t* a_indices, \
const int32_t* c_d_indices);
const int32_t* b_indices, \
const int32_t* c_d_indices, \
void* const workspace_ptr);
#define GATHER_GEMM_SCATTER_CHECK(status) \
{ \
cutlass::Status error = status; \
......@@ -46,7 +59,7 @@ namespace sparse {
} \
}
#define DEFINE_LAUNCH_KERNEL(dtype, cutlass_type) \
template <typename Gemm> \
template <typename Config> \
void launchKernel(dtype const alpha, \
dtype const beta, \
const GPUContext& dev_ctx, \
......@@ -58,11 +71,14 @@ namespace sparse {
const int n, \
const int k, \
const int32_t* a_indices, \
const int32_t* c_d_indices) { \
const int32_t* b_indices, \
const int32_t* c_d_indices, \
void* const workspace_ptr) { \
cutlass::gemm::GemmCoord problem_size_real({m, n, k}); \
int split_k_slices = 1; \
using Gemm = typename Config::Gemm; \
int split_k_slices = std::max(std::min(64, k / 128), 1); \
typename Gemm::Arguments arguments{ \
cutlass::gemm::GemmUniversalMode::kGemm, \
Config::Mode, \
problem_size_real, \
split_k_slices, \
{static_cast<const cutlass_type>(static_cast<const float>(alpha)), \
......@@ -71,25 +87,97 @@ namespace sparse {
reinterpret_cast<const cutlass_type* const>(b), \
reinterpret_cast<const cutlass_type* const>(c), \
reinterpret_cast<cutlass_type* const>(d), \
cutlass::layout::RowMajor().capacity(problem_size_real.mk()), \
cutlass::layout::RowMajor().capacity(problem_size_real.kn()), \
cutlass::layout::RowMajor().capacity(problem_size_real.mn()), \
cutlass::layout::RowMajor().capacity(problem_size_real.mn()), \
problem_size_real.k(), \
problem_size_real.n(), \
problem_size_real.n(), \
problem_size_real.n(), \
m * k, \
k * n, \
m * n, \
m * n, \
std::is_same<typename Gemm::Base::LayoutA, \
cutlass::layout::RowMajor>::value \
? problem_size_real.k() \
: problem_size_real.m(), \
std::is_same<typename Gemm::Base::LayoutB, \
cutlass::layout::RowMajor>::value \
? problem_size_real.n() \
: problem_size_real.k(), \
std::is_same<typename Gemm::Base::LayoutC, \
cutlass::layout::RowMajor>::value \
? problem_size_real.n() \
: problem_size_real.m(), \
std::is_same<typename Gemm::Base::LayoutC, \
cutlass::layout::RowMajor>::value \
? problem_size_real.n() \
: problem_size_real.m(), \
a_indices, \
nullptr, \
b_indices, \
c_d_indices}; \
size_t workspace_size = Gemm::get_workspace_size(arguments); \
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); \
cutlass::device_memory::allocation<uint8_t>* const real_workspace_ptr = \
static_cast<cutlass::device_memory::allocation<uint8_t>* const>( \
workspace_ptr); \
if (Config::Mode == \
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { \
size_t current_workspace_size = Gemm::get_workspace_size(arguments); \
if (current_workspace_size > workspace_size) { \
workspace_size = current_workspace_size; \
real_workspace_ptr->reset(workspace_size); \
} \
\
arguments.ptr_D = real_workspace_ptr->get(); \
} \
Gemm gemm_op; \
cutlass::Status status = gemm_op.can_implement(arguments); \
GATHER_GEMM_SCATTER_CHECK(status); \
status = gemm_op.initialize(arguments, workspace.get()); \
if (Config::Mode == \
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { \
status = gemm_op.initialize(arguments, real_workspace_ptr->get()); \
} else { \
cutlass::device_memory::allocation<uint8_t> empty_workspace(0); \
status = gemm_op.initialize(arguments, empty_workspace.get()); \
} \
GATHER_GEMM_SCATTER_CHECK(status); \
gemm_op(dev_ctx.stream()); \
if (Config::Mode == \
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { \
using ReductionOp = cutlass::reduction::thread::ReduceAdd< \
typename Gemm::ElementAccumulator, \
typename Gemm::EpilogueOutputOp::ElementAccumulator, \
Gemm::EpilogueOutputOp::kCount>; \
\
using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< \
cutlass::MatrixShape<4, 32 * Gemm::EpilogueOutputOp::kCount>, \
typename Gemm::EpilogueOutputOp, \
ReductionOp>; \
using ReductionDevice = \
typename cutlass::reduction::device::ReduceSplitK<ReductionKernel>; \
ReductionDevice reduction_op; \
int splitk_gemm_stride = n; \
cutlass::layout::RowMajor splitk_gemm_layout(splitk_gemm_stride); \
void* workspace_gemm_ptr = real_workspace_ptr->get(); \
cutlass::TensorRef<typename Gemm::ElementAccumulator, \
cutlass::layout::RowMajor> \
ref_workspace(reinterpret_cast<typename Gemm::ElementAccumulator*>( \
workspace_gemm_ptr), \
splitk_gemm_layout); \
cutlass::TensorRef<typename Gemm::Base::ElementC, \
typename Gemm::Base::LayoutC> \
ref_c(reinterpret_cast<typename Gemm::Base::ElementC* const>(d), \
splitk_gemm_layout); \
cutlass::TensorRef<typename Gemm::Base::ElementC, \
typename Gemm::Base::LayoutC> \
ref_d(reinterpret_cast<typename Gemm::Base::ElementC* const>(d), \
splitk_gemm_layout); \
typename ReductionDevice::Arguments reduction_args( \
problem_size_real.mn(), \
split_k_slices, \
static_cast<size_t>(problem_size_real.m() * problem_size_real.n()), \
ref_workspace, \
ref_d, \
ref_c, \
{static_cast<const cutlass_type>(static_cast<const float>(alpha)), \
static_cast<const cutlass_type>(static_cast<const float>(beta))}); \
status = reduction_op.initialize(reduction_args); \
GATHER_GEMM_SCATTER_CHECK(status); \
reduction_op(dev_ctx.stream()); \
} \
}
TYPEDEF_KERNEL_POINTER(fp16_gather_gemm_scatter, phi::dtype::float16)
......
......@@ -97,7 +97,7 @@ def CreateGatherGemmScatterOperator(
return operations
def GenerateSM80_TensorOp_16816(manifest, cuda_version):
def GenerateSM80_TensorOp_16816(manifest, cuda_version, debug=False):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return
......@@ -191,6 +191,12 @@ def GenerateSM80_TensorOp_16816(manifest, cuda_version):
[64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc
),
]
if debug:
tile_descriptions = [
TileDescription(
[256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [
math_inst.element_a,
......@@ -218,13 +224,15 @@ def GenerateSM80_TensorOp_16816(manifest, cuda_version):
)
def GenerateSM80_TensorOp_1688(manifest, cuda_version):
def GenerateSM80_TensorOp_1688(manifest, cuda_version, debug=False):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return
layouts = [
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor),
]
math_instructions = [
......@@ -302,6 +310,13 @@ def GenerateSM80_TensorOp_1688(manifest, cuda_version):
),
]
if debug:
tile_descriptions = [
TileDescription(
[256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [
math_inst.element_a,
math_inst.element_b,
......@@ -325,13 +340,15 @@ def GenerateSM80_TensorOp_1688(manifest, cuda_version):
)
def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version):
def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version, debug=False):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return
layouts = [
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor),
]
math_instructions = [
......@@ -409,6 +426,13 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version):
),
]
if debug:
tile_descriptions = [
TileDescription(
[256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32]
CreateGatherGemmScatterOperator(
......@@ -416,13 +440,17 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version):
)
def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version):
def GenerateSM80_TensorOp_1688_fast_fp32_math(
manifest, cuda_version, debug=False
):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return
layouts = [
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
(LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor),
]
math_instructions = [
......@@ -482,6 +510,13 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version):
),
]
if debug:
tile_descriptions = [
TileDescription(
[128, 128, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32]
CreateGatherGemmScatterOperator(
......@@ -489,11 +524,11 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version):
)
def GenerateSM80(manifest, cuda_version):
GenerateSM80_TensorOp_16816(manifest, cuda_version)
GenerateSM80_TensorOp_1688(manifest, cuda_version)
GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version)
GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version)
def GenerateSM80(manifest, cuda_version, debug=False):
GenerateSM80_TensorOp_16816(manifest, cuda_version, debug)
GenerateSM80_TensorOp_1688(manifest, cuda_version, debug)
GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version, debug)
GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version, debug)
class KernelCfg:
......
......@@ -41,12 +41,12 @@ namespace sparse {
} // namespace phi
#endif
"""
self.fp16_kernels_list = (
"static std::vector<fp16_gather_gemm_scatter> fp16_kernels = {\n"
)
self.fp32_kernels_list = (
"static std::vector<fp32_gather_gemm_scatter> fp32_kernels = {\n"
)
self.kernels_lists = {
"hnn": "static std::vector<fp16_gather_gemm_scatter> fp16_nn_kernels = {",
"snn": "static std::vector<fp32_gather_gemm_scatter> fp32_nn_kernels = {",
"snt": "static std::vector<fp32_gather_gemm_scatter> fp32_nt_kernels = {",
"stn": "static std::vector<fp32_gather_gemm_scatter> fp32_tn_kernels = {",
}
def __enter__(self):
self.operation_path = os.path.join(
......@@ -78,19 +78,25 @@ namespace sparse {
self.source_files.append(configuration_emitter.configuration_path)
self.configurations.append(configuration_name)
if 'h' == operations[0].short_math_name():
self.fp16_kernels_list += (
if operations[0].layout_name() == 'tn':
self.kernels_lists[
operations[0].short_math_name() + operations[0].layout_name()
] += (
"""
launchKernel<"""
+ configuration_name
+ "::Gemm>,"
+ "<cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel"
+ ">>,"
)
if 's' == operations[0].short_math_name():
self.fp32_kernels_list += (
else:
self.kernels_lists[
operations[0].short_math_name() + operations[0].layout_name()
] += (
"""
launchKernel<"""
+ configuration_name
+ "::Gemm>,"
+ "<>>,"
)
self.top_level_file.write(
......@@ -117,11 +123,11 @@ launchKernel<"""
)
)
self.fp16_kernels_list += "\n};\n"
self.fp32_kernels_list += "\n};\n"
for k, v in self.kernels_lists.items():
self.kernels_lists[k] += "\n};\n"
self.top_level_file.write(self.namespace_template)
self.top_level_file.write(self.fp16_kernels_list)
self.top_level_file.write(self.fp32_kernels_list)
for k, v in self.kernels_lists.items():
self.top_level_file.write(v)
self.top_level_file.write(self.epilogue_template)
self.top_level_file.close()
......
......@@ -52,6 +52,8 @@ class EmitGatherGemmScatterInstance(EmitGemmInstance):
"""
self.gemm_template = """
// Gemm operator ${operation_name}
template<cutlass::gemm::GemmUniversalMode Mode_ =
cutlass::gemm::GemmUniversalMode::kGemm>
struct ${operation_name} {
using Gemm =
cutlass::gemm::device::GemmUniversal<
......@@ -75,10 +77,11 @@ struct ${operation_name} {
${math_operation},
${transform_a},
${transform_b},
true, // gather a
false, // gather b
true // scatter d
${gather_a}, // gather a
${gather_b}, // gather b
${scatter_d} // scatter d
>;
static const cutlass::gemm::GemmUniversalMode Mode = Mode_;
};
"""
......@@ -192,6 +195,9 @@ struct ${operation_name} {
'math_operation': MathOperationTag[
operation.tile_description.math_instruction.math_operation
],
'gather_a': 'true',
'gather_b': str(operation.layout_name() == 'tn').lower(),
'scatter_d': str(operation.layout_name() != 'tn').lower(),
}
return SubstituteTemplate(gemm_template, values)
......@@ -295,8 +301,8 @@ class GatherGemmScatterOperation(GemmOperation):
B,
C,
element_epilogue,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity8,
epilogue_functor,
swizzling_functor,
)
self.ShortLayoutTypeNames = {
LayoutType.ColumnMajor: 't',
......
......@@ -24,15 +24,33 @@ namespace phi {
namespace sparse {
// To reduce tuning time, map shape (m,n,k) to (m/features_num_range,n,k) so
// that shapes in this range share the same key.
// that shapes within this range share the same key.
constexpr int features_num_range = 10000;
#define DEFINE_GATHER_GEMM_SCATTER_DRIVER(dtype, kernels) \
template <typename T, typename IntT> \
typename std::enable_if<std::is_same<T, dtype>::value && \
std::is_same<IntT, int32_t>::value, \
void>::type \
GatherGemmScatterDriver(const phi::GPUContext& ctx, \
template <typename T, typename IntT, bool TransposeA, bool TransposeB>
void GatherGemmScatterDriver(
const phi::GPUContext& ctx,
const size_t key,
const T* const a,
const T* const b,
const T* const c,
T* const d,
const int& m,
const int& n,
const int& k,
const IntT* a_indices,
const IntT* b_indices,
const IntT* c_d_indices,
T alpha,
T beta,
cutlass::device_memory::allocation<uint8_t>* const workspace_ptr) {}
#define EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER( \
T, kernels, transpose_a, transpose_b) \
template <> \
inline void GatherGemmScatterDriver<T, int32_t, transpose_a, transpose_b>( \
const phi::GPUContext& ctx, \
const size_t key, \
const T* const a, \
const T* const b, \
const T* const c, \
......@@ -40,13 +58,16 @@ constexpr int features_num_range = 10000;
const int& m, \
const int& n, \
const int& k, \
const IntT* a_indices, \
const IntT* c_d_indices, \
const int32_t* a_indices, \
const int32_t* b_indices, \
const int32_t* c_d_indices, \
T alpha, \
T beta) { \
auto* tuner = autotune::MakeGatherGemmScatterTuner(kernels[0]); \
T beta, \
cutlass::device_memory::allocation<uint8_t>* const workspace_ptr) { \
auto* tuner = \
autotune::MakeGatherGemmScatterTuner<transpose_a, transpose_b>( \
kernels[0]); \
for (auto i = 1; i < kernels.size(); i++) tuner->AddCallBack(kernels[i]); \
size_t key = autotune::GenKey(m / features_num_range, n, k); \
tuner->Run(ctx, \
key, \
alpha, \
......@@ -60,28 +81,27 @@ constexpr int features_num_range = 10000;
n, \
k, \
a_indices, \
c_d_indices); \
b_indices, \
c_d_indices, \
workspace_ptr); \
}
template <typename T, typename IntT>
typename std::enable_if<std::is_same<T, double>::value ||
!std::is_same<IntT, int32_t>::value,
void>::type
GatherGemmScatterDriver(const phi::GPUContext& ctx,
const T* const a,
const T* const b,
const T* const c,
T* const d,
const int& m,
const int& n,
const int& k,
const IntT* a_indices,
const IntT* c_d_indices,
T alpha,
T beta) {}
DEFINE_GATHER_GEMM_SCATTER_DRIVER(phi::dtype::float16, fp16_kernels)
DEFINE_GATHER_GEMM_SCATTER_DRIVER(float, fp32_kernels)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(phi::dtype::float16,
fp16_nn_kernels,
false,
false)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(float,
fp32_nn_kernels,
false,
false)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(float,
fp32_nt_kernels,
false,
true)
EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(float,
fp32_tn_kernels,
true,
false)
} // namespace sparse
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册