未验证 提交 22bcb75a 编写于 作者: U umiswing 提交者: GitHub

remove if constexpr(), which is not supported on gcc54 (#50395)

上级 17d10a5d
......@@ -150,31 +150,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i];
const IntT* scatter_indices =
rulebook_ptr + rulebook_len + h_offsets_ptr[i];
if constexpr (std::is_same<T, phi::dtype::float16>::value &&
std::is_same<IntT, int32_t>::value) {
fp16_gather_gemm_scatter gather_gemm_scatter =
getBestFp16Kernel(M, N, K);
gather_gemm_scatter(
dev_ctx,
reinterpret_cast<const cutlass::half_t*>(
x.non_zero_elements().data<T>()),
reinterpret_cast<const cutlass::half_t*>(tmp_kernel_ptr),
reinterpret_cast<cutlass::half_t*>(out_values_ptr),
reinterpret_cast<cutlass::half_t*>(out_values_ptr),
M,
N,
K,
static_cast<const int32_t*>(gather_indices),
static_cast<const int32_t*>(scatter_indices),
static_cast<cutlass::half_t>(1),
static_cast<cutlass::half_t>(1));
}
if constexpr (std::is_same<T, float>::value &&
std::is_same<IntT, int32_t>::value) {
fp32_gather_gemm_scatter gather_gemm_scatter =
getBestFp32Kernel(M, N, K, dev_ctx.GetComputeCapability());
gather_gemm_scatter(dev_ctx,
dispatchKernel(dev_ctx,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
......@@ -184,26 +160,8 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
K,
gather_indices,
scatter_indices,
static_cast<T>(1),
static_cast<T>(1));
}
if constexpr (std::is_same<T, double>::value &&
std::is_same<IntT, int32_t>::value) {
fp64_gather_gemm_scatter gather_gemm_scatter =
getBestFp64Kernel(M, N, K);
gather_gemm_scatter(dev_ctx,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
out_values_ptr,
M,
N,
K,
gather_indices,
scatter_indices,
static_cast<T>(1),
static_cast<T>(1));
}
cutlass,
x.dtype());
}
} else {
#endif
......
......@@ -23,6 +23,7 @@
#include "cutlass/util/device_memory.h"
#include "examples/common/helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
namespace phi {
namespace sparse {
typedef void (*fp16_gather_gemm_scatter)(const GPUContext& dev_ctx,
......@@ -115,6 +116,66 @@ void launchKernel(const GPUContext& dev_ctx,
CUTLASS_CHECK(status);
gemm_op(dev_ctx.stream());
}
static void dispatchKernel(const GPUContext& dev_ctx,
const void* const a,
const void* const b,
const void* const c,
void* const d,
const int m,
const int n,
const int k,
const void* a_indices,
const void* c_d_indices,
const bool cutlass,
const phi::DataType type) {
if (!cutlass) return;
if (type == phi::DataType::FLOAT16) {
fp16_gather_gemm_scatter gather_gemm_scatter = getBestFp16Kernel(m, n, k);
gather_gemm_scatter(dev_ctx,
static_cast<const cutlass::half_t*>(a),
static_cast<const cutlass::half_t*>(b),
static_cast<const cutlass::half_t*>(c),
static_cast<cutlass::half_t*>(d),
m,
n,
k,
static_cast<const int32_t*>(a_indices),
static_cast<const int32_t*>(c_d_indices),
static_cast<cutlass::half_t>(1),
static_cast<cutlass::half_t>(1));
} else if (type == phi::DataType::FLOAT32) {
fp32_gather_gemm_scatter gather_gemm_scatter =
getBestFp32Kernel(m, n, k, dev_ctx.GetComputeCapability());
gather_gemm_scatter(dev_ctx,
static_cast<const float*>(a),
static_cast<const float*>(b),
static_cast<const float*>(c),
static_cast<float*>(d),
m,
n,
k,
static_cast<const int32_t*>(a_indices),
static_cast<const int32_t*>(c_d_indices),
static_cast<float>(1),
static_cast<float>(1));
} else if (type == phi::DataType::FLOAT64) {
fp64_gather_gemm_scatter gather_gemm_scatter = getBestFp64Kernel(m, n, k);
gather_gemm_scatter(dev_ctx,
static_cast<const double*>(a),
static_cast<const double*>(b),
static_cast<const double*>(c),
static_cast<double*>(d),
m,
n,
k,
static_cast<const int32_t*>(a_indices),
static_cast<const int32_t*>(c_d_indices),
static_cast<double>(1),
static_cast<double>(1));
}
}
struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册