未验证 提交 6ec9d32c 编写于 作者: 石晓伟 提交者: GitHub

vector view for flatbuffers, test=develop (#3862)

* vector view for flatbuffers, test=develop

* update dependencies, test=develop
上级 5ab1d7e7
......@@ -9,3 +9,6 @@ lite_fbs_library(fbs_op_desc SRCS op_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_var_desc SRCS var_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_block_desc SRCS block_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_program_desc SRCS program_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(vector_view SRCS vector_view.cc FBS_DEPS framework_fbs_header)
lite_cc_test(test_vector_view SRCS vector_view_test.cc DEPS vector_view)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/vector_view.h"
namespace paddle {
namespace lite {} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <type_traits>
#include <vector>
#include "flatbuffers/flatbuffers.h"
namespace paddle {
namespace lite {
namespace fbs {
struct Flatbuffers {};
struct Standand {};
template <typename T, typename U = void>
struct ElementTraits {
typedef T element_type;
};
template <typename T>
struct ElementTraits<T,
typename std::enable_if<std::is_class<T>::value>::type> {
typedef flatbuffers::Offset<T> element_type;
};
template <>
struct ElementTraits<std::string, void> {
typedef flatbuffers::Offset<flatbuffers::String> element_type;
};
template <typename T, typename U>
struct VectorTraits;
template <typename T>
struct VectorTraits<T, Flatbuffers> {
typedef flatbuffers::Vector<typename ElementTraits<T>::element_type>
vector_type;
typedef typename vector_type::const_iterator const_iterator;
typedef typename const_iterator::value_type value_type;
typedef const typename const_iterator::reference const_reference;
typedef value_type subscript_return_type;
};
template <typename T>
struct VectorTraits<T, Standand> {
typedef std::vector<T> vector_type;
typedef typename vector_type::const_iterator const_iterator;
typedef typename vector_type::const_reference const_reference;
typedef const_reference subscript_return_type;
};
template <typename T, typename U = Flatbuffers>
class VectorView {
public:
typedef VectorTraits<T, U> Traits;
explicit VectorView(typename Traits::vector_type const* cvec) {
cvec_ = cvec;
}
typename Traits::subscript_return_type operator[](size_t i) const {
return cvec_->operator[](i);
}
typename Traits::const_iterator begin() const { return cvec_->begin(); }
typename Traits::const_iterator end() const { return cvec_->end(); }
size_t size() const { return cvec_->size(); }
~VectorView() = default;
private:
typename Traits::vector_type const* cvec_;
};
struct FBSStrIterator {
typedef flatbuffers::VectorIterator<
flatbuffers::Offset<flatbuffers::String>,
typename flatbuffers::IndirectHelper<
flatbuffers::Offset<flatbuffers::String>>::return_type>
VI;
explicit FBSStrIterator(const VI& iter) { iter_ = iter; }
const VI& RawIter() const { return iter_; }
bool operator==(const FBSStrIterator& other) const {
return iter_ == other.RawIter();
}
bool operator<(const FBSStrIterator& other) const {
return iter_ < other.RawIter();
}
bool operator!=(const FBSStrIterator& other) const {
return iter_ != other.RawIter();
}
ptrdiff_t operator-(const FBSStrIterator& other) const {
return iter_ - other.RawIter();
}
std::string operator*() const { return iter_.operator*()->str(); }
std::string operator->() const { return iter_.operator->()->str(); }
FBSStrIterator& operator++() {
iter_++;
return *this;
}
FBSStrIterator& operator--() {
iter_--;
return *this;
}
FBSStrIterator operator+(const size_t& offset) {
return FBSStrIterator(iter_ + offset);
}
FBSStrIterator operator-(const size_t& offset) {
return FBSStrIterator(iter_ - offset);
}
private:
VI iter_;
};
template <>
class VectorView<std::string, Flatbuffers> {
public:
typedef VectorTraits<std::string, Flatbuffers> Traits;
explicit VectorView(typename Traits::vector_type const* cvec) {
cvec_ = cvec;
}
std::string operator[](size_t i) const { return cvec_->operator[](i)->str(); }
FBSStrIterator begin() const { return FBSStrIterator(cvec_->begin()); }
FBSStrIterator end() const { return FBSStrIterator(cvec_->end()); }
size_t size() const { return cvec_->size(); }
~VectorView() = default;
private:
typename Traits::vector_type const* cvec_;
};
} // namespace fbs
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/vector_view.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "lite/model_parser/flatbuffers/framework_generated.h"
namespace paddle {
namespace lite {
TEST(VectorView, std_vector) {
std::vector<int64_t> vector{1, 2, 3};
fbs::VectorView<int64_t, fbs::Standand> vector_view(&vector);
size_t i = 0;
for (const auto& value : vector_view) {
EXPECT_EQ(value, vector[i]);
++i;
}
for (size_t j = 0; j < vector_view.size(); ++j) {
EXPECT_EQ(vector_view[i], vector[i]);
}
}
TEST(VectorView, Flatbuffers) {
using namespace flatbuffers; // NOLINT
using namespace paddle::lite::fbs; // NOLINT
auto create_desc = [](FlatBufferBuilder& fbb) {
/* --------- Set --------- */
// Attr
std::vector<int32_t> ints({-1, 0, 1, 2, 3});
auto string_0 = fbb.CreateString("string_0");
auto string_1 = fbb.CreateString("string_1");
std::vector<Offset<String>> strings;
strings.push_back(string_0);
strings.push_back(string_1);
auto attr = proto::OpDesc_::CreateAttrDirect(fbb,
nullptr,
proto::AttrType_INT,
0,
0.0f,
nullptr,
&ints,
nullptr,
&strings);
// OpDesc
std::vector<Offset<proto::OpDesc_::Attr>> attrs;
attrs.push_back(attr);
auto op_desc =
proto::CreateOpDescDirect(fbb, "hello!", nullptr, nullptr, &attrs);
// BlockDesc 0
std::vector<Offset<proto::OpDesc>> ops;
ops.push_back(op_desc);
auto block_0 = proto::CreateBlockDescDirect(fbb, 0, 0, nullptr, &ops);
// BlockDesc 1
auto block_1 = proto::CreateBlockDescDirect(fbb, 1);
// ProgramDesc
std::vector<Offset<proto::BlockDesc>> block_vector;
block_vector.push_back(block_0);
block_vector.push_back(block_1);
auto orc = proto::CreateProgramDescDirect(fbb, &block_vector);
fbb.Finish(orc);
};
FlatBufferBuilder fbb;
create_desc(fbb);
auto program = fbs::proto::GetProgramDesc(fbb.GetBufferPointer());
// BlockDesc View
VectorView<proto::BlockDesc> block_view(program->blocks());
EXPECT_EQ(block_view.size(), static_cast<size_t>(2));
EXPECT_EQ(block_view[0]->idx(), 0);
EXPECT_EQ(block_view[1]->idx(), 1);
// OpDesc & Attr View
VectorView<proto::OpDesc> op_view(block_view[0]->ops());
EXPECT_EQ(op_view[0]->type()->str(), std::string("hello!"));
VectorView<proto::OpDesc_::Attr> attr_view(op_view[0]->attrs());
// int32_t View
VectorView<int32_t> ints_view(attr_view[0]->ints());
std::vector<int32_t> ints({-1, 0, 1, 2, 3});
size_t cnt_0 = 0;
for (const auto& i : ints_view) {
EXPECT_EQ(i, ints[cnt_0]);
++cnt_0;
}
for (size_t i = 0; i < ints_view.size(); ++i) {
EXPECT_EQ(ints_view[i], ints[i]);
}
// String View
VectorView<std::string> strings_view(attr_view[0]->strings());
std::vector<std::string> strings({"string_0", "string_1"});
EXPECT_EQ(strings_view.size(), strings.size());
size_t cnt_1 = 0;
for (const auto& s : strings_view) {
EXPECT_EQ(s, strings[cnt_1]);
++cnt_1;
}
for (size_t i = 0; i < strings_view.size(); ++i) {
EXPECT_EQ(strings_view[i], strings[i]);
}
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册