From fd8ec4a1304b6bfbd1a5f35a214f67c31ccd3d65 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 20 Sep 2022 14:57:21 +0800 Subject: [PATCH] [Cherry-pick] Sparse add InferMeta (#46235) cherry-pick : #46016, #46021, #45974 * [Sparse]Sparse add support gpu (#45974) * [Sparse]Remove unused code (#46021) * [Sparse] Add infer meta (#46016) --- paddle/phi/api/lib/CMakeLists.txt | 12 +- paddle/phi/api/lib/api_gen_utils.cc | 20 ++ paddle/phi/api/lib/api_gen_utils.h | 8 + paddle/phi/api/lib/sparse_api_custom_impl.cc | 202 ------------------ paddle/phi/api/lib/tensor_method.cc | 13 +- .../yaml/generator/intermediate_api_gen.py | 5 +- .../phi/api/yaml/generator/sparse_api_gen.py | 56 ++++- .../api/yaml/generator/sparse_bw_api_gen.py | 9 +- paddle/phi/api/yaml/sparse_backward.yaml | 109 ++++++++++ paddle/phi/api/yaml/sparse_ops.yaml | 93 ++++++++ paddle/phi/core/meta_tensor.cc | 35 ++- paddle/phi/core/sparse_coo_tensor.cc | 61 +++++- paddle/phi/core/sparse_coo_tensor.h | 40 +++- paddle/phi/core/sparse_csr_tensor.cc | 52 ++++- paddle/phi/core/sparse_csr_tensor.h | 41 +++- paddle/phi/core/tensor_meta.cc | 12 ++ paddle/phi/core/tensor_meta.h | 20 ++ paddle/phi/core/tensor_utils.cc | 2 +- paddle/phi/core/tensor_utils.h | 8 + paddle/phi/infermeta/CMakeLists.txt | 1 + paddle/phi/infermeta/sparse/CMakeLists.txt | 9 + paddle/phi/infermeta/sparse/backward.cc | 35 +++ paddle/phi/infermeta/sparse/backward.h | 33 +++ paddle/phi/infermeta/sparse/binary.cc | 147 +++++++++++++ paddle/phi/infermeta/sparse/binary.h | 52 +++++ paddle/phi/infermeta/sparse/multiary.cc | 32 +++ paddle/phi/infermeta/sparse/multiary.h | 32 +++ paddle/phi/infermeta/sparse/unary.cc | 36 ++++ .../sparse/unary.h} | 16 +- paddle/phi/kernels/CMakeLists.txt | 3 +- .../sparse/cpu/elementwise_grad_kernel.cc | 7 +- .../kernels/sparse/cpu/elementwise_kernel.cc | 33 ++- .../kernels/sparse/cpu/sparse_utils_kernel.cc | 38 ++-- .../kernels/sparse/elementwise_grad_kernel.h | 11 + .../phi/kernels/sparse/elementwise_kernel.h | 13 +- paddle/phi/kernels/sparse/empty_kernel.cc | 10 +- .../sparse/gpu/elementwise_grad_kernel.cu | 56 +++++ .../kernels/sparse/gpu/elementwise_kernel.cu | 88 ++++++++ .../kernels/sparse/gpu/sparse_utils_kernel.cu | 20 +- .../kernels/sparse/impl/unary_kernel_impl.h | 4 - .../phi/kernels/sparse/sparse_utils_kernel.h | 21 ++ .../phi/tests/core/test_sparse_coo_tensor.cc | 1 - .../phi/tests/core/test_sparse_csr_tensor.cc | 1 - .../unittests/test_sparse_elementwise_op.py | 31 ++- 44 files changed, 1200 insertions(+), 328 deletions(-) delete mode 100644 paddle/phi/api/lib/sparse_api_custom_impl.cc create mode 100644 paddle/phi/infermeta/sparse/CMakeLists.txt create mode 100644 paddle/phi/infermeta/sparse/backward.cc create mode 100644 paddle/phi/infermeta/sparse/backward.h create mode 100644 paddle/phi/infermeta/sparse/binary.cc create mode 100644 paddle/phi/infermeta/sparse/binary.h create mode 100644 paddle/phi/infermeta/sparse/multiary.cc create mode 100644 paddle/phi/infermeta/sparse/multiary.h create mode 100644 paddle/phi/infermeta/sparse/unary.cc rename paddle/phi/{api/lib/sparse_api_custom_impl.h => infermeta/sparse/unary.h} (66%) create mode 100644 paddle/phi/kernels/sparse/gpu/elementwise_grad_kernel.cu create mode 100644 paddle/phi/kernels/sparse/gpu/elementwise_kernel.cu diff --git a/paddle/phi/api/lib/CMakeLists.txt b/paddle/phi/api/lib/CMakeLists.txt index 957d43b4623..3795060d24b 100644 --- a/paddle/phi/api/lib/CMakeLists.txt +++ b/paddle/phi/api/lib/CMakeLists.txt @@ -370,12 +370,6 @@ cc_library( SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform) -cc_library( - sparse_api_custom_impl - SRCS sparse_api_custom_impl.cc - DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform - tensor_copy) - cc_library( phi_function_api SRCS ${api_source_file} @@ -389,6 +383,7 @@ cc_library( kernel_dispatch api_gen_utils backward_infermeta + sparse_backward_infermeta phi_data_transform phi_function_api api_custom_impl @@ -396,12 +391,12 @@ cc_library( cc_library( sparse_api SRCS ${sparse_api_source_file} - DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api_custom_impl) + DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils) cc_library( sparse_bw_api SRCS ${sparse_bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api - sparse_api_custom_impl) + sparse_backward_infermeta) cc_library( phi_dygraph_api SRCS ${dygraph_api_source_file} @@ -424,6 +419,7 @@ cc_library( api_gen_utils kernel_dispatch infermeta + sparse_infermeta sparse_api strings_api) cc_library( diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index cbcc475a0df..39f9fa93918 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -88,6 +88,10 @@ std::shared_ptr TensorToStringTensor(const Tensor& tensor) { return std::dynamic_pointer_cast(tensor.impl()); } +std::shared_ptr TensorToSparseCooTensor( + const Tensor& tensor) { + return std::static_pointer_cast(tensor.impl()); +} /* ----------------- for infer_meta --------------------- */ phi::MetaTensor MakeMetaTensor(const phi::TensorBase& tensor) { @@ -130,6 +134,22 @@ phi::MetaTensor MakeMetaTensor( return phi::MetaTensor(); } +phi::MetaTensor MakeMetaTensor( + const paddle::optional& tensor) { + if (tensor) { + return {phi::MetaTensor(*tensor)}; + } + return phi::MetaTensor(); +} + +phi::MetaTensor MakeMetaTensor( + const paddle::optional& tensor) { + if (tensor) { + return {phi::MetaTensor(*tensor)}; + } + return phi::MetaTensor(); +} + std::vector MakeMetaTensor( const paddle::optional>& tensors) { std::vector meta_tensors; diff --git a/paddle/phi/api/lib/api_gen_utils.h b/paddle/phi/api/lib/api_gen_utils.h index 2a7f283dabb..797fcd72973 100644 --- a/paddle/phi/api/lib/api_gen_utils.h +++ b/paddle/phi/api/lib/api_gen_utils.h @@ -52,6 +52,8 @@ paddle::optional TensorToSelectedRows( std::shared_ptr TensorToStringTensor(const Tensor& tensor); +std::shared_ptr TensorToSparseCooTensor( + const Tensor& tensor); /* ----------------- for infer_meta --------------------- */ phi::MetaTensor MakeMetaTensor(const phi::TensorBase& tensor); @@ -68,6 +70,12 @@ std::vector MakeMetaTensor( phi::MetaTensor MakeMetaTensor( const paddle::optional& tensor); +phi::MetaTensor MakeMetaTensor( + const paddle::optional& tensor); + +phi::MetaTensor MakeMetaTensor( + const paddle::optional& tensor); + std::vector MakeMetaTensor( const paddle::optional>& tensors); diff --git a/paddle/phi/api/lib/sparse_api_custom_impl.cc b/paddle/phi/api/lib/sparse_api_custom_impl.cc deleted file mode 100644 index 6aaf21a5e7f..00000000000 --- a/paddle/phi/api/lib/sparse_api_custom_impl.cc +++ /dev/null @@ -1,202 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/api/lib/sparse_api_custom_impl.h" - -#include - -#include "glog/logging.h" -#include "paddle/phi/api/lib/kernel_dispatch.h" -#include "paddle/phi/core/kernel_registry.h" - -namespace paddle { -namespace experimental { -namespace sparse { - -Tensor to_sparse_coo_impl(const Tensor& x, const int64_t sparse_dim) { - if (x.layout() == phi::DataLayout::SPARSE_COO) { - return x; - } - - // 1. Get kernel signature and kernel - std::string kernel_name = "dense_to_coo"; - if (x.layout() == phi::DataLayout::SPARSE_CSR) { - kernel_name = "csr_to_coo"; - } - - auto kernel_key_set = ParseKernelKeyByInputArgs(x); - auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - - auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( - kernel_name, kernel_key); - const auto& kernel = kernel_result.kernel; - - VLOG(6) << "add API kernel key: " << kernel_key; - VLOG(6) << "to API kernel: " << kernel; - - // 2. Get Device Context - auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); - auto kernel_context = phi::KernelContext(dev_ctx); - - // 3. Auto data transform - if (x.layout() == phi::DataLayout::SPARSE_CSR) { - auto input = std::dynamic_pointer_cast(x.impl()); - kernel_context.EmplaceBackInput(input.get()); - } else { - auto input = std::dynamic_pointer_cast(x.impl()); - kernel_context.EmplaceBackInput(input.get()); - kernel_context.EmplaceBackAttr(sparse_dim); - } - - // 4. InferMeta - auto indices_meta = - phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW); - auto elements_meta = phi::DenseTensorMeta(x.dtype(), {1}, x.layout()); - - // 5. Prepare outputs - // create empty SparseCooTensor - phi::DenseTensor non_zero_indices(std::make_shared(), - std::move(indices_meta)); - phi::DenseTensor non_zero_elements(std::make_shared(), - std::move(elements_meta)); - auto coo = std::make_shared( - non_zero_indices, non_zero_elements, x.dims()); - - kernel_context.EmplaceBackOutput(coo.get()); - Tensor out; - out.set_impl(coo); - - // 6. Call kernel - kernel(&kernel_context); - - return out; -} - -Tensor to_sparse_csr_impl(const Tensor& x) { - if (x.layout() == phi::DataLayout::SPARSE_CSR) { - return x; - } - // 1. Get kernel signature and kernel - std::string kernel_name = "dense_to_csr"; - if (x.layout() == phi::DataLayout::SPARSE_COO) { - kernel_name = "coo_to_csr"; - } - - auto kernel_key_set = ParseKernelKeyByInputArgs(x); - auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - - auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( - kernel_name, kernel_key); - const auto& kernel = kernel_result.kernel; - - VLOG(6) << "add API kernel key: " << kernel_key; - VLOG(6) << "to API kernel: " << kernel; - - // 2. Get Device Context - auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); - auto kernel_context = phi::KernelContext(dev_ctx); - - // 3. Auto data transform - if (x.layout() == phi::DataLayout::SPARSE_COO) { - auto input = std::dynamic_pointer_cast(x.impl()); - kernel_context.EmplaceBackInput(input.get()); - } else { - auto input = std::dynamic_pointer_cast(x.impl()); - kernel_context.EmplaceBackInput(input.get()); - } - - // 4. InferMeta - auto crows_meta = - phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW); - auto cols_meta = - phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW); - auto elements_meta = phi::DenseTensorMeta(x.dtype(), {1}, x.layout()); - - // 5. Prepare outputs - // create empty SparseCooTensor - phi::DenseTensor non_zero_crows(std::make_shared(), - std::move(crows_meta)); - phi::DenseTensor non_zero_cols(std::make_shared(), - std::move(cols_meta)); - phi::DenseTensor non_zero_elements(std::make_shared(), - std::move(elements_meta)); - auto csr = std::make_shared( - non_zero_crows, non_zero_cols, non_zero_elements, x.dims()); - - kernel_context.EmplaceBackOutput(csr.get()); - Tensor out; - out.set_impl(csr); - - // 6. Call kernel - kernel(&kernel_context); - - return out; -} - -Tensor to_dense_impl(const Tensor& x) { - if (x.layout() != phi::DataLayout::SPARSE_CSR && - x.layout() != phi::DataLayout::SPARSE_COO) { - return x; - } - - // 1. Get kernel signature and kernel - std::string kernel_name = "coo_to_dense"; - if (x.layout() == phi::DataLayout::SPARSE_CSR) { - kernel_name = "csr_to_dense"; - } - - auto kernel_key_set = ParseKernelKeyByInputArgs(x); - auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - - auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( - kernel_name, kernel_key); - const auto& kernel = kernel_result.kernel; - - VLOG(6) << "add API kernel key: " << kernel_key; - VLOG(6) << "to API kernel: " << kernel; - - // 2. Get Device Context - auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); - auto kernel_context = phi::KernelContext(dev_ctx); - - // 3. Auto data transform - if (x.layout() == phi::DataLayout::SPARSE_COO) { - auto input = std::dynamic_pointer_cast(x.impl()); - kernel_context.EmplaceBackInput(input.get()); - } else { - auto input = std::dynamic_pointer_cast(x.impl()); - kernel_context.EmplaceBackInput(input.get()); - } - - // 4. InferMeta - auto dense_meta = phi::DenseTensorMeta(x.dtype(), x.dims(), x.layout()); - - // 5. Prepare outputs - // create empty SparseCooTensor - auto dense_out = std::make_shared( - std::make_shared(), std::move(dense_meta)); - - kernel_context.EmplaceBackOutput(dense_out.get()); - Tensor out; - out.set_impl(dense_out); - - // 6. Call kernel - kernel(&kernel_context); - - return out; -} - -} // namespace sparse -} // namespace experimental -} // namespace paddle diff --git a/paddle/phi/api/lib/tensor_method.cc b/paddle/phi/api/lib/tensor_method.cc index 96f9aefbb1f..312f52fa5e6 100644 --- a/paddle/phi/api/lib/tensor_method.cc +++ b/paddle/phi/api/lib/tensor_method.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/infermeta/unary.h" +// clang-format off namespace paddle { namespace experimental { @@ -165,7 +166,11 @@ void Tensor::copy_(const Tensor &src, static_cast(impl_.get())); } else if (kernel_type == KernelType::SPARSE_COO_KERNEL) { SetSparseKernelOutput(this, TensorType::SPARSE_COO); - // TODO(zhangkaihuo) add sparse infer_meta + phi::MetaTensor meta_out(impl_.get()); + phi::UnchangedInferMeta( + MakeMetaTensor( + *(std::static_pointer_cast(src.impl_))), + &meta_out); phi::Copy(*dev_ctx, (*(std::static_pointer_cast(src.impl_))), target_place, @@ -173,7 +178,11 @@ void Tensor::copy_(const Tensor &src, static_cast(impl_.get())); } else if (kernel_type == KernelType::SPARSE_CSR_KERNEL) { SetSparseKernelOutput(this, TensorType::SPARSE_CSR); - // TODO(zhangkaihuo) add sparse infer_meta + phi::MetaTensor meta_out(impl_.get()); + phi::UnchangedInferMeta( + MakeMetaTensor( + *(std::static_pointer_cast(src.impl_))), + &meta_out); phi::Copy(*dev_ctx, (*(std::static_pointer_cast(src.impl_))), target_place, diff --git a/paddle/phi/api/yaml/generator/intermediate_api_gen.py b/paddle/phi/api/yaml/generator/intermediate_api_gen.py index 8bec3e8c158..ce615dcb248 100644 --- a/paddle/phi/api/yaml/generator/intermediate_api_gen.py +++ b/paddle/phi/api/yaml/generator/intermediate_api_gen.py @@ -43,7 +43,6 @@ def source_include(header_file_path): #include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/kernel_dispatch.h" -#include "paddle/phi/api/lib/sparse_api_custom_impl.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/infermeta/binary.h" #include "paddle/phi/infermeta/multiary.h" @@ -51,6 +50,10 @@ def source_include(header_file_path): #include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/ternary.h" +#include "paddle/phi/infermeta/sparse/unary.h" +#include "paddle/phi/infermeta/sparse/binary.h" +#include "paddle/phi/infermeta/sparse/multiary.h" + #include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/supplement_tracing.h" """ diff --git a/paddle/phi/api/yaml/generator/sparse_api_gen.py b/paddle/phi/api/yaml/generator/sparse_api_gen.py index eb36bea8e89..9123cf7fff5 100644 --- a/paddle/phi/api/yaml/generator/sparse_api_gen.py +++ b/paddle/phi/api/yaml/generator/sparse_api_gen.py @@ -18,6 +18,7 @@ import argparse import re from api_gen import ForwardAPI +from api_base import PREFIX_TENSOR_NAME class SparseAPI(ForwardAPI): @@ -136,6 +137,36 @@ class SparseAPI(ForwardAPI): return kernel_context_code + def prepare_input(self): + input_names = self.inputs['names'] + input_types = self.inputs['tensor_type'] + attr_names = self.attrs['names'] + infer_meta = self.infer_meta + + infer_meta_params = infer_meta['param'] if infer_meta[ + 'param'] is not None else input_names + attr_names + + create_input_var_code = "" + tensor_type_map = { + 'dense': 'phi::DenseTensor', + 'sparse_coo': 'phi::SparseCooTensor', + 'sparse_csr': 'phi::SparseCsrTensor' + } + for param in infer_meta_params: + if param in input_names: + var_name = "auto " + PREFIX_TENSOR_NAME + param + " = " + if self.inputs['input_info'][param] == "const Tensor&": + create_input_var_code = create_input_var_code + var_name + param + ".impl();\n" + elif param in self.optional_vars: + tensor_type = 'phi::DenseTensor' + for name, input_type in zip(input_names, input_types): + if param == name: + tensor_type = tensor_type_map[input_type] + break + optional_var = "paddle::optional<" + tensor_type + ">(" + create_input_var_code = create_input_var_code + var_name + param + " ? " + optional_var + "*static_cast<" + tensor_type + "*>((*" + param + ").impl().get())) : " + optional_var + "paddle::none);\n" + return f"""{create_input_var_code}""" + def gen_sparse_kernel_code(self, kernel_name, inplace_flag=False): _, kernel_output_names, output_create = self.gene_output( self.kernel['dispatch'][kernel_name][1], None, '', inplace_flag) @@ -154,6 +185,8 @@ class SparseAPI(ForwardAPI): auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend); auto kernel_context = phi::KernelContext(dev_ctx); {output_create} +{self.prepare_input()} +{self.gene_infer_meta(kernel_output_names, '')} {kernel_context_code} phi_kernel(&kernel_context); {return_code}""" @@ -167,6 +200,7 @@ class SparseAPI(ForwardAPI): 'sparse_csr': 'DataLayout::SPARSE_CSR' } condition_list = [] + tensor_type_list = [] for i, in_type in enumerate(input_types): if in_type == "dense": if self.inputs['names'][i] in self.optional_vars: @@ -178,9 +212,15 @@ class SparseAPI(ForwardAPI): f"phi::DenseTensor::classof({self.inputs['names'][i]}.impl().get())" ) else: - condition_list.append( - f"{self.inputs['names'][i]}.layout() == {sparse_type_map[in_type]}" - ) + if in_type == 'sparse_coo': + condition_list.append( + f"{self.inputs['names'][i]}.is_sparse_coo_tensor()") + else: + condition_list.append( + f"{self.inputs['names'][i]}.is_sparse_csr_tensor()") + tensor_type_list.append(in_type) + self.inputs['tensor_type'] = tensor_type_list + return " && ".join(condition_list) def gene_dispatch_code(self, kernel_name, inplace_flag=False): @@ -229,8 +269,16 @@ def source_include(header_file_path): #include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/kernel_dispatch.h" -#include "paddle/phi/api/lib/sparse_api_custom_impl.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/infermeta/unary.h" +#include "paddle/phi/infermeta/binary.h" +#include "paddle/phi/infermeta/ternary.h" +#include "paddle/phi/infermeta/multiary.h" +#include "paddle/utils/none.h" + +#include "paddle/phi/infermeta/sparse/unary.h" +#include "paddle/phi/infermeta/sparse/binary.h" +#include "paddle/phi/infermeta/sparse/multiary.h" """ diff --git a/paddle/phi/api/yaml/generator/sparse_bw_api_gen.py b/paddle/phi/api/yaml/generator/sparse_bw_api_gen.py index 6845f91c604..83569d69510 100644 --- a/paddle/phi/api/yaml/generator/sparse_bw_api_gen.py +++ b/paddle/phi/api/yaml/generator/sparse_bw_api_gen.py @@ -111,8 +111,15 @@ def source_include(header_file_path): #include "paddle/phi/api/include/sparse_api.h" #include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/kernel_dispatch.h" -#include "paddle/phi/api/lib/sparse_api_custom_impl.h" #include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/infermeta/unary.h" +#include "paddle/phi/infermeta/binary.h" +#include "paddle/phi/infermeta/backward.h" + +#include "paddle/phi/infermeta/sparse/unary.h" +#include "paddle/phi/infermeta/sparse/binary.h" +#include "paddle/phi/infermeta/sparse/backward.h" """ diff --git a/paddle/phi/api/yaml/sparse_backward.yaml b/paddle/phi/api/yaml/sparse_backward.yaml index 4bc306388d1..41816898c3a 100644 --- a/paddle/phi/api/yaml/sparse_backward.yaml +++ b/paddle/phi/api/yaml/sparse_backward.yaml @@ -2,6 +2,9 @@ forward : tanh(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : abs_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, abs_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -10,6 +13,9 @@ forward : acos(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : acos_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, acos_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -18,6 +24,9 @@ forward : acosh(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : acosh_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, acosh_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -26,6 +35,9 @@ forward : add(Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad) output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] kernel : func : add_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}, add_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr} @@ -34,6 +46,9 @@ forward : addmm(Tensor input, Tensor x, Tensor y, float alpha=1.0, float beta=1.0) -> Tensor(out) args : (Tensor input, Tensor x, Tensor y, Tensor out_grad, float alpha=1.0, float beta=1.0) output : Tensor(input_grad), Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralTernaryGradInferMeta + param : [input, x, y] kernel : func : addmm_csr_dense_grad {dense, sparse_csr, dense, dense -> dense, sparse_csr, dense}, addmm_csr_csr_grad {sparse_csr, sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr, sparse_csr}, @@ -44,6 +59,9 @@ forward : asin(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : asin_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, asin_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -52,6 +70,9 @@ forward : asinh(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : asinh_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, asinh_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -60,6 +81,9 @@ forward : atan(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : atan_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, atan_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -68,6 +92,9 @@ forward : atanh(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : atanh_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, atanh_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -76,6 +103,9 @@ forward : cast(Tensor x, DataType index_dtype, DataType value_dtype) -> Tensor(out) args : (Tensor x, Tensor out_grad, DataType value_dtype) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] kernel : func : cast_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, cast_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -85,6 +115,9 @@ forward : conv3d_coo (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) -> Tensor(out), Tensor(rulebook), Tensor(counter) args : (Tensor x, Tensor kernel, Tensor out, Tensor rulebook, Tensor counter, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) output : Tensor(x_grad), Tensor(kernel_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, kernel] kernel : func : conv3d_coo_grad{sparse_coo, dense, sparse_coo, dense, dense, sparse_coo -> sparse_coo, dense} @@ -92,6 +125,9 @@ forward : divide(Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out, Tensor out_grad) output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] kernel : func : divide_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}, divide_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr} @@ -106,6 +142,9 @@ forward : expm1(Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out] kernel : func : expm1_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, expm1_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -114,6 +153,9 @@ forward : leaky_relu(Tensor x, float alpha) -> Tensor(out) args : (Tensor x, Tensor out_grad, float alpha) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : leaky_relu_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, leaky_relu_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -122,6 +164,9 @@ forward : log1p(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : log1p_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, log1p_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -130,6 +175,9 @@ forward : masked_matmul(Tensor x, Tensor y, Tensor mask) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad) output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] kernel : func : masked_matmul_csr_grad{dense, dense, sparse_csr -> dense, dense} @@ -137,6 +185,9 @@ forward : matmul(Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad) output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] kernel : func : matmul_csr_dense_grad {sparse_csr, dense, dense -> sparse_csr, dense}, matmul_csr_csr_grad {sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}, @@ -147,6 +198,9 @@ forward : maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook), Tensor(counter) args : (Tensor x, Tensor rulebook, Tensor counter, Tensor out, Tensor out_grad, int[] kernel_sizes) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] kernel : func : maxpool_coo_grad {sparse_coo, dense, dense, sparse_coo, sparse_coo -> sparse_coo} @@ -154,6 +208,9 @@ forward : multiply(Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad) output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] kernel : func : multiply_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}, multiply_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr} @@ -162,6 +219,9 @@ forward : mv(Tensor x, Tensor vec) -> Tensor(out) args : (Tensor x, Tensor vec, Tensor out_grad) output : Tensor(x_grad), Tensor(vec_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, vec] kernel : func : mv_coo_grad{sparse_coo, dense, dense -> sparse_coo, dense}, mv_csr_grad{sparse_csr, dense, dense -> sparse_csr, dense} @@ -170,6 +230,9 @@ forward : pow(Tensor x, float factor) -> Tensor(out) args : (Tensor x, Tensor out_grad, float factor) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : pow_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, pow_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -178,6 +241,9 @@ forward : relu6(Tensor x, float threshold) -> Tensor(out) args : (Tensor out, Tensor out_grad, float threshold) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out] kernel : func : relu6_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, relu6_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -186,6 +252,9 @@ forward : relu(Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out] kernel : func : relu_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, relu_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -194,12 +263,18 @@ forward : scale(Tensor x, float scale, float bias, bool bias_after_scale) -> Tensor(out) args : (Tensor out_grad, float scale) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out_grad] invoke : scale(out_grad, scale, 0.0, true) - backward_op : sin_grad forward : sin(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : sin_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, sin_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -208,6 +283,9 @@ forward : sinh(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : sinh_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, sinh_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -216,6 +294,9 @@ forward : softmax(Tensor x, int axis=-1) -> Tensor(out) args : (Tensor out, Tensor out_grad, int axis) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out] kernel : func : softmax_csr_grad{sparse_csr, sparse_csr -> sparse_csr} @@ -223,6 +304,9 @@ forward : sparse_coo_tensor(Tensor values, Tensor indices, IntArray dense_shape) -> Tensor(out) args : (Tensor indices, Tensor out_grad) output : Tensor(values_grad) + infer_meta : + func : UnchangedInferMeta + param: [out_grad] kernel : func : sparse_coo_tensor_grad{dense, sparse_coo -> dense} @@ -230,6 +314,9 @@ forward : sqrt(Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out] kernel : func : sqrt_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, sqrt_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -238,6 +325,9 @@ forward : square(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : square_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, square_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -246,6 +336,9 @@ forward : subtract(Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad) output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] kernel : func : subtract_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}, subtract_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr} @@ -254,6 +347,9 @@ forward : tan(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : tan_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, tan_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -262,6 +358,9 @@ forward : tanh(Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out] kernel : func : tanh_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, tanh_csr_grad {sparse_csr, sparse_csr -> sparse_csr} @@ -270,6 +369,9 @@ forward : to_dense(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : coo_to_dense_grad{sparse_coo, dense -> sparse_coo} @@ -277,6 +379,8 @@ forward : to_sparse_coo(Tensor x, int64_t sparse_dim) -> Tensor(out) args : (Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta kernel : func : coo_to_dense { sparse_coo -> dense } @@ -284,6 +388,9 @@ forward : values_coo(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : values_coo_grad{sparse_coo, dense-> sparse_coo} @@ -291,6 +398,8 @@ forward : fused_attention_csr(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax) args: (Tensor query, Tensor key, Tensor value, Tensor softmax, Tensor out_grad) output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad) + infer_meta : + func : sparse::FusedAttentionGradInferMeta kernel : func : fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense} layout : softmax diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index c36181e0e8a..043c12615fb 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -1,6 +1,8 @@ - op : abs args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : abs_coo{sparse_coo -> sparse_coo}, abs_csr{sparse_csr -> sparse_csr} @@ -10,6 +12,8 @@ - op : acos args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : acos_coo{sparse_coo -> sparse_coo}, acos_csr{sparse_csr -> sparse_csr} @@ -19,6 +23,8 @@ - op : acosh args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : acosh_coo{sparse_coo -> sparse_coo}, acosh_csr{sparse_csr -> sparse_csr} @@ -28,6 +34,8 @@ - op : add args : (Tensor x, Tensor y) output : Tensor(out) + infer_meta : + func : ElementwiseInferMeta kernel : func : add_coo_coo{sparse_coo, sparse_coo -> sparse_coo}, add_csr_csr{sparse_csr, sparse_csr -> sparse_csr} @@ -37,6 +45,8 @@ - op : asin args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : asin_coo{sparse_coo -> sparse_coo}, asin_csr{sparse_csr -> sparse_csr} @@ -46,6 +56,8 @@ - op : asinh args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : asinh_coo{sparse_coo -> sparse_coo}, asinh_csr{sparse_csr -> sparse_csr} @@ -55,6 +67,8 @@ - op : atan args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : atan_coo{sparse_coo -> sparse_coo}, atan_csr{sparse_csr -> sparse_csr} @@ -64,6 +78,8 @@ - op : atanh args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : atanh_coo{sparse_coo -> sparse_coo}, atanh_csr{sparse_csr -> sparse_csr} @@ -73,6 +89,9 @@ - op : cast args : (Tensor x, DataType index_dtype=DataType::UNDEFINED, DataType value_dtype=DataType::UNDEFINED) output : Tensor(out) + infer_meta : + func : CastInferMeta + param: [x, value_dtype] kernel : func : cast_coo{sparse_coo -> sparse_coo}, cast_csr{sparse_csr -> sparse_csr} @@ -83,6 +102,8 @@ - op : conv3d args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) output : Tensor(out), Tensor(rulebook), Tensor(counter) + infer_meta : + func : sparse::Conv3dInferMeta kernel : func : conv3d_coo{sparse_coo, dense -> sparse_coo, dense, dense} layout : x @@ -92,6 +113,8 @@ - op : divide args : (Tensor x, Tensor y) output : Tensor(out) + infer_meta : + func : ElementwiseInferMeta kernel : func : divide_coo_coo{sparse_coo, sparse_coo -> sparse_coo}, divide_csr_csr{sparse_csr, sparse_csr -> sparse_csr} @@ -101,6 +124,9 @@ - op : divide_scalar args : (Tensor x, float scalar) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : divide_coo_scalar{sparse_coo -> sparse_coo}, divide_csr_scalar{sparse_csr -> sparse_csr} @@ -109,6 +135,8 @@ - op : expm1 args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : expm1_coo{sparse_coo -> sparse_coo}, expm1_csr{sparse_csr -> sparse_csr} @@ -118,6 +146,9 @@ - op : leaky_relu args : (Tensor x, float alpha) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : leaky_relu_coo{sparse_coo -> sparse_coo}, leaky_relu_csr{sparse_csr -> sparse_csr} @@ -127,6 +158,8 @@ - op : log1p args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : log1p_coo{sparse_coo -> sparse_coo}, log1p_csr{sparse_csr -> sparse_csr} @@ -136,6 +169,8 @@ - op : multiply args : (Tensor x, Tensor y) output : Tensor(out) + infer_meta : + func : ElementwiseInferMeta kernel : func : multiply_coo_coo{sparse_coo, sparse_coo -> sparse_coo}, multiply_csr_csr{sparse_csr, sparse_csr -> sparse_csr} @@ -145,6 +180,9 @@ - op : pow args : (Tensor x, float factor) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : pow_coo{sparse_coo -> sparse_coo}, pow_csr{sparse_csr -> sparse_csr} @@ -154,6 +192,8 @@ - op : relu args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : relu_coo{sparse_coo -> sparse_coo}, relu_csr{sparse_csr -> sparse_csr} @@ -163,6 +203,9 @@ - op : relu6 args : (Tensor x, float threshold) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : relu6_coo{sparse_coo -> sparse_coo}, relu6_csr{sparse_csr -> sparse_csr} @@ -172,6 +215,9 @@ - op : scale args : (Tensor x, float scale, float bias, bool bias_after_scale) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : scale_coo{sparse_coo -> sparse_coo}, scale_csr{sparse_csr -> sparse_csr} @@ -180,6 +226,8 @@ - op : sin args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : sin_coo{sparse_coo -> sparse_coo}, sin_csr{sparse_csr -> sparse_csr} @@ -189,6 +237,8 @@ - op : sinh args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : sinh_coo{sparse_coo -> sparse_coo}, sinh_csr{sparse_csr -> sparse_csr} @@ -198,6 +248,9 @@ - op : softmax args : (Tensor x, int axis=-1) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : softmax_csr{sparse_csr -> sparse_csr} layout : x @@ -206,6 +259,8 @@ - op : sparse_coo_tensor args : (Tensor values, Tensor indices, IntArray dense_shape) output : Tensor(out) + infer_meta : + func : sparse::SparseCooTensorInferMeta kernel : func : sparse_coo_tensor{dense, dense -> sparse_coo} layout : values @@ -215,6 +270,8 @@ - op : sqrt args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : sqrt_coo{sparse_coo -> sparse_coo}, sqrt_csr{sparse_csr -> sparse_csr} @@ -224,6 +281,8 @@ - op : square args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : square_coo{sparse_coo -> sparse_coo}, square_csr{sparse_csr -> sparse_csr} @@ -233,6 +292,8 @@ - op : subtract args : (Tensor x, Tensor y) output : Tensor(out) + infer_meta : + func : ElementwiseInferMeta kernel : func : subtract_coo_coo{sparse_coo, sparse_coo -> sparse_coo}, subtract_csr_csr{sparse_csr, sparse_csr -> sparse_csr} @@ -242,6 +303,8 @@ - op : tan args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : tan_coo{sparse_coo -> sparse_coo}, tan_csr{sparse_csr -> sparse_csr} @@ -251,6 +314,8 @@ - op : tanh args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : tanh_coo{sparse_coo -> sparse_coo}, tanh_csr{sparse_csr -> sparse_csr} @@ -260,6 +325,8 @@ - op : to_dense args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : coo_to_dense {sparse_coo -> dense}, csr_to_dense {sparse_csr -> dense} @@ -268,6 +335,9 @@ - op : to_sparse_coo args : (Tensor x, int64_t sparse_dim) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] kernel : func : dense_to_coo { dense -> sparse_coo }, csr_to_coo { sparse_csr -> sparse_coo} @@ -276,6 +346,8 @@ - op : to_sparse_csr args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func : dense_to_csr {dense -> sparse_csr}, coo_to_csr {sparse_coo -> sparse_csr} @@ -283,6 +355,8 @@ - op : values args : (Tensor x) output : Tensor(out) + infer_meta : + func : sparse::ValuesInferMeta kernel : func : values_coo{sparse_coo -> dense}, values_csr{sparse_csr -> dense} @@ -292,6 +366,8 @@ - op: addmm args : (Tensor input, Tensor x, Tensor y, float alpha=1.0, float beta=1.0) output : Tensor(out) + infer_meta : + func : AddmmInferMeta kernel : func : addmm_csr_dense {dense, sparse_csr, dense -> dense}, addmm_csr_csr {sparse_csr, sparse_csr, sparse_csr -> sparse_csr}, @@ -303,6 +379,8 @@ - op: coalesce args : (Tensor x) output : Tensor(out) + infer_meta : + func : UnchangedInferMeta kernel : func: coalesce{sparse_coo -> sparse_coo} layout : x @@ -310,6 +388,9 @@ - op: full_like args : (Tensor x, Scalar value, DataType dtype=DataType::UNDEFINED) output : Tensor(out) + infer_meta : + func : CreateLikeInferMeta + param : [x, dtype] kernel : func : coo_full_like{sparse_coo -> sparse_coo}, csr_full_like{sparse_csr -> sparse_csr} @@ -319,6 +400,8 @@ - op: fused_attention args : (Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) output : Tensor(out), Tensor(softmax) + infer_meta : + func : sparse::FusedAttentionInferMeta kernel : func : fused_attention_csr{dense, dense, dense, sparse_csr, dense, dense -> dense, sparse_csr} layout : sparse_mask @@ -330,6 +413,9 @@ - op: masked_matmul args : (Tensor x, Tensor y, Tensor mask) output : Tensor(out) + infer_meta : + func : MatmulInferMeta + param : [x, y, false, false] kernel : func : masked_matmul_csr{dense, dense, sparse_csr -> sparse_csr} layout : x @@ -338,6 +424,9 @@ - op: matmul args : (Tensor x, Tensor y) output : Tensor(out) + infer_meta : + func : MatmulInferMeta + param: [x, y, false, false] kernel : func : matmul_csr_dense {sparse_csr, dense -> dense}, matmul_csr_csr {sparse_csr, sparse_csr -> sparse_csr}, @@ -349,6 +438,8 @@ - op: maxpool args : (Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) output : Tensor(out), Tensor(rulebook), Tensor(counter) + infer_meta : + func : sparse::Pool3dInferMeta kernel : func : maxpool_coo{sparse_coo -> sparse_coo, dense, dense} layout : x @@ -358,6 +449,8 @@ - op: mv args : (Tensor x, Tensor vec) output : Tensor(out) + infer_meta : + func : MvInferMeta kernel : func : mv_coo{sparse_coo, dense -> dense}, mv_csr{sparse_csr, dense -> dense} diff --git a/paddle/phi/core/meta_tensor.cc b/paddle/phi/core/meta_tensor.cc index 9a008e429da..8915b2ee871 100644 --- a/paddle/phi/core/meta_tensor.cc +++ b/paddle/phi/core/meta_tensor.cc @@ -64,6 +64,12 @@ void MetaTensor::set_dims(const DDim& dims) { DenseTensorUtils::GetMutableMeta( static_cast(tensor_)->mutable_value()) ->dims = dims; + } else if (phi::SparseCooTensor::classof(tensor_)) { + DenseTensorUtils::GetMutableMeta(static_cast(tensor_)) + ->dims = dims; + } else if (phi::SparseCsrTensor::classof(tensor_)) { + DenseTensorUtils::GetMutableMeta(static_cast(tensor_)) + ->dims = dims; } else { PADDLE_THROW(phi::errors::Unimplemented( "Unsupported setting dims for `%s`.", tensor_->type_info().name())); @@ -81,6 +87,13 @@ void MetaTensor::set_dtype(DataType dtype) { DenseTensorUtils::GetMutableMeta( static_cast(tensor_)->mutable_value()) ->dtype = dtype; + } else if (phi::SparseCooTensor::classof(tensor_)) { + DenseTensorUtils::GetMutableMeta(static_cast(tensor_)) + ->dtype = dtype; + } else if (phi::SparseCsrTensor::classof(tensor_)) { + DenseTensorUtils::GetMutableMeta(static_cast(tensor_)) + ->dtype = dtype; + // No need to set dtype } else { PADDLE_THROW(phi::errors::Unimplemented( "Unsupported settting dtype for `%s`.", tensor_->type_info().name())); @@ -98,6 +111,12 @@ void MetaTensor::set_layout(DataLayout layout) { DenseTensorUtils::GetMutableMeta( static_cast(tensor_)->mutable_value()) ->layout = layout; + } else if (phi::SparseCooTensor::classof(tensor_)) { + DenseTensorUtils::GetMutableMeta(static_cast(tensor_)) + ->layout = layout; + } else if (phi::SparseCsrTensor::classof(tensor_)) { + DenseTensorUtils::GetMutableMeta(static_cast(tensor_)) + ->layout = layout; } else { PADDLE_THROW(phi::errors::Unimplemented( "Unsupported settting layout for `%s`.", tensor_->type_info().name())); @@ -107,6 +126,10 @@ void MetaTensor::set_layout(DataLayout layout) { void MetaTensor::share_lod(const MetaTensor& meta_tensor) { ValidCheck(*this); ValidCheck(meta_tensor); + if (phi::SparseCooTensor::classof(tensor_) || + phi::SparseCsrTensor::classof(tensor_)) { + return; + } if (meta_tensor.lod().size() == 0) { // no need share return; @@ -128,7 +151,9 @@ void MetaTensor::share_lod(const MetaTensor& meta_tensor) { void MetaTensor::share_meta(const MetaTensor& meta_tensor) { ValidCheck(*this); if (phi::DenseTensor::classof(tensor_) || - phi::SelectedRows::classof(tensor_)) { + phi::SelectedRows::classof(tensor_) || + phi::SparseCooTensor::classof(tensor_) || + phi::SparseCsrTensor::classof(tensor_)) { share_dims(meta_tensor); set_dtype(meta_tensor.dtype()); set_layout(meta_tensor.layout()); @@ -143,7 +168,9 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) { ValidCheck(*this); bool is_dense_tensor = phi::DenseTensor::classof(tensor_); bool is_selected_rows = phi::SelectedRows::classof(tensor_); - if (is_dense_tensor || is_selected_rows) { + bool is_sparse_coo = phi::SparseCooTensor::classof(tensor_); + bool is_sparse_csr = phi::SparseCsrTensor::classof(tensor_); + if (is_dense_tensor || is_selected_rows || is_sparse_coo || is_sparse_csr) { set_dims(meta_tensor.dims()); if (is_selected_rows) { const auto in_tensor_base = meta_tensor.tensor(); @@ -172,6 +199,10 @@ const LoD& MetaTensor::lod() const { return static_cast(tensor_)->lod(); } else if (phi::SelectedRows::classof(tensor_)) { return static_cast(tensor_)->value().lod(); + } else if (phi::SparseCooTensor::classof(tensor_)) { + return static_cast(tensor_)->non_zero_elements().lod(); + } else if (phi::SparseCsrTensor::classof(tensor_)) { + return static_cast(tensor_)->non_zero_elements().lod(); } else { PADDLE_THROW(phi::errors::Unimplemented("Unsupported getting lod of `%s`.", tensor_->type_info().name())); diff --git a/paddle/phi/core/sparse_coo_tensor.cc b/paddle/phi/core/sparse_coo_tensor.cc index bf4d601c0b5..8df031421fe 100644 --- a/paddle/phi/core/sparse_coo_tensor.cc +++ b/paddle/phi/core/sparse_coo_tensor.cc @@ -21,34 +21,47 @@ SparseCooTensor::SparseCooTensor() { this->SetMember(non_zero_indices, non_zero_elements, {1}, true); } +SparseCooTensor::SparseCooTensor(SparseCooTensor&& other) { + this->non_zero_elements_ = other.non_zero_elements_; + this->non_zero_indices_ = other.non_zero_indices_; + this->coalesced_ = other.coalesced_; + set_meta(other.meta()); +} + SparseCooTensor::SparseCooTensor(const DenseTensor& non_zero_indices, const DenseTensor& non_zero_elements, const DDim& dims) : non_zero_indices_(non_zero_indices), non_zero_elements_(non_zero_elements), - coalesced_(false), - dims_(dims) {} + coalesced_(false) { + meta_.dims = dims; + meta_.layout = DataLayout::NCHW; + meta_.dtype = non_zero_elements.dtype(); +} SparseCooTensor::SparseCooTensor(DenseTensor&& non_zero_indices, DenseTensor&& non_zero_elements, const DDim& dims) : non_zero_indices_(non_zero_indices), non_zero_elements_(non_zero_elements), - coalesced_(false), - dims_(dims) {} + coalesced_(false) { + meta_.dims = dims; + meta_.layout = DataLayout::NCHW; + meta_.dtype = non_zero_elements.dtype(); +} SparseCooTensor::SparseCooTensor(const SparseCooTensor& other) : non_zero_indices_(other.non_zero_indices_), - non_zero_elements_(other.non_zero_elements_), - dims_(other.dims_) { + non_zero_elements_(other.non_zero_elements_) { this->coalesced_ = other.coalesced_; + set_meta(other.meta()); } SparseCooTensor SparseCooTensor::operator=(const SparseCooTensor& other) { - this->dims_ = other.dims_; - this->non_zero_indices_ = other.non_zero_indices_; this->non_zero_elements_ = other.non_zero_elements_; + this->non_zero_indices_ = other.non_zero_indices_; this->coalesced_ = other.coalesced_; + set_meta(other.meta()); return *this; } @@ -111,8 +124,18 @@ void SparseCooTensor::SetMember(const DenseTensor& non_zero_indices, const bool coalesced) { this->non_zero_indices_ = non_zero_indices; this->non_zero_elements_ = non_zero_elements; - this->dims_ = dims; + this->meta_.dims = dims; + this->coalesced_ = coalesced; +} + +void SparseCooTensor::SetMember(const DenseTensor& non_zero_indices, + const DenseTensor& non_zero_elements, + const SparseTensorMeta& meta, + const bool coalesced) { + this->non_zero_indices_ = non_zero_indices; + this->non_zero_elements_ = non_zero_elements; this->coalesced_ = coalesced; + set_meta(meta); } int32_t SparseCooTensor::sparse_dim() const { @@ -120,7 +143,25 @@ int32_t SparseCooTensor::sparse_dim() const { } int32_t SparseCooTensor::dense_dim() const { - return dims_.size() - sparse_dim(); + return meta_.dims.size() - sparse_dim(); +} + +void SparseCooTensor::set_meta(SparseTensorMeta&& meta) { + PADDLE_ENFORCE(!meta_.valid(), + phi::errors::InvalidArgument( + "Only when the original attribute of Tensor is " + "incomplete, can it be reset.")); + meta_ = std::move(meta); +} + +void SparseCooTensor::set_meta(const SparseTensorMeta& meta) { + PADDLE_ENFORCE( + meta.valid(), + phi::errors::InvalidArgument( + "Input meta is invalid, please check the meta attribute.")); + meta_.dims = meta.dims; + meta_.dtype = meta.dtype; + meta_.layout = meta.layout; } } // namespace phi diff --git a/paddle/phi/core/sparse_coo_tensor.h b/paddle/phi/core/sparse_coo_tensor.h index ba85a751dc0..f8869aa524d 100644 --- a/paddle/phi/core/sparse_coo_tensor.h +++ b/paddle/phi/core/sparse_coo_tensor.h @@ -20,6 +20,8 @@ limitations under the License. */ namespace phi { +class DenseTensorUtils; + /// \brief The SparseCooTensor uses two DenseTensors to represent /// the non zero elements and the indices of non zero elements of /// original DenseTensor. @@ -93,21 +95,19 @@ class SparseCooTensor : public TensorBase, /// \brief Return the number of elements contained in original dense tensor /// \return The number of elements contained in original dense tensor - int64_t numel() const override { return product(dims_); } + int64_t numel() const override { return product(meta_.dims); } /// \brief Returns the dims of the original dense tensor. /// \return The dims of the original dense tensor. - const DDim& dims() const noexcept override { return dims_; } + const DDim& dims() const noexcept override { return meta_.dims; } /// \brief Returns the data type of the tensor. /// \return The data type of the tensor. - DataType dtype() const noexcept override { - return non_zero_elements_.dtype(); - } + DataType dtype() const noexcept override { return meta_.dtype; } /// \brief Returns the data layout of the tensor. /// \return The data layout of the tensor. - DataLayout layout() const noexcept override { return DataLayout::SPARSE_COO; } + DataLayout layout() const noexcept override { return meta_.layout; } /// \brief Returns the data place of the tensor. /// \return The data place of the tensor. @@ -140,6 +140,17 @@ class SparseCooTensor : public TensorBase, const DDim& dims, const bool coalesced = false); + /// \brief set the member of sparse coo tensor. + /// \param non_zero_indices The indices of non zero elements in original dense + /// tensor. + /// \param non_zero_elements The non zero elements of original dense tensor. + /// \param meta The meta of original dense tensor. + /// \param coalesced whether the indices has coalesced. + void SetMember(const DenseTensor& non_zero_indices, + const DenseTensor& non_zero_elements, + const SparseTensorMeta& meta, + const bool coalesced = false); + /// \brief Get a mutable pointer of non_zero_indices_. /// return a mutable pointer of non_zero_indices_. DenseTensor* mutable_indices() { return &non_zero_indices_; } @@ -161,15 +172,22 @@ class SparseCooTensor : public TensorBase, DataType dtype, size_t requested_size = 0) override; - /// \brief set the dims of original dense tensor - void set_dims(const DDim& dims) { this->dims_ = dims; } - /// \brief get the sparse dim int32_t sparse_dim() const; /// \brief get the dnese dim int32_t dense_dim() const; + /// \brief Returns the meta information of the tensor. + /// \return The meta information of the tensor. + const SparseTensorMeta& meta() const noexcept { return meta_; } + + void set_meta(SparseTensorMeta&& meta); + + void set_meta(const SparseTensorMeta& meta); + + void set_dims(const DDim& dims) { meta_.dims = dims; } + /// \brief query table according to key const std::pair* IndicesPairs( const std::string& key) const { @@ -213,6 +231,10 @@ class SparseCooTensor : public TensorBase, } private: + friend class DenseTensorUtils; + + SparseTensorMeta meta_; + // save the indices of non zero elements in original dense tensor DenseTensor non_zero_indices_; // save the non zero elements of original dense tensor diff --git a/paddle/phi/core/sparse_csr_tensor.cc b/paddle/phi/core/sparse_csr_tensor.cc index 45131f48338..5c793048ea3 100644 --- a/paddle/phi/core/sparse_csr_tensor.cc +++ b/paddle/phi/core/sparse_csr_tensor.cc @@ -21,7 +21,6 @@ SparseCsrTensor::SparseCsrTensor() { this->non_zero_crows_ = crows; this->non_zero_cols_ = cols; this->non_zero_elements_ = values; - this->dims_ = phi::make_ddim({1, 1}); } inline void check_shape(const DDim& dims) { @@ -54,27 +53,30 @@ SparseCsrTensor::SparseCsrTensor(const DenseTensor& non_zero_crows, const DDim& dims) : non_zero_crows_(non_zero_crows), non_zero_cols_(non_zero_cols), - non_zero_elements_(non_zero_elements), - dims_(dims) { + non_zero_elements_(non_zero_elements) { if (non_zero_crows.initialized()) { - Check(non_zero_crows_, non_zero_cols_, non_zero_elements_, dims_); + Check(non_zero_crows_, non_zero_cols_, non_zero_elements_, dims); } else { // create a empty tensor check_shape(dims); } + meta_.dims = dims; + meta_.layout = DataLayout::NCHW; + meta_.dtype = non_zero_elements.dtype(); } SparseCsrTensor::SparseCsrTensor(const SparseCsrTensor& other) : non_zero_crows_(other.non_zero_crows_), non_zero_cols_(other.non_zero_cols_), - non_zero_elements_(other.non_zero_elements_), - dims_(other.dims_) {} + non_zero_elements_(other.non_zero_elements_) { + set_meta(other.meta()); +} SparseCsrTensor& SparseCsrTensor::operator=(const SparseCsrTensor& other) { - this->dims_ = other.dims(); - this->non_zero_crows_ = other.crows(); - this->non_zero_cols_ = other.cols(); - this->non_zero_elements_ = other.values(); + this->non_zero_crows_ = other.non_zero_crows(); + this->non_zero_cols_ = other.non_zero_cols(); + this->non_zero_elements_ = other.non_zero_elements(); + set_meta(other.meta()); return *this; } @@ -114,7 +116,35 @@ void SparseCsrTensor::SetMember(const DenseTensor& non_zero_crows, this->non_zero_crows_ = non_zero_crows; this->non_zero_cols_ = non_zero_cols; this->non_zero_elements_ = non_zero_elements; - this->dims_ = dims; + meta_.dims = dims; +} + +void SparseCsrTensor::SetMember(const DenseTensor& non_zero_crows, + const DenseTensor& non_zero_cols, + const DenseTensor& non_zero_elements, + const SparseTensorMeta& meta) { + Check(non_zero_crows, non_zero_cols, non_zero_elements, meta.dims); + this->non_zero_crows_ = non_zero_crows; + this->non_zero_cols_ = non_zero_cols; + this->non_zero_elements_ = non_zero_elements; + set_meta(meta); } +void SparseCsrTensor::set_meta(SparseTensorMeta&& meta) { + PADDLE_ENFORCE(!meta_.valid(), + phi::errors::InvalidArgument( + "Only when the original attribute of Tensor is " + "incomplete, can it be reset.")); + meta_ = std::move(meta); +} + +void SparseCsrTensor::set_meta(const SparseTensorMeta& meta) { + PADDLE_ENFORCE( + meta.valid(), + phi::errors::InvalidArgument( + "Input meta is invalid, please check the meta attribute.")); + meta_.dims = meta.dims; + meta_.dtype = meta.dtype; + meta_.layout = meta.layout; +} } // namespace phi diff --git a/paddle/phi/core/sparse_csr_tensor.h b/paddle/phi/core/sparse_csr_tensor.h index ee47e39f97f..056d049942a 100644 --- a/paddle/phi/core/sparse_csr_tensor.h +++ b/paddle/phi/core/sparse_csr_tensor.h @@ -14,14 +14,13 @@ limitations under the License. */ #pragma once -#include "paddle/phi/core/allocator.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_meta.h" namespace phi { -class CompatibleDenseTensorUtils; +class DenseTensorUtils; /// \brief The SparseCsrTensor uses three 1-D DenseTensors to represent /// the row index , column index and non zero elements of the original @@ -100,21 +99,19 @@ class SparseCsrTensor : public TensorBase, /// \brief Return the number of elements contained in original dense tensor /// \return The number of elements contained in original dense tensor - int64_t numel() const override { return product(dims_); } + int64_t numel() const override { return product(meta_.dims); } /// \brief Returns the dims of the original dense tensor. /// \return The dims of the original dense tensor. - const DDim& dims() const noexcept override { return dims_; } + const DDim& dims() const noexcept override { return meta_.dims; } /// \brief Returns the data type of the tensor. /// \return The data type of the tensor. - DataType dtype() const noexcept override { - return non_zero_elements_.dtype(); - } + DataType dtype() const noexcept override { return meta_.dtype; } /// \brief Returns the data layout of the tensor. /// \return The data layout of the tensor. - DataLayout layout() const noexcept override { return DataLayout::SPARSE_CSR; } + DataLayout layout() const noexcept override { return meta_.layout; } /// \brief Returns the data place of the tensor. /// \return The data place of the tensor. @@ -145,6 +142,18 @@ class SparseCsrTensor : public TensorBase, const DenseTensor& non_zero_elements, const DDim& dims); + /// \brief set the member of sparse csr tensor. + /// \param non_zero_crows The compresessed row index of non zero elements in + /// original dense tensor. + /// \param non_zero_cols The column index of non zero elements in original + /// dense tensor. + /// \param non_zero_elements The non zero elements of original dense tensor. + /// \param meta The meta of original dense tensor. + void SetMember(const DenseTensor& non_zero_crows, + const DenseTensor& non_zero_cols, + const DenseTensor& non_zero_elements, + const SparseTensorMeta& meta); + /// \brief Get a mutable pointer of non_zero_crows. /// return a mutable pointer of non_zero_crows. DenseTensor* mutable_crows() { return &non_zero_crows_; } @@ -169,18 +178,28 @@ class SparseCsrTensor : public TensorBase, /// mutable_values() DenseTensor* mutable_non_zero_elements() { return &non_zero_elements_; } + /// \brief Returns the meta information of the tensor. + /// \return The meta information of the tensor. + const SparseTensorMeta& meta() const noexcept { return meta_; } + + void set_meta(SparseTensorMeta&& meta); + + void set_meta(const SparseTensorMeta& meta); + /// \brief set the dims of original dense tensor - void set_dims(const DDim& dims) { this->dims_ = dims; } + void set_dims(const DDim& dims) { meta_.dims = dims; } + + protected: + SparseTensorMeta meta_; private: + friend class DenseTensorUtils; // save the compressed rows information of non zero elements DenseTensor non_zero_crows_; // save the columns information of non zero elements DenseTensor non_zero_cols_; // save the non zero elements DenseTensor non_zero_elements_; - // save the number of non zero elements in each batch - DDim dims_; /* --------------------------- */ /* example: 2-D Tensor */ /* --------------------------- */ diff --git a/paddle/phi/core/tensor_meta.cc b/paddle/phi/core/tensor_meta.cc index 0140ec23937..da088025768 100644 --- a/paddle/phi/core/tensor_meta.cc +++ b/paddle/phi/core/tensor_meta.cc @@ -48,4 +48,16 @@ bool StringTensorMeta::valid() const noexcept { return valid; } +SparseTensorMeta::SparseTensorMeta(const DDim& dims) : dims(dims) {} + +SparseTensorMeta::SparseTensorMeta(const DDim& dims, const DataLayout& layout) + : dims(dims), layout(layout) {} + +bool SparseTensorMeta::valid() const noexcept { + bool valid{true}; + valid = valid && (layout != DataLayout::UNDEFINED); + valid = valid && (product(dims) >= 0); + return valid; +} + } // namespace phi diff --git a/paddle/phi/core/tensor_meta.h b/paddle/phi/core/tensor_meta.h index 18f276f8b62..8969ef16d95 100644 --- a/paddle/phi/core/tensor_meta.h +++ b/paddle/phi/core/tensor_meta.h @@ -99,4 +99,24 @@ inline bool operator==(const StringTensorMeta& lhs, (lhs.offset == rhs.offset); } +struct SparseTensorMeta { + using DataLayout = paddle::experimental::DataLayout; + + SparseTensorMeta() = default; + explicit SparseTensorMeta(const DDim& dims); + explicit SparseTensorMeta(const DDim& dims, const DataLayout& layout); + /// \brief Test whether the metadata is valid. Does not throw exceptions. + /// \return Whether the metadata is valid. + bool valid() const noexcept; + + DDim dims; + DataType dtype; + DataLayout layout{DataLayout::NCHW}; +}; + +inline bool operator==(const SparseTensorMeta& lhs, + const SparseTensorMeta& rhs) { + return (lhs.dims == rhs.dims) && (lhs.layout == rhs.layout); +} + } // namespace phi diff --git a/paddle/phi/core/tensor_utils.cc b/paddle/phi/core/tensor_utils.cc index dcd25180e29..6e87f40ed0a 100644 --- a/paddle/phi/core/tensor_utils.cc +++ b/paddle/phi/core/tensor_utils.cc @@ -296,7 +296,7 @@ void Copy(const Context& dev_ctx, dst_place, blocking, dst->mutable_non_zero_elements()); - dst->set_dims(src.dims()); + dst->set_meta(src.meta()); dst->SetCoalesced(src.coalesced()); } diff --git a/paddle/phi/core/tensor_utils.h b/paddle/phi/core/tensor_utils.h index c478e3e0895..ceb46e2abec 100644 --- a/paddle/phi/core/tensor_utils.h +++ b/paddle/phi/core/tensor_utils.h @@ -28,6 +28,14 @@ class DenseTensorUtils { return &(tensor->meta_); } + static SparseTensorMeta* GetMutableMeta(SparseCooTensor* tensor) { + return &(tensor->meta_); + } + + static SparseTensorMeta* GetMutableMeta(SparseCsrTensor* tensor) { + return &(tensor->meta_); + } + static const std::shared_ptr& GetHolder( const DenseTensor& tensor) { return tensor.holder_; diff --git a/paddle/phi/infermeta/CMakeLists.txt b/paddle/phi/infermeta/CMakeLists.txt index 92b64ab4e66..b896bb818fa 100644 --- a/paddle/phi/infermeta/CMakeLists.txt +++ b/paddle/phi/infermeta/CMakeLists.txt @@ -7,3 +7,4 @@ cc_library( SRCS backward.cc DEPS meta_tensor convert_utils) add_subdirectory(strings) +add_subdirectory(sparse) diff --git a/paddle/phi/infermeta/sparse/CMakeLists.txt b/paddle/phi/infermeta/sparse/CMakeLists.txt new file mode 100644 index 00000000000..8717ef2cf6f --- /dev/null +++ b/paddle/phi/infermeta/sparse/CMakeLists.txt @@ -0,0 +1,9 @@ +cc_library( + sparse_infermeta + SRCS unary.cc binary.cc multiary.cc + DEPS convert_utils infermeta_utils) + +cc_library( + sparse_backward_infermeta + SRCS backward.cc + DEPS meta_tensor convert_utils) diff --git a/paddle/phi/infermeta/sparse/backward.cc b/paddle/phi/infermeta/sparse/backward.cc new file mode 100644 index 00000000000..d09c0e6fb84 --- /dev/null +++ b/paddle/phi/infermeta/sparse/backward.cc @@ -0,0 +1,35 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/sparse/backward.h" +#include "paddle/phi/infermeta/unary.h" + +#include "paddle/phi/core/infermeta_utils.h" + +namespace phi { +namespace sparse { + +void FusedAttentionGradInferMeta(const MetaTensor& query, + const MetaTensor& key, + const MetaTensor& value, + const MetaTensor& softmax, + const MetaTensor& out_grad, + MetaTensor* query_grad, + MetaTensor* key_grad, + MetaTensor* value_grad) { + // TODO(zhouwei, zhangkaihuo) add correct infer meta +} + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/infermeta/sparse/backward.h b/paddle/phi/infermeta/sparse/backward.h new file mode 100644 index 00000000000..e5c797923df --- /dev/null +++ b/paddle/phi/infermeta/sparse/backward.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/core/tensor_meta.h" + +namespace phi { +namespace sparse { + +void FusedAttentionGradInferMeta(const MetaTensor& query, + const MetaTensor& key, + const MetaTensor& value, + const MetaTensor& softmax, + const MetaTensor& out_grad, + MetaTensor* query_grad, + MetaTensor* key_grad, + MetaTensor* value_grad); + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/infermeta/sparse/binary.cc b/paddle/phi/infermeta/sparse/binary.cc new file mode 100644 index 00000000000..1b86f00ac2e --- /dev/null +++ b/paddle/phi/infermeta/sparse/binary.cc @@ -0,0 +1,147 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/sparse/binary.h" + +namespace phi { +namespace sparse { + +inline void GetOutShape(const DDim& x_dims, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + DDim* out_dims) { + PADDLE_ENFORCE_EQ( + x_dims.size(), + 5, + phi::errors::InvalidArgument("the shape of x should be (N, D, H, W, C)")); + PADDLE_ENFORCE_EQ(kernel_sizes.size(), + 5, + phi::errors::InvalidArgument( + "the shape of kernel should be (D, H, W, C, OC)")); + + // infer out shape + (*out_dims)[0] = x_dims[0]; + (*out_dims)[4] = kernel_sizes[4]; + for (int i = 1; i < 4; i++) { + (*out_dims)[i] = (x_dims[i] + 2 * paddings[i - 1] - + dilations[i - 1] * (kernel_sizes[i - 1] - 1) - 1) / + strides[i - 1] + + 1; + } +} + +inline void ResetSubmKernelSizeAndStrides(const DDim& kernel_dims, + std::vector* paddings, + std::vector* strides) { + for (uint64_t i = 0; i < paddings->size(); i++) { + (*paddings)[i] = kernel_dims[i] / 2; + (*strides)[i] = 1; + } +} + +void Conv3dInferMeta(const MetaTensor& x, + const MetaTensor& kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + const std::string& key, + MetaTensor* out, + MetaTensor* rulebook, + MetaTensor* counter) { + const auto& x_dims = x.dims(); + const auto& kernel_dims = kernel.dims(); + DDim out_dims = {1, 1, 1, 1, 1}; + + std::vector kernel_sizes(kernel_dims.size()); + for (int i = 0; i < kernel_dims.size(); i++) { + kernel_sizes[i] = kernel_dims[i]; + } + + std::vector subm_paddings(paddings), subm_strides(strides); + if (subm) { + // the out shape of subm_conv is same as input shape + // reset the padding=kernel_size/2 and strides=1 + ResetSubmKernelSizeAndStrides(kernel.dims(), &subm_paddings, &subm_strides); + } + + GetOutShape( + x_dims, kernel_sizes, subm_paddings, dilations, subm_strides, &out_dims); + + out->set_dtype(x.dtype()); + out->set_dims(out_dims); + out->set_layout(x.layout()); + + rulebook->set_dtype(DataType::INT32); + rulebook->set_layout(DataLayout::NCHW); + rulebook->set_dims({1}); + + counter->set_dtype(DataType::INT32); + counter->set_layout(DataLayout::NCHW); + counter->set_dims({1}); +} + +inline const std::vector PoolResetKernel( + const std::vector& kernel_sizes, + const int in_channels, + const int out_channels) { + std::vector res(kernel_sizes); + res.resize(5); + res[3] = in_channels; + res[4] = out_channels; + return res; +} + +void Pool3dInferMeta(const MetaTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + MetaTensor* out, + MetaTensor* rulebook, + MetaTensor* counter) { + const auto& x_dims = x.dims(); + DDim out_dims = {1, 1, 1, 1, 1}; + + const std::vector& real_kernel_sizes = + PoolResetKernel(kernel_sizes, x_dims[4], x_dims[4]); + GetOutShape( + x_dims, real_kernel_sizes, paddings, dilations, strides, &out_dims); + out->set_dtype(x.dtype()); + out->set_dims(out_dims); + out->set_layout(x.layout()); + + rulebook->set_dtype(DataType::INT32); + rulebook->set_layout(DataLayout::NCHW); + rulebook->set_dims({1}); + + counter->set_dtype(DataType::INT32); + counter->set_layout(DataLayout::NCHW); + counter->set_dims({1}); +} + +void SparseCooTensorInferMeta(const MetaTensor& values, + const MetaTensor& indices, + const IntArray& dense_shape, + MetaTensor* out) { + out->set_dims(phi::make_ddim(dense_shape.GetData())); + out->set_dtype(values.dtype()); + out->set_layout(values.layout()); +} + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/infermeta/sparse/binary.h b/paddle/phi/infermeta/sparse/binary.h new file mode 100644 index 00000000000..39d58bb539e --- /dev/null +++ b/paddle/phi/infermeta/sparse/binary.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/core/tensor_meta.h" + +namespace phi { +namespace sparse { + +void Conv3dInferMeta(const MetaTensor& x, + const MetaTensor& kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + const std::string& key, + MetaTensor* out, + MetaTensor* rulebook, + MetaTensor* counter); + +void Pool3dInferMeta(const MetaTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + MetaTensor* out, + MetaTensor* rulebook, + MetaTensor* counter); + +void SparseCooTensorInferMeta(const MetaTensor& values, + const MetaTensor& indices, + const IntArray& dense_shape, + MetaTensor* out); + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/infermeta/sparse/multiary.cc b/paddle/phi/infermeta/sparse/multiary.cc new file mode 100644 index 00000000000..fc940239d40 --- /dev/null +++ b/paddle/phi/infermeta/sparse/multiary.cc @@ -0,0 +1,32 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/sparse/multiary.h" + +namespace phi { +namespace sparse { + +void FusedAttentionInferMeta(const MetaTensor& query, + const MetaTensor& key, + const MetaTensor& value, + const MetaTensor& sparse_mask, + const MetaTensor& key_padding_mask, + const MetaTensor& attn_mask, + MetaTensor* out, + MetaTensor* softmax) { + // TODO(zhouwei,zhangkaihuo) add correct infer meta +} + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/infermeta/sparse/multiary.h b/paddle/phi/infermeta/sparse/multiary.h new file mode 100644 index 00000000000..20070e2cd9d --- /dev/null +++ b/paddle/phi/infermeta/sparse/multiary.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/meta_tensor.h" + +namespace phi { +namespace sparse { + +void FusedAttentionInferMeta(const MetaTensor& query, + const MetaTensor& key, + const MetaTensor& value, + const MetaTensor& sparse_mask, + const MetaTensor& key_padding_mask, + const MetaTensor& attn_mask, + MetaTensor* out, + MetaTensor* softmax); + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/infermeta/sparse/unary.cc b/paddle/phi/infermeta/sparse/unary.cc new file mode 100644 index 00000000000..45cb4f75e38 --- /dev/null +++ b/paddle/phi/infermeta/sparse/unary.cc @@ -0,0 +1,36 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/sparse/unary.h" + +#include "paddle/phi/core/infermeta_utils.h" + +namespace phi { +namespace sparse { + +void IndicesInferMeta(const MetaTensor& x, MetaTensor* out) { + out->set_dims({-1}); + out->set_dtype(DataType::INT32); + out->set_layout(DataLayout::NCHW); +} + +void ValuesInferMeta(const MetaTensor& x, MetaTensor* out) { + const auto& x_dims = x.dims(); + out->set_dims({-1, x_dims[x_dims.size() - 1]}); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); +} + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/api/lib/sparse_api_custom_impl.h b/paddle/phi/infermeta/sparse/unary.h similarity index 66% rename from paddle/phi/api/lib/sparse_api_custom_impl.h rename to paddle/phi/infermeta/sparse/unary.h index 6053d281f0f..880e90b7ae6 100644 --- a/paddle/phi/api/lib/sparse_api_custom_impl.h +++ b/paddle/phi/infermeta/sparse/unary.h @@ -14,19 +14,15 @@ limitations under the License. */ #pragma once -#include "paddle/phi/api/include/tensor.h" -#include "paddle/phi/common/backend.h" +#include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/core/tensor_meta.h" -namespace paddle { -namespace experimental { +namespace phi { namespace sparse { -Tensor to_dense_impl(const Tensor& x); +void IndicesInferMeta(const MetaTensor& x, MetaTensor* out); -Tensor to_sparse_coo_impl(const Tensor& x, const int64_t sparse_dim); - -Tensor to_sparse_csr_impl(const Tensor& x); +void ValuesInferMeta(const MetaTensor& x, MetaTensor* out); } // namespace sparse -} // namespace experimental -} // namespace paddle +} // namespace phi diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index d60584f77dc..7ea9041d77a 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -45,7 +45,8 @@ set(COMMON_KERNEL_DEPS selected_rows_functor) # remove this dep after removing fluid deps on tensor creation set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils) -set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta infermeta_utils) +set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta infermeta_utils + sparse_infermeta) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} switch_autotune) set(COMMON_KERNEL_DEPS diff --git a/paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc index 9c0939ec114..58ed3f2d6b0 100644 --- a/paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc @@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/kernels/sparse/elementwise_grad_kernel.h" +#include "paddle/phi/kernels/sparse/elementwise_kernel.h" + #include "glog/logging.h" #include "paddle/phi/backends/cpu/cpu_context.h" @@ -24,7 +27,7 @@ limitations under the License. */ #include "paddle/phi/kernels/elementwise_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/sparse/elementwise_kernel.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" namespace phi { namespace sparse { @@ -45,7 +48,7 @@ void AllocCooPtr(const Context& dev_ctx, SparseCooTensor* dx) { DenseTensor dx_indices = phi::EmptyLike(dev_ctx, x.indices()); DenseTensor dx_values = phi::EmptyLike(dev_ctx, x.values()); - dx->SetMember(dx_indices, dx_values, x.dims(), true); + dx->SetMember(dx_indices, dx_values, x.dims(), x.coalesced()); } template diff --git a/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc b/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc index 4156e46dc81..4e0eb90d781 100644 --- a/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/kernels/sparse/elementwise_kernel.h" - #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/elementwise_add_kernel.h" #include "paddle/phi/kernels/elementwise_kernel.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/sparse/flatten_indices.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" #include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" namespace phi { @@ -246,9 +247,7 @@ void ElementWiseCooKernelImpl(const Context& dev_ctx, vectorize(slice_ddim(x.values().dims(), 1, x.values().dims().size())); indeces_dim.insert(indeces_dim.begin(), nnz); DenseTensorMeta values_meta( - paddle::experimental::CppTypeToDataType::Type(), - phi::make_ddim(indeces_dim), - DataLayout::NCHW); + x.dtype(), phi::make_ddim(indeces_dim), DataLayout::NCHW); phi::DenseTensor out_indices = phi::Empty(dev_ctx, std::move(indices_meta)); phi::DenseTensor out_values = phi::Empty(dev_ctx, std::move(values_meta)); @@ -263,22 +262,16 @@ void ElementWiseCooKernelImpl(const Context& dev_ctx, } } -#define DEFINE_CSR_ELEMENTWISE_CPU_KERNEL(name) \ - template \ - void ElementWise##name##CsrCPUKernel(const Context& dev_ctx, \ - const SparseCsrTensor& x, \ - const SparseCsrTensor& y, \ - SparseCsrTensor* out) { \ - funcs::name##Functor functor; \ - auto coo_x = CsrToCoo(dev_ctx, x); \ - auto coo_y = CsrToCoo(dev_ctx, y); \ - DenseTensor indeces; \ - DenseTensor values; \ - SparseCooTensor coo_out; \ - coo_out.SetMember(indeces, values, x.dims()); \ - ElementWiseCooKernelImpl>( \ - dev_ctx, coo_x, coo_y, &coo_out, functor); \ - *out = CooToCsr(dev_ctx, coo_out); \ +#define DEFINE_CSR_ELEMENTWISE_CPU_KERNEL(name) \ + template \ + void ElementWise##name##CsrCPUKernel(const Context& dev_ctx, \ + const SparseCsrTensor& x, \ + const SparseCsrTensor& y, \ + SparseCsrTensor* out) { \ + auto coo_x = CsrToCoo(dev_ctx, x); \ + auto coo_y = CsrToCoo(dev_ctx, y); \ + auto coo_out = ElementWise##name##Coo(dev_ctx, coo_x, coo_y); \ + CooToCsrKernel(dev_ctx, coo_out, out); \ } #define DEFINE_CSR_ELEMENTWISE_KERNEL(name) \ diff --git a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc index 5199f42ed99..d0016099cd7 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc @@ -103,6 +103,7 @@ void DenseToCooKernel(const Context& dev_ctx, ++index; } } + out->SetMember(indices, values, x_dims, true); } @@ -181,17 +182,12 @@ void CooToCsrCPUKernel(const CPUContext& dev_ctx, int batchs = x_dims.size() == 2 ? 1 : x_dims[0]; int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1]; - phi::DenseTensor crows; - crows.Resize({batchs * (rows + 1)}); - IntT* csr_crows_data = dev_ctx.template Alloc(&crows); - - phi::DenseTensor cols; - cols.Resize({non_zero_num}); - IntT* csr_cols_data = dev_ctx.template Alloc(&cols); - - phi::DenseTensor values; - values.Resize({non_zero_num}); - T* csr_values_data = dev_ctx.template Alloc(&values); + phi::DenseTensor crows = phi::Empty(dev_ctx, {batchs * (rows + 1)}); + phi::DenseTensor cols = phi::Empty(dev_ctx, {non_zero_num}); + phi::DenseTensor values = phi::EmptyLike(dev_ctx, x.values()); + IntT* csr_crows_data = crows.data(); + IntT* csr_cols_data = cols.data(); + T* csr_values_data = values.data(); const auto& coo_indices = x.indices(); const auto& coo_values = x.values(); @@ -270,8 +266,7 @@ void CooToDenseCPUKernel(const CPUContext& dev_ctx, const int64_t dense_dim = x.dense_dim(); const T* x_data = values.data(); - *out = phi::Empty(dev_ctx, - DenseTensorMeta(x.dtype(), x.dims(), x.values().layout())); + dev_ctx.template Alloc(out); T* out_data = out->data(); int64_t base_offset = 1; for (int64_t i = 0; i < dense_dim; i++) { @@ -403,6 +398,21 @@ PD_REGISTER_KERNEL(values_coo, kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } +PD_REGISTER_KERNEL(indices_coo, + CPU, + ALL_LAYOUT, + phi::sparse::IndicesCooKernel, + float, + double, + phi::dtype::float16, + uint8_t, + int8_t, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + PD_REGISTER_KERNEL(values_csr, CPU, ALL_LAYOUT, @@ -415,7 +425,7 @@ PD_REGISTER_KERNEL(values_csr, int16_t, int, int64_t) { - kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); } PD_REGISTER_KERNEL(sparse_coo_tensor, diff --git a/paddle/phi/kernels/sparse/elementwise_grad_kernel.h b/paddle/phi/kernels/sparse/elementwise_grad_kernel.h index df3feb597e3..86eb3b4381d 100644 --- a/paddle/phi/kernels/sparse/elementwise_grad_kernel.h +++ b/paddle/phi/kernels/sparse/elementwise_grad_kernel.h @@ -15,7 +15,9 @@ limitations under the License. */ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" +#include "paddle/phi/infermeta/sparse/unary.h" #include "paddle/phi/kernels/empty_kernel.h" namespace phi { @@ -49,6 +51,9 @@ namespace sparse { const Sparse##type##Tensor& dout) { \ Sparse##type##Tensor dx; \ Sparse##type##Tensor dy; \ + MetaTensor meta_dx(&dx), meta_dy(&dy); \ + phi::UnchangedInferMeta(x, &meta_dx); \ + phi::UnchangedInferMeta(y, &meta_dy); \ ElementWise##name##type##GradKernel( \ dev_ctx, x, y, dout, &dx, &dy); \ return std::vector{dx, dy}; \ @@ -89,6 +94,9 @@ std::vector ElementWiseDivideCsrGrad( const SparseCsrTensor& dout) { SparseCsrTensor dx; SparseCsrTensor dy; + MetaTensor meta_dx(&dx), meta_dy(&dy); + phi::UnchangedInferMeta(x, &meta_dx); + phi::UnchangedInferMeta(y, &meta_dy); ElementWiseDivideCsrGradKernel( dev_ctx, x, y, out, dout, &dx, &dy); return std::vector{dx, dy}; @@ -103,6 +111,9 @@ std::vector ElementWiseDivideCooGrad( const SparseCooTensor& dout) { SparseCooTensor dx; SparseCooTensor dy; + MetaTensor meta_dx(&dx), meta_dy(&dy); + phi::UnchangedInferMeta(x, &meta_dx); + phi::UnchangedInferMeta(y, &meta_dy); ElementWiseDivideCooGradKernel( dev_ctx, x, y, out, dout, &dx, &dy); return std::vector{dx, dy}; diff --git a/paddle/phi/kernels/sparse/elementwise_kernel.h b/paddle/phi/kernels/sparse/elementwise_kernel.h index 0f9e67f7063..59a554348cf 100644 --- a/paddle/phi/kernels/sparse/elementwise_kernel.h +++ b/paddle/phi/kernels/sparse/elementwise_kernel.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" +#include "paddle/phi/infermeta/binary.h" namespace phi { namespace sparse { @@ -45,8 +46,10 @@ namespace sparse { const SparseCsrTensor& y) { \ DenseTensor crows; \ DenseTensor cols; \ - DenseTensor non_zero_elements; \ - SparseCsrTensor out(crows, cols, non_zero_elements, x.dims()); \ + DenseTensor values; \ + SparseCsrTensor out(crows, cols, values, x.dims()); \ + MetaTensor meta_out(out); \ + phi::ElementwiseInferMeta(x, y, &meta_out); \ ElementWise##name##CsrKernel(dev_ctx, x, y, &out); \ return out; \ } @@ -57,8 +60,10 @@ namespace sparse { const SparseCooTensor& x, \ const SparseCooTensor& y) { \ DenseTensor indices; \ - DenseTensor non_zero_elements; \ - SparseCooTensor out(indices, non_zero_elements, x.dims()); \ + DenseTensor values; \ + SparseCooTensor out(indices, values, x.dims()); \ + MetaTensor meta_out(out); \ + phi::ElementwiseInferMeta(x, y, &meta_out); \ ElementWise##name##CooKernel(dev_ctx, x, y, &out); \ return out; \ } diff --git a/paddle/phi/kernels/sparse/empty_kernel.cc b/paddle/phi/kernels/sparse/empty_kernel.cc index ebe0abc45ce..96a7301c589 100644 --- a/paddle/phi/kernels/sparse/empty_kernel.cc +++ b/paddle/phi/kernels/sparse/empty_kernel.cc @@ -26,11 +26,10 @@ template void EmptyLikeCooKernel(const Context& dev_ctx, const SparseCooTensor& x, SparseCooTensor* out) { - out->set_dims(x.dims()); *(out->mutable_indices()) = x.indices(); - const DenseTensor& x_values = x.non_zero_elements(); - DenseTensor* out_values = out->mutable_non_zero_elements(); + const DenseTensor& x_values = x.values(); + DenseTensor* out_values = out->mutable_values(); out_values->Resize(x_values.dims()); dev_ctx.template Alloc(out_values); } @@ -39,12 +38,11 @@ template void EmptyLikeCsrKernel(const Context& dev_ctx, const SparseCsrTensor& x, SparseCsrTensor* out) { - out->set_dims(x.dims()); *(out->mutable_crows()) = x.crows(); *(out->mutable_cols()) = x.cols(); - const DenseTensor& x_values = x.non_zero_elements(); - DenseTensor* out_values = out->mutable_non_zero_elements(); + const DenseTensor& x_values = x.values(); + DenseTensor* out_values = out->mutable_values(); out_values->Resize(x_values.dims()); dev_ctx.template Alloc(out_values); } diff --git a/paddle/phi/kernels/sparse/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/elementwise_grad_kernel.cu new file mode 100644 index 00000000000..e434dad588e --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/elementwise_grad_kernel.cu @@ -0,0 +1,56 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/elementwise_grad_kernel.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" + +namespace phi { +namespace sparse { + +template +void ElementWiseAddCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& y, + const SparseCooTensor& dout, + SparseCooTensor* dx, + SparseCooTensor* dy) { + if (dx) { + EmptyLikeCooKernel(dev_ctx, x, dx); + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); + } + + if (dy) { + EmptyLikeCooKernel(dev_ctx, y, dy); + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + } +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(add_coo_coo_grad, + GPU, + ALL_LAYOUT, + phi::sparse::ElementWiseAddCooGradKernel, + float, + double, + int16_t, + int, + int64_t, + phi::dtype::float16) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/gpu/elementwise_kernel.cu b/paddle/phi/kernels/sparse/gpu/elementwise_kernel.cu new file mode 100644 index 00000000000..7496f47de89 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/elementwise_kernel.cu @@ -0,0 +1,88 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/phi/kernels/elementwise_add_kernel.h" +#include "paddle/phi/kernels/sparse/elementwise_kernel.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" + +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" + +namespace phi { +namespace sparse { + +template +void ElementWiseAddCooGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& y, + SparseCooTensor* out) { + const auto& x_indices = x.indices(); + const auto& y_indices = y.indices(); + PADDLE_ENFORCE_EQ( + x_indices.numel(), + y_indices.numel(), + phi::errors::PreconditionNotMet( + "The numel of x.indices() and y.indices() should be equal")); + const IntT* x_indices_ptr = x_indices.data(); + const IntT* y_indices_ptr = y_indices.data(); +#ifdef PADDLE_WITH_HIP + bool is_same = thrust::equal(thrust::hip::par.on(dev_ctx.stream()), +#else + bool is_same = thrust::equal(thrust::cuda::par.on(dev_ctx.stream()), +#endif + x_indices_ptr, + x_indices_ptr + x_indices.numel(), + y_indices_ptr); + PADDLE_ENFORCE_EQ( + is_same, + true, + phi::errors::PreconditionNotMet( + "Currently, ElementWiseAddCooKernel only supports the case " + "where x and y have the same indices")); + EmptyLikeCooKernel(dev_ctx, x, out); + phi::AddKernel( + dev_ctx, x.values(), y.values(), out->mutable_values()); +} + +template +void ElementWiseAddCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& y, + SparseCooTensor* out) { + PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "VerifyIndices", ([&] { + ElementWiseAddCooGPUKernel( + dev_ctx, x, y, out); + })); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(add_coo_coo, + GPU, + ALL_LAYOUT, + phi::sparse::ElementWiseAddCooKernel, + float, + double, + int16_t, + int, + int64_t, + phi::dtype::float16) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu index 2ceda7da750..c037f6b1b83 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu @@ -172,6 +172,7 @@ void DenseToCooKernel(const Context& dev_ctx, temp_indexs_ptr, indices_data, sparse_data); + out->SetMember(indices, values, x_dims, true); } @@ -461,8 +462,8 @@ void CooToDenseGPUKernel(const GPUContext& dev_ctx, const auto place = dev_ctx.GetPlace(); const T* x_data = values.data(); - *out = phi::Empty( - dev_ctx, phi::DenseTensorMeta(x.dtype(), x.dims(), x.values().layout())); + dev_ctx.template Alloc(out); + T* out_data = out->data(); int64_t base_offset = 1; for (int64_t i = 0; i < dense_dim; i++) { @@ -619,6 +620,21 @@ PD_REGISTER_KERNEL(values_csr, int16_t, int, int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} + +PD_REGISTER_KERNEL(indices_coo, + GPU, + ALL_LAYOUT, + phi::sparse::IndicesCooKernel, + float, + double, + phi::dtype::float16, + uint8_t, + int8_t, + int16_t, + int, + int64_t) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } diff --git a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h index c965be21fb0..9b8b33d4d3a 100644 --- a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h +++ b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h @@ -129,8 +129,6 @@ void CastCooKernel(const Context& dev_ctx, DataType index_dtype, DataType value_dtype, SparseCooTensor* out) { - out->set_dims(x.dims()); - const DenseTensor& x_indices = x.indices(); const DenseTensor& x_values = x.non_zero_elements(); DenseTensor* out_indices = out->mutable_indices(); @@ -165,8 +163,6 @@ void CastCsrKernel(const Context& dev_ctx, DataType index_dtype, DataType value_dtype, SparseCsrTensor* out) { - out->set_dims(x.dims()); - const DenseTensor& x_crows = x.crows(); const DenseTensor& x_cols = x.cols(); const DenseTensor& x_values = x.non_zero_elements(); diff --git a/paddle/phi/kernels/sparse/sparse_utils_kernel.h b/paddle/phi/kernels/sparse/sparse_utils_kernel.h index 932427d42cd..fa16114e0f9 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_utils_kernel.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" +#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/kernels/empty_kernel.h" namespace phi { @@ -36,6 +37,8 @@ SparseCooTensor DenseToCoo(const Context& dev_ctx, DenseTensor indices; DenseTensor values; SparseCooTensor coo(indices, values, x.dims()); + MetaTensor meta_out(&coo); + phi::UnchangedInferMeta(x, &meta_out); DenseToCooKernel(dev_ctx, x, sparse_dim, &coo); return coo; } @@ -50,6 +53,8 @@ SparseCooTensor CsrToCoo(const Context& dev_ctx, const SparseCsrTensor& x) { DenseTensor indices; DenseTensor values; SparseCooTensor coo(indices, values, x.dims()); + MetaTensor meta_out(&coo); + phi::UnchangedInferMeta(x, &meta_out); CsrToCooKernel(dev_ctx, x, &coo); return coo; } @@ -65,6 +70,8 @@ SparseCsrTensor CooToCsr(const Context& dev_ctx, const SparseCooTensor& x) { DenseTensor cols; DenseTensor non_zero_elements; SparseCsrTensor csr(crows, cols, non_zero_elements, x.dims()); + MetaTensor meta_out(&csr); + phi::UnchangedInferMeta(x, &meta_out); CooToCsrKernel(dev_ctx, x, &csr); return csr; } @@ -79,10 +86,13 @@ void DenseToCsrKernel(const Context& dev_ctx, true, phi::errors::InvalidArgument( "SparseCsrTensor only support 2-D or 3-D Tensor.")); + const int64_t sparse_dim = x_dims.size() == 2 ? 2 : 3; DenseTensor indices; DenseTensor values; SparseCooTensor coo(indices, values, x.dims()); + MetaTensor meta_out(&coo); + phi::UnchangedInferMeta(x, &meta_out); DenseToCooKernel(dev_ctx, x, sparse_dim, &coo); CooToCsrKernel(dev_ctx, coo, out); } @@ -93,6 +103,8 @@ SparseCsrTensor DenseToCsr(const Context& dev_ctx, const DenseTensor& x) { DenseTensor cols; DenseTensor non_zero_elements; SparseCsrTensor csr(crows, cols, non_zero_elements, x.dims()); + MetaTensor meta_out(&csr); + phi::UnchangedInferMeta(x, &meta_out); DenseToCsrKernel(dev_ctx, x, &csr); return csr; } @@ -117,6 +129,8 @@ void CsrToDenseKernel(const Context& dev_ctx, DenseTensor indices; DenseTensor values; SparseCooTensor coo(indices, values, x.dims()); + MetaTensor meta_out(&coo); + phi::UnchangedInferMeta(x, &meta_out); CsrToCooKernel(dev_ctx, x, &coo); CooToDenseKernel(dev_ctx, coo, out); } @@ -143,6 +157,13 @@ void ValuesCsrKernel(const Context& dev_ctx, *out = x.non_zero_elements(); } +template +void IndicesCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + DenseTensor* out) { + *out = x.indices(); +} + template void SparseCooTensorKernel(const Context& dev_ctx, const DenseTensor& values, diff --git a/paddle/phi/tests/core/test_sparse_coo_tensor.cc b/paddle/phi/tests/core/test_sparse_coo_tensor.cc index e9ee1dde6b2..81e58843f54 100644 --- a/paddle/phi/tests/core/test_sparse_coo_tensor.cc +++ b/paddle/phi/tests/core/test_sparse_coo_tensor.cc @@ -52,7 +52,6 @@ TEST(sparse_coo_tensor, construct) { CHECK_EQ(sparse.numel(), 9); CHECK(sparse.dims() == dense_dims); CHECK(sparse.dtype() == DataType::FLOAT32); - CHECK(sparse.layout() == DataLayout::SPARSE_COO); CHECK(sparse.place() == phi::CPUPlace()); } diff --git a/paddle/phi/tests/core/test_sparse_csr_tensor.cc b/paddle/phi/tests/core/test_sparse_csr_tensor.cc index 7fad7bac399..42f87fc5aae 100644 --- a/paddle/phi/tests/core/test_sparse_csr_tensor.cc +++ b/paddle/phi/tests/core/test_sparse_csr_tensor.cc @@ -62,7 +62,6 @@ TEST(sparse_csr_tensor, construct) { CHECK_EQ(sparse.numel(), 9); CHECK(sparse.dims() == dense_dims); CHECK(sparse.dtype() == DataType::FLOAT32); - CHECK(sparse.layout() == DataLayout::SPARSE_CSR); CHECK(sparse.place() == paddle::platform::CPUPlace()); CHECK(sparse.initialized() == true); } diff --git a/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py b/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py index 149c4cfb22b..20f66e5f9a6 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py @@ -18,7 +18,7 @@ from operator import __add__, __sub__, __mul__, __truediv__ import numpy as np import paddle -from paddle.fluid.framework import _test_eager_guard +import paddle.incubate.sparse as sparse op_list = [__add__, __sub__, __mul__, __truediv__] @@ -134,6 +134,35 @@ class TestSparseElementWiseAPI(unittest.TestCase): for op in op_list: self.func_test_coo(op) + def test_add_same_indices(self): + indices_data = [[0, 1], [0, 3]] + values1_data = [[1.0], [2.0]] + values2_data = [[1.0], [2.0]] + shape = [2, 4, 2] + + sp_a = sparse.sparse_coo_tensor(indices_data, + values1_data, + shape, + stop_gradient=False) + sp_b = sparse.sparse_coo_tensor(indices_data, + values2_data, + shape, + stop_gradient=False) + + values1 = paddle.to_tensor(values1_data, stop_gradient=False) + values2 = paddle.to_tensor(values2_data, stop_gradient=False) + + #c.values() = a.values() + b.values() + sp_c = sparse.add(sp_a, sp_b) + sp_c.backward() + ref_c = values1 + values2 + ref_c.backward() + np.testing.assert_allclose(sp_c.values().numpy(), ref_c.numpy()) + np.testing.assert_allclose(sp_a.grad.values().numpy(), + values1.grad.numpy()) + np.testing.assert_allclose(sp_b.grad.values().numpy(), + values2.grad.numpy()) + if __name__ == "__main__": paddle.device.set_device('cpu') -- GitLab