diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake new file mode 100644 index 0000000000000000000000000000000000000000..a80a729a13957c9648da70380b41d78e3de662f2 --- /dev/null +++ b/cmake/external/cutlass.cmake @@ -0,0 +1,43 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +include(ExternalProject) + +set(CUTLASS_PREFIX_DIR ${THIRD_PARTY_PATH}/cutlass) + +set(CUTLASS_REPOSITORY https://github.com/NVIDIA/cutlass.git) +set(CUTLASS_TAG v2.9.1) + +include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/") +include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/include/") +include_directories( + "${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/util/include/") + +add_definitions("-DPADDLE_WITH_CUTLASS") + +ExternalProject_Add( + extern_cutlass + ${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE} + GIT_REPOSITORY ${CUTLASS_REPOSITORY} + GIT_TAG "${CUTLASS_TAG}" + PREFIX ${CUTLASS_PREFIX_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "") + +add_library(cutlass INTERFACE) + +add_dependencies(cutlass extern_cutlass) diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index bfba3dfbac404837faaccc0fba5b2672f7190c12..02c5bc3602f737f3dea9777971fb49b4e3682d87 100755 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -492,4 +492,14 @@ if(WITH_CUSPARSELT) list(APPEND third_party_deps extern_cusparselt) endif() +if(WITH_GPU + AND NOT WITH_ARM + AND NOT WIN32 + AND NOT APPLE) + if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0) + include(external/cutlass) # download, build, install cusparselt + list(APPEND third_party_deps extern_cutlass) + endif() +endif() + add_custom_target(third_party ALL DEPENDS ${third_party_deps}) diff --git a/paddle/phi/kernels/funcs/norm_utils.h b/paddle/phi/kernels/funcs/norm_utils.h index 2d0a879e41c783a801021db38711d997f62011b4..5c898549b353ead9856624ff5de556b7e8440c10 100644 --- a/paddle/phi/kernels/funcs/norm_utils.h +++ b/paddle/phi/kernels/funcs/norm_utils.h @@ -18,6 +18,10 @@ limitations under the License. */ namespace phi { namespace funcs { +#define CUDNN_PER_ACTIVATION_THRESHOLD 10240 +#define CUDNN_SPATIAL_THRESHOLD_TRAIN 880801 +#define CUDNN_SPATIAL_THRESHOLD_EVAL 65535 + inline void ExtractNCWHD(const phi::DDim &dims, const DataLayout &data_layout, int *N, diff --git a/paddle/phi/kernels/funcs/sparse/utils.cu.h b/paddle/phi/kernels/funcs/sparse/utils.cu.h index 074fe1ca420497689cf7d6942bfe9c2709e5b191..f3b742dfc38cd516503be2993844b32de8b9cd2f 100644 --- a/paddle/phi/kernels/funcs/sparse/utils.cu.h +++ b/paddle/phi/kernels/funcs/sparse/utils.cu.h @@ -26,6 +26,19 @@ __global__ void DistanceKernel(const T* start, const T* end, T* distance) { } } +inline __device__ bool SetBits(const int value, int* ptr) { + const int index = value >> 5; + const int mask = 1 << (value & 31); + const int old = atomicOr(ptr + index, mask); + return (mask & old) != 0; +} + +inline __device__ bool TestBits(const int value, const int* ptr) { + const int index = value >> 5; + const int mask = 1 << (value & 31); + return (mask & ptr[index]) != 0; +} + } // namespace sparse } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index cfd04cf6a8ef8db1904343fbd679e0aec6276231..5acccdfcea3899826acf83320994937173c52b2b 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -852,15 +852,17 @@ void BatchNormGradRawKernel(const Context &ctx, // ctx.GetPlace()), // epsilon, saved_mean_data, saved_var_data)); #else - // CUDNN only support small batch size - // const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; - const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240; - const size_t CUDNN_SPATIAL_THRESHOLD = 880801; - const bool use_native_kernel = - ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || - (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); - if (use_native_kernel) { - if (x_dims.size() == 2) { + } + // CUDNN only support small batch size + bool use_native_nhwc = + d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC) + : false; + const bool use_native_kernel = + ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || + (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN)); + if (use_native_nhwc || (d_x && d_scale && d_bias)) { + if (use_native_kernel || use_native_nhwc) { + if (x_dims.size() == 2 || use_native_nhwc) { dim3 block; dim3 grid; const int block_size = 512; diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 5c6fd04c15e68b9d7b5ad0a944b4edd98d99f723..ec13e0167b2ce54eca8dac7d2ff96269537b6a21 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -72,6 +72,40 @@ static __global__ void BNForwardInference(const T *x, } } +template +static __global__ void InverseVariance(const BatchNormParamType *variance, + const double epsilon, + const int C, + BatchNormParamType *inv_variance) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < C) { + inv_variance[tid] = 1 / sqrt(variance[tid] + epsilon); + } +} + +template +static __global__ void BN1DForwardInference( + const T *x, + const BatchNormParamType *mean, + const BatchNormParamType *inv_variance, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const int C, + const int N, + const int HxW, + const double epsilon, + T *y) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + int num = N * C * HxW; + for (int i = gid; i < num; i += stride) { + const int c = layout == phi::DataLayout::kNCHW ? i / HxW % C : i % C; + BatchNormParamType x_sub_mean = + static_cast>(x[i]) - mean[c]; + y[i] = static_cast(scale[c] * x_sub_mean * inv_variance[c] + bias[c]); + } +} + template static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining( const T *x, @@ -691,9 +725,6 @@ void BatchNormKernel(const Context &ctx, auto handle = ctx.cudnn_handle(); - const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240; - const size_t CUDNN_SPATIAL_THRESHOLD = 880801; - // Now, depending on whether we are running test or not, we have two paths. // It is training mode when it's not reference AND not using pre-trained // model. @@ -797,8 +828,8 @@ void BatchNormKernel(const Context &ctx, // epsilon)); #else const bool use_native_kernel = - ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || - (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); + (x_dims.size() == 2 || + (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_EVAL)); if (use_native_kernel) { const int block_size = 256; const int grid_size = (N * C * H * W * D + block_size - 1) / block_size; @@ -816,18 +847,43 @@ void BatchNormKernel(const Context &ctx, epsilon, transformed_y.template data()); } else { - BNForwardInference - <<>>( - transformed_x.template data(), - est_mean->template data>(), - est_var->template data>(), - scale.template data>(), - bias.template data>(), - C, - N, - H * W * D, - epsilon, - transformed_y.template data()); + if (x_dims.size() == 2) { + DenseTensor inv_var = phi::Empty>(ctx, {C}); + auto *inv_var_ptr = inv_var.data>(); + const int threads = 512 > C ? C : 512; + const int blocks = (C + 511) / 512; + InverseVariance<<>>( + est_var->template data>(), + epsilon, + C, + inv_var_ptr); + BN1DForwardInference + <<>>( + transformed_x.template data(), + est_mean->template data>(), + // est_var->template data>(), + inv_var_ptr, + scale.template data>(), + bias.template data>(), + C, + N, + H * W * D, + epsilon, + transformed_y.template data()); + } else { + BNForwardInference + <<>>( + transformed_x.template data(), + est_mean->template data>(), + est_var->template data>(), + scale.template data>(), + bias.template data>(), + C, + N, + H * W * D, + epsilon, + transformed_y.template data()); + } } } else { PADDLE_ENFORCE_GPU_SUCCESS( @@ -949,7 +1005,7 @@ void BatchNormKernel(const Context &ctx, // const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; const bool use_native_kernel = ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || - (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); + (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN)); if (use_native_kernel) { dim3 block; dim3 grid; diff --git a/paddle/phi/kernels/sparse/gpu/conv.cu.h b/paddle/phi/kernels/sparse/gpu/conv.cu.h index 161930a06fa854f49209f51f870e8c050ceb6e41..61457e506b22d15a8e3817ce9db501d158538c3d 100644 --- a/paddle/phi/kernels/sparse/gpu/conv.cu.h +++ b/paddle/phi/kernels/sparse/gpu/conv.cu.h @@ -15,8 +15,14 @@ limitations under the License. */ #pragma once #include -#include #include +#ifdef __NVCC__ +#include +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif #include "paddle/phi/kernels/sparse/conv_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" @@ -167,7 +173,7 @@ inline void GatherV2(const GPUContext& dev_ctx, template __global__ void UniqueKernel(const IntT* in_indexs, const int rulebook_len, - int* out_index_table, + int* index_flags, int* out_indexs, int* nnz) { extern __shared__ int cache[]; @@ -182,8 +188,8 @@ __global__ void UniqueKernel(const IntT* in_indexs, if (i < rulebook_len) { // atomicOr only support int int index = static_cast(in_indexs[i]); - int flag = atomicOr(out_index_table + index, 1); - if (flag == 0) { + const bool flag = phi::funcs::sparse::SetBits(index, index_flags); + if (!flag) { int j = atomicAdd(&count, 1); cache[j] = index; } @@ -199,6 +205,88 @@ __global__ void UniqueKernel(const IntT* in_indexs, } } +inline __device__ uint32_t BitCount(const uint32_t data) { + uint32_t count = data; + count = (count & 0x55555555) + ((count >> 1) & 0x55555555); + count = (count & 0x33333333) + ((count >> 2) & 0x33333333); + count = (count & 0x0f0f0f0f) + ((count >> 4) & 0x0f0f0f0f); + count = (count & 0x00ff00ff) + ((count >> 8) & 0x00ff00ff); + count = (count & 0x0000ffff) + ((count >> 16) & 0x0000ffff); + return count; +} + +static __global__ void GetOutIndexsCounter(const int* flags, + const int n, + int* out) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + __shared__ int block_count; + if (threadIdx.x == 0) { + block_count = 0; + } + __syncthreads(); + + if (tid < n) { + // get the count of 1 in flags[tid] + uint32_t count = BitCount(static_cast(flags[tid])); + // add to block_count + // TODO(zhangkaihuo): replace with block reduce_sum + atomicAdd(&block_count, static_cast(count)); + } + __syncthreads(); + // write to out + if (threadIdx.x == 0) { + out[blockIdx.x] = block_count; + } +} + +template +__global__ void GetOutIndexs(const int* flags, + const int n, + const int* offsets, + const int out_nnz, + int* out) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + __shared__ int block_counts[BS]; + __shared__ int block_outs[BS * 32]; + + int count = 0; + + if (tid < n) { + // get the count of 1 in flags[tid] + int flag = flags[tid]; + count = BitCount(static_cast(flag)); + } + + // call block prefix_sum + // using namespace cub; + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + BlockScan(temp_storage).ExclusiveSum(count, count); + __syncthreads(); + + // write index to out + if (tid < n) { + // get the count of 1 in flags[tid] + int flag = flags[tid]; + // int j = block_counts[threadIdx.x]; + int j = count; + // TODO(zhangkaihuo): opt the loop + for (int i = 0; i < 32; ++i) { + if ((1 & (flag >> i)) == 1) { + block_outs[j++] = (tid << 5) + i; + } + } + } + + __syncthreads(); + // write to block_outs + int start = offsets[blockIdx.x]; + int end = blockIdx.x == gridDim.x - 1 ? out_nnz : offsets[blockIdx.x + 1]; + for (int i = threadIdx.x; i < end - start; i += blockDim.x) { + out[start + i] = block_outs[i]; + } +} + template __global__ void GroupIndexs(const int* out_index_table, const int n, @@ -284,7 +372,6 @@ __global__ void ProductRuleBookKernel(const T* x_indices, atomicAdd(&counter_buf[kernel_index], 1); kernel_i = kernel_index; } - // rulebook[kernel_index * non_zero_num + i] = kernel_i; rulebook[kernel_index * non_zero_num + i] = in_i; rulebook[kernel_index * non_zero_num + offset + i] = out_index; ++kernel_index; @@ -299,17 +386,19 @@ __global__ void ProductRuleBookKernel(const T* x_indices, } template -__global__ void GetOutIndexTable(const IntT* indices, - const IntT non_zero_num, - const Dims4D dims, - int* out_index_table) { +__global__ void GetOutIndexTable1(const IntT* indices, + const IntT non_zero_num, + const Dims4D dims, + int* index_flags, + int* out_index_table) { CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) { IntT batch = indices[i]; IntT in_z = indices[i + non_zero_num]; IntT in_y = indices[i + 2 * non_zero_num]; IntT in_x = indices[i + 3 * non_zero_num]; IntT index = PointToIndex(batch, in_x, in_y, in_z, dims); - out_index_table[index] = i == 0 ? -1 : i; + phi::funcs::sparse::SetBits(index, index_flags); + out_index_table[index] = i; } } @@ -375,6 +464,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices, const Dims4D paddings, const Dims4D dilations, const Dims4D strides, + const int* index_flags, const int* out_index_table, T* rulebook, int* counter) { @@ -417,9 +507,10 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices, T out_x = (in_x + paddings[3] - kx * dilations[3]) / strides[3]; out_index = phi::funcs::sparse::PointToIndex( batch, out_x, out_y, out_z, out_dims); - int real_out_index = out_index_table[out_index]; - if (real_out_index != 0) { - real_out_index = real_out_index == -1 ? 0 : real_out_index; + const bool flag = + phi::funcs::sparse::TestBits(out_index, index_flags); + if (flag) { + int real_out_index = out_index_table[out_index]; in_i = i; int buf_i = atomicAdd(&counter_buf[kernel_index], 1); kernel_i = kernel_index; @@ -440,7 +531,6 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices, __syncthreads(); for (int i = 0; i < kernel_size; i++) { if (threadIdx.x < counter_buf[i]) { - // rulebook[i * non_zero_num + counter_buf2[i] + threadIdx.x] = i; rulebook[i * non_zero_num + counter_buf2[i] + threadIdx.x] = rulebook_buf[i * blockDim.x + threadIdx.x]; rulebook[i * non_zero_num + offset + counter_buf2[i] + threadIdx.x] = @@ -575,12 +665,18 @@ int ProductRuleBook(const Context& dev_ctx, DenseTensorMeta rulebook_meta( indices_dtype, {rulebook_rows, rulebook_cols}, DataLayout::NCHW); - int64_t table_size = 1; + int table_size = 1; for (int i = 0; i < out_dims.size() - 1; i++) { table_size *= out_dims[i]; } DenseTensor out_index_table = phi::Empty(dev_ctx, {table_size}); int* out_index_table_ptr = out_index_table.data(); + // index_flags: flag the indices exist or not + int index_flags_size = (table_size + 31) / 32; + DenseTensor index_flags = phi::Empty(dev_ctx, {index_flags_size}); + int* index_flags_ptr = index_flags.data(); + phi::backends::gpu::GpuMemsetAsync( + index_flags_ptr, 0, sizeof(int) * index_flags.numel(), dev_ctx.stream()); if (subm) { DenseTensor tmp_rulebook = phi::Empty(dev_ctx, std::move(rulebook_meta)); @@ -590,16 +686,16 @@ int ProductRuleBook(const Context& dev_ctx, phi::Copy(dev_ctx, x.indices(), dev_ctx.GetPlace(), false, &out_indices); - phi::backends::gpu::GpuMemsetAsync( - out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream()); - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); - GetOutIndexTable<<>>( - out_indices.data(), non_zero_num, d_x_dims, out_index_table_ptr); + GetOutIndexTable1<<>>(out_indices.data(), + non_zero_num, + d_x_dims, + index_flags_ptr, + out_index_table_ptr); size_t cache_size = kernel_size * 2 * sizeof(int) + @@ -625,6 +721,7 @@ int ProductRuleBook(const Context& dev_ctx, d_paddings, d_dilations, d_strides, + index_flags_ptr, out_index_table_ptr, rulebook_ptr, counter_ptr); @@ -695,9 +792,6 @@ int ProductRuleBook(const Context& dev_ctx, int* out_index_ptr = out_index->data(); int* unique_key_ptr = unique_key.data(); - phi::backends::gpu::GpuMemsetAsync( - out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream()); - phi::backends::gpu::GpuMemsetAsync( unique_key_ptr, 0, sizeof(int), dev_ctx.stream()); @@ -708,7 +802,7 @@ int ProductRuleBook(const Context& dev_ctx, cache_size, dev_ctx.stream()>>>(rulebook_ptr + rulebook_len, rulebook_len, - out_index_table_ptr, + index_flags_ptr, out_index_ptr, unique_key_ptr); @@ -719,13 +813,25 @@ int ProductRuleBook(const Context& dev_ctx, gpuMemcpyDeviceToHost, dev_ctx.stream()); dev_ctx.Wait(); + + const int threads = 256; + const int blocks = (index_flags.numel() + threads - 1) / threads; + GetOutIndexsCounter<<>>( + index_flags_ptr, index_flags.numel(), out_index_table_ptr); #ifdef PADDLE_WITH_HIP - thrust::sort(thrust::hip::par.on(dev_ctx.stream()), + thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()), #else - thrust::sort(thrust::cuda::par.on(dev_ctx.stream()), + thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()), #endif - out_index_ptr, - out_index_ptr + out_nnz); + out_index_table_ptr, + out_index_table_ptr + blocks, + out_index_table_ptr); + GetOutIndexs + <<>>(index_flags_ptr, + index_flags.numel(), + out_index_table_ptr, + out_nnz, + out_index_ptr); const int64_t sparse_dim = 4; phi::DenseTensor out_indices = diff --git a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu index 282033e62e3572050d63eaac356e182b62de6cbe..e6f3ca336491874bb86ffcdb2606ea8841a19a35 100644 --- a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu @@ -22,6 +22,9 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/scatter.cu.h" #include "paddle/phi/kernels/funcs/sparse/scatter.cu.h" #include "paddle/phi/kernels/sparse/gpu/conv.cu.h" +#ifdef PADDLE_WITH_CUTLASS +#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h" +#endif #include "glog/logging.h" @@ -120,85 +123,171 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, dev_ctx, x, key, tmp_rulebook, h_counter, out, rulebook, counter); } - // 2. gather - phi::DenseTensor in_features = - phi::Empty(dev_ctx, {rulebook_len, in_channels}); - phi::DenseTensor out_features = - phi::Empty(dev_ctx, {rulebook_len, out_channels}); - T* in_features_ptr = in_features.data(); - T* out_features_ptr = out_features.data(); - phi::funcs::SetConstant set_zero; - set_zero(dev_ctx, &out_features, static_cast(0.0f)); - - Gather(dev_ctx, - x.values().data(), - rulebook_ptr, - rulebook_len, - in_channels, - in_features_ptr); - - // 3. call gemm for every werght - auto blas = phi::funcs::GetBlas(dev_ctx); - auto* out_values = out->mutable_values(); - T* out_values_ptr = out_values->data(); - set_zero(dev_ctx, out_values, static_cast(0.0f)); - - if (subm) { - auto config = - phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); - unique_value.ResizeAndAllocate( - {static_cast(out->nnz() * kernel_size)}); - out_index.ResizeAndAllocate({static_cast(rulebook_len)}); - int* out_index_ptr = out_index.data(); - int* unique_value_ptr = unique_value.data(); - phi::backends::gpu::GpuMemsetAsync( - out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream()); - GroupIndexs<<>>(rulebook_len, - kernel_size, - rulebook_ptr + rulebook_len, - out_index_ptr, - unique_value_ptr); +#ifdef PADDLE_WITH_CUTLASS + bool cutlass = true; + if (dev_ctx.GetComputeCapability() < 75) cutlass = false; + if (in_channels % 4 != 0 || out_channels % 4 != 0) { + if (std::is_same::value) cutlass = false; + if (std::is_same::value) cutlass = false; } + if (!std::is_same::value) cutlass = false; + if (cutlass) { + auto* out_values = out->mutable_non_zero_elements(); + T* out_values_ptr = out_values->data(); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, out_values, static_cast(0.0f)); + + const T* kernel_ptr = kernel.data(); + for (int i = 0; i < kernel_size; i++) { + if (h_counter_ptr[i] <= 0) { + continue; + } - const T* kernel_ptr = kernel.data(); - for (int i = 0; i < kernel_size; i++) { - if (h_counter_ptr[i] <= 0) { - continue; + const int M = h_counter_ptr[i]; + const int K = in_channels; + const int N = out_channels; + const T* tmp_kernel_ptr = kernel_ptr + i * K * N; + const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i]; + const IntT* scatter_indices = + rulebook_ptr + rulebook_len + h_offsets_ptr[i]; + + if constexpr (std::is_same::value && + std::is_same::value) { + fp16_gather_gemm_scatter gather_gemm_scatter = + getBestFp16Kernel(M, N, K); + gather_gemm_scatter( + dev_ctx, + reinterpret_cast( + x.non_zero_elements().data()), + reinterpret_cast(tmp_kernel_ptr), + reinterpret_cast(out_values_ptr), + reinterpret_cast(out_values_ptr), + M, + N, + K, + static_cast(gather_indices), + static_cast(scatter_indices), + static_cast(1), + static_cast(1)); + } + if constexpr (std::is_same::value && + std::is_same::value) { + fp32_gather_gemm_scatter gather_gemm_scatter = + getBestFp32Kernel(M, N, K, dev_ctx.GetComputeCapability()); + gather_gemm_scatter(dev_ctx, + x.non_zero_elements().data(), + tmp_kernel_ptr, + out_values_ptr, + out_values_ptr, + M, + N, + K, + gather_indices, + scatter_indices, + static_cast(1), + static_cast(1)); + } + if constexpr (std::is_same::value && + std::is_same::value) { + fp64_gather_gemm_scatter gather_gemm_scatter = + getBestFp64Kernel(M, N, K); + gather_gemm_scatter(dev_ctx, + x.non_zero_elements().data(), + tmp_kernel_ptr, + out_values_ptr, + out_values_ptr, + M, + N, + K, + gather_indices, + scatter_indices, + static_cast(1), + static_cast(1)); + } + } + } else { +#endif + if (subm) { + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); + unique_value.ResizeAndAllocate( + {static_cast(out->nnz() * kernel_size)}); + out_index.ResizeAndAllocate({static_cast(rulebook_len)}); + int* out_index_ptr = out_index.data(); + int* unique_value_ptr = unique_value.data(); + phi::backends::gpu::GpuMemsetAsync( + out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream()); + GroupIndexs<<>>(rulebook_len, + kernel_size, + rulebook_ptr + rulebook_len, + out_index_ptr, + unique_value_ptr); } + // 2. gather + phi::DenseTensor in_features = + phi::Empty(dev_ctx, {rulebook_len, in_channels}); + phi::DenseTensor out_features = + phi::Empty(dev_ctx, {rulebook_len, out_channels}); + T* in_features_ptr = in_features.data(); + T* out_features_ptr = out_features.data(); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, &out_features, static_cast(0.0f)); - // call gemm: (n, in_channels) * (in_channels, out_channels) - const int M = h_counter_ptr[i]; - const int K = in_channels; - const int N = out_channels; - T* tmp_in_ptr = in_features_ptr + h_offsets_ptr[i] * in_channels; - const T* tmp_kernel_ptr = kernel_ptr + i * K * N; - T* tmp_out_ptr = out_features_ptr + h_offsets_ptr[i] * out_channels; - - blas.GEMM(CblasNoTrans, - CblasNoTrans, - M, - N, - K, - static_cast(1), - tmp_in_ptr, - tmp_kernel_ptr, - static_cast(0), - tmp_out_ptr); - } + Gather(dev_ctx, + x.values().data(), + rulebook_ptr, + rulebook_len, + in_channels, + in_features_ptr); - // 4. scatter - phi::funcs::sparse::ScatterV2(dev_ctx, - out_features_ptr, - out_index.data(), - unique_value.data(), - out->nnz(), - kernel_size, - out_channels, - 1, - out_values_ptr); + // 3. call gemm for every werght + auto blas = phi::funcs::GetBlas(dev_ctx); + auto* out_values = out->mutable_values(); + T* out_values_ptr = out_values->data(); + set_zero(dev_ctx, out_values, static_cast(0.0f)); + + const T* kernel_ptr = kernel.data(); + for (int i = 0; i < kernel_size; i++) { + if (h_counter_ptr[i] <= 0) { + continue; + } + + // call gemm: (n, in_channels) * (in_channels, out_channels) + const int M = h_counter_ptr[i]; + const int K = in_channels; + const int N = out_channels; + T* tmp_in_ptr = in_features_ptr + h_offsets_ptr[i] * in_channels; + const T* tmp_kernel_ptr = kernel_ptr + i * K * N; + T* tmp_out_ptr = out_features_ptr + h_offsets_ptr[i] * out_channels; + + blas.GEMM(CblasNoTrans, + CblasNoTrans, + M, + N, + K, + static_cast(1), + tmp_in_ptr, + tmp_kernel_ptr, + static_cast(0), + tmp_out_ptr); + } + + // 4. scatter + phi::funcs::sparse::ScatterV2(dev_ctx, + out_features_ptr, + out_index.data(), + unique_value.data(), + out->nnz(), + kernel_size, + out_channels, + 1, + out_values_ptr); +#ifdef PADDLE_WITH_CUTLASS + } +#endif } /** diff --git a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu new file mode 100644 index 0000000000000000000000000000000000000000..cfbaa7f1d63068db4942a1102db9b95b13649c56 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu @@ -0,0 +1,194 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_CUTLASS +#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h" +namespace phi { +namespace sparse { +fp16_gather_gemm_scatter getBestFp16Kernel(const int M, + const int N, + const int K) { + if (K == 4 && N == 16) { + return launchKernel; + } + if (K == 16 && N == 16) { + return launchKernel; + } + if (K == 16 && N == 32) { + return launchKernel; + } + if (K == 32 && N == 32) { + return launchKernel; + } + if (K == 32 && N == 64) { + return launchKernel; + } + if (K == 64 && N == 64) { + if (M > 100000) + launchKernel< + cutlass::half_t, + cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8::Gemm>; + if (M > 20000) + launchKernel< + cutlass::half_t, + cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8::Gemm>; + if (M > 15000) + return launchKernel< + cutlass::half_t, + cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8::Gemm>; + return launchKernel; + } + if (K == 128) { + if (M >= 5000) + return launchKernel< + cutlass::half_t, + cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>; + return launchKernel; + } + if (N == 128) { + return launchKernel; + } + return launchKernel; +} +fp32_gather_gemm_scatter getBestFp32Kernel(const int M, + const int N, + const int K, + const int SM) { + if (SM == 75) { + return launchKernel< + float, + cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4::Gemm>; + } + if (K == 4 && N == 16) { + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; + } + if (K == 16 && N == 16) { + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; + } + if (K == 16 && N == 32) { + if (M >= 10000) + return launchKernel< + float, + cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>; + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; + } + if (K == 32 && N == 32) { + if (M >= 10000) + return launchKernel< + float, + cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>; + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; + } + if (K == 32 && N == 64) { + if (M >= 10000) + return launchKernel< + float, + cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>; + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; + } + if (K == 64 && N == 64) { + if (M >= 15000) + return launchKernel< + float, + cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>; + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; + } + if (K == 128) { + if (M >= 100000) + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4::Gemm>; + if (M >= 5000) + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4::Gemm>; + return launchKernel< + float, + cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4::Gemm>; + } + if (N == 128) { + if (M >= 100000) + return launchKernel< + float, + cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4::Gemm>; + if (M >= 5000) + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4::Gemm>; + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4::Gemm>; + } + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; +} +fp64_gather_gemm_scatter getBestFp64Kernel(const int M, + const int N, + const int K) { + if (K == 4 && N == 16) { + return launchKernel; + } + if (K == 16 && N == 16) { + if (M >= 10000) + return launchKernel; + return launchKernel; + } + if (K == 16 && N == 32) { + return launchKernel; + } + if (K == 32 && N == 32) { + return launchKernel; + } + if (K == 32 && N == 64) { + return launchKernel; + } + if (K == 64 && N == 64) { + return launchKernel; + } + return launchKernel; +} + +} // namespace sparse +} // namespace phi +#endif diff --git a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h new file mode 100644 index 0000000000000000000000000000000000000000..b596ff545383fe88e66ef46ad56e778e91703460 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h @@ -0,0 +1,580 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#ifdef PADDLE_WITH_CUTLASS +#include "cutlass/arch/mma.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/util/device_memory.h" +#include "examples/common/helper.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +namespace phi { +namespace sparse { +typedef void (*fp16_gather_gemm_scatter)(const GPUContext& dev_ctx, + const cutlass::half_t* const a, + const cutlass::half_t* const b, + const cutlass::half_t* const c, + cutlass::half_t* const d, + const int m, + const int n, + const int k, + const int32_t* a_indices, + const int32_t* c_d_indices, + cutlass::half_t const alpha, + cutlass::half_t const beta); +typedef void (*fp32_gather_gemm_scatter)(const GPUContext& dev_ctx, + const float* const a, + const float* const b, + const float* const c, + float* const d, + const int m, + const int n, + const int k, + const int32_t* a_indices, + const int32_t* c_d_indices, + float const alpha, + float const beta); +typedef void (*fp64_gather_gemm_scatter)(const GPUContext& dev_ctx, + const double* const a, + const double* const b, + const double* const c, + double* const d, + const int m, + const int n, + const int k, + const int32_t* a_indices, + const int32_t* c_d_indices, + double const alpha, + double const beta); +fp16_gather_gemm_scatter getBestFp16Kernel(const int M, + const int K, + const int N); +fp32_gather_gemm_scatter getBestFp32Kernel(const int M, + const int K, + const int N, + const int SM); +fp64_gather_gemm_scatter getBestFp64Kernel(const int M, + const int K, + const int N); +template +void launchKernel(const GPUContext& dev_ctx, + const T* const a, + const T* const b, + const T* const c, + T* const d, + const int m, + const int n, + const int k, + const int32_t* a_indices, + const int32_t* c_d_indices, + T const alpha, + T const beta) { + cutlass::gemm::GemmCoord problem_size_real({m, n, k}); + int split_k_slices = 1; + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size_real, + split_k_slices, + {alpha, beta}, + a, + b, + c, + d, + cutlass::layout::RowMajor().capacity(problem_size_real.mk()), + cutlass::layout::RowMajor().capacity(problem_size_real.kn()), + cutlass::layout::RowMajor().capacity(problem_size_real.mn()), + cutlass::layout::RowMajor().capacity(problem_size_real.mn()), + problem_size_real.k(), + problem_size_real.n(), + problem_size_real.n(), + problem_size_real.n(), + a_indices, + nullptr, + c_d_indices}; + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + Gemm gemm_op; + cutlass::Status status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + status = gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + gemm_op(dev_ctx.stream()); +} +struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 8, + 8, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_h1688gemm_64x128_32x2_nn_align8 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 8, + 8, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 4, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 4, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 8, + 8, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_h16816gemm_64x64_64x5_nn_align8 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 5, + 8, + 8, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread:: + LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 8, + 8, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread:: + LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 8, + 8, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 10, + 4, + 4, + cutlass::arch::OpMultiplyAddFastF16, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 3, + 4, + 4, + cutlass::arch::OpMultiplyAddFastF16, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 4, + 4, + 4, + cutlass::arch::OpMultiplyAddFastF16, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 3, + 4, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 6, + 4, + 4, + cutlass::arch::OpMultiplyAddFastF16, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 3, + 4, + 4, + cutlass::arch::OpMultiplyAddFastF32, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_d884gemm_16x32_16x5_nn_align1 { + using Gemm = cutlass::gemm::device::GemmUniversal< + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::RowMajor, + double, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<16, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 5, + 1, + 1, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_d884gemm_32x16_16x5_nn_align1 { + using Gemm = cutlass::gemm::device::GemmUniversal< + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::RowMajor, + double, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 16, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 5, + 1, + 1, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; + +// sm75 +struct cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 8, + 8, + cutlass::arch::OpMultiplyAdd>; +}; + +} // namespace sparse +} // namespace phi +#endif diff --git a/paddle/phi/kernels/sparse/gpu/mask_kernel.cu b/paddle/phi/kernels/sparse/gpu/mask_kernel.cu index c4d2a691a4b3bd4b8fe19207d8e8e6daa2bd7f74..45f827801bc10d2cb39f517dacfa6524aa297e6d 100644 --- a/paddle/phi/kernels/sparse/gpu/mask_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/mask_kernel.cu @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h" +#include "paddle/phi/kernels/funcs/sparse/utils.cu.h" namespace phi { namespace sparse { @@ -118,15 +119,20 @@ void SparseMaskKernel(const Context& dev_ctx, } template -__global__ void MaskTable(const IntT* x_indexs, const int n, int* table) { +__global__ void MaskTable(const IntT* x_indexs, + const int n, + int* index_flags, + int* table) { CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { int index = x_indexs[i]; - table[index] = i == 0 ? -1 : i; + phi::funcs::sparse::SetBits(index, index_flags); + table[index] = i; } } template __global__ void MaskCopy(const IntT* mask_indexs, + const int* index_flags, const int* table, const int n, const int stride, @@ -135,9 +141,10 @@ __global__ void MaskCopy(const IntT* mask_indexs, using LoadT = phi::AlignedVector; using StoreT = phi::AlignedVector; CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { - int j = table[mask_indexs[i]]; - if (j != 0) { - if (j == -1) j = 0; + const int mask_index = mask_indexs[i]; + const bool flag = phi::funcs::sparse::TestBits(mask_index, index_flags); + if (flag) { + int j = table[mask_index]; for (int k = 0; k < stride; k += VecSize) { LoadT vec_x; phi::Load(x_values + j * stride + k, &vec_x); @@ -217,12 +224,15 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, int table_size = 1; auto x_dims = x.dims(); - for (int i = 0; i < x_dims.size() - 1; i++) { + for (int i = 0; i < sparse_dim; i++) { table_size *= x_dims[i]; } DenseTensor table = phi::Empty(dev_ctx, {table_size}); - phi::backends::gpu::GpuMemsetAsync( - table.data(), 0, table_size * sizeof(int), dev_ctx.stream()); + DenseTensor index_flags = phi::Empty(dev_ctx, {(table_size + 31) / 32}); + phi::backends::gpu::GpuMemsetAsync(index_flags.data(), + 0, + index_flags.numel() * sizeof(int), + dev_ctx.stream()); const int64_t stride = x.dims().size() == sparse_dim ? 1 : x.values().dims()[1]; *out = phi::EmptyLike(dev_ctx, x.values()); @@ -234,8 +244,10 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, MaskTable<<>>( - x_indexs_ptr, x_indexs.numel(), table.data()); + dev_ctx.stream()>>>(x_indexs_ptr, + x_indexs.numel(), + index_flags.data(), + table.data()); config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1); @@ -246,6 +258,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, config.thread_per_block, 0, dev_ctx.stream()>>>(mask_indexs_ptr, + index_flags.data(), table.data(), mask_indexs.numel(), stride, @@ -256,6 +269,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, config.thread_per_block, 0, dev_ctx.stream()>>>(mask_indexs_ptr, + index_flags.data(), table.data(), mask_indexs.numel(), stride, diff --git a/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py b/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py index 279a688f6aeaf2fbfe40b175d1625ccce4097594..a0beb85d5cd03ec5f63eaa7e99b60ae66e42b32c 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py @@ -64,42 +64,50 @@ class TestSparseElementWiseAPI(unittest.TestCase): csr_y = s_dense_y.to_sparse_csr() actual_res = get_actual_res(csr_x, csr_y, op) - actual_res.backward(actual_res) expect_res = op(dense_x, dense_y) expect_res.backward(expect_res) - np.testing.assert_allclose(expect_res.numpy(), - actual_res.to_dense().numpy(), - rtol=1e-05, - equal_nan=True) + np.testing.assert_allclose( + expect_res.numpy(), + actual_res.to_dense().numpy(), + rtol=1e-05, + equal_nan=True, + ) if not (op == __truediv__ and dtype in ['int32', 'int64']): - np.testing.assert_allclose(dense_x.grad.numpy(), - csr_x.grad.to_dense().numpy(), - rtol=1e-05, - equal_nan=True) - np.testing.assert_allclose(dense_y.grad.numpy(), - csr_y.grad.to_dense().numpy(), - rtol=1e-05, - equal_nan=True) + actual_res.backward(actual_res) + np.testing.assert_allclose( + dense_x.grad.numpy(), + csr_x.grad.to_dense().numpy(), + rtol=1e-05, + equal_nan=True, + ) + np.testing.assert_allclose( + dense_y.grad.numpy(), + csr_y.grad.to_dense().numpy(), + rtol=1e-05, + equal_nan=True, + ) def func_test_coo(self, op): for sparse_dim in range(len(self.coo_shape) - 1, len(self.coo_shape)): for dtype in self.support_dtypes: - x = np.random.randint(-255, 255, - size=self.coo_shape).astype(dtype) - y = np.random.randint(-255, 255, - size=self.coo_shape).astype(dtype) + x = np.random.randint(-255, 255, size=self.coo_shape).astype( + dtype + ) + y = np.random.randint(-255, 255, size=self.coo_shape).astype( + dtype + ) dense_x = paddle.to_tensor(x, dtype=dtype, stop_gradient=False) dense_y = paddle.to_tensor(y, dtype=dtype, stop_gradient=False) - s_dense_x = paddle.to_tensor(x, - dtype=dtype, - stop_gradient=False) - s_dense_y = paddle.to_tensor(y, - dtype=dtype, - stop_gradient=False) + s_dense_x = paddle.to_tensor( + x, dtype=dtype, stop_gradient=False + ) + s_dense_y = paddle.to_tensor( + y, dtype=dtype, stop_gradient=False + ) coo_x = s_dense_x.to_sparse_coo(sparse_dim) coo_y = s_dense_y.to_sparse_coo(sparse_dim) @@ -109,18 +117,24 @@ class TestSparseElementWiseAPI(unittest.TestCase): expect_res = op(dense_x, dense_y) expect_res.backward(expect_res) - np.testing.assert_allclose(expect_res.numpy(), - actual_res.to_dense().numpy(), - rtol=1e-05, - equal_nan=True) - np.testing.assert_allclose(dense_x.grad.numpy(), - coo_x.grad.to_dense().numpy(), - rtol=1e-05, - equal_nan=True) - np.testing.assert_allclose(dense_y.grad.numpy(), - coo_y.grad.to_dense().numpy(), - rtol=1e-05, - equal_nan=True) + np.testing.assert_allclose( + expect_res.numpy(), + actual_res.to_dense().numpy(), + rtol=1e-05, + equal_nan=True, + ) + np.testing.assert_allclose( + dense_x.grad.numpy(), + coo_x.grad.to_dense().numpy(), + rtol=1e-05, + equal_nan=True, + ) + np.testing.assert_allclose( + dense_y.grad.numpy(), + coo_y.grad.to_dense().numpy(), + rtol=1e-05, + equal_nan=True, + ) def test_support_dtypes_csr(self): paddle.device.set_device('cpu') @@ -140,38 +154,37 @@ class TestSparseElementWiseAPI(unittest.TestCase): values2_data = [[1.0], [2.0]] shape = [2, 4, 2] - sp_a = sparse.sparse_coo_tensor(indices_data, - values1_data, - shape, - stop_gradient=False) - sp_b = sparse.sparse_coo_tensor(indices_data, - values2_data, - shape, - stop_gradient=False) + sp_a = sparse.sparse_coo_tensor( + indices_data, values1_data, shape, stop_gradient=False + ) + sp_b = sparse.sparse_coo_tensor( + indices_data, values2_data, shape, stop_gradient=False + ) values1 = paddle.to_tensor(values1_data, stop_gradient=False) values2 = paddle.to_tensor(values2_data, stop_gradient=False) - #c.values() = a.values() + b.values() + # c.values() = a.values() + b.values() sp_c = sparse.add(sp_a, sp_b) sp_c.backward() ref_c = values1 + values2 ref_c.backward() np.testing.assert_allclose(sp_c.values().numpy(), ref_c.numpy()) - np.testing.assert_allclose(sp_a.grad.values().numpy(), - values1.grad.numpy()) - np.testing.assert_allclose(sp_b.grad.values().numpy(), - values2.grad.numpy()) + np.testing.assert_allclose( + sp_a.grad.values().numpy(), values1.grad.numpy() + ) + np.testing.assert_allclose( + sp_b.grad.values().numpy(), values2.grad.numpy() + ) def test_add_bias(self): indices_data = [[0, 1], [0, 3]] values_data = [[1.0, 1.0], [2.0, 2.0]] shape = [2, 4, 2] - sp_a = sparse.sparse_coo_tensor(indices_data, - values_data, - shape, - stop_gradient=False) + sp_a = sparse.sparse_coo_tensor( + indices_data, values_data, shape, stop_gradient=False + ) bias_values = [1.0, 2.0] @@ -179,14 +192,15 @@ class TestSparseElementWiseAPI(unittest.TestCase): values2 = paddle.to_tensor(bias_values, stop_gradient=False) values3 = paddle.to_tensor(bias_values, stop_gradient=False) - #c.values() = a.values() + b + # c.values() = a.values() + b sp_c = sparse.add(sp_a, values2) sp_c.backward() ref_c = values1 + values3 ref_c.backward() np.testing.assert_allclose(sp_c.values().numpy(), ref_c.numpy()) - np.testing.assert_allclose(sp_a.grad.values().numpy(), - values1.grad.numpy()) + np.testing.assert_allclose( + sp_a.grad.values().numpy(), values1.grad.numpy() + ) np.testing.assert_allclose(values2.grad.numpy(), values3.grad.numpy())