提交 dcf4682b 编写于 作者: S superjom

add LoDTensor::NumElements(id,id)

上级 d226daee
......@@ -77,9 +77,15 @@ size_t LoDTensor::NumElements(size_t level, size_t idx) const {
PADDLE_ENFORCE_LT(idx, NumElements(level));
// the last level of LoD, just return number of records in Tensor
if (level == NumLevels() - 1) {
return lod_[level][idx + 1] - lod_[level][idx];
}
// high level of LoD, and there is another lower level, return number of
// lower-level elements
auto tmp = SliceInLevel(lod_, level, idx, idx + 1);
PADDLE_ENFORCE_GE(tmp.size(), 2);
// there is a 0 as a placeholder stored in LoD, so the number of elements
// equals lod.size() - 1
return tmp[1].size() - 1;
}
void LoDTensor::SliceLevels(size_t level_begin, size_t level_end) {
......
......@@ -38,6 +38,18 @@ using Vector = thrust::host_vector<
T, thrust::system::cuda::experimental::pinned_allocator<T>>;
#endif
/*
* 3-level LoD stores
*
* 0 10 20
* 0 5 10 15 20
* 0 2 5 7 10 12 15 20
*
* - in a level, each element indicates offset in the underlying Tensor
* - the first element should be 0 and that indicates that this sequence start
* from 0
* - each sequence's begin and end(no-inclusive) is level[id, id+1]
*/
using LoD = std::vector<Vector<size_t>>;
LoD SliceLevels(const LoD& in, size_t level_begin, size_t level_end);
......
......@@ -56,6 +56,12 @@ TEST_F(LoDTensorTester, NumElements) {
ASSERT_EQ(lod_tensor_.NumElements(2), 8UL);
}
TEST_F(LoDTensorTester, NumElements2) {
ASSERT_EQ(lod_tensor_.NumElements(0, 0), 2UL);
ASSERT_EQ(lod_tensor_.NumElements(0, 1), 2UL);
ASSERT_EQ(lod_tensor_.NumElements(1, 1), 2UL);
}
TEST_F(LoDTensorTester, SliceLevels) {
// slice 1 level
for (size_t level = 0; level < 3UL; ++level) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部