未验证 提交 78132fe1 编写于 作者: Z zhangkaihuo 提交者: GitHub

Add a Sparse OP : to_sparse_coo (#39264)

* dense_to_sparse_coo

* optimize unit testing; support rocm

* 1. delete fluid related header file
2. update the copyright

* fix hipMemcpy

* update dense_to_sparsecoo

* add namespace sparse
上级 2d6d6fa1
add_subdirectory(lib)
cc_library(pten_api SRCS all.cc DEPS pten_function_api pten_bw_function_api manual_api)
cc_library(pten_api SRCS all.cc DEPS pten_function_api pten_bw_function_api manual_api sparse_api)
......@@ -27,6 +27,7 @@ limitations under the License. */
// new pten apis
#include "paddle/pten/api/include/api.h"
#include "paddle/pten/api/include/manual_api.h"
#include "paddle/pten/api/include/sparse_api.h"
#include "paddle/pten/api/include/tensor.h"
// pten common headers
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/api/include/tensor.h"
#include "paddle/pten/common/backend.h"
namespace paddle {
namespace experimental {
namespace sparse {
PADDLE_API Tensor to_sparse_coo(const Tensor& x,
Backend backend,
const int64_t sparse_dim);
} // namespace sparse
} // namespace experimental
} // namespace paddle
......@@ -63,5 +63,6 @@ add_custom_command(
VERBATIM)
cc_library(manual_api SRCS manual_api.cc DEPS pten_tensor pten kernel_dispatch)
cc_library(sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch)
cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch)
cc_library(pten_bw_function_api SRCS ${bw_api_source_file} DEPS pten_tensor pten kernel_dispatch backward_infermeta pten_function_api)
......@@ -19,3 +19,4 @@ limitations under the License. */
PT_DECLARE_API(Math);
PT_DECLARE_API(Utils);
PT_DECLARE_API(SparseApi);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/api/include/sparse_api.h"
#include <memory>
#include "glog/logging.h"
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/infermeta/unary.h"
PT_DECLARE_KERNEL(dense_to_sparse_coo, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(dense_to_sparse_coo, GPU, ALL_LAYOUT);
#endif
namespace paddle {
namespace experimental {
namespace sparse {
PADDLE_API Tensor to_sparse_coo(const Tensor& x,
Backend backend,
const int64_t sparse_dim) {
if (x.layout() == pten::DataLayout::SPARSE_COO) {
return x;
}
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
std::string kernel_name = "dense_to_sparse_coo";
if (x.layout() == pten::DataLayout::SPARSE_CSR) {
kernel_name = "sparse_csr_to_coo";
}
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, kernel_key);
VLOG(6) << "to API kernel key: " << kernel_key;
VLOG(6) << "to API kernel: " << kernel;
// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(dev_ctx);
// 3. Auto data transform
if (x.layout() == pten::DataLayout::SPARSE_CSR) {
auto input = std::dynamic_pointer_cast<pten::SparseCsrTensor>(x.impl());
kernel_context.EmplaceBackInput(input.get());
} else {
auto input = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackInput(input.get());
kernel_context.EmplaceBackAttr(sparse_dim);
}
// 4. InferMeta
auto indices_meta = pten::DenseTensorMeta(
pten::DataType::INT64, {-1}, pten::DataLayout::NCHW);
auto elements_meta = pten::DenseTensorMeta(x.dtype(), {-1}, x.layout());
// 5. Prepare outputs
// create empty SparseCooTensor
pten::DenseTensor non_zero_indices(
pten::make_intrusive<paddle::experimental::SharedStorage>(
pten::TransToFluidPlace(backend)),
std::move(indices_meta));
pten::DenseTensor non_zero_elements(
pten::make_intrusive<paddle::experimental::SharedStorage>(
pten::TransToFluidPlace(backend)),
std::move(elements_meta));
auto coo = std::make_shared<pten::SparseCooTensor>(
non_zero_indices, non_zero_elements, x.dims());
kernel_context.EmplaceBackOutput(coo.get());
Tensor out;
out.set_impl(coo);
// 6. Call kernel
kernel(&kernel_context);
return out;
}
} // namespace sparse
} // namespace experimental
} // namespace paddle
PT_REGISTER_API(SparseApi);
......@@ -9,7 +9,7 @@ add_subdirectory(funcs)
# pten depends all pten kernel targets
set_property(GLOBAL PROPERTY PTEN_KERNELS "")
set(COMMON_KERNEL_DEPS dense_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils)
set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function)
# remove this dep after removing fluid deps on tensor creation
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} pten_api_utils)
......@@ -23,5 +23,6 @@ endif()
# auto build kernel targets by cmake
register_kernels(EXCLUDES math_kernel DEPS ${COMMON_KERNEL_DEPS})
kernel_library(math_kernel DEPS ${MATH_KERNEL_DEPS})
add_subdirectory(sparse)
copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})
set(SPARSE_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils)
register_kernels(DEPS ${SPARSE_KERNEL_DEPS} SUB_DIR "sparse_kernel")
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/sparse/sparse_utils_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/tensor_meta.h"
namespace pten {
namespace sparse {
template <typename T>
inline bool IsZero(const T* data, const size_t n) {
const T zero = static_cast<T>(0);
for (size_t i = 0; i < n; i++) {
if (data[i] != zero) {
return false;
}
}
return true;
}
// TODO(zhangkaihuo): implement a kernel to count the number of non-zero
// elements in tensor
template <typename T>
inline int64_t GetNonZeroNum(const DenseTensor& dense,
const int64_t sparse_dim) {
const auto& dims = dense.dims();
PADDLE_ENFORCE_GE(
dims.size(),
sparse_dim,
paddle::platform::errors::InvalidArgument(
"sparse_dim(%d) should be less than or equal to dense.dim(%d)",
sparse_dim,
dims.size()));
auto dims_2d = flatten_to_2d(dims, sparse_dim);
const int rows = dims_2d[0];
const int cols = dims_2d[1];
const T* data = dense.data<T>();
int64_t non_zero_num = 0;
for (int64_t i = 0; i < rows; i++) {
if (!IsZero(data + i * cols, cols)) {
non_zero_num = non_zero_num + 1;
}
}
return non_zero_num;
}
template <typename T, typename Context>
void DenseToSparseCooKernel(const Context& dev_ctx,
const DenseTensor& x,
const int64_t sparse_dim,
SparseCooTensor* out) {
const T* x_data = x.data<T>();
const auto& x_dims = x.dims();
int64_t non_zero_num = GetNonZeroNum<T>(x, sparse_dim);
const auto place = dev_ctx.GetPlace();
const auto values_dims = InferDenseDims(x_dims, sparse_dim, non_zero_num);
DenseTensorMeta indices_meta(DataType::INT64,
{sparse_dim, static_cast<int64_t>(non_zero_num)},
DataLayout::NCHW);
DenseTensorMeta values_meta(x.meta().dtype, values_dims, x.meta().layout);
pten::DenseTensor indices =
pten::Empty<int64_t, Context>(dev_ctx, std::move(indices_meta));
pten::DenseTensor values =
pten::Empty<T, Context>(dev_ctx, std::move(values_meta));
int64_t* indices_data = indices.mutable_data<int64_t>(place);
T* values_data = values.mutable_data<T>(place);
auto dims_2d = flatten_to_2d(x_dims, sparse_dim);
const int rows = dims_2d[0];
const int cols = dims_2d[1];
int index = 0;
for (int i = 0; i < rows; i++) {
if (!IsZero(x_data + i * cols, cols)) {
int64_t sparse_index = i;
for (int64_t j = sparse_dim - 1; j >= 0; j--) {
indices_data[j * non_zero_num + index] = sparse_index % x_dims[j];
sparse_index /= x_dims[j];
}
memcpy(values_data + index * cols, x_data + i * cols, cols * sizeof(T));
++index;
}
}
out->SetMember(indices, values, x_dims, true);
}
} // namespace sparse
} // namespace pten
PT_REGISTER_KERNEL(dense_to_sparse_coo,
CPU,
ALL_LAYOUT,
pten::sparse::DenseToSparseCooKernel,
float,
double,
paddle::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/execution_policy.h>
#include <thrust/remove.h>
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/tensor_meta.h"
#include "paddle/pten/kernels/sparse/sparse_utils_kernel.h"
namespace pten {
namespace sparse {
template <typename T>
inline __device__ bool DevIsZero(const T* data, const int64_t cols) {
const T zero = static_cast<T>(0);
// TODO(zhangkaihuo): check the data is zero or not in parallen when cols > 1
for (int64_t i = 0; i < cols; i++) {
if (data[i] != zero) {
return false;
}
}
return true;
}
template <typename T>
__global__ void GetNonZeroNums(const T* dense_data,
const int rows,
const int cols,
int* non_zero_num,
int* temp_indexs) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
__shared__ int counter;
if (threadIdx.x == 0) counter = 0;
__syncthreads();
for (int i = tid; i < rows; i += gridDim.x * blockDim.x) {
int index = -1;
// TODO(zhangkaihuo): when cols=1, vectorization can be used
if (!DevIsZero(dense_data + i * cols, cols)) {
// use reductions?
atomicAdd(&counter, 1);
index = i;
}
temp_indexs[i] = index;
}
__syncthreads();
if (threadIdx.x == 0) {
atomicAdd(non_zero_num, counter);
}
}
template <typename T>
__global__ void GetNonZeroElementsAndIndices(const T* dense_data,
const int64_t sparse_dim,
const int64_t cols,
const int64_t* x_dims,
const int non_zero_num,
const int* indexs,
int64_t* indices,
T* sparse_data) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) {
int64_t sparse_index = indexs[i];
int64_t x_index = sparse_index;
for (int64_t j = sparse_dim - 1; j >= 0; j--) {
indices[j * non_zero_num + i] = sparse_index % x_dims[j];
sparse_index /= x_dims[j];
}
for (int j = 0; j < cols; j++) {
sparse_data[i * cols + j] = dense_data[x_index * cols + j];
}
}
}
template <typename Context>
void GetGpuLaunchConfig1D(const Context& dev_ctx,
const int64_t n,
int* grid_size,
int* block_size) {
const int MAX_BLOCK_DIM = dev_ctx.GetMaxThreadsPerBlock();
const int MAX_GRID_DIM = dev_ctx.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM;
*block_size = (n >= MAX_BLOCK_DIM) ? MAX_BLOCK_DIM
: (1 << static_cast<int>(std::log2(n)));
*grid_size = n / *block_size;
*grid_size = (*grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : *grid_size;
}
template <typename T, typename Context>
void DenseToSparseCooKernel(const Context& dev_ctx,
const DenseTensor& x,
const int64_t sparse_dim,
SparseCooTensor* out) {
const T* x_data = x.data<T>();
const auto& x_dims = x.dims();
auto dims_2d = flatten_to_2d(x_dims, sparse_dim);
const int rows = dims_2d[0];
const int cols = dims_2d[1];
auto nums_meta =
pten::DenseTensorMeta(DataType::INT32, {1}, pten::DataLayout::NCHW);
DenseTensor nums =
pten::Empty<int64_t, Context>(dev_ctx, std::move(nums_meta));
auto x_dims_meta =
pten::DenseTensorMeta(DataType::INT64,
{static_cast<int64_t>(x_dims.size())},
pten::DataLayout::NCHW);
DenseTensor d_x_dims =
pten::Empty<T, Context>(dev_ctx, std::move(x_dims_meta));
const auto place = dev_ctx.GetPlace();
// 1. get numbers of non zero elements, and get the index of non zero elements
int* nums_ptr = nums.mutable_data<int>(place);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
hipMemsetAsync(nums_ptr, 0, sizeof(int), dev_ctx.stream()));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(nums_ptr, 0, sizeof(int), dev_ctx.stream()));
#endif
int grid_size = 1, block_size = 1;
GetGpuLaunchConfig1D(dev_ctx, rows, &grid_size, &block_size);
auto temp_indexs_meta =
pten::DenseTensorMeta(DataType::INT32, {rows}, pten::DataLayout::NCHW);
DenseTensor temp_indexs =
pten::Empty<T, Context>(dev_ctx, std::move(temp_indexs_meta));
int* temp_indexs_ptr = temp_indexs.mutable_data<int>(place);
GetNonZeroNums<<<grid_size, block_size, 0, dev_ctx.stream()>>>(
x_data, rows, cols, nums_ptr, temp_indexs_ptr);
#ifdef PADDLE_WITH_HIP
thrust::remove(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::remove(thrust::cuda::par.on(dev_ctx.stream()),
#endif
temp_indexs_ptr,
temp_indexs_ptr + rows,
-1);
// 2. copy non_zero_num to host, copy x_dims to device
int non_zero_num = 0;
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(&non_zero_num,
nums_ptr,
sizeof(int),
hipMemcpyDeviceToHost,
dev_ctx.stream()));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&non_zero_num,
nums_ptr,
sizeof(int),
cudaMemcpyDeviceToHost,
dev_ctx.stream()));
#endif
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
hipMemcpyAsync(d_x_dims.mutable_data<int64_t>(place),
x_dims.Get(),
x_dims.size() * sizeof(x_dims[0]),
hipMemcpyHostToDevice,
dev_ctx.stream()));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemcpyAsync(d_x_dims.mutable_data<int64_t>(place),
x_dims.Get(),
x_dims.size() * sizeof(x_dims[0]),
cudaMemcpyHostToDevice,
dev_ctx.stream()));
#endif
dev_ctx.Wait(); // wait the copy
const auto values_dims = InferDenseDims(x_dims, sparse_dim, non_zero_num);
DenseTensorMeta indices_meta(DataType::INT64,
{sparse_dim, static_cast<int64_t>(non_zero_num)},
DataLayout::NCHW);
DenseTensorMeta values_meta(x.meta().dtype, values_dims, x.meta().layout);
pten::DenseTensor indices(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(indices_meta));
pten::DenseTensor values(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(values_meta));
int64_t* indices_data = indices.mutable_data<int64_t>(place);
T* sparse_data = values.mutable_data<T>(place);
// 3. calc indices by indexs and get values by indexs
GetGpuLaunchConfig1D(dev_ctx, non_zero_num, &grid_size, &block_size);
GetNonZeroElementsAndIndices<<<grid_size, block_size, 0, dev_ctx.stream()>>>(
x_data,
sparse_dim,
cols,
d_x_dims.data<int64_t>(),
non_zero_num,
temp_indexs_ptr,
indices_data,
sparse_data);
out->SetMember(indices, values, x_dims, true);
}
} // namespace sparse
} // namespace pten
PT_REGISTER_KERNEL(dense_to_sparse_coo,
GPU,
ALL_LAYOUT,
pten::sparse::DenseToSparseCooKernel,
float,
double,
pten::dtype::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/sparse_coo_tensor.h"
#include "paddle/pten/core/sparse_csr_tensor.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten {
namespace sparse {
inline const DDim InferDenseDims(const DDim& x_dims,
const int64_t sparse_dim,
const int64_t non_zero_num) {
auto dense_dim = x_dims.size() - sparse_dim;
DDim values_dims;
if (dense_dim) {
std::vector<int64_t> dense_dim_vec(dense_dim + 1);
dense_dim_vec[0] = non_zero_num;
memcpy(&dense_dim_vec[1],
x_dims.Get() + sparse_dim,
dense_dim * sizeof(x_dims[0]));
values_dims = pten::framework::make_ddim(dense_dim_vec);
} else {
values_dims = pten::framework::make_ddim({non_zero_num});
}
return values_dims;
}
template <typename T, typename Context>
void DenseToSparseCooKernel(const Context& dev_ctx,
const DenseTensor& x,
const int64_t sparse_dim,
SparseCooTensor* out);
template <typename T, typename Context>
SparseCooTensor DenseToSparseCoo(const Context& dev_ctx,
const DenseTensor& x,
const int64_t sparse_dim) {
DenseTensor indices = pten::Empty<T, Context>(dev_ctx);
DenseTensor values = pten::Empty<T, Context>(dev_ctx);
SparseCooTensor coo(indices, values, x.dims());
DenseToSparseCooKernel<T, Context>(dev_ctx, x, sparse_dim, &coo);
return coo;
}
} // namespace sparse
} // namespace pten
......@@ -22,3 +22,4 @@ cc_test(test_scale_api SRCS test_scale_api.cc DEPS pten_tensor pten_api pten_api
cc_test(test_scale_benchmark SRCS test_scale_benchmark.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_conj_api SRCS test_conj_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_concat_api SRCS test_concat_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_sparse_utils_api SRCS test_sparse_utils_api.cc DEPS pten_tensor pten_api pten_api_utils)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/api/include/api.h"
#include "paddle/pten/api/include/sparse_api.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/sparse_coo_tensor.h"
TEST(API, to_sparse_coo) {
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
auto dense_x = std::make_shared<pten::DenseTensor>(
alloc.get(),
pten::DenseTensorMeta(pten::DataType::FLOAT32,
pten::framework::make_ddim({3, 3}),
pten::DataLayout::NCHW));
pten::CPUPlace cpu;
const int64_t sparse_dim = 2;
auto* dense_x_data = dense_x->mutable_data<float>(cpu);
float dense_data[3][3] = {{0.0, 1.0, 0.0}, {2.0, 0.0, 3.0}, {3.2, 0.0, 0.0}};
std::vector<float> non_zero_data = {1.0, 2.0, 3.0, 3.2};
std::vector<int64_t> indices_data = {0, 1, 1, 2, 1, 0, 2, 0};
std::vector<int64_t> cols_data = {1, 0, 2, 0};
std::vector<int64_t> crows_data = {0, 1, 3, 4};
const int64_t non_zero_num = 4;
std::copy(&dense_data[0][0], &dense_data[0][0] + 9, dense_x_data);
pten::CPUContext dev_ctx_cpu;
// 1. test dense_to_sparse_coo
paddle::experimental::Tensor x(dense_x);
auto out = paddle::experimental::sparse::to_sparse_coo(
x, pten::Backend::CPU, sparse_dim);
auto coo = std::dynamic_pointer_cast<pten::SparseCooTensor>(out.impl());
ASSERT_EQ(coo->nnz(), non_zero_num);
int cmp_indices = memcmp(coo->non_zero_indices().data<int64_t>(),
indices_data.data(),
indices_data.size() * sizeof(int64_t));
ASSERT_EQ(cmp_indices, 0);
int cmp_elements = memcmp(coo->non_zero_elements().data<float>(),
non_zero_data.data(),
non_zero_data.size() * sizeof(float));
ASSERT_EQ(cmp_elements, 0);
}
......@@ -11,3 +11,4 @@ cc_test(test_reshape_dev_api SRCS test_reshape_dev_api.cc DEPS pten pten_api_uti
cc_test(test_sum_dev_api SRCS test_sum_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_conj_dev_api SRCS test_conj_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_concat_dev_api SRCS test_concat_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_sparse_utils_dev_api SRCS test_sparse_utils_dev_api.cc DEPS pten pten_api_utils)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF NCHW KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/kernels/copy_kernel.h"
#include "paddle/pten/kernels/sparse/sparse_utils_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
namespace pten {
namespace tests {
template <typename ValueT, typename IndicesT>
inline void CheckResult(
const DeviceContext* dev_ctx,
const SparseCooTensor& coo,
const std::vector<ValueT> non_zero_elements,
const std::vector<IndicesT>& non_zero_indices,
const int64_t non_zero_num,
const std::shared_ptr<paddle::experimental::DefaultAllocator>& alloc) {
const DenseTensor real_indices = coo.non_zero_indices();
const DenseTensor real_elements = coo.non_zero_elements();
ASSERT_EQ(coo.nnz(), non_zero_num);
#if defined(PADDLE_WITH_CUDA)
if (coo.place() == paddle::platform::CUDAPlace()) {
const auto* dev_ctx_cuda =
static_cast<const paddle::platform::CUDADeviceContext*>(dev_ctx);
DenseTensor indices(
alloc.get(),
DenseTensorMeta(
DataType::INT64, real_indices.dims(), real_indices.layout()));
DenseTensor elements(alloc.get(),
DenseTensorMeta(real_elements.dtype(),
real_elements.dims(),
real_elements.layout()));
pten::Copy(*dev_ctx_cuda, real_indices, true, &indices);
pten::Copy(*dev_ctx_cuda, real_elements, true, &elements);
int cmp_indices = memcmp(indices.data<IndicesT>(),
non_zero_indices.data(),
non_zero_indices.size() * sizeof(IndicesT));
ASSERT_EQ(cmp_indices, 0);
int cmp_elements = memcmp(elements.data<ValueT>(),
non_zero_elements.data(),
non_zero_elements.size() * sizeof(ValueT));
ASSERT_EQ(cmp_elements, 0);
} else {
#endif
int cmp_indices = memcmp(real_indices.data<IndicesT>(),
non_zero_indices.data(),
non_zero_indices.size() * sizeof(IndicesT));
ASSERT_EQ(cmp_indices, 0);
int cmp_elements = memcmp(real_elements.data<ValueT>(),
non_zero_elements.data(),
non_zero_elements.size() * sizeof(ValueT));
ASSERT_EQ(cmp_elements, 0);
#if defined(PADDLE_WITH_CUDA)
}
#endif
}
template <typename T>
void TestDenseToSparseCoo(const DenseTensor& dense_x,
const int64_t sparse_dim,
const std::vector<T>& non_zero_data,
const std::vector<int64_t>& indices_data,
const int64_t non_zero_num) {
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
pten::CPUContext dev_ctx_cpu;
// 1. test cpu
auto cpu_sparse_out =
sparse::DenseToSparseCoo<T>(dev_ctx_cpu, dense_x, sparse_dim);
CheckResult<T, int64_t>(&dev_ctx_cpu,
cpu_sparse_out,
non_zero_data,
indices_data,
non_zero_num,
alloc);
// 2. test cuda
#if defined(PADDLE_WITH_CUDA)
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx_cuda = pool.GetByPlace(paddle::platform::CUDAPlace());
const auto cuda_alloc =
std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CUDAPlace());
DenseTensor d_dense_x(
cuda_alloc.get(),
DenseTensorMeta(dense_x.dtype(), dense_x.dims(), dense_x.layout()));
pten::Copy(*dev_ctx_cuda, dense_x, true, &d_dense_x);
auto sparse_out =
sparse::DenseToSparseCoo<T>(*dev_ctx_cuda, d_dense_x, sparse_dim);
CheckResult<T, int64_t>(dev_ctx_cuda,
sparse_out,
non_zero_data,
indices_data,
non_zero_num,
alloc);
#endif
}
TEST(DEV_API, to_sparse_coo) {
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
std::default_random_engine random(time(NULL));
std::uniform_real_distribution<float> dis(0.0, 1.0);
std::uniform_int_distribution<int> dis_int(4, 64);
const int rows = dis_int(random), cols = dis_int(random);
DenseTensor dense_x(
alloc.get(),
DenseTensorMeta(DataType::FLOAT32, {rows, cols}, DataLayout::NCHW));
pten::CPUPlace cpu;
auto* dense_x_data = dense_x.mutable_data<float>(cpu);
std::vector<float> dense_data(rows * cols);
std::vector<float> non_zero_data;
std::vector<int64_t> rows_data, cols_data;
const int64_t sparse_dim = 2;
const float zero_rate = dis(random);
int64_t non_zero_num = 0;
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
bool iszero = dis(random) < zero_rate;
if (iszero) {
dense_data[i * cols + j] = 0.0;
} else {
float data = dis(random);
dense_data[i * cols + j] = data;
non_zero_data.push_back(data);
rows_data.push_back(i);
cols_data.push_back(j);
non_zero_num += 1;
}
}
}
std::copy(
dense_data.data(), dense_data.data() + dense_data.size(), dense_x_data);
std::vector<int64_t> indices_data(non_zero_num * 2);
memcpy(&indices_data[0], &rows_data[0], non_zero_num * sizeof(int64_t));
memcpy(&indices_data[non_zero_num],
&cols_data[0],
non_zero_num * sizeof(int64_t));
TestDenseToSparseCoo(
dense_x, sparse_dim, non_zero_data, indices_data, non_zero_num);
}
TEST(DEV_API, to_sparse_coo_hybird) {
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
DenseTensor dense_x(
alloc.get(),
DenseTensorMeta(DataType::FLOAT32, {3, 3}, DataLayout::NCHW));
pten::CPUPlace cpu;
const int64_t sparse_dim = 1; // the non zero element is a vector
auto* dense_x_data = dense_x.mutable_data<float>(cpu);
float dense_data[3][3] = {{0.0, 1.0, 0.0}, {0.0, 0.0, 0.0}, {3.2, 0.0, 0.0}};
std::vector<float> non_zero_data = {
/*element0(*/ 0.0, 1.0, 0.0 /*)*/, /*element1(*/ 3.2, 0.0, 0.0 /*)*/};
std::vector<int64_t> indices_data = {0, 2};
const int64_t non_zero_num = 2;
std::copy(&dense_data[0][0], &dense_data[0][0] + 9, dense_x_data);
TestDenseToSparseCoo(
dense_x, sparse_dim, non_zero_data, indices_data, non_zero_num);
}
TEST(DEV_API, to_sparse_coo_fp16) {
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
DenseTensor dense_x(
alloc.get(),
DenseTensorMeta(DataType::FLOAT16, {3, 3}, DataLayout::NCHW));
pten::CPUPlace cpu;
const int64_t sparse_dim = 2;
const int64_t non_zero_num = 2;
auto* dense_x_data = dense_x.mutable_data<pten::dtype::float16>(cpu);
float dense_data[3][3] = {{0.0, 1.0, 0.0}, {0.0, 0.0, 0.0}, {3.2, 0.0, 0.0}};
std::vector<float> data = {1.0, 3.2};
std::vector<pten::dtype::float16> non_zero_data(non_zero_num);
for (int i = 0; i < non_zero_num; i++) {
non_zero_data[i] = static_cast<pten::dtype::float16>(data[i]);
}
std::vector<int64_t> indices_data = {0, 2, 1, 0};
std::copy(&dense_data[0][0], &dense_data[0][0] + 9, dense_x_data);
TestDenseToSparseCoo<paddle::float16>(
dense_x, sparse_dim, non_zero_data, indices_data, non_zero_num);
}
TEST(DEV_API, to_sparse_coo_batch) {
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
DenseTensor dense_x(
alloc.get(),
DenseTensorMeta(DataType::FLOAT32, {2, 3, 3}, DataLayout::NCHW));
pten::CPUPlace cpu;
const int64_t sparse_dim = 3;
const int64_t non_zero_num = 4;
auto* dense_x_data = dense_x.mutable_data<float>(cpu);
float dense_data[2][3][3] = {
{{0.0, 1.0, 0.0}, {0.0, 0.0, 0.0}, {2.0, 0.0, 0.0}},
{{0.0, 0.0, 0.0}, {0.0, 3.0, 0.0}, {4.0, 0.0, 0.0}}};
std::vector<float> non_zero_data = {1.0, 2.0, 3.0, 4.0};
std::vector<int64_t> indices_data = {0, 0, 1, 1, 0, 2, 1, 2, 1, 0, 1, 0};
/*
0, 0, 1, 1,
0, 2, 1, 2,
1, 0, 1, 0
*/
std::copy(&dense_data[0][0][0], &dense_data[0][0][0] + 18, dense_x_data);
TestDenseToSparseCoo<float>(
dense_x, sparse_dim, non_zero_data, indices_data, non_zero_num);
}
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册