未验证 提交 8c5e432b 编写于 作者: Z zhangkaihuo 提交者: GitHub

[cherry-pick] Optimize sparse kernel and fix some bug (#50118)

cherry-pick some PR about optimize sparse kernel and fix some bug:
#47736 #47703 #47604 #46679 #48439 #49009 #49734
上级 e32ff656
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include(ExternalProject)
set(CUTLASS_PREFIX_DIR ${THIRD_PARTY_PATH}/cutlass)
set(CUTLASS_REPOSITORY https://github.com/NVIDIA/cutlass.git)
set(CUTLASS_TAG v2.9.1)
include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/")
include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/include/")
include_directories(
"${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/util/include/")
add_definitions("-DPADDLE_WITH_CUTLASS")
ExternalProject_Add(
extern_cutlass
${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE}
GIT_REPOSITORY ${CUTLASS_REPOSITORY}
GIT_TAG "${CUTLASS_TAG}"
PREFIX ${CUTLASS_PREFIX_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND "")
add_library(cutlass INTERFACE)
add_dependencies(cutlass extern_cutlass)
...@@ -492,4 +492,14 @@ if(WITH_CUSPARSELT) ...@@ -492,4 +492,14 @@ if(WITH_CUSPARSELT)
list(APPEND third_party_deps extern_cusparselt) list(APPEND third_party_deps extern_cusparselt)
endif() endif()
if(WITH_GPU
AND NOT WITH_ARM
AND NOT WIN32
AND NOT APPLE)
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0)
include(external/cutlass) # download, build, install cusparselt
list(APPEND third_party_deps extern_cutlass)
endif()
endif()
add_custom_target(third_party ALL DEPENDS ${third_party_deps}) add_custom_target(third_party ALL DEPENDS ${third_party_deps})
...@@ -18,6 +18,10 @@ limitations under the License. */ ...@@ -18,6 +18,10 @@ limitations under the License. */
namespace phi { namespace phi {
namespace funcs { namespace funcs {
#define CUDNN_PER_ACTIVATION_THRESHOLD 10240
#define CUDNN_SPATIAL_THRESHOLD_TRAIN 880801
#define CUDNN_SPATIAL_THRESHOLD_EVAL 65535
inline void ExtractNCWHD(const phi::DDim &dims, inline void ExtractNCWHD(const phi::DDim &dims,
const DataLayout &data_layout, const DataLayout &data_layout,
int *N, int *N,
......
...@@ -26,6 +26,19 @@ __global__ void DistanceKernel(const T* start, const T* end, T* distance) { ...@@ -26,6 +26,19 @@ __global__ void DistanceKernel(const T* start, const T* end, T* distance) {
} }
} }
inline __device__ bool SetBits(const int value, int* ptr) {
const int index = value >> 5;
const int mask = 1 << (value & 31);
const int old = atomicOr(ptr + index, mask);
return (mask & old) != 0;
}
inline __device__ bool TestBits(const int value, const int* ptr) {
const int index = value >> 5;
const int mask = 1 << (value & 31);
return (mask & ptr[index]) != 0;
}
} // namespace sparse } // namespace sparse
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -852,15 +852,17 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -852,15 +852,17 @@ void BatchNormGradRawKernel(const Context &ctx,
// ctx.GetPlace()), // ctx.GetPlace()),
// epsilon, saved_mean_data, saved_var_data)); // epsilon, saved_mean_data, saved_var_data));
#else #else
// CUDNN only support small batch size }
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; // CUDNN only support small batch size
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240; bool use_native_nhwc =
const size_t CUDNN_SPATIAL_THRESHOLD = 880801; d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC)
const bool use_native_kernel = : false;
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || const bool use_native_kernel =
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
if (use_native_kernel) { (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
if (x_dims.size() == 2) { if (use_native_nhwc || (d_x && d_scale && d_bias)) {
if (use_native_kernel || use_native_nhwc) {
if (x_dims.size() == 2 || use_native_nhwc) {
dim3 block; dim3 block;
dim3 grid; dim3 grid;
const int block_size = 512; const int block_size = 512;
......
...@@ -72,6 +72,40 @@ static __global__ void BNForwardInference(const T *x, ...@@ -72,6 +72,40 @@ static __global__ void BNForwardInference(const T *x,
} }
} }
template <typename T>
static __global__ void InverseVariance(const BatchNormParamType<T> *variance,
const double epsilon,
const int C,
BatchNormParamType<T> *inv_variance) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < C) {
inv_variance[tid] = 1 / sqrt(variance[tid] + epsilon);
}
}
template <typename T, phi::DataLayout layout>
static __global__ void BN1DForwardInference(
const T *x,
const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *inv_variance,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *bias,
const int C,
const int N,
const int HxW,
const double epsilon,
T *y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int num = N * C * HxW;
for (int i = gid; i < num; i += stride) {
const int c = layout == phi::DataLayout::kNCHW ? i / HxW % C : i % C;
BatchNormParamType<T> x_sub_mean =
static_cast<BatchNormParamType<T>>(x[i]) - mean[c];
y[i] = static_cast<T>(scale[c] * x_sub_mean * inv_variance[c] + bias[c]);
}
}
template <typename T, int BlockDim, phi::DataLayout layout> template <typename T, int BlockDim, phi::DataLayout layout>
static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining( static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining(
const T *x, const T *x,
...@@ -691,9 +725,6 @@ void BatchNormKernel(const Context &ctx, ...@@ -691,9 +725,6 @@ void BatchNormKernel(const Context &ctx,
auto handle = ctx.cudnn_handle(); auto handle = ctx.cudnn_handle();
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
// Now, depending on whether we are running test or not, we have two paths. // Now, depending on whether we are running test or not, we have two paths.
// It is training mode when it's not reference AND not using pre-trained // It is training mode when it's not reference AND not using pre-trained
// model. // model.
...@@ -797,8 +828,8 @@ void BatchNormKernel(const Context &ctx, ...@@ -797,8 +828,8 @@ void BatchNormKernel(const Context &ctx,
// epsilon)); // epsilon));
#else #else
const bool use_native_kernel = const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || (x_dims.size() == 2 ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_EVAL));
if (use_native_kernel) { if (use_native_kernel) {
const int block_size = 256; const int block_size = 256;
const int grid_size = (N * C * H * W * D + block_size - 1) / block_size; const int grid_size = (N * C * H * W * D + block_size - 1) / block_size;
...@@ -816,18 +847,43 @@ void BatchNormKernel(const Context &ctx, ...@@ -816,18 +847,43 @@ void BatchNormKernel(const Context &ctx,
epsilon, epsilon,
transformed_y.template data<T>()); transformed_y.template data<T>());
} else { } else {
BNForwardInference<T, DataLayout::kNHWC> if (x_dims.size() == 2) {
<<<grid_size, block_size, 0, ctx.stream()>>>( DenseTensor inv_var = phi::Empty<BatchNormParamType<T>>(ctx, {C});
transformed_x.template data<T>(), auto *inv_var_ptr = inv_var.data<BatchNormParamType<T>>();
est_mean->template data<BatchNormParamType<T>>(), const int threads = 512 > C ? C : 512;
est_var->template data<BatchNormParamType<T>>(), const int blocks = (C + 511) / 512;
scale.template data<BatchNormParamType<T>>(), InverseVariance<T><<<blocks, threads>>>(
bias.template data<BatchNormParamType<T>>(), est_var->template data<BatchNormParamType<T>>(),
C, epsilon,
N, C,
H * W * D, inv_var_ptr);
epsilon, BN1DForwardInference<T, DataLayout::kNHWC>
transformed_y.template data<T>()); <<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
// est_var->template data<BatchNormParamType<T>>(),
inv_var_ptr,
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
} else {
BNForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(),
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
}
} }
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
...@@ -949,7 +1005,7 @@ void BatchNormKernel(const Context &ctx, ...@@ -949,7 +1005,7 @@ void BatchNormKernel(const Context &ctx,
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; // const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const bool use_native_kernel = const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
if (use_native_kernel) { if (use_native_kernel) {
dim3 block; dim3 block;
dim3 grid; dim3 grid;
......
...@@ -15,8 +15,14 @@ limitations under the License. */ ...@@ -15,8 +15,14 @@ limitations under the License. */
#pragma once #pragma once
#include <thrust/remove.h> #include <thrust/remove.h>
#include <thrust/sort.h>
#include <thrust/unique.h> #include <thrust/unique.h>
#ifdef __NVCC__
#include <cub/block/block_scan.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/phi/kernels/sparse/conv_kernel.h" #include "paddle/phi/kernels/sparse/conv_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
...@@ -167,7 +173,7 @@ inline void GatherV2(const GPUContext& dev_ctx, ...@@ -167,7 +173,7 @@ inline void GatherV2(const GPUContext& dev_ctx,
template <typename IntT> template <typename IntT>
__global__ void UniqueKernel(const IntT* in_indexs, __global__ void UniqueKernel(const IntT* in_indexs,
const int rulebook_len, const int rulebook_len,
int* out_index_table, int* index_flags,
int* out_indexs, int* out_indexs,
int* nnz) { int* nnz) {
extern __shared__ int cache[]; extern __shared__ int cache[];
...@@ -182,8 +188,8 @@ __global__ void UniqueKernel(const IntT* in_indexs, ...@@ -182,8 +188,8 @@ __global__ void UniqueKernel(const IntT* in_indexs,
if (i < rulebook_len) { if (i < rulebook_len) {
// atomicOr only support int // atomicOr only support int
int index = static_cast<int>(in_indexs[i]); int index = static_cast<int>(in_indexs[i]);
int flag = atomicOr(out_index_table + index, 1); const bool flag = phi::funcs::sparse::SetBits(index, index_flags);
if (flag == 0) { if (!flag) {
int j = atomicAdd(&count, 1); int j = atomicAdd(&count, 1);
cache[j] = index; cache[j] = index;
} }
...@@ -199,6 +205,88 @@ __global__ void UniqueKernel(const IntT* in_indexs, ...@@ -199,6 +205,88 @@ __global__ void UniqueKernel(const IntT* in_indexs,
} }
} }
inline __device__ uint32_t BitCount(const uint32_t data) {
uint32_t count = data;
count = (count & 0x55555555) + ((count >> 1) & 0x55555555);
count = (count & 0x33333333) + ((count >> 2) & 0x33333333);
count = (count & 0x0f0f0f0f) + ((count >> 4) & 0x0f0f0f0f);
count = (count & 0x00ff00ff) + ((count >> 8) & 0x00ff00ff);
count = (count & 0x0000ffff) + ((count >> 16) & 0x0000ffff);
return count;
}
static __global__ void GetOutIndexsCounter(const int* flags,
const int n,
int* out) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
__shared__ int block_count;
if (threadIdx.x == 0) {
block_count = 0;
}
__syncthreads();
if (tid < n) {
// get the count of 1 in flags[tid]
uint32_t count = BitCount(static_cast<uint32_t>(flags[tid]));
// add to block_count
// TODO(zhangkaihuo): replace with block reduce_sum
atomicAdd(&block_count, static_cast<int>(count));
}
__syncthreads();
// write to out
if (threadIdx.x == 0) {
out[blockIdx.x] = block_count;
}
}
template <int BS>
__global__ void GetOutIndexs(const int* flags,
const int n,
const int* offsets,
const int out_nnz,
int* out) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
__shared__ int block_counts[BS];
__shared__ int block_outs[BS * 32];
int count = 0;
if (tid < n) {
// get the count of 1 in flags[tid]
int flag = flags[tid];
count = BitCount(static_cast<uint32_t>(flag));
}
// call block prefix_sum
// using namespace cub;
typedef cub::BlockScan<int, BS> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
BlockScan(temp_storage).ExclusiveSum(count, count);
__syncthreads();
// write index to out
if (tid < n) {
// get the count of 1 in flags[tid]
int flag = flags[tid];
// int j = block_counts[threadIdx.x];
int j = count;
// TODO(zhangkaihuo): opt the loop
for (int i = 0; i < 32; ++i) {
if ((1 & (flag >> i)) == 1) {
block_outs[j++] = (tid << 5) + i;
}
}
}
__syncthreads();
// write to block_outs
int start = offsets[blockIdx.x];
int end = blockIdx.x == gridDim.x - 1 ? out_nnz : offsets[blockIdx.x + 1];
for (int i = threadIdx.x; i < end - start; i += blockDim.x) {
out[start + i] = block_outs[i];
}
}
template <typename IntT> template <typename IntT>
__global__ void GroupIndexs(const int* out_index_table, __global__ void GroupIndexs(const int* out_index_table,
const int n, const int n,
...@@ -284,7 +372,6 @@ __global__ void ProductRuleBookKernel(const T* x_indices, ...@@ -284,7 +372,6 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
atomicAdd(&counter_buf[kernel_index], 1); atomicAdd(&counter_buf[kernel_index], 1);
kernel_i = kernel_index; kernel_i = kernel_index;
} }
// rulebook[kernel_index * non_zero_num + i] = kernel_i;
rulebook[kernel_index * non_zero_num + i] = in_i; rulebook[kernel_index * non_zero_num + i] = in_i;
rulebook[kernel_index * non_zero_num + offset + i] = out_index; rulebook[kernel_index * non_zero_num + offset + i] = out_index;
++kernel_index; ++kernel_index;
...@@ -299,17 +386,19 @@ __global__ void ProductRuleBookKernel(const T* x_indices, ...@@ -299,17 +386,19 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
} }
template <typename IntT> template <typename IntT>
__global__ void GetOutIndexTable(const IntT* indices, __global__ void GetOutIndexTable1(const IntT* indices,
const IntT non_zero_num, const IntT non_zero_num,
const Dims4D dims, const Dims4D dims,
int* out_index_table) { int* index_flags,
int* out_index_table) {
CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) { CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) {
IntT batch = indices[i]; IntT batch = indices[i];
IntT in_z = indices[i + non_zero_num]; IntT in_z = indices[i + non_zero_num];
IntT in_y = indices[i + 2 * non_zero_num]; IntT in_y = indices[i + 2 * non_zero_num];
IntT in_x = indices[i + 3 * non_zero_num]; IntT in_x = indices[i + 3 * non_zero_num];
IntT index = PointToIndex(batch, in_x, in_y, in_z, dims); IntT index = PointToIndex(batch, in_x, in_y, in_z, dims);
out_index_table[index] = i == 0 ? -1 : i; phi::funcs::sparse::SetBits(index, index_flags);
out_index_table[index] = i;
} }
} }
...@@ -375,6 +464,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices, ...@@ -375,6 +464,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
const Dims4D paddings, const Dims4D paddings,
const Dims4D dilations, const Dims4D dilations,
const Dims4D strides, const Dims4D strides,
const int* index_flags,
const int* out_index_table, const int* out_index_table,
T* rulebook, T* rulebook,
int* counter) { int* counter) {
...@@ -417,9 +507,10 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices, ...@@ -417,9 +507,10 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
T out_x = (in_x + paddings[3] - kx * dilations[3]) / strides[3]; T out_x = (in_x + paddings[3] - kx * dilations[3]) / strides[3];
out_index = phi::funcs::sparse::PointToIndex<Dims4D>( out_index = phi::funcs::sparse::PointToIndex<Dims4D>(
batch, out_x, out_y, out_z, out_dims); batch, out_x, out_y, out_z, out_dims);
int real_out_index = out_index_table[out_index]; const bool flag =
if (real_out_index != 0) { phi::funcs::sparse::TestBits(out_index, index_flags);
real_out_index = real_out_index == -1 ? 0 : real_out_index; if (flag) {
int real_out_index = out_index_table[out_index];
in_i = i; in_i = i;
int buf_i = atomicAdd(&counter_buf[kernel_index], 1); int buf_i = atomicAdd(&counter_buf[kernel_index], 1);
kernel_i = kernel_index; kernel_i = kernel_index;
...@@ -440,7 +531,6 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices, ...@@ -440,7 +531,6 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
__syncthreads(); __syncthreads();
for (int i = 0; i < kernel_size; i++) { for (int i = 0; i < kernel_size; i++) {
if (threadIdx.x < counter_buf[i]) { if (threadIdx.x < counter_buf[i]) {
// rulebook[i * non_zero_num + counter_buf2[i] + threadIdx.x] = i;
rulebook[i * non_zero_num + counter_buf2[i] + threadIdx.x] = rulebook[i * non_zero_num + counter_buf2[i] + threadIdx.x] =
rulebook_buf[i * blockDim.x + threadIdx.x]; rulebook_buf[i * blockDim.x + threadIdx.x];
rulebook[i * non_zero_num + offset + counter_buf2[i] + threadIdx.x] = rulebook[i * non_zero_num + offset + counter_buf2[i] + threadIdx.x] =
...@@ -575,12 +665,18 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -575,12 +665,18 @@ int ProductRuleBook(const Context& dev_ctx,
DenseTensorMeta rulebook_meta( DenseTensorMeta rulebook_meta(
indices_dtype, {rulebook_rows, rulebook_cols}, DataLayout::NCHW); indices_dtype, {rulebook_rows, rulebook_cols}, DataLayout::NCHW);
int64_t table_size = 1; int table_size = 1;
for (int i = 0; i < out_dims.size() - 1; i++) { for (int i = 0; i < out_dims.size() - 1; i++) {
table_size *= out_dims[i]; table_size *= out_dims[i];
} }
DenseTensor out_index_table = phi::Empty<int>(dev_ctx, {table_size}); DenseTensor out_index_table = phi::Empty<int>(dev_ctx, {table_size});
int* out_index_table_ptr = out_index_table.data<int>(); int* out_index_table_ptr = out_index_table.data<int>();
// index_flags: flag the indices exist or not
int index_flags_size = (table_size + 31) / 32;
DenseTensor index_flags = phi::Empty<int>(dev_ctx, {index_flags_size});
int* index_flags_ptr = index_flags.data<int>();
phi::backends::gpu::GpuMemsetAsync(
index_flags_ptr, 0, sizeof(int) * index_flags.numel(), dev_ctx.stream());
if (subm) { if (subm) {
DenseTensor tmp_rulebook = phi::Empty(dev_ctx, std::move(rulebook_meta)); DenseTensor tmp_rulebook = phi::Empty(dev_ctx, std::move(rulebook_meta));
...@@ -590,16 +686,16 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -590,16 +686,16 @@ int ProductRuleBook(const Context& dev_ctx,
phi::Copy(dev_ctx, x.indices(), dev_ctx.GetPlace(), false, &out_indices); phi::Copy(dev_ctx, x.indices(), dev_ctx.GetPlace(), false, &out_indices);
phi::backends::gpu::GpuMemsetAsync(
out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream());
auto config = auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);
GetOutIndexTable<IntT><<<config.block_per_grid, GetOutIndexTable1<IntT><<<config.block_per_grid,
config.thread_per_block, config.thread_per_block,
0, 0,
dev_ctx.stream()>>>( dev_ctx.stream()>>>(out_indices.data<IntT>(),
out_indices.data<IntT>(), non_zero_num, d_x_dims, out_index_table_ptr); non_zero_num,
d_x_dims,
index_flags_ptr,
out_index_table_ptr);
size_t cache_size = size_t cache_size =
kernel_size * 2 * sizeof(int) + kernel_size * 2 * sizeof(int) +
...@@ -625,6 +721,7 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -625,6 +721,7 @@ int ProductRuleBook(const Context& dev_ctx,
d_paddings, d_paddings,
d_dilations, d_dilations,
d_strides, d_strides,
index_flags_ptr,
out_index_table_ptr, out_index_table_ptr,
rulebook_ptr, rulebook_ptr,
counter_ptr); counter_ptr);
...@@ -695,9 +792,6 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -695,9 +792,6 @@ int ProductRuleBook(const Context& dev_ctx,
int* out_index_ptr = out_index->data<int>(); int* out_index_ptr = out_index->data<int>();
int* unique_key_ptr = unique_key.data<int>(); int* unique_key_ptr = unique_key.data<int>();
phi::backends::gpu::GpuMemsetAsync(
out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream());
phi::backends::gpu::GpuMemsetAsync( phi::backends::gpu::GpuMemsetAsync(
unique_key_ptr, 0, sizeof(int), dev_ctx.stream()); unique_key_ptr, 0, sizeof(int), dev_ctx.stream());
...@@ -708,7 +802,7 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -708,7 +802,7 @@ int ProductRuleBook(const Context& dev_ctx,
cache_size, cache_size,
dev_ctx.stream()>>>(rulebook_ptr + rulebook_len, dev_ctx.stream()>>>(rulebook_ptr + rulebook_len,
rulebook_len, rulebook_len,
out_index_table_ptr, index_flags_ptr,
out_index_ptr, out_index_ptr,
unique_key_ptr); unique_key_ptr);
...@@ -719,13 +813,25 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -719,13 +813,25 @@ int ProductRuleBook(const Context& dev_ctx,
gpuMemcpyDeviceToHost, gpuMemcpyDeviceToHost,
dev_ctx.stream()); dev_ctx.stream());
dev_ctx.Wait(); dev_ctx.Wait();
const int threads = 256;
const int blocks = (index_flags.numel() + threads - 1) / threads;
GetOutIndexsCounter<<<blocks, threads, 0, dev_ctx.stream()>>>(
index_flags_ptr, index_flags.numel(), out_index_table_ptr);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
thrust::sort(thrust::hip::par.on(dev_ctx.stream()), thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()),
#else #else
thrust::sort(thrust::cuda::par.on(dev_ctx.stream()), thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()),
#endif #endif
out_index_ptr, out_index_table_ptr,
out_index_ptr + out_nnz); out_index_table_ptr + blocks,
out_index_table_ptr);
GetOutIndexs<threads>
<<<blocks, threads, 0, dev_ctx.stream()>>>(index_flags_ptr,
index_flags.numel(),
out_index_table_ptr,
out_nnz,
out_index_ptr);
const int64_t sparse_dim = 4; const int64_t sparse_dim = 4;
phi::DenseTensor out_indices = phi::DenseTensor out_indices =
......
...@@ -22,6 +22,9 @@ limitations under the License. */ ...@@ -22,6 +22,9 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/scatter.cu.h" #include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h" #include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h" #include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
#ifdef PADDLE_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
#endif
#include "glog/logging.h" #include "glog/logging.h"
...@@ -120,85 +123,171 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -120,85 +123,171 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
dev_ctx, x, key, tmp_rulebook, h_counter, out, rulebook, counter); dev_ctx, x, key, tmp_rulebook, h_counter, out, rulebook, counter);
} }
// 2. gather #ifdef PADDLE_WITH_CUTLASS
phi::DenseTensor in_features = bool cutlass = true;
phi::Empty<T>(dev_ctx, {rulebook_len, in_channels}); if (dev_ctx.GetComputeCapability() < 75) cutlass = false;
phi::DenseTensor out_features = if (in_channels % 4 != 0 || out_channels % 4 != 0) {
phi::Empty<T>(dev_ctx, {rulebook_len, out_channels}); if (std::is_same<T, phi::dtype::float16>::value) cutlass = false;
T* in_features_ptr = in_features.data<T>(); if (std::is_same<T, float>::value) cutlass = false;
T* out_features_ptr = out_features.data<T>();
phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, &out_features, static_cast<T>(0.0f));
Gather<T, IntT>(dev_ctx,
x.values().data<T>(),
rulebook_ptr,
rulebook_len,
in_channels,
in_features_ptr);
// 3. call gemm for every werght
auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx);
auto* out_values = out->mutable_values();
T* out_values_ptr = out_values->data<T>();
set_zero(dev_ctx, out_values, static_cast<T>(0.0f));
if (subm) {
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
unique_value.ResizeAndAllocate(
{static_cast<int>(out->nnz() * kernel_size)});
out_index.ResizeAndAllocate({static_cast<int>(rulebook_len)});
int* out_index_ptr = out_index.data<int>();
int* unique_value_ptr = unique_value.data<int>();
phi::backends::gpu::GpuMemsetAsync(
out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream());
GroupIndexs<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(rulebook_len,
kernel_size,
rulebook_ptr + rulebook_len,
out_index_ptr,
unique_value_ptr);
} }
if (!std::is_same<IntT, int32_t>::value) cutlass = false;
if (cutlass) {
auto* out_values = out->mutable_non_zero_elements();
T* out_values_ptr = out_values->data<T>();
phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, out_values, static_cast<T>(0.0f));
const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
if (h_counter_ptr[i] <= 0) {
continue;
}
const T* kernel_ptr = kernel.data<T>(); const int M = h_counter_ptr[i];
for (int i = 0; i < kernel_size; i++) { const int K = in_channels;
if (h_counter_ptr[i] <= 0) { const int N = out_channels;
continue; const T* tmp_kernel_ptr = kernel_ptr + i * K * N;
const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i];
const IntT* scatter_indices =
rulebook_ptr + rulebook_len + h_offsets_ptr[i];
if constexpr (std::is_same<T, phi::dtype::float16>::value &&
std::is_same<IntT, int32_t>::value) {
fp16_gather_gemm_scatter gather_gemm_scatter =
getBestFp16Kernel(M, N, K);
gather_gemm_scatter(
dev_ctx,
reinterpret_cast<const cutlass::half_t*>(
x.non_zero_elements().data<T>()),
reinterpret_cast<const cutlass::half_t*>(tmp_kernel_ptr),
reinterpret_cast<cutlass::half_t*>(out_values_ptr),
reinterpret_cast<cutlass::half_t*>(out_values_ptr),
M,
N,
K,
static_cast<const int32_t*>(gather_indices),
static_cast<const int32_t*>(scatter_indices),
static_cast<cutlass::half_t>(1),
static_cast<cutlass::half_t>(1));
}
if constexpr (std::is_same<T, float>::value &&
std::is_same<IntT, int32_t>::value) {
fp32_gather_gemm_scatter gather_gemm_scatter =
getBestFp32Kernel(M, N, K, dev_ctx.GetComputeCapability());
gather_gemm_scatter(dev_ctx,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
out_values_ptr,
M,
N,
K,
gather_indices,
scatter_indices,
static_cast<T>(1),
static_cast<T>(1));
}
if constexpr (std::is_same<T, double>::value &&
std::is_same<IntT, int32_t>::value) {
fp64_gather_gemm_scatter gather_gemm_scatter =
getBestFp64Kernel(M, N, K);
gather_gemm_scatter(dev_ctx,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
out_values_ptr,
out_values_ptr,
M,
N,
K,
gather_indices,
scatter_indices,
static_cast<T>(1),
static_cast<T>(1));
}
}
} else {
#endif
if (subm) {
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
unique_value.ResizeAndAllocate(
{static_cast<int>(out->nnz() * kernel_size)});
out_index.ResizeAndAllocate({static_cast<int>(rulebook_len)});
int* out_index_ptr = out_index.data<int>();
int* unique_value_ptr = unique_value.data<int>();
phi::backends::gpu::GpuMemsetAsync(
out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream());
GroupIndexs<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(rulebook_len,
kernel_size,
rulebook_ptr + rulebook_len,
out_index_ptr,
unique_value_ptr);
} }
// 2. gather
phi::DenseTensor in_features =
phi::Empty<T>(dev_ctx, {rulebook_len, in_channels});
phi::DenseTensor out_features =
phi::Empty<T>(dev_ctx, {rulebook_len, out_channels});
T* in_features_ptr = in_features.data<T>();
T* out_features_ptr = out_features.data<T>();
phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, &out_features, static_cast<T>(0.0f));
// call gemm: (n, in_channels) * (in_channels, out_channels) Gather<T, IntT>(dev_ctx,
const int M = h_counter_ptr[i]; x.values().data<T>(),
const int K = in_channels; rulebook_ptr,
const int N = out_channels; rulebook_len,
T* tmp_in_ptr = in_features_ptr + h_offsets_ptr[i] * in_channels; in_channels,
const T* tmp_kernel_ptr = kernel_ptr + i * K * N; in_features_ptr);
T* tmp_out_ptr = out_features_ptr + h_offsets_ptr[i] * out_channels;
blas.GEMM(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
static_cast<T>(1),
tmp_in_ptr,
tmp_kernel_ptr,
static_cast<T>(0),
tmp_out_ptr);
}
// 4. scatter // 3. call gemm for every werght
phi::funcs::sparse::ScatterV2<T>(dev_ctx, auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx);
out_features_ptr, auto* out_values = out->mutable_values();
out_index.data<int>(), T* out_values_ptr = out_values->data<T>();
unique_value.data<int>(), set_zero(dev_ctx, out_values, static_cast<T>(0.0f));
out->nnz(),
kernel_size, const T* kernel_ptr = kernel.data<T>();
out_channels, for (int i = 0; i < kernel_size; i++) {
1, if (h_counter_ptr[i] <= 0) {
out_values_ptr); continue;
}
// call gemm: (n, in_channels) * (in_channels, out_channels)
const int M = h_counter_ptr[i];
const int K = in_channels;
const int N = out_channels;
T* tmp_in_ptr = in_features_ptr + h_offsets_ptr[i] * in_channels;
const T* tmp_kernel_ptr = kernel_ptr + i * K * N;
T* tmp_out_ptr = out_features_ptr + h_offsets_ptr[i] * out_channels;
blas.GEMM(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
static_cast<T>(1),
tmp_in_ptr,
tmp_kernel_ptr,
static_cast<T>(0),
tmp_out_ptr);
}
// 4. scatter
phi::funcs::sparse::ScatterV2<T>(dev_ctx,
out_features_ptr,
out_index.data<int>(),
unique_value.data<int>(),
out->nnz(),
kernel_size,
out_channels,
1,
out_values_ptr);
#ifdef PADDLE_WITH_CUTLASS
}
#endif
} }
/** /**
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
namespace phi {
namespace sparse {
fp16_gather_gemm_scatter getBestFp16Kernel(const int M,
const int N,
const int K) {
if (K == 4 && N == 16) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4::Gemm>;
}
if (K == 16 && N == 16) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 16 && N == 32) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 32 && N == 32) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 32 && N == 64) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 64 && N == 64) {
if (M > 100000)
launchKernel<
cutlass::half_t,
cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8::Gemm>;
if (M > 20000)
launchKernel<
cutlass::half_t,
cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8::Gemm>;
if (M > 15000)
return launchKernel<
cutlass::half_t,
cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8::Gemm>;
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
if (K == 128) {
if (M >= 5000)
return launchKernel<
cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
return launchKernel<cutlass::half_t,
cutlass_tensorop_h16816gemm_64x64_64x5_nn_align8::Gemm>;
}
if (N == 128) {
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>;
}
return launchKernel<cutlass::half_t,
cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4::Gemm>;
}
fp32_gather_gemm_scatter getBestFp32Kernel(const int M,
const int N,
const int K,
const int SM) {
if (SM == 75) {
return launchKernel<
float,
cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4::Gemm>;
}
if (K == 4 && N == 16) {
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 16 && N == 16) {
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 16 && N == 32) {
if (M >= 10000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 32 && N == 32) {
if (M >= 10000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 32 && N == 64) {
if (M >= 10000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 64 && N == 64) {
if (M >= 15000)
return launchKernel<
float,
cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
if (K == 128) {
if (M >= 100000)
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4::Gemm>;
if (M >= 5000)
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4::Gemm>;
}
if (N == 128) {
if (M >= 100000)
return launchKernel<
float,
cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4::Gemm>;
if (M >= 5000)
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4::Gemm>;
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4::Gemm>;
}
return launchKernel<
float,
cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>;
}
fp64_gather_gemm_scatter getBestFp64Kernel(const int M,
const int N,
const int K) {
if (K == 4 && N == 16) {
return launchKernel<double,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1::Gemm>;
}
if (K == 16 && N == 16) {
if (M >= 10000)
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
return launchKernel<double,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1::Gemm>;
}
if (K == 16 && N == 32) {
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
if (K == 32 && N == 32) {
return launchKernel<double,
cutlass_tensorop_d884gemm_16x32_16x5_nn_align1::Gemm>;
}
if (K == 32 && N == 64) {
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
if (K == 64 && N == 64) {
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
return launchKernel<double,
cutlass_tensorop_d884gemm_32x16_16x5_nn_align1::Gemm>;
}
} // namespace sparse
} // namespace phi
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_CUTLASS
#include "cutlass/arch/mma.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/util/device_memory.h"
#include "examples/common/helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace phi {
namespace sparse {
typedef void (*fp16_gather_gemm_scatter)(const GPUContext& dev_ctx,
const cutlass::half_t* const a,
const cutlass::half_t* const b,
const cutlass::half_t* const c,
cutlass::half_t* const d,
const int m,
const int n,
const int k,
const int32_t* a_indices,
const int32_t* c_d_indices,
cutlass::half_t const alpha,
cutlass::half_t const beta);
typedef void (*fp32_gather_gemm_scatter)(const GPUContext& dev_ctx,
const float* const a,
const float* const b,
const float* const c,
float* const d,
const int m,
const int n,
const int k,
const int32_t* a_indices,
const int32_t* c_d_indices,
float const alpha,
float const beta);
typedef void (*fp64_gather_gemm_scatter)(const GPUContext& dev_ctx,
const double* const a,
const double* const b,
const double* const c,
double* const d,
const int m,
const int n,
const int k,
const int32_t* a_indices,
const int32_t* c_d_indices,
double const alpha,
double const beta);
fp16_gather_gemm_scatter getBestFp16Kernel(const int M,
const int K,
const int N);
fp32_gather_gemm_scatter getBestFp32Kernel(const int M,
const int K,
const int N,
const int SM);
fp64_gather_gemm_scatter getBestFp64Kernel(const int M,
const int K,
const int N);
template <typename T, typename Gemm>
void launchKernel(const GPUContext& dev_ctx,
const T* const a,
const T* const b,
const T* const c,
T* const d,
const int m,
const int n,
const int k,
const int32_t* a_indices,
const int32_t* c_d_indices,
T const alpha,
T const beta) {
cutlass::gemm::GemmCoord problem_size_real({m, n, k});
int split_k_slices = 1;
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size_real,
split_k_slices,
{alpha, beta},
a,
b,
c,
d,
cutlass::layout::RowMajor().capacity(problem_size_real.mk()),
cutlass::layout::RowMajor().capacity(problem_size_real.kn()),
cutlass::layout::RowMajor().capacity(problem_size_real.mn()),
cutlass::layout::RowMajor().capacity(problem_size_real.mn()),
problem_size_real.k(),
problem_size_real.n(),
problem_size_real.n(),
problem_size_real.n(),
a_indices,
nullptr,
c_d_indices};
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
Gemm gemm_op;
cutlass::Status status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(status);
gemm_op(dev_ctx.stream());
}
struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
8,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h1688gemm_64x128_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
8,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
4,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
4,
4,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
4,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
4,
4,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
8,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_h16816gemm_64x64_64x5_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t,
8,
cutlass::half_t,
cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
5,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::
LinearCombination<cutlass::half_t, 8, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::
LinearCombination<cutlass::half_t, 8, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
10,
4,
4,
cutlass::arch::OpMultiplyAddFastF16,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 16>,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
3,
4,
4,
cutlass::arch::OpMultiplyAddFastF16,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 64, 16>,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
4,
4,
4,
cutlass::arch::OpMultiplyAddFastF16,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 16>,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
3,
4,
4,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 16>,
cutlass::gemm::GemmShape<32, 64, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
6,
4,
4,
cutlass::arch::OpMultiplyAddFastF16,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 64, 16>,
cutlass::gemm::GemmShape<32, 32, 16>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
3,
4,
4,
cutlass::arch::OpMultiplyAddFastF32,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_d884gemm_16x32_16x5_nn_align1 {
using Gemm = cutlass::gemm::device::GemmUniversal<
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::RowMajor,
double,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<16, 32, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<double, 1, double, double>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
5,
1,
1,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
struct cutlass_tensorop_d884gemm_32x16_16x5_nn_align1 {
using Gemm = cutlass::gemm::device::GemmUniversal<
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::RowMajor,
double,
cutlass::layout::RowMajor,
double,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 16, 16>,
cutlass::gemm::GemmShape<16, 16, 16>,
cutlass::gemm::GemmShape<8, 8, 4>,
cutlass::epilogue::thread::LinearCombination<double, 1, double, double>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
5,
1,
1,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true,
false,
true>;
};
// sm75
struct cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd>;
};
} // namespace sparse
} // namespace phi
#endif
...@@ -25,6 +25,7 @@ limitations under the License. */ ...@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h" #include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h"
#include "paddle/phi/kernels/funcs/sparse/utils.cu.h"
namespace phi { namespace phi {
namespace sparse { namespace sparse {
...@@ -118,15 +119,20 @@ void SparseMaskKernel(const Context& dev_ctx, ...@@ -118,15 +119,20 @@ void SparseMaskKernel(const Context& dev_ctx,
} }
template <typename IntT> template <typename IntT>
__global__ void MaskTable(const IntT* x_indexs, const int n, int* table) { __global__ void MaskTable(const IntT* x_indexs,
const int n,
int* index_flags,
int* table) {
CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) {
int index = x_indexs[i]; int index = x_indexs[i];
table[index] = i == 0 ? -1 : i; phi::funcs::sparse::SetBits(index, index_flags);
table[index] = i;
} }
} }
template <typename T, typename IntT, int VecSize> template <typename T, typename IntT, int VecSize>
__global__ void MaskCopy(const IntT* mask_indexs, __global__ void MaskCopy(const IntT* mask_indexs,
const int* index_flags,
const int* table, const int* table,
const int n, const int n,
const int stride, const int stride,
...@@ -135,9 +141,10 @@ __global__ void MaskCopy(const IntT* mask_indexs, ...@@ -135,9 +141,10 @@ __global__ void MaskCopy(const IntT* mask_indexs,
using LoadT = phi::AlignedVector<T, VecSize>; using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>; using StoreT = phi::AlignedVector<T, VecSize>;
CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) {
int j = table[mask_indexs[i]]; const int mask_index = mask_indexs[i];
if (j != 0) { const bool flag = phi::funcs::sparse::TestBits(mask_index, index_flags);
if (j == -1) j = 0; if (flag) {
int j = table[mask_index];
for (int k = 0; k < stride; k += VecSize) { for (int k = 0; k < stride; k += VecSize) {
LoadT vec_x; LoadT vec_x;
phi::Load<T, VecSize>(x_values + j * stride + k, &vec_x); phi::Load<T, VecSize>(x_values + j * stride + k, &vec_x);
...@@ -217,12 +224,15 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, ...@@ -217,12 +224,15 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
int table_size = 1; int table_size = 1;
auto x_dims = x.dims(); auto x_dims = x.dims();
for (int i = 0; i < x_dims.size() - 1; i++) { for (int i = 0; i < sparse_dim; i++) {
table_size *= x_dims[i]; table_size *= x_dims[i];
} }
DenseTensor table = phi::Empty<int>(dev_ctx, {table_size}); DenseTensor table = phi::Empty<int>(dev_ctx, {table_size});
phi::backends::gpu::GpuMemsetAsync( DenseTensor index_flags = phi::Empty<int>(dev_ctx, {(table_size + 31) / 32});
table.data<int>(), 0, table_size * sizeof(int), dev_ctx.stream()); phi::backends::gpu::GpuMemsetAsync(index_flags.data<int>(),
0,
index_flags.numel() * sizeof(int),
dev_ctx.stream());
const int64_t stride = const int64_t stride =
x.dims().size() == sparse_dim ? 1 : x.values().dims()[1]; x.dims().size() == sparse_dim ? 1 : x.values().dims()[1];
*out = phi::EmptyLike<T>(dev_ctx, x.values()); *out = phi::EmptyLike<T>(dev_ctx, x.values());
...@@ -234,8 +244,10 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, ...@@ -234,8 +244,10 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
MaskTable<<<config.block_per_grid, MaskTable<<<config.block_per_grid,
config.thread_per_block, config.thread_per_block,
0, 0,
dev_ctx.stream()>>>( dev_ctx.stream()>>>(x_indexs_ptr,
x_indexs_ptr, x_indexs.numel(), table.data<int>()); x_indexs.numel(),
index_flags.data<int>(),
table.data<int>());
config = config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1); phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1);
...@@ -246,6 +258,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, ...@@ -246,6 +258,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
config.thread_per_block, config.thread_per_block,
0, 0,
dev_ctx.stream()>>>(mask_indexs_ptr, dev_ctx.stream()>>>(mask_indexs_ptr,
index_flags.data<int>(),
table.data<int>(), table.data<int>(),
mask_indexs.numel(), mask_indexs.numel(),
stride, stride,
...@@ -256,6 +269,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, ...@@ -256,6 +269,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
config.thread_per_block, config.thread_per_block,
0, 0,
dev_ctx.stream()>>>(mask_indexs_ptr, dev_ctx.stream()>>>(mask_indexs_ptr,
index_flags.data<int>(),
table.data<int>(), table.data<int>(),
mask_indexs.numel(), mask_indexs.numel(),
stride, stride,
......
...@@ -64,42 +64,50 @@ class TestSparseElementWiseAPI(unittest.TestCase): ...@@ -64,42 +64,50 @@ class TestSparseElementWiseAPI(unittest.TestCase):
csr_y = s_dense_y.to_sparse_csr() csr_y = s_dense_y.to_sparse_csr()
actual_res = get_actual_res(csr_x, csr_y, op) actual_res = get_actual_res(csr_x, csr_y, op)
actual_res.backward(actual_res)
expect_res = op(dense_x, dense_y) expect_res = op(dense_x, dense_y)
expect_res.backward(expect_res) expect_res.backward(expect_res)
np.testing.assert_allclose(expect_res.numpy(), np.testing.assert_allclose(
actual_res.to_dense().numpy(), expect_res.numpy(),
rtol=1e-05, actual_res.to_dense().numpy(),
equal_nan=True) rtol=1e-05,
equal_nan=True,
)
if not (op == __truediv__ and dtype in ['int32', 'int64']): if not (op == __truediv__ and dtype in ['int32', 'int64']):
np.testing.assert_allclose(dense_x.grad.numpy(), actual_res.backward(actual_res)
csr_x.grad.to_dense().numpy(), np.testing.assert_allclose(
rtol=1e-05, dense_x.grad.numpy(),
equal_nan=True) csr_x.grad.to_dense().numpy(),
np.testing.assert_allclose(dense_y.grad.numpy(), rtol=1e-05,
csr_y.grad.to_dense().numpy(), equal_nan=True,
rtol=1e-05, )
equal_nan=True) np.testing.assert_allclose(
dense_y.grad.numpy(),
csr_y.grad.to_dense().numpy(),
rtol=1e-05,
equal_nan=True,
)
def func_test_coo(self, op): def func_test_coo(self, op):
for sparse_dim in range(len(self.coo_shape) - 1, len(self.coo_shape)): for sparse_dim in range(len(self.coo_shape) - 1, len(self.coo_shape)):
for dtype in self.support_dtypes: for dtype in self.support_dtypes:
x = np.random.randint(-255, 255, x = np.random.randint(-255, 255, size=self.coo_shape).astype(
size=self.coo_shape).astype(dtype) dtype
y = np.random.randint(-255, 255, )
size=self.coo_shape).astype(dtype) y = np.random.randint(-255, 255, size=self.coo_shape).astype(
dtype
)
dense_x = paddle.to_tensor(x, dtype=dtype, stop_gradient=False) dense_x = paddle.to_tensor(x, dtype=dtype, stop_gradient=False)
dense_y = paddle.to_tensor(y, dtype=dtype, stop_gradient=False) dense_y = paddle.to_tensor(y, dtype=dtype, stop_gradient=False)
s_dense_x = paddle.to_tensor(x, s_dense_x = paddle.to_tensor(
dtype=dtype, x, dtype=dtype, stop_gradient=False
stop_gradient=False) )
s_dense_y = paddle.to_tensor(y, s_dense_y = paddle.to_tensor(
dtype=dtype, y, dtype=dtype, stop_gradient=False
stop_gradient=False) )
coo_x = s_dense_x.to_sparse_coo(sparse_dim) coo_x = s_dense_x.to_sparse_coo(sparse_dim)
coo_y = s_dense_y.to_sparse_coo(sparse_dim) coo_y = s_dense_y.to_sparse_coo(sparse_dim)
...@@ -109,18 +117,24 @@ class TestSparseElementWiseAPI(unittest.TestCase): ...@@ -109,18 +117,24 @@ class TestSparseElementWiseAPI(unittest.TestCase):
expect_res = op(dense_x, dense_y) expect_res = op(dense_x, dense_y)
expect_res.backward(expect_res) expect_res.backward(expect_res)
np.testing.assert_allclose(expect_res.numpy(), np.testing.assert_allclose(
actual_res.to_dense().numpy(), expect_res.numpy(),
rtol=1e-05, actual_res.to_dense().numpy(),
equal_nan=True) rtol=1e-05,
np.testing.assert_allclose(dense_x.grad.numpy(), equal_nan=True,
coo_x.grad.to_dense().numpy(), )
rtol=1e-05, np.testing.assert_allclose(
equal_nan=True) dense_x.grad.numpy(),
np.testing.assert_allclose(dense_y.grad.numpy(), coo_x.grad.to_dense().numpy(),
coo_y.grad.to_dense().numpy(), rtol=1e-05,
rtol=1e-05, equal_nan=True,
equal_nan=True) )
np.testing.assert_allclose(
dense_y.grad.numpy(),
coo_y.grad.to_dense().numpy(),
rtol=1e-05,
equal_nan=True,
)
def test_support_dtypes_csr(self): def test_support_dtypes_csr(self):
paddle.device.set_device('cpu') paddle.device.set_device('cpu')
...@@ -140,38 +154,37 @@ class TestSparseElementWiseAPI(unittest.TestCase): ...@@ -140,38 +154,37 @@ class TestSparseElementWiseAPI(unittest.TestCase):
values2_data = [[1.0], [2.0]] values2_data = [[1.0], [2.0]]
shape = [2, 4, 2] shape = [2, 4, 2]
sp_a = sparse.sparse_coo_tensor(indices_data, sp_a = sparse.sparse_coo_tensor(
values1_data, indices_data, values1_data, shape, stop_gradient=False
shape, )
stop_gradient=False) sp_b = sparse.sparse_coo_tensor(
sp_b = sparse.sparse_coo_tensor(indices_data, indices_data, values2_data, shape, stop_gradient=False
values2_data, )
shape,
stop_gradient=False)
values1 = paddle.to_tensor(values1_data, stop_gradient=False) values1 = paddle.to_tensor(values1_data, stop_gradient=False)
values2 = paddle.to_tensor(values2_data, stop_gradient=False) values2 = paddle.to_tensor(values2_data, stop_gradient=False)
#c.values() = a.values() + b.values() # c.values() = a.values() + b.values()
sp_c = sparse.add(sp_a, sp_b) sp_c = sparse.add(sp_a, sp_b)
sp_c.backward() sp_c.backward()
ref_c = values1 + values2 ref_c = values1 + values2
ref_c.backward() ref_c.backward()
np.testing.assert_allclose(sp_c.values().numpy(), ref_c.numpy()) np.testing.assert_allclose(sp_c.values().numpy(), ref_c.numpy())
np.testing.assert_allclose(sp_a.grad.values().numpy(), np.testing.assert_allclose(
values1.grad.numpy()) sp_a.grad.values().numpy(), values1.grad.numpy()
np.testing.assert_allclose(sp_b.grad.values().numpy(), )
values2.grad.numpy()) np.testing.assert_allclose(
sp_b.grad.values().numpy(), values2.grad.numpy()
)
def test_add_bias(self): def test_add_bias(self):
indices_data = [[0, 1], [0, 3]] indices_data = [[0, 1], [0, 3]]
values_data = [[1.0, 1.0], [2.0, 2.0]] values_data = [[1.0, 1.0], [2.0, 2.0]]
shape = [2, 4, 2] shape = [2, 4, 2]
sp_a = sparse.sparse_coo_tensor(indices_data, sp_a = sparse.sparse_coo_tensor(
values_data, indices_data, values_data, shape, stop_gradient=False
shape, )
stop_gradient=False)
bias_values = [1.0, 2.0] bias_values = [1.0, 2.0]
...@@ -179,14 +192,15 @@ class TestSparseElementWiseAPI(unittest.TestCase): ...@@ -179,14 +192,15 @@ class TestSparseElementWiseAPI(unittest.TestCase):
values2 = paddle.to_tensor(bias_values, stop_gradient=False) values2 = paddle.to_tensor(bias_values, stop_gradient=False)
values3 = paddle.to_tensor(bias_values, stop_gradient=False) values3 = paddle.to_tensor(bias_values, stop_gradient=False)
#c.values() = a.values() + b # c.values() = a.values() + b
sp_c = sparse.add(sp_a, values2) sp_c = sparse.add(sp_a, values2)
sp_c.backward() sp_c.backward()
ref_c = values1 + values3 ref_c = values1 + values3
ref_c.backward() ref_c.backward()
np.testing.assert_allclose(sp_c.values().numpy(), ref_c.numpy()) np.testing.assert_allclose(sp_c.values().numpy(), ref_c.numpy())
np.testing.assert_allclose(sp_a.grad.values().numpy(), np.testing.assert_allclose(
values1.grad.numpy()) sp_a.grad.values().numpy(), values1.grad.numpy()
)
np.testing.assert_allclose(values2.grad.numpy(), values3.grad.numpy()) np.testing.assert_allclose(values2.grad.numpy(), values3.grad.numpy())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册