From 835572afe70e3c0a0f11ff2f40a53b899b7adda6 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 31 Aug 2017 16:51:02 +0800 Subject: [PATCH] make LODTensor class instead struct --- paddle/framework/lod_tensor.h | 12 ++++++++++-- paddle/framework/lod_tensor_test.cc | 20 ++++++++++---------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index 7a9aebf50..9e6b6b4ac 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -48,11 +48,19 @@ bool operator==(const LOD& a, const LOD& b); * LODTensor (Level of details Tensor) * see https://en.wikipedia.org/wiki/Level_of_details for reference. */ -struct LODTensor { +class LODTensor { public: LODTensor() {} LODTensor(const LOD& lod, Tensor* t) : lod_(lod), tensor_(t) {} + void set_lod(const LOD& lod) { lod_ = lod; } + + void set_tensor(Tensor* tensor) { tensor_ = tensor; } + + Tensor& tensor() { return *tensor_; } + + LOD lod() { return lod_; } + /* * Get a element from LOD. */ @@ -91,7 +99,7 @@ struct LODTensor { */ void SliceInLevel(size_t level, size_t elem_begin, size_t elem_end); - public: + private: LOD lod_; Tensor* tensor_; // not owned }; diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc index 30c8925ad..9a351605e 100644 --- a/paddle/framework/lod_tensor_test.cc +++ b/paddle/framework/lod_tensor_test.cc @@ -40,8 +40,8 @@ class LODTensorTester : public ::testing::Test { // malloc memory tensor.mutable_data(place); - lod_tensor.lod_ = lod; - lod_tensor.tensor_ = &tensor; + lod_tensor.set_lod(lod); + lod_tensor.set_tensor(&tensor); } protected: @@ -65,8 +65,8 @@ TEST_F(LODTensorTester, SliceLevels) { new_lod_tensor.SliceLevels(level, level + 1); ASSERT_EQ(new_lod_tensor.NumLevels(), 1UL); ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level)); - ASSERT_EQ(new_lod_tensor.tensor_->data(), - lod_tensor.tensor_->data()); + ASSERT_EQ(new_lod_tensor.tensor().data(), + lod_tensor.tensor().data()); } // slice 2 level for (size_t level = 0; level < 2UL; ++level) { @@ -75,8 +75,8 @@ TEST_F(LODTensorTester, SliceLevels) { ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL); ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level)); ASSERT_EQ(new_lod_tensor.NumElements(1), lod_tensor.NumElements(level + 1)); - ASSERT_EQ(new_lod_tensor.tensor_->data(), - lod_tensor.tensor_->data()); + ASSERT_EQ(new_lod_tensor.tensor().data(), + lod_tensor.tensor().data()); } } @@ -88,8 +88,8 @@ TEST_F(LODTensorTester, SliceInLevel) { EXPECT_EQ(new_lod_tensor.NumElements(0), 2UL); EXPECT_EQ(new_lod_tensor.NumElements(1), 4UL); EXPECT_EQ(new_lod_tensor.NumElements(2), 8UL); - ASSERT_EQ(new_lod_tensor.tensor_->data(), - lod_tensor.tensor_->data()); + ASSERT_EQ(new_lod_tensor.tensor().data(), + lod_tensor.tensor().data()); level = 1; new_lod_tensor = lod_tensor; @@ -97,8 +97,8 @@ TEST_F(LODTensorTester, SliceInLevel) { ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL); ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL); ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL); - ASSERT_EQ(new_lod_tensor.tensor_->data(), - lod_tensor.tensor_->data()); + ASSERT_EQ(new_lod_tensor.tensor().data(), + lod_tensor.tensor().data()); } } // namespace framework -- GitLab