未验证 提交 8c5e432b 编写于 作者: Z zhangkaihuo 提交者: GitHub

[cherry-pick] Optimize sparse kernel and fix some bug (#50118)

cherry-pick some PR about optimize sparse kernel and fix some bug:
#47736 #47703 #47604 #46679 #48439 #49009 #49734
上级 e32ff656
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include(ExternalProject)
set(CUTLASS_PREFIX_DIR ${THIRD_PARTY_PATH}/cutlass)
set(CUTLASS_REPOSITORY https://github.com/NVIDIA/cutlass.git)
set(CUTLASS_TAG v2.9.1)
include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/")
include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/include/")
include_directories(
"${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/util/include/")
add_definitions("-DPADDLE_WITH_CUTLASS")
ExternalProject_Add(
extern_cutlass
${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE}
GIT_REPOSITORY ${CUTLASS_REPOSITORY}
GIT_TAG "${CUTLASS_TAG}"
PREFIX ${CUTLASS_PREFIX_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND "")
add_library(cutlass INTERFACE)
add_dependencies(cutlass extern_cutlass)
......@@ -492,4 +492,14 @@ if(WITH_CUSPARSELT)
list(APPEND third_party_deps extern_cusparselt)
endif()
if(WITH_GPU
AND NOT WITH_ARM
AND NOT WIN32
AND NOT APPLE)
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0)
include(external/cutlass) # download, build, install cusparselt
list(APPEND third_party_deps extern_cutlass)
endif()
endif()
add_custom_target(third_party ALL DEPENDS ${third_party_deps})
......@@ -18,6 +18,10 @@ limitations under the License. */
namespace phi {
namespace funcs {
#define CUDNN_PER_ACTIVATION_THRESHOLD 10240
#define CUDNN_SPATIAL_THRESHOLD_TRAIN 880801
#define CUDNN_SPATIAL_THRESHOLD_EVAL 65535
inline void ExtractNCWHD(const phi::DDim &dims,
const DataLayout &data_layout,
int *N,
......
......@@ -26,6 +26,19 @@ __global__ void DistanceKernel(const T* start, const T* end, T* distance) {
}
}
inline __device__ bool SetBits(const int value, int* ptr) {
const int index = value >> 5;
const int mask = 1 << (value & 31);
const int old = atomicOr(ptr + index, mask);
return (mask & old) != 0;
}
inline __device__ bool TestBits(const int value, const int* ptr) {
const int index = value >> 5;
const int mask = 1 << (value & 31);
return (mask & ptr[index]) != 0;
}
} // namespace sparse
} // namespace funcs
} // namespace phi
......@@ -852,15 +852,17 @@ void BatchNormGradRawKernel(const Context &ctx,
// ctx.GetPlace()),
// epsilon, saved_mean_data, saved_var_data));
#else
// CUDNN only support small batch size
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
if (use_native_kernel) {
if (x_dims.size() == 2) {
}
// CUDNN only support small batch size
bool use_native_nhwc =
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC)
: false;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
if (use_native_nhwc || (d_x && d_scale && d_bias)) {
if (use_native_kernel || use_native_nhwc) {
if (x_dims.size() == 2 || use_native_nhwc) {
dim3 block;
dim3 grid;
const int block_size = 512;
......
......@@ -72,6 +72,40 @@ static __global__ void BNForwardInference(const T *x,
}
}
template <typename T>
static __global__ void InverseVariance(const BatchNormParamType<T> *variance,
const double epsilon,
const int C,
BatchNormParamType<T> *inv_variance) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < C) {
inv_variance[tid] = 1 / sqrt(variance[tid] + epsilon);
}
}
template <typename T, phi::DataLayout layout>
static __global__ void BN1DForwardInference(
const T *x,
const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *inv_variance,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *bias,
const int C,
const int N,
const int HxW,
const double epsilon,
T *y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int num = N * C * HxW;
for (int i = gid; i < num; i += stride) {
const int c = layout == phi::DataLayout::kNCHW ? i / HxW % C : i % C;
BatchNormParamType<T> x_sub_mean =
static_cast<BatchNormParamType<T>>(x[i]) - mean[c];
y[i] = static_cast<T>(scale[c] * x_sub_mean * inv_variance[c] + bias[c]);
}
}
template <typename T, int BlockDim, phi::DataLayout layout>
static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining(
const T *x,
......@@ -691,9 +725,6 @@ void BatchNormKernel(const Context &ctx,
auto handle = ctx.cudnn_handle();
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
// Now, depending on whether we are running test or not, we have two paths.
// It is training mode when it's not reference AND not using pre-trained
// model.
......@@ -797,8 +828,8 @@ void BatchNormKernel(const Context &ctx,
// epsilon));
#else
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
(x_dims.size() == 2 ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_EVAL));
if (use_native_kernel) {
const int block_size = 256;
const int grid_size = (N * C * H * W * D + block_size - 1) / block_size;
......@@ -816,18 +847,43 @@ void BatchNormKernel(const Context &ctx,
epsilon,
transformed_y.template data<T>());
} else {
BNForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(),
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
if (x_dims.size() == 2) {
DenseTensor inv_var = phi::Empty<BatchNormParamType<T>>(ctx, {C});
auto *inv_var_ptr = inv_var.data<BatchNormParamType<T>>();
const int threads = 512 > C ? C : 512;
const int blocks = (C + 511) / 512;
InverseVariance<T><<<blocks, threads>>>(
est_var->template data<BatchNormParamType<T>>(),
epsilon,
C,
inv_var_ptr);
BN1DForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
// est_var->template data<BatchNormParamType<T>>(),
inv_var_ptr,
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
} else {
BNForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(),
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
}
}
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
......@@ -949,7 +1005,7 @@ void BatchNormKernel(const Context &ctx,
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
if (use_native_kernel) {
dim3 block;
dim3 grid;
......
......@@ -15,8 +15,14 @@ limitations under the License. */
#pragma once
#include <thrust/remove.h>
#include <thrust/sort.h>
#include <thrust/unique.h>
#ifdef __NVCC__
#include <cub/block/block_scan.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/phi/kernels/sparse/conv_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
......@@ -167,7 +173,7 @@ inline void GatherV2(const GPUContext& dev_ctx,
template <typename IntT>
__global__ void UniqueKernel(const IntT* in_indexs,
const int rulebook_len,
int* out_index_table,
int* index_flags,
int* out_indexs,
int* nnz) {
extern __shared__ int cache[];
......@@ -182,8 +188,8 @@ __global__ void UniqueKernel(const IntT* in_indexs,
if (i < rulebook_len) {
// atomicOr only support int
int index = static_cast<int>(in_indexs[i]);
int flag = atomicOr(out_index_table + index, 1);
if (flag == 0) {
const bool flag = phi::funcs::sparse::SetBits(index, index_flags);
if (!flag) {
int j = atomicAdd(&count, 1);
cache[j] = index;
}
......@@ -199,6 +205,88 @@ __global__ void UniqueKernel(const IntT* in_indexs,
}
}
inline __device__ uint32_t BitCount(const uint32_t data) {
uint32_t count = data;
count = (count & 0x55555555) + ((count >> 1) & 0x55555555);
count = (count & 0x33333333) + ((count >> 2) & 0x33333333);
count = (count & 0x0f0f0f0f) + ((count >> 4) & 0x0f0f0f0f);
count = (count & 0x00ff00ff) + ((count >> 8) & 0x00ff00ff);
count = (count & 0x0000ffff) + ((count >> 16) & 0x0000ffff);
return count;
}
static __global__ void GetOutIndexsCounter(const int* flags,
const int n,
int* out) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
__shared__ int block_count;
if (threadIdx.x == 0) {
block_count = 0;
}
__syncthreads();
if (tid < n) {
// get the count of 1 in flags[tid]
uint32_t count = BitCount(static_cast<uint32_t>(flags[tid]));
// add to block_count
// TODO(zhangkaihuo): replace with block reduce_sum
atomicAdd(&block_count, static_cast<int>(count));
}
__syncthreads();
// write to out
if (threadIdx.x == 0) {
out[blockIdx.x] = block_count;
}
}
template <int BS>
__global__ void GetOutIndexs(const int* flags,
const int n,
const int* offsets,
const int out_nnz,
int* out) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
__shared__ int block_counts[BS];
__shared__ int block_outs[BS * 32];
int count = 0;
if (tid < n) {
// get the count of 1 in flags[tid]
int flag = flags[tid];
count = BitCount(static_cast<uint32_t>(flag));
}
// call block prefix_sum
// using namespace cub;
typedef cub::BlockScan<int, BS> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
BlockScan(temp_storage).ExclusiveSum(count, count);
__syncthreads();
// write index to out
if (tid < n) {
// get the count of 1 in flags[tid]
int flag = flags[tid];
// int j = block_counts[threadIdx.x];
int j = count;
// TODO(zhangkaihuo): opt the loop
for (int i = 0; i < 32; ++i) {
if ((1 & (flag >> i)) == 1) {
block_outs[j++] = (tid << 5) + i;
}
}
}
__syncthreads();
// write to block_outs
int start = offsets[blockIdx.x];
int end = blockIdx.x == gridDim.x - 1 ? out_nnz : offsets[blockIdx.x + 1];
for (int i = threadIdx.x; i < end - start; i += blockDim.x) {
out[start + i] = block_outs[i];
}
}
template <typename IntT>
__global__ void GroupIndexs(const int* out_index_table,
const int n,
......@@ -284,7 +372,6 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
atomicAdd(&counter_buf[kernel_index], 1);
kernel_i = kernel_index;
}
// rulebook[kernel_index * non_zero_num + i] = kernel_i;
rulebook[kernel_index * non_zero_num + i] = in_i;
rulebook[kernel_index * non_zero_num + offset + i] = out_index;
++kernel_index;
......@@ -299,17 +386,19 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
}
template <typename IntT>
__global__ void GetOutIndexTable(const IntT* indices,
const IntT non_zero_num,
const Dims4D dims,
int* out_index_table) {
__global__ void GetOutIndexTable1(const IntT* indices,
const IntT non_zero_num,
const Dims4D dims,
int* index_flags,
int* out_index_table) {
CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) {
IntT batch = indices[i];
IntT in_z = indices[i + non_zero_num];
IntT in_y = indices[i + 2 * non_zero_num];
IntT in_x = indices[i + 3 * non_zero_num];
IntT index = PointToIndex(batch, in_x, in_y, in_z, dims);
out_index_table[index] = i == 0 ? -1 : i;
phi::funcs::sparse::SetBits(index, index_flags);
out_index_table[index] = i;
}
}
......@@ -375,6 +464,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
const Dims4D paddings,
const Dims4D dilations,
const Dims4D strides,
const int* index_flags,
const int* out_index_table,
T* rulebook,
int* counter) {
......@@ -417,9 +507,10 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
T out_x = (in_x + paddings[3] - kx * dilations[3]) / strides[3];
out_index = phi::funcs::sparse::PointToIndex<Dims4D>(
batch, out_x, out_y, out_z, out_dims);
int real_out_index = out_index_table[out_index];
if (real_out_index != 0) {
real_out_index = real_out_index == -1 ? 0 : real_out_index;
const bool flag =
phi::funcs::sparse::TestBits(out_index, index_flags);
if (flag) {
int real_out_index = out_index_table[out_index];
in_i = i;
int buf_i = atomicAdd(&counter_buf[kernel_index], 1);
kernel_i = kernel_index;
......@@ -440,7 +531,6 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
__syncthreads();
for (int i = 0; i < kernel_size; i++) {
if (threadIdx.x < counter_buf[i]) {
// rulebook[i * non_zero_num + counter_buf2[i] + threadIdx.x] = i;
rulebook[i * non_zero_num + counter_buf2[i] + threadIdx.x] =
rulebook_buf[i * blockDim.x + threadIdx.x];
rulebook[i * non_zero_num + offset + counter_buf2[i] + threadIdx.x] =
......@@ -575,12 +665,18 @@ int ProductRuleBook(const Context& dev_ctx,
DenseTensorMeta rulebook_meta(
indices_dtype, {rulebook_rows, rulebook_cols}, DataLayout::NCHW);
int64_t table_size = 1;
int table_size = 1;
for (int i = 0; i < out_dims.size() - 1; i++) {
table_size *= out_dims[i];
}
DenseTensor out_index_table = phi::Empty<int>(dev_ctx, {table_size});
int* out_index_table_ptr = out_index_table.data<int>();
// index_flags: flag the indices exist or not
int index_flags_size = (table_size + 31) / 32;
DenseTensor index_flags = phi::Empty<int>(dev_ctx, {index_flags_size});
int* index_flags_ptr = index_flags.data<int>();
phi::backends::gpu::GpuMemsetAsync(
index_flags_ptr, 0, sizeof(int) * index_flags.numel(), dev_ctx.stream());
if (subm) {
DenseTensor tmp_rulebook = phi::Empty(dev_ctx, std::move(rulebook_meta));
......@@ -590,16 +686,16 @@ int ProductRuleBook(const Context& dev_ctx,
phi::Copy(dev_ctx, x.indices(), dev_ctx.GetPlace(), false, &out_indices);
phi::backends::gpu::GpuMemsetAsync(
out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream());
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);
GetOutIndexTable<IntT><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(
out_indices.data<IntT>(), non_zero_num, d_x_dims, out_index_table_ptr);
GetOutIndexTable1<IntT><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(out_indices.data<IntT>(),
non_zero_num,
d_x_dims,
index_flags_ptr,
out_index_table_ptr);
size_t cache_size =
kernel_size * 2 * sizeof(int) +
......@@ -625,6 +721,7 @@ int ProductRuleBook(const Context& dev_ctx,
d_paddings,
d_dilations,
d_strides,
index_flags_ptr,
out_index_table_ptr,
rulebook_ptr,
counter_ptr);
......@@ -695,9 +792,6 @@ int ProductRuleBook(const Context& dev_ctx,
int* out_index_ptr = out_index->data<int>();
int* unique_key_ptr = unique_key.data<int>();
phi::backends::gpu::GpuMemsetAsync(
out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream());
phi::backends::gpu::GpuMemsetAsync(
unique_key_ptr, 0, sizeof(int), dev_ctx.stream());
......@@ -708,7 +802,7 @@ int ProductRuleBook(const Context& dev_ctx,
cache_size,
dev_ctx.stream()>>>(rulebook_ptr + rulebook_len,
rulebook_len,
out_index_table_ptr,
index_flags_ptr,
out_index_ptr,
unique_key_ptr);
......@@ -719,13 +813,25 @@ int ProductRuleBook(const Context& dev_ctx,
gpuMemcpyDeviceToHost,
dev_ctx.stream());
dev_ctx.Wait();
const int threads = 256;
const int blocks = (index_flags.numel() + threads - 1) / threads;
GetOutIndexsCounter<<<blocks, threads, 0, dev_ctx.stream()>>>(
index_flags_ptr, index_flags.numel(), out_index_table_ptr);
#ifdef PADDLE_WITH_HIP
thrust::sort(thrust::hip::par.on(dev_ctx.stream()),
thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::sort(thrust::cuda::par.on(dev_ctx.stream()),
thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()),
#endif
out_index_ptr,
out_index_ptr + out_nnz);
out_index_table_ptr,
out_index_table_ptr + blocks,
out_index_table_ptr);
GetOutIndexs<threads>
<<<blocks, threads, 0, dev_ctx.stream()>>>(index_flags_ptr,
index_flags.numel(),
out_index_table_ptr,
out_nnz,
out_index_ptr);
const int64_t sparse_dim = 4;
phi::DenseTensor out_indices =
......
......@@ -22,6 +22,9 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
#ifdef PADDLE_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
#endif
#include "glog/logging.h"
......@@ -120,85 +123,171 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
dev_ctx, x, key, tmp_rulebook, h_counter, out, rulebook, counter);
}
// 2. gather
phi::DenseTensor in_features =
phi::Empty<T>(dev_ctx, {rulebook_len, in_channels});
phi::DenseTensor out_features =
phi::Empty<T>(dev_ctx, {rulebook_len, out_channels});
T* in_features_ptr = in_features.data<T>();
T* out_features_ptr = out_features.data<T>();
phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, &out_features, static_cast<T>(0.0f));
Gather<T, IntT>(dev_ctx,
x.values().data<T>(),
rulebook_ptr,
rulebook_len,
in_channels,
in_features_ptr);
// 3. call gemm for every werght
auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx);
auto* out_values = out->mutable_values();
T* out_values_ptr = out_values->data<T>();
set_zero(dev_ctx, out_values, static_cast<T>(0.0f));
if (subm) {
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
unique_value.ResizeAndAllocate(
{static_cast<int>(out->nnz() * kernel_size)});
out_index.ResizeAndAllocate({static_cast<int>(rulebook_len)});
int* out_index_ptr = out_index.data<int>();
int* unique_value_ptr = unique_value.data<int>();
phi::backends::gpu::GpuMemsetAsync(
out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream());
GroupIndexs<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(rulebook_len,
kernel_size,
rulebook_ptr + rulebook_len,
out_index_ptr,
unique_value_ptr);
#ifdef PADDLE_WITH_CUTLASS
bool cutlass = true;
if (dev_ctx.GetComputeCapability() < 75) cutlass = false;
if (in_channels % 4 != 0 || out_channels % 4 != 0) {
if (std::is_same<T, phi::dtype::float16>::value) cutlass = false;
if (std::is_same<T, float>::value) cutlass = false;
}
if (!std::is_same<IntT, int32_t>::value) cutlass = false;
if (cutlass) {
auto* out_values = out->mutable_non_zero_elements();
T* out_values_ptr = out_values->data<T>();
phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, out_values, static_cast<T>(0.0f));
const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
if (h_counter_ptr[i] <= 0) {
continue;
}
const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
if (h_counter_ptr[i] <= 0) {
continue;
const int M = h_counter_ptr[i];
const int K = in_channels;
const int N = out_channels;
const T* tmp_kernel_ptr = kernel_ptr + i * K * N;
const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i];
const IntT* scatter_indices =
rulebook_ptr + rulebook_len + h_offsets_ptr[i];
if constexpr (std::is_same<T, phi::dtype::float16>::value &&
std::is_same<IntT, int32_t>::value) {
fp16_gather_gemm_scatter gather_gemm_scatter =
getBestFp16Kernel(M, N, K);
gather_gemm_scatter(
dev_ctx,
reinterpret_cast<const cutlass::half_t*>(
x.non_zero_elements().data<T>()),
reinterpret_cast<const cutlass::half_t*>(tmp_kernel_ptr),
reinterpret_cast<cutlass::half_t*>(out_values_ptr),
reinterpret_cast<cutlass::half_t*>(out_values_ptr),
M,
N,
K,
static_cast<const int32_t*>(gather_indices),
static_cast<const int32_t*>(scatter_indices),
static_cast<cutlass::half_t>(1),
static_cast<cutlass::half_t>(1));
}
if constexpr (std::is_same<T, float>::value &&
std::is_same<IntT, int32_t>::value) {
fp32_gather_gemm_scatter gather_gemm_scatter =
getBestFp32Kernel(M, N, K, dev_ctx.GetComputeCapability());
gather_gemm_scatter(dev_ctx,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
out_values_ptr,
M,
N,
K,
gather_indices,
scatter_indices,
static_cast<T>(1),
static_cast<T>(1));
}
if constexpr (std::is_same<T, double>::value &&
std::is_same<IntT, int32_t>::value) {
fp64_gather_gemm_scatter gather_gemm_scatter =
getBestFp64Kernel(M, N, K);
gather_gemm_scatter(dev_ctx,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
out_values_ptr,
M,
N,
K,
gather_indices,
scatter_indices,
static_cast<T>(1),
static_cast<T>(1));
}
}
} else {
#endif
if (subm) {
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
unique_value.ResizeAndAllocate(
{static_cast<int>(out->nnz() * kernel_size)});
out_index.ResizeAndAllocate({static_cast<int>(rulebook_len)});
int* out_index_ptr = out_index.data<int>();
int* unique_value_ptr = unique_value.data<int>();
phi::backends::gpu::GpuMemsetAsync(
out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream());
GroupIndexs<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(rulebook_len,
kernel_size,
rulebook_ptr + rulebook_len,
out_index_ptr,
unique_value_ptr);
}
// 2. gather
phi::DenseTensor in_features =
phi::Empty<T>(dev_ctx, {rulebook_len, in_channels});
phi::DenseTensor out_features =
phi::Empty<T>(dev_ctx, {rulebook_len, out_channels});
T* in_features_ptr = in_features.data<T>();
T* out_features_ptr = out_features.data<T>();
phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, &out_features, static_cast<T>(0.0f));
// call gemm: (n, in_channels) * (in_channels, out_channels)
const int M = h_counter_ptr[i];
const int K = in_channels;
const int N = out_channels;
T* tmp_in_ptr = in_features_ptr + h_offsets_ptr[i] * in_channels;
const T* tmp_kernel_ptr = kernel_ptr + i * K * N;
T* tmp_out_ptr = out_features_ptr + h_offsets_ptr[i] * out_channels;
blas.GEMM(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
static_cast<T>(1),
tmp_in_ptr,
tmp_kernel_ptr,
static_cast<T>(0),
tmp_out_ptr);
}
Gather<T, IntT>(dev_ctx,
x.values().data<T>(),
rulebook_ptr,
rulebook_len,
in_channels,
in_features_ptr);
// 4. scatter
phi::funcs::sparse::ScatterV2<T>(dev_ctx,
out_features_ptr,
out_index.data<int>(),
unique_value.data<int>(),
out->nnz(),
kernel_size,
out_channels,
1,
out_values_ptr);
// 3. call gemm for every werght
auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx);
auto* out_values = out->mutable_values();
T* out_values_ptr = out_values->data<T>();
set_zero(dev_ctx, out_values, static_cast<T>(0.0f));
const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
if (h_counter_ptr[i] <= 0) {
continue;
}
// call gemm: (n, in_channels) * (in_channels, out_channels)
const int M = h_counter_ptr[i];
const int K = in_channels;
const int N = out_channels;
T* tmp_in_ptr = in_features_ptr + h_offsets_ptr[i] * in_channels;
const T* tmp_kernel_ptr = kernel_ptr + i * K * N;
T* tmp_out_ptr = out_features_ptr + h_offsets_ptr[i] * out_channels;
blas.GEMM(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
static_cast<T>(1),
tmp_in_ptr,
tmp_kernel_ptr,
static_cast<T>(0),
tmp_out_ptr);
}
// 4. scatter
phi::funcs::sparse::ScatterV2<T>(dev_ctx,
out_features_ptr,
out_index.data<int>(),
unique_value.data<int>(),
out->nnz(),
kernel_size,
out_channels,
1,
out_values_ptr);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
}
/**
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
namespace phi {
namespace sparse {
fp16_gather_gemm_scatter getBestFp16Kernel(const int M,
const int N,
const int K) {
if (K == 4 && N == 16) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4::Gemm>;
}
if (K == 16 && N == 16) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 16 && N == 32) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 32 && N == 32) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 32 && N == 64) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 64 && N == 64) {
if (M > 100000)
launchKernel<
cutlass::half_t,
cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8::Gemm>;
if (M > 20000)
launchKernel<
cutlass::half_t,
cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8::Gemm>;
if (M > 15000)
return launchKernel<
cutlass::half_t,
cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8::Gemm>;
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 128) {
if (M >= 5000)
return launchKernel<
cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
return launchKernel<cutlass::half_t,
cutlass_tensorop_h16816gemm_64x64_64x5_nn_align8::Gemm>;
}
if (N == 128) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4::Gemm>;
}
fp32_gather_gemm_scatter getBestFp32Kernel(const int M,
const int N,
const int K,
const int SM) {
if (SM == 75) {
return launchKernel<
float,
cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4::Gemm>;
}
if (K == 4 && N == 16) {
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 16 && N == 16) {
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 16 && N == 32) {
if (M >= 10000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 32 && N == 32) {
if (M >= 10000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 32 && N == 64) {
if (M >= 10000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 64 && N == 64) {
if (M >= 15000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 128) {
if (M >= 100000)
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4::Gemm>;
if (M >= 5000)
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4::Gemm>;
}
if (N == 128) {
if (M >= 100000)
return launchKernel<
float,
cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4::Gemm>;
if (M >= 5000)
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4::Gemm>;
}
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
fp64_gather_gemm_scatter getBestFp64Kernel(const int M,
const int N,
const int K) {
if (K == 4 && N == 16) {
return launchKernel<double,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1::Gemm>;
}
if (K == 16 && N == 16) {
if (M >= 10000)
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
return launchKernel<double,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1::Gemm>;
}
if (K == 16 && N == 32) {
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
if (K == 32 && N == 32) {
return launchKernel<double,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1::Gemm>;
}
if (K == 32 && N == 64) {
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
if (K == 64 && N == 64) {
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
} // namespace sparse
} // namespace phi
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_CUTLASS
#include "cutlass/arch/mma.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/util/device_memory.h"
#include "examples/common/helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace phi {
namespace sparse {
typedef void (*fp16_gather_gemm_scatter)(const GPUContext& dev_ctx,
const cutlass::half_t* const a,
const cutlass::half_t* const b,
const cutlass::half_t* const c,
cutlass::half_t* const d,
const int m,
const int n,
const int k,
const int32_t* a_indices,
const int32_t* c_d_indices,
cutlass::half_t const alpha,
cutlass::half_t const beta);
typedef void (*fp32_gather_gemm_scatter)(const GPUContext& dev_ctx,
const float* const a,
const float* const b,
const float* const c,
float* const d,
const int m,
const int n,
const int k,
const int32_t* a_indices,
const int32_t* c_d_indices,
float const alpha,
float const beta);
typedef void (*fp64_gather_gemm_scatter)(const GPUContext& dev_ctx,
const double* const a,
const double* const b,
const double* const c,
double* const d,
const int m,
const int n,
const int k,
const int32_t* a_indices,
const int32_t* c_d_indices,
double const alpha,
double const beta);
fp16_gather_gemm_scatter getBestFp16Kernel(const int M,
const int K,
const int N);
fp32_gather_gemm_scatter getBestFp32Kernel(const int M,
const int K,
const int N,
const int SM);
fp64_gather_gemm_scatter getBestFp64Kernel(const int M,
const int K,
const int N);
template <typename T, typename Gemm>
void launchKernel(const GPUContext& dev_ctx,
const T* const a,
const T* const b,
const T* const c,
T* const d,
const int m,
const int n,
const int k,
const int32_t* a_indices,
const int32_t* c_d_indices,
T const alpha,
T const beta) {
cutlass::gemm::GemmCoord problem_size_real({m, n, k});
int split_k_slices = 1;
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size_real,
split_k_slices,
{alpha, beta},
a,
b,
c,
d,
cutlass::layout::RowMajor().capacity(problem_size_real.mk()),
cutlass::layout::RowMajor().capacity(problem_size_real.kn()),
cutlass::layout::RowMajor().capacity(problem_size_real.mn()),
cutlass::layout::RowMajor().capacity(problem_size_real.mn()),
problem_size_real.k(),
problem_size_real.n(),
problem_size_real.n(),
problem_size_real.n(),
a_indices,
nullptr,
c_d_indices};
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
Gemm gemm_op;
cutlass::Status status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(status);
gemm_op(dev_ctx.stream());
}
struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
8,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h1688gemm_64x128_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
8,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
4,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
4,
4,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
4,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
4,
4,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
8,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h16816gemm_64x64_64x5_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
8,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
5,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::
LinearCombination<cutlass::half_t, 8, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::
LinearCombination<cutlass::half_t, 8, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
10,
4,
4,
cutlass::arch::OpMultiplyAddFastF16,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 16>,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
3,
4,
4,
cutlass::arch::OpMultiplyAddFastF16,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 16>,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
4,
4,
4,
cutlass::arch::OpMultiplyAddFastF16,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 16>,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
3,
4,
4,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 16>,
cutlass::gemm::GemmShape<32, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
6,
4,
4,
cutlass::arch::OpMultiplyAddFastF16,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
3,
4,
4,
cutlass::arch::OpMultiplyAddFastF32,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_d884gemm_16x32_16x5_nn_align1 {
using Gemm = cutlass::gemm::device::GemmUniversal<
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::RowMajor,
double,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<16, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<double, 1, double, double>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
5,
1,
1,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_d884gemm_32x16_16x5_nn_align1 {
using Gemm = cutlass::gemm::device::GemmUniversal<
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::RowMajor,
double,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 16, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<double, 1, double, double>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
5,
1,
1,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
// sm75
struct cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd>;
};
} // namespace sparse
} // namespace phi
#endif
......@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h"
#include "paddle/phi/kernels/funcs/sparse/utils.cu.h"
namespace phi {
namespace sparse {
......@@ -118,15 +119,20 @@ void SparseMaskKernel(const Context& dev_ctx,
}
template <typename IntT>
__global__ void MaskTable(const IntT* x_indexs, const int n, int* table) {
__global__ void MaskTable(const IntT* x_indexs,
const int n,
int* index_flags,
int* table) {
CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) {
int index = x_indexs[i];
table[index] = i == 0 ? -1 : i;
phi::funcs::sparse::SetBits(index, index_flags);
table[index] = i;
}
}
template <typename T, typename IntT, int VecSize>
__global__ void MaskCopy(const IntT* mask_indexs,
const int* index_flags,
const int* table,
const int n,
const int stride,
......@@ -135,9 +141,10 @@ __global__ void MaskCopy(const IntT* mask_indexs,
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) {
int j = table[mask_indexs[i]];
if (j != 0) {
if (j == -1) j = 0;
const int mask_index = mask_indexs[i];
const bool flag = phi::funcs::sparse::TestBits(mask_index, index_flags);
if (flag) {
int j = table[mask_index];
for (int k = 0; k < stride; k += VecSize) {
LoadT vec_x;
phi::Load<T, VecSize>(x_values + j * stride + k, &vec_x);
......@@ -217,12 +224,15 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
int table_size = 1;
auto x_dims = x.dims();
for (int i = 0; i < x_dims.size() - 1; i++) {
for (int i = 0; i < sparse_dim; i++) {
table_size *= x_dims[i];
}
DenseTensor table = phi::Empty<int>(dev_ctx, {table_size});
phi::backends::gpu::GpuMemsetAsync(
table.data<int>(), 0, table_size * sizeof(int), dev_ctx.stream());
DenseTensor index_flags = phi::Empty<int>(dev_ctx, {(table_size + 31) / 32});
phi::backends::gpu::GpuMemsetAsync(index_flags.data<int>(),
0,
index_flags.numel() * sizeof(int),
dev_ctx.stream());
const int64_t stride =
x.dims().size() == sparse_dim ? 1 : x.values().dims()[1];
*out = phi::EmptyLike<T>(dev_ctx, x.values());
......@@ -234,8 +244,10 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
MaskTable<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(
x_indexs_ptr, x_indexs.numel(), table.data<int>());
dev_ctx.stream()>>>(x_indexs_ptr,
x_indexs.numel(),
index_flags.data<int>(),
table.data<int>());
config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1);
......@@ -246,6 +258,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
config.thread_per_block,
0,
dev_ctx.stream()>>>(mask_indexs_ptr,
index_flags.data<int>(),
table.data<int>(),
mask_indexs.numel(),
stride,
......@@ -256,6 +269,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
config.thread_per_block,
0,
dev_ctx.stream()>>>(mask_indexs_ptr,
index_flags.data<int>(),
table.data<int>(),
mask_indexs.numel(),
stride,
......
......@@ -64,42 +64,50 @@ class TestSparseElementWiseAPI(unittest.TestCase):
csr_y = s_dense_y.to_sparse_csr()
actual_res = get_actual_res(csr_x, csr_y, op)
actual_res.backward(actual_res)
expect_res = op(dense_x, dense_y)
expect_res.backward(expect_res)
np.testing.assert_allclose(expect_res.numpy(),
actual_res.to_dense().numpy(),
rtol=1e-05,
equal_nan=True)
np.testing.assert_allclose(
expect_res.numpy(),
actual_res.to_dense().numpy(),
rtol=1e-05,
equal_nan=True,
)
if not (op == __truediv__ and dtype in ['int32', 'int64']):
np.testing.assert_allclose(dense_x.grad.numpy(),
csr_x.grad.to_dense().numpy(),
rtol=1e-05,
equal_nan=True)
np.testing.assert_allclose(dense_y.grad.numpy(),
csr_y.grad.to_dense().numpy(),
rtol=1e-05,
equal_nan=True)
actual_res.backward(actual_res)
np.testing.assert_allclose(
dense_x.grad.numpy(),
csr_x.grad.to_dense().numpy(),
rtol=1e-05,
equal_nan=True,
)
np.testing.assert_allclose(
dense_y.grad.numpy(),
csr_y.grad.to_dense().numpy(),
rtol=1e-05,
equal_nan=True,
)
def func_test_coo(self, op):
for sparse_dim in range(len(self.coo_shape) - 1, len(self.coo_shape)):
for dtype in self.support_dtypes:
x = np.random.randint(-255, 255,
size=self.coo_shape).astype(dtype)
y = np.random.randint(-255, 255,
size=self.coo_shape).astype(dtype)
x = np.random.randint(-255, 255, size=self.coo_shape).astype(
dtype
)
y = np.random.randint(-255, 255, size=self.coo_shape).astype(
dtype
)
dense_x = paddle.to_tensor(x, dtype=dtype, stop_gradient=False)
dense_y = paddle.to_tensor(y, dtype=dtype, stop_gradient=False)
s_dense_x = paddle.to_tensor(x,
dtype=dtype,
stop_gradient=False)
s_dense_y = paddle.to_tensor(y,
dtype=dtype,
stop_gradient=False)
s_dense_x = paddle.to_tensor(
x, dtype=dtype, stop_gradient=False
)
s_dense_y = paddle.to_tensor(
y, dtype=dtype, stop_gradient=False
)
coo_x = s_dense_x.to_sparse_coo(sparse_dim)
coo_y = s_dense_y.to_sparse_coo(sparse_dim)
......@@ -109,18 +117,24 @@ class TestSparseElementWiseAPI(unittest.TestCase):
expect_res = op(dense_x, dense_y)
expect_res.backward(expect_res)
np.testing.assert_allclose(expect_res.numpy(),
actual_res.to_dense().numpy(),
rtol=1e-05,
equal_nan=True)
np.testing.assert_allclose(dense_x.grad.numpy(),
coo_x.grad.to_dense().numpy(),
rtol=1e-05,
equal_nan=True)
np.testing.assert_allclose(dense_y.grad.numpy(),
coo_y.grad.to_dense().numpy(),
rtol=1e-05,
equal_nan=True)
np.testing.assert_allclose(
expect_res.numpy(),
actual_res.to_dense().numpy(),
rtol=1e-05,
equal_nan=True,
)
np.testing.assert_allclose(
dense_x.grad.numpy(),
coo_x.grad.to_dense().numpy(),
rtol=1e-05,
equal_nan=True,
)
np.testing.assert_allclose(
dense_y.grad.numpy(),
coo_y.grad.to_dense().numpy(),
rtol=1e-05,
equal_nan=True,
)
def test_support_dtypes_csr(self):
paddle.device.set_device('cpu')
......@@ -140,38 +154,37 @@ class TestSparseElementWiseAPI(unittest.TestCase):
values2_data = [[1.0], [2.0]]
shape = [2, 4, 2]
sp_a = sparse.sparse_coo_tensor(indices_data,
values1_data,
shape,
stop_gradient=False)
sp_b = sparse.sparse_coo_tensor(indices_data,
values2_data,
shape,
stop_gradient=False)
sp_a = sparse.sparse_coo_tensor(
indices_data, values1_data, shape, stop_gradient=False
)
sp_b = sparse.sparse_coo_tensor(
indices_data, values2_data, shape, stop_gradient=False
)
values1 = paddle.to_tensor(values1_data, stop_gradient=False)
values2 = paddle.to_tensor(values2_data, stop_gradient=False)
#c.values() = a.values() + b.values()
# c.values() = a.values() + b.values()
sp_c = sparse.add(sp_a, sp_b)
sp_c.backward()
ref_c = values1 + values2
ref_c.backward()
np.testing.assert_allclose(sp_c.values().numpy(), ref_c.numpy())
np.testing.assert_allclose(sp_a.grad.values().numpy(),
values1.grad.numpy())
np.testing.assert_allclose(sp_b.grad.values().numpy(),
values2.grad.numpy())
np.testing.assert_allclose(
sp_a.grad.values().numpy(), values1.grad.numpy()
)
np.testing.assert_allclose(
sp_b.grad.values().numpy(), values2.grad.numpy()
)
def test_add_bias(self):
indices_data = [[0, 1], [0, 3]]
values_data = [[1.0, 1.0], [2.0, 2.0]]
shape = [2, 4, 2]
sp_a = sparse.sparse_coo_tensor(indices_data,
values_data,
shape,
stop_gradient=False)
sp_a = sparse.sparse_coo_tensor(
indices_data, values_data, shape, stop_gradient=False
)
bias_values = [1.0, 2.0]
......@@ -179,14 +192,15 @@ class TestSparseElementWiseAPI(unittest.TestCase):
values2 = paddle.to_tensor(bias_values, stop_gradient=False)
values3 = paddle.to_tensor(bias_values, stop_gradient=False)
#c.values() = a.values() + b
# c.values() = a.values() + b
sp_c = sparse.add(sp_a, values2)
sp_c.backward()
ref_c = values1 + values3
ref_c.backward()
np.testing.assert_allclose(sp_c.values().numpy(), ref_c.numpy())
np.testing.assert_allclose(sp_a.grad.values().numpy(),
values1.grad.numpy())
np.testing.assert_allclose(
sp_a.grad.values().numpy(), values1.grad.numpy()
)
np.testing.assert_allclose(values2.grad.numpy(), values3.grad.numpy())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册