未验证 提交 abb38136 编写于 作者: O OccupyMars2025 提交者: GitHub

[Hackathon 3rd No.22 ] add paddle.incubate.sparse.reshape (#46694)

* add sparse reshape

* change the dtype in all test cases to int64

* just one test case

* modify comments

* Update test_sparse_reshape_op.py

* chang the type of "shape"  from  vector<int64_t>  to  IntArray

* check whether sp_out.to_dense() is the cause  of error

* print sp_out

* Update reshape_kernel.cc

* use numpy to generate the equal paddle tensor

* just check dense_tensor.numpy()

* check cpu and cuda versions

* Update test_sparse_reshape_op.py

* supply all test cases for cpu forward coo kernel

* test forward coo cuda kernel

* change configuration of cuda kernel

* keep only one test case

* test coo cpu kernel (forward and backward)

* row major or column major ???

* test cuda coo forward kernel

* complete declaration and registration

* Update __init__.py

* rebuild

* retrigger CI

* add cudaMalloc and cudaMemcpy  in  ReshapeCooKernel  and change back to row major order in a cuda dense tensor

* midify minor error

* test only cpu coo forward kernel

* add all test cases for coo forward kernel  (both cpu and gpu)

* test all forward kernels (coo, csr; cpu, gpu)

* add all test cases for all kinds of kernels

* just retrigger CI

* Update sparse_ops.yaml

* Update sparse_ops.yaml

* Update sparse_ops.yaml

* resolve conflicts

* Update sparse_ops.yaml

* don't specify tensor place

* new shape has -1 or 0 in it

* Update unary_grad_kernel.h

* correct lvalue error

* code style

* Update sparse_backward.yaml

* Update sparse_ops.yaml

* Update unary_kernel.h

* Update unary.py

* Update sparse_backward.yaml

* Update unary.py

* code style

* code style

* code style

* Update unary.py

* specify tensor place explicitly

* do not use numpy array

* use numpy array in unit test again

* modify example code in docstring
上级 64307903
...@@ -272,6 +272,17 @@ ...@@ -272,6 +272,17 @@
func : relu_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, func : relu_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
relu_csr_grad {sparse_csr, sparse_csr -> sparse_csr} relu_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
- backward_op : reshape_grad
forward : reshape(Tensor x, IntArray shape) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : reshape_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
reshape_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
- backward_op : scale_grad - backward_op : scale_grad
forward : scale(Tensor x, float scale, float bias, bool bias_after_scale) -> Tensor(out) forward : scale(Tensor x, float scale, float bias, bool bias_after_scale) -> Tensor(out)
args : (Tensor out_grad, float scale) args : (Tensor out_grad, float scale)
......
...@@ -489,3 +489,14 @@ ...@@ -489,3 +489,14 @@
func : sync_batch_norm_coo{sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense} func : sync_batch_norm_coo{sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense}
data_type : x data_type : x
backward : sync_batch_norm_grad backward : sync_batch_norm_grad
- op : reshape
args : (Tensor x, IntArray shape)
output : Tensor(out)
infer_meta :
func : ReshapeInferMeta
kernel :
func : reshape_coo{sparse_coo -> sparse_coo},
reshape_csr{sparse_csr -> sparse_csr}
layout : x
backward : reshape_grad
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void ReshapeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
SparseCooTensor* dx) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCooKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}
template <typename T, typename Context>
void ReshapeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
SparseCsrTensor* dx) {
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCsrKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(reshape_coo_grad,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(reshape_csr_grad,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void ReshapeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out) {
// TODO(OccupyMars2025): Currently, reshape is only applicable to sparse dims
int64_t x_nnz = x.nnz();
// Use DDim::reshape to handle -1 and 0 in the argument "shape"
std::vector<int> new_shape(shape.GetData().begin(), shape.GetData().end());
phi::DDim out_dims = x.dims().reshape(new_shape);
// get sparse part dimensions of x and out
std::vector<int64_t> x_sparse_part_dims;
std::vector<int64_t> out_sparse_part_dims;
for (int i = 0; i < x.sparse_dim(); ++i) {
x_sparse_part_dims.push_back(x.dims()[i]);
}
for (int i = 0; i < out_dims.size() - x.dense_dim(); ++i) {
out_sparse_part_dims.push_back(out_dims[i]);
}
DenseTensor out_indices = Empty<int64_t, Context>(
dev_ctx, {static_cast<int64_t>(out_sparse_part_dims.size()), x_nnz});
DenseTensor out_values(x.values());
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
// compute values of indices
const DenseTensor& x_indices = x.indices();
const auto* x_indices_data = x_indices.data<int64_t>();
auto* out_indices_data = out_indices.data<int64_t>();
const phi::DDim& x_sparse_part_strides =
phi::stride(phi::make_ddim(x_sparse_part_dims));
const phi::DDim& out_sparse_part_strides =
phi::stride(phi::make_ddim(out_sparse_part_dims));
int64_t location = 0;
for (int64_t j = 0; j < x_nnz; ++j) {
location = 0;
for (int i = 0; i < x.sparse_dim(); ++i) {
location += x_indices_data[i * x_nnz + j] * x_sparse_part_strides[i];
}
for (size_t i = 0; i < out_sparse_part_dims.size(); ++i) {
out_indices_data[i * x_nnz + j] = location / out_sparse_part_strides[i];
location %= out_sparse_part_strides[i];
}
}
}
template <typename T, typename Context>
void ReshapeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& shape,
SparseCsrTensor* out) {
// transform csr format to coo format, and then use coo kernel
const SparseCooTensor x_coo = CsrToCoo<T, Context>(dev_ctx, x);
SparseCooTensor out_coo;
ReshapeCooKernel<T, Context>(dev_ctx, x_coo, shape, &out_coo);
CooToCsrKernel<T, Context>(dev_ctx, out_coo, out);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(reshape_coo,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(reshape_csr,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
...@@ -329,7 +329,8 @@ PD_REGISTER_KERNEL(csr_to_coo, ...@@ -329,7 +329,8 @@ PD_REGISTER_KERNEL(csr_to_coo,
int8_t, int8_t,
int16_t, int16_t,
int, int,
int64_t) {} int64_t,
bool) {}
PD_REGISTER_KERNEL(coo_to_csr, PD_REGISTER_KERNEL(coo_to_csr,
CPU, CPU,
...@@ -342,7 +343,8 @@ PD_REGISTER_KERNEL(coo_to_csr, ...@@ -342,7 +343,8 @@ PD_REGISTER_KERNEL(coo_to_csr,
int8_t, int8_t,
int16_t, int16_t,
int, int,
int64_t) {} int64_t,
bool) {}
PD_REGISTER_KERNEL(dense_to_csr, PD_REGISTER_KERNEL(dense_to_csr,
CPU, CPU,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
namespace phi {
namespace sparse {
// just copy from paddle\phi\kernels\sparse\cpu\reshape_grad_kernel.cc
template <typename T, typename Context>
void ReshapeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
SparseCooTensor* dx) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCooKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}
// just copy from paddle\phi\kernels\sparse\cpu\reshape_grad_kernel.cc
template <typename T, typename Context>
void ReshapeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
SparseCsrTensor* dx) {
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCsrKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(reshape_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooGradKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(reshape_csr_grad,
GPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrGradKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
namespace phi {
namespace sparse {
__global__ void ReshapeCooCudaKernel(const int64_t* x_indices_data,
const int num_x_sparse_part_dims,
const int num_out_sparse_part_dims,
const int64_t x_nnz,
const int64_t* x_sparse_part_strides,
const int64_t* out_sparse_part_strides,
int64_t* out_indices_data) {
CUDA_KERNEL_LOOP_TYPE(j, x_nnz, int64_t) {
int64_t location = 0;
for (int i = 0; i < num_x_sparse_part_dims; ++i) {
location += x_indices_data[i * x_nnz + j] * x_sparse_part_strides[i];
}
for (int i = 0; i < num_out_sparse_part_dims; ++i) {
out_indices_data[i * x_nnz + j] = location / out_sparse_part_strides[i];
location %= out_sparse_part_strides[i];
}
}
}
template <typename T, typename Context>
void ReshapeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out) {
int64_t x_nnz = x.nnz();
std::vector<int> new_shape(shape.GetData().begin(), shape.GetData().end());
phi::DDim out_dims = x.dims().reshape(new_shape);
// get sparse part dimensions of x and out
std::vector<int64_t> x_sparse_part_dims;
std::vector<int64_t> out_sparse_part_dims;
for (int i = 0; i < x.sparse_dim(); ++i) {
x_sparse_part_dims.push_back(x.dims()[i]);
}
for (int i = 0; i < out_dims.size() - x.dense_dim(); ++i) {
out_sparse_part_dims.push_back(out_dims[i]);
}
DenseTensor out_indices = Empty<int64_t, Context>(
dev_ctx, {static_cast<int64_t>(out_sparse_part_dims.size()), x_nnz});
DenseTensor out_values(x.values());
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
// compute values of out indices
const auto* x_indices_data = x.indices().data<int64_t>();
auto* out_indices_data = out_indices.data<int64_t>();
const phi::DDim& x_sparse_part_strides =
phi::stride(phi::make_ddim(x_sparse_part_dims));
const phi::DDim& out_sparse_part_strides =
phi::stride(phi::make_ddim(out_sparse_part_dims));
int64_t *destination_x_sparse_part_strides,
*destination_out_sparse_part_strides;
#ifdef PADDLE_WITH_HIP
hipMalloc(reinterpret_cast<void**>(&destination_x_sparse_part_strides),
sizeof(int64_t) * x_sparse_part_strides.size());
hipMemcpy(destination_x_sparse_part_strides,
x_sparse_part_strides.Get(),
sizeof(int64_t) * x_sparse_part_strides.size(),
hipMemcpyHostToDevice);
hipMalloc(reinterpret_cast<void**>(&destination_out_sparse_part_strides),
sizeof(int64_t) * out_sparse_part_strides.size());
hipMemcpy(destination_out_sparse_part_strides,
out_sparse_part_strides.Get(),
sizeof(int64_t) * out_sparse_part_strides.size(),
hipMemcpyHostToDevice);
#else
cudaMalloc(reinterpret_cast<void**>(&destination_x_sparse_part_strides),
sizeof(int64_t) * x_sparse_part_strides.size());
cudaMemcpy(destination_x_sparse_part_strides,
x_sparse_part_strides.Get(),
sizeof(int64_t) * x_sparse_part_strides.size(),
cudaMemcpyHostToDevice);
cudaMalloc(reinterpret_cast<void**>(&destination_out_sparse_part_strides),
sizeof(int64_t) * out_sparse_part_strides.size());
cudaMemcpy(destination_out_sparse_part_strides,
out_sparse_part_strides.Get(),
sizeof(int64_t) * out_sparse_part_strides.size(),
cudaMemcpyHostToDevice);
#endif
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_nnz, 1);
ReshapeCooCudaKernel<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(
x_indices_data,
x_sparse_part_dims.size(),
out_sparse_part_dims.size(),
x_nnz,
destination_x_sparse_part_strides,
destination_out_sparse_part_strides,
out_indices_data);
}
// just copy from paddle\phi\kernels\sparse\cpu\reshape_kernel.cc
template <typename T, typename Context>
void ReshapeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& shape,
SparseCsrTensor* out) {
// transform csr format to coo format, and then use coo kernel
const SparseCooTensor x_coo = CsrToCoo<T, Context>(dev_ctx, x);
SparseCooTensor out_coo;
ReshapeCooKernel<T, Context>(dev_ctx, x_coo, shape, &out_coo);
CooToCsrKernel<T, Context>(dev_ctx, out_coo, out);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(reshape_coo,
GPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(reshape_csr,
GPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
...@@ -539,7 +539,8 @@ PD_REGISTER_KERNEL(csr_to_coo, ...@@ -539,7 +539,8 @@ PD_REGISTER_KERNEL(csr_to_coo,
int8_t, int8_t,
int16_t, int16_t,
int, int,
int64_t) {} int64_t,
bool) {}
PD_REGISTER_KERNEL(coo_to_csr, PD_REGISTER_KERNEL(coo_to_csr,
GPU, GPU,
...@@ -552,7 +553,8 @@ PD_REGISTER_KERNEL(coo_to_csr, ...@@ -552,7 +553,8 @@ PD_REGISTER_KERNEL(coo_to_csr,
int8_t, int8_t,
int16_t, int16_t,
int, int,
int64_t) {} int64_t,
bool) {}
PD_REGISTER_KERNEL(dense_to_csr, PD_REGISTER_KERNEL(dense_to_csr,
GPU, GPU,
......
...@@ -89,5 +89,17 @@ void TransposeCsrGradKernel(const Context& dev_ctx, ...@@ -89,5 +89,17 @@ void TransposeCsrGradKernel(const Context& dev_ctx,
const std::vector<int>& perm, const std::vector<int>& perm,
SparseCsrTensor* dx); SparseCsrTensor* dx);
template <typename T, typename Context>
void ReshapeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
SparseCooTensor* dx);
template <typename T, typename Context>
void ReshapeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
SparseCsrTensor* dx);
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h"
...@@ -155,5 +157,43 @@ SparseCooTensor ReluCsr(const Context& dev_ctx, const SparseCooTensor& x) { ...@@ -155,5 +157,43 @@ SparseCooTensor ReluCsr(const Context& dev_ctx, const SparseCooTensor& x) {
return csr; return csr;
} }
template <typename T, typename Context>
void ReshapeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out);
template <typename T, typename Context>
void ReshapeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& shape,
SparseCsrTensor* out);
template <typename T, typename Context>
SparseCooTensor ReshapeCoo(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape) {
SparseCooTensor coo;
ReshapeCooKernel<T, Context>(dev_ctx, x, shape, &coo);
return coo;
}
template <typename T, typename Context>
SparseCsrTensor ReshapeCsr(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& shape) {
PADDLE_ENFORCE_LE(
2,
shape.size(),
phi::errors::InvalidArgument("size of shape must be equal to 2 or 3"));
PADDLE_ENFORCE_GE(
3,
shape.size(),
phi::errors::InvalidArgument("size of shape must be equal to 2 or 3"));
SparseCsrTensor csr;
ReshapeCsrKernel<T, Context>(dev_ctx, x, shape, &csr);
return csr;
}
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
class TestReshape(unittest.TestCase):
"""
Test the API paddle.incubate.sparse.reshape on some sparse tensors.
x: sparse, out: sparse
"""
def check_result(self, x_shape, new_shape, format):
"""
x_shape: original shape
new_shape: new shape
format: "coo" or "csr"
Transform a sparse tensor with shape "x_shape" to
a sparse tensor with shape "new_shape".
Compare the output of paddle.reshape and the output of
paddle.incubate.sparse.reshape.
"""
mask = np.random.randint(0, 2, x_shape)
np_x = np.random.randint(-100, 100, x_shape) * mask
# check cpu kernel
dense_x = paddle.to_tensor(np_x, place=paddle.CPUPlace())
dense_x.stop_gradient = False
dense_out = paddle.reshape(dense_x, new_shape)
if format == "coo":
sp_x = paddle.to_tensor(np_x,
place=paddle.CPUPlace()).to_sparse_coo(
len(x_shape))
else:
sp_x = paddle.to_tensor(np_x,
place=paddle.CPUPlace()).to_sparse_csr()
sp_x.stop_gradient = False
sp_out = paddle.incubate.sparse.reshape(sp_x, new_shape)
np.testing.assert_allclose(sp_out.to_dense().numpy(),
dense_out.numpy(),
rtol=1e-05)
dense_out.backward()
sp_out.backward()
np.testing.assert_allclose(sp_x.grad.to_dense().numpy(),
dense_x.grad.numpy() *
np_x.astype('bool').astype('int'),
rtol=1e-05)
# check gpu kernel
if paddle.device.is_compiled_with_cuda():
dense_x = paddle.to_tensor(np_x, place=paddle.CUDAPlace(0))
dense_x.stop_gradient = False
dense_out = paddle.reshape(dense_x, new_shape)
if format == "coo":
sp_x = paddle.to_tensor(
np_x, place=paddle.CUDAPlace(0)).to_sparse_coo(len(x_shape))
else:
sp_x = paddle.to_tensor(
np_x, place=paddle.CUDAPlace(0)).to_sparse_csr()
sp_x.stop_gradient = False
sp_out = paddle.incubate.sparse.reshape(sp_x, new_shape)
np.testing.assert_allclose(sp_out.to_dense().numpy(),
dense_out.numpy(),
rtol=1e-05)
dense_out.backward()
sp_out.backward()
np.testing.assert_allclose(sp_x.grad.to_dense().numpy(),
dense_x.grad.numpy() *
np_x.astype('bool').astype('int'),
rtol=1e-05)
def test_reshape_2d(self):
self.check_result([2, 5], [
10,
], 'coo')
self.check_result([12, 5], [15, 4], 'coo')
self.check_result([10, 5], [2, 25], 'csr')
self.check_result([9, 8], [18, 4], 'csr')
def test_reshape_3d(self):
self.check_result([6, 2, 3], [6, 2, 3], 'coo')
self.check_result([6, 2, 3], [2, 3, 3, 2], 'coo')
self.check_result([6, 2, 3], [1, 18, 2], 'coo')
self.check_result([6, 2, 3], [2, 9, 2], 'coo')
self.check_result([6, 2, 3], [2, 1, 18], 'coo')
self.check_result([6, 2, 3], [1, 2, 2, 3, 3], 'coo')
self.check_result([6, 2, 3], [6, 2, 3], 'csr')
self.check_result([6, 2, 3], [6, 3, 2], 'csr')
self.check_result([6, 2, 3], [2, 6, 3], 'csr')
self.check_result([6, 2, 3], [3, 6, 2], 'csr')
self.check_result([6, 2, 3], [4, 9, 1], 'csr')
self.check_result([6, 2, 3], [12, 1, 3], 'csr')
def test_reshape_nd(self):
self.check_result([8, 3, 4, 4, 5, 3], [24, 8, 10, 3], 'coo')
self.check_result([3, 4, 4, 5, 7], [1, 12, 2, 5, 14], 'coo')
def test_reshape_with_zero_or_minus_one_in_new_shape(self):
self.check_result([6, 2, 3], [-1, 0, 3], 'coo')
self.check_result([6, 2, 3], [2, 3, 0, -1], 'coo')
self.check_result([6, 2, 3], [1, -1, 2], 'coo')
self.check_result([6, 2, 3], [-1, 9, 2], 'coo')
self.check_result([6, 2, 3], [2, -1, 18], 'coo')
self.check_result([6, 2, 3], [1, 0, 2, -1, 3], 'coo')
self.check_result([6, 2, 3], [0, 0, -1], 'csr')
self.check_result([6, 2, 3], [-1, 3, 2], 'csr')
self.check_result([6, 2, 3], [2, -1, 0], 'csr')
self.check_result([6, 2, 3], [-1, 6, 2], 'csr')
self.check_result([6, 2, 3], [-1, 9, 1], 'csr')
self.check_result([6, 2, 3], [-1, 1, 3], 'csr')
if __name__ == "__main__":
unittest.main()
...@@ -35,6 +35,7 @@ from .unary import deg2rad ...@@ -35,6 +35,7 @@ from .unary import deg2rad
from .unary import rad2deg from .unary import rad2deg
from .unary import expm1 from .unary import expm1
from .unary import transpose from .unary import transpose
from .unary import reshape
from .binary import mv from .binary import mv
from .binary import matmul from .binary import matmul
...@@ -50,35 +51,9 @@ from .multiary import addmm ...@@ -50,35 +51,9 @@ from .multiary import addmm
from . import nn from . import nn
__all__ = [ __all__ = [
'sparse_coo_tensor', 'sparse_coo_tensor', 'sparse_csr_tensor', 'sin', 'tan', 'asin', 'atan',
'sparse_csr_tensor', 'sinh', 'tanh', 'asinh', 'atanh', 'sqrt', 'square', 'log1p', 'abs', 'pow',
'sin', 'cast', 'neg', 'deg2rad', 'rad2deg', 'expm1', 'mv', 'matmul',
'tan', 'masked_matmul', 'addmm', 'add', 'subtract', 'transpose', 'multiply',
'asin', 'divide', 'coalesce', 'is_same_shape', 'reshape'
'atan',
'sinh',
'tanh',
'asinh',
'atanh',
'sqrt',
'square',
'log1p',
'abs',
'pow',
'cast',
'neg',
'deg2rad',
'rad2deg',
'expm1',
'mv',
'matmul',
'masked_matmul',
'addmm',
'add',
'subtract',
'transpose',
'multiply',
'divide',
'coalesce',
'is_same_shape',
] ]
...@@ -639,3 +639,60 @@ def expm1(x, name=None): ...@@ -639,3 +639,60 @@ def expm1(x, name=None):
out = paddle.incubate.sparse.expm1(sparse_x) out = paddle.incubate.sparse.expm1(sparse_x)
""" """
return _C_ops.sparse_expm1(x) return _C_ops.sparse_expm1(x)
@dygraph_only
def reshape(x, shape, name=None):
"""
Changes the shape of ``x`` without changing its value, requiring x to be a SparseCooTensor or SparseCsrTensor.
Currently this function can only reshape the sparse dims of ``x`` , but ``shape`` argument must be specified
as the shape of the reshaped tensor.
Note that if x is a SparseCsrTensor, then len(shape) must be 2 or 3.
There are some tricks when specifying the target shape.
- 1. -1 means the value of this dimension is inferred from the total element number of x and remaining dimensions. Thus one and only one dimension can be set -1.
- 2. 0 means the actual dimension value is going to be copied from the corresponding dimension of x. The indices of 0 in the target shape can not exceed the rank of x.
Here are some examples to explain it.
- 1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape is [6, 8], the reshape operator will transform x into a 2-D tensor with shape [6, 8] and leaving x's data unchanged.
- 2. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape is [2, 3, -1, 2], the reshape operator will transform x into a 4-D tensor with shape [2, 3, 4, 2] and leaving x's data unchanged. In this case, one dimension of the target shape is set to -1, the value of this dimension is inferred from the total element number of x and remaining dimensions.
- 3. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape is [-1, 0, 3, 2], the reshape operator will transform x into a 4-D tensor with shape [2, 4, 3, 2] and leaving x's data unchanged. In this case, besides -1, 0 means the actual dimension value is going to be copied from the corresponding dimension of x.
Args:
x (Tensor): The input sparse tensor with data type ``float32``, ``float64``, ``int32``, ``int64`` or ``bool``.
shape (list|tuple): Define the target shape. At most one dimension of the target shape can be -1.
The data type is ``int32``.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: A reshaped Tensor with the same data type as ``x``.
Examples:
.. code-block:: python
import paddle
x_shape = [6, 2, 3]
new_shape = [1, 0, 2, -1, 3]
format = "coo"
dense_x = paddle.randint(-100, 100, x_shape) * paddle.randint(0, 2, x_shape)
if format == "coo":
sp_x = dense_x.to_sparse_coo(len(x_shape))
else:
sp_x = dense_x.to_sparse_csr()
sp_out = paddle.incubate.sparse.reshape(sp_x, new_shape)
print(sp_out)
# the shape of sp_out is [1, 2, 2, 3, 3]
"""
return _C_ops.sparse_reshape(x, shape)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册