diff --git a/paddle/phi/kernels/sparse/CMakeLists.txt b/paddle/phi/kernels/sparse/CMakeLists.txt index eaea6d952167c149f8add498768b26dd0d54f16a..479d53042949861a4679715a3a9a6250d03beda0 100644 --- a/paddle/phi/kernels/sparse/CMakeLists.txt +++ b/paddle/phi/kernels/sparse/CMakeLists.txt @@ -1,3 +1,3 @@ -set(SPARSE_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils math_function custom_kernel) -register_kernels(DEPS ${SPARSE_KERNEL_DEPS} SUB_DIR "sparse_kernel") +set(SPARSE_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils math_function custom_kernel copy_kernel) +register_kernels(DEPS ${SPARSE_KERNEL_DEPS} SUB_DIR "sparse") diff --git a/paddle/phi/kernels/sparse/sparse_activation_grad_kernel.cc b/paddle/phi/kernels/sparse/sparse_activation_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..ddfd51bc79b5d9e58034f865a772c4887cf8b6f4 --- /dev/null +++ b/paddle/phi/kernels/sparse/sparse_activation_grad_kernel.cc @@ -0,0 +1,70 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/sparse_activation_grad_kernel.h" +#include "paddle/phi/kernels/activation_grad_kernel.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace sparse { + +template +void SparseReluGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& out_grad, + SparseCooTensor* x_grad) { + DenseTensor non_zero_indices = + phi::EmptyLike(dev_ctx, x.non_zero_indices()); + DenseTensor non_zero_elements = + phi::EmptyLike(dev_ctx, x.non_zero_elements()); + phi::Copy(dev_ctx, + x.non_zero_indices(), + dev_ctx.GetPlace(), + false, + &non_zero_indices); + phi::ReluGradKernel(dev_ctx, + x.non_zero_elements(), + out_grad.non_zero_elements(), + &non_zero_elements); + x_grad->SetMember(non_zero_indices, non_zero_elements, x.dims(), true); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(sparse_relu_grad, + CPU, + ALL_LAYOUT, + phi::sparse::SparseReluGradKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(sparse_relu_grad, + GPU, + ALL_LAYOUT, + phi::sparse::SparseReluGradKernel, + float, + double, + phi::dtype::float16) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} +#endif diff --git a/paddle/phi/kernels/sparse/sparse_activation_grad_kernel.h b/paddle/phi/kernels/sparse/sparse_activation_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..aab4a3e5a590b5cc14883b9c6619e6150687ab4b --- /dev/null +++ b/paddle/phi/kernels/sparse/sparse_activation_grad_kernel.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/sparse_coo_tensor.h" + +namespace phi { +namespace sparse { + +template +void SparseReluGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& out_grad, + SparseCooTensor* x_grad); + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/sparse_activation_kernel.cc b/paddle/phi/kernels/sparse/sparse_activation_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f6bcda8262fa6e09190243836c6dbc9d30716fa --- /dev/null +++ b/paddle/phi/kernels/sparse/sparse_activation_kernel.cc @@ -0,0 +1,66 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/sparse_activation_kernel.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace sparse { + +template +void SparseReluKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out) { + DenseTensor non_zero_indices = + phi::EmptyLike(dev_ctx, x.non_zero_indices()); + DenseTensor non_zero_elements = + phi::EmptyLike(dev_ctx, x.non_zero_elements()); + phi::Copy(dev_ctx, + x.non_zero_indices(), + dev_ctx.GetPlace(), + false, + &non_zero_indices); + phi::ReluKernel( + dev_ctx, x.non_zero_elements(), &non_zero_elements); + out->SetMember(non_zero_indices, non_zero_elements, x.dims(), true); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(sparse_relu, + CPU, + ALL_LAYOUT, + phi::sparse::SparseReluKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(sparse_relu, + GPU, + ALL_LAYOUT, + phi::sparse::SparseReluKernel, + float, + double, + phi::dtype::float16) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} +#endif diff --git a/paddle/phi/kernels/sparse/sparse_activation_kernel.h b/paddle/phi/kernels/sparse/sparse_activation_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..568c0aa8b2ecb35b54cddbe67894d20efca9a348 --- /dev/null +++ b/paddle/phi/kernels/sparse/sparse_activation_kernel.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/kernels/activation_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" + +namespace phi { +namespace sparse { + +template +void SparseReluKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out); + +template +SparseCooTensor SparseRelu(const Context& dev_ctx, const SparseCooTensor& x) { + DenseTensor indices, values; + SparseCooTensor coo(indices, values, x.dims()); + SparseReluKernel(dev_ctx, x, &coo); + return coo; +} + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/tests/kernels/CMakeLists.txt b/paddle/phi/tests/kernels/CMakeLists.txt index f845cb419bfa07b53ebb7ab91f50e63103cad001..a02e4f3d57aa3badbcf57e5c61e81c25cc122e46 100644 --- a/paddle/phi/tests/kernels/CMakeLists.txt +++ b/paddle/phi/tests/kernels/CMakeLists.txt @@ -15,6 +15,7 @@ cc_test(test_split_dev_api SRCS test_split_dev_api.cc DEPS phi phi_api_utils) cc_test(test_sparse_utils_dev_api SRCS test_sparse_utils_dev_api.cc DEPS phi phi_api_utils) cc_test(test_sparse_conv3d_dev_api SRCS test_sparse_conv3d_dev_api.cc DEPS phi phi_api_utils) cc_test(test_sparse_pool_dev_api SRCS test_sparse_pool_dev_api.cc DEPS phi phi_api_utils) +cc_test(test_sparse_activation_dev_api SRCS test_sparse_activation_dev_api.cc DEPS phi phi_api_utils) cc_test(test_math_function SRCS test_math_function.cc DEPS math_function) if(WITH_GPU) diff --git a/paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..ee4d021d397b0fae33ea227574cb5657fdd6bf58 --- /dev/null +++ b/paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc @@ -0,0 +1,83 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/place.h" + +#include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/activation_grad_kernel.h" +#include "paddle/phi/kernels/activation_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/sparse/sparse_activation_grad_kernel.h" +#include "paddle/phi/kernels/sparse/sparse_activation_kernel.h" +#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" + +namespace phi { +namespace tests { + +TEST(DEV_API, sparse_relu) { + std::vector data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0}; + phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.Init(); + + DenseTensor dense_x = + phi::Empty(dev_ctx_cpu, + DenseTensorMeta(DataType::FLOAT32, {3, 4}, DataLayout::NCHW)); + memcpy(dense_x.data(), data.data(), data.size() * sizeof(float)); + auto sparse_coo = sparse::DenseToSparseCoo(dev_ctx_cpu, dense_x, 2); + + auto sparse_out = sparse::SparseRelu(dev_ctx_cpu, sparse_coo); + DenseTensor dense_out = + phi::EmptyLike(dev_ctx_cpu, sparse_out.non_zero_elements()); + ReluKernel(dev_ctx_cpu, sparse_coo.non_zero_elements(), &dense_out); + + int cmp = memcmp(dense_out.data(), + sparse_out.non_zero_elements().data(), + dense_out.numel() * sizeof(float)); + ASSERT_EQ(cmp, 0); + // backward + DenseTensor dense_grad_x = phi::EmptyLike(dev_ctx_cpu, dense_out); + ReluGradKernel( + dev_ctx_cpu, sparse_coo.non_zero_elements(), dense_out, &dense_grad_x); + SparseCooTensor sparse_grad_x( + phi::EmptyLike(dev_ctx_cpu, sparse_coo.non_zero_indices()), + phi::EmptyLike(dev_ctx_cpu, sparse_coo.non_zero_elements()), + {3, 4}); + + SparseCooTensor sparse_out_grad( + sparse_coo.non_zero_indices(), dense_out, {3, 4}); + sparse::SparseReluGradKernel( + dev_ctx_cpu, sparse_coo, sparse_out_grad, &sparse_grad_x); + + cmp = memcmp(dense_grad_x.data(), + sparse_grad_x.non_zero_elements().data(), + dense_grad_x.numel() * sizeof(float)); + ASSERT_EQ(cmp, 0); +} + +} // namespace tests +} // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_sparse_activation_op.py b/python/paddle/fluid/tests/unittests/test_sparse_activation_op.py new file mode 100644 index 0000000000000000000000000000000000000000..df13ae4e4b7fffe7945c434dd78a0b3cc5fdf42e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_activation_op.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle +from paddle import _C_ops +from paddle.fluid.framework import _test_eager_guard + + +class TestSparseActivation(unittest.TestCase): + def test_sparse_relu(self): + with _test_eager_guard(): + x = [[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]] + dense_x = paddle.to_tensor(x, dtype='float32') + dense_shape = [3, 4] + stop_gradient = True + sparse_dim = 2 + sparse_coo_x = dense_x.to_sparse_coo(sparse_dim) + #TODO(zhangkaihuo): change to test the corresponding API: paddle.sparse.relu(sparse_coo_x) + sparse_act_out = _C_ops.final_state_sparse_relu(sparse_coo_x) + correct_result = [0, 2, 0, 4, 5] + actual_result = sparse_act_out.non_zero_elements().numpy() + assert np.array_equal(correct_result, actual_result) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/to_string.py b/python/paddle/tensor/to_string.py index 91e5cfe97c6cdb503fba343a8a8d16a956aaffaf..f164bbc466f18da9b7145533c32369a85d6124df 100644 --- a/python/paddle/tensor/to_string.py +++ b/python/paddle/tensor/to_string.py @@ -317,7 +317,7 @@ def tensor_to_string(tensor, prefix='Tensor'): _template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient},\n{indent}{data})" - if not tensor._is_dense_tensor_hold_allocation(): + if not tensor._is_initialized(): return "Tensor(Not initialized)" if tensor.is_sparse(): diff --git a/python/paddle/utils/code_gen/sparse_api.yaml b/python/paddle/utils/code_gen/sparse_api.yaml index 293fdc1528a122a9ce2c446bb8b1fa917e47c637..24e965f85c53ff32599ba28739e6ec9672e1bce8 100644 --- a/python/paddle/utils/code_gen/sparse_api.yaml +++ b/python/paddle/utils/code_gen/sparse_api.yaml @@ -21,3 +21,11 @@ args : (Tensor x) output : Tensor(out@SparseCsrTensor) invoke : to_sparse_csr_impl(x) + +- api : relu + args : (Tensor x) + output : Tensor(out@SparseCooTensor) + kernel : + func : sparse_relu + layout : x + backward : sparse_relu_grad diff --git a/python/paddle/utils/code_gen/sparse_bw_api.yaml b/python/paddle/utils/code_gen/sparse_bw_api.yaml index 6532f103cbf86288ffc739656440dc378d48eb2d..711b4cedc59a586cdf5bf9e21f90e531cc6cbbc6 100644 --- a/python/paddle/utils/code_gen/sparse_bw_api.yaml +++ b/python/paddle/utils/code_gen/sparse_bw_api.yaml @@ -4,3 +4,10 @@ output : Tensor(x_grad@DenseTensor), Tensor(kernel_grad@DenseTensor) kernel : func : sparse_conv3d_grad + +- backward_api : sparse_relu_grad + forward : sparse_relu(Tensor x) -> Tensor(out@SparseCooTensor) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad@SparseCooTensor) + kernel : + func : sparse_relu_grad