未验证 提交 3f011d82 编写于 作者: H Hao Lin 提交者: GitHub

Add ext_tensor.slice() API (#34227)

* Add ext_tensor.slice() API, test=develop

* Call Tensor::mutable_data first to fix bugs and add test for writing to sliced tensor

* Fix unit test bug

* Fix code format problem, test=develop

* Fix code format problem

* Fix code format problem

* strengthen unit test

* Use CustomTensorUtils::ShareDataFrom to simplify codes
上级 4d7af372
......@@ -88,10 +88,20 @@ class PD_DLL_DECL Tensor {
/// It's usually used to set the input tensor data.
/// \param PlaceType of target place, of which
/// the tensor will copy to.
template <typename T>
Tensor copy_to(const PlaceType& place) const;
/// \brief Return a sub-tensor of the given tensor.
/// It is usually used to extract a sub-tensor (which supports
/// modifying the data of the original tensor) to perform further
/// operations.
/// \param begin_idx The index of the start row (inclusive) to slice.
/// The index number begins from 0.
/// \param end_idx The index of the end row (exclusive) to slice.
/// The index number begins from begin_idx + 1.
/// \return The sliced tensor.
Tensor slice(const int64_t begin_idx, const int64_t end_idx) const;
/// \brief Return the shape of the Tensor.
std::vector<int64_t> shape() const;
......
......@@ -124,6 +124,21 @@ void DeviceCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
} \
auto *tensor = static_cast<framework::LoDTensor *>(tensor_.get());
#define GET_INNER_PLACE \
platform::Place place; \
switch (place_) { \
case PlaceType::kCPU: \
place = platform::CPUPlace(); \
break; \
case PlaceType::kGPU: \
place = platform::CUDAPlace(); \
break; \
default: \
PADDLE_THROW(platform::errors::Unavailable( \
"Custom operator unsupported place id(%d)", \
static_cast<int>(place_))); \
}
void Tensor::reshape(const std::vector<int64_t> &shape) {
GET_CASTED_TENSOR
auto new_dim = framework::make_ddim(shape);
......@@ -257,6 +272,16 @@ Tensor Tensor::copy_to(const PlaceType &target_place) const {
return target;
}
Tensor Tensor::slice(const int64_t begin_idx, const int64_t end_idx) const {
GET_CASTED_TENSOR
GET_INNER_PLACE
framework::Tensor intermediate = tensor->Slice(begin_idx, end_idx);
Tensor target = Tensor(place_);
framework::CustomTensorUtils::ShareDataFrom(
static_cast<const void *>(&intermediate), target);
return target;
}
template PD_DLL_DECL Tensor
Tensor::copy_to<float>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor
......
......@@ -92,6 +92,41 @@ void TestAPISizeAndShape() {
CHECK(t1.shape() == tensor_shape);
}
void TestAPISlice() {
std::vector<int64_t> tensor_shape_origin1 = {5, 5};
std::vector<int64_t> tensor_shape_sub1 = {3, 5};
std::vector<int64_t> tensor_shape_origin2 = {5, 5, 5};
std::vector<int64_t> tensor_shape_sub2 = {1, 5, 5};
#ifdef PADDLE_WITH_CUDA
auto t1 = paddle::Tensor(paddle::PlaceType::kGPU, tensor_shape_origin1);
t1.mutable_data<float>();
CHECK(t1.slice(0, 5).shape() == tensor_shape_origin1);
CHECK(t1.slice(0, 3).shape() == tensor_shape_sub1);
auto t2 = paddle::Tensor(paddle::PlaceType::kGPU, tensor_shape_origin2);
t2.mutable_data<float>();
CHECK(t2.slice(4, 5).shape() == tensor_shape_sub2);
#endif
auto t3 = paddle::Tensor(paddle::PlaceType::kCPU, tensor_shape_origin1);
t3.mutable_data<float>();
CHECK(t3.slice(0, 5).shape() == tensor_shape_origin1);
CHECK(t3.slice(0, 3).shape() == tensor_shape_sub1);
auto t4 = paddle::Tensor(paddle::PlaceType::kCPU, tensor_shape_origin2);
t4.mutable_data<float>();
CHECK(t4.slice(4, 5).shape() == tensor_shape_sub2);
// Test writing function for sliced tensor
auto t = InitCPUTensorForTest<float>();
auto t_sliced = t.slice(0, 1);
auto* t_sliced_data_ptr = t_sliced.mutable_data<float>();
for (int64_t i = 0; i < t_sliced.size(); i++) {
t_sliced_data_ptr[i] += static_cast<float>(5);
}
auto* t_data_ptr = t.mutable_data<float>();
for (int64_t i = 0; i < t_sliced.size(); i++) {
CHECK_EQ(t_data_ptr[i], static_cast<float>(10));
}
}
template <typename T>
paddle::DataType TestDtype() {
std::vector<int64_t> tensor_shape = {5, 5};
......@@ -261,6 +296,8 @@ TEST(CustomTensor, copyTest) {
TestAPISizeAndShape();
VLOG(2) << "TestPlace";
TestAPIPlace();
VLOG(2) << "TestSlice";
TestAPISlice();
VLOG(2) << "TestCast";
GroupTestCast();
VLOG(2) << "TestDtypeConvert";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册