未验证 提交 76d527e1 编写于 作者: Z zhangkaihuo 提交者: GitHub

Add a Sparse Op: to_sparse_csr (#39333)

* implement AllocateFrom

* dense_to_sparse_coo

* optimize unit testing; support rocm

* 1. delete fluid related header file
2. update the copyright

* fix hipMemcpy

* update dense_to_sparsecoo

* add namespace sparse

* sparse_csr_to_dense

* test to_sparse_coo: csr_to_coo

* fix writing error

* to_sparse_csr: dense_to_sparse_csr and sparse_coo_to_csr

* fix check shape

* fix unit test

* replace CUDADeviceContext by GPUContext
上级 e606b44a
......@@ -25,6 +25,8 @@ PADDLE_API Tensor to_sparse_coo(const Tensor& x,
Backend backend,
const int64_t sparse_dim);
PADDLE_API Tensor to_sparse_csr(const Tensor& x, Backend backend);
} // namespace sparse
} // namespace experimental
} // namespace paddle
......@@ -23,9 +23,15 @@ limitations under the License. */
#include "paddle/pten/infermeta/unary.h"
PT_DECLARE_KERNEL(dense_to_sparse_coo, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(sparse_csr_to_coo, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(dense_to_sparse_csr, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(sparse_coo_to_csr, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(dense_to_sparse_coo, GPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(sparse_csr_to_coo, GPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(dense_to_sparse_csr, GPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(sparse_coo_to_csr, GPU, ALL_LAYOUT);
#endif
namespace paddle {
......@@ -95,6 +101,71 @@ PADDLE_API Tensor to_sparse_coo(const Tensor& x,
return out;
}
PADDLE_API Tensor to_sparse_csr(const Tensor& x, Backend backend) {
if (x.layout() == pten::DataLayout::SPARSE_CSR) {
return x;
}
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
std::string kernel_name = "dense_to_sparse_csr";
if (x.layout() == pten::DataLayout::SPARSE_COO) {
kernel_name = "sparse_coo_to_csr";
}
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, kernel_key);
VLOG(6) << "to API kernel key: " << kernel_key;
VLOG(6) << "to API kernel: " << kernel;
// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(dev_ctx);
// 3. Auto data transform
if (x.layout() == pten::DataLayout::SPARSE_COO) {
auto input = std::dynamic_pointer_cast<pten::SparseCooTensor>(x.impl());
kernel_context.EmplaceBackInput(input.get());
} else {
auto input = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackInput(input.get());
}
// 4. InferMeta
auto crows_meta = pten::DenseTensorMeta(
pten::DataType::INT64, {-1}, pten::DataLayout::NCHW);
auto cols_meta = pten::DenseTensorMeta(
pten::DataType::INT64, {-1}, pten::DataLayout::NCHW);
auto elements_meta = pten::DenseTensorMeta(x.dtype(), {-1}, x.layout());
// 5. Prepare outputs
// create empty SparseCooTensor
pten::DenseTensor non_zero_crows(
pten::make_intrusive<paddle::experimental::SharedStorage>(
pten::TransToFluidPlace(backend)),
std::move(crows_meta));
pten::DenseTensor non_zero_cols(
pten::make_intrusive<paddle::experimental::SharedStorage>(
pten::TransToFluidPlace(backend)),
std::move(cols_meta));
pten::DenseTensor non_zero_elements(
pten::make_intrusive<paddle::experimental::SharedStorage>(
pten::TransToFluidPlace(backend)),
std::move(elements_meta));
auto csr = std::make_shared<pten::SparseCsrTensor>(
non_zero_crows, non_zero_cols, non_zero_elements, x.dims());
kernel_context.EmplaceBackOutput(csr.get());
Tensor out;
out.set_impl(csr);
// 6. Call kernel
kernel(&kernel_context);
return out;
}
} // namespace sparse
} // namespace experimental
} // namespace paddle
......
......@@ -157,6 +157,91 @@ void SparseCsrToCooKernel(const Context& dev_ctx,
out->SetMember(indices, values, x_dims, true);
}
template <typename T, typename Context>
void SparseCooToCsrKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCsrTensor* out) {
const auto& x_dims = x.dims();
bool valid = x_dims.size() == 2 || x_dims.size() == 3;
PADDLE_ENFORCE_EQ(valid,
true,
paddle::platform::errors::InvalidArgument(
"SparseCsrTensor only support 2-D or 3-D matrix"));
const int64_t non_zero_num = x.nnz();
if (non_zero_num <= 0) return;
int batchs = x_dims.size() == 2 ? 1 : x_dims[0];
int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1];
const auto place = dev_ctx.GetPlace();
DenseTensorMeta crows_meta(
DataType::INT64, {batchs * (rows + 1)}, DataLayout::NCHW);
DenseTensorMeta cols_meta(DataType::INT64, {non_zero_num}, DataLayout::NCHW);
DenseTensorMeta values_meta(x.dtype(), {non_zero_num}, x.layout());
pten::DenseTensor non_zero_crows(
pten::make_intrusive<paddle::experimental::SharedStorage>(place),
std::move(crows_meta));
pten::DenseTensor non_zero_cols(
pten::make_intrusive<paddle::experimental::SharedStorage>(place),
std::move(cols_meta));
pten::DenseTensor non_zero_elements(
pten::make_intrusive<paddle::experimental::SharedStorage>(place),
std::move(values_meta));
int64_t* csr_crows_data = non_zero_crows.mutable_data<int64_t>(place);
int64_t* csr_cols_data = non_zero_cols.mutable_data<int64_t>(place);
T* csr_values_data = non_zero_elements.mutable_data<T>(place);
const auto& coo_indices = x.non_zero_indices();
const auto& coo_values = x.non_zero_elements();
const int64_t* batchs_ptr = coo_indices.data<int64_t>();
const int64_t* coo_rows_data =
batchs == 1 ? batchs_ptr : batchs_ptr + non_zero_num;
const int64_t* coo_cols_data = coo_rows_data + non_zero_num;
const T* coo_values_data = coo_values.data<T>();
if (!x.coalesced()) {
// TODO(zhangkahuo): call coalesced() to distinct and sort the indices
}
std::vector<int64_t> offsets(batchs, 0);
if (batchs > 1) {
for (int i = 0; i < non_zero_num; i++) {
if (i == non_zero_num - 1 || batchs_ptr[i] != batchs_ptr[i + 1]) {
offsets[batchs_ptr[i]] = i + 1;
}
}
} else {
offsets[0] = non_zero_num;
}
for (int b = 0; b < batchs; b++) {
if (offsets[b] == 0) continue;
int batch_start = 0;
int batch_non_zero_num = offsets[b];
if (b > 0) {
batch_start = offsets[b - 1];
batch_non_zero_num -= batch_start;
}
auto* coo_rows_ptr = coo_rows_data + batch_start;
for (int i = 0; i <= coo_rows_ptr[0]; i++) {
csr_crows_data[b * (rows + 1) + i] = 0;
}
for (int64_t i = 1; i < batch_non_zero_num; i++) {
for (int j = coo_rows_ptr[i - 1]; j < coo_rows_ptr[i]; j++) {
csr_crows_data[b * (rows + 1) + j + 1] = i;
}
}
for (int64_t i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1;
i++) {
csr_crows_data[b * (rows + 1) + i] = batch_non_zero_num;
}
}
memcpy(csr_cols_data, coo_cols_data, sizeof(int64_t) * non_zero_num);
memcpy(csr_values_data, coo_values_data, sizeof(T) * non_zero_num);
out->SetMember(non_zero_crows, non_zero_cols, non_zero_elements, x_dims);
}
} // namespace sparse
} // namespace pten
......@@ -185,3 +270,29 @@ PT_REGISTER_KERNEL(sparse_csr_to_coo,
int16_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(sparse_coo_to_csr,
CPU,
ALL_LAYOUT,
pten::sparse::SparseCooToCsrKernel,
float,
double,
pten::dtype::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(dense_to_sparse_csr,
CPU,
ALL_LAYOUT,
pten::sparse::DenseToSparseCsrKernel,
float,
double,
pten::dtype::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
......@@ -330,6 +330,143 @@ void SparseCsrToCooKernel(const Context& dev_ctx,
out->SetMember(indices, values, x_dims, true);
}
__global__ void GetBatchsOffset(const int64_t* batchs_ptr,
const int non_zero_num,
int64_t* batchs_offset) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) {
if (i == non_zero_num - 1 || batchs_ptr[i] != batchs_ptr[i + 1]) {
batchs_offset[batchs_ptr[i]] = i + 1;
}
}
}
__global__ void ConvertCooRowsToCsrCrows(
const int64_t* batchs_offset, // can be null if batchs = 1
const int64_t* coo_rows_data,
int64_t* csr_crows_data,
const int rows,
const int64_t non_zero_num) {
const int b = blockIdx.y;
int batch_non_zero_num =
batchs_offset == nullptr ? non_zero_num : batchs_offset[b];
if (batch_non_zero_num == 0) return;
int batch_start = 0;
if (b > 0) {
batch_start = batchs_offset[b - 1];
batch_non_zero_num -= batch_start;
}
auto* coo_rows_ptr = coo_rows_data + batch_start;
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < batch_non_zero_num; i += gridDim.x * blockDim.x) {
if (i == 0) {
for (int j = 0; j <= coo_rows_ptr[0]; j++) {
csr_crows_data[b * (rows + 1) + j] = 0;
}
} else {
for (int j = coo_rows_ptr[i - 1]; j < coo_rows_ptr[i]; j++) {
csr_crows_data[b * (rows + 1) + j + 1] = i;
}
}
if (i == batch_non_zero_num - 1) {
for (int64_t i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1;
i++) {
csr_crows_data[b * (rows + 1) + i] = batch_non_zero_num;
}
}
}
}
template <typename T, typename Context>
void SparseCooToCsrKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCsrTensor* out) {
const auto& x_dims = x.dims();
bool valid = x_dims.size() == 2 || x_dims.size() == 3;
PADDLE_ENFORCE_EQ(valid,
true,
paddle::platform::errors::InvalidArgument(
"SparseCsrTensor only support 2-D or 3-D matrix"));
const int64_t non_zero_num = x.nnz();
if (non_zero_num <= 0) return;
int batchs = x_dims.size() == 2 ? 1 : x_dims[0];
int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1];
const auto place = dev_ctx.GetPlace();
DenseTensorMeta crows_meta(
DataType::INT64, {batchs * (rows + 1)}, DataLayout::NCHW);
DenseTensorMeta cols_meta(DataType::INT64, {non_zero_num}, DataLayout::NCHW);
DenseTensorMeta values_meta(x.dtype(), {non_zero_num}, x.layout());
pten::DenseTensor non_zero_crows(
pten::make_intrusive<paddle::experimental::SharedStorage>(place),
std::move(crows_meta));
pten::DenseTensor non_zero_cols(
pten::make_intrusive<paddle::experimental::SharedStorage>(place),
std::move(cols_meta));
pten::DenseTensor non_zero_elements(
pten::make_intrusive<paddle::experimental::SharedStorage>(place),
std::move(values_meta));
int64_t* csr_crows_data = non_zero_crows.mutable_data<int64_t>(place);
int64_t* csr_cols_data = non_zero_cols.mutable_data<int64_t>(place);
T* csr_values_data = non_zero_elements.mutable_data<T>(place);
const auto& coo_indices = x.non_zero_indices();
const auto& coo_values = x.non_zero_elements();
const int64_t* batchs_ptr = coo_indices.data<int64_t>();
const int64_t* coo_rows_data =
batchs == 1 ? batchs_ptr : batchs_ptr + non_zero_num;
const int64_t* coo_cols_data = coo_rows_data + non_zero_num;
const T* coo_values_data = coo_values.data<T>();
if (!x.coalesced()) {
// TODO(zhangkahuo): call coalesced() to distinct and sort the indices
}
int grid_size = 1, block_size = 1;
GetGpuLaunchConfig1D(dev_ctx, batchs, &grid_size, &block_size);
if (batchs > 1) {
DenseTensorMeta batchs_meta(DataType::INT64, {batchs}, DataLayout::NCHW);
pten::DenseTensor batchs_offset(
pten::make_intrusive<paddle::experimental::SharedStorage>(place),
std::move(batchs_meta));
int64_t* batchs_offset_ptr = batchs_offset.mutable_data<int64_t>(place);
GetBatchsOffset<<<grid_size, block_size, 0, dev_ctx.stream()>>>(
batchs_ptr, non_zero_num, batchs_offset_ptr);
dim3 grids(grid_size, batchs, 1);
ConvertCooRowsToCsrCrows<<<grids, block_size, 0, dev_ctx.stream()>>>(
batchs_offset_ptr, coo_rows_data, csr_crows_data, rows, non_zero_num);
} else {
ConvertCooRowsToCsrCrows<<<grid_size, block_size, 0, dev_ctx.stream()>>>(
nullptr, coo_rows_data, csr_crows_data, rows, non_zero_num);
}
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(csr_cols_data,
coo_cols_data,
sizeof(int64_t) * non_zero_num,
hipMemcpyDeviceToDevice,
dev_ctx.stream()));
PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(csr_values_data,
coo_values_data,
sizeof(T) * non_zero_num,
hipMemcpyDeviceToDevice,
dev_ctx.stream()));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(csr_cols_data,
coo_cols_data,
sizeof(int64_t) * non_zero_num,
cudaMemcpyDeviceToDevice,
dev_ctx.stream()));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(csr_values_data,
coo_values_data,
sizeof(T) * non_zero_num,
cudaMemcpyDeviceToDevice,
dev_ctx.stream()));
#endif
out->SetMember(non_zero_crows, non_zero_cols, non_zero_elements, x_dims);
}
} // namespace sparse
} // namespace pten
......@@ -358,3 +495,29 @@ PT_REGISTER_KERNEL(sparse_csr_to_coo,
int16_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(sparse_coo_to_csr,
GPU,
ALL_LAYOUT,
pten::sparse::SparseCooToCsrKernel,
float,
double,
pten::dtype::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(dense_to_sparse_csr,
GPU,
ALL_LAYOUT,
pten::sparse::DenseToSparseCsrKernel,
float,
double,
pten::dtype::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
......@@ -72,5 +72,51 @@ SparseCooTensor SparseCsrToCoo(const Context& dev_ctx,
return coo;
}
template <typename T, typename Context>
void SparseCooToCsrKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCsrTensor* out);
template <typename T, typename Context>
SparseCsrTensor SparseCooToCsr(const Context& dev_ctx,
const SparseCooTensor& x) {
DenseTensor non_zero_crows = pten::Empty<int64_t, Context>(dev_ctx);
DenseTensor non_zero_cols = pten::Empty<int64_t, Context>(dev_ctx);
DenseTensor non_zero_elements = pten::Empty<T, Context>(dev_ctx);
SparseCsrTensor csr(
non_zero_crows, non_zero_cols, non_zero_elements, x.dims());
SparseCooToCsrKernel<T, Context>(dev_ctx, x, &csr);
return csr;
}
template <typename T, typename Context>
void DenseToSparseCsrKernel(const Context& dev_ctx,
const DenseTensor& x,
SparseCsrTensor* out) {
const auto& x_dims = x.dims();
bool valid = x_dims.size() == 2 || x_dims.size() == 3;
PADDLE_ENFORCE_EQ(valid,
true,
paddle::platform::errors::InvalidArgument(
"SparseCsrTensor only support 2-D or 3-D Tensor."));
const int64_t sparse_dim = x_dims.size() == 2 ? 2 : 3;
DenseTensor indices = pten::Empty<T, Context>(dev_ctx);
DenseTensor values = pten::Empty<T, Context>(dev_ctx);
SparseCooTensor coo(indices, values, x.dims());
DenseToSparseCooKernel<T, Context>(dev_ctx, x, sparse_dim, &coo);
SparseCooToCsrKernel<T, Context>(dev_ctx, coo, out);
}
template <typename T, typename Context>
SparseCsrTensor DenseToSparseCsr(const Context& dev_ctx, const DenseTensor& x) {
DenseTensor non_zero_crows = pten::Empty<int64_t, Context>(dev_ctx);
DenseTensor non_zero_cols = pten::Empty<int64_t, Context>(dev_ctx);
DenseTensor non_zero_elements = pten::Empty<T, Context>(dev_ctx);
SparseCsrTensor csr(
non_zero_crows, non_zero_cols, non_zero_elements, x.dims());
DenseToSparseCsrKernel<T, Context>(dev_ctx, x, &csr);
return csr;
}
} // namespace sparse
} // namespace pten
......@@ -8,8 +8,8 @@ You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See
the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
......@@ -103,3 +103,75 @@ TEST(API, to_sparse_coo) {
non_zero_data.size() * sizeof(float));
ASSERT_EQ(cmp_elements2, 0);
}
TEST(API, to_sparse_csr) {
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
auto dense_x = std::make_shared<pten::DenseTensor>(
alloc.get(),
pten::DenseTensorMeta(pten::DataType::FLOAT32,
pten::framework::make_ddim({3, 3}),
pten::DataLayout::NCHW));
pten::CPUPlace cpu;
const int64_t sparse_dim = 2;
auto* dense_x_data = dense_x->mutable_data<float>(cpu);
float dense_data[3][3] = {{0.0, 1.0, 0.0}, {2.0, 0.0, 3.0}, {3.2, 0.0, 0.0}};
std::vector<float> non_zero_data = {1.0, 2.0, 3.0, 3.2};
std::vector<int64_t> indices_data = {0, 1, 1, 2, 1, 0, 2, 0};
std::vector<int64_t> cols_data = {1, 0, 2, 0};
std::vector<int64_t> crows_data = {0, 1, 3, 4};
const int64_t non_zero_num = 4;
std::copy(&dense_data[0][0], &dense_data[0][0] + 9, dense_x_data);
pten::CPUContext dev_ctx_cpu;
// 1. test dense_to_sparse_csr
paddle::experimental::Tensor x(dense_x);
auto out = paddle::experimental::sparse::to_sparse_csr(x, pten::Backend::CPU);
auto csr = std::dynamic_pointer_cast<pten::SparseCsrTensor>(out.impl());
auto check = [&](const pten::SparseCsrTensor& csr) {
ASSERT_EQ(csr.non_zero_cols().numel(), non_zero_num);
int cmp_crows = memcmp(csr.non_zero_crows().data<int64_t>(),
crows_data.data(),
crows_data.size() * sizeof(int64_t));
ASSERT_EQ(cmp_crows, 0);
int cmp_cols = memcmp(csr.non_zero_cols().data<int64_t>(),
cols_data.data(),
cols_data.size() * sizeof(int64_t));
ASSERT_EQ(cmp_cols, 0);
int cmp_elements = memcmp(csr.non_zero_elements().data<float>(),
non_zero_data.data(),
non_zero_data.size() * sizeof(float));
ASSERT_EQ(cmp_elements, 0);
};
check(*csr);
// 1. test sparse_coo_to_csr
auto dense_dims = pten::framework::make_ddim({3, 3});
pten::DenseTensorMeta indices_meta(pten::DataType::INT64,
{sparse_dim, non_zero_num},
pten::DataLayout::NCHW);
pten::DenseTensorMeta values_meta(
pten::DataType::FLOAT32, {non_zero_num}, pten::DataLayout::NCHW);
pten::CPUPlace place;
pten::DenseTensor indices(alloc.get(), indices_meta);
pten::DenseTensor values(alloc.get(), values_meta);
memcpy(indices.mutable_data<int64_t>(place),
indices_data.data(),
indices_data.size() * sizeof(int64_t));
memcpy(values.mutable_data<float>(place),
non_zero_data.data(),
non_zero_data.size() * sizeof(float));
auto coo =
std::make_shared<pten::SparseCooTensor>(indices, values, dense_dims);
paddle::experimental::Tensor coo_x(coo);
auto out2 =
paddle::experimental::sparse::to_sparse_csr(coo_x, pten::Backend::CPU);
auto csr2 = std::dynamic_pointer_cast<pten::SparseCsrTensor>(out.impl());
check(*csr2);
}
......@@ -43,7 +43,7 @@ inline void CheckResult(
#if defined(PADDLE_WITH_CUDA)
if (coo.place() == pten::GPUPlace()) {
const auto* dev_ctx_cuda = static_cast<const pten::GPUContext*>(dev_ctx);
const auto* dev_ctx_gpu = static_cast<const pten::GPUContext*>(dev_ctx);
DenseTensor indices(
alloc.get(),
DenseTensorMeta(
......@@ -53,8 +53,8 @@ inline void CheckResult(
DenseTensorMeta(real_elements.dtype(),
real_elements.dims(),
real_elements.layout()));
pten::Copy(*dev_ctx_cuda, real_indices, true, &indices);
pten::Copy(*dev_ctx_cuda, real_elements, true, &elements);
pten::Copy(*dev_ctx_gpu, real_indices, true, &indices);
pten::Copy(*dev_ctx_gpu, real_elements, true, &elements);
int cmp_indices = memcmp(indices.data<IndicesT>(),
non_zero_indices.data(),
......@@ -103,9 +103,6 @@ void TestDenseToSparseCoo(const DenseTensor& dense_x,
// 2. test cuda
#if defined(PADDLE_WITH_CUDA)
// paddle::platform::DeviceContextPool& pool =
// paddle::platform::DeviceContextPool::Instance();
// auto* dev_ctx_cuda = pool.GetByPlace(paddle::platform::CUDAPlace());
pten::GPUContext dev_ctx_gpu;
dev_ctx_gpu.PartialInitWithoutAllocator();
dev_ctx_gpu.SetAllocator(
......@@ -327,8 +324,6 @@ void TestSparseCsrToCoo(const DDim& dense_dims,
const auto cuda_alloc =
std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CUDAPlace());
// auto& pool = paddle::platform::DeviceContextPool::Instance();
// auto* dev_ctx_cuda = pool.GetByPlace(paddle::platform::CUDAPlace());
pten::DenseTensor d_crows(cuda_alloc.get(), crows_meta);
pten::DenseTensor d_cols(cuda_alloc.get(), cols_meta);
pten::DenseTensor d_values(cuda_alloc.get(), values_meta);
......@@ -382,5 +377,287 @@ TEST(DEV_API, sparse_csr_to_coo_batch_and_fp16) {
non_zero_num);
}
template <typename ValueT, typename IndicesT>
inline void CheckCsrResult(
const DeviceContext* dev_ctx,
const SparseCsrTensor& csr,
const std::vector<ValueT> non_zero_elements,
const std::vector<IndicesT>& non_zero_crows,
const std::vector<IndicesT>& non_zero_cols,
const int64_t non_zero_num,
const std::shared_ptr<paddle::experimental::DefaultAllocator>& alloc) {
const DenseTensor real_crows = csr.non_zero_crows();
const DenseTensor real_cols = csr.non_zero_cols();
const DenseTensor real_elements = csr.non_zero_elements();
ASSERT_EQ(csr.non_zero_cols().numel(), non_zero_num);
#if defined(PADDLE_WITH_CUDA)
if (csr.place() == paddle::platform::CUDAPlace()) {
const auto* dev_ctx_gpu = static_cast<const pten::GPUContext*>(dev_ctx);
DenseTensor crows(
alloc.get(),
DenseTensorMeta(
DataType::INT64, real_crows.dims(), real_crows.layout()));
DenseTensor cols(
alloc.get(),
DenseTensorMeta(DataType::INT64, real_cols.dims(), real_cols.layout()));
DenseTensor elements(alloc.get(),
DenseTensorMeta(real_elements.dtype(),
real_elements.dims(),
real_elements.layout()));
pten::Copy(*dev_ctx_gpu, real_crows, true, &crows);
pten::Copy(*dev_ctx_gpu, real_cols, true, &cols);
pten::Copy(*dev_ctx_gpu, real_elements, true, &elements);
int cmp_crows = memcmp(crows.data<IndicesT>(),
non_zero_crows.data(),
non_zero_crows.size() * sizeof(IndicesT));
ASSERT_EQ(cmp_crows, 0);
int cmp_cols = memcmp(cols.data<IndicesT>(),
non_zero_cols.data(),
non_zero_cols.size() * sizeof(IndicesT));
ASSERT_EQ(cmp_cols, 0);
int cmp_elements = memcmp(elements.data<ValueT>(),
non_zero_elements.data(),
non_zero_elements.size() * sizeof(ValueT));
ASSERT_EQ(cmp_elements, 0);
} else {
#endif
int cmp_crows = memcmp(real_crows.data<IndicesT>(),
non_zero_crows.data(),
non_zero_crows.size() * sizeof(IndicesT));
ASSERT_EQ(cmp_crows, 0);
int cmp_cols = memcmp(real_cols.data<IndicesT>(),
non_zero_cols.data(),
non_zero_cols.size() * sizeof(IndicesT));
ASSERT_EQ(cmp_cols, 0);
int cmp_elements = memcmp(real_elements.data<ValueT>(),
non_zero_elements.data(),
non_zero_elements.size() * sizeof(ValueT));
ASSERT_EQ(cmp_elements, 0);
#if defined(PADDLE_WITH_CUDA)
}
#endif
}
template <typename T>
void TestCooToCsr(const DDim& dense_dims,
const int64_t& non_zero_num,
const std::vector<T>& non_zero_data,
const std::vector<int64_t>& non_zero_indices,
const std::vector<int64_t>& cols_data,
const std::vector<int64_t>& crows_data) {
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
pten::CPUPlace cpu;
DenseTensorMeta indices_meta(
DataType::INT64,
{static_cast<int64_t>(dense_dims.size()), non_zero_num},
DataLayout::NCHW);
DenseTensor indices(alloc.get(), indices_meta);
DenseTensorMeta values_meta(
paddle::experimental::CppTypeToDataType<T>::Type(),
{non_zero_num},
DataLayout::NCHW);
DenseTensor values(alloc.get(), values_meta);
memcpy(indices.mutable_data<int64_t>(cpu),
non_zero_indices.data(),
non_zero_indices.size() * sizeof(int64_t));
memcpy(values.mutable_data<T>(cpu),
non_zero_data.data(),
non_zero_data.size() * sizeof(T));
pten::SparseCooTensor coo(indices, values, dense_dims);
// 1. test cpu
pten::CPUContext dev_ctx_cpu;
auto cpu_sparse_out = sparse::SparseCooToCsr<T>(dev_ctx_cpu, coo);
CheckCsrResult<T, int64_t>(&dev_ctx_cpu,
cpu_sparse_out,
non_zero_data,
crows_data,
cols_data,
non_zero_num,
alloc);
// 2. test cuda
#if defined(PADDLE_WITH_CUDA)
const auto cuda_alloc =
std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CUDAPlace());
pten::GPUContext dev_ctx_gpu;
dev_ctx_gpu.PartialInitWithoutAllocator();
dev_ctx_gpu.SetAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(dev_ctx_gpu.GetPlace(), dev_ctx_gpu.stream())
.get());
dev_ctx_gpu.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(pten::CPUPlace())
.get());
dev_ctx_gpu.PartialInitWithAllocator();
pten::DenseTensor d_indices(cuda_alloc.get(), indices_meta);
pten::DenseTensor d_values(cuda_alloc.get(), values_meta);
pten::Copy(dev_ctx_gpu, indices, true, &d_indices);
pten::Copy(dev_ctx_gpu, values, true, &d_values);
pten::SparseCooTensor d_coo(d_indices, d_values, dense_dims);
auto cuda_sparse_out = sparse::SparseCooToCsr<T>(dev_ctx_gpu, d_coo);
CheckCsrResult<T, int64_t>(&dev_ctx_gpu,
cuda_sparse_out,
non_zero_data,
crows_data,
cols_data,
non_zero_num,
alloc);
#endif
}
TEST(DEV_API, coo_to_csr) {
// float dense_data[3][3] = {{0.0, 1.0, 0.0}, {2.0, 0.0, 3.0}, {3.2, 0.0,
// 0.0}};
std::vector<float> non_zero_data = {1.0, 2.0, 3.0, 3.2};
std::vector<int64_t> non_zero_indices = {0, 1, 1, 2, 1, 0, 2, 0};
std::vector<int64_t> cols_data = {1, 0, 2, 0};
std::vector<int64_t> crows_data = {0, 1, 3, 4};
const int64_t non_zero_num = 4;
auto dense_dims = pten::framework::make_ddim({3, 3});
TestCooToCsr<float>(dense_dims,
non_zero_num,
non_zero_data,
non_zero_indices,
cols_data,
crows_data);
}
TEST(DEV_API, batch_coo_to_csr) {
// float dense_data[2][3][3] =
// {{{0.0, 1.0, 0.0}, {2.0, 0.0, 3.0}, {3.2, 0.0, 0.0}},
// {{0.0, 1.0, 0.0}, {2.0, 0.0, 3.0}, {0.0, 0.0, 0.0}}};
const int64_t non_zero_num = 7;
std::vector<float> data = {1.0, 2.0, 3.0, 3.2, 1.0, 2.0, 3.0};
std::vector<pten::dtype::float16> non_zero_data(non_zero_num);
for (int64_t i = 0; i < non_zero_num; i++) {
non_zero_data[i] = static_cast<pten::dtype::float16>(data[i]);
}
std::vector<int64_t> non_zero_indices = {0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 2,
0, 1, 1, 1, 0, 2, 0, 1, 0, 2};
std::vector<int64_t> cols_data = {1, 0, 2, 0, 1, 0, 2};
std::vector<int64_t> crows_data = {0, 1, 3, 4, 0, 1, 3, 3};
auto dense_dims = pten::framework::make_ddim({2, 3, 3});
TestCooToCsr<pten::dtype::float16>(dense_dims,
non_zero_num,
non_zero_data,
non_zero_indices,
cols_data,
crows_data);
}
template <typename T>
void TestDenseToSparseCsr(const DenseTensor& dense_x,
const int64_t non_zero_num,
const std::vector<T>& non_zero_data,
const std::vector<int64_t>& crows_data,
const std::vector<int64_t>& cols_data) {
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
pten::CPUContext dev_ctx_cpu;
// 1. test cpu
auto cpu_sparse_out = sparse::DenseToSparseCsr<T>(dev_ctx_cpu, dense_x);
CheckCsrResult<T, int64_t>(&dev_ctx_cpu,
cpu_sparse_out,
non_zero_data,
crows_data,
cols_data,
non_zero_num,
alloc);
// 2. test cuda
#if defined(PADDLE_WITH_CUDA)
const auto cuda_alloc =
std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CUDAPlace());
DenseTensor d_dense_x(
cuda_alloc.get(),
DenseTensorMeta(dense_x.dtype(), dense_x.dims(), dense_x.layout()));
pten::GPUContext dev_ctx_gpu;
dev_ctx_gpu.PartialInitWithoutAllocator();
dev_ctx_gpu.SetAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(dev_ctx_gpu.GetPlace(), dev_ctx_gpu.stream())
.get());
dev_ctx_gpu.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(pten::CPUPlace())
.get());
dev_ctx_gpu.PartialInitWithAllocator();
pten::Copy(dev_ctx_gpu, dense_x, true, &d_dense_x);
auto sparse_out = sparse::DenseToSparseCsr<T>(dev_ctx_gpu, d_dense_x);
CheckCsrResult<T, int64_t>(&dev_ctx_gpu,
sparse_out,
non_zero_data,
crows_data,
cols_data,
non_zero_num,
alloc);
#endif
}
TEST(DEV_API, dense_to_sparse_csr) {
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
DenseTensor dense_x(
alloc.get(),
DenseTensorMeta(
DataType::FLOAT32, framework::make_ddim({3, 3}), DataLayout::NCHW));
pten::CPUPlace cpu;
auto* dense_x_data = dense_x.mutable_data<float>(cpu);
float dense_data[3][3] = {{0.0, 1.0, 0.0}, {2.0, 0.0, 3.0}, {3.2, 0.0, 0.0}};
std::vector<float> non_zero_data = {1.0, 2.0, 3.0, 3.2};
std::vector<int64_t> cols_data = {1, 0, 2, 0};
std::vector<int64_t> crows_data = {0, 1, 3, 4};
const int64_t non_zero_num = 4;
std::copy(&dense_data[0][0], &dense_data[0][0] + 9, dense_x_data);
TestDenseToSparseCsr<float>(
dense_x, non_zero_num, non_zero_data, crows_data, cols_data);
}
TEST(DEV_API, dense_to_sparse_csr_batch) {
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
DenseTensor dense_x(alloc.get(),
DenseTensorMeta(DataType::FLOAT16,
framework::make_ddim({2, 3, 3}),
DataLayout::NCHW));
pten::CPUPlace cpu;
auto* dense_x_data = dense_x.mutable_data<pten::dtype::float16>(cpu);
const int64_t non_zero_num = 7;
float dense_data[2][3][3] = {
{{0.0, 1.0, 0.0}, {2.0, 0.0, 3.0}, {3.2, 0.0, 0.0}},
{{0.0, 1.0, 0.0}, {2.0, 0.0, 0.0}, {3.2, 0.0, 0.0}}};
std::vector<float> data = {1.0, 2.0, 3.0, 3.2, 1.0, 2.0, 3.2};
std::vector<pten::dtype::float16> non_zero_data(non_zero_num);
for (int64_t i = 0; i < non_zero_num; i++) {
non_zero_data[i] = static_cast<pten::dtype::float16>(data[i]);
}
std::vector<int64_t> cols_data = {1, 0, 2, 0, 1, 0, 0};
std::vector<int64_t> crows_data = {0, 1, 3, 4, 0, 1, 2, 3};
float* dense_ptr = &dense_data[0][0][0];
for (int i = 0; i < 18; i++) {
dense_x_data[i] = static_cast<pten::dtype::float16>(dense_ptr[i]);
}
TestDenseToSparseCsr<pten::dtype::float16>(
dense_x, non_zero_num, non_zero_data, crows_data, cols_data);
}
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册