diff --git a/.gitignore b/.gitignore index debec551d9cd7344a31efbbb709bfbb759a15d3f..a2009a1ed30a1c6a17627b06170734fc17390d31 100644 --- a/.gitignore +++ b/.gitignore @@ -7,9 +7,11 @@ paddle/fluid/op_use_default_grad_maker_DEV.spec paddle/fluid/op_use_default_grad_maker_PR.spec paddle/phi/api/backward/backward_api.h paddle/phi/api/include/api.h +paddle/phi/api/include/sparse_api.h paddle/phi/api/lib/api.cc paddle/phi/api/lib/dygraph_api.* paddle/phi/api/lib/backward_api.cc +paddle/phi/api/lib/sparse_api.cc paddle/phi/extension.h paddle/phi/include/* paddle/phi/infermeta/generated.* diff --git a/paddle/phi/api/lib/CMakeLists.txt b/paddle/phi/api/lib/CMakeLists.txt index 5edb83f8c3fc01d198d3f63b64047b9e45cd747b..4f449c578bab00482fd91528496f4d8788f927b1 100644 --- a/paddle/phi/api/lib/CMakeLists.txt +++ b/paddle/phi/api/lib/CMakeLists.txt @@ -32,6 +32,14 @@ set(bw_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/backward_api.cc) set(bw_api_header_file_tmp ${bw_api_header_file}.tmp) set(bw_api_source_file_tmp ${bw_api_source_file}.tmp) +# sparse api file +set(sparse_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_api_gen.py) +set(sparse_api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_api.yaml) +set(sparse_api_header_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/include/sparse_api.h) +set(sparse_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/sparse_api.cc) +set(sparse_api_header_file_tmp ${api_header_file}.tmp) +set(sparse_api_source_file_tmp ${api_source_file}.tmp) + # wrapped infermeta file set(wrapped_infermeta_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/wrapped_infermeta_gen.py) set(api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml) @@ -73,6 +81,19 @@ add_custom_command( DEPENDS ${bw_api_yaml_file} ${bw_api_gen_file} ${api_gen_base} VERBATIM) +# generate sparse api +add_custom_command( + OUTPUT ${sparse_api_header_file} ${sparse_api_source_file} + COMMAND ${PYTHON_EXECUTABLE} ${sparse_api_gen_file} + --api_yaml_path ${sparse_api_yaml_file} + --api_header_path ${sparse_api_header_file_tmp} + --api_source_path ${sparse_api_source_file_tmp} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${sparse_api_header_file_tmp} ${sparse_api_header_file} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${sparse_api_source_file_tmp} ${sparse_api_source_file} + COMMENT "copy_if_different ${sparse_api_header_file} ${sparse_sparse_api_source_file}" + DEPENDS ${sparse_api_yaml_file} ${sparse_api_gen_file} ${api_gen_base} + VERBATIM) + # generate wrapped infermeta add_custom_command( OUTPUT ${wrapped_infermeta_header_file} ${wrapped_infermeta_source_file} @@ -87,12 +108,14 @@ cc_library(op_meta_info SRCS op_meta_info.cc DEPS phi_tensor_raw) cc_library(wrapped_infermeta SRCS ${wrapped_infermeta_source_file} DEPS phi) cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS phi_tensor_raw phi_context kernel_factory) +cc_library(api_gen_utils SRCS api_gen_utils.cc DEPS phi_tensor_raw selected_rows sparse_csr_tensor sparse_coo_tensor) cc_library(phi_data_transform SRCS data_transform.cc DEPS phi_tensor_raw transfer_layout_kernel cast_kernel data_device_transform) -cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch phi_data_transform) +cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform) +cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform) -cc_library(sparse_api SRCS sparse_api.cc DEPS phi_tensor_raw phi kernel_dispatch phi_data_transform) -cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch phi_data_transform api_custom_impl) -cc_library(phi_dygraph_api SRCS ${dygraph_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch phi_data_transform) -cc_library(phi_bw_function_api SRCS ${bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch backward_infermeta phi_data_transform phi_function_api api_custom_impl) +cc_library(sparse_api SRCS sparse_api.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api_custom_impl) +cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform api_custom_impl) +cc_library(phi_dygraph_api SRCS ${dygraph_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform) +cc_library(phi_bw_function_api SRCS ${bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform phi_function_api api_custom_impl) cc_library(phi_tensor SRCS tensor_method.cc DEPS phi_tensor_raw phi_function_api) diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 19b113838eab5403aca00d9d97b278646228c512..fc1afb26bf4143e5c75398b3dc1042581e1f1546 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -14,8 +14,8 @@ limitations under the License. */ #include "paddle/phi/api/lib/api_custom_impl.h" +#include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/api_registry.h" -#include "paddle/phi/api/lib/api_utils.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/utils/storage.h" diff --git a/paddle/phi/api/lib/api_utils.h b/paddle/phi/api/lib/api_gen_utils.cc similarity index 62% rename from paddle/phi/api/lib/api_utils.h rename to paddle/phi/api/lib/api_gen_utils.cc index 6c1fa97c0f52a697383a3526220cc758d778823d..f04e74b45fcd42cfeee860b05f52855ec15ef8f6 100644 --- a/paddle/phi/api/lib/api_utils.h +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -12,26 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - -#include "paddle/phi/api/include/tensor.h" -#include "paddle/phi/api/lib/utils/storage.h" -#include "paddle/phi/core/compat/convert_utils.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/meta_tensor.h" -#include "paddle/phi/core/selected_rows.h" +#include "paddle/phi/api/lib/api_gen_utils.h" namespace paddle { namespace experimental { /* ------------------ for input ----------------------- */ -inline std::shared_ptr TensorToDenseTensor( - const Tensor& tensor) { +std::shared_ptr TensorToDenseTensor(const Tensor& tensor) { return std::dynamic_pointer_cast(tensor.impl()); } -inline std::shared_ptr TensorToDenseTensor( +std::shared_ptr TensorToDenseTensor( const paddle::optional& tensor) { if (tensor) { return std::dynamic_pointer_cast(tensor->impl()); @@ -39,7 +31,7 @@ inline std::shared_ptr TensorToDenseTensor( return nullptr; } -inline std::unique_ptr> TensorToDenseTensor( +std::unique_ptr> TensorToDenseTensor( const std::vector& tensors) { auto pt_tensors = std::make_unique>(); pt_tensors->reserve(tensors.size()); @@ -52,12 +44,11 @@ inline std::unique_ptr> TensorToDenseTensor( return std::move(pt_tensors); } -inline std::shared_ptr TensorToSelectedRows( - const Tensor& tensor) { +std::shared_ptr TensorToSelectedRows(const Tensor& tensor) { return std::dynamic_pointer_cast(tensor.impl()); } -inline std::shared_ptr TensorToSelectedRows( +std::shared_ptr TensorToSelectedRows( const paddle::optional& tensor) { if (tensor) { return std::dynamic_pointer_cast(tensor->impl()); @@ -67,11 +58,11 @@ inline std::shared_ptr TensorToSelectedRows( /* ----------------- for infer_meta --------------------- */ -inline phi::MetaTensor MakeMetaTensor(const phi::DenseTensor& tensor) { +phi::MetaTensor MakeMetaTensor(const phi::DenseTensor& tensor) { return phi::MetaTensor(tensor); } -inline paddle::optional MakeMetaTensor( +paddle::optional MakeMetaTensor( const paddle::optional& tensor) { if (tensor) { return {phi::MetaTensor(*tensor)}; @@ -79,7 +70,7 @@ inline paddle::optional MakeMetaTensor( return {paddle::none}; } -inline std::vector MakeMetaTensor( +std::vector MakeMetaTensor( const std::vector& tensors) { std::vector meta_tensors; meta_tensors.reserve(tensors.size()); @@ -89,11 +80,11 @@ inline std::vector MakeMetaTensor( return meta_tensors; } -inline phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor) { +phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor) { return phi::MetaTensor(tensor); } -inline paddle::optional MakeMetaTensor( +paddle::optional MakeMetaTensor( const paddle::optional& tensor) { if (tensor) { return {phi::MetaTensor(*tensor)}; @@ -103,7 +94,7 @@ inline paddle::optional MakeMetaTensor( /* ------------------ for output ----------------------- */ -inline phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) { +phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) { if (!out->initialized()) { auto dense_tensor = std::make_shared( phi::make_intrusive(phi::TransToPhiPlace(backend)), @@ -114,8 +105,9 @@ inline phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) { return static_cast(out->impl().get()); } -inline std::vector SetKernelOutput( - size_t out_size, Backend backend, std::vector* out) { +std::vector SetKernelOutput(size_t out_size, + Backend backend, + std::vector* out) { out->reserve(out_size); std::vector results(out_size); for (size_t i = 0; i < out_size; ++i) { @@ -129,8 +121,7 @@ inline std::vector SetKernelOutput( return results; } -inline phi::SelectedRows* SetSelectedRowsKernelOutput(Backend backend, - Tensor* out) { +phi::SelectedRows* SetSelectedRowsKernelOutput(Backend backend, Tensor* out) { if (!out->initialized()) { auto select_rows = std::make_shared(); out->set_impl(select_rows); @@ -139,5 +130,29 @@ inline phi::SelectedRows* SetSelectedRowsKernelOutput(Backend backend, return static_cast(out->impl().get()); } +phi::TensorBase* SetSparseKernelOutput(Tensor* out, TensorType type) { + if (!out->initialized()) { + if (type == TensorType::SPARSE_COO) { + auto sparse_tensor = std::make_shared( + phi::DenseTensor(), phi::DenseTensor(), phi::DDim{-1}); + out->set_impl(sparse_tensor); + return sparse_tensor.get(); + } else if (type == TensorType::SPARSE_CSR) { + auto sparse_tensor = + std::make_shared(phi::DenseTensor(), + phi::DenseTensor(), + phi::DenseTensor(), + phi::DDim{-1}); + out->set_impl(sparse_tensor); + return sparse_tensor.get(); + } else { + auto dense_tensor = std::make_shared(); + out->set_impl(dense_tensor); + return dense_tensor.get(); + } + } + return out->impl().get(); +} + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/lib/api_gen_utils.h b/paddle/phi/api/lib/api_gen_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..109c6e7ab71f5f889e63c410ee84aaad6c6b8110 --- /dev/null +++ b/paddle/phi/api/lib/api_gen_utils.h @@ -0,0 +1,74 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/api/lib/utils/storage.h" +#include "paddle/phi/core/compat/convert_utils.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/core/selected_rows.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" + +namespace paddle { +namespace experimental { + +enum class TensorType { DENSE_TENSOR, SPARSE_CSR, SPARSE_COO }; + +/* ------------------ for input ----------------------- */ + +std::shared_ptr TensorToDenseTensor(const Tensor& tensor); + +std::shared_ptr TensorToDenseTensor( + const paddle::optional& tensor); + +std::unique_ptr> TensorToDenseTensor( + const std::vector& tensors); + +std::shared_ptr TensorToSelectedRows(const Tensor& tensor); + +std::shared_ptr TensorToSelectedRows( + const paddle::optional& tensor); + +/* ----------------- for infer_meta --------------------- */ + +phi::MetaTensor MakeMetaTensor(const phi::DenseTensor& tensor); + +paddle::optional MakeMetaTensor( + const paddle::optional& tensor); + +std::vector MakeMetaTensor( + const std::vector& tensors); + +phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor); + +paddle::optional MakeMetaTensor( + const paddle::optional& tensor); + +/* ------------------ for output ----------------------- */ + +phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out); + +std::vector SetKernelOutput(size_t out_size, + Backend backend, + std::vector* out); + +phi::SelectedRows* SetSelectedRowsKernelOutput(Backend backend, Tensor* out); + +phi::TensorBase* SetSparseKernelOutput(Tensor* out, TensorType type); + +} // namespace experimental +} // namespace paddle diff --git a/paddle/phi/api/lib/sparse_api.cc b/paddle/phi/api/lib/sparse_api_custom_impl.cc similarity index 86% rename from paddle/phi/api/lib/sparse_api.cc rename to paddle/phi/api/lib/sparse_api_custom_impl.cc index 9e1f59c0aa74329b15efcbff123b137fbf0b1360..832c19361e5eb03419fe988c9a30304b5993afdf 100644 --- a/paddle/phi/api/lib/sparse_api.cc +++ b/paddle/phi/api/lib/sparse_api_custom_impl.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/api/include/sparse_api.h" +#include "paddle/phi/api/lib/sparse_api_custom_impl.h" #include #include "glog/logging.h" @@ -20,31 +20,14 @@ limitations under the License. */ #include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/utils/storage.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/infermeta/unary.h" - -PD_DECLARE_KERNEL(dense_to_sparse_coo, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(sparse_csr_to_coo, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(dense_to_sparse_csr, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(sparse_coo_to_csr, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(sparse_coo_to_dense, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(sparse_csr_to_dense, CPU, ALL_LAYOUT); - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_DECLARE_KERNEL(dense_to_sparse_coo, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(sparse_csr_to_coo, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(dense_to_sparse_csr, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(sparse_coo_to_csr, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(sparse_coo_to_dense, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(sparse_csr_to_dense, GPU, ALL_LAYOUT); -#endif namespace paddle { namespace experimental { namespace sparse { -PADDLE_API Tensor to_sparse_coo(const Tensor& x, - Backend backend, - const int64_t sparse_dim) { +Tensor to_sparse_coo_impl(const Tensor& x, + Backend backend, + const int64_t sparse_dim) { if (x.layout() == phi::DataLayout::SPARSE_COO) { return x; } @@ -105,7 +88,7 @@ PADDLE_API Tensor to_sparse_coo(const Tensor& x, return out; } -PADDLE_API Tensor to_sparse_csr(const Tensor& x, Backend backend) { +Tensor to_sparse_csr_impl(const Tensor& x, Backend backend) { if (x.layout() == phi::DataLayout::SPARSE_CSR) { return x; } @@ -171,7 +154,7 @@ PADDLE_API Tensor to_sparse_csr(const Tensor& x, Backend backend) { return out; } -PADDLE_API Tensor to_dense(const Tensor& x, Backend backend) { +Tensor to_dense_impl(const Tensor& x, Backend backend) { if (x.layout() != phi::DataLayout::SPARSE_CSR && x.layout() != phi::DataLayout::SPARSE_COO) { return x; diff --git a/paddle/phi/api/include/sparse_api.h b/paddle/phi/api/lib/sparse_api_custom_impl.h similarity index 74% rename from paddle/phi/api/include/sparse_api.h rename to paddle/phi/api/lib/sparse_api_custom_impl.h index a131804cd6f582c01586671a21851066910b21d4..293b2cfa3d33480ccccd0f601f8e15c639b93e1e 100644 --- a/paddle/phi/api/include/sparse_api.h +++ b/paddle/phi/api/lib/sparse_api_custom_impl.h @@ -21,13 +21,13 @@ namespace paddle { namespace experimental { namespace sparse { -PADDLE_API Tensor to_sparse_coo(const Tensor& x, - Backend backend, - const int64_t sparse_dim); +Tensor to_dense_impl(const Tensor& x, Backend backend); -PADDLE_API Tensor to_sparse_csr(const Tensor& x, Backend backend); +Tensor to_sparse_coo_impl(const Tensor& x, + Backend backend, + const int64_t sparse_dim); -PADDLE_API Tensor to_dense(const Tensor& x, Backend backend); +Tensor to_sparse_csr_impl(const Tensor& x, Backend backend); } // namespace sparse } // namespace experimental diff --git a/paddle/phi/kernels/sparse/cpu/convolution.h b/paddle/phi/kernels/sparse/cpu/convolution.h index ab2fef5320f716b6bc780ad14b8e2adef44427dd..1031f76917920adba26ec75d166f18d85435be70 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution.h +++ b/paddle/phi/kernels/sparse/cpu/convolution.h @@ -107,7 +107,9 @@ void ProductRuleBook(const Context& dev_ctx, f_calc_rulebook(nullptr); // alloc the rulebook - rulebook->ResizeAndAllocate({3, rulebook_len}); + DenseTensorMeta rulebook_meta( + DataType::INT32, {3, rulebook_len}, DataLayout::NCHW); + rulebook->set_meta(rulebook_meta); dev_ctx.Alloc(rulebook, rulebook->dtype(), rulebook->numel() * sizeof(int)); int* rulebook_ptr = rulebook->data(); f_calc_rulebook(rulebook_ptr); diff --git a/paddle/phi/tests/api/CMakeLists.txt b/paddle/phi/tests/api/CMakeLists.txt index cde085423e482e62a280815700ead9a0b6c64262..be12960d1d675b46987996713a0631399d1f0652 100644 --- a/paddle/phi/tests/api/CMakeLists.txt +++ b/paddle/phi/tests/api/CMakeLists.txt @@ -25,3 +25,4 @@ cc_test(test_concat_api SRCS test_concat_api.cc DEPS phi_tensor phi_api phi_api_ cc_test(test_split_api SRCS test_split_api.cc DEPS phi_tensor phi_api phi_api_utils) cc_test(test_data_transform SRCS test_data_transform.cc DEPS phi_tensor phi_api phi_api_utils) cc_test(test_sparse_utils_api SRCS test_sparse_utils_api.cc DEPS phi_tensor phi_api phi_api_utils) +cc_test(test_sparse_conv_api SRCS test_sparse_conv_api.cc DEPS phi_tensor phi_api phi_api_utils) diff --git a/paddle/phi/tests/api/test_sparse_conv_api.cc b/paddle/phi/tests/api/test_sparse_conv_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..16d7cb66f4cc5f14abf31bb0a16d58c266bc15fb --- /dev/null +++ b/paddle/phi/tests/api/test_sparse_conv_api.cc @@ -0,0 +1,174 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See +the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/phi/api/include/api.h" + +#include "paddle/phi/api/include/sparse_api.h" + +#include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/sparse_coo_tensor.h" + +template +void TestConv3dBase(const std::vector& indices, + const std::vector& features, + const phi::DDim& x_dims, + const std::vector& kernel, + const phi::DDim& kernel_dims, + const std::vector& correct_out_indices, + const std::vector& correct_out_features, + const phi::DDim& correct_out_dims, + const int non_zero_num, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const float diff = 1e-3) { + const auto alloc = std::make_unique( + paddle::platform::CPUPlace()); + + const int in_channels = kernel_dims[3]; + const int out_channels = kernel_dims[4]; + + phi::DenseTensor indices_tensor( + alloc.get(), + phi::DenseTensorMeta( + phi::DataType::INT32, {4, non_zero_num}, phi::DataLayout::NCHW)); + memcpy( + indices_tensor.data(), indices.data(), indices.size() * sizeof(int)); + + phi::DenseTensor features_tensor( + alloc.get(), + phi::DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), + {non_zero_num, in_channels}, + phi::DataLayout::NHWC)); + memcpy( + features_tensor.data(), features.data(), features.size() * sizeof(T)); + + auto x_tensor = std::make_shared( + indices_tensor, features_tensor, x_dims); + paddle::experimental::Tensor x(x_tensor); + + auto kernel_tensor = std::make_shared( + alloc.get(), + phi::DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), + kernel_dims, + phi::DataLayout::NHWC)); + paddle::experimental::Tensor weight(kernel_tensor); + + memcpy(kernel_tensor->mutable_data(paddle::platform::CPUPlace()), + kernel.data(), + kernel.size() * sizeof(T)); + + if (!std::is_same::value) { + auto outs = paddle::experimental::sparse::conv3d( + x, weight, paddings, dilations, strides, 1); + + auto out = std::dynamic_pointer_cast( + std::get<0>(outs).impl()); + ASSERT_EQ(correct_out_dims.size(), out->dims().size()); + for (int i = 0; i < correct_out_dims.size(); i++) { + ASSERT_EQ(correct_out_dims[i], out->dims()[i]); + } + ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, out->nnz()); + + int cmp_indices = memcmp(correct_out_indices.data(), + out->non_zero_indices().data(), + correct_out_indices.size() * sizeof(int)); + ASSERT_EQ(cmp_indices, 0); + + for (uint64_t i = 0; i < correct_out_features.size(); i++) { + float tmp = std::fabs(static_cast( + correct_out_features[i] - out->non_zero_elements().data()[i])); + ASSERT_LT(tmp, diff); + } + } +} + +void TestConv3d(const std::vector& indices, + const std::vector& features, + const phi::DDim& x_dims, + const std::vector& kernel, + const phi::DDim& kernel_dims, + const std::vector& correct_out_indices, + const std::vector& correct_out_features, + const phi::DDim& correct_out_dims, + const int non_zero_num, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations) { + // test float + TestConv3dBase(indices, + features, + x_dims, + kernel, + kernel_dims, + correct_out_indices, + correct_out_features, + correct_out_dims, + non_zero_num, + paddings, + strides, + dilations); +} + +TEST(API, sparse_conv2d) { + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + const int in_channels = 1; + const int out_channels = 1; + phi::DDim x_dims = {1, 1, 5, 5, in_channels}; + phi::DDim kernel_dims = {1, 3, 3, in_channels, out_channels}; + phi::DDim out_dims = {1, 1, 3, 3, out_channels}; + std::vector paddings = {0, 0, 0}; + std::vector strides = {1, 1, 1}; + std::vector dilations = {1, 1, 1}; + + const int non_zero_num = 3; + std::vector indices_flatten = {0, 0, 0, 0, 0, 0, 0, 4, 0, 3, 2, 4}; + + std::vector features = {-0.79394531, -0.3125, -0.55029297}; + // 3*3*3=27 + std::vector kernel = {0.65820312, + 0.75048828, + 0.21411133, + 0.17370605, + 0.85546875, + 0.53076172, + 0.28833008, + 0.71044922, + 0.00659943}; + + std::vector out_indices_flatten = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 2, 2, 2, 1, 2, 0, 1, 2}; + + std::vector out_features = { + -0.17004, -0.71338, -0.00206, -0.22205, -0.09009}; + + TestConv3d(indices_flatten, + features, + x_dims, + kernel, + kernel_dims, + out_indices_flatten, + out_features, + out_dims, + non_zero_num, + paddings, + strides, + dilations); +} diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index cfd817c24c7367f69673353a8aaceeedec506e15..6c07cdec2ee19c9e689f354d4a5314049235402c 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -43,7 +43,9 @@ class BaseAPI(object): self.is_base_api = False self.invoke = api_item_yaml['invoke'] else: - self.infer_meta = self.parse_infer_meta(api_item_yaml['infer_meta']) + if 'infer_meta' in api_item_yaml: + self.infer_meta = self.parse_infer_meta(api_item_yaml[ + 'infer_meta']) self.kernel = self.parse_kernel(api_item_yaml['kernel']) self.support_selected_rows_kernel = False if len(self.kernel[ 'func']) == 1 else True @@ -182,9 +184,9 @@ class BaseAPI(object): 'Tensor': 'Tensor', 'Tensor[]': 'std::vector' } - if re.search(r'\(\w*\)', output_item): + if re.search(r'\([a-zA-Z0-9_@]*\)', output_item): result = re.search( - r"(?P[a-zA-Z0-9_[\]]+)\s*\((?P\w+)\)", + r"(?P[a-zA-Z0-9_[\]]+)\s*\((?P[a-zA-Z0-9_@]+)\)", output_item) out_type = result.group('out_type') assert out_type in output_type_map, \ @@ -499,11 +501,8 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self. def get_kernel_args(self, code_indent): input_trans_map = { 'const Tensor&': 'const phi::DenseTensor&', - 'const Tensor &': 'const phi::DenseTensor&', 'const std::vector&': 'const std::vector&', - 'const std::vector &': - 'const std::vector&', 'const paddle::optional&': 'paddle::optional', 'const paddle::optional>&': @@ -592,7 +591,6 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self. def get_selected_rows_kernel_args(self, code_indent): input_trans_map = { 'const Tensor&': 'const phi::SelectedRows&', - 'const Tensor &': 'const phi::SelectedRows&', 'const paddle::optional&': 'paddle::optional' } diff --git a/python/paddle/utils/code_gen/api_gen.py b/python/paddle/utils/code_gen/api_gen.py index a26630ad04100fbebdb7c270b83912bb722040d4..1bdfa8b66972eb0d4ff45509ada066ce92ae5f78 100644 --- a/python/paddle/utils/code_gen/api_gen.py +++ b/python/paddle/utils/code_gen/api_gen.py @@ -105,7 +105,7 @@ def source_include(header_file_path): #include "paddle/phi/api/lib/api_custom_impl.h" #include "paddle/phi/api/lib/api_registry.h" -#include "paddle/phi/api/lib/api_utils.h" +#include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/utils/storage.h" diff --git a/python/paddle/utils/code_gen/backward_api_gen.py b/python/paddle/utils/code_gen/backward_api_gen.py index 125ebed82de8b25b0a2c20ca7b76560966313566..b9f991f9b0f88daa3ae07cba33b439c073d8fbe0 100644 --- a/python/paddle/utils/code_gen/backward_api_gen.py +++ b/python/paddle/utils/code_gen/backward_api_gen.py @@ -146,7 +146,7 @@ def source_include(header_file_path): #include "paddle/phi/api/lib/api_custom_impl.h" #include "paddle/phi/api/lib/api_registry.h" -#include "paddle/phi/api/lib/api_utils.h" +#include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/utils/storage.h" diff --git a/python/paddle/utils/code_gen/sparse_api.yaml b/python/paddle/utils/code_gen/sparse_api.yaml new file mode 100644 index 0000000000000000000000000000000000000000..135989121cca695b0e629192774af0eb3e41c812 --- /dev/null +++ b/python/paddle/utils/code_gen/sparse_api.yaml @@ -0,0 +1,21 @@ +- sparse_api : conv3d + args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups) + output : Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor) + kernel : + func : sparse_conv3d + layout : x + +- sparse_api : to_dense + args : (Tensor x, Backend backend) + output : Tensor(out@DenseTensor) + invoke : to_dense_impl(x, backend) + +- sparse_api : to_sparse_coo + args : (Tensor x, Backend backend, int64_t sparse_dim) + output : Tensor(out@SparseCooTensor) + invoke : to_sparse_coo_impl(x, backend, sparse_dim) + +- sparse_api : to_sparse_csr + args : (Tensor x, Backend backend) + output : Tensor(out@SparseCsrTensor) + invoke : to_sparse_csr_impl(x, backend) diff --git a/python/paddle/utils/code_gen/sparse_api_gen.py b/python/paddle/utils/code_gen/sparse_api_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..99c5a4f49f8c41920135953ca02a17148164eb45 --- /dev/null +++ b/python/paddle/utils/code_gen/sparse_api_gen.py @@ -0,0 +1,282 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import yaml +import argparse +import re + +from api_base import BaseAPI + + +class SparseAPI(BaseAPI): + def __init__(self, api_item_yaml): + super(SparseAPI, self).__init__(api_item_yaml) + + def get_api_name(self, api_item_yaml): + return api_item_yaml['sparse_api'] + + def get_api_func_name(self): + return self.api + + def get_return_type(self, out_type_list): + return out_type_list[0] if len( + out_type_list) == 1 else "std::tuple<" + ",".join( + out_type_list) + ">" + + def gene_api_declaration(self): + return f""" +// {", ".join(self.outputs['names'])} +PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_str['args_declare']}); +""" + + def get_kernel_tensor_out_type(self, output_name): + sparse_type = 'TensorType::DENSE_TENSOR' + if output_name.endswith('@SparseCooTensor'): + sparse_type = 'TensorType::SPARSE_COO' + elif output_name.endswith('@SparseCsrTensor'): + sparse_type = 'TensorType::SPARSE_CSR' + return sparse_type + + def gene_output(self, + output_type_list, + set_out_func, + code_indent, + inplace_flag=False): + kernel_output = "" + output_names = [] + output_create = "" + + if len(output_type_list) == 1: + kernel_output = 'kernel_out' + output_names.append('kernel_out') + inplace_assign = " = " + self.inplace_map[self.outputs['names'][ + 0]] if inplace_flag and self.inplace_map is not None and self.outputs[ + 'names'][0] in self.inplace_map else "" + output_create = f""" + {self.outputs['return_type']} out{inplace_assign}; + auto* kernel_out = {set_out_func}(&out, {self.get_kernel_tensor_out_type(self.outputs['names'][0])});""" + + elif len(output_type_list) > 1: + output_create = f""" + {self.outputs['return_type']} out;""" + + for i in range(len(output_type_list)): + kernel_output = kernel_output + f'kernel_out_{i}, ' + output_names.append(f'kernel_out_{i}') + if inplace_flag and self.inplace_map is not None and self.outputs[ + 'names'][i] in self.inplace_map: + output_create = output_create + f""" + std::get<{i}>(out) = {self.inplace_map[self.outputs['names'][i]]};""" + + output_create = output_create + f""" + auto* kernel_out_{i} = {set_out_func}(&std::get<{i}>(out), {self.get_kernel_tensor_out_type(self.outputs['names'][i])});""" + + kernel_output = kernel_output[:-2] + else: + raise ValueError( + "{} : Output error: the output should not be empty.".format( + self.api)) + + return kernel_output, output_names, output_create + + def gen_sparse_kernel_context(self, kernel_output_names): + input_trans_map = { + 'const Tensor&': 'const phi::TenseBase&', + 'const std::vector&': 'const std::vector&', + 'const paddle::optional&': + 'paddle::optional' + } + out_trans_map = { + 'Tensor': 'phi::TenseBase*', + 'std::vector': 'std::vector' + } + input_names = self.inputs['names'] + input_infos = self.inputs['input_info'] + + attr_names = self.attrs['names'] + kernel_param = self.kernel['param'] + if kernel_param is None: + kernel_param = input_names + attr_names + + kernel_context_code = "" + for param in kernel_param: + if param in input_names: + if param in self.optional_vars: + raise ValueError( + f"{self.api} : Unsupport optional input({param}) for sparse api." + ) + else: + kernel_context_code = kernel_context_code + f""" + kernel_context.EmplaceBackInput({param}.impl().get());""" + + continue + if param in attr_names: + # set attr for kernel_context + if 'ScalarArray' in self.attrs['attr_info'][param][0]: + param = 'phi::ScalarArray(' + param + ')' + elif 'Scalar' in self.attrs['attr_info'][param][0]: + param = 'phi::Scalar(' + param + ')' + elif isinstance(param, bool): + param = str(param).lower() + else: + param + str(param) + ", " + kernel_context_code = kernel_context_code + f""" + kernel_context.EmplaceBackAttr({param});""" + + for out_name in kernel_output_names: + kernel_context_code = kernel_context_code + f""" + kernel_context.EmplaceBackOutput({out_name});""" + + return kernel_context_code + + def gen_sparse_kernel_code(self, inplace_flag=False): + _, kernel_output_names, output_create = self.gene_output( + self.outputs['types'], 'SetSparseKernelOutput', '', inplace_flag) + + kernel_context_code = self.gen_sparse_kernel_context( + kernel_output_names) + + return f""" + auto phi_kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}); + VLOG(6) << "{self.api} api sparse kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; + VLOG(6) << "{self.api} api sparse kernel: " << phi_kernel; + + auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); + auto kernel_context = phi::KernelContext(dev_ctx); +{output_create} +{kernel_context_code} + phi_kernel(&kernel_context); + + return out;""" + + def gene_base_api_code(self, inplace_flag=False): + api_func_name = self.get_api_func_name() + return f""" +PADDLE_API {self.outputs['return_type']} {api_func_name}({self.args_str["args_define"]}) {{ +{self.gene_kernel_select()} +{self.gen_sparse_kernel_code(inplace_flag)} +}} +""" + + +def header_include(): + return """ +#include + +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/common/scalar_array.h" +#include "paddle/utils/optional.h" +""" + + +def source_include(header_file_path): + return f""" +#include "{header_file_path}" +#include + +#include "glog/logging.h" + +#include "paddle/phi/api/lib/api_registry.h" +#include "paddle/phi/api/lib/api_gen_utils.h" +#include "paddle/phi/api/lib/data_transform.h" +#include "paddle/phi/api/lib/kernel_dispatch.h" +#include "paddle/phi/api/lib/sparse_api_custom_impl.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/declarations.h" +""" + + +def api_register(): + return """ +PD_REGISTER_API(Test); +""" + + +def api_namespace(): + return (""" +namespace paddle { +namespace experimental { +namespace sparse { + +""", """ + +} // namespace sparse +} // namespace experimental +} // namespace paddle +""") + + +def generate_api(api_yaml_path, header_file_path, source_file_path): + + with open(api_yaml_path, 'r') as f: + apis = yaml.load(f, Loader=yaml.FullLoader) + header_file = open(header_file_path, 'w') + source_file = open(source_file_path, 'w') + + namespace = api_namespace() + + header_file.write("#pragma once\n") + header_file.write(header_include()) + header_file.write(namespace[0]) + + include_header_file = "paddle/phi/api/include/sparse_api.h" + source_file.write(source_include(include_header_file)) + source_file.write(namespace[0]) + + for api in apis: + sparse_api = SparseAPI(api) + header_file.write(sparse_api.gene_api_declaration()) + source_file.write(sparse_api.gene_api_code()) + + header_file.write(namespace[1]) + source_file.write(namespace[1]) + + source_file.write(api_register()) + + header_file.close() + source_file.close() + + +def main(): + parser = argparse.ArgumentParser( + description='Generate PaddlePaddle C++ Sparse API files') + parser.add_argument( + '--api_yaml_path', + help='path to sparse api yaml file', + default='python/paddle/utils/code_gen/sparse_api.yaml') + + parser.add_argument( + '--api_header_path', + help='output of generated api header code file', + default='paddle/phi/api/include/sparse_api.h') + + parser.add_argument( + '--api_source_path', + help='output of generated api source code file', + default='paddle/phi/api/lib/sparse_api.cc') + + options = parser.parse_args() + + api_yaml_path = options.api_yaml_path + header_file_path = options.api_header_path + source_file_path = options.api_source_path + + generate_api(api_yaml_path, header_file_path, source_file_path) + + +if __name__ == '__main__': + main()