未验证 提交 abb38136 编写于 作者: O OccupyMars2025 提交者: GitHub

[Hackathon 3rd No.22 ] add paddle.incubate.sparse.reshape (#46694)

* add sparse reshape

* change the dtype in all test cases to int64

* just one test case

* modify comments

* Update test_sparse_reshape_op.py

* chang the type of "shape"  from  vector<int64_t>  to  IntArray

* check whether sp_out.to_dense() is the cause  of error

* print sp_out

* Update reshape_kernel.cc

* use numpy to generate the equal paddle tensor

* just check dense_tensor.numpy()

* check cpu and cuda versions

* Update test_sparse_reshape_op.py

* supply all test cases for cpu forward coo kernel

* test forward coo cuda kernel

* change configuration of cuda kernel

* keep only one test case

* test coo cpu kernel (forward and backward)

* row major or column major ???

* test cuda coo forward kernel

* complete declaration and registration

* Update __init__.py

* rebuild

* retrigger CI

* add cudaMalloc and cudaMemcpy  in  ReshapeCooKernel  and change back to row major order in a cuda dense tensor

* midify minor error

* test only cpu coo forward kernel

* add all test cases for coo forward kernel  (both cpu and gpu)

* test all forward kernels (coo, csr; cpu, gpu)

* add all test cases for all kinds of kernels

* just retrigger CI

* Update sparse_ops.yaml

* Update sparse_ops.yaml

* Update sparse_ops.yaml

* resolve conflicts

* Update sparse_ops.yaml

* don't specify tensor place

* new shape has -1 or 0 in it

* Update unary_grad_kernel.h

* correct lvalue error

* code style

* Update sparse_backward.yaml

* Update sparse_ops.yaml

* Update unary_kernel.h

* Update unary.py

* Update sparse_backward.yaml

* Update unary.py

* code style

* code style

* code style

* Update unary.py

* specify tensor place explicitly

* do not use numpy array

* use numpy array in unit test again

* modify example code in docstring
上级 64307903
......@@ -272,6 +272,17 @@
func : relu_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
relu_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
- backward_op : reshape_grad
forward : reshape(Tensor x, IntArray shape) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : reshape_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
reshape_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
- backward_op : scale_grad
forward : scale(Tensor x, float scale, float bias, bool bias_after_scale) -> Tensor(out)
args : (Tensor out_grad, float scale)
......
......@@ -489,3 +489,14 @@
func : sync_batch_norm_coo{sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense}
data_type : x
backward : sync_batch_norm_grad
- op : reshape
args : (Tensor x, IntArray shape)
output : Tensor(out)
infer_meta :
func : ReshapeInferMeta
kernel :
func : reshape_coo{sparse_coo -> sparse_coo},
reshape_csr{sparse_csr -> sparse_csr}
layout : x
backward : reshape_grad
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void ReshapeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
SparseCooTensor* dx) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCooKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}
template <typename T, typename Context>
void ReshapeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
SparseCsrTensor* dx) {
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCsrKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(reshape_coo_grad,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(reshape_csr_grad,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void ReshapeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out) {
// TODO(OccupyMars2025): Currently, reshape is only applicable to sparse dims
int64_t x_nnz = x.nnz();
// Use DDim::reshape to handle -1 and 0 in the argument "shape"
std::vector<int> new_shape(shape.GetData().begin(), shape.GetData().end());
phi::DDim out_dims = x.dims().reshape(new_shape);
// get sparse part dimensions of x and out
std::vector<int64_t> x_sparse_part_dims;
std::vector<int64_t> out_sparse_part_dims;
for (int i = 0; i < x.sparse_dim(); ++i) {
x_sparse_part_dims.push_back(x.dims()[i]);
}
for (int i = 0; i < out_dims.size() - x.dense_dim(); ++i) {
out_sparse_part_dims.push_back(out_dims[i]);
}
DenseTensor out_indices = Empty<int64_t, Context>(
dev_ctx, {static_cast<int64_t>(out_sparse_part_dims.size()), x_nnz});
DenseTensor out_values(x.values());
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
// compute values of indices
const DenseTensor& x_indices = x.indices();
const auto* x_indices_data = x_indices.data<int64_t>();
auto* out_indices_data = out_indices.data<int64_t>();
const phi::DDim& x_sparse_part_strides =
phi::stride(phi::make_ddim(x_sparse_part_dims));
const phi::DDim& out_sparse_part_strides =
phi::stride(phi::make_ddim(out_sparse_part_dims));
int64_t location = 0;
for (int64_t j = 0; j < x_nnz; ++j) {
location = 0;
for (int i = 0; i < x.sparse_dim(); ++i) {
location += x_indices_data[i * x_nnz + j] * x_sparse_part_strides[i];
}
for (size_t i = 0; i < out_sparse_part_dims.size(); ++i) {
out_indices_data[i * x_nnz + j] = location / out_sparse_part_strides[i];
location %= out_sparse_part_strides[i];
}
}
}
template <typename T, typename Context>
void ReshapeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& shape,
SparseCsrTensor* out) {
// transform csr format to coo format, and then use coo kernel
const SparseCooTensor x_coo = CsrToCoo<T, Context>(dev_ctx, x);
SparseCooTensor out_coo;
ReshapeCooKernel<T, Context>(dev_ctx, x_coo, shape, &out_coo);
CooToCsrKernel<T, Context>(dev_ctx, out_coo, out);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(reshape_coo,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(reshape_csr,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
......@@ -329,7 +329,8 @@ PD_REGISTER_KERNEL(csr_to_coo,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
bool) {}
PD_REGISTER_KERNEL(coo_to_csr,
CPU,
......@@ -342,7 +343,8 @@ PD_REGISTER_KERNEL(coo_to_csr,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
bool) {}
PD_REGISTER_KERNEL(dense_to_csr,
CPU,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
namespace phi {
namespace sparse {
// just copy from paddle\phi\kernels\sparse\cpu\reshape_grad_kernel.cc
template <typename T, typename Context>
void ReshapeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
SparseCooTensor* dx) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCooKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}
// just copy from paddle\phi\kernels\sparse\cpu\reshape_grad_kernel.cc
template <typename T, typename Context>
void ReshapeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
SparseCsrTensor* dx) {
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCsrKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(reshape_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooGradKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(reshape_csr_grad,
GPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrGradKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
namespace phi {
namespace sparse {
__global__ void ReshapeCooCudaKernel(const int64_t* x_indices_data,
const int num_x_sparse_part_dims,
const int num_out_sparse_part_dims,
const int64_t x_nnz,
const int64_t* x_sparse_part_strides,
const int64_t* out_sparse_part_strides,
int64_t* out_indices_data) {
CUDA_KERNEL_LOOP_TYPE(j, x_nnz, int64_t) {
int64_t location = 0;
for (int i = 0; i < num_x_sparse_part_dims; ++i) {
location += x_indices_data[i * x_nnz + j] * x_sparse_part_strides[i];
}
for (int i = 0; i < num_out_sparse_part_dims; ++i) {
out_indices_data[i * x_nnz + j] = location / out_sparse_part_strides[i];
location %= out_sparse_part_strides[i];
}
}
}
template <typename T, typename Context>
void ReshapeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out) {
int64_t x_nnz = x.nnz();
std::vector<int> new_shape(shape.GetData().begin(), shape.GetData().end());
phi::DDim out_dims = x.dims().reshape(new_shape);
// get sparse part dimensions of x and out
std::vector<int64_t> x_sparse_part_dims;
std::vector<int64_t> out_sparse_part_dims;
for (int i = 0; i < x.sparse_dim(); ++i) {
x_sparse_part_dims.push_back(x.dims()[i]);
}
for (int i = 0; i < out_dims.size() - x.dense_dim(); ++i) {
out_sparse_part_dims.push_back(out_dims[i]);
}
DenseTensor out_indices = Empty<int64_t, Context>(
dev_ctx, {static_cast<int64_t>(out_sparse_part_dims.size()), x_nnz});
DenseTensor out_values(x.values());
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
// compute values of out indices
const auto* x_indices_data = x.indices().data<int64_t>();
auto* out_indices_data = out_indices.data<int64_t>();
const phi::DDim& x_sparse_part_strides =
phi::stride(phi::make_ddim(x_sparse_part_dims));
const phi::DDim& out_sparse_part_strides =
phi::stride(phi::make_ddim(out_sparse_part_dims));
int64_t *destination_x_sparse_part_strides,
*destination_out_sparse_part_strides;
#ifdef PADDLE_WITH_HIP
hipMalloc(reinterpret_cast<void**>(&destination_x_sparse_part_strides),
sizeof(int64_t) * x_sparse_part_strides.size());
hipMemcpy(destination_x_sparse_part_strides,
x_sparse_part_strides.Get(),
sizeof(int64_t) * x_sparse_part_strides.size(),
hipMemcpyHostToDevice);
hipMalloc(reinterpret_cast<void**>(&destination_out_sparse_part_strides),
sizeof(int64_t) * out_sparse_part_strides.size());
hipMemcpy(destination_out_sparse_part_strides,
out_sparse_part_strides.Get(),
sizeof(int64_t) * out_sparse_part_strides.size(),
hipMemcpyHostToDevice);
#else
cudaMalloc(reinterpret_cast<void**>(&destination_x_sparse_part_strides),
sizeof(int64_t) * x_sparse_part_strides.size());
cudaMemcpy(destination_x_sparse_part_strides,
x_sparse_part_strides.Get(),
sizeof(int64_t) * x_sparse_part_strides.size(),
cudaMemcpyHostToDevice);
cudaMalloc(reinterpret_cast<void**>(&destination_out_sparse_part_strides),
sizeof(int64_t) * out_sparse_part_strides.size());
cudaMemcpy(destination_out_sparse_part_strides,
out_sparse_part_strides.Get(),
sizeof(int64_t) * out_sparse_part_strides.size(),
cudaMemcpyHostToDevice);
#endif
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_nnz, 1);
ReshapeCooCudaKernel<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(
x_indices_data,
x_sparse_part_dims.size(),
out_sparse_part_dims.size(),
x_nnz,
destination_x_sparse_part_strides,
destination_out_sparse_part_strides,
out_indices_data);
}
// just copy from paddle\phi\kernels\sparse\cpu\reshape_kernel.cc
template <typename T, typename Context>
void ReshapeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& shape,
SparseCsrTensor* out) {
// transform csr format to coo format, and then use coo kernel
const SparseCooTensor x_coo = CsrToCoo<T, Context>(dev_ctx, x);
SparseCooTensor out_coo;
ReshapeCooKernel<T, Context>(dev_ctx, x_coo, shape, &out_coo);
CooToCsrKernel<T, Context>(dev_ctx, out_coo, out);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(reshape_coo,
GPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(reshape_csr,
GPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
......@@ -539,7 +539,8 @@ PD_REGISTER_KERNEL(csr_to_coo,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
bool) {}
PD_REGISTER_KERNEL(coo_to_csr,
GPU,
......@@ -552,7 +553,8 @@ PD_REGISTER_KERNEL(coo_to_csr,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
bool) {}
PD_REGISTER_KERNEL(dense_to_csr,
GPU,
......
......@@ -89,5 +89,17 @@ void TransposeCsrGradKernel(const Context& dev_ctx,
const std::vector<int>& perm,
SparseCsrTensor* dx);
template <typename T, typename Context>
void ReshapeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
SparseCooTensor* dx);
template <typename T, typename Context>
void ReshapeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
SparseCsrTensor* dx);
} // namespace sparse
} // namespace phi
......@@ -14,6 +14,8 @@
#pragma once
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
......@@ -155,5 +157,43 @@ SparseCooTensor ReluCsr(const Context& dev_ctx, const SparseCooTensor& x) {
return csr;
}
template <typename T, typename Context>
void ReshapeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out);
template <typename T, typename Context>
void ReshapeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& shape,
SparseCsrTensor* out);
template <typename T, typename Context>
SparseCooTensor ReshapeCoo(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape) {
SparseCooTensor coo;
ReshapeCooKernel<T, Context>(dev_ctx, x, shape, &coo);
return coo;
}
template <typename T, typename Context>
SparseCsrTensor ReshapeCsr(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& shape) {
PADDLE_ENFORCE_LE(
2,
shape.size(),
phi::errors::InvalidArgument("size of shape must be equal to 2 or 3"));
PADDLE_ENFORCE_GE(
3,
shape.size(),
phi::errors::InvalidArgument("size of shape must be equal to 2 or 3"));
SparseCsrTensor csr;
ReshapeCsrKernel<T, Context>(dev_ctx, x, shape, &csr);
return csr;
}
} // namespace sparse
} // namespace phi
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
class TestReshape(unittest.TestCase):
"""
Test the API paddle.incubate.sparse.reshape on some sparse tensors.
x: sparse, out: sparse
"""
def check_result(self, x_shape, new_shape, format):
"""
x_shape: original shape
new_shape: new shape
format: "coo" or "csr"
Transform a sparse tensor with shape "x_shape" to
a sparse tensor with shape "new_shape".
Compare the output of paddle.reshape and the output of
paddle.incubate.sparse.reshape.
"""
mask = np.random.randint(0, 2, x_shape)
np_x = np.random.randint(-100, 100, x_shape) * mask
# check cpu kernel
dense_x = paddle.to_tensor(np_x, place=paddle.CPUPlace())
dense_x.stop_gradient = False
dense_out = paddle.reshape(dense_x, new_shape)
if format == "coo":
sp_x = paddle.to_tensor(np_x,
place=paddle.CPUPlace()).to_sparse_coo(
len(x_shape))
else:
sp_x = paddle.to_tensor(np_x,
place=paddle.CPUPlace()).to_sparse_csr()
sp_x.stop_gradient = False
sp_out = paddle.incubate.sparse.reshape(sp_x, new_shape)
np.testing.assert_allclose(sp_out.to_dense().numpy(),
dense_out.numpy(),
rtol=1e-05)
dense_out.backward()
sp_out.backward()
np.testing.assert_allclose(sp_x.grad.to_dense().numpy(),
dense_x.grad.numpy() *
np_x.astype('bool').astype('int'),
rtol=1e-05)
# check gpu kernel
if paddle.device.is_compiled_with_cuda():
dense_x = paddle.to_tensor(np_x, place=paddle.CUDAPlace(0))
dense_x.stop_gradient = False
dense_out = paddle.reshape(dense_x, new_shape)
if format == "coo":
sp_x = paddle.to_tensor(
np_x, place=paddle.CUDAPlace(0)).to_sparse_coo(len(x_shape))
else:
sp_x = paddle.to_tensor(
np_x, place=paddle.CUDAPlace(0)).to_sparse_csr()
sp_x.stop_gradient = False
sp_out = paddle.incubate.sparse.reshape(sp_x, new_shape)
np.testing.assert_allclose(sp_out.to_dense().numpy(),
dense_out.numpy(),
rtol=1e-05)
dense_out.backward()
sp_out.backward()
np.testing.assert_allclose(sp_x.grad.to_dense().numpy(),
dense_x.grad.numpy() *
np_x.astype('bool').astype('int'),
rtol=1e-05)
def test_reshape_2d(self):
self.check_result([2, 5], [
10,
], 'coo')
self.check_result([12, 5], [15, 4], 'coo')
self.check_result([10, 5], [2, 25], 'csr')
self.check_result([9, 8], [18, 4], 'csr')
def test_reshape_3d(self):
self.check_result([6, 2, 3], [6, 2, 3], 'coo')
self.check_result([6, 2, 3], [2, 3, 3, 2], 'coo')
self.check_result([6, 2, 3], [1, 18, 2], 'coo')
self.check_result([6, 2, 3], [2, 9, 2], 'coo')
self.check_result([6, 2, 3], [2, 1, 18], 'coo')
self.check_result([6, 2, 3], [1, 2, 2, 3, 3], 'coo')
self.check_result([6, 2, 3], [6, 2, 3], 'csr')
self.check_result([6, 2, 3], [6, 3, 2], 'csr')
self.check_result([6, 2, 3], [2, 6, 3], 'csr')
self.check_result([6, 2, 3], [3, 6, 2], 'csr')
self.check_result([6, 2, 3], [4, 9, 1], 'csr')
self.check_result([6, 2, 3], [12, 1, 3], 'csr')
def test_reshape_nd(self):
self.check_result([8, 3, 4, 4, 5, 3], [24, 8, 10, 3], 'coo')
self.check_result([3, 4, 4, 5, 7], [1, 12, 2, 5, 14], 'coo')
def test_reshape_with_zero_or_minus_one_in_new_shape(self):
self.check_result([6, 2, 3], [-1, 0, 3], 'coo')
self.check_result([6, 2, 3], [2, 3, 0, -1], 'coo')
self.check_result([6, 2, 3], [1, -1, 2], 'coo')
self.check_result([6, 2, 3], [-1, 9, 2], 'coo')
self.check_result([6, 2, 3], [2, -1, 18], 'coo')
self.check_result([6, 2, 3], [1, 0, 2, -1, 3], 'coo')
self.check_result([6, 2, 3], [0, 0, -1], 'csr')
self.check_result([6, 2, 3], [-1, 3, 2], 'csr')
self.check_result([6, 2, 3], [2, -1, 0], 'csr')
self.check_result([6, 2, 3], [-1, 6, 2], 'csr')
self.check_result([6, 2, 3], [-1, 9, 1], 'csr')
self.check_result([6, 2, 3], [-1, 1, 3], 'csr')
if __name__ == "__main__":
unittest.main()
......@@ -35,6 +35,7 @@ from .unary import deg2rad
from .unary import rad2deg
from .unary import expm1
from .unary import transpose
from .unary import reshape
from .binary import mv
from .binary import matmul
......@@ -50,35 +51,9 @@ from .multiary import addmm
from . import nn
__all__ = [
'sparse_coo_tensor',
'sparse_csr_tensor',
'sin',
'tan',
'asin',
'atan',
'sinh',
'tanh',
'asinh',
'atanh',
'sqrt',
'square',
'log1p',
'abs',
'pow',
'cast',
'neg',
'deg2rad',
'rad2deg',
'expm1',
'mv',
'matmul',
'masked_matmul',
'addmm',
'add',
'subtract',
'transpose',
'multiply',
'divide',
'coalesce',
'is_same_shape',
'sparse_coo_tensor', 'sparse_csr_tensor', 'sin', 'tan', 'asin', 'atan',
'sinh', 'tanh', 'asinh', 'atanh', 'sqrt', 'square', 'log1p', 'abs', 'pow',
'cast', 'neg', 'deg2rad', 'rad2deg', 'expm1', 'mv', 'matmul',
'masked_matmul', 'addmm', 'add', 'subtract', 'transpose', 'multiply',
'divide', 'coalesce', 'is_same_shape', 'reshape'
]
......@@ -639,3 +639,60 @@ def expm1(x, name=None):
out = paddle.incubate.sparse.expm1(sparse_x)
"""
return _C_ops.sparse_expm1(x)
@dygraph_only
def reshape(x, shape, name=None):
"""
Changes the shape of ``x`` without changing its value, requiring x to be a SparseCooTensor or SparseCsrTensor.
Currently this function can only reshape the sparse dims of ``x`` , but ``shape`` argument must be specified
as the shape of the reshaped tensor.
Note that if x is a SparseCsrTensor, then len(shape) must be 2 or 3.
There are some tricks when specifying the target shape.
- 1. -1 means the value of this dimension is inferred from the total element number of x and remaining dimensions. Thus one and only one dimension can be set -1.
- 2. 0 means the actual dimension value is going to be copied from the corresponding dimension of x. The indices of 0 in the target shape can not exceed the rank of x.
Here are some examples to explain it.
- 1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape is [6, 8], the reshape operator will transform x into a 2-D tensor with shape [6, 8] and leaving x's data unchanged.
- 2. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape is [2, 3, -1, 2], the reshape operator will transform x into a 4-D tensor with shape [2, 3, 4, 2] and leaving x's data unchanged. In this case, one dimension of the target shape is set to -1, the value of this dimension is inferred from the total element number of x and remaining dimensions.
- 3. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape is [-1, 0, 3, 2], the reshape operator will transform x into a 4-D tensor with shape [2, 4, 3, 2] and leaving x's data unchanged. In this case, besides -1, 0 means the actual dimension value is going to be copied from the corresponding dimension of x.
Args:
x (Tensor): The input sparse tensor with data type ``float32``, ``float64``, ``int32``, ``int64`` or ``bool``.
shape (list|tuple): Define the target shape. At most one dimension of the target shape can be -1.
The data type is ``int32``.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: A reshaped Tensor with the same data type as ``x``.
Examples:
.. code-block:: python
import paddle
x_shape = [6, 2, 3]
new_shape = [1, 0, 2, -1, 3]
format = "coo"
dense_x = paddle.randint(-100, 100, x_shape) * paddle.randint(0, 2, x_shape)
if format == "coo":
sp_x = dense_x.to_sparse_coo(len(x_shape))
else:
sp_x = dense_x.to_sparse_csr()
sp_out = paddle.incubate.sparse.reshape(sp_x, new_shape)
print(sp_out)
# the shape of sp_out is [1, 2, 2, 3, 3]
"""
return _C_ops.sparse_reshape(x, shape)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册