diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index 7a9aebf50a6d96881a76b9046ea7280c2ec44be8..9e6b6b4aca41ed464292b56bf6f2d27514f874f7 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -48,11 +48,19 @@ bool operator==(const LOD& a, const LOD& b); * LODTensor (Level of details Tensor) * see https://en.wikipedia.org/wiki/Level_of_details for reference. */ -struct LODTensor { +class LODTensor { public: LODTensor() {} LODTensor(const LOD& lod, Tensor* t) : lod_(lod), tensor_(t) {} + void set_lod(const LOD& lod) { lod_ = lod; } + + void set_tensor(Tensor* tensor) { tensor_ = tensor; } + + Tensor& tensor() { return *tensor_; } + + LOD lod() { return lod_; } + /* * Get a element from LOD. */ @@ -91,7 +99,7 @@ struct LODTensor { */ void SliceInLevel(size_t level, size_t elem_begin, size_t elem_end); - public: + private: LOD lod_; Tensor* tensor_; // not owned }; diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc index 30c8925ad0f3c032902893545236706f6aabb2a4..9a351605edb5013bdab2c6193bdd9ce401acc937 100644 --- a/paddle/framework/lod_tensor_test.cc +++ b/paddle/framework/lod_tensor_test.cc @@ -40,8 +40,8 @@ class LODTensorTester : public ::testing::Test { // malloc memory tensor.mutable_data(place); - lod_tensor.lod_ = lod; - lod_tensor.tensor_ = &tensor; + lod_tensor.set_lod(lod); + lod_tensor.set_tensor(&tensor); } protected: @@ -65,8 +65,8 @@ TEST_F(LODTensorTester, SliceLevels) { new_lod_tensor.SliceLevels(level, level + 1); ASSERT_EQ(new_lod_tensor.NumLevels(), 1UL); ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level)); - ASSERT_EQ(new_lod_tensor.tensor_->data(), - lod_tensor.tensor_->data()); + ASSERT_EQ(new_lod_tensor.tensor().data(), + lod_tensor.tensor().data()); } // slice 2 level for (size_t level = 0; level < 2UL; ++level) { @@ -75,8 +75,8 @@ TEST_F(LODTensorTester, SliceLevels) { ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL); ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level)); ASSERT_EQ(new_lod_tensor.NumElements(1), lod_tensor.NumElements(level + 1)); - ASSERT_EQ(new_lod_tensor.tensor_->data(), - lod_tensor.tensor_->data()); + ASSERT_EQ(new_lod_tensor.tensor().data(), + lod_tensor.tensor().data()); } } @@ -88,8 +88,8 @@ TEST_F(LODTensorTester, SliceInLevel) { EXPECT_EQ(new_lod_tensor.NumElements(0), 2UL); EXPECT_EQ(new_lod_tensor.NumElements(1), 4UL); EXPECT_EQ(new_lod_tensor.NumElements(2), 8UL); - ASSERT_EQ(new_lod_tensor.tensor_->data(), - lod_tensor.tensor_->data()); + ASSERT_EQ(new_lod_tensor.tensor().data(), + lod_tensor.tensor().data()); level = 1; new_lod_tensor = lod_tensor; @@ -97,8 +97,8 @@ TEST_F(LODTensorTester, SliceInLevel) { ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL); ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL); ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL); - ASSERT_EQ(new_lod_tensor.tensor_->data(), - lod_tensor.tensor_->data()); + ASSERT_EQ(new_lod_tensor.tensor().data(), + lod_tensor.tensor().data()); } } // namespace framework