diff --git a/paddle/phi/api/lib/tensor_method.cc b/paddle/phi/api/lib/tensor_method.cc index 2ead95e11b7eb210e3d4add501e8258ca7af4763..6d38bbda36310dd346f322a5faae66a139fb70b3 100644 --- a/paddle/phi/api/lib/tensor_method.cc +++ b/paddle/phi/api/lib/tensor_method.cc @@ -167,39 +167,21 @@ void Tensor::copy_(const Tensor &src, blocking, static_cast(impl_.get())); } else if (kernel_type == KernelType::SPARSE_COO_KERNEL) { - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "copy_sparse_coo", {kernel_backend, kernel_layout, kernel_data_type}); - VLOG(6) << "copy API kernel key: " << kernel_key; - VLOG(6) << "copy API kernel: " << kernel; - using kernel_signature = void (*)(const platform::DeviceContext &, - const phi::SparseCooTensor &, - phi::Place, - bool, - phi::SparseCooTensor *); - this->set_impl(std::make_shared()); - auto *kernel_fn = kernel.GetVariadicKernelFn(); - (*kernel_fn)(*dev_ctx, - (*(std::static_pointer_cast(src.impl_))), - target_place, - blocking, - static_cast(impl_.get())); + SetSparseKernelOutput(this, TensorType::SPARSE_COO); + // TODO(zhangkaihuo) add sparse infer_meta + phi::Copy(*dev_ctx, + (*(std::static_pointer_cast(src.impl_))), + target_place, + blocking, + static_cast(impl_.get())); } else if (kernel_type == KernelType::SPARSE_CSR_KERNEL) { - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "copy_sparse_csr", {kernel_backend, kernel_layout, kernel_data_type}); - VLOG(6) << "copy API kernel key: " << kernel_key; - VLOG(6) << "copy API kernel: " << kernel; - using kernel_signature = void (*)(const platform::DeviceContext &, - const phi::SparseCsrTensor &, - phi::Place, - bool, - phi::SparseCsrTensor *); - this->set_impl(std::make_shared()); - auto *kernel_fn = kernel.GetVariadicKernelFn(); - (*kernel_fn)(*dev_ctx, - (*(std::static_pointer_cast(src.impl_))), - target_place, - blocking, - static_cast(impl_.get())); + SetSparseKernelOutput(this, TensorType::SPARSE_CSR); + // TODO(zhangkaihuo) add sparse infer_meta + phi::Copy(*dev_ctx, + (*(std::static_pointer_cast(src.impl_))), + target_place, + blocking, + static_cast(impl_.get())); } else { PADDLE_THROW(phi::errors::InvalidArgument( "We currently only support dense tensor copy for now and if u need to " diff --git a/paddle/phi/core/tensor_utils.cc b/paddle/phi/core/tensor_utils.cc index f6743a0c1849b2350b67810840b5737a6387b153..e2d8d5a03651c9ea7cd8dfd8fa62708398d1da94 100644 --- a/paddle/phi/core/tensor_utils.cc +++ b/paddle/phi/core/tensor_utils.cc @@ -234,6 +234,53 @@ void Copy(const Context& dev_ctx, dev_ctx, src.value(), dst_place, blocking, dst->mutable_value()); } +template +void Copy(const Context& dev_ctx, + const SparseCooTensor& src, + Place dst_place, + bool blocking, + SparseCooTensor* dst) { + phi::Copy(dev_ctx, + src.non_zero_indices(), + dst_place, + blocking, + dst->mutable_non_zero_indices()); + + phi::Copy(dev_ctx, + src.non_zero_elements(), + dst_place, + blocking, + dst->mutable_non_zero_elements()); + dst->set_dims(src.dims()); + dst->SetCoalesced(src.coalesced()); +} + +template +void Copy(const Context& dev_ctx, + const SparseCsrTensor& src, + Place dst_place, + bool blocking, + SparseCsrTensor* dst) { + phi::Copy(dev_ctx, + src.non_zero_crows(), + dst_place, + blocking, + dst->mutable_non_zero_crows()); + + phi::Copy(dev_ctx, + src.non_zero_cols(), + dst_place, + blocking, + dst->mutable_non_zero_cols()); + + phi::Copy(dev_ctx, + src.non_zero_elements(), + dst_place, + blocking, + dst->mutable_non_zero_elements()); + dst->set_dims(src.dims()); +} + template void Copy(const CPUContext& dev_ctx, const DenseTensor& src, Place dst_place, @@ -257,6 +304,30 @@ template void Copy(const DeviceContext& dev_ctx, bool blocking, SelectedRows* dst); +template void Copy(const CPUContext& dev_ctx, + const SparseCooTensor& src, + Place dst_place, + bool blocking, + SparseCooTensor* dst); + +template void Copy(const DeviceContext& dev_ctx, + const SparseCooTensor& src, + Place dst_place, + bool blocking, + SparseCooTensor* dst); + +template void Copy(const CPUContext& dev_ctx, + const SparseCsrTensor& src, + Place dst_place, + bool blocking, + SparseCsrTensor* dst); + +template void Copy(const DeviceContext& dev_ctx, + const SparseCsrTensor& src, + Place dst_place, + bool blocking, + SparseCsrTensor* dst); + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template void Copy(const GPUContext& dev_ctx, const DenseTensor& src, @@ -268,6 +339,16 @@ template void Copy(const GPUContext& dev_ctx, Place dst_place, bool blocking, SelectedRows* dst); +template void Copy(const GPUContext& dev_ctx, + const SparseCooTensor& src, + Place dst_place, + bool blocking, + SparseCooTensor* dst); +template void Copy(const GPUContext& dev_ctx, + const SparseCsrTensor& src, + Place dst_place, + bool blocking, + SparseCsrTensor* dst); #endif #ifdef PADDLE_WITH_XPU diff --git a/paddle/phi/core/tensor_utils.h b/paddle/phi/core/tensor_utils.h index 1c490fd53931c7f5f168dca6a576a015a1aae99f..c478e3e0895763b05a0b5902850a4528c93ff5b8 100644 --- a/paddle/phi/core/tensor_utils.h +++ b/paddle/phi/core/tensor_utils.h @@ -16,6 +16,8 @@ limitations under the License. */ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/selected_rows.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/tensor_meta.h" namespace phi { @@ -85,4 +87,18 @@ void Copy(const Context& dev_ctx, bool blocking, SelectedRows* dst); +template +void Copy(const Context& dev_ctx, + const SparseCooTensor& src, + Place dst_place, + bool blocking, + SparseCooTensor* dst); + +template +void Copy(const Context& dev_ctx, + const SparseCsrTensor& src, + Place dst_place, + bool blocking, + SparseCsrTensor* dst); + } // namespace phi diff --git a/paddle/phi/kernels/sparse/copy_kernel.cc b/paddle/phi/kernels/sparse/copy_kernel.cc deleted file mode 100644 index 76726f0ffcce02ad8b9f1e425e9f8a3cbddaf980..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/sparse/copy_kernel.cc +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/kernels/sparse/copy_kernel.h" - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/core/sparse_coo_tensor.h" -#include "paddle/phi/core/sparse_csr_tensor.h" -#include "paddle/phi/core/tensor_utils.h" - -namespace phi { -namespace sparse { - -template -void CopyCoo(const Context& dev_ctx, - const SparseCooTensor& src, - Place dst_place, - bool blocking, - SparseCooTensor* dst) { - phi::Copy(dev_ctx, - src.non_zero_indices(), - dst_place, - blocking, - dst->mutable_non_zero_indices()); - - phi::Copy(dev_ctx, - src.non_zero_elements(), - dst_place, - blocking, - dst->mutable_non_zero_elements()); - dst->set_dims(src.dims()); -} - -template -void CopyCsr(const Context& dev_ctx, - const SparseCsrTensor& src, - Place dst_place, - bool blocking, - SparseCsrTensor* dst) { - phi::Copy(dev_ctx, - src.non_zero_crows(), - dst_place, - blocking, - dst->mutable_non_zero_crows()); - - phi::Copy(dev_ctx, - src.non_zero_cols(), - dst_place, - blocking, - dst->mutable_non_zero_cols()); - - phi::Copy(dev_ctx, - src.non_zero_elements(), - dst_place, - blocking, - dst->mutable_non_zero_elements()); - dst->set_dims(src.dims()); -} - -} // namespace sparse -} // namespace phi - -PD_REGISTER_GENERAL_KERNEL(copy_sparse_coo, - CPU, - ALL_LAYOUT, - phi::sparse::CopyCoo, - ALL_DTYPE) {} - -PD_REGISTER_GENERAL_KERNEL(copy_sparse_csr, - CPU, - ALL_LAYOUT, - phi::sparse::CopyCsr, - ALL_DTYPE) {} - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_GENERAL_KERNEL(copy_sparse_coo, - GPU, - ALL_LAYOUT, - phi::sparse::CopyCoo, - ALL_DTYPE) {} -PD_REGISTER_GENERAL_KERNEL(copy_sparse_csr, - GPU, - ALL_LAYOUT, - phi::sparse::CopyCsr, - ALL_DTYPE) {} -#endif diff --git a/paddle/phi/kernels/sparse/copy_kernel.h b/paddle/phi/kernels/sparse/copy_kernel.h deleted file mode 100644 index 70e2aaef8a88806f5bf77f014856055b434d0f7c..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/sparse/copy_kernel.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/sparse_coo_tensor.h" -#include "paddle/phi/core/sparse_csr_tensor.h" -#include "paddle/phi/kernels/empty_kernel.h" - -namespace phi { -namespace sparse { - -template -void CopyCoo(const Context& dev_ctx, - const SparseCooTensor& src, - Place dst_place, - bool blocking, - SparseCooTensor* dst); - -template -void CopyCsr(const Context& dev_ctx, - const SparseCsrTensor& src, - Place dst_place, - bool blocking, - SparseCsrTensor* dst); - -} // namespace sparse -} // namespace phi diff --git a/paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc index 6ae74e77422a4ce4f3ba94badb0b4487242e56eb..d9ebbd10267f5655b922fdfa381d3bf947a2f62a 100644 --- a/paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc @@ -16,11 +16,12 @@ limitations under the License. */ #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/activation_kernel.h" #include "paddle/phi/kernels/elementwise_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/sparse/copy_kernel.h" #include "paddle/phi/kernels/sparse/elementwise_kernel.h" namespace phi { @@ -56,16 +57,16 @@ void ElementWiseAddCsrGradCPUKernel(const Context& dev_ctx, if (dx != nullptr && dy == nullptr) { VLOG(4) << "Special case when dy is not needed"; AllocCsrPtr(dev_ctx, x, dx); - CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); } else if (dx == nullptr && dy != nullptr) { VLOG(4) << "Special case when dx is not needed"; AllocCsrPtr(dev_ctx, y, dy); - CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); } else { AllocCsrPtr(dev_ctx, x, dx); AllocCsrPtr(dev_ctx, y, dy); - CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); - CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); } } @@ -78,12 +79,12 @@ void ElementWiseSubtractCsrGradCPUKernel(const Context& dev_ctx, SparseCsrTensor* dy) { if (dx) { AllocCsrPtr(dev_ctx, x, dx); - CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); } if (dy) { AllocCsrPtr(dev_ctx, y, dy); - CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); phi::NegativeKernel( dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements()); } @@ -126,7 +127,7 @@ void ElementWiseDivideCsrGradCPUKernel(const Context& dev_ctx, if (dy) { // -dout * out / y AllocCsrPtr(dev_ctx, y, dy); - CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); phi::NegativeKernel( dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements()); auto tmp = sparse::ElementWiseMultiplyCsr(dev_ctx, *dy, out); @@ -145,16 +146,16 @@ void ElementWiseAddCooGradCPUKernel(const Context& dev_ctx, if (dx != nullptr && dy == nullptr) { VLOG(4) << "Special case when dy is not needed"; AllocCooPtr(dev_ctx, x, dx); - CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); } else if (dx == nullptr && dy != nullptr) { VLOG(4) << "Special case when dx is not needed"; AllocCooPtr(dev_ctx, y, dy); - CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); } else { AllocCooPtr(dev_ctx, x, dx); AllocCooPtr(dev_ctx, y, dy); - CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); - CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); } } @@ -167,12 +168,12 @@ void ElementWiseSubtractCooGradCPUKernel(const Context& dev_ctx, SparseCooTensor* dy) { if (dx) { AllocCooPtr(dev_ctx, x, dx); - CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); } if (dy) { AllocCooPtr(dev_ctx, y, dy); - CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); phi::NegativeKernel( dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements()); } @@ -215,7 +216,7 @@ void ElementWiseDivideCooGradCPUKernel(const Context& dev_ctx, if (dy) { // -dout * out / y AllocCooPtr(dev_ctx, y, dy); - CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); phi::NegativeKernel( dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements()); auto tmp = sparse::ElementWiseMultiplyCoo(dev_ctx, *dy, out);