From 3f011d82538c9c52bf0c9aee9448bb879d96d642 Mon Sep 17 00:00:00 2001 From: Hao Lin Date: Wed, 11 Aug 2021 15:47:54 +0800 Subject: [PATCH] Add ext_tensor.slice() API (#34227) * Add ext_tensor.slice() API, test=develop * Call Tensor::mutable_data first to fix bugs and add test for writing to sliced tensor * Fix unit test bug * Fix code format problem, test=develop * Fix code format problem * Fix code format problem * strengthen unit test * Use CustomTensorUtils::ShareDataFrom to simplify codes --- paddle/fluid/extension/include/ext_tensor.h | 12 ++++++- paddle/fluid/extension/src/ext_tensor.cc | 25 +++++++++++++ paddle/fluid/framework/custom_tensor_test.cc | 37 ++++++++++++++++++++ 3 files changed, 73 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/extension/include/ext_tensor.h b/paddle/fluid/extension/include/ext_tensor.h index d40503409fb..7d13f56b02b 100644 --- a/paddle/fluid/extension/include/ext_tensor.h +++ b/paddle/fluid/extension/include/ext_tensor.h @@ -88,10 +88,20 @@ class PD_DLL_DECL Tensor { /// It's usually used to set the input tensor data. /// \param PlaceType of target place, of which /// the tensor will copy to. - template Tensor copy_to(const PlaceType& place) const; + /// \brief Return a sub-tensor of the given tensor. + /// It is usually used to extract a sub-tensor (which supports + /// modifying the data of the original tensor) to perform further + /// operations. + /// \param begin_idx The index of the start row (inclusive) to slice. + /// The index number begins from 0. + /// \param end_idx The index of the end row (exclusive) to slice. + /// The index number begins from begin_idx + 1. + /// \return The sliced tensor. + Tensor slice(const int64_t begin_idx, const int64_t end_idx) const; + /// \brief Return the shape of the Tensor. std::vector shape() const; diff --git a/paddle/fluid/extension/src/ext_tensor.cc b/paddle/fluid/extension/src/ext_tensor.cc index a9e286b4f9b..317fb7b2270 100644 --- a/paddle/fluid/extension/src/ext_tensor.cc +++ b/paddle/fluid/extension/src/ext_tensor.cc @@ -124,6 +124,21 @@ void DeviceCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc, } \ auto *tensor = static_cast(tensor_.get()); +#define GET_INNER_PLACE \ + platform::Place place; \ + switch (place_) { \ + case PlaceType::kCPU: \ + place = platform::CPUPlace(); \ + break; \ + case PlaceType::kGPU: \ + place = platform::CUDAPlace(); \ + break; \ + default: \ + PADDLE_THROW(platform::errors::Unavailable( \ + "Custom operator unsupported place id(%d)", \ + static_cast(place_))); \ + } + void Tensor::reshape(const std::vector &shape) { GET_CASTED_TENSOR auto new_dim = framework::make_ddim(shape); @@ -257,6 +272,16 @@ Tensor Tensor::copy_to(const PlaceType &target_place) const { return target; } +Tensor Tensor::slice(const int64_t begin_idx, const int64_t end_idx) const { + GET_CASTED_TENSOR + GET_INNER_PLACE + framework::Tensor intermediate = tensor->Slice(begin_idx, end_idx); + Tensor target = Tensor(place_); + framework::CustomTensorUtils::ShareDataFrom( + static_cast(&intermediate), target); + return target; +} + template PD_DLL_DECL Tensor Tensor::copy_to(const PlaceType &target_place) const; template PD_DLL_DECL Tensor diff --git a/paddle/fluid/framework/custom_tensor_test.cc b/paddle/fluid/framework/custom_tensor_test.cc index b2896c74c39..7fbc4f554ba 100644 --- a/paddle/fluid/framework/custom_tensor_test.cc +++ b/paddle/fluid/framework/custom_tensor_test.cc @@ -92,6 +92,41 @@ void TestAPISizeAndShape() { CHECK(t1.shape() == tensor_shape); } +void TestAPISlice() { + std::vector tensor_shape_origin1 = {5, 5}; + std::vector tensor_shape_sub1 = {3, 5}; + std::vector tensor_shape_origin2 = {5, 5, 5}; + std::vector tensor_shape_sub2 = {1, 5, 5}; +#ifdef PADDLE_WITH_CUDA + auto t1 = paddle::Tensor(paddle::PlaceType::kGPU, tensor_shape_origin1); + t1.mutable_data(); + CHECK(t1.slice(0, 5).shape() == tensor_shape_origin1); + CHECK(t1.slice(0, 3).shape() == tensor_shape_sub1); + auto t2 = paddle::Tensor(paddle::PlaceType::kGPU, tensor_shape_origin2); + t2.mutable_data(); + CHECK(t2.slice(4, 5).shape() == tensor_shape_sub2); +#endif + auto t3 = paddle::Tensor(paddle::PlaceType::kCPU, tensor_shape_origin1); + t3.mutable_data(); + CHECK(t3.slice(0, 5).shape() == tensor_shape_origin1); + CHECK(t3.slice(0, 3).shape() == tensor_shape_sub1); + auto t4 = paddle::Tensor(paddle::PlaceType::kCPU, tensor_shape_origin2); + t4.mutable_data(); + CHECK(t4.slice(4, 5).shape() == tensor_shape_sub2); + + // Test writing function for sliced tensor + auto t = InitCPUTensorForTest(); + auto t_sliced = t.slice(0, 1); + auto* t_sliced_data_ptr = t_sliced.mutable_data(); + for (int64_t i = 0; i < t_sliced.size(); i++) { + t_sliced_data_ptr[i] += static_cast(5); + } + auto* t_data_ptr = t.mutable_data(); + for (int64_t i = 0; i < t_sliced.size(); i++) { + CHECK_EQ(t_data_ptr[i], static_cast(10)); + } +} + template paddle::DataType TestDtype() { std::vector tensor_shape = {5, 5}; @@ -261,6 +296,8 @@ TEST(CustomTensor, copyTest) { TestAPISizeAndShape(); VLOG(2) << "TestPlace"; TestAPIPlace(); + VLOG(2) << "TestSlice"; + TestAPISlice(); VLOG(2) << "TestCast"; GroupTestCast(); VLOG(2) << "TestDtypeConvert"; -- GitLab