未验证 提交 5158fa4f 编写于 作者: U umiswing 提交者: GitHub

summer-ospp 2022: 飞桨PaddlePaddle Sparse Conv开发和优化: gather-gemm-scatter fuse (#46679)

上级 60e0c506
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include(ExternalProject)
set(CUTLASS_PREFIX_DIR ${THIRD_PARTY_PATH}/cutlass)
set(CUTLASS_REPOSITORY https://github.com/NVIDIA/cutlass.git)
set(CUTLASS_TAG v2.9.1)
include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/")
include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/include/")
include_directories(
"${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/util/include/")
add_definitions("-DPADDLE_WITH_CUTLASS")
ExternalProject_Add(
extern_cutlass
${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE}
GIT_REPOSITORY ${CUTLASS_REPOSITORY}
GIT_TAG "${CUTLASS_TAG}"
PREFIX ${CUTLASS_PREFIX_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND "")
add_library(cutlass INTERFACE)
add_dependencies(cutlass extern_cutlass)
...@@ -505,4 +505,14 @@ if(WITH_CUSPARSELT) ...@@ -505,4 +505,14 @@ if(WITH_CUSPARSELT)
list(APPEND third_party_deps extern_cusparselt) list(APPEND third_party_deps extern_cusparselt)
endif() endif()
if(WITH_GPU
AND NOT WITH_ARM
AND NOT WIN32
AND NOT APPLE)
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0)
include(external/cutlass) # download, build, install cusparselt
list(APPEND third_party_deps extern_cutlass)
endif()
endif()
add_custom_target(third_party ALL DEPENDS ${third_party_deps}) add_custom_target(third_party ALL DEPENDS ${third_party_deps})
...@@ -22,6 +22,9 @@ limitations under the License. */ ...@@ -22,6 +22,9 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/scatter.cu.h" #include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h" #include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h" #include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
#ifdef PADDLE_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
#endif
#include "glog/logging.h" #include "glog/logging.h"
...@@ -120,29 +123,6 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -120,29 +123,6 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
dev_ctx, x, key, tmp_rulebook, h_counter, out, rulebook, counter); dev_ctx, x, key, tmp_rulebook, h_counter, out, rulebook, counter);
} }
// 2. gather
phi::DenseTensor in_features =
phi::Empty<T>(dev_ctx, {rulebook_len, in_channels});
phi::DenseTensor out_features =
phi::Empty<T>(dev_ctx, {rulebook_len, out_channels});
T* in_features_ptr = in_features.data<T>();
T* out_features_ptr = out_features.data<T>();
phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, &out_features, static_cast<T>(0.0f));
Gather<T, IntT>(dev_ctx,
x.values().data<T>(),
rulebook_ptr,
rulebook_len,
in_channels,
in_features_ptr);
// 3. call gemm for every werght
auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx);
auto* out_values = out->mutable_values();
T* out_values_ptr = out_values->data<T>();
set_zero(dev_ctx, out_values, static_cast<T>(0.0f));
if (subm) { if (subm) {
auto config = auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
...@@ -162,43 +142,152 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -162,43 +142,152 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
out_index_ptr, out_index_ptr,
unique_value_ptr); unique_value_ptr);
} }
#ifdef PADDLE_WITH_CUTLASS
bool cutlass = true;
if (dev_ctx.GetComputeCapability() < 80) cutlass = false;
if (in_channels % 4 != 0 || out_channels % 4 != 0) {
if (std::is_same<T, phi::dtype::float16>::value) cutlass = false;
if (std::is_same<T, float>::value) cutlass = false;
}
if (!std::is_same<IntT, int32_t>::value) cutlass = false;
if (cutlass) {
auto* out_values = out->mutable_non_zero_elements();
T* out_values_ptr = out_values->data<T>();
phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, out_values, static_cast<T>(0.0f));
const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
if (h_counter_ptr[i] <= 0) {
continue;
}
const T* kernel_ptr = kernel.data<T>(); const int M = h_counter_ptr[i];
for (int i = 0; i < kernel_size; i++) { const int K = in_channels;
if (h_counter_ptr[i] <= 0) { const int N = out_channels;
continue; const T* tmp_kernel_ptr = kernel_ptr + i * K * N;
const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i];
const IntT* scatter_indices =
rulebook_ptr + rulebook_len + h_offsets_ptr[i];
if constexpr (std::is_same<T, phi::dtype::float16>::value &&
std::is_same<IntT, int32_t>::value) {
fp16_gather_gemm_scatter gather_gemm_scatter =
getBestFp16Kernel(M, N, K);
gather_gemm_scatter(
dev_ctx,
reinterpret_cast<const cutlass::half_t*>(
x.non_zero_elements().data<T>()),
reinterpret_cast<const cutlass::half_t*>(tmp_kernel_ptr),
reinterpret_cast<cutlass::half_t*>(out_values_ptr),
reinterpret_cast<cutlass::half_t*>(out_values_ptr),
M,
N,
K,
static_cast<const int32_t*>(gather_indices),
static_cast<const int32_t*>(scatter_indices),
static_cast<cutlass::half_t>(1),
static_cast<cutlass::half_t>(1));
}
if constexpr (std::is_same<T, float>::value &&
std::is_same<IntT, int32_t>::value) {
fp32_gather_gemm_scatter gather_gemm_scatter =
getBestFp32Kernel(M, N, K);
gather_gemm_scatter(dev_ctx,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
out_values_ptr,
M,
N,
K,
gather_indices,
scatter_indices,
static_cast<T>(1),
static_cast<T>(1));
}
if constexpr (std::is_same<T, double>::value &&
std::is_same<IntT, int32_t>::value) {
fp64_gather_gemm_scatter gather_gemm_scatter =
getBestFp64Kernel(M, N, K);
gather_gemm_scatter(dev_ctx,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
out_values_ptr,
M,
N,
K,
gather_indices,
scatter_indices,
static_cast<T>(1),
static_cast<T>(1));
}
} }
} else {
#endif
// 2. gather
phi::DenseTensor in_features =
phi::Empty<T>(dev_ctx, {rulebook_len, in_channels});
phi::DenseTensor out_features =
phi::Empty<T>(dev_ctx, {rulebook_len, out_channels});
T* in_features_ptr = in_features.data<T>();
T* out_features_ptr = out_features.data<T>();
phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, &out_features, static_cast<T>(0.0f));
// call gemm: (n, in_channels) * (in_channels, out_channels) Gather<T, IntT>(dev_ctx,
const int M = h_counter_ptr[i]; x.values().data<T>(),
const int K = in_channels; rulebook_ptr,
const int N = out_channels; rulebook_len,
T* tmp_in_ptr = in_features_ptr + h_offsets_ptr[i] * in_channels; in_channels,
const T* tmp_kernel_ptr = kernel_ptr + i * K * N; in_features_ptr);
T* tmp_out_ptr = out_features_ptr + h_offsets_ptr[i] * out_channels;
// 3. call gemm for every werght
blas.GEMM(CblasNoTrans, auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx);
CblasNoTrans, auto* out_values = out->mutable_values();
M, T* out_values_ptr = out_values->data<T>();
N, set_zero(dev_ctx, out_values, static_cast<T>(0.0f));
K,
static_cast<T>(1),
tmp_in_ptr,
tmp_kernel_ptr,
static_cast<T>(0),
tmp_out_ptr);
}
// 4. scatter const T* kernel_ptr = kernel.data<T>();
phi::funcs::sparse::ScatterV2<T>(dev_ctx, for (int i = 0; i < kernel_size; i++) {
out_features_ptr, if (h_counter_ptr[i] <= 0) {
out_index.data<int>(), continue;
unique_value.data<int>(), }
out->nnz(),
kernel_size, // call gemm: (n, in_channels) * (in_channels, out_channels)
out_channels, const int M = h_counter_ptr[i];
1, const int K = in_channels;
out_values_ptr); const int N = out_channels;
T* tmp_in_ptr = in_features_ptr + h_offsets_ptr[i] * in_channels;
const T* tmp_kernel_ptr = kernel_ptr + i * K * N;
T* tmp_out_ptr = out_features_ptr + h_offsets_ptr[i] * out_channels;
blas.GEMM(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
static_cast<T>(1),
tmp_in_ptr,
tmp_kernel_ptr,
static_cast<T>(0),
tmp_out_ptr);
}
// 4. scatter
phi::funcs::sparse::ScatterV2<T>(dev_ctx,
out_features_ptr,
out_index.data<int>(),
unique_value.data<int>(),
out->nnz(),
kernel_size,
out_channels,
1,
out_values_ptr);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
} }
/** /**
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
namespace phi {
namespace sparse {
fp16_gather_gemm_scatter getBestFp16Kernel(const int M,
const int N,
const int K) {
if (K == 4 && N == 16) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4::Gemm>;
}
if (K == 16 && N == 16) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 16 && N == 32) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 32 && N == 32) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 32 && N == 64) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 64 && N == 64) {
if (M > 100000)
launchKernel<
cutlass::half_t,
cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8::Gemm>;
if (M > 20000)
launchKernel<
cutlass::half_t,
cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8::Gemm>;
if (M > 15000)
return launchKernel<
cutlass::half_t,
cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8::Gemm>;
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 128) {
if (M >= 5000)
return launchKernel<
cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
return launchKernel<cutlass::half_t,
cutlass_tensorop_h16816gemm_64x64_64x5_nn_align8::Gemm>;
}
if (N == 128) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4::Gemm>;
}
fp32_gather_gemm_scatter getBestFp32Kernel(const int M,
const int N,
const int K) {
if (K == 4 && N == 16) {
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 16 && N == 16) {
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 16 && N == 32) {
if (M >= 10000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 32 && N == 32) {
if (M >= 10000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 32 && N == 64) {
if (M >= 10000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 64 && N == 64) {
if (M >= 15000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 128) {
if (M >= 100000)
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4::Gemm>;
if (M >= 5000)
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4::Gemm>;
}
if (N == 128) {
if (M >= 100000)
return launchKernel<
float,
cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4::Gemm>;
if (M >= 5000)
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4::Gemm>;
}
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
fp64_gather_gemm_scatter getBestFp64Kernel(const int M,
const int N,
const int K) {
if (K == 4 && N == 16) {
return launchKernel<double,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1::Gemm>;
}
if (K == 16 && N == 16) {
if (M >= 10000)
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
return launchKernel<double,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1::Gemm>;
}
if (K == 16 && N == 32) {
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
if (K == 32 && N == 32) {
return launchKernel<double,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1::Gemm>;
}
if (K == 32 && N == 64) {
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
if (K == 64 && N == 64) {
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
} // namespace sparse
} // namespace phi
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_CUTLASS
#include "cutlass/arch/mma.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/util/device_memory.h"
#include "examples/common/helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace phi {
namespace sparse {
typedef void (*fp16_gather_gemm_scatter)(const GPUContext& dev_ctx,
const cutlass::half_t* const a,
const cutlass::half_t* const b,
const cutlass::half_t* const c,
cutlass::half_t* const d,
const int m,
const int n,
const int k,
const int32_t* a_indices,
const int32_t* c_d_indices,
cutlass::half_t const alpha,
cutlass::half_t const beta);
typedef void (*fp32_gather_gemm_scatter)(const GPUContext& dev_ctx,
const float* const a,
const float* const b,
const float* const c,
float* const d,
const int m,
const int n,
const int k,
const int32_t* a_indices,
const int32_t* c_d_indices,
float const alpha,
float const beta);
typedef void (*fp64_gather_gemm_scatter)(const GPUContext& dev_ctx,
const double* const a,
const double* const b,
const double* const c,
double* const d,
const int m,
const int n,
const int k,
const int32_t* a_indices,
const int32_t* c_d_indices,
double const alpha,
double const beta);
fp16_gather_gemm_scatter getBestFp16Kernel(const int M,
const int K,
const int N);
fp32_gather_gemm_scatter getBestFp32Kernel(const int M,
const int K,
const int N);
fp64_gather_gemm_scatter getBestFp64Kernel(const int M,
const int K,
const int N);
template <typename T, typename Gemm>
void launchKernel(const GPUContext& dev_ctx,
const T* const a,
const T* const b,
const T* const c,
T* const d,
const int m,
const int n,
const int k,
const int32_t* a_indices,
const int32_t* c_d_indices,
T const alpha,
T const beta) {
cutlass::gemm::GemmCoord problem_size_real({m, n, k});
int split_k_slices = 1;
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size_real,
split_k_slices,
{alpha, beta},
a,
b,
c,
d,
cutlass::layout::RowMajor().capacity(problem_size_real.mk()),
cutlass::layout::RowMajor().capacity(problem_size_real.kn()),
cutlass::layout::RowMajor().capacity(problem_size_real.mn()),
cutlass::layout::RowMajor().capacity(problem_size_real.mn()),
problem_size_real.k(),
problem_size_real.n(),
problem_size_real.n(),
problem_size_real.n(),
a_indices,
nullptr,
c_d_indices};
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
Gemm gemm_op;
cutlass::Status status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(status);
gemm_op(dev_ctx.stream());
}
struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
8,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h1688gemm_64x128_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
8,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
4,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
4,
4,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
4,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
4,
4,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
8,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h16816gemm_64x64_64x5_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
8,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
5,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::
LinearCombination<cutlass::half_t, 8, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::
LinearCombination<cutlass::half_t, 8, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
10,
4,
4,
cutlass::arch::OpMultiplyAddFastF16,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 16>,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
3,
4,
4,
cutlass::arch::OpMultiplyAddFastF16,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 16>,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
4,
4,
4,
cutlass::arch::OpMultiplyAddFastF16,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 16>,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
3,
4,
4,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 16>,
cutlass::gemm::GemmShape<32, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
6,
4,
4,
cutlass::arch::OpMultiplyAddFastF16,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
3,
4,
4,
cutlass::arch::OpMultiplyAddFastF32,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_d884gemm_16x32_16x5_nn_align1 {
using Gemm = cutlass::gemm::device::GemmUniversal<
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::RowMajor,
double,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<16, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<double, 1, double, double>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
5,
1,
1,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_d884gemm_32x16_16x5_nn_align1 {
using Gemm = cutlass::gemm::device::GemmUniversal<
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::RowMajor,
double,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 16, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<double, 1, double, double>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
5,
1,
1,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
} // namespace sparse
} // namespace phi
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册