未验证 提交 f65b9b18 编写于 作者: 石晓伟 提交者: GitHub

[API] Flatbuffers interfaces, test=develop (#4049)

* [test] flatbuffers io, test=develop

* determine the basic interfaces and help functions of flatbuffers, test=develop
上级 4bdeabb8
......@@ -10,5 +10,6 @@ lite_fbs_library(fbs_var_desc SRCS var_desc.cc FBS_DEPS fbs_headers)
lite_fbs_library(fbs_block_desc SRCS block_desc.cc FBS_DEPS fbs_headers)
lite_cc_library(fbs_program_desc SRCS program_desc.cc DEPS fbs_op_desc fbs_var_desc fbs_block_desc)
lite_fbs_library(fbs_param_desc SRCS param_desc.cc FBS_DEPS fbs_headers)
lite_cc_library(fbs_io SRCS io.cc DEPS fbs_program_desc fbs_param_desc)
lite_cc_library(fbs_io SRCS io.cc DEPS fbs_program_desc fbs_param_desc scope)
lite_cc_test(test_vector_view SRCS vector_view_test.cc DEPS fbs_program_desc)
lite_cc_test(test_fbs_io SRCS io_test.cc DEPS fbs_io)
......@@ -23,16 +23,22 @@ namespace paddle {
namespace lite {
namespace fbs {
void LoadModel(const std::string& path, ProgramDesc* prog) {
CHECK(prog);
std::vector<char> LoadFile(const std::string& path) {
FILE* file = fopen(path.c_str(), "rb");
fseek(file, 0, SEEK_END);
int64_t length = ftell(file);
rewind(file);
std::vector<char> buf(length);
CHECK(fread(buf.data(), 1, length, file));
CHECK(fread(buf.data(), 1, length, file) == length);
fclose(file);
return buf;
}
void SaveFile(const std::string& path, const void* src, size_t byte_size) {
CHECK(src);
FILE* file = fopen(path.c_str(), "wb");
CHECK(fwrite(src, sizeof(char), byte_size, file) == byte_size);
fclose(file);
prog->Init(std::move(buf));
}
void SetParamWithTensor(const std::string& name,
......@@ -72,6 +78,7 @@ void SetScopeWithCombinedParams(lite::Scope* scope,
SetTensorWithParam(tensor, param);
}
}
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -25,18 +25,14 @@ namespace paddle {
namespace lite {
namespace fbs {
void LoadModel(const std::string& path, ProgramDesc* prog);
void SetParamWithTensor(const std::string& name,
const lite::Tensor& tensor,
ParamDescWriteAPI* prog);
void SetTensorWithParam(const lite::Tensor& tensor, ParamDescReadAPI* prog);
std::vector<char> LoadFile(const std::string& path);
void SaveFile(const std::string& path, const void* src, size_t byte_size);
void SetScopeWithCombinedParams(lite::Scope* scope,
const CombinedParamsDescReadAPI& params);
void SetCombinedParamsWithScope(const lite::Scope& scope,
const std::vector<std::string>& params_name,
CombinedParamsDescWriteAPI* params);
void SetScopeWithCombinedParams(lite::Scope* scope,
const CombinedParamsDescReadAPI& params);
} // namespace fbs
} // namespace lite
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/io.h"
#include <gtest/gtest.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace fbs {
namespace {
template <typename T>
void set_tensor(paddle::lite::Tensor* tensor,
const std::vector<int64_t>& dims) {
auto production =
std::accumulate(begin(dims), end(dims), 1, std::multiplies<int64_t>());
tensor->Resize(dims);
std::vector<T> data;
data.resize(production);
for (size_t i = 0; i < production; ++i) {
data[i] = i / 2.f;
}
std::memcpy(tensor->mutable_data<T>(), data.data(), sizeof(T) * data.size());
}
} // namespace
TEST(CombinedParamsDesc, Scope) {
/* --------- Save scope ---------- */
Scope scope;
std::vector<std::string> params_name({"var_0", "var_1"});
// variable 0
Variable* var_0 = scope.Var(params_name[0]);
Tensor* tensor_0 = var_0->GetMutable<Tensor>();
set_tensor<float>(tensor_0, std::vector<int64_t>({3, 2}));
// variable 1
Variable* var_1 = scope.Var(params_name[1]);
Tensor* tensor_1 = var_1->GetMutable<Tensor>();
set_tensor<int8_t>(tensor_1, std::vector<int64_t>({10, 1}));
// Set combined parameters
fbs::CombinedParamsDesc combined_param;
SetCombinedParamsWithScope(scope, params_name, &combined_param);
/* --------- Check scope ---------- */
auto check_params = [&](const CombinedParamsDescReadAPI& desc) {
Scope scope_l;
SetScopeWithCombinedParams(&scope_l, desc);
// variable 0
Variable* var_l0 = scope_l.FindVar(params_name[0]);
CHECK(var_l0);
const Tensor& tensor_l0 = var_l0->Get<Tensor>();
CHECK(TensorCompareWith(*tensor_0, tensor_l0));
// variable 1
Variable* var_l1 = scope_l.FindVar(params_name[1]);
CHECK(var_l1);
const Tensor& tensor_l1 = var_l1->Get<Tensor>();
CHECK(TensorCompareWith(*tensor_1, tensor_l1));
};
check_params(combined_param);
/* --------- Cache scope ---------- */
std::vector<char> cache;
cache.resize(combined_param.buf_size());
std::memcpy(cache.data(), combined_param.data(), combined_param.buf_size());
/* --------- View scope ---------- */
check_params(CombinedParamsDescView(std::move(cache)));
}
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -86,8 +86,9 @@ class CombinedParamsDescView : public CombinedParamsDescReadAPI {
void InitParams() {
desc_ = proto::GetCombinedParamsDesc(buf_.data());
params_.reserve(GetParamsSize());
for (size_t idx = 0; idx < GetParamsSize(); ++idx) {
size_t params_size = desc_->params()->size();
params_.reserve(params_size);
for (size_t idx = 0; idx < params_size; ++idx) {
params_.push_back(ParamDescView(desc_->params()->Get(idx)));
}
}
......@@ -114,6 +115,7 @@ class ParamDesc : public ParamDescAPI {
}
explicit ParamDesc(proto::ParamDescT* desc) : desc_(desc) {
desc_->variable.Set(proto::ParamDesc_::LoDTensorDescT());
lod_tensor_ = desc_->variable.AsLoDTensorDesc();
CHECK(lod_tensor_);
}
......@@ -165,6 +167,7 @@ class CombinedParamsDesc : public CombinedParamsDescAPI {
raw_buf->UnPackTo(&desc_);
SyncParams();
}
const ParamDescReadAPI* GetParamDesc(size_t idx) const override {
return &params_[idx];
}
......@@ -172,7 +175,8 @@ class CombinedParamsDesc : public CombinedParamsDescAPI {
size_t GetParamsSize() const override { return desc_.params.size(); }
ParamDescWriteAPI* AddParamDesc() override {
desc_.params.push_back(std::unique_ptr<proto::ParamDescT>());
desc_.params.push_back(
std::unique_ptr<proto::ParamDescT>(new proto::ParamDescT));
SyncParams();
return &params_[params_.size() - 1];
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册