提交 f0088335 编写于 作者: M Megvii Engine Team

feat(mgb): upgrade flatbuffer

GitOrigin-RevId: 7b1a04934ea3a7936661bb0558e4c58b258cdff2
上级 657db8dc
...@@ -33,7 +33,7 @@ public: ...@@ -33,7 +33,7 @@ public:
void rewind() override { std::rewind(m_fptr); } void rewind() override { std::rewind(m_fptr); }
void skip(size_t bytes) override { void skip(int64_t bytes) override {
auto err = fseek(m_fptr, bytes, SEEK_CUR); auto err = fseek(m_fptr, bytes, SEEK_CUR);
mgb_assert(!err); mgb_assert(!err);
} }
...@@ -104,7 +104,7 @@ public: ...@@ -104,7 +104,7 @@ public:
void rewind() override { m_offset = 0; } void rewind() override { m_offset = 0; }
void skip(size_t bytes) override { void skip(int64_t bytes) override {
m_offset += bytes; m_offset += bytes;
mgb_assert(m_offset <= m_size); mgb_assert(m_offset <= m_size);
} }
...@@ -146,7 +146,7 @@ public: ...@@ -146,7 +146,7 @@ public:
m_offset = 0; m_offset = 0;
} }
void skip(size_t bytes) override { void skip(int64_t bytes) override {
m_offset += bytes; m_offset += bytes;
mgb_assert(m_offset <= m_size); mgb_assert(m_offset <= m_size);
} }
......
...@@ -838,25 +838,22 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewi ...@@ -838,25 +838,22 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewi
// Read fbs::Graph // Read fbs::Graph
uint32_t size; uint32_t size;
m_file->read(&size, sizeof(size)); m_file->read(&size, sizeof(size));
m_graph_buf = m_file->read_shared(size); m_file->skip(-sizeof(size));
m_graph_buf = m_file->read_shared(size + sizeof(size));
// Rewind back to tensor data // Rewind back to tensor data
m_file->rewind(); m_file->rewind();
m_file->skip(tensor_begin); m_file->skip(tensor_begin);
mgb_throw_if(
!fbs::GraphBufferHasIdentifier(m_graph_buf.data()), SerializationError,
"invalid fbs model");
{ {
flatbuffers::Verifier verifier( flatbuffers::Verifier verifier(
static_cast<const uint8_t*>(m_graph_buf.data()), m_graph_buf.size()); static_cast<const uint8_t*>(m_graph_buf.data()), m_graph_buf.size());
mgb_throw_if( mgb_throw_if(
!fbs::VerifyGraphBuffer(verifier), SerializationError, !fbs::VerifySizePrefixedGraphBuffer(verifier), SerializationError,
"model verification failed (invalid or corrupted model?)"); "model verification failed (invalid or corrupted model?)");
} }
m_graph = fbs::GetGraph(m_graph_buf.data()); m_graph = fbs::GetSizePrefixedGraph(m_graph_buf.data());
m_mgb_version = m_graph->mgb_version(); m_mgb_version = m_graph->mgb_version();
if (m_graph->mgb_version() > MGB_VERSION) { if (m_graph->mgb_version() > MGB_VERSION) {
mgb_log_warn( mgb_log_warn(
......
...@@ -801,21 +801,18 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re ...@@ -801,21 +801,18 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re
// Read fbs::Graph // Read fbs::Graph
uint32_t size; uint32_t size;
m_file->read(&size, sizeof(size)); m_file->read(&size, sizeof(size));
m_model_buf = m_file->read_shared(size); m_file->skip(-sizeof(size));
m_model_buf = m_file->read_shared(size + sizeof(size));
mgb_throw_if(
!fbs::v2::ModelBufferHasIdentifier(m_model_buf.data()), SerializationError,
"invalid fbs model");
{ {
flatbuffers::Verifier verifier( flatbuffers::Verifier verifier(
static_cast<const uint8_t*>(m_model_buf.data()), m_model_buf.size()); static_cast<const uint8_t*>(m_model_buf.data()), m_model_buf.size());
mgb_throw_if( mgb_throw_if(
!fbs::v2::VerifyModelBuffer(verifier), SerializationError, !fbs::v2::VerifySizePrefixedModelBuffer(verifier), SerializationError,
"model verification failed (invalid or corrupted model?)"); "model verification failed (invalid or corrupted model?)");
} }
m_model = fbs::v2::GetModel(m_model_buf.data()); m_model = fbs::v2::GetSizePrefixedModel(m_model_buf.data());
m_mgb_version = m_model->mge_version(); m_mgb_version = m_model->mge_version();
m_model_version = m_model->model_version(); m_model_version = m_model->model_version();
if (m_model->mge_version() > MGB_VERSION) { if (m_model->mge_version() > MGB_VERSION) {
......
...@@ -33,7 +33,7 @@ public: ...@@ -33,7 +33,7 @@ public:
virtual void rewind() = 0; virtual void rewind() = 0;
//! skip given number of bytes //! skip given number of bytes
virtual void skip(size_t bytes) = 0; virtual void skip(int64_t bytes) = 0;
//! read data into buffer //! read data into buffer
virtual void read(void* dst, size_t size) = 0; virtual void read(void* dst, size_t size) = 0;
......
...@@ -217,10 +217,10 @@ if(NOT DEFINED IOS_DEPLOYMENT_TARGET) ...@@ -217,10 +217,10 @@ if(NOT DEFINED IOS_DEPLOYMENT_TARGET)
"2.0" "2.0"
CACHE STRING "Minimum iOS version to build for.") CACHE STRING "Minimum iOS version to build for.")
else() else()
# Unless specified, SDK version 10.0 is used by default as minimum target version # Unless specified, SDK version 11.0 is used by default as minimum target version
# (iOS, tvOS). # (iOS, tvOS).
set(IOS_DEPLOYMENT_TARGET set(IOS_DEPLOYMENT_TARGET
"10.0" "11.0"
CACHE STRING "Minimum iOS version to build for.") CACHE STRING "Minimum iOS version to build for.")
endif() endif()
message( message(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册