提交 835572af 编写于 作者: Q qijun

make LODTensor class instead struct

上级 03978442
......@@ -48,11 +48,19 @@ bool operator==(const LOD& a, const LOD& b);
* LODTensor (Level of details Tensor)
* see https://en.wikipedia.org/wiki/Level_of_details for reference.
*/
struct LODTensor {
class LODTensor {
public:
LODTensor() {}
LODTensor(const LOD& lod, Tensor* t) : lod_(lod), tensor_(t) {}
void set_lod(const LOD& lod) { lod_ = lod; }
void set_tensor(Tensor* tensor) { tensor_ = tensor; }
Tensor& tensor() { return *tensor_; }
LOD lod() { return lod_; }
/*
* Get a element from LOD.
*/
......@@ -91,7 +99,7 @@ struct LODTensor {
*/
void SliceInLevel(size_t level, size_t elem_begin, size_t elem_end);
public:
private:
LOD lod_;
Tensor* tensor_; // not owned
};
......
......@@ -40,8 +40,8 @@ class LODTensorTester : public ::testing::Test {
// malloc memory
tensor.mutable_data<float>(place);
lod_tensor.lod_ = lod;
lod_tensor.tensor_ = &tensor;
lod_tensor.set_lod(lod);
lod_tensor.set_tensor(&tensor);
}
protected:
......@@ -65,8 +65,8 @@ TEST_F(LODTensorTester, SliceLevels) {
new_lod_tensor.SliceLevels(level, level + 1);
ASSERT_EQ(new_lod_tensor.NumLevels(), 1UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level));
ASSERT_EQ(new_lod_tensor.tensor_->data<float>(),
lod_tensor.tensor_->data<float>());
ASSERT_EQ(new_lod_tensor.tensor().data<float>(),
lod_tensor.tensor().data<float>());
}
// slice 2 level
for (size_t level = 0; level < 2UL; ++level) {
......@@ -75,8 +75,8 @@ TEST_F(LODTensorTester, SliceLevels) {
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level));
ASSERT_EQ(new_lod_tensor.NumElements(1), lod_tensor.NumElements(level + 1));
ASSERT_EQ(new_lod_tensor.tensor_->data<float>(),
lod_tensor.tensor_->data<float>());
ASSERT_EQ(new_lod_tensor.tensor().data<float>(),
lod_tensor.tensor().data<float>());
}
}
......@@ -88,8 +88,8 @@ TEST_F(LODTensorTester, SliceInLevel) {
EXPECT_EQ(new_lod_tensor.NumElements(0), 2UL);
EXPECT_EQ(new_lod_tensor.NumElements(1), 4UL);
EXPECT_EQ(new_lod_tensor.NumElements(2), 8UL);
ASSERT_EQ(new_lod_tensor.tensor_->data<float>(),
lod_tensor.tensor_->data<float>());
ASSERT_EQ(new_lod_tensor.tensor().data<float>(),
lod_tensor.tensor().data<float>());
level = 1;
new_lod_tensor = lod_tensor;
......@@ -97,8 +97,8 @@ TEST_F(LODTensorTester, SliceInLevel) {
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
ASSERT_EQ(new_lod_tensor.tensor_->data<float>(),
lod_tensor.tensor_->data<float>());
ASSERT_EQ(new_lod_tensor.tensor().data<float>(),
lod_tensor.tensor().data<float>());
}
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册