提交 835572af 编写于 作者: Q qijun

make LODTensor class instead struct

上级 03978442
...@@ -48,11 +48,19 @@ bool operator==(const LOD& a, const LOD& b); ...@@ -48,11 +48,19 @@ bool operator==(const LOD& a, const LOD& b);
* LODTensor (Level of details Tensor) * LODTensor (Level of details Tensor)
* see https://en.wikipedia.org/wiki/Level_of_details for reference. * see https://en.wikipedia.org/wiki/Level_of_details for reference.
*/ */
struct LODTensor { class LODTensor {
public: public:
LODTensor() {} LODTensor() {}
LODTensor(const LOD& lod, Tensor* t) : lod_(lod), tensor_(t) {} LODTensor(const LOD& lod, Tensor* t) : lod_(lod), tensor_(t) {}
void set_lod(const LOD& lod) { lod_ = lod; }
void set_tensor(Tensor* tensor) { tensor_ = tensor; }
Tensor& tensor() { return *tensor_; }
LOD lod() { return lod_; }
/* /*
* Get a element from LOD. * Get a element from LOD.
*/ */
...@@ -91,7 +99,7 @@ struct LODTensor { ...@@ -91,7 +99,7 @@ struct LODTensor {
*/ */
void SliceInLevel(size_t level, size_t elem_begin, size_t elem_end); void SliceInLevel(size_t level, size_t elem_begin, size_t elem_end);
public: private:
LOD lod_; LOD lod_;
Tensor* tensor_; // not owned Tensor* tensor_; // not owned
}; };
......
...@@ -40,8 +40,8 @@ class LODTensorTester : public ::testing::Test { ...@@ -40,8 +40,8 @@ class LODTensorTester : public ::testing::Test {
// malloc memory // malloc memory
tensor.mutable_data<float>(place); tensor.mutable_data<float>(place);
lod_tensor.lod_ = lod; lod_tensor.set_lod(lod);
lod_tensor.tensor_ = &tensor; lod_tensor.set_tensor(&tensor);
} }
protected: protected:
...@@ -65,8 +65,8 @@ TEST_F(LODTensorTester, SliceLevels) { ...@@ -65,8 +65,8 @@ TEST_F(LODTensorTester, SliceLevels) {
new_lod_tensor.SliceLevels(level, level + 1); new_lod_tensor.SliceLevels(level, level + 1);
ASSERT_EQ(new_lod_tensor.NumLevels(), 1UL); ASSERT_EQ(new_lod_tensor.NumLevels(), 1UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level)); ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level));
ASSERT_EQ(new_lod_tensor.tensor_->data<float>(), ASSERT_EQ(new_lod_tensor.tensor().data<float>(),
lod_tensor.tensor_->data<float>()); lod_tensor.tensor().data<float>());
} }
// slice 2 level // slice 2 level
for (size_t level = 0; level < 2UL; ++level) { for (size_t level = 0; level < 2UL; ++level) {
...@@ -75,8 +75,8 @@ TEST_F(LODTensorTester, SliceLevels) { ...@@ -75,8 +75,8 @@ TEST_F(LODTensorTester, SliceLevels) {
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL); ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level)); ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level));
ASSERT_EQ(new_lod_tensor.NumElements(1), lod_tensor.NumElements(level + 1)); ASSERT_EQ(new_lod_tensor.NumElements(1), lod_tensor.NumElements(level + 1));
ASSERT_EQ(new_lod_tensor.tensor_->data<float>(), ASSERT_EQ(new_lod_tensor.tensor().data<float>(),
lod_tensor.tensor_->data<float>()); lod_tensor.tensor().data<float>());
} }
} }
...@@ -88,8 +88,8 @@ TEST_F(LODTensorTester, SliceInLevel) { ...@@ -88,8 +88,8 @@ TEST_F(LODTensorTester, SliceInLevel) {
EXPECT_EQ(new_lod_tensor.NumElements(0), 2UL); EXPECT_EQ(new_lod_tensor.NumElements(0), 2UL);
EXPECT_EQ(new_lod_tensor.NumElements(1), 4UL); EXPECT_EQ(new_lod_tensor.NumElements(1), 4UL);
EXPECT_EQ(new_lod_tensor.NumElements(2), 8UL); EXPECT_EQ(new_lod_tensor.NumElements(2), 8UL);
ASSERT_EQ(new_lod_tensor.tensor_->data<float>(), ASSERT_EQ(new_lod_tensor.tensor().data<float>(),
lod_tensor.tensor_->data<float>()); lod_tensor.tensor().data<float>());
level = 1; level = 1;
new_lod_tensor = lod_tensor; new_lod_tensor = lod_tensor;
...@@ -97,8 +97,8 @@ TEST_F(LODTensorTester, SliceInLevel) { ...@@ -97,8 +97,8 @@ TEST_F(LODTensorTester, SliceInLevel) {
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL); ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL); ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL); ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(new_lod_tensor.tensor_->data<float>(), ASSERT_EQ(new_lod_tensor.tensor().data<float>(),
lod_tensor.tensor_->data<float>()); lod_tensor.tensor().data<float>());
} }
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册