未验证 提交 8f469ddd 编写于 作者: Z zhangkaihuo 提交者: GitHub

Add sparse kernel coalesced (#41784)

上级 c31dd04c
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.h"
namespace phi {
namespace funcs {
namespace sparse {
template <typename IntT>
__global__ void FlattenIndicesKernel(const IntT* indices,
const IntT* sparse_offsets,
const int64_t non_zero_num,
const int64_t sparse_dim,
IntT* out) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
phi::funcs::sparse::FlattenIndices<IntT>(indices,
sparse_offsets,
non_zero_num,
sparse_dim,
tid,
gridDim.x * blockDim.x,
out);
}
template <typename IntT>
__global__ void IndexToCoordinateKernel(const IntT* indexs,
const Dim<DDim::kMaxRank> dims,
const int64_t non_zero_num,
const int64_t sparse_dim,
IntT* indices) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
IndexToCoordinate(indexs,
dims,
non_zero_num,
sparse_dim,
tid,
gridDim.x * blockDim.x,
indices);
}
} // namespace sparse
} // namespace funcs
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint.h>
#include "paddle/phi/core/ddim.h"
namespace phi {
namespace funcs {
namespace sparse {
template <typename IntT>
inline const IntT HOSTDEVICE CoordinateToIndex(const IntT* indices,
const IntT* sparse_offsets,
const int64_t non_zero_num,
const int64_t sparse_dim,
const int i) {
IntT index = 0;
for (IntT j = 0; j < sparse_dim; j++) {
index += indices[j * non_zero_num + i] * sparse_offsets[j];
}
return index;
}
template <typename IntT>
inline void HOSTDEVICE FlattenIndices(const IntT* indices,
const IntT* sparse_offsets,
const int64_t non_zero_num,
const int64_t sparse_dim,
const int64_t start,
const int64_t stride,
IntT* out) {
for (int64_t i = start; i < non_zero_num; i += stride) {
out[i] =
CoordinateToIndex(indices, sparse_offsets, non_zero_num, sparse_dim, i);
}
}
// 1. indices.dims().size() == 2
template <typename IntT>
inline void CalcOffsetsPerDim(const DDim& dims,
const int64_t sparse_dim,
IntT* offsets) {
IntT offset = 1;
for (IntT i = sparse_dim - 1; i >= 0; i--) {
offsets[i] = offset;
offset *= dims[i];
}
}
template <typename IntT>
inline void HOSTDEVICE IndexToCoordinate(const IntT index,
const Dim<DDim::kMaxRank>& dims,
const int64_t non_zero_num,
const int64_t sparse_dim,
const int indices_offset,
IntT* indices) {
IntT tmp_index = index;
for (int j = sparse_dim - 1; j >= 0; j--) {
indices[j * non_zero_num + indices_offset] = tmp_index % dims[j];
tmp_index /= dims[j];
}
}
template <typename IntT>
inline void HOSTDEVICE IndexToCoordinate(const IntT* indexs,
const Dim<DDim::kMaxRank>& dims,
const int64_t non_zero_num,
const int64_t sparse_dim,
const int64_t start,
const int64_t stride,
IntT* indices) {
for (int64_t i = start; i < non_zero_num; i += stride) {
IntT tmp_index = indexs[i];
IndexToCoordinate(tmp_index, dims, non_zero_num, sparse_dim, i, indices);
}
}
} // namespace sparse
} // namespace funcs
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
namespace phi {
namespace funcs {
namespace sparse {
/**
* brief: scatter add
* input: the inputs
* unique_value: refer to UpdateIndexKernel notes
* out_index: the output feature index
* non_zero_num: the number of output features
* rulebook_len: the length of rulebook
* channels: the output channel size
* out: the outputs
**/
template <typename T>
__global__ void ScatterKernel(const T* input,
const int* unique_value,
const int* out_index,
const int non_zero_num,
const int rulebook_len,
const int channels,
T* out,
const bool subm = false) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) {
int indices_i = i / channels;
int channels_i = i - indices_i * channels;
int start = unique_value[indices_i];
int end = indices_i == non_zero_num - 1 ? rulebook_len
: unique_value[indices_i + 1];
// max(end-start) = kernel_size
T sum = static_cast<T>(0);
if (subm) {
sum = out[indices_i * channels + channels_i];
}
for (int j = start; j < end; j++) {
const int out_feature_i = out_index[j];
sum += input[out_feature_i * channels + channels_i];
}
out[indices_i * channels + channels_i] = sum;
}
}
} // namespace sparse
} // namespace funcs
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
namespace phi {
namespace funcs {
namespace sparse {
// brief: calculation the distance between start and end
template <typename T>
__global__ void DistanceKernel(const T* start, const T* end, T* distance) {
if (threadIdx.x == 0) {
*distance = end - start;
}
}
} // namespace sparse
} // namespace funcs
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void CoalescedKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out);
} // namespace sparse
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/coalesced_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.h"
namespace phi {
namespace sparse {
template <typename T, typename IntT>
void CoalescedCPUKernel(const CPUContext& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out) {
const DenseTensor& x_indices = x.non_zero_indices();
const DenseTensor& x_values = x.non_zero_elements();
DenseTensor out_indices = phi::EmptyLike<IntT>(dev_ctx, x_indices);
DenseTensor out_values = phi::EmptyLike<T>(dev_ctx, x_values);
const int64_t sparse_dim = x.non_zero_indices().dims()[0];
std::vector<IntT> sparse_offsets(sparse_dim), x_indexs(x.nnz());
phi::funcs::sparse::CalcOffsetsPerDim<IntT>(
x.dims(), sparse_dim, sparse_offsets.data());
phi::funcs::sparse::FlattenIndices(x.non_zero_indices().data<IntT>(),
sparse_offsets.data(),
x.nnz(),
sparse_dim,
0,
1,
x_indexs.data());
const T* x_values_ptr = x_values.data<T>();
const int64_t stride =
x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim;
std::map<IntT, std::vector<int64_t>> indices_to_index;
for (uint64_t i = 0; i < x_indexs.size(); i++) {
IntT index = x_indexs[i];
if (indices_to_index.find(index) == indices_to_index.end()) {
std::vector<int64_t> indexs;
indexs.push_back(i);
indices_to_index[index] = indexs;
} else {
indices_to_index[index].push_back(i);
}
}
const int64_t out_nnz = indices_to_index.size();
out_indices.Resize({x_indices.dims()[0], out_nnz});
if (out_values.dims().size() == 1) {
out_values.Resize(phi::make_ddim({out_nnz}));
} else {
out_values.Resize(phi::make_ddim({out_nnz, x_values.dims()[1]}));
}
IntT* out_indices_ptr = out_indices.data<IntT>();
T* out_values_ptr = out_values.data<T>();
auto iter = indices_to_index.begin();
Dim<DDim::kMaxRank> const_dims;
for (int i = 0; i < x.dims().size(); i++) {
const_dims[i] = x.dims()[i];
}
for (int i = 0; iter != indices_to_index.end(); iter++, i++) {
phi::funcs::sparse::IndexToCoordinate(
iter->first, const_dims, out_nnz, sparse_dim, i, out_indices_ptr);
memcpy(out_values_ptr + i * stride,
x_values_ptr + iter->second[0] * stride,
stride * sizeof(T));
for (uint64_t j = 1; j < iter->second.size(); j++) {
for (int k = 0; k < stride; k++) {
out_values_ptr[i * stride + k] +=
x_values_ptr[iter->second[j] * stride + k];
}
}
}
out->SetMember(out_indices, out_values, x.dims(), true);
}
template <typename T, typename Context>
void CoalescedKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out) {
PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "CoalescedCPUKernel", ([&] {
CoalescedCPUKernel<T, data_t>(dev_ctx, x, out);
}));
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(sort,
CPU,
ALL_LAYOUT,
phi::sparse::CoalescedKernel,
float,
double,
phi::dtype::float16,
uint8_t,
int16_t,
int,
int64_t) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
......@@ -20,7 +20,9 @@ limitations under the License. */
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/common_shape.h"
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.h"
#include "paddle/phi/api/ext/dispatch.h"
namespace phi {
namespace sparse {
......@@ -56,10 +58,10 @@ void SparseMaskCPUKernel(const CPUContext& dev_ctx,
std::vector<IntT> out_indexs(non_zero_num), sparse_offsets(sparse_dim);
phi::funcs::sparse::CalcOffsetsPerDim<IntT>(
dims, sparse_dim, &sparse_offsets);
dims, sparse_dim, sparse_offsets.data());
for (int64_t i = 0; i < non_zero_num; i++) {
int64_t index = phi::funcs::sparse::IndicesToIndex<IntT>(
int64_t index = phi::funcs::sparse::CoordinateToIndex<IntT>(
indices_ptr, sparse_offsets.data(), non_zero_num, sparse_dim, i);
memcpy(out_values_ptr + i * cols, x_ptr + index * cols, cols * sizeof(T));
}
......@@ -98,7 +100,7 @@ void SparseMaskHelperCPUKernel(const CPUContext& dev_ctx,
std::vector<IntT> sparse_offsets(sparse_dim), x_indexs(x.nnz()),
mask_indexs(mask_indices.dims()[1]);
phi::funcs::sparse::CalcOffsetsPerDim<IntT>(
x.dims(), sparse_dim, &sparse_offsets);
x.dims(), sparse_dim, sparse_offsets.data());
phi::funcs::sparse::FlattenIndices(x.non_zero_indices().data<IntT>(),
sparse_offsets.data(),
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
#include "paddle/phi/kernels/funcs/sparse/utils.cu.h"
#include "paddle/phi/kernels/sparse/coalesced_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename IntT>
void CoalescedGPUKernel(const GPUContext& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out) {
const DenseTensor& x_indices = x.non_zero_indices();
const DenseTensor& x_values = x.non_zero_elements();
DenseTensor out_indices = phi::EmptyLike<IntT>(dev_ctx, x_indices);
DenseTensor out_values = phi::EmptyLike<T>(dev_ctx, x_values);
const int64_t nnz = x.nnz();
const int64_t sparse_dim = x.non_zero_indices().dims()[0];
std::vector<IntT> sparse_offsets(sparse_dim);
phi::funcs::sparse::CalcOffsetsPerDim<IntT>(
x.dims(), sparse_dim, sparse_offsets.data());
DenseTensorMeta sparse_offset_meta(
paddle::experimental::CppTypeToDataType<IntT>::Type(),
{sparse_dim},
DataLayout::NCHW);
DenseTensor d_sparse_offsets =
phi::Empty<GPUContext>(dev_ctx, std::move(sparse_offset_meta));
DenseTensor indexs = phi::Empty(
dev_ctx, DenseTensorMeta(x_indices.dtype(), {nnz}, x_indices.layout()));
IntT* indexs_ptr = indexs.data<IntT>();
phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data<IntT>(),
sparse_offsets.data(),
sizeof(IntT) * sparse_dim,
#ifdef PADDLE_WITH_HIP
hipMemcpyHostToDevice,
#else
cudaMemcpyHostToDevice,
#endif
dev_ctx.stream());
// 1. flatten indices
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, nnz, 1);
phi::funcs::sparse::FlattenIndicesKernel<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(
x.non_zero_indices().data<IntT>(),
d_sparse_offsets.data<IntT>(),
indexs.numel(),
sparse_dim,
indexs_ptr);
// 2. get the address of each non-zero values
const T* x_values_ptr = x_values.data<T>();
const int64_t stride =
x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim;
DenseTensor values_indexs = phi::Empty(
dev_ctx, DenseTensorMeta(DataType::INT32, {nnz}, DataLayout::NCHW));
int* values_indexs_ptr = values_indexs.data<int>();
DenseTensor public_indexs = phi::EmptyLike<int>(dev_ctx, values_indexs);
// values_indexs = [0,1,2,,,nnz-1]
phi::IndexKernel<int, kps::IdentityFunctor<int>>(
dev_ctx, &values_indexs, kps::IdentityFunctor<int>());
phi::IndexKernel<int, kps::IdentityFunctor<int>>(
dev_ctx, &public_indexs, kps::IdentityFunctor<int>());
// 3. sort (indices, values index)
#ifdef PADDLE_WITH_HIP
thrust::sort_by_key(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::sort_by_key(thrust::cuda::par.on(dev_ctx.stream()),
#endif
indexs_ptr,
indexs_ptr + nnz,
values_indexs_ptr);
// 4. unique index
thrust::pair<IntT*, int*> new_end =
#ifdef PADDLE_WITH_HIP
thrust::unique_by_key(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::unique_by_key(thrust::cuda::par.on(dev_ctx.stream()),
#endif
indexs_ptr,
indexs_ptr + nnz,
public_indexs.data<int>());
phi::funcs::sparse::DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>(
indexs_ptr, new_end.first, out_indices.data<IntT>());
IntT out_nnz = 0;
phi::backends::gpu::GpuMemcpyAsync(&out_nnz,
out_indices.data<IntT>(),
sizeof(IntT),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost,
#else
cudaMemcpyDeviceToHost,
#endif
dev_ctx.stream());
dev_ctx.Wait();
out_indices.Resize({x_indices.dims()[0], out_nnz});
if (out_values.dims().size() == 1) {
out_values.Resize(phi::make_ddim({out_nnz}));
} else {
out_values.Resize(phi::make_ddim({out_nnz, x_values.dims()[1]}));
}
// 5. scatter the values
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, nnz * stride, 1);
phi::funcs::sparse::ScatterKernel<T><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(
x_values_ptr,
public_indexs.data<int>(),
values_indexs_ptr,
out_nnz,
nnz,
stride,
out_values.data<T>());
// 6. convert index to coordinate
Dim<DDim::kMaxRank> const_dims;
for (int i = 0; i < x.dims().size(); i++) {
const_dims[i] = x.dims()[i];
}
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_nnz, 1);
phi::funcs::sparse::IndexToCoordinateKernel<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(
indexs_ptr, const_dims, out_nnz, sparse_dim, out_indices.data<IntT>());
out->SetMember(out_indices, out_values, x.dims(), true);
}
template <typename T, typename Context>
void CoalescedKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out) {
PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "CoalescedGPUKernel", ([&] {
CoalescedGPUKernel<T, data_t>(dev_ctx, x, out);
}));
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(sort,
GPU,
ALL_LAYOUT,
phi::sparse::CoalescedKernel,
float,
double,
phi::dtype::float16,
uint8_t,
int16_t,
int,
int64_t) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
......@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/utils.cu.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h"
#include "paddle/phi/kernels/sparse/convolution_kernel.h"
......@@ -60,46 +61,6 @@ __global__ void GatherKernel(const T* params,
}
}
/**
* brief: scatter add
* input: the inputs
* unique_value: refer to UpdateIndexKernel notes
* out_index: the output feature index
* non_zero_num: the number of output features
* rulebook_len: the length of rulebook
* channels: the output channel size
* out: the outputs
**/
template <typename T>
__global__ void ScatterKernel(const T* input,
const int* unique_value,
const int* out_index,
const int non_zero_num,
const int rulebook_len,
const int channels,
T* out,
const bool subm = false) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) {
int indices_i = i / channels;
int channels_i = i - indices_i * channels;
int start = unique_value[indices_i];
int end = indices_i == non_zero_num - 1 ? rulebook_len
: unique_value[indices_i + 1];
// max(end-start) = kernel_size
T sum = static_cast<T>(0);
if (subm) {
sum = out[indices_i * channels + channels_i];
}
for (int j = start; j < end; j++) {
const int out_feature_i = out_index[j];
sum += input[out_feature_i * channels + channels_i];
}
out[indices_i * channels + channels_i] = sum;
}
}
template <typename Context, typename IntT = int>
inline IntT* SortedAndUniqueIndex(const Context& dev_ctx,
const IntT* rulebook_ptr,
......@@ -186,14 +147,6 @@ __global__ void UpdateIndexKernel(const T* unique_keys,
}
}
// brief: calculation the distance between start and end
template <typename T>
__global__ void DistanceKernel(const T* start, const T* end, T* distance) {
if (threadIdx.x == 0) {
*distance = end - start;
}
}
template <typename IntT>
__global__ void UpdateOutIndexAndCounterAfterLowerBound(
const IntT* x_indexs,
......@@ -402,7 +355,7 @@ int ProductRuleBook(const Context& dev_ctx,
rulebook_ptr + rulebook_rows * rulebook_cols,
-1);
DistanceKernel<IntT><<<1, 1, 0, dev_ctx.stream()>>>(
phi::funcs::sparse::DistanceKernel<IntT><<<1, 1, 0, dev_ctx.stream()>>>(
rulebook_ptr, last, rulebook_ptr + 3 * kernel_size * non_zero_num - 1);
IntT rulebook_len = 0;
phi::backends::gpu::GpuMemcpyAsync(
......@@ -468,7 +421,7 @@ int ProductRuleBook(const Context& dev_ctx,
rulebook_ptr,
rulebook_ptr + 3 * rulebook_len,
-1);
DistanceKernel<IntT><<<1, 1, 0, dev_ctx.stream()>>>(
phi::funcs::sparse::DistanceKernel<IntT><<<1, 1, 0, dev_ctx.stream()>>>(
rulebook_ptr, last, bound_ptr);
phi::backends::gpu::GpuMemcpyAsync(&rulebook_len,
bound_ptr,
......@@ -536,7 +489,7 @@ int ProductRuleBook(const Context& dev_ctx,
// thrust::distance doesn't support stream parameters
// const int out_non_zero_num = thrust::distance(unique_key_ptr,
// new_end.first);
DistanceKernel<IntT><<<1, 1, 0, dev_ctx.stream()>>>(
phi::funcs::sparse::DistanceKernel<IntT><<<1, 1, 0, dev_ctx.stream()>>>(
unique_key_ptr,
new_end,
rulebook_ptr + rulebook_rows * rulebook_cols - 1);
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
#include "paddle/phi/kernels/sparse/convolution_grad_kernel.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"
......@@ -222,10 +223,11 @@ void Conv3dGradGPUKernel(const GPUContext& dev_ctx,
config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * in_channels, 1);
ScatterKernel<T><<<config.block_per_grid.x,
phi::funcs::sparse::ScatterKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(d_x_features_ptr,
dev_ctx.stream()>>>(
d_x_features_ptr,
unique_value.data<int>(),
out_index.data<int>(),
x.nnz(),
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
#include "paddle/phi/kernels/sparse/convolution_kernel.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"
......@@ -169,10 +170,11 @@ void Conv3dGPUKernel(const GPUContext& dev_ctx,
} else {
config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, out->nnz() * out_channels, 1);
ScatterKernel<T><<<config.block_per_grid.x,
phi::funcs::sparse::ScatterKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_features_ptr,
dev_ctx.stream()>>>(
out_features_ptr,
unique_value.data<int>(),
out_index.data<int>(),
out->nnz(),
......
......@@ -23,7 +23,7 @@ limitations under the License. */
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/common_shape.h"
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h"
#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h"
namespace phi {
......@@ -123,23 +123,6 @@ void SparseMaskKernel(const Context& dev_ctx,
}));
}
// TODO(zhangkaihuo): Use an op to realize the function of FlattenIndices
template <typename IntT>
__global__ void FlattenIndicesKernel(const IntT* indices,
const IntT* sparse_offsets,
const int64_t non_zero_num,
const int64_t sparse_dim,
IntT* out) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
phi::funcs::sparse::FlattenIndices<IntT>(indices,
sparse_offsets,
non_zero_num,
sparse_dim,
tid,
gridDim.x * blockDim.x,
out);
}
template <typename T, typename IntT>
__global__ void SparseMaskCopyKernel(const IntT* x_indexs,
const IntT* mask_indexs,
......@@ -192,7 +175,8 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
IntT* bound_out_ptr = bound_out.data<IntT>();
// 1. calc the offsets of per dim
phi::funcs::sparse::CalcOffsetsPerDim(x.dims(), sparse_dim, &sparse_offsets);
phi::funcs::sparse::CalcOffsetsPerDim(
x.dims(), sparse_dim, sparse_offsets.data());
// 2. copy sparse_offsets to device
phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data<IntT>(),
sparse_offsets.data(),
......@@ -207,10 +191,11 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
// 3. flatten x indices and mask indices
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_indexs.numel(), 1);
FlattenIndicesKernel<<<config.block_per_grid,
phi::funcs::sparse::FlattenIndicesKernel<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(x.non_zero_indices().data<IntT>(),
dev_ctx.stream()>>>(
x.non_zero_indices().data<IntT>(),
d_sparse_offsets.data<IntT>(),
x_indexs.numel(),
sparse_dim,
......@@ -218,10 +203,11 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1);
FlattenIndicesKernel<<<config.block_per_grid,
phi::funcs::sparse::FlattenIndicesKernel<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(mask_indices.data<IntT>(),
dev_ctx.stream()>>>(
mask_indices.data<IntT>(),
d_sparse_offsets.data<IntT>(),
mask_indexs.numel(),
sparse_dim,
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/sparse/coalesced_kernel.h"
namespace phi {
namespace sparse {
......@@ -154,9 +155,9 @@ void SparseCooTensorKernel(const Context& dev_ctx,
const DenseTensor& indices,
const IntArray& dense_shape,
SparseCooTensor* out) {
*out =
SparseCooTensor(indices, values, phi::make_ddim(dense_shape.GetData()));
// TODO(zhangkaihuo): sort and merge the dumplicate indices
SparseCooTensor before_coalesced(
indices, values, phi::make_ddim(dense_shape.GetData()));
CoalescedKernel<T, Context>(dev_ctx, before_coalesced, out);
}
} // namespace sparse
......
......@@ -19,6 +19,8 @@ import paddle
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
devices = ['cpu', 'gpu']
class TestSparseCreate(unittest.TestCase):
def test_create_coo_by_tensor(self):
......@@ -30,6 +32,8 @@ class TestSparseCreate(unittest.TestCase):
dense_elements = paddle.to_tensor(values, dtype='float32')
coo = paddle.sparse.sparse_coo_tensor(
dense_indices, dense_elements, dense_shape, stop_gradient=False)
# test the to_string.py
print(coo)
assert np.array_equal(indices, coo.indices().numpy())
assert np.array_equal(values, coo.values().numpy())
......@@ -37,7 +41,7 @@ class TestSparseCreate(unittest.TestCase):
with _test_eager_guard():
indices = [[0, 1, 2], [1, 2, 0]]
values = [1.0, 2.0, 3.0]
dense_shape = [2, 3]
dense_shape = [3, 3]
coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
assert np.array_equal(indices, coo.indices().numpy())
assert np.array_equal(values, coo.values().numpy())
......@@ -67,6 +71,8 @@ class TestSparseCreate(unittest.TestCase):
dense_shape = [3, 4]
csr = paddle.sparse.sparse_csr_tensor(crows, cols, values,
dense_shape)
# test the to_string.py
print(csr)
assert np.array_equal(crows, csr.crows().numpy())
assert np.array_equal(cols, csr.cols().numpy())
assert np.array_equal(values, csr.values().numpy())
......@@ -205,6 +211,10 @@ class TestSparseConvert(unittest.TestCase):
def test_sparse_coo_tensor_grad(self):
with _test_eager_guard():
for device in devices:
if device == 'cpu' or (device == 'gpu' and
paddle.is_compiled_with_cuda()):
paddle.device.set_device(device)
indices = [[0, 1], [0, 1]]
values = [1, 2]
indices = paddle.to_tensor(indices, dtype='int32')
......@@ -220,23 +230,135 @@ class TestSparseConvert(unittest.TestCase):
grad_indices, grad_values, shape=[2, 2])
sparse_x.backward(sparse_out_grad)
correct_values_grad = [0, 3]
assert np.array_equal(correct_values_grad, values.grad.numpy())
assert np.array_equal(correct_values_grad,
values.grad.numpy())
place = core.CPUPlace()
indices_cpu = paddle.to_tensor(indices, dtype='int32', place=place)
values_cpu = paddle.to_tensor(
values, dtype='float32', place=place, stop_gradient=False)
sparse_x_cpu = paddle.sparse.sparse_coo_tensor(
indices_cpu,
values_cpu,
shape=[2, 2],
place=place,
stop_gradient=False)
def test_sparse_coo_tensor_sorted(self):
with _test_eager_guard():
for device in devices:
if device == 'cpu' or (device == 'gpu' and
paddle.is_compiled_with_cuda()):
paddle.device.set_device(device)
#test unsorted and duplicate indices
indices = [[1, 0, 0], [0, 1, 1]]
values = [1.0, 2.0, 3.0]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
sparse_x = paddle.sparse.sparse_coo_tensor(indices, values)
indices_sorted = [[0, 1], [1, 0]]
values_sorted = [5.0, 1.0]
assert np.array_equal(indices_sorted,
sparse_x.indices().numpy())
assert np.array_equal(values_sorted,
sparse_x.values().numpy())
class TestCooError(unittest.TestCase):
def test_small_shape(self):
with _test_eager_guard():
with self.assertRaises(ValueError):
indices = [[2, 3], [0, 2]]
values = [1, 2]
# 1. the shape too small
dense_shape = [2, 2]
sparse_x = paddle.sparse.sparse_coo_tensor(
indices, values, shape=dense_shape)
def test_same_nnz(self):
with _test_eager_guard():
with self.assertRaises(ValueError):
# 2. test the nnz of indices must same as nnz of values
indices = [[1, 2], [1, 0]]
values = [1, 2, 3]
sparse_x = paddle.sparse.sparse_coo_tensor(indices, values)
def test_same_dimensions(self):
with _test_eager_guard():
with self.assertRaises(ValueError):
indices = [[1, 2], [1, 0]]
values = [1, 2, 3]
shape = [2, 3, 4]
sparse_x = paddle.sparse.sparse_coo_tensor(
indices, values, shape=shape)
sparse_out_grad_cpu = paddle.sparse.sparse_coo_tensor(
grad_indices, grad_values, shape=[2, 2], place=place)
sparse_x_cpu.backward(sparse_out_grad_cpu)
assert np.array_equal(correct_values_grad, values_cpu.grad.numpy())
def test_indices_dtype(self):
with _test_eager_guard():
with self.assertRaises(TypeError):
indices = [[1.0, 2.0], [0, 1]]
values = [1, 2]
sparse_x = paddle.sparse.sparse_coo_tensor(indices, values)
class TestCsrError(unittest.TestCase):
def test_dimension1(self):
with _test_eager_guard():
with self.assertRaises(ValueError):
crows = [0, 1, 2, 3]
cols = [0, 1, 2]
values = [1, 2, 3]
shape = [3]
sparse_x = paddle.sparse.sparse_csr_tensor(crows, cols, values,
shape)
def test_dimension2(self):
with _test_eager_guard():
with self.assertRaises(ValueError):
crows = [0, 1, 2, 3]
cols = [0, 1, 2]
values = [1, 2, 3]
shape = [3, 3, 3, 3]
sparse_x = paddle.sparse.sparse_csr_tensor(crows, cols, values,
shape)
def test_same_shape1(self):
with _test_eager_guard():
with self.assertRaises(ValueError):
crows = [0, 1, 2, 3]
cols = [0, 1, 2, 3]
values = [1, 2, 3]
shape = [3, 4]
sparse_x = paddle.sparse.sparse_csr_tensor(crows, cols, values,
shape)
def test_same_shape2(self):
with _test_eager_guard():
with self.assertRaises(ValueError):
crows = [0, 1, 2, 3]
cols = [0, 1, 2, 3]
values = [1, 2, 3, 4]
shape = [3, 4]
sparse_x = paddle.sparse.sparse_csr_tensor(crows, cols, values,
shape)
def test_same_shape3(self):
with _test_eager_guard():
with self.assertRaises(ValueError):
crows = [0, 1, 2, 3, 0, 1, 2]
cols = [0, 1, 2, 3, 0, 1, 2]
values = [1, 2, 3, 4, 0, 1, 2]
shape = [2, 3, 4]
sparse_x = paddle.sparse.sparse_csr_tensor(crows, cols, values,
shape)
def test_crows_first_value(self):
with _test_eager_guard():
with self.assertRaises(ValueError):
crows = [1, 1, 2, 3]
cols = [0, 1, 2]
values = [1, 2, 3]
shape = [3, 4]
sparse_x = paddle.sparse.sparse_csr_tensor(crows, cols, values,
shape)
def test_dtype(self):
with _test_eager_guard():
with self.assertRaises(TypeError):
crows = [0, 1, 2, 3.0]
cols = [0, 1, 2]
values = [1, 2, 3]
shape = [3]
sparse_x = paddle.sparse.sparse_csr_tensor(crows, cols, values,
shape)
if __name__ == "__main__":
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import _C_ops
from ..framework import core, dygraph_only
from ..framework import _current_expected_place, _get_paddle_place
......@@ -51,6 +52,13 @@ def _get_place(place):
return place
def _check_indices_dtype(dtype):
if dtype not in [paddle.int8, paddle.int16, paddle.int32, paddle.int64]:
raise TypeError(
"the dtype of indices must be 'int8' or 'int16' or 'int32' or 'int64'"
)
@dygraph_only
def sparse_coo_tensor(indices,
values,
......@@ -117,6 +125,18 @@ def sparse_coo_tensor(indices,
if len(indices.shape) != 2:
raise ValueError("'indices' must be 2-D.")
nnz = indices.shape[1]
sparse_dim = indices.shape[0]
_check_indices_dtype(indices.dtype)
if nnz != values.shape[0]:
raise ValueError(
"the indices and values must have same number of non-zero, but get {} and {}".
format(nnz, values.shape[0]))
dense_dim = len(values.shape) - 1
if not indices.place._equals(place):
indices = indices._copy_to(place, False)
......@@ -125,8 +145,17 @@ def sparse_coo_tensor(indices,
values = _handle_dtype(values, dtype)
values.stop_gradient = stop_gradient
min_shape = _infer_dense_shape(indices)
if shape is None:
shape = _infer_dense_shape(indices)
shape = min_shape
else:
if shape < min_shape:
raise ValueError("the minimun shape required is {}, but get {}".
format(min_shape, shape))
if len(shape) != sparse_dim + dense_dim:
raise ValueError(
"the number of dimensions(len(shape) must be sparse_dim({}) + dense_dim({}), but get {}".
format(sparse_dim, dense_dim, len(shape)))
return _C_ops.final_state_sparse_create_sparse_coo_tensor(values, indices,
shape)
......@@ -144,6 +173,7 @@ def sparse_csr_tensor(crows,
r"""
Constructs a sparse ``paddle.Tensor`` in CSR(Compressed Sparse Row) format according to the
``crows``, ``cols`` and ``values``.
Currently, the crows and cols of each batch must be incrementd.
Args:
crows(list|tuple|ndarray|Tensor): 1-D array, each element in the rows represents the
......@@ -202,10 +232,14 @@ def sparse_csr_tensor(crows,
cols = to_tensor(cols, dtype=None, place=place, stop_gradient=True)
if not isinstance(values, core.eager.Tensor):
values = to_tensor(values, dtype, place, stop_gradient)
if len(crows.shape) != 1 or len(cols.shape) != 1 or len(values.shape) != 1:
_check_indices_dtype(crows.dtype)
_check_indices_dtype(cols.dtype)
if len(shape) != 2 and len(shape) != 3:
raise ValueError(
"SparseCsrTensor only support 2-D or 3-D matrix. The 'crows', 'cols' and 'values' must be 1-D."
)
"SparseCsrTensor only support 2-D or 3-D matrix. but get shape {}".
format(shape))
if not crows.place._equals(place):
crows = crows._copy_to(place, False)
......@@ -217,5 +251,30 @@ def sparse_csr_tensor(crows,
values = values._copy_to(place, False)
values = _handle_dtype(values, dtype)
values.stop_gradient = stop_gradient
if len(crows.shape) != 1 or len(cols.shape) != 1 or len(values.shape) != 1:
raise ValueError("The 'crows', 'cols' and 'values' must be 1-D.")
if (len(cols) != len(values)):
raise ValueError("the length of cols must be same as length of values")
if len(shape) == 2:
if crows.shape[0] != shape[0] + 1:
raise ValueError(
"The length({}) of crows must be equal to the rows({})+1 of matrix.".
format(crows.shape[0], shape[0]))
if crows[0] != 0:
raise ValueError("the 0th value of crows must be 0")
if crows[-1] != values.shape[0]:
raise ValueError(
"the last value of crows must be equal the number of non-zero")
else:
if crows.shape[0] % (shape[0] + 1) != 0:
raise ValueError(
"The length({}) of crows must be divisible the rows({})+1 of matrix.".
format(crows.shape[0], shape[0]))
# TODO(zkh2016): check whether the value in crows and cols is legal
return core.eager.sparse_csr_tensor(crows, cols, values, shape,
stop_gradient)
......@@ -27,6 +27,7 @@
kernel :
func : sparse_coo_tensor
layout : values
data_type : values
backward : create_sparse_coo_tensor_grad
- api : csr_values
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册