diff --git a/paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h b/paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h new file mode 100644 index 0000000000000000000000000000000000000000..26b8549aaafdc76767f18357a3cc420f7d646864 --- /dev/null +++ b/paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/kernels/funcs/sparse/flatten_indices.h" + +namespace phi { +namespace funcs { +namespace sparse { + +template +__global__ void FlattenIndicesKernel(const IntT* indices, + const IntT* sparse_offsets, + const int64_t non_zero_num, + const int64_t sparse_dim, + IntT* out) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + phi::funcs::sparse::FlattenIndices(indices, + sparse_offsets, + non_zero_num, + sparse_dim, + tid, + gridDim.x * blockDim.x, + out); +} + +template +__global__ void IndexToCoordinateKernel(const IntT* indexs, + const Dim dims, + const int64_t non_zero_num, + const int64_t sparse_dim, + IntT* indices) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + IndexToCoordinate(indexs, + dims, + non_zero_num, + sparse_dim, + tid, + gridDim.x * blockDim.x, + indices); +} + +} // namespace sparse +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/sparse/flatten_indices.h b/paddle/phi/kernels/funcs/sparse/flatten_indices.h new file mode 100644 index 0000000000000000000000000000000000000000..ca212e4366ec439ceeef22cc8de48790ee0d6c67 --- /dev/null +++ b/paddle/phi/kernels/funcs/sparse/flatten_indices.h @@ -0,0 +1,93 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/phi/core/ddim.h" + +namespace phi { +namespace funcs { +namespace sparse { + +template +inline const IntT HOSTDEVICE CoordinateToIndex(const IntT* indices, + const IntT* sparse_offsets, + const int64_t non_zero_num, + const int64_t sparse_dim, + const int i) { + IntT index = 0; + for (IntT j = 0; j < sparse_dim; j++) { + index += indices[j * non_zero_num + i] * sparse_offsets[j]; + } + return index; +} + +template +inline void HOSTDEVICE FlattenIndices(const IntT* indices, + const IntT* sparse_offsets, + const int64_t non_zero_num, + const int64_t sparse_dim, + const int64_t start, + const int64_t stride, + IntT* out) { + for (int64_t i = start; i < non_zero_num; i += stride) { + out[i] = + CoordinateToIndex(indices, sparse_offsets, non_zero_num, sparse_dim, i); + } +} + +// 1. indices.dims().size() == 2 +template +inline void CalcOffsetsPerDim(const DDim& dims, + const int64_t sparse_dim, + IntT* offsets) { + IntT offset = 1; + for (IntT i = sparse_dim - 1; i >= 0; i--) { + offsets[i] = offset; + offset *= dims[i]; + } +} + +template +inline void HOSTDEVICE IndexToCoordinate(const IntT index, + const Dim& dims, + const int64_t non_zero_num, + const int64_t sparse_dim, + const int indices_offset, + IntT* indices) { + IntT tmp_index = index; + for (int j = sparse_dim - 1; j >= 0; j--) { + indices[j * non_zero_num + indices_offset] = tmp_index % dims[j]; + tmp_index /= dims[j]; + } +} + +template +inline void HOSTDEVICE IndexToCoordinate(const IntT* indexs, + const Dim& dims, + const int64_t non_zero_num, + const int64_t sparse_dim, + const int64_t start, + const int64_t stride, + IntT* indices) { + for (int64_t i = start; i < non_zero_num; i += stride) { + IntT tmp_index = indexs[i]; + IndexToCoordinate(tmp_index, dims, non_zero_num, sparse_dim, i, indices); + } +} + +} // namespace sparse +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/sparse/scatter.cu.h b/paddle/phi/kernels/funcs/sparse/scatter.cu.h new file mode 100644 index 0000000000000000000000000000000000000000..9ed7cef12a148cc24ab925c1767c4fa92321879c --- /dev/null +++ b/paddle/phi/kernels/funcs/sparse/scatter.cu.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace phi { +namespace funcs { +namespace sparse { + +/** + * brief: scatter add + * input: the inputs + * unique_value: refer to UpdateIndexKernel notes + * out_index: the output feature index + * non_zero_num: the number of output features + * rulebook_len: the length of rulebook + * channels: the output channel size + * out: the outputs +**/ +template +__global__ void ScatterKernel(const T* input, + const int* unique_value, + const int* out_index, + const int non_zero_num, + const int rulebook_len, + const int channels, + T* out, + const bool subm = false) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) { + int indices_i = i / channels; + int channels_i = i - indices_i * channels; + + int start = unique_value[indices_i]; + int end = indices_i == non_zero_num - 1 ? rulebook_len + : unique_value[indices_i + 1]; + // max(end-start) = kernel_size + T sum = static_cast(0); + if (subm) { + sum = out[indices_i * channels + channels_i]; + } + for (int j = start; j < end; j++) { + const int out_feature_i = out_index[j]; + sum += input[out_feature_i * channels + channels_i]; + } + out[indices_i * channels + channels_i] = sum; + } +} + +} // namespace sparse +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/sparse/utils.cu.h b/paddle/phi/kernels/funcs/sparse/utils.cu.h new file mode 100644 index 0000000000000000000000000000000000000000..074fe1ca420497689cf7d6942bfe9c2709e5b191 --- /dev/null +++ b/paddle/phi/kernels/funcs/sparse/utils.cu.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace phi { +namespace funcs { +namespace sparse { + +// brief: calculation the distance between start and end +template +__global__ void DistanceKernel(const T* start, const T* end, T* distance) { + if (threadIdx.x == 0) { + *distance = end - start; + } +} + +} // namespace sparse +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/sparse/coalesced_kernel.h b/paddle/phi/kernels/sparse/coalesced_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..0755579a57ade255c98396d862dbe76283cb29a5 --- /dev/null +++ b/paddle/phi/kernels/sparse/coalesced_kernel.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/kernels/empty_kernel.h" + +namespace phi { +namespace sparse { + +template +void CoalescedKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out); + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/cpu/coalesced_kernel.cc b/paddle/phi/kernels/sparse/cpu/coalesced_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..0ebddf9b683f066904538db3dba812c23ea83fd0 --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/coalesced_kernel.cc @@ -0,0 +1,121 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/coalesced_kernel.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/funcs/sparse/flatten_indices.h" + +namespace phi { +namespace sparse { + +template +void CoalescedCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out) { + const DenseTensor& x_indices = x.non_zero_indices(); + const DenseTensor& x_values = x.non_zero_elements(); + DenseTensor out_indices = phi::EmptyLike(dev_ctx, x_indices); + DenseTensor out_values = phi::EmptyLike(dev_ctx, x_values); + + const int64_t sparse_dim = x.non_zero_indices().dims()[0]; + std::vector sparse_offsets(sparse_dim), x_indexs(x.nnz()); + phi::funcs::sparse::CalcOffsetsPerDim( + x.dims(), sparse_dim, sparse_offsets.data()); + + phi::funcs::sparse::FlattenIndices(x.non_zero_indices().data(), + sparse_offsets.data(), + x.nnz(), + sparse_dim, + 0, + 1, + x_indexs.data()); + + const T* x_values_ptr = x_values.data(); + const int64_t stride = + x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim; + + std::map> indices_to_index; + for (uint64_t i = 0; i < x_indexs.size(); i++) { + IntT index = x_indexs[i]; + if (indices_to_index.find(index) == indices_to_index.end()) { + std::vector indexs; + indexs.push_back(i); + indices_to_index[index] = indexs; + } else { + indices_to_index[index].push_back(i); + } + } + + const int64_t out_nnz = indices_to_index.size(); + + out_indices.Resize({x_indices.dims()[0], out_nnz}); + if (out_values.dims().size() == 1) { + out_values.Resize(phi::make_ddim({out_nnz})); + } else { + out_values.Resize(phi::make_ddim({out_nnz, x_values.dims()[1]})); + } + + IntT* out_indices_ptr = out_indices.data(); + T* out_values_ptr = out_values.data(); + auto iter = indices_to_index.begin(); + + Dim const_dims; + for (int i = 0; i < x.dims().size(); i++) { + const_dims[i] = x.dims()[i]; + } + + for (int i = 0; iter != indices_to_index.end(); iter++, i++) { + phi::funcs::sparse::IndexToCoordinate( + iter->first, const_dims, out_nnz, sparse_dim, i, out_indices_ptr); + memcpy(out_values_ptr + i * stride, + x_values_ptr + iter->second[0] * stride, + stride * sizeof(T)); + for (uint64_t j = 1; j < iter->second.size(); j++) { + for (int k = 0; k < stride; k++) { + out_values_ptr[i * stride + k] += + x_values_ptr[iter->second[j] * stride + k]; + } + } + } + + out->SetMember(out_indices, out_values, x.dims(), true); +} + +template +void CoalescedKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out) { + PD_VISIT_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "CoalescedCPUKernel", ([&] { + CoalescedCPUKernel(dev_ctx, x, out); + })); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(sort, + CPU, + ALL_LAYOUT, + phi::sparse::CoalescedKernel, + float, + double, + phi::dtype::float16, + uint8_t, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc index c10a240c684302d6d76d86935816794185f1c5e6..1508de407caa7ea62aff6e456cb20b2bb7599601 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc @@ -20,7 +20,9 @@ limitations under the License. */ #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/funcs/sparse/common_shape.h" +#include "paddle/phi/kernels/funcs/sparse/flatten_indices.h" + +#include "paddle/phi/api/ext/dispatch.h" namespace phi { namespace sparse { @@ -56,10 +58,10 @@ void SparseMaskCPUKernel(const CPUContext& dev_ctx, std::vector out_indexs(non_zero_num), sparse_offsets(sparse_dim); phi::funcs::sparse::CalcOffsetsPerDim( - dims, sparse_dim, &sparse_offsets); + dims, sparse_dim, sparse_offsets.data()); for (int64_t i = 0; i < non_zero_num; i++) { - int64_t index = phi::funcs::sparse::IndicesToIndex( + int64_t index = phi::funcs::sparse::CoordinateToIndex( indices_ptr, sparse_offsets.data(), non_zero_num, sparse_dim, i); memcpy(out_values_ptr + i * cols, x_ptr + index * cols, cols * sizeof(T)); } @@ -98,7 +100,7 @@ void SparseMaskHelperCPUKernel(const CPUContext& dev_ctx, std::vector sparse_offsets(sparse_dim), x_indexs(x.nnz()), mask_indexs(mask_indices.dims()[1]); phi::funcs::sparse::CalcOffsetsPerDim( - x.dims(), sparse_dim, &sparse_offsets); + x.dims(), sparse_dim, sparse_offsets.data()); phi::funcs::sparse::FlattenIndices(x.non_zero_indices().data(), sparse_offsets.data(), diff --git a/paddle/phi/kernels/sparse/gpu/coalesced_kernel.cu b/paddle/phi/kernels/sparse/gpu/coalesced_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..3ffcd28955a532feef303d1e9b47eaf40f0dfae4 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/coalesced_kernel.cu @@ -0,0 +1,189 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/funcs/index_impl.cu.h" +#include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h" +#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h" +#include "paddle/phi/kernels/funcs/sparse/utils.cu.h" +#include "paddle/phi/kernels/sparse/coalesced_kernel.h" + +namespace phi { +namespace sparse { + +template +void CoalescedGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out) { + const DenseTensor& x_indices = x.non_zero_indices(); + const DenseTensor& x_values = x.non_zero_elements(); + DenseTensor out_indices = phi::EmptyLike(dev_ctx, x_indices); + DenseTensor out_values = phi::EmptyLike(dev_ctx, x_values); + + const int64_t nnz = x.nnz(); + const int64_t sparse_dim = x.non_zero_indices().dims()[0]; + std::vector sparse_offsets(sparse_dim); + + phi::funcs::sparse::CalcOffsetsPerDim( + x.dims(), sparse_dim, sparse_offsets.data()); + + DenseTensorMeta sparse_offset_meta( + paddle::experimental::CppTypeToDataType::Type(), + {sparse_dim}, + DataLayout::NCHW); + DenseTensor d_sparse_offsets = + phi::Empty(dev_ctx, std::move(sparse_offset_meta)); + DenseTensor indexs = phi::Empty( + dev_ctx, DenseTensorMeta(x_indices.dtype(), {nnz}, x_indices.layout())); + IntT* indexs_ptr = indexs.data(); + + phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data(), + sparse_offsets.data(), + sizeof(IntT) * sparse_dim, +#ifdef PADDLE_WITH_HIP + hipMemcpyHostToDevice, +#else + cudaMemcpyHostToDevice, +#endif + dev_ctx.stream()); + + // 1. flatten indices + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, nnz, 1); + phi::funcs::sparse::FlattenIndicesKernel<<>>( + x.non_zero_indices().data(), + d_sparse_offsets.data(), + indexs.numel(), + sparse_dim, + indexs_ptr); + + // 2. get the address of each non-zero values + const T* x_values_ptr = x_values.data(); + const int64_t stride = + x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim; + DenseTensor values_indexs = phi::Empty( + dev_ctx, DenseTensorMeta(DataType::INT32, {nnz}, DataLayout::NCHW)); + int* values_indexs_ptr = values_indexs.data(); + DenseTensor public_indexs = phi::EmptyLike(dev_ctx, values_indexs); + + // values_indexs = [0,1,2,,,nnz-1] + phi::IndexKernel>( + dev_ctx, &values_indexs, kps::IdentityFunctor()); + phi::IndexKernel>( + dev_ctx, &public_indexs, kps::IdentityFunctor()); + +// 3. sort (indices, values index) +#ifdef PADDLE_WITH_HIP + thrust::sort_by_key(thrust::hip::par.on(dev_ctx.stream()), +#else + thrust::sort_by_key(thrust::cuda::par.on(dev_ctx.stream()), +#endif + indexs_ptr, + indexs_ptr + nnz, + values_indexs_ptr); + + // 4. unique index + thrust::pair new_end = +#ifdef PADDLE_WITH_HIP + thrust::unique_by_key(thrust::hip::par.on(dev_ctx.stream()), +#else + thrust::unique_by_key(thrust::cuda::par.on(dev_ctx.stream()), +#endif + indexs_ptr, + indexs_ptr + nnz, + public_indexs.data()); + + phi::funcs::sparse::DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( + indexs_ptr, new_end.first, out_indices.data()); + + IntT out_nnz = 0; + phi::backends::gpu::GpuMemcpyAsync(&out_nnz, + out_indices.data(), + sizeof(IntT), +#ifdef PADDLE_WITH_HIP + hipMemcpyDeviceToHost, +#else + cudaMemcpyDeviceToHost, +#endif + dev_ctx.stream()); + dev_ctx.Wait(); + + out_indices.Resize({x_indices.dims()[0], out_nnz}); + if (out_values.dims().size() == 1) { + out_values.Resize(phi::make_ddim({out_nnz})); + } else { + out_values.Resize(phi::make_ddim({out_nnz, x_values.dims()[1]})); + } + + // 5. scatter the values + config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, nnz * stride, 1); + phi::funcs::sparse::ScatterKernel<<>>( + x_values_ptr, + public_indexs.data(), + values_indexs_ptr, + out_nnz, + nnz, + stride, + out_values.data()); + + // 6. convert index to coordinate + Dim const_dims; + for (int i = 0; i < x.dims().size(); i++) { + const_dims[i] = x.dims()[i]; + } + + config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_nnz, 1); + phi::funcs::sparse::IndexToCoordinateKernel<<>>( + indexs_ptr, const_dims, out_nnz, sparse_dim, out_indices.data()); + + out->SetMember(out_indices, out_values, x.dims(), true); +} + +template +void CoalescedKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out) { + PD_VISIT_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "CoalescedGPUKernel", ([&] { + CoalescedGPUKernel(dev_ctx, x, out); + })); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(sort, + GPU, + ALL_LAYOUT, + phi::sparse::CoalescedKernel, + float, + double, + phi::dtype::float16, + uint8_t, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/gpu/convolution.cu.h b/paddle/phi/kernels/sparse/gpu/convolution.cu.h index 2396a5975de4e85c932b43b40a51cf6b03427aa5..fcbb3c60183eb53a3ab3933a8907717149da9393 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution.cu.h +++ b/paddle/phi/kernels/sparse/gpu/convolution.cu.h @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/index_impl.cu.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/sparse/utils.cu.h" #include "paddle/phi/kernels/primitive/compute_primitives.h" #include "paddle/phi/kernels/sparse/convolution_kernel.h" @@ -60,46 +61,6 @@ __global__ void GatherKernel(const T* params, } } -/** - * brief: scatter add - * input: the inputs - * unique_value: refer to UpdateIndexKernel notes - * out_index: the output feature index - * non_zero_num: the number of output features - * rulebook_len: the length of rulebook - * channels: the output channel size - * out: the outputs -**/ -template -__global__ void ScatterKernel(const T* input, - const int* unique_value, - const int* out_index, - const int non_zero_num, - const int rulebook_len, - const int channels, - T* out, - const bool subm = false) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) { - int indices_i = i / channels; - int channels_i = i - indices_i * channels; - - int start = unique_value[indices_i]; - int end = indices_i == non_zero_num - 1 ? rulebook_len - : unique_value[indices_i + 1]; - // max(end-start) = kernel_size - T sum = static_cast(0); - if (subm) { - sum = out[indices_i * channels + channels_i]; - } - for (int j = start; j < end; j++) { - const int out_feature_i = out_index[j]; - sum += input[out_feature_i * channels + channels_i]; - } - out[indices_i * channels + channels_i] = sum; - } -} - template inline IntT* SortedAndUniqueIndex(const Context& dev_ctx, const IntT* rulebook_ptr, @@ -186,14 +147,6 @@ __global__ void UpdateIndexKernel(const T* unique_keys, } } -// brief: calculation the distance between start and end -template -__global__ void DistanceKernel(const T* start, const T* end, T* distance) { - if (threadIdx.x == 0) { - *distance = end - start; - } -} - template __global__ void UpdateOutIndexAndCounterAfterLowerBound( const IntT* x_indexs, @@ -402,7 +355,7 @@ int ProductRuleBook(const Context& dev_ctx, rulebook_ptr + rulebook_rows * rulebook_cols, -1); - DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( + phi::funcs::sparse::DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( rulebook_ptr, last, rulebook_ptr + 3 * kernel_size * non_zero_num - 1); IntT rulebook_len = 0; phi::backends::gpu::GpuMemcpyAsync( @@ -468,7 +421,7 @@ int ProductRuleBook(const Context& dev_ctx, rulebook_ptr, rulebook_ptr + 3 * rulebook_len, -1); - DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( + phi::funcs::sparse::DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( rulebook_ptr, last, bound_ptr); phi::backends::gpu::GpuMemcpyAsync(&rulebook_len, bound_ptr, @@ -536,7 +489,7 @@ int ProductRuleBook(const Context& dev_ctx, // thrust::distance doesn't support stream parameters // const int out_non_zero_num = thrust::distance(unique_key_ptr, // new_end.first); - DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( + phi::funcs::sparse::DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( unique_key_ptr, new_end, rulebook_ptr + rulebook_rows * rulebook_cols - 1); diff --git a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu index ed9579fcd5b672e37c6aa6cbbb1a5c5835a342fa..e54e39f5541d55a0613554ddfad9cb8a26ca23ff 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h" #include "paddle/phi/kernels/sparse/convolution_grad_kernel.h" #include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" @@ -222,17 +223,18 @@ void Conv3dGradGPUKernel(const GPUContext& dev_ctx, config = phi::backends::gpu::GetGpuLaunchConfig1D( dev_ctx, rulebook_len * in_channels, 1); - ScatterKernel<<>>(d_x_features_ptr, - unique_value.data(), - out_index.data(), - x.nnz(), - rulebook_len, - in_channels, - x_grad_values_ptr, - subm); + phi::funcs::sparse::ScatterKernel<<>>( + d_x_features_ptr, + unique_value.data(), + out_index.data(), + x.nnz(), + rulebook_len, + in_channels, + x_grad_values_ptr, + subm); } template diff --git a/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu b/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu index 93da65dc0f7d8c630d1348b6ed1c31c7372973f6..30f0482a0cc360fe9294e642abc7f0f1b5ba3548 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" +#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h" #include "paddle/phi/kernels/sparse/convolution_kernel.h" #include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" @@ -169,16 +170,17 @@ void Conv3dGPUKernel(const GPUContext& dev_ctx, } else { config = phi::backends::gpu::GetGpuLaunchConfig1D( dev_ctx, out->nnz() * out_channels, 1); - ScatterKernel<<>>(out_features_ptr, - unique_value.data(), - out_index.data(), - out->nnz(), - n, - out_channels, - out_values_ptr); + phi::funcs::sparse::ScatterKernel<<>>( + out_features_ptr, + unique_value.data(), + out_index.data(), + out->nnz(), + n, + out_channels, + out_values_ptr); } } /** diff --git a/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu index dff1cc2318f132e1ecc41d73c01346b57ce70b9d..4e2d12f33955e4810ee995aec0843507c70a9d25 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu @@ -23,7 +23,7 @@ limitations under the License. */ #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/funcs/sparse/common_shape.h" +#include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h" #include "paddle/phi/kernels/sparse/sparse_mask_kernel.h" namespace phi { @@ -123,23 +123,6 @@ void SparseMaskKernel(const Context& dev_ctx, })); } -// TODO(zhangkaihuo): Use an op to realize the function of FlattenIndices -template -__global__ void FlattenIndicesKernel(const IntT* indices, - const IntT* sparse_offsets, - const int64_t non_zero_num, - const int64_t sparse_dim, - IntT* out) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - phi::funcs::sparse::FlattenIndices(indices, - sparse_offsets, - non_zero_num, - sparse_dim, - tid, - gridDim.x * blockDim.x, - out); -} - template __global__ void SparseMaskCopyKernel(const IntT* x_indexs, const IntT* mask_indexs, @@ -192,7 +175,8 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, IntT* bound_out_ptr = bound_out.data(); // 1. calc the offsets of per dim - phi::funcs::sparse::CalcOffsetsPerDim(x.dims(), sparse_dim, &sparse_offsets); + phi::funcs::sparse::CalcOffsetsPerDim( + x.dims(), sparse_dim, sparse_offsets.data()); // 2. copy sparse_offsets to device phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data(), sparse_offsets.data(), @@ -207,25 +191,27 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, // 3. flatten x indices and mask indices auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_indexs.numel(), 1); - FlattenIndicesKernel<<>>(x.non_zero_indices().data(), - d_sparse_offsets.data(), - x_indexs.numel(), - sparse_dim, - x_indexs_ptr); + phi::funcs::sparse::FlattenIndicesKernel<<>>( + x.non_zero_indices().data(), + d_sparse_offsets.data(), + x_indexs.numel(), + sparse_dim, + x_indexs_ptr); config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1); - FlattenIndicesKernel<<>>(mask_indices.data(), - d_sparse_offsets.data(), - mask_indexs.numel(), - sparse_dim, - mask_indexs_ptr); + phi::funcs::sparse::FlattenIndicesKernel<<>>( + mask_indices.data(), + d_sparse_offsets.data(), + mask_indexs.numel(), + sparse_dim, + mask_indexs_ptr); // 4. call thrust::lower_bound #ifdef PADDLE_WITH_HIP thrust::lower_bound(thrust::hip::par.on(dev_ctx.stream()), diff --git a/paddle/phi/kernels/sparse/sparse_utils_kernel.h b/paddle/phi/kernels/sparse/sparse_utils_kernel.h index 8cf9c0a28648ac88f29b3b570bee052a777001b8..072e6f141f8f1cb6f5fd2bc0d5aa90dc0bd7df5b 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_utils_kernel.h @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/sparse/coalesced_kernel.h" namespace phi { namespace sparse { @@ -154,9 +155,9 @@ void SparseCooTensorKernel(const Context& dev_ctx, const DenseTensor& indices, const IntArray& dense_shape, SparseCooTensor* out) { - *out = - SparseCooTensor(indices, values, phi::make_ddim(dense_shape.GetData())); - // TODO(zhangkaihuo): sort and merge the dumplicate indices + SparseCooTensor before_coalesced( + indices, values, phi::make_ddim(dense_shape.GetData())); + CoalescedKernel(dev_ctx, before_coalesced, out); } } // namespace sparse diff --git a/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py b/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py index 89cfc711910ce77c7b19578e3bd641b67ba1b94f..c87626a10c6315f609886a34f486fa93269ea090 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py @@ -19,6 +19,8 @@ import paddle import paddle.fluid.core as core from paddle.fluid.framework import _test_eager_guard +devices = ['cpu', 'gpu'] + class TestSparseCreate(unittest.TestCase): def test_create_coo_by_tensor(self): @@ -30,6 +32,8 @@ class TestSparseCreate(unittest.TestCase): dense_elements = paddle.to_tensor(values, dtype='float32') coo = paddle.sparse.sparse_coo_tensor( dense_indices, dense_elements, dense_shape, stop_gradient=False) + # test the to_string.py + print(coo) assert np.array_equal(indices, coo.indices().numpy()) assert np.array_equal(values, coo.values().numpy()) @@ -37,7 +41,7 @@ class TestSparseCreate(unittest.TestCase): with _test_eager_guard(): indices = [[0, 1, 2], [1, 2, 0]] values = [1.0, 2.0, 3.0] - dense_shape = [2, 3] + dense_shape = [3, 3] coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape) assert np.array_equal(indices, coo.indices().numpy()) assert np.array_equal(values, coo.values().numpy()) @@ -67,6 +71,8 @@ class TestSparseCreate(unittest.TestCase): dense_shape = [3, 4] csr = paddle.sparse.sparse_csr_tensor(crows, cols, values, dense_shape) + # test the to_string.py + print(csr) assert np.array_equal(crows, csr.crows().numpy()) assert np.array_equal(cols, csr.cols().numpy()) assert np.array_equal(values, csr.values().numpy()) @@ -205,38 +211,154 @@ class TestSparseConvert(unittest.TestCase): def test_sparse_coo_tensor_grad(self): with _test_eager_guard(): - indices = [[0, 1], [0, 1]] - values = [1, 2] - indices = paddle.to_tensor(indices, dtype='int32') - values = paddle.to_tensor( - values, dtype='float32', stop_gradient=False) - sparse_x = paddle.sparse.sparse_coo_tensor( - indices, values, shape=[2, 2], stop_gradient=False) - grad_indices = [[0, 1], [1, 1]] - grad_values = [2, 3] - grad_indices = paddle.to_tensor(grad_indices, dtype='int32') - grad_values = paddle.to_tensor(grad_values, dtype='float32') - sparse_out_grad = paddle.sparse.sparse_coo_tensor( - grad_indices, grad_values, shape=[2, 2]) - sparse_x.backward(sparse_out_grad) - correct_values_grad = [0, 3] - assert np.array_equal(correct_values_grad, values.grad.numpy()) + for device in devices: + if device == 'cpu' or (device == 'gpu' and + paddle.is_compiled_with_cuda()): + paddle.device.set_device(device) + indices = [[0, 1], [0, 1]] + values = [1, 2] + indices = paddle.to_tensor(indices, dtype='int32') + values = paddle.to_tensor( + values, dtype='float32', stop_gradient=False) + sparse_x = paddle.sparse.sparse_coo_tensor( + indices, values, shape=[2, 2], stop_gradient=False) + grad_indices = [[0, 1], [1, 1]] + grad_values = [2, 3] + grad_indices = paddle.to_tensor(grad_indices, dtype='int32') + grad_values = paddle.to_tensor(grad_values, dtype='float32') + sparse_out_grad = paddle.sparse.sparse_coo_tensor( + grad_indices, grad_values, shape=[2, 2]) + sparse_x.backward(sparse_out_grad) + correct_values_grad = [0, 3] + assert np.array_equal(correct_values_grad, + values.grad.numpy()) - place = core.CPUPlace() - indices_cpu = paddle.to_tensor(indices, dtype='int32', place=place) - values_cpu = paddle.to_tensor( - values, dtype='float32', place=place, stop_gradient=False) - sparse_x_cpu = paddle.sparse.sparse_coo_tensor( - indices_cpu, - values_cpu, - shape=[2, 2], - place=place, - stop_gradient=False) + def test_sparse_coo_tensor_sorted(self): + with _test_eager_guard(): + for device in devices: + if device == 'cpu' or (device == 'gpu' and + paddle.is_compiled_with_cuda()): + paddle.device.set_device(device) + #test unsorted and duplicate indices + indices = [[1, 0, 0], [0, 1, 1]] + values = [1.0, 2.0, 3.0] + indices = paddle.to_tensor(indices, dtype='int32') + values = paddle.to_tensor(values, dtype='float32') + sparse_x = paddle.sparse.sparse_coo_tensor(indices, values) + indices_sorted = [[0, 1], [1, 0]] + values_sorted = [5.0, 1.0] + assert np.array_equal(indices_sorted, + sparse_x.indices().numpy()) + assert np.array_equal(values_sorted, + sparse_x.values().numpy()) + + +class TestCooError(unittest.TestCase): + def test_small_shape(self): + with _test_eager_guard(): + with self.assertRaises(ValueError): + indices = [[2, 3], [0, 2]] + values = [1, 2] + # 1. the shape too small + dense_shape = [2, 2] + sparse_x = paddle.sparse.sparse_coo_tensor( + indices, values, shape=dense_shape) + + def test_same_nnz(self): + with _test_eager_guard(): + with self.assertRaises(ValueError): + # 2. test the nnz of indices must same as nnz of values + indices = [[1, 2], [1, 0]] + values = [1, 2, 3] + sparse_x = paddle.sparse.sparse_coo_tensor(indices, values) + + def test_same_dimensions(self): + with _test_eager_guard(): + with self.assertRaises(ValueError): + indices = [[1, 2], [1, 0]] + values = [1, 2, 3] + shape = [2, 3, 4] + sparse_x = paddle.sparse.sparse_coo_tensor( + indices, values, shape=shape) + + def test_indices_dtype(self): + with _test_eager_guard(): + with self.assertRaises(TypeError): + indices = [[1.0, 2.0], [0, 1]] + values = [1, 2] + sparse_x = paddle.sparse.sparse_coo_tensor(indices, values) + + +class TestCsrError(unittest.TestCase): + def test_dimension1(self): + with _test_eager_guard(): + with self.assertRaises(ValueError): + crows = [0, 1, 2, 3] + cols = [0, 1, 2] + values = [1, 2, 3] + shape = [3] + sparse_x = paddle.sparse.sparse_csr_tensor(crows, cols, values, + shape) + + def test_dimension2(self): + with _test_eager_guard(): + with self.assertRaises(ValueError): + crows = [0, 1, 2, 3] + cols = [0, 1, 2] + values = [1, 2, 3] + shape = [3, 3, 3, 3] + sparse_x = paddle.sparse.sparse_csr_tensor(crows, cols, values, + shape) + + def test_same_shape1(self): + with _test_eager_guard(): + with self.assertRaises(ValueError): + crows = [0, 1, 2, 3] + cols = [0, 1, 2, 3] + values = [1, 2, 3] + shape = [3, 4] + sparse_x = paddle.sparse.sparse_csr_tensor(crows, cols, values, + shape) - sparse_out_grad_cpu = paddle.sparse.sparse_coo_tensor( - grad_indices, grad_values, shape=[2, 2], place=place) - sparse_x_cpu.backward(sparse_out_grad_cpu) - assert np.array_equal(correct_values_grad, values_cpu.grad.numpy()) + def test_same_shape2(self): + with _test_eager_guard(): + with self.assertRaises(ValueError): + crows = [0, 1, 2, 3] + cols = [0, 1, 2, 3] + values = [1, 2, 3, 4] + shape = [3, 4] + sparse_x = paddle.sparse.sparse_csr_tensor(crows, cols, values, + shape) + + def test_same_shape3(self): + with _test_eager_guard(): + with self.assertRaises(ValueError): + crows = [0, 1, 2, 3, 0, 1, 2] + cols = [0, 1, 2, 3, 0, 1, 2] + values = [1, 2, 3, 4, 0, 1, 2] + shape = [2, 3, 4] + sparse_x = paddle.sparse.sparse_csr_tensor(crows, cols, values, + shape) + + def test_crows_first_value(self): + with _test_eager_guard(): + with self.assertRaises(ValueError): + crows = [1, 1, 2, 3] + cols = [0, 1, 2] + values = [1, 2, 3] + shape = [3, 4] + sparse_x = paddle.sparse.sparse_csr_tensor(crows, cols, values, + shape) + + def test_dtype(self): + with _test_eager_guard(): + with self.assertRaises(TypeError): + crows = [0, 1, 2, 3.0] + cols = [0, 1, 2] + values = [1, 2, 3] + shape = [3] + sparse_x = paddle.sparse.sparse_csr_tensor(crows, cols, values, + shape) if __name__ == "__main__": diff --git a/python/paddle/sparse/creation.py b/python/paddle/sparse/creation.py index ac9276f3142c0ca322b5fb73286e14677740826b..d494336e1ff507a8fb16726a215321829a62ce04 100644 --- a/python/paddle/sparse/creation.py +++ b/python/paddle/sparse/creation.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle from paddle import _C_ops from ..framework import core, dygraph_only from ..framework import _current_expected_place, _get_paddle_place @@ -51,6 +52,13 @@ def _get_place(place): return place +def _check_indices_dtype(dtype): + if dtype not in [paddle.int8, paddle.int16, paddle.int32, paddle.int64]: + raise TypeError( + "the dtype of indices must be 'int8' or 'int16' or 'int32' or 'int64'" + ) + + @dygraph_only def sparse_coo_tensor(indices, values, @@ -117,6 +125,18 @@ def sparse_coo_tensor(indices, if len(indices.shape) != 2: raise ValueError("'indices' must be 2-D.") + nnz = indices.shape[1] + sparse_dim = indices.shape[0] + + _check_indices_dtype(indices.dtype) + + if nnz != values.shape[0]: + raise ValueError( + "the indices and values must have same number of non-zero, but get {} and {}". + format(nnz, values.shape[0])) + + dense_dim = len(values.shape) - 1 + if not indices.place._equals(place): indices = indices._copy_to(place, False) @@ -125,8 +145,17 @@ def sparse_coo_tensor(indices, values = _handle_dtype(values, dtype) values.stop_gradient = stop_gradient + min_shape = _infer_dense_shape(indices) if shape is None: - shape = _infer_dense_shape(indices) + shape = min_shape + else: + if shape < min_shape: + raise ValueError("the minimun shape required is {}, but get {}". + format(min_shape, shape)) + if len(shape) != sparse_dim + dense_dim: + raise ValueError( + "the number of dimensions(len(shape) must be sparse_dim({}) + dense_dim({}), but get {}". + format(sparse_dim, dense_dim, len(shape))) return _C_ops.final_state_sparse_create_sparse_coo_tensor(values, indices, shape) @@ -144,6 +173,7 @@ def sparse_csr_tensor(crows, r""" Constructs a sparse ``paddle.Tensor`` in CSR(Compressed Sparse Row) format according to the ``crows``, ``cols`` and ``values``. + Currently, the crows and cols of each batch must be incrementd. Args: crows(list|tuple|ndarray|Tensor): 1-D array, each element in the rows represents the @@ -202,10 +232,14 @@ def sparse_csr_tensor(crows, cols = to_tensor(cols, dtype=None, place=place, stop_gradient=True) if not isinstance(values, core.eager.Tensor): values = to_tensor(values, dtype, place, stop_gradient) - if len(crows.shape) != 1 or len(cols.shape) != 1 or len(values.shape) != 1: + + _check_indices_dtype(crows.dtype) + _check_indices_dtype(cols.dtype) + + if len(shape) != 2 and len(shape) != 3: raise ValueError( - "SparseCsrTensor only support 2-D or 3-D matrix. The 'crows', 'cols' and 'values' must be 1-D." - ) + "SparseCsrTensor only support 2-D or 3-D matrix. but get shape {}". + format(shape)) if not crows.place._equals(place): crows = crows._copy_to(place, False) @@ -217,5 +251,30 @@ def sparse_csr_tensor(crows, values = values._copy_to(place, False) values = _handle_dtype(values, dtype) values.stop_gradient = stop_gradient + + if len(crows.shape) != 1 or len(cols.shape) != 1 or len(values.shape) != 1: + raise ValueError("The 'crows', 'cols' and 'values' must be 1-D.") + + if (len(cols) != len(values)): + raise ValueError("the length of cols must be same as length of values") + + if len(shape) == 2: + if crows.shape[0] != shape[0] + 1: + raise ValueError( + "The length({}) of crows must be equal to the rows({})+1 of matrix.". + format(crows.shape[0], shape[0])) + if crows[0] != 0: + raise ValueError("the 0th value of crows must be 0") + + if crows[-1] != values.shape[0]: + raise ValueError( + "the last value of crows must be equal the number of non-zero") + else: + if crows.shape[0] % (shape[0] + 1) != 0: + raise ValueError( + "The length({}) of crows must be divisible the rows({})+1 of matrix.". + format(crows.shape[0], shape[0])) + # TODO(zkh2016): check whether the value in crows and cols is legal + return core.eager.sparse_csr_tensor(crows, cols, values, shape, stop_gradient) diff --git a/python/paddle/utils/code_gen/sparse_api.yaml b/python/paddle/utils/code_gen/sparse_api.yaml index 2187d4abb2d637c71012dffcde291a8b609e9fd1..100d7ad78319b99b1c275d03c904d0ff4cf029c6 100644 --- a/python/paddle/utils/code_gen/sparse_api.yaml +++ b/python/paddle/utils/code_gen/sparse_api.yaml @@ -27,6 +27,7 @@ kernel : func : sparse_coo_tensor layout : values + data_type : values backward : create_sparse_coo_tensor_grad - api : csr_values