diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index 908a1f2fd0abe0aa4016c72dbcbc18dcc144232c..3c349637cdbe59b2cf9a1ea28e7715f4181f9293 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -72,20 +72,16 @@ bool operator==(const LoD& a, const LoD& b) { return true; } -void LoDTensor::SliceLevels(size_t level_begin, size_t level_end) { +void LoDTensor::ShrinkLevels(size_t level_begin, size_t level_end) { auto new_lod = framework::SliceLevels(lod_, level_begin, level_end); lod_ = new_lod; } -void LoDTensor::SliceInLevel(size_t level, size_t elem_begin, size_t elem_end) { - PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level, - NumLevels()); - PADDLE_ENFORCE(elem_begin < NumElements(level), - "element begin [%d] out of range [%d]", elem_begin, - NumElements(level)); - PADDLE_ENFORCE(elem_end < NumElements(level) + 1, - "element end [%d] out of range [%d]", elem_end, - NumElements(level)); +void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin, + size_t elem_end) { + PADDLE_ENFORCE_LT(level, NumLevels()); + PADDLE_ENFORCE_LT(elem_begin, NumElements(level)); + PADDLE_ENFORCE_LT(elem_end, NumElements(level) + 1); auto new_lod = framework::SliceInLevel(lod_, level, elem_begin, elem_end); lod_ = new_lod; diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index fac5cd20aa7f9db0792f8102bb442192ab1ad63f..82f58464264c6871b51251e0feae3d5ca076cd2b 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -89,15 +89,15 @@ class LoDTensor : public Tensor { } /* - * Slice of levels[level_begin:level_end] + * Shrink levels[level_begin:level_end] */ - void SliceLevels(size_t level_begin, size_t level_end); + void ShrinkLevels(size_t level_begin, size_t level_end); /* - * Slice of elements of a level, [elem_begin: elem_end] + * Shrink elements of a level, [elem_begin: elem_end] * @note: low performance in slice lod_. */ - void SliceInLevel(size_t level, size_t elem_begin, size_t elem_end); + void ShrinkInLevel(size_t level, size_t elem_begin, size_t elem_end); private: LoD lod_; diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc index 7915326b27a22e9280e3f09d9bbfc2a58f46aff7..486b839738ec077545163bc47e6a97ef188c3c2f 100644 --- a/paddle/framework/lod_tensor_test.cc +++ b/paddle/framework/lod_tensor_test.cc @@ -56,11 +56,11 @@ TEST_F(LoDTensorTester, NumElements) { ASSERT_EQ(lod_tensor_.NumElements(2), 8UL); } -TEST_F(LoDTensorTester, SliceLevels) { +TEST_F(LoDTensorTester, ShrinkLevels) { // slice 1 level for (size_t level = 0; level < 3UL; ++level) { LoDTensor new_lod_tensor = lod_tensor_; - new_lod_tensor.SliceLevels(level, level + 1); + new_lod_tensor.ShrinkLevels(level, level + 1); ASSERT_EQ(new_lod_tensor.NumLevels(), 1UL); ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor_.NumElements(level)); ASSERT_EQ(new_lod_tensor.data(), lod_tensor_.data()); @@ -68,7 +68,7 @@ TEST_F(LoDTensorTester, SliceLevels) { // slice 2 level for (size_t level = 0; level < 2UL; ++level) { LoDTensor new_lod_tensor = lod_tensor_; - new_lod_tensor.SliceLevels(level, level + 2); + new_lod_tensor.ShrinkLevels(level, level + 2); ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL); ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor_.NumElements(level)); ASSERT_EQ(new_lod_tensor.NumElements(1), @@ -77,10 +77,10 @@ TEST_F(LoDTensorTester, SliceLevels) { } } -TEST_F(LoDTensorTester, SliceInLevel) { +TEST_F(LoDTensorTester, ShrinkInLevel) { size_t level = 0; LoDTensor new_lod_tensor = lod_tensor_; - new_lod_tensor.SliceInLevel(level, 0, 2); + new_lod_tensor.ShrinkInLevel(level, 0, 2); EXPECT_EQ(new_lod_tensor.NumLevels(), 3UL); EXPECT_EQ(new_lod_tensor.NumElements(0), 2UL); EXPECT_EQ(new_lod_tensor.NumElements(1), 4UL); @@ -89,7 +89,7 @@ TEST_F(LoDTensorTester, SliceInLevel) { level = 1; new_lod_tensor = lod_tensor_; - new_lod_tensor.SliceInLevel(level, 0, 2); + new_lod_tensor.ShrinkInLevel(level, 0, 2); ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL); ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL); ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);