未验证 提交 3328eb03 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Add slice api implemention for Tensor (#37276)

* add slice api impl of Tensor

* fix test slice error
上级 ca8c4f3e
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/pten/api/lib/ext_compat_utils.h" #include "paddle/pten/api/lib/ext_compat_utils.h"
#include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/compat_utils.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/tensor_base.h" #include "paddle/pten/core/tensor_base.h"
#include "paddle/pten/core/tensor_meta.h" #include "paddle/pten/core/tensor_meta.h"
...@@ -236,11 +237,18 @@ template PD_DLL_DECL paddle::platform::complex<double> ...@@ -236,11 +237,18 @@ template PD_DLL_DECL paddle::platform::complex<double>
template PD_DLL_DECL paddle::platform::float16 * template PD_DLL_DECL paddle::platform::float16 *
Tensor::data<paddle::platform::float16>(); Tensor::data<paddle::platform::float16>();
// TODO(chenweihang): replace slice impl by API
Tensor Tensor::slice(const int64_t begin_idx, const int64_t end_idx) const { Tensor Tensor::slice(const int64_t begin_idx, const int64_t end_idx) const {
if (detail::IsDenseTensor(impl_)) {
return Tensor(std::make_shared<pten::DenseTensor>(
std::move(pten::CompatibleDenseTensorUtils::Slice(
std::dynamic_pointer_cast<pten::DenseTensor>(impl_).get(),
begin_idx,
end_idx))));
} else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"The slice operation is not supported now, " "Only supported slice operation on DenseTensor now."));
"and it will be implemented by calling the slice kernel later.")); }
return Tensor();
} }
std::shared_ptr<pten::TensorBase> Tensor::impl() const { return impl_; } std::shared_ptr<pten::TensorBase> Tensor::impl() const { return impl_; }
......
...@@ -15,3 +15,4 @@ cc_test(test_fill_api SRCS test_fill_api.cc DEPS pten_tensor pten_api pten_api_u ...@@ -15,3 +15,4 @@ cc_test(test_fill_api SRCS test_fill_api.cc DEPS pten_tensor pten_api pten_api_u
cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_slice_api SRCS test_slice_api.cc DEPS pten_tensor pten_api pten_api_utils)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/api/include/creation.h"
#include "paddle/pten/api/include/tensor.h"
#include "paddle/pten/core/kernel_registry.h"
PT_DECLARE_MODULE(CreationCPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(CreationCUDA);
#endif
namespace pten {
namespace tests {
TEST(Tensor, slice) {
auto x = paddle::experimental::full({4, 3}, 1, pten::DataType::INT64);
auto slice_x = x.slice(1, 2);
// check slice result
ASSERT_EQ(slice_x.dims().size(), 2);
ASSERT_EQ(slice_x.dims()[0], 1);
ASSERT_EQ(slice_x.dims()[1], 3);
ASSERT_EQ(slice_x.numel(), 3);
ASSERT_EQ(slice_x.is_cpu(), true);
ASSERT_EQ(slice_x.type(), pten::DataType::INT64);
ASSERT_EQ(slice_x.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(slice_x.initialized(), true);
for (int64_t i = 0; i < slice_x.numel(); ++i) {
ASSERT_EQ(slice_x.mutable_data<int64_t>()[i], 1);
}
}
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册