未验证 提交 148fa05e 编写于 作者: Z zhangkaihuo 提交者: GitHub

Change sparse Copy from Kernel to basic component utils (#43916)

上级 1dbbe20e
...@@ -167,39 +167,21 @@ void Tensor::copy_(const Tensor &src, ...@@ -167,39 +167,21 @@ void Tensor::copy_(const Tensor &src,
blocking, blocking,
static_cast<phi::SelectedRows *>(impl_.get())); static_cast<phi::SelectedRows *>(impl_.get()));
} else if (kernel_type == KernelType::SPARSE_COO_KERNEL) { } else if (kernel_type == KernelType::SPARSE_COO_KERNEL) {
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( SetSparseKernelOutput(this, TensorType::SPARSE_COO);
"copy_sparse_coo", {kernel_backend, kernel_layout, kernel_data_type}); // TODO(zhangkaihuo) add sparse infer_meta
VLOG(6) << "copy API kernel key: " << kernel_key; phi::Copy(*dev_ctx,
VLOG(6) << "copy API kernel: " << kernel; (*(std::static_pointer_cast<phi::SparseCooTensor>(src.impl_))),
using kernel_signature = void (*)(const platform::DeviceContext &, target_place,
const phi::SparseCooTensor &, blocking,
phi::Place, static_cast<phi::SparseCooTensor *>(impl_.get()));
bool,
phi::SparseCooTensor *);
this->set_impl(std::make_shared<phi::SparseCooTensor>());
auto *kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
(*(std::static_pointer_cast<phi::SparseCooTensor>(src.impl_))),
target_place,
blocking,
static_cast<phi::SparseCooTensor *>(impl_.get()));
} else if (kernel_type == KernelType::SPARSE_CSR_KERNEL) { } else if (kernel_type == KernelType::SPARSE_CSR_KERNEL) {
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( SetSparseKernelOutput(this, TensorType::SPARSE_CSR);
"copy_sparse_csr", {kernel_backend, kernel_layout, kernel_data_type}); // TODO(zhangkaihuo) add sparse infer_meta
VLOG(6) << "copy API kernel key: " << kernel_key; phi::Copy(*dev_ctx,
VLOG(6) << "copy API kernel: " << kernel; (*(std::static_pointer_cast<phi::SparseCsrTensor>(src.impl_))),
using kernel_signature = void (*)(const platform::DeviceContext &, target_place,
const phi::SparseCsrTensor &, blocking,
phi::Place, static_cast<phi::SparseCsrTensor *>(impl_.get()));
bool,
phi::SparseCsrTensor *);
this->set_impl(std::make_shared<phi::SparseCsrTensor>());
auto *kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
(*(std::static_pointer_cast<phi::SparseCsrTensor>(src.impl_))),
target_place,
blocking,
static_cast<phi::SparseCsrTensor *>(impl_.get()));
} else { } else {
PADDLE_THROW(phi::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"We currently only support dense tensor copy for now and if u need to " "We currently only support dense tensor copy for now and if u need to "
......
...@@ -234,6 +234,53 @@ void Copy(const Context& dev_ctx, ...@@ -234,6 +234,53 @@ void Copy(const Context& dev_ctx,
dev_ctx, src.value(), dst_place, blocking, dst->mutable_value()); dev_ctx, src.value(), dst_place, blocking, dst->mutable_value());
} }
template <typename Context>
void Copy(const Context& dev_ctx,
const SparseCooTensor& src,
Place dst_place,
bool blocking,
SparseCooTensor* dst) {
phi::Copy<Context>(dev_ctx,
src.non_zero_indices(),
dst_place,
blocking,
dst->mutable_non_zero_indices());
phi::Copy<Context>(dev_ctx,
src.non_zero_elements(),
dst_place,
blocking,
dst->mutable_non_zero_elements());
dst->set_dims(src.dims());
dst->SetCoalesced(src.coalesced());
}
template <typename Context>
void Copy(const Context& dev_ctx,
const SparseCsrTensor& src,
Place dst_place,
bool blocking,
SparseCsrTensor* dst) {
phi::Copy<Context>(dev_ctx,
src.non_zero_crows(),
dst_place,
blocking,
dst->mutable_non_zero_crows());
phi::Copy<Context>(dev_ctx,
src.non_zero_cols(),
dst_place,
blocking,
dst->mutable_non_zero_cols());
phi::Copy<Context>(dev_ctx,
src.non_zero_elements(),
dst_place,
blocking,
dst->mutable_non_zero_elements());
dst->set_dims(src.dims());
}
template void Copy(const CPUContext& dev_ctx, template void Copy(const CPUContext& dev_ctx,
const DenseTensor& src, const DenseTensor& src,
Place dst_place, Place dst_place,
...@@ -257,6 +304,30 @@ template void Copy(const DeviceContext& dev_ctx, ...@@ -257,6 +304,30 @@ template void Copy(const DeviceContext& dev_ctx,
bool blocking, bool blocking,
SelectedRows* dst); SelectedRows* dst);
template void Copy(const CPUContext& dev_ctx,
const SparseCooTensor& src,
Place dst_place,
bool blocking,
SparseCooTensor* dst);
template void Copy(const DeviceContext& dev_ctx,
const SparseCooTensor& src,
Place dst_place,
bool blocking,
SparseCooTensor* dst);
template void Copy(const CPUContext& dev_ctx,
const SparseCsrTensor& src,
Place dst_place,
bool blocking,
SparseCsrTensor* dst);
template void Copy(const DeviceContext& dev_ctx,
const SparseCsrTensor& src,
Place dst_place,
bool blocking,
SparseCsrTensor* dst);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template void Copy(const GPUContext& dev_ctx, template void Copy(const GPUContext& dev_ctx,
const DenseTensor& src, const DenseTensor& src,
...@@ -268,6 +339,16 @@ template void Copy(const GPUContext& dev_ctx, ...@@ -268,6 +339,16 @@ template void Copy(const GPUContext& dev_ctx,
Place dst_place, Place dst_place,
bool blocking, bool blocking,
SelectedRows* dst); SelectedRows* dst);
template void Copy(const GPUContext& dev_ctx,
const SparseCooTensor& src,
Place dst_place,
bool blocking,
SparseCooTensor* dst);
template void Copy(const GPUContext& dev_ctx,
const SparseCsrTensor& src,
Place dst_place,
bool blocking,
SparseCsrTensor* dst);
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
......
...@@ -16,6 +16,8 @@ limitations under the License. */ ...@@ -16,6 +16,8 @@ limitations under the License. */
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
namespace phi { namespace phi {
...@@ -85,4 +87,18 @@ void Copy(const Context& dev_ctx, ...@@ -85,4 +87,18 @@ void Copy(const Context& dev_ctx,
bool blocking, bool blocking,
SelectedRows* dst); SelectedRows* dst);
template <typename Context>
void Copy(const Context& dev_ctx,
const SparseCooTensor& src,
Place dst_place,
bool blocking,
SparseCooTensor* dst);
template <typename Context>
void Copy(const Context& dev_ctx,
const SparseCsrTensor& src,
Place dst_place,
bool blocking,
SparseCsrTensor* dst);
} // namespace phi } // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/copy_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/tensor_utils.h"
namespace phi {
namespace sparse {
template <typename Context>
void CopyCoo(const Context& dev_ctx,
const SparseCooTensor& src,
Place dst_place,
bool blocking,
SparseCooTensor* dst) {
phi::Copy<Context>(dev_ctx,
src.non_zero_indices(),
dst_place,
blocking,
dst->mutable_non_zero_indices());
phi::Copy<Context>(dev_ctx,
src.non_zero_elements(),
dst_place,
blocking,
dst->mutable_non_zero_elements());
dst->set_dims(src.dims());
}
template <typename Context>
void CopyCsr(const Context& dev_ctx,
const SparseCsrTensor& src,
Place dst_place,
bool blocking,
SparseCsrTensor* dst) {
phi::Copy<Context>(dev_ctx,
src.non_zero_crows(),
dst_place,
blocking,
dst->mutable_non_zero_crows());
phi::Copy<Context>(dev_ctx,
src.non_zero_cols(),
dst_place,
blocking,
dst->mutable_non_zero_cols());
phi::Copy<Context>(dev_ctx,
src.non_zero_elements(),
dst_place,
blocking,
dst->mutable_non_zero_elements());
dst->set_dims(src.dims());
}
} // namespace sparse
} // namespace phi
PD_REGISTER_GENERAL_KERNEL(copy_sparse_coo,
CPU,
ALL_LAYOUT,
phi::sparse::CopyCoo<phi::CPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(copy_sparse_csr,
CPU,
ALL_LAYOUT,
phi::sparse::CopyCsr<phi::CPUContext>,
ALL_DTYPE) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(copy_sparse_coo,
GPU,
ALL_LAYOUT,
phi::sparse::CopyCoo<phi::GPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(copy_sparse_csr,
GPU,
ALL_LAYOUT,
phi::sparse::CopyCsr<phi::GPUContext>,
ALL_DTYPE) {}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h"
namespace phi {
namespace sparse {
template <typename Context>
void CopyCoo(const Context& dev_ctx,
const SparseCooTensor& src,
Place dst_place,
bool blocking,
SparseCooTensor* dst);
template <typename Context>
void CopyCsr(const Context& dev_ctx,
const SparseCsrTensor& src,
Place dst_place,
bool blocking,
SparseCsrTensor* dst);
} // namespace sparse
} // namespace phi
...@@ -16,11 +16,12 @@ limitations under the License. */ ...@@ -16,11 +16,12 @@ limitations under the License. */
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/visit_type.h" #include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/activation_kernel.h" #include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/elementwise_kernel.h" #include "paddle/phi/kernels/elementwise_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/sparse/copy_kernel.h"
#include "paddle/phi/kernels/sparse/elementwise_kernel.h" #include "paddle/phi/kernels/sparse/elementwise_kernel.h"
namespace phi { namespace phi {
...@@ -56,16 +57,16 @@ void ElementWiseAddCsrGradCPUKernel(const Context& dev_ctx, ...@@ -56,16 +57,16 @@ void ElementWiseAddCsrGradCPUKernel(const Context& dev_ctx,
if (dx != nullptr && dy == nullptr) { if (dx != nullptr && dy == nullptr) {
VLOG(4) << "Special case when dy is not needed"; VLOG(4) << "Special case when dy is not needed";
AllocCsrPtr<T, IntT>(dev_ctx, x, dx); AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
} else if (dx == nullptr && dy != nullptr) { } else if (dx == nullptr && dy != nullptr) {
VLOG(4) << "Special case when dx is not needed"; VLOG(4) << "Special case when dx is not needed";
AllocCsrPtr<T, IntT>(dev_ctx, y, dy); AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
} else { } else {
AllocCsrPtr<T, IntT>(dev_ctx, x, dx); AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
AllocCsrPtr<T, IntT>(dev_ctx, y, dy); AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
} }
} }
...@@ -78,12 +79,12 @@ void ElementWiseSubtractCsrGradCPUKernel(const Context& dev_ctx, ...@@ -78,12 +79,12 @@ void ElementWiseSubtractCsrGradCPUKernel(const Context& dev_ctx,
SparseCsrTensor* dy) { SparseCsrTensor* dy) {
if (dx) { if (dx) {
AllocCsrPtr<T, IntT>(dev_ctx, x, dx); AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
} }
if (dy) { if (dy) {
AllocCsrPtr<T, IntT>(dev_ctx, y, dy); AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
phi::NegativeKernel<T, Context>( phi::NegativeKernel<T, Context>(
dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements()); dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements());
} }
...@@ -126,7 +127,7 @@ void ElementWiseDivideCsrGradCPUKernel(const Context& dev_ctx, ...@@ -126,7 +127,7 @@ void ElementWiseDivideCsrGradCPUKernel(const Context& dev_ctx,
if (dy) { if (dy) {
// -dout * out / y // -dout * out / y
AllocCsrPtr<T, IntT>(dev_ctx, y, dy); AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
phi::NegativeKernel<T, Context>( phi::NegativeKernel<T, Context>(
dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements()); dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements());
auto tmp = sparse::ElementWiseMultiplyCsr<T, Context>(dev_ctx, *dy, out); auto tmp = sparse::ElementWiseMultiplyCsr<T, Context>(dev_ctx, *dy, out);
...@@ -145,16 +146,16 @@ void ElementWiseAddCooGradCPUKernel(const Context& dev_ctx, ...@@ -145,16 +146,16 @@ void ElementWiseAddCooGradCPUKernel(const Context& dev_ctx,
if (dx != nullptr && dy == nullptr) { if (dx != nullptr && dy == nullptr) {
VLOG(4) << "Special case when dy is not needed"; VLOG(4) << "Special case when dy is not needed";
AllocCooPtr<T, IntT>(dev_ctx, x, dx); AllocCooPtr<T, IntT>(dev_ctx, x, dx);
CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
} else if (dx == nullptr && dy != nullptr) { } else if (dx == nullptr && dy != nullptr) {
VLOG(4) << "Special case when dx is not needed"; VLOG(4) << "Special case when dx is not needed";
AllocCooPtr<T, IntT>(dev_ctx, y, dy); AllocCooPtr<T, IntT>(dev_ctx, y, dy);
CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
} else { } else {
AllocCooPtr<T, IntT>(dev_ctx, x, dx); AllocCooPtr<T, IntT>(dev_ctx, x, dx);
AllocCooPtr<T, IntT>(dev_ctx, y, dy); AllocCooPtr<T, IntT>(dev_ctx, y, dy);
CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
} }
} }
...@@ -167,12 +168,12 @@ void ElementWiseSubtractCooGradCPUKernel(const Context& dev_ctx, ...@@ -167,12 +168,12 @@ void ElementWiseSubtractCooGradCPUKernel(const Context& dev_ctx,
SparseCooTensor* dy) { SparseCooTensor* dy) {
if (dx) { if (dx) {
AllocCooPtr<T, IntT>(dev_ctx, x, dx); AllocCooPtr<T, IntT>(dev_ctx, x, dx);
CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
} }
if (dy) { if (dy) {
AllocCooPtr<T, IntT>(dev_ctx, y, dy); AllocCooPtr<T, IntT>(dev_ctx, y, dy);
CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
phi::NegativeKernel<T, Context>( phi::NegativeKernel<T, Context>(
dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements()); dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements());
} }
...@@ -215,7 +216,7 @@ void ElementWiseDivideCooGradCPUKernel(const Context& dev_ctx, ...@@ -215,7 +216,7 @@ void ElementWiseDivideCooGradCPUKernel(const Context& dev_ctx,
if (dy) { if (dy) {
// -dout * out / y // -dout * out / y
AllocCooPtr<T, IntT>(dev_ctx, y, dy); AllocCooPtr<T, IntT>(dev_ctx, y, dy);
CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
phi::NegativeKernel<T, Context>( phi::NegativeKernel<T, Context>(
dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements()); dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements());
auto tmp = sparse::ElementWiseMultiplyCoo<T, Context>(dev_ctx, *dy, out); auto tmp = sparse::ElementWiseMultiplyCoo<T, Context>(dev_ctx, *dy, out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册