From 3328eb033e9a66dcc5904f31df5f9f8f0e321a01 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 17 Nov 2021 03:32:55 -0600 Subject: [PATCH] [PTen] Add slice api implemention for Tensor (#37276) * add slice api impl of Tensor * fix test slice error --- paddle/pten/api/lib/tensor.cc | 16 ++++++-- paddle/pten/tests/api/CMakeLists.txt | 1 + paddle/pten/tests/api/test_slice_api.cc | 50 +++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 4 deletions(-) create mode 100644 paddle/pten/tests/api/test_slice_api.cc diff --git a/paddle/pten/api/lib/tensor.cc b/paddle/pten/api/lib/tensor.cc index bb3fba88586..db5fa9f671f 100644 --- a/paddle/pten/api/lib/tensor.cc +++ b/paddle/pten/api/lib/tensor.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/pten/api/lib/ext_compat_utils.h" #include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/api/lib/utils/storage.h" +#include "paddle/pten/core/compat_utils.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/tensor_base.h" #include "paddle/pten/core/tensor_meta.h" @@ -236,11 +237,18 @@ template PD_DLL_DECL paddle::platform::complex template PD_DLL_DECL paddle::platform::float16 * Tensor::data(); +// TODO(chenweihang): replace slice impl by API Tensor Tensor::slice(const int64_t begin_idx, const int64_t end_idx) const { - PADDLE_THROW(platform::errors::Unimplemented( - "The slice operation is not supported now, " - "and it will be implemented by calling the slice kernel later.")); - return Tensor(); + if (detail::IsDenseTensor(impl_)) { + return Tensor(std::make_shared( + std::move(pten::CompatibleDenseTensorUtils::Slice( + std::dynamic_pointer_cast(impl_).get(), + begin_idx, + end_idx)))); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Only supported slice operation on DenseTensor now.")); + } } std::shared_ptr Tensor::impl() const { return impl_; } diff --git a/paddle/pten/tests/api/CMakeLists.txt b/paddle/pten/tests/api/CMakeLists.txt index 5bc5f0ace88..9e688d8200b 100644 --- a/paddle/pten/tests/api/CMakeLists.txt +++ b/paddle/pten/tests/api/CMakeLists.txt @@ -15,3 +15,4 @@ cc_test(test_fill_api SRCS test_fill_api.cc DEPS pten_tensor pten_api pten_api_u cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS pten_tensor pten_api pten_api_utils) +cc_test(test_slice_api SRCS test_slice_api.cc DEPS pten_tensor pten_api pten_api_utils) diff --git a/paddle/pten/tests/api/test_slice_api.cc b/paddle/pten/tests/api/test_slice_api.cc new file mode 100644 index 00000000000..eb8be21bcbb --- /dev/null +++ b/paddle/pten/tests/api/test_slice_api.cc @@ -0,0 +1,50 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/api/include/creation.h" +#include "paddle/pten/api/include/tensor.h" +#include "paddle/pten/core/kernel_registry.h" + +PT_DECLARE_MODULE(CreationCPU); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_DECLARE_MODULE(CreationCUDA); +#endif + +namespace pten { +namespace tests { + +TEST(Tensor, slice) { + auto x = paddle::experimental::full({4, 3}, 1, pten::DataType::INT64); + auto slice_x = x.slice(1, 2); + + // check slice result + ASSERT_EQ(slice_x.dims().size(), 2); + ASSERT_EQ(slice_x.dims()[0], 1); + ASSERT_EQ(slice_x.dims()[1], 3); + ASSERT_EQ(slice_x.numel(), 3); + ASSERT_EQ(slice_x.is_cpu(), true); + ASSERT_EQ(slice_x.type(), pten::DataType::INT64); + ASSERT_EQ(slice_x.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(slice_x.initialized(), true); + for (int64_t i = 0; i < slice_x.numel(); ++i) { + ASSERT_EQ(slice_x.mutable_data()[i], 1); + } +} + +} // namespace tests +} // namespace pten -- GitLab