未验证 提交 148fa05e 编写于 作者: Z zhangkaihuo 提交者: GitHub

Change sparse Copy from Kernel to basic component utils (#43916)

上级 1dbbe20e
......@@ -167,39 +167,21 @@ void Tensor::copy_(const Tensor &src,
blocking,
static_cast<phi::SelectedRows *>(impl_.get()));
} else if (kernel_type == KernelType::SPARSE_COO_KERNEL) {
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"copy_sparse_coo", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "copy API kernel key: " << kernel_key;
VLOG(6) << "copy API kernel: " << kernel;
using kernel_signature = void (*)(const platform::DeviceContext &,
const phi::SparseCooTensor &,
phi::Place,
bool,
phi::SparseCooTensor *);
this->set_impl(std::make_shared<phi::SparseCooTensor>());
auto *kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
(*(std::static_pointer_cast<phi::SparseCooTensor>(src.impl_))),
target_place,
blocking,
static_cast<phi::SparseCooTensor *>(impl_.get()));
SetSparseKernelOutput(this, TensorType::SPARSE_COO);
// TODO(zhangkaihuo) add sparse infer_meta
phi::Copy(*dev_ctx,
(*(std::static_pointer_cast<phi::SparseCooTensor>(src.impl_))),
target_place,
blocking,
static_cast<phi::SparseCooTensor *>(impl_.get()));
} else if (kernel_type == KernelType::SPARSE_CSR_KERNEL) {
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"copy_sparse_csr", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "copy API kernel key: " << kernel_key;
VLOG(6) << "copy API kernel: " << kernel;
using kernel_signature = void (*)(const platform::DeviceContext &,
const phi::SparseCsrTensor &,
phi::Place,
bool,
phi::SparseCsrTensor *);
this->set_impl(std::make_shared<phi::SparseCsrTensor>());
auto *kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
(*(std::static_pointer_cast<phi::SparseCsrTensor>(src.impl_))),
target_place,
blocking,
static_cast<phi::SparseCsrTensor *>(impl_.get()));
SetSparseKernelOutput(this, TensorType::SPARSE_CSR);
// TODO(zhangkaihuo) add sparse infer_meta
phi::Copy(*dev_ctx,
(*(std::static_pointer_cast<phi::SparseCsrTensor>(src.impl_))),
target_place,
blocking,
static_cast<phi::SparseCsrTensor *>(impl_.get()));
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"We currently only support dense tensor copy for now and if u need to "
......
......@@ -234,6 +234,53 @@ void Copy(const Context& dev_ctx,
dev_ctx, src.value(), dst_place, blocking, dst->mutable_value());
}
template <typename Context>
void Copy(const Context& dev_ctx,
const SparseCooTensor& src,
Place dst_place,
bool blocking,
SparseCooTensor* dst) {
phi::Copy<Context>(dev_ctx,
src.non_zero_indices(),
dst_place,
blocking,
dst->mutable_non_zero_indices());
phi::Copy<Context>(dev_ctx,
src.non_zero_elements(),
dst_place,
blocking,
dst->mutable_non_zero_elements());
dst->set_dims(src.dims());
dst->SetCoalesced(src.coalesced());
}
template <typename Context>
void Copy(const Context& dev_ctx,
const SparseCsrTensor& src,
Place dst_place,
bool blocking,
SparseCsrTensor* dst) {
phi::Copy<Context>(dev_ctx,
src.non_zero_crows(),
dst_place,
blocking,
dst->mutable_non_zero_crows());
phi::Copy<Context>(dev_ctx,
src.non_zero_cols(),
dst_place,
blocking,
dst->mutable_non_zero_cols());
phi::Copy<Context>(dev_ctx,
src.non_zero_elements(),
dst_place,
blocking,
dst->mutable_non_zero_elements());
dst->set_dims(src.dims());
}
template void Copy(const CPUContext& dev_ctx,
const DenseTensor& src,
Place dst_place,
......@@ -257,6 +304,30 @@ template void Copy(const DeviceContext& dev_ctx,
bool blocking,
SelectedRows* dst);
template void Copy(const CPUContext& dev_ctx,
const SparseCooTensor& src,
Place dst_place,
bool blocking,
SparseCooTensor* dst);
template void Copy(const DeviceContext& dev_ctx,
const SparseCooTensor& src,
Place dst_place,
bool blocking,
SparseCooTensor* dst);
template void Copy(const CPUContext& dev_ctx,
const SparseCsrTensor& src,
Place dst_place,
bool blocking,
SparseCsrTensor* dst);
template void Copy(const DeviceContext& dev_ctx,
const SparseCsrTensor& src,
Place dst_place,
bool blocking,
SparseCsrTensor* dst);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template void Copy(const GPUContext& dev_ctx,
const DenseTensor& src,
......@@ -268,6 +339,16 @@ template void Copy(const GPUContext& dev_ctx,
Place dst_place,
bool blocking,
SelectedRows* dst);
template void Copy(const GPUContext& dev_ctx,
const SparseCooTensor& src,
Place dst_place,
bool blocking,
SparseCooTensor* dst);
template void Copy(const GPUContext& dev_ctx,
const SparseCsrTensor& src,
Place dst_place,
bool blocking,
SparseCsrTensor* dst);
#endif
#ifdef PADDLE_WITH_XPU
......
......@@ -16,6 +16,8 @@ limitations under the License. */
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/tensor_meta.h"
namespace phi {
......@@ -85,4 +87,18 @@ void Copy(const Context& dev_ctx,
bool blocking,
SelectedRows* dst);
template <typename Context>
void Copy(const Context& dev_ctx,
const SparseCooTensor& src,
Place dst_place,
bool blocking,
SparseCooTensor* dst);
template <typename Context>
void Copy(const Context& dev_ctx,
const SparseCsrTensor& src,
Place dst_place,
bool blocking,
SparseCsrTensor* dst);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/copy_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/tensor_utils.h"
namespace phi {
namespace sparse {
template <typename Context>
void CopyCoo(const Context& dev_ctx,
const SparseCooTensor& src,
Place dst_place,
bool blocking,
SparseCooTensor* dst) {
phi::Copy<Context>(dev_ctx,
src.non_zero_indices(),
dst_place,
blocking,
dst->mutable_non_zero_indices());
phi::Copy<Context>(dev_ctx,
src.non_zero_elements(),
dst_place,
blocking,
dst->mutable_non_zero_elements());
dst->set_dims(src.dims());
}
template <typename Context>
void CopyCsr(const Context& dev_ctx,
const SparseCsrTensor& src,
Place dst_place,
bool blocking,
SparseCsrTensor* dst) {
phi::Copy<Context>(dev_ctx,
src.non_zero_crows(),
dst_place,
blocking,
dst->mutable_non_zero_crows());
phi::Copy<Context>(dev_ctx,
src.non_zero_cols(),
dst_place,
blocking,
dst->mutable_non_zero_cols());
phi::Copy<Context>(dev_ctx,
src.non_zero_elements(),
dst_place,
blocking,
dst->mutable_non_zero_elements());
dst->set_dims(src.dims());
}
} // namespace sparse
} // namespace phi
PD_REGISTER_GENERAL_KERNEL(copy_sparse_coo,
CPU,
ALL_LAYOUT,
phi::sparse::CopyCoo<phi::CPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(copy_sparse_csr,
CPU,
ALL_LAYOUT,
phi::sparse::CopyCsr<phi::CPUContext>,
ALL_DTYPE) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(copy_sparse_coo,
GPU,
ALL_LAYOUT,
phi::sparse::CopyCoo<phi::GPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(copy_sparse_csr,
GPU,
ALL_LAYOUT,
phi::sparse::CopyCsr<phi::GPUContext>,
ALL_DTYPE) {}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h"
namespace phi {
namespace sparse {
template <typename Context>
void CopyCoo(const Context& dev_ctx,
const SparseCooTensor& src,
Place dst_place,
bool blocking,
SparseCooTensor* dst);
template <typename Context>
void CopyCsr(const Context& dev_ctx,
const SparseCsrTensor& src,
Place dst_place,
bool blocking,
SparseCsrTensor* dst);
} // namespace sparse
} // namespace phi
......@@ -16,11 +16,12 @@ limitations under the License. */
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/elementwise_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/sparse/copy_kernel.h"
#include "paddle/phi/kernels/sparse/elementwise_kernel.h"
namespace phi {
......@@ -56,16 +57,16 @@ void ElementWiseAddCsrGradCPUKernel(const Context& dev_ctx,
if (dx != nullptr && dy == nullptr) {
VLOG(4) << "Special case when dy is not needed";
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
} else if (dx == nullptr && dy != nullptr) {
VLOG(4) << "Special case when dx is not needed";
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
} else {
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
}
}
......@@ -78,12 +79,12 @@ void ElementWiseSubtractCsrGradCPUKernel(const Context& dev_ctx,
SparseCsrTensor* dy) {
if (dx) {
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
}
if (dy) {
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
phi::NegativeKernel<T, Context>(
dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements());
}
......@@ -126,7 +127,7 @@ void ElementWiseDivideCsrGradCPUKernel(const Context& dev_ctx,
if (dy) {
// -dout * out / y
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
phi::NegativeKernel<T, Context>(
dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements());
auto tmp = sparse::ElementWiseMultiplyCsr<T, Context>(dev_ctx, *dy, out);
......@@ -145,16 +146,16 @@ void ElementWiseAddCooGradCPUKernel(const Context& dev_ctx,
if (dx != nullptr && dy == nullptr) {
VLOG(4) << "Special case when dy is not needed";
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
} else if (dx == nullptr && dy != nullptr) {
VLOG(4) << "Special case when dx is not needed";
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
} else {
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
}
}
......@@ -167,12 +168,12 @@ void ElementWiseSubtractCooGradCPUKernel(const Context& dev_ctx,
SparseCooTensor* dy) {
if (dx) {
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
}
if (dy) {
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
phi::NegativeKernel<T, Context>(
dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements());
}
......@@ -215,7 +216,7 @@ void ElementWiseDivideCooGradCPUKernel(const Context& dev_ctx,
if (dy) {
// -dout * out / y
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
phi::NegativeKernel<T, Context>(
dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements());
auto tmp = sparse::ElementWiseMultiplyCoo<T, Context>(dev_ctx, *dy, out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册