未验证 提交 acd08a9b 编写于 作者: Z zhangkaihuo 提交者: GitHub

Add kernel sparse_mask_helper; sparse_coo_tensor_grad (#41586)

上级 d84934da
...@@ -40,6 +40,45 @@ inline const DDim InferDenseDims(const DDim& x_dims, ...@@ -40,6 +40,45 @@ inline const DDim InferDenseDims(const DDim& x_dims,
return values_dims; return values_dims;
} }
template <typename IntT>
inline const IntT HOSTDEVICE IndicesToIndex(const IntT* indices,
const IntT* sparse_offsets,
const int64_t non_zero_num,
const int64_t sparse_dim,
const int i) {
IntT index = 0;
for (IntT j = 0; j < sparse_dim; j++) {
index += indices[j * non_zero_num + i] * sparse_offsets[j];
}
return index;
}
template <typename IntT>
inline void HOSTDEVICE FlattenIndices(const IntT* indices,
const IntT* sparse_offsets,
const int64_t non_zero_num,
const int64_t sparse_dim,
const int start,
const int stride,
IntT* out) {
for (int i = start; i < non_zero_num; i += stride) {
out[i] =
IndicesToIndex(indices, sparse_offsets, non_zero_num, sparse_dim, i);
}
}
// 1. indices.dims().size() == 2
template <typename IntT>
inline void CalcOffsetsPerDim(const DDim& dims,
const int64_t sparse_dim,
std::vector<IntT>* offsets) {
IntT offset = 1;
for (IntT i = sparse_dim - 1; i >= 0; i--) {
(*offsets)[i] = offset;
offset *= dims[i];
}
}
} // namespace sparse } // namespace sparse
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/common_shape.h"
#include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/api/ext/dispatch.h"
...@@ -38,12 +39,6 @@ void SparseMaskCPUKernel(const CPUContext& dev_ctx, ...@@ -38,12 +39,6 @@ void SparseMaskCPUKernel(const CPUContext& dev_ctx,
const DenseTensor& indices = mask.non_zero_indices(); const DenseTensor& indices = mask.non_zero_indices();
const DenseTensor& values = mask.non_zero_elements(); const DenseTensor& values = mask.non_zero_elements();
int sparse_dim = indices.dims().size(); int sparse_dim = indices.dims().size();
std::vector<int64_t> sparse_offsets(sparse_dim);
int64_t offset = 1;
for (int i = sparse_dim - 1; i >= 0; i--) {
sparse_offsets[i] = offset;
offset *= dims[i];
}
DenseTensor out_indices = phi::EmptyLike<T>(dev_ctx, indices); DenseTensor out_indices = phi::EmptyLike<T>(dev_ctx, indices);
DenseTensor out_values = phi::EmptyLike<T>(dev_ctx, values); DenseTensor out_values = phi::EmptyLike<T>(dev_ctx, values);
...@@ -51,21 +46,25 @@ void SparseMaskCPUKernel(const CPUContext& dev_ctx, ...@@ -51,21 +46,25 @@ void SparseMaskCPUKernel(const CPUContext& dev_ctx,
// the out_indices is same as indices of mask // the out_indices is same as indices of mask
phi::Copy(dev_ctx, indices, dev_ctx.GetPlace(), false, &out_indices); phi::Copy(dev_ctx, indices, dev_ctx.GetPlace(), false, &out_indices);
const IntT* indices_ptr = indices.data<IntT>();
T* out_values_ptr = out_values.data<T>(); T* out_values_ptr = out_values.data<T>();
const T* x_ptr = x.data<T>(); const T* x_ptr = x.data<T>();
const int64_t non_zero_num = mask.nnz(); const int64_t non_zero_num = mask.nnz();
auto dims_2d = flatten_to_2d(dims, sparse_dim); auto dims_2d = flatten_to_2d(dims, sparse_dim);
const int cols = dims_2d[1]; const int cols = dims_2d[1];
const IntT* indices_ptr = indices.data<IntT>();
std::vector<IntT> out_indexs(non_zero_num), sparse_offsets(sparse_dim);
phi::funcs::sparse::CalcOffsetsPerDim<IntT>(
dims, sparse_dim, &sparse_offsets);
for (int64_t i = 0; i < non_zero_num; i++) { for (int64_t i = 0; i < non_zero_num; i++) {
int64_t index = 0; int64_t index = phi::funcs::sparse::IndicesToIndex<IntT>(
for (int j = 0; j < sparse_dim; j++) { indices_ptr, sparse_offsets.data(), non_zero_num, sparse_dim, i);
index += indices_ptr[j * non_zero_num + i] * sparse_offsets[j];
}
memcpy(out_values_ptr + i * cols, x_ptr + index * cols, cols * sizeof(T)); memcpy(out_values_ptr + i * cols, x_ptr + index * cols, cols * sizeof(T));
} }
out->SetMember(out_indices, out_values, dims, true); out->SetMember(out_indices, out_values, dims, true);
} }
...@@ -85,6 +84,73 @@ void SparseMaskKernel(const Context& dev_ctx, ...@@ -85,6 +84,73 @@ void SparseMaskKernel(const Context& dev_ctx,
})); }));
} }
template <typename T, typename IntT>
void SparseMaskHelperCPUKernel(const CPUContext& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& mask_indices,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
mask_indices.dims().size(),
2,
phi::errors::InvalidArgument("the mask_indices must be 2-D tensor"));
const int64_t sparse_dim = x.non_zero_indices().dims()[0];
std::vector<IntT> sparse_offsets(sparse_dim), x_indexs(x.nnz()),
mask_indexs(mask_indices.dims()[1]);
phi::funcs::sparse::CalcOffsetsPerDim<IntT>(
x.dims(), sparse_dim, &sparse_offsets);
phi::funcs::sparse::FlattenIndices(x.non_zero_indices().data<IntT>(),
sparse_offsets.data(),
x.nnz(),
sparse_dim,
0,
1,
x_indexs.data());
phi::funcs::sparse::FlattenIndices(mask_indices.data<IntT>(),
sparse_offsets.data(),
x.nnz(),
sparse_dim,
0,
1,
mask_indexs.data());
std::unordered_map<IntT, uint64_t> x_indexs_map;
for (uint64_t i = 0; i < x_indexs.size(); i++) {
x_indexs_map[x_indexs[i]] = i;
}
*out = phi::EmptyLike<T>(dev_ctx, x.non_zero_elements());
T* out_ptr = out->data<T>();
memset(out_ptr, static_cast<T>(0), out->numel() * sizeof(T));
const int64_t stride =
x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim;
const T* in_ptr = x.non_zero_elements().data<T>();
// TODO(zhangkaihuo): multithreading can be used for acceleration
for (uint64_t i = 0; i < mask_indexs.size(); i++) {
auto iter = x_indexs_map.find(mask_indexs[i]);
if (iter != x_indexs_map.end()) {
memcpy(out_ptr + i * stride,
in_ptr + iter->second * stride,
stride * sizeof(T));
}
}
}
/**
* @brief filter values from x.values() using mask_indices
*/
template <typename T, typename Context>
void SparseMaskHelperKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& mask_indices,
DenseTensor* out) {
PD_DISPATCH_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "SparseMaskHelperCPUKernel", ([&] {
SparseMaskHelperCPUKernel<T, data_t>(dev_ctx, x, mask_indices, out);
}));
}
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
...@@ -101,3 +167,16 @@ PD_REGISTER_KERNEL(sparse_mask, ...@@ -101,3 +167,16 @@ PD_REGISTER_KERNEL(sparse_mask,
int64_t) { int64_t) {
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
} }
PD_REGISTER_KERNEL(sparse_mask_helper,
CPU,
ALL_LAYOUT,
phi::sparse::SparseMaskHelperKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
...@@ -394,3 +394,15 @@ PD_REGISTER_KERNEL(csr_values, ...@@ -394,3 +394,15 @@ PD_REGISTER_KERNEL(csr_values,
int64_t) { int64_t) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
} }
PD_REGISTER_KERNEL(sparse_coo_tensor,
CPU,
ALL_LAYOUT,
phi::sparse::SparseCooTensorKernel,
float,
double,
phi::dtype::float16,
uint8_t,
int16_t,
int,
int64_t) {}
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <thrust/binary_search.h>
#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
...@@ -20,6 +22,7 @@ limitations under the License. */ ...@@ -20,6 +22,7 @@ limitations under the License. */
#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/common_shape.h"
#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h" #include "paddle/phi/kernels/sparse/sparse_mask_kernel.h"
#include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/api/ext/dispatch.h"
...@@ -59,7 +62,7 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx, ...@@ -59,7 +62,7 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx,
const DenseTensor& indices = mask.non_zero_indices(); const DenseTensor& indices = mask.non_zero_indices();
const DenseTensor& values = mask.non_zero_elements(); const DenseTensor& values = mask.non_zero_elements();
int sparse_dim = indices.dims().size(); int sparse_dim = indices.dims().size();
DenseTensor sparse_offsets = phi::Empty( DenseTensor sparse_offsets = phi::Empty<GPUContext>(
dev_ctx, dev_ctx,
DenseTensorMeta(DataType::INT64, {sparse_dim}, DataLayout::NCHW)); DenseTensorMeta(DataType::INT64, {sparse_dim}, DataLayout::NCHW));
std::vector<int64_t> h_sparse_offsets(sparse_dim); std::vector<int64_t> h_sparse_offsets(sparse_dim);
...@@ -121,6 +124,153 @@ void SparseMaskKernel(const Context& dev_ctx, ...@@ -121,6 +124,153 @@ void SparseMaskKernel(const Context& dev_ctx,
})); }));
} }
// TODO(zhangkaihuo): Use an op to realize the function of FlattenIndices
template <typename IntT>
__global__ void FlattenIndicesKernel(const IntT* indices,
const IntT* sparse_offsets,
const int64_t non_zero_num,
const int64_t sparse_dim,
IntT* out) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
phi::funcs::sparse::FlattenIndices<IntT>(indices,
sparse_offsets,
non_zero_num,
sparse_dim,
tid,
gridDim.x * blockDim.x,
out);
}
template <typename T, typename IntT>
__global__ void SparseMaskCopyKernel(const IntT* x_indexs,
const IntT* mask_indexs,
const IntT* bound_out,
const T* x_values,
const int64_t n,
const int64_t stride,
T* out_values) {
CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) {
const IntT j = bound_out[i];
if (j >= 0 && j < n && mask_indexs[i] == x_indexs[j]) {
for (int k = 0; k < stride; k++) {
out_values[i * stride + k] = x_values[j * stride + k];
}
}
}
}
template <typename T, typename IntT>
void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& mask_indices,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
mask_indices.dims().size(),
2,
phi::errors::InvalidArgument("the mask_indices must be 2-D tensor"));
const int64_t sparse_dim = x.non_zero_indices().dims()[0];
auto indices_dtype = paddle::experimental::CppTypeToDataType<IntT>::Type();
std::vector<IntT> sparse_offsets(sparse_dim);
DenseTensorMeta x_indexs_meta(indices_dtype, {x.nnz()}, DataLayout::NCHW);
DenseTensorMeta mask_indexs_meta(
indices_dtype, {mask_indices.dims()[1]}, DataLayout::NCHW);
DenseTensorMeta sparse_offset_meta(
indices_dtype, {sparse_dim}, DataLayout::NCHW);
DenseTensor x_indexs =
phi::Empty<GPUContext>(dev_ctx, std::move(x_indexs_meta));
DenseTensor mask_indexs =
phi::Empty<GPUContext>(dev_ctx, std::move(mask_indexs_meta));
DenseTensor bound_out =
phi::Empty<GPUContext>(dev_ctx, std::move(mask_indexs_meta));
DenseTensor d_sparse_offsets =
phi::Empty<GPUContext>(dev_ctx, std::move(sparse_offset_meta));
IntT* x_indexs_ptr = x_indexs.data<IntT>();
IntT* mask_indexs_ptr = mask_indexs.data<IntT>();
IntT* bound_out_ptr = bound_out.data<IntT>();
// 1. calc the offsets of per dim
phi::funcs::sparse::CalcOffsetsPerDim(x.dims(), sparse_dim, &sparse_offsets);
// 2. copy sparse_offsets to device
phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data<IntT>(),
sparse_offsets.data(),
sizeof(IntT) * sparse_dim,
#ifdef PADDLE_WITH_HIP
hipMemcpyHostToDevice,
#else
cudaMemcpyHostToDevice,
#endif
dev_ctx.stream());
// 3. flatten x indices and mask indices
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_indexs.numel(), 1);
FlattenIndicesKernel<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(x.non_zero_indices().data<IntT>(),
d_sparse_offsets.data<IntT>(),
x_indexs.numel(),
sparse_dim,
x_indexs_ptr);
config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1);
FlattenIndicesKernel<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(mask_indices.data<IntT>(),
d_sparse_offsets.data<IntT>(),
mask_indexs.numel(),
sparse_dim,
mask_indexs_ptr);
// 4. call thrust::lower_bound
#ifdef PADDLE_WITH_HIP
thrust::lower_bound(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::lower_bound(thrust::cuda::par.on(dev_ctx.stream()),
#endif
x_indexs_ptr,
x_indexs_ptr + x_indexs.numel(),
mask_indexs_ptr,
mask_indexs_ptr + mask_indexs.numel(),
bound_out_ptr);
// 5. copy value to out
*out = phi::EmptyLike<T>(dev_ctx, x.non_zero_elements());
phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, out, static_cast<T>(0));
T* out_ptr = out->data<T>();
const int64_t stride =
x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim;
SparseMaskCopyKernel<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(x_indexs_ptr,
mask_indexs_ptr,
bound_out_ptr,
x.non_zero_elements().data<T>(),
mask_indexs.numel(),
stride,
out_ptr);
}
template <typename T, typename Context>
void SparseMaskHelperKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& mask_indices,
DenseTensor* out) {
PD_DISPATCH_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "SparseMaskHelperGPUKernel", ([&] {
SparseMaskHelperGPUKernel<T, data_t>(dev_ctx, x, mask_indices, out);
}));
}
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
...@@ -138,3 +288,17 @@ PD_REGISTER_KERNEL(sparse_mask, ...@@ -138,3 +288,17 @@ PD_REGISTER_KERNEL(sparse_mask,
int64_t) { int64_t) {
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
} }
PD_REGISTER_KERNEL(sparse_mask_helper,
GPU,
ALL_LAYOUT,
phi::sparse::SparseMaskHelperKernel,
float,
double,
phi::dtype::float16,
uint8_t,
int16_t,
int,
int64_t) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
...@@ -665,3 +665,15 @@ PD_REGISTER_KERNEL(csr_values, ...@@ -665,3 +665,15 @@ PD_REGISTER_KERNEL(csr_values,
int64_t) { int64_t) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
} }
PD_REGISTER_KERNEL(sparse_coo_tensor,
GPU,
ALL_LAYOUT,
phi::sparse::SparseCooTensorKernel,
float,
double,
phi::dtype::float16,
uint8_t,
int16_t,
int,
int64_t) {}
...@@ -26,5 +26,11 @@ void SparseMaskKernel(const Context& dev_ctx, ...@@ -26,5 +26,11 @@ void SparseMaskKernel(const Context& dev_ctx,
const SparseCooTensor& mask, const SparseCooTensor& mask,
SparseCooTensor* out); SparseCooTensor* out);
template <typename T, typename Context>
void SparseMaskHelperKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& mask_indices,
DenseTensor* out);
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
...@@ -66,6 +66,19 @@ PD_REGISTER_KERNEL(sparse_coo_to_dense_grad, ...@@ -66,6 +66,19 @@ PD_REGISTER_KERNEL(sparse_coo_to_dense_grad,
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
} }
PD_REGISTER_KERNEL(sparse_coo_tensor_grad,
CPU,
ALL_LAYOUT,
phi::sparse::SparseCooTensorGradKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t) {
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(coo_values_grad, PD_REGISTER_KERNEL(coo_values_grad,
GPU, GPU,
...@@ -95,4 +108,16 @@ PD_REGISTER_KERNEL(sparse_coo_to_dense_grad, ...@@ -95,4 +108,16 @@ PD_REGISTER_KERNEL(sparse_coo_to_dense_grad,
int64_t) { int64_t) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
} }
PD_REGISTER_KERNEL(sparse_coo_tensor_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SparseCooTensorGradKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t) {
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
#endif #endif
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h"
namespace phi { namespace phi {
namespace sparse { namespace sparse {
...@@ -32,5 +33,13 @@ void SparseCooToDenseGradKernel(const Context& dev_ctx, ...@@ -32,5 +33,13 @@ void SparseCooToDenseGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
SparseCooTensor* x_grad); SparseCooTensor* x_grad);
template <typename T, typename Context>
void SparseCooTensorGradKernel(const Context& dev_ctx,
const DenseTensor& indices,
const SparseCooTensor& out_grad,
DenseTensor* values_grad) {
SparseMaskHelperKernel<T, Context>(dev_ctx, out_grad, indices, values_grad);
}
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/api/lib/utils/storage.h" #include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h"
...@@ -147,5 +148,16 @@ void CsrValuesKernel(const Context& dev_ctx, ...@@ -147,5 +148,16 @@ void CsrValuesKernel(const Context& dev_ctx,
*out = x.non_zero_elements(); *out = x.non_zero_elements();
} }
template <typename T, typename Context>
void SparseCooTensorKernel(const Context& dev_ctx,
const DenseTensor& values,
const DenseTensor& indices,
const IntArray& dense_shape,
SparseCooTensor* out) {
*out =
SparseCooTensor(indices, values, phi::make_ddim(dense_shape.GetData()));
// TODO(zhangkaihuo): sort and merge the dumplicate indices
}
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
...@@ -134,9 +134,11 @@ class TestSparseConvert(unittest.TestCase): ...@@ -134,9 +134,11 @@ class TestSparseConvert(unittest.TestCase):
#test to_sparse_coo_grad backward #test to_sparse_coo_grad backward
out_grad_indices = [[0, 1], [0, 1]] out_grad_indices = [[0, 1], [0, 1]]
out_grad_values = [2.0, 3.0] out_grad_values = [2.0, 3.0]
out_grad = core.eager.sparse_coo_tensor( out_grad = paddle.sparse.sparse_coo_tensor(
paddle.to_tensor(out_grad_indices), paddle.to_tensor(out_grad_indices),
paddle.to_tensor(out_grad_values), out.shape, True) paddle.to_tensor(out_grad_values),
shape=out.shape,
stop_gradient=True)
out.backward(out_grad) out.backward(out_grad)
assert np.array_equal(dense_x.grad.numpy(), assert np.array_equal(dense_x.grad.numpy(),
out_grad.to_dense().numpy()) out_grad.to_dense().numpy())
...@@ -145,9 +147,11 @@ class TestSparseConvert(unittest.TestCase): ...@@ -145,9 +147,11 @@ class TestSparseConvert(unittest.TestCase):
with _test_eager_guard(): with _test_eager_guard():
indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]] indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]]
values = [1.0, 2.0, 3.0, 4.0, 5.0] values = [1.0, 2.0, 3.0, 4.0, 5.0]
sparse_x = core.eager.sparse_coo_tensor( sparse_x = paddle.sparse.sparse_coo_tensor(
paddle.to_tensor(indices), paddle.to_tensor(indices),
paddle.to_tensor(values), [3, 4], False) paddle.to_tensor(values),
shape=[3, 4],
stop_gradient=False)
dense_tensor = sparse_x.to_dense() dense_tensor = sparse_x.to_dense()
#test to_dense_grad backward #test to_dense_grad backward
out_grad = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], out_grad = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],
...@@ -158,6 +162,17 @@ class TestSparseConvert(unittest.TestCase): ...@@ -158,6 +162,17 @@ class TestSparseConvert(unittest.TestCase):
assert np.array_equal(correct_x_grad, assert np.array_equal(correct_x_grad,
sparse_x.grad.values().numpy()) sparse_x.grad.values().numpy())
paddle.device.set_device("cpu")
sparse_x_cpu = paddle.sparse.sparse_coo_tensor(
paddle.to_tensor(indices),
paddle.to_tensor(values),
shape=[3, 4],
stop_gradient=False)
dense_tensor_cpu = sparse_x_cpu.to_dense()
dense_tensor_cpu.backward(paddle.to_tensor(out_grad))
assert np.array_equal(correct_x_grad,
sparse_x_cpu.grad.values().numpy())
def test_to_sparse_csr(self): def test_to_sparse_csr(self):
with _test_eager_guard(): with _test_eager_guard():
x = [[0, 1, 0, 2], [0, 0, 3, 0], [4, 5, 0, 0]] x = [[0, 1, 0, 2], [0, 0, 3, 0], [4, 5, 0, 0]]
...@@ -177,15 +192,52 @@ class TestSparseConvert(unittest.TestCase): ...@@ -177,15 +192,52 @@ class TestSparseConvert(unittest.TestCase):
with _test_eager_guard(): with _test_eager_guard():
indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]] indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]]
values = [1.0, 2.0, 3.0, 4.0, 5.0] values = [1.0, 2.0, 3.0, 4.0, 5.0]
sparse_x = core.eager.sparse_coo_tensor( sparse_x = paddle.sparse.sparse_coo_tensor(
paddle.to_tensor(indices), paddle.to_tensor(indices),
paddle.to_tensor(values), [3, 4], False) paddle.to_tensor(values),
shape=[3, 4],
stop_gradient=False)
values_tensor = sparse_x.values() values_tensor = sparse_x.values()
out_grad = [2.0, 3.0, 5.0, 8.0, 9.0] out_grad = [2.0, 3.0, 5.0, 8.0, 9.0]
# test coo_values_grad # test coo_values_grad
values_tensor.backward(paddle.to_tensor(out_grad)) values_tensor.backward(paddle.to_tensor(out_grad))
assert np.array_equal(out_grad, sparse_x.grad.values().numpy()) assert np.array_equal(out_grad, sparse_x.grad.values().numpy())
def test_sparse_coo_tensor_grad(self):
with _test_eager_guard():
indices = [[0, 1], [0, 1]]
values = [1, 2]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(
values, dtype='float32', stop_gradient=False)
sparse_x = paddle.sparse.sparse_coo_tensor(
indices, values, shape=[2, 2], stop_gradient=False)
grad_indices = [[0, 1], [1, 1]]
grad_values = [2, 3]
grad_indices = paddle.to_tensor(grad_indices, dtype='int32')
grad_values = paddle.to_tensor(grad_values, dtype='float32')
sparse_out_grad = paddle.sparse.sparse_coo_tensor(
grad_indices, grad_values, shape=[2, 2])
sparse_x.backward(sparse_out_grad)
correct_values_grad = [0, 3]
assert np.array_equal(correct_values_grad, values.grad.numpy())
place = core.CPUPlace()
indices_cpu = paddle.to_tensor(indices, dtype='int32', place=place)
values_cpu = paddle.to_tensor(
values, dtype='float32', place=place, stop_gradient=False)
sparse_x_cpu = paddle.sparse.sparse_coo_tensor(
indices_cpu,
values_cpu,
shape=[2, 2],
place=place,
stop_gradient=False)
sparse_out_grad_cpu = paddle.sparse.sparse_coo_tensor(
grad_indices, grad_values, shape=[2, 2], place=place)
sparse_x_cpu.backward(sparse_out_grad_cpu)
assert np.array_equal(correct_values_grad, values_cpu.grad.numpy())
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from paddle import _C_ops from paddle import _C_ops
from ..framework import core, dygraph_only from ..framework import core, dygraph_only
from ..framework import _current_expected_place, _get_paddle_place
from ..tensor import to_tensor from ..tensor import to_tensor
from ..tensor import max from ..tensor import max
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype
...@@ -38,6 +39,18 @@ def _infer_dense_shape(indices): ...@@ -38,6 +39,18 @@ def _infer_dense_shape(indices):
return list(lens.numpy()) return list(lens.numpy())
def _get_place(place):
place = _get_paddle_place(place)
if place is None:
place = _current_expected_place()
elif not isinstance(place, (core.Place, core.CPUPlace, core.CUDAPinnedPlace,
core.CUDAPlace)):
raise ValueError(
"'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace"
)
return place
@dygraph_only @dygraph_only
def sparse_coo_tensor(indices, def sparse_coo_tensor(indices,
values, values,
...@@ -94,6 +107,8 @@ def sparse_coo_tensor(indices, ...@@ -94,6 +107,8 @@ def sparse_coo_tensor(indices,
# values=[1., 2., 3.]) # values=[1., 2., 3.])
""" """
place = _get_place(place)
if not isinstance(indices, core.eager.Tensor): if not isinstance(indices, core.eager.Tensor):
indices = to_tensor( indices = to_tensor(
indices, dtype=None, place=place, stop_gradient=True) indices, dtype=None, place=place, stop_gradient=True)
...@@ -101,13 +116,20 @@ def sparse_coo_tensor(indices, ...@@ -101,13 +116,20 @@ def sparse_coo_tensor(indices,
values = to_tensor(values, dtype, place, stop_gradient) values = to_tensor(values, dtype, place, stop_gradient)
if len(indices.shape) != 2: if len(indices.shape) != 2:
raise ValueError("'indices' must be 2-D.") raise ValueError("'indices' must be 2-D.")
if place is not None:
if not indices.place._equals(place):
indices = indices._copy_to(place, False) indices = indices._copy_to(place, False)
if not values.place._equals(place):
values = values._copy_to(place, False) values = values._copy_to(place, False)
values = _handle_dtype(values, dtype) values = _handle_dtype(values, dtype)
values.stop_gradient = stop_gradient
if shape is None: if shape is None:
shape = _infer_dense_shape(indices) shape = _infer_dense_shape(indices)
return core.eager.sparse_coo_tensor(indices, values, shape, stop_gradient)
return _C_ops.final_state_sparse_create_sparse_coo_tensor(values, indices,
shape)
#TODO: need to support shape is None #TODO: need to support shape is None
...@@ -171,6 +193,9 @@ def sparse_csr_tensor(crows, ...@@ -171,6 +193,9 @@ def sparse_csr_tensor(crows,
# cols=[1, 3, 2, 0, 1], # cols=[1, 3, 2, 0, 1],
# values=[1, 2, 3, 4, 5]) # values=[1, 2, 3, 4, 5])
""" """
place = _get_place(place)
if not isinstance(crows, core.eager.Tensor): if not isinstance(crows, core.eager.Tensor):
crows = to_tensor(crows, dtype=None, place=place, stop_gradient=True) crows = to_tensor(crows, dtype=None, place=place, stop_gradient=True)
if not isinstance(cols, core.eager.Tensor): if not isinstance(cols, core.eager.Tensor):
...@@ -182,10 +207,15 @@ def sparse_csr_tensor(crows, ...@@ -182,10 +207,15 @@ def sparse_csr_tensor(crows,
"SparseCsrTensor only support 2-D or 3-D matrix. The 'crows', 'cols' and 'values' must be 1-D." "SparseCsrTensor only support 2-D or 3-D matrix. The 'crows', 'cols' and 'values' must be 1-D."
) )
if place is not None: if not crows.place._equals(place):
crows = crows._copy_to(place, False) crows = crows._copy_to(place, False)
if not cols.place._equals(place):
cols = cols._copy_to(place, False) cols = cols._copy_to(place, False)
if not values.place._equals(place):
values = values._copy_to(place, False) values = values._copy_to(place, False)
values = _handle_dtype(values, dtype) values = _handle_dtype(values, dtype)
values.stop_gradient = stop_gradient
return core.eager.sparse_csr_tensor(crows, cols, values, shape, return core.eager.sparse_csr_tensor(crows, cols, values, shape,
stop_gradient) stop_gradient)
...@@ -21,6 +21,14 @@ ...@@ -21,6 +21,14 @@
layout : x layout : x
backward : coo_values_grad backward : coo_values_grad
- api : create_sparse_coo_tensor
args : (Tensor values, Tensor indices, IntArray dense_shape)
output : Tensor(out@SparseCooTensor)
kernel :
func : sparse_coo_tensor
layout : values
backward : create_sparse_coo_tensor_grad
- api : csr_values - api : csr_values
args : (Tensor x) args : (Tensor x)
output : Tensor(out@DenseTensor) output : Tensor(out@DenseTensor)
......
...@@ -19,6 +19,13 @@ ...@@ -19,6 +19,13 @@
kernel : kernel :
func : coo_values_grad func : coo_values_grad
- backward_api : create_sparse_coo_tensor_grad
forward : create_sparse_coo_tensor(Tensor values, Tensor indices, IntArray dense_shape) -> Tensor(out@SparseCooTensor)
args : (Tensor indices, Tensor out_grad)
output : Tensor(values_grad@DenseTensor)
kernel :
func : sparse_coo_tensor_grad
- backward_api : dense_to_coo_grad - backward_api : dense_to_coo_grad
forward : dense_to_coo(Tensor x, int64_t sparse_dim) -> Tensor(out@SparseCooTensor) forward : dense_to_coo(Tensor x, int64_t sparse_dim) -> Tensor(out@SparseCooTensor)
args : (Tensor out_grad) args : (Tensor out_grad)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册